Coverage for services/inference/src/utils/metrics.py: 45%

83 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-25 16:18 +0000

1"""Prometheus metrics for Phase 6 Inference Service.""" 

2from prometheus_client import Counter, Histogram, Gauge 

3import time 

4import logging 

5from contextlib import contextmanager 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10# ============================================================================ 

11# INFERENCE METRICS 

12# ============================================================================ 

13 

14# Histogram: inference latency in milliseconds 

15inference_latency = Histogram( 

16 "inference_latency_ms", 

17 "Inference latency in milliseconds (end-to-end)", 

18 buckets=[10, 25, 50, 75, 100, 150, 200, 300, 400, 500, 750, 1000], 

19) 

20 

21# Histogram: preprocessing latency 

22preprocessing_latency = Histogram( 

23 "preprocessing_latency_ms", 

24 "IQ preprocessing latency in milliseconds", 

25 buckets=[1, 5, 10, 25, 50, 100], 

26) 

27 

28# Histogram: ONNX runtime latency (pure inference) 

29onnx_latency = Histogram( 

30 "onnx_latency_ms", 

31 "Pure ONNX runtime latency in milliseconds", 

32 buckets=[5, 10, 25, 50, 75, 100, 150, 200], 

33) 

34 

35# ============================================================================ 

36# CACHE METRICS 

37# ============================================================================ 

38 

39cache_hits = Counter( 

40 "cache_hits_total", 

41 "Total cache hits", 

42) 

43 

44cache_misses = Counter( 

45 "cache_misses_total", 

46 "Total cache misses", 

47) 

48 

49# Gauge: current cache hit rate (0-1) 

50cache_hit_rate = Gauge( 

51 "cache_hit_rate", 

52 "Current cache hit rate (0-1)", 

53) 

54 

55# Gauge: Redis memory usage (bytes) 

56redis_memory_bytes = Gauge( 

57 "redis_memory_bytes", 

58 "Redis memory usage in bytes", 

59) 

60 

61# ============================================================================ 

62# REQUEST METRICS 

63# ============================================================================ 

64 

65requests_total = Counter( 

66 "inference_requests_total", 

67 "Total inference requests by endpoint", 

68 ["endpoint"], 

69) 

70 

71errors_total = Counter( 

72 "inference_errors_total", 

73 "Total inference errors by type", 

74 ["error_type"], 

75) 

76 

77# Gauge: active concurrent requests 

78active_requests = Gauge( 

79 "inference_active_requests", 

80 "Number of active/concurrent inference requests", 

81) 

82 

83# ============================================================================ 

84# MODEL METRICS 

85# ============================================================================ 

86 

87model_reloads = Counter( 

88 "model_reloads_total", 

89 "Total model reloads from MLflow", 

90) 

91 

92model_loads = Gauge( 

93 "model_loaded", 

94 "Is model currently loaded (1=yes, 0=no)", 

95) 

96 

97model_inference_count = Counter( 

98 "model_inference_count_total", 

99 "Total inferences performed with this model", 

100) 

101 

102model_accuracy_meters = Gauge( 

103 "model_accuracy_meters", 

104 "Model accuracy from training metadata (sigma in meters)", 

105) 

106 

107# ============================================================================ 

108# METRIC HELPER FUNCTIONS 

109# ============================================================================ 

110 

111def record_inference_latency(duration_ms: float): 

112 """Record end-to-end inference latency.""" 

113 inference_latency.observe(duration_ms) 

114 

115 

116def record_preprocessing_latency(duration_ms: float): 

117 """Record IQ preprocessing latency.""" 

118 preprocessing_latency.observe(duration_ms) 

119 

120 

121def record_onnx_latency(duration_ms: float): 

122 """Record pure ONNX runtime latency.""" 

123 onnx_latency.observe(duration_ms) 

124 

125 

126def record_cache_hit(): 

127 """Record cache hit and update hit rate.""" 

128 cache_hits.inc() 

129 _update_cache_hit_rate() 

130 

131 

132def record_cache_miss(): 

133 """Record cache miss and update hit rate.""" 

134 cache_misses.inc() 

135 _update_cache_hit_rate() 

136 

137 

138def _update_cache_hit_rate(): 

139 """Update cache hit rate gauge.""" 

140 try: 

141 # Get current hit/miss counts 

142 total_hits = cache_hits._value.get() if hasattr(cache_hits, '_value') else 0 

143 total_misses = cache_misses._value.get() if hasattr(cache_misses, '_value') else 0 

144 total = total_hits + total_misses 

145 

146 if total > 0: 

147 rate = total_hits / total 

148 cache_hit_rate.set(rate) 

149 logger.debug(f"Cache hit rate: {rate:.2%} ({total_hits}/{total})") 

150 except Exception as e: 

151 logger.warning(f"Could not update cache hit rate: {e}") 

152 

153 

154def record_request_error(error_type: str): 

155 """Record inference error.""" 

156 errors_total.labels(error_type=error_type).inc() 

157 

158 

159def set_redis_memory(bytes_value: float): 

160 """Set Redis memory usage gauge.""" 

161 redis_memory_bytes.set(bytes_value) 

162 

163 

164def record_model_reload(): 

165 """Record model reload event.""" 

166 model_reloads.inc() 

167 

168 

169def set_model_loaded(loaded: bool): 

170 """Set model loaded status (1=loaded, 0=not loaded).""" 

171 model_loads.set(1 if loaded else 0) 

172 

173 

174def record_model_inference(): 

175 """Record successful model inference.""" 

176 model_inference_count.inc() 

177 

178 

179def set_model_accuracy(accuracy_m: float): 

180 """Set model accuracy from metadata.""" 

181 model_accuracy_meters.set(accuracy_m) 

182 

183 

184# ============================================================================ 

185# CONTEXT MANAGERS 

186# ============================================================================ 

187 

188@contextmanager 

189def InferenceMetricsContext(endpoint: str): 

190 """ 

191 Context manager for recording inference metrics. 

192  

193 Usage: 

194 with InferenceMetricsContext("predict"): 

195 # Run inference 

196 result = model.predict(data) 

197 # Metrics automatically recorded 

198  

199 Args: 

200 endpoint: API endpoint name for labeling 

201 """ 

202 start_time = time.time() 

203 active_requests.inc() 

204 requests_total.labels(endpoint=endpoint).inc() 

205 

206 try: 

207 yield 

208 except Exception as e: 

209 error_type = type(e).__name__ 

210 record_request_error(error_type) 

211 raise 

212 finally: 

213 duration_ms = (time.time() - start_time) * 1000 

214 record_inference_latency(duration_ms) 

215 active_requests.dec() 

216 logger.debug(f"Endpoint {endpoint} completed in {duration_ms:.2f}ms") 

217 

218 

219@contextmanager 

220def PreprocessingMetricsContext(): 

221 """Context manager for preprocessing latency.""" 

222 start_time = time.time() 

223 try: 

224 yield 

225 finally: 

226 duration_ms = (time.time() - start_time) * 1000 

227 record_preprocessing_latency(duration_ms) 

228 

229 

230@contextmanager 

231def ONNXMetricsContext(): 

232 """Context manager for ONNX runtime latency.""" 

233 start_time = time.time() 

234 try: 

235 yield 

236 finally: 

237 duration_ms = (time.time() - start_time) * 1000 

238 record_onnx_latency(duration_ms)