|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
import numpy as np |
|
|
from dataclasses import dataclass |
|
|
|
|
|
from .transformer import ( |
|
|
LayerNormFp32, |
|
|
LayerNorm, |
|
|
QuickGELU, |
|
|
MultimodalTransformer, |
|
|
MixClsHead, |
|
|
) |
|
|
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ClassHeadCfg(CLIPTextCfg): |
|
|
cls_mlp_ratio: int = 4 |
|
|
cls_layers: int = 1 |
|
|
|
|
|
|
|
|
def _build_cls_head( |
|
|
width, |
|
|
embed_dim, |
|
|
clshead_cfg, |
|
|
quick_gelu: bool = False, |
|
|
cast_dtype: Optional[torch.dtype] = None, |
|
|
): |
|
|
clshead_cfg = ClassHeadCfg(**clshead_cfg) if isinstance(clshead_cfg, dict) else clshead_cfg |
|
|
act_layer = QuickGELU if quick_gelu else nn.GELU |
|
|
norm_layer = ( |
|
|
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm |
|
|
) |
|
|
|
|
|
head = MixClsHead( |
|
|
width=width, |
|
|
embed_dim=embed_dim, |
|
|
layers=clshead_cfg.cls_layers, |
|
|
mlp_ratio=clshead_cfg.cls_mlp_ratio, |
|
|
act_layer=act_layer, |
|
|
norm_layer=norm_layer, |
|
|
output_dim=clshead_cfg.vocab_size, |
|
|
) |
|
|
|
|
|
return head |
|
|
|
|
|
|
|
|
class Classifier(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embed_dim, |
|
|
text_cfg: CLIPTextCfg, |
|
|
vision_cfg: CLIPVisionCfg, |
|
|
init_logit_scale: float = np.log(1 / 0.07), |
|
|
quick_gelu: bool = False, |
|
|
cast_dtype: Optional[torch.dtype] = None, |
|
|
): |
|
|
super().__init__() |
|
|
text_cfg = ClassHeadCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg |
|
|
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg |
|
|
|
|
|
self.visual = _build_vision_tower(0, vision_cfg, quick_gelu, cast_dtype) |
|
|
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) |
|
|
self.context_length = self.text.context_length |
|
|
self.vocab_size = self.text.vocab_size |
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) |
|
|
|
|
|
self.head = _build_cls_head( |
|
|
vision_cfg.width, |
|
|
embed_dim, |
|
|
clshead_cfg=text_cfg, |
|
|
quick_gelu=quick_gelu, |
|
|
cast_dtype=cast_dtype, |
|
|
) |
|
|
|
|
|
self.register_buffer("cap_fq", torch.zeros([1, self.vocab_size], dtype=torch.float64)) |
|
|
self.register_buffer("num_samples", torch.zeros([1, 1], dtype=torch.float64)) |
|
|
|
|
|
@torch.jit.ignore |
|
|
def set_grad_checkpointing(self, enable=True): |
|
|
self.visual.set_grad_checkpointing(enable) |
|
|
self.text.set_grad_checkpointing(enable) |
|
|
|
|
|
|
|
|
def encode_image(self, images, normalize=False, return_all=False): |
|
|
image_features = self.visual(images) |
|
|
image_features, logits = self.head(image_features) |
|
|
image_features = F.normalize(image_features, dim=-1) if normalize else image_features |
|
|
if return_all: |
|
|
return image_features, logits |
|
|
return image_features |
|
|
|
|
|
def encode_text(self, text, normalize=False): |
|
|
features = self.text(text) |
|
|
return F.normalize(features, dim=-1) if normalize else features |
|
|
|
|
|
def forward(self, image=None, text=None): |
|
|
image_features = self.encode_image(image, normalize=True, return_all=True) if image is not None else None |
|
|
|
|
|
text_features = self.encode_text(text, normalize=True) if text is not None else None |
|
|
labels = text.clone() |
|
|
|
|
|
return { |
|
|
"cap_fq": self.cap_fq, |
|
|
"num_samples": self.num_samples, |
|
|
"image_features": image_features, |
|
|
"text_features": text_features, |
|
|
"labels": labels, |
|
|
"logit_scale": self.logit_scale.exp(), |
|
|
} |