Coverage for services/training/src/models/lightning_module.py: 0%
121 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
1"""
2LocalizationLitModule: PyTorch Lightning training module.
4Encapsulates the entire training loop:
5- forward pass (LocalizationNet)
6- loss computation (GaussianNLLLoss)
7- optimization (Adam)
8- learning rate scheduling
9- validation and logging
10- model checkpointing via MLflow
12This module orchestrates training with Lightning, handling:
13- Distributed training (multi-GPU ready)
14- Gradient accumulation
15- Mixed precision training (optional)
16- Experiment tracking via MLflow
17"""
19import torch
20import torch.nn as nn
21import torch.optim as optim
22from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
23import pytorch_lightning as pl
25# MLflowLogger location changed in pytorch-lightning 2.0+
26try:
27 from pytorch_lightning.loggers import MLflowLogger
28except ImportError:
29 try:
30 from pytorch_lightning.loggers.mlflow import MLflowLogger
31 except ImportError:
32 # Fallback: MLflowLogger not available, use None
33 MLflowLogger = None
35import structlog
36from typing import Dict, Tuple, Optional
37import numpy as np
38from pathlib import Path
40from src.models.localization_net import LocalizationNet
41from src.utils.losses import GaussianNLLLoss
43logger = structlog.get_logger(__name__)
46class LocalizationLitModule(pl.LightningModule):
47 """
48 PyTorch Lightning module for RF source localization training.
50 Architecture:
51 - Backbone: LocalizationNet (ResNet-18)
52 - Loss: Gaussian Negative Log-Likelihood
53 - Optimizer: Adam with warmup
54 - LR Scheduler: CosineAnnealing with restarts
56 Metrics tracked:
57 - Training: MSE, NLL, MAE
58 - Validation: MSE, NLL, MAE, calibration metrics
59 - Test: Final evaluation
60 """
62 def __init__(
63 self,
64 learning_rate: float = 1e-3,
65 weight_decay: float = 1e-5,
66 warmup_steps: int = 0,
67 num_training_steps: int = 10000,
68 pretrained_backbone: bool = True,
69 freeze_backbone: bool = False,
70 uncertainty_bounds: Tuple[float, float] = (0.01, 1.0),
71 backbone_size: str = 'large',
72 ):
73 """
74 Initialize LocalizationLitModule.
76 Args:
77 learning_rate (float): Initial learning rate
78 weight_decay (float): L2 regularization weight
79 warmup_steps (int): Number of warmup steps
80 num_training_steps (int): Total training steps for cosine scheduling
81 pretrained_backbone (bool): Use ImageNet pretrained ConvNeXt-Large
82 freeze_backbone (bool): Freeze backbone during training
83 uncertainty_bounds (Tuple[float, float]): (min_sigma, max_sigma)
84 backbone_size (str): ConvNeXt size ('tiny', 'small', 'medium', 'large')
85 """
86 super().__init__()
88 self.learning_rate = learning_rate
89 self.weight_decay = weight_decay
90 self.warmup_steps = warmup_steps
91 self.num_training_steps = num_training_steps
92 self.backbone_size = backbone_size
94 # Model - Now using ConvNeXt-Large (200M params, 88.6% ImageNet accuracy)
95 # vs previous ResNet-18 (11M params, 69.8% accuracy)
96 self.model = LocalizationNet(
97 pretrained=pretrained_backbone,
98 freeze_backbone=freeze_backbone,
99 uncertainty_min=uncertainty_bounds[0],
100 uncertainty_max=uncertainty_bounds[1],
101 backbone_size=backbone_size,
102 )
104 # Loss
105 self.loss_fn = GaussianNLLLoss(reduction='mean')
107 # Metrics storage
108 self.train_losses = []
109 self.val_losses = []
110 self.val_maes = []
112 # Hyperparameter saving
113 self.save_hyperparameters()
115 logger.info(
116 "lightning_module_initialized",
117 backbone="ConvNeXt-Large",
118 learning_rate=learning_rate,
119 warmup_steps=warmup_steps,
120 num_training_steps=num_training_steps,
121 pretrained_backbone=pretrained_backbone,
122 backbone_size=backbone_size,
123 )
125 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
126 """Forward pass through the network."""
127 return self.model(x)
129 def training_step(self, batch, batch_idx: int) -> torch.Tensor:
130 """
131 Training step called for each batch.
133 Args:
134 batch: Tuple of (features, labels)
135 batch_idx: Batch index
137 Returns:
138 Loss tensor for backward pass
139 """
140 features, labels, metadata = batch
142 # Forward pass
143 positions, uncertainties = self(features)
145 # Compute loss
146 loss, stats = self.loss_fn.forward_with_stats(
147 positions, uncertainties, labels
148 )
150 # Log metrics
151 self.log('train/loss', loss, prog_bar=True, on_step=True, on_epoch=True)
152 self.log('train/mae', stats['mae'], prog_bar=True, on_step=False, on_epoch=True)
153 self.log('train/sigma_mean', stats['sigma_mean'], on_step=False, on_epoch=True)
154 self.log('train/mse_term', stats['mse_term'], on_step=False, on_epoch=True)
155 self.log('train/log_term', stats['log_term'], on_step=False, on_epoch=True)
157 # Store for epoch-level statistics
158 self.train_losses.append(loss.item())
160 if batch_idx % 100 == 0:
161 logger.debug(
162 "training_step",
163 batch_idx=batch_idx,
164 loss=loss.item(),
165 mae=stats['mae'],
166 sigma_mean=stats['sigma_mean'],
167 )
169 return loss
171 def validation_step(self, batch, batch_idx: int):
172 """Validation step called for each validation batch."""
174 features, labels, metadata = batch
176 # Forward pass
177 with torch.no_grad():
178 positions, uncertainties = self(features)
180 # Compute loss
181 loss, stats = self.loss_fn.forward_with_stats(
182 positions, uncertainties, labels
183 )
185 # Compute additional metrics
186 mae = torch.abs(positions - labels).mean()
187 position_error = torch.norm(positions - labels, dim=1) # Euclidean distance
189 # Log metrics
190 self.log('val/loss', loss, prog_bar=True, on_epoch=True)
191 self.log('val/mae', mae, on_epoch=True)
192 self.log('val/position_error_mean', position_error.mean(), on_epoch=True)
193 self.log('val/position_error_std', position_error.std(), on_epoch=True)
194 self.log('val/sigma_mean', stats['sigma_mean'], on_epoch=True)
196 # Store for averaging
197 self.val_losses.append(loss.item())
198 self.val_maes.append(mae.item())
200 return {
201 'loss': loss,
202 'mae': mae,
203 'position_error': position_error,
204 }
206 def test_step(self, batch, batch_idx: int):
207 """Test step (evaluation on test set)."""
209 features, labels, metadata = batch
211 with torch.no_grad():
212 positions, uncertainties = self(features)
214 loss, stats = self.loss_fn.forward_with_stats(
215 positions, uncertainties, labels
216 )
218 mae = torch.abs(positions - labels).mean()
219 position_error = torch.norm(positions - labels, dim=1)
221 self.log('test/loss', loss)
222 self.log('test/mae', mae)
223 self.log('test/position_error_mean', position_error.mean())
224 self.log('test/position_error_std', position_error.std())
226 return {
227 'loss': loss,
228 'mae': mae,
229 'position_error': position_error,
230 'positions': positions,
231 'labels': labels,
232 'uncertainties': uncertainties,
233 }
235 def configure_optimizers(self):
236 """
237 Configure optimizer and learning rate scheduler.
239 Uses:
240 - Optimizer: Adam with weight decay (AdamW equivalent)
241 - Scheduler: Cosine annealing with warmup
242 """
244 # Optimizer
245 optimizer = optim.AdamW(
246 self.parameters(),
247 lr=self.learning_rate,
248 weight_decay=self.weight_decay,
249 )
251 # Learning rate scheduler
252 scheduler = CosineAnnealingLR(
253 optimizer,
254 T_max=max(self.num_training_steps - self.warmup_steps, 1),
255 eta_min=1e-6,
256 )
258 scheduler_config = {
259 'scheduler': scheduler,
260 'interval': 'step',
261 'frequency': 1,
262 }
264 logger.info(
265 "optimizer_configured",
266 optimizer="AdamW",
267 learning_rate=self.learning_rate,
268 weight_decay=self.weight_decay,
269 scheduler="CosineAnnealing",
270 )
272 return {
273 'optimizer': optimizer,
274 'lr_scheduler': scheduler_config,
275 }
277 def on_epoch_end(self):
278 """Called at the end of each epoch."""
280 if len(self.train_losses) > 0:
281 avg_train_loss = np.mean(self.train_losses[-100:])
282 logger.debug("epoch_end", avg_train_loss=avg_train_loss)
284 # Clear metrics
285 self.train_losses = []
286 self.val_losses = []
287 self.val_maes = []
289 def get_model_for_export(self) -> LocalizationNet:
290 """
291 Get the underlying model for export.
293 Used for ONNX export and inference deployment.
294 """
295 return self.model
298# Verification function
299def verify_lightning_module():
300 """Verification function for Lightning module."""
302 logger.info("Starting Lightning module verification...")
304 # Create module
305 module = LocalizationLitModule(
306 learning_rate=1e-3,
307 num_training_steps=1000,
308 pretrained_backbone=False,
309 )
311 # Create dummy batch
312 features = torch.randn(8, 3, 128, 32)
313 labels = torch.randn(8, 2)
314 metadata = {'session_id': ['s1'] * 8}
315 batch = (features, labels, metadata)
317 # Test training step
318 loss = module.training_step(batch, 0)
319 assert not torch.isnan(loss), "Loss should not be NaN"
320 assert loss.item() > 0, "Loss should be positive"
322 logger.info("✅ Training step passed!")
324 # Test validation step
325 val_outputs = module.validation_step(batch, 0)
326 assert 'loss' in val_outputs
327 assert 'mae' in val_outputs
329 logger.info("✅ Validation step passed!")
331 # Test forward pass
332 positions, uncertainties = module(features)
333 assert positions.shape == (8, 2)
334 assert uncertainties.shape == (8, 2)
335 assert (uncertainties > 0).all()
337 logger.info("✅ Forward pass passed!")
339 # Test optimizer configuration
340 optimizer_config = module.configure_optimizers()
341 assert 'optimizer' in optimizer_config
342 assert 'lr_scheduler' in optimizer_config
344 logger.info("✅ Optimizer configuration passed!")
346 logger.info("✅ Lightning module verification complete!")
348 return module
351if __name__ == "__main__":
352 import logging
353 logging.basicConfig(level=logging.INFO)
354 module = verify_lightning_module()