"""Hailo-specific implementation of BaseInference for model inference on Hailo hardware."""
from typing import Any, Optional
from codecarbon import track_emissions
from hailo_platform import (
HEF,
ConfigureParams,
FormatType,
HailoSchedulingAlgorithm,
HailoStreamInterface,
InferVStreams,
InputVStreamParams,
OutputVStreamParams,
VDevice,
)
from green_ai_bench.base_inference import BaseInference
class HailoInference(BaseInference):
"""Hailo-specific implementation of BaseInference for model inference on Hailo hardware.
This class handles the setup, data generation, and inference for models running on
Hailo accelerator hardware. It supports tracking power consumption during inference.
Args:
model_path (str): Path to the Hailo model file (.hef)
**kwargs: Additional arguments passed to BaseInference
"""
def __init__(self, model_path, **kwargs):
"""Initialize the HailoInference class."""
self.model: Optional[HEF] = None
self._input_info: Optional[Any] = None
super().__init__(model_path, **kwargs)
def setup_model(self):
"""Initialize and configure the Hailo model for inference.
Returns:
tuple: Contains:
- NetworkGroup: The configured network group
- Any: Network group parameters
- InputVStreamParams: Input stream parameters
- OutputVStreamParams: Output stream parameters
Raises:
RuntimeError: If model initialization fails
"""
print(self.model_path)
self.model = HEF(self.model_path)
if self.model is None:
raise RuntimeError("Failed to initialize Hailo model")
return
def generate_data(self):
"""Generate input data for the model using parameters from the input stream."""
# Define dataset params
self._input_info = self.model.get_input_vstream_infos()[0]
# output_vstream_info = self.model.get_output_vstream_infos()[0]
self.image_height, self.image_width, self.channels = self._input_info.shape
self.input_name = self._input_info.name
return super().generate_data()
def infer(self, infer_pipeline):
"""Run inference on the Hailo device with emission tracking.
Returns:
Any: Model inference results
"""
return infer_pipeline.infer(self.input_data)