Coverage for services/training/src/config/model_config.py: 55%

67 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-25 16:18 +0000

1""" 

2Model configuration and selection utilities. 

3 

4Provides convenient factory functions and configs for different backbone architectures. 

5Allows easy experimentation with different models without changing core code. 

6""" 

7 

8from enum import Enum 

9from dataclasses import dataclass 

10from typing import Dict 

11import structlog 

12 

13logger = structlog.get_logger(__name__) 

14 

15 

16class BackboneArchitecture(Enum): 

17 """Available backbone architectures for LocalizationNet.""" 

18 

19 # ConvNeXt variants (RECOMMENDED - Modern 2022 architecture) 

20 CONVNEXT_TINY = "convnext_tiny" # 29M params, lightweight 

21 CONVNEXT_SMALL = "convnext_small" # 50M params, balanced 

22 CONVNEXT_MEDIUM = "convnext_base" # 89M params, good accuracy 

23 CONVNEXT_LARGE = "convnext_large" # 200M params, best accuracy ⭐ 

24 

25 # ResNet variants (Traditional, well-tested) 

26 RESNET_50 = "resnet50" # 26M params, conservative upgrade 

27 RESNET_101 = "resnet101" # 45M params, heavier 

28 

29 # Vision Transformers (Experimental, better long-range dependencies) 

30 VIT_BASE = "vit_b_16" # 86M params, transformer-based 

31 VIT_LARGE = "vit_l_16" # 306M params, very large 

32 

33 # EfficientNet (Balanced, good for edge deployment) 

34 EFFICIENTNET_B3 = "efficientnet_b3" # 12M params, lightweight 

35 EFFICIENTNET_B4 = "efficientnet_b4" # 19M params, balanced 

36 

37 

38@dataclass 

39class ModelConfig: 

40 """Configuration for LocalizationNet.""" 

41 

42 backbone: BackboneArchitecture = BackboneArchitecture.CONVNEXT_LARGE 

43 pretrained: bool = True 

44 freeze_backbone: bool = False 

45 uncertainty_min: float = 0.01 

46 uncertainty_max: float = 1.0 

47 

48 # Training hyperparameters 

49 learning_rate: float = 1e-3 

50 weight_decay: float = 1e-5 

51 warmup_steps: int = 500 

52 num_training_steps: int = 10000 

53 

54 # Data parameters 

55 n_mels: int = 128 

56 n_frames: int = 32 

57 

58 def __post_init__(self): 

59 """Validate configuration.""" 

60 if self.uncertainty_min < 0 or self.uncertainty_max < self.uncertainty_min: 

61 raise ValueError("Invalid uncertainty bounds") 

62 if self.learning_rate <= 0 or self.weight_decay < 0: 

63 raise ValueError("Invalid learning rate or weight decay") 

64 

65 

66# Predefined configurations for common use cases 

67CONFIGS = { 

68 # Production configs 

69 'production_high_accuracy': ModelConfig( 

70 backbone=BackboneArchitecture.CONVNEXT_LARGE, 

71 pretrained=True, 

72 freeze_backbone=False, 

73 learning_rate=1e-3, 

74 num_training_steps=50000, 

75 ), 

76 

77 'production_balanced': ModelConfig( 

78 backbone=BackboneArchitecture.CONVNEXT_MEDIUM, 

79 pretrained=True, 

80 freeze_backbone=False, 

81 learning_rate=1e-3, 

82 num_training_steps=30000, 

83 ), 

84 

85 'production_lightweight': ModelConfig( 

86 backbone=BackboneArchitecture.EFFICIENTNET_B4, 

87 pretrained=True, 

88 freeze_backbone=False, 

89 learning_rate=1e-3, 

90 num_training_steps=20000, 

91 ), 

92 

93 # Development/testing configs 

94 'dev_fast': ModelConfig( 

95 backbone=BackboneArchitecture.CONVNEXT_SMALL, 

96 pretrained=True, 

97 freeze_backbone=True, # Frozen backbone for faster training 

98 learning_rate=5e-4, 

99 num_training_steps=5000, 

100 ), 

101 

102 'dev_test': ModelConfig( 

103 backbone=BackboneArchitecture.EFFICIENTNET_B3, 

104 pretrained=False, # No pretraining for quick tests 

105 freeze_backbone=False, 

106 learning_rate=1e-3, 

107 num_training_steps=1000, 

108 ), 

109 

110 # Experimental configs 

111 'experimental_vit': ModelConfig( 

112 backbone=BackboneArchitecture.VIT_BASE, 

113 pretrained=True, 

114 freeze_backbone=False, 

115 learning_rate=1e-4, # ViT typically needs lower LR 

116 num_training_steps=50000, 

117 ), 

118 

119 'experimental_resnet': ModelConfig( 

120 backbone=BackboneArchitecture.RESNET_101, 

121 pretrained=True, 

122 freeze_backbone=False, 

123 learning_rate=1e-3, 

124 num_training_steps=30000, 

125 ), 

126} 

127 

128 

129def get_model_config(config_name: str) -> ModelConfig: 

130 """ 

131 Get a predefined model configuration by name. 

132  

133 Args: 

134 config_name (str): Name of the configuration 

135  

136 Returns: 

137 ModelConfig: Configuration object 

138  

139 Raises: 

140 KeyError: If config_name not found 

141 """ 

142 if config_name not in CONFIGS: 

143 available = ', '.join(CONFIGS.keys()) 

144 raise KeyError(f"Unknown config '{config_name}'. Available: {available}") 

145 

146 config = CONFIGS[config_name] 

147 logger.info( 

148 "model_config_loaded", 

149 config_name=config_name, 

150 backbone=config.backbone.value, 

151 learning_rate=config.learning_rate, 

152 ) 

153 

154 return config 

155 

156 

157def get_backbone_info(backbone: BackboneArchitecture) -> Dict: 

158 """ 

159 Get information about a backbone architecture. 

160  

161 Returns: 

162 Dict with: params_millions, imagenet_top1, inference_ms, vram_gb 

163 """ 

164 info_map = { 

165 # ConvNeXt (Modern 2022 architecture) 

166 BackboneArchitecture.CONVNEXT_TINY: { 

167 'params_millions': 29, 

168 'imagenet_top1': '81.9%', 

169 'inference_ms': 15, 

170 'vram_gb': 4, 

171 'description': 'Lightweight, fast', 

172 }, 

173 BackboneArchitecture.CONVNEXT_SMALL: { 

174 'params_millions': 50, 

175 'imagenet_top1': '83.6%', 

176 'inference_ms': 20, 

177 'vram_gb': 6, 

178 'description': 'Balanced speed/accuracy', 

179 }, 

180 BackboneArchitecture.CONVNEXT_MEDIUM: { 

181 'params_millions': 89, 

182 'imagenet_top1': '86.2%', 

183 'inference_ms': 30, 

184 'vram_gb': 8, 

185 'description': 'Very good accuracy', 

186 }, 

187 BackboneArchitecture.CONVNEXT_LARGE: { 

188 'params_millions': 200, 

189 'imagenet_top1': '88.6%', 

190 'inference_ms': 45, 

191 'vram_gb': 12, 

192 'description': 'Best accuracy ⭐ RECOMMENDED', 

193 }, 

194 

195 # ResNet (Traditional, well-tested) 

196 BackboneArchitecture.RESNET_50: { 

197 'params_millions': 26, 

198 'imagenet_top1': '76.1%', 

199 'inference_ms': 25, 

200 'vram_gb': 8, 

201 'description': 'Well-tested, conservative', 

202 }, 

203 BackboneArchitecture.RESNET_101: { 

204 'params_millions': 45, 

205 'imagenet_top1': '77.4%', 

206 'inference_ms': 35, 

207 'vram_gb': 10, 

208 'description': 'Larger ResNet', 

209 }, 

210 

211 # Vision Transformers (Experimental) 

212 BackboneArchitecture.VIT_BASE: { 

213 'params_millions': 86, 

214 'imagenet_top1': '84.1%', 

215 'inference_ms': 55, 

216 'vram_gb': 10, 

217 'description': 'Transformer-based, good long-range', 

218 }, 

219 BackboneArchitecture.VIT_LARGE: { 

220 'params_millions': 306, 

221 'imagenet_top1': '85.9%', 

222 'inference_ms': 90, 

223 'vram_gb': 16, 

224 'description': 'Very large, max accuracy', 

225 }, 

226 

227 # EfficientNet (Balanced) 

228 BackboneArchitecture.EFFICIENTNET_B3: { 

229 'params_millions': 12, 

230 'imagenet_top1': '81.6%', 

231 'inference_ms': 15, 

232 'vram_gb': 4, 

233 'description': 'Lightweight, good balance', 

234 }, 

235 BackboneArchitecture.EFFICIENTNET_B4: { 

236 'params_millions': 19, 

237 'imagenet_top1': '83.4%', 

238 'inference_ms': 20, 

239 'vram_gb': 6, 

240 'description': 'Balanced efficiency', 

241 }, 

242 } 

243 

244 if backbone not in info_map: 

245 raise ValueError(f"Unknown backbone: {backbone}") 

246 

247 return info_map[backbone] 

248 

249 

250def print_backbone_comparison(): 

251 """Print a comparison table of all available backbones.""" 

252 

253 print("\n" + "="*100) 

254 print("BACKBONE ARCHITECTURE COMPARISON") 

255 print("="*100) 

256 print(f"\n{'Architecture':<20} {'Params (M)':<12} {'ImageNet':<12} {'Inference':<12} {'VRAM':<8} {'Description':<30}") 

257 print("-"*100) 

258 

259 for backbone in BackboneArchitecture: 

260 info = get_backbone_info(backbone) 

261 print(f"{backbone.value:<20} {info['params_millions']:<12} {info['imagenet_top1']:<12} " 

262 f"{info['inference_ms']}ms{'':<8} {info['vram_gb']}GB{'':<4} {info['description']:<30}") 

263 

264 print("="*100) 

265 print("\n⭐ RECOMMENDED for your hardware: ConvNeXt-Large") 

266 print(" - 200M params, 88.6% ImageNet, 45ms inference") 

267 print(" - Best accuracy with acceptable speed") 

268 print(" - Fits comfortably in RTX 3090 (12GB VRAM needed)\n") 

269 

270 

271if __name__ == "__main__": 

272 import logging 

273 logging.basicConfig(level=logging.INFO) 

274 

275 print_backbone_comparison() 

276 

277 # Test loading a config 

278 config = get_model_config('production_high_accuracy') 

279 logger.info("Config loaded successfully", config=config)