|
|
|
|
|
|
|
|
from abc import ABCMeta, abstractmethod |
|
|
from copy import deepcopy |
|
|
from functools import wraps |
|
|
from math import prod, sqrt |
|
|
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union |
|
|
|
|
|
import einops |
|
|
import torch |
|
|
import torch.backends.cuda |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as f |
|
|
from torch.nn.attention import SDPBackend, sdpa_kernel |
|
|
from transformers import PretrainedConfig |
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.cache_utils import Cache |
|
|
from transformers.integrations import use_kernel_forward_from_hub |
|
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
|
from transformers.modeling_layers import GradientCheckpointingLayer |
|
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
|
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
|
|
from transformers.processing_utils import Unpack |
|
|
|
|
|
from .configuration_jvlm import ( |
|
|
ImagePaddingEmbedType, |
|
|
ImagePooling2DType, |
|
|
ImageProjectionType, |
|
|
JinaAttentionConfig, |
|
|
JinaFFNConfig, |
|
|
JinaLNormConfig, |
|
|
JinaTransformerBlockConfig, |
|
|
JinaVLConnectorConfig, |
|
|
LayerNormType, |
|
|
) |
|
|
|
|
|
|
|
|
class Dropout(nn.Dropout): |
|
|
def __init__( |
|
|
self, |
|
|
p: float = 0.5, |
|
|
inplace: bool = False, |
|
|
mask_p: float = 0.0, |
|
|
broadcast_dims: Sequence[int] = (), |
|
|
) -> None: |
|
|
super().__init__(p, inplace) |
|
|
self.mask_p = mask_p |
|
|
self.broadcast_dims = broadcast_dims |
|
|
|
|
|
def forward( |
|
|
self, _input: torch.Tensor, drop_mask: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
:param _input: A tensor of shape `(batch_size, seq_len, embed_dim)` |
|
|
:param drop_mask: A tensor of shape `(batch_size, seq_len)` with values of zero |
|
|
or one |
|
|
""" |
|
|
if self.p == 0.0 and (self.mask_p is None or self.mask_p == 0.0): |
|
|
return _input |
|
|
else: |
|
|
if self.mask_p > 0.0 and self.training: |
|
|
assert drop_mask is not None |
|
|
drop_mask = drop_mask.to(_input.dtype) |
|
|
keep_prob = 1.0 - self.p |
|
|
keep_prob2 = 1.0 - self.mask_p |
|
|
keep_prob = drop_mask * keep_prob2 + (1 - drop_mask) * keep_prob |
|
|
keep_prob = keep_prob.unsqueeze(-1) |
|
|
dropout_shape = list(_input.shape) |
|
|
keep_prob = keep_prob.broadcast_to(dropout_shape) |
|
|
multiplier = _input.new_empty(dropout_shape).bernoulli_(keep_prob) |
|
|
multiplier.div_(keep_prob) |
|
|
return _input * multiplier |
|
|
elif self.p > 0.0 and len(self.broadcast_dims) > 0 and self.training: |
|
|
keep_prob = 1.0 - self.p |
|
|
dropout_shape = list(_input.shape) |
|
|
for dim in self.broadcast_dims: |
|
|
dropout_shape[dim] = 1 |
|
|
keep = _input.new_empty(dropout_shape).bernoulli_(keep_prob) |
|
|
multiplier = keep.broadcast_to(_input.shape) |
|
|
multiplier.div_(keep_prob) |
|
|
return _input * multiplier |
|
|
else: |
|
|
return f.dropout(_input, self.p, self.training, self.inplace) |
|
|
|
|
|
|
|
|
class ResidualPathDropout(nn.Module): |
|
|
"""Drops paths (Stochastic Depth) per sample (when applied in main path of residual |
|
|
blocks). |
|
|
|
|
|
Taken from |
|
|
https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py |
|
|
""" |
|
|
|
|
|
def __init__(self, p: float = 0.5, scale_by_keep: bool = True) -> None: |
|
|
super(ResidualPathDropout, self).__init__() |
|
|
assert 0 <= p < 1.0 |
|
|
self.p = p |
|
|
self.scale_by_keep = scale_by_keep |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of |
|
|
residual blocks). |
|
|
|
|
|
This is the same as the DropConnect impl I created for EfficientNet, etc |
|
|
networks, however, the original name is misleading as 'Drop Connect' is a |
|
|
different form of dropout in a separate paper... |
|
|
See discussion: |
|
|
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 |
|
|
|
|
|
I've opted for changing the layer and argument names to 'drop path' rather |
|
|
than mix DropConnect as a layer name and use 'survival rate' as the argument. |
|
|
""" |
|
|
if self.p == 0.0 or not self.training: |
|
|
return x |
|
|
keep_prob = 1 - self.p |
|
|
|
|
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob) |
|
|
if keep_prob > 0.0 and self.scale_by_keep: |
|
|
random_tensor.div_(keep_prob) |
|
|
return x * random_tensor |
|
|
|
|
|
|
|
|
class PatchDropout(nn.Module): |
|
|
""" |
|
|
https://arxiv.org/abs/2212.00794 |
|
|
""" |
|
|
|
|
|
def __init__(self, p: float = 0.5, exclude_first_token: bool = True): |
|
|
super().__init__() |
|
|
assert 0 <= p < 1.0 |
|
|
self.p = p |
|
|
self.exclude_first_token = exclude_first_token |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
if not self.training or self.p == 0.0: |
|
|
return x, None |
|
|
|
|
|
if self.exclude_first_token: |
|
|
_cls_tokens, x = x[:, :1], x[:, 1:] |
|
|
else: |
|
|
_cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) |
|
|
|
|
|
batch, ntokens = x.size() |
|
|
batch_indices = torch.arange(batch) |
|
|
batch_indices = batch_indices[..., None] |
|
|
|
|
|
keep_prob = 1 - self.p |
|
|
num_patches_keep = max(1, int(ntokens * keep_prob)) |
|
|
rand = torch.randn(batch, ntokens) |
|
|
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices |
|
|
|
|
|
x = x[batch_indices, patch_indices_keep] |
|
|
if self.exclude_first_token: |
|
|
x = torch.cat((_cls_tokens, x), dim=1) |
|
|
|
|
|
return x, patch_indices_keep |
|
|
|
|
|
|
|
|
""" |
|
|
Embedding layers. Adapted from AllenAI Molmo |
|
|
https://github.com/allenai/molmo |
|
|
""" |
|
|
|
|
|
|
|
|
class ExtendedEmbedding(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
num_embeddings: int, |
|
|
num_new_embeddings: int, |
|
|
num_features: int, |
|
|
): |
|
|
super().__init__() |
|
|
self.embedding = nn.Parameter( |
|
|
torch.zeros(num_embeddings, num_features), |
|
|
) |
|
|
self.new_embedding = nn.Parameter( |
|
|
torch.zeros(num_new_embeddings, num_features), |
|
|
) |
|
|
|
|
|
@property |
|
|
def weight(self): |
|
|
return self.embedding |
|
|
|
|
|
@weight.setter |
|
|
def weight(self, w): |
|
|
self.embedding = w |
|
|
|
|
|
@property |
|
|
def embedding_table(self) -> torch.Tensor: |
|
|
return torch.cat([self.embedding, self.new_embedding], dim=0) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return f.embedding(x, self.embedding_table) |
|
|
|
|
|
|
|
|
class PatchEmbedding(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int = 768, |
|
|
patch_size: int = 16, |
|
|
num_channels: int = 3, |
|
|
input_size: Optional[Tuple[int, int]] = None, |
|
|
bias: bool = True, |
|
|
use_linear: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
self._input_size = input_size |
|
|
self._patch_size = (patch_size, patch_size) |
|
|
if input_size is not None: |
|
|
self._patch_shape = ( |
|
|
self._input_size[0] // self._patch_size[0], |
|
|
self._input_size[1] // self._patch_size[1], |
|
|
) |
|
|
self._num_patches = prod(self._patch_shape) |
|
|
else: |
|
|
assert not use_linear, 'Linear patch embedding requires a fixed input size!' |
|
|
self._patch_shape = None |
|
|
self._num_patches = None |
|
|
self._num_channels = num_channels |
|
|
self._dim = dim |
|
|
self._bias = bias |
|
|
if use_linear: |
|
|
self.proj = nn.Linear( |
|
|
self._num_channels * self._patch_size[0] * self._patch_size[1], |
|
|
self._dim, |
|
|
bias=self._bias, |
|
|
) |
|
|
self._proj_impl = 'linear' |
|
|
else: |
|
|
self.proj = nn.Conv2d( |
|
|
self._num_channels, |
|
|
self._dim, |
|
|
kernel_size=self._patch_size, |
|
|
stride=self._patch_size, |
|
|
bias=self._bias, |
|
|
|
|
|
) |
|
|
self._proj_impl = 'conv2d' |
|
|
|
|
|
def _linear_pre_projection(self, x: torch.Tensor) -> torch.Tensor: |
|
|
b, c, *_ = x.shape |
|
|
p1, p2 = self._patch_size |
|
|
patches = x.unfold(2, p1, p1).unfold(3, p2, p2) |
|
|
patches = patches.permute(0, 2, 3, 4, 5, 1) |
|
|
return patches.reshape(b, -1, c * p1 * p2) |
|
|
|
|
|
@staticmethod |
|
|
def _conv2d_pre_projection(x: torch.Tensor) -> torch.Tensor: |
|
|
return x |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: |
|
|
|
|
|
if len(x.shape) == 4: |
|
|
bs, ch, h, w = x.shape |
|
|
p1, p2 = self._patch_size |
|
|
assert ch == self._num_channels, ( |
|
|
f'Input tensor has {ch} channels, but model expects ' |
|
|
f'{self._num_channels} channels' |
|
|
) |
|
|
if self._input_size is not None: |
|
|
assert (h, w) == self._input_size, ( |
|
|
f"Input image shape {(h, w)} doesn't match model's " |
|
|
f'{self._input_size}' |
|
|
) |
|
|
if self._proj_impl == 'linear': |
|
|
patches = x.unfold(2, p1, p1).unfold(3, p2, p2) |
|
|
patches = patches.permute(0, 2, 3, 4, 5, 1) |
|
|
x = patches.reshape(bs, -1, ch * p1 * p2) |
|
|
else: |
|
|
assert h % p1 == 0 and w % p2 == 0, ( |
|
|
f'Input image shape {(h, w)} is not divisible by patch size ' |
|
|
f'{self._patch_size}' |
|
|
) |
|
|
shape = (h // p1, w // p2) |
|
|
|
|
|
|
|
|
elif len(x.shape) == 3: |
|
|
bs, sl, np = x.shape |
|
|
h = int(sqrt(sl)) |
|
|
shape = (h, h) |
|
|
if self._input_size is not None: |
|
|
assert self._num_patches == sl, ( |
|
|
f"Input sequence length ({sl}) doesn't match model's patch shape " |
|
|
f'({self._patch_shape})' |
|
|
) |
|
|
else: |
|
|
assert h * h == sl, ( |
|
|
f'Input sequence length {sl} is not a perfect square. Please ' |
|
|
f'provide a square sequence length, from which the shape can be ' |
|
|
f'inferred. For non-square inputs, use a 4D tensor with shape ' |
|
|
f'(batch_size, n_channels, height, width)' |
|
|
) |
|
|
p1, p2 = self._patch_size |
|
|
assert np == self._num_channels * p1 * p2, ( |
|
|
f'The input number of pixels ({np}) does not match the expected number ' |
|
|
f'n_channels * patch_size_horizontal * patch_size_vertical ' |
|
|
f'({self._num_channels * p1 * p2})' |
|
|
) |
|
|
if self._proj_impl == 'conv2d': |
|
|
|
|
|
x = ( |
|
|
x.unfold(1, h, h) |
|
|
.reshape(bs, h, h, p1, p2, self._num_channels) |
|
|
.permute(0, 5, 1, 3, 2, 4) |
|
|
.reshape(bs, self._num_channels, h * p1, h * p2) |
|
|
) |
|
|
else: |
|
|
raise ValueError( |
|
|
f'Input tensor must be 3D or 4D, got {len(x.shape)}D tensor with shape ' |
|
|
f'{x.shape}. Accepted shapes are (batch_size, n_channels, height, ' |
|
|
f'width) or (batch_size, seq_len, n_pixels)' |
|
|
) |
|
|
out = self.proj(x.to(dtype=self.proj.weight.dtype)) |
|
|
if self._proj_impl == 'conv2d': |
|
|
out = out.flatten(2).permute(0, 2, 1) |
|
|
return out, shape |
|
|
|
|
|
|
|
|
""" |
|
|
Rotary Positional Embeddings. Compatible with HuggingFace transformers |
|
|
https://github.com/huggingface/transformers/blob/main/src/transformers/ |
|
|
modeling_rope_utils.py |
|
|
""" |
|
|
|
|
|
|
|
|
def inv_freq_to_device(rope_forward): |
|
|
"""Sometimes the inv_freq is calculated on the wrong device, or ends up in lower |
|
|
precision than float32. |
|
|
|
|
|
This wrapper ensures that inv_freq is always on the right device and in float32 |
|
|
precision. |
|
|
""" |
|
|
|
|
|
@wraps(rope_forward) |
|
|
def wrapper(self, x, position_ids): |
|
|
if self.inv_freq.dtype != torch.float32 or self.rope_init_device != x.device: |
|
|
invfreq, self.attention_scaling = self.rope_init_fn( |
|
|
self.config, x.device, self.max_seq_len_cached |
|
|
) |
|
|
self.register_buffer('inv_freq', invfreq, persistent=False) |
|
|
self.original_inv_freq = self.inv_freq |
|
|
self.rope_init_device = x.device |
|
|
return rope_forward(self, x, position_ids) |
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
|
inv_freq: torch.Tensor |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: PretrainedConfig, |
|
|
theta: float, |
|
|
head_dim: int, |
|
|
hidden_size: int, |
|
|
partial_rotary_factor: float, |
|
|
device: Optional[torch.device] = None, |
|
|
scaling: Optional[Dict[str, Any]] = None, |
|
|
): |
|
|
super().__init__() |
|
|
assert hasattr(config, 'rope_theta') |
|
|
self.config = deepcopy(config) |
|
|
|
|
|
|
|
|
setattr(self.config, 'rope_theta', theta) |
|
|
setattr(self.config, 'partial_rotary_factor', partial_rotary_factor) |
|
|
setattr(self.config, 'head_dim', head_dim) |
|
|
setattr(self.config, 'hidden_size', hidden_size) |
|
|
setattr(self.config, 'rope_scaling', scaling or {}) |
|
|
|
|
|
self.rope_type = 'default' |
|
|
if hasattr(config, 'rope_scaling') and config.rope_scaling is not None: |
|
|
self.rope_type = config.rope_scaling.get('rope_type', 'default') |
|
|
|
|
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
seqlen = config.max_position_embeddings or config.max_sequence_length |
|
|
invfreq, self.attention_scaling = self.rope_init_fn(self.config, device, seqlen) |
|
|
self.rope_init_device = device |
|
|
self.register_buffer('inv_freq', invfreq, persistent=False) |
|
|
self.original_inv_freq = self.inv_freq |
|
|
self.max_seq_len_cached = seqlen |
|
|
self.original_max_seq_len = self.max_seq_len_cached |
|
|
|
|
|
@torch.no_grad() |
|
|
@inv_freq_to_device |
|
|
@dynamic_rope_update |
|
|
def forward(self, x: torch.Tensor, position_ids: torch.Tensor): |
|
|
device_type = ( |
|
|
x.device.type |
|
|
if isinstance(x.device.type, str) and x.device.type != 'mps' |
|
|
else 'cpu' |
|
|
) |
|
|
with torch.autocast(device_type=device_type, enabled=False): |
|
|
inv_freq_expanded = self.inv_freq[None, :, None].expand( |
|
|
position_ids.shape[0], -1, 1 |
|
|
) |
|
|
position_ids_expanded = position_ids[:, None, :].float() |
|
|
freqs = inv_freq_expanded * position_ids_expanded |
|
|
freqs = freqs.transpose(1, 2) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
cos = emb.cos() * self.attention_scaling |
|
|
sin = emb.sin() * self.attention_scaling |
|
|
|
|
|
return cos, sin |
|
|
|
|
|
|
|
|
""" |
|
|
Residual wrapper. Adapted from AllenAI Molmo |
|
|
https://github.com/allenai/molmo |
|
|
""" |
|
|
|
|
|
|
|
|
class Residual(nn.Module): |
|
|
def __init__(self, submodule: nn.Module): |
|
|
super().__init__() |
|
|
self.submodule = submodule |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return x + self.submodule(x) |
|
|
|
|
|
|
|
|
""" |
|
|
Layer scaling. Adapted from |
|
|
https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/layer_scale.py |
|
|
""" |
|
|
|
|
|
|
|
|
class LayerScale(nn.Module): |
|
|
""" |
|
|
LayerScale appearing in DINO v2 |
|
|
From |
|
|
https://github.com/facebookresearch/dinov2/blob/main/dinov2/layers/layer_scale.py |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
init_value: float = 1e-5, |
|
|
inplace: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.init_value = init_value |
|
|
self.inplace = inplace |
|
|
self.gamma = nn.Parameter(init_value * torch.ones((dim,)), requires_grad=True) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
|
|
|
|
|
|
""" |
|
|
Layer normalization. Adapted from AllenAI Molmo |
|
|
https://github.com/allenai/molmo |
|
|
""" |
|
|
|
|
|
|
|
|
class _LayerNorm(nn.Module, metaclass=ABCMeta): |
|
|
def __init__( |
|
|
self, |
|
|
config: JinaLNormConfig, |
|
|
size: int, |
|
|
elementwise_affine: Optional[bool] = True, |
|
|
eps: float = 1e-05, |
|
|
weight_initializer: Optional[Callable] = torch.ones, |
|
|
bias_initializer: Optional[Callable] = torch.zeros, |
|
|
): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.eps = self.config.eps or eps |
|
|
self.normalized_shape = (size,) |
|
|
if elementwise_affine or ( |
|
|
elementwise_affine is None and self.config.with_affine |
|
|
): |
|
|
self.weight = nn.Parameter(weight_initializer(self.normalized_shape)) |
|
|
use_bias = self.config.bias |
|
|
if use_bias: |
|
|
self.bias = nn.Parameter(bias_initializer(self.normalized_shape)) |
|
|
else: |
|
|
self.register_parameter('bias', None) |
|
|
else: |
|
|
self.register_parameter('bias', None) |
|
|
self.register_parameter('weight', None) |
|
|
|
|
|
@abstractmethod |
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
raise NotImplementedError |
|
|
|
|
|
@staticmethod |
|
|
def _cast_if_autocast_enabled( |
|
|
tensor: torch.Tensor, dtype: Optional[torch.dtype] = None |
|
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
if tensor.device.type == 'cuda' and torch.is_autocast_enabled(): |
|
|
return tensor.to( |
|
|
dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype() |
|
|
) |
|
|
elif tensor.device.type == 'cpu' and torch.is_autocast_cpu_enabled(): |
|
|
return tensor.to( |
|
|
dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype() |
|
|
) |
|
|
else: |
|
|
return tensor |
|
|
|
|
|
|
|
|
class LayerNorm(_LayerNorm): |
|
|
"""The default :class:`LayerNorm` implementation which can optionally run in low |
|
|
precision.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: JinaLNormConfig, |
|
|
size: int, |
|
|
low_precision: bool = False, |
|
|
elementwise_affine: Optional[bool] = None, |
|
|
eps: float = 1e-05, |
|
|
): |
|
|
super().__init__( |
|
|
config, size=size, elementwise_affine=elementwise_affine, eps=eps |
|
|
) |
|
|
self.low_precision = low_precision |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
if self.low_precision: |
|
|
module_device = x.device |
|
|
downcast_x = self._cast_if_autocast_enabled(x) |
|
|
downcast_weight = ( |
|
|
self._cast_if_autocast_enabled(self.weight) |
|
|
if self.weight is not None |
|
|
else self.weight |
|
|
) |
|
|
downcast_bias = ( |
|
|
self._cast_if_autocast_enabled(self.bias) |
|
|
if self.bias is not None |
|
|
else self.bias |
|
|
) |
|
|
with torch.autocast(enabled=False, device_type=module_device.type): |
|
|
return f.layer_norm( |
|
|
downcast_x, |
|
|
self.normalized_shape, |
|
|
weight=downcast_weight, |
|
|
bias=downcast_bias, |
|
|
eps=self.eps, |
|
|
) |
|
|
else: |
|
|
return f.layer_norm( |
|
|
x, |
|
|
self.normalized_shape, |
|
|
weight=self.weight, |
|
|
bias=self.bias, |
|
|
eps=self.eps, |
|
|
) |
|
|
|
|
|
|
|
|
@use_kernel_forward_from_hub('RMSNorm') |
|
|
class RMSLayerNorm(_LayerNorm): |
|
|
"""RMS layer norm, a simplified :class:`LayerNorm` implementation.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: JinaLNormConfig, |
|
|
size: int, |
|
|
elementwise_affine: Optional[bool] = None, |
|
|
eps: float = 1e-5, |
|
|
): |
|
|
super().__init__( |
|
|
config, size=size, elementwise_affine=elementwise_affine, eps=eps |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
with torch.autocast(enabled=False, device_type=x.device.type): |
|
|
og_dtype = x.dtype |
|
|
x = x.to(torch.float32) |
|
|
variance = x.pow(2).mean(-1, keepdim=True) |
|
|
x = x * torch.rsqrt(variance + self.eps) |
|
|
x = x.to(og_dtype) |
|
|
|
|
|
if self.weight is not None: |
|
|
if self.bias is not None: |
|
|
return self.weight * x + self.bias |
|
|
else: |
|
|
return self.weight * x |
|
|
else: |
|
|
return x |
|
|
|
|
|
|
|
|
def build_layer_norm(config: JinaLNormConfig, size: int, **kwargs) -> _LayerNorm: |
|
|
if config.type == LayerNormType.default: |
|
|
return LayerNorm(config, size=size, low_precision=False, **kwargs) |
|
|
elif config.type == LayerNormType.low_precision: |
|
|
return LayerNorm(config, size=size, low_precision=True, **kwargs) |
|
|
return RMSLayerNorm(config, size=size, **kwargs) |
|
|
|
|
|
|
|
|
""" |
|
|
Multi Head Scaled Dot Product Attention module and utilities. Adapted from AllenAI Molmo |
|
|
https://github.com/allenai/molmo |
|
|
""" |
|
|
|
|
|
|
|
|
def _create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor: |
|
|
with torch.autocast(device.type, enabled=False): |
|
|
causal_mask = torch.triu( |
|
|
torch.ones(seq_len, seq_len, device=device, dtype=torch.float), |
|
|
diagonal=1, |
|
|
) |
|
|
causal_mask.masked_fill_(causal_mask == 1, torch.finfo(causal_mask.dtype).min) |
|
|
causal_mask = causal_mask.view(1, 1, seq_len, seq_len) |
|
|
return causal_mask |
|
|
|
|
|
|
|
|
def _ensure_finite( |
|
|
x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False |
|
|
): |
|
|
"""Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the |
|
|
dtype when ``check_neg_inf`` is ``True`` and replace ``float("inf")`` with the |
|
|
maximum value of the dtype when ``check_pos_inf`` is ``True``""" |
|
|
if check_neg_inf: |
|
|
x.masked_fill_(x == float('-inf'), torch.finfo(x.dtype).min) |
|
|
if check_pos_inf: |
|
|
x.masked_fill_(x == float('inf'), torch.finfo(x.dtype).max) |
|
|
|
|
|
|
|
|
def resolve_causal_mask( |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
causal_mask: Optional[torch.Tensor], |
|
|
past_key_values: Optional[Cache], |
|
|
batch_size: int, |
|
|
seq_len: int, |
|
|
past_length: int, |
|
|
device, |
|
|
): |
|
|
if attention_mask is not None: |
|
|
|
|
|
if len(attention_mask.shape) == 2: |
|
|
attention_mask = attention_mask[:, : past_length + seq_len] |
|
|
attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[ |
|
|
:, None, None, : |
|
|
] |
|
|
else: |
|
|
attention_mask = attention_mask.unsqueeze(1).to(dtype=torch.float) |
|
|
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
causal_mask is not None |
|
|
or attention_mask is not None |
|
|
or past_key_values is not None |
|
|
): |
|
|
if causal_mask is None: |
|
|
causal_mask = _create_causal_mask(past_length + seq_len, device) |
|
|
elif causal_mask.dtype in (torch.int8, torch.bool): |
|
|
causal_mask = causal_mask.to(dtype=torch.float) |
|
|
causal_mask.masked_fill_( |
|
|
causal_mask == 0.0, torch.finfo(causal_mask.dtype).min |
|
|
) |
|
|
mask_len = seq_len |
|
|
if attention_mask is not None: |
|
|
mask_len = attention_mask.shape[-1] |
|
|
elif past_key_values is not None: |
|
|
mask_len = past_length + seq_len |
|
|
causal_mask = causal_mask[:, :, :mask_len, :mask_len].to(dtype=torch.float) |
|
|
|
|
|
if attention_mask is not None: |
|
|
causal_mask = causal_mask + attention_mask |
|
|
|
|
|
|
|
|
|
|
|
_ensure_finite(causal_mask, check_neg_inf=True, check_pos_inf=False) |
|
|
return causal_mask |
|
|
|
|
|
|
|
|
def cast_attention_mask(bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor: |
|
|
target_dtype = input_dtype |
|
|
|
|
|
|
|
|
|
|
|
if bias.device.type == 'cuda' and torch.is_autocast_enabled(): |
|
|
target_dtype = torch.get_autocast_gpu_dtype() |
|
|
elif bias.device.type == 'cpu' and torch.is_autocast_cpu_enabled(): |
|
|
target_dtype = torch.get_autocast_cpu_dtype() |
|
|
if bias.dtype != target_dtype: |
|
|
bias = bias.to(target_dtype) |
|
|
_ensure_finite(bias, check_neg_inf=True, check_pos_inf=False) |
|
|
return bias |
|
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n: int) -> torch.Tensor: |
|
|
batch, kvheads, slen, hdim = hidden_states.shape |
|
|
if n == 1: |
|
|
return hidden_states |
|
|
hidden_states = hidden_states[:, :, None, :, :].expand( |
|
|
batch, kvheads, n, slen, hdim |
|
|
) |
|
|
return hidden_states.reshape(batch, kvheads * n, slen, hdim) |
|
|
|
|
|
|
|
|
def eager_attention_forward( |
|
|
module: nn.Module, |
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
scaling: float, |
|
|
dropout: float = 0.0, |
|
|
**_, |
|
|
): |
|
|
key_states = repeat_kv(key, module.num_key_value_groups) |
|
|
value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
|
|
|
weights = torch.matmul(query * scaling, key_states.transpose(2, 3)) |
|
|
if attention_mask is not None: |
|
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
|
|
weights = weights + causal_mask |
|
|
|
|
|
weights = f.softmax(weights, dim=-1, dtype=torch.float32).to(query.dtype) |
|
|
weights = f.dropout(weights, p=dropout, training=module.training).to( |
|
|
value_states.dtype |
|
|
) |
|
|
out = torch.matmul(weights, value_states).to(query.dtype) |
|
|
out = out.transpose(1, 2).contiguous() |
|
|
|
|
|
return out, weights |
|
|
|
|
|
|
|
|
def rotate_half(x: torch.Tensor): |
|
|
b, nh, t, hs = x.size() |
|
|
x = x.view(b, nh, t, 2, hs // 2) |
|
|
x1, x2 = x.unbind(dim=-2) |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
|
def apply_rotary_positional_embeddings( |
|
|
x: torch.Tensor, |
|
|
cos: torch.Tensor, |
|
|
sin: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
return (x * cos + rotate_half(x) * sin).to(x.dtype) |
|
|
|
|
|
|
|
|
def apply_rope_to_qk(q, k, cos, sin): |
|
|
q_, k_ = q.float(), k.float() |
|
|
with torch.autocast(q.device.type, enabled=False): |
|
|
q_ = apply_rotary_positional_embeddings(q_, cos, sin) |
|
|
k_ = apply_rotary_positional_embeddings(k_, cos, sin) |
|
|
q = q_.type_as(q) |
|
|
k = k_.type_as(k) |
|
|
return q, k |
|
|
|
|
|
|
|
|
class MHSDPA(nn.Module): |
|
|
"""Multi Head Scaled Dot Product Attention.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: JinaAttentionConfig, |
|
|
hidden_size: int, |
|
|
output_size: Optional[int] = None, |
|
|
self_attn: bool = True, |
|
|
is_causal: bool = False, |
|
|
layer_idx: int = 0, |
|
|
attn_implementation: Optional[str] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
self.hidden_size = hidden_size |
|
|
self.n_heads = config.n_heads |
|
|
self.n_kv_heads = config.n_kv_heads or self.n_heads |
|
|
self.n_kv_groups = self.n_heads // self.n_kv_heads |
|
|
self.output_size = output_size or self.hidden_size |
|
|
|
|
|
|
|
|
self.num_key_value_groups = self.n_kv_groups |
|
|
self.is_causal = is_causal |
|
|
self.layer_idx = layer_idx |
|
|
self.sliding_window = config.sliding_window |
|
|
|
|
|
head_dim = config.head_dim |
|
|
if head_dim is None: |
|
|
assert self.hidden_size % self.n_heads == 0 |
|
|
head_dim = self.hidden_size // self.n_heads |
|
|
self.head_dim = head_dim |
|
|
|
|
|
self.scale = config.softmax_scale or self.head_dim**-0.5 |
|
|
self.scaling = self.scale |
|
|
|
|
|
|
|
|
if config.clip_qkv is not None: |
|
|
assert config.clip_qkv > 0 |
|
|
self.clip_qkv = config.clip_qkv |
|
|
|
|
|
self.fp32_attn = config.fp32 |
|
|
self.self_attn = self_attn |
|
|
self.fused_dims = ( |
|
|
self.n_heads * self.head_dim, |
|
|
self.n_kv_heads * self.head_dim, |
|
|
self.n_kv_heads * self.head_dim, |
|
|
) |
|
|
if self.self_attn: |
|
|
self.qkv_w = nn.Linear(self.hidden_size, sum(self.fused_dims), bias=False) |
|
|
else: |
|
|
self.q_w = nn.Linear( |
|
|
self.hidden_size, |
|
|
self.n_heads * self.head_dim, |
|
|
bias=False, |
|
|
) |
|
|
self.kv_w = nn.Linear( |
|
|
self.hidden_size, |
|
|
sum(self.fused_dims) - self.n_heads * self.head_dim, |
|
|
bias=False, |
|
|
) |
|
|
self.out = nn.Linear( |
|
|
self.n_heads * self.head_dim, |
|
|
self.output_size, |
|
|
bias=config.o_bias, |
|
|
) |
|
|
self.q_b = nn.Parameter( |
|
|
torch.zeros(self.n_heads * self.head_dim), |
|
|
requires_grad=config.q_bias, |
|
|
) |
|
|
self.k_b = nn.Parameter( |
|
|
torch.zeros(self.n_kv_heads * self.head_dim), |
|
|
requires_grad=config.k_bias, |
|
|
) |
|
|
self.v_b = nn.Parameter( |
|
|
torch.zeros(self.n_kv_heads * self.head_dim), |
|
|
requires_grad=config.v_bias, |
|
|
) |
|
|
self.q_lnorm = nn.Identity() |
|
|
self.k_lnorm = nn.Identity() |
|
|
self.v_lnorm = nn.Identity() |
|
|
self.inner_lnorm = nn.Identity() |
|
|
self.add_q_lnorm = config.q_lnorm |
|
|
self.add_k_lnorm = config.k_lnorm |
|
|
self.add_v_lnorm = config.v_lnorm |
|
|
self.qkv_lnorm_on_heads = config.qkv_lnorm_on_heads |
|
|
q_lnorm_size = ( |
|
|
self.head_dim if self.qkv_lnorm_on_heads else self.n_heads * self.head_dim |
|
|
) |
|
|
kv_lnorm_size = ( |
|
|
self.head_dim |
|
|
if self.qkv_lnorm_on_heads |
|
|
else self.n_kv_heads * self.head_dim |
|
|
) |
|
|
if self.add_q_lnorm: |
|
|
self.q_lnorm = build_layer_norm( |
|
|
config.lnorm_config, |
|
|
size=q_lnorm_size, |
|
|
elementwise_affine=config.lnorm_config.with_affine, |
|
|
) |
|
|
if self.add_k_lnorm: |
|
|
self.k_lnorm = build_layer_norm( |
|
|
config.lnorm_config, |
|
|
size=kv_lnorm_size, |
|
|
elementwise_affine=config.lnorm_config.with_affine, |
|
|
) |
|
|
if self.add_v_lnorm: |
|
|
self.v_lnorm = build_layer_norm( |
|
|
config.lnorm_config, |
|
|
size=kv_lnorm_size, |
|
|
elementwise_affine=config.lnorm_config.with_affine, |
|
|
) |
|
|
if config.inner_lnorm: |
|
|
self.inner_lnorm = build_layer_norm( |
|
|
config.lnorm_config, |
|
|
size=(self.n_heads * self.head_dim), |
|
|
elementwise_affine=config.lnorm_config.with_affine, |
|
|
) |
|
|
self.drop_p = config.dropout |
|
|
self.attn_interface, *_ = self._get_attention_interface( |
|
|
attn_implementation or 'eager', None, None |
|
|
) |
|
|
|
|
|
def _get_attention_interface( |
|
|
self, |
|
|
attn_implementation: str, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
is_causal: Optional[bool] = None, |
|
|
) -> Tuple[Callable, Optional[torch.Tensor], Optional[bool]]: |
|
|
if 'flash' in attn_implementation and self.fp32_attn: |
|
|
raise ValueError('Flash attention does not support fp32 attention') |
|
|
if self.sliding_window != -1 and 'flash' not in attn_implementation: |
|
|
raise ValueError('Sliding window attention requires flash attention') |
|
|
|
|
|
attn_interface: Callable = eager_attention_forward |
|
|
if attn_implementation != 'eager': |
|
|
attn_interface = ALL_ATTENTION_FUNCTIONS[attn_implementation] |
|
|
|
|
|
setattr(self.config, '_attn_implementation', attn_implementation) |
|
|
|
|
|
if 'flash' in attn_implementation: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attn_mask is not None: |
|
|
|
|
|
attn_mask = (attn_mask > -1).to(torch.int32) |
|
|
|
|
|
attn_mask = attn_mask.squeeze(1).max(dim=1)[0] |
|
|
elif 'sdpa' in attn_implementation: |
|
|
if attn_mask is not None and is_causal is not None: |
|
|
is_causal = False |
|
|
|
|
|
elif attn_implementation == 'eager': |
|
|
if is_causal: |
|
|
assert attn_mask is not None |
|
|
assert attn_mask.ndim == 4 |
|
|
|
|
|
return attn_interface, attn_mask, is_causal |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
xq: torch.Tensor, |
|
|
xk: Optional[torch.Tensor] = None, |
|
|
rope_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
attn_implementation: Optional[str] = None, |
|
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
if self.self_attn: |
|
|
|
|
|
qkv_b = torch.cat((self.q_b, self.k_b, self.v_b)) |
|
|
|
|
|
qkv = f.linear(xq, self.qkv_w.weight, qkv_b) |
|
|
if self.clip_qkv is not None: |
|
|
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) |
|
|
q, k, v = qkv.split(self.fused_dims, dim=-1) |
|
|
else: |
|
|
assert xk is not None |
|
|
q = f.linear(xq, self.q_w.weight, self.q_b) |
|
|
kv_b = torch.cat((self.k_b, self.v_b)) |
|
|
kv = f.linear(xk, self.kv_w.weight, kv_b) |
|
|
if self.clip_qkv is not None: |
|
|
q.clamp_(min=-self.clip_qkv, max=self.clip_qkv) |
|
|
kv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) |
|
|
k, v = kv.split(self.fused_dims[1:], dim=-1) |
|
|
|
|
|
b, tq, _ = q.size() |
|
|
_, tk, __ = k.size() |
|
|
og_dtype = k.dtype |
|
|
if self.fp32_attn: |
|
|
dtype = torch.float32 |
|
|
q = q.to(torch.float) |
|
|
k = k.to(torch.float) |
|
|
else: |
|
|
dtype = og_dtype |
|
|
|
|
|
|
|
|
if not self.qkv_lnorm_on_heads: |
|
|
q = self.q_lnorm(q).to(dtype=dtype) |
|
|
k = self.k_lnorm(k).to(dtype=dtype) |
|
|
v = self.v_lnorm(v).to(dtype=dtype) |
|
|
|
|
|
|
|
|
|
|
|
q = q.view(b, tq, self.n_heads, -1).transpose(1, 2) |
|
|
|
|
|
k = k.view(b, tk, self.n_kv_heads, -1).transpose(1, 2) |
|
|
|
|
|
v = v.view(b, tk, self.n_kv_heads, -1).transpose(1, 2) |
|
|
|
|
|
|
|
|
if self.qkv_lnorm_on_heads: |
|
|
q = self.q_lnorm(q).to(dtype=dtype) |
|
|
k = self.k_lnorm(k).to(dtype=dtype) |
|
|
v = self.v_lnorm(v).to(dtype=dtype) |
|
|
|
|
|
cache_kwargs: Dict[str, torch.Tensor] = {'cache_position': cache_position} |
|
|
if rope_embeddings is not None: |
|
|
cos, sin = rope_embeddings |
|
|
cache_kwargs['cos'] = cos |
|
|
cache_kwargs['sin'] = sin |
|
|
cos = cos.unsqueeze(1) |
|
|
sin = sin.unsqueeze(1) |
|
|
q, k = apply_rope_to_qk(q, k, cos, sin) |
|
|
|
|
|
if past_key_values is not None: |
|
|
k, v = past_key_values.update(k, v, self.layer_idx, cache_kwargs) |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qlen, klen = q.shape[-2], k.shape[-2] |
|
|
attention_mask = cast_attention_mask( |
|
|
attention_mask[:, :, klen - qlen : klen, :klen], dtype |
|
|
) |
|
|
|
|
|
attention_interface = self.attn_interface |
|
|
is_causal = self.is_causal |
|
|
if attn_implementation is not None: |
|
|
attention_interface, attention_mask, is_causal = ( |
|
|
self._get_attention_interface( |
|
|
attn_implementation, |
|
|
attention_mask, |
|
|
self.is_causal, |
|
|
) |
|
|
) |
|
|
|
|
|
if self.sliding_window != -1: |
|
|
kwargs['sliding_window'] = self.sliding_window |
|
|
if is_causal is not None: |
|
|
kwargs['is_causal'] = is_causal |
|
|
|
|
|
attn, weights = attention_interface( |
|
|
self, |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
attention_mask, |
|
|
dropout=0.0 if not self.training else self.drop_p, |
|
|
scaling=self.scaling, |
|
|
**kwargs, |
|
|
) |
|
|
attn = attn.to(og_dtype) |
|
|
attn = attn.view(b, tq, -1) |
|
|
out = self.inner_lnorm(attn) |
|
|
out = self.out(out) |
|
|
|
|
|
return out, weights |
|
|
|
|
|
|
|
|
""" |
|
|
FFN module. Adapted from AllenAI Molmo https://github.com/allenai/molmo |
|
|
""" |
|
|
|
|
|
|
|
|
class FFN(nn.Module): |
|
|
"""Feed-Forward Network.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: JinaFFNConfig, |
|
|
hidden_size: int, |
|
|
output_size: Optional[int] = None, |
|
|
layer_idx: int = 0, |
|
|
): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.hidden_size = hidden_size |
|
|
self.output_size = output_size or hidden_size |
|
|
self.intermediate_size = config.size or config.ratio * hidden_size |
|
|
self.layer_idx = layer_idx |
|
|
self.gated_activation = config.gated_activation |
|
|
self.use_bias = config.bias |
|
|
|
|
|
activation_type = config.activation_type.lower() |
|
|
self.act = ACT2FN[activation_type] |
|
|
|
|
|
intermediate_size = self.intermediate_size |
|
|
if self.gated_activation: |
|
|
intermediate_size = 2 * self.intermediate_size |
|
|
|
|
|
self.up = nn.Linear(self.hidden_size, intermediate_size, bias=self.use_bias) |
|
|
self.down = nn.Linear( |
|
|
self.intermediate_size, self.output_size, bias=self.use_bias |
|
|
) |
|
|
self.inner_lnorm = ( |
|
|
build_layer_norm(self.config.lnorm_config, self.intermediate_size) |
|
|
if config.inner_lnorm |
|
|
else nn.Identity() |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
if self.gated_activation: |
|
|
x = self.up(x) |
|
|
x, gate = x.chunk(2, dim=-1) |
|
|
return self.down(self.inner_lnorm(self.act(gate) * x)) |
|
|
return self.down(self.inner_lnorm(self.act(self.up(x)))) |
|
|
|
|
|
|
|
|
""" |
|
|
Transformer block. Adapted from AllenAI Molmo https://github.com/allenai/molmo |
|
|
""" |
|
|
|
|
|
|
|
|
class TransformerBlock(GradientCheckpointingLayer): |
|
|
def __init__( |
|
|
self, |
|
|
config: JinaTransformerBlockConfig, |
|
|
hidden_size: int, |
|
|
is_causal: bool = True, |
|
|
layer_idx: int = 0, |
|
|
attn_implementation: Optional[str] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.hidden_size = hidden_size |
|
|
self.is_causal = is_causal |
|
|
self.layer_idx = layer_idx |
|
|
self.drop_path = config.residual_path_dropout |
|
|
self.attn_lscale_init = config.attn_lscale_init |
|
|
self.ffn_lscale_init = config.ffn_lscale_init |
|
|
self.postnorm = config.postnorm |
|
|
|
|
|
self.attn = MHSDPA( |
|
|
config.attn_config, |
|
|
hidden_size=self.hidden_size, |
|
|
is_causal=is_causal, |
|
|
self_attn=True, |
|
|
layer_idx=layer_idx, |
|
|
attn_implementation=attn_implementation, |
|
|
) |
|
|
self.ffn = FFN( |
|
|
config.ffn_config, hidden_size=self.hidden_size, layer_idx=layer_idx |
|
|
) |
|
|
self.attn_drop = Dropout( |
|
|
config.residual_dropout, mask_p=config.residual_response_dropout |
|
|
) |
|
|
self.ffn_drop = Dropout( |
|
|
config.residual_dropout, mask_p=config.residual_response_dropout |
|
|
) |
|
|
self.path_drop = ( |
|
|
ResidualPathDropout(self.drop_path) |
|
|
if self.drop_path > 0.0 |
|
|
else nn.Identity() |
|
|
) |
|
|
self.attn_lnorm = build_layer_norm(config.lnorm_config, size=hidden_size) |
|
|
self.ffn_lnorm = build_layer_norm(config.lnorm_config, size=hidden_size) |
|
|
self.attn_lscale = nn.Identity() |
|
|
self.ffn_lscale = nn.Identity() |
|
|
if self.attn_lscale_init is not None: |
|
|
self.attn_lscale = LayerScale(self.hidden_size, self.attn_lscale_init) |
|
|
if self.ffn_lscale_init is not None: |
|
|
self.ffn_lscale = LayerScale(self.hidden_size, self.ffn_lscale_init) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
rope_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
drop_mask: Optional[torch.Tensor] = None, |
|
|
attn_implementation: Optional[str] = None, |
|
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
if not self.postnorm: |
|
|
x_norm = self.attn_lnorm(x) |
|
|
else: |
|
|
x_norm = x |
|
|
|
|
|
x_attn, x_attn_weights = self.attn( |
|
|
x_norm, |
|
|
rope_embeddings=rope_embeddings, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
cache_position=cache_position, |
|
|
attn_implementation=attn_implementation, |
|
|
**kwargs, |
|
|
) |
|
|
if self.postnorm: |
|
|
x_attn = self.attn_lnorm(x_attn) |
|
|
|
|
|
x_attn = self.path_drop(self.attn_lscale(x_attn)) |
|
|
x = x + self.attn_drop(x_attn, drop_mask=drop_mask) |
|
|
|
|
|
if not self.postnorm: |
|
|
x_norm = self.ffn_lnorm(x) |
|
|
else: |
|
|
x_norm = x |
|
|
|
|
|
x_ffn = self.ffn(x_norm) |
|
|
if self.postnorm: |
|
|
x_ffn = self.ffn_lnorm(x) |
|
|
|
|
|
x_ffn = self.path_drop(self.ffn_lscale(x_ffn)) |
|
|
x = x + self.ffn_drop(x_ffn, drop_mask=drop_mask) |
|
|
|
|
|
return x, x_attn_weights |
|
|
|
|
|
|
|
|
""" |
|
|
Vision Language Connector. Adapted from AllenAI Molmo https://github.com/allenai/molmo |
|
|
""" |
|
|
|
|
|
|
|
|
class VisionLanguageConnector(GradientCheckpointingLayer): |
|
|
"""Vision-Language Connector.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: JinaVLConnectorConfig, |
|
|
input_size: int, |
|
|
intermediate_size: int, |
|
|
output_size: int, |
|
|
n_patches: Tuple[int, int], |
|
|
attn_implementation: Optional[str] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.input_size = input_size |
|
|
self.intermediate_size = intermediate_size |
|
|
self.output_size = output_size |
|
|
self.n_patches = n_patches |
|
|
|
|
|
self.padding_embed_type = config.padding_embed_type |
|
|
self.pooling_type = config.pooling_type |
|
|
self.projector_type = config.projector_type |
|
|
self.spatial_merge_size = config.spatial_merge_size |
|
|
self.pooling_h = config.pooling_h |
|
|
self.pooling_w = config.pooling_w |
|
|
self.pad_embed = None |
|
|
self.pooling = None |
|
|
self.projector: Union[nn.Linear, nn.ModuleList, FFN] |
|
|
|
|
|
if config.padding_embed_type is not None: |
|
|
if config.padding_embed_type in { |
|
|
ImagePaddingEmbedType.regress, |
|
|
ImagePaddingEmbedType.pad_embed, |
|
|
}: |
|
|
self.pad_embed = nn.Parameter(torch.zeros((self.input_size,))) |
|
|
else: |
|
|
self.pad_embed = nn.Parameter(torch.zeros((2, self.input_size))) |
|
|
|
|
|
pooling_input_size = self.input_size |
|
|
projector_input_size = self.intermediate_size |
|
|
if config.pooling_type in { |
|
|
ImagePooling2DType.attention, |
|
|
ImagePooling2DType.attention_meanq, |
|
|
ImagePooling2DType.attention_2wide, |
|
|
}: |
|
|
assert config.attn_pooling_config is not None |
|
|
if config.pooling_type == ImagePooling2DType.attention_2wide: |
|
|
pooling_input_size *= 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_implementation = attn_implementation or 'eager' |
|
|
if attn_implementation.startswith('flash'): |
|
|
attn_implementation = 'sdpa' |
|
|
self.pooling = MHSDPA( |
|
|
config.attn_pooling_config, |
|
|
hidden_size=pooling_input_size, |
|
|
is_causal=False, |
|
|
self_attn=False, |
|
|
output_size=projector_input_size, |
|
|
attn_implementation=attn_implementation, |
|
|
) |
|
|
elif config.pooling_type in [ |
|
|
ImagePooling2DType.stack, |
|
|
ImagePooling2DType.token_merger, |
|
|
]: |
|
|
projector_input_size *= config.pooling_h * config.pooling_w |
|
|
|
|
|
if config.projector_type in { |
|
|
ImageProjectionType.mlpx2, |
|
|
ImageProjectionType.mlp, |
|
|
}: |
|
|
assert config.mlp_projector_config is not None |
|
|
mlp_projector_kwargs = dict( |
|
|
config=config.mlp_projector_config, |
|
|
hidden_size=projector_input_size, |
|
|
output_size=output_size, |
|
|
) |
|
|
if config.projector_type == ImageProjectionType.mlpx2: |
|
|
|
|
|
self.projector = nn.ModuleList( |
|
|
[FFN(**mlp_projector_kwargs), Residual(FFN(**mlp_projector_kwargs))] |
|
|
) |
|
|
else: |
|
|
self.projector = FFN(**mlp_projector_kwargs) |
|
|
else: |
|
|
self.projector = nn.Linear( |
|
|
projector_input_size, |
|
|
output_size, |
|
|
bias=False, |
|
|
) |
|
|
self.projector_dropout = Dropout(config.projector_dropout) |
|
|
self.feature_dropout = Dropout(config.feature_dropout) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
image_features: torch.Tensor, |
|
|
image_masks: Optional[torch.Tensor] = None, |
|
|
attn_implementation: Optional[str] = None, |
|
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
|
|
|
|
|
|
bs, ncrops = image_features.shape[:2] |
|
|
ogtype = image_features.dtype |
|
|
|
|
|
if self.padding_embed_type is not None: |
|
|
assert image_masks is not None |
|
|
if self.padding_embed_type == ImagePaddingEmbedType.pad_embed: |
|
|
all_pad = (image_masks == 0).to(dtype=torch.float32) |
|
|
pad_embed = self.pad_embed[None, None, None, :] |
|
|
image_features = image_features + pad_embed * torch.unsqueeze( |
|
|
all_pad, -1 |
|
|
) |
|
|
elif self.padding_embed_type == ImagePaddingEmbedType.regress: |
|
|
pad_embed = self.pad_embed[None, None, None, :] |
|
|
image_features = image_features + pad_embed * torch.unsqueeze( |
|
|
torch.maximum(image_masks, torch.zeros_like(image_masks)), -1 |
|
|
) |
|
|
else: |
|
|
pad_embed = self.pad_embed[:, None, None, None, :] |
|
|
all_pad = image_masks == 0 |
|
|
partial_pad = torch.logical_and( |
|
|
image_masks < 1, torch.logical_not(all_pad) |
|
|
).to(dtype=torch.float32) |
|
|
all_pad = all_pad.to(dtype=torch.float32) |
|
|
image_features = image_features + pad_embed[0] * torch.unsqueeze( |
|
|
all_pad, -1 |
|
|
) |
|
|
image_features = image_features + pad_embed[1] * torch.unsqueeze( |
|
|
partial_pad, -1 |
|
|
) |
|
|
|
|
|
image_features = image_features.to(dtype=ogtype) |
|
|
image_features = self.feature_dropout(image_features) |
|
|
image_features = image_features.reshape((bs, ncrops) + self.n_patches + (-1,)) |
|
|
pad_h = self.n_patches[0] % self.pooling_h |
|
|
pad_w = self.n_patches[1] % self.pooling_w |
|
|
if pad_h != 0 or pad_w != 0: |
|
|
|
|
|
image_features = f.pad( |
|
|
image_features, |
|
|
(0, 0, 0, pad_w, 0, pad_h, 0, 0, 0, 0), |
|
|
) |
|
|
if self.pooling_type == ImagePooling2DType.token_merger: |
|
|
context_dim = image_features.shape[-1] |
|
|
hidden_size = context_dim * (self.spatial_merge_size**2) |
|
|
image_features = image_features.view([-1, hidden_size]) |
|
|
else: |
|
|
image_features = einops.rearrange( |
|
|
image_features, |
|
|
'b n (h dh) (w dw) c -> (b n h w) (dh dw) c', |
|
|
dh=self.pooling_h, |
|
|
dw=self.pooling_w, |
|
|
) |
|
|
image_features = image_features.contiguous() |
|
|
if self.pooling_type == ImagePooling2DType.attention_meanq: |
|
|
query = image_features.mean(-2, keepdim=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_implementation = attn_implementation or 'eager' |
|
|
if attn_implementation.startswith('flash'): |
|
|
attn_implementation = 'sdpa' |
|
|
if attn_implementation == 'sdpa': |
|
|
with sdpa_kernel(backends=[SDPBackend.MATH]): |
|
|
image_features, _ = self.pooling( |
|
|
xq=query, |
|
|
xk=image_features, |
|
|
attn_implementation='sdpa', |
|
|
**kwargs, |
|
|
) |
|
|
else: |
|
|
image_features, _ = self.pooling( |
|
|
xq=query, |
|
|
xk=image_features, |
|
|
attn_implementation=attn_implementation, |
|
|
**kwargs, |
|
|
) |
|
|
elif self.pooling_type not in { |
|
|
ImagePooling2DType.none, |
|
|
ImagePooling2DType.stack, |
|
|
}: |
|
|
image_features = self.pooling(image_features[:, :1, :], image_features) |
|
|
|
|
|
h = self.n_patches[0] // self.pooling_h + pad_h |
|
|
w = self.n_patches[1] // self.pooling_w + pad_w |
|
|
|
|
|
image_features = image_features.reshape(bs, ncrops, h * w, -1) |
|
|
|
|
|
return self.projector(image_features) |
|
|
|