omnivinci / media_encoder.py
leoye's picture
Initial commit
fd01e7c
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
from torch.nn import Module, ModuleList
import numpy as np
from einops import rearrange, repeat
from torch.cuda.amp import autocast
from torch import nn, einsum, broadcast_tensors, Tensor
from beartype import beartype
from beartype.typing import Literal, Union, Optional
from math import pi, log
import math
class CacheFeatures(object):
def __init__(self, value, type):
self.value = value
self.type = type
def my_to(self, device, dtype):
self.value['features'] = self.value['features'].to(device, dtype) if 'features' in self.value and self.value['features'] is not None else None
return self
def __call__(self):
return self.value
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# broadcat, as tortoise-tts was using it
def broadcat(tensors, dim = -1):
broadcasted_tensors = broadcast_tensors(*tensors)
def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor:
# return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1)
# Reshape x to group elements along the specified dimension into chunks of 'size', then average over those chunks.
# Check if the dimension is divisible by the pool size, if not pad with mean values
if x.shape[dim] % size != 0:
print(f"Warning: dimension {dim} with size {x.shape[dim]} is not divisible by pool size {size}, padding with mean values")
remainder = x.shape[dim] % size
pad_len = size - remainder
# Get the mean of the last few elements along the dimension to be pooled
last_elements = x.narrow(dim, x.shape[dim] - remainder, remainder)
mean_value = last_elements.mean()
# Create padding tensor with the same shape as x except for the dimension being pooled
pad_shape = list(x.shape)
pad_shape[dim] = pad_len
padding = torch.ones(pad_shape, device=x.device, dtype=x.dtype) * mean_value
# Concatenate the original tensor with the padding along the specified dimension
x = torch.cat([x, padding], dim=dim)
shape_before = x.shape[:dim]
shape_after = x.shape[dim + 1 :]
new_shape = shape_before + (-1, size) + shape_after
x_reshaped = x.view(new_shape)
return x_reshaped.mean(dim + 1)
def rotate_half(x):
x = rearrange(x, '... (d r) -> ... d r', r = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d r -> ... (d r)')
def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2):
with torch.amp.autocast(device_type='cuda', enabled=False):
ori_dtype = t.dtype
embed_dtype = torch.float64
t = t.to(embed_dtype)
if t.ndim == 3:
seq_len = t.shape[seq_dim]
freqs = freqs[-seq_len:].to(t)
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
return torch.cat((t_left, t, t_right), dim = -1).to(ori_dtype)
class MaxTimeContinuousTimeRotaryEmbedding(nn.Module):
def __init__(self, dim, max_time, period_mode="shortest", device=None):
super().__init__()
assert dim % 2 == 0, "RoPE embedding dimension must be even"
# Set max period = max_time
if period_mode == "shortest": # shortest period is max_time
base = 5
inv_freq = 2 * math.pi / (max_time * (base ** (torch.arange(0, dim // 2).float() / (dim // 2))))
elif period_mode == "longest": # longest period is max_time ** ((dim // 2) / (dim // 2 - 1))
theta = max_time ** ((dim // 2) / (dim // 2 - 1))
inv_freq = 2 * math.pi / ((theta ** (torch.arange(0, dim // 2).float() / (dim // 2))))
else:
raise ValueError(f"Invalid period mode: {period_mode}")
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, time_values: torch.Tensor):
"""
time_values: [batch_size, seq_len], in seconds (or any continuous unit)
Returns:
cos, sin: [batch_size, seq_len, dim]
"""
batch_size, seq_len = time_values.shape
time_values_exp = time_values[:, None, :] # [batch, 1, seq_len]
freqs = (self.inv_freq[None, :, None] @ time_values_exp).transpose(1, 2) # [batch, seq_len, dim//2]
# emb = torch.cat([freqs, freqs], dim=-1) # [batch, seq_len, dim]
# return emb.cos(), emb.sin()
return freqs
def get_axial_freqs(self, *dims):
Colon = slice(None)
all_freqs = []
for ind, dim in enumerate(dims):
pos = torch.arange(dim, device = self.device)
freqs = self.forward(pos, seq_len = dim)
all_axis = [None] * len(dims)
all_axis[ind] = Colon
new_axis_slice = (Ellipsis, *all_axis, Colon)
all_freqs.append(freqs[new_axis_slice])
all_freqs = broadcast_tensors(*all_freqs)
return torch.cat(all_freqs, dim = -1)
class RotaryEmbedding(Module):
@beartype
def __init__(
self,
dim,
custom_freqs: Optional[Tensor] = None,
freqs_for: Union[Literal['lang', 'pixel', 'constant']] = 'lang',
theta = 10000,
max_freq = 10,
num_freqs = 1,
learned_freq = False,
use_xpos = False,
xpos_scale_base = 512,
interpolate_factor = 1.,
theta_rescale_factor = 1.,
seq_before_head_dim = False,
cache_if_possible = True,
max_time = None
):
super().__init__()
self.dim = dim
self.freqs_for = freqs_for
self.max_freq = max_freq
self.num_freqs = num_freqs
self.learned_freq = learned_freq
self.use_xpos = use_xpos
self.xpos_scale_base = xpos_scale_base
self.interpolate_factor = interpolate_factor
self.theta_rescale_factor = theta_rescale_factor
self.cache_if_possible = cache_if_possible
self.max_time = max_time
self.tmp_store('cached_freqs', None)
self.tmp_store('cached_scales', None)
# Adjust theta to avoid angle wrapping after large times
if exists(max_time) and freqs_for == 'lang':
# Make sure highest frequency completes 1 full rotation over max time
# theta = base of exponent: higher theta → lower frequency range
# max_time * (1/theta^(0)) = 2pi => theta = max_time / (2pi)
theta = max_time / (2 * pi)
theta *= theta_rescale_factor ** (dim / (dim - 2))
self.theta = theta
if exists(custom_freqs):
freqs = custom_freqs
elif freqs_for == 'lang':
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
elif freqs_for == 'pixel':
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
elif freqs_for == 'constant':
freqs = torch.ones(num_freqs).float()
self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
self.learned_freq = learned_freq
# dummy for device
self.tmp_store('dummy', torch.tensor(0))
# default sequence dimension
self.seq_before_head_dim = seq_before_head_dim
self.default_seq_dim = -3 if seq_before_head_dim else -2
# interpolation factors
assert interpolate_factor >= 1.
self.interpolate_factor = interpolate_factor
# xpos
if not use_xpos:
self.tmp_store('scale', None)
return
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.scale_base = xpos_scale_base
self.tmp_store('scale', scale)
# add apply_rotary_emb as static method
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
@property
def device(self):
return self.dummy.device
def tmp_store(self, key, value):
self.register_buffer(key, value, persistent = False)
def get_seq_pos(self, seq_len, device, dtype, offset = 0):
return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor
def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0):
seq_dim = default(seq_dim, self.default_seq_dim)
assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
freqs = self.forward(self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset), seq_len = seq_len, offset = offset)
if seq_dim == -3:
freqs = rearrange(freqs, 'n d -> n 1 d')
return apply_rotary_emb(freqs, t, seq_dim = seq_dim)
def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0):
seq_dim = default(seq_dim, self.default_seq_dim)
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
assert q_len <= k_len
rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, offset = k_len - q_len + offset)
rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim, offset = offset)
rotated_q = rotated_q.type(q.dtype)
rotated_k = rotated_k.type(k.dtype)
return rotated_q, rotated_k
def rotate_queries_and_keys(self, q, k, seq_dim = None):
seq_dim = default(seq_dim, self.default_seq_dim)
assert self.use_xpos
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
freqs = self.forward(seq, seq_len = seq_len)
scale = self.get_scale(seq, seq_len = seq_len).to(dtype)
if seq_dim == -3:
freqs = rearrange(freqs, 'n d -> n 1 d')
scale = rearrange(scale, 'n d -> n 1 d')
rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim)
rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim)
rotated_q = rotated_q.type(q.dtype)
rotated_k = rotated_k.type(k.dtype)
return rotated_q, rotated_k
@beartype
def get_scale(
self,
t: Tensor,
seq_len: Optional[int] = None,
offset = 0
):
assert self.use_xpos
should_cache = (
self.cache_if_possible and
exists(seq_len)
)
if (
should_cache and \
exists(self.cached_scales) and \
(seq_len + offset) <= self.cached_scales.shape[0]
):
return self.cached_scales[offset:(offset + seq_len)]
scale = 1.
if self.use_xpos:
power = (t - len(t) // 2) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
if should_cache:
self.tmp_store('cached_scales', scale)
return scale
def get_axial_freqs(self, *dims):
Colon = slice(None)
all_freqs = []
for ind, dim in enumerate(dims):
if self.freqs_for == 'pixel':
pos = torch.linspace(-1, 1, steps = dim, device = self.device)
else:
pos = torch.arange(dim, device = self.device)
freqs = self.forward(pos, seq_len = dim)
all_axis = [None] * len(dims)
all_axis[ind] = Colon
new_axis_slice = (Ellipsis, *all_axis, Colon)
all_freqs.append(freqs[new_axis_slice])
all_freqs = broadcast_tensors(*all_freqs)
return torch.cat(all_freqs, dim = -1)
def forward(
self,
t: Tensor,
seq_len = None,
offset = 0
):
should_cache = (
self.cache_if_possible and \
not self.learned_freq and \
exists(seq_len) and \
self.freqs_for != 'pixel'
)
if (
should_cache and \
exists(self.cached_freqs) and \
(offset + seq_len) <= self.cached_freqs.shape[0]
):
return self.cached_freqs[offset:(offset + seq_len)].detach()
freqs = self.freqs
# Scale time to keep t * freq <= 2pi
if hasattr(self, 'max_time') and self.max_time is not None:
t = t / self.max_time * (2 * pi)
freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
if should_cache:
self.tmp_store('cached_freqs', freqs.detach())
return freqs
class BaseEncoder(nn.Module):
def __init__(self, parent: nn.Module) -> None:
super().__init__()
self._parent = [parent]
@property
def parent(self) -> nn.Module:
return self._parent[0]
class BasicImageEncoder(BaseEncoder):
def __init__(
self,
parent: torch.nn.Module,
start_tokens: Optional[str] = None,
end_tokens: Optional[str] = "\n",
) -> None:
super().__init__(parent)
end_tokens = None if end_tokens == "None" else end_tokens
self.start_tokens = start_tokens
self.end_tokens = end_tokens
def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
if tokens is None:
return None
token_ids = self.parent.tokenizer(tokens).input_ids
token_ids = torch.tensor(token_ids, device=self.parent.device)
return self.parent.llm_model_embed_tokens(token_ids)
def _process_features(
self,
features: torch.Tensor,
start_token_embeds: Optional[torch.Tensor],
end_token_embeds: Optional[torch.Tensor],
) -> torch.Tensor:
if start_token_embeds is not None:
features = torch.cat([start_token_embeds, features], dim=0)
if end_token_embeds is not None:
features = torch.cat([features, end_token_embeds], dim=0)
return features
def forward(self, images: List[torch.Tensor], config: Dict[str, Any], mm_info: dict) -> List[torch.Tensor]:
images = torch.stack(images, dim=0)
features = self.parent.encode_images(images, block_sizes=config.get("block_sizes"))
process_features = partial(
self._process_features,
start_token_embeds=self.embed_tokens(self.start_tokens),
end_token_embeds=self.embed_tokens(self.end_tokens),
)
return [process_features(f) for f in features]
class BasicVideoEncoder(BaseEncoder):
def __init__(
self,
parent: torch.nn.Module,
start_tokens: Optional[str] = None,
end_tokens: Optional[str] = "\n",
) -> None:
super().__init__(parent)
end_tokens = None if end_tokens == "None" else end_tokens
self.start_tokens = start_tokens
self.end_tokens = end_tokens
def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
if tokens is None:
return None
token_ids = self.parent.tokenizer(tokens).input_ids
token_ids = torch.tensor(token_ids, device=self.parent.device)
return self.parent.llm_model_embed_tokens(token_ids)
def _process_features(
self,
features: torch.Tensor,
start_token_embeds: Optional[torch.Tensor],
end_token_embeds: Optional[torch.Tensor],
) -> torch.Tensor:
if start_token_embeds is not None:
start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0)
features = torch.cat([start_embeds, features], dim=1)
if end_token_embeds is not None:
end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0)
features = torch.cat([features, end_embeds], dim=1)
return features.flatten(0, 1)
def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
num_frames = [video.shape[0] for video in videos]
images = torch.cat(videos, dim=0)
features = self.parent.encode_images(images)
features = torch.split(features, num_frames)
process_features = partial(
self._process_features,
start_token_embeds=self.embed_tokens(self.start_tokens),
end_token_embeds=self.embed_tokens(self.end_tokens),
)
return [process_features(f) for f in features]
class BasicSoundEncoder(BaseEncoder):
def __init__(
self,
parent: torch.nn.Module,
start_tokens: Optional[str] = None,
end_tokens: Optional[str] = "\n",
embed_time = "True",
trope_theta = 50000,
trope_dim = 128,
max_time = None,
time_embed_type = "pixel",
period_fix = False,
) -> None:
super().__init__(parent)
end_tokens = None if end_tokens == "None" else end_tokens
if embed_time == "True":
embed_time = True
elif embed_time == "False":
embed_time = False
self.start_tokens = start_tokens
self.end_tokens = end_tokens
if embed_time == "False" or embed_time == False:
self.embed_time = False
else:
self.embed_time = True
self.time_embed_type = time_embed_type
period_mode = None
if type(period_fix) == str:
if period_fix == "shortest":
period_fix = "MTCT"
period_mode = "shortest"
elif period_fix == "longest":
period_fix = "MTCT"
period_mode = "longest"
self.period_fix = period_fix
self.max_time = max_time
if period_fix == "MTCT":
if period_mode is None:
self.pos_emb = MaxTimeContinuousTimeRotaryEmbedding(
dim = trope_dim,
max_time = max_time,
)
else:
self.pos_emb = MaxTimeContinuousTimeRotaryEmbedding(
dim = trope_dim,
max_time = max_time,
period_mode = period_mode,
)
elif time_embed_type in ["pixel", "lang"]:
if trope_dim is None and max_time is None:
raise ValueError("trope_dim or max_time is required when embed_time is True")
self.pos_emb = RotaryEmbedding(
dim = trope_dim,
freqs_for = time_embed_type,
max_freq = 256,
max_time = max_time,
)
elif time_embed_type == "learned_embed":
self.time_embed = parent.sound_mm_projector.time_embed
else:
raise ValueError(f"Invalid time_embed_type: {time_embed_type}")
def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
if tokens is None:
return None
token_ids = self.parent.tokenizer(tokens).input_ids
token_ids = torch.tensor(token_ids, device=self.parent.device)
# return self.parent.llm.model.embed_tokens(token_ids)
return self.parent.llm_model_embed_tokens(token_ids)
def _process_features(
self,
features: torch.Tensor,
start_token_embeds: Optional[torch.Tensor],
end_token_embeds: Optional[torch.Tensor],
times: Optional[torch.Tensor] = None,
time_embed: Optional[torch.Tensor] = None,
) -> torch.Tensor:
features = features.to(self.parent.device)
device = features.device
dtype = features.dtype
if self.embed_time:
device = features.device
dtype = features.dtype
# Handle different embedding types
if self.time_embed_type in ["pixel", "lang"]:
times = times.unsqueeze(0)
new_times = times
pos_emb = self.pos_emb.to(device)
if self.period_fix == "True":
if self.max_time is not None:
angle = new_times.to(device) / self.max_time * 2 * np.pi
else:
angle = new_times.to(device)
elif self.period_fix == "MTCT":
freqs = self.pos_emb(new_times.float())
freqs = freqs.squeeze(0)
features = apply_rotary_emb(freqs, features)
else:
angle = (-new_times * 2 * np.pi).to(device)
if not self.period_fix == "MTCT":
freqs = pos_emb.get_axial_freqs(new_times.shape[0], features.shape[-2]).to(device)
angle_expanded = angle.unsqueeze(2)
angle_expanded = angle_expanded.expand(new_times.shape[0], features.shape[-2], freqs.shape[-1])
freqs = freqs * angle_expanded
freqs = freqs.squeeze(0)
# ori_dtype = features.dtype
# embed_dtype = torch.float32
# features = features.to(embed_dtype)
features = apply_rotary_emb(freqs, features)
# features = features.to(ori_dtype)
elif self.time_embed_type == "learned_embed": # Learned embedding
# Add time embeddings to features
features = features + time_embed
else:
raise ValueError(f"Invalid time_embed_type: {self.time_embed_type}")
if start_token_embeds is not None:
features = torch.cat([start_token_embeds, features], dim=0)
if end_token_embeds is not None:
features = torch.cat([features, end_token_embeds], dim=0)
return features
def forward(self, sounds: List[torch.Tensor], config: Dict[str, Any], mm_info: dict) -> List[torch.Tensor]:
# sounds = torch.stack(sounds, dim=0)
features = self.parent.encode_sound(sounds, mm_info=mm_info)
process_features = partial(
self._process_features,
start_token_embeds=self.embed_tokens(self.start_tokens),
end_token_embeds=self.embed_tokens(self.end_tokens),
)
if self.embed_time:
new_features = []
device = features[0].device
fea_count = len(features)
aud_idx = 0
bs = len(mm_info["audio_info"])
if self.time_embed_type == "learned_embed": # Learned embedding, we need to first collect all times and only do time embedding once
times_list = []
for i in range(bs):
_audio_info = mm_info["audio_info"][i]
if _audio_info is not None:
for j in range(len(_audio_info)):
_feature = features[aud_idx]
if _audio_info[j] == "dummy":
times = torch.zeros(_feature.shape[0], device=device, dtype=_feature.dtype)
else:
audio_chunk_length = _audio_info[j]["new_audio_chunk_length"]
sec_per_embed = audio_chunk_length / _feature.shape[0]
audio_start_sec = _audio_info[j]["audio_start_sec"]
times = [audio_start_sec + i * sec_per_embed + sec_per_embed / 2 for i in range(_feature.shape[0])]
times = torch.tensor(times).to(device)
times_list.append(times)
aud_idx += 1
times = torch.stack(times_list, dim=0)
time_embeds = self.time_embed(times, dtype=features[0].dtype)
aud_idx = 0
for i in range(bs):
_audio_info = mm_info["audio_info"][i]
if _audio_info is not None:
for j in range(len(_audio_info)):
try:
_feature = features[aud_idx]
except Exception as e:
print(f"Error: {e}. Length of features: {len(features)}. Length of _audio_info: {len(_audio_info)}. Length of _feature: {_feature.shape[0]}")
raise e
if _audio_info[j] == "dummy":
times = torch.zeros(_feature.shape[0], device=device, dtype=_feature.dtype)
else:
audio_chunk_length = _audio_info[j]["new_audio_chunk_length"]
sec_per_embed = audio_chunk_length / _feature.shape[0]
audio_start_sec = _audio_info[j]["audio_start_sec"]
times = [audio_start_sec + i * sec_per_embed + sec_per_embed / 2 for i in range(_feature.shape[0])]
times = torch.tensor(times).to(device)
if self.time_embed_type == "learned_embed":
_feature = process_features(_feature, time_embed=time_embeds[aud_idx])
else:
_feature = process_features(_feature, times=times)
new_features.append(_feature)
aud_idx += 1
assert aud_idx == fea_count , "aud_idx: {}, fea_count: {}".format(aud_idx, fea_count)
features = new_features
else:
features = [process_features(f) for f in features]
return features
# return [process_features(f) for f in feature
class TSPVideoEncoder(BasicVideoEncoder):
def __init__(
self,
parent: torch.nn.Module,
pool_sizes: List[Tuple[int, int, int]],
start_tokens: Optional[str] = None,
end_tokens: Optional[str] = "\n",
sep_tokens: Optional[str] = None,
embed_time: str = "False",
trope_theta = 50000,
trope_dim = 128,
max_time = None,
time_embed_type = "pixel",
period_fix = False,
) -> None:
super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens)
self.pool_sizes = pool_sizes
self.sep_tokens = sep_tokens
if embed_time == "False":
self.embed_time = False
else:
self.embed_time = True
self.time_embed_type = time_embed_type
period_mode = None
if type(period_fix) == str:
if period_fix == "shortest":
period_fix = "MTCT"
period_mode = "shortest"
elif period_fix == "longest":
period_fix = "MTCT"
period_mode = "longest"
self.period_fix = period_fix
self.max_time = max_time
if period_fix == "MTCT":
if period_mode is None:
self.pos_emb = MaxTimeContinuousTimeRotaryEmbedding(
dim = trope_dim,
max_time = max_time,
)
else:
self.pos_emb = MaxTimeContinuousTimeRotaryEmbedding(
dim = trope_dim,
max_time = max_time,
period_mode = period_mode,
)
elif time_embed_type in ["pixel", "lang"]:
if trope_dim is None and max_time is None:
raise ValueError("trope_dim or max_time is required when embed_time is True")
if time_embed_type == "lang":
self.pos_emb = RotaryEmbedding(
dim = trope_dim,
freqs_for = 'lang',
theta = trope_theta,
max_time = max_time,
)
elif time_embed_type == "pixel":
self.pos_emb = RotaryEmbedding(
dim = trope_dim,
freqs_for = time_embed_type,
max_freq = 256
)
elif time_embed_type == "learned_embed":
self.time_embed = parent.mm_projector.time_embed
else:
raise ValueError(f"Invalid time_embed_type: {time_embed_type}")
def _process_features(
self,
inputs: torch.Tensor,
start_token_embeds: Optional[torch.Tensor],
end_token_embeds: Optional[torch.Tensor],
sep_token_embeds: Optional[torch.Tensor],
times: Optional[torch.Tensor] = None,
time_embed: Optional[torch.Tensor] = None,
) -> torch.Tensor:
nt, ns = inputs.shape[:2]
nl = int(ns**0.5)
outputs = []
for pool_size in self.pool_sizes:
features = inputs.view(nt, nl, nl, -1)
for dim, p in enumerate(pool_size):
try:
features = pool(features, p, dim=dim)
except Exception as e:
print(f"Error: Pooling failed: {e}")
print(f"inputs.shape: {inputs.shape}, features.shape: {features.shape}, pool_size: {p}, dim: {dim}")
raise e
features = features.flatten(1, 2)
if self.embed_time:
device = features.device
dtype = features.dtype
if self.time_embed_type in ["pixel", "lang"]:
# consider the pooling in self.pool_sizes
temporal_pool_size = pool_size[0]
if temporal_pool_size != 1:
if len(times) % temporal_pool_size != 0:
# pad
print(f"Warning: length of times: {len(times)} is not a multiple of temporal_pool_size: {temporal_pool_size}")
remainder = len(times) % temporal_pool_size
pad_len = temporal_pool_size - remainder
last_window_mean_times = times[-remainder:].mean()
times = torch.cat([times, torch.ones(pad_len).to(times.device) * last_window_mean_times])
new_times = pool(times, temporal_pool_size, 0)
else:
new_times = times
pos_emb = self.pos_emb.to(device)
if self.period_fix == "True":
if self.max_time is not None:
angle = new_times.to(device) / self.max_time * 2 * np.pi
else:
angle = new_times.to(device)
elif self.period_fix == "MTCT":
if new_times.ndim == 1:
new_times = new_times.unsqueeze(0)
freqs = self.pos_emb(new_times.float())
freqs = freqs.squeeze(0)
freqs = freqs.unsqueeze(1)
features = apply_rotary_emb(freqs, features, seq_dim=0)
else:
angle = (-new_times * 2 * np.pi).to(device)
if not self.period_fix == "MTCT":
freqs = pos_emb.get_axial_freqs(new_times.shape[0], features.shape[-2]).to(device)
angle_expanded = angle.unsqueeze(1).unsqueeze(2)
angle_expanded = angle_expanded.expand(new_times.shape[0], features.shape[-2], freqs.shape[-1])
freqs = freqs * angle_expanded
# ori_dtype = features.dtype
# embed_dtype = torch.float32
# features = features.to(embed_dtype)
features = apply_rotary_emb(freqs, features)
# features = features.to(ori_dtype)
elif self.time_embed_type == "learned_embed": # Learned embedding
# Add time embeddings to features
features = features + time_embed
else:
raise ValueError(f"Invalid time_embed_type: {self.time_embed_type}")
features = super()._process_features(
features,
start_token_embeds=start_token_embeds,
end_token_embeds=end_token_embeds,
)
if sep_token_embeds is not None:
features = torch.cat([features, sep_token_embeds], dim=0)
outputs.append(features)
return torch.cat(outputs, dim=0)
def forward(self, videos: List[torch.Tensor], config: Dict[str, Any], mm_info: dict) -> List[torch.Tensor]:
cache_feas = []
cache_feas_index = []
for _idx in range(len(videos)):
if type(videos[_idx]) == CacheFeatures:
cache_feas.append(videos[_idx])
cache_feas_index.append(_idx)
num_frames = [
_.value['features'].shape[0] if isinstance(_, CacheFeatures) else _.shape[0]
for _ in videos
]
features = self.parent.encode_video(videos, mm_info=mm_info, num_frames=num_frames)
features = torch.split(features, num_frames)
process_features = partial(
self._process_features,
start_token_embeds=self.embed_tokens(self.start_tokens),
end_token_embeds=self.embed_tokens(self.end_tokens),
sep_token_embeds=self.embed_tokens(self.sep_tokens),
)
if self.embed_time:
bs = len(mm_info["video_info"])
vid_idx = 0
device = features[0].device
if self.time_embed_type == "learned_embed":
# Learned embedding, we need to first collect all times from all videos and only do time embedding once
times_list = []
for i in range(bs):
_video_info = mm_info["video_info"][i]
if _video_info is not None:
for j in range(len(_video_info)):
_feature = features[vid_idx]
if _video_info[j] == "dummy":
times = torch.zeros(_feature.shape[0], device=device, dtype=_feature.dtype)
else:
times = _video_info[j]["video_frame_times"]
times = torch.tensor(times).to(device)
for pool_size in self.pool_sizes:
temporal_pool_size = pool_size[0]
if temporal_pool_size != 1:
if len(times) % temporal_pool_size != 0:
# pad
print(f"Warning: length of times: {len(times)} is not a multiple of temporal_pool_size: {temporal_pool_size}")
remainder = len(times) % temporal_pool_size
pad_len = temporal_pool_size - remainder
last_window_mean_times = times[-remainder:].mean()
times = torch.cat([times, torch.ones(pad_len).to(times.device) * last_window_mean_times])
times = pool(times, temporal_pool_size, 0)
times_list.append(times)
vid_idx += 1
# pad the times to the same length
ori_lens = [len(times) for times in times_list]
max_len = max(ori_lens)
for i in range(len(times_list)):
if len(times_list[i]) < max_len:
times_list[i] = torch.cat([times_list[i], torch.zeros(max_len - len(times_list[i])).to(times_list[i].device)])
times = torch.stack(times_list, dim=0)
time_embeds = self.time_embed(times, dtype=features[0].dtype)
# remove the padding for each embed
new_time_embeds = []
for i in range(len(times_list)):
new_time_embeds.append(time_embeds[i][:ori_lens[i]].unsqueeze(1).expand(-1, features[0].shape[1], -1))
# add dummy embed to the first embed
new_time_embeds[0] = new_time_embeds[0] + 0 * time_embeds.mean()
new_features = []
fea_count = len(features)
vid_idx = 0
for i in range(bs):
_video_info = mm_info["video_info"][i]
if _video_info is not None:
for j in range(len(_video_info)):
_feature = features[vid_idx]
if _video_info[j] == "dummy":
times = torch.zeros(_feature.shape[0], device=device, dtype=_feature.dtype)
else:
times = _video_info[j]["video_frame_times"]
times = torch.tensor(times).to(device)
if self.time_embed_type == "learned_embed":
_feature = process_features(_feature, time_embed=new_time_embeds[vid_idx])
else:
_feature = process_features(_feature, times=times)
new_features.append(_feature)
vid_idx += 1
assert vid_idx == fea_count, "vid_idx: {}, fea_count: {}".format(vid_idx, fea_count)
features = new_features
else:
features = [process_features(f) for f in features]
return features
def _encode_video_frames(self, video_frames: torch.Tensor) -> torch.Tensor:
"""Helper method to encode video frames when cached features are not available."""
features = self.parent.encode_images(video_frames.unsqueeze(0))
return features.squeeze(0)