Coverage for services/rf-acquisition/src/storage/db_manager.py: 56%

169 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-25 16:18 +0000

1"""Database management utilities for TimescaleDB operations.""" 

2 

3import logging 

4from contextlib import contextmanager 

5from typing import List, Dict, Any, Optional, Generator, Tuple 

6from datetime import datetime, timedelta 

7 

8from sqlalchemy import create_engine, select, func, and_, desc 

9from sqlalchemy.orm import Session, sessionmaker 

10from sqlalchemy.exc import IntegrityError 

11from sqlalchemy.pool import NullPool 

12 

13try: 

14 from ..config import settings 

15except ImportError: 

16 # For testing 

17 from config import settings 

18 

19try: 

20 from ..models.db import Measurement, Base 

21except ImportError: 

22 # For testing 

23 from models.db import Measurement, Base 

24 

25logger = logging.getLogger(__name__) 

26 

27 

28class DatabaseManager: 

29 """Manages database connections and operations for TimescaleDB.""" 

30 

31 def __init__(self, database_url: Optional[str] = None): 

32 """Initialize database manager.""" 

33 self.database_url = database_url or settings.database_url 

34 self.engine = None 

35 self.SessionLocal = None 

36 self._initialize_engine() 

37 

38 def _initialize_engine(self) -> None: 

39 """Initialize SQLAlchemy engine.""" 

40 try: 

41 # Build connect_args based on database URL 

42 connect_args = {} 

43 poolclass = NullPool # Default: no connection pooling 

44 

45 if "postgresql" in self.database_url or "postgres" in self.database_url: 

46 connect_args = {"options": "-c timezone=utc"} 

47 elif "sqlite" in self.database_url: 

48 connect_args = {"check_same_thread": False} 

49 # SQLite in-memory needs StaticPool to maintain the connection 

50 if ":memory:" in self.database_url: 

51 from sqlalchemy.pool import StaticPool 

52 poolclass = StaticPool 

53 

54 self.engine = create_engine( 

55 self.database_url, 

56 echo=False, 

57 poolclass=poolclass, 

58 connect_args=connect_args if connect_args else {} 

59 ) 

60 self.SessionLocal = sessionmaker( 

61 bind=self.engine, 

62 expire_on_commit=False, 

63 autoflush=False 

64 ) 

65 logger.info("Database engine initialized successfully") 

66 except Exception as e: 

67 logger.error(f"Failed to initialize database engine: {e}") 

68 raise 

69 

70 def create_tables(self) -> bool: 

71 """Create all tables in the database.""" 

72 try: 

73 Base.metadata.create_all(self.engine) 

74 logger.info("Database tables created successfully") 

75 return True 

76 except Exception as e: 

77 logger.error(f"Failed to create database tables: {e}") 

78 return False 

79 

80 def check_connection(self) -> bool: 

81 """Check if database connection is working.""" 

82 try: 

83 with self.engine.connect() as conn: 

84 result = conn.execute(select(1)) 

85 result.close() 

86 logger.debug("Database connection check successful") 

87 return True 

88 except Exception as e: 

89 logger.error(f"Database connection check failed: {e}") 

90 return False 

91 

92 @contextmanager 

93 def get_session(self) -> Generator[Session, None, None]: 

94 """Context manager for database session.""" 

95 session = self.SessionLocal() 

96 try: 

97 yield session 

98 session.commit() 

99 except Exception as e: 

100 session.rollback() 

101 logger.error(f"Session error, rolling back: {e}") 

102 raise 

103 finally: 

104 session.close() 

105 

106 def insert_measurement( 

107 self, 

108 task_id: str, 

109 measurement_dict: Dict[str, Any], 

110 s3_path: Optional[str] = None 

111 ) -> Optional[int]: 

112 """Insert a single measurement into the database.""" 

113 try: 

114 with self.get_session() as session: 

115 measurement = Measurement.from_measurement_dict( 

116 task_id=task_id, 

117 measurement_dict=measurement_dict, 

118 s3_path=s3_path 

119 ) 

120 session.add(measurement) 

121 session.flush() 

122 meas_id = measurement.id 

123 logger.debug(f"Inserted measurement {meas_id} for task {task_id}") 

124 return meas_id 

125 except IntegrityError as e: 

126 logger.warning(f"Integrity error inserting measurement: {e}") 

127 return None 

128 except Exception as e: 

129 logger.error(f"Error inserting measurement: {e}") 

130 return None 

131 

132 def insert_measurements_bulk( 

133 self, 

134 task_id: str, 

135 measurements_list: List[Dict[str, Any]], 

136 s3_paths: Optional[Dict[int, str]] = None 

137 ) -> Tuple[int, int]: 

138 """Bulk insert measurements into the database.""" 

139 successful = 0 

140 failed = 0 

141 

142 try: 

143 with self.get_session() as session: 

144 for measurement_dict in measurements_list: 

145 try: 

146 websdr_id = measurement_dict.get("websdr_id") 

147 s3_path = s3_paths.get(websdr_id) if s3_paths else None 

148 

149 measurement = Measurement.from_measurement_dict( 

150 task_id=task_id, 

151 measurement_dict=measurement_dict, 

152 s3_path=s3_path 

153 ) 

154 session.add(measurement) 

155 successful += 1 

156 except (ValueError, TypeError) as e: 

157 logger.warning( 

158 f"Skipping invalid measurement for WebSDR " 

159 f"{measurement_dict.get('websdr_id')}: {e}" 

160 ) 

161 failed += 1 

162 

163 session.commit() 

164 logger.info( 

165 f"Bulk insert completed: {successful} successful, {failed} failed" 

166 ) 

167 except Exception as e: 

168 logger.error(f"Bulk insert error: {e}") 

169 failed += len(measurements_list) - successful 

170 

171 return successful, failed 

172 

173 def get_recent_measurements( 

174 self, 

175 task_id: Optional[str] = None, 

176 websdr_id: Optional[int] = None, 

177 limit: int = 100, 

178 hours_back: int = 24 

179 ) -> List[Measurement]: 

180 """Get recent measurements from the database.""" 

181 try: 

182 with self.get_session() as session: 

183 cutoff_time = datetime.utcnow() - timedelta(hours=hours_back) 

184 

185 query = select(Measurement).where( 

186 Measurement.timestamp_utc >= cutoff_time 

187 ) 

188 

189 if task_id: 

190 query = query.where(Measurement.task_id == task_id) 

191 if websdr_id: 

192 query = query.where(Measurement.websdr_id == websdr_id) 

193 

194 query = query.order_by(desc(Measurement.timestamp_utc)).limit(limit) 

195 results = session.execute(query).scalars().all() 

196 logger.debug(f"Retrieved {len(results)} recent measurements") 

197 return results 

198 except Exception as e: 

199 logger.error(f"Error retrieving recent measurements: {e}") 

200 return [] 

201 

202 def get_session_measurements(self, task_id: str) -> Dict[int, List[Measurement]]: 

203 """Get all measurements for a specific session/task.""" 

204 try: 

205 with self.get_session() as session: 

206 query = select(Measurement).where( 

207 Measurement.task_id == task_id 

208 ).order_by(Measurement.websdr_id, desc(Measurement.timestamp_utc)) 

209 

210 results = session.execute(query).scalars().all() 

211 

212 # Group by websdr_id 

213 grouped = {} 

214 for measurement in results: 

215 if measurement.websdr_id not in grouped: 

216 grouped[measurement.websdr_id] = [] 

217 grouped[measurement.websdr_id].append(measurement) 

218 

219 logger.debug( 

220 f"Retrieved {len(results)} measurements for task {task_id} " 

221 f"from {len(grouped)} WebSDRs" 

222 ) 

223 return grouped 

224 except Exception as e: 

225 logger.error(f"Error retrieving session measurements: {e}") 

226 return {} 

227 

228 def get_snr_statistics( 

229 self, 

230 task_id: Optional[str] = None, 

231 hours_back: int = 24 

232 ) -> Dict[int, Dict[str, float]]: 

233 """Get SNR statistics grouped by WebSDR.""" 

234 try: 

235 with self.get_session() as session: 

236 cutoff_time = datetime.utcnow() - timedelta(hours=hours_back) 

237 

238 query = select( 

239 Measurement.websdr_id, 

240 func.avg(Measurement.snr_db).label("avg_snr"), 

241 func.min(Measurement.snr_db).label("min_snr"), 

242 func.max(Measurement.snr_db).label("max_snr"), 

243 func.count(Measurement.id).label("count") 

244 ).where(Measurement.timestamp_utc >= cutoff_time) 

245 

246 if task_id: 

247 query = query.where(Measurement.task_id == task_id) 

248 

249 query = query.group_by(Measurement.websdr_id) 

250 results = session.execute(query).all() 

251 

252 stats = {} 

253 for row in results: 

254 stats[row.websdr_id] = { 

255 "avg_snr_db": float(row.avg_snr) if row.avg_snr else None, 

256 "min_snr_db": float(row.min_snr) if row.min_snr else None, 

257 "max_snr_db": float(row.max_snr) if row.max_snr else None, 

258 "count": row.count 

259 } 

260 

261 logger.debug(f"Retrieved SNR statistics for {len(stats)} WebSDRs") 

262 return stats 

263 except Exception as e: 

264 logger.error(f"Error retrieving SNR statistics: {e}") 

265 return {} 

266 

267 def close(self) -> None: 

268 """Close database engine and cleanup resources.""" 

269 try: 

270 if self.engine: 

271 self.engine.dispose() 

272 logger.info("Database engine closed") 

273 except Exception as e: 

274 logger.error(f"Error closing database engine: {e}") 

275 

276 

277# Global database manager instance 

278_db_manager: Optional[DatabaseManager] = None 

279 

280 

281def get_db_manager() -> DatabaseManager: 

282 """Get or create global database manager instance.""" 

283 global _db_manager 

284 if _db_manager is None: 

285 _db_manager = DatabaseManager() 

286 return _db_manager 

287 

288 

289def reset_db_manager() -> None: 

290 """Reset global database manager (useful for testing).""" 

291 global _db_manager 

292 if _db_manager: 

293 _db_manager.close() 

294 _db_manager = None