Superclass / open_clip /cls_model.py
speedinghzl's picture
Upload folder using huggingface_hub
11a6d1d verified
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from dataclasses import dataclass
from .transformer import (
LayerNormFp32,
LayerNorm,
QuickGELU,
MultimodalTransformer,
MixClsHead,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
@dataclass
class ClassHeadCfg(CLIPTextCfg):
cls_mlp_ratio: int = 4
cls_layers: int = 1
def _build_cls_head(
width,
embed_dim,
clshead_cfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
clshead_cfg = ClassHeadCfg(**clshead_cfg) if isinstance(clshead_cfg, dict) else clshead_cfg
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)
head = MixClsHead(
width=width,
embed_dim=embed_dim,
layers=clshead_cfg.cls_layers,
mlp_ratio=clshead_cfg.cls_mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
output_dim=clshead_cfg.vocab_size,
)
return head
class Classifier(nn.Module):
def __init__(
self,
embed_dim,
text_cfg: CLIPTextCfg,
vision_cfg: CLIPVisionCfg,
init_logit_scale: float = np.log(1 / 0.07),
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
super().__init__()
text_cfg = ClassHeadCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
self.visual = _build_vision_tower(0, vision_cfg, quick_gelu, cast_dtype)
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.context_length = self.text.context_length
self.vocab_size = self.text.vocab_size
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
self.head = _build_cls_head(
vision_cfg.width,
embed_dim,
clshead_cfg=text_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
self.register_buffer("cap_fq", torch.zeros([1, self.vocab_size], dtype=torch.float64))
self.register_buffer("num_samples", torch.zeros([1, 1], dtype=torch.float64))
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
# self.text_decoder.set_grad_checkpointing(enable)
def encode_image(self, images, normalize=False, return_all=False):
image_features = self.visual(images)
image_features, logits = self.head(image_features)
image_features = F.normalize(image_features, dim=-1) if normalize else image_features
if return_all:
return image_features, logits
return image_features
def encode_text(self, text, normalize=False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features
def forward(self, image=None, text=None):
image_features = self.encode_image(image, normalize=True, return_all=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None
labels = text.clone()
return {
"cap_fq": self.cap_fq,
"num_samples": self.num_samples,
"image_features": image_features,
"text_features": text_features,
"labels": labels,
"logit_scale": self.logit_scale.exp(),
}