Coverage for services/inference/src/utils/model_versioning.py: 46%

203 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-25 16:18 +0000

1""" 

2Model Versioning & A/B Testing Framework 

3========================================= 

4 

5Manages multiple model versions with support for: 

6- Loading different model versions from MLflow registry 

7- Dynamic version switching without service restart (graceful reload) 

8- A/B testing with traffic allocation 

9- Version metadata and performance tracking 

10- Fallback to previous version on failure 

11 

12T6.5: Model Versioning & A/B Testing 

13""" 

14 

15import json 

16import logging 

17from dataclasses import dataclass, field, asdict 

18from datetime import datetime 

19from enum import Enum 

20from typing import Dict, List, Optional, Tuple 

21from contextlib import asynccontextmanager 

22import asyncio 

23 

24import numpy as np 

25import onnxruntime as ort 

26 

27from ..config import settings 

28 

29 

30logger = logging.getLogger(__name__) 

31 

32 

33class ModelStage(str, Enum): 

34 """MLflow model stages""" 

35 PRODUCTION = "Production" 

36 STAGING = "Staging" 

37 ARCHIVED = "Archived" 

38 NONE = "None" 

39 

40 

41class VersionStatus(str, Enum): 

42 """Model version status during lifecycle""" 

43 LOADING = "loading" 

44 READY = "ready" 

45 STALE = "stale" 

46 ERROR = "error" 

47 DEPRECATED = "deprecated" 

48 

49 

50@dataclass 

51class ModelVersion: 

52 """Metadata for a single model version""" 

53 version_id: str # MLflow version ID 

54 stage: ModelStage = ModelStage.PRODUCTION 

55 created_at: datetime = field(default_factory=datetime.utcnow) 

56 updated_at: datetime = field(default_factory=datetime.utcnow) 

57 status: VersionStatus = VersionStatus.READY 

58 performance_metrics: Dict[str, float] = field(default_factory=dict) 

59 # performance_metrics: { 

60 # "accuracy": 0.95, 

61 # "latency_ms": 150.5, 

62 # "throughput": 45.2, 

63 # "cache_hit_rate": 0.82 

64 # } 

65 notes: str = "" 

66 is_active: bool = False 

67 error_message: Optional[str] = None 

68 

69 def to_dict(self) -> Dict: 

70 """Export to dictionary (datetime → ISO string)""" 

71 data = asdict(self) 

72 data["created_at"] = self.created_at.isoformat() 

73 data["updated_at"] = self.updated_at.isoformat() 

74 data["stage"] = self.stage.value 

75 data["status"] = self.status.value 

76 return data 

77 

78 

79@dataclass 

80class ABTestConfig: 

81 """A/B Testing configuration""" 

82 enabled: bool = False 

83 version_a: str = "" # Primary version ID 

84 version_b: str = "" # Experimental version ID 

85 traffic_split: float = 0.5 # 0.5 = 50/50 split 

86 min_traffic_split: float = 0.01 # Minimum 1% 

87 max_traffic_split: float = 0.99 # Maximum 99% 

88 auto_winner: bool = False # Auto promote winner if enabled 

89 winner_threshold: float = 0.95 # Confidence threshold for auto-promotion 

90 started_at: datetime = field(default_factory=datetime.utcnow) 

91 duration_seconds: Optional[int] = None # None = unlimited 

92 metrics_to_compare: List[str] = field(default_factory=lambda: ["accuracy", "latency_ms"]) 

93 

94 def is_active(self) -> bool: 

95 """Check if A/B test is still active""" 

96 if not self.enabled: 

97 return False 

98 if self.duration_seconds is None: 

99 return True 

100 elapsed = (datetime.utcnow() - self.started_at).total_seconds() 

101 return elapsed < self.duration_seconds 

102 

103 def to_dict(self) -> Dict: 

104 """Export to dictionary""" 

105 data = asdict(self) 

106 data["started_at"] = self.started_at.isoformat() 

107 return data 

108 

109 

110class ModelVersionRegistry: 

111 """ 

112 Manages model versions with versioning and A/B testing support. 

113  

114 Architecture: 

115 - Maintains registry of loaded model versions 

116 - Supports multiple concurrent versions (memory overhead vs flexibility) 

117 - Routes predictions based on A/B test config 

118 - Provides graceful fallback on version failure 

119  

120 Usage: 

121 registry = ModelVersionRegistry(mlflow_tracking_uri) 

122 await registry.load_version("v1", "Production") 

123 await registry.set_active_version("v1") 

124 prediction = await registry.predict(features) 

125 """ 

126 

127 def __init__( 

128 self, 

129 mlflow_tracking_uri: str = None, 

130 max_versions: int = 5, 

131 session_options: Optional[ort.SessionOptions] = None 

132 ): 

133 """ 

134 Initialize version registry. 

135  

136 Args: 

137 mlflow_tracking_uri: MLflow server endpoint (defaults to settings.mlflow_tracking_uri) 

138 max_versions: Maximum concurrent loaded versions 

139 session_options: ONNX Runtime session configuration 

140 """ 

141 self.mlflow_uri = mlflow_tracking_uri or settings.mlflow_tracking_uri 

142 self.max_versions = max_versions 

143 self.session_options = session_options or ort.SessionOptions() 

144 

145 # Core registry 

146 self.versions: Dict[str, Tuple[ort.InferenceSession, ModelVersion]] = {} 

147 self.active_version_id: Optional[str] = None 

148 self.previous_version_id: Optional[str] = None # For fallback 

149 

150 # A/B Testing 

151 self.ab_test_config = ABTestConfig(enabled=False) 

152 self.ab_test_stats = {"routed_to_a": 0, "routed_to_b": 0} 

153 

154 # Metrics 

155 self.version_metrics: Dict[str, Dict] = {} 

156 self._lock = asyncio.Lock() # Async lock for thread-safe operations 

157 

158 logger.info( 

159 f"ModelVersionRegistry initialized (max_versions={max_versions}, " 

160 f"mlflow={mlflow_tracking_uri})" 

161 ) 

162 

163 async def load_version( 

164 self, 

165 version_id: str, 

166 stage: ModelStage = ModelStage.PRODUCTION, 

167 model_path: Optional[str] = None, 

168 metadata: Optional[Dict] = None 

169 ) -> bool: 

170 """ 

171 Load a model version from MLflow or local path. 

172  

173 Args: 

174 version_id: Unique version identifier (e.g., "v1", "prod-2025-10-22") 

175 stage: MLflow stage (Production, Staging, Archived) 

176 model_path: Path to ONNX model (if not in MLflow) 

177 metadata: Optional performance metrics 

178  

179 Returns: 

180 True if successful, False on error 

181  

182 Raises: 

183 ValueError: If max_versions exceeded and cannot unload old version 

184 """ 

185 async with self._lock: 

186 # Check if already loaded 

187 if version_id in self.versions: 

188 logger.warning(f"Version {version_id} already loaded, skipping") 

189 return True 

190 

191 # Check capacity 

192 if len(self.versions) >= self.max_versions: 

193 logger.warning( 

194 f"Version registry full (max={self.max_versions}), " 

195 f"consider unloading unused versions" 

196 ) 

197 # Could implement automatic LRU unload here 

198 

199 try: 

200 # In production: load from MLflow registry 

201 # For now: simulate loading from local path 

202 if model_path is None: 

203 model_path = f"models/{version_id}/model.onnx" 

204 

205 # Load ONNX session 

206 logger.info(f"Loading model version {version_id} from {model_path}") 

207 session = ort.InferenceSession( 

208 model_path, 

209 self.session_options, 

210 providers=['CPUExecutionProvider'] 

211 ) 

212 

213 # Create metadata 

214 version_meta = ModelVersion( 

215 version_id=version_id, 

216 stage=stage, 

217 status=VersionStatus.READY, 

218 performance_metrics=metadata or {} 

219 ) 

220 

221 self.versions[version_id] = (session, version_meta) 

222 logger.info(f"✅ Version {version_id} loaded successfully (stage={stage.value})") 

223 return True 

224 

225 except FileNotFoundError as e: 

226 logger.error(f"❌ Model file not found for {version_id}: {e}") 

227 return False 

228 except Exception as e: 

229 logger.error(f"❌ Error loading version {version_id}: {e}") 

230 version_meta = ModelVersion( 

231 version_id=version_id, 

232 stage=stage, 

233 status=VersionStatus.ERROR, 

234 error_message=str(e) 

235 ) 

236 self.versions[version_id] = (None, version_meta) 

237 return False 

238 

239 async def set_active_version(self, version_id: str) -> bool: 

240 """ 

241 Set active prediction version. 

242  

243 Args: 

244 version_id: Version to activate 

245  

246 Returns: 

247 True if successful 

248 """ 

249 async with self._lock: 

250 if version_id not in self.versions: 

251 logger.error(f"Version {version_id} not loaded") 

252 return False 

253 

254 session, meta = self.versions[version_id] 

255 if session is None: 

256 logger.error(f"Version {version_id} has error, cannot activate") 

257 return False 

258 

259 # Store previous for fallback 

260 if self.active_version_id: 

261 self.previous_version_id = self.active_version_id 

262 

263 self.active_version_id = version_id 

264 meta.is_active = True 

265 meta.updated_at = datetime.utcnow() 

266 

267 logger.info(f"🔄 Active version switched to {version_id}") 

268 return True 

269 

270 async def unload_version(self, version_id: str) -> bool: 

271 """ 

272 Unload a model version to free memory. 

273  

274 Args: 

275 version_id: Version to unload 

276  

277 Returns: 

278 True if successful 

279 """ 

280 async with self._lock: 

281 if version_id not in self.versions: 

282 return False 

283 

284 if version_id == self.active_version_id: 

285 logger.warning(f"Cannot unload active version {version_id}") 

286 return False 

287 

288 del self.versions[version_id] 

289 logger.info(f"Version {version_id} unloaded (memory freed)") 

290 return True 

291 

292 def start_ab_test( 

293 self, 

294 version_a: str, 

295 version_b: str, 

296 traffic_split: float = 0.5, 

297 duration_seconds: Optional[int] = None 

298 ) -> bool: 

299 """ 

300 Start A/B test between two model versions. 

301  

302 Args: 

303 version_a: Primary/control version 

304 version_b: Experimental/treatment version 

305 traffic_split: Proportion to route to version_b (0.5 = 50/50) 

306 duration_seconds: Test duration (None = unlimited) 

307  

308 Returns: 

309 True if A/B test started 

310 """ 

311 if version_a not in self.versions or version_b not in self.versions: 

312 logger.error(f"One or both versions not loaded (A={version_a}, B={version_b})") 

313 return False 

314 

315 traffic_split = max(0.01, min(0.99, traffic_split)) # Clamp to [0.01, 0.99] 

316 

317 self.ab_test_config = ABTestConfig( 

318 enabled=True, 

319 version_a=version_a, 

320 version_b=version_b, 

321 traffic_split=traffic_split, 

322 started_at=datetime.utcnow(), 

323 duration_seconds=duration_seconds 

324 ) 

325 self.ab_test_stats = {"routed_to_a": 0, "routed_to_b": 0} 

326 

327 logger.info( 

328 f"🧪 A/B Test started: {version_a} vs {version_b} " 

329 f"(split={traffic_split:.1%}, duration={duration_seconds}s)" 

330 ) 

331 return True 

332 

333 def end_ab_test(self, winner: Optional[str] = None) -> bool: 

334 """ 

335 End A/B test and optionally promote winner. 

336  

337 Args: 

338 winner: Version to promote (None = no promotion) 

339  

340 Returns: 

341 True if successful 

342 """ 

343 self.ab_test_config.enabled = False 

344 

345 logger.info( 

346 f"🏁 A/B Test ended. Stats: " 

347 f"A={self.ab_test_stats['routed_to_a']}, " 

348 f"B={self.ab_test_stats['routed_to_b']}" 

349 ) 

350 

351 if winner and winner in self.versions: 

352 logger.info(f"✨ Promoting {winner} to active version") 

353 return asyncio.run(self.set_active_version(winner)) 

354 

355 return True 

356 

357 async def predict( 

358 self, 

359 features: np.ndarray, 

360 use_ab_routing: bool = True 

361 ) -> Tuple[np.ndarray, str]: 

362 """ 

363 Run prediction on active version (or routed via A/B test). 

364  

365 Args: 

366 features: Input features (mel-spectrogram, shape expected by model) 

367 use_ab_routing: Route through A/B test if enabled 

368  

369 Returns: 

370 Tuple of (prediction, version_id_used) 

371  

372 Raises: 

373 RuntimeError: If no active version or prediction fails 

374 """ 

375 if not self.active_version_id: 

376 raise RuntimeError("No active model version loaded") 

377 

378 # Determine which version to use 

379 version_id = self.active_version_id 

380 

381 if use_ab_routing and self.ab_test_config.is_active(): 

382 # Route based on traffic split 

383 if np.random.random() < self.ab_test_config.traffic_split: 

384 version_id = self.ab_test_config.version_b 

385 self.ab_test_stats["routed_to_b"] += 1 

386 else: 

387 version_id = self.ab_test_config.version_a 

388 self.ab_test_stats["routed_to_a"] += 1 

389 

390 # Run inference with fallback 

391 try: 

392 session, meta = self.versions[version_id] 

393 if session is None: 

394 raise RuntimeError(f"Version {version_id} has error") 

395 

396 # ONNX input/output names 

397 input_name = session.get_inputs()[0].name 

398 output_names = [o.name for o in session.get_outputs()] 

399 

400 # Ensure proper shape (add batch dimension if needed) 

401 if features.ndim == 2: 

402 features = np.expand_dims(features, axis=0) 

403 

404 # Run inference 

405 outputs = session.run(output_names, {input_name: features.astype(np.float32)}) 

406 

407 logger.debug(f"Prediction via version {version_id}: {outputs[0].shape}") 

408 return outputs[0], version_id 

409 

410 except Exception as e: 

411 logger.error(f"Prediction failed on {version_id}: {e}") 

412 

413 # Fallback to previous version if available 

414 if self.previous_version_id and self.previous_version_id != version_id: 

415 logger.warning(f"Falling back to previous version {self.previous_version_id}") 

416 try: 

417 session, meta = self.versions[self.previous_version_id] 

418 if session: 

419 input_name = session.get_inputs()[0].name 

420 output_names = [o.name for o in session.get_outputs()] 

421 outputs = session.run(output_names, {input_name: features.astype(np.float32)}) 

422 return outputs[0], self.previous_version_id 

423 except Exception as fallback_error: 

424 logger.error(f"Fallback failed: {fallback_error}") 

425 

426 raise RuntimeError(f"Prediction failed with no fallback: {e}") 

427 

428 def get_version_info(self, version_id: str) -> Optional[Dict]: 

429 """Get metadata for a specific version""" 

430 if version_id not in self.versions: 

431 return None 

432 

433 _, meta = self.versions[version_id] 

434 return meta.to_dict() 

435 

436 def list_versions(self) -> List[Dict]: 

437 """List all loaded versions with metadata""" 

438 return [meta.to_dict() for _, meta in self.versions.values()] 

439 

440 def get_registry_status(self) -> Dict: 

441 """Get overall registry status""" 

442 return { 

443 "active_version": self.active_version_id, 

444 "previous_version": self.previous_version_id, 

445 "loaded_versions": len(self.versions), 

446 "max_versions": self.max_versions, 

447 "versions": self.list_versions(), 

448 "ab_test": self.ab_test_config.to_dict() if self.ab_test_config.enabled else None, 

449 "ab_test_stats": self.ab_test_stats 

450 } 

451 

452 @asynccontextmanager 

453 async def model_context(self, version_id: str): 

454 """ 

455 Context manager for temporary version switching. 

456  

457 Usage: 

458 async with registry.model_context("v2"): 

459 result = await registry.predict(features) 

460 # Uses v2 temporarily 

461 # Back to original active version 

462 """ 

463 original_version = self.active_version_id 

464 try: 

465 if await self.set_active_version(version_id): 

466 yield 

467 else: 

468 raise RuntimeError(f"Failed to switch to {version_id}") 

469 finally: 

470 if original_version: 

471 await self.set_active_version(original_version) 

472 

473 

474async def create_version_registry( 

475 mlflow_uri: str = None, 

476 initial_versions: Optional[Dict[str, str]] = None 

477) -> ModelVersionRegistry: 

478 """ 

479 Factory function to create and initialize version registry. 

480  

481 Args: 

482 mlflow_uri: MLflow tracking server (defaults to settings.mlflow_tracking_uri) 

483 initial_versions: Dict of {version_id: model_path} to preload 

484  

485 Returns: 

486 Initialized ModelVersionRegistry 

487  

488 Example: 

489 registry = await create_version_registry( 

490 initial_versions={"v1": "models/v1.onnx"} 

491 ) 

492 await registry.set_active_version("v1") 

493 """ 

494 registry = ModelVersionRegistry(mlflow_tracking_uri=mlflow_uri) 

495 

496 if initial_versions: 

497 for version_id, model_path in initial_versions.items(): 

498 await registry.load_version( 

499 version_id, 

500 stage=ModelStage.PRODUCTION, 

501 model_path=model_path 

502 ) 

503 

504 return registry