Coverage for services/training/src/utils/losses.py: 0%
96 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
1"""
2Gaussian Negative Log-Likelihood Loss for uncertainty-aware regression.
4This custom loss function is designed for the localization task where we want to:
51. Predict accurate position estimates (latitude, longitude)
62. Estimate uncertainty for each prediction (sigma_x, sigma_y)
73. Penalize overconfidence (small sigma with large error)
9The loss combines position error with uncertainty calibration.
11Loss formula:
12L = (y - mu)^2 / (2 * sigma^2) + log(sigma)
14Where:
15- y: ground truth position
16- mu: predicted position
17- sigma: predicted uncertainty (standard deviation)
19Interpretation:
20- First term: MSE weighted by inverse of uncertainty
21- Second term: Regularization that prevents collapse of sigma to zero
23This encourages the model to:
24- Make accurate predictions (minimize position error)
25- Produce well-calibrated uncertainty (not too overconfident)
26"""
28import torch
29import torch.nn as nn
30import torch.nn.functional as F
31import structlog
32from typing import Tuple, Dict
33import numpy as np
35logger = structlog.get_logger(__name__)
38class GaussianNLLLoss(nn.Module):
39 """
40 Gaussian Negative Log-Likelihood Loss for uncertainty-aware regression.
42 Suitable for tasks where the model should output both predictions and
43 uncertainty estimates.
45 Args:
46 reduction (str): 'mean' or 'sum' for batch reduction
47 eps (float): Small value to avoid numerical issues
48 """
50 def __init__(self, reduction: str = 'mean', eps: float = 1e-6):
51 super(GaussianNLLLoss, self).__init__()
53 if reduction not in ['mean', 'sum', 'none']:
54 raise ValueError(f"Invalid reduction: {reduction}")
56 self.reduction = reduction
57 self.eps = eps
59 def forward(
60 self,
61 predictions: torch.Tensor,
62 uncertainties: torch.Tensor,
63 targets: torch.Tensor,
64 ) -> torch.Tensor:
65 """
66 Compute Gaussian NLL loss.
68 Args:
69 predictions (torch.Tensor): Predicted positions, shape (batch_size, 2)
70 uncertainties (torch.Tensor): Predicted uncertainties (sigmas), shape (batch_size, 2)
71 Must be positive
72 targets (torch.Tensor): Ground truth positions, shape (batch_size, 2)
74 Returns:
75 torch.Tensor: Loss value (scalar if reduction='mean'/'sum', else shape (batch_size, 2))
77 Example:
78 >>> loss_fn = GaussianNLLLoss()
79 >>> pred = torch.randn(8, 2)
80 >>> sigma = torch.abs(torch.randn(8, 2)) + 0.1 # Ensure positive
81 >>> target = torch.randn(8, 2)
82 >>> loss = loss_fn(pred, sigma, target)
83 """
85 # Ensure uncertainties are positive
86 uncertainties = torch.clamp(uncertainties, min=self.eps)
88 # Compute residuals
89 residuals = targets - predictions # (batch_size, 2)
91 # Gaussian NLL = (residual^2) / (2 * sigma^2) + log(sigma)
92 # First term: weighted MSE
93 mse_term = (residuals ** 2) / (2 * uncertainties ** 2)
95 # Second term: log of uncertainty (regularization)
96 log_term = torch.log(uncertainties)
98 # Total loss per element
99 loss_per_element = mse_term + log_term
101 # Aggregate based on reduction strategy
102 if self.reduction == 'mean':
103 return loss_per_element.mean()
104 elif self.reduction == 'sum':
105 return loss_per_element.sum()
106 else: # 'none'
107 return loss_per_element
109 def forward_with_stats(
110 self,
111 predictions: torch.Tensor,
112 uncertainties: torch.Tensor,
113 targets: torch.Tensor,
114 ) -> Tuple[torch.Tensor, Dict]:
115 """
116 Compute loss and return detailed statistics.
118 Useful for monitoring and debugging during training.
120 Returns:
121 Tuple[torch.Tensor, Dict]:
122 - Loss value
123 - Statistics dict with components
124 """
126 # Compute loss
127 loss = self.forward(predictions, uncertainties, targets)
129 # Compute components separately
130 uncertainties_safe = torch.clamp(uncertainties, min=self.eps)
131 residuals = targets - predictions
132 mse_term = (residuals ** 2) / (2 * uncertainties_safe ** 2)
133 log_term = torch.log(uncertainties_safe)
135 # Statistics
136 stats = {
137 'loss': loss.item() if loss.dim() == 0 else loss.mean().item(),
138 'mse_term': mse_term.mean().item(),
139 'log_term': log_term.mean().item(),
140 'mae': torch.abs(residuals).mean().item(),
141 'sigma_mean': uncertainties.mean().item(),
142 'sigma_min': uncertainties.min().item(),
143 'sigma_max': uncertainties.max().item(),
144 'residual_mean': residuals.mean().item(),
145 'residual_std': residuals.std().item(),
146 }
148 return loss, stats
151class HuberNLLLoss(nn.Module):
152 """
153 Huber loss variant for uncertainty-aware regression.
155 More robust to outliers than pure Gaussian NLL.
157 Args:
158 delta (float): Huber loss delta parameter (transition point)
159 reduction (str): 'mean', 'sum', or 'none'
160 """
162 def __init__(self, delta: float = 1.0, reduction: str = 'mean'):
163 super(HuberNLLLoss, self).__init__()
164 self.delta = delta
165 self.reduction = reduction
166 self.eps = 1e-6
168 def forward(
169 self,
170 predictions: torch.Tensor,
171 uncertainties: torch.Tensor,
172 targets: torch.Tensor,
173 ) -> torch.Tensor:
174 """
175 Compute Huber NLL loss.
177 More robust than Gaussian NLL for data with outliers.
178 """
180 uncertainties = torch.clamp(uncertainties, min=self.eps)
181 residuals = targets - predictions
183 # Normalize by uncertainty
184 normalized_residuals = residuals / uncertainties
186 # Huber loss on normalized residuals
187 huber_loss = F.huber_loss(normalized_residuals, torch.zeros_like(normalized_residuals),
188 delta=self.delta, reduction='none')
190 # Add log term
191 log_term = torch.log(uncertainties)
192 loss_per_element = huber_loss + log_term
194 if self.reduction == 'mean':
195 return loss_per_element.mean()
196 elif self.reduction == 'sum':
197 return loss_per_element.sum()
198 else:
199 return loss_per_element
202class QuantileLoss(nn.Module):
203 """
204 Quantile loss for predicting confidence intervals.
206 Can be used alongside Gaussian NLL for more flexible uncertainty modeling.
208 Args:
209 quantiles (list): List of quantiles to predict (e.g., [0.1, 0.5, 0.9])
210 """
212 def __init__(self, quantiles: list = [0.1, 0.5, 0.9]):
213 super(QuantileLoss, self).__init__()
214 self.quantiles = quantiles
215 self.eps = 1e-6
217 def forward(
218 self,
219 predictions: torch.Tensor,
220 targets: torch.Tensor,
221 ) -> torch.Tensor:
222 """
223 Compute quantile loss.
225 Args:
226 predictions (torch.Tensor): Predicted quantiles, shape (batch, len(quantiles), 2)
227 targets (torch.Tensor): Ground truth, shape (batch, 2)
229 Returns:
230 torch.Tensor: Quantile loss
231 """
233 total_loss = 0
234 for i, q in enumerate(self.quantiles):
235 residuals = targets - predictions[:, i, :]
236 # Quantile loss: max(q * residual, (q - 1) * residual)
237 loss = torch.max(
238 q * residuals,
239 (q - 1) * residuals
240 ).mean()
241 total_loss += loss
243 return total_loss / len(self.quantiles)
246def verify_gaussian_nll_loss():
247 """Verification function for Gaussian NLL loss."""
249 logger.info("Starting Gaussian NLL loss verification...")
251 # Create loss function
252 loss_fn = GaussianNLLLoss(reduction='mean')
254 # Create dummy data
255 batch_size = 8
256 predictions = torch.randn(batch_size, 2)
257 uncertainties = torch.abs(torch.randn(batch_size, 2)) + 0.1 # Ensure positive
258 targets = torch.randn(batch_size, 2)
260 # Compute loss
261 loss = loss_fn(predictions, uncertainties, targets)
263 assert loss.item() > 0, "Loss should be positive"
264 assert not torch.isnan(loss), "Loss should not be NaN"
266 logger.info("✅ Basic loss computation passed!")
268 # Test with stats
269 loss_with_stats, stats = loss_fn.forward_with_stats(predictions, uncertainties, targets)
271 logger.info("loss_statistics", stats=stats)
273 # Test edge cases
274 # Case 1: Perfect prediction with low uncertainty
275 perfect_pred = torch.tensor([[0.0, 0.0]])
276 perfect_sigma = torch.tensor([[0.1, 0.1]])
277 perfect_target = torch.tensor([[0.0, 0.0]])
279 loss_perfect = loss_fn(perfect_pred, perfect_sigma, perfect_target)
280 logger.info("loss_perfect_prediction", value=loss_perfect.item())
282 # Case 2: Bad prediction with high uncertainty (should be better than low uncertainty)
283 bad_pred = torch.tensor([[1.0, 1.0]])
284 high_sigma = torch.tensor([[1.0, 1.0]])
285 bad_target = torch.tensor([[0.0, 0.0]])
287 loss_high_sigma = loss_fn(bad_pred, high_sigma, bad_target)
288 logger.info("loss_high_uncertainty", value=loss_high_sigma.item())
290 # Case 3: Same error with low uncertainty (should be worse)
291 low_sigma = torch.tensor([[0.1, 0.1]])
292 loss_low_sigma = loss_fn(bad_pred, low_sigma, bad_target)
293 logger.info("loss_low_uncertainty", value=loss_low_sigma.item())
295 assert loss_low_sigma > loss_high_sigma, \
296 "Low uncertainty should have higher loss for same error"
298 logger.info("✅ Gaussian NLL loss verification complete!")
300 return loss_fn
303if __name__ == "__main__":
304 import logging
305 logging.basicConfig(level=logging.INFO)
306 loss_fn = verify_gaussian_nll_loss()