Coverage for services/inference/src/utils/cache.py: 28%

133 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-25 16:18 +0000

1"""Redis cache manager for Phase 6 Inference Service. 

2 

3Implements caching strategy for prediction results to achieve >80% cache hit rate 

4and reduce latency for repeated queries. 

5""" 

6 

7import redis 

8import json 

9import logging 

10import hashlib 

11from typing import Optional, Dict, Any, List 

12from datetime import datetime 

13import numpy as np 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18class RedisCache: 

19 """ 

20 Redis-based cache for inference predictions. 

21  

22 Caching strategy: 

23 - Key: hash(preprocessed_features) - stable and deterministic 

24 - Value: JSON-serialized prediction result 

25 - TTL: Configurable (default 3600 seconds = 1 hour) 

26 - Hit rate target: >80% 

27 """ 

28 

29 def __init__( 

30 self, 

31 host: str = "localhost", 

32 port: int = 6379, 

33 db: int = 0, 

34 ttl_seconds: int = 3600, 

35 password: Optional[str] = None, 

36 ): 

37 """ 

38 Initialize Redis cache connection. 

39  

40 Args: 

41 host: Redis server hostname 

42 port: Redis server port 

43 db: Redis database number (0-15) 

44 ttl_seconds: Cache entry TTL in seconds (default 1 hour) 

45 password: Redis password (optional) 

46 """ 

47 self.ttl_seconds = ttl_seconds 

48 self.host = host 

49 self.port = port 

50 self.db = db 

51 

52 try: 

53 self.client = redis.Redis( 

54 host=host, 

55 port=port, 

56 db=db, 

57 password=password, 

58 decode_responses=True, # Return strings instead of bytes 

59 socket_connect_timeout=5, 

60 socket_keepalive=True, 

61 ) 

62 # Test connection 

63 self.client.ping() 

64 logger.info(f"Connected to Redis at {host}:{port}/{db}") 

65 except redis.ConnectionError as e: 

66 logger.error(f"Failed to connect to Redis: {e}") 

67 raise 

68 

69 def _generate_cache_key(self, features: np.ndarray) -> str: 

70 """ 

71 Generate stable cache key from features. 

72  

73 Args: 

74 features: Preprocessed features (mel-spectrogram) 

75  

76 Returns: 

77 Cache key string 

78 """ 

79 try: 

80 # Convert to bytes for hashing 

81 features_bytes = features.astype(np.float32).tobytes() 

82 

83 # Create SHA256 hash 

84 hash_obj = hashlib.sha256(features_bytes) 

85 cache_key = f"pred:{hash_obj.hexdigest()}" 

86 

87 logger.debug(f"Generated cache key: {cache_key}") 

88 return cache_key 

89 

90 except Exception as e: 

91 logger.error(f"Failed to generate cache key: {e}") 

92 raise 

93 

94 def get(self, features: np.ndarray) -> Optional[Dict[str, Any]]: 

95 """ 

96 Retrieve cached prediction if available. 

97  

98 Args: 

99 features: Preprocessed features 

100  

101 Returns: 

102 Cached prediction dict, or None if not found 

103 """ 

104 try: 

105 cache_key = self._generate_cache_key(features) 

106 

107 # Retrieve from Redis 

108 cached_value = self.client.get(cache_key) 

109 

110 if cached_value is None: 

111 logger.debug(f"Cache miss for key: {cache_key}") 

112 return None 

113 

114 # Deserialize JSON 

115 result = json.loads(cached_value) 

116 result['_cache_hit'] = True 

117 

118 logger.debug(f"Cache hit: {cache_key}") 

119 return result 

120 

121 except json.JSONDecodeError as e: 

122 logger.warning(f"Failed to deserialize cached value: {e}") 

123 return None 

124 except Exception as e: 

125 logger.error(f"Redis get error: {e}") 

126 return None 

127 

128 def set(self, features: np.ndarray, prediction: Dict[str, Any]) -> bool: 

129 """ 

130 Cache prediction result. 

131  

132 Args: 

133 features: Preprocessed features 

134 prediction: Prediction result dict 

135  

136 Returns: 

137 True if cached successfully, False otherwise 

138 """ 

139 try: 

140 cache_key = self._generate_cache_key(features) 

141 

142 # Prepare value for storage (make it JSON-serializable) 

143 cache_value = self._prepare_for_cache(prediction) 

144 

145 # Serialize to JSON 

146 json_value = json.dumps(cache_value) 

147 

148 # Store in Redis with TTL 

149 success = self.client.setex( 

150 cache_key, 

151 self.ttl_seconds, 

152 json_value 

153 ) 

154 

155 if success: 

156 logger.debug(f"Cached prediction: {cache_key} (TTL: {self.ttl_seconds}s)") 

157 

158 return success 

159 

160 except Exception as e: 

161 logger.error(f"Redis set error: {e}") 

162 return False 

163 

164 def _prepare_for_cache(self, obj: Any) -> Any: 

165 """ 

166 Prepare object for JSON serialization. 

167  

168 Converts numpy types to native Python types. 

169  

170 Args: 

171 obj: Object to prepare 

172  

173 Returns: 

174 JSON-serializable object 

175 """ 

176 if isinstance(obj, dict): 

177 return {k: self._prepare_for_cache(v) for k, v in obj.items()} 

178 elif isinstance(obj, (list, tuple)): 

179 return [self._prepare_for_cache(item) for item in obj] 

180 elif isinstance(obj, np.ndarray): 

181 return obj.tolist() 

182 elif isinstance(obj, (np.integer, np.floating)): 

183 return obj.item() 

184 elif isinstance(obj, datetime): 

185 return obj.isoformat() 

186 else: 

187 return obj 

188 

189 def delete(self, features: np.ndarray) -> bool: 

190 """ 

191 Delete cached prediction. 

192  

193 Args: 

194 features: Preprocessed features 

195  

196 Returns: 

197 True if deleted, False if not found 

198 """ 

199 try: 

200 cache_key = self._generate_cache_key(features) 

201 deleted = self.client.delete(cache_key) 

202 logger.debug(f"Deleted cache entry: {cache_key}") 

203 return deleted > 0 

204 except Exception as e: 

205 logger.error(f"Redis delete error: {e}") 

206 return False 

207 

208 def clear(self) -> bool: 

209 """ 

210 Clear all cache entries in current database. 

211  

212 WARNING: This clears the entire Redis database! 

213  

214 Returns: 

215 True if successful 

216 """ 

217 try: 

218 self.client.flushdb() 

219 logger.warning(f"Cleared all cache entries in database {self.db}") 

220 return True 

221 except Exception as e: 

222 logger.error(f"Redis flush error: {e}") 

223 return False 

224 

225 def get_stats(self) -> Dict[str, Any]: 

226 """ 

227 Get cache statistics from Redis. 

228  

229 Returns: 

230 Statistics dict with cache info 

231 """ 

232 try: 

233 info = self.client.info(section='memory') 

234 

235 stats = { 

236 'used_memory_bytes': info.get('used_memory', 0), 

237 'used_memory_human': info.get('used_memory_human', 'N/A'), 

238 'used_memory_peak': info.get('used_memory_peak', 0), 

239 'total_keys': self.client.dbsize(), 

240 'connection_host': self.host, 

241 'connection_port': self.port, 

242 'connection_db': self.db, 

243 'ttl_seconds': self.ttl_seconds, 

244 } 

245 

246 logger.debug(f"Cache stats: {stats}") 

247 return stats 

248 

249 except Exception as e: 

250 logger.error(f"Failed to get cache stats: {e}") 

251 return {} 

252 

253 def close(self): 

254 """Close Redis connection.""" 

255 try: 

256 self.client.close() 

257 logger.info("Closed Redis connection") 

258 except Exception as e: 

259 logger.error(f"Error closing Redis: {e}") 

260 

261 

262class CacheStatistics: 

263 """Track cache hit/miss statistics.""" 

264 

265 def __init__(self): 

266 """Initialize statistics.""" 

267 self.hits = 0 

268 self.misses = 0 

269 

270 @property 

271 def total(self) -> int: 

272 """Total accesses.""" 

273 return self.hits + self.misses 

274 

275 @property 

276 def hit_rate(self) -> float: 

277 """Cache hit rate (0-1).""" 

278 if self.total == 0: 

279 return 0.0 

280 return self.hits / self.total 

281 

282 def record_hit(self): 

283 """Record cache hit.""" 

284 self.hits += 1 

285 

286 def record_miss(self): 

287 """Record cache miss.""" 

288 self.misses += 1 

289 

290 def reset(self): 

291 """Reset statistics.""" 

292 self.hits = 0 

293 self.misses = 0 

294 

295 def __str__(self) -> str: 

296 """String representation.""" 

297 return ( 

298 f"CacheStats(hits={self.hits}, misses={self.misses}, " 

299 f"total={self.total}, hit_rate={self.hit_rate:.1%})" 

300 ) 

301 

302 def to_dict(self) -> Dict[str, Any]: 

303 """Return as dictionary.""" 

304 return { 

305 'hits': self.hits, 

306 'misses': self.misses, 

307 'total': self.total, 

308 'hit_rate': self.hit_rate, 

309 } 

310 

311 

312def create_cache( 

313 host: str = "localhost", 

314 port: int = 6379, 

315 ttl_seconds: int = 3600, 

316) -> Optional[RedisCache]: 

317 """ 

318 Factory function to create Redis cache with error handling. 

319  

320 Args: 

321 host: Redis host 

322 port: Redis port 

323 ttl_seconds: Cache TTL 

324  

325 Returns: 

326 RedisCache instance, or None if connection failed 

327 """ 

328 try: 

329 return RedisCache(host=host, port=port, ttl_seconds=ttl_seconds) 

330 except Exception as e: 

331 logger.error(f"Failed to create cache: {e}") 

332 return None