Coverage for services/training/src/onnx_export.py: 0%

156 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-25 16:18 +0000

1""" 

2ONNX Export Module: Export LocalizationNet to ONNX format and upload to MinIO. 

3 

4This module handles: 

51. Converting PyTorch LocalizationNet to ONNX format 

62. Input/output validation and shape verification 

73. Quantization (optional, for inference optimization) 

84. Upload to MinIO artifact storage 

95. MLflow integration (register model, track versions) 

106. Batch prediction for verification 

11 

12ONNX (Open Neural Network Exchange) provides: 

13- Platform-independent model format 

14- Hardware acceleration support (CPU, GPU, mobile, edge devices) 

15- Inference optimization for production 

16- Model interoperability (use in any framework) 

17 

18Features: 

19- Dynamic batch size support for inference flexibility 

20- Input shape: (batch_size, 3, 128, 32) - mel-spectrogram 

21- Output shapes: 

22 - positions: (batch_size, 2) - [lat, lon] 

23 - uncertainties: (batch_size, 2) - [sigma_x, sigma_y] 

24 

25Performance: 

26- ONNX inference: ~20-30ms on CPU, <5ms on GPU (vs PyTorch ~50ms) 

27- Quantization: 4x smaller model (~100MB → ~25MB) 

28- Supports batching for throughput optimization 

29""" 

30 

31import torch 

32import torch.nn as nn 

33import onnx 

34import onnxruntime as ort 

35import numpy as np 

36import structlog 

37from pathlib import Path 

38from typing import Dict, Tuple, Optional, List 

39import json 

40import hashlib 

41from datetime import datetime 

42import io 

43 

44logger = structlog.get_logger(__name__) 

45 

46 

47class ONNXExporter: 

48 """ 

49 Export LocalizationNet to ONNX format with validation and optimization. 

50  

51 Workflow: 

52 1. Load trained PyTorch Lightning checkpoint 

53 2. Convert to ONNX 

54 3. Validate ONNX model 

55 4. (Optional) Quantize for inference optimization 

56 5. Upload to MinIO 

57 6. Register with MLflow Model Registry 

58 """ 

59 

60 def __init__(self, s3_client, mlflow_tracker): 

61 """ 

62 Initialize ONNX exporter. 

63  

64 Args: 

65 s3_client: boto3 S3 client (for MinIO) 

66 mlflow_tracker: MLflowTracker instance (for model registration) 

67 """ 

68 self.s3_client = s3_client 

69 self.mlflow_tracker = mlflow_tracker 

70 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 

71 

72 logger.info( 

73 "onnx_exporter_initialized", 

74 device=str(self.device), 

75 onnxruntime_version=ort.__version__, 

76 ) 

77 

78 def export_to_onnx( 

79 self, 

80 model: nn.Module, 

81 output_path: Path, 

82 opset_version: int = 14, 

83 do_constant_folding: bool = True, 

84 ) -> Path: 

85 """ 

86 Export PyTorch model to ONNX format. 

87  

88 Args: 

89 model (nn.Module): LocalizationNet instance (eval mode) 

90 output_path (Path): Where to save ONNX file 

91 opset_version (int): ONNX opset version (14 = good CPU support, 18 = latest GPU) 

92 do_constant_folding (bool): Optimize constant computations 

93  

94 Returns: 

95 Path to exported ONNX file 

96  

97 Raises: 

98 RuntimeError: If export fails 

99 """ 

100 model.eval() 

101 

102 # Create dummy input matching expected shape 

103 # (batch_size=1, channels=3, height=128, width=32) 

104 dummy_input = torch.randn(1, 3, 128, 32, device=self.device) 

105 

106 # Input/output names (required by ONNX) 

107 input_names = ['mel_spectrogram'] 

108 output_names = ['positions', 'uncertainties'] 

109 

110 # Dynamic axes for variable batch size 

111 dynamic_axes = { 

112 'mel_spectrogram': {0: 'batch_size'}, 

113 'positions': {0: 'batch_size'}, 

114 'uncertainties': {0: 'batch_size'}, 

115 } 

116 

117 try: 

118 logger.info( 

119 "exporting_to_onnx", 

120 opset_version=opset_version, 

121 output_path=str(output_path), 

122 ) 

123 

124 torch.onnx.export( 

125 model, 

126 dummy_input, 

127 str(output_path), 

128 input_names=input_names, 

129 output_names=output_names, 

130 dynamic_axes=dynamic_axes, 

131 opset_version=opset_version, 

132 do_constant_folding=do_constant_folding, 

133 verbose=False, 

134 ) 

135 

136 file_size_mb = output_path.stat().st_size / (1024 * 1024) 

137 logger.info( 

138 "onnx_export_successful", 

139 output_path=str(output_path), 

140 file_size_mb=f"{file_size_mb:.2f}", 

141 ) 

142 

143 return output_path 

144 

145 except Exception as e: 

146 logger.error( 

147 "onnx_export_failed", 

148 error=str(e), 

149 output_path=str(output_path), 

150 ) 

151 raise RuntimeError(f"ONNX export failed: {e}") from e 

152 

153 def validate_onnx_model(self, onnx_path: Path) -> Dict[str, any]: 

154 """ 

155 Validate ONNX model structure and shapes. 

156  

157 Args: 

158 onnx_path (Path): Path to ONNX file 

159  

160 Returns: 

161 Dict with model info: 

162 - inputs: List of input specifications 

163 - outputs: List of output specifications 

164 - opset_version: ONNX opset version 

165 - producer_name: Framework that exported the model 

166 - ir_version: ONNX IR version 

167  

168 Raises: 

169 ValueError: If model validation fails 

170 """ 

171 try: 

172 # Load and validate ONNX model 

173 onnx_model = onnx.load(str(onnx_path)) 

174 onnx.checker.check_model(onnx_model) 

175 

176 # Extract model info 

177 graph = onnx_model.graph 

178 

179 inputs_info = [] 

180 for input_tensor in graph.input: 

181 shape = [dim.dim_value for dim in input_tensor.type.tensor_type.shape.dim] 

182 inputs_info.append({ 

183 'name': input_tensor.name, 

184 'shape': shape, 

185 'dtype': str(input_tensor.type.tensor_type.data_type), 

186 }) 

187 

188 outputs_info = [] 

189 for output_tensor in graph.output: 

190 shape = [dim.dim_value for dim in output_tensor.type.tensor_type.shape.dim] 

191 outputs_info.append({ 

192 'name': output_tensor.name, 

193 'shape': shape, 

194 'dtype': str(output_tensor.type.tensor_type.data_type), 

195 }) 

196 

197 model_info = { 

198 'inputs': inputs_info, 

199 'outputs': outputs_info, 

200 'opset_version': onnx_model.opset_import[0].version, 

201 'producer_name': onnx_model.producer_name, 

202 'ir_version': onnx_model.ir_version, 

203 } 

204 

205 logger.info( 

206 "onnx_model_validated", 

207 onnx_path=str(onnx_path), 

208 inputs=len(inputs_info), 

209 outputs=len(outputs_info), 

210 ) 

211 

212 return model_info 

213 

214 except Exception as e: 

215 logger.error( 

216 "onnx_validation_failed", 

217 error=str(e), 

218 onnx_path=str(onnx_path), 

219 ) 

220 raise ValueError(f"ONNX model validation failed: {e}") from e 

221 

222 def test_onnx_inference( 

223 self, 

224 onnx_path: Path, 

225 pytorch_model: nn.Module, 

226 num_batches: int = 5, 

227 batch_size: int = 8, 

228 tolerance: float = 1e-5, 

229 ) -> Dict[str, float]: 

230 """ 

231 Test ONNX inference against PyTorch for accuracy verification. 

232  

233 Compares outputs of PyTorch model vs ONNX runtime to ensure 

234 numerical equivalence after export. 

235  

236 Args: 

237 onnx_path (Path): Path to ONNX file 

238 pytorch_model (nn.Module): Original PyTorch model 

239 num_batches (int): Number of test batches 

240 batch_size (int): Batch size for testing 

241 tolerance (float): Maximum allowed difference (MAE) 

242  

243 Returns: 

244 Dict with comparison metrics: 

245 - positions_mae: Mean Absolute Error for positions 

246 - uncertainties_mae: Mean Absolute Error for uncertainties 

247 - inference_time_onnx_ms: ONNX inference time (ms) 

248 - inference_time_pytorch_ms: PyTorch inference time (ms) 

249 - speedup: ONNX vs PyTorch speedup factor 

250 - passed: Boolean, True if tolerance met 

251  

252 Raises: 

253 AssertionError: If accuracy tolerance exceeded 

254 """ 

255 import time 

256 

257 pytorch_model.eval() 

258 sess = ort.InferenceSession(str(onnx_path)) 

259 

260 positions_diffs = [] 

261 uncertainties_diffs = [] 

262 times_onnx = [] 

263 times_pytorch = [] 

264 

265 with torch.no_grad(): 

266 for _ in range(num_batches): 

267 # Create test batch 

268 test_input = torch.randn(batch_size, 3, 128, 32, device=self.device) 

269 test_input_np = test_input.cpu().numpy().astype(np.float32) 

270 

271 # PyTorch inference 

272 t0 = time.time() 

273 py_positions, py_uncertainties = pytorch_model(test_input) 

274 t_pytorch = (time.time() - t0) * 1000 # ms 

275 times_pytorch.append(t_pytorch) 

276 

277 py_positions_np = py_positions.cpu().numpy() 

278 py_uncertainties_np = py_uncertainties.cpu().numpy() 

279 

280 # ONNX inference 

281 t0 = time.time() 

282 onnx_outputs = sess.run( 

283 None, # Output names = all outputs 

284 {'mel_spectrogram': test_input_np} 

285 ) 

286 t_onnx = (time.time() - t0) * 1000 # ms 

287 times_onnx.append(t_onnx) 

288 

289 onnx_positions = onnx_outputs[0] 

290 onnx_uncertainties = onnx_outputs[1] 

291 

292 # Compare outputs 

293 pos_diff = np.abs(py_positions_np - onnx_positions).mean() 

294 unc_diff = np.abs(py_uncertainties_np - onnx_uncertainties).mean() 

295 

296 positions_diffs.append(pos_diff) 

297 uncertainties_diffs.append(unc_diff) 

298 

299 positions_mae = np.mean(positions_diffs) 

300 uncertainties_mae = np.mean(uncertainties_diffs) 

301 mean_time_onnx = np.mean(times_onnx) 

302 mean_time_pytorch = np.mean(times_pytorch) 

303 speedup = mean_time_pytorch / mean_time_onnx 

304 

305 passed = (positions_mae < tolerance) and (uncertainties_mae < tolerance) 

306 

307 results = { 

308 'positions_mae': float(positions_mae), 

309 'uncertainties_mae': float(uncertainties_mae), 

310 'inference_time_onnx_ms': float(mean_time_onnx), 

311 'inference_time_pytorch_ms': float(mean_time_pytorch), 

312 'speedup': float(speedup), 

313 'passed': passed, 

314 } 

315 

316 logger.info( 

317 "onnx_inference_test_complete", 

318 positions_mae=f"{positions_mae:.2e}", 

319 uncertainties_mae=f"{uncertainties_mae:.2e}", 

320 speedup=f"{speedup:.2f}x", 

321 passed=passed, 

322 ) 

323 

324 if not passed: 

325 raise AssertionError( 

326 f"ONNX inference accuracy check failed: " 

327 f"positions_mae={positions_mae}, uncertainties_mae={uncertainties_mae}" 

328 ) 

329 

330 return results 

331 

332 def upload_to_minio( 

333 self, 

334 onnx_path: Path, 

335 bucket_name: str = 'heimdall-models', 

336 object_name: Optional[str] = None, 

337 ) -> str: 

338 """ 

339 Upload ONNX model to MinIO (S3-compatible storage). 

340  

341 Args: 

342 onnx_path (Path): Local path to ONNX file 

343 bucket_name (str): MinIO bucket (default: heimdall-models) 

344 object_name (str): S3 object path (default: models/localization/v{timestamp}.onnx) 

345  

346 Returns: 

347 S3 URI (s3://bucket/path) 

348  

349 Raises: 

350 Exception: If upload fails 

351 """ 

352 try: 

353 if object_name is None: 

354 timestamp = datetime.utcnow().strftime('%Y%m%d_%H%M%S') 

355 object_name = f'models/localization/v{timestamp}.onnx' 

356 

357 # Read file 

358 with open(onnx_path, 'rb') as f: 

359 file_data = f.read() 

360 

361 # Upload to MinIO 

362 self.s3_client.put_object( 

363 Bucket=bucket_name, 

364 Key=object_name, 

365 Body=file_data, 

366 ContentType='application/octet-stream', 

367 Metadata={ 

368 'export-date': datetime.utcnow().isoformat(), 

369 'model-type': 'localization-net', 

370 'format': 'onnx', 

371 'file-size': str(len(file_data)), 

372 }, 

373 ) 

374 

375 s3_uri = f's3://{bucket_name}/{object_name}' 

376 

377 logger.info( 

378 "onnx_uploaded_to_minio", 

379 bucket=bucket_name, 

380 object=object_name, 

381 file_size_mb=f"{len(file_data) / (1024*1024):.2f}", 

382 s3_uri=s3_uri, 

383 ) 

384 

385 return s3_uri 

386 

387 except Exception as e: 

388 logger.error( 

389 "onnx_upload_failed", 

390 error=str(e), 

391 bucket=bucket_name, 

392 ) 

393 raise RuntimeError(f"Failed to upload ONNX to MinIO: {e}") from e 

394 

395 def get_model_metadata( 

396 self, 

397 onnx_path: Path, 

398 pytorch_model: nn.Module, 

399 run_id: str, 

400 inference_metrics: Dict = None, 

401 ) -> Dict: 

402 """ 

403 Generate comprehensive metadata for the exported model. 

404  

405 Args: 

406 onnx_path (Path): Path to ONNX file 

407 pytorch_model (nn.Module): Original PyTorch model 

408 run_id (str): MLflow run ID 

409 inference_metrics (Dict): Inference test results 

410  

411 Returns: 

412 Dict with model metadata 

413 """ 

414 file_size = onnx_path.stat().st_size 

415 file_hash = hashlib.sha256() 

416 

417 with open(onnx_path, 'rb') as f: 

418 for chunk in iter(lambda: f.read(4096), b''): 

419 file_hash.update(chunk) 

420 

421 metadata = { 

422 'model_type': 'LocalizationNet', 

423 'backbone': 'ConvNeXt-Large', 

424 'input_shape': [1, 3, 128, 32], 

425 'output_names': ['positions', 'uncertainties'], 

426 'output_shapes': { 

427 'positions': [1, 2], 

428 'uncertainties': [1, 2], 

429 }, 

430 'export_date': datetime.utcnow().isoformat(), 

431 'onnx_file_size_bytes': file_size, 

432 'onnx_file_size_mb': file_size / (1024 * 1024), 

433 'onnx_file_sha256': file_hash.hexdigest(), 

434 'mlflow_run_id': run_id, 

435 'pytorch_params': pytorch_model.get_params_count(), 

436 'inference_metrics': inference_metrics or {}, 

437 } 

438 

439 return metadata 

440 

441 def register_with_mlflow( 

442 self, 

443 model_name: str, 

444 s3_uri: str, 

445 metadata: Dict, 

446 stage: str = 'Staging', 

447 ) -> Dict: 

448 """ 

449 Register ONNX model with MLflow Model Registry. 

450  

451 Args: 

452 model_name (str): Model registry name 

453 s3_uri (str): S3 URI to ONNX model 

454 metadata (Dict): Model metadata 

455 stage (str): Initial stage ('Staging' or 'Production') 

456  

457 Returns: 

458 Dict with registration details 

459 """ 

460 try: 

461 # Register model via MLflowTracker 

462 model_version = self.mlflow_tracker.register_model( 

463 model_name=model_name, 

464 model_uri=f's3://{s3_uri}', # MLflow expects s3:// URI 

465 tags={ 

466 'framework': 'pytorch', 

467 'format': 'onnx', 

468 'input_shape': '1,3,128,32', 

469 'output_count': '2', 

470 }, 

471 ) 

472 

473 # Transition to staging 

474 if stage in ['Staging', 'Production']: 

475 self.mlflow_tracker.transition_model_stage( 

476 model_name=model_name, 

477 version=model_version, 

478 stage=stage, 

479 ) 

480 

481 # Log metadata as artifact 

482 metadata_json = json.dumps(metadata, indent=2) 

483 self.mlflow_tracker.log_artifact( 

484 metadata_json, 

485 artifact_path=f'models/{model_name}/metadata.json' 

486 ) 

487 

488 result = { 

489 'model_name': model_name, 

490 'model_version': model_version, 

491 'stage': stage, 

492 's3_uri': s3_uri, 

493 } 

494 

495 logger.info( 

496 "model_registered_with_mlflow", 

497 model_name=model_name, 

498 model_version=model_version, 

499 stage=stage, 

500 ) 

501 

502 return result 

503 

504 except Exception as e: 

505 logger.error( 

506 "mlflow_registration_failed", 

507 error=str(e), 

508 model_name=model_name, 

509 ) 

510 raise RuntimeError(f"Failed to register model with MLflow: {e}") from e 

511 

512 

513def export_and_register_model( 

514 pytorch_model: nn.Module, 

515 run_id: str, 

516 s3_client, 

517 mlflow_tracker, 

518 output_dir: Path = Path('/tmp/onnx_exports'), 

519 model_name: str = 'heimdall-localization-onnx', 

520) -> Dict: 

521 """ 

522 Complete workflow: export PyTorch → ONNX → validate → upload → register. 

523  

524 Args: 

525 pytorch_model (nn.Module): Trained LocalizationNet model 

526 run_id (str): MLflow run ID (for tracking) 

527 s3_client: boto3 S3 client 

528 mlflow_tracker: MLflowTracker instance 

529 output_dir (Path): Directory for temporary ONNX files 

530 model_name (str): MLflow model registry name 

531  

532 Returns: 

533 Dict with complete export and registration details 

534  

535 Workflow: 

536 1. Export to ONNX 

537 2. Validate ONNX structure 

538 3. Test inference accuracy 

539 4. Upload to MinIO 

540 5. Register with MLflow 

541 6. Log metadata 

542 """ 

543 output_dir.mkdir(parents=True, exist_ok=True) 

544 exporter = ONNXExporter(s3_client, mlflow_tracker) 

545 

546 try: 

547 # Step 1: Export to ONNX 

548 onnx_path = output_dir / f'{model_name}_v{run_id[:8]}.onnx' 

549 exporter.export_to_onnx(pytorch_model, onnx_path) 

550 

551 # Step 2: Validate ONNX 

552 model_info = exporter.validate_onnx_model(onnx_path) 

553 

554 # Step 3: Test inference 

555 inference_metrics = exporter.test_onnx_inference(onnx_path, pytorch_model) 

556 

557 # Step 4: Upload to MinIO 

558 s3_uri = exporter.upload_to_minio(onnx_path) 

559 

560 # Step 5: Get metadata 

561 metadata = exporter.get_model_metadata( 

562 onnx_path, 

563 pytorch_model, 

564 run_id, 

565 inference_metrics, 

566 ) 

567 

568 # Step 6: Register with MLflow 

569 registration = exporter.register_with_mlflow( 

570 model_name, 

571 s3_uri, 

572 metadata, 

573 stage='Staging', 

574 ) 

575 

576 logger.info( 

577 "onnx_export_complete", 

578 model_name=model_name, 

579 onnx_file_size_mb=f"{metadata['onnx_file_size_mb']:.2f}", 

580 s3_uri=s3_uri, 

581 mlflow_version=registration['model_version'], 

582 speedup=f"{inference_metrics['speedup']:.2f}x", 

583 ) 

584 

585 return { 

586 'success': True, 

587 'model_name': model_name, 

588 'run_id': run_id, 

589 'onnx_path': str(onnx_path), 

590 's3_uri': s3_uri, 

591 'model_info': model_info, 

592 'metadata': metadata, 

593 'inference_metrics': inference_metrics, 

594 'registration': registration, 

595 } 

596 

597 except Exception as e: 

598 logger.error( 

599 "onnx_export_workflow_failed", 

600 error=str(e), 

601 model_name=model_name, 

602 ) 

603 return { 

604 'success': False, 

605 'error': str(e), 

606 } 

607 

608 

609if __name__ == "__main__": 

610 """Quick test - verify ONNX export works with dummy model.""" 

611 import logging 

612 logging.basicConfig(level=logging.INFO) 

613 

614 from src.models.localization_net import LocalizationNet 

615 

616 logger.info("Testing ONNX export...") 

617 

618 # Create dummy model 

619 model = LocalizationNet(pretrained=False) 

620 model.eval() 

621 

622 # Export 

623 exporter = ONNXExporter(None, None) 

624 output_path = Path('/tmp/test_model.onnx') 

625 exporter.export_to_onnx(model, output_path) 

626 

627 # Validate 

628 model_info = exporter.validate_onnx_model(output_path) 

629 logger.info(f"✅ ONNX export successful! Inputs: {model_info['inputs']}, Outputs: {model_info['outputs']}")