Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Artist Style Embedding - Model Architecture | |
| EVA02-Large based Multi-branch Style Encoder | |
| """ | |
| from typing import Dict, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import timm | |
| class EVA02Encoder(nn.Module): | |
| """ | |
| EVA02-Large backbone encoder | |
| Pre-trained on CLIP, excellent for style features | |
| """ | |
| def __init__( | |
| self, | |
| pretrained: bool = True, | |
| output_dim: int = 1024, | |
| ): | |
| super().__init__() | |
| self.backbone = timm.create_model( | |
| "eva02_large_patch14_clip_224", | |
| pretrained=pretrained, | |
| num_classes=0, | |
| ) | |
| # EVA02-Large output: 1024 | |
| self.feature_dim = 1024 | |
| if self.feature_dim != output_dim: | |
| self.proj = nn.Sequential( | |
| nn.Linear(self.feature_dim, output_dim), | |
| nn.LayerNorm(output_dim), | |
| nn.GELU(), | |
| ) | |
| else: | |
| self.proj = nn.Identity() | |
| self.output_dim = output_dim | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| features = self.backbone(x) | |
| return self.proj(features) | |
| class GatedFusion(nn.Module): | |
| """Gated attention fusion for multi-branch features""" | |
| def __init__(self, input_dim: int, num_branches: int = 3): | |
| super().__init__() | |
| self.num_branches = num_branches | |
| self.gate = nn.Sequential( | |
| nn.Linear(input_dim * num_branches, input_dim), | |
| nn.ReLU(), | |
| nn.Linear(input_dim, num_branches), | |
| nn.Softmax(dim=-1), | |
| ) | |
| def forward( | |
| self, | |
| features: torch.Tensor, # [B, num_branches, dim] | |
| mask: Optional[torch.Tensor] = None, # [B, num_branches] | |
| ) -> torch.Tensor: | |
| B, N, D = features.shape | |
| concat_features = features.view(B, -1) | |
| gates = self.gate(concat_features) | |
| if mask is not None: | |
| gates = gates * mask.float() | |
| gates = gates / (gates.sum(dim=-1, keepdim=True) + 1e-8) | |
| gates = gates.unsqueeze(-1) | |
| fused = (features * gates).sum(dim=1) | |
| return fused | |
| class StyleEmbeddingHead(nn.Module): | |
| """Final embedding projection head""" | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| embedding_dim: int = 512, | |
| hidden_dim: int = 1024, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, embedding_dim), | |
| ) | |
| self.final_norm = nn.LayerNorm(embedding_dim) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.mlp(x) | |
| x = self.final_norm(x) | |
| x = F.normalize(x, p=2, dim=-1) | |
| return x | |
| class MultiBranchStyleEncoder(nn.Module): | |
| """ | |
| Multi-branch style encoder with separate EVA02-Large backbones | |
| - Full image branch | |
| - Face crop branch | |
| - Eye crop branch | |
| """ | |
| def __init__( | |
| self, | |
| embedding_dim: int = 512, | |
| hidden_dim: int = 1024, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.shared_backbone = EVA02Encoder(pretrained=True, output_dim=hidden_dim) | |
| # Gated Fusion | |
| self.fusion = GatedFusion(hidden_dim, num_branches=3) | |
| # Embedding head | |
| self.embedding_head = StyleEmbeddingHead( | |
| hidden_dim, embedding_dim, hidden_dim, dropout | |
| ) | |
| self.embedding_dim = embedding_dim | |
| self.hidden_dim = hidden_dim | |
| def forward( | |
| self, | |
| full: torch.Tensor, | |
| face: torch.Tensor, | |
| eye: torch.Tensor, | |
| has_face: torch.Tensor, | |
| has_eye: torch.Tensor, | |
| ) -> torch.Tensor: | |
| B = full.shape[0] | |
| device = full.device | |
| # Encode all branches | |
| full_features = self.shared_backbone(full) | |
| face_features = self.shared_backbone(face) * has_face.unsqueeze(-1) | |
| eye_features = self.shared_backbone(eye) * has_eye.unsqueeze(-1) | |
| # Stack and create mask | |
| stacked = torch.stack([full_features, face_features, eye_features], dim=1) | |
| mask = torch.stack([ | |
| torch.ones(B, device=device, dtype=torch.bool), | |
| has_face, | |
| has_eye, | |
| ], dim=1) | |
| # Fusion | |
| fused = self.fusion(stacked, mask) | |
| # Embedding | |
| embeddings = self.embedding_head(fused) | |
| return embeddings | |
| def get_backbone_params(self): | |
| """Returns parameters of the single shared backbone""" | |
| # Assuming you named your shared backbone 'self.shared_encoder' | |
| return self.shared_backbone.parameters() | |
| def get_head_params(self): | |
| """Returns parameters of all heads and fusion layers""" | |
| params = [] | |
| # Add the fusion layer | |
| params.extend(self.fusion.parameters()) | |
| # Add the final embedding MLP | |
| params.extend(self.embedding_head.parameters()) | |
| # Note: If your shared_encoder has a projection head (proj), | |
| # that is usually trained with the backbone, so we don't add it here. | |
| return params | |
| def freeze_backbone(self): | |
| """Freezes the single shared backbone""" | |
| for param in self.get_backbone_params(): | |
| param.requires_grad = False | |
| self.shared_backbone.eval() # Optional: keep BN stats frozen | |
| def unfreeze_backbone(self): | |
| """Unfreezes the single shared backbone""" | |
| for param in self.get_backbone_params(): | |
| param.requires_grad = True | |
| self.shared_backbone.train() | |
| class ArtistStyleModel(nn.Module): | |
| """ | |
| Complete model: Multi-branch Encoder + ArcFace Head | |
| """ | |
| def __init__( | |
| self, | |
| num_classes: int, | |
| embedding_dim: int = 512, | |
| hidden_dim: int = 1024, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.embedding_dim = embedding_dim | |
| # Style encoder | |
| self.encoder = MultiBranchStyleEncoder( | |
| embedding_dim=embedding_dim, | |
| hidden_dim=hidden_dim, | |
| dropout=dropout, | |
| ) | |
| # ArcFace weight | |
| self.arcface_weight = nn.Parameter( | |
| torch.FloatTensor(num_classes, embedding_dim) | |
| ) | |
| nn.init.xavier_uniform_(self.arcface_weight) | |
| def forward( | |
| self, | |
| full: torch.Tensor, | |
| face: torch.Tensor, | |
| eye: torch.Tensor, | |
| has_face: torch.Tensor, | |
| has_eye: torch.Tensor, | |
| ) -> Dict[str, torch.Tensor]: | |
| embeddings = self.encoder(full, face, eye, has_face, has_eye) | |
| # Cosine similarity with normalized weights | |
| normalized_weights = F.normalize(self.arcface_weight, p=2, dim=1) | |
| cosine = F.linear(embeddings, normalized_weights) | |
| return { | |
| 'embeddings': embeddings, | |
| 'cosine': cosine, | |
| } | |
| def get_embeddings( | |
| self, | |
| full: torch.Tensor, | |
| face: torch.Tensor, | |
| eye: torch.Tensor, | |
| has_face: torch.Tensor, | |
| has_eye: torch.Tensor, | |
| ) -> torch.Tensor: | |
| return self.encoder(full, face, eye, has_face, has_eye) | |
| def create_model(config, num_classes: int) -> ArtistStyleModel: | |
| """Create model from config""" | |
| return ArtistStyleModel( | |
| num_classes=num_classes, | |
| embedding_dim=config.model.embedding_dim, | |
| hidden_dim=config.model.hidden_dim, | |
| dropout=config.model.dropout, | |
| ) | |