Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. | |
| # | |
| # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX | |
| # and OPT implementations in this library. It has been modified from its | |
| # original forms to accommodate minor architectural differences compared | |
| # to GPT-NeoX and OPT used by the Meta AI team that trained the model. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """PyTorch Qwen2-VL model.""" | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| from torch.nn import CrossEntropyLoss, LayerNorm | |
| from transformers.activations import ACT2FN | |
| from transformers.cache_utils import Cache, StaticCache | |
| from transformers.modeling_attn_mask_utils import AttentionMaskConverter | |
| from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.utils import (add_start_docstrings, | |
| add_start_docstrings_to_model_forward, | |
| is_flash_attn_2_available, | |
| is_flash_attn_greater_or_equal_2_10, logging, | |
| replace_return_docstrings) | |
| from .configuration_qwen2vl_encoder import Qwen2VLVisionConfig | |
| if is_flash_attn_2_available(): | |
| from flash_attn import flash_attn_varlen_func | |
| from transformers.modeling_flash_attention_utils import \ | |
| _flash_attention_forward | |
| else: | |
| flash_attn_varlen_func = None | |
| logger = logging.get_logger(__name__) | |
| # Copied from transformers.models.llama.modeling_llama.rotate_half | |
| def rotate_half(x): | |
| """Rotates half the hidden dims of the input.""" | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: | |
| orig_dtype = tensor.dtype | |
| tensor = tensor.float() | |
| cos = freqs.cos() | |
| sin = freqs.sin() | |
| cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() | |
| sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() | |
| output = (tensor * cos) + (rotate_half(tensor) * sin) | |
| output = output.to(orig_dtype) | |
| return output | |
| class VisionRotaryEmbedding(nn.Module): | |
| def __init__(self, dim: int, theta: float = 10000.0) -> None: | |
| super().__init__() | |
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| def forward(self, seqlen: int) -> torch.Tensor: | |
| seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) | |
| freqs = torch.outer(seq, self.inv_freq) | |
| return freqs | |
| class PatchEmbed(nn.Module): | |
| def __init__( | |
| self, | |
| patch_size: int = 14, | |
| temporal_patch_size: int = 2, | |
| in_channels: int = 3, | |
| embed_dim: int = 1152, | |
| ) -> None: | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.temporal_patch_size = temporal_patch_size | |
| self.in_channels = in_channels | |
| self.embed_dim = embed_dim | |
| kernel_size = [temporal_patch_size, patch_size, patch_size] | |
| self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| target_dtype = self.proj.weight.dtype | |
| hidden_states = hidden_states.view( | |
| -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size | |
| ) | |
| hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) | |
| return hidden_states | |
| class PatchMerger(nn.Module): | |
| def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: | |
| super().__init__() | |
| self.hidden_size = context_dim * (spatial_merge_size**2) | |
| self.ln_q = LayerNorm(context_dim, eps=1e-6) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(self.hidden_size, self.hidden_size), | |
| nn.GELU(), | |
| nn.Linear(self.hidden_size, dim), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) | |
| return x | |
| class VisionMlp(nn.Module): | |
| def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: | |
| super().__init__() | |
| self.fc1 = nn.Linear(dim, hidden_dim) | |
| self.act = ACT2FN[hidden_act] | |
| self.fc2 = nn.Linear(hidden_dim, dim) | |
| def forward(self, x) -> torch.Tensor: | |
| return self.fc2(self.act(self.fc1(x))) | |
| class VisionAttention(nn.Module): | |
| def __init__(self, dim: int, num_heads: int = 16) -> None: | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.qkv = nn.Linear(dim, dim * 3, bias=True) | |
| self.proj = nn.Linear(dim, dim) | |
| def forward( | |
| self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None | |
| ) -> torch.Tensor: | |
| seq_length = hidden_states.shape[0] | |
| q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) | |
| q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) | |
| k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) | |
| attention_mask = torch.full( | |
| [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype | |
| ) | |
| for i in range(1, len(cu_seqlens)): | |
| attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 | |
| q = q.transpose(0, 1) | |
| k = k.transpose(0, 1) | |
| v = v.transpose(0, 1) | |
| attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) | |
| attn_weights = attn_weights + attention_mask | |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) | |
| attn_output = torch.matmul(attn_weights, v) | |
| attn_output = attn_output.transpose(0, 1) | |
| attn_output = attn_output.reshape(seq_length, -1) | |
| attn_output = self.proj(attn_output) | |
| return attn_output | |
| class VisionFlashAttention2(nn.Module): | |
| def __init__(self, dim: int, num_heads: int = 16) -> None: | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.qkv = nn.Linear(dim, dim * 3, bias=True) | |
| self.proj = nn.Linear(dim, dim) | |
| def forward( | |
| self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None | |
| ) -> torch.Tensor: | |
| seq_length = hidden_states.shape[0] | |
| q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) | |
| q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) | |
| k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) | |
| max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() | |
| attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( | |
| seq_length, -1 | |
| ) | |
| attn_output = self.proj(attn_output) | |
| return attn_output | |
| class VisionSdpaAttention(nn.Module): | |
| def __init__(self, dim: int, num_heads: int = 16) -> None: | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.qkv = nn.Linear(dim, dim * 3, bias=True) | |
| self.proj = nn.Linear(dim, dim) | |
| def forward( | |
| self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None | |
| ) -> torch.Tensor: | |
| seq_length = hidden_states.shape[0] | |
| q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) | |
| q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) | |
| k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) | |
| attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) | |
| for i in range(1, len(cu_seqlens)): | |
| attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True | |
| q = q.transpose(0, 1) | |
| k = k.transpose(0, 1) | |
| v = v.transpose(0, 1) | |
| attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) | |
| attn_output = attn_output.transpose(0, 1) | |
| attn_output = attn_output.reshape(seq_length, -1) | |
| attn_output = self.proj(attn_output) | |
| return attn_output | |
| QWEN2_VL_VISION_ATTENTION_CLASSES = { | |
| "eager": VisionAttention, | |
| "flash_attention_2": VisionFlashAttention2, | |
| "sdpa": VisionSdpaAttention, | |
| } | |
| class Qwen2VLVisionBlock(nn.Module): | |
| def __init__(self, config, attn_implementation: str = "sdpa") -> None: | |
| super().__init__() | |
| self.norm1 = LayerNorm(config.embed_dim, eps=1e-6) | |
| self.norm2 = LayerNorm(config.embed_dim, eps=1e-6) | |
| mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio) | |
| self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation]( | |
| config.embed_dim, num_heads=config.num_heads | |
| ) | |
| self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act) | |
| def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: | |
| hidden_states = hidden_states + self.attn( | |
| self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb | |
| ) | |
| hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) | |
| return hidden_states | |
| class Qwen2VLPreTrainedModel(PreTrainedModel): | |
| config_class = Qwen2VLVisionConfig | |
| base_model_prefix = "model" | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["Qwen2VLVisionBlock"] | |
| _skip_keys_device_placement = "past_key_values" | |
| _supports_flash_attn_2 = True | |
| _supports_sdpa = True | |
| _supports_cache_class = True | |
| _supports_static_cache = True | |
| def _init_weights(self, module): | |
| std = self.config.initializer_range | |
| if isinstance(module, (nn.Linear, nn.Conv3d)): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): | |
| config_class = Qwen2VLVisionConfig | |
| _no_split_modules = ["Qwen2VLVisionBlock"] | |
| def __init__(self, config) -> None: | |
| super().__init__(config) | |
| self.spatial_merge_size = config.spatial_merge_size | |
| self.gradient_checkpointing = False | |
| self.patch_embed = PatchEmbed( | |
| patch_size=config.patch_size, | |
| temporal_patch_size=config.temporal_patch_size, | |
| in_channels=config.in_channels, | |
| embed_dim=config.embed_dim, | |
| ) | |
| head_dim = config.embed_dim // config.num_heads | |
| self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) | |
| self.blocks = nn.ModuleList( | |
| [Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] | |
| ) | |
| # | |
| # if self.spatial_merge_size > 1: | |
| # self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim) | |
| def get_dtype(self) -> torch.dtype: | |
| return self.blocks[0].mlp.fc2.weight.dtype | |
| def get_device(self) -> torch.device: | |
| return self.blocks[0].mlp.fc2.weight.device | |
| def rot_pos_emb(self, grid_thw, strides): | |
| pos_ids = [] | |
| for (t, h, w), stride in zip(grid_thw, strides): | |
| hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) | |
| hpos_ids = hpos_ids.reshape( | |
| h // stride, | |
| stride, | |
| w // stride, | |
| stride, | |
| ) | |
| hpos_ids = hpos_ids.permute(0, 2, 1, 3) | |
| hpos_ids = hpos_ids.flatten() | |
| wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) | |
| wpos_ids = wpos_ids.reshape( | |
| h // stride, | |
| stride, | |
| w // stride, | |
| stride, | |
| ) | |
| wpos_ids = wpos_ids.permute(0, 2, 1, 3) | |
| wpos_ids = wpos_ids.flatten() | |
| pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) | |
| pos_ids = torch.cat(pos_ids, dim=0) | |
| max_grid_size = grid_thw[:, 1:].max() | |
| rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) | |
| rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) | |
| return rotary_pos_emb | |
| def forward(self, hidden_states, grid_thws, strides) -> torch.Tensor: | |
| hidden_states = self.patch_embed(hidden_states) | |
| # BUG: These codes will cause deepspeed issue: `RuntimeError: disagreement between rank0 and rankx` | |
| # rotary_pos_emb = [] | |
| # for thw in grid_thws: | |
| # rotary_pos_emb.append(self.rot_pos_emb(thw).unsqueeze(0)) | |
| # rotary_pos_emb1 = torch.cat(rotary_pos_emb, dim=1).squeeze(0) | |
| # grid_thws = torch.cat(grid_thws, dim = 0) | |
| # new version of creating rotary position embedding | |
| # grid_thws shapes like [batch_flatten_image_num, 3] | |
| # grid_thws = torch.cat(grid_thws, dim = 0) # is conducted in the `encoder.py` | |
| rotary_pos_emb = self.rot_pos_emb(grid_thws, strides) | |
| cu_seqlens = torch.repeat_interleave(grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]).cumsum(dim=0, dtype=torch.int32) | |
| cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) | |
| for blk in self.blocks: | |
| if self.gradient_checkpointing and self.training: | |
| hidden_states = self._gradient_checkpointing_func( | |
| blk.__call__, | |
| hidden_states, | |
| cu_seqlens, | |
| rotary_pos_emb | |
| ) | |
| else: | |
| hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) | |
| # if self.spatial_merge_size > 1: | |
| # hidden_states = self.merger(hidden_states) | |
| return hidden_states | |