Coverage for services/rf-acquisition/src/storage/db_manager.py: 56%
169 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
1"""Database management utilities for TimescaleDB operations."""
3import logging
4from contextlib import contextmanager
5from typing import List, Dict, Any, Optional, Generator, Tuple
6from datetime import datetime, timedelta
8from sqlalchemy import create_engine, select, func, and_, desc
9from sqlalchemy.orm import Session, sessionmaker
10from sqlalchemy.exc import IntegrityError
11from sqlalchemy.pool import NullPool
13try:
14 from ..config import settings
15except ImportError:
16 # For testing
17 from config import settings
19try:
20 from ..models.db import Measurement, Base
21except ImportError:
22 # For testing
23 from models.db import Measurement, Base
25logger = logging.getLogger(__name__)
28class DatabaseManager:
29 """Manages database connections and operations for TimescaleDB."""
31 def __init__(self, database_url: Optional[str] = None):
32 """Initialize database manager."""
33 self.database_url = database_url or settings.database_url
34 self.engine = None
35 self.SessionLocal = None
36 self._initialize_engine()
38 def _initialize_engine(self) -> None:
39 """Initialize SQLAlchemy engine."""
40 try:
41 # Build connect_args based on database URL
42 connect_args = {}
43 poolclass = NullPool # Default: no connection pooling
45 if "postgresql" in self.database_url or "postgres" in self.database_url:
46 connect_args = {"options": "-c timezone=utc"}
47 elif "sqlite" in self.database_url:
48 connect_args = {"check_same_thread": False}
49 # SQLite in-memory needs StaticPool to maintain the connection
50 if ":memory:" in self.database_url:
51 from sqlalchemy.pool import StaticPool
52 poolclass = StaticPool
54 self.engine = create_engine(
55 self.database_url,
56 echo=False,
57 poolclass=poolclass,
58 connect_args=connect_args if connect_args else {}
59 )
60 self.SessionLocal = sessionmaker(
61 bind=self.engine,
62 expire_on_commit=False,
63 autoflush=False
64 )
65 logger.info("Database engine initialized successfully")
66 except Exception as e:
67 logger.error(f"Failed to initialize database engine: {e}")
68 raise
70 def create_tables(self) -> bool:
71 """Create all tables in the database."""
72 try:
73 Base.metadata.create_all(self.engine)
74 logger.info("Database tables created successfully")
75 return True
76 except Exception as e:
77 logger.error(f"Failed to create database tables: {e}")
78 return False
80 def check_connection(self) -> bool:
81 """Check if database connection is working."""
82 try:
83 with self.engine.connect() as conn:
84 result = conn.execute(select(1))
85 result.close()
86 logger.debug("Database connection check successful")
87 return True
88 except Exception as e:
89 logger.error(f"Database connection check failed: {e}")
90 return False
92 @contextmanager
93 def get_session(self) -> Generator[Session, None, None]:
94 """Context manager for database session."""
95 session = self.SessionLocal()
96 try:
97 yield session
98 session.commit()
99 except Exception as e:
100 session.rollback()
101 logger.error(f"Session error, rolling back: {e}")
102 raise
103 finally:
104 session.close()
106 def insert_measurement(
107 self,
108 task_id: str,
109 measurement_dict: Dict[str, Any],
110 s3_path: Optional[str] = None
111 ) -> Optional[int]:
112 """Insert a single measurement into the database."""
113 try:
114 with self.get_session() as session:
115 measurement = Measurement.from_measurement_dict(
116 task_id=task_id,
117 measurement_dict=measurement_dict,
118 s3_path=s3_path
119 )
120 session.add(measurement)
121 session.flush()
122 meas_id = measurement.id
123 logger.debug(f"Inserted measurement {meas_id} for task {task_id}")
124 return meas_id
125 except IntegrityError as e:
126 logger.warning(f"Integrity error inserting measurement: {e}")
127 return None
128 except Exception as e:
129 logger.error(f"Error inserting measurement: {e}")
130 return None
132 def insert_measurements_bulk(
133 self,
134 task_id: str,
135 measurements_list: List[Dict[str, Any]],
136 s3_paths: Optional[Dict[int, str]] = None
137 ) -> Tuple[int, int]:
138 """Bulk insert measurements into the database."""
139 successful = 0
140 failed = 0
142 try:
143 with self.get_session() as session:
144 for measurement_dict in measurements_list:
145 try:
146 websdr_id = measurement_dict.get("websdr_id")
147 s3_path = s3_paths.get(websdr_id) if s3_paths else None
149 measurement = Measurement.from_measurement_dict(
150 task_id=task_id,
151 measurement_dict=measurement_dict,
152 s3_path=s3_path
153 )
154 session.add(measurement)
155 successful += 1
156 except (ValueError, TypeError) as e:
157 logger.warning(
158 f"Skipping invalid measurement for WebSDR "
159 f"{measurement_dict.get('websdr_id')}: {e}"
160 )
161 failed += 1
163 session.commit()
164 logger.info(
165 f"Bulk insert completed: {successful} successful, {failed} failed"
166 )
167 except Exception as e:
168 logger.error(f"Bulk insert error: {e}")
169 failed += len(measurements_list) - successful
171 return successful, failed
173 def get_recent_measurements(
174 self,
175 task_id: Optional[str] = None,
176 websdr_id: Optional[int] = None,
177 limit: int = 100,
178 hours_back: int = 24
179 ) -> List[Measurement]:
180 """Get recent measurements from the database."""
181 try:
182 with self.get_session() as session:
183 cutoff_time = datetime.utcnow() - timedelta(hours=hours_back)
185 query = select(Measurement).where(
186 Measurement.timestamp_utc >= cutoff_time
187 )
189 if task_id:
190 query = query.where(Measurement.task_id == task_id)
191 if websdr_id:
192 query = query.where(Measurement.websdr_id == websdr_id)
194 query = query.order_by(desc(Measurement.timestamp_utc)).limit(limit)
195 results = session.execute(query).scalars().all()
196 logger.debug(f"Retrieved {len(results)} recent measurements")
197 return results
198 except Exception as e:
199 logger.error(f"Error retrieving recent measurements: {e}")
200 return []
202 def get_session_measurements(self, task_id: str) -> Dict[int, List[Measurement]]:
203 """Get all measurements for a specific session/task."""
204 try:
205 with self.get_session() as session:
206 query = select(Measurement).where(
207 Measurement.task_id == task_id
208 ).order_by(Measurement.websdr_id, desc(Measurement.timestamp_utc))
210 results = session.execute(query).scalars().all()
212 # Group by websdr_id
213 grouped = {}
214 for measurement in results:
215 if measurement.websdr_id not in grouped:
216 grouped[measurement.websdr_id] = []
217 grouped[measurement.websdr_id].append(measurement)
219 logger.debug(
220 f"Retrieved {len(results)} measurements for task {task_id} "
221 f"from {len(grouped)} WebSDRs"
222 )
223 return grouped
224 except Exception as e:
225 logger.error(f"Error retrieving session measurements: {e}")
226 return {}
228 def get_snr_statistics(
229 self,
230 task_id: Optional[str] = None,
231 hours_back: int = 24
232 ) -> Dict[int, Dict[str, float]]:
233 """Get SNR statistics grouped by WebSDR."""
234 try:
235 with self.get_session() as session:
236 cutoff_time = datetime.utcnow() - timedelta(hours=hours_back)
238 query = select(
239 Measurement.websdr_id,
240 func.avg(Measurement.snr_db).label("avg_snr"),
241 func.min(Measurement.snr_db).label("min_snr"),
242 func.max(Measurement.snr_db).label("max_snr"),
243 func.count(Measurement.id).label("count")
244 ).where(Measurement.timestamp_utc >= cutoff_time)
246 if task_id:
247 query = query.where(Measurement.task_id == task_id)
249 query = query.group_by(Measurement.websdr_id)
250 results = session.execute(query).all()
252 stats = {}
253 for row in results:
254 stats[row.websdr_id] = {
255 "avg_snr_db": float(row.avg_snr) if row.avg_snr else None,
256 "min_snr_db": float(row.min_snr) if row.min_snr else None,
257 "max_snr_db": float(row.max_snr) if row.max_snr else None,
258 "count": row.count
259 }
261 logger.debug(f"Retrieved SNR statistics for {len(stats)} WebSDRs")
262 return stats
263 except Exception as e:
264 logger.error(f"Error retrieving SNR statistics: {e}")
265 return {}
267 def close(self) -> None:
268 """Close database engine and cleanup resources."""
269 try:
270 if self.engine:
271 self.engine.dispose()
272 logger.info("Database engine closed")
273 except Exception as e:
274 logger.error(f"Error closing database engine: {e}")
277# Global database manager instance
278_db_manager: Optional[DatabaseManager] = None
281def get_db_manager() -> DatabaseManager:
282 """Get or create global database manager instance."""
283 global _db_manager
284 if _db_manager is None:
285 _db_manager = DatabaseManager()
286 return _db_manager
289def reset_db_manager() -> None:
290 """Reset global database manager (useful for testing)."""
291 global _db_manager
292 if _db_manager:
293 _db_manager.close()
294 _db_manager = None