iljung1106
Initial commit
546ff88
"""
Artist Style Embedding - Model Architecture
EVA02-Large based Multi-branch Style Encoder
"""
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
class EVA02Encoder(nn.Module):
"""
EVA02-Large backbone encoder
Pre-trained on CLIP, excellent for style features
"""
def __init__(
self,
pretrained: bool = True,
output_dim: int = 1024,
):
super().__init__()
self.backbone = timm.create_model(
"eva02_large_patch14_clip_224",
pretrained=pretrained,
num_classes=0,
)
# EVA02-Large output: 1024
self.feature_dim = 1024
if self.feature_dim != output_dim:
self.proj = nn.Sequential(
nn.Linear(self.feature_dim, output_dim),
nn.LayerNorm(output_dim),
nn.GELU(),
)
else:
self.proj = nn.Identity()
self.output_dim = output_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.backbone(x)
return self.proj(features)
class GatedFusion(nn.Module):
"""Gated attention fusion for multi-branch features"""
def __init__(self, input_dim: int, num_branches: int = 3):
super().__init__()
self.num_branches = num_branches
self.gate = nn.Sequential(
nn.Linear(input_dim * num_branches, input_dim),
nn.ReLU(),
nn.Linear(input_dim, num_branches),
nn.Softmax(dim=-1),
)
def forward(
self,
features: torch.Tensor, # [B, num_branches, dim]
mask: Optional[torch.Tensor] = None, # [B, num_branches]
) -> torch.Tensor:
B, N, D = features.shape
concat_features = features.view(B, -1)
gates = self.gate(concat_features)
if mask is not None:
gates = gates * mask.float()
gates = gates / (gates.sum(dim=-1, keepdim=True) + 1e-8)
gates = gates.unsqueeze(-1)
fused = (features * gates).sum(dim=1)
return fused
class StyleEmbeddingHead(nn.Module):
"""Final embedding projection head"""
def __init__(
self,
input_dim: int,
embedding_dim: int = 512,
hidden_dim: int = 1024,
dropout: float = 0.1,
):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, embedding_dim),
)
self.final_norm = nn.LayerNorm(embedding_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.mlp(x)
x = self.final_norm(x)
x = F.normalize(x, p=2, dim=-1)
return x
class MultiBranchStyleEncoder(nn.Module):
"""
Multi-branch style encoder with separate EVA02-Large backbones
- Full image branch
- Face crop branch
- Eye crop branch
"""
def __init__(
self,
embedding_dim: int = 512,
hidden_dim: int = 1024,
dropout: float = 0.1,
):
super().__init__()
self.shared_backbone = EVA02Encoder(pretrained=True, output_dim=hidden_dim)
# Gated Fusion
self.fusion = GatedFusion(hidden_dim, num_branches=3)
# Embedding head
self.embedding_head = StyleEmbeddingHead(
hidden_dim, embedding_dim, hidden_dim, dropout
)
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
def forward(
self,
full: torch.Tensor,
face: torch.Tensor,
eye: torch.Tensor,
has_face: torch.Tensor,
has_eye: torch.Tensor,
) -> torch.Tensor:
B = full.shape[0]
device = full.device
# Encode all branches
full_features = self.shared_backbone(full)
face_features = self.shared_backbone(face) * has_face.unsqueeze(-1)
eye_features = self.shared_backbone(eye) * has_eye.unsqueeze(-1)
# Stack and create mask
stacked = torch.stack([full_features, face_features, eye_features], dim=1)
mask = torch.stack([
torch.ones(B, device=device, dtype=torch.bool),
has_face,
has_eye,
], dim=1)
# Fusion
fused = self.fusion(stacked, mask)
# Embedding
embeddings = self.embedding_head(fused)
return embeddings
def get_backbone_params(self):
"""Returns parameters of the single shared backbone"""
# Assuming you named your shared backbone 'self.shared_encoder'
return self.shared_backbone.parameters()
def get_head_params(self):
"""Returns parameters of all heads and fusion layers"""
params = []
# Add the fusion layer
params.extend(self.fusion.parameters())
# Add the final embedding MLP
params.extend(self.embedding_head.parameters())
# Note: If your shared_encoder has a projection head (proj),
# that is usually trained with the backbone, so we don't add it here.
return params
def freeze_backbone(self):
"""Freezes the single shared backbone"""
for param in self.get_backbone_params():
param.requires_grad = False
self.shared_backbone.eval() # Optional: keep BN stats frozen
def unfreeze_backbone(self):
"""Unfreezes the single shared backbone"""
for param in self.get_backbone_params():
param.requires_grad = True
self.shared_backbone.train()
class ArtistStyleModel(nn.Module):
"""
Complete model: Multi-branch Encoder + ArcFace Head
"""
def __init__(
self,
num_classes: int,
embedding_dim: int = 512,
hidden_dim: int = 1024,
dropout: float = 0.1,
):
super().__init__()
self.num_classes = num_classes
self.embedding_dim = embedding_dim
# Style encoder
self.encoder = MultiBranchStyleEncoder(
embedding_dim=embedding_dim,
hidden_dim=hidden_dim,
dropout=dropout,
)
# ArcFace weight
self.arcface_weight = nn.Parameter(
torch.FloatTensor(num_classes, embedding_dim)
)
nn.init.xavier_uniform_(self.arcface_weight)
def forward(
self,
full: torch.Tensor,
face: torch.Tensor,
eye: torch.Tensor,
has_face: torch.Tensor,
has_eye: torch.Tensor,
) -> Dict[str, torch.Tensor]:
embeddings = self.encoder(full, face, eye, has_face, has_eye)
# Cosine similarity with normalized weights
normalized_weights = F.normalize(self.arcface_weight, p=2, dim=1)
cosine = F.linear(embeddings, normalized_weights)
return {
'embeddings': embeddings,
'cosine': cosine,
}
def get_embeddings(
self,
full: torch.Tensor,
face: torch.Tensor,
eye: torch.Tensor,
has_face: torch.Tensor,
has_eye: torch.Tensor,
) -> torch.Tensor:
return self.encoder(full, face, eye, has_face, has_eye)
def create_model(config, num_classes: int) -> ArtistStyleModel:
"""Create model from config"""
return ArtistStyleModel(
num_classes=num_classes,
embedding_dim=config.model.embedding_dim,
hidden_dim=config.model.hidden_dim,
dropout=config.model.dropout,
)