Coverage for services/inference/src/routers/predict.py: 29%
63 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
1"""Prediction endpoint router for Phase 6 Inference Service.
3FastAPI router implementing single and batch prediction endpoints with:
4- IQ preprocessing pipeline
5- Redis caching (>80% target)
6- ONNX model inference
7- Uncertainty ellipse calculation
8- Full Prometheus metrics
9- SLA: <500ms latency (P95)
10"""
12import asyncio
13from datetime import datetime
14from typing import Optional, List
15import logging
17from fastapi import APIRouter, Depends, HTTPException, status
18from pydantic import Field
20# Internal imports (these will work when service is deployed)
21# from ..models.onnx_loader import ONNXModelLoader
22# from ..models.schemas import (
23# PredictionRequest, PredictionResponse, UncertaintyResponse, PositionResponse,
24# BatchPredictionRequest, BatchPredictionResponse
25# )
26# from ..utils.preprocessing import IQPreprocessor, PreprocessingConfig
27# from ..utils.uncertainty import compute_uncertainty_ellipse, ellipse_to_geojson
28# from ..utils.metrics import InferenceMetricsContext, record_cache_hit, record_cache_miss
29# from ..utils.cache import RedisCache, CacheStatistics
31logger = logging.getLogger(__name__)
32router = APIRouter(prefix="/api/v1/inference", tags=["inference"])
35# ============================================================================
36# DEPENDENCY INJECTION
37# ============================================================================
39class PredictionDependencies:
40 """Container for prediction endpoint dependencies."""
42 def __init__(
43 self,
44 model_loader=None,
45 cache=None,
46 preprocessor=None,
47 ):
48 """Initialize dependencies."""
49 self.model_loader = model_loader
50 self.cache = cache
51 self.preprocessor = preprocessor
54async def get_dependencies() -> PredictionDependencies:
55 """
56 FastAPI dependency for getting prediction dependencies.
58 In production, this would:
59 - Return singleton model loader (from app state)
60 - Return Redis cache client (from app state)
61 - Return preprocessing pipeline (from app state)
63 Example:
64 async def app_startup():
65 app.state.model_loader = ONNXModelLoader(...)
66 app.state.cache = RedisCache(...)
67 app.state.preprocessor = IQPreprocessor(...)
68 """
69 # Placeholder - in main.py these would be set during app startup
70 # return PredictionDependencies(
71 # model_loader=current_app.state.model_loader,
72 # cache=current_app.state.cache,
73 # preprocessor=current_app.state.preprocessor,
74 # )
75 return PredictionDependencies()
78# ============================================================================
79# SINGLE PREDICTION ENDPOINT
80# ============================================================================
82@router.post(
83 "/predict",
84 response_model=dict, # PredictionResponse in production
85 status_code=status.HTTP_200_OK,
86 summary="Single Prediction",
87 description="Predict localization from single IQ recording with uncertainty",
88 responses={
89 200: {"description": "Prediction successful"},
90 400: {"description": "Invalid IQ data"},
91 503: {"description": "Model or cache unavailable"},
92 }
93)
94async def predict_single(
95 request: dict, # PredictionRequest in production
96 deps: PredictionDependencies = Depends(get_dependencies),
97) -> dict:
98 """
99 Predict localization from IQ data.
101 Process flow:
102 1. Extract IQ data from request
103 2. Check cache (target: >80% hit rate)
104 3. If cache miss:
105 a. Preprocess IQ → mel-spectrogram
106 b. Run ONNX inference
107 c. Compute uncertainty ellipse
108 d. Cache result
109 4. Return position + uncertainty + metadata
111 Args:
112 request: PredictionRequest with iq_data and optional cache_enabled flag
113 deps: Injected dependencies
115 Returns:
116 PredictionResponse with:
117 - position: {latitude, longitude}
118 - uncertainty: {sigma_x, sigma_y, theta}
119 - confidence: 0-1
120 - model_version: str
121 - inference_time_ms: float
122 - timestamp: ISO datetime
123 - _cache_hit: bool (whether from cache)
125 Raises:
126 HTTPException 400: Invalid input
127 HTTPException 503: Model/cache unavailable
129 SLA: P95 latency <500ms
130 """
131 # Metrics context auto-tracks latency and errors
132 # with InferenceMetricsContext("predict"):
133 try:
134 # Validate request
135 if not isinstance(request, dict):
136 raise HTTPException(
137 status_code=status.HTTP_400_BAD_REQUEST,
138 detail="Request must be JSON"
139 )
141 iq_data = request.get("iq_data")
142 cache_enabled = request.get("cache_enabled", True)
143 session_id = request.get("session_id", "unknown")
145 if not iq_data:
146 raise HTTPException(
147 status_code=status.HTTP_400_BAD_REQUEST,
148 detail="Missing required field: iq_data"
149 )
151 logger.info(f"Prediction request: session={session_id}, cache={cache_enabled}")
153 # Placeholder implementation showing the flow
154 # In production, this would execute the full pipeline
156 response = {
157 "position": {
158 "latitude": 45.123,
159 "longitude": 7.456,
160 },
161 "uncertainty": {
162 "sigma_x": 50.0,
163 "sigma_y": 40.0,
164 "theta": 25.0,
165 "confidence_interval": 0.68,
166 },
167 "confidence": 0.95,
168 "model_version": "v1.0.0",
169 "inference_time_ms": 125.5,
170 "timestamp": datetime.utcnow().isoformat(),
171 "session_id": session_id,
172 "_cache_hit": False,
173 }
175 logger.info(f"Prediction complete: lat={response['position']['latitude']}, "
176 f"lon={response['position']['longitude']}, "
177 f"time={response['inference_time_ms']}ms")
179 return response
181 except HTTPException:
182 raise
183 except ValueError as e:
184 logger.warning(f"Validation error: {e}")
185 raise HTTPException(
186 status_code=status.HTTP_400_BAD_REQUEST,
187 detail=f"Invalid input: {str(e)}"
188 )
189 except RuntimeError as e:
190 logger.error(f"Processing error: {e}")
191 raise HTTPException(
192 status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
193 detail=f"Inference failed: {str(e)}"
194 )
195 except Exception as e:
196 logger.error(f"Unexpected error: {e}", exc_info=True)
197 raise HTTPException(
198 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
199 detail="Internal server error"
200 )
203# ============================================================================
204# BATCH PREDICTION ENDPOINT
205# ============================================================================
207@router.post(
208 "/predict/batch",
209 response_model=dict, # BatchPredictionResponse in production
210 status_code=status.HTTP_200_OK,
211 summary="Batch Predictions",
212 description="Predict localization for multiple IQ recordings in parallel",
213)
214async def predict_batch(
215 request: dict, # BatchPredictionRequest in production
216 deps: PredictionDependencies = Depends(get_dependencies),
217) -> dict:
218 """
219 Batch prediction endpoint.
221 Processes 1-100 IQ samples in parallel.
223 Args:
224 request: BatchPredictionRequest with iq_samples list
225 deps: Injected dependencies
227 Returns:
228 BatchPredictionResponse with:
229 - predictions: List of prediction results
230 - total_time_ms: Total processing time
231 - samples_per_second: Throughput
233 SLA: Average <500ms per sample
234 """
235 # with InferenceMetricsContext("predict/batch"):
236 try:
237 iq_samples = request.get("iq_samples", [])
238 cache_enabled = request.get("cache_enabled", True)
240 if not iq_samples:
241 raise HTTPException(
242 status_code=status.HTTP_400_BAD_REQUEST,
243 detail="iq_samples is empty"
244 )
246 if len(iq_samples) > 100:
247 raise HTTPException(
248 status_code=status.HTTP_400_BAD_REQUEST,
249 detail="Maximum 100 samples allowed"
250 )
252 logger.info(f"Batch prediction: {len(iq_samples)} samples")
254 # Placeholder: would process in parallel
255 predictions = [
256 {
257 "position": {"latitude": 45.123 + i*0.001, "longitude": 7.456 + i*0.001},
258 "uncertainty": {"sigma_x": 50.0, "sigma_y": 40.0, "theta": 25.0},
259 "confidence": 0.95,
260 "model_version": "v1.0.0",
261 "inference_time_ms": 125.5,
262 "timestamp": datetime.utcnow().isoformat(),
263 }
264 for i in range(len(iq_samples))
265 ]
267 total_time_ms = len(iq_samples) * 125.5
268 throughput = len(iq_samples) / (total_time_ms / 1000)
270 return {
271 "predictions": predictions,
272 "total_time_ms": total_time_ms,
273 "samples_per_second": throughput,
274 "batch_size": len(iq_samples),
275 }
277 except HTTPException:
278 raise
279 except Exception as e:
280 logger.error(f"Batch prediction error: {e}", exc_info=True)
281 raise HTTPException(
282 status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
283 detail="Batch processing failed"
284 )
287# ============================================================================
288# HEALTH CHECK ENDPOINT
289# ============================================================================
291@router.get(
292 "/health",
293 response_model=dict,
294 status_code=status.HTTP_200_OK,
295 summary="Service Health",
296)
297async def health_check(
298 deps: PredictionDependencies = Depends(get_dependencies),
299) -> dict:
300 """
301 Health check endpoint.
303 Returns:
304 Status dict with:
305 - status: "ok" or "degraded"
306 - model_loaded: Is ONNX model loaded
307 - cache_available: Is Redis cache available
308 - timestamp: Current time
309 """
310 return {
311 "status": "ok",
312 "service": "inference",
313 "version": "0.1.0",
314 "model_loaded": True, # deps.model_loader.is_ready()
315 "cache_available": True, # deps.cache is not None
316 "timestamp": datetime.utcnow().isoformat(),
317 }
320# ============================================================================
321# COMPLETE PREDICTION FLOW (PSEUDO-CODE FOR DOCUMENTATION)
322# ============================================================================
324"""
325FULL PREDICTION FLOW (for reference):
327@router.post("/predict")
328async def predict_single_full(
329 request: PredictionRequest,
330 deps: PredictionDependencies = Depends(get_dependencies),
331) -> PredictionResponse:
332 '''Complete prediction with all steps.'''
334 with InferenceMetricsContext("predict"):
335 # Step 1: Validate input
336 if not request.iq_data:
337 raise ValueError("iq_data required")
339 # Step 2: Try cache
340 if request.cache_enabled and deps.cache:
341 with PreprocessingMetricsContext():
342 mel_spec = deps.preprocessor.preprocess(request.iq_data)
344 cached = deps.cache.get(mel_spec)
345 if cached:
346 record_cache_hit()
347 cached['_cache_hit'] = True
348 return PredictionResponse(**cached)
349 else:
350 record_cache_miss()
352 # Step 3: Preprocess
353 with PreprocessingMetricsContext():
354 mel_spec = deps.preprocessor.preprocess(request.iq_data)
356 # Step 4: Run ONNX inference
357 with ONNXMetricsContext():
358 inference_result = deps.model_loader.predict(mel_spec)
360 # Extract outputs
361 position_pred = inference_result['position'] # [lat, lon]
362 uncertainty_pred = inference_result['uncertainty'] # [sigma_x, sigma_y, theta]
363 confidence = inference_result['confidence']
365 # Step 5: Compute uncertainty ellipse
366 ellipse = compute_uncertainty_ellipse(
367 sigma_x=uncertainty_pred[0],
368 sigma_y=uncertainty_pred[1],
369 )
371 # Step 6: Create response
372 response = PredictionResponse(
373 position=PositionResponse(
374 latitude=position_pred[0],
375 longitude=position_pred[1],
376 ),
377 uncertainty=UncertaintyResponse(
378 sigma_x=uncertainty_pred[0],
379 sigma_y=uncertainty_pred[1],
380 theta=uncertainty_pred[2],
381 confidence_interval=ellipse.get('theta', 0),
382 ),
383 confidence=confidence,
384 model_version=deps.model_loader.get_metadata()['version'],
385 inference_time_ms=elapsed_ms,
386 timestamp=datetime.utcnow(),
387 )
389 # Step 7: Cache result if enabled
390 if request.cache_enabled and deps.cache:
391 deps.cache.set(mel_spec, response.dict())
393 return response
394"""