Coverage for services/inference/src/utils/cache.py: 28%
133 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
1"""Redis cache manager for Phase 6 Inference Service.
3Implements caching strategy for prediction results to achieve >80% cache hit rate
4and reduce latency for repeated queries.
5"""
7import redis
8import json
9import logging
10import hashlib
11from typing import Optional, Dict, Any, List
12from datetime import datetime
13import numpy as np
15logger = logging.getLogger(__name__)
18class RedisCache:
19 """
20 Redis-based cache for inference predictions.
22 Caching strategy:
23 - Key: hash(preprocessed_features) - stable and deterministic
24 - Value: JSON-serialized prediction result
25 - TTL: Configurable (default 3600 seconds = 1 hour)
26 - Hit rate target: >80%
27 """
29 def __init__(
30 self,
31 host: str = "localhost",
32 port: int = 6379,
33 db: int = 0,
34 ttl_seconds: int = 3600,
35 password: Optional[str] = None,
36 ):
37 """
38 Initialize Redis cache connection.
40 Args:
41 host: Redis server hostname
42 port: Redis server port
43 db: Redis database number (0-15)
44 ttl_seconds: Cache entry TTL in seconds (default 1 hour)
45 password: Redis password (optional)
46 """
47 self.ttl_seconds = ttl_seconds
48 self.host = host
49 self.port = port
50 self.db = db
52 try:
53 self.client = redis.Redis(
54 host=host,
55 port=port,
56 db=db,
57 password=password,
58 decode_responses=True, # Return strings instead of bytes
59 socket_connect_timeout=5,
60 socket_keepalive=True,
61 )
62 # Test connection
63 self.client.ping()
64 logger.info(f"Connected to Redis at {host}:{port}/{db}")
65 except redis.ConnectionError as e:
66 logger.error(f"Failed to connect to Redis: {e}")
67 raise
69 def _generate_cache_key(self, features: np.ndarray) -> str:
70 """
71 Generate stable cache key from features.
73 Args:
74 features: Preprocessed features (mel-spectrogram)
76 Returns:
77 Cache key string
78 """
79 try:
80 # Convert to bytes for hashing
81 features_bytes = features.astype(np.float32).tobytes()
83 # Create SHA256 hash
84 hash_obj = hashlib.sha256(features_bytes)
85 cache_key = f"pred:{hash_obj.hexdigest()}"
87 logger.debug(f"Generated cache key: {cache_key}")
88 return cache_key
90 except Exception as e:
91 logger.error(f"Failed to generate cache key: {e}")
92 raise
94 def get(self, features: np.ndarray) -> Optional[Dict[str, Any]]:
95 """
96 Retrieve cached prediction if available.
98 Args:
99 features: Preprocessed features
101 Returns:
102 Cached prediction dict, or None if not found
103 """
104 try:
105 cache_key = self._generate_cache_key(features)
107 # Retrieve from Redis
108 cached_value = self.client.get(cache_key)
110 if cached_value is None:
111 logger.debug(f"Cache miss for key: {cache_key}")
112 return None
114 # Deserialize JSON
115 result = json.loads(cached_value)
116 result['_cache_hit'] = True
118 logger.debug(f"Cache hit: {cache_key}")
119 return result
121 except json.JSONDecodeError as e:
122 logger.warning(f"Failed to deserialize cached value: {e}")
123 return None
124 except Exception as e:
125 logger.error(f"Redis get error: {e}")
126 return None
128 def set(self, features: np.ndarray, prediction: Dict[str, Any]) -> bool:
129 """
130 Cache prediction result.
132 Args:
133 features: Preprocessed features
134 prediction: Prediction result dict
136 Returns:
137 True if cached successfully, False otherwise
138 """
139 try:
140 cache_key = self._generate_cache_key(features)
142 # Prepare value for storage (make it JSON-serializable)
143 cache_value = self._prepare_for_cache(prediction)
145 # Serialize to JSON
146 json_value = json.dumps(cache_value)
148 # Store in Redis with TTL
149 success = self.client.setex(
150 cache_key,
151 self.ttl_seconds,
152 json_value
153 )
155 if success:
156 logger.debug(f"Cached prediction: {cache_key} (TTL: {self.ttl_seconds}s)")
158 return success
160 except Exception as e:
161 logger.error(f"Redis set error: {e}")
162 return False
164 def _prepare_for_cache(self, obj: Any) -> Any:
165 """
166 Prepare object for JSON serialization.
168 Converts numpy types to native Python types.
170 Args:
171 obj: Object to prepare
173 Returns:
174 JSON-serializable object
175 """
176 if isinstance(obj, dict):
177 return {k: self._prepare_for_cache(v) for k, v in obj.items()}
178 elif isinstance(obj, (list, tuple)):
179 return [self._prepare_for_cache(item) for item in obj]
180 elif isinstance(obj, np.ndarray):
181 return obj.tolist()
182 elif isinstance(obj, (np.integer, np.floating)):
183 return obj.item()
184 elif isinstance(obj, datetime):
185 return obj.isoformat()
186 else:
187 return obj
189 def delete(self, features: np.ndarray) -> bool:
190 """
191 Delete cached prediction.
193 Args:
194 features: Preprocessed features
196 Returns:
197 True if deleted, False if not found
198 """
199 try:
200 cache_key = self._generate_cache_key(features)
201 deleted = self.client.delete(cache_key)
202 logger.debug(f"Deleted cache entry: {cache_key}")
203 return deleted > 0
204 except Exception as e:
205 logger.error(f"Redis delete error: {e}")
206 return False
208 def clear(self) -> bool:
209 """
210 Clear all cache entries in current database.
212 WARNING: This clears the entire Redis database!
214 Returns:
215 True if successful
216 """
217 try:
218 self.client.flushdb()
219 logger.warning(f"Cleared all cache entries in database {self.db}")
220 return True
221 except Exception as e:
222 logger.error(f"Redis flush error: {e}")
223 return False
225 def get_stats(self) -> Dict[str, Any]:
226 """
227 Get cache statistics from Redis.
229 Returns:
230 Statistics dict with cache info
231 """
232 try:
233 info = self.client.info(section='memory')
235 stats = {
236 'used_memory_bytes': info.get('used_memory', 0),
237 'used_memory_human': info.get('used_memory_human', 'N/A'),
238 'used_memory_peak': info.get('used_memory_peak', 0),
239 'total_keys': self.client.dbsize(),
240 'connection_host': self.host,
241 'connection_port': self.port,
242 'connection_db': self.db,
243 'ttl_seconds': self.ttl_seconds,
244 }
246 logger.debug(f"Cache stats: {stats}")
247 return stats
249 except Exception as e:
250 logger.error(f"Failed to get cache stats: {e}")
251 return {}
253 def close(self):
254 """Close Redis connection."""
255 try:
256 self.client.close()
257 logger.info("Closed Redis connection")
258 except Exception as e:
259 logger.error(f"Error closing Redis: {e}")
262class CacheStatistics:
263 """Track cache hit/miss statistics."""
265 def __init__(self):
266 """Initialize statistics."""
267 self.hits = 0
268 self.misses = 0
270 @property
271 def total(self) -> int:
272 """Total accesses."""
273 return self.hits + self.misses
275 @property
276 def hit_rate(self) -> float:
277 """Cache hit rate (0-1)."""
278 if self.total == 0:
279 return 0.0
280 return self.hits / self.total
282 def record_hit(self):
283 """Record cache hit."""
284 self.hits += 1
286 def record_miss(self):
287 """Record cache miss."""
288 self.misses += 1
290 def reset(self):
291 """Reset statistics."""
292 self.hits = 0
293 self.misses = 0
295 def __str__(self) -> str:
296 """String representation."""
297 return (
298 f"CacheStats(hits={self.hits}, misses={self.misses}, "
299 f"total={self.total}, hit_rate={self.hit_rate:.1%})"
300 )
302 def to_dict(self) -> Dict[str, Any]:
303 """Return as dictionary."""
304 return {
305 'hits': self.hits,
306 'misses': self.misses,
307 'total': self.total,
308 'hit_rate': self.hit_rate,
309 }
312def create_cache(
313 host: str = "localhost",
314 port: int = 6379,
315 ttl_seconds: int = 3600,
316) -> Optional[RedisCache]:
317 """
318 Factory function to create Redis cache with error handling.
320 Args:
321 host: Redis host
322 port: Redis port
323 ttl_seconds: Cache TTL
325 Returns:
326 RedisCache instance, or None if connection failed
327 """
328 try:
329 return RedisCache(host=host, port=port, ttl_seconds=ttl_seconds)
330 except Exception as e:
331 logger.error(f"Failed to create cache: {e}")
332 return None