Coverage for services/training/src/models/lightning_module.py: 0%

121 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-25 16:18 +0000

1""" 

2LocalizationLitModule: PyTorch Lightning training module. 

3 

4Encapsulates the entire training loop: 

5- forward pass (LocalizationNet) 

6- loss computation (GaussianNLLLoss) 

7- optimization (Adam) 

8- learning rate scheduling 

9- validation and logging 

10- model checkpointing via MLflow 

11 

12This module orchestrates training with Lightning, handling: 

13- Distributed training (multi-GPU ready) 

14- Gradient accumulation 

15- Mixed precision training (optional) 

16- Experiment tracking via MLflow 

17""" 

18 

19import torch 

20import torch.nn as nn 

21import torch.optim as optim 

22from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR 

23import pytorch_lightning as pl 

24 

25# MLflowLogger location changed in pytorch-lightning 2.0+ 

26try: 

27 from pytorch_lightning.loggers import MLflowLogger 

28except ImportError: 

29 try: 

30 from pytorch_lightning.loggers.mlflow import MLflowLogger 

31 except ImportError: 

32 # Fallback: MLflowLogger not available, use None 

33 MLflowLogger = None 

34 

35import structlog 

36from typing import Dict, Tuple, Optional 

37import numpy as np 

38from pathlib import Path 

39 

40from src.models.localization_net import LocalizationNet 

41from src.utils.losses import GaussianNLLLoss 

42 

43logger = structlog.get_logger(__name__) 

44 

45 

46class LocalizationLitModule(pl.LightningModule): 

47 """ 

48 PyTorch Lightning module for RF source localization training. 

49  

50 Architecture: 

51 - Backbone: LocalizationNet (ResNet-18) 

52 - Loss: Gaussian Negative Log-Likelihood 

53 - Optimizer: Adam with warmup 

54 - LR Scheduler: CosineAnnealing with restarts 

55  

56 Metrics tracked: 

57 - Training: MSE, NLL, MAE 

58 - Validation: MSE, NLL, MAE, calibration metrics 

59 - Test: Final evaluation 

60 """ 

61 

62 def __init__( 

63 self, 

64 learning_rate: float = 1e-3, 

65 weight_decay: float = 1e-5, 

66 warmup_steps: int = 0, 

67 num_training_steps: int = 10000, 

68 pretrained_backbone: bool = True, 

69 freeze_backbone: bool = False, 

70 uncertainty_bounds: Tuple[float, float] = (0.01, 1.0), 

71 backbone_size: str = 'large', 

72 ): 

73 """ 

74 Initialize LocalizationLitModule. 

75  

76 Args: 

77 learning_rate (float): Initial learning rate 

78 weight_decay (float): L2 regularization weight 

79 warmup_steps (int): Number of warmup steps 

80 num_training_steps (int): Total training steps for cosine scheduling 

81 pretrained_backbone (bool): Use ImageNet pretrained ConvNeXt-Large 

82 freeze_backbone (bool): Freeze backbone during training 

83 uncertainty_bounds (Tuple[float, float]): (min_sigma, max_sigma) 

84 backbone_size (str): ConvNeXt size ('tiny', 'small', 'medium', 'large') 

85 """ 

86 super().__init__() 

87 

88 self.learning_rate = learning_rate 

89 self.weight_decay = weight_decay 

90 self.warmup_steps = warmup_steps 

91 self.num_training_steps = num_training_steps 

92 self.backbone_size = backbone_size 

93 

94 # Model - Now using ConvNeXt-Large (200M params, 88.6% ImageNet accuracy) 

95 # vs previous ResNet-18 (11M params, 69.8% accuracy) 

96 self.model = LocalizationNet( 

97 pretrained=pretrained_backbone, 

98 freeze_backbone=freeze_backbone, 

99 uncertainty_min=uncertainty_bounds[0], 

100 uncertainty_max=uncertainty_bounds[1], 

101 backbone_size=backbone_size, 

102 ) 

103 

104 # Loss 

105 self.loss_fn = GaussianNLLLoss(reduction='mean') 

106 

107 # Metrics storage 

108 self.train_losses = [] 

109 self.val_losses = [] 

110 self.val_maes = [] 

111 

112 # Hyperparameter saving 

113 self.save_hyperparameters() 

114 

115 logger.info( 

116 "lightning_module_initialized", 

117 backbone="ConvNeXt-Large", 

118 learning_rate=learning_rate, 

119 warmup_steps=warmup_steps, 

120 num_training_steps=num_training_steps, 

121 pretrained_backbone=pretrained_backbone, 

122 backbone_size=backbone_size, 

123 ) 

124 

125 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 

126 """Forward pass through the network.""" 

127 return self.model(x) 

128 

129 def training_step(self, batch, batch_idx: int) -> torch.Tensor: 

130 """ 

131 Training step called for each batch. 

132  

133 Args: 

134 batch: Tuple of (features, labels) 

135 batch_idx: Batch index 

136  

137 Returns: 

138 Loss tensor for backward pass 

139 """ 

140 features, labels, metadata = batch 

141 

142 # Forward pass 

143 positions, uncertainties = self(features) 

144 

145 # Compute loss 

146 loss, stats = self.loss_fn.forward_with_stats( 

147 positions, uncertainties, labels 

148 ) 

149 

150 # Log metrics 

151 self.log('train/loss', loss, prog_bar=True, on_step=True, on_epoch=True) 

152 self.log('train/mae', stats['mae'], prog_bar=True, on_step=False, on_epoch=True) 

153 self.log('train/sigma_mean', stats['sigma_mean'], on_step=False, on_epoch=True) 

154 self.log('train/mse_term', stats['mse_term'], on_step=False, on_epoch=True) 

155 self.log('train/log_term', stats['log_term'], on_step=False, on_epoch=True) 

156 

157 # Store for epoch-level statistics 

158 self.train_losses.append(loss.item()) 

159 

160 if batch_idx % 100 == 0: 

161 logger.debug( 

162 "training_step", 

163 batch_idx=batch_idx, 

164 loss=loss.item(), 

165 mae=stats['mae'], 

166 sigma_mean=stats['sigma_mean'], 

167 ) 

168 

169 return loss 

170 

171 def validation_step(self, batch, batch_idx: int): 

172 """Validation step called for each validation batch.""" 

173 

174 features, labels, metadata = batch 

175 

176 # Forward pass 

177 with torch.no_grad(): 

178 positions, uncertainties = self(features) 

179 

180 # Compute loss 

181 loss, stats = self.loss_fn.forward_with_stats( 

182 positions, uncertainties, labels 

183 ) 

184 

185 # Compute additional metrics 

186 mae = torch.abs(positions - labels).mean() 

187 position_error = torch.norm(positions - labels, dim=1) # Euclidean distance 

188 

189 # Log metrics 

190 self.log('val/loss', loss, prog_bar=True, on_epoch=True) 

191 self.log('val/mae', mae, on_epoch=True) 

192 self.log('val/position_error_mean', position_error.mean(), on_epoch=True) 

193 self.log('val/position_error_std', position_error.std(), on_epoch=True) 

194 self.log('val/sigma_mean', stats['sigma_mean'], on_epoch=True) 

195 

196 # Store for averaging 

197 self.val_losses.append(loss.item()) 

198 self.val_maes.append(mae.item()) 

199 

200 return { 

201 'loss': loss, 

202 'mae': mae, 

203 'position_error': position_error, 

204 } 

205 

206 def test_step(self, batch, batch_idx: int): 

207 """Test step (evaluation on test set).""" 

208 

209 features, labels, metadata = batch 

210 

211 with torch.no_grad(): 

212 positions, uncertainties = self(features) 

213 

214 loss, stats = self.loss_fn.forward_with_stats( 

215 positions, uncertainties, labels 

216 ) 

217 

218 mae = torch.abs(positions - labels).mean() 

219 position_error = torch.norm(positions - labels, dim=1) 

220 

221 self.log('test/loss', loss) 

222 self.log('test/mae', mae) 

223 self.log('test/position_error_mean', position_error.mean()) 

224 self.log('test/position_error_std', position_error.std()) 

225 

226 return { 

227 'loss': loss, 

228 'mae': mae, 

229 'position_error': position_error, 

230 'positions': positions, 

231 'labels': labels, 

232 'uncertainties': uncertainties, 

233 } 

234 

235 def configure_optimizers(self): 

236 """ 

237 Configure optimizer and learning rate scheduler. 

238  

239 Uses: 

240 - Optimizer: Adam with weight decay (AdamW equivalent) 

241 - Scheduler: Cosine annealing with warmup 

242 """ 

243 

244 # Optimizer 

245 optimizer = optim.AdamW( 

246 self.parameters(), 

247 lr=self.learning_rate, 

248 weight_decay=self.weight_decay, 

249 ) 

250 

251 # Learning rate scheduler 

252 scheduler = CosineAnnealingLR( 

253 optimizer, 

254 T_max=max(self.num_training_steps - self.warmup_steps, 1), 

255 eta_min=1e-6, 

256 ) 

257 

258 scheduler_config = { 

259 'scheduler': scheduler, 

260 'interval': 'step', 

261 'frequency': 1, 

262 } 

263 

264 logger.info( 

265 "optimizer_configured", 

266 optimizer="AdamW", 

267 learning_rate=self.learning_rate, 

268 weight_decay=self.weight_decay, 

269 scheduler="CosineAnnealing", 

270 ) 

271 

272 return { 

273 'optimizer': optimizer, 

274 'lr_scheduler': scheduler_config, 

275 } 

276 

277 def on_epoch_end(self): 

278 """Called at the end of each epoch.""" 

279 

280 if len(self.train_losses) > 0: 

281 avg_train_loss = np.mean(self.train_losses[-100:]) 

282 logger.debug("epoch_end", avg_train_loss=avg_train_loss) 

283 

284 # Clear metrics 

285 self.train_losses = [] 

286 self.val_losses = [] 

287 self.val_maes = [] 

288 

289 def get_model_for_export(self) -> LocalizationNet: 

290 """ 

291 Get the underlying model for export. 

292  

293 Used for ONNX export and inference deployment. 

294 """ 

295 return self.model 

296 

297 

298# Verification function 

299def verify_lightning_module(): 

300 """Verification function for Lightning module.""" 

301 

302 logger.info("Starting Lightning module verification...") 

303 

304 # Create module 

305 module = LocalizationLitModule( 

306 learning_rate=1e-3, 

307 num_training_steps=1000, 

308 pretrained_backbone=False, 

309 ) 

310 

311 # Create dummy batch 

312 features = torch.randn(8, 3, 128, 32) 

313 labels = torch.randn(8, 2) 

314 metadata = {'session_id': ['s1'] * 8} 

315 batch = (features, labels, metadata) 

316 

317 # Test training step 

318 loss = module.training_step(batch, 0) 

319 assert not torch.isnan(loss), "Loss should not be NaN" 

320 assert loss.item() > 0, "Loss should be positive" 

321 

322 logger.info("✅ Training step passed!") 

323 

324 # Test validation step 

325 val_outputs = module.validation_step(batch, 0) 

326 assert 'loss' in val_outputs 

327 assert 'mae' in val_outputs 

328 

329 logger.info("✅ Validation step passed!") 

330 

331 # Test forward pass 

332 positions, uncertainties = module(features) 

333 assert positions.shape == (8, 2) 

334 assert uncertainties.shape == (8, 2) 

335 assert (uncertainties > 0).all() 

336 

337 logger.info("✅ Forward pass passed!") 

338 

339 # Test optimizer configuration 

340 optimizer_config = module.configure_optimizers() 

341 assert 'optimizer' in optimizer_config 

342 assert 'lr_scheduler' in optimizer_config 

343 

344 logger.info("✅ Optimizer configuration passed!") 

345 

346 logger.info("✅ Lightning module verification complete!") 

347 

348 return module 

349 

350 

351if __name__ == "__main__": 

352 import logging 

353 logging.basicConfig(level=logging.INFO) 

354 module = verify_lightning_module()