|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class RanTSelecor(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
num_features=1408, |
|
|
in_token_num=576, |
|
|
out_token_num=144, |
|
|
min_out_token_num=16, |
|
|
max_video_tokens=2048, |
|
|
fix_random=False, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
self.in_token_num = in_token_num |
|
|
self.out_token_num = out_token_num |
|
|
self.min_out_token_num = min_out_token_num |
|
|
self.max_video_tokens = max_video_tokens |
|
|
self.pre_proj = nn.Linear(num_features, 1) |
|
|
self.score_proj = nn.Linear(in_token_num, out_token_num) |
|
|
|
|
|
self.trans_weight = nn.Parameter( |
|
|
torch.zeros(in_token_num, in_token_num), requires_grad=True |
|
|
) |
|
|
self.trans_bias = nn.Parameter(self._init_bias(), requires_grad=True) |
|
|
|
|
|
self.hidden_size = num_features |
|
|
self._init_bias() |
|
|
self.fix_random = fix_random |
|
|
|
|
|
def _init_bias(self): |
|
|
tensor = torch.zeros(self.in_token_num) |
|
|
m = self.out_token_num |
|
|
n = tensor.numel() |
|
|
if m > n: |
|
|
m = n |
|
|
count = 0 |
|
|
|
|
|
n_group = n // m |
|
|
interval = max(4, n_group) |
|
|
|
|
|
for j in range(0, interval): |
|
|
for i in range(j, n, interval): |
|
|
tensor[i] = 0.02 |
|
|
count += 1 |
|
|
if count >= m: |
|
|
break |
|
|
if count >= m: |
|
|
break |
|
|
return tensor |
|
|
|
|
|
def forward(self, image_embeds, n_frames, noise_epsilon=0.001): |
|
|
image_embeds_list = image_embeds.split(n_frames, dim=0) |
|
|
|
|
|
ret_tokens = [] |
|
|
for image_embeds_per_video in image_embeds_list: |
|
|
video_raw_token = image_embeds_per_video |
|
|
|
|
|
video_raw_token_trans = self.pre_proj(video_raw_token).mT |
|
|
video_token_logits = ( |
|
|
video_raw_token_trans @ self.trans_weight + self.trans_bias |
|
|
) |
|
|
|
|
|
video_token_logits = video_token_logits.squeeze(1) |
|
|
|
|
|
video_token_scores = F.softmax(video_token_logits) |
|
|
topk_indices = torch.argsort(video_token_scores, descending=True)[:, :self.out_token_num] |
|
|
topk_indices, _ = torch.sort(topk_indices) |
|
|
video_topk_token = video_raw_token[ |
|
|
torch.arange(video_raw_token.size(0)).unsqueeze(1), topk_indices |
|
|
] |
|
|
ret_tokens.append(video_topk_token) |
|
|
return ret_tokens |
|
|
|
|
|
|
|
|
def build_adapter(config): |
|
|
return RanTSelecor(**config) |
|
|
|