Coverage for services/training/src/config/model_config.py: 55%
67 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-25 16:18 +0000
1"""
2Model configuration and selection utilities.
4Provides convenient factory functions and configs for different backbone architectures.
5Allows easy experimentation with different models without changing core code.
6"""
8from enum import Enum
9from dataclasses import dataclass
10from typing import Dict
11import structlog
13logger = structlog.get_logger(__name__)
16class BackboneArchitecture(Enum):
17 """Available backbone architectures for LocalizationNet."""
19 # ConvNeXt variants (RECOMMENDED - Modern 2022 architecture)
20 CONVNEXT_TINY = "convnext_tiny" # 29M params, lightweight
21 CONVNEXT_SMALL = "convnext_small" # 50M params, balanced
22 CONVNEXT_MEDIUM = "convnext_base" # 89M params, good accuracy
23 CONVNEXT_LARGE = "convnext_large" # 200M params, best accuracy ⭐
25 # ResNet variants (Traditional, well-tested)
26 RESNET_50 = "resnet50" # 26M params, conservative upgrade
27 RESNET_101 = "resnet101" # 45M params, heavier
29 # Vision Transformers (Experimental, better long-range dependencies)
30 VIT_BASE = "vit_b_16" # 86M params, transformer-based
31 VIT_LARGE = "vit_l_16" # 306M params, very large
33 # EfficientNet (Balanced, good for edge deployment)
34 EFFICIENTNET_B3 = "efficientnet_b3" # 12M params, lightweight
35 EFFICIENTNET_B4 = "efficientnet_b4" # 19M params, balanced
38@dataclass
39class ModelConfig:
40 """Configuration for LocalizationNet."""
42 backbone: BackboneArchitecture = BackboneArchitecture.CONVNEXT_LARGE
43 pretrained: bool = True
44 freeze_backbone: bool = False
45 uncertainty_min: float = 0.01
46 uncertainty_max: float = 1.0
48 # Training hyperparameters
49 learning_rate: float = 1e-3
50 weight_decay: float = 1e-5
51 warmup_steps: int = 500
52 num_training_steps: int = 10000
54 # Data parameters
55 n_mels: int = 128
56 n_frames: int = 32
58 def __post_init__(self):
59 """Validate configuration."""
60 if self.uncertainty_min < 0 or self.uncertainty_max < self.uncertainty_min:
61 raise ValueError("Invalid uncertainty bounds")
62 if self.learning_rate <= 0 or self.weight_decay < 0:
63 raise ValueError("Invalid learning rate or weight decay")
66# Predefined configurations for common use cases
67CONFIGS = {
68 # Production configs
69 'production_high_accuracy': ModelConfig(
70 backbone=BackboneArchitecture.CONVNEXT_LARGE,
71 pretrained=True,
72 freeze_backbone=False,
73 learning_rate=1e-3,
74 num_training_steps=50000,
75 ),
77 'production_balanced': ModelConfig(
78 backbone=BackboneArchitecture.CONVNEXT_MEDIUM,
79 pretrained=True,
80 freeze_backbone=False,
81 learning_rate=1e-3,
82 num_training_steps=30000,
83 ),
85 'production_lightweight': ModelConfig(
86 backbone=BackboneArchitecture.EFFICIENTNET_B4,
87 pretrained=True,
88 freeze_backbone=False,
89 learning_rate=1e-3,
90 num_training_steps=20000,
91 ),
93 # Development/testing configs
94 'dev_fast': ModelConfig(
95 backbone=BackboneArchitecture.CONVNEXT_SMALL,
96 pretrained=True,
97 freeze_backbone=True, # Frozen backbone for faster training
98 learning_rate=5e-4,
99 num_training_steps=5000,
100 ),
102 'dev_test': ModelConfig(
103 backbone=BackboneArchitecture.EFFICIENTNET_B3,
104 pretrained=False, # No pretraining for quick tests
105 freeze_backbone=False,
106 learning_rate=1e-3,
107 num_training_steps=1000,
108 ),
110 # Experimental configs
111 'experimental_vit': ModelConfig(
112 backbone=BackboneArchitecture.VIT_BASE,
113 pretrained=True,
114 freeze_backbone=False,
115 learning_rate=1e-4, # ViT typically needs lower LR
116 num_training_steps=50000,
117 ),
119 'experimental_resnet': ModelConfig(
120 backbone=BackboneArchitecture.RESNET_101,
121 pretrained=True,
122 freeze_backbone=False,
123 learning_rate=1e-3,
124 num_training_steps=30000,
125 ),
126}
129def get_model_config(config_name: str) -> ModelConfig:
130 """
131 Get a predefined model configuration by name.
133 Args:
134 config_name (str): Name of the configuration
136 Returns:
137 ModelConfig: Configuration object
139 Raises:
140 KeyError: If config_name not found
141 """
142 if config_name not in CONFIGS:
143 available = ', '.join(CONFIGS.keys())
144 raise KeyError(f"Unknown config '{config_name}'. Available: {available}")
146 config = CONFIGS[config_name]
147 logger.info(
148 "model_config_loaded",
149 config_name=config_name,
150 backbone=config.backbone.value,
151 learning_rate=config.learning_rate,
152 )
154 return config
157def get_backbone_info(backbone: BackboneArchitecture) -> Dict:
158 """
159 Get information about a backbone architecture.
161 Returns:
162 Dict with: params_millions, imagenet_top1, inference_ms, vram_gb
163 """
164 info_map = {
165 # ConvNeXt (Modern 2022 architecture)
166 BackboneArchitecture.CONVNEXT_TINY: {
167 'params_millions': 29,
168 'imagenet_top1': '81.9%',
169 'inference_ms': 15,
170 'vram_gb': 4,
171 'description': 'Lightweight, fast',
172 },
173 BackboneArchitecture.CONVNEXT_SMALL: {
174 'params_millions': 50,
175 'imagenet_top1': '83.6%',
176 'inference_ms': 20,
177 'vram_gb': 6,
178 'description': 'Balanced speed/accuracy',
179 },
180 BackboneArchitecture.CONVNEXT_MEDIUM: {
181 'params_millions': 89,
182 'imagenet_top1': '86.2%',
183 'inference_ms': 30,
184 'vram_gb': 8,
185 'description': 'Very good accuracy',
186 },
187 BackboneArchitecture.CONVNEXT_LARGE: {
188 'params_millions': 200,
189 'imagenet_top1': '88.6%',
190 'inference_ms': 45,
191 'vram_gb': 12,
192 'description': 'Best accuracy ⭐ RECOMMENDED',
193 },
195 # ResNet (Traditional, well-tested)
196 BackboneArchitecture.RESNET_50: {
197 'params_millions': 26,
198 'imagenet_top1': '76.1%',
199 'inference_ms': 25,
200 'vram_gb': 8,
201 'description': 'Well-tested, conservative',
202 },
203 BackboneArchitecture.RESNET_101: {
204 'params_millions': 45,
205 'imagenet_top1': '77.4%',
206 'inference_ms': 35,
207 'vram_gb': 10,
208 'description': 'Larger ResNet',
209 },
211 # Vision Transformers (Experimental)
212 BackboneArchitecture.VIT_BASE: {
213 'params_millions': 86,
214 'imagenet_top1': '84.1%',
215 'inference_ms': 55,
216 'vram_gb': 10,
217 'description': 'Transformer-based, good long-range',
218 },
219 BackboneArchitecture.VIT_LARGE: {
220 'params_millions': 306,
221 'imagenet_top1': '85.9%',
222 'inference_ms': 90,
223 'vram_gb': 16,
224 'description': 'Very large, max accuracy',
225 },
227 # EfficientNet (Balanced)
228 BackboneArchitecture.EFFICIENTNET_B3: {
229 'params_millions': 12,
230 'imagenet_top1': '81.6%',
231 'inference_ms': 15,
232 'vram_gb': 4,
233 'description': 'Lightweight, good balance',
234 },
235 BackboneArchitecture.EFFICIENTNET_B4: {
236 'params_millions': 19,
237 'imagenet_top1': '83.4%',
238 'inference_ms': 20,
239 'vram_gb': 6,
240 'description': 'Balanced efficiency',
241 },
242 }
244 if backbone not in info_map:
245 raise ValueError(f"Unknown backbone: {backbone}")
247 return info_map[backbone]
250def print_backbone_comparison():
251 """Print a comparison table of all available backbones."""
253 print("\n" + "="*100)
254 print("BACKBONE ARCHITECTURE COMPARISON")
255 print("="*100)
256 print(f"\n{'Architecture':<20} {'Params (M)':<12} {'ImageNet':<12} {'Inference':<12} {'VRAM':<8} {'Description':<30}")
257 print("-"*100)
259 for backbone in BackboneArchitecture:
260 info = get_backbone_info(backbone)
261 print(f"{backbone.value:<20} {info['params_millions']:<12} {info['imagenet_top1']:<12} "
262 f"{info['inference_ms']}ms{'':<8} {info['vram_gb']}GB{'':<4} {info['description']:<30}")
264 print("="*100)
265 print("\n⭐ RECOMMENDED for your hardware: ConvNeXt-Large")
266 print(" - 200M params, 88.6% ImageNet, 45ms inference")
267 print(" - Best accuracy with acceptable speed")
268 print(" - Fits comfortably in RTX 3090 (12GB VRAM needed)\n")
271if __name__ == "__main__":
272 import logging
273 logging.basicConfig(level=logging.INFO)
275 print_backbone_comparison()
277 # Test loading a config
278 config = get_model_config('production_high_accuracy')
279 logger.info("Config loaded successfully", config=config)