File size: 3,746 Bytes
11a6d1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from typing import Optional

import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from dataclasses import dataclass

from .transformer import (
    LayerNormFp32,
    LayerNorm,
    QuickGELU,
    MultimodalTransformer,
    MixClsHead,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower


@dataclass
class ClassHeadCfg(CLIPTextCfg):
    cls_mlp_ratio: int = 4
    cls_layers: int = 1


def _build_cls_head(
        width,
        embed_dim,
        clshead_cfg,
        quick_gelu: bool = False,
        cast_dtype: Optional[torch.dtype] = None,
):
    clshead_cfg = ClassHeadCfg(**clshead_cfg) if isinstance(clshead_cfg, dict) else clshead_cfg
    act_layer = QuickGELU if quick_gelu else nn.GELU
    norm_layer = (
        LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
    )

    head = MixClsHead(
        width=width,
        embed_dim=embed_dim,
        layers=clshead_cfg.cls_layers,
        mlp_ratio=clshead_cfg.cls_mlp_ratio,
        act_layer=act_layer,
        norm_layer=norm_layer,
        output_dim=clshead_cfg.vocab_size,
    )

    return head


class Classifier(nn.Module):
    def __init__(
            self,
            embed_dim,
            text_cfg: CLIPTextCfg,
            vision_cfg: CLIPVisionCfg,
            init_logit_scale: float = np.log(1 / 0.07),
            quick_gelu: bool = False,
            cast_dtype: Optional[torch.dtype] = None,
    ):
        super().__init__()
        text_cfg = ClassHeadCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
        vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg

        self.visual = _build_vision_tower(0, vision_cfg, quick_gelu, cast_dtype)
        self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
        self.context_length = self.text.context_length
        self.vocab_size = self.text.vocab_size
        self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)

        self.head = _build_cls_head(
            vision_cfg.width,
            embed_dim,
            clshead_cfg=text_cfg,
            quick_gelu=quick_gelu,
            cast_dtype=cast_dtype,
        )

        self.register_buffer("cap_fq", torch.zeros([1, self.vocab_size], dtype=torch.float64))
        self.register_buffer("num_samples", torch.zeros([1, 1], dtype=torch.float64))

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.visual.set_grad_checkpointing(enable)
        self.text.set_grad_checkpointing(enable)
        # self.text_decoder.set_grad_checkpointing(enable)

    def encode_image(self, images, normalize=False, return_all=False):
        image_features = self.visual(images)
        image_features, logits = self.head(image_features)
        image_features = F.normalize(image_features, dim=-1) if normalize else image_features
        if return_all:
            return image_features, logits
        return image_features

    def encode_text(self, text, normalize=False):
        features = self.text(text)
        return F.normalize(features, dim=-1) if normalize else features

    def forward(self, image=None, text=None):
        image_features = self.encode_image(image, normalize=True, return_all=True) if image is not None else None

        text_features = self.encode_text(text, normalize=True) if text is not None else None
        labels = text.clone()

        return {
            "cap_fq": self.cap_fq,
            "num_samples": self.num_samples,
            "image_features": image_features,
            "text_features": text_features,
            "labels": labels,
            "logit_scale": self.logit_scale.exp(),
        }