""" Artist Style Embedding - Model Architecture EVA02-Large based Multi-branch Style Encoder """ from typing import Dict, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F import timm class EVA02Encoder(nn.Module): """ EVA02-Large backbone encoder Pre-trained on CLIP, excellent for style features """ def __init__( self, pretrained: bool = True, output_dim: int = 1024, ): super().__init__() self.backbone = timm.create_model( "eva02_large_patch14_clip_224", pretrained=pretrained, num_classes=0, ) # EVA02-Large output: 1024 self.feature_dim = 1024 if self.feature_dim != output_dim: self.proj = nn.Sequential( nn.Linear(self.feature_dim, output_dim), nn.LayerNorm(output_dim), nn.GELU(), ) else: self.proj = nn.Identity() self.output_dim = output_dim def forward(self, x: torch.Tensor) -> torch.Tensor: features = self.backbone(x) return self.proj(features) class GatedFusion(nn.Module): """Gated attention fusion for multi-branch features""" def __init__(self, input_dim: int, num_branches: int = 3): super().__init__() self.num_branches = num_branches self.gate = nn.Sequential( nn.Linear(input_dim * num_branches, input_dim), nn.ReLU(), nn.Linear(input_dim, num_branches), nn.Softmax(dim=-1), ) def forward( self, features: torch.Tensor, # [B, num_branches, dim] mask: Optional[torch.Tensor] = None, # [B, num_branches] ) -> torch.Tensor: B, N, D = features.shape concat_features = features.view(B, -1) gates = self.gate(concat_features) if mask is not None: gates = gates * mask.float() gates = gates / (gates.sum(dim=-1, keepdim=True) + 1e-8) gates = gates.unsqueeze(-1) fused = (features * gates).sum(dim=1) return fused class StyleEmbeddingHead(nn.Module): """Final embedding projection head""" def __init__( self, input_dim: int, embedding_dim: int = 512, hidden_dim: int = 1024, dropout: float = 0.1, ): super().__init__() self.mlp = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, embedding_dim), ) self.final_norm = nn.LayerNorm(embedding_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.mlp(x) x = self.final_norm(x) x = F.normalize(x, p=2, dim=-1) return x class MultiBranchStyleEncoder(nn.Module): """ Multi-branch style encoder with separate EVA02-Large backbones - Full image branch - Face crop branch - Eye crop branch """ def __init__( self, embedding_dim: int = 512, hidden_dim: int = 1024, dropout: float = 0.1, ): super().__init__() self.shared_backbone = EVA02Encoder(pretrained=True, output_dim=hidden_dim) # Gated Fusion self.fusion = GatedFusion(hidden_dim, num_branches=3) # Embedding head self.embedding_head = StyleEmbeddingHead( hidden_dim, embedding_dim, hidden_dim, dropout ) self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim def forward( self, full: torch.Tensor, face: torch.Tensor, eye: torch.Tensor, has_face: torch.Tensor, has_eye: torch.Tensor, ) -> torch.Tensor: B = full.shape[0] device = full.device # Encode all branches full_features = self.shared_backbone(full) face_features = self.shared_backbone(face) * has_face.unsqueeze(-1) eye_features = self.shared_backbone(eye) * has_eye.unsqueeze(-1) # Stack and create mask stacked = torch.stack([full_features, face_features, eye_features], dim=1) mask = torch.stack([ torch.ones(B, device=device, dtype=torch.bool), has_face, has_eye, ], dim=1) # Fusion fused = self.fusion(stacked, mask) # Embedding embeddings = self.embedding_head(fused) return embeddings def get_backbone_params(self): """Returns parameters of the single shared backbone""" # Assuming you named your shared backbone 'self.shared_encoder' return self.shared_backbone.parameters() def get_head_params(self): """Returns parameters of all heads and fusion layers""" params = [] # Add the fusion layer params.extend(self.fusion.parameters()) # Add the final embedding MLP params.extend(self.embedding_head.parameters()) # Note: If your shared_encoder has a projection head (proj), # that is usually trained with the backbone, so we don't add it here. return params def freeze_backbone(self): """Freezes the single shared backbone""" for param in self.get_backbone_params(): param.requires_grad = False self.shared_backbone.eval() # Optional: keep BN stats frozen def unfreeze_backbone(self): """Unfreezes the single shared backbone""" for param in self.get_backbone_params(): param.requires_grad = True self.shared_backbone.train() class ArtistStyleModel(nn.Module): """ Complete model: Multi-branch Encoder + ArcFace Head """ def __init__( self, num_classes: int, embedding_dim: int = 512, hidden_dim: int = 1024, dropout: float = 0.1, ): super().__init__() self.num_classes = num_classes self.embedding_dim = embedding_dim # Style encoder self.encoder = MultiBranchStyleEncoder( embedding_dim=embedding_dim, hidden_dim=hidden_dim, dropout=dropout, ) # ArcFace weight self.arcface_weight = nn.Parameter( torch.FloatTensor(num_classes, embedding_dim) ) nn.init.xavier_uniform_(self.arcface_weight) def forward( self, full: torch.Tensor, face: torch.Tensor, eye: torch.Tensor, has_face: torch.Tensor, has_eye: torch.Tensor, ) -> Dict[str, torch.Tensor]: embeddings = self.encoder(full, face, eye, has_face, has_eye) # Cosine similarity with normalized weights normalized_weights = F.normalize(self.arcface_weight, p=2, dim=1) cosine = F.linear(embeddings, normalized_weights) return { 'embeddings': embeddings, 'cosine': cosine, } def get_embeddings( self, full: torch.Tensor, face: torch.Tensor, eye: torch.Tensor, has_face: torch.Tensor, has_eye: torch.Tensor, ) -> torch.Tensor: return self.encoder(full, face, eye, has_face, has_eye) def create_model(config, num_classes: int) -> ArtistStyleModel: """Create model from config""" return ArtistStyleModel( num_classes=num_classes, embedding_dim=config.model.embedding_dim, hidden_dim=config.model.hidden_dim, dropout=config.model.dropout, )