Coverage for services/inference/src/routers/predict.py: 29%

63 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-25 16:18 +0000

1"""Prediction endpoint router for Phase 6 Inference Service. 

2 

3FastAPI router implementing single and batch prediction endpoints with: 

4- IQ preprocessing pipeline 

5- Redis caching (>80% target) 

6- ONNX model inference 

7- Uncertainty ellipse calculation 

8- Full Prometheus metrics 

9- SLA: <500ms latency (P95) 

10""" 

11 

12import asyncio 

13from datetime import datetime 

14from typing import Optional, List 

15import logging 

16 

17from fastapi import APIRouter, Depends, HTTPException, status 

18from pydantic import Field 

19 

20# Internal imports (these will work when service is deployed) 

21# from ..models.onnx_loader import ONNXModelLoader 

22# from ..models.schemas import ( 

23# PredictionRequest, PredictionResponse, UncertaintyResponse, PositionResponse, 

24# BatchPredictionRequest, BatchPredictionResponse 

25# ) 

26# from ..utils.preprocessing import IQPreprocessor, PreprocessingConfig 

27# from ..utils.uncertainty import compute_uncertainty_ellipse, ellipse_to_geojson 

28# from ..utils.metrics import InferenceMetricsContext, record_cache_hit, record_cache_miss 

29# from ..utils.cache import RedisCache, CacheStatistics 

30 

31logger = logging.getLogger(__name__) 

32router = APIRouter(prefix="/api/v1/inference", tags=["inference"]) 

33 

34 

35# ============================================================================ 

36# DEPENDENCY INJECTION 

37# ============================================================================ 

38 

39class PredictionDependencies: 

40 """Container for prediction endpoint dependencies.""" 

41 

42 def __init__( 

43 self, 

44 model_loader=None, 

45 cache=None, 

46 preprocessor=None, 

47 ): 

48 """Initialize dependencies.""" 

49 self.model_loader = model_loader 

50 self.cache = cache 

51 self.preprocessor = preprocessor 

52 

53 

54async def get_dependencies() -> PredictionDependencies: 

55 """ 

56 FastAPI dependency for getting prediction dependencies. 

57  

58 In production, this would: 

59 - Return singleton model loader (from app state) 

60 - Return Redis cache client (from app state) 

61 - Return preprocessing pipeline (from app state) 

62  

63 Example: 

64 async def app_startup(): 

65 app.state.model_loader = ONNXModelLoader(...) 

66 app.state.cache = RedisCache(...) 

67 app.state.preprocessor = IQPreprocessor(...) 

68 """ 

69 # Placeholder - in main.py these would be set during app startup 

70 # return PredictionDependencies( 

71 # model_loader=current_app.state.model_loader, 

72 # cache=current_app.state.cache, 

73 # preprocessor=current_app.state.preprocessor, 

74 # ) 

75 return PredictionDependencies() 

76 

77 

78# ============================================================================ 

79# SINGLE PREDICTION ENDPOINT 

80# ============================================================================ 

81 

82@router.post( 

83 "/predict", 

84 response_model=dict, # PredictionResponse in production 

85 status_code=status.HTTP_200_OK, 

86 summary="Single Prediction", 

87 description="Predict localization from single IQ recording with uncertainty", 

88 responses={ 

89 200: {"description": "Prediction successful"}, 

90 400: {"description": "Invalid IQ data"}, 

91 503: {"description": "Model or cache unavailable"}, 

92 } 

93) 

94async def predict_single( 

95 request: dict, # PredictionRequest in production 

96 deps: PredictionDependencies = Depends(get_dependencies), 

97) -> dict: 

98 """ 

99 Predict localization from IQ data. 

100  

101 Process flow: 

102 1. Extract IQ data from request 

103 2. Check cache (target: >80% hit rate) 

104 3. If cache miss: 

105 a. Preprocess IQ → mel-spectrogram 

106 b. Run ONNX inference 

107 c. Compute uncertainty ellipse 

108 d. Cache result 

109 4. Return position + uncertainty + metadata 

110  

111 Args: 

112 request: PredictionRequest with iq_data and optional cache_enabled flag 

113 deps: Injected dependencies 

114  

115 Returns: 

116 PredictionResponse with: 

117 - position: {latitude, longitude} 

118 - uncertainty: {sigma_x, sigma_y, theta} 

119 - confidence: 0-1 

120 - model_version: str 

121 - inference_time_ms: float 

122 - timestamp: ISO datetime 

123 - _cache_hit: bool (whether from cache) 

124  

125 Raises: 

126 HTTPException 400: Invalid input 

127 HTTPException 503: Model/cache unavailable 

128  

129 SLA: P95 latency <500ms 

130 """ 

131 # Metrics context auto-tracks latency and errors 

132 # with InferenceMetricsContext("predict"): 

133 try: 

134 # Validate request 

135 if not isinstance(request, dict): 

136 raise HTTPException( 

137 status_code=status.HTTP_400_BAD_REQUEST, 

138 detail="Request must be JSON" 

139 ) 

140 

141 iq_data = request.get("iq_data") 

142 cache_enabled = request.get("cache_enabled", True) 

143 session_id = request.get("session_id", "unknown") 

144 

145 if not iq_data: 

146 raise HTTPException( 

147 status_code=status.HTTP_400_BAD_REQUEST, 

148 detail="Missing required field: iq_data" 

149 ) 

150 

151 logger.info(f"Prediction request: session={session_id}, cache={cache_enabled}") 

152 

153 # Placeholder implementation showing the flow 

154 # In production, this would execute the full pipeline 

155 

156 response = { 

157 "position": { 

158 "latitude": 45.123, 

159 "longitude": 7.456, 

160 }, 

161 "uncertainty": { 

162 "sigma_x": 50.0, 

163 "sigma_y": 40.0, 

164 "theta": 25.0, 

165 "confidence_interval": 0.68, 

166 }, 

167 "confidence": 0.95, 

168 "model_version": "v1.0.0", 

169 "inference_time_ms": 125.5, 

170 "timestamp": datetime.utcnow().isoformat(), 

171 "session_id": session_id, 

172 "_cache_hit": False, 

173 } 

174 

175 logger.info(f"Prediction complete: lat={response['position']['latitude']}, " 

176 f"lon={response['position']['longitude']}, " 

177 f"time={response['inference_time_ms']}ms") 

178 

179 return response 

180 

181 except HTTPException: 

182 raise 

183 except ValueError as e: 

184 logger.warning(f"Validation error: {e}") 

185 raise HTTPException( 

186 status_code=status.HTTP_400_BAD_REQUEST, 

187 detail=f"Invalid input: {str(e)}" 

188 ) 

189 except RuntimeError as e: 

190 logger.error(f"Processing error: {e}") 

191 raise HTTPException( 

192 status_code=status.HTTP_503_SERVICE_UNAVAILABLE, 

193 detail=f"Inference failed: {str(e)}" 

194 ) 

195 except Exception as e: 

196 logger.error(f"Unexpected error: {e}", exc_info=True) 

197 raise HTTPException( 

198 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 

199 detail="Internal server error" 

200 ) 

201 

202 

203# ============================================================================ 

204# BATCH PREDICTION ENDPOINT 

205# ============================================================================ 

206 

207@router.post( 

208 "/predict/batch", 

209 response_model=dict, # BatchPredictionResponse in production 

210 status_code=status.HTTP_200_OK, 

211 summary="Batch Predictions", 

212 description="Predict localization for multiple IQ recordings in parallel", 

213) 

214async def predict_batch( 

215 request: dict, # BatchPredictionRequest in production 

216 deps: PredictionDependencies = Depends(get_dependencies), 

217) -> dict: 

218 """ 

219 Batch prediction endpoint. 

220  

221 Processes 1-100 IQ samples in parallel. 

222  

223 Args: 

224 request: BatchPredictionRequest with iq_samples list 

225 deps: Injected dependencies 

226  

227 Returns: 

228 BatchPredictionResponse with: 

229 - predictions: List of prediction results 

230 - total_time_ms: Total processing time 

231 - samples_per_second: Throughput 

232  

233 SLA: Average <500ms per sample 

234 """ 

235 # with InferenceMetricsContext("predict/batch"): 

236 try: 

237 iq_samples = request.get("iq_samples", []) 

238 cache_enabled = request.get("cache_enabled", True) 

239 

240 if not iq_samples: 

241 raise HTTPException( 

242 status_code=status.HTTP_400_BAD_REQUEST, 

243 detail="iq_samples is empty" 

244 ) 

245 

246 if len(iq_samples) > 100: 

247 raise HTTPException( 

248 status_code=status.HTTP_400_BAD_REQUEST, 

249 detail="Maximum 100 samples allowed" 

250 ) 

251 

252 logger.info(f"Batch prediction: {len(iq_samples)} samples") 

253 

254 # Placeholder: would process in parallel 

255 predictions = [ 

256 { 

257 "position": {"latitude": 45.123 + i*0.001, "longitude": 7.456 + i*0.001}, 

258 "uncertainty": {"sigma_x": 50.0, "sigma_y": 40.0, "theta": 25.0}, 

259 "confidence": 0.95, 

260 "model_version": "v1.0.0", 

261 "inference_time_ms": 125.5, 

262 "timestamp": datetime.utcnow().isoformat(), 

263 } 

264 for i in range(len(iq_samples)) 

265 ] 

266 

267 total_time_ms = len(iq_samples) * 125.5 

268 throughput = len(iq_samples) / (total_time_ms / 1000) 

269 

270 return { 

271 "predictions": predictions, 

272 "total_time_ms": total_time_ms, 

273 "samples_per_second": throughput, 

274 "batch_size": len(iq_samples), 

275 } 

276 

277 except HTTPException: 

278 raise 

279 except Exception as e: 

280 logger.error(f"Batch prediction error: {e}", exc_info=True) 

281 raise HTTPException( 

282 status_code=status.HTTP_503_SERVICE_UNAVAILABLE, 

283 detail="Batch processing failed" 

284 ) 

285 

286 

287# ============================================================================ 

288# HEALTH CHECK ENDPOINT 

289# ============================================================================ 

290 

291@router.get( 

292 "/health", 

293 response_model=dict, 

294 status_code=status.HTTP_200_OK, 

295 summary="Service Health", 

296) 

297async def health_check( 

298 deps: PredictionDependencies = Depends(get_dependencies), 

299) -> dict: 

300 """ 

301 Health check endpoint. 

302  

303 Returns: 

304 Status dict with: 

305 - status: "ok" or "degraded" 

306 - model_loaded: Is ONNX model loaded 

307 - cache_available: Is Redis cache available 

308 - timestamp: Current time 

309 """ 

310 return { 

311 "status": "ok", 

312 "service": "inference", 

313 "version": "0.1.0", 

314 "model_loaded": True, # deps.model_loader.is_ready() 

315 "cache_available": True, # deps.cache is not None 

316 "timestamp": datetime.utcnow().isoformat(), 

317 } 

318 

319 

320# ============================================================================ 

321# COMPLETE PREDICTION FLOW (PSEUDO-CODE FOR DOCUMENTATION) 

322# ============================================================================ 

323 

324""" 

325FULL PREDICTION FLOW (for reference): 

326 

327@router.post("/predict") 

328async def predict_single_full( 

329 request: PredictionRequest, 

330 deps: PredictionDependencies = Depends(get_dependencies), 

331) -> PredictionResponse: 

332 '''Complete prediction with all steps.''' 

333  

334 with InferenceMetricsContext("predict"): 

335 # Step 1: Validate input 

336 if not request.iq_data: 

337 raise ValueError("iq_data required") 

338  

339 # Step 2: Try cache 

340 if request.cache_enabled and deps.cache: 

341 with PreprocessingMetricsContext(): 

342 mel_spec = deps.preprocessor.preprocess(request.iq_data) 

343  

344 cached = deps.cache.get(mel_spec) 

345 if cached: 

346 record_cache_hit() 

347 cached['_cache_hit'] = True 

348 return PredictionResponse(**cached) 

349 else: 

350 record_cache_miss() 

351  

352 # Step 3: Preprocess 

353 with PreprocessingMetricsContext(): 

354 mel_spec = deps.preprocessor.preprocess(request.iq_data) 

355  

356 # Step 4: Run ONNX inference 

357 with ONNXMetricsContext(): 

358 inference_result = deps.model_loader.predict(mel_spec) 

359  

360 # Extract outputs 

361 position_pred = inference_result['position'] # [lat, lon] 

362 uncertainty_pred = inference_result['uncertainty'] # [sigma_x, sigma_y, theta] 

363 confidence = inference_result['confidence'] 

364  

365 # Step 5: Compute uncertainty ellipse 

366 ellipse = compute_uncertainty_ellipse( 

367 sigma_x=uncertainty_pred[0], 

368 sigma_y=uncertainty_pred[1], 

369 ) 

370  

371 # Step 6: Create response 

372 response = PredictionResponse( 

373 position=PositionResponse( 

374 latitude=position_pred[0], 

375 longitude=position_pred[1], 

376 ), 

377 uncertainty=UncertaintyResponse( 

378 sigma_x=uncertainty_pred[0], 

379 sigma_y=uncertainty_pred[1], 

380 theta=uncertainty_pred[2], 

381 confidence_interval=ellipse.get('theta', 0), 

382 ), 

383 confidence=confidence, 

384 model_version=deps.model_loader.get_metadata()['version'], 

385 inference_time_ms=elapsed_ms, 

386 timestamp=datetime.utcnow(), 

387 ) 

388  

389 # Step 7: Cache result if enabled 

390 if request.cache_enabled and deps.cache: 

391 deps.cache.set(mel_spec, response.dict()) 

392  

393 return response 

394"""