"""Base class for model inference on various hardware platforms."""
from typing import Any, Dict, List, Optional
import numpy as np
import onnxruntime
from green_ai_bench.generate_data import generate_data
class BaseInference:
"""Base class for model inference on various hardware platforms."""
def __init__(
self, model_path: str, use_GPU=False, **kwargs: Dict[str, Any]
) -> None:
"""Initialize the BaseInference class."""
self.model_path: str = model_path
self.session: Optional[onnxruntime.InferenceSession] = None
self.low: int = kwargs.get("low", 2)
self.high: int = kwargs.get("high", 255)
self.num_of_images: int = kwargs.get("num_of_images", 1)
self.image_height: int = kwargs.get("image_height", 224)
self.image_width: int = kwargs.get("image_width", 224)
self.channels: int = kwargs.get("channels", 3)
self.channel_first: bool = kwargs.get("channel_first", True)
self.input_data: Optional[Dict[str, np.ndarray]] = None
self.dataset: Optional[np.ndarray] = None
self.input_name: Optional[str] = None
self.use_GPU: bool = use_GPU
def setup_model(self) -> None:
"""Initialize the ONNX model for inference."""
if self.use_GPU:
providers = ["CUDAExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
self.session = onnxruntime.InferenceSession(
self.model_path, providers=providers
)
if self.session is None:
raise RuntimeError("Failed to initialize ONNX session")
self.input_name = self.session.get_inputs()[0].name
def generate_data(self) -> None:
"""Generate input data for the model."""
if self.input_name is None:
raise RuntimeError("Model not properly initialized")
self.input_data, self.dataset = generate_data(
name=self.input_name,
low=self.low,
high=self.high,
num_of_images=self.num_of_images,
image_height=self.image_height,
image_width=self.image_width,
channels=self.channels,
channel_first=self.channel_first,
)
def infer(self) -> Optional[List[np.ndarray]]:
"""Empty inference function."""
return None