Coverage for services/training/src/train.py: 0%

149 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-25 16:18 +0000

1""" 

2Training Entry Point Script: Orchestrate complete training pipeline. 

3 

4This script orchestrates the complete RF source localization training pipeline: 

51. Load training/validation sessions from MinIO and PostgreSQL 

62. Create PyTorch DataLoaders with proper batching 

73. Initialize PyTorch Lightning trainer with checkpoint callbacks 

84. Train LocalizationNet with Gaussian NLL loss 

95. Export best checkpoint to ONNX format 

106. Register model with MLflow Model Registry 

11 

12Pipeline: 

13 Sessions (MinIO) → DataLoaders → Lightning Trainer → Best Checkpoint 

14 ↓ ↓ 

15 PostgreSQL (metadata) ONNX Export + MLflow Registry 

16 

17Performance targets: 

18- Training throughput: 32 samples/batch (configurable) 

19- Validation frequency: Every epoch 

20- Best model selection: Lowest validation loss (early stopping) 

21- Checkpoint saving: Top 3 models (by validation loss) 

22- ONNX inference speedup: 1.5-2.5x vs PyTorch 

23 

24Usage: 

25 python train.py --epochs 100 --batch_size 32 --lr 1e-3 --val_split 0.2 

26 python train.py --checkpoint /path/to/checkpoint.ckpt --resume_training 

27 python train.py --export_only --checkpoint /path/to/best.ckpt 

28 

29Configuration: 

30 All parameters configurable via CLI arguments or .env file 

31 MLflow tracking automatic (experiment: heimdall-localization) 

32 Checkpoints saved to: /tmp/heimdall_checkpoints/ 

33 ONNX export to: MinIO (heimdall-models bucket) 

34""" 

35 

36import os 

37import sys 

38import argparse 

39import json 

40from pathlib import Path 

41from typing import Optional, Tuple, Dict, Any 

42from datetime import datetime 

43import logging 

44 

45# PyTorch & Lightning 

46import torch 

47import torch.nn as nn 

48from torch.utils.data import DataLoader, random_split 

49import pytorch_lightning as pl 

50from pytorch_lightning.callbacks import ( 

51 ModelCheckpoint, 

52 EarlyStopping, 

53 LearningRateMonitor, 

54) 

55 

56# Project imports 

57import structlog 

58from config import settings # Now imports from config/ package 

59from mlflow_setup import MLflowTracker 

60from onnx_export import export_and_register_model, ONNXExporter 

61from models.localization_net import LocalizationNet, LocalizationLightningModule 

62from data.dataset import HeimdallDataset 

63from data.features import MEL_SPECTROGRAM_SHAPE 

64 

65# Configure logging 

66logger = structlog.get_logger(__name__) 

67logging.basicConfig(level=logging.INFO) 

68pl_logger = logging.getLogger("pytorch_lightning") 

69pl_logger.setLevel(logging.WARNING) # Reduce Lightning verbosity 

70 

71 

72class TrainingPipeline: 

73 """ 

74 Orchestrates the complete training workflow. 

75  

76 Responsibilities: 

77 - Load data from MinIO and PostgreSQL 

78 - Create data loaders with proper batching 

79 - Initialize Lightning trainer with callbacks 

80 - Execute training loop 

81 - Export and register best model 

82 """ 

83 

84 def __init__( 

85 self, 

86 epochs: int = 100, 

87 batch_size: int = 32, 

88 learning_rate: float = 1e-3, 

89 validation_split: float = 0.2, 

90 num_workers: int = 4, 

91 accelerator: str = "gpu", 

92 devices: int = 1, 

93 checkpoint_dir: Optional[Path] = None, 

94 experiment_name: str = "heimdall-localization", 

95 run_name_prefix: str = "rf-localization", 

96 ): 

97 """ 

98 Initialize training pipeline. 

99  

100 Args: 

101 epochs (int): Number of training epochs 

102 batch_size (int): Batch size for data loaders 

103 learning_rate (float): Learning rate for optimizer 

104 validation_split (float): Fraction of data for validation 

105 num_workers (int): Number of PyTorch DataLoader workers 

106 accelerator (str): Training accelerator ("cpu", "gpu", "auto") 

107 devices (int): Number of GPUs (if accelerator="gpu") 

108 checkpoint_dir (Path): Directory for saving checkpoints 

109 experiment_name (str): MLflow experiment name 

110 run_name_prefix (str): Prefix for MLflow run name 

111 """ 

112 

113 self.epochs = epochs 

114 self.batch_size = batch_size 

115 self.learning_rate = learning_rate 

116 self.validation_split = validation_split 

117 self.num_workers = num_workers 

118 self.accelerator = accelerator if torch.cuda.is_available() else "cpu" 

119 self.devices = devices if torch.cuda.is_available() else 1 

120 

121 # Setup checkpoint directory 

122 self.checkpoint_dir = checkpoint_dir or Path("/tmp/heimdall_checkpoints") 

123 self.checkpoint_dir.mkdir(parents=True, exist_ok=True) 

124 

125 # Initialize MLflow tracker 

126 self.mlflow_tracker = self._init_mlflow( 

127 experiment_name=experiment_name, 

128 run_name_prefix=run_name_prefix, 

129 ) 

130 

131 # Initialize boto3 S3 client for MinIO 

132 import boto3 

133 self.s3_client = boto3.client( 

134 "s3", 

135 endpoint_url=settings.mlflow_s3_endpoint_url, 

136 aws_access_key_id=settings.mlflow_s3_access_key_id, 

137 aws_secret_access_key=settings.mlflow_s3_secret_access_key, 

138 ) 

139 

140 # Initialize ONNX exporter 

141 self.onnx_exporter = ONNXExporter(self.s3_client, self.mlflow_tracker) 

142 

143 logger.info( 

144 "training_pipeline_initialized", 

145 epochs=epochs, 

146 batch_size=batch_size, 

147 learning_rate=learning_rate, 

148 accelerator=self.accelerator, 

149 devices=self.devices, 

150 checkpoint_dir=str(self.checkpoint_dir), 

151 ) 

152 

153 def _init_mlflow( 

154 self, 

155 experiment_name: str, 

156 run_name_prefix: str, 

157 ) -> MLflowTracker: 

158 """ 

159 Initialize MLflow tracker. 

160  

161 Args: 

162 experiment_name (str): MLflow experiment name 

163 run_name_prefix (str): Prefix for run name 

164  

165 Returns: 

166 MLflowTracker instance 

167 """ 

168 tracker = MLflowTracker( 

169 tracking_uri=settings.mlflow_tracking_uri, 

170 artifact_uri=settings.mlflow_artifact_uri, 

171 backend_store_uri=settings.mlflow_backend_store_uri, 

172 registry_uri=settings.mlflow_registry_uri, 

173 s3_endpoint_url=settings.mlflow_s3_endpoint_url, 

174 s3_access_key_id=settings.mlflow_s3_access_key_id, 

175 s3_secret_access_key=settings.mlflow_s3_secret_access_key, 

176 experiment_name=experiment_name, 

177 ) 

178 

179 # Create run name with timestamp 

180 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 

181 run_name = f"{run_name_prefix}_{timestamp}" 

182 

183 # Start new run 

184 tracker.start_run(run_name) 

185 

186 # Log hyperparameters 

187 tracker.log_params({ 

188 "epochs": self.epochs, 

189 "batch_size": self.batch_size, 

190 "learning_rate": self.learning_rate, 

191 "validation_split": self.validation_split, 

192 "num_workers": self.num_workers, 

193 "accelerator": self.accelerator, 

194 "optimizer": "AdamW", 

195 "loss_function": "GaussianNLL", 

196 "model": "ConvNeXt-Large", 

197 }) 

198 

199 return tracker 

200 

201 def load_data( 

202 self, 

203 data_dir: str = "/tmp/heimdall_training_data", 

204 ) -> Tuple[DataLoader, DataLoader]: 

205 """ 

206 Load training and validation data. 

207  

208 Loads IQ recordings from MinIO and ground truth from PostgreSQL, 

209 creates train/val split, and returns PyTorch DataLoaders. 

210  

211 Args: 

212 data_dir (str): Directory containing preprocessed data 

213  

214 Returns: 

215 Tuple of (train_loader, val_loader) 

216 """ 

217 logger.info("loading_dataset", data_dir=data_dir) 

218 

219 # Create dataset 

220 dataset = HeimdallDataset( 

221 data_dir=data_dir, 

222 split="all", 

223 augmentation=True, 

224 ) 

225 

226 dataset_size = len(dataset) 

227 val_size = int(dataset_size * self.validation_split) 

228 train_size = dataset_size - val_size 

229 

230 logger.info( 

231 "dataset_loaded", 

232 total_samples=dataset_size, 

233 train_samples=train_size, 

234 val_samples=val_size, 

235 ) 

236 

237 # Split into train/val 

238 train_dataset, val_dataset = random_split( 

239 dataset, 

240 [train_size, val_size], 

241 generator=torch.Generator().manual_seed(42), # Reproducible split 

242 ) 

243 

244 # Create data loaders 

245 train_loader = DataLoader( 

246 train_dataset, 

247 batch_size=self.batch_size, 

248 shuffle=True, 

249 num_workers=self.num_workers, 

250 pin_memory=True, 

251 ) 

252 

253 val_loader = DataLoader( 

254 val_dataset, 

255 batch_size=self.batch_size, 

256 shuffle=False, 

257 num_workers=self.num_workers, 

258 pin_memory=True, 

259 ) 

260 

261 logger.info( 

262 "dataloaders_created", 

263 train_batches=len(train_loader), 

264 val_batches=len(val_loader), 

265 batch_size=self.batch_size, 

266 ) 

267 

268 return train_loader, val_loader 

269 

270 def create_lightning_module(self) -> LocalizationLightningModule: 

271 """ 

272 Create PyTorch Lightning module for training. 

273  

274 Returns: 

275 LocalizationLightningModule instance configured for training 

276 """ 

277 logger.info( 

278 "creating_lightning_module", 

279 learning_rate=self.learning_rate, 

280 ) 

281 

282 # Initialize model 

283 model = LocalizationNet() 

284 

285 # Wrap in Lightning module 

286 lightning_module = LocalizationLightningModule( 

287 model=model, 

288 learning_rate=self.learning_rate, 

289 ) 

290 

291 return lightning_module 

292 

293 def create_trainer( 

294 self, 

295 ) -> pl.Trainer: 

296 """ 

297 Create PyTorch Lightning trainer with callbacks. 

298  

299 Callbacks: 

300 - ModelCheckpoint: Save top 3 models by validation loss 

301 - EarlyStopping: Stop if val_loss doesn't improve for 10 epochs 

302 - LearningRateMonitor: Track learning rate in MLflow 

303  

304 Returns: 

305 Configured Trainer instance 

306 """ 

307 logger.info( 

308 "creating_trainer", 

309 accelerator=self.accelerator, 

310 devices=self.devices, 

311 epochs=self.epochs, 

312 ) 

313 

314 # Callbacks 

315 checkpoint_callback = ModelCheckpoint( 

316 dirpath=self.checkpoint_dir, 

317 filename="localization-{epoch:02d}-{val_loss:.4f}", 

318 monitor="val_loss", 

319 mode="min", 

320 save_top_k=3, # Keep top 3 models 

321 verbose=True, 

322 ) 

323 

324 early_stopping_callback = EarlyStopping( 

325 monitor="val_loss", 

326 mode="min", 

327 patience=10, 

328 verbose=True, 

329 ) 

330 

331 lr_monitor_callback = LearningRateMonitor(logging_interval="epoch") 

332 

333 # Create trainer 

334 trainer = pl.Trainer( 

335 max_epochs=self.epochs, 

336 accelerator=self.accelerator, 

337 devices=self.devices, 

338 callbacks=[ 

339 checkpoint_callback, 

340 early_stopping_callback, 

341 lr_monitor_callback, 

342 ], 

343 log_every_n_steps=10, 

344 enable_checkpointing=True, 

345 enable_model_summary=True, 

346 ) 

347 

348 return trainer 

349 

350 def train( 

351 self, 

352 train_loader: DataLoader, 

353 val_loader: DataLoader, 

354 ) -> Path: 

355 """ 

356 Execute training loop. 

357  

358 Args: 

359 train_loader (DataLoader): Training data loader 

360 val_loader (DataLoader): Validation data loader 

361  

362 Returns: 

363 Path to best checkpoint 

364 """ 

365 logger.info("starting_training", epochs=self.epochs) 

366 

367 # Create Lightning module and trainer 

368 lightning_module = self.create_lightning_module() 

369 trainer = self.create_trainer() 

370 

371 # Train 

372 trainer.fit( 

373 lightning_module, 

374 train_dataloaders=train_loader, 

375 val_dataloaders=val_loader, 

376 ) 

377 

378 # Get best checkpoint path 

379 best_checkpoint_path = trainer.checkpoint_callback.best_model_path 

380 

381 logger.info( 

382 "training_complete", 

383 best_checkpoint=best_checkpoint_path, 

384 best_val_loss=trainer.checkpoint_callback.best_model_score, 

385 ) 

386 

387 # Log final metrics to MLflow 

388 self.mlflow_tracker.log_metric("best_val_loss", float(trainer.checkpoint_callback.best_model_score)) 

389 self.mlflow_tracker.log_metric("final_epoch", trainer.current_epoch) 

390 

391 return Path(best_checkpoint_path) 

392 

393 def export_and_register( 

394 self, 

395 best_checkpoint_path: Path, 

396 model_name: str = "heimdall-localization-onnx", 

397 ) -> Dict[str, Any]: 

398 """ 

399 Export best model to ONNX and register with MLflow. 

400  

401 Pipeline: 

402 1. Load best checkpoint from training 

403 2. Export to ONNX format 

404 3. Validate ONNX (shape, accuracy) 

405 4. Upload to MinIO 

406 5. Register with MLflow Model Registry 

407  

408 Args: 

409 best_checkpoint_path (Path): Path to best checkpoint from training 

410 model_name (str): Name for ONNX model in MLflow registry 

411  

412 Returns: 

413 Dict with export results (ONNX path, S3 URI, model version, etc.) 

414 """ 

415 logger.info( 

416 "exporting_model", 

417 checkpoint=str(best_checkpoint_path), 

418 model_name=model_name, 

419 ) 

420 

421 # Load checkpoint 

422 checkpoint = torch.load(best_checkpoint_path, map_location="cpu") 

423 

424 # Create model and load state dict 

425 model = LocalizationNet() 

426 

427 # Handle Lightning checkpoint format 

428 if "state_dict" in checkpoint: 

429 state_dict = checkpoint["state_dict"] 

430 # Remove "model." prefix added by Lightning 

431 state_dict = { 

432 k.replace("model.", "", 1): v 

433 for k, v in state_dict.items() 

434 } 

435 model.load_state_dict(state_dict) 

436 else: 

437 model.load_state_dict(checkpoint) 

438 

439 # Export and register 

440 result = export_and_register_model( 

441 pytorch_model=model, 

442 run_id=self.mlflow_tracker.active_run_id, 

443 s3_client=self.s3_client, 

444 mlflow_tracker=self.mlflow_tracker, 

445 model_name=model_name, 

446 ) 

447 

448 # Log export results to MLflow 

449 self.mlflow_tracker.log_params({ 

450 "onnx_model_name": result.get("model_name", "unknown"), 

451 "onnx_model_version": result.get("model_version", "unknown"), 

452 "onnx_file_size_mb": result.get("metadata", {}).get("file_size_mb", 0), 

453 }) 

454 

455 logger.info( 

456 "model_exported_and_registered", 

457 model_name=result.get("model_name"), 

458 model_version=result.get("model_version"), 

459 s3_uri=result.get("s3_uri"), 

460 ) 

461 

462 return result 

463 

464 def run( 

465 self, 

466 data_dir: str = "/tmp/heimdall_training_data", 

467 export_only: bool = False, 

468 checkpoint_path: Optional[Path] = None, 

469 ) -> Dict[str, Any]: 

470 """ 

471 Execute complete training pipeline. 

472  

473 Pipeline: 

474 1. Load data (if not export_only) 

475 2. Train model (if not export_only) 

476 3. Export best checkpoint to ONNX 

477 4. Register with MLflow 

478  

479 Args: 

480 data_dir (str): Directory containing training data 

481 export_only (bool): Skip training, only export from checkpoint 

482 checkpoint_path (Optional[Path]): Path to checkpoint (for export_only=True) 

483  

484 Returns: 

485 Dict with pipeline results 

486 """ 

487 start_time = datetime.now() 

488 

489 try: 

490 if export_only and checkpoint_path: 

491 # Only export and register existing checkpoint 

492 logger.info("running_export_only_mode", checkpoint=str(checkpoint_path)) 

493 result = self.export_and_register( 

494 best_checkpoint_path=checkpoint_path, 

495 ) 

496 else: 

497 # Full training pipeline 

498 logger.info("running_full_training_pipeline", data_dir=data_dir) 

499 

500 # 1. Load data 

501 train_loader, val_loader = self.load_data(data_dir=data_dir) 

502 

503 # 2. Train 

504 best_checkpoint = self.train(train_loader, val_loader) 

505 

506 # 3. Export and register 

507 result = self.export_and_register(best_checkpoint_path=best_checkpoint) 

508 

509 # Calculate elapsed time 

510 elapsed_time = (datetime.now() - start_time).total_seconds() 

511 

512 # Log final status 

513 logger.info( 

514 "pipeline_complete", 

515 elapsed_seconds=elapsed_time, 

516 success=result.get("success", False), 

517 ) 

518 

519 # Finalize MLflow run 

520 self.mlflow_tracker.end_run() 

521 

522 return { 

523 "success": True, 

524 "elapsed_time": elapsed_time, 

525 "export_result": result, 

526 } 

527 

528 except Exception as e: 

529 logger.error( 

530 "pipeline_error", 

531 error=str(e), 

532 exc_info=True, 

533 ) 

534 

535 # End MLflow run with failed status 

536 self.mlflow_tracker.end_run(status="FAILED") 

537 

538 raise 

539 

540 

541def parse_arguments() -> argparse.Namespace: 

542 """ 

543 Parse command-line arguments. 

544  

545 Returns: 

546 Parsed arguments 

547 """ 

548 parser = argparse.ArgumentParser( 

549 description="Training entry point for Heimdall RF localization pipeline", 

550 ) 

551 

552 # Training parameters 

553 parser.add_argument( 

554 "--epochs", 

555 type=int, 

556 default=100, 

557 help="Number of training epochs (default: 100)", 

558 ) 

559 parser.add_argument( 

560 "--batch_size", 

561 type=int, 

562 default=32, 

563 help="Batch size for training (default: 32)", 

564 ) 

565 parser.add_argument( 

566 "--learning_rate", 

567 "--lr", 

568 type=float, 

569 default=1e-3, 

570 help="Learning rate for optimizer (default: 1e-3)", 

571 ) 

572 parser.add_argument( 

573 "--validation_split", 

574 "--val_split", 

575 type=float, 

576 default=0.2, 

577 help="Fraction of data for validation (default: 0.2)", 

578 ) 

579 

580 # Data parameters 

581 parser.add_argument( 

582 "--data_dir", 

583 type=str, 

584 default="/tmp/heimdall_training_data", 

585 help="Directory containing training data", 

586 ) 

587 parser.add_argument( 

588 "--num_workers", 

589 type=int, 

590 default=4, 

591 help="Number of DataLoader workers (default: 4)", 

592 ) 

593 

594 # Device parameters 

595 parser.add_argument( 

596 "--accelerator", 

597 type=str, 

598 default="gpu", 

599 choices=["cpu", "gpu", "auto"], 

600 help="Training accelerator (default: gpu)", 

601 ) 

602 parser.add_argument( 

603 "--devices", 

604 type=int, 

605 default=1, 

606 help="Number of GPUs (default: 1)", 

607 ) 

608 

609 # Checkpoint parameters 

610 parser.add_argument( 

611 "--checkpoint_dir", 

612 type=str, 

613 default="/tmp/heimdall_checkpoints", 

614 help="Directory for saving checkpoints", 

615 ) 

616 parser.add_argument( 

617 "--checkpoint", 

618 type=str, 

619 default=None, 

620 help="Path to existing checkpoint (for resume or export)", 

621 ) 

622 

623 # Mode parameters 

624 parser.add_argument( 

625 "--export_only", 

626 action="store_true", 

627 help="Skip training, only export and register checkpoint", 

628 ) 

629 parser.add_argument( 

630 "--resume_training", 

631 action="store_true", 

632 help="Resume training from checkpoint", 

633 ) 

634 

635 # MLflow parameters 

636 parser.add_argument( 

637 "--experiment_name", 

638 type=str, 

639 default="heimdall-localization", 

640 help="MLflow experiment name", 

641 ) 

642 parser.add_argument( 

643 "--run_name_prefix", 

644 type=str, 

645 default="rf-localization", 

646 help="Prefix for MLflow run name", 

647 ) 

648 

649 return parser.parse_args() 

650 

651 

652def main(): 

653 """ 

654 Main entry point for training pipeline. 

655 """ 

656 # Parse arguments 

657 args = parse_arguments() 

658 

659 logger.info( 

660 "training_pipeline_started", 

661 epochs=args.epochs, 

662 batch_size=args.batch_size, 

663 learning_rate=args.learning_rate, 

664 accelerator=args.accelerator, 

665 ) 

666 

667 try: 

668 # Create pipeline 

669 pipeline = TrainingPipeline( 

670 epochs=args.epochs, 

671 batch_size=args.batch_size, 

672 learning_rate=args.learning_rate, 

673 validation_split=args.validation_split, 

674 num_workers=args.num_workers, 

675 accelerator=args.accelerator, 

676 devices=args.devices, 

677 checkpoint_dir=Path(args.checkpoint_dir) if args.checkpoint_dir else None, 

678 experiment_name=args.experiment_name, 

679 run_name_prefix=args.run_name_prefix, 

680 ) 

681 

682 # Run pipeline 

683 result = pipeline.run( 

684 data_dir=args.data_dir, 

685 export_only=args.export_only, 

686 checkpoint_path=Path(args.checkpoint) if args.checkpoint else None, 

687 ) 

688 

689 # Print summary 

690 print("\n" + "="*80) 

691 print("TRAINING PIPELINE COMPLETE") 

692 print("="*80) 

693 print(f"Success: {result['success']}") 

694 print(f"Elapsed Time: {result['elapsed_time']:.2f} seconds") 

695 print(f"Export Result: {json.dumps(result['export_result'], indent=2)}") 

696 print("="*80 + "\n") 

697 

698 sys.exit(0) 

699 

700 except Exception as e: 

701 logger.error( 

702 "training_pipeline_failed", 

703 error=str(e), 

704 exc_info=True, 

705 ) 

706 print(f"\nTraining pipeline failed: {str(e)}\n", file=sys.stderr) 

707 sys.exit(1) 

708 

709 

710if __name__ == "__main__": 

711 main()