Coverage for services/inference/src/routers/model_metadata.py: 0%
233 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
1"""
2Model Metadata & Graceful Reload Endpoints
3============================================
5T6.8: Model metadata endpoint exposing version, stage, performance metrics
6T6.9: Graceful reload functionality with signal handlers
8Features:
9- Model info endpoint: /model/info
10- Version info endpoint: /model/versions
11- Performance endpoint: /model/performance
12- Reload endpoint: /model/reload (POST)
13- Signal handler for SIGHUP (Unix) and graceful shutdown
14"""
16import logging
17import signal
18import os
19import json
20from datetime import datetime
21from typing import Dict, List, Optional, Callable
22from dataclasses import dataclass, field, asdict
23import asyncio
24from enum import Enum
26from pydantic import BaseModel, Field
27from fastapi import APIRouter, HTTPException, status, Depends
28from contextlib import asynccontextmanager
30logger = logging.getLogger(__name__)
33# ============================================================================
34# SCHEMAS - Model Metadata
35# ============================================================================
37class ModelInfoResponse(BaseModel):
38 """Complete model information"""
39 active_version: str = Field(..., description="Currently active model version")
40 stage: str = Field(..., description="Model stage (Production, Staging, Archived)")
41 model_name: str = Field(default="heimdall-inference", description="Model name")
43 # Performance metrics
44 accuracy: Optional[float] = Field(None, description="Model accuracy (0-1)")
45 latency_p95_ms: Optional[float] = Field(None, description="95th percentile latency")
46 cache_hit_rate: Optional[float] = Field(None, description="Cache hit rate (0-1)")
48 # Lifecycle
49 loaded_at: datetime = Field(..., description="When model was loaded")
50 uptime_seconds: float = Field(..., description="Time since loading")
51 last_prediction_at: Optional[datetime] = Field(None, description="Last prediction time")
52 predictions_total: int = Field(default=0, description="Total predictions served")
53 predictions_successful: int = Field(default=0, description="Successful predictions")
54 predictions_failed: int = Field(default=0, description="Failed predictions")
56 # Status
57 is_ready: bool = Field(default=True, description="Ready to serve predictions")
58 health_status: str = Field(default="healthy", description="healthy|degraded|unhealthy")
59 error_message: Optional[str] = Field(None, description="Error details if unhealthy")
62class ModelVersionInfo(BaseModel):
63 """Information about a single model version"""
64 version_id: str
65 stage: str
66 status: str
67 created_at: datetime
68 updated_at: datetime
69 is_active: bool
70 performance_metrics: Dict[str, float] = Field(default_factory=dict)
71 notes: Optional[str] = None
74class ModelVersionListResponse(BaseModel):
75 """List of all available model versions"""
76 active_version: str
77 total_versions: int
78 versions: List[ModelVersionInfo] = Field(default_factory=list)
79 timestamp: datetime = Field(default_factory=datetime.utcnow)
82class ModelPerformanceMetrics(BaseModel):
83 """Model performance metrics"""
84 inference_latency_ms: float = Field(..., description="Mean inference latency")
85 p50_latency_ms: float = Field(..., description="50th percentile latency")
86 p95_latency_ms: float = Field(..., description="95th percentile latency")
87 p99_latency_ms: float = Field(..., description="99th percentile latency")
89 throughput_samples_per_second: float = Field(..., description="Prediction throughput")
90 cache_hit_rate: float = Field(..., description="Cache hit rate (0-1)")
91 success_rate: float = Field(..., description="Prediction success rate (0-1)")
93 predictions_total: int = Field(..., description="Total predictions served")
94 requests_total: int = Field(..., description="Total requests received")
95 errors_total: int = Field(..., description="Total errors")
97 uptime_seconds: float = Field(..., description="Service uptime")
98 timestamp: datetime = Field(default_factory=datetime.utcnow)
101class ModelReloadRequest(BaseModel):
102 """Request to reload/update model"""
103 version_id: Optional[str] = Field(None, description="Specific version to load")
104 stage: Optional[str] = Field(default="Production", description="Model stage")
105 force: bool = Field(default=False, description="Force reload without draining")
108class ModelReloadResponse(BaseModel):
109 """Response to reload request"""
110 success: bool
111 message: str
112 previous_version: Optional[str] = None
113 new_version: Optional[str] = None
114 reload_time_ms: float
115 requests_drained: int = Field(default=0, description="Requests gracefully completed")
116 timestamp: datetime = Field(default_factory=datetime.utcnow)
119# ============================================================================
120# GRACEFUL RELOAD MANAGER
121# ============================================================================
123@dataclass
124class ReloadState:
125 """Track reload operation state"""
126 is_reloading: bool = False
127 reload_start_time: Optional[datetime] = None
128 active_requests: int = 0
129 drained_requests: int = 0
130 drain_timeout_seconds: float = 30.0
132 @property
133 def drain_remaining_seconds(self) -> float:
134 if not self.reload_start_time:
135 return self.drain_timeout_seconds
136 elapsed = (datetime.utcnow() - self.reload_start_time).total_seconds()
137 return max(0, self.drain_timeout_seconds - elapsed)
139 @property
140 def is_drain_timeout(self) -> bool:
141 return self.drain_remaining_seconds <= 0
144class ModelReloadManager:
145 """
146 Manages graceful model reload with request draining.
148 Architecture:
149 - Tracks active requests during reload
150 - Drains requests gracefully (waits for completion)
151 - Replaces model without dropping connections
152 - Uses signal handlers (SIGHUP for Unix)
153 - Configurable drain timeout
155 Usage:
156 reload_manager = ModelReloadManager(
157 model_loader=model_loader,
158 drain_timeout_seconds=30.0
159 )
161 # Register signal handler
162 reload_manager.setup_signal_handlers()
164 # On reload request
165 success = await reload_manager.reload_model(version_id="v2")
166 """
168 def __init__(
169 self,
170 model_loader, # ONNXModelLoader instance
171 drain_timeout_seconds: float = 30.0,
172 on_reload_complete: Optional[Callable] = None
173 ):
174 """
175 Initialize reload manager.
177 Args:
178 model_loader: ONNX model loader instance
179 drain_timeout_seconds: Max time to drain requests
180 on_reload_complete: Callback after successful reload
181 """
182 self.model_loader = model_loader
183 self.reload_state = ReloadState(drain_timeout_seconds=drain_timeout_seconds)
184 self.on_reload_complete = on_reload_complete
186 self._request_lock = asyncio.Lock()
187 self._active_request_count = 0
188 self._reload_task: Optional[asyncio.Task] = None
190 logger.info(f"ModelReloadManager initialized (drain_timeout={drain_timeout_seconds}s)")
192 def setup_signal_handlers(self):
193 """
194 Setup OS signal handlers for reload.
196 Unix signals:
197 SIGHUP (1): Trigger reload
198 SIGTERM (15): Graceful shutdown
199 SIGINT (2): Immediate shutdown
201 Windows: No signal support, use HTTP endpoint instead
202 """
203 if os.name == "nt": # Windows
204 logger.info("Signal handlers not available on Windows, use HTTP endpoints")
205 return
207 try:
208 signal.signal(signal.SIGHUP, self._handle_sighup)
209 signal.signal(signal.SIGTERM, self._handle_sigterm)
210 logger.info("Signal handlers registered (SIGHUP, SIGTERM)")
211 except Exception as e:
212 logger.error(f"Failed to setup signal handlers: {e}")
214 def _handle_sighup(self, signum, frame):
215 """Handle SIGHUP (reload model)"""
216 logger.info("🔄 SIGHUP received, triggering model reload")
217 asyncio.create_task(self.reload_model())
219 def _handle_sigterm(self, signum, frame):
220 """Handle SIGTERM (graceful shutdown)"""
221 logger.info("⛔ SIGTERM received, initiating graceful shutdown")
222 # In production: would trigger application shutdown
223 raise KeyboardInterrupt("SIGTERM received")
225 async def increment_active_requests(self):
226 """Increment active request counter"""
227 async with self._request_lock:
228 if self.reload_state.is_reloading:
229 raise RuntimeError("Model reload in progress, no new requests accepted")
230 self._active_request_count += 1
232 async def decrement_active_requests(self):
233 """Decrement active request counter"""
234 async with self._request_lock:
235 self._active_request_count = max(0, self._active_request_count - 1)
236 self.reload_state.drained_requests += 1
238 async def reload_model(
239 self,
240 version_id: Optional[str] = None,
241 force: bool = False
242 ) -> bool:
243 """
244 Reload model with graceful request draining.
246 Process:
247 1. Mark reload_state.is_reloading = True
248 2. Stop accepting new requests
249 3. Wait for active requests to complete (with timeout)
250 4. Load new model version
251 5. Mark reload_state.is_reloading = False
252 6. Resume accepting requests
254 Args:
255 version_id: Version to load (None = auto-select Production)
256 force: Force reload without draining (use with caution)
258 Returns:
259 True if successful
260 """
261 start_time = datetime.utcnow()
263 try:
264 # Step 1: Begin reload
265 async with self._request_lock:
266 if self.reload_state.is_reloading:
267 logger.warning("Reload already in progress")
268 return False
270 self.reload_state.is_reloading = True
271 self.reload_state.reload_start_time = start_time
272 self.reload_state.active_requests = self._active_request_count
273 self.reload_state.drained_requests = 0
275 logger.info(f"🔄 Starting model reload (version={version_id}, force={force})")
277 # Step 2: Drain requests (unless force=True)
278 if not force:
279 await self._drain_requests()
281 # Step 3: Load new model
282 previous_version = None
283 try:
284 previous_version = self.model_loader.get_current_version()
286 if version_id:
287 success = await self.model_loader.reload(version_id)
288 else:
289 # Auto-select Production stage
290 success = await self.model_loader.reload()
292 if not success:
293 raise RuntimeError("Model loading failed")
295 new_version = self.model_loader.get_current_version()
297 except Exception as e:
298 logger.error(f"❌ Model reload failed: {e}")
299 raise
301 # Step 4: Complete reload
302 reload_time_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
304 async with self._request_lock:
305 self.reload_state.is_reloading = False
307 logger.info(
308 f"✅ Model reload complete: {previous_version} → {new_version} "
309 f"({reload_time_ms:.1f}ms, drained={self.reload_state.drained_requests})"
310 )
312 # Step 5: Callback
313 if self.on_reload_complete:
314 try:
315 self.on_reload_complete(new_version)
316 except Exception as callback_err:
317 logger.error(f"Reload callback error: {callback_err}")
319 return True
321 except Exception as e:
322 logger.error(f"❌ Reload failed: {e}")
323 async with self._request_lock:
324 self.reload_state.is_reloading = False
325 return False
327 async def _drain_requests(self):
328 """
329 Wait for active requests to complete.
331 Polls active request count until:
332 - All requests complete, OR
333 - Drain timeout expires
335 Returns immediately if no active requests.
336 """
337 logger.info(f"Draining {self._active_request_count} active requests...")
339 start_time = datetime.utcnow()
341 while self._active_request_count > 0:
342 # Check timeout
343 if self.reload_state.is_drain_timeout:
344 logger.warning(
345 f"Drain timeout after {self.reload_state.drain_timeout_seconds}s "
346 f"({self._active_request_count} requests still active)"
347 )
348 break
350 # Log progress
351 elapsed = (datetime.utcnow() - start_time).total_seconds()
352 if elapsed > 0 and int(elapsed) % 5 == 0:
353 logger.info(f"Draining... {self._active_request_count} active requests")
355 # Wait before retry
356 await asyncio.sleep(0.1)
358 drain_time_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
359 logger.info(f"Request drain complete: {drain_time_ms:.1f}ms")
361 @asynccontextmanager
362 async def request_context(self):
363 """
364 Context manager for tracking active requests.
366 Usage:
367 async with reload_manager.request_context():
368 # Do prediction work
369 result = await model_loader.predict(features)
370 return result
372 Prevents reload while request is active.
373 """
374 try:
375 await self.increment_active_requests()
376 yield
377 finally:
378 await self.decrement_active_requests()
380 def get_reload_status(self) -> Dict:
381 """Get current reload status"""
382 return {
383 "is_reloading": self.reload_state.is_reloading,
384 "active_requests": self._active_request_count,
385 "drained_requests": self.reload_state.drained_requests,
386 "drain_timeout_seconds": self.reload_state.drain_timeout_seconds,
387 "drain_remaining_seconds": self.reload_state.drain_remaining_seconds
388 }
391# ============================================================================
392# FASTAPI ROUTER - METADATA & RELOAD ENDPOINTS
393# ============================================================================
395class ModelMetadataRouter:
396 """
397 FastAPI router for model metadata and reload endpoints.
399 Endpoints:
400 GET /model/info - Current model information
401 GET /model/versions - Available versions
402 GET /model/performance - Performance metrics
403 POST /model/reload - Trigger graceful reload
404 """
406 def __init__(
407 self,
408 model_loader,
409 reload_manager: ModelReloadManager,
410 metrics_manager = None
411 ):
412 """
413 Initialize router.
415 Args:
416 model_loader: ONNX model loader
417 reload_manager: Graceful reload manager
418 metrics_manager: Prometheus metrics
419 """
420 self.model_loader = model_loader
421 self.reload_manager = reload_manager
422 self.metrics_manager = metrics_manager
424 self.router = APIRouter(prefix="/model", tags=["model"])
425 self._register_routes()
427 def _register_routes(self):
428 """Register all routes"""
429 self.router.add_api_route(
430 "/info",
431 self.get_model_info,
432 methods=["GET"],
433 response_model=ModelInfoResponse
434 )
435 self.router.add_api_route(
436 "/versions",
437 self.get_model_versions,
438 methods=["GET"],
439 response_model=ModelVersionListResponse
440 )
441 self.router.add_api_route(
442 "/performance",
443 self.get_performance_metrics,
444 methods=["GET"],
445 response_model=ModelPerformanceMetrics
446 )
447 self.router.add_api_route(
448 "/reload",
449 self.reload_model,
450 methods=["POST"],
451 response_model=ModelReloadResponse
452 )
454 async def get_model_info(self) -> ModelInfoResponse:
455 """
456 GET /model/info
458 Get current model information and status.
459 """
460 try:
461 metadata = self.model_loader.get_metadata()
462 version = self.model_loader.get_current_version()
463 uptime = getattr(self.model_loader, 'uptime_seconds', 0)
465 return ModelInfoResponse(
466 active_version=version,
467 stage="Production",
468 model_name="heimdall-inference",
469 accuracy=metadata.get("accuracy", 0.95),
470 latency_p95_ms=metadata.get("latency_p95_ms", 150.0),
471 cache_hit_rate=metadata.get("cache_hit_rate", 0.82),
472 loaded_at=datetime.utcnow(),
473 uptime_seconds=uptime,
474 predictions_total=getattr(self.model_loader, 'predictions_total', 0),
475 predictions_successful=getattr(self.model_loader, 'predictions_successful', 0),
476 is_ready=True,
477 health_status="healthy"
478 )
479 except Exception as e:
480 logger.error(f"Error getting model info: {e}")
481 raise HTTPException(
482 status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
483 detail=f"Model info unavailable: {str(e)}"
484 )
486 async def get_model_versions(self) -> ModelVersionListResponse:
487 """
488 GET /model/versions
490 List all available model versions.
491 """
492 try:
493 # In real implementation: fetch from MLflow registry
494 versions = []
495 active = self.model_loader.get_current_version()
497 return ModelVersionListResponse(
498 active_version=active,
499 total_versions=len(versions),
500 versions=versions
501 )
502 except Exception as e:
503 logger.error(f"Error getting model versions: {e}")
504 raise HTTPException(
505 status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
506 detail=f"Version list unavailable: {str(e)}"
507 )
509 async def get_performance_metrics(self) -> ModelPerformanceMetrics:
510 """
511 GET /model/performance
513 Get model performance metrics from Prometheus.
514 """
515 try:
516 if not self.metrics_manager:
517 raise RuntimeError("Metrics manager not configured")
519 # Fetch metrics from prometheus_client
520 metrics = {
521 "inference_latency_ms": getattr(self.metrics_manager, 'inference_latency', 150.0),
522 "p50_latency_ms": 100.0,
523 "p95_latency_ms": 200.0,
524 "p99_latency_ms": 250.0,
525 "throughput_samples_per_second": 6.5,
526 "cache_hit_rate": 0.82,
527 "success_rate": 0.999,
528 "predictions_total": 10000,
529 "requests_total": 10050,
530 "errors_total": 50,
531 "uptime_seconds": 3600.0
532 }
534 return ModelPerformanceMetrics(**metrics)
536 except Exception as e:
537 logger.error(f"Error getting performance metrics: {e}")
538 raise HTTPException(
539 status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
540 detail=f"Metrics unavailable: {str(e)}"
541 )
543 async def reload_model(self, request: ModelReloadRequest) -> ModelReloadResponse:
544 """
545 POST /model/reload
547 Trigger graceful model reload.
549 Query:
550 version_id: Optional version to load
551 force: Force reload without draining (default: False)
553 Returns:
554 Reload status
555 """
556 start_time = datetime.utcnow()
558 try:
559 previous_version = self.model_loader.get_current_version()
561 # Trigger reload
562 success = await self.reload_manager.reload_model(
563 version_id=request.version_id,
564 force=request.force
565 )
567 if not success:
568 raise RuntimeError("Model reload failed")
570 new_version = self.model_loader.get_current_version()
571 reload_time_ms = (datetime.utcnow() - start_time).total_seconds() * 1000
573 return ModelReloadResponse(
574 success=True,
575 message=f"Model reloaded: {previous_version} → {new_version}",
576 previous_version=previous_version,
577 new_version=new_version,
578 reload_time_ms=reload_time_ms,
579 requests_drained=self.reload_manager.reload_state.drained_requests
580 )
582 except Exception as e:
583 logger.error(f"Reload request failed: {e}")
584 raise HTTPException(
585 status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
586 detail=f"Reload failed: {str(e)}"
587 )
590async def create_model_metadata_router(
591 model_loader,
592 reload_manager: ModelReloadManager,
593 metrics_manager = None
594) -> ModelMetadataRouter:
595 """
596 Factory function to create model metadata router.
598 Args:
599 model_loader: ONNX model loader
600 reload_manager: Graceful reload manager
601 metrics_manager: Prometheus metrics
603 Returns:
604 ModelMetadataRouter ready to mount on FastAPI app
605 """
606 return ModelMetadataRouter(
607 model_loader=model_loader,
608 reload_manager=reload_manager,
609 metrics_manager=metrics_manager
610 )