green_ai_bench.base_inference

[docs] module green_ai_bench.base_inference

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Base class for model inference on various hardware platforms."""

from typing import Any, Dict, List, Optional

import numpy as np
import onnxruntime

from green_ai_bench.generate_data import generate_data


class BaseInference:
    """Base class for model inference on various hardware platforms."""

    def __init__(
        self, model_path: str, use_GPU=False, **kwargs: Dict[str, Any]
    ) -> None:
        """Initialize the BaseInference class."""
        self.model_path: str = model_path
        self.session: Optional[onnxruntime.InferenceSession] = None
        self.low: int = kwargs.get("low", 2)
        self.high: int = kwargs.get("high", 255)
        self.num_of_images: int = kwargs.get("num_of_images", 1)
        self.image_height: int = kwargs.get("image_height", 224)
        self.image_width: int = kwargs.get("image_width", 224)
        self.channels: int = kwargs.get("channels", 3)
        self.channel_first: bool = kwargs.get("channel_first", True)
        self.input_data: Optional[Dict[str, np.ndarray]] = None
        self.dataset: Optional[np.ndarray] = None
        self.input_name: Optional[str] = None
        self.use_GPU: bool = use_GPU

    def setup_model(self) -> None:
        """Initialize the ONNX model for inference."""
        if self.use_GPU:
            providers = ["CUDAExecutionProvider"]
        else:
            providers = ["CPUExecutionProvider"]
        self.session = onnxruntime.InferenceSession(
            self.model_path, providers=providers
        )
        if self.session is None:
            raise RuntimeError("Failed to initialize ONNX session")
        self.input_name = self.session.get_inputs()[0].name

    def generate_data(self) -> None:
        """Generate input data for the model."""
        if self.input_name is None:
            raise RuntimeError("Model not properly initialized")
        self.input_data, self.dataset = generate_data(
            name=self.input_name,
            low=self.low,
            high=self.high,
            num_of_images=self.num_of_images,
            image_height=self.image_height,
            image_width=self.image_width,
            channels=self.channels,
            channel_first=self.channel_first,
        )

    def infer(self) -> Optional[List[np.ndarray]]:
        """Empty inference function."""
        return None