Coverage for services/inference/src/models/onnx_loader.py: 91%

90 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-25 16:18 +0000

1"""ONNX Model Loader for Phase 6 Inference Service.""" 

2import logging 

3from typing import Dict, Optional 

4import numpy as np 

5import onnxruntime as ort 

6import mlflow 

7from mlflow.tracking import MlflowClient 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12class ONNXModelLoader: 

13 """Load and manage ONNX model from MLflow registry.""" 

14 

15 def __init__( 

16 self, 

17 mlflow_uri: str, 

18 model_name: str = "localization_model", 

19 stage: str = "Production", 

20 ): 

21 """ 

22 Initialize ONNX Model Loader. 

23  

24 Args: 

25 mlflow_uri: MLflow tracking URI (e.g., "http://mlflow:5000") 

26 model_name: Registered model name in MLflow 

27 stage: Model stage ("Production", "Staging", "None") 

28  

29 Raises: 

30 ValueError: If model not found in registry or stage 

31 RuntimeError: If ONNX session initialization fails 

32 """ 

33 self.mlflow_uri = mlflow_uri 

34 self.model_name = model_name 

35 self.stage = stage 

36 self.session = None 

37 self.model_metadata = None 

38 self.reload_count = 0 

39 

40 # Initialize MLflow client 

41 mlflow.set_tracking_uri(mlflow_uri) 

42 self.client = MlflowClient(tracking_uri=mlflow_uri) 

43 

44 # Load model on initialization 

45 self._load_model() 

46 

47 def _load_model(self) -> None: 

48 """ 

49 Load ONNX model from MLflow registry. 

50  

51 Raises: 

52 ValueError: If model not found or stage invalid 

53 RuntimeError: If ONNX session init fails 

54 """ 

55 try: 

56 logger.info( 

57 f"Loading model '{self.model_name}' from stage '{self.stage}' " 

58 f"(MLflow URI: {self.mlflow_uri})" 

59 ) 

60 

61 # Get all model versions 

62 model_versions = self.client.search_model_versions( 

63 f"name='{self.model_name}'" 

64 ) 

65 

66 if not model_versions: 

67 raise ValueError( 

68 f"Model '{self.model_name}' not found in MLflow registry" 

69 ) 

70 

71 # Find version in requested stage 

72 version = None 

73 for mv in model_versions: 

74 if mv.current_stage == self.stage: 

75 version = mv 

76 break 

77 

78 if version is None: 

79 available_stages = {mv.current_stage for mv in model_versions} 

80 raise ValueError( 

81 f"Model '{self.model_name}' not found in stage '{self.stage}'. " 

82 f"Available stages: {available_stages}" 

83 ) 

84 

85 # Download model artifact from MLflow 

86 model_uri = f"models:/{self.model_name}/{self.stage}" 

87 logger.info(f"Downloading model from {model_uri}") 

88 local_path = mlflow.artifacts.download_artifacts(model_uri) 

89 

90 # Initialize ONNX Runtime session with optimizations 

91 sess_options = ort.SessionOptions() 

92 sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL 

93 sess_options.log_severity_level = 3 # Suppress ONNX warnings 

94 

95 model_path = f"{local_path}/model.onnx" 

96 logger.info(f"Creating ONNX Runtime session from {model_path}") 

97 

98 self.session = ort.InferenceSession( 

99 model_path, 

100 sess_options, 

101 providers=["CPUExecutionProvider"], 

102 ) 

103 

104 # Get model input/output info 

105 input_info = self.session.get_inputs() 

106 output_info = self.session.get_outputs() 

107 

108 logger.info( 

109 f"ONNX Model loaded successfully. " 

110 f"Inputs: {len(input_info)}, Outputs: {len(output_info)}" 

111 ) 

112 

113 # Store metadata 

114 self.model_metadata = { 

115 "model_name": self.model_name, 

116 "version": version.version, 

117 "stage": self.stage, 

118 "run_id": version.run_id, 

119 "created_at": str(version.creation_timestamp), 

120 "status": version.status, 

121 "input_name": input_info[0].name, 

122 "input_shape": input_info[0].shape, 

123 "output_names": [out.name for out in output_info], 

124 "output_shapes": [out.shape for out in output_info], 

125 "reload_count": self.reload_count, 

126 } 

127 

128 self.reload_count += 1 

129 logger.info(f"Model metadata: {self.model_metadata}") 

130 

131 except Exception as e: 

132 logger.error(f"Failed to load ONNX model: {e}", exc_info=True) 

133 raise 

134 

135 def predict(self, features: np.ndarray) -> Dict: 

136 """ 

137 Run ONNX inference. 

138  

139 Args: 

140 features: Input features (numpy array). 

141 Shape depends on model (typically [batch, feature_dim] or [feature_dim]) 

142  

143 Returns: 

144 Dict with keys: 

145 - position: {latitude: float, longitude: float} 

146 - uncertainty: {sigma_x: float, sigma_y: float, theta: float} 

147 - confidence: float (0-1) 

148  

149 Raises: 

150 RuntimeError: If model not loaded 

151 ValueError: If input validation fails 

152 """ 

153 if self.session is None: 

154 raise RuntimeError("Model not loaded. Call _load_model() first.") 

155 

156 try: 

157 # Validate and reshape input 

158 if not isinstance(features, np.ndarray): 

159 features = np.array(features, dtype=np.float32) 

160 

161 if features.ndim == 1: 

162 features = features[np.newaxis, ...] # Add batch dimension 

163 

164 features = features.astype(np.float32) 

165 

166 # Get model input/output names 

167 input_name = self.session.get_inputs()[0].name 

168 output_names = [out.name for out in self.session.get_outputs()] 

169 

170 logger.debug(f"Running inference with input shape {features.shape}") 

171 

172 # Run inference 

173 outputs = self.session.run( 

174 output_names, 

175 {input_name: features}, 

176 ) 

177 

178 logger.debug(f"Inference complete. Outputs: {len(outputs)}") 

179 

180 # Parse outputs (assuming Phase 5 model outputs): 

181 # Output 0: position [batch, 2] -> (lat, lon) 

182 # Output 1: uncertainty [batch, 3] -> (sigma_x, sigma_y, theta) 

183 # Output 2: confidence [batch, 1] -> probability 

184 

185 position = outputs[0][0] # First batch, position 

186 uncertainty = outputs[1][0] if len(outputs) > 1 else np.array([0.0, 0.0, 0.0]) 

187 confidence = outputs[2][0] if len(outputs) > 2 else np.array([1.0]) 

188 

189 # Ensure arrays have enough elements 

190 if len(uncertainty) < 3: 

191 uncertainty = np.pad(uncertainty, (0, 3 - len(uncertainty)), 'constant') 

192 

193 result = { 

194 "position": { 

195 "latitude": float(position[0]), 

196 "longitude": float(position[1]), 

197 }, 

198 "uncertainty": { 

199 "sigma_x": float(uncertainty[0]), 

200 "sigma_y": float(uncertainty[1]), 

201 "theta": float(uncertainty[2]), 

202 }, 

203 "confidence": float(confidence[0]), 

204 } 

205 

206 logger.debug(f"Prediction result: {result}") 

207 return result 

208 

209 except Exception as e: 

210 logger.error(f"Inference failed: {e}", exc_info=True) 

211 raise 

212 

213 def get_metadata(self) -> Dict: 

214 """ 

215 Return model metadata. 

216  

217 Returns: 

218 Dict with model information from MLflow registry 

219 """ 

220 if self.model_metadata is None: 

221 logger.warning("Model metadata not available") 

222 return {} 

223 

224 return self.model_metadata.copy() 

225 

226 def reload(self) -> None: 

227 """ 

228 Reload model from MLflow registry (for graceful updates). 

229  

230 Useful for updating model without restarting service. 

231 """ 

232 logger.info(f"Reloading model '{self.model_name}' from MLflow...") 

233 try: 

234 self._load_model() 

235 logger.info("Model reloaded successfully") 

236 except Exception as e: 

237 logger.error(f"Model reload failed: {e}", exc_info=True) 

238 raise 

239 

240 def is_ready(self) -> bool: 

241 """ 

242 Check if model is ready for inference. 

243  

244 Returns: 

245 True if model loaded and session active, False otherwise 

246 """ 

247 return self.session is not None and self.model_metadata is not None