iljung1106
Initial commit
546ff88
"""
Artist Style Embedding - Configuration
Maximum Performance Settings for RTX 5090
"""
from dataclasses import dataclass, field
from typing import Optional
import torch
@dataclass
class DataConfig:
# 데이터셋 경로
dataset_root: str = "./dataset"
dataset_face_root: str = "./dataset_face"
dataset_eyes_root: str = "./dataset_eyes"
# 이미지 설정
image_size: int = 224
min_images_per_artist: int = 3
# 데이터 분할
train_ratio: float = 0.8
val_ratio: float = 0.1
test_ratio: float = 0.1
# 데이터 로딩
num_workers: int = 12
pin_memory: bool = True
@dataclass
class ModelConfig:
# Backbone - EVA02-Large (최고 성능)
backbone: str = "eva02_large_patch14_clip_224"
backbone_pretrained: bool = True
freeze_backbone_epochs: int = 0 # 처음부터 unfreeze
# 임베딩 설정
embedding_dim: int = 512
hidden_dim: int = 1024
# Multi-branch 설정 - 모든 브랜치 활성화, 별도 백본
use_face_branch: bool = True
use_eye_branch: bool = True
share_backbone_weights: bool = False # 별도 백본으로 최고 성능
# Fusion 설정
fusion_type: str = "gated"
num_attention_heads: int = 8
# Dropout
dropout: float = 0.2 # 약간 높임
@dataclass
class LossConfig:
# ArcFace settings
arcface_scale: float = 64.0
arcface_margin: float = 0.5
arcface_weight: float = 0.2
# Multi-Similarity Loss weight
ms_loss_weight: float = 3.0
# Center Loss weight
center_loss_weight: float = 0.01
@dataclass
class TrainConfig:
# 학습 설정
epochs: int = 100
batch_size: int = 128
# Optimizer - 더 높은 learning rate
learning_rate: float = 5e-4 # 1e-4 → 5e-4
backbone_lr_multiplier: float = 0.2 # 0.1 → 0.2 (backbone도 더 학습)
weight_decay: float = 0.01 # 0.05 → 0.01 (regularization 줄임)
# Scheduler
warmup_epochs: int = 3 # 5 → 3
min_lr: float = 1e-6
# Gradient
max_grad_norm: float = 1.0
# Mixed precision
use_amp: bool = True
# 체크포인트
save_dir: str = "./checkpoints"
save_every_n_epochs: int = 5
# 로깅
log_every_n_steps: int = 50
wandb_project: Optional[str] = "artist-style-embedding"
wandb_run_name: Optional[str] = None
# Sampling
samples_per_class: int = 4
# Early stopping
patience: int = 20 # 더 오래 기다림
# Device
device: str = "cuda" if torch.cuda.is_available() else "cpu"
# Random seed
seed: int = 42
@dataclass
class Config:
data: DataConfig = field(default_factory=DataConfig)
model: ModelConfig = field(default_factory=ModelConfig)
loss: LossConfig = field(default_factory=LossConfig)
train: TrainConfig = field(default_factory=TrainConfig)
def __post_init__(self):
if self.train.wandb_run_name is None:
self.train.wandb_run_name = f"eva02_large_emb{self.model.embedding_dim}"
def get_config():
return Config()