Coverage for services/training/src/train.py: 0%
149 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
1"""
2Training Entry Point Script: Orchestrate complete training pipeline.
4This script orchestrates the complete RF source localization training pipeline:
51. Load training/validation sessions from MinIO and PostgreSQL
62. Create PyTorch DataLoaders with proper batching
73. Initialize PyTorch Lightning trainer with checkpoint callbacks
84. Train LocalizationNet with Gaussian NLL loss
95. Export best checkpoint to ONNX format
106. Register model with MLflow Model Registry
12Pipeline:
13 Sessions (MinIO) → DataLoaders → Lightning Trainer → Best Checkpoint
14 ↓ ↓
15 PostgreSQL (metadata) ONNX Export + MLflow Registry
17Performance targets:
18- Training throughput: 32 samples/batch (configurable)
19- Validation frequency: Every epoch
20- Best model selection: Lowest validation loss (early stopping)
21- Checkpoint saving: Top 3 models (by validation loss)
22- ONNX inference speedup: 1.5-2.5x vs PyTorch
24Usage:
25 python train.py --epochs 100 --batch_size 32 --lr 1e-3 --val_split 0.2
26 python train.py --checkpoint /path/to/checkpoint.ckpt --resume_training
27 python train.py --export_only --checkpoint /path/to/best.ckpt
29Configuration:
30 All parameters configurable via CLI arguments or .env file
31 MLflow tracking automatic (experiment: heimdall-localization)
32 Checkpoints saved to: /tmp/heimdall_checkpoints/
33 ONNX export to: MinIO (heimdall-models bucket)
34"""
36import os
37import sys
38import argparse
39import json
40from pathlib import Path
41from typing import Optional, Tuple, Dict, Any
42from datetime import datetime
43import logging
45# PyTorch & Lightning
46import torch
47import torch.nn as nn
48from torch.utils.data import DataLoader, random_split
49import pytorch_lightning as pl
50from pytorch_lightning.callbacks import (
51 ModelCheckpoint,
52 EarlyStopping,
53 LearningRateMonitor,
54)
56# Project imports
57import structlog
58from config import settings # Now imports from config/ package
59from mlflow_setup import MLflowTracker
60from onnx_export import export_and_register_model, ONNXExporter
61from models.localization_net import LocalizationNet, LocalizationLightningModule
62from data.dataset import HeimdallDataset
63from data.features import MEL_SPECTROGRAM_SHAPE
65# Configure logging
66logger = structlog.get_logger(__name__)
67logging.basicConfig(level=logging.INFO)
68pl_logger = logging.getLogger("pytorch_lightning")
69pl_logger.setLevel(logging.WARNING) # Reduce Lightning verbosity
72class TrainingPipeline:
73 """
74 Orchestrates the complete training workflow.
76 Responsibilities:
77 - Load data from MinIO and PostgreSQL
78 - Create data loaders with proper batching
79 - Initialize Lightning trainer with callbacks
80 - Execute training loop
81 - Export and register best model
82 """
84 def __init__(
85 self,
86 epochs: int = 100,
87 batch_size: int = 32,
88 learning_rate: float = 1e-3,
89 validation_split: float = 0.2,
90 num_workers: int = 4,
91 accelerator: str = "gpu",
92 devices: int = 1,
93 checkpoint_dir: Optional[Path] = None,
94 experiment_name: str = "heimdall-localization",
95 run_name_prefix: str = "rf-localization",
96 ):
97 """
98 Initialize training pipeline.
100 Args:
101 epochs (int): Number of training epochs
102 batch_size (int): Batch size for data loaders
103 learning_rate (float): Learning rate for optimizer
104 validation_split (float): Fraction of data for validation
105 num_workers (int): Number of PyTorch DataLoader workers
106 accelerator (str): Training accelerator ("cpu", "gpu", "auto")
107 devices (int): Number of GPUs (if accelerator="gpu")
108 checkpoint_dir (Path): Directory for saving checkpoints
109 experiment_name (str): MLflow experiment name
110 run_name_prefix (str): Prefix for MLflow run name
111 """
113 self.epochs = epochs
114 self.batch_size = batch_size
115 self.learning_rate = learning_rate
116 self.validation_split = validation_split
117 self.num_workers = num_workers
118 self.accelerator = accelerator if torch.cuda.is_available() else "cpu"
119 self.devices = devices if torch.cuda.is_available() else 1
121 # Setup checkpoint directory
122 self.checkpoint_dir = checkpoint_dir or Path("/tmp/heimdall_checkpoints")
123 self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
125 # Initialize MLflow tracker
126 self.mlflow_tracker = self._init_mlflow(
127 experiment_name=experiment_name,
128 run_name_prefix=run_name_prefix,
129 )
131 # Initialize boto3 S3 client for MinIO
132 import boto3
133 self.s3_client = boto3.client(
134 "s3",
135 endpoint_url=settings.mlflow_s3_endpoint_url,
136 aws_access_key_id=settings.mlflow_s3_access_key_id,
137 aws_secret_access_key=settings.mlflow_s3_secret_access_key,
138 )
140 # Initialize ONNX exporter
141 self.onnx_exporter = ONNXExporter(self.s3_client, self.mlflow_tracker)
143 logger.info(
144 "training_pipeline_initialized",
145 epochs=epochs,
146 batch_size=batch_size,
147 learning_rate=learning_rate,
148 accelerator=self.accelerator,
149 devices=self.devices,
150 checkpoint_dir=str(self.checkpoint_dir),
151 )
153 def _init_mlflow(
154 self,
155 experiment_name: str,
156 run_name_prefix: str,
157 ) -> MLflowTracker:
158 """
159 Initialize MLflow tracker.
161 Args:
162 experiment_name (str): MLflow experiment name
163 run_name_prefix (str): Prefix for run name
165 Returns:
166 MLflowTracker instance
167 """
168 tracker = MLflowTracker(
169 tracking_uri=settings.mlflow_tracking_uri,
170 artifact_uri=settings.mlflow_artifact_uri,
171 backend_store_uri=settings.mlflow_backend_store_uri,
172 registry_uri=settings.mlflow_registry_uri,
173 s3_endpoint_url=settings.mlflow_s3_endpoint_url,
174 s3_access_key_id=settings.mlflow_s3_access_key_id,
175 s3_secret_access_key=settings.mlflow_s3_secret_access_key,
176 experiment_name=experiment_name,
177 )
179 # Create run name with timestamp
180 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
181 run_name = f"{run_name_prefix}_{timestamp}"
183 # Start new run
184 tracker.start_run(run_name)
186 # Log hyperparameters
187 tracker.log_params({
188 "epochs": self.epochs,
189 "batch_size": self.batch_size,
190 "learning_rate": self.learning_rate,
191 "validation_split": self.validation_split,
192 "num_workers": self.num_workers,
193 "accelerator": self.accelerator,
194 "optimizer": "AdamW",
195 "loss_function": "GaussianNLL",
196 "model": "ConvNeXt-Large",
197 })
199 return tracker
201 def load_data(
202 self,
203 data_dir: str = "/tmp/heimdall_training_data",
204 ) -> Tuple[DataLoader, DataLoader]:
205 """
206 Load training and validation data.
208 Loads IQ recordings from MinIO and ground truth from PostgreSQL,
209 creates train/val split, and returns PyTorch DataLoaders.
211 Args:
212 data_dir (str): Directory containing preprocessed data
214 Returns:
215 Tuple of (train_loader, val_loader)
216 """
217 logger.info("loading_dataset", data_dir=data_dir)
219 # Create dataset
220 dataset = HeimdallDataset(
221 data_dir=data_dir,
222 split="all",
223 augmentation=True,
224 )
226 dataset_size = len(dataset)
227 val_size = int(dataset_size * self.validation_split)
228 train_size = dataset_size - val_size
230 logger.info(
231 "dataset_loaded",
232 total_samples=dataset_size,
233 train_samples=train_size,
234 val_samples=val_size,
235 )
237 # Split into train/val
238 train_dataset, val_dataset = random_split(
239 dataset,
240 [train_size, val_size],
241 generator=torch.Generator().manual_seed(42), # Reproducible split
242 )
244 # Create data loaders
245 train_loader = DataLoader(
246 train_dataset,
247 batch_size=self.batch_size,
248 shuffle=True,
249 num_workers=self.num_workers,
250 pin_memory=True,
251 )
253 val_loader = DataLoader(
254 val_dataset,
255 batch_size=self.batch_size,
256 shuffle=False,
257 num_workers=self.num_workers,
258 pin_memory=True,
259 )
261 logger.info(
262 "dataloaders_created",
263 train_batches=len(train_loader),
264 val_batches=len(val_loader),
265 batch_size=self.batch_size,
266 )
268 return train_loader, val_loader
270 def create_lightning_module(self) -> LocalizationLightningModule:
271 """
272 Create PyTorch Lightning module for training.
274 Returns:
275 LocalizationLightningModule instance configured for training
276 """
277 logger.info(
278 "creating_lightning_module",
279 learning_rate=self.learning_rate,
280 )
282 # Initialize model
283 model = LocalizationNet()
285 # Wrap in Lightning module
286 lightning_module = LocalizationLightningModule(
287 model=model,
288 learning_rate=self.learning_rate,
289 )
291 return lightning_module
293 def create_trainer(
294 self,
295 ) -> pl.Trainer:
296 """
297 Create PyTorch Lightning trainer with callbacks.
299 Callbacks:
300 - ModelCheckpoint: Save top 3 models by validation loss
301 - EarlyStopping: Stop if val_loss doesn't improve for 10 epochs
302 - LearningRateMonitor: Track learning rate in MLflow
304 Returns:
305 Configured Trainer instance
306 """
307 logger.info(
308 "creating_trainer",
309 accelerator=self.accelerator,
310 devices=self.devices,
311 epochs=self.epochs,
312 )
314 # Callbacks
315 checkpoint_callback = ModelCheckpoint(
316 dirpath=self.checkpoint_dir,
317 filename="localization-{epoch:02d}-{val_loss:.4f}",
318 monitor="val_loss",
319 mode="min",
320 save_top_k=3, # Keep top 3 models
321 verbose=True,
322 )
324 early_stopping_callback = EarlyStopping(
325 monitor="val_loss",
326 mode="min",
327 patience=10,
328 verbose=True,
329 )
331 lr_monitor_callback = LearningRateMonitor(logging_interval="epoch")
333 # Create trainer
334 trainer = pl.Trainer(
335 max_epochs=self.epochs,
336 accelerator=self.accelerator,
337 devices=self.devices,
338 callbacks=[
339 checkpoint_callback,
340 early_stopping_callback,
341 lr_monitor_callback,
342 ],
343 log_every_n_steps=10,
344 enable_checkpointing=True,
345 enable_model_summary=True,
346 )
348 return trainer
350 def train(
351 self,
352 train_loader: DataLoader,
353 val_loader: DataLoader,
354 ) -> Path:
355 """
356 Execute training loop.
358 Args:
359 train_loader (DataLoader): Training data loader
360 val_loader (DataLoader): Validation data loader
362 Returns:
363 Path to best checkpoint
364 """
365 logger.info("starting_training", epochs=self.epochs)
367 # Create Lightning module and trainer
368 lightning_module = self.create_lightning_module()
369 trainer = self.create_trainer()
371 # Train
372 trainer.fit(
373 lightning_module,
374 train_dataloaders=train_loader,
375 val_dataloaders=val_loader,
376 )
378 # Get best checkpoint path
379 best_checkpoint_path = trainer.checkpoint_callback.best_model_path
381 logger.info(
382 "training_complete",
383 best_checkpoint=best_checkpoint_path,
384 best_val_loss=trainer.checkpoint_callback.best_model_score,
385 )
387 # Log final metrics to MLflow
388 self.mlflow_tracker.log_metric("best_val_loss", float(trainer.checkpoint_callback.best_model_score))
389 self.mlflow_tracker.log_metric("final_epoch", trainer.current_epoch)
391 return Path(best_checkpoint_path)
393 def export_and_register(
394 self,
395 best_checkpoint_path: Path,
396 model_name: str = "heimdall-localization-onnx",
397 ) -> Dict[str, Any]:
398 """
399 Export best model to ONNX and register with MLflow.
401 Pipeline:
402 1. Load best checkpoint from training
403 2. Export to ONNX format
404 3. Validate ONNX (shape, accuracy)
405 4. Upload to MinIO
406 5. Register with MLflow Model Registry
408 Args:
409 best_checkpoint_path (Path): Path to best checkpoint from training
410 model_name (str): Name for ONNX model in MLflow registry
412 Returns:
413 Dict with export results (ONNX path, S3 URI, model version, etc.)
414 """
415 logger.info(
416 "exporting_model",
417 checkpoint=str(best_checkpoint_path),
418 model_name=model_name,
419 )
421 # Load checkpoint
422 checkpoint = torch.load(best_checkpoint_path, map_location="cpu")
424 # Create model and load state dict
425 model = LocalizationNet()
427 # Handle Lightning checkpoint format
428 if "state_dict" in checkpoint:
429 state_dict = checkpoint["state_dict"]
430 # Remove "model." prefix added by Lightning
431 state_dict = {
432 k.replace("model.", "", 1): v
433 for k, v in state_dict.items()
434 }
435 model.load_state_dict(state_dict)
436 else:
437 model.load_state_dict(checkpoint)
439 # Export and register
440 result = export_and_register_model(
441 pytorch_model=model,
442 run_id=self.mlflow_tracker.active_run_id,
443 s3_client=self.s3_client,
444 mlflow_tracker=self.mlflow_tracker,
445 model_name=model_name,
446 )
448 # Log export results to MLflow
449 self.mlflow_tracker.log_params({
450 "onnx_model_name": result.get("model_name", "unknown"),
451 "onnx_model_version": result.get("model_version", "unknown"),
452 "onnx_file_size_mb": result.get("metadata", {}).get("file_size_mb", 0),
453 })
455 logger.info(
456 "model_exported_and_registered",
457 model_name=result.get("model_name"),
458 model_version=result.get("model_version"),
459 s3_uri=result.get("s3_uri"),
460 )
462 return result
464 def run(
465 self,
466 data_dir: str = "/tmp/heimdall_training_data",
467 export_only: bool = False,
468 checkpoint_path: Optional[Path] = None,
469 ) -> Dict[str, Any]:
470 """
471 Execute complete training pipeline.
473 Pipeline:
474 1. Load data (if not export_only)
475 2. Train model (if not export_only)
476 3. Export best checkpoint to ONNX
477 4. Register with MLflow
479 Args:
480 data_dir (str): Directory containing training data
481 export_only (bool): Skip training, only export from checkpoint
482 checkpoint_path (Optional[Path]): Path to checkpoint (for export_only=True)
484 Returns:
485 Dict with pipeline results
486 """
487 start_time = datetime.now()
489 try:
490 if export_only and checkpoint_path:
491 # Only export and register existing checkpoint
492 logger.info("running_export_only_mode", checkpoint=str(checkpoint_path))
493 result = self.export_and_register(
494 best_checkpoint_path=checkpoint_path,
495 )
496 else:
497 # Full training pipeline
498 logger.info("running_full_training_pipeline", data_dir=data_dir)
500 # 1. Load data
501 train_loader, val_loader = self.load_data(data_dir=data_dir)
503 # 2. Train
504 best_checkpoint = self.train(train_loader, val_loader)
506 # 3. Export and register
507 result = self.export_and_register(best_checkpoint_path=best_checkpoint)
509 # Calculate elapsed time
510 elapsed_time = (datetime.now() - start_time).total_seconds()
512 # Log final status
513 logger.info(
514 "pipeline_complete",
515 elapsed_seconds=elapsed_time,
516 success=result.get("success", False),
517 )
519 # Finalize MLflow run
520 self.mlflow_tracker.end_run()
522 return {
523 "success": True,
524 "elapsed_time": elapsed_time,
525 "export_result": result,
526 }
528 except Exception as e:
529 logger.error(
530 "pipeline_error",
531 error=str(e),
532 exc_info=True,
533 )
535 # End MLflow run with failed status
536 self.mlflow_tracker.end_run(status="FAILED")
538 raise
541def parse_arguments() -> argparse.Namespace:
542 """
543 Parse command-line arguments.
545 Returns:
546 Parsed arguments
547 """
548 parser = argparse.ArgumentParser(
549 description="Training entry point for Heimdall RF localization pipeline",
550 )
552 # Training parameters
553 parser.add_argument(
554 "--epochs",
555 type=int,
556 default=100,
557 help="Number of training epochs (default: 100)",
558 )
559 parser.add_argument(
560 "--batch_size",
561 type=int,
562 default=32,
563 help="Batch size for training (default: 32)",
564 )
565 parser.add_argument(
566 "--learning_rate",
567 "--lr",
568 type=float,
569 default=1e-3,
570 help="Learning rate for optimizer (default: 1e-3)",
571 )
572 parser.add_argument(
573 "--validation_split",
574 "--val_split",
575 type=float,
576 default=0.2,
577 help="Fraction of data for validation (default: 0.2)",
578 )
580 # Data parameters
581 parser.add_argument(
582 "--data_dir",
583 type=str,
584 default="/tmp/heimdall_training_data",
585 help="Directory containing training data",
586 )
587 parser.add_argument(
588 "--num_workers",
589 type=int,
590 default=4,
591 help="Number of DataLoader workers (default: 4)",
592 )
594 # Device parameters
595 parser.add_argument(
596 "--accelerator",
597 type=str,
598 default="gpu",
599 choices=["cpu", "gpu", "auto"],
600 help="Training accelerator (default: gpu)",
601 )
602 parser.add_argument(
603 "--devices",
604 type=int,
605 default=1,
606 help="Number of GPUs (default: 1)",
607 )
609 # Checkpoint parameters
610 parser.add_argument(
611 "--checkpoint_dir",
612 type=str,
613 default="/tmp/heimdall_checkpoints",
614 help="Directory for saving checkpoints",
615 )
616 parser.add_argument(
617 "--checkpoint",
618 type=str,
619 default=None,
620 help="Path to existing checkpoint (for resume or export)",
621 )
623 # Mode parameters
624 parser.add_argument(
625 "--export_only",
626 action="store_true",
627 help="Skip training, only export and register checkpoint",
628 )
629 parser.add_argument(
630 "--resume_training",
631 action="store_true",
632 help="Resume training from checkpoint",
633 )
635 # MLflow parameters
636 parser.add_argument(
637 "--experiment_name",
638 type=str,
639 default="heimdall-localization",
640 help="MLflow experiment name",
641 )
642 parser.add_argument(
643 "--run_name_prefix",
644 type=str,
645 default="rf-localization",
646 help="Prefix for MLflow run name",
647 )
649 return parser.parse_args()
652def main():
653 """
654 Main entry point for training pipeline.
655 """
656 # Parse arguments
657 args = parse_arguments()
659 logger.info(
660 "training_pipeline_started",
661 epochs=args.epochs,
662 batch_size=args.batch_size,
663 learning_rate=args.learning_rate,
664 accelerator=args.accelerator,
665 )
667 try:
668 # Create pipeline
669 pipeline = TrainingPipeline(
670 epochs=args.epochs,
671 batch_size=args.batch_size,
672 learning_rate=args.learning_rate,
673 validation_split=args.validation_split,
674 num_workers=args.num_workers,
675 accelerator=args.accelerator,
676 devices=args.devices,
677 checkpoint_dir=Path(args.checkpoint_dir) if args.checkpoint_dir else None,
678 experiment_name=args.experiment_name,
679 run_name_prefix=args.run_name_prefix,
680 )
682 # Run pipeline
683 result = pipeline.run(
684 data_dir=args.data_dir,
685 export_only=args.export_only,
686 checkpoint_path=Path(args.checkpoint) if args.checkpoint else None,
687 )
689 # Print summary
690 print("\n" + "="*80)
691 print("TRAINING PIPELINE COMPLETE")
692 print("="*80)
693 print(f"Success: {result['success']}")
694 print(f"Elapsed Time: {result['elapsed_time']:.2f} seconds")
695 print(f"Export Result: {json.dumps(result['export_result'], indent=2)}")
696 print("="*80 + "\n")
698 sys.exit(0)
700 except Exception as e:
701 logger.error(
702 "training_pipeline_failed",
703 error=str(e),
704 exc_info=True,
705 )
706 print(f"\nTraining pipeline failed: {str(e)}\n", file=sys.stderr)
707 sys.exit(1)
710if __name__ == "__main__":
711 main()