Coverage for services/training/src/data/dataset.py: 17%
142 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
1"""
2HeimdallDataset: PyTorch Dataset for loading training data.
4Loads IQ recordings from MinIO and ground truth labels from PostgreSQL.
6Data flow:
71. Query PostgreSQL for recording sessions with known source locations
82. Download IQ data (.npy files) from MinIO
93. Extract features (mel-spectrogram) using features.py
104. Return (features, ground_truth_label) pairs for training
11"""
13import numpy as np
14import torch
15from torch.utils.data import Dataset
16import structlog
17from typing import Tuple, Optional, Dict, List
18import os
19from pathlib import Path
20import pickle
21from datetime import datetime
23logger = structlog.get_logger(__name__)
26class HeimdallDataset(Dataset):
27 """
28 PyTorch Dataset for RF source localization training.
30 Each sample returns:
31 - features: Mel-spectrogram (3, 128, 32) - 3 channels from multi-receiver IQ
32 - label: Ground truth position [latitude, longitude]
33 - uncertainty: Reference uncertainty (optional) - used for uncertainty-aware loss
35 Data sources:
36 - PostgreSQL: Recording sessions with ground truth coordinates
37 - MinIO: IQ data files (.npy format)
39 Features:
40 - Lazy loading (only loads on access to avoid memory bloat)
41 - Optional data augmentation
42 - Caching of processed features
43 - Statistical normalization per sample
44 """
46 def __init__(
47 self,
48 data_dir: str,
49 split: str = 'train',
50 augmentation: bool = False,
51 cache_dir: Optional[str] = None,
52 normalize: bool = True,
53 n_mels: int = 128,
54 n_frames: int = 32,
55 ):
56 """
57 Initialize HeimdallDataset.
59 Args:
60 data_dir (str): Directory containing preprocessed training data
61 split (str): 'train', 'val', or 'test'
62 augmentation (bool): Apply data augmentation
63 cache_dir (Optional[str]): Cache processed features to disk
64 normalize (bool): Normalize features (zero mean, unit variance)
65 n_mels (int): Number of mel frequency bins
66 n_frames (int): Number of time frames per spectrogram
67 """
68 self.data_dir = Path(data_dir)
69 self.split = split
70 self.augmentation = augmentation
71 self.cache_dir = Path(cache_dir) if cache_dir else None
72 self.normalize = normalize
73 self.n_mels = n_mels
74 self.n_frames = n_frames
76 # Create cache directory if needed
77 if self.cache_dir:
78 self.cache_dir.mkdir(parents=True, exist_ok=True)
80 # Load dataset metadata
81 self.samples = self._load_samples()
83 logger.info(
84 "heimdall_dataset_initialized",
85 split=split,
86 num_samples=len(self.samples),
87 augmentation=augmentation,
88 normalize=normalize,
89 )
91 def _load_samples(self) -> List[Dict]:
92 """
93 Load sample metadata from disk.
95 Expected structure:
96 data_dir/
97 ├── train/
98 │ ├── session_001_iq.npy (IQ data)
99 │ ├── session_001_label.npy (ground truth [lat, lon])
100 │ ├── session_001_meta.pkl (metadata)
101 │ └── ...
102 ├── val/
103 └── test/
104 """
106 split_dir = self.data_dir / self.split
107 if not split_dir.exists():
108 raise FileNotFoundError(f"Split directory not found: {split_dir}")
110 samples = []
112 # Find all sessions in this split
113 session_files = sorted(split_dir.glob("*_iq.npy"))
115 for iq_file in session_files:
116 session_id = iq_file.stem.replace('_iq', '')
118 label_file = split_dir / f"{session_id}_label.npy"
119 meta_file = split_dir / f"{session_id}_meta.pkl"
121 if label_file.exists():
122 samples.append({
123 'session_id': session_id,
124 'iq_file': str(iq_file),
125 'label_file': str(label_file),
126 'meta_file': str(meta_file) if meta_file.exists() else None,
127 })
129 logger.info(
130 "samples_loaded",
131 split=self.split,
132 count=len(samples),
133 )
135 return samples
137 def _get_cache_path(self, session_id: str) -> Path:
138 """Get cache file path for a session."""
139 if not self.cache_dir:
140 return None
141 return self.cache_dir / f"{session_id}_features.pt"
143 def _load_iq_data(self, iq_file: str) -> np.ndarray:
144 """Load IQ data from .npy file."""
145 iq_data = np.load(iq_file)
147 # Expected shape: (3, n_samples) for 3-receiver IQ data
148 if iq_data.ndim == 1:
149 # Single receiver - replicate to 3 channels
150 iq_data = np.tile(iq_data[np.newaxis, :], (3, 1))
151 elif iq_data.ndim == 2 and iq_data.shape[0] != 3:
152 # Wrong number of channels
153 logger.warning("unexpected_iq_shape", shape=iq_data.shape)
155 return iq_data
157 def _extract_features(self, iq_data: np.ndarray) -> np.ndarray:
158 """
159 Extract mel-spectrogram features from IQ data.
161 Returns:
162 np.ndarray: Features of shape (3, n_mels, n_frames)
163 """
164 from src.data.features import iq_to_mel_spectrogram, normalize_features
166 # Process each channel
167 mel_specs = []
168 for ch in range(iq_data.shape[0]):
169 mel_spec = iq_to_mel_spectrogram(
170 iq_data[ch],
171 n_mels=self.n_mels,
172 )
174 # Normalize per-channel
175 if self.normalize:
176 mel_spec, _ = normalize_features(mel_spec)
178 # Resize to fixed shape
179 mel_spec = self._resize_to_fixed_shape(mel_spec)
180 mel_specs.append(mel_spec)
182 features = np.stack(mel_specs, axis=0) # (3, n_mels, n_frames)
183 return features
185 def _resize_to_fixed_shape(self, mel_spec: np.ndarray) -> np.ndarray:
186 """
187 Resize mel-spectrogram to fixed shape (n_mels, n_frames).
189 Uses padding or truncation as needed.
190 """
191 from scipy import signal as scipy_signal
193 current_frames = mel_spec.shape[1]
195 if current_frames == self.n_frames:
196 return mel_spec
197 elif current_frames > self.n_frames:
198 # Truncate to center
199 start = (current_frames - self.n_frames) // 2
200 return mel_spec[:, start:start + self.n_frames]
201 else:
202 # Pad with reflection
203 pad_total = self.n_frames - current_frames
204 pad_left = pad_total // 2
205 pad_right = pad_total - pad_left
206 return np.pad(mel_spec, ((0, 0), (pad_left, pad_right)), mode='reflect')
208 def __len__(self) -> int:
209 """Return number of samples in dataset."""
210 return len(self.samples)
212 def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
213 """
214 Get a single sample.
216 Args:
217 idx (int): Sample index
219 Returns:
220 Tuple containing:
221 - features (torch.Tensor): Mel-spectrogram features (3, 128, 32)
222 - label (torch.Tensor): Ground truth [latitude, longitude]
223 - metadata (Dict): Additional metadata (session_id, etc.)
224 """
226 sample = self.samples[idx]
227 session_id = sample['session_id']
229 # Check cache first
230 cache_path = self._get_cache_path(session_id)
231 if cache_path and cache_path.exists():
232 data = torch.load(cache_path)
233 return data['features'], data['label'], {'session_id': session_id, 'from_cache': True}
235 # Load IQ data
236 iq_data = self._load_iq_data(sample['iq_file'])
238 # Extract features
239 features = self._extract_features(iq_data)
241 # Load ground truth label
242 label = np.load(sample['label_file']) # [lat, lon]
244 # Apply augmentation if training
245 if self.augmentation and self.split == 'train':
246 features = self._augment_features(features)
248 # Convert to tensors
249 features_tensor = torch.from_numpy(features).float()
250 label_tensor = torch.from_numpy(label).float()
252 # Cache if enabled
253 if cache_path:
254 torch.save({
255 'features': features_tensor,
256 'label': label_tensor,
257 }, cache_path)
259 metadata = {
260 'session_id': session_id,
261 'from_cache': False,
262 }
264 if sample['meta_file'] and os.path.exists(sample['meta_file']):
265 with open(sample['meta_file'], 'rb') as f:
266 metadata.update(pickle.load(f))
268 return features_tensor, label_tensor, metadata
270 def _augment_features(self, features: np.ndarray) -> np.ndarray:
271 """Apply data augmentation (optional)."""
273 # Random noise
274 if np.random.rand() < 0.3:
275 noise = np.random.randn(*features.shape) * 0.1
276 features = features + noise
278 # Random time shift
279 if np.random.rand() < 0.3:
280 shift = np.random.randint(-2, 3)
281 if shift != 0:
282 features = np.roll(features, shift, axis=2)
284 return features
286 def get_statistics(self) -> Dict:
287 """
288 Compute dataset statistics (mean, std, min, max across all samples).
290 Useful for understanding data distribution and debugging.
291 """
292 all_features = []
294 for i in range(min(100, len(self))): # Sample first 100
295 features, _, _ = self[i]
296 all_features.append(features.numpy())
298 all_features = np.concatenate([f.flatten() for f in all_features])
300 stats = {
301 'mean': float(np.mean(all_features)),
302 'std': float(np.std(all_features)),
303 'min': float(np.min(all_features)),
304 'max': float(np.max(all_features)),
305 'num_samples_analyzed': min(100, len(self)),
306 }
308 logger.info("dataset_statistics_computed", stats=stats)
310 return stats
313def create_dummy_dataset(output_dir: str, n_train: int = 100, n_val: int = 20):
314 """
315 Create a dummy dataset for testing.
317 Useful for development and debugging before training on real data.
318 """
320 output_path = Path(output_dir)
321 output_path.mkdir(parents=True, exist_ok=True)
323 for split, n_samples in [('train', n_train), ('val', n_val)]:
324 split_dir = output_path / split
325 split_dir.mkdir(exist_ok=True)
327 for i in range(n_samples):
328 session_id = f"session_{i:06d}"
330 # Create synthetic IQ data
331 iq_data = np.random.randn(3, 192000).astype(np.float32) # 3 channels, 1 sec at 192kHz
332 np.save(split_dir / f"{session_id}_iq.npy", iq_data)
334 # Create synthetic label (lat, lon)
335 label = np.array([45.5 + np.random.randn()*0.1, 8.5 + np.random.randn()*0.1]).astype(np.float32)
336 np.save(split_dir / f"{session_id}_label.npy", label)
338 # Create metadata
339 meta = {'receiver_ids': ['r1', 'r2', 'r3'], 'timestamp': datetime.now().isoformat()}
340 with open(split_dir / f"{session_id}_meta.pkl", 'wb') as f:
341 pickle.dump(meta, f)
343 logger.info("dummy_dataset_created", split=split, count=n_samples)
346def verify_dataset():
347 """Verification function for dataset."""
349 logger.info("Starting dataset verification...")
351 # Create dummy dataset
352 dummy_dir = "/tmp/heimdall_dummy_dataset"
353 create_dummy_dataset(dummy_dir, n_train=10, n_val=2)
355 # Create dataset
356 dataset = HeimdallDataset(dummy_dir, split='train', augmentation=False)
358 # Get a sample
359 features, label, metadata = dataset[0]
361 # Verify shapes
362 assert features.shape == (3, 128, 32), f"Expected shape (3, 128, 32), got {features.shape}"
363 assert label.shape == (2,), f"Expected label shape (2,), got {label.shape}"
365 logger.info(
366 "✅ Dataset verification passed!",
367 sample_features_shape=tuple(features.shape),
368 sample_label_shape=tuple(label.shape),
369 sample_metadata=metadata,
370 )
372 return dataset
375if __name__ == "__main__":
376 import logging
377 logging.basicConfig(level=logging.INFO)
378 dataset = verify_dataset()