Coverage for services/training/src/data/dataset.py: 17%

142 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-25 16:18 +0000

1""" 

2HeimdallDataset: PyTorch Dataset for loading training data. 

3 

4Loads IQ recordings from MinIO and ground truth labels from PostgreSQL. 

5 

6Data flow: 

71. Query PostgreSQL for recording sessions with known source locations 

82. Download IQ data (.npy files) from MinIO 

93. Extract features (mel-spectrogram) using features.py 

104. Return (features, ground_truth_label) pairs for training 

11""" 

12 

13import numpy as np 

14import torch 

15from torch.utils.data import Dataset 

16import structlog 

17from typing import Tuple, Optional, Dict, List 

18import os 

19from pathlib import Path 

20import pickle 

21from datetime import datetime 

22 

23logger = structlog.get_logger(__name__) 

24 

25 

26class HeimdallDataset(Dataset): 

27 """ 

28 PyTorch Dataset for RF source localization training. 

29  

30 Each sample returns: 

31 - features: Mel-spectrogram (3, 128, 32) - 3 channels from multi-receiver IQ 

32 - label: Ground truth position [latitude, longitude] 

33 - uncertainty: Reference uncertainty (optional) - used for uncertainty-aware loss 

34  

35 Data sources: 

36 - PostgreSQL: Recording sessions with ground truth coordinates 

37 - MinIO: IQ data files (.npy format) 

38  

39 Features: 

40 - Lazy loading (only loads on access to avoid memory bloat) 

41 - Optional data augmentation 

42 - Caching of processed features 

43 - Statistical normalization per sample 

44 """ 

45 

46 def __init__( 

47 self, 

48 data_dir: str, 

49 split: str = 'train', 

50 augmentation: bool = False, 

51 cache_dir: Optional[str] = None, 

52 normalize: bool = True, 

53 n_mels: int = 128, 

54 n_frames: int = 32, 

55 ): 

56 """ 

57 Initialize HeimdallDataset. 

58  

59 Args: 

60 data_dir (str): Directory containing preprocessed training data 

61 split (str): 'train', 'val', or 'test' 

62 augmentation (bool): Apply data augmentation 

63 cache_dir (Optional[str]): Cache processed features to disk 

64 normalize (bool): Normalize features (zero mean, unit variance) 

65 n_mels (int): Number of mel frequency bins 

66 n_frames (int): Number of time frames per spectrogram 

67 """ 

68 self.data_dir = Path(data_dir) 

69 self.split = split 

70 self.augmentation = augmentation 

71 self.cache_dir = Path(cache_dir) if cache_dir else None 

72 self.normalize = normalize 

73 self.n_mels = n_mels 

74 self.n_frames = n_frames 

75 

76 # Create cache directory if needed 

77 if self.cache_dir: 

78 self.cache_dir.mkdir(parents=True, exist_ok=True) 

79 

80 # Load dataset metadata 

81 self.samples = self._load_samples() 

82 

83 logger.info( 

84 "heimdall_dataset_initialized", 

85 split=split, 

86 num_samples=len(self.samples), 

87 augmentation=augmentation, 

88 normalize=normalize, 

89 ) 

90 

91 def _load_samples(self) -> List[Dict]: 

92 """ 

93 Load sample metadata from disk. 

94  

95 Expected structure: 

96 data_dir/ 

97 ├── train/ 

98 │ ├── session_001_iq.npy (IQ data) 

99 │ ├── session_001_label.npy (ground truth [lat, lon]) 

100 │ ├── session_001_meta.pkl (metadata) 

101 │ └── ... 

102 ├── val/ 

103 └── test/ 

104 """ 

105 

106 split_dir = self.data_dir / self.split 

107 if not split_dir.exists(): 

108 raise FileNotFoundError(f"Split directory not found: {split_dir}") 

109 

110 samples = [] 

111 

112 # Find all sessions in this split 

113 session_files = sorted(split_dir.glob("*_iq.npy")) 

114 

115 for iq_file in session_files: 

116 session_id = iq_file.stem.replace('_iq', '') 

117 

118 label_file = split_dir / f"{session_id}_label.npy" 

119 meta_file = split_dir / f"{session_id}_meta.pkl" 

120 

121 if label_file.exists(): 

122 samples.append({ 

123 'session_id': session_id, 

124 'iq_file': str(iq_file), 

125 'label_file': str(label_file), 

126 'meta_file': str(meta_file) if meta_file.exists() else None, 

127 }) 

128 

129 logger.info( 

130 "samples_loaded", 

131 split=self.split, 

132 count=len(samples), 

133 ) 

134 

135 return samples 

136 

137 def _get_cache_path(self, session_id: str) -> Path: 

138 """Get cache file path for a session.""" 

139 if not self.cache_dir: 

140 return None 

141 return self.cache_dir / f"{session_id}_features.pt" 

142 

143 def _load_iq_data(self, iq_file: str) -> np.ndarray: 

144 """Load IQ data from .npy file.""" 

145 iq_data = np.load(iq_file) 

146 

147 # Expected shape: (3, n_samples) for 3-receiver IQ data 

148 if iq_data.ndim == 1: 

149 # Single receiver - replicate to 3 channels 

150 iq_data = np.tile(iq_data[np.newaxis, :], (3, 1)) 

151 elif iq_data.ndim == 2 and iq_data.shape[0] != 3: 

152 # Wrong number of channels 

153 logger.warning("unexpected_iq_shape", shape=iq_data.shape) 

154 

155 return iq_data 

156 

157 def _extract_features(self, iq_data: np.ndarray) -> np.ndarray: 

158 """ 

159 Extract mel-spectrogram features from IQ data. 

160  

161 Returns: 

162 np.ndarray: Features of shape (3, n_mels, n_frames) 

163 """ 

164 from src.data.features import iq_to_mel_spectrogram, normalize_features 

165 

166 # Process each channel 

167 mel_specs = [] 

168 for ch in range(iq_data.shape[0]): 

169 mel_spec = iq_to_mel_spectrogram( 

170 iq_data[ch], 

171 n_mels=self.n_mels, 

172 ) 

173 

174 # Normalize per-channel 

175 if self.normalize: 

176 mel_spec, _ = normalize_features(mel_spec) 

177 

178 # Resize to fixed shape 

179 mel_spec = self._resize_to_fixed_shape(mel_spec) 

180 mel_specs.append(mel_spec) 

181 

182 features = np.stack(mel_specs, axis=0) # (3, n_mels, n_frames) 

183 return features 

184 

185 def _resize_to_fixed_shape(self, mel_spec: np.ndarray) -> np.ndarray: 

186 """ 

187 Resize mel-spectrogram to fixed shape (n_mels, n_frames). 

188  

189 Uses padding or truncation as needed. 

190 """ 

191 from scipy import signal as scipy_signal 

192 

193 current_frames = mel_spec.shape[1] 

194 

195 if current_frames == self.n_frames: 

196 return mel_spec 

197 elif current_frames > self.n_frames: 

198 # Truncate to center 

199 start = (current_frames - self.n_frames) // 2 

200 return mel_spec[:, start:start + self.n_frames] 

201 else: 

202 # Pad with reflection 

203 pad_total = self.n_frames - current_frames 

204 pad_left = pad_total // 2 

205 pad_right = pad_total - pad_left 

206 return np.pad(mel_spec, ((0, 0), (pad_left, pad_right)), mode='reflect') 

207 

208 def __len__(self) -> int: 

209 """Return number of samples in dataset.""" 

210 return len(self.samples) 

211 

212 def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict]: 

213 """ 

214 Get a single sample. 

215  

216 Args: 

217 idx (int): Sample index 

218  

219 Returns: 

220 Tuple containing: 

221 - features (torch.Tensor): Mel-spectrogram features (3, 128, 32) 

222 - label (torch.Tensor): Ground truth [latitude, longitude] 

223 - metadata (Dict): Additional metadata (session_id, etc.) 

224 """ 

225 

226 sample = self.samples[idx] 

227 session_id = sample['session_id'] 

228 

229 # Check cache first 

230 cache_path = self._get_cache_path(session_id) 

231 if cache_path and cache_path.exists(): 

232 data = torch.load(cache_path) 

233 return data['features'], data['label'], {'session_id': session_id, 'from_cache': True} 

234 

235 # Load IQ data 

236 iq_data = self._load_iq_data(sample['iq_file']) 

237 

238 # Extract features 

239 features = self._extract_features(iq_data) 

240 

241 # Load ground truth label 

242 label = np.load(sample['label_file']) # [lat, lon] 

243 

244 # Apply augmentation if training 

245 if self.augmentation and self.split == 'train': 

246 features = self._augment_features(features) 

247 

248 # Convert to tensors 

249 features_tensor = torch.from_numpy(features).float() 

250 label_tensor = torch.from_numpy(label).float() 

251 

252 # Cache if enabled 

253 if cache_path: 

254 torch.save({ 

255 'features': features_tensor, 

256 'label': label_tensor, 

257 }, cache_path) 

258 

259 metadata = { 

260 'session_id': session_id, 

261 'from_cache': False, 

262 } 

263 

264 if sample['meta_file'] and os.path.exists(sample['meta_file']): 

265 with open(sample['meta_file'], 'rb') as f: 

266 metadata.update(pickle.load(f)) 

267 

268 return features_tensor, label_tensor, metadata 

269 

270 def _augment_features(self, features: np.ndarray) -> np.ndarray: 

271 """Apply data augmentation (optional).""" 

272 

273 # Random noise 

274 if np.random.rand() < 0.3: 

275 noise = np.random.randn(*features.shape) * 0.1 

276 features = features + noise 

277 

278 # Random time shift 

279 if np.random.rand() < 0.3: 

280 shift = np.random.randint(-2, 3) 

281 if shift != 0: 

282 features = np.roll(features, shift, axis=2) 

283 

284 return features 

285 

286 def get_statistics(self) -> Dict: 

287 """ 

288 Compute dataset statistics (mean, std, min, max across all samples). 

289  

290 Useful for understanding data distribution and debugging. 

291 """ 

292 all_features = [] 

293 

294 for i in range(min(100, len(self))): # Sample first 100 

295 features, _, _ = self[i] 

296 all_features.append(features.numpy()) 

297 

298 all_features = np.concatenate([f.flatten() for f in all_features]) 

299 

300 stats = { 

301 'mean': float(np.mean(all_features)), 

302 'std': float(np.std(all_features)), 

303 'min': float(np.min(all_features)), 

304 'max': float(np.max(all_features)), 

305 'num_samples_analyzed': min(100, len(self)), 

306 } 

307 

308 logger.info("dataset_statistics_computed", stats=stats) 

309 

310 return stats 

311 

312 

313def create_dummy_dataset(output_dir: str, n_train: int = 100, n_val: int = 20): 

314 """ 

315 Create a dummy dataset for testing. 

316  

317 Useful for development and debugging before training on real data. 

318 """ 

319 

320 output_path = Path(output_dir) 

321 output_path.mkdir(parents=True, exist_ok=True) 

322 

323 for split, n_samples in [('train', n_train), ('val', n_val)]: 

324 split_dir = output_path / split 

325 split_dir.mkdir(exist_ok=True) 

326 

327 for i in range(n_samples): 

328 session_id = f"session_{i:06d}" 

329 

330 # Create synthetic IQ data 

331 iq_data = np.random.randn(3, 192000).astype(np.float32) # 3 channels, 1 sec at 192kHz 

332 np.save(split_dir / f"{session_id}_iq.npy", iq_data) 

333 

334 # Create synthetic label (lat, lon) 

335 label = np.array([45.5 + np.random.randn()*0.1, 8.5 + np.random.randn()*0.1]).astype(np.float32) 

336 np.save(split_dir / f"{session_id}_label.npy", label) 

337 

338 # Create metadata 

339 meta = {'receiver_ids': ['r1', 'r2', 'r3'], 'timestamp': datetime.now().isoformat()} 

340 with open(split_dir / f"{session_id}_meta.pkl", 'wb') as f: 

341 pickle.dump(meta, f) 

342 

343 logger.info("dummy_dataset_created", split=split, count=n_samples) 

344 

345 

346def verify_dataset(): 

347 """Verification function for dataset.""" 

348 

349 logger.info("Starting dataset verification...") 

350 

351 # Create dummy dataset 

352 dummy_dir = "/tmp/heimdall_dummy_dataset" 

353 create_dummy_dataset(dummy_dir, n_train=10, n_val=2) 

354 

355 # Create dataset 

356 dataset = HeimdallDataset(dummy_dir, split='train', augmentation=False) 

357 

358 # Get a sample 

359 features, label, metadata = dataset[0] 

360 

361 # Verify shapes 

362 assert features.shape == (3, 128, 32), f"Expected shape (3, 128, 32), got {features.shape}" 

363 assert label.shape == (2,), f"Expected label shape (2,), got {label.shape}" 

364 

365 logger.info( 

366 "✅ Dataset verification passed!", 

367 sample_features_shape=tuple(features.shape), 

368 sample_label_shape=tuple(label.shape), 

369 sample_metadata=metadata, 

370 ) 

371 

372 return dataset 

373 

374 

375if __name__ == "__main__": 

376 import logging 

377 logging.basicConfig(level=logging.INFO) 

378 dataset = verify_dataset()