Coverage for services/training/src/mlflow_setup.py: 0%
125 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
1"""
2MLflow tracking integration for Heimdall training pipeline.
4Provides:
5- MLflow client initialization and configuration
6- Experiment and run management
7- Parameter and metric logging
8- Artifact management and registration
9- Model registry operations
10"""
12import os
13import json
14from typing import Dict, Any, Optional
15from pathlib import Path
16import structlog
17import mlflow
18from mlflow.tracking import MlflowClient
19from mlflow.entities import RunStatus
21logger = structlog.get_logger(__name__)
24class MLflowTracker:
25 """
26 Centralized MLflow tracking manager for training pipeline.
28 Responsibilities:
29 - Initialize MLflow tracking server
30 - Create/manage experiments
31 - Log training runs with parameters and metrics
32 - Handle artifact storage (MinIO/S3)
33 - Register models to MLflow Registry
34 """
36 def __init__(
37 self,
38 tracking_uri: str,
39 artifact_uri: str,
40 backend_store_uri: str,
41 registry_uri: str,
42 s3_endpoint_url: str,
43 s3_access_key_id: str,
44 s3_secret_access_key: str,
45 experiment_name: str = "heimdall-localization",
46 ):
47 """
48 Initialize MLflow tracker.
50 Args:
51 tracking_uri (str): PostgreSQL URI for MLflow tracking server
52 Format: postgresql://user:pass@host:port/dbname
53 artifact_uri (str): S3 bucket URI for artifacts
54 Format: s3://bucket-name
55 backend_store_uri (str): PostgreSQL URI for backend store
56 registry_uri (str): URI for model registry
57 s3_endpoint_url (str): MinIO endpoint URL (e.g., http://minio:9000)
58 s3_access_key_id (str): MinIO access key
59 s3_secret_access_key (str): MinIO secret key
60 experiment_name (str): MLflow experiment name
61 """
63 # Store configuration
64 self.tracking_uri = tracking_uri
65 self.artifact_uri = artifact_uri
66 self.backend_store_uri = backend_store_uri
67 self.registry_uri = registry_uri
68 self.experiment_name = experiment_name
70 # Set environment variables for S3/MinIO
71 os.environ['AWS_ACCESS_KEY_ID'] = s3_access_key_id
72 os.environ['AWS_SECRET_ACCESS_KEY'] = s3_secret_access_key
74 # Configure MLflow
75 self._configure_mlflow(
76 tracking_uri,
77 artifact_uri,
78 backend_store_uri,
79 registry_uri,
80 s3_endpoint_url,
81 )
83 # Create client
84 self.client = MlflowClient(tracking_uri=tracking_uri)
86 # Initialize experiment
87 self.experiment_id = self._get_or_create_experiment(experiment_name)
89 logger.info(
90 "mlflow_tracker_initialized",
91 tracking_uri=tracking_uri,
92 artifact_uri=artifact_uri,
93 experiment_name=experiment_name,
94 experiment_id=self.experiment_id,
95 )
97 def _configure_mlflow(
98 self,
99 tracking_uri: str,
100 artifact_uri: str,
101 backend_store_uri: str,
102 registry_uri: str,
103 s3_endpoint_url: str,
104 ):
105 """
106 Configure MLflow connection parameters.
108 Args:
109 tracking_uri (str): PostgreSQL URI for tracking
110 artifact_uri (str): S3 artifact root
111 backend_store_uri (str): Backend store URI
112 registry_uri (str): Model registry URI
113 s3_endpoint_url (str): MinIO endpoint
114 """
116 # Set tracking URI
117 mlflow.set_tracking_uri(tracking_uri)
119 # Configure S3/MinIO environment
120 os.environ['MLFLOW_S3_ENDPOINT_URL'] = s3_endpoint_url
121 os.environ['MLFLOW_S3_IGNORE_TLS'] = 'true'
123 logger.debug(
124 "mlflow_configured",
125 tracking_uri=tracking_uri,
126 artifact_uri=artifact_uri,
127 s3_endpoint_url=s3_endpoint_url,
128 )
130 def _get_or_create_experiment(self, experiment_name: str) -> str:
131 """
132 Get existing experiment by name or create new one.
134 Args:
135 experiment_name (str): Name of experiment
137 Returns:
138 str: Experiment ID
139 """
141 try:
142 experiment = self.client.get_experiment_by_name(experiment_name)
143 if experiment:
144 logger.info(
145 "experiment_found",
146 experiment_name=experiment_name,
147 experiment_id=experiment.experiment_id,
148 )
149 return experiment.experiment_id
150 except Exception as e:
151 logger.warning(
152 "experiment_lookup_failed",
153 experiment_name=experiment_name,
154 error=str(e),
155 )
157 # Create new experiment
158 try:
159 experiment_id = mlflow.create_experiment(
160 name=experiment_name,
161 artifact_location=self.artifact_uri,
162 )
163 logger.info(
164 "experiment_created",
165 experiment_name=experiment_name,
166 experiment_id=experiment_id,
167 )
168 return experiment_id
169 except Exception as e:
170 logger.error(
171 "experiment_creation_failed",
172 experiment_name=experiment_name,
173 error=str(e),
174 )
175 raise
177 def start_run(
178 self,
179 run_name: str,
180 tags: Optional[Dict[str, str]] = None,
181 ) -> str:
182 """
183 Start a new MLflow run.
185 Args:
186 run_name (str): Name for the run
187 tags (dict, optional): Dictionary of tags to set
189 Returns:
190 str: Run ID
191 """
193 # Start run
194 run = mlflow.start_run(
195 experiment_id=self.experiment_id,
196 run_name=run_name,
197 )
199 run_id = run.info.run_id
201 # Set default tags
202 default_tags = {
203 'phase': 'training',
204 'service': 'training',
205 'model': 'LocalizationNet',
206 }
208 if tags:
209 default_tags.update(tags)
211 mlflow.set_tags(default_tags)
213 logger.info(
214 "mlflow_run_started",
215 run_id=run_id,
216 run_name=run_name,
217 experiment_id=self.experiment_id,
218 tags=default_tags,
219 )
221 return run_id
223 def end_run(self, status: str = "FINISHED"):
224 """
225 End the current MLflow run.
227 Args:
228 status (str): Final run status (FINISHED, FAILED, KILLED)
229 """
231 mlflow.end_run(status=status)
233 logger.info(
234 "mlflow_run_ended",
235 status=status,
236 )
238 def log_params(self, params: Dict[str, Any]):
239 """
240 Log training parameters.
242 Args:
243 params (dict): Dictionary of parameters
245 Example:
246 tracker.log_params({
247 'learning_rate': 1e-3,
248 'batch_size': 32,
249 'epochs': 100,
250 'backbone': 'ConvNeXt-Large',
251 })
252 """
254 for key, value in params.items():
255 try:
256 # Convert non-string types
257 if isinstance(value, (list, dict)):
258 value = json.dumps(value)
260 mlflow.log_param(key, value)
261 except Exception as e:
262 logger.warning(
263 "parameter_logging_failed",
264 param_name=key,
265 param_value=str(value)[:100],
266 error=str(e),
267 )
269 def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
270 """
271 Log training metrics.
273 Args:
274 metrics (dict): Dictionary of metric names and values
275 step (int, optional): Step number (epoch)
277 Example:
278 tracker.log_metrics({
279 'train_loss': 0.523,
280 'val_loss': 0.487,
281 'train_mae': 12.3,
282 }, step=epoch)
283 """
285 for metric_name, metric_value in metrics.items():
286 try:
287 mlflow.log_metric(metric_name, metric_value, step=step)
288 except Exception as e:
289 logger.warning(
290 "metric_logging_failed",
291 metric_name=metric_name,
292 metric_value=metric_value,
293 error=str(e),
294 )
296 def log_artifact(self, local_path: str, artifact_path: str = "artifacts"):
297 """
298 Log a local artifact file to MLflow.
300 Args:
301 local_path (str): Local file path
302 artifact_path (str): Destination path in artifact store
304 Example:
305 tracker.log_artifact('model.onnx', 'models')
306 """
308 local_path = Path(local_path)
310 if not local_path.exists():
311 logger.warning(
312 "artifact_not_found",
313 local_path=str(local_path),
314 )
315 return
317 try:
318 mlflow.log_artifact(str(local_path), artifact_path=artifact_path)
319 logger.info(
320 "artifact_logged",
321 local_path=str(local_path),
322 artifact_path=artifact_path,
323 )
324 except Exception as e:
325 logger.error(
326 "artifact_logging_failed",
327 local_path=str(local_path),
328 artifact_path=artifact_path,
329 error=str(e),
330 )
332 def log_artifacts_dir(self, local_dir: str, artifact_path: str = "artifacts"):
333 """
334 Log an entire directory of artifacts.
336 Args:
337 local_dir (str): Local directory path
338 artifact_path (str): Destination path in artifact store
339 """
341 local_dir = Path(local_dir)
343 if not local_dir.is_dir():
344 logger.warning(
345 "artifact_dir_not_found",
346 local_dir=str(local_dir),
347 )
348 return
350 try:
351 mlflow.log_artifacts(str(local_dir), artifact_path=artifact_path)
352 logger.info(
353 "artifact_dir_logged",
354 local_dir=str(local_dir),
355 artifact_path=artifact_path,
356 num_files=len(list(local_dir.rglob("*"))),
357 )
358 except Exception as e:
359 logger.error(
360 "artifact_dir_logging_failed",
361 local_dir=str(local_dir),
362 artifact_path=artifact_path,
363 error=str(e),
364 )
366 def register_model(
367 self,
368 model_name: str,
369 model_uri: str,
370 description: str = "",
371 tags: Optional[Dict[str, str]] = None,
372 ) -> str:
373 """
374 Register model to MLflow Model Registry.
376 Args:
377 model_name (str): Name for the model in registry
378 model_uri (str): URI of model artifacts (runs://<run_id>/path/to/model)
379 description (str): Model description
380 tags (dict, optional): Model tags
382 Returns:
383 str: Model version
385 Example:
386 version = tracker.register_model(
387 model_name="heimdall-localization-v1",
388 model_uri="runs://abc123def/models/model",
389 description="ConvNeXt-Large with uncertainty",
390 tags={'stage': 'production'},
391 )
392 """
394 try:
395 model_version = mlflow.register_model(
396 model_uri=model_uri,
397 name=model_name,
398 tags=tags or {},
399 await_registration_for=300,
400 )
402 logger.info(
403 "model_registered",
404 model_name=model_name,
405 model_version=model_version.version,
406 model_uri=model_uri,
407 description=description,
408 )
410 return model_version.version
411 except Exception as e:
412 logger.error(
413 "model_registration_failed",
414 model_name=model_name,
415 model_uri=model_uri,
416 error=str(e),
417 )
418 raise
420 def transition_model_stage(
421 self,
422 model_name: str,
423 version: str,
424 stage: str,
425 ):
426 """
427 Transition a registered model to a new stage.
429 Args:
430 model_name (str): Name of registered model
431 version (str): Version number
432 stage (str): Target stage (None, Staging, Production, Archived)
434 Example:
435 tracker.transition_model_stage(
436 model_name="heimdall-localization-v1",
437 version="1",
438 stage="Production",
439 )
440 """
442 try:
443 self.client.transition_model_version_stage(
444 name=model_name,
445 version=version,
446 stage=stage,
447 )
449 logger.info(
450 "model_stage_transitioned",
451 model_name=model_name,
452 version=version,
453 stage=stage,
454 )
455 except Exception as e:
456 logger.error(
457 "model_transition_failed",
458 model_name=model_name,
459 version=version,
460 stage=stage,
461 error=str(e),
462 )
464 def get_run_info(self, run_id: str) -> Dict[str, Any]:
465 """
466 Get information about a specific run.
468 Args:
469 run_id (str): Run ID
471 Returns:
472 dict: Run information including metrics, parameters, artifacts
473 """
475 run = self.client.get_run(run_id)
477 return {
478 'run_id': run.info.run_id,
479 'experiment_id': run.info.experiment_id,
480 'status': run.info.status,
481 'start_time': run.info.start_time,
482 'end_time': run.info.end_time,
483 'parameters': run.data.params,
484 'metrics': run.data.metrics,
485 'tags': run.data.tags,
486 }
488 def get_best_run(
489 self,
490 metric: str = "val/loss",
491 compare_fn=min,
492 ) -> Optional[Dict[str, Any]]:
493 """
494 Get the best run from current experiment based on a metric.
496 Args:
497 metric (str): Metric name to compare
498 compare_fn: Comparison function (min or max)
500 Returns:
501 dict: Best run information or None if no runs
503 Example:
504 best = tracker.get_best_run(metric="val/loss", compare_fn=min)
505 """
507 try:
508 runs = mlflow.search_runs(
509 experiment_ids=[self.experiment_id],
510 order_by=[f"metrics.{metric} {'' if compare_fn == min else 'DESC'}"],
511 max_results=1,
512 )
514 if runs.empty:
515 logger.info("no_runs_found", experiment_id=self.experiment_id)
516 return None
518 best_run = runs.iloc[0]
520 logger.info(
521 "best_run_found",
522 run_id=best_run['run_id'],
523 metric=metric,
524 value=best_run[f'metrics.{metric}'],
525 )
527 return {
528 'run_id': best_run['run_id'],
529 'metric_value': best_run[f'metrics.{metric}'],
530 'params': best_run[[col for col in best_run.index if col.startswith('params.')]],
531 }
532 except Exception as e:
533 logger.error(
534 "best_run_lookup_failed",
535 experiment_id=self.experiment_id,
536 metric=metric,
537 error=str(e),
538 )
539 return None
542def initialize_mlflow(settings) -> MLflowTracker:
543 """
544 Initialize MLflow tracker from settings.
546 Args:
547 settings: Pydantic Settings object with MLflow configuration
549 Returns:
550 MLflowTracker: Initialized tracker instance
551 """
553 tracker = MLflowTracker(
554 tracking_uri=settings.mlflow_tracking_uri,
555 artifact_uri=settings.mlflow_artifact_uri,
556 backend_store_uri=settings.mlflow_backend_store_uri,
557 registry_uri=settings.mlflow_registry_uri,
558 s3_endpoint_url=settings.mlflow_s3_endpoint_url,
559 s3_access_key_id=settings.mlflow_s3_access_key_id,
560 s3_secret_access_key=settings.mlflow_s3_secret_access_key,
561 experiment_name=settings.mlflow_experiment_name,
562 )
564 return tracker
567if __name__ == "__main__":
568 """Test MLflow setup"""
569 from src.config import settings
571 tracker = initialize_mlflow(settings)
572 logger.info("✅ MLflow setup complete!")