Coverage for services/training/src/utils/losses.py: 0%

96 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-25 16:18 +0000

1""" 

2Gaussian Negative Log-Likelihood Loss for uncertainty-aware regression. 

3 

4This custom loss function is designed for the localization task where we want to: 

51. Predict accurate position estimates (latitude, longitude) 

62. Estimate uncertainty for each prediction (sigma_x, sigma_y) 

73. Penalize overconfidence (small sigma with large error) 

8 

9The loss combines position error with uncertainty calibration. 

10 

11Loss formula: 

12L = (y - mu)^2 / (2 * sigma^2) + log(sigma) 

13 

14Where: 

15- y: ground truth position 

16- mu: predicted position 

17- sigma: predicted uncertainty (standard deviation) 

18 

19Interpretation: 

20- First term: MSE weighted by inverse of uncertainty 

21- Second term: Regularization that prevents collapse of sigma to zero 

22 

23This encourages the model to: 

24- Make accurate predictions (minimize position error) 

25- Produce well-calibrated uncertainty (not too overconfident) 

26""" 

27 

28import torch 

29import torch.nn as nn 

30import torch.nn.functional as F 

31import structlog 

32from typing import Tuple, Dict 

33import numpy as np 

34 

35logger = structlog.get_logger(__name__) 

36 

37 

38class GaussianNLLLoss(nn.Module): 

39 """ 

40 Gaussian Negative Log-Likelihood Loss for uncertainty-aware regression. 

41  

42 Suitable for tasks where the model should output both predictions and 

43 uncertainty estimates. 

44  

45 Args: 

46 reduction (str): 'mean' or 'sum' for batch reduction 

47 eps (float): Small value to avoid numerical issues 

48 """ 

49 

50 def __init__(self, reduction: str = 'mean', eps: float = 1e-6): 

51 super(GaussianNLLLoss, self).__init__() 

52 

53 if reduction not in ['mean', 'sum', 'none']: 

54 raise ValueError(f"Invalid reduction: {reduction}") 

55 

56 self.reduction = reduction 

57 self.eps = eps 

58 

59 def forward( 

60 self, 

61 predictions: torch.Tensor, 

62 uncertainties: torch.Tensor, 

63 targets: torch.Tensor, 

64 ) -> torch.Tensor: 

65 """ 

66 Compute Gaussian NLL loss. 

67  

68 Args: 

69 predictions (torch.Tensor): Predicted positions, shape (batch_size, 2) 

70 uncertainties (torch.Tensor): Predicted uncertainties (sigmas), shape (batch_size, 2) 

71 Must be positive 

72 targets (torch.Tensor): Ground truth positions, shape (batch_size, 2) 

73  

74 Returns: 

75 torch.Tensor: Loss value (scalar if reduction='mean'/'sum', else shape (batch_size, 2)) 

76  

77 Example: 

78 >>> loss_fn = GaussianNLLLoss() 

79 >>> pred = torch.randn(8, 2) 

80 >>> sigma = torch.abs(torch.randn(8, 2)) + 0.1 # Ensure positive 

81 >>> target = torch.randn(8, 2) 

82 >>> loss = loss_fn(pred, sigma, target) 

83 """ 

84 

85 # Ensure uncertainties are positive 

86 uncertainties = torch.clamp(uncertainties, min=self.eps) 

87 

88 # Compute residuals 

89 residuals = targets - predictions # (batch_size, 2) 

90 

91 # Gaussian NLL = (residual^2) / (2 * sigma^2) + log(sigma) 

92 # First term: weighted MSE 

93 mse_term = (residuals ** 2) / (2 * uncertainties ** 2) 

94 

95 # Second term: log of uncertainty (regularization) 

96 log_term = torch.log(uncertainties) 

97 

98 # Total loss per element 

99 loss_per_element = mse_term + log_term 

100 

101 # Aggregate based on reduction strategy 

102 if self.reduction == 'mean': 

103 return loss_per_element.mean() 

104 elif self.reduction == 'sum': 

105 return loss_per_element.sum() 

106 else: # 'none' 

107 return loss_per_element 

108 

109 def forward_with_stats( 

110 self, 

111 predictions: torch.Tensor, 

112 uncertainties: torch.Tensor, 

113 targets: torch.Tensor, 

114 ) -> Tuple[torch.Tensor, Dict]: 

115 """ 

116 Compute loss and return detailed statistics. 

117  

118 Useful for monitoring and debugging during training. 

119  

120 Returns: 

121 Tuple[torch.Tensor, Dict]: 

122 - Loss value 

123 - Statistics dict with components 

124 """ 

125 

126 # Compute loss 

127 loss = self.forward(predictions, uncertainties, targets) 

128 

129 # Compute components separately 

130 uncertainties_safe = torch.clamp(uncertainties, min=self.eps) 

131 residuals = targets - predictions 

132 mse_term = (residuals ** 2) / (2 * uncertainties_safe ** 2) 

133 log_term = torch.log(uncertainties_safe) 

134 

135 # Statistics 

136 stats = { 

137 'loss': loss.item() if loss.dim() == 0 else loss.mean().item(), 

138 'mse_term': mse_term.mean().item(), 

139 'log_term': log_term.mean().item(), 

140 'mae': torch.abs(residuals).mean().item(), 

141 'sigma_mean': uncertainties.mean().item(), 

142 'sigma_min': uncertainties.min().item(), 

143 'sigma_max': uncertainties.max().item(), 

144 'residual_mean': residuals.mean().item(), 

145 'residual_std': residuals.std().item(), 

146 } 

147 

148 return loss, stats 

149 

150 

151class HuberNLLLoss(nn.Module): 

152 """ 

153 Huber loss variant for uncertainty-aware regression. 

154  

155 More robust to outliers than pure Gaussian NLL. 

156  

157 Args: 

158 delta (float): Huber loss delta parameter (transition point) 

159 reduction (str): 'mean', 'sum', or 'none' 

160 """ 

161 

162 def __init__(self, delta: float = 1.0, reduction: str = 'mean'): 

163 super(HuberNLLLoss, self).__init__() 

164 self.delta = delta 

165 self.reduction = reduction 

166 self.eps = 1e-6 

167 

168 def forward( 

169 self, 

170 predictions: torch.Tensor, 

171 uncertainties: torch.Tensor, 

172 targets: torch.Tensor, 

173 ) -> torch.Tensor: 

174 """ 

175 Compute Huber NLL loss. 

176  

177 More robust than Gaussian NLL for data with outliers. 

178 """ 

179 

180 uncertainties = torch.clamp(uncertainties, min=self.eps) 

181 residuals = targets - predictions 

182 

183 # Normalize by uncertainty 

184 normalized_residuals = residuals / uncertainties 

185 

186 # Huber loss on normalized residuals 

187 huber_loss = F.huber_loss(normalized_residuals, torch.zeros_like(normalized_residuals), 

188 delta=self.delta, reduction='none') 

189 

190 # Add log term 

191 log_term = torch.log(uncertainties) 

192 loss_per_element = huber_loss + log_term 

193 

194 if self.reduction == 'mean': 

195 return loss_per_element.mean() 

196 elif self.reduction == 'sum': 

197 return loss_per_element.sum() 

198 else: 

199 return loss_per_element 

200 

201 

202class QuantileLoss(nn.Module): 

203 """ 

204 Quantile loss for predicting confidence intervals. 

205  

206 Can be used alongside Gaussian NLL for more flexible uncertainty modeling. 

207  

208 Args: 

209 quantiles (list): List of quantiles to predict (e.g., [0.1, 0.5, 0.9]) 

210 """ 

211 

212 def __init__(self, quantiles: list = [0.1, 0.5, 0.9]): 

213 super(QuantileLoss, self).__init__() 

214 self.quantiles = quantiles 

215 self.eps = 1e-6 

216 

217 def forward( 

218 self, 

219 predictions: torch.Tensor, 

220 targets: torch.Tensor, 

221 ) -> torch.Tensor: 

222 """ 

223 Compute quantile loss. 

224  

225 Args: 

226 predictions (torch.Tensor): Predicted quantiles, shape (batch, len(quantiles), 2) 

227 targets (torch.Tensor): Ground truth, shape (batch, 2) 

228  

229 Returns: 

230 torch.Tensor: Quantile loss 

231 """ 

232 

233 total_loss = 0 

234 for i, q in enumerate(self.quantiles): 

235 residuals = targets - predictions[:, i, :] 

236 # Quantile loss: max(q * residual, (q - 1) * residual) 

237 loss = torch.max( 

238 q * residuals, 

239 (q - 1) * residuals 

240 ).mean() 

241 total_loss += loss 

242 

243 return total_loss / len(self.quantiles) 

244 

245 

246def verify_gaussian_nll_loss(): 

247 """Verification function for Gaussian NLL loss.""" 

248 

249 logger.info("Starting Gaussian NLL loss verification...") 

250 

251 # Create loss function 

252 loss_fn = GaussianNLLLoss(reduction='mean') 

253 

254 # Create dummy data 

255 batch_size = 8 

256 predictions = torch.randn(batch_size, 2) 

257 uncertainties = torch.abs(torch.randn(batch_size, 2)) + 0.1 # Ensure positive 

258 targets = torch.randn(batch_size, 2) 

259 

260 # Compute loss 

261 loss = loss_fn(predictions, uncertainties, targets) 

262 

263 assert loss.item() > 0, "Loss should be positive" 

264 assert not torch.isnan(loss), "Loss should not be NaN" 

265 

266 logger.info("✅ Basic loss computation passed!") 

267 

268 # Test with stats 

269 loss_with_stats, stats = loss_fn.forward_with_stats(predictions, uncertainties, targets) 

270 

271 logger.info("loss_statistics", stats=stats) 

272 

273 # Test edge cases 

274 # Case 1: Perfect prediction with low uncertainty 

275 perfect_pred = torch.tensor([[0.0, 0.0]]) 

276 perfect_sigma = torch.tensor([[0.1, 0.1]]) 

277 perfect_target = torch.tensor([[0.0, 0.0]]) 

278 

279 loss_perfect = loss_fn(perfect_pred, perfect_sigma, perfect_target) 

280 logger.info("loss_perfect_prediction", value=loss_perfect.item()) 

281 

282 # Case 2: Bad prediction with high uncertainty (should be better than low uncertainty) 

283 bad_pred = torch.tensor([[1.0, 1.0]]) 

284 high_sigma = torch.tensor([[1.0, 1.0]]) 

285 bad_target = torch.tensor([[0.0, 0.0]]) 

286 

287 loss_high_sigma = loss_fn(bad_pred, high_sigma, bad_target) 

288 logger.info("loss_high_uncertainty", value=loss_high_sigma.item()) 

289 

290 # Case 3: Same error with low uncertainty (should be worse) 

291 low_sigma = torch.tensor([[0.1, 0.1]]) 

292 loss_low_sigma = loss_fn(bad_pred, low_sigma, bad_target) 

293 logger.info("loss_low_uncertainty", value=loss_low_sigma.item()) 

294 

295 assert loss_low_sigma > loss_high_sigma, \ 

296 "Low uncertainty should have higher loss for same error" 

297 

298 logger.info("✅ Gaussian NLL loss verification complete!") 

299 

300 return loss_fn 

301 

302 

303if __name__ == "__main__": 

304 import logging 

305 logging.basicConfig(level=logging.INFO) 

306 loss_fn = verify_gaussian_nll_loss()