Coverage for services/inference/src/models/onnx_loader.py: 91%
90 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
1"""ONNX Model Loader for Phase 6 Inference Service."""
2import logging
3from typing import Dict, Optional
4import numpy as np
5import onnxruntime as ort
6import mlflow
7from mlflow.tracking import MlflowClient
9logger = logging.getLogger(__name__)
12class ONNXModelLoader:
13 """Load and manage ONNX model from MLflow registry."""
15 def __init__(
16 self,
17 mlflow_uri: str,
18 model_name: str = "localization_model",
19 stage: str = "Production",
20 ):
21 """
22 Initialize ONNX Model Loader.
24 Args:
25 mlflow_uri: MLflow tracking URI (e.g., "http://mlflow:5000")
26 model_name: Registered model name in MLflow
27 stage: Model stage ("Production", "Staging", "None")
29 Raises:
30 ValueError: If model not found in registry or stage
31 RuntimeError: If ONNX session initialization fails
32 """
33 self.mlflow_uri = mlflow_uri
34 self.model_name = model_name
35 self.stage = stage
36 self.session = None
37 self.model_metadata = None
38 self.reload_count = 0
40 # Initialize MLflow client
41 mlflow.set_tracking_uri(mlflow_uri)
42 self.client = MlflowClient(tracking_uri=mlflow_uri)
44 # Load model on initialization
45 self._load_model()
47 def _load_model(self) -> None:
48 """
49 Load ONNX model from MLflow registry.
51 Raises:
52 ValueError: If model not found or stage invalid
53 RuntimeError: If ONNX session init fails
54 """
55 try:
56 logger.info(
57 f"Loading model '{self.model_name}' from stage '{self.stage}' "
58 f"(MLflow URI: {self.mlflow_uri})"
59 )
61 # Get all model versions
62 model_versions = self.client.search_model_versions(
63 f"name='{self.model_name}'"
64 )
66 if not model_versions:
67 raise ValueError(
68 f"Model '{self.model_name}' not found in MLflow registry"
69 )
71 # Find version in requested stage
72 version = None
73 for mv in model_versions:
74 if mv.current_stage == self.stage:
75 version = mv
76 break
78 if version is None:
79 available_stages = {mv.current_stage for mv in model_versions}
80 raise ValueError(
81 f"Model '{self.model_name}' not found in stage '{self.stage}'. "
82 f"Available stages: {available_stages}"
83 )
85 # Download model artifact from MLflow
86 model_uri = f"models:/{self.model_name}/{self.stage}"
87 logger.info(f"Downloading model from {model_uri}")
88 local_path = mlflow.artifacts.download_artifacts(model_uri)
90 # Initialize ONNX Runtime session with optimizations
91 sess_options = ort.SessionOptions()
92 sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
93 sess_options.log_severity_level = 3 # Suppress ONNX warnings
95 model_path = f"{local_path}/model.onnx"
96 logger.info(f"Creating ONNX Runtime session from {model_path}")
98 self.session = ort.InferenceSession(
99 model_path,
100 sess_options,
101 providers=["CPUExecutionProvider"],
102 )
104 # Get model input/output info
105 input_info = self.session.get_inputs()
106 output_info = self.session.get_outputs()
108 logger.info(
109 f"ONNX Model loaded successfully. "
110 f"Inputs: {len(input_info)}, Outputs: {len(output_info)}"
111 )
113 # Store metadata
114 self.model_metadata = {
115 "model_name": self.model_name,
116 "version": version.version,
117 "stage": self.stage,
118 "run_id": version.run_id,
119 "created_at": str(version.creation_timestamp),
120 "status": version.status,
121 "input_name": input_info[0].name,
122 "input_shape": input_info[0].shape,
123 "output_names": [out.name for out in output_info],
124 "output_shapes": [out.shape for out in output_info],
125 "reload_count": self.reload_count,
126 }
128 self.reload_count += 1
129 logger.info(f"Model metadata: {self.model_metadata}")
131 except Exception as e:
132 logger.error(f"Failed to load ONNX model: {e}", exc_info=True)
133 raise
135 def predict(self, features: np.ndarray) -> Dict:
136 """
137 Run ONNX inference.
139 Args:
140 features: Input features (numpy array).
141 Shape depends on model (typically [batch, feature_dim] or [feature_dim])
143 Returns:
144 Dict with keys:
145 - position: {latitude: float, longitude: float}
146 - uncertainty: {sigma_x: float, sigma_y: float, theta: float}
147 - confidence: float (0-1)
149 Raises:
150 RuntimeError: If model not loaded
151 ValueError: If input validation fails
152 """
153 if self.session is None:
154 raise RuntimeError("Model not loaded. Call _load_model() first.")
156 try:
157 # Validate and reshape input
158 if not isinstance(features, np.ndarray):
159 features = np.array(features, dtype=np.float32)
161 if features.ndim == 1:
162 features = features[np.newaxis, ...] # Add batch dimension
164 features = features.astype(np.float32)
166 # Get model input/output names
167 input_name = self.session.get_inputs()[0].name
168 output_names = [out.name for out in self.session.get_outputs()]
170 logger.debug(f"Running inference with input shape {features.shape}")
172 # Run inference
173 outputs = self.session.run(
174 output_names,
175 {input_name: features},
176 )
178 logger.debug(f"Inference complete. Outputs: {len(outputs)}")
180 # Parse outputs (assuming Phase 5 model outputs):
181 # Output 0: position [batch, 2] -> (lat, lon)
182 # Output 1: uncertainty [batch, 3] -> (sigma_x, sigma_y, theta)
183 # Output 2: confidence [batch, 1] -> probability
185 position = outputs[0][0] # First batch, position
186 uncertainty = outputs[1][0] if len(outputs) > 1 else np.array([0.0, 0.0, 0.0])
187 confidence = outputs[2][0] if len(outputs) > 2 else np.array([1.0])
189 # Ensure arrays have enough elements
190 if len(uncertainty) < 3:
191 uncertainty = np.pad(uncertainty, (0, 3 - len(uncertainty)), 'constant')
193 result = {
194 "position": {
195 "latitude": float(position[0]),
196 "longitude": float(position[1]),
197 },
198 "uncertainty": {
199 "sigma_x": float(uncertainty[0]),
200 "sigma_y": float(uncertainty[1]),
201 "theta": float(uncertainty[2]),
202 },
203 "confidence": float(confidence[0]),
204 }
206 logger.debug(f"Prediction result: {result}")
207 return result
209 except Exception as e:
210 logger.error(f"Inference failed: {e}", exc_info=True)
211 raise
213 def get_metadata(self) -> Dict:
214 """
215 Return model metadata.
217 Returns:
218 Dict with model information from MLflow registry
219 """
220 if self.model_metadata is None:
221 logger.warning("Model metadata not available")
222 return {}
224 return self.model_metadata.copy()
226 def reload(self) -> None:
227 """
228 Reload model from MLflow registry (for graceful updates).
230 Useful for updating model without restarting service.
231 """
232 logger.info(f"Reloading model '{self.model_name}' from MLflow...")
233 try:
234 self._load_model()
235 logger.info("Model reloaded successfully")
236 except Exception as e:
237 logger.error(f"Model reload failed: {e}", exc_info=True)
238 raise
240 def is_ready(self) -> bool:
241 """
242 Check if model is ready for inference.
244 Returns:
245 True if model loaded and session active, False otherwise
246 """
247 return self.session is not None and self.model_metadata is not None