File size: 2,568 Bytes
5284c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
import torch.nn.functional as F


class RanTSelecor(nn.Module):
    def __init__(
        self,
        num_features=1408,
        in_token_num=576,
        out_token_num=144,
        min_out_token_num=16,
        max_video_tokens=2048,
        fix_random=False,
        **kwargs,
    ):
        super().__init__()
        self.in_token_num = in_token_num
        self.out_token_num = out_token_num
        self.min_out_token_num = min_out_token_num
        self.max_video_tokens = max_video_tokens
        self.pre_proj = nn.Linear(num_features, 1)
        self.score_proj = nn.Linear(in_token_num, out_token_num)

        self.trans_weight = nn.Parameter(
            torch.zeros(in_token_num, in_token_num), requires_grad=True
        )
        self.trans_bias = nn.Parameter(self._init_bias(), requires_grad=True)

        self.hidden_size = num_features
        self._init_bias()
        self.fix_random = fix_random

    def _init_bias(self):
        tensor = torch.zeros(self.in_token_num)
        m = self.out_token_num
        n = tensor.numel()
        if m > n:
            m = n
        count = 0

        n_group = n // m
        interval = max(4, n_group)

        for j in range(0, interval):
            for i in range(j, n, interval):
                tensor[i] = 0.02
                count += 1
                if count >= m:
                    break
            if count >= m:
                break
        return tensor

    def forward(self, image_embeds, n_frames, noise_epsilon=0.001):
        image_embeds_list = image_embeds.split(n_frames, dim=0)
        # Compute temporal tokens as the mean along the time axis
        ret_tokens = []
        for image_embeds_per_video in image_embeds_list:
            video_raw_token = image_embeds_per_video

            video_raw_token_trans = self.pre_proj(video_raw_token).mT
            video_token_logits = (
                video_raw_token_trans @ self.trans_weight + self.trans_bias
            )

            video_token_logits = video_token_logits.squeeze(1)

            video_token_scores = F.softmax(video_token_logits)
            topk_indices = torch.argsort(video_token_scores, descending=True)[:, :self.out_token_num]
            topk_indices, _ = torch.sort(topk_indices)
            video_topk_token = video_raw_token[
                torch.arange(video_raw_token.size(0)).unsqueeze(1), topk_indices
            ]
            ret_tokens.append(video_topk_token)
        return ret_tokens


def build_adapter(config):
    return RanTSelecor(**config)