Coverage for services/training/src/onnx_export.py: 0%
156 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
1"""
2ONNX Export Module: Export LocalizationNet to ONNX format and upload to MinIO.
4This module handles:
51. Converting PyTorch LocalizationNet to ONNX format
62. Input/output validation and shape verification
73. Quantization (optional, for inference optimization)
84. Upload to MinIO artifact storage
95. MLflow integration (register model, track versions)
106. Batch prediction for verification
12ONNX (Open Neural Network Exchange) provides:
13- Platform-independent model format
14- Hardware acceleration support (CPU, GPU, mobile, edge devices)
15- Inference optimization for production
16- Model interoperability (use in any framework)
18Features:
19- Dynamic batch size support for inference flexibility
20- Input shape: (batch_size, 3, 128, 32) - mel-spectrogram
21- Output shapes:
22 - positions: (batch_size, 2) - [lat, lon]
23 - uncertainties: (batch_size, 2) - [sigma_x, sigma_y]
25Performance:
26- ONNX inference: ~20-30ms on CPU, <5ms on GPU (vs PyTorch ~50ms)
27- Quantization: 4x smaller model (~100MB → ~25MB)
28- Supports batching for throughput optimization
29"""
31import torch
32import torch.nn as nn
33import onnx
34import onnxruntime as ort
35import numpy as np
36import structlog
37from pathlib import Path
38from typing import Dict, Tuple, Optional, List
39import json
40import hashlib
41from datetime import datetime
42import io
44logger = structlog.get_logger(__name__)
47class ONNXExporter:
48 """
49 Export LocalizationNet to ONNX format with validation and optimization.
51 Workflow:
52 1. Load trained PyTorch Lightning checkpoint
53 2. Convert to ONNX
54 3. Validate ONNX model
55 4. (Optional) Quantize for inference optimization
56 5. Upload to MinIO
57 6. Register with MLflow Model Registry
58 """
60 def __init__(self, s3_client, mlflow_tracker):
61 """
62 Initialize ONNX exporter.
64 Args:
65 s3_client: boto3 S3 client (for MinIO)
66 mlflow_tracker: MLflowTracker instance (for model registration)
67 """
68 self.s3_client = s3_client
69 self.mlflow_tracker = mlflow_tracker
70 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
72 logger.info(
73 "onnx_exporter_initialized",
74 device=str(self.device),
75 onnxruntime_version=ort.__version__,
76 )
78 def export_to_onnx(
79 self,
80 model: nn.Module,
81 output_path: Path,
82 opset_version: int = 14,
83 do_constant_folding: bool = True,
84 ) -> Path:
85 """
86 Export PyTorch model to ONNX format.
88 Args:
89 model (nn.Module): LocalizationNet instance (eval mode)
90 output_path (Path): Where to save ONNX file
91 opset_version (int): ONNX opset version (14 = good CPU support, 18 = latest GPU)
92 do_constant_folding (bool): Optimize constant computations
94 Returns:
95 Path to exported ONNX file
97 Raises:
98 RuntimeError: If export fails
99 """
100 model.eval()
102 # Create dummy input matching expected shape
103 # (batch_size=1, channels=3, height=128, width=32)
104 dummy_input = torch.randn(1, 3, 128, 32, device=self.device)
106 # Input/output names (required by ONNX)
107 input_names = ['mel_spectrogram']
108 output_names = ['positions', 'uncertainties']
110 # Dynamic axes for variable batch size
111 dynamic_axes = {
112 'mel_spectrogram': {0: 'batch_size'},
113 'positions': {0: 'batch_size'},
114 'uncertainties': {0: 'batch_size'},
115 }
117 try:
118 logger.info(
119 "exporting_to_onnx",
120 opset_version=opset_version,
121 output_path=str(output_path),
122 )
124 torch.onnx.export(
125 model,
126 dummy_input,
127 str(output_path),
128 input_names=input_names,
129 output_names=output_names,
130 dynamic_axes=dynamic_axes,
131 opset_version=opset_version,
132 do_constant_folding=do_constant_folding,
133 verbose=False,
134 )
136 file_size_mb = output_path.stat().st_size / (1024 * 1024)
137 logger.info(
138 "onnx_export_successful",
139 output_path=str(output_path),
140 file_size_mb=f"{file_size_mb:.2f}",
141 )
143 return output_path
145 except Exception as e:
146 logger.error(
147 "onnx_export_failed",
148 error=str(e),
149 output_path=str(output_path),
150 )
151 raise RuntimeError(f"ONNX export failed: {e}") from e
153 def validate_onnx_model(self, onnx_path: Path) -> Dict[str, any]:
154 """
155 Validate ONNX model structure and shapes.
157 Args:
158 onnx_path (Path): Path to ONNX file
160 Returns:
161 Dict with model info:
162 - inputs: List of input specifications
163 - outputs: List of output specifications
164 - opset_version: ONNX opset version
165 - producer_name: Framework that exported the model
166 - ir_version: ONNX IR version
168 Raises:
169 ValueError: If model validation fails
170 """
171 try:
172 # Load and validate ONNX model
173 onnx_model = onnx.load(str(onnx_path))
174 onnx.checker.check_model(onnx_model)
176 # Extract model info
177 graph = onnx_model.graph
179 inputs_info = []
180 for input_tensor in graph.input:
181 shape = [dim.dim_value for dim in input_tensor.type.tensor_type.shape.dim]
182 inputs_info.append({
183 'name': input_tensor.name,
184 'shape': shape,
185 'dtype': str(input_tensor.type.tensor_type.data_type),
186 })
188 outputs_info = []
189 for output_tensor in graph.output:
190 shape = [dim.dim_value for dim in output_tensor.type.tensor_type.shape.dim]
191 outputs_info.append({
192 'name': output_tensor.name,
193 'shape': shape,
194 'dtype': str(output_tensor.type.tensor_type.data_type),
195 })
197 model_info = {
198 'inputs': inputs_info,
199 'outputs': outputs_info,
200 'opset_version': onnx_model.opset_import[0].version,
201 'producer_name': onnx_model.producer_name,
202 'ir_version': onnx_model.ir_version,
203 }
205 logger.info(
206 "onnx_model_validated",
207 onnx_path=str(onnx_path),
208 inputs=len(inputs_info),
209 outputs=len(outputs_info),
210 )
212 return model_info
214 except Exception as e:
215 logger.error(
216 "onnx_validation_failed",
217 error=str(e),
218 onnx_path=str(onnx_path),
219 )
220 raise ValueError(f"ONNX model validation failed: {e}") from e
222 def test_onnx_inference(
223 self,
224 onnx_path: Path,
225 pytorch_model: nn.Module,
226 num_batches: int = 5,
227 batch_size: int = 8,
228 tolerance: float = 1e-5,
229 ) -> Dict[str, float]:
230 """
231 Test ONNX inference against PyTorch for accuracy verification.
233 Compares outputs of PyTorch model vs ONNX runtime to ensure
234 numerical equivalence after export.
236 Args:
237 onnx_path (Path): Path to ONNX file
238 pytorch_model (nn.Module): Original PyTorch model
239 num_batches (int): Number of test batches
240 batch_size (int): Batch size for testing
241 tolerance (float): Maximum allowed difference (MAE)
243 Returns:
244 Dict with comparison metrics:
245 - positions_mae: Mean Absolute Error for positions
246 - uncertainties_mae: Mean Absolute Error for uncertainties
247 - inference_time_onnx_ms: ONNX inference time (ms)
248 - inference_time_pytorch_ms: PyTorch inference time (ms)
249 - speedup: ONNX vs PyTorch speedup factor
250 - passed: Boolean, True if tolerance met
252 Raises:
253 AssertionError: If accuracy tolerance exceeded
254 """
255 import time
257 pytorch_model.eval()
258 sess = ort.InferenceSession(str(onnx_path))
260 positions_diffs = []
261 uncertainties_diffs = []
262 times_onnx = []
263 times_pytorch = []
265 with torch.no_grad():
266 for _ in range(num_batches):
267 # Create test batch
268 test_input = torch.randn(batch_size, 3, 128, 32, device=self.device)
269 test_input_np = test_input.cpu().numpy().astype(np.float32)
271 # PyTorch inference
272 t0 = time.time()
273 py_positions, py_uncertainties = pytorch_model(test_input)
274 t_pytorch = (time.time() - t0) * 1000 # ms
275 times_pytorch.append(t_pytorch)
277 py_positions_np = py_positions.cpu().numpy()
278 py_uncertainties_np = py_uncertainties.cpu().numpy()
280 # ONNX inference
281 t0 = time.time()
282 onnx_outputs = sess.run(
283 None, # Output names = all outputs
284 {'mel_spectrogram': test_input_np}
285 )
286 t_onnx = (time.time() - t0) * 1000 # ms
287 times_onnx.append(t_onnx)
289 onnx_positions = onnx_outputs[0]
290 onnx_uncertainties = onnx_outputs[1]
292 # Compare outputs
293 pos_diff = np.abs(py_positions_np - onnx_positions).mean()
294 unc_diff = np.abs(py_uncertainties_np - onnx_uncertainties).mean()
296 positions_diffs.append(pos_diff)
297 uncertainties_diffs.append(unc_diff)
299 positions_mae = np.mean(positions_diffs)
300 uncertainties_mae = np.mean(uncertainties_diffs)
301 mean_time_onnx = np.mean(times_onnx)
302 mean_time_pytorch = np.mean(times_pytorch)
303 speedup = mean_time_pytorch / mean_time_onnx
305 passed = (positions_mae < tolerance) and (uncertainties_mae < tolerance)
307 results = {
308 'positions_mae': float(positions_mae),
309 'uncertainties_mae': float(uncertainties_mae),
310 'inference_time_onnx_ms': float(mean_time_onnx),
311 'inference_time_pytorch_ms': float(mean_time_pytorch),
312 'speedup': float(speedup),
313 'passed': passed,
314 }
316 logger.info(
317 "onnx_inference_test_complete",
318 positions_mae=f"{positions_mae:.2e}",
319 uncertainties_mae=f"{uncertainties_mae:.2e}",
320 speedup=f"{speedup:.2f}x",
321 passed=passed,
322 )
324 if not passed:
325 raise AssertionError(
326 f"ONNX inference accuracy check failed: "
327 f"positions_mae={positions_mae}, uncertainties_mae={uncertainties_mae}"
328 )
330 return results
332 def upload_to_minio(
333 self,
334 onnx_path: Path,
335 bucket_name: str = 'heimdall-models',
336 object_name: Optional[str] = None,
337 ) -> str:
338 """
339 Upload ONNX model to MinIO (S3-compatible storage).
341 Args:
342 onnx_path (Path): Local path to ONNX file
343 bucket_name (str): MinIO bucket (default: heimdall-models)
344 object_name (str): S3 object path (default: models/localization/v{timestamp}.onnx)
346 Returns:
347 S3 URI (s3://bucket/path)
349 Raises:
350 Exception: If upload fails
351 """
352 try:
353 if object_name is None:
354 timestamp = datetime.utcnow().strftime('%Y%m%d_%H%M%S')
355 object_name = f'models/localization/v{timestamp}.onnx'
357 # Read file
358 with open(onnx_path, 'rb') as f:
359 file_data = f.read()
361 # Upload to MinIO
362 self.s3_client.put_object(
363 Bucket=bucket_name,
364 Key=object_name,
365 Body=file_data,
366 ContentType='application/octet-stream',
367 Metadata={
368 'export-date': datetime.utcnow().isoformat(),
369 'model-type': 'localization-net',
370 'format': 'onnx',
371 'file-size': str(len(file_data)),
372 },
373 )
375 s3_uri = f's3://{bucket_name}/{object_name}'
377 logger.info(
378 "onnx_uploaded_to_minio",
379 bucket=bucket_name,
380 object=object_name,
381 file_size_mb=f"{len(file_data) / (1024*1024):.2f}",
382 s3_uri=s3_uri,
383 )
385 return s3_uri
387 except Exception as e:
388 logger.error(
389 "onnx_upload_failed",
390 error=str(e),
391 bucket=bucket_name,
392 )
393 raise RuntimeError(f"Failed to upload ONNX to MinIO: {e}") from e
395 def get_model_metadata(
396 self,
397 onnx_path: Path,
398 pytorch_model: nn.Module,
399 run_id: str,
400 inference_metrics: Dict = None,
401 ) -> Dict:
402 """
403 Generate comprehensive metadata for the exported model.
405 Args:
406 onnx_path (Path): Path to ONNX file
407 pytorch_model (nn.Module): Original PyTorch model
408 run_id (str): MLflow run ID
409 inference_metrics (Dict): Inference test results
411 Returns:
412 Dict with model metadata
413 """
414 file_size = onnx_path.stat().st_size
415 file_hash = hashlib.sha256()
417 with open(onnx_path, 'rb') as f:
418 for chunk in iter(lambda: f.read(4096), b''):
419 file_hash.update(chunk)
421 metadata = {
422 'model_type': 'LocalizationNet',
423 'backbone': 'ConvNeXt-Large',
424 'input_shape': [1, 3, 128, 32],
425 'output_names': ['positions', 'uncertainties'],
426 'output_shapes': {
427 'positions': [1, 2],
428 'uncertainties': [1, 2],
429 },
430 'export_date': datetime.utcnow().isoformat(),
431 'onnx_file_size_bytes': file_size,
432 'onnx_file_size_mb': file_size / (1024 * 1024),
433 'onnx_file_sha256': file_hash.hexdigest(),
434 'mlflow_run_id': run_id,
435 'pytorch_params': pytorch_model.get_params_count(),
436 'inference_metrics': inference_metrics or {},
437 }
439 return metadata
441 def register_with_mlflow(
442 self,
443 model_name: str,
444 s3_uri: str,
445 metadata: Dict,
446 stage: str = 'Staging',
447 ) -> Dict:
448 """
449 Register ONNX model with MLflow Model Registry.
451 Args:
452 model_name (str): Model registry name
453 s3_uri (str): S3 URI to ONNX model
454 metadata (Dict): Model metadata
455 stage (str): Initial stage ('Staging' or 'Production')
457 Returns:
458 Dict with registration details
459 """
460 try:
461 # Register model via MLflowTracker
462 model_version = self.mlflow_tracker.register_model(
463 model_name=model_name,
464 model_uri=f's3://{s3_uri}', # MLflow expects s3:// URI
465 tags={
466 'framework': 'pytorch',
467 'format': 'onnx',
468 'input_shape': '1,3,128,32',
469 'output_count': '2',
470 },
471 )
473 # Transition to staging
474 if stage in ['Staging', 'Production']:
475 self.mlflow_tracker.transition_model_stage(
476 model_name=model_name,
477 version=model_version,
478 stage=stage,
479 )
481 # Log metadata as artifact
482 metadata_json = json.dumps(metadata, indent=2)
483 self.mlflow_tracker.log_artifact(
484 metadata_json,
485 artifact_path=f'models/{model_name}/metadata.json'
486 )
488 result = {
489 'model_name': model_name,
490 'model_version': model_version,
491 'stage': stage,
492 's3_uri': s3_uri,
493 }
495 logger.info(
496 "model_registered_with_mlflow",
497 model_name=model_name,
498 model_version=model_version,
499 stage=stage,
500 )
502 return result
504 except Exception as e:
505 logger.error(
506 "mlflow_registration_failed",
507 error=str(e),
508 model_name=model_name,
509 )
510 raise RuntimeError(f"Failed to register model with MLflow: {e}") from e
513def export_and_register_model(
514 pytorch_model: nn.Module,
515 run_id: str,
516 s3_client,
517 mlflow_tracker,
518 output_dir: Path = Path('/tmp/onnx_exports'),
519 model_name: str = 'heimdall-localization-onnx',
520) -> Dict:
521 """
522 Complete workflow: export PyTorch → ONNX → validate → upload → register.
524 Args:
525 pytorch_model (nn.Module): Trained LocalizationNet model
526 run_id (str): MLflow run ID (for tracking)
527 s3_client: boto3 S3 client
528 mlflow_tracker: MLflowTracker instance
529 output_dir (Path): Directory for temporary ONNX files
530 model_name (str): MLflow model registry name
532 Returns:
533 Dict with complete export and registration details
535 Workflow:
536 1. Export to ONNX
537 2. Validate ONNX structure
538 3. Test inference accuracy
539 4. Upload to MinIO
540 5. Register with MLflow
541 6. Log metadata
542 """
543 output_dir.mkdir(parents=True, exist_ok=True)
544 exporter = ONNXExporter(s3_client, mlflow_tracker)
546 try:
547 # Step 1: Export to ONNX
548 onnx_path = output_dir / f'{model_name}_v{run_id[:8]}.onnx'
549 exporter.export_to_onnx(pytorch_model, onnx_path)
551 # Step 2: Validate ONNX
552 model_info = exporter.validate_onnx_model(onnx_path)
554 # Step 3: Test inference
555 inference_metrics = exporter.test_onnx_inference(onnx_path, pytorch_model)
557 # Step 4: Upload to MinIO
558 s3_uri = exporter.upload_to_minio(onnx_path)
560 # Step 5: Get metadata
561 metadata = exporter.get_model_metadata(
562 onnx_path,
563 pytorch_model,
564 run_id,
565 inference_metrics,
566 )
568 # Step 6: Register with MLflow
569 registration = exporter.register_with_mlflow(
570 model_name,
571 s3_uri,
572 metadata,
573 stage='Staging',
574 )
576 logger.info(
577 "onnx_export_complete",
578 model_name=model_name,
579 onnx_file_size_mb=f"{metadata['onnx_file_size_mb']:.2f}",
580 s3_uri=s3_uri,
581 mlflow_version=registration['model_version'],
582 speedup=f"{inference_metrics['speedup']:.2f}x",
583 )
585 return {
586 'success': True,
587 'model_name': model_name,
588 'run_id': run_id,
589 'onnx_path': str(onnx_path),
590 's3_uri': s3_uri,
591 'model_info': model_info,
592 'metadata': metadata,
593 'inference_metrics': inference_metrics,
594 'registration': registration,
595 }
597 except Exception as e:
598 logger.error(
599 "onnx_export_workflow_failed",
600 error=str(e),
601 model_name=model_name,
602 )
603 return {
604 'success': False,
605 'error': str(e),
606 }
609if __name__ == "__main__":
610 """Quick test - verify ONNX export works with dummy model."""
611 import logging
612 logging.basicConfig(level=logging.INFO)
614 from src.models.localization_net import LocalizationNet
616 logger.info("Testing ONNX export...")
618 # Create dummy model
619 model = LocalizationNet(pretrained=False)
620 model.eval()
622 # Export
623 exporter = ONNXExporter(None, None)
624 output_path = Path('/tmp/test_model.onnx')
625 exporter.export_to_onnx(model, output_path)
627 # Validate
628 model_info = exporter.validate_onnx_model(output_path)
629 logger.info(f"✅ ONNX export successful! Inputs: {model_info['inputs']}, Outputs: {model_info['outputs']}")