Coverage for services/training/src/mlflow_setup.py: 0%

125 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-25 16:18 +0000

1""" 

2MLflow tracking integration for Heimdall training pipeline. 

3 

4Provides: 

5- MLflow client initialization and configuration 

6- Experiment and run management 

7- Parameter and metric logging 

8- Artifact management and registration 

9- Model registry operations 

10""" 

11 

12import os 

13import json 

14from typing import Dict, Any, Optional 

15from pathlib import Path 

16import structlog 

17import mlflow 

18from mlflow.tracking import MlflowClient 

19from mlflow.entities import RunStatus 

20 

21logger = structlog.get_logger(__name__) 

22 

23 

24class MLflowTracker: 

25 """ 

26 Centralized MLflow tracking manager for training pipeline. 

27  

28 Responsibilities: 

29 - Initialize MLflow tracking server 

30 - Create/manage experiments 

31 - Log training runs with parameters and metrics 

32 - Handle artifact storage (MinIO/S3) 

33 - Register models to MLflow Registry 

34 """ 

35 

36 def __init__( 

37 self, 

38 tracking_uri: str, 

39 artifact_uri: str, 

40 backend_store_uri: str, 

41 registry_uri: str, 

42 s3_endpoint_url: str, 

43 s3_access_key_id: str, 

44 s3_secret_access_key: str, 

45 experiment_name: str = "heimdall-localization", 

46 ): 

47 """ 

48 Initialize MLflow tracker. 

49  

50 Args: 

51 tracking_uri (str): PostgreSQL URI for MLflow tracking server 

52 Format: postgresql://user:pass@host:port/dbname 

53 artifact_uri (str): S3 bucket URI for artifacts 

54 Format: s3://bucket-name 

55 backend_store_uri (str): PostgreSQL URI for backend store 

56 registry_uri (str): URI for model registry 

57 s3_endpoint_url (str): MinIO endpoint URL (e.g., http://minio:9000) 

58 s3_access_key_id (str): MinIO access key 

59 s3_secret_access_key (str): MinIO secret key 

60 experiment_name (str): MLflow experiment name 

61 """ 

62 

63 # Store configuration 

64 self.tracking_uri = tracking_uri 

65 self.artifact_uri = artifact_uri 

66 self.backend_store_uri = backend_store_uri 

67 self.registry_uri = registry_uri 

68 self.experiment_name = experiment_name 

69 

70 # Set environment variables for S3/MinIO 

71 os.environ['AWS_ACCESS_KEY_ID'] = s3_access_key_id 

72 os.environ['AWS_SECRET_ACCESS_KEY'] = s3_secret_access_key 

73 

74 # Configure MLflow 

75 self._configure_mlflow( 

76 tracking_uri, 

77 artifact_uri, 

78 backend_store_uri, 

79 registry_uri, 

80 s3_endpoint_url, 

81 ) 

82 

83 # Create client 

84 self.client = MlflowClient(tracking_uri=tracking_uri) 

85 

86 # Initialize experiment 

87 self.experiment_id = self._get_or_create_experiment(experiment_name) 

88 

89 logger.info( 

90 "mlflow_tracker_initialized", 

91 tracking_uri=tracking_uri, 

92 artifact_uri=artifact_uri, 

93 experiment_name=experiment_name, 

94 experiment_id=self.experiment_id, 

95 ) 

96 

97 def _configure_mlflow( 

98 self, 

99 tracking_uri: str, 

100 artifact_uri: str, 

101 backend_store_uri: str, 

102 registry_uri: str, 

103 s3_endpoint_url: str, 

104 ): 

105 """ 

106 Configure MLflow connection parameters. 

107  

108 Args: 

109 tracking_uri (str): PostgreSQL URI for tracking 

110 artifact_uri (str): S3 artifact root 

111 backend_store_uri (str): Backend store URI 

112 registry_uri (str): Model registry URI 

113 s3_endpoint_url (str): MinIO endpoint 

114 """ 

115 

116 # Set tracking URI 

117 mlflow.set_tracking_uri(tracking_uri) 

118 

119 # Configure S3/MinIO environment 

120 os.environ['MLFLOW_S3_ENDPOINT_URL'] = s3_endpoint_url 

121 os.environ['MLFLOW_S3_IGNORE_TLS'] = 'true' 

122 

123 logger.debug( 

124 "mlflow_configured", 

125 tracking_uri=tracking_uri, 

126 artifact_uri=artifact_uri, 

127 s3_endpoint_url=s3_endpoint_url, 

128 ) 

129 

130 def _get_or_create_experiment(self, experiment_name: str) -> str: 

131 """ 

132 Get existing experiment by name or create new one. 

133  

134 Args: 

135 experiment_name (str): Name of experiment 

136  

137 Returns: 

138 str: Experiment ID 

139 """ 

140 

141 try: 

142 experiment = self.client.get_experiment_by_name(experiment_name) 

143 if experiment: 

144 logger.info( 

145 "experiment_found", 

146 experiment_name=experiment_name, 

147 experiment_id=experiment.experiment_id, 

148 ) 

149 return experiment.experiment_id 

150 except Exception as e: 

151 logger.warning( 

152 "experiment_lookup_failed", 

153 experiment_name=experiment_name, 

154 error=str(e), 

155 ) 

156 

157 # Create new experiment 

158 try: 

159 experiment_id = mlflow.create_experiment( 

160 name=experiment_name, 

161 artifact_location=self.artifact_uri, 

162 ) 

163 logger.info( 

164 "experiment_created", 

165 experiment_name=experiment_name, 

166 experiment_id=experiment_id, 

167 ) 

168 return experiment_id 

169 except Exception as e: 

170 logger.error( 

171 "experiment_creation_failed", 

172 experiment_name=experiment_name, 

173 error=str(e), 

174 ) 

175 raise 

176 

177 def start_run( 

178 self, 

179 run_name: str, 

180 tags: Optional[Dict[str, str]] = None, 

181 ) -> str: 

182 """ 

183 Start a new MLflow run. 

184  

185 Args: 

186 run_name (str): Name for the run 

187 tags (dict, optional): Dictionary of tags to set 

188  

189 Returns: 

190 str: Run ID 

191 """ 

192 

193 # Start run 

194 run = mlflow.start_run( 

195 experiment_id=self.experiment_id, 

196 run_name=run_name, 

197 ) 

198 

199 run_id = run.info.run_id 

200 

201 # Set default tags 

202 default_tags = { 

203 'phase': 'training', 

204 'service': 'training', 

205 'model': 'LocalizationNet', 

206 } 

207 

208 if tags: 

209 default_tags.update(tags) 

210 

211 mlflow.set_tags(default_tags) 

212 

213 logger.info( 

214 "mlflow_run_started", 

215 run_id=run_id, 

216 run_name=run_name, 

217 experiment_id=self.experiment_id, 

218 tags=default_tags, 

219 ) 

220 

221 return run_id 

222 

223 def end_run(self, status: str = "FINISHED"): 

224 """ 

225 End the current MLflow run. 

226  

227 Args: 

228 status (str): Final run status (FINISHED, FAILED, KILLED) 

229 """ 

230 

231 mlflow.end_run(status=status) 

232 

233 logger.info( 

234 "mlflow_run_ended", 

235 status=status, 

236 ) 

237 

238 def log_params(self, params: Dict[str, Any]): 

239 """ 

240 Log training parameters. 

241  

242 Args: 

243 params (dict): Dictionary of parameters 

244  

245 Example: 

246 tracker.log_params({ 

247 'learning_rate': 1e-3, 

248 'batch_size': 32, 

249 'epochs': 100, 

250 'backbone': 'ConvNeXt-Large', 

251 }) 

252 """ 

253 

254 for key, value in params.items(): 

255 try: 

256 # Convert non-string types 

257 if isinstance(value, (list, dict)): 

258 value = json.dumps(value) 

259 

260 mlflow.log_param(key, value) 

261 except Exception as e: 

262 logger.warning( 

263 "parameter_logging_failed", 

264 param_name=key, 

265 param_value=str(value)[:100], 

266 error=str(e), 

267 ) 

268 

269 def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): 

270 """ 

271 Log training metrics. 

272  

273 Args: 

274 metrics (dict): Dictionary of metric names and values 

275 step (int, optional): Step number (epoch) 

276  

277 Example: 

278 tracker.log_metrics({ 

279 'train_loss': 0.523, 

280 'val_loss': 0.487, 

281 'train_mae': 12.3, 

282 }, step=epoch) 

283 """ 

284 

285 for metric_name, metric_value in metrics.items(): 

286 try: 

287 mlflow.log_metric(metric_name, metric_value, step=step) 

288 except Exception as e: 

289 logger.warning( 

290 "metric_logging_failed", 

291 metric_name=metric_name, 

292 metric_value=metric_value, 

293 error=str(e), 

294 ) 

295 

296 def log_artifact(self, local_path: str, artifact_path: str = "artifacts"): 

297 """ 

298 Log a local artifact file to MLflow. 

299  

300 Args: 

301 local_path (str): Local file path 

302 artifact_path (str): Destination path in artifact store 

303  

304 Example: 

305 tracker.log_artifact('model.onnx', 'models') 

306 """ 

307 

308 local_path = Path(local_path) 

309 

310 if not local_path.exists(): 

311 logger.warning( 

312 "artifact_not_found", 

313 local_path=str(local_path), 

314 ) 

315 return 

316 

317 try: 

318 mlflow.log_artifact(str(local_path), artifact_path=artifact_path) 

319 logger.info( 

320 "artifact_logged", 

321 local_path=str(local_path), 

322 artifact_path=artifact_path, 

323 ) 

324 except Exception as e: 

325 logger.error( 

326 "artifact_logging_failed", 

327 local_path=str(local_path), 

328 artifact_path=artifact_path, 

329 error=str(e), 

330 ) 

331 

332 def log_artifacts_dir(self, local_dir: str, artifact_path: str = "artifacts"): 

333 """ 

334 Log an entire directory of artifacts. 

335  

336 Args: 

337 local_dir (str): Local directory path 

338 artifact_path (str): Destination path in artifact store 

339 """ 

340 

341 local_dir = Path(local_dir) 

342 

343 if not local_dir.is_dir(): 

344 logger.warning( 

345 "artifact_dir_not_found", 

346 local_dir=str(local_dir), 

347 ) 

348 return 

349 

350 try: 

351 mlflow.log_artifacts(str(local_dir), artifact_path=artifact_path) 

352 logger.info( 

353 "artifact_dir_logged", 

354 local_dir=str(local_dir), 

355 artifact_path=artifact_path, 

356 num_files=len(list(local_dir.rglob("*"))), 

357 ) 

358 except Exception as e: 

359 logger.error( 

360 "artifact_dir_logging_failed", 

361 local_dir=str(local_dir), 

362 artifact_path=artifact_path, 

363 error=str(e), 

364 ) 

365 

366 def register_model( 

367 self, 

368 model_name: str, 

369 model_uri: str, 

370 description: str = "", 

371 tags: Optional[Dict[str, str]] = None, 

372 ) -> str: 

373 """ 

374 Register model to MLflow Model Registry. 

375  

376 Args: 

377 model_name (str): Name for the model in registry 

378 model_uri (str): URI of model artifacts (runs://<run_id>/path/to/model) 

379 description (str): Model description 

380 tags (dict, optional): Model tags 

381  

382 Returns: 

383 str: Model version 

384  

385 Example: 

386 version = tracker.register_model( 

387 model_name="heimdall-localization-v1", 

388 model_uri="runs://abc123def/models/model", 

389 description="ConvNeXt-Large with uncertainty", 

390 tags={'stage': 'production'}, 

391 ) 

392 """ 

393 

394 try: 

395 model_version = mlflow.register_model( 

396 model_uri=model_uri, 

397 name=model_name, 

398 tags=tags or {}, 

399 await_registration_for=300, 

400 ) 

401 

402 logger.info( 

403 "model_registered", 

404 model_name=model_name, 

405 model_version=model_version.version, 

406 model_uri=model_uri, 

407 description=description, 

408 ) 

409 

410 return model_version.version 

411 except Exception as e: 

412 logger.error( 

413 "model_registration_failed", 

414 model_name=model_name, 

415 model_uri=model_uri, 

416 error=str(e), 

417 ) 

418 raise 

419 

420 def transition_model_stage( 

421 self, 

422 model_name: str, 

423 version: str, 

424 stage: str, 

425 ): 

426 """ 

427 Transition a registered model to a new stage. 

428  

429 Args: 

430 model_name (str): Name of registered model 

431 version (str): Version number 

432 stage (str): Target stage (None, Staging, Production, Archived) 

433  

434 Example: 

435 tracker.transition_model_stage( 

436 model_name="heimdall-localization-v1", 

437 version="1", 

438 stage="Production", 

439 ) 

440 """ 

441 

442 try: 

443 self.client.transition_model_version_stage( 

444 name=model_name, 

445 version=version, 

446 stage=stage, 

447 ) 

448 

449 logger.info( 

450 "model_stage_transitioned", 

451 model_name=model_name, 

452 version=version, 

453 stage=stage, 

454 ) 

455 except Exception as e: 

456 logger.error( 

457 "model_transition_failed", 

458 model_name=model_name, 

459 version=version, 

460 stage=stage, 

461 error=str(e), 

462 ) 

463 

464 def get_run_info(self, run_id: str) -> Dict[str, Any]: 

465 """ 

466 Get information about a specific run. 

467  

468 Args: 

469 run_id (str): Run ID 

470  

471 Returns: 

472 dict: Run information including metrics, parameters, artifacts 

473 """ 

474 

475 run = self.client.get_run(run_id) 

476 

477 return { 

478 'run_id': run.info.run_id, 

479 'experiment_id': run.info.experiment_id, 

480 'status': run.info.status, 

481 'start_time': run.info.start_time, 

482 'end_time': run.info.end_time, 

483 'parameters': run.data.params, 

484 'metrics': run.data.metrics, 

485 'tags': run.data.tags, 

486 } 

487 

488 def get_best_run( 

489 self, 

490 metric: str = "val/loss", 

491 compare_fn=min, 

492 ) -> Optional[Dict[str, Any]]: 

493 """ 

494 Get the best run from current experiment based on a metric. 

495  

496 Args: 

497 metric (str): Metric name to compare 

498 compare_fn: Comparison function (min or max) 

499  

500 Returns: 

501 dict: Best run information or None if no runs 

502  

503 Example: 

504 best = tracker.get_best_run(metric="val/loss", compare_fn=min) 

505 """ 

506 

507 try: 

508 runs = mlflow.search_runs( 

509 experiment_ids=[self.experiment_id], 

510 order_by=[f"metrics.{metric} {'' if compare_fn == min else 'DESC'}"], 

511 max_results=1, 

512 ) 

513 

514 if runs.empty: 

515 logger.info("no_runs_found", experiment_id=self.experiment_id) 

516 return None 

517 

518 best_run = runs.iloc[0] 

519 

520 logger.info( 

521 "best_run_found", 

522 run_id=best_run['run_id'], 

523 metric=metric, 

524 value=best_run[f'metrics.{metric}'], 

525 ) 

526 

527 return { 

528 'run_id': best_run['run_id'], 

529 'metric_value': best_run[f'metrics.{metric}'], 

530 'params': best_run[[col for col in best_run.index if col.startswith('params.')]], 

531 } 

532 except Exception as e: 

533 logger.error( 

534 "best_run_lookup_failed", 

535 experiment_id=self.experiment_id, 

536 metric=metric, 

537 error=str(e), 

538 ) 

539 return None 

540 

541 

542def initialize_mlflow(settings) -> MLflowTracker: 

543 """ 

544 Initialize MLflow tracker from settings. 

545  

546 Args: 

547 settings: Pydantic Settings object with MLflow configuration 

548  

549 Returns: 

550 MLflowTracker: Initialized tracker instance 

551 """ 

552 

553 tracker = MLflowTracker( 

554 tracking_uri=settings.mlflow_tracking_uri, 

555 artifact_uri=settings.mlflow_artifact_uri, 

556 backend_store_uri=settings.mlflow_backend_store_uri, 

557 registry_uri=settings.mlflow_registry_uri, 

558 s3_endpoint_url=settings.mlflow_s3_endpoint_url, 

559 s3_access_key_id=settings.mlflow_s3_access_key_id, 

560 s3_secret_access_key=settings.mlflow_s3_secret_access_key, 

561 experiment_name=settings.mlflow_experiment_name, 

562 ) 

563 

564 return tracker 

565 

566 

567if __name__ == "__main__": 

568 """Test MLflow setup""" 

569 from src.config import settings 

570 

571 tracker = initialize_mlflow(settings) 

572 logger.info("✅ MLflow setup complete!")