Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Artist Style Embedding - Configuration | |
| Maximum Performance Settings for RTX 5090 | |
| """ | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| import torch | |
| class DataConfig: | |
| # 데이터셋 경로 | |
| dataset_root: str = "./dataset" | |
| dataset_face_root: str = "./dataset_face" | |
| dataset_eyes_root: str = "./dataset_eyes" | |
| # 이미지 설정 | |
| image_size: int = 224 | |
| min_images_per_artist: int = 3 | |
| # 데이터 분할 | |
| train_ratio: float = 0.8 | |
| val_ratio: float = 0.1 | |
| test_ratio: float = 0.1 | |
| # 데이터 로딩 | |
| num_workers: int = 12 | |
| pin_memory: bool = True | |
| class ModelConfig: | |
| # Backbone - EVA02-Large (최고 성능) | |
| backbone: str = "eva02_large_patch14_clip_224" | |
| backbone_pretrained: bool = True | |
| freeze_backbone_epochs: int = 0 # 처음부터 unfreeze | |
| # 임베딩 설정 | |
| embedding_dim: int = 512 | |
| hidden_dim: int = 1024 | |
| # Multi-branch 설정 - 모든 브랜치 활성화, 별도 백본 | |
| use_face_branch: bool = True | |
| use_eye_branch: bool = True | |
| share_backbone_weights: bool = False # 별도 백본으로 최고 성능 | |
| # Fusion 설정 | |
| fusion_type: str = "gated" | |
| num_attention_heads: int = 8 | |
| # Dropout | |
| dropout: float = 0.2 # 약간 높임 | |
| class LossConfig: | |
| # ArcFace settings | |
| arcface_scale: float = 64.0 | |
| arcface_margin: float = 0.5 | |
| arcface_weight: float = 0.2 | |
| # Multi-Similarity Loss weight | |
| ms_loss_weight: float = 3.0 | |
| # Center Loss weight | |
| center_loss_weight: float = 0.01 | |
| class TrainConfig: | |
| # 학습 설정 | |
| epochs: int = 100 | |
| batch_size: int = 128 | |
| # Optimizer - 더 높은 learning rate | |
| learning_rate: float = 5e-4 # 1e-4 → 5e-4 | |
| backbone_lr_multiplier: float = 0.2 # 0.1 → 0.2 (backbone도 더 학습) | |
| weight_decay: float = 0.01 # 0.05 → 0.01 (regularization 줄임) | |
| # Scheduler | |
| warmup_epochs: int = 3 # 5 → 3 | |
| min_lr: float = 1e-6 | |
| # Gradient | |
| max_grad_norm: float = 1.0 | |
| # Mixed precision | |
| use_amp: bool = True | |
| # 체크포인트 | |
| save_dir: str = "./checkpoints" | |
| save_every_n_epochs: int = 5 | |
| # 로깅 | |
| log_every_n_steps: int = 50 | |
| wandb_project: Optional[str] = "artist-style-embedding" | |
| wandb_run_name: Optional[str] = None | |
| # Sampling | |
| samples_per_class: int = 4 | |
| # Early stopping | |
| patience: int = 20 # 더 오래 기다림 | |
| # Device | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Random seed | |
| seed: int = 42 | |
| class Config: | |
| data: DataConfig = field(default_factory=DataConfig) | |
| model: ModelConfig = field(default_factory=ModelConfig) | |
| loss: LossConfig = field(default_factory=LossConfig) | |
| train: TrainConfig = field(default_factory=TrainConfig) | |
| def __post_init__(self): | |
| if self.train.wandb_run_name is None: | |
| self.train.wandb_run_name = f"eva02_large_emb{self.model.embedding_dim}" | |
| def get_config(): | |
| return Config() | |