Coverage for services/inference/src/utils/model_versioning.py: 46%
203 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
1"""
2Model Versioning & A/B Testing Framework
3=========================================
5Manages multiple model versions with support for:
6- Loading different model versions from MLflow registry
7- Dynamic version switching without service restart (graceful reload)
8- A/B testing with traffic allocation
9- Version metadata and performance tracking
10- Fallback to previous version on failure
12T6.5: Model Versioning & A/B Testing
13"""
15import json
16import logging
17from dataclasses import dataclass, field, asdict
18from datetime import datetime
19from enum import Enum
20from typing import Dict, List, Optional, Tuple
21from contextlib import asynccontextmanager
22import asyncio
24import numpy as np
25import onnxruntime as ort
27from ..config import settings
30logger = logging.getLogger(__name__)
33class ModelStage(str, Enum):
34 """MLflow model stages"""
35 PRODUCTION = "Production"
36 STAGING = "Staging"
37 ARCHIVED = "Archived"
38 NONE = "None"
41class VersionStatus(str, Enum):
42 """Model version status during lifecycle"""
43 LOADING = "loading"
44 READY = "ready"
45 STALE = "stale"
46 ERROR = "error"
47 DEPRECATED = "deprecated"
50@dataclass
51class ModelVersion:
52 """Metadata for a single model version"""
53 version_id: str # MLflow version ID
54 stage: ModelStage = ModelStage.PRODUCTION
55 created_at: datetime = field(default_factory=datetime.utcnow)
56 updated_at: datetime = field(default_factory=datetime.utcnow)
57 status: VersionStatus = VersionStatus.READY
58 performance_metrics: Dict[str, float] = field(default_factory=dict)
59 # performance_metrics: {
60 # "accuracy": 0.95,
61 # "latency_ms": 150.5,
62 # "throughput": 45.2,
63 # "cache_hit_rate": 0.82
64 # }
65 notes: str = ""
66 is_active: bool = False
67 error_message: Optional[str] = None
69 def to_dict(self) -> Dict:
70 """Export to dictionary (datetime → ISO string)"""
71 data = asdict(self)
72 data["created_at"] = self.created_at.isoformat()
73 data["updated_at"] = self.updated_at.isoformat()
74 data["stage"] = self.stage.value
75 data["status"] = self.status.value
76 return data
79@dataclass
80class ABTestConfig:
81 """A/B Testing configuration"""
82 enabled: bool = False
83 version_a: str = "" # Primary version ID
84 version_b: str = "" # Experimental version ID
85 traffic_split: float = 0.5 # 0.5 = 50/50 split
86 min_traffic_split: float = 0.01 # Minimum 1%
87 max_traffic_split: float = 0.99 # Maximum 99%
88 auto_winner: bool = False # Auto promote winner if enabled
89 winner_threshold: float = 0.95 # Confidence threshold for auto-promotion
90 started_at: datetime = field(default_factory=datetime.utcnow)
91 duration_seconds: Optional[int] = None # None = unlimited
92 metrics_to_compare: List[str] = field(default_factory=lambda: ["accuracy", "latency_ms"])
94 def is_active(self) -> bool:
95 """Check if A/B test is still active"""
96 if not self.enabled:
97 return False
98 if self.duration_seconds is None:
99 return True
100 elapsed = (datetime.utcnow() - self.started_at).total_seconds()
101 return elapsed < self.duration_seconds
103 def to_dict(self) -> Dict:
104 """Export to dictionary"""
105 data = asdict(self)
106 data["started_at"] = self.started_at.isoformat()
107 return data
110class ModelVersionRegistry:
111 """
112 Manages model versions with versioning and A/B testing support.
114 Architecture:
115 - Maintains registry of loaded model versions
116 - Supports multiple concurrent versions (memory overhead vs flexibility)
117 - Routes predictions based on A/B test config
118 - Provides graceful fallback on version failure
120 Usage:
121 registry = ModelVersionRegistry(mlflow_tracking_uri)
122 await registry.load_version("v1", "Production")
123 await registry.set_active_version("v1")
124 prediction = await registry.predict(features)
125 """
127 def __init__(
128 self,
129 mlflow_tracking_uri: str = None,
130 max_versions: int = 5,
131 session_options: Optional[ort.SessionOptions] = None
132 ):
133 """
134 Initialize version registry.
136 Args:
137 mlflow_tracking_uri: MLflow server endpoint (defaults to settings.mlflow_tracking_uri)
138 max_versions: Maximum concurrent loaded versions
139 session_options: ONNX Runtime session configuration
140 """
141 self.mlflow_uri = mlflow_tracking_uri or settings.mlflow_tracking_uri
142 self.max_versions = max_versions
143 self.session_options = session_options or ort.SessionOptions()
145 # Core registry
146 self.versions: Dict[str, Tuple[ort.InferenceSession, ModelVersion]] = {}
147 self.active_version_id: Optional[str] = None
148 self.previous_version_id: Optional[str] = None # For fallback
150 # A/B Testing
151 self.ab_test_config = ABTestConfig(enabled=False)
152 self.ab_test_stats = {"routed_to_a": 0, "routed_to_b": 0}
154 # Metrics
155 self.version_metrics: Dict[str, Dict] = {}
156 self._lock = asyncio.Lock() # Async lock for thread-safe operations
158 logger.info(
159 f"ModelVersionRegistry initialized (max_versions={max_versions}, "
160 f"mlflow={mlflow_tracking_uri})"
161 )
163 async def load_version(
164 self,
165 version_id: str,
166 stage: ModelStage = ModelStage.PRODUCTION,
167 model_path: Optional[str] = None,
168 metadata: Optional[Dict] = None
169 ) -> bool:
170 """
171 Load a model version from MLflow or local path.
173 Args:
174 version_id: Unique version identifier (e.g., "v1", "prod-2025-10-22")
175 stage: MLflow stage (Production, Staging, Archived)
176 model_path: Path to ONNX model (if not in MLflow)
177 metadata: Optional performance metrics
179 Returns:
180 True if successful, False on error
182 Raises:
183 ValueError: If max_versions exceeded and cannot unload old version
184 """
185 async with self._lock:
186 # Check if already loaded
187 if version_id in self.versions:
188 logger.warning(f"Version {version_id} already loaded, skipping")
189 return True
191 # Check capacity
192 if len(self.versions) >= self.max_versions:
193 logger.warning(
194 f"Version registry full (max={self.max_versions}), "
195 f"consider unloading unused versions"
196 )
197 # Could implement automatic LRU unload here
199 try:
200 # In production: load from MLflow registry
201 # For now: simulate loading from local path
202 if model_path is None:
203 model_path = f"models/{version_id}/model.onnx"
205 # Load ONNX session
206 logger.info(f"Loading model version {version_id} from {model_path}")
207 session = ort.InferenceSession(
208 model_path,
209 self.session_options,
210 providers=['CPUExecutionProvider']
211 )
213 # Create metadata
214 version_meta = ModelVersion(
215 version_id=version_id,
216 stage=stage,
217 status=VersionStatus.READY,
218 performance_metrics=metadata or {}
219 )
221 self.versions[version_id] = (session, version_meta)
222 logger.info(f"✅ Version {version_id} loaded successfully (stage={stage.value})")
223 return True
225 except FileNotFoundError as e:
226 logger.error(f"❌ Model file not found for {version_id}: {e}")
227 return False
228 except Exception as e:
229 logger.error(f"❌ Error loading version {version_id}: {e}")
230 version_meta = ModelVersion(
231 version_id=version_id,
232 stage=stage,
233 status=VersionStatus.ERROR,
234 error_message=str(e)
235 )
236 self.versions[version_id] = (None, version_meta)
237 return False
239 async def set_active_version(self, version_id: str) -> bool:
240 """
241 Set active prediction version.
243 Args:
244 version_id: Version to activate
246 Returns:
247 True if successful
248 """
249 async with self._lock:
250 if version_id not in self.versions:
251 logger.error(f"Version {version_id} not loaded")
252 return False
254 session, meta = self.versions[version_id]
255 if session is None:
256 logger.error(f"Version {version_id} has error, cannot activate")
257 return False
259 # Store previous for fallback
260 if self.active_version_id:
261 self.previous_version_id = self.active_version_id
263 self.active_version_id = version_id
264 meta.is_active = True
265 meta.updated_at = datetime.utcnow()
267 logger.info(f"🔄 Active version switched to {version_id}")
268 return True
270 async def unload_version(self, version_id: str) -> bool:
271 """
272 Unload a model version to free memory.
274 Args:
275 version_id: Version to unload
277 Returns:
278 True if successful
279 """
280 async with self._lock:
281 if version_id not in self.versions:
282 return False
284 if version_id == self.active_version_id:
285 logger.warning(f"Cannot unload active version {version_id}")
286 return False
288 del self.versions[version_id]
289 logger.info(f"Version {version_id} unloaded (memory freed)")
290 return True
292 def start_ab_test(
293 self,
294 version_a: str,
295 version_b: str,
296 traffic_split: float = 0.5,
297 duration_seconds: Optional[int] = None
298 ) -> bool:
299 """
300 Start A/B test between two model versions.
302 Args:
303 version_a: Primary/control version
304 version_b: Experimental/treatment version
305 traffic_split: Proportion to route to version_b (0.5 = 50/50)
306 duration_seconds: Test duration (None = unlimited)
308 Returns:
309 True if A/B test started
310 """
311 if version_a not in self.versions or version_b not in self.versions:
312 logger.error(f"One or both versions not loaded (A={version_a}, B={version_b})")
313 return False
315 traffic_split = max(0.01, min(0.99, traffic_split)) # Clamp to [0.01, 0.99]
317 self.ab_test_config = ABTestConfig(
318 enabled=True,
319 version_a=version_a,
320 version_b=version_b,
321 traffic_split=traffic_split,
322 started_at=datetime.utcnow(),
323 duration_seconds=duration_seconds
324 )
325 self.ab_test_stats = {"routed_to_a": 0, "routed_to_b": 0}
327 logger.info(
328 f"🧪 A/B Test started: {version_a} vs {version_b} "
329 f"(split={traffic_split:.1%}, duration={duration_seconds}s)"
330 )
331 return True
333 def end_ab_test(self, winner: Optional[str] = None) -> bool:
334 """
335 End A/B test and optionally promote winner.
337 Args:
338 winner: Version to promote (None = no promotion)
340 Returns:
341 True if successful
342 """
343 self.ab_test_config.enabled = False
345 logger.info(
346 f"🏁 A/B Test ended. Stats: "
347 f"A={self.ab_test_stats['routed_to_a']}, "
348 f"B={self.ab_test_stats['routed_to_b']}"
349 )
351 if winner and winner in self.versions:
352 logger.info(f"✨ Promoting {winner} to active version")
353 return asyncio.run(self.set_active_version(winner))
355 return True
357 async def predict(
358 self,
359 features: np.ndarray,
360 use_ab_routing: bool = True
361 ) -> Tuple[np.ndarray, str]:
362 """
363 Run prediction on active version (or routed via A/B test).
365 Args:
366 features: Input features (mel-spectrogram, shape expected by model)
367 use_ab_routing: Route through A/B test if enabled
369 Returns:
370 Tuple of (prediction, version_id_used)
372 Raises:
373 RuntimeError: If no active version or prediction fails
374 """
375 if not self.active_version_id:
376 raise RuntimeError("No active model version loaded")
378 # Determine which version to use
379 version_id = self.active_version_id
381 if use_ab_routing and self.ab_test_config.is_active():
382 # Route based on traffic split
383 if np.random.random() < self.ab_test_config.traffic_split:
384 version_id = self.ab_test_config.version_b
385 self.ab_test_stats["routed_to_b"] += 1
386 else:
387 version_id = self.ab_test_config.version_a
388 self.ab_test_stats["routed_to_a"] += 1
390 # Run inference with fallback
391 try:
392 session, meta = self.versions[version_id]
393 if session is None:
394 raise RuntimeError(f"Version {version_id} has error")
396 # ONNX input/output names
397 input_name = session.get_inputs()[0].name
398 output_names = [o.name for o in session.get_outputs()]
400 # Ensure proper shape (add batch dimension if needed)
401 if features.ndim == 2:
402 features = np.expand_dims(features, axis=0)
404 # Run inference
405 outputs = session.run(output_names, {input_name: features.astype(np.float32)})
407 logger.debug(f"Prediction via version {version_id}: {outputs[0].shape}")
408 return outputs[0], version_id
410 except Exception as e:
411 logger.error(f"Prediction failed on {version_id}: {e}")
413 # Fallback to previous version if available
414 if self.previous_version_id and self.previous_version_id != version_id:
415 logger.warning(f"Falling back to previous version {self.previous_version_id}")
416 try:
417 session, meta = self.versions[self.previous_version_id]
418 if session:
419 input_name = session.get_inputs()[0].name
420 output_names = [o.name for o in session.get_outputs()]
421 outputs = session.run(output_names, {input_name: features.astype(np.float32)})
422 return outputs[0], self.previous_version_id
423 except Exception as fallback_error:
424 logger.error(f"Fallback failed: {fallback_error}")
426 raise RuntimeError(f"Prediction failed with no fallback: {e}")
428 def get_version_info(self, version_id: str) -> Optional[Dict]:
429 """Get metadata for a specific version"""
430 if version_id not in self.versions:
431 return None
433 _, meta = self.versions[version_id]
434 return meta.to_dict()
436 def list_versions(self) -> List[Dict]:
437 """List all loaded versions with metadata"""
438 return [meta.to_dict() for _, meta in self.versions.values()]
440 def get_registry_status(self) -> Dict:
441 """Get overall registry status"""
442 return {
443 "active_version": self.active_version_id,
444 "previous_version": self.previous_version_id,
445 "loaded_versions": len(self.versions),
446 "max_versions": self.max_versions,
447 "versions": self.list_versions(),
448 "ab_test": self.ab_test_config.to_dict() if self.ab_test_config.enabled else None,
449 "ab_test_stats": self.ab_test_stats
450 }
452 @asynccontextmanager
453 async def model_context(self, version_id: str):
454 """
455 Context manager for temporary version switching.
457 Usage:
458 async with registry.model_context("v2"):
459 result = await registry.predict(features)
460 # Uses v2 temporarily
461 # Back to original active version
462 """
463 original_version = self.active_version_id
464 try:
465 if await self.set_active_version(version_id):
466 yield
467 else:
468 raise RuntimeError(f"Failed to switch to {version_id}")
469 finally:
470 if original_version:
471 await self.set_active_version(original_version)
474async def create_version_registry(
475 mlflow_uri: str = None,
476 initial_versions: Optional[Dict[str, str]] = None
477) -> ModelVersionRegistry:
478 """
479 Factory function to create and initialize version registry.
481 Args:
482 mlflow_uri: MLflow tracking server (defaults to settings.mlflow_tracking_uri)
483 initial_versions: Dict of {version_id: model_path} to preload
485 Returns:
486 Initialized ModelVersionRegistry
488 Example:
489 registry = await create_version_registry(
490 initial_versions={"v1": "models/v1.onnx"}
491 )
492 await registry.set_active_version("v1")
493 """
494 registry = ModelVersionRegistry(mlflow_tracking_uri=mlflow_uri)
496 if initial_versions:
497 for version_id, model_path in initial_versions.items():
498 await registry.load_version(
499 version_id,
500 stage=ModelStage.PRODUCTION,
501 model_path=model_path
502 )
504 return registry