Commit
·
fd01e7c
0
Parent(s):
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +39 -0
- README.md +95 -0
- asset/NVIDIA_OneWay_Noncommercial_License.docx +0 -0
- asset/omni_benchmarks.png +3 -0
- asset/omni_benchmarks2.png +3 -0
- asset/performance.png +3 -0
- audio_encoder.py +70 -0
- auto_processor.py +476 -0
- base_projector.py +236 -0
- builder.py +253 -0
- config.json +0 -0
- configuration_vila.py +127 -0
- constants.py +83 -0
- conversation.py +189 -0
- distributed.py +89 -0
- environment_setup.sh +60 -0
- example_infer.py +335 -0
- example_mini_audio.py +89 -0
- example_mini_image.py +76 -0
- example_mini_video.py +101 -0
- llm/added_tokens.json +20 -0
- llm/config.json +0 -0
- llm/generation_config.json +14 -0
- llm/merges.txt +0 -0
- llm/model-00001-of-00004.safetensors +3 -0
- llm/model-00002-of-00004.safetensors +3 -0
- llm/model-00003-of-00004.safetensors +3 -0
- llm/model-00004-of-00004.safetensors +3 -0
- llm/model.safetensors.index.json +346 -0
- llm/special_tokens_map.json +39 -0
- llm/tokenizer.json +3 -0
- llm/tokenizer_config.json +165 -0
- llm/vocab.json +0 -0
- media.py +555 -0
- media_encoder.py +955 -0
- mm_projector/config.json +10 -0
- mm_projector/model.safetensors +3 -0
- mm_utils.py +567 -0
- model_utils_packing.py +50 -0
- modeling_vila.py +1834 -0
- preprocessor_config.json +13 -0
- pyproject.toml +59 -0
- qwen2.jinja +11 -0
- qwen_audio_encoder.py +89 -0
- siglip_encoder.py +293 -0
- sound_base_projector.py +126 -0
- sound_mm_projector/config.json +10 -0
- sound_mm_projector/model.safetensors +3 -0
- sound_tower/config.json +50 -0
- sound_tower/model.safetensors +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
asset/omni_benchmarks.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
asset/omni_benchmarks2.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
llm/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
asset/performance.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# <span style="background: linear-gradient(45deg, #667eea 0%, #764ba2 25%, #f093fb 50%, #f5576c 75%, #4facfe 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; font-weight: bold; font-size: 1.1em;">**OmniVinci: Enhancing Architecture and Data for Omni-Modal Understanding LLM**</span> <br />
|
| 2 |
+
|
| 3 |
+
[](https://arxiv.org/)
|
| 4 |
+
[](https://github.com/NVlabs)
|
| 5 |
+
[](https://huggingface.co/nvidia/omnivinci)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
## Introduction
|
| 9 |
+
OmniVinci is an NVIDIA research project focused on exploring omni-modal LLMs that can not only see and read but also listen, speak, and reason.
|
| 10 |
+
|
| 11 |
+
We are among the best omni-modality understanding models. Check out our performance on some of the most popular omni-modality, audio, and vision benchmarks:
|
| 12 |
+
<p align="center">
|
| 13 |
+
<img src="./asset/performance.png" width="80%"/>
|
| 14 |
+
<p>
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
## Quickstart
|
| 18 |
+
|
| 19 |
+
Below, we provide simple examples to show how to use our model with Transformers.
|
| 20 |
+
|
| 21 |
+
### Environment Setup
|
| 22 |
+
|
| 23 |
+
1. Download and navigate to the HuggingFace repository:
|
| 24 |
+
```
|
| 25 |
+
huggingface-cli download nvidia/omnivinci --local-dir ./omnivinci --local-dir-use-symlinks False
|
| 26 |
+
cd ./omnivinci
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
2. Install Python environment (based on NVILA codebase):
|
| 30 |
+
```
|
| 31 |
+
bash ./environment_setup.sh omnivinci
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
### 🤗 Transformers Usage
|
| 35 |
+
|
| 36 |
+
#### Video (with Audio) Inference Example
|
| 37 |
+
```python
|
| 38 |
+
from transformers import AutoProcessor, AutoModel, AutoConfig,AutoModelForCausalLM
|
| 39 |
+
import torch
|
| 40 |
+
import os
|
| 41 |
+
|
| 42 |
+
# default: Load the model on the available device(s)
|
| 43 |
+
model_path = "./"
|
| 44 |
+
video_path = "xxx.mp4"
|
| 45 |
+
generation_kwargs = {"max_new_tokens": 1024, "max_length": 99999999}
|
| 46 |
+
load_audio_in_video = True
|
| 47 |
+
num_video_frames = 128
|
| 48 |
+
audio_length = "max_3600"
|
| 49 |
+
|
| 50 |
+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
| 51 |
+
|
| 52 |
+
model = AutoModel.from_pretrained(model_path,
|
| 53 |
+
trust_remote_code=True,
|
| 54 |
+
torch_dtype="torch.float16",
|
| 55 |
+
device_map="auto")
|
| 56 |
+
|
| 57 |
+
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
| 58 |
+
generation_config = model.default_generation_config
|
| 59 |
+
generation_config.update(**generation_kwargs)
|
| 60 |
+
|
| 61 |
+
model.config.load_audio_in_video = load_audio_in_video
|
| 62 |
+
processor.config.load_audio_in_video = load_audio_in_video
|
| 63 |
+
if num_video_frames > 0:
|
| 64 |
+
model.config.num_video_frames = num_video_frames
|
| 65 |
+
processor.config.num_video_frames = num_video_frames
|
| 66 |
+
if audio_length != -1:
|
| 67 |
+
model.config.audio_chunk_length = audio_length
|
| 68 |
+
processor.config.audio_chunk_length = audio_length
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
conversation = [{
|
| 72 |
+
"role": "user",
|
| 73 |
+
"content": [
|
| 74 |
+
{"type": "video", "video":video_path},
|
| 75 |
+
{"type": "text", "text": "Assess the video, followed by a detailed description of its video and audio contents."}
|
| 76 |
+
]
|
| 77 |
+
}]
|
| 78 |
+
text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
|
| 79 |
+
|
| 80 |
+
inputs = processor([text])
|
| 81 |
+
|
| 82 |
+
output_ids = model.generate(
|
| 83 |
+
input_ids=inputs.input_ids,
|
| 84 |
+
media=getattr(inputs, 'media', None),
|
| 85 |
+
media_config=getattr(inputs, 'media_config', None),
|
| 86 |
+
generation_config=generation_config,
|
| 87 |
+
)
|
| 88 |
+
print(processor.tokenizer.batch_decode(output_ids, skip_special_tokens=True))
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
- **For audio and image inference examples, please refer to `example_mini_audio.py` and `example_mini_image.py`.**
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
## License / Terms of Use
|
| 95 |
+
The model is released under the [NVIDIA OneWay Noncommercial License](asset/NVIDIA_OneWay_Noncommercial_License.docx).
|
asset/NVIDIA_OneWay_Noncommercial_License.docx
ADDED
|
Binary file (20.6 kB). View file
|
|
|
asset/omni_benchmarks.png
ADDED
|
Git LFS Details
|
asset/omni_benchmarks2.png
ADDED
|
Git LFS Details
|
asset/performance.png
ADDED
|
Git LFS Details
|
audio_encoder.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class AudioTower(nn.Module):
|
| 23 |
+
def __init__(self, audio_tower, args, delay_load=False):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.is_loaded = False
|
| 27 |
+
|
| 28 |
+
self.audio_tower_name = audio_tower
|
| 29 |
+
self.cfg_only = None
|
| 30 |
+
|
| 31 |
+
def forward(self, sounds):
|
| 32 |
+
if type(sounds) is list:
|
| 33 |
+
sound_features = []
|
| 34 |
+
audio_output_lengths = []
|
| 35 |
+
for sound in sounds:
|
| 36 |
+
if hasattr(sound, "input_features"):
|
| 37 |
+
sound = sound["input_features"]
|
| 38 |
+
sound_feature = self.audio_tower(sound)
|
| 39 |
+
sound_feature = sound_feature.last_hidden_state
|
| 40 |
+
sound_feature = sound_feature.to(sound.dtype)
|
| 41 |
+
sound_features.append(sound_feature)
|
| 42 |
+
audio_output_lengths.append(sound_feature.shape[1])
|
| 43 |
+
sound_features = torch.cat(sound_features, dim=1).squeeze(0)
|
| 44 |
+
else:
|
| 45 |
+
raise NotImplementedError("Not implemented for this encoder")
|
| 46 |
+
|
| 47 |
+
return sound_features, audio_output_lengths
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def dummy_feature(self):
|
| 51 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def dtype(self):
|
| 55 |
+
return self.audio_tower.dtype
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def config(self):
|
| 59 |
+
if self.is_loaded:
|
| 60 |
+
return self.audio_tower.config
|
| 61 |
+
else:
|
| 62 |
+
return self.cfg_only
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def device(self):
|
| 66 |
+
return self.audio_tower.device
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def hidden_size(self):
|
| 70 |
+
return self.config.hidden_size
|
auto_processor.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import copy
|
| 17 |
+
import os
|
| 18 |
+
import os.path as osp
|
| 19 |
+
import warnings
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
from io import BytesIO
|
| 22 |
+
from typing import List, Optional, Union
|
| 23 |
+
|
| 24 |
+
import PIL.Image
|
| 25 |
+
import requests
|
| 26 |
+
import torch
|
| 27 |
+
from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoProcessor, AutoTokenizer
|
| 28 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 29 |
+
from transformers.image_utils import ImageInput
|
| 30 |
+
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
| 31 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 32 |
+
from transformers.utils import logging
|
| 33 |
+
|
| 34 |
+
from .constants import DEFAULT_IMAGE_TOKEN, MEDIA_TOKENS
|
| 35 |
+
from .media import Image, Video, extract_media, Sound
|
| 36 |
+
from .mm_utils import process_image, process_images
|
| 37 |
+
from .tokenizer_utils import tokenize_conversation
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def to_rgb(pil_image: PIL.Image.Image) -> PIL.Image.Image:
|
| 41 |
+
"""Convert PIL image to RGB format."""
|
| 42 |
+
if pil_image.mode == "RGBA":
|
| 43 |
+
white_background = PIL.Image.new("RGB", pil_image.size, (255, 255, 255))
|
| 44 |
+
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
|
| 45 |
+
return white_background
|
| 46 |
+
else:
|
| 47 |
+
return pil_image.convert("RGB")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def fetch_image(ele: dict[str, str | PIL.Image.Image], size_factor=None) -> PIL.Image.Image:
|
| 51 |
+
"""Fetch and load image from various sources (local path, URL, base64, PIL.Image)."""
|
| 52 |
+
if "image" in ele:
|
| 53 |
+
image = ele["image"]
|
| 54 |
+
else:
|
| 55 |
+
image = ele["image_url"]
|
| 56 |
+
image_obj = None
|
| 57 |
+
if isinstance(image, PIL.Image.Image):
|
| 58 |
+
image_obj = image
|
| 59 |
+
elif image.startswith("http://") or image.startswith("https://"):
|
| 60 |
+
response = requests.get(image, stream=True)
|
| 61 |
+
image_obj = PIL.Image.open(BytesIO(response.content))
|
| 62 |
+
elif image.startswith("file://"):
|
| 63 |
+
image_obj = PIL.Image.open(image[7:])
|
| 64 |
+
elif image.startswith("data:image"):
|
| 65 |
+
if "base64," in image:
|
| 66 |
+
_, base64_data = image.split("base64,", 1)
|
| 67 |
+
data = base64.b64decode(base64_data)
|
| 68 |
+
image_obj = PIL.Image.open(BytesIO(data))
|
| 69 |
+
else:
|
| 70 |
+
image_obj = PIL.Image.open(image)
|
| 71 |
+
if image_obj is None:
|
| 72 |
+
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
|
| 73 |
+
image = to_rgb(image_obj)
|
| 74 |
+
|
| 75 |
+
return image
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def fetch_image_url_or_fpath(url_or_fpath):
|
| 79 |
+
"""Fetch image from URL or local file path, returns local file path."""
|
| 80 |
+
if url_or_fpath.startswith("http") or url_or_fpath.startswith("https"):
|
| 81 |
+
import tempfile
|
| 82 |
+
|
| 83 |
+
import requests
|
| 84 |
+
|
| 85 |
+
# Download the image to a temporary file
|
| 86 |
+
temp_dir = tempfile.mkdtemp()
|
| 87 |
+
temp_file = os.path.join(temp_dir, os.path.basename(url_or_fpath))
|
| 88 |
+
|
| 89 |
+
response = requests.get(url_or_fpath, stream=True)
|
| 90 |
+
response.raise_for_status()
|
| 91 |
+
|
| 92 |
+
with open(temp_file, "wb") as f:
|
| 93 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 94 |
+
f.write(chunk)
|
| 95 |
+
|
| 96 |
+
return temp_file
|
| 97 |
+
elif url_or_fpath.startswith("file://"):
|
| 98 |
+
fpath = url_or_fpath.replace("file://", "")
|
| 99 |
+
assert osp.exists(fpath), f"File {fpath} does not exist"
|
| 100 |
+
return fpath
|
| 101 |
+
elif osp.exists(url_or_fpath):
|
| 102 |
+
assert osp.isfile(url_or_fpath), f"File {url_or_fpath} does not exist"
|
| 103 |
+
return url_or_fpath
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError(f"Unsupported image path: {url_or_fpath}")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def pad_fn(input_ids_list: List[torch.Tensor], padding_value=0, target_len=None, padding_side="left") -> torch.Tensor:
|
| 109 |
+
# tensor shape is (batch_size, seq_len)
|
| 110 |
+
max_len = max([ids.shape[1] for ids in input_ids_list])
|
| 111 |
+
if target_len is not None:
|
| 112 |
+
assert target_len >= max_len, "target_len must be greater than or equal to max_len"
|
| 113 |
+
max_len = target_len
|
| 114 |
+
|
| 115 |
+
new_input_ids_list = []
|
| 116 |
+
for i, input_ids in enumerate(input_ids_list):
|
| 117 |
+
pad_tensor = torch.ones_like(input_ids) * padding_value
|
| 118 |
+
curr_len = input_ids.shape[1]
|
| 119 |
+
pad_tensor = pad_tensor[:, : max_len - curr_len]
|
| 120 |
+
if padding_side == "right":
|
| 121 |
+
input_ids = torch.cat((input_ids, pad_tensor), dim=1)
|
| 122 |
+
else:
|
| 123 |
+
input_ids = torch.cat((pad_tensor, input_ids), dim=1)
|
| 124 |
+
new_input_ids_list.append(input_ids)
|
| 125 |
+
return torch.cat(new_input_ids_list, dim=0)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def extract_value_from_conv(chat):
|
| 129 |
+
value = []
|
| 130 |
+
if isinstance(chat["content"], str):
|
| 131 |
+
value.append(chat["content"])
|
| 132 |
+
return value
|
| 133 |
+
|
| 134 |
+
# otherwise, it's a list of content
|
| 135 |
+
for content in chat["content"]:
|
| 136 |
+
if content["type"] == "image":
|
| 137 |
+
if "path" in content:
|
| 138 |
+
# VILA style, can be either filepath or http url
|
| 139 |
+
value.append(Image(fetch_image_url_or_fpath(content["path"])))
|
| 140 |
+
elif "image" in content:
|
| 141 |
+
# Qwen style
|
| 142 |
+
value.append(Image(fetch_image_url_or_fpath(content["image"])))
|
| 143 |
+
elif "image_pil" in content:
|
| 144 |
+
# Qwen style
|
| 145 |
+
assert isinstance(content["image_pil"], PIL.Image.Image), f"Type of image_pil must be PIL.Image.Image"
|
| 146 |
+
value.append(content["image_pil"])
|
| 147 |
+
else:
|
| 148 |
+
raise ValueError(f"Type = `image` , but no `path` or `image` in {chat['content']}")
|
| 149 |
+
elif content["type"] == "video":
|
| 150 |
+
if "video" in content:
|
| 151 |
+
# Qwen style
|
| 152 |
+
value.append(Video(fetch_image_url_or_fpath(content["video"])))
|
| 153 |
+
else:
|
| 154 |
+
raise ValueError(f"Type = `video` , but no `video` in {chat['content']}")
|
| 155 |
+
elif content["type"] == "text":
|
| 156 |
+
value.append(content["text"])
|
| 157 |
+
elif content["type"] == "audio":
|
| 158 |
+
value.append(Sound(fetch_image_url_or_fpath(content["audio"])))
|
| 159 |
+
elif content["type"] == "sound":
|
| 160 |
+
value.append(Sound(fetch_image_url_or_fpath(content["sound"])))
|
| 161 |
+
elif content["type"] == "speech":
|
| 162 |
+
value.append(Sound(fetch_image_url_or_fpath(content["speech"])))
|
| 163 |
+
else:
|
| 164 |
+
raise ValueError(f"Unsupported content type: {content['type']}")
|
| 165 |
+
return value
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class VILAProcessorKwargs(ProcessingKwargs, total=False):
|
| 169 |
+
_defaults = {
|
| 170 |
+
"text_kwargs": {
|
| 171 |
+
"padding": False,
|
| 172 |
+
},
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class VILAProcessor(ProcessorMixin):
|
| 177 |
+
attributes = []
|
| 178 |
+
valid_kwargs = []
|
| 179 |
+
|
| 180 |
+
def __init__(
|
| 181 |
+
self, image_processor=None, tokenizer=None, chat_template=None, config=None, padding_side="left", **kwargs
|
| 182 |
+
):
|
| 183 |
+
self.image_token = MEDIA_TOKENS["image"]
|
| 184 |
+
self.video_token = MEDIA_TOKENS["video"]
|
| 185 |
+
self.speech_token = MEDIA_TOKENS["speech"]
|
| 186 |
+
self.sound_token = MEDIA_TOKENS["sound"]
|
| 187 |
+
self.config = config
|
| 188 |
+
self.image_processor = image_processor
|
| 189 |
+
self.tokenizer = tokenizer
|
| 190 |
+
self.padding_side = padding_side
|
| 191 |
+
|
| 192 |
+
# Use <|endoftext|> token as padding token for Qwen models
|
| 193 |
+
self.pad_token_id = self.tokenizer("<|endoftext|>").input_ids[0]
|
| 194 |
+
self.eos_token_id = self.tokenizer.eos_token_id
|
| 195 |
+
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
| 196 |
+
|
| 197 |
+
@staticmethod
|
| 198 |
+
def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
| 199 |
+
"""
|
| 200 |
+
Extract vision information from conversations.
|
| 201 |
+
Reference: qwen_vl_utils
|
| 202 |
+
"""
|
| 203 |
+
vision_infos = []
|
| 204 |
+
if isinstance(conversations[0], dict):
|
| 205 |
+
conversations = [conversations]
|
| 206 |
+
for conversation in conversations:
|
| 207 |
+
for message in conversation:
|
| 208 |
+
if isinstance(message["content"], list):
|
| 209 |
+
for ele in message["content"]:
|
| 210 |
+
if (
|
| 211 |
+
"image" in ele
|
| 212 |
+
or "image_url" in ele
|
| 213 |
+
or "video" in ele
|
| 214 |
+
or ele["type"] in ("image", "image_url", "video")
|
| 215 |
+
):
|
| 216 |
+
vision_infos.append(ele)
|
| 217 |
+
return vision_infos
|
| 218 |
+
|
| 219 |
+
@staticmethod
|
| 220 |
+
def process_vision_info(
|
| 221 |
+
conversations: list[dict] | list[list[dict]],
|
| 222 |
+
return_video_kwargs: bool = False,
|
| 223 |
+
) -> tuple[list[PIL.Image.Image] | None, list[torch.Tensor | list[PIL.Image.Image]] | None, Optional[dict]]:
|
| 224 |
+
"""
|
| 225 |
+
Process vision information from conversations.
|
| 226 |
+
Reference: qwen_vl_utils
|
| 227 |
+
|
| 228 |
+
Note: NVILA does not depend on this function, but maintains the same interface.
|
| 229 |
+
"""
|
| 230 |
+
vision_infos = extract_vision_info(conversations)
|
| 231 |
+
# Read images or videos
|
| 232 |
+
image_inputs = []
|
| 233 |
+
video_inputs = []
|
| 234 |
+
video_sample_fps_list = []
|
| 235 |
+
for vision_info in vision_infos:
|
| 236 |
+
if "image" in vision_info or "image_url" in vision_info:
|
| 237 |
+
image_inputs.append(fetch_image(vision_info))
|
| 238 |
+
elif "video" in vision_info:
|
| 239 |
+
video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
|
| 240 |
+
video_sample_fps_list.append(video_sample_fps)
|
| 241 |
+
video_inputs.append(video_input)
|
| 242 |
+
else:
|
| 243 |
+
raise ValueError("image, image_url or video should in content.")
|
| 244 |
+
if len(image_inputs) == 0:
|
| 245 |
+
image_inputs = None
|
| 246 |
+
if len(video_inputs) == 0:
|
| 247 |
+
video_inputs = None
|
| 248 |
+
if return_video_kwargs:
|
| 249 |
+
return image_inputs, video_inputs, {"fps": video_sample_fps_list}
|
| 250 |
+
return image_inputs, video_inputs
|
| 251 |
+
|
| 252 |
+
@staticmethod
|
| 253 |
+
def move_data_to_device(cls, prompt_inputs):
|
| 254 |
+
def _move_data_to_device(item):
|
| 255 |
+
# wrap function grpo trainer _prepare_input
|
| 256 |
+
kwargs = {"device": cls.args.device}
|
| 257 |
+
if cls.is_deepspeed_enabled and (torch.is_floating_point(item) or torch.is_complex(item)):
|
| 258 |
+
kwargs.update({"dtype": cls.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
|
| 259 |
+
return item.to(**kwargs)
|
| 260 |
+
|
| 261 |
+
prompt_inputs.input_ids = _move_data_to_device(prompt_inputs.input_ids)
|
| 262 |
+
prompt_inputs.attention_mask = _move_data_to_device(prompt_inputs.attention_mask)
|
| 263 |
+
if "image" in prompt_inputs.media:
|
| 264 |
+
prompt_inputs.media["image"] = [_move_data_to_device(img) for img in prompt_inputs.media["image"]]
|
| 265 |
+
return prompt_inputs
|
| 266 |
+
|
| 267 |
+
@classmethod
|
| 268 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 269 |
+
padding_side = kwargs.get("padding_side", "left")
|
| 270 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
| 271 |
+
pretrained_model_name_or_path = pretrained_model_name_or_path
|
| 272 |
+
else:
|
| 273 |
+
print(f"pretrained_model_name_or_path {pretrained_model_name_or_path} is not a directory, downloading")
|
| 274 |
+
from huggingface_hub import snapshot_download
|
| 275 |
+
|
| 276 |
+
pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path)
|
| 277 |
+
|
| 278 |
+
image_processor = AutoImageProcessor.from_pretrained(
|
| 279 |
+
osp.join(pretrained_model_name_or_path, "vision_tower"), trust_remote_code=True
|
| 280 |
+
)
|
| 281 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 282 |
+
osp.join(pretrained_model_name_or_path, "llm"), trust_remote_code=True
|
| 283 |
+
)
|
| 284 |
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
|
| 285 |
+
|
| 286 |
+
return cls(image_processor=image_processor, tokenizer=tokenizer, config=config, padding_side=padding_side)
|
| 287 |
+
|
| 288 |
+
def __repr__(self):
|
| 289 |
+
return f"VILAProcessor(image_processor=SigLip, tokenizer={self.tokenizer}, config={self.config})"
|
| 290 |
+
|
| 291 |
+
def __call__(
|
| 292 |
+
self,
|
| 293 |
+
conversation=None,
|
| 294 |
+
**kwargs: Unpack[VILAProcessorKwargs],
|
| 295 |
+
) -> BatchFeature:
|
| 296 |
+
"""
|
| 297 |
+
The `conv` will be look like
|
| 298 |
+
[
|
| 299 |
+
{
|
| 300 |
+
'from': 'human',
|
| 301 |
+
'value': [
|
| 302 |
+
<transformers_modules.NVILA-Lite-2B-hf-preview.media.Image object at 0x154e68e4c460>,
|
| 303 |
+
'What are the common elements in these pictures?'
|
| 304 |
+
]
|
| 305 |
+
}
|
| 306 |
+
]
|
| 307 |
+
and `conversation` will be a list of such `conv`s
|
| 308 |
+
"""
|
| 309 |
+
if kwargs.get("text", None) is not None:
|
| 310 |
+
conversation = kwargs.get("text")
|
| 311 |
+
assert conversation is not None, "`conversation` or `text` is required"
|
| 312 |
+
padding_side = kwargs.get("padding_side", self.padding_side)
|
| 313 |
+
|
| 314 |
+
input_ids_list = []
|
| 315 |
+
attention_mask = []
|
| 316 |
+
media = defaultdict(list)
|
| 317 |
+
media_config = defaultdict(dict)
|
| 318 |
+
for conv in conversation:
|
| 319 |
+
feat = self.__single_call__(conv, **kwargs)
|
| 320 |
+
input_ids_list.append(feat.input_ids)
|
| 321 |
+
attention_mask.append(feat.attention_mask)
|
| 322 |
+
for name in feat.media:
|
| 323 |
+
media[name] += feat.media[name]
|
| 324 |
+
for name in feat.media_config:
|
| 325 |
+
media_config[name].update(feat.media_config[name])
|
| 326 |
+
|
| 327 |
+
# pad the input_ids to batchfy
|
| 328 |
+
input_ids = pad_fn(
|
| 329 |
+
input_ids_list,
|
| 330 |
+
padding_value=self.pad_token_id,
|
| 331 |
+
padding_side=padding_side,
|
| 332 |
+
)
|
| 333 |
+
# Ignore the pad token in the attention mask
|
| 334 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
| 335 |
+
attention_mask[input_ids == self.pad_token_id] = False
|
| 336 |
+
input_texts = self.tokenizer.batch_decode(input_ids)
|
| 337 |
+
bdata = BatchFeature(
|
| 338 |
+
data={
|
| 339 |
+
# "input_texts": input_texts,
|
| 340 |
+
"input_ids": input_ids,
|
| 341 |
+
"attention_mask": attention_mask,
|
| 342 |
+
"media": media,
|
| 343 |
+
"media_config": media_config,
|
| 344 |
+
}
|
| 345 |
+
)
|
| 346 |
+
return bdata
|
| 347 |
+
|
| 348 |
+
def __single_call__(
|
| 349 |
+
self,
|
| 350 |
+
conversation,
|
| 351 |
+
images: ImageInput = None,
|
| 352 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| 353 |
+
videos = None,
|
| 354 |
+
**kwargs: Unpack[VILAProcessorKwargs],
|
| 355 |
+
) -> BatchFeature:
|
| 356 |
+
conversation = copy.deepcopy(conversation)
|
| 357 |
+
media = extract_media(conversation, self.config)
|
| 358 |
+
# Process media
|
| 359 |
+
media_config = defaultdict(dict)
|
| 360 |
+
for name in media:
|
| 361 |
+
if name == "image":
|
| 362 |
+
if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
|
| 363 |
+
self.config.image_processor = self.image_processor
|
| 364 |
+
if self.config.image_aspect_ratio == "dynamic":
|
| 365 |
+
images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
|
| 366 |
+
# Note: This assumes images appear at the first conversation position
|
| 367 |
+
conversation[0]["value"] = conversation[0]["value"].replace(
|
| 368 |
+
DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
|
| 369 |
+
)
|
| 370 |
+
else:
|
| 371 |
+
if type(self.config.s2_scales) is str:
|
| 372 |
+
self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
|
| 373 |
+
images, block_sizes = process_image(
|
| 374 |
+
media["image"][0], self.config, None, enable_dynamic_s2=True
|
| 375 |
+
)
|
| 376 |
+
images = images.half()
|
| 377 |
+
media_config[name]["block_sizes"] = [block_sizes]
|
| 378 |
+
else:
|
| 379 |
+
images = process_images(media["image"], self.image_processor, self.config).half()
|
| 380 |
+
media[name] = [image for image in images]
|
| 381 |
+
elif name == "video":
|
| 382 |
+
media[name] = [
|
| 383 |
+
process_images(images, self.image_processor, self.config).half() for images in media[name]
|
| 384 |
+
]
|
| 385 |
+
elif name == "speech":
|
| 386 |
+
speeches = media["speech"]
|
| 387 |
+
media[name] = [speech for speech in speeches]
|
| 388 |
+
elif name == "sound":
|
| 389 |
+
sounds = media["sound"]
|
| 390 |
+
for sound in sounds:
|
| 391 |
+
if type(sound) is dict:
|
| 392 |
+
for k, v in sound.items():
|
| 393 |
+
sound[k] = v.half()
|
| 394 |
+
media[name] = [sound for sound in sounds]
|
| 395 |
+
elif name == "video_info":
|
| 396 |
+
media[name] = [media["video_info"]]
|
| 397 |
+
elif name == "audio_info":
|
| 398 |
+
media[name] = [media["audio_info"]]
|
| 399 |
+
else:
|
| 400 |
+
raise ValueError(f"Unsupported media type: {name}")
|
| 401 |
+
|
| 402 |
+
inputs = tokenize_conversation(
|
| 403 |
+
conversation,
|
| 404 |
+
self.tokenizer,
|
| 405 |
+
mm_use_bos_eos_tokens=self.config.mm_use_bos_eos_tokens,
|
| 406 |
+
unified_audio_encoder=self.config.unified_audio_encoder,
|
| 407 |
+
add_generation_prompt=True,
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
input_ids = inputs.unsqueeze(0)
|
| 411 |
+
|
| 412 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
| 413 |
+
return BatchFeature(
|
| 414 |
+
data={
|
| 415 |
+
"input_ids": input_ids,
|
| 416 |
+
"attention_mask": attention_mask,
|
| 417 |
+
"media": media,
|
| 418 |
+
"media_config": media_config,
|
| 419 |
+
}
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
def batch_decode(self, *args, **kwargs):
|
| 423 |
+
"""
|
| 424 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 425 |
+
refer to the docstring of this method for more information.
|
| 426 |
+
"""
|
| 427 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 428 |
+
|
| 429 |
+
def decode(self, *args, **kwargs):
|
| 430 |
+
"""
|
| 431 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 432 |
+
the docstring of this method for more information.
|
| 433 |
+
"""
|
| 434 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 435 |
+
|
| 436 |
+
def post_process_image_text_to_text(self, generated_outputs):
|
| 437 |
+
"""
|
| 438 |
+
Post-process the output of the model to decode the text.
|
| 439 |
+
|
| 440 |
+
Args:
|
| 441 |
+
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
| 442 |
+
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
| 443 |
+
or `(sequence_length,)`.
|
| 444 |
+
|
| 445 |
+
Returns:
|
| 446 |
+
`List[str]`: The decoded text.
|
| 447 |
+
"""
|
| 448 |
+
return self.tokenizer.batch_decode(
|
| 449 |
+
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
@property
|
| 453 |
+
def model_input_names(self):
|
| 454 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 455 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 456 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 457 |
+
|
| 458 |
+
def convert_gpt_conv_to_vila_conv(self, conversation):
|
| 459 |
+
vila_conv = []
|
| 460 |
+
for chat in conversation:
|
| 461 |
+
vila_chat = {"from": "", "value": []}
|
| 462 |
+
if chat["role"] in ("user", "system"):
|
| 463 |
+
# user allows to input image and text
|
| 464 |
+
vila_chat["from"] = "human" if chat["role"] == "user" else "system"
|
| 465 |
+
vila_chat["value"] = extract_value_from_conv(chat)
|
| 466 |
+
elif chat["role"] == "assistant":
|
| 467 |
+
vila_chat["from"] = "gpt"
|
| 468 |
+
vila_chat["value"] = extract_value_from_conv(chat)
|
| 469 |
+
else:
|
| 470 |
+
raise ValueError(f"Unsupported role: {chat['role']} in chat {chat}")
|
| 471 |
+
vila_conv.append(vila_chat)
|
| 472 |
+
|
| 473 |
+
return vila_conv
|
| 474 |
+
|
| 475 |
+
def apply_chat_template(self, conversation, add_generation_prompt=True, **kwargs):
|
| 476 |
+
return self.convert_gpt_conv_to_vila_conv(conversation)
|
base_projector.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import re
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class IdentityMap(nn.Module):
|
| 25 |
+
"""Identity mapping that returns input unchanged."""
|
| 26 |
+
def __init__(self):
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
def forward(self, x, *args, **kwargs):
|
| 30 |
+
return x
|
| 31 |
+
|
| 32 |
+
@property
|
| 33 |
+
def config(self):
|
| 34 |
+
return {"mm_projector_type": "identity"}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SimpleResBlock(nn.Module):
|
| 38 |
+
"""Simple residual block with layer normalization."""
|
| 39 |
+
|
| 40 |
+
def __init__(self, channels):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.pre_norm = nn.LayerNorm(channels)
|
| 43 |
+
self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
x = self.pre_norm(x)
|
| 47 |
+
return x + self.proj(x)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class DownSampleBlock(nn.Module):
|
| 51 |
+
"""Downsample 2D feature maps by rearranging into 2x2 blocks."""
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
vit_embeds = x
|
| 54 |
+
h = w = int(vit_embeds.shape[1] ** 0.5)
|
| 55 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
| 56 |
+
vit_embeds = self.flat_square(vit_embeds)
|
| 57 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
| 58 |
+
return vit_embeds
|
| 59 |
+
|
| 60 |
+
def flat_square(self, x):
|
| 61 |
+
n, w, h, c = x.size()
|
| 62 |
+
if w % 2 == 1:
|
| 63 |
+
x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
|
| 64 |
+
n, w, h, c = x.size()
|
| 65 |
+
if h % 2 == 1:
|
| 66 |
+
x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
|
| 67 |
+
n, w, h, c = x.size()
|
| 68 |
+
x = x.contiguous()
|
| 69 |
+
x = x.view(n, w, int(h / 2), int(c * 2))
|
| 70 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 71 |
+
x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
|
| 72 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 73 |
+
return x
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class DownSample2x2BlockFix(nn.Module):
|
| 77 |
+
"""Downsample 2D feature maps by rearranging into 2x2 blocks (fixed version)."""
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
vit_embeds = x
|
| 81 |
+
h = w = int(vit_embeds.shape[1] ** 0.5)
|
| 82 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
| 83 |
+
vit_embeds = flat_square_2x2(vit_embeds)
|
| 84 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
| 85 |
+
return vit_embeds
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def flat_square_2x2(x):
|
| 89 |
+
"""Rearrange feature map into 2x2 blocks."""
|
| 90 |
+
n, w, h, c = x.size()
|
| 91 |
+
if w % 2 == 1:
|
| 92 |
+
x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
|
| 93 |
+
n, w, h, c = x.size()
|
| 94 |
+
x = x.contiguous()
|
| 95 |
+
if h % 2 == 1:
|
| 96 |
+
x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
|
| 97 |
+
n, w, h, c = x.size()
|
| 98 |
+
x = x.view(n, w, int(h / 2), int(c * 2))
|
| 99 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 100 |
+
x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
|
| 101 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 102 |
+
return x
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class DownSample3x3BlockFix(nn.Module):
|
| 106 |
+
"""Downsample 2D feature maps by rearranging into 3x3 blocks (fixed version)."""
|
| 107 |
+
|
| 108 |
+
def forward(self, x):
|
| 109 |
+
vit_embeds = x
|
| 110 |
+
h = w = int(vit_embeds.shape[1] ** 0.5)
|
| 111 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
| 112 |
+
vit_embeds = flat_square_3x3(vit_embeds)
|
| 113 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
| 114 |
+
return vit_embeds
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def flat_square_3x3(x):
|
| 118 |
+
"""Rearrange feature map into 3x3 blocks."""
|
| 119 |
+
n, w, h, c = x.size()
|
| 120 |
+
if w % 3 != 0:
|
| 121 |
+
x = torch.concat([x, torch.zeros((n, 3 - (w % 3), h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
|
| 122 |
+
n, w, h, c = x.size()
|
| 123 |
+
x = x.contiguous()
|
| 124 |
+
if h % 3 != 0:
|
| 125 |
+
x = torch.concat([x, torch.zeros((n, w, 3 - (h % 3), c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
|
| 126 |
+
n, w, h, c = x.size()
|
| 127 |
+
x = x.view(n, w, int(h / 3), int(c * 3))
|
| 128 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 129 |
+
x = x.view(n, int(h / 3), int(w / 3), int(c * 9))
|
| 130 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 131 |
+
return x
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class MultimodalProjectorConfig(PretrainedConfig):
|
| 135 |
+
"""Configuration for vision-to-language projector."""
|
| 136 |
+
|
| 137 |
+
model_type = "v2l_projector"
|
| 138 |
+
|
| 139 |
+
def __init__(self, mm_projector_type: str = None, **kwargs):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.mm_projector_type = mm_projector_type
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class MultimodalProjector(PreTrainedModel):
|
| 145 |
+
"""Multimodal projector for mapping vision features to LLM space."""
|
| 146 |
+
config_class = MultimodalProjectorConfig
|
| 147 |
+
|
| 148 |
+
def __init__(self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig):
|
| 149 |
+
super().__init__(mm_projector_cfg)
|
| 150 |
+
mm_projector_type = mm_projector_cfg.mm_projector_type
|
| 151 |
+
self.downsample_rate = 1
|
| 152 |
+
if mm_projector_type == "identity":
|
| 153 |
+
self.layers = IdentityMap()
|
| 154 |
+
elif mm_projector_type == "linear":
|
| 155 |
+
self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size)
|
| 156 |
+
elif mm_projector_type == "mlp_downsample":
|
| 157 |
+
self.layers = nn.Sequential(
|
| 158 |
+
DownSampleBlock(),
|
| 159 |
+
nn.LayerNorm(config.mm_hidden_size * 4),
|
| 160 |
+
nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
|
| 161 |
+
nn.GELU(),
|
| 162 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
| 163 |
+
)
|
| 164 |
+
self.downsample_rate = 2
|
| 165 |
+
elif mm_projector_type == "mlp_downsample_2x2_fix":
|
| 166 |
+
self.layers = nn.Sequential(
|
| 167 |
+
DownSample2x2BlockFix(),
|
| 168 |
+
nn.LayerNorm(config.mm_hidden_size * 4),
|
| 169 |
+
nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
|
| 170 |
+
nn.GELU(),
|
| 171 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
| 172 |
+
)
|
| 173 |
+
self.downsample_rate = 2
|
| 174 |
+
elif mm_projector_type == "mlp_downsample_3x3_fix":
|
| 175 |
+
self.layers = nn.Sequential(
|
| 176 |
+
DownSample3x3BlockFix(),
|
| 177 |
+
nn.LayerNorm(config.mm_hidden_size * 9),
|
| 178 |
+
nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 3),
|
| 179 |
+
nn.GELU(),
|
| 180 |
+
nn.LayerNorm(config.mm_hidden_size * 3),
|
| 181 |
+
nn.Linear(config.mm_hidden_size * 3, config.hidden_size),
|
| 182 |
+
nn.GELU(),
|
| 183 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
| 184 |
+
)
|
| 185 |
+
self.downsample_rate = 3
|
| 186 |
+
elif mm_projector_type == "mlp_downsample_3x3_s2":
|
| 187 |
+
self.layers = nn.Sequential(
|
| 188 |
+
DownSample3x3BlockFix(),
|
| 189 |
+
nn.LayerNorm(config.mm_hidden_size * 9),
|
| 190 |
+
nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 3),
|
| 191 |
+
nn.GELU(),
|
| 192 |
+
nn.LayerNorm(config.mm_hidden_size * 3),
|
| 193 |
+
nn.Linear(config.mm_hidden_size * 3, config.mm_hidden_size),
|
| 194 |
+
nn.GELU(),
|
| 195 |
+
nn.LayerNorm(config.mm_hidden_size),
|
| 196 |
+
nn.Linear(config.mm_hidden_size, config.mm_hidden_size // 3),
|
| 197 |
+
nn.GELU(),
|
| 198 |
+
nn.LayerNorm(config.mm_hidden_size // 3),
|
| 199 |
+
nn.Linear(config.mm_hidden_size // 3, config.hidden_size),
|
| 200 |
+
nn.GELU(),
|
| 201 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
| 202 |
+
)
|
| 203 |
+
elif mm_projector_type == "mlp_downsample_3x3_s2_new":
|
| 204 |
+
self.layers = nn.Sequential(
|
| 205 |
+
DownSample3x3BlockFix(),
|
| 206 |
+
nn.LayerNorm(config.mm_hidden_size * 9),
|
| 207 |
+
nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 4),
|
| 208 |
+
nn.GELU(),
|
| 209 |
+
nn.LayerNorm(config.mm_hidden_size * 4),
|
| 210 |
+
nn.Linear(config.mm_hidden_size * 4, config.mm_hidden_size * 2),
|
| 211 |
+
nn.GELU(),
|
| 212 |
+
nn.LayerNorm(config.mm_hidden_size * 2),
|
| 213 |
+
nn.Linear(config.mm_hidden_size * 2, config.mm_hidden_size),
|
| 214 |
+
nn.GELU(),
|
| 215 |
+
nn.LayerNorm(config.mm_hidden_size),
|
| 216 |
+
nn.Linear(config.mm_hidden_size, config.mm_hidden_size // 3),
|
| 217 |
+
nn.GELU(),
|
| 218 |
+
nn.LayerNorm(config.mm_hidden_size // 3),
|
| 219 |
+
nn.Linear(config.mm_hidden_size // 3, config.hidden_size),
|
| 220 |
+
nn.GELU(),
|
| 221 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
| 222 |
+
)
|
| 223 |
+
else:
|
| 224 |
+
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type)
|
| 225 |
+
if mlp_gelu_match:
|
| 226 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
| 227 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
| 228 |
+
for _ in range(1, mlp_depth):
|
| 229 |
+
modules.append(nn.GELU())
|
| 230 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
| 231 |
+
self.layers = nn.Sequential(*modules)
|
| 232 |
+
else:
|
| 233 |
+
raise ValueError(f"Unknown projector type: {mm_projector_type}")
|
| 234 |
+
|
| 235 |
+
def forward(self, x, *args, **kwargs):
|
| 236 |
+
return self.layers(x)
|
builder.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import os
|
| 19 |
+
import os.path as osp
|
| 20 |
+
import warnings
|
| 21 |
+
from dataclasses import asdict
|
| 22 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import transformers
|
| 26 |
+
from huggingface_hub import file_exists, repo_exists
|
| 27 |
+
from huggingface_hub.utils import HFValidationError
|
| 28 |
+
from transformers import (
|
| 29 |
+
AutoConfig,
|
| 30 |
+
AutoModelForCausalLM,
|
| 31 |
+
AutoTokenizer,
|
| 32 |
+
PretrainedConfig,
|
| 33 |
+
PreTrainedModel,
|
| 34 |
+
PreTrainedTokenizer,
|
| 35 |
+
)
|
| 36 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 37 |
+
|
| 38 |
+
from .constants import MEDIA_TOKENS, SENTINEL_TOKEN
|
| 39 |
+
from .conversation import SeparatorStyle, default_conversation
|
| 40 |
+
|
| 41 |
+
DUMMY_CONVERSATION = [
|
| 42 |
+
{"from": "human", "value": "question"},
|
| 43 |
+
{"from": "gpt", "value": "answer"},
|
| 44 |
+
] * 10
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
|
| 48 |
+
"""Tokenize a prompt and return input IDs."""
|
| 49 |
+
return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def has_tokenizer(repo_id_or_path: str) -> bool:
|
| 53 |
+
"""Check if a tokenizer exists at the given path or repository."""
|
| 54 |
+
# Check if the tokenizer is in a local directory
|
| 55 |
+
if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
|
| 56 |
+
return True
|
| 57 |
+
|
| 58 |
+
# Check if the tokenizer is in a Hugging Face Hub repo
|
| 59 |
+
try:
|
| 60 |
+
return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json")
|
| 61 |
+
except HFValidationError:
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
|
| 66 |
+
"""Add sentinel token to tokenizer if not already present."""
|
| 67 |
+
if not hasattr(tokenizer, "sentinel_token"):
|
| 68 |
+
tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
|
| 69 |
+
tokenizer.sentinel_token = SENTINEL_TOKEN
|
| 70 |
+
tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def tokenize_conversation_legacy(
|
| 74 |
+
messages: Sequence[Dict[str, str]],
|
| 75 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 76 |
+
add_generation_prompt: bool = False,
|
| 77 |
+
overrides: Optional[Dict[str, str]] = None,
|
| 78 |
+
no_system_prompt: bool = False,
|
| 79 |
+
) -> torch.Tensor:
|
| 80 |
+
"""Tokenize conversation using legacy format."""
|
| 81 |
+
conv = default_conversation.copy()
|
| 82 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 83 |
+
|
| 84 |
+
if no_system_prompt:
|
| 85 |
+
conv.system = ""
|
| 86 |
+
|
| 87 |
+
# Skip the first message if it is not from human
|
| 88 |
+
if messages[0]["from"] != "human":
|
| 89 |
+
messages = messages[1:]
|
| 90 |
+
|
| 91 |
+
# Add a generation prompt if needed
|
| 92 |
+
if add_generation_prompt:
|
| 93 |
+
messages.append({"from": "gpt", "value": None})
|
| 94 |
+
|
| 95 |
+
conv.messages = []
|
| 96 |
+
for turn, message in enumerate(messages):
|
| 97 |
+
role = roles[message["from"]]
|
| 98 |
+
assert role == conv.roles[turn % 2]
|
| 99 |
+
if overrides is not None and message["from"] in overrides:
|
| 100 |
+
conv.append_message(role, overrides[message["from"]])
|
| 101 |
+
else:
|
| 102 |
+
conv.append_message(role, message["value"])
|
| 103 |
+
|
| 104 |
+
return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def tokenize_conversation(
|
| 108 |
+
messages: Sequence[Dict[str, str]],
|
| 109 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 110 |
+
add_generation_prompt: bool = False,
|
| 111 |
+
overrides: Optional[Dict[str, str]] = None,
|
| 112 |
+
no_system_prompt: bool = False,
|
| 113 |
+
) -> torch.Tensor:
|
| 114 |
+
"""Tokenize conversation using modern chat template format."""
|
| 115 |
+
# Normalize the conversation before tokenization
|
| 116 |
+
for message in messages:
|
| 117 |
+
message["value"] = message["value"].strip()
|
| 118 |
+
|
| 119 |
+
if default_conversation.sep_style != SeparatorStyle.AUTO:
|
| 120 |
+
return tokenize_conversation_legacy(
|
| 121 |
+
messages,
|
| 122 |
+
tokenizer,
|
| 123 |
+
add_generation_prompt=add_generation_prompt,
|
| 124 |
+
overrides=overrides,
|
| 125 |
+
no_system_prompt=no_system_prompt,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
conversation = []
|
| 129 |
+
for m in messages:
|
| 130 |
+
message = {}
|
| 131 |
+
if m["from"] == "human":
|
| 132 |
+
message["role"] = "user"
|
| 133 |
+
elif m["from"] == "gpt":
|
| 134 |
+
message["role"] = "assistant"
|
| 135 |
+
else:
|
| 136 |
+
raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
|
| 137 |
+
|
| 138 |
+
message["content"] = m["value"]
|
| 139 |
+
if overrides is not None and m["from"] in overrides:
|
| 140 |
+
message["content"] = overrides[m["from"]]
|
| 141 |
+
conversation.append(message)
|
| 142 |
+
|
| 143 |
+
if no_system_prompt:
|
| 144 |
+
conversation = [{"role": "system", "content": ""}] + conversation
|
| 145 |
+
|
| 146 |
+
text = tokenizer.apply_chat_template(
|
| 147 |
+
conversation,
|
| 148 |
+
add_generation_prompt=add_generation_prompt,
|
| 149 |
+
tokenize=False,
|
| 150 |
+
)
|
| 151 |
+
return tokenizer_image_token(text, tokenizer, return_tensors="pt")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
|
| 155 |
+
"""Infer stop tokens from tokenizer by analyzing dummy conversation."""
|
| 156 |
+
_maybe_add_sentinel_token(tokenizer)
|
| 157 |
+
template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
|
| 158 |
+
|
| 159 |
+
stop_tokens = {tokenizer.eos_token}
|
| 160 |
+
for k in range(template.size(0) - 1):
|
| 161 |
+
if template[k] == tokenizer.sentinel_token_id:
|
| 162 |
+
stop_token = tokenizer.decode(template[k + 1])
|
| 163 |
+
stop_tokens.add(stop_token)
|
| 164 |
+
return list(stop_tokens)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def context_length_extension(config):
|
| 168 |
+
"""Extend context length using RoPE scaling if needed."""
|
| 169 |
+
orig_ctx_len = getattr(config, "max_position_embeddings", None)
|
| 170 |
+
model_max_length = getattr(config, "model_max_length", None)
|
| 171 |
+
if orig_ctx_len and model_max_length > orig_ctx_len:
|
| 172 |
+
print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
|
| 173 |
+
scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
|
| 174 |
+
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
| 175 |
+
return config
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def build_llm_and_tokenizer(
|
| 179 |
+
model_name_or_path: str,
|
| 180 |
+
config: PretrainedConfig,
|
| 181 |
+
attn_implementation=None,
|
| 182 |
+
model_max_length=None,
|
| 183 |
+
*args,
|
| 184 |
+
**kwargs,
|
| 185 |
+
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
| 186 |
+
"""Build language model and tokenizer from pretrained checkpoint."""
|
| 187 |
+
llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
|
| 188 |
+
llm_cfg._attn_implementation = attn_implementation
|
| 189 |
+
llm_cfg.model_max_length = model_max_length
|
| 190 |
+
if model_max_length is not None:
|
| 191 |
+
context_length_extension(llm_cfg)
|
| 192 |
+
|
| 193 |
+
# Quantization related
|
| 194 |
+
quantization_restore_from_checkpoint = False
|
| 195 |
+
|
| 196 |
+
if type(config.model_dtype) == str:
|
| 197 |
+
model_dtype = eval(config.model_dtype)
|
| 198 |
+
else:
|
| 199 |
+
model_dtype = config.model_dtype
|
| 200 |
+
|
| 201 |
+
if quantization_restore_from_checkpoint:
|
| 202 |
+
fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
|
| 203 |
+
|
| 204 |
+
llm = AutoModelForCausalLM.from_pretrained(
|
| 205 |
+
fp8_model_name_or_path, config=llm_cfg, torch_dtype=model_dtype, *args, **kwargs
|
| 206 |
+
)
|
| 207 |
+
else:
|
| 208 |
+
if is_deepspeed_zero3_enabled():
|
| 209 |
+
kwargs.pop("device_map")
|
| 210 |
+
llm = AutoModelForCausalLM.from_pretrained(
|
| 211 |
+
model_name_or_path, config=llm_cfg, torch_dtype=model_dtype, *args, **kwargs
|
| 212 |
+
)
|
| 213 |
+
print(f"Loaded model from {model_name_or_path} with dtype {model_dtype}")
|
| 214 |
+
|
| 215 |
+
# Locate the tokenizer.
|
| 216 |
+
llm_path = model_name_or_path
|
| 217 |
+
if not has_tokenizer(llm_path):
|
| 218 |
+
llm_path = osp.join(llm_path, "llm")
|
| 219 |
+
if not has_tokenizer(llm_path):
|
| 220 |
+
raise ValueError(f"Cannot find tokenizer in {llm_path}.")
|
| 221 |
+
|
| 222 |
+
tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False)
|
| 223 |
+
if model_max_length is not None:
|
| 224 |
+
tokenizer.model_max_length = model_max_length
|
| 225 |
+
|
| 226 |
+
# Load chat template if specified.
|
| 227 |
+
if getattr(config, "chat_template", None) is not None:
|
| 228 |
+
print(f"Using chat template: {config.chat_template}")
|
| 229 |
+
fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
|
| 230 |
+
if not os.path.exists(fpath):
|
| 231 |
+
fpath = os.path.join(os.path.dirname(model_name_or_path), f"{config.chat_template}.jinja")
|
| 232 |
+
with open(fpath) as fd:
|
| 233 |
+
chat_template = fd.read()
|
| 234 |
+
tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
|
| 235 |
+
|
| 236 |
+
# Set stop tokens for the tokenizer
|
| 237 |
+
tokenizer.stop_tokens = infer_stop_tokens(tokenizer)
|
| 238 |
+
tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens)
|
| 239 |
+
|
| 240 |
+
# Add media tokens to the tokenizer
|
| 241 |
+
tokenizer.media_tokens = MEDIA_TOKENS
|
| 242 |
+
tokenizer.media_token_ids = {}
|
| 243 |
+
for name, token in MEDIA_TOKENS.items():
|
| 244 |
+
if config.speech_tower_cfg is None and name == "speech":
|
| 245 |
+
continue
|
| 246 |
+
if config.sound_tower_cfg is None and name == "sound":
|
| 247 |
+
continue
|
| 248 |
+
tokenizer.add_tokens([token], special_tokens=True)
|
| 249 |
+
tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token)
|
| 250 |
+
tokenizer.media_tokens[name] = token
|
| 251 |
+
|
| 252 |
+
config.hidden_size = llm.config.hidden_size
|
| 253 |
+
return llm, tokenizer
|
config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
configuration_vila.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import math
|
| 19 |
+
import os
|
| 20 |
+
import os.path as osp
|
| 21 |
+
from copy import deepcopy
|
| 22 |
+
from threading import Thread
|
| 23 |
+
from typing import List, Optional
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torchvision
|
| 27 |
+
from PIL import Image
|
| 28 |
+
from transformers import (
|
| 29 |
+
AutoProcessor,
|
| 30 |
+
PretrainedConfig,
|
| 31 |
+
PreTrainedModel,
|
| 32 |
+
Qwen2Config,
|
| 33 |
+
Qwen2ForCausalLM,
|
| 34 |
+
Qwen2PreTrainedModel,
|
| 35 |
+
TextIteratorStreamer,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class VILAConfig(PretrainedConfig):
|
| 40 |
+
model_type = "vila"
|
| 41 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
llm_cfg=None,
|
| 46 |
+
vision_tower_cfg=None,
|
| 47 |
+
mm_projector_cfg=None,
|
| 48 |
+
speech_tower_cfg=None,
|
| 49 |
+
sound_tower_cfg=None,
|
| 50 |
+
speech_mm_projector_cfg=None,
|
| 51 |
+
sound_mm_projector_cfg=None,
|
| 52 |
+
architectures=None,
|
| 53 |
+
resume_path=None,
|
| 54 |
+
hidden_size=None,
|
| 55 |
+
mm_hidden_size=None,
|
| 56 |
+
image_aspect_ratio=None,
|
| 57 |
+
num_video_frames=None,
|
| 58 |
+
fps=None,
|
| 59 |
+
mm_vision_select_layer=None,
|
| 60 |
+
mm_vision_select_feature=None,
|
| 61 |
+
mm_use_im_start_end=False,
|
| 62 |
+
mm_use_im_patch_token=False,
|
| 63 |
+
mm_projector_lr=None,
|
| 64 |
+
vision_tower_lr=None,
|
| 65 |
+
vision_resolution=None,
|
| 66 |
+
interpolate_mode=None,
|
| 67 |
+
s2=None,
|
| 68 |
+
dynamic_s2=None,
|
| 69 |
+
s2_scales=None,
|
| 70 |
+
s2_max_split_size=None,
|
| 71 |
+
s2_resize_output_to_scale_idx=0,
|
| 72 |
+
min_tiles: Optional[int] = 1,
|
| 73 |
+
max_tiles: Optional[int] = 12,
|
| 74 |
+
num_time_tokens=None,
|
| 75 |
+
time_token_format=None,
|
| 76 |
+
image_encoder: str = '{"_target_": "llava.model.encoders.BasicImageEncoder"}',
|
| 77 |
+
video_encoder: str = '{"_target_": "llava.model.encoders.TSPVideoEncoder"}',
|
| 78 |
+
sound_encoder: str = '{"_target_": "llava.model.encoders.BasicSoundEncoder"}',
|
| 79 |
+
speech_encoder: str = '{"_target_": "llava.model.encoders.BasicSpeechEncoder"}',
|
| 80 |
+
**kwargs,
|
| 81 |
+
):
|
| 82 |
+
super().__init__(**kwargs)
|
| 83 |
+
|
| 84 |
+
self.architectures = architectures
|
| 85 |
+
self.llm_cfg = llm_cfg
|
| 86 |
+
self.vision_tower_cfg = vision_tower_cfg
|
| 87 |
+
self.mm_projector_cfg = mm_projector_cfg
|
| 88 |
+
self.speech_tower_cfg = speech_tower_cfg
|
| 89 |
+
self.sound_tower_cfg = sound_tower_cfg
|
| 90 |
+
self.speech_mm_projector_cfg = speech_mm_projector_cfg
|
| 91 |
+
self.sound_mm_projector_cfg = sound_mm_projector_cfg
|
| 92 |
+
self.resume_path = resume_path
|
| 93 |
+
|
| 94 |
+
self.hidden_size = hidden_size
|
| 95 |
+
self.mm_hidden_size = mm_hidden_size
|
| 96 |
+
self.image_aspect_ratio = image_aspect_ratio
|
| 97 |
+
self.num_video_frames = num_video_frames
|
| 98 |
+
self.fps = fps
|
| 99 |
+
self.mm_vision_select_layer = mm_vision_select_layer
|
| 100 |
+
self.mm_vision_select_feature = mm_vision_select_feature
|
| 101 |
+
self.mm_use_im_start_end = mm_use_im_start_end
|
| 102 |
+
self.mm_use_im_patch_token = mm_use_im_patch_token
|
| 103 |
+
self.mm_projector_lr = mm_projector_lr
|
| 104 |
+
self.vision_tower_lr = vision_tower_lr
|
| 105 |
+
self.vision_resolution = vision_resolution
|
| 106 |
+
self.interpolate_mode = interpolate_mode
|
| 107 |
+
self.s2 = s2
|
| 108 |
+
self.dynamic_s2 = dynamic_s2
|
| 109 |
+
self.s2_scales = s2_scales
|
| 110 |
+
self.s2_max_split_size = s2_max_split_size
|
| 111 |
+
self.s2_resize_output_to_scale_idx = s2_resize_output_to_scale_idx
|
| 112 |
+
self.min_tiles = min_tiles
|
| 113 |
+
self.max_tiles = max_tiles
|
| 114 |
+
self.num_time_tokens = num_time_tokens
|
| 115 |
+
self.time_token_format = time_token_format
|
| 116 |
+
|
| 117 |
+
self.image_encoder = image_encoder
|
| 118 |
+
self.video_encoder = video_encoder
|
| 119 |
+
self.sound_encoder = sound_encoder
|
| 120 |
+
self.speech_encoder = speech_encoder
|
| 121 |
+
self.audio_sampling_rate = 16000
|
| 122 |
+
self.audio_chunk_length = 120
|
| 123 |
+
self.interleaved_vis_aud_in_video = True
|
| 124 |
+
self.interleaved_video_segment_duration = 30
|
| 125 |
+
self.audio_hop_length = 60
|
| 126 |
+
|
| 127 |
+
super().__init__(**kwargs)
|
constants.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
| 18 |
+
|
| 19 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
| 20 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
| 21 |
+
|
| 22 |
+
LOGDIR = "."
|
| 23 |
+
|
| 24 |
+
# Model Constants
|
| 25 |
+
IGNORE_INDEX = -100
|
| 26 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 27 |
+
DEFAULT_SOUND_TOKEN = "<sound>"
|
| 28 |
+
DEFAULT_SPEECH_TOKEN = "<speech>"
|
| 29 |
+
SENTINEL_TOKEN = "<vila/sentinel>"
|
| 30 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
| 31 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
SENTINEL_TOKEN = "<vila/sentinel>"
|
| 35 |
+
|
| 36 |
+
MEDIA_TOKENS = {
|
| 37 |
+
"image": "<image>",
|
| 38 |
+
"video": "<vila/video>",
|
| 39 |
+
"speech": "<speech>",
|
| 40 |
+
"sound": "<sound>",
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# Token IDs for different model variants:
|
| 44 |
+
"""
|
| 45 |
+
vila:
|
| 46 |
+
151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 47 |
+
151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 48 |
+
151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 49 |
+
151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 50 |
+
151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 51 |
+
151648: AddedToken("<vila/sentinel>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 52 |
+
151649: AddedToken("<image>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 53 |
+
151650: AddedToken("<vila/video>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 54 |
+
|
| 55 |
+
xvila:
|
| 56 |
+
151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 57 |
+
151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 58 |
+
151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 59 |
+
151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 60 |
+
151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 61 |
+
151648: AddedToken("<vila/sentinel>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 62 |
+
151649: AddedToken("<image>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 63 |
+
151650: AddedToken("<vila/video>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 64 |
+
151651: AddedToken("<speech>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 65 |
+
151652: AddedToken("<sound>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 66 |
+
151653: AddedToken("<|image_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 67 |
+
151654: AddedToken("<|image_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 68 |
+
151655: AddedToken("<|video_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 69 |
+
151656: AddedToken("<|video_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 70 |
+
151657: AddedToken("<|speech_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 71 |
+
151658: AddedToken("<|speech_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 72 |
+
151659: AddedToken("<|sound_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 73 |
+
151660: AddedToken("<|sound_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 74 |
+
"""
|
| 75 |
+
MM_BOS_EOS_TOKENS = {
|
| 76 |
+
"image": ["<|image_bos|>", "<|image_eos|>"],
|
| 77 |
+
"video": ["<|video_bos|>", "<|video_eos|>"],
|
| 78 |
+
"speech": ["<|speech_bos|>", "<|speech_eos|>"],
|
| 79 |
+
"sound": ["<|sound_bos|>", "<|sound_eos|>"],
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
NUM_EXTRA_TOKENS_VILA = 8
|
| 83 |
+
NUM_EXTRA_TOKENS_XVILA = 10
|
conversation.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
| 17 |
+
|
| 18 |
+
import dataclasses
|
| 19 |
+
from enum import Enum, auto
|
| 20 |
+
from typing import List
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SeparatorStyle(Enum):
|
| 24 |
+
"""Different separator style."""
|
| 25 |
+
|
| 26 |
+
AUTO = auto()
|
| 27 |
+
TWO = auto()
|
| 28 |
+
MPT = auto()
|
| 29 |
+
PLAIN = auto()
|
| 30 |
+
LLAMA_3 = auto()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclasses.dataclass
|
| 34 |
+
class Conversation:
|
| 35 |
+
"""A class that keeps all conversation history."""
|
| 36 |
+
|
| 37 |
+
system: str
|
| 38 |
+
roles: List[str]
|
| 39 |
+
messages: List[List[str]]
|
| 40 |
+
sep_style: SeparatorStyle = SeparatorStyle.AUTO
|
| 41 |
+
sep: str = "###"
|
| 42 |
+
sep2: str = None
|
| 43 |
+
version: str = "Unknown"
|
| 44 |
+
|
| 45 |
+
def get_prompt(self):
|
| 46 |
+
messages = self.messages
|
| 47 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
| 48 |
+
messages = self.messages.copy()
|
| 49 |
+
init_role, init_msg = messages[0].copy()
|
| 50 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
| 51 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
| 52 |
+
|
| 53 |
+
if self.sep_style == SeparatorStyle.TWO:
|
| 54 |
+
seps = [self.sep, self.sep2]
|
| 55 |
+
ret = self.system + seps[0]
|
| 56 |
+
for i, (role, message) in enumerate(messages):
|
| 57 |
+
if message:
|
| 58 |
+
if type(message) is tuple:
|
| 59 |
+
message, _, _ = message
|
| 60 |
+
ret += role + ": " + message + seps[i % 2]
|
| 61 |
+
else:
|
| 62 |
+
ret += role + ":"
|
| 63 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
| 64 |
+
ret = self.system + self.sep
|
| 65 |
+
for rid, (role, message) in enumerate(messages):
|
| 66 |
+
if message:
|
| 67 |
+
if type(message) is tuple:
|
| 68 |
+
message = message[0]
|
| 69 |
+
sep = self.sep if rid < len(messages) - 1 else self.sep2
|
| 70 |
+
ret += role + message + sep
|
| 71 |
+
else:
|
| 72 |
+
ret += role
|
| 73 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
| 74 |
+
ret = self.system + self.sep
|
| 75 |
+
for role, message in messages:
|
| 76 |
+
if message:
|
| 77 |
+
if type(message) is tuple:
|
| 78 |
+
message, _, _ = message
|
| 79 |
+
ret += role + message + self.sep
|
| 80 |
+
else:
|
| 81 |
+
ret += role
|
| 82 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
| 83 |
+
seps = [self.sep, self.sep2]
|
| 84 |
+
ret = self.system
|
| 85 |
+
for i, (role, message) in enumerate(messages):
|
| 86 |
+
if message:
|
| 87 |
+
if type(message) is tuple:
|
| 88 |
+
message, _, _ = message
|
| 89 |
+
ret += message + seps[i % 2]
|
| 90 |
+
else:
|
| 91 |
+
ret += ""
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
| 94 |
+
|
| 95 |
+
return ret
|
| 96 |
+
|
| 97 |
+
def append_message(self, role, message):
|
| 98 |
+
self.messages.append([role, message])
|
| 99 |
+
|
| 100 |
+
def copy(self):
|
| 101 |
+
return Conversation(
|
| 102 |
+
system=self.system,
|
| 103 |
+
roles=self.roles,
|
| 104 |
+
messages=[[x, y] for x, y in self.messages],
|
| 105 |
+
sep_style=self.sep_style,
|
| 106 |
+
sep=self.sep,
|
| 107 |
+
sep2=self.sep2,
|
| 108 |
+
version=self.version,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
conv_auto = Conversation(
|
| 113 |
+
system="",
|
| 114 |
+
roles=("", ""),
|
| 115 |
+
messages=(),
|
| 116 |
+
sep_style=SeparatorStyle.AUTO,
|
| 117 |
+
sep="\n",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
conv_vicuna_v1 = Conversation(
|
| 121 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
| 122 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
| 123 |
+
roles=("USER", "ASSISTANT"),
|
| 124 |
+
version="v1",
|
| 125 |
+
messages=(),
|
| 126 |
+
sep_style=SeparatorStyle.TWO,
|
| 127 |
+
sep=" ",
|
| 128 |
+
sep2="</s>",
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
conv_llava_plain = Conversation(
|
| 132 |
+
system="",
|
| 133 |
+
roles=("", ""),
|
| 134 |
+
messages=(),
|
| 135 |
+
sep_style=SeparatorStyle.PLAIN,
|
| 136 |
+
sep="\n",
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
hermes_2 = Conversation(
|
| 140 |
+
system="<|im_start|>system\nAnswer the questions.",
|
| 141 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
| 142 |
+
sep_style=SeparatorStyle.MPT,
|
| 143 |
+
sep="<|im_end|>",
|
| 144 |
+
messages=(),
|
| 145 |
+
version="hermes-2",
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
llama_3_chat = Conversation(
|
| 149 |
+
system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
|
| 150 |
+
"You are able to understand the visual content that the user provides, "
|
| 151 |
+
"and assist the user with a variety of tasks using natural language.",
|
| 152 |
+
roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
|
| 153 |
+
version="llama_v3",
|
| 154 |
+
messages=(),
|
| 155 |
+
sep_style=SeparatorStyle.LLAMA_3,
|
| 156 |
+
sep="<|eot_id|>",
|
| 157 |
+
sep2="<|end_of_text|>",
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
default_conversation = conv_auto
|
| 162 |
+
conv_templates = {
|
| 163 |
+
"auto": conv_auto,
|
| 164 |
+
"hermes-2": hermes_2,
|
| 165 |
+
"llama_3": llama_3_chat,
|
| 166 |
+
"v1": conv_vicuna_v1,
|
| 167 |
+
"vicuna_v1": conv_vicuna_v1,
|
| 168 |
+
"plain": conv_llava_plain,
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
CONVERSATION_MODE_MAPPING = {
|
| 173 |
+
"vila1.5-3b": "vicuna_v1",
|
| 174 |
+
"vila1.5-8b": "llama_3",
|
| 175 |
+
"vila1.5-13b": "vicuna_v1",
|
| 176 |
+
"vila1.5-40b": "hermes-2",
|
| 177 |
+
"llama-3": "llama_3",
|
| 178 |
+
"llama3": "llama_3",
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def auto_set_conversation_mode(model_name_or_path: str) -> str:
|
| 183 |
+
"""Automatically set conversation mode based on model name/path."""
|
| 184 |
+
global default_conversation
|
| 185 |
+
for k, v in CONVERSATION_MODE_MAPPING.items():
|
| 186 |
+
if k in model_name_or_path.lower():
|
| 187 |
+
print(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.")
|
| 188 |
+
default_conversation = conv_templates[v]
|
| 189 |
+
return
|
distributed.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import warnings
|
| 19 |
+
from typing import Any, List, Optional
|
| 20 |
+
|
| 21 |
+
from torch import distributed as dist
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
"init",
|
| 25 |
+
"is_initialized",
|
| 26 |
+
"size",
|
| 27 |
+
"rank",
|
| 28 |
+
"local_size",
|
| 29 |
+
"local_rank",
|
| 30 |
+
"is_main",
|
| 31 |
+
"barrier",
|
| 32 |
+
"gather",
|
| 33 |
+
"all_gather",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def init() -> None:
|
| 38 |
+
if "RANK" not in os.environ:
|
| 39 |
+
warnings.warn("Environment variable `RANK` is not set. Skipping distributed initialization.")
|
| 40 |
+
return
|
| 41 |
+
dist.init_process_group(backend="nccl", init_method="env://")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def is_initialized() -> bool:
|
| 45 |
+
return dist.is_initialized()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def size() -> int:
|
| 49 |
+
return int(os.environ.get("WORLD_SIZE", 1))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def rank() -> int:
|
| 53 |
+
return int(os.environ.get("RANK", 0))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def local_size() -> int:
|
| 57 |
+
return int(os.environ.get("LOCAL_WORLD_SIZE", 1))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def local_rank() -> int:
|
| 61 |
+
return int(os.environ.get("LOCAL_RANK", 0))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def is_main() -> bool:
|
| 65 |
+
return rank() == 0
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def barrier() -> None:
|
| 69 |
+
dist.barrier()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def gather(obj: Any, dst: int = 0) -> Optional[List[Any]]:
|
| 73 |
+
if not is_initialized():
|
| 74 |
+
return [obj]
|
| 75 |
+
if is_main():
|
| 76 |
+
objs = [None for _ in range(size())]
|
| 77 |
+
dist.gather_object(obj, objs, dst=dst)
|
| 78 |
+
return objs
|
| 79 |
+
else:
|
| 80 |
+
dist.gather_object(obj, dst=dst)
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def all_gather(obj: Any) -> List[Any]:
|
| 85 |
+
if not is_initialized():
|
| 86 |
+
return [obj]
|
| 87 |
+
objs = [None for _ in range(size())]
|
| 88 |
+
dist.all_gather_object(objs, obj)
|
| 89 |
+
return objs
|
environment_setup.sh
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
CONDA_ENV=${1:-""}
|
| 5 |
+
if [ -n "$CONDA_ENV" ]; then
|
| 6 |
+
# This is required to activate conda environment
|
| 7 |
+
eval "$(conda shell.bash hook)"
|
| 8 |
+
|
| 9 |
+
conda create -n $CONDA_ENV python=3.10.14 -y
|
| 10 |
+
conda activate $CONDA_ENV
|
| 11 |
+
# This is optional if you prefer to use built-in nvcc
|
| 12 |
+
conda install -c nvidia cuda-toolkit=12.2 -y
|
| 13 |
+
else
|
| 14 |
+
echo "Skipping conda environment creation. Make sure you have the correct environment activated."
|
| 15 |
+
fi
|
| 16 |
+
|
| 17 |
+
# Using uv to speedup installations
|
| 18 |
+
pip install uv
|
| 19 |
+
alias uvp="uv pip"
|
| 20 |
+
|
| 21 |
+
echo "[INFO] Using python $(which python)"
|
| 22 |
+
echo "[INFO] Using pip $(which pip)"
|
| 23 |
+
echo "[INFO] Using uv $(which uv)"
|
| 24 |
+
|
| 25 |
+
# This is required to enable PEP 660 support
|
| 26 |
+
uv pip install --upgrade pip setuptools
|
| 27 |
+
|
| 28 |
+
# Install FlashAttention2
|
| 29 |
+
uv pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.8/flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
| 30 |
+
|
| 31 |
+
# Install VILA
|
| 32 |
+
uv pip install -e ".[train,eval]"
|
| 33 |
+
|
| 34 |
+
# numpy introduce a lot dependencies issues, separate from pyproject.yaml
|
| 35 |
+
pip install numpy==1.26.4
|
| 36 |
+
|
| 37 |
+
# audio
|
| 38 |
+
uv pip install soundfile librosa openai-whisper ftfy
|
| 39 |
+
conda install -c conda-forge ffmpeg
|
| 40 |
+
uv pip install jiwer
|
| 41 |
+
|
| 42 |
+
# Downgrade protobuf to 3.20 for backward compatibility
|
| 43 |
+
uv pip install protobuf==3.20.*
|
| 44 |
+
|
| 45 |
+
# Replace transformers and deepspeed files
|
| 46 |
+
site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])')
|
| 47 |
+
cp -rv ./transformers/modeling_utils.py $site_pkg_path/transformers/modeling_utils.py # for using qwen 2.5 omni checkpoint
|
| 48 |
+
|
| 49 |
+
# for benchmark adoption
|
| 50 |
+
uv pip install faiss-gpu-cu12
|
| 51 |
+
|
| 52 |
+
# Quantization requires the newest triton version, and introduce dependency issue
|
| 53 |
+
uv pip install triton==3.1.0 # we don't need this version if we do not use FP8LinearQwen2Config, QLlavaLlamaConfig, etc. It is not compatible with mamba-ssm.
|
| 54 |
+
|
| 55 |
+
uv pip install kaldiio
|
| 56 |
+
|
| 57 |
+
# for rotary embedding
|
| 58 |
+
uv pip install beartype
|
| 59 |
+
|
| 60 |
+
uv pip install pydantic==1.10.22
|
example_infer.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
from transformers import AutoProcessor, AutoModel, AutoConfig, GenerationConfig
|
| 18 |
+
import torch
|
| 19 |
+
import os
|
| 20 |
+
import time
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import List, Dict, Any, Optional, Union
|
| 23 |
+
import logging
|
| 24 |
+
import sys
|
| 25 |
+
os.environ["HF_HUB_OFFLINE"] = "1" # Use local cache for models
|
| 26 |
+
|
| 27 |
+
# Set up logging
|
| 28 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
def add_to_sys_path_direct(model_path):
|
| 32 |
+
"""Add model path directly to sys.path"""
|
| 33 |
+
if model_path not in sys.path:
|
| 34 |
+
sys.path.insert(0, model_path) # Insert at beginning for priority
|
| 35 |
+
print(f"✓ Added to sys.path: {model_path}")
|
| 36 |
+
else:
|
| 37 |
+
print(f"Already in sys.path: {model_path}")
|
| 38 |
+
|
| 39 |
+
class NVOmniVideoInference:
|
| 40 |
+
"""A class to handle NVOmni video model inference with improved error handling and flexibility."""
|
| 41 |
+
|
| 42 |
+
def __init__(self, model_path: str, torch_dtype="torch.float16", device_map="auto"):
|
| 43 |
+
"""
|
| 44 |
+
Initialize the NVOmni model for video inference.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
model_path (str): Path to the model directory
|
| 48 |
+
torch_dtype: PyTorch data type for model weights
|
| 49 |
+
device_map (str): Device mapping strategy for model loading
|
| 50 |
+
"""
|
| 51 |
+
self.model_path = model_path
|
| 52 |
+
self.torch_dtype = torch_dtype
|
| 53 |
+
self.device_map = device_map
|
| 54 |
+
self.model = None
|
| 55 |
+
self.processor = None
|
| 56 |
+
self.config = None
|
| 57 |
+
self.device = None
|
| 58 |
+
|
| 59 |
+
self.load_model()
|
| 60 |
+
|
| 61 |
+
def validate_paths(self, model_path: str, video_path: str = None) -> bool:
|
| 62 |
+
"""Validate that required paths exist."""
|
| 63 |
+
if not Path(model_path).exists():
|
| 64 |
+
logger.error(f"Model path does not exist: {model_path}")
|
| 65 |
+
return False
|
| 66 |
+
|
| 67 |
+
if video_path and not Path(video_path).exists():
|
| 68 |
+
logger.error(f"Video path does not exist: {video_path}")
|
| 69 |
+
return False
|
| 70 |
+
|
| 71 |
+
return True
|
| 72 |
+
|
| 73 |
+
def load_model(self) -> bool:
|
| 74 |
+
"""Load the model, processor, and config with error handling."""
|
| 75 |
+
if not self.validate_paths(self.model_path):
|
| 76 |
+
return False
|
| 77 |
+
|
| 78 |
+
if True:
|
| 79 |
+
logger.info("Loading model configuration...")
|
| 80 |
+
self.config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True)
|
| 81 |
+
|
| 82 |
+
logger.info("Loading model...")
|
| 83 |
+
start_time = time.time()
|
| 84 |
+
self.model = AutoModel.from_pretrained(
|
| 85 |
+
self.model_path,
|
| 86 |
+
trust_remote_code=True,
|
| 87 |
+
torch_dtype=self.torch_dtype,
|
| 88 |
+
device_map=self.device_map,
|
| 89 |
+
low_cpu_mem_usage=True # More memory efficient loading
|
| 90 |
+
)#.to(eval(self.torch_dtype))
|
| 91 |
+
load_time = time.time() - start_time
|
| 92 |
+
logger.info(f"Model loaded in {load_time:.2f} seconds")
|
| 93 |
+
|
| 94 |
+
logger.info("Loading processor...")
|
| 95 |
+
self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
|
| 96 |
+
|
| 97 |
+
# Set device for single-device setups
|
| 98 |
+
if hasattr(self.model, 'device'):
|
| 99 |
+
self.device = self.model.device
|
| 100 |
+
else:
|
| 101 |
+
self.device = next(self.model.parameters()).device if self.model.parameters() else torch.device('cpu')
|
| 102 |
+
|
| 103 |
+
logger.info(f"Model successfully loaded on device: {self.device}")
|
| 104 |
+
self._print_model_info()
|
| 105 |
+
return True
|
| 106 |
+
|
| 107 |
+
def _print_model_info(self):
|
| 108 |
+
"""Print useful information about the loaded model."""
|
| 109 |
+
logger.info("=" * 50)
|
| 110 |
+
logger.info("MODEL INFORMATION")
|
| 111 |
+
logger.info("=" * 50)
|
| 112 |
+
|
| 113 |
+
if self.config:
|
| 114 |
+
logger.info(f"Model type: {getattr(self.config, 'model_type', 'Unknown')}")
|
| 115 |
+
logger.info(f"Hidden size: {getattr(self.config, 'hidden_size', 'Unknown')}")
|
| 116 |
+
|
| 117 |
+
if self.model and torch.cuda.is_available():
|
| 118 |
+
logger.info(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
| 119 |
+
logger.info(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
| 120 |
+
|
| 121 |
+
def create_conversation(self, video_path: str, text_prompt: str) -> List[Dict[str, Any]]:
|
| 122 |
+
"""
|
| 123 |
+
Create a conversation format for the model.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
video_path (str): Path to the video file
|
| 127 |
+
text_prompt (str): Text prompt for the model
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
List[Dict]: Conversation in the expected format
|
| 131 |
+
"""
|
| 132 |
+
return [{
|
| 133 |
+
"role": "user",
|
| 134 |
+
"content": [
|
| 135 |
+
{"type": "video", "video": video_path},
|
| 136 |
+
{"type": "text", "text": text_prompt}
|
| 137 |
+
]
|
| 138 |
+
}]
|
| 139 |
+
|
| 140 |
+
@torch.inference_mode()
|
| 141 |
+
def generate_response(
|
| 142 |
+
self,
|
| 143 |
+
video_path: str,
|
| 144 |
+
text_prompt: str,
|
| 145 |
+
max_new_tokens: int = 256,
|
| 146 |
+
temperature: float = None,
|
| 147 |
+
top_p: float = None,
|
| 148 |
+
do_sample: bool = None,
|
| 149 |
+
num_video_frames: int = -1,
|
| 150 |
+
load_audio_in_video: bool = True,
|
| 151 |
+
audio_length: Union[int, str] = "max_3600",
|
| 152 |
+
) -> Optional[str]:
|
| 153 |
+
"""
|
| 154 |
+
Generate a response from the model given a video and text prompt.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
video_path (str): Path to the video file
|
| 158 |
+
text_prompt (str): Text prompt for the model
|
| 159 |
+
max_new_tokens (int): Maximum number of new tokens to generate
|
| 160 |
+
temperature (float): Sampling temperature
|
| 161 |
+
top_p (float): Top-p sampling parameter
|
| 162 |
+
do_sample (bool): Whether to use sampling
|
| 163 |
+
custom_generation_config (GenerationConfig): Custom generation configuration
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
Optional[str]: Generated response or None if failed
|
| 167 |
+
"""
|
| 168 |
+
if not self.model or not self.processor:
|
| 169 |
+
logger.error("Model or processor not loaded. Please initialize the model first.")
|
| 170 |
+
return None
|
| 171 |
+
|
| 172 |
+
if not self.validate_paths(self.model_path, video_path):
|
| 173 |
+
return None
|
| 174 |
+
|
| 175 |
+
# try:
|
| 176 |
+
if True:
|
| 177 |
+
|
| 178 |
+
logger.info(f"Processing video: {video_path}")
|
| 179 |
+
logger.info(f"Text prompt: {text_prompt}")
|
| 180 |
+
|
| 181 |
+
# Create conversation
|
| 182 |
+
conversation = self.create_conversation(video_path, text_prompt)
|
| 183 |
+
|
| 184 |
+
# Apply chat template
|
| 185 |
+
text = self.processor.apply_chat_template(
|
| 186 |
+
conversation,
|
| 187 |
+
tokenize=False,
|
| 188 |
+
add_generation_prompt=True
|
| 189 |
+
)
|
| 190 |
+
logger.info(f"Chat template applied")
|
| 191 |
+
|
| 192 |
+
# set model params
|
| 193 |
+
self.model.config.load_audio_in_video = load_audio_in_video
|
| 194 |
+
self.processor.config.load_audio_in_video = load_audio_in_video
|
| 195 |
+
if num_video_frames > 0:
|
| 196 |
+
self.model.config.num_video_frames = num_video_frames
|
| 197 |
+
self.processor.config.num_video_frames = num_video_frames
|
| 198 |
+
if audio_length != -1:
|
| 199 |
+
self.model.config.audio_chunk_length = audio_length
|
| 200 |
+
self.processor.config.audio_chunk_length = audio_length
|
| 201 |
+
logger.info(f"Model config - load_audio_in_video: {self.model.config.load_audio_in_video}, num_video_frames: {self.model.config.num_video_frames}, audio_chunk_length: {self.model.config.audio_chunk_length}")
|
| 202 |
+
|
| 203 |
+
# Process inputs
|
| 204 |
+
start_time = time.time()
|
| 205 |
+
inputs = self.processor([text])
|
| 206 |
+
|
| 207 |
+
# Move inputs to the correct device if needed
|
| 208 |
+
if hasattr(inputs, 'input_ids') and inputs.input_ids is not None:
|
| 209 |
+
inputs.input_ids = inputs.input_ids.to(self.device)
|
| 210 |
+
|
| 211 |
+
processing_time = time.time() - start_time
|
| 212 |
+
logger.info(f"Input processing completed in {processing_time:.2f} seconds")
|
| 213 |
+
|
| 214 |
+
logger.info("Generating response...")
|
| 215 |
+
start_time = time.time()
|
| 216 |
+
|
| 217 |
+
generation_kwargs = {"max_new_tokens": max_new_tokens, "max_length": 99999999}
|
| 218 |
+
if top_p is not None:
|
| 219 |
+
generation_kwargs["top_p"] = top_p
|
| 220 |
+
if do_sample is not None:
|
| 221 |
+
generation_kwargs["do_sample"] = do_sample
|
| 222 |
+
if temperature is not None:
|
| 223 |
+
generation_kwargs["temperature"] = temperature
|
| 224 |
+
|
| 225 |
+
generation_config = self.model.default_generation_config
|
| 226 |
+
generation_config.update(**generation_kwargs)
|
| 227 |
+
|
| 228 |
+
logger.info(f"Generation config: {generation_config.to_dict()}")
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
with torch.no_grad():
|
| 232 |
+
output_ids = self.model.generate(
|
| 233 |
+
input_ids=inputs.input_ids,
|
| 234 |
+
media=getattr(inputs, 'media', None),
|
| 235 |
+
media_config=getattr(inputs, 'media_config', None),
|
| 236 |
+
generation_config=generation_config,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
generation_time = time.time() - start_time
|
| 240 |
+
logger.info(f"Generation completed in {generation_time:.2f} seconds")
|
| 241 |
+
|
| 242 |
+
# Decode response
|
| 243 |
+
response = self.processor.tokenizer.batch_decode(
|
| 244 |
+
output_ids,
|
| 245 |
+
skip_special_tokens=True
|
| 246 |
+
)[0]
|
| 247 |
+
|
| 248 |
+
return response
|
| 249 |
+
|
| 250 |
+
def batch_generate(
|
| 251 |
+
self,
|
| 252 |
+
video_text_pairs: List[tuple],
|
| 253 |
+
**generation_kwargs
|
| 254 |
+
) -> List[Optional[str]]:
|
| 255 |
+
"""
|
| 256 |
+
Generate responses for multiple video-text pairs.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
video_text_pairs (List[tuple]): List of (video_path, text_prompt) tuples
|
| 260 |
+
**generation_kwargs: Arguments passed to generate_response
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
List[Optional[str]]: List of generated responses
|
| 264 |
+
"""
|
| 265 |
+
responses = []
|
| 266 |
+
for i, (video_path, text_prompt) in enumerate(video_text_pairs):
|
| 267 |
+
logger.info(f"Processing batch item {i+1}/{len(video_text_pairs)}")
|
| 268 |
+
response = self.generate_response(video_path, text_prompt, **generation_kwargs)
|
| 269 |
+
responses.append(response)
|
| 270 |
+
|
| 271 |
+
# Clear cache between generations to manage memory
|
| 272 |
+
if torch.cuda.is_available():
|
| 273 |
+
torch.cuda.empty_cache()
|
| 274 |
+
|
| 275 |
+
return responses
|
| 276 |
+
|
| 277 |
+
def main():
|
| 278 |
+
"""Main function demonstrating usage of the NVOmni model."""
|
| 279 |
+
|
| 280 |
+
# Configuration
|
| 281 |
+
MODEL_PATH = "./"
|
| 282 |
+
VIDEO_PATH = "xxx.mp4"
|
| 283 |
+
TEXT_PROMPT = "Assess the video, followed by a detailed description of it's video and audio contents."
|
| 284 |
+
|
| 285 |
+
num_video_frames=128
|
| 286 |
+
audio_length="max_3600"
|
| 287 |
+
load_audio_in_video=True
|
| 288 |
+
|
| 289 |
+
add_to_sys_path_direct(MODEL_PATH)
|
| 290 |
+
|
| 291 |
+
# Initialize the inference class
|
| 292 |
+
logger.info("Initializing NVOmni Video Inference...")
|
| 293 |
+
inferencer = NVOmniVideoInference(MODEL_PATH, torch_dtype="torch.float16")
|
| 294 |
+
|
| 295 |
+
if inferencer.model is None:
|
| 296 |
+
logger.error("Failed to initialize model. Exiting.")
|
| 297 |
+
return
|
| 298 |
+
|
| 299 |
+
# Generate response
|
| 300 |
+
logger.info("Starting inference...")
|
| 301 |
+
response = inferencer.generate_response(
|
| 302 |
+
video_path=VIDEO_PATH,
|
| 303 |
+
text_prompt=TEXT_PROMPT,
|
| 304 |
+
num_video_frames=num_video_frames,
|
| 305 |
+
load_audio_in_video=load_audio_in_video,
|
| 306 |
+
audio_length=audio_length,
|
| 307 |
+
max_new_tokens=1024,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
if response:
|
| 311 |
+
print("\n" + "="*60)
|
| 312 |
+
print("GENERATED RESPONSE")
|
| 313 |
+
print("="*60)
|
| 314 |
+
print(response)
|
| 315 |
+
print("="*60)
|
| 316 |
+
else:
|
| 317 |
+
logger.error("Failed to generate response")
|
| 318 |
+
|
| 319 |
+
# Example of batch processing
|
| 320 |
+
if False:
|
| 321 |
+
logger.info("\nExample: Batch processing")
|
| 322 |
+
batch_pairs = [
|
| 323 |
+
(VIDEO_PATH, "What is happening in this video?"),
|
| 324 |
+
(VIDEO_PATH, "Describe the audio content of this video."),
|
| 325 |
+
]
|
| 326 |
+
|
| 327 |
+
batch_responses = inferencer.batch_generate(batch_pairs, max_new_tokens=128)
|
| 328 |
+
|
| 329 |
+
for i, (pair, response) in enumerate(zip(batch_pairs, batch_responses)):
|
| 330 |
+
print(f"\n--- Batch Response {i+1} ---")
|
| 331 |
+
print(f"Prompt: {pair[1]}")
|
| 332 |
+
print(f"Response: {response}")
|
| 333 |
+
|
| 334 |
+
if __name__ == "__main__":
|
| 335 |
+
main()
|
example_mini_audio.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
Example script for audio transcription using the model.
|
| 18 |
+
|
| 19 |
+
This script demonstrates how to:
|
| 20 |
+
1. Load the model and processor
|
| 21 |
+
2. Configure audio processing parameters
|
| 22 |
+
3. Process audio input
|
| 23 |
+
4. Generate transcription output
|
| 24 |
+
|
| 25 |
+
Usage:
|
| 26 |
+
python example_mini_audio.py --model_path <path_to_model> --audio_path <path_to_audio>
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from transformers import AutoProcessor, AutoModel, AutoConfig, AutoModelForCausalLM
|
| 30 |
+
import torch
|
| 31 |
+
import os
|
| 32 |
+
import argparse
|
| 33 |
+
|
| 34 |
+
# Configuration
|
| 35 |
+
parser = argparse.ArgumentParser(description="Audio transcription example")
|
| 36 |
+
parser.add_argument("--model_path", type=str, default="./", help="Path to the model")
|
| 37 |
+
parser.add_argument("--audio_path", type=str, required=True, help="Path to the audio file")
|
| 38 |
+
parser.add_argument("--max_new_tokens", type=int, default=1024, help="Maximum number of tokens to generate")
|
| 39 |
+
parser.add_argument("--num_video_frames", type=int, default=128, help="Number of video frames to process")
|
| 40 |
+
parser.add_argument("--audio_length", type=str, default="max_3600", help="Maximum audio length")
|
| 41 |
+
|
| 42 |
+
args = parser.parse_args()
|
| 43 |
+
|
| 44 |
+
model_path = args.model_path
|
| 45 |
+
audio_path = args.audio_path
|
| 46 |
+
generation_kwargs = {"max_new_tokens": args.max_new_tokens, "max_length": 99999999}
|
| 47 |
+
load_audio_in_video = True
|
| 48 |
+
num_video_frames = args.num_video_frames
|
| 49 |
+
audio_length = args.audio_length
|
| 50 |
+
|
| 51 |
+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
| 52 |
+
|
| 53 |
+
model = AutoModel.from_pretrained(model_path,
|
| 54 |
+
trust_remote_code=True,
|
| 55 |
+
torch_dtype="torch.float16",
|
| 56 |
+
device_map="auto")
|
| 57 |
+
|
| 58 |
+
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
| 59 |
+
generation_config = model.default_generation_config
|
| 60 |
+
generation_config.update(**generation_kwargs)
|
| 61 |
+
|
| 62 |
+
model.config.load_audio_in_video = load_audio_in_video
|
| 63 |
+
processor.config.load_audio_in_video = load_audio_in_video
|
| 64 |
+
if num_video_frames > 0:
|
| 65 |
+
model.config.num_video_frames = num_video_frames
|
| 66 |
+
processor.config.num_video_frames = num_video_frames
|
| 67 |
+
if audio_length != -1:
|
| 68 |
+
model.config.audio_chunk_length = audio_length
|
| 69 |
+
processor.config.audio_chunk_length = audio_length
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
conversation = [{
|
| 73 |
+
"role": "user",
|
| 74 |
+
"content": [
|
| 75 |
+
{"type": "audio", "audio": audio_path},
|
| 76 |
+
{"type": "text", "text": "Transcribe the whole speech."}
|
| 77 |
+
]
|
| 78 |
+
}]
|
| 79 |
+
text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
|
| 80 |
+
|
| 81 |
+
inputs = processor([text])
|
| 82 |
+
|
| 83 |
+
output_ids = model.generate(
|
| 84 |
+
input_ids=inputs.input_ids,
|
| 85 |
+
media=getattr(inputs, 'media', None),
|
| 86 |
+
media_config=getattr(inputs, 'media_config', None),
|
| 87 |
+
generation_config=generation_config,
|
| 88 |
+
)
|
| 89 |
+
print(processor.tokenizer.batch_decode(output_ids, skip_special_tokens=True))
|
example_mini_image.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
Example script for image understanding using the model.
|
| 18 |
+
|
| 19 |
+
This script demonstrates how to:
|
| 20 |
+
1. Load the model and processor
|
| 21 |
+
2. Process image input
|
| 22 |
+
3. Generate description output
|
| 23 |
+
|
| 24 |
+
Usage:
|
| 25 |
+
python example_mini_image.py --model_path <path_to_model> --image_path <path_to_image>
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from transformers import AutoProcessor, AutoModel, AutoConfig, AutoModelForCausalLM
|
| 29 |
+
import torch
|
| 30 |
+
import os
|
| 31 |
+
import argparse
|
| 32 |
+
|
| 33 |
+
# Configuration
|
| 34 |
+
parser = argparse.ArgumentParser(description="Image understanding example")
|
| 35 |
+
parser.add_argument("--model_path", type=str, default="./", help="Path to the model")
|
| 36 |
+
parser.add_argument("--image_path", type=str, required=True, help="Path to the image file")
|
| 37 |
+
parser.add_argument("--max_new_tokens", type=int, default=1024, help="Maximum number of tokens to generate")
|
| 38 |
+
parser.add_argument("--prompt", type=str, default="Describe the image in detail.", help="Text prompt for the model")
|
| 39 |
+
|
| 40 |
+
args = parser.parse_args()
|
| 41 |
+
|
| 42 |
+
model_path = args.model_path
|
| 43 |
+
image_path = args.image_path
|
| 44 |
+
generation_kwargs = {"max_new_tokens": args.max_new_tokens, "max_length": 99999999}
|
| 45 |
+
|
| 46 |
+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
| 47 |
+
|
| 48 |
+
model = AutoModel.from_pretrained(
|
| 49 |
+
model_path,
|
| 50 |
+
trust_remote_code=True,
|
| 51 |
+
torch_dtype=torch.float16,
|
| 52 |
+
device_map="auto"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
| 56 |
+
generation_config = model.default_generation_config
|
| 57 |
+
generation_config.update(**generation_kwargs)
|
| 58 |
+
|
| 59 |
+
conversation = [{
|
| 60 |
+
"role": "user",
|
| 61 |
+
"content": [
|
| 62 |
+
{"type": "image", "image": image_path},
|
| 63 |
+
{"type": "text", "text": args.prompt}
|
| 64 |
+
]
|
| 65 |
+
}]
|
| 66 |
+
text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
|
| 67 |
+
|
| 68 |
+
inputs = processor([text])
|
| 69 |
+
|
| 70 |
+
output_ids = model.generate(
|
| 71 |
+
input_ids=inputs.input_ids,
|
| 72 |
+
media=getattr(inputs, 'media', None),
|
| 73 |
+
media_config=getattr(inputs, 'media_config', None),
|
| 74 |
+
generation_config=generation_config,
|
| 75 |
+
)
|
| 76 |
+
print(processor.tokenizer.batch_decode(output_ids, skip_special_tokens=True))
|
example_mini_video.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
Example script for video understanding using the model.
|
| 18 |
+
|
| 19 |
+
This script demonstrates how to:
|
| 20 |
+
1. Load the model and processor
|
| 21 |
+
2. Configure video and audio processing parameters
|
| 22 |
+
3. Process video input with optional audio
|
| 23 |
+
4. Generate description output
|
| 24 |
+
|
| 25 |
+
Usage:
|
| 26 |
+
python example_mini_video.py --model_path <path_to_model> --video_path <path_to_video>
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from transformers import AutoProcessor, AutoModel, AutoConfig, AutoModelForCausalLM
|
| 30 |
+
import torch
|
| 31 |
+
import os
|
| 32 |
+
import argparse
|
| 33 |
+
|
| 34 |
+
# Configuration
|
| 35 |
+
parser = argparse.ArgumentParser(description="Video understanding example")
|
| 36 |
+
parser.add_argument("--model_path", type=str, default="./", help="Path to the model")
|
| 37 |
+
parser.add_argument("--video_path", type=str, required=True, help="Path to the video file")
|
| 38 |
+
parser.add_argument("--max_new_tokens", type=int, default=1024, help="Maximum number of tokens to generate")
|
| 39 |
+
parser.add_argument("--num_video_frames", type=int, default=128, help="Number of video frames to process")
|
| 40 |
+
parser.add_argument("--audio_length", type=str, default="max_3600", help="Maximum audio length")
|
| 41 |
+
parser.add_argument("--prompt", type=str, default="What are they talking about in detail?", help="Text prompt for the model")
|
| 42 |
+
parser.add_argument("--load_audio", action="store_true", default=True, help="Load audio from video")
|
| 43 |
+
|
| 44 |
+
args = parser.parse_args()
|
| 45 |
+
|
| 46 |
+
model_path = args.model_path
|
| 47 |
+
video_path = args.video_path
|
| 48 |
+
generation_kwargs = {"max_new_tokens": args.max_new_tokens, "max_length": 99999999}
|
| 49 |
+
load_audio_in_video = args.load_audio
|
| 50 |
+
num_video_frames = args.num_video_frames
|
| 51 |
+
audio_length = args.audio_length
|
| 52 |
+
text_prompt = args.prompt
|
| 53 |
+
|
| 54 |
+
assert os.path.exists(video_path), f"Video path {video_path} does not exist."
|
| 55 |
+
|
| 56 |
+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
| 57 |
+
|
| 58 |
+
model = AutoModel.from_pretrained(
|
| 59 |
+
model_path,
|
| 60 |
+
trust_remote_code=True,
|
| 61 |
+
torch_dtype=torch.float16,
|
| 62 |
+
device_map="auto"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
| 66 |
+
generation_config = model.default_generation_config
|
| 67 |
+
generation_config.update(**generation_kwargs)
|
| 68 |
+
|
| 69 |
+
model.config.load_audio_in_video = load_audio_in_video
|
| 70 |
+
processor.config.load_audio_in_video = load_audio_in_video
|
| 71 |
+
if num_video_frames > 0:
|
| 72 |
+
model.config.num_video_frames = num_video_frames
|
| 73 |
+
processor.config.num_video_frames = num_video_frames
|
| 74 |
+
if audio_length != -1:
|
| 75 |
+
model.config.audio_chunk_length = audio_length
|
| 76 |
+
processor.config.audio_chunk_length = audio_length
|
| 77 |
+
|
| 78 |
+
def forward_inference(video_path, text_prompt):
|
| 79 |
+
"""Run inference on video with text prompt."""
|
| 80 |
+
print(f"Text prompt: {text_prompt}")
|
| 81 |
+
print(f"Video path: {video_path}")
|
| 82 |
+
conversation = [{
|
| 83 |
+
"role": "user",
|
| 84 |
+
"content": [
|
| 85 |
+
{"type": "video", "video": video_path},
|
| 86 |
+
{"type": "text", "text": text_prompt}
|
| 87 |
+
]
|
| 88 |
+
}]
|
| 89 |
+
text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
|
| 90 |
+
|
| 91 |
+
inputs = processor([text])
|
| 92 |
+
|
| 93 |
+
output_ids = model.generate(
|
| 94 |
+
input_ids=inputs.input_ids,
|
| 95 |
+
media=getattr(inputs, 'media', None),
|
| 96 |
+
media_config=getattr(inputs, 'media_config', None),
|
| 97 |
+
generation_config=generation_config,
|
| 98 |
+
)
|
| 99 |
+
print(processor.tokenizer.batch_decode(output_ids, skip_special_tokens=True))
|
| 100 |
+
|
| 101 |
+
forward_inference(video_path, text_prompt)
|
llm/added_tokens.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"<image>": 151649,
|
| 3 |
+
"<sound>": 151652,
|
| 4 |
+
"<speech>": 151651,
|
| 5 |
+
"<vila/sentinel>": 151648,
|
| 6 |
+
"<vila/video>": 151650,
|
| 7 |
+
"<|endoftext|>": 151643,
|
| 8 |
+
"<|im_end|>": 151645,
|
| 9 |
+
"<|im_start|>": 151644,
|
| 10 |
+
"<|image_bos|>": 151653,
|
| 11 |
+
"<|image_eos|>": 151654,
|
| 12 |
+
"<|sound_bos|>": 151659,
|
| 13 |
+
"<|sound_eos|>": 151660,
|
| 14 |
+
"<|speech_bos|>": 151657,
|
| 15 |
+
"<|speech_eos|>": 151658,
|
| 16 |
+
"<|video_bos|>": 151655,
|
| 17 |
+
"<|video_eos|>": 151656,
|
| 18 |
+
"[BOS]": 151646,
|
| 19 |
+
"[PAD]": 151647
|
| 20 |
+
}
|
llm/config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
llm/generation_config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 151643,
|
| 3 |
+
"do_sample": true,
|
| 4 |
+
"eos_token_id": [
|
| 5 |
+
151645,
|
| 6 |
+
151643
|
| 7 |
+
],
|
| 8 |
+
"pad_token_id": 151643,
|
| 9 |
+
"repetition_penalty": 1.05,
|
| 10 |
+
"temperature": 0.7,
|
| 11 |
+
"top_k": 20,
|
| 12 |
+
"top_p": 0.8,
|
| 13 |
+
"transformers_version": "4.46.0"
|
| 14 |
+
}
|
llm/merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
llm/model-00001-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:94f591013a08df71d152c96e5bc415bedc434bf30514fb03bd39c8d49e7161cd
|
| 3 |
+
size 4874772072
|
llm/model-00002-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4f8d257b5d45c7d51d3abd0b25a524ae8234cca1d6536976dfa833aa9bb06ffe
|
| 3 |
+
size 4932751008
|
llm/model-00003-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:872102a5b7897af896ea518db185fda955786e8f45d70a49c48bbb8a4a45d305
|
| 3 |
+
size 4330865200
|
llm/model-00004-of-00004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9b82d845caeacc0f77859bece772e05aaf828b51b25946d4ab58801845130b30
|
| 3 |
+
size 1087106176
|
llm/model.safetensors.index.json
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 15225455616
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"lm_head.weight": "model-00004-of-00004.safetensors",
|
| 7 |
+
"model.embed_tokens.weight": "model-00001-of-00004.safetensors",
|
| 8 |
+
"model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 9 |
+
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 10 |
+
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 11 |
+
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 12 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 13 |
+
"model.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 14 |
+
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 15 |
+
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 16 |
+
"model.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 17 |
+
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 18 |
+
"model.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 19 |
+
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 20 |
+
"model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 21 |
+
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 22 |
+
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 23 |
+
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 24 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 25 |
+
"model.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 26 |
+
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 27 |
+
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 28 |
+
"model.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 29 |
+
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 30 |
+
"model.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 31 |
+
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 32 |
+
"model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 33 |
+
"model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 34 |
+
"model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 35 |
+
"model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 36 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 37 |
+
"model.layers.10.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 38 |
+
"model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 39 |
+
"model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 40 |
+
"model.layers.10.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 41 |
+
"model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 42 |
+
"model.layers.10.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 43 |
+
"model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 44 |
+
"model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 45 |
+
"model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 46 |
+
"model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 47 |
+
"model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 48 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 49 |
+
"model.layers.11.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 50 |
+
"model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 51 |
+
"model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 52 |
+
"model.layers.11.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 53 |
+
"model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 54 |
+
"model.layers.11.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 55 |
+
"model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 56 |
+
"model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 57 |
+
"model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 58 |
+
"model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 59 |
+
"model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 60 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 61 |
+
"model.layers.12.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 62 |
+
"model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 63 |
+
"model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 64 |
+
"model.layers.12.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 65 |
+
"model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 66 |
+
"model.layers.12.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 67 |
+
"model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 68 |
+
"model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 69 |
+
"model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 70 |
+
"model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 71 |
+
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 72 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 73 |
+
"model.layers.13.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 74 |
+
"model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 75 |
+
"model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 76 |
+
"model.layers.13.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 77 |
+
"model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 78 |
+
"model.layers.13.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 79 |
+
"model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 80 |
+
"model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 81 |
+
"model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 82 |
+
"model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 83 |
+
"model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 84 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 85 |
+
"model.layers.14.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 86 |
+
"model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 87 |
+
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 88 |
+
"model.layers.14.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 89 |
+
"model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 90 |
+
"model.layers.14.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 91 |
+
"model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 92 |
+
"model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 93 |
+
"model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 94 |
+
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 95 |
+
"model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 96 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 97 |
+
"model.layers.15.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 98 |
+
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 99 |
+
"model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 100 |
+
"model.layers.15.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 101 |
+
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 102 |
+
"model.layers.15.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 103 |
+
"model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 104 |
+
"model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 105 |
+
"model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 106 |
+
"model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 107 |
+
"model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 108 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 109 |
+
"model.layers.16.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 110 |
+
"model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 111 |
+
"model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 112 |
+
"model.layers.16.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 113 |
+
"model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 114 |
+
"model.layers.16.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 115 |
+
"model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 116 |
+
"model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 117 |
+
"model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 118 |
+
"model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 119 |
+
"model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 120 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 121 |
+
"model.layers.17.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 122 |
+
"model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 123 |
+
"model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 124 |
+
"model.layers.17.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 125 |
+
"model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 126 |
+
"model.layers.17.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 127 |
+
"model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 128 |
+
"model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 129 |
+
"model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 130 |
+
"model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 131 |
+
"model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 132 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 133 |
+
"model.layers.18.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 134 |
+
"model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 135 |
+
"model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 136 |
+
"model.layers.18.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 137 |
+
"model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 138 |
+
"model.layers.18.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 139 |
+
"model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 140 |
+
"model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 141 |
+
"model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 142 |
+
"model.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 143 |
+
"model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 144 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 145 |
+
"model.layers.19.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 146 |
+
"model.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 147 |
+
"model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 148 |
+
"model.layers.19.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 149 |
+
"model.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 150 |
+
"model.layers.19.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 151 |
+
"model.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 152 |
+
"model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 153 |
+
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 154 |
+
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 155 |
+
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 156 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 157 |
+
"model.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 158 |
+
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 159 |
+
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 160 |
+
"model.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 161 |
+
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 162 |
+
"model.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 163 |
+
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 164 |
+
"model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 165 |
+
"model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 166 |
+
"model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 167 |
+
"model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 168 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 169 |
+
"model.layers.20.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 170 |
+
"model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 171 |
+
"model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 172 |
+
"model.layers.20.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 173 |
+
"model.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 174 |
+
"model.layers.20.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 175 |
+
"model.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 176 |
+
"model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 177 |
+
"model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 178 |
+
"model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 179 |
+
"model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 180 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 181 |
+
"model.layers.21.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 182 |
+
"model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 183 |
+
"model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 184 |
+
"model.layers.21.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 185 |
+
"model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 186 |
+
"model.layers.21.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 187 |
+
"model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 188 |
+
"model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 189 |
+
"model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 190 |
+
"model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 191 |
+
"model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 192 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 193 |
+
"model.layers.22.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 194 |
+
"model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 195 |
+
"model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 196 |
+
"model.layers.22.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 197 |
+
"model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 198 |
+
"model.layers.22.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 199 |
+
"model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 200 |
+
"model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 201 |
+
"model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 202 |
+
"model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 203 |
+
"model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 204 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 205 |
+
"model.layers.23.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 206 |
+
"model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 207 |
+
"model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 208 |
+
"model.layers.23.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 209 |
+
"model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 210 |
+
"model.layers.23.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 211 |
+
"model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 212 |
+
"model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 213 |
+
"model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 214 |
+
"model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 215 |
+
"model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 216 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 217 |
+
"model.layers.24.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 218 |
+
"model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 219 |
+
"model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 220 |
+
"model.layers.24.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 221 |
+
"model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 222 |
+
"model.layers.24.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 223 |
+
"model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 224 |
+
"model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 225 |
+
"model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 226 |
+
"model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 227 |
+
"model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 228 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 229 |
+
"model.layers.25.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 230 |
+
"model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 231 |
+
"model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 232 |
+
"model.layers.25.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 233 |
+
"model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 234 |
+
"model.layers.25.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 235 |
+
"model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 236 |
+
"model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 237 |
+
"model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 238 |
+
"model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 239 |
+
"model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 240 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 241 |
+
"model.layers.26.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 242 |
+
"model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 243 |
+
"model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 244 |
+
"model.layers.26.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 245 |
+
"model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 246 |
+
"model.layers.26.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 247 |
+
"model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 248 |
+
"model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 249 |
+
"model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
| 250 |
+
"model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
| 251 |
+
"model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
| 252 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
| 253 |
+
"model.layers.27.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
| 254 |
+
"model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
| 255 |
+
"model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
| 256 |
+
"model.layers.27.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
| 257 |
+
"model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
| 258 |
+
"model.layers.27.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
| 259 |
+
"model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
| 260 |
+
"model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 261 |
+
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 262 |
+
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 263 |
+
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 264 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 265 |
+
"model.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 266 |
+
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 267 |
+
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 268 |
+
"model.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 269 |
+
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 270 |
+
"model.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 271 |
+
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 272 |
+
"model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 273 |
+
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 274 |
+
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 275 |
+
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 276 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 277 |
+
"model.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 278 |
+
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 279 |
+
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 280 |
+
"model.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 281 |
+
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 282 |
+
"model.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 283 |
+
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 284 |
+
"model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 285 |
+
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 286 |
+
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 287 |
+
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 288 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 289 |
+
"model.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 290 |
+
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 291 |
+
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 292 |
+
"model.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 293 |
+
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 294 |
+
"model.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 295 |
+
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 296 |
+
"model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 297 |
+
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 298 |
+
"model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 299 |
+
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 300 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 301 |
+
"model.layers.6.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 302 |
+
"model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 303 |
+
"model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 304 |
+
"model.layers.6.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 305 |
+
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 306 |
+
"model.layers.6.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 307 |
+
"model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 308 |
+
"model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 309 |
+
"model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
| 310 |
+
"model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
| 311 |
+
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
| 312 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
| 313 |
+
"model.layers.7.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 314 |
+
"model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 315 |
+
"model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 316 |
+
"model.layers.7.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 317 |
+
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 318 |
+
"model.layers.7.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 319 |
+
"model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 320 |
+
"model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 321 |
+
"model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 322 |
+
"model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 323 |
+
"model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 324 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 325 |
+
"model.layers.8.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
| 326 |
+
"model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
| 327 |
+
"model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
| 328 |
+
"model.layers.8.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
| 329 |
+
"model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
| 330 |
+
"model.layers.8.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
| 331 |
+
"model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
| 332 |
+
"model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 333 |
+
"model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
| 334 |
+
"model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
| 335 |
+
"model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
| 336 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
| 337 |
+
"model.layers.9.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
| 338 |
+
"model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
| 339 |
+
"model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
| 340 |
+
"model.layers.9.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
| 341 |
+
"model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
| 342 |
+
"model.layers.9.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
| 343 |
+
"model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
| 344 |
+
"model.norm.weight": "model-00003-of-00004.safetensors"
|
| 345 |
+
}
|
| 346 |
+
}
|
llm/special_tokens_map.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
{
|
| 4 |
+
"content": "<|sound_bos|>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"content": "<|sound_eos|>",
|
| 12 |
+
"lstrip": false,
|
| 13 |
+
"normalized": false,
|
| 14 |
+
"rstrip": false,
|
| 15 |
+
"single_word": false
|
| 16 |
+
}
|
| 17 |
+
],
|
| 18 |
+
"bos_token": {
|
| 19 |
+
"content": "[BOS]",
|
| 20 |
+
"lstrip": false,
|
| 21 |
+
"normalized": false,
|
| 22 |
+
"rstrip": false,
|
| 23 |
+
"single_word": false
|
| 24 |
+
},
|
| 25 |
+
"eos_token": {
|
| 26 |
+
"content": "<|im_end|>",
|
| 27 |
+
"lstrip": false,
|
| 28 |
+
"normalized": false,
|
| 29 |
+
"rstrip": false,
|
| 30 |
+
"single_word": false
|
| 31 |
+
},
|
| 32 |
+
"pad_token": {
|
| 33 |
+
"content": "[PAD]",
|
| 34 |
+
"lstrip": false,
|
| 35 |
+
"normalized": false,
|
| 36 |
+
"rstrip": false,
|
| 37 |
+
"single_word": false
|
| 38 |
+
}
|
| 39 |
+
}
|
llm/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:491635783283196cfd9ab5d019617234a246b35a58da4761afd6ad77380f43c8
|
| 3 |
+
size 11415920
|
llm/tokenizer_config.json
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"added_tokens_decoder": {
|
| 4 |
+
"151643": {
|
| 5 |
+
"content": "<|endoftext|>",
|
| 6 |
+
"lstrip": false,
|
| 7 |
+
"normalized": false,
|
| 8 |
+
"rstrip": false,
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"special": true
|
| 11 |
+
},
|
| 12 |
+
"151644": {
|
| 13 |
+
"content": "<|im_start|>",
|
| 14 |
+
"lstrip": false,
|
| 15 |
+
"normalized": false,
|
| 16 |
+
"rstrip": false,
|
| 17 |
+
"single_word": false,
|
| 18 |
+
"special": true
|
| 19 |
+
},
|
| 20 |
+
"151645": {
|
| 21 |
+
"content": "<|im_end|>",
|
| 22 |
+
"lstrip": false,
|
| 23 |
+
"normalized": false,
|
| 24 |
+
"rstrip": false,
|
| 25 |
+
"single_word": false,
|
| 26 |
+
"special": true
|
| 27 |
+
},
|
| 28 |
+
"151646": {
|
| 29 |
+
"content": "[BOS]",
|
| 30 |
+
"lstrip": false,
|
| 31 |
+
"normalized": false,
|
| 32 |
+
"rstrip": false,
|
| 33 |
+
"single_word": false,
|
| 34 |
+
"special": true
|
| 35 |
+
},
|
| 36 |
+
"151647": {
|
| 37 |
+
"content": "[PAD]",
|
| 38 |
+
"lstrip": false,
|
| 39 |
+
"normalized": false,
|
| 40 |
+
"rstrip": false,
|
| 41 |
+
"single_word": false,
|
| 42 |
+
"special": true
|
| 43 |
+
},
|
| 44 |
+
"151648": {
|
| 45 |
+
"content": "<vila/sentinel>",
|
| 46 |
+
"lstrip": false,
|
| 47 |
+
"normalized": false,
|
| 48 |
+
"rstrip": false,
|
| 49 |
+
"single_word": false,
|
| 50 |
+
"special": true
|
| 51 |
+
},
|
| 52 |
+
"151649": {
|
| 53 |
+
"content": "<image>",
|
| 54 |
+
"lstrip": false,
|
| 55 |
+
"normalized": false,
|
| 56 |
+
"rstrip": false,
|
| 57 |
+
"single_word": false,
|
| 58 |
+
"special": true
|
| 59 |
+
},
|
| 60 |
+
"151650": {
|
| 61 |
+
"content": "<vila/video>",
|
| 62 |
+
"lstrip": false,
|
| 63 |
+
"normalized": false,
|
| 64 |
+
"rstrip": false,
|
| 65 |
+
"single_word": false,
|
| 66 |
+
"special": true
|
| 67 |
+
},
|
| 68 |
+
"151651": {
|
| 69 |
+
"content": "<speech>",
|
| 70 |
+
"lstrip": false,
|
| 71 |
+
"normalized": false,
|
| 72 |
+
"rstrip": false,
|
| 73 |
+
"single_word": false,
|
| 74 |
+
"special": true
|
| 75 |
+
},
|
| 76 |
+
"151652": {
|
| 77 |
+
"content": "<sound>",
|
| 78 |
+
"lstrip": false,
|
| 79 |
+
"normalized": false,
|
| 80 |
+
"rstrip": false,
|
| 81 |
+
"single_word": false,
|
| 82 |
+
"special": true
|
| 83 |
+
},
|
| 84 |
+
"151653": {
|
| 85 |
+
"content": "<|image_bos|>",
|
| 86 |
+
"lstrip": false,
|
| 87 |
+
"normalized": false,
|
| 88 |
+
"rstrip": false,
|
| 89 |
+
"single_word": false,
|
| 90 |
+
"special": true
|
| 91 |
+
},
|
| 92 |
+
"151654": {
|
| 93 |
+
"content": "<|image_eos|>",
|
| 94 |
+
"lstrip": false,
|
| 95 |
+
"normalized": false,
|
| 96 |
+
"rstrip": false,
|
| 97 |
+
"single_word": false,
|
| 98 |
+
"special": true
|
| 99 |
+
},
|
| 100 |
+
"151655": {
|
| 101 |
+
"content": "<|video_bos|>",
|
| 102 |
+
"lstrip": false,
|
| 103 |
+
"normalized": false,
|
| 104 |
+
"rstrip": false,
|
| 105 |
+
"single_word": false,
|
| 106 |
+
"special": true
|
| 107 |
+
},
|
| 108 |
+
"151656": {
|
| 109 |
+
"content": "<|video_eos|>",
|
| 110 |
+
"lstrip": false,
|
| 111 |
+
"normalized": false,
|
| 112 |
+
"rstrip": false,
|
| 113 |
+
"single_word": false,
|
| 114 |
+
"special": true
|
| 115 |
+
},
|
| 116 |
+
"151657": {
|
| 117 |
+
"content": "<|speech_bos|>",
|
| 118 |
+
"lstrip": false,
|
| 119 |
+
"normalized": false,
|
| 120 |
+
"rstrip": false,
|
| 121 |
+
"single_word": false,
|
| 122 |
+
"special": true
|
| 123 |
+
},
|
| 124 |
+
"151658": {
|
| 125 |
+
"content": "<|speech_eos|>",
|
| 126 |
+
"lstrip": false,
|
| 127 |
+
"normalized": false,
|
| 128 |
+
"rstrip": false,
|
| 129 |
+
"single_word": false,
|
| 130 |
+
"special": true
|
| 131 |
+
},
|
| 132 |
+
"151659": {
|
| 133 |
+
"content": "<|sound_bos|>",
|
| 134 |
+
"lstrip": false,
|
| 135 |
+
"normalized": false,
|
| 136 |
+
"rstrip": false,
|
| 137 |
+
"single_word": false,
|
| 138 |
+
"special": true
|
| 139 |
+
},
|
| 140 |
+
"151660": {
|
| 141 |
+
"content": "<|sound_eos|>",
|
| 142 |
+
"lstrip": false,
|
| 143 |
+
"normalized": false,
|
| 144 |
+
"rstrip": false,
|
| 145 |
+
"single_word": false,
|
| 146 |
+
"special": true
|
| 147 |
+
}
|
| 148 |
+
},
|
| 149 |
+
"additional_special_tokens": [
|
| 150 |
+
"<|sound_bos|>",
|
| 151 |
+
"<|sound_eos|>"
|
| 152 |
+
],
|
| 153 |
+
"bos_token": "[BOS]",
|
| 154 |
+
"chat_template": "{% if messages[0]['role'] != 'system' %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{% for message in messages if message['content'] is not none %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
| 155 |
+
"clean_up_tokenization_spaces": false,
|
| 156 |
+
"eos_token": "<|im_end|>",
|
| 157 |
+
"errors": "replace",
|
| 158 |
+
"legacy": false,
|
| 159 |
+
"model_max_length": 14000,
|
| 160 |
+
"pad_token": "[PAD]",
|
| 161 |
+
"padding_side": "right",
|
| 162 |
+
"split_special_tokens": false,
|
| 163 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 164 |
+
"unk_token": null
|
| 165 |
+
}
|
llm/vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
media.py
ADDED
|
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import glob
|
| 17 |
+
import time
|
| 18 |
+
import random
|
| 19 |
+
import os
|
| 20 |
+
import tempfile
|
| 21 |
+
from collections import defaultdict
|
| 22 |
+
from io import BytesIO
|
| 23 |
+
from typing import Any, Dict, List, Optional, Union
|
| 24 |
+
import io
|
| 25 |
+
import cv2
|
| 26 |
+
import kaldiio
|
| 27 |
+
import librosa
|
| 28 |
+
import soundfile as sf
|
| 29 |
+
import torch
|
| 30 |
+
import numpy as np
|
| 31 |
+
import PIL
|
| 32 |
+
import PIL.Image
|
| 33 |
+
import requests
|
| 34 |
+
import tarfile
|
| 35 |
+
import whisper
|
| 36 |
+
import decord
|
| 37 |
+
from decord import AudioReader, cpu
|
| 38 |
+
|
| 39 |
+
from transformers import PretrainedConfig
|
| 40 |
+
|
| 41 |
+
MEDIA_TOKENS = {
|
| 42 |
+
"image": "<image>",
|
| 43 |
+
"video": "<vila/video>",
|
| 44 |
+
"speech": "<speech>",
|
| 45 |
+
"sound": "<sound>",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Media:
|
| 50 |
+
"""Base class for media objects."""
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class File(Media):
|
| 55 |
+
"""File-based media object."""
|
| 56 |
+
def __init__(self, path: str) -> None:
|
| 57 |
+
self.path = path
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Image(File):
|
| 61 |
+
"""Image media object."""
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Video(File):
|
| 66 |
+
"""Video media object."""
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class Speech(File):
|
| 71 |
+
"""Speech audio media object."""
|
| 72 |
+
def __init__(self, path, extension: str = None) -> None:
|
| 73 |
+
self.path = path
|
| 74 |
+
self.extension = extension
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Sound(File):
|
| 78 |
+
"""Sound/music audio media object."""
|
| 79 |
+
def __init__(self, path, extension: str = None) -> None:
|
| 80 |
+
self.path = path
|
| 81 |
+
self.extension = extension
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def make_list(obj: Any) -> List:
|
| 85 |
+
"""Convert object to list if not already a list."""
|
| 86 |
+
return obj if isinstance(obj, list) else [obj]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image:
|
| 90 |
+
"""Extract PIL Image from Image object or return PIL Image as-is."""
|
| 91 |
+
if isinstance(image, Image):
|
| 92 |
+
if image.path.startswith("http://") or image.path.startswith("https://"):
|
| 93 |
+
image = PIL.Image.open(requests.get(image.path, stream=True).raw)
|
| 94 |
+
else:
|
| 95 |
+
image = PIL.Image.open(image.path)
|
| 96 |
+
return image
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _load_video_bytesio(
|
| 100 |
+
video_bytesio: BytesIO, *, num_frames: int, config: PretrainedConfig, load_aud: bool = False
|
| 101 |
+
) -> List[PIL.Image.Image]:
|
| 102 |
+
"""Load video from BytesIO object by writing to temporary file."""
|
| 103 |
+
with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
|
| 104 |
+
temp_video.write(video_bytesio.read())
|
| 105 |
+
temp_video_name = temp_video.name
|
| 106 |
+
return _load_video(temp_video_name, num_frames=num_frames, load_aud=load_aud, config=config)
|
| 107 |
+
|
| 108 |
+
def get_overlap(inp1, inp2):
|
| 109 |
+
"""
|
| 110 |
+
Calculates the overlapping time frame between a video clip and an audio segment.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
inp1 (list): [start_sec, end_sec]
|
| 114 |
+
inp2 (list): [start_sec, end_sec]
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
tuple or None: (overlap_start, overlap_end) if overlap exists, else None.
|
| 118 |
+
"""
|
| 119 |
+
# Calculate the maximum start time and minimum end time
|
| 120 |
+
overlap_start = max(inp1[0], inp2[0])
|
| 121 |
+
overlap_end = min(inp1[1], inp2[1])
|
| 122 |
+
|
| 123 |
+
# Check if there is an actual overlap
|
| 124 |
+
if overlap_start < overlap_end:
|
| 125 |
+
return (overlap_start, overlap_end)
|
| 126 |
+
else:
|
| 127 |
+
return None
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _load_video(
|
| 131 |
+
video_path: str, *, num_frames: int, config: PretrainedConfig, load_aud: bool = False
|
| 132 |
+
) -> List[PIL.Image.Image]:
|
| 133 |
+
# Load video frames from a directory
|
| 134 |
+
if os.path.isdir(video_path):
|
| 135 |
+
frame_paths = sorted(glob.glob(os.path.join(video_path, "*")))
|
| 136 |
+
indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int)
|
| 137 |
+
return [PIL.Image.open(frame_paths[index]) for index in indices]
|
| 138 |
+
|
| 139 |
+
# Load video frames from a video file
|
| 140 |
+
vidcap = cv2.VideoCapture(video_path)
|
| 141 |
+
|
| 142 |
+
# Load audio if available and needed
|
| 143 |
+
audio_info = None
|
| 144 |
+
if load_aud:
|
| 145 |
+
try:
|
| 146 |
+
aud_feature, audio_info = _load_speech(video_path, config)
|
| 147 |
+
except Exception as e:
|
| 148 |
+
aud_feature = None
|
| 149 |
+
else:
|
| 150 |
+
aud_feature = None
|
| 151 |
+
|
| 152 |
+
# Find the last frame as frame count might not be accurate
|
| 153 |
+
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 154 |
+
while frame_count > 0:
|
| 155 |
+
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
|
| 156 |
+
if vidcap.grab():
|
| 157 |
+
break
|
| 158 |
+
frame_count -= 1
|
| 159 |
+
else:
|
| 160 |
+
raise ValueError(f"Video '{video_path}' has no frames.")
|
| 161 |
+
|
| 162 |
+
# Extract frames uniformly
|
| 163 |
+
indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
|
| 164 |
+
|
| 165 |
+
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
| 166 |
+
video_duration = frame_count / fps
|
| 167 |
+
|
| 168 |
+
# When load_audio_in_video and interleaved_vis_aud_in_video is True, we need to load frames for each video segment
|
| 169 |
+
if config.load_audio_in_video and config.interleaved_vis_aud_in_video and aud_feature is not None:
|
| 170 |
+
segment_duration = config.interleaved_video_segment_duration
|
| 171 |
+
if segment_duration == -1:
|
| 172 |
+
raise ValueError("video_segment_duration is not set")
|
| 173 |
+
|
| 174 |
+
segment_vis_indices_list = []
|
| 175 |
+
segment_aud_indices_list = []
|
| 176 |
+
segment_counts = np.ceil(video_duration / segment_duration).astype(int)
|
| 177 |
+
|
| 178 |
+
if type(aud_feature) == dict:
|
| 179 |
+
aud_feas = aud_feature["input_features"]
|
| 180 |
+
else:
|
| 181 |
+
aud_feas = aud_feature
|
| 182 |
+
audio_start_sec = audio_info['audio_start_sec']
|
| 183 |
+
audio_end_sec = audio_info['audio_end_sample_sec']
|
| 184 |
+
|
| 185 |
+
stft_frames_per_second = config.audio_sampling_rate // config.audio_hop_length
|
| 186 |
+
|
| 187 |
+
_idx = 0
|
| 188 |
+
aud_sample_start_idx = 0
|
| 189 |
+
for i in range(segment_counts):
|
| 190 |
+
end_frame = min((i+1) * segment_duration * fps, frame_count)
|
| 191 |
+
|
| 192 |
+
_indices = []
|
| 193 |
+
while _idx < len(indices) and indices[_idx] < end_frame and _idx < len(indices):
|
| 194 |
+
_indices.append(indices[_idx])
|
| 195 |
+
_idx += 1
|
| 196 |
+
segment_vis_indices_list.append(_indices)
|
| 197 |
+
clip_start_sec = i * segment_duration
|
| 198 |
+
clip_end_sec = min(clip_start_sec + segment_duration, video_duration)
|
| 199 |
+
|
| 200 |
+
# get the audio indices for the current clip
|
| 201 |
+
overlap = get_overlap([clip_start_sec, clip_end_sec], [audio_start_sec, audio_end_sec])
|
| 202 |
+
if overlap is not None:
|
| 203 |
+
aud_sample_end_idx = round((overlap[1] - audio_start_sec) * stft_frames_per_second)
|
| 204 |
+
segment_aud_indices_list.append([aud_sample_start_idx, aud_sample_end_idx])
|
| 205 |
+
aud_sample_start_idx = aud_sample_end_idx
|
| 206 |
+
else:
|
| 207 |
+
segment_aud_indices_list.append([])
|
| 208 |
+
frames = {}
|
| 209 |
+
frame_times = {}
|
| 210 |
+
for index in indices:
|
| 211 |
+
if index in frames:
|
| 212 |
+
continue
|
| 213 |
+
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
|
| 214 |
+
success, frame = vidcap.read()
|
| 215 |
+
if not success:
|
| 216 |
+
print(f"Failed to read frame {index} from video '{video_path}'. Skipped.")
|
| 217 |
+
continue
|
| 218 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 219 |
+
frames[index] = PIL.Image.fromarray(frame)
|
| 220 |
+
frame_times[index] = index / fps
|
| 221 |
+
|
| 222 |
+
output_frames = [frames[index] for index in indices if index in frames]
|
| 223 |
+
output_frame_times = [frame_times[index] for index in indices if index in frame_times]
|
| 224 |
+
|
| 225 |
+
video_info = {}
|
| 226 |
+
if config.load_audio_in_video and config.interleaved_vis_aud_in_video and aud_feature is not None:
|
| 227 |
+
new_segment_vis_indices_list = []
|
| 228 |
+
processed_frame_index = 0
|
| 229 |
+
for i, segment_indices in enumerate(segment_vis_indices_list):
|
| 230 |
+
new_segment_vis_indices_list.append([])
|
| 231 |
+
for index in segment_indices:
|
| 232 |
+
if index in frames:
|
| 233 |
+
new_segment_vis_indices_list[-1].append(processed_frame_index)
|
| 234 |
+
processed_frame_index += 1
|
| 235 |
+
segment_vis_indices_list = new_segment_vis_indices_list
|
| 236 |
+
|
| 237 |
+
video_info["segment_vis_indices_list"] = segment_vis_indices_list
|
| 238 |
+
video_info["segment_aud_indices_list"] = segment_aud_indices_list
|
| 239 |
+
video_info['expected_frame_count'] = len(indices)
|
| 240 |
+
video_info['video_path'] = video_path
|
| 241 |
+
if audio_info is not None:
|
| 242 |
+
audio_info['video_path'] = video_path
|
| 243 |
+
video_info['has_audio'] = aud_feature is not None
|
| 244 |
+
video_info['video_duration'] = video_duration
|
| 245 |
+
video_info['audio_info'] = audio_info
|
| 246 |
+
|
| 247 |
+
# calculate the time of each frame
|
| 248 |
+
video_info['video_frame_times'] = output_frame_times
|
| 249 |
+
|
| 250 |
+
return output_frames, aud_feature, video_info
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]:
|
| 254 |
+
num_frames = config.num_video_frames
|
| 255 |
+
aud_fea = None
|
| 256 |
+
|
| 257 |
+
if getattr(config, "fps") != 0:
|
| 258 |
+
print("Extracting frames from video with specified FPS is not supported yet. Ignored.")
|
| 259 |
+
|
| 260 |
+
if isinstance(video.path, BytesIO):
|
| 261 |
+
frames, aud_fea, video_info = _load_video_bytesio(
|
| 262 |
+
video.path, num_frames=num_frames, config=config, load_aud=config.load_audio_in_video
|
| 263 |
+
)
|
| 264 |
+
else:
|
| 265 |
+
frames, aud_fea, video_info = _load_video(
|
| 266 |
+
video.path, num_frames=num_frames, config=config, load_aud=config.load_audio_in_video
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
if config.load_audio_in_video:
|
| 270 |
+
return frames, aud_fea, video_info
|
| 271 |
+
else:
|
| 272 |
+
return frames, video_info
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def soundFile_read_audio(audio_file, offset=None, duration=None, dtype='float32'):
|
| 276 |
+
if dtype not in ['int32', 'float32']:
|
| 277 |
+
print("audio dtype must be int32 or float32. Default to float32")
|
| 278 |
+
dtype = 'float32'
|
| 279 |
+
# return read audio and its sample rate
|
| 280 |
+
if isinstance(audio_file, bytes):
|
| 281 |
+
audio_file = io.BytesIO(audio_file)
|
| 282 |
+
with sf.SoundFile(audio_file, 'r') as f:
|
| 283 |
+
sample_rate = f.samplerate
|
| 284 |
+
if offset is not None and offset > 0:
|
| 285 |
+
f.seek(int(offset * sample_rate))
|
| 286 |
+
if duration is not None and duration > 0:
|
| 287 |
+
samples = f.read(int(duration * sample_rate), dtype=dtype)
|
| 288 |
+
else:
|
| 289 |
+
samples = f.read(dtype=dtype)
|
| 290 |
+
return samples, sample_rate
|
| 291 |
+
|
| 292 |
+
def load_audio_from_tar(tar_file, audio_file):
|
| 293 |
+
with tarfile.open(tar_file, 'r') as tar:
|
| 294 |
+
audio_member = tar.getmember(audio_file)
|
| 295 |
+
audio_file = tar.extractfile(audio_member)
|
| 296 |
+
return librosa.load(audio_file)
|
| 297 |
+
|
| 298 |
+
def _load_audio_file(audio_path: str, config: PretrainedConfig):
|
| 299 |
+
# Load video frames from a directory
|
| 300 |
+
if audio_path is None:
|
| 301 |
+
return None
|
| 302 |
+
|
| 303 |
+
dirname = os.path.dirname(audio_path)
|
| 304 |
+
filename = os.path.basename(audio_path)
|
| 305 |
+
|
| 306 |
+
if dirname.endswith(".tar"):
|
| 307 |
+
speech, sample_rate = load_audio_from_tar(dirname, filename)
|
| 308 |
+
else:
|
| 309 |
+
sample_rate = config.audio_sampling_rate
|
| 310 |
+
speech = whisper.load_audio(audio_path, sr=sample_rate)
|
| 311 |
+
|
| 312 |
+
return speech, sample_rate
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def _load_audio(audio: Union[str, dict], config: PretrainedConfig):
|
| 316 |
+
if isinstance(audio, str):
|
| 317 |
+
return _load_audio_file(audio, config)
|
| 318 |
+
elif isinstance(audio, dict):
|
| 319 |
+
audio_sample = audio['sample']
|
| 320 |
+
if isinstance(audio_sample, (bytes, io.BytesIO)):
|
| 321 |
+
offset = audio.get('offset', None)
|
| 322 |
+
duration = audio.get('duration', None)
|
| 323 |
+
dtype = audio.get('dtype', 'float32')
|
| 324 |
+
return soundFile_read_audio(
|
| 325 |
+
audio_sample, offset=offset, duration=duration, dtype=dtype
|
| 326 |
+
)
|
| 327 |
+
elif isinstance(audio_sample, np.ndarray):
|
| 328 |
+
return audio_sample, audio.get('sample_rate')
|
| 329 |
+
else:
|
| 330 |
+
raise ValueError(f"Expect the loaded audio to be a processed numpy array or raw bytes. Got {type(audio_sample)}")
|
| 331 |
+
else:
|
| 332 |
+
raise ValueError(f"Expect input to be a path string or dict. Got {type(audio)}")
|
| 333 |
+
|
| 334 |
+
def _whisper_process(audio, sample_rate, audio_chunk_length, max_chunks_per_file):
|
| 335 |
+
outputs = []
|
| 336 |
+
num_audio_chunks = 0
|
| 337 |
+
|
| 338 |
+
chunk_length = audio_chunk_length * sample_rate
|
| 339 |
+
for i in range(0, len(audio), chunk_length):
|
| 340 |
+
chunk = audio[i : i + chunk_length]
|
| 341 |
+
chunk = whisper.pad_or_trim(chunk)
|
| 342 |
+
if chunk.dtype != np.float32:
|
| 343 |
+
chunk = chunk.astype(np.float32)
|
| 344 |
+
mel = whisper.log_mel_spectrogram(chunk, n_mels=128)
|
| 345 |
+
num_audio_chunks+=1
|
| 346 |
+
outputs.append(mel)
|
| 347 |
+
if num_audio_chunks == max_chunks_per_file:
|
| 348 |
+
break
|
| 349 |
+
|
| 350 |
+
frames = torch.stack(outputs, dim=0)
|
| 351 |
+
return frames.numpy().tolist()
|
| 352 |
+
|
| 353 |
+
def _load_speech(speech, config: PretrainedConfig):
|
| 354 |
+
if type(speech) == str:
|
| 355 |
+
speech_path = speech
|
| 356 |
+
else:
|
| 357 |
+
speech_path = speech.path
|
| 358 |
+
|
| 359 |
+
# Load video frames from a directory
|
| 360 |
+
if speech_path is None:
|
| 361 |
+
return None
|
| 362 |
+
speech_outputs = []
|
| 363 |
+
|
| 364 |
+
if config.audio_chunk_length and not (type(config.audio_chunk_length) == str and "max" in config.audio_chunk_length):
|
| 365 |
+
try:
|
| 366 |
+
config.audio_chunk_length = int(config.audio_chunk_length)
|
| 367 |
+
except Exception as e:
|
| 368 |
+
print(f"Error setting audio_chunk_length: {e}")
|
| 369 |
+
raise e
|
| 370 |
+
|
| 371 |
+
audio_n_samples_limit = config.audio_chunk_length * config.audio_sampling_rate
|
| 372 |
+
|
| 373 |
+
def load_wav(speech_path):
|
| 374 |
+
speech, sr = librosa.load(speech_path, sr=config.audio_sampling_rate)
|
| 375 |
+
cur_max_length = speech.shape[0]
|
| 376 |
+
ori_audio_duration = cur_max_length / sr
|
| 377 |
+
return speech, ori_audio_duration
|
| 378 |
+
|
| 379 |
+
def get_audio(speech, audio_n_samples):
|
| 380 |
+
|
| 381 |
+
if type(speech) == decord.audio_reader.AudioReader:
|
| 382 |
+
ori_n_samples = speech.shape[1]
|
| 383 |
+
else:
|
| 384 |
+
ori_n_samples = speech.shape[0]
|
| 385 |
+
|
| 386 |
+
# random audio smaple
|
| 387 |
+
audio_start_sample_id = 0
|
| 388 |
+
audio_end_sample_id = ori_n_samples
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
load_max_audio = type(config.audio_chunk_length) == str and "max" in config.audio_chunk_length
|
| 392 |
+
if hasattr(config, 'random_audio_sample') and not load_max_audio:
|
| 393 |
+
if ori_n_samples > audio_n_samples:
|
| 394 |
+
audio_start_sample_id = random.randint(0, ori_n_samples - audio_n_samples)
|
| 395 |
+
audio_end_sample_id = audio_start_sample_id + audio_n_samples
|
| 396 |
+
else:
|
| 397 |
+
if load_max_audio:
|
| 398 |
+
if "_" in config.audio_chunk_length:
|
| 399 |
+
max_audio_chunk_length = int(config.audio_chunk_length.split("_")[1])
|
| 400 |
+
max_audio_n_samples = max_audio_chunk_length * config.audio_sampling_rate
|
| 401 |
+
audio_n_samples = min(ori_n_samples, max_audio_n_samples)
|
| 402 |
+
audio_end_sample_id = audio_n_samples
|
| 403 |
+
else:
|
| 404 |
+
audio_n_samples = ori_n_samples
|
| 405 |
+
audio_end_sample_id = audio_n_samples
|
| 406 |
+
else:
|
| 407 |
+
audio_end_sample_id = min(audio_n_samples, ori_n_samples)
|
| 408 |
+
|
| 409 |
+
if type(speech) == decord.audio_reader.AudioReader:
|
| 410 |
+
speech = speech[audio_start_sample_id:audio_end_sample_id].asnumpy()[0]
|
| 411 |
+
else:
|
| 412 |
+
speech = speech[audio_start_sample_id:audio_end_sample_id]
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
return speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id
|
| 416 |
+
|
| 417 |
+
if isinstance(speech_path, dict):
|
| 418 |
+
if "offset" in speech_path:
|
| 419 |
+
speech, ori_sample_rate = _load_audio(speech_path, config)
|
| 420 |
+
|
| 421 |
+
else:
|
| 422 |
+
speech = speech_path["sample"]
|
| 423 |
+
ori_sample_rate = speech_path["sample_rate"]
|
| 424 |
+
|
| 425 |
+
# resample the speech based on current sample rate
|
| 426 |
+
speech = librosa.resample(speech, orig_sr=ori_sample_rate, target_sr=config.audio_sampling_rate)
|
| 427 |
+
# variable audio sequence lengths
|
| 428 |
+
ori_audio_duration = speech.shape[0] / config.audio_sampling_rate
|
| 429 |
+
speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit)
|
| 430 |
+
|
| 431 |
+
elif isinstance(speech_path, BytesIO):
|
| 432 |
+
if speech.extension == ".wav":
|
| 433 |
+
# speech, sr = librosa.load(speech_path, sr=config.audio_sampling_rate)
|
| 434 |
+
# ori_audio_duration = speech.shape[0] / sr
|
| 435 |
+
speech, ori_audio_duration = load_wav(speech_path)
|
| 436 |
+
speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit)
|
| 437 |
+
else:
|
| 438 |
+
raise ValueError(f"Unsupported audio extension: {speech.extension}")
|
| 439 |
+
|
| 440 |
+
elif ".mat" in speech_path or ".ark" in speech_path:
|
| 441 |
+
rate, speech = kaldiio.load_mat(speech_path)
|
| 442 |
+
speech = librosa.resample(speech, orig_sr=rate, target_sr=config.audio_sampling_rate)
|
| 443 |
+
speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit)
|
| 444 |
+
ori_audio_duration = speech.shape[0] / config.audio_sampling_rate
|
| 445 |
+
elif ".mp4" in speech_path:
|
| 446 |
+
# Load audio from video file
|
| 447 |
+
ar = AudioReader(speech_path, ctx=cpu(0), sample_rate=config.audio_sampling_rate, mono=True)
|
| 448 |
+
cur_max_length = ar.shape[1]
|
| 449 |
+
ori_audio_duration = cur_max_length / config.audio_sampling_rate
|
| 450 |
+
speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(ar, audio_n_samples_limit)
|
| 451 |
+
else:
|
| 452 |
+
assert os.path.exists(speech_path), f"File {speech_path} does not exist"
|
| 453 |
+
speech, ori_audio_duration = load_wav(speech_path)
|
| 454 |
+
speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit)
|
| 455 |
+
|
| 456 |
+
# convert to float
|
| 457 |
+
speech = speech.astype(np.float32)
|
| 458 |
+
audio_n_samples = int(np.ceil(speech.shape[0] / (config.audio_sampling_rate * 30)) * (config.audio_sampling_rate * 30))
|
| 459 |
+
|
| 460 |
+
speech = whisper.pad_or_trim(speech, length=audio_n_samples) # we don't pad or trim here, instead, we pad based on the max length of all audio samples in the batch size later
|
| 461 |
+
|
| 462 |
+
new_audio_chunk_length = int(audio_n_samples // config.audio_sampling_rate)
|
| 463 |
+
audio_start_sec = audio_start_sample_id / config.audio_sampling_rate
|
| 464 |
+
audio_end_sample_sec = audio_end_sample_id / config.audio_sampling_rate
|
| 465 |
+
|
| 466 |
+
audio_info = {}
|
| 467 |
+
audio_info['new_audio_chunk_length'] = new_audio_chunk_length
|
| 468 |
+
audio_info['new_audio_n_samples'] = audio_n_samples
|
| 469 |
+
audio_info['ori_audio_duration'] = ori_audio_duration
|
| 470 |
+
audio_info['audio_start_sec'] = audio_start_sec
|
| 471 |
+
audio_info['audio_end_sample_sec'] = audio_end_sample_sec
|
| 472 |
+
|
| 473 |
+
return speech, audio_info
|
| 474 |
+
|
| 475 |
+
def _extract_speech(speech: Speech, config: PretrainedConfig):
|
| 476 |
+
frames, audio_info = _load_speech(speech, config)
|
| 477 |
+
return frames, audio_info
|
| 478 |
+
|
| 479 |
+
_extract_sound = _extract_speech
|
| 480 |
+
def extract_media(
|
| 481 |
+
messages: List[Dict[str, Any]],
|
| 482 |
+
config: Optional[PretrainedConfig] = None,
|
| 483 |
+
draft: bool = False,
|
| 484 |
+
) -> Dict[str, List[Any]]:
|
| 485 |
+
media = defaultdict(list)
|
| 486 |
+
|
| 487 |
+
if not hasattr(config, "load_audio_in_video"):
|
| 488 |
+
print(f"Warning: load_audio_in_video not in config, set to False")
|
| 489 |
+
config.load_audio_in_video = False
|
| 490 |
+
|
| 491 |
+
for message in messages:
|
| 492 |
+
text = ""
|
| 493 |
+
for part in make_list(message["value"]):
|
| 494 |
+
if isinstance(part, str):
|
| 495 |
+
for token in MEDIA_TOKENS.values():
|
| 496 |
+
if token in part:
|
| 497 |
+
print(f"Media token '{token}' found in text: '{part}'. Removed.")
|
| 498 |
+
part = part.replace(token, "").strip()
|
| 499 |
+
text += part
|
| 500 |
+
elif isinstance(part, (Image, PIL.Image.Image)):
|
| 501 |
+
if draft:
|
| 502 |
+
media["image"].append(part)
|
| 503 |
+
else:
|
| 504 |
+
media["image"].append(_extract_image(part))
|
| 505 |
+
text += MEDIA_TOKENS["image"]
|
| 506 |
+
elif isinstance(part, Video):
|
| 507 |
+
if draft:
|
| 508 |
+
media["video"].append(part)
|
| 509 |
+
else:
|
| 510 |
+
if config.load_audio_in_video:
|
| 511 |
+
output, aud_fea, video_info = _extract_video(part, config)
|
| 512 |
+
media["video"].append(output)
|
| 513 |
+
media["video_info"].append(video_info)
|
| 514 |
+
if aud_fea is not None:
|
| 515 |
+
media["sound"].append(aud_fea)
|
| 516 |
+
media["audio_info"].append(video_info['audio_info'])
|
| 517 |
+
text += MEDIA_TOKENS["sound"]
|
| 518 |
+
else:
|
| 519 |
+
output, video_info = _extract_video(part, config)
|
| 520 |
+
media["video"].append(output)
|
| 521 |
+
media["video_info"].append(video_info)
|
| 522 |
+
text += MEDIA_TOKENS["video"]
|
| 523 |
+
elif isinstance(part, Speech):
|
| 524 |
+
if draft:
|
| 525 |
+
if config.unified_audio_encoder:
|
| 526 |
+
media["sound"].append(part)
|
| 527 |
+
text += MEDIA_TOKENS["sound"]
|
| 528 |
+
else:
|
| 529 |
+
media["speech"].append(part)
|
| 530 |
+
text += MEDIA_TOKENS["speech"]
|
| 531 |
+
else:
|
| 532 |
+
output, audio_info = _extract_speech(part, config)
|
| 533 |
+
if output is not None:
|
| 534 |
+
if config.unified_audio_encoder:
|
| 535 |
+
media["sound"].append(output)
|
| 536 |
+
text += MEDIA_TOKENS["sound"]
|
| 537 |
+
else:
|
| 538 |
+
media["speech"].append(output)
|
| 539 |
+
text += MEDIA_TOKENS["speech"]
|
| 540 |
+
media["audio_info"].append(audio_info)
|
| 541 |
+
elif isinstance(part, Sound):
|
| 542 |
+
if draft:
|
| 543 |
+
media["sound"].append(part)
|
| 544 |
+
text += MEDIA_TOKENS["sound"]
|
| 545 |
+
else:
|
| 546 |
+
output, audio_info = _extract_sound(part, config)
|
| 547 |
+
if output is not None:
|
| 548 |
+
media["sound"].append(output)
|
| 549 |
+
media["audio_info"].append(audio_info)
|
| 550 |
+
text += MEDIA_TOKENS["sound"]
|
| 551 |
+
else:
|
| 552 |
+
print(f"part: {part}")
|
| 553 |
+
raise ValueError(f"Unsupported prompt part type: {type(part)}")
|
| 554 |
+
message["value"] = text
|
| 555 |
+
return media
|
media_encoder.py
ADDED
|
@@ -0,0 +1,955 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from functools import partial
|
| 17 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from torch import nn
|
| 21 |
+
from torch.nn import Module, ModuleList
|
| 22 |
+
import numpy as np
|
| 23 |
+
from einops import rearrange, repeat
|
| 24 |
+
from torch.cuda.amp import autocast
|
| 25 |
+
from torch import nn, einsum, broadcast_tensors, Tensor
|
| 26 |
+
from beartype import beartype
|
| 27 |
+
from beartype.typing import Literal, Union, Optional
|
| 28 |
+
from math import pi, log
|
| 29 |
+
import math
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class CacheFeatures(object):
|
| 33 |
+
def __init__(self, value, type):
|
| 34 |
+
self.value = value
|
| 35 |
+
self.type = type
|
| 36 |
+
def my_to(self, device, dtype):
|
| 37 |
+
self.value['features'] = self.value['features'].to(device, dtype) if 'features' in self.value and self.value['features'] is not None else None
|
| 38 |
+
return self
|
| 39 |
+
def __call__(self):
|
| 40 |
+
return self.value
|
| 41 |
+
|
| 42 |
+
def exists(val):
|
| 43 |
+
return val is not None
|
| 44 |
+
|
| 45 |
+
def default(val, d):
|
| 46 |
+
return val if exists(val) else d
|
| 47 |
+
|
| 48 |
+
# broadcat, as tortoise-tts was using it
|
| 49 |
+
|
| 50 |
+
def broadcat(tensors, dim = -1):
|
| 51 |
+
broadcasted_tensors = broadcast_tensors(*tensors)
|
| 52 |
+
|
| 53 |
+
def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor:
|
| 54 |
+
# return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1)
|
| 55 |
+
# Reshape x to group elements along the specified dimension into chunks of 'size', then average over those chunks.
|
| 56 |
+
|
| 57 |
+
# Check if the dimension is divisible by the pool size, if not pad with mean values
|
| 58 |
+
if x.shape[dim] % size != 0:
|
| 59 |
+
print(f"Warning: dimension {dim} with size {x.shape[dim]} is not divisible by pool size {size}, padding with mean values")
|
| 60 |
+
remainder = x.shape[dim] % size
|
| 61 |
+
pad_len = size - remainder
|
| 62 |
+
|
| 63 |
+
# Get the mean of the last few elements along the dimension to be pooled
|
| 64 |
+
last_elements = x.narrow(dim, x.shape[dim] - remainder, remainder)
|
| 65 |
+
mean_value = last_elements.mean()
|
| 66 |
+
|
| 67 |
+
# Create padding tensor with the same shape as x except for the dimension being pooled
|
| 68 |
+
pad_shape = list(x.shape)
|
| 69 |
+
pad_shape[dim] = pad_len
|
| 70 |
+
padding = torch.ones(pad_shape, device=x.device, dtype=x.dtype) * mean_value
|
| 71 |
+
|
| 72 |
+
# Concatenate the original tensor with the padding along the specified dimension
|
| 73 |
+
x = torch.cat([x, padding], dim=dim)
|
| 74 |
+
|
| 75 |
+
shape_before = x.shape[:dim]
|
| 76 |
+
shape_after = x.shape[dim + 1 :]
|
| 77 |
+
new_shape = shape_before + (-1, size) + shape_after
|
| 78 |
+
x_reshaped = x.view(new_shape)
|
| 79 |
+
return x_reshaped.mean(dim + 1)
|
| 80 |
+
|
| 81 |
+
def rotate_half(x):
|
| 82 |
+
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
| 83 |
+
x1, x2 = x.unbind(dim = -1)
|
| 84 |
+
x = torch.stack((-x2, x1), dim = -1)
|
| 85 |
+
return rearrange(x, '... d r -> ... (d r)')
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2):
|
| 89 |
+
with torch.amp.autocast(device_type='cuda', enabled=False):
|
| 90 |
+
ori_dtype = t.dtype
|
| 91 |
+
embed_dtype = torch.float64
|
| 92 |
+
t = t.to(embed_dtype)
|
| 93 |
+
if t.ndim == 3:
|
| 94 |
+
seq_len = t.shape[seq_dim]
|
| 95 |
+
freqs = freqs[-seq_len:].to(t)
|
| 96 |
+
|
| 97 |
+
rot_dim = freqs.shape[-1]
|
| 98 |
+
end_index = start_index + rot_dim
|
| 99 |
+
|
| 100 |
+
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
| 101 |
+
|
| 102 |
+
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
|
| 103 |
+
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
| 104 |
+
return torch.cat((t_left, t, t_right), dim = -1).to(ori_dtype)
|
| 105 |
+
|
| 106 |
+
class MaxTimeContinuousTimeRotaryEmbedding(nn.Module):
|
| 107 |
+
def __init__(self, dim, max_time, period_mode="shortest", device=None):
|
| 108 |
+
super().__init__()
|
| 109 |
+
assert dim % 2 == 0, "RoPE embedding dimension must be even"
|
| 110 |
+
|
| 111 |
+
# Set max period = max_time
|
| 112 |
+
if period_mode == "shortest": # shortest period is max_time
|
| 113 |
+
base = 5
|
| 114 |
+
inv_freq = 2 * math.pi / (max_time * (base ** (torch.arange(0, dim // 2).float() / (dim // 2))))
|
| 115 |
+
elif period_mode == "longest": # longest period is max_time ** ((dim // 2) / (dim // 2 - 1))
|
| 116 |
+
theta = max_time ** ((dim // 2) / (dim // 2 - 1))
|
| 117 |
+
inv_freq = 2 * math.pi / ((theta ** (torch.arange(0, dim // 2).float() / (dim // 2))))
|
| 118 |
+
else:
|
| 119 |
+
raise ValueError(f"Invalid period mode: {period_mode}")
|
| 120 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 121 |
+
|
| 122 |
+
def forward(self, time_values: torch.Tensor):
|
| 123 |
+
"""
|
| 124 |
+
time_values: [batch_size, seq_len], in seconds (or any continuous unit)
|
| 125 |
+
Returns:
|
| 126 |
+
cos, sin: [batch_size, seq_len, dim]
|
| 127 |
+
"""
|
| 128 |
+
batch_size, seq_len = time_values.shape
|
| 129 |
+
time_values_exp = time_values[:, None, :] # [batch, 1, seq_len]
|
| 130 |
+
freqs = (self.inv_freq[None, :, None] @ time_values_exp).transpose(1, 2) # [batch, seq_len, dim//2]
|
| 131 |
+
# emb = torch.cat([freqs, freqs], dim=-1) # [batch, seq_len, dim]
|
| 132 |
+
# return emb.cos(), emb.sin()
|
| 133 |
+
return freqs
|
| 134 |
+
|
| 135 |
+
def get_axial_freqs(self, *dims):
|
| 136 |
+
Colon = slice(None)
|
| 137 |
+
all_freqs = []
|
| 138 |
+
|
| 139 |
+
for ind, dim in enumerate(dims):
|
| 140 |
+
pos = torch.arange(dim, device = self.device)
|
| 141 |
+
|
| 142 |
+
freqs = self.forward(pos, seq_len = dim)
|
| 143 |
+
|
| 144 |
+
all_axis = [None] * len(dims)
|
| 145 |
+
all_axis[ind] = Colon
|
| 146 |
+
|
| 147 |
+
new_axis_slice = (Ellipsis, *all_axis, Colon)
|
| 148 |
+
all_freqs.append(freqs[new_axis_slice])
|
| 149 |
+
|
| 150 |
+
all_freqs = broadcast_tensors(*all_freqs)
|
| 151 |
+
return torch.cat(all_freqs, dim = -1)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class RotaryEmbedding(Module):
|
| 157 |
+
@beartype
|
| 158 |
+
def __init__(
|
| 159 |
+
self,
|
| 160 |
+
dim,
|
| 161 |
+
custom_freqs: Optional[Tensor] = None,
|
| 162 |
+
freqs_for: Union[Literal['lang', 'pixel', 'constant']] = 'lang',
|
| 163 |
+
theta = 10000,
|
| 164 |
+
max_freq = 10,
|
| 165 |
+
num_freqs = 1,
|
| 166 |
+
learned_freq = False,
|
| 167 |
+
use_xpos = False,
|
| 168 |
+
xpos_scale_base = 512,
|
| 169 |
+
interpolate_factor = 1.,
|
| 170 |
+
theta_rescale_factor = 1.,
|
| 171 |
+
seq_before_head_dim = False,
|
| 172 |
+
cache_if_possible = True,
|
| 173 |
+
max_time = None
|
| 174 |
+
):
|
| 175 |
+
super().__init__()
|
| 176 |
+
|
| 177 |
+
self.dim = dim
|
| 178 |
+
self.freqs_for = freqs_for
|
| 179 |
+
self.max_freq = max_freq
|
| 180 |
+
self.num_freqs = num_freqs
|
| 181 |
+
self.learned_freq = learned_freq
|
| 182 |
+
self.use_xpos = use_xpos
|
| 183 |
+
self.xpos_scale_base = xpos_scale_base
|
| 184 |
+
self.interpolate_factor = interpolate_factor
|
| 185 |
+
self.theta_rescale_factor = theta_rescale_factor
|
| 186 |
+
self.cache_if_possible = cache_if_possible
|
| 187 |
+
self.max_time = max_time
|
| 188 |
+
|
| 189 |
+
self.tmp_store('cached_freqs', None)
|
| 190 |
+
self.tmp_store('cached_scales', None)
|
| 191 |
+
|
| 192 |
+
# Adjust theta to avoid angle wrapping after large times
|
| 193 |
+
if exists(max_time) and freqs_for == 'lang':
|
| 194 |
+
# Make sure highest frequency completes 1 full rotation over max time
|
| 195 |
+
# theta = base of exponent: higher theta → lower frequency range
|
| 196 |
+
# max_time * (1/theta^(0)) = 2pi => theta = max_time / (2pi)
|
| 197 |
+
theta = max_time / (2 * pi)
|
| 198 |
+
|
| 199 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
| 200 |
+
|
| 201 |
+
self.theta = theta
|
| 202 |
+
|
| 203 |
+
if exists(custom_freqs):
|
| 204 |
+
freqs = custom_freqs
|
| 205 |
+
elif freqs_for == 'lang':
|
| 206 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 207 |
+
elif freqs_for == 'pixel':
|
| 208 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
| 209 |
+
elif freqs_for == 'constant':
|
| 210 |
+
freqs = torch.ones(num_freqs).float()
|
| 211 |
+
|
| 212 |
+
self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
|
| 213 |
+
|
| 214 |
+
self.learned_freq = learned_freq
|
| 215 |
+
|
| 216 |
+
# dummy for device
|
| 217 |
+
|
| 218 |
+
self.tmp_store('dummy', torch.tensor(0))
|
| 219 |
+
|
| 220 |
+
# default sequence dimension
|
| 221 |
+
|
| 222 |
+
self.seq_before_head_dim = seq_before_head_dim
|
| 223 |
+
self.default_seq_dim = -3 if seq_before_head_dim else -2
|
| 224 |
+
|
| 225 |
+
# interpolation factors
|
| 226 |
+
|
| 227 |
+
assert interpolate_factor >= 1.
|
| 228 |
+
self.interpolate_factor = interpolate_factor
|
| 229 |
+
|
| 230 |
+
# xpos
|
| 231 |
+
if not use_xpos:
|
| 232 |
+
self.tmp_store('scale', None)
|
| 233 |
+
return
|
| 234 |
+
|
| 235 |
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
| 236 |
+
self.scale_base = xpos_scale_base
|
| 237 |
+
self.tmp_store('scale', scale)
|
| 238 |
+
|
| 239 |
+
# add apply_rotary_emb as static method
|
| 240 |
+
|
| 241 |
+
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
|
| 242 |
+
|
| 243 |
+
@property
|
| 244 |
+
def device(self):
|
| 245 |
+
return self.dummy.device
|
| 246 |
+
|
| 247 |
+
def tmp_store(self, key, value):
|
| 248 |
+
self.register_buffer(key, value, persistent = False)
|
| 249 |
+
|
| 250 |
+
def get_seq_pos(self, seq_len, device, dtype, offset = 0):
|
| 251 |
+
return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor
|
| 252 |
+
|
| 253 |
+
def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0):
|
| 254 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 255 |
+
|
| 256 |
+
assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
|
| 257 |
+
|
| 258 |
+
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
|
| 259 |
+
|
| 260 |
+
freqs = self.forward(self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset), seq_len = seq_len, offset = offset)
|
| 261 |
+
|
| 262 |
+
if seq_dim == -3:
|
| 263 |
+
freqs = rearrange(freqs, 'n d -> n 1 d')
|
| 264 |
+
|
| 265 |
+
return apply_rotary_emb(freqs, t, seq_dim = seq_dim)
|
| 266 |
+
|
| 267 |
+
def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0):
|
| 268 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 269 |
+
|
| 270 |
+
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
| 271 |
+
assert q_len <= k_len
|
| 272 |
+
|
| 273 |
+
rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, offset = k_len - q_len + offset)
|
| 274 |
+
rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim, offset = offset)
|
| 275 |
+
|
| 276 |
+
rotated_q = rotated_q.type(q.dtype)
|
| 277 |
+
rotated_k = rotated_k.type(k.dtype)
|
| 278 |
+
|
| 279 |
+
return rotated_q, rotated_k
|
| 280 |
+
|
| 281 |
+
def rotate_queries_and_keys(self, q, k, seq_dim = None):
|
| 282 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 283 |
+
|
| 284 |
+
assert self.use_xpos
|
| 285 |
+
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
| 286 |
+
|
| 287 |
+
seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
|
| 288 |
+
|
| 289 |
+
freqs = self.forward(seq, seq_len = seq_len)
|
| 290 |
+
scale = self.get_scale(seq, seq_len = seq_len).to(dtype)
|
| 291 |
+
|
| 292 |
+
if seq_dim == -3:
|
| 293 |
+
freqs = rearrange(freqs, 'n d -> n 1 d')
|
| 294 |
+
scale = rearrange(scale, 'n d -> n 1 d')
|
| 295 |
+
|
| 296 |
+
rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim)
|
| 297 |
+
rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim)
|
| 298 |
+
|
| 299 |
+
rotated_q = rotated_q.type(q.dtype)
|
| 300 |
+
rotated_k = rotated_k.type(k.dtype)
|
| 301 |
+
|
| 302 |
+
return rotated_q, rotated_k
|
| 303 |
+
|
| 304 |
+
@beartype
|
| 305 |
+
def get_scale(
|
| 306 |
+
self,
|
| 307 |
+
t: Tensor,
|
| 308 |
+
seq_len: Optional[int] = None,
|
| 309 |
+
offset = 0
|
| 310 |
+
):
|
| 311 |
+
assert self.use_xpos
|
| 312 |
+
|
| 313 |
+
should_cache = (
|
| 314 |
+
self.cache_if_possible and
|
| 315 |
+
exists(seq_len)
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
if (
|
| 319 |
+
should_cache and \
|
| 320 |
+
exists(self.cached_scales) and \
|
| 321 |
+
(seq_len + offset) <= self.cached_scales.shape[0]
|
| 322 |
+
):
|
| 323 |
+
return self.cached_scales[offset:(offset + seq_len)]
|
| 324 |
+
|
| 325 |
+
scale = 1.
|
| 326 |
+
if self.use_xpos:
|
| 327 |
+
power = (t - len(t) // 2) / self.scale_base
|
| 328 |
+
scale = self.scale ** rearrange(power, 'n -> n 1')
|
| 329 |
+
scale = torch.cat((scale, scale), dim = -1)
|
| 330 |
+
|
| 331 |
+
if should_cache:
|
| 332 |
+
self.tmp_store('cached_scales', scale)
|
| 333 |
+
|
| 334 |
+
return scale
|
| 335 |
+
|
| 336 |
+
def get_axial_freqs(self, *dims):
|
| 337 |
+
Colon = slice(None)
|
| 338 |
+
all_freqs = []
|
| 339 |
+
|
| 340 |
+
for ind, dim in enumerate(dims):
|
| 341 |
+
if self.freqs_for == 'pixel':
|
| 342 |
+
pos = torch.linspace(-1, 1, steps = dim, device = self.device)
|
| 343 |
+
else:
|
| 344 |
+
pos = torch.arange(dim, device = self.device)
|
| 345 |
+
|
| 346 |
+
freqs = self.forward(pos, seq_len = dim)
|
| 347 |
+
|
| 348 |
+
all_axis = [None] * len(dims)
|
| 349 |
+
all_axis[ind] = Colon
|
| 350 |
+
|
| 351 |
+
new_axis_slice = (Ellipsis, *all_axis, Colon)
|
| 352 |
+
all_freqs.append(freqs[new_axis_slice])
|
| 353 |
+
|
| 354 |
+
all_freqs = broadcast_tensors(*all_freqs)
|
| 355 |
+
return torch.cat(all_freqs, dim = -1)
|
| 356 |
+
|
| 357 |
+
def forward(
|
| 358 |
+
self,
|
| 359 |
+
t: Tensor,
|
| 360 |
+
seq_len = None,
|
| 361 |
+
offset = 0
|
| 362 |
+
):
|
| 363 |
+
should_cache = (
|
| 364 |
+
self.cache_if_possible and \
|
| 365 |
+
not self.learned_freq and \
|
| 366 |
+
exists(seq_len) and \
|
| 367 |
+
self.freqs_for != 'pixel'
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
if (
|
| 371 |
+
should_cache and \
|
| 372 |
+
exists(self.cached_freqs) and \
|
| 373 |
+
(offset + seq_len) <= self.cached_freqs.shape[0]
|
| 374 |
+
):
|
| 375 |
+
return self.cached_freqs[offset:(offset + seq_len)].detach()
|
| 376 |
+
|
| 377 |
+
freqs = self.freqs
|
| 378 |
+
|
| 379 |
+
# Scale time to keep t * freq <= 2pi
|
| 380 |
+
if hasattr(self, 'max_time') and self.max_time is not None:
|
| 381 |
+
t = t / self.max_time * (2 * pi)
|
| 382 |
+
|
| 383 |
+
freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
|
| 384 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
|
| 385 |
+
|
| 386 |
+
if should_cache:
|
| 387 |
+
self.tmp_store('cached_freqs', freqs.detach())
|
| 388 |
+
|
| 389 |
+
return freqs
|
| 390 |
+
|
| 391 |
+
class BaseEncoder(nn.Module):
|
| 392 |
+
def __init__(self, parent: nn.Module) -> None:
|
| 393 |
+
super().__init__()
|
| 394 |
+
self._parent = [parent]
|
| 395 |
+
|
| 396 |
+
@property
|
| 397 |
+
def parent(self) -> nn.Module:
|
| 398 |
+
return self._parent[0]
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class BasicImageEncoder(BaseEncoder):
|
| 402 |
+
def __init__(
|
| 403 |
+
self,
|
| 404 |
+
parent: torch.nn.Module,
|
| 405 |
+
start_tokens: Optional[str] = None,
|
| 406 |
+
end_tokens: Optional[str] = "\n",
|
| 407 |
+
) -> None:
|
| 408 |
+
super().__init__(parent)
|
| 409 |
+
end_tokens = None if end_tokens == "None" else end_tokens
|
| 410 |
+
self.start_tokens = start_tokens
|
| 411 |
+
self.end_tokens = end_tokens
|
| 412 |
+
|
| 413 |
+
def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
|
| 414 |
+
if tokens is None:
|
| 415 |
+
return None
|
| 416 |
+
token_ids = self.parent.tokenizer(tokens).input_ids
|
| 417 |
+
token_ids = torch.tensor(token_ids, device=self.parent.device)
|
| 418 |
+
return self.parent.llm_model_embed_tokens(token_ids)
|
| 419 |
+
|
| 420 |
+
def _process_features(
|
| 421 |
+
self,
|
| 422 |
+
features: torch.Tensor,
|
| 423 |
+
start_token_embeds: Optional[torch.Tensor],
|
| 424 |
+
end_token_embeds: Optional[torch.Tensor],
|
| 425 |
+
) -> torch.Tensor:
|
| 426 |
+
if start_token_embeds is not None:
|
| 427 |
+
features = torch.cat([start_token_embeds, features], dim=0)
|
| 428 |
+
if end_token_embeds is not None:
|
| 429 |
+
features = torch.cat([features, end_token_embeds], dim=0)
|
| 430 |
+
return features
|
| 431 |
+
|
| 432 |
+
def forward(self, images: List[torch.Tensor], config: Dict[str, Any], mm_info: dict) -> List[torch.Tensor]:
|
| 433 |
+
images = torch.stack(images, dim=0)
|
| 434 |
+
features = self.parent.encode_images(images, block_sizes=config.get("block_sizes"))
|
| 435 |
+
process_features = partial(
|
| 436 |
+
self._process_features,
|
| 437 |
+
start_token_embeds=self.embed_tokens(self.start_tokens),
|
| 438 |
+
end_token_embeds=self.embed_tokens(self.end_tokens),
|
| 439 |
+
)
|
| 440 |
+
return [process_features(f) for f in features]
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
class BasicVideoEncoder(BaseEncoder):
|
| 444 |
+
def __init__(
|
| 445 |
+
self,
|
| 446 |
+
parent: torch.nn.Module,
|
| 447 |
+
start_tokens: Optional[str] = None,
|
| 448 |
+
end_tokens: Optional[str] = "\n",
|
| 449 |
+
) -> None:
|
| 450 |
+
super().__init__(parent)
|
| 451 |
+
end_tokens = None if end_tokens == "None" else end_tokens
|
| 452 |
+
self.start_tokens = start_tokens
|
| 453 |
+
self.end_tokens = end_tokens
|
| 454 |
+
|
| 455 |
+
def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
|
| 456 |
+
if tokens is None:
|
| 457 |
+
return None
|
| 458 |
+
token_ids = self.parent.tokenizer(tokens).input_ids
|
| 459 |
+
token_ids = torch.tensor(token_ids, device=self.parent.device)
|
| 460 |
+
return self.parent.llm_model_embed_tokens(token_ids)
|
| 461 |
+
|
| 462 |
+
def _process_features(
|
| 463 |
+
self,
|
| 464 |
+
features: torch.Tensor,
|
| 465 |
+
start_token_embeds: Optional[torch.Tensor],
|
| 466 |
+
end_token_embeds: Optional[torch.Tensor],
|
| 467 |
+
) -> torch.Tensor:
|
| 468 |
+
if start_token_embeds is not None:
|
| 469 |
+
start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0)
|
| 470 |
+
features = torch.cat([start_embeds, features], dim=1)
|
| 471 |
+
if end_token_embeds is not None:
|
| 472 |
+
end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0)
|
| 473 |
+
features = torch.cat([features, end_embeds], dim=1)
|
| 474 |
+
return features.flatten(0, 1)
|
| 475 |
+
|
| 476 |
+
def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
|
| 477 |
+
num_frames = [video.shape[0] for video in videos]
|
| 478 |
+
images = torch.cat(videos, dim=0)
|
| 479 |
+
features = self.parent.encode_images(images)
|
| 480 |
+
features = torch.split(features, num_frames)
|
| 481 |
+
process_features = partial(
|
| 482 |
+
self._process_features,
|
| 483 |
+
start_token_embeds=self.embed_tokens(self.start_tokens),
|
| 484 |
+
end_token_embeds=self.embed_tokens(self.end_tokens),
|
| 485 |
+
)
|
| 486 |
+
return [process_features(f) for f in features]
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
class BasicSoundEncoder(BaseEncoder):
|
| 492 |
+
def __init__(
|
| 493 |
+
self,
|
| 494 |
+
parent: torch.nn.Module,
|
| 495 |
+
start_tokens: Optional[str] = None,
|
| 496 |
+
end_tokens: Optional[str] = "\n",
|
| 497 |
+
embed_time = "True",
|
| 498 |
+
trope_theta = 50000,
|
| 499 |
+
trope_dim = 128,
|
| 500 |
+
max_time = None,
|
| 501 |
+
time_embed_type = "pixel",
|
| 502 |
+
period_fix = False,
|
| 503 |
+
) -> None:
|
| 504 |
+
super().__init__(parent)
|
| 505 |
+
end_tokens = None if end_tokens == "None" else end_tokens
|
| 506 |
+
if embed_time == "True":
|
| 507 |
+
embed_time = True
|
| 508 |
+
elif embed_time == "False":
|
| 509 |
+
embed_time = False
|
| 510 |
+
self.start_tokens = start_tokens
|
| 511 |
+
self.end_tokens = end_tokens
|
| 512 |
+
|
| 513 |
+
if embed_time == "False" or embed_time == False:
|
| 514 |
+
self.embed_time = False
|
| 515 |
+
else:
|
| 516 |
+
self.embed_time = True
|
| 517 |
+
self.time_embed_type = time_embed_type
|
| 518 |
+
|
| 519 |
+
period_mode = None
|
| 520 |
+
if type(period_fix) == str:
|
| 521 |
+
if period_fix == "shortest":
|
| 522 |
+
period_fix = "MTCT"
|
| 523 |
+
period_mode = "shortest"
|
| 524 |
+
elif period_fix == "longest":
|
| 525 |
+
period_fix = "MTCT"
|
| 526 |
+
period_mode = "longest"
|
| 527 |
+
|
| 528 |
+
self.period_fix = period_fix
|
| 529 |
+
self.max_time = max_time
|
| 530 |
+
|
| 531 |
+
if period_fix == "MTCT":
|
| 532 |
+
if period_mode is None:
|
| 533 |
+
self.pos_emb = MaxTimeContinuousTimeRotaryEmbedding(
|
| 534 |
+
dim = trope_dim,
|
| 535 |
+
max_time = max_time,
|
| 536 |
+
)
|
| 537 |
+
else:
|
| 538 |
+
self.pos_emb = MaxTimeContinuousTimeRotaryEmbedding(
|
| 539 |
+
dim = trope_dim,
|
| 540 |
+
max_time = max_time,
|
| 541 |
+
period_mode = period_mode,
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
elif time_embed_type in ["pixel", "lang"]:
|
| 545 |
+
if trope_dim is None and max_time is None:
|
| 546 |
+
raise ValueError("trope_dim or max_time is required when embed_time is True")
|
| 547 |
+
self.pos_emb = RotaryEmbedding(
|
| 548 |
+
dim = trope_dim,
|
| 549 |
+
freqs_for = time_embed_type,
|
| 550 |
+
max_freq = 256,
|
| 551 |
+
max_time = max_time,
|
| 552 |
+
)
|
| 553 |
+
elif time_embed_type == "learned_embed":
|
| 554 |
+
self.time_embed = parent.sound_mm_projector.time_embed
|
| 555 |
+
else:
|
| 556 |
+
raise ValueError(f"Invalid time_embed_type: {time_embed_type}")
|
| 557 |
+
|
| 558 |
+
def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
|
| 559 |
+
if tokens is None:
|
| 560 |
+
return None
|
| 561 |
+
token_ids = self.parent.tokenizer(tokens).input_ids
|
| 562 |
+
token_ids = torch.tensor(token_ids, device=self.parent.device)
|
| 563 |
+
# return self.parent.llm.model.embed_tokens(token_ids)
|
| 564 |
+
return self.parent.llm_model_embed_tokens(token_ids)
|
| 565 |
+
|
| 566 |
+
def _process_features(
|
| 567 |
+
self,
|
| 568 |
+
features: torch.Tensor,
|
| 569 |
+
start_token_embeds: Optional[torch.Tensor],
|
| 570 |
+
end_token_embeds: Optional[torch.Tensor],
|
| 571 |
+
times: Optional[torch.Tensor] = None,
|
| 572 |
+
time_embed: Optional[torch.Tensor] = None,
|
| 573 |
+
) -> torch.Tensor:
|
| 574 |
+
|
| 575 |
+
features = features.to(self.parent.device)
|
| 576 |
+
device = features.device
|
| 577 |
+
dtype = features.dtype
|
| 578 |
+
|
| 579 |
+
if self.embed_time:
|
| 580 |
+
device = features.device
|
| 581 |
+
dtype = features.dtype
|
| 582 |
+
|
| 583 |
+
# Handle different embedding types
|
| 584 |
+
if self.time_embed_type in ["pixel", "lang"]:
|
| 585 |
+
times = times.unsqueeze(0)
|
| 586 |
+
new_times = times
|
| 587 |
+
pos_emb = self.pos_emb.to(device)
|
| 588 |
+
if self.period_fix == "True":
|
| 589 |
+
if self.max_time is not None:
|
| 590 |
+
angle = new_times.to(device) / self.max_time * 2 * np.pi
|
| 591 |
+
else:
|
| 592 |
+
angle = new_times.to(device)
|
| 593 |
+
elif self.period_fix == "MTCT":
|
| 594 |
+
freqs = self.pos_emb(new_times.float())
|
| 595 |
+
freqs = freqs.squeeze(0)
|
| 596 |
+
features = apply_rotary_emb(freqs, features)
|
| 597 |
+
else:
|
| 598 |
+
angle = (-new_times * 2 * np.pi).to(device)
|
| 599 |
+
|
| 600 |
+
if not self.period_fix == "MTCT":
|
| 601 |
+
freqs = pos_emb.get_axial_freqs(new_times.shape[0], features.shape[-2]).to(device)
|
| 602 |
+
angle_expanded = angle.unsqueeze(2)
|
| 603 |
+
angle_expanded = angle_expanded.expand(new_times.shape[0], features.shape[-2], freqs.shape[-1])
|
| 604 |
+
freqs = freqs * angle_expanded
|
| 605 |
+
freqs = freqs.squeeze(0)
|
| 606 |
+
# ori_dtype = features.dtype
|
| 607 |
+
# embed_dtype = torch.float32
|
| 608 |
+
# features = features.to(embed_dtype)
|
| 609 |
+
features = apply_rotary_emb(freqs, features)
|
| 610 |
+
# features = features.to(ori_dtype)
|
| 611 |
+
elif self.time_embed_type == "learned_embed": # Learned embedding
|
| 612 |
+
# Add time embeddings to features
|
| 613 |
+
features = features + time_embed
|
| 614 |
+
else:
|
| 615 |
+
raise ValueError(f"Invalid time_embed_type: {self.time_embed_type}")
|
| 616 |
+
|
| 617 |
+
if start_token_embeds is not None:
|
| 618 |
+
features = torch.cat([start_token_embeds, features], dim=0)
|
| 619 |
+
if end_token_embeds is not None:
|
| 620 |
+
features = torch.cat([features, end_token_embeds], dim=0)
|
| 621 |
+
return features
|
| 622 |
+
|
| 623 |
+
def forward(self, sounds: List[torch.Tensor], config: Dict[str, Any], mm_info: dict) -> List[torch.Tensor]:
|
| 624 |
+
# sounds = torch.stack(sounds, dim=0)
|
| 625 |
+
features = self.parent.encode_sound(sounds, mm_info=mm_info)
|
| 626 |
+
process_features = partial(
|
| 627 |
+
self._process_features,
|
| 628 |
+
start_token_embeds=self.embed_tokens(self.start_tokens),
|
| 629 |
+
end_token_embeds=self.embed_tokens(self.end_tokens),
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
if self.embed_time:
|
| 634 |
+
new_features = []
|
| 635 |
+
device = features[0].device
|
| 636 |
+
fea_count = len(features)
|
| 637 |
+
aud_idx = 0
|
| 638 |
+
bs = len(mm_info["audio_info"])
|
| 639 |
+
|
| 640 |
+
if self.time_embed_type == "learned_embed": # Learned embedding, we need to first collect all times and only do time embedding once
|
| 641 |
+
times_list = []
|
| 642 |
+
for i in range(bs):
|
| 643 |
+
_audio_info = mm_info["audio_info"][i]
|
| 644 |
+
if _audio_info is not None:
|
| 645 |
+
for j in range(len(_audio_info)):
|
| 646 |
+
_feature = features[aud_idx]
|
| 647 |
+
if _audio_info[j] == "dummy":
|
| 648 |
+
times = torch.zeros(_feature.shape[0], device=device, dtype=_feature.dtype)
|
| 649 |
+
else:
|
| 650 |
+
audio_chunk_length = _audio_info[j]["new_audio_chunk_length"]
|
| 651 |
+
sec_per_embed = audio_chunk_length / _feature.shape[0]
|
| 652 |
+
audio_start_sec = _audio_info[j]["audio_start_sec"]
|
| 653 |
+
times = [audio_start_sec + i * sec_per_embed + sec_per_embed / 2 for i in range(_feature.shape[0])]
|
| 654 |
+
times = torch.tensor(times).to(device)
|
| 655 |
+
times_list.append(times)
|
| 656 |
+
aud_idx += 1
|
| 657 |
+
|
| 658 |
+
times = torch.stack(times_list, dim=0)
|
| 659 |
+
time_embeds = self.time_embed(times, dtype=features[0].dtype)
|
| 660 |
+
|
| 661 |
+
aud_idx = 0
|
| 662 |
+
for i in range(bs):
|
| 663 |
+
_audio_info = mm_info["audio_info"][i]
|
| 664 |
+
if _audio_info is not None:
|
| 665 |
+
for j in range(len(_audio_info)):
|
| 666 |
+
try:
|
| 667 |
+
_feature = features[aud_idx]
|
| 668 |
+
except Exception as e:
|
| 669 |
+
print(f"Error: {e}. Length of features: {len(features)}. Length of _audio_info: {len(_audio_info)}. Length of _feature: {_feature.shape[0]}")
|
| 670 |
+
raise e
|
| 671 |
+
if _audio_info[j] == "dummy":
|
| 672 |
+
times = torch.zeros(_feature.shape[0], device=device, dtype=_feature.dtype)
|
| 673 |
+
else:
|
| 674 |
+
audio_chunk_length = _audio_info[j]["new_audio_chunk_length"]
|
| 675 |
+
sec_per_embed = audio_chunk_length / _feature.shape[0]
|
| 676 |
+
audio_start_sec = _audio_info[j]["audio_start_sec"]
|
| 677 |
+
times = [audio_start_sec + i * sec_per_embed + sec_per_embed / 2 for i in range(_feature.shape[0])]
|
| 678 |
+
times = torch.tensor(times).to(device)
|
| 679 |
+
if self.time_embed_type == "learned_embed":
|
| 680 |
+
_feature = process_features(_feature, time_embed=time_embeds[aud_idx])
|
| 681 |
+
else:
|
| 682 |
+
_feature = process_features(_feature, times=times)
|
| 683 |
+
new_features.append(_feature)
|
| 684 |
+
aud_idx += 1
|
| 685 |
+
|
| 686 |
+
assert aud_idx == fea_count , "aud_idx: {}, fea_count: {}".format(aud_idx, fea_count)
|
| 687 |
+
features = new_features
|
| 688 |
+
else:
|
| 689 |
+
features = [process_features(f) for f in features]
|
| 690 |
+
return features
|
| 691 |
+
|
| 692 |
+
# return [process_features(f) for f in feature
|
| 693 |
+
|
| 694 |
+
class TSPVideoEncoder(BasicVideoEncoder):
|
| 695 |
+
def __init__(
|
| 696 |
+
self,
|
| 697 |
+
parent: torch.nn.Module,
|
| 698 |
+
pool_sizes: List[Tuple[int, int, int]],
|
| 699 |
+
start_tokens: Optional[str] = None,
|
| 700 |
+
end_tokens: Optional[str] = "\n",
|
| 701 |
+
sep_tokens: Optional[str] = None,
|
| 702 |
+
embed_time: str = "False",
|
| 703 |
+
trope_theta = 50000,
|
| 704 |
+
trope_dim = 128,
|
| 705 |
+
max_time = None,
|
| 706 |
+
time_embed_type = "pixel",
|
| 707 |
+
period_fix = False,
|
| 708 |
+
) -> None:
|
| 709 |
+
super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens)
|
| 710 |
+
self.pool_sizes = pool_sizes
|
| 711 |
+
self.sep_tokens = sep_tokens
|
| 712 |
+
|
| 713 |
+
if embed_time == "False":
|
| 714 |
+
self.embed_time = False
|
| 715 |
+
else:
|
| 716 |
+
self.embed_time = True
|
| 717 |
+
self.time_embed_type = time_embed_type
|
| 718 |
+
|
| 719 |
+
period_mode = None
|
| 720 |
+
if type(period_fix) == str:
|
| 721 |
+
if period_fix == "shortest":
|
| 722 |
+
period_fix = "MTCT"
|
| 723 |
+
period_mode = "shortest"
|
| 724 |
+
elif period_fix == "longest":
|
| 725 |
+
period_fix = "MTCT"
|
| 726 |
+
period_mode = "longest"
|
| 727 |
+
|
| 728 |
+
self.period_fix = period_fix
|
| 729 |
+
self.max_time = max_time
|
| 730 |
+
|
| 731 |
+
if period_fix == "MTCT":
|
| 732 |
+
if period_mode is None:
|
| 733 |
+
self.pos_emb = MaxTimeContinuousTimeRotaryEmbedding(
|
| 734 |
+
dim = trope_dim,
|
| 735 |
+
max_time = max_time,
|
| 736 |
+
)
|
| 737 |
+
else:
|
| 738 |
+
self.pos_emb = MaxTimeContinuousTimeRotaryEmbedding(
|
| 739 |
+
dim = trope_dim,
|
| 740 |
+
max_time = max_time,
|
| 741 |
+
period_mode = period_mode,
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
elif time_embed_type in ["pixel", "lang"]:
|
| 745 |
+
if trope_dim is None and max_time is None:
|
| 746 |
+
raise ValueError("trope_dim or max_time is required when embed_time is True")
|
| 747 |
+
|
| 748 |
+
if time_embed_type == "lang":
|
| 749 |
+
self.pos_emb = RotaryEmbedding(
|
| 750 |
+
dim = trope_dim,
|
| 751 |
+
freqs_for = 'lang',
|
| 752 |
+
theta = trope_theta,
|
| 753 |
+
max_time = max_time,
|
| 754 |
+
)
|
| 755 |
+
elif time_embed_type == "pixel":
|
| 756 |
+
self.pos_emb = RotaryEmbedding(
|
| 757 |
+
dim = trope_dim,
|
| 758 |
+
freqs_for = time_embed_type,
|
| 759 |
+
max_freq = 256
|
| 760 |
+
)
|
| 761 |
+
elif time_embed_type == "learned_embed":
|
| 762 |
+
self.time_embed = parent.mm_projector.time_embed
|
| 763 |
+
else:
|
| 764 |
+
raise ValueError(f"Invalid time_embed_type: {time_embed_type}")
|
| 765 |
+
|
| 766 |
+
def _process_features(
|
| 767 |
+
self,
|
| 768 |
+
inputs: torch.Tensor,
|
| 769 |
+
start_token_embeds: Optional[torch.Tensor],
|
| 770 |
+
end_token_embeds: Optional[torch.Tensor],
|
| 771 |
+
sep_token_embeds: Optional[torch.Tensor],
|
| 772 |
+
times: Optional[torch.Tensor] = None,
|
| 773 |
+
time_embed: Optional[torch.Tensor] = None,
|
| 774 |
+
) -> torch.Tensor:
|
| 775 |
+
nt, ns = inputs.shape[:2]
|
| 776 |
+
nl = int(ns**0.5)
|
| 777 |
+
outputs = []
|
| 778 |
+
for pool_size in self.pool_sizes:
|
| 779 |
+
features = inputs.view(nt, nl, nl, -1)
|
| 780 |
+
for dim, p in enumerate(pool_size):
|
| 781 |
+
try:
|
| 782 |
+
features = pool(features, p, dim=dim)
|
| 783 |
+
except Exception as e:
|
| 784 |
+
print(f"Error: Pooling failed: {e}")
|
| 785 |
+
print(f"inputs.shape: {inputs.shape}, features.shape: {features.shape}, pool_size: {p}, dim: {dim}")
|
| 786 |
+
raise e
|
| 787 |
+
features = features.flatten(1, 2)
|
| 788 |
+
|
| 789 |
+
if self.embed_time:
|
| 790 |
+
device = features.device
|
| 791 |
+
dtype = features.dtype
|
| 792 |
+
if self.time_embed_type in ["pixel", "lang"]:
|
| 793 |
+
# consider the pooling in self.pool_sizes
|
| 794 |
+
temporal_pool_size = pool_size[0]
|
| 795 |
+
if temporal_pool_size != 1:
|
| 796 |
+
if len(times) % temporal_pool_size != 0:
|
| 797 |
+
# pad
|
| 798 |
+
print(f"Warning: length of times: {len(times)} is not a multiple of temporal_pool_size: {temporal_pool_size}")
|
| 799 |
+
remainder = len(times) % temporal_pool_size
|
| 800 |
+
pad_len = temporal_pool_size - remainder
|
| 801 |
+
last_window_mean_times = times[-remainder:].mean()
|
| 802 |
+
times = torch.cat([times, torch.ones(pad_len).to(times.device) * last_window_mean_times])
|
| 803 |
+
new_times = pool(times, temporal_pool_size, 0)
|
| 804 |
+
else:
|
| 805 |
+
new_times = times
|
| 806 |
+
|
| 807 |
+
pos_emb = self.pos_emb.to(device)
|
| 808 |
+
if self.period_fix == "True":
|
| 809 |
+
if self.max_time is not None:
|
| 810 |
+
angle = new_times.to(device) / self.max_time * 2 * np.pi
|
| 811 |
+
else:
|
| 812 |
+
angle = new_times.to(device)
|
| 813 |
+
elif self.period_fix == "MTCT":
|
| 814 |
+
if new_times.ndim == 1:
|
| 815 |
+
new_times = new_times.unsqueeze(0)
|
| 816 |
+
freqs = self.pos_emb(new_times.float())
|
| 817 |
+
freqs = freqs.squeeze(0)
|
| 818 |
+
freqs = freqs.unsqueeze(1)
|
| 819 |
+
features = apply_rotary_emb(freqs, features, seq_dim=0)
|
| 820 |
+
else:
|
| 821 |
+
angle = (-new_times * 2 * np.pi).to(device)
|
| 822 |
+
|
| 823 |
+
if not self.period_fix == "MTCT":
|
| 824 |
+
freqs = pos_emb.get_axial_freqs(new_times.shape[0], features.shape[-2]).to(device)
|
| 825 |
+
angle_expanded = angle.unsqueeze(1).unsqueeze(2)
|
| 826 |
+
angle_expanded = angle_expanded.expand(new_times.shape[0], features.shape[-2], freqs.shape[-1])
|
| 827 |
+
freqs = freqs * angle_expanded
|
| 828 |
+
# ori_dtype = features.dtype
|
| 829 |
+
# embed_dtype = torch.float32
|
| 830 |
+
# features = features.to(embed_dtype)
|
| 831 |
+
features = apply_rotary_emb(freqs, features)
|
| 832 |
+
# features = features.to(ori_dtype)
|
| 833 |
+
elif self.time_embed_type == "learned_embed": # Learned embedding
|
| 834 |
+
# Add time embeddings to features
|
| 835 |
+
features = features + time_embed
|
| 836 |
+
else:
|
| 837 |
+
raise ValueError(f"Invalid time_embed_type: {self.time_embed_type}")
|
| 838 |
+
|
| 839 |
+
features = super()._process_features(
|
| 840 |
+
features,
|
| 841 |
+
start_token_embeds=start_token_embeds,
|
| 842 |
+
end_token_embeds=end_token_embeds,
|
| 843 |
+
)
|
| 844 |
+
if sep_token_embeds is not None:
|
| 845 |
+
features = torch.cat([features, sep_token_embeds], dim=0)
|
| 846 |
+
outputs.append(features)
|
| 847 |
+
return torch.cat(outputs, dim=0)
|
| 848 |
+
|
| 849 |
+
def forward(self, videos: List[torch.Tensor], config: Dict[str, Any], mm_info: dict) -> List[torch.Tensor]:
|
| 850 |
+
cache_feas = []
|
| 851 |
+
cache_feas_index = []
|
| 852 |
+
for _idx in range(len(videos)):
|
| 853 |
+
if type(videos[_idx]) == CacheFeatures:
|
| 854 |
+
cache_feas.append(videos[_idx])
|
| 855 |
+
cache_feas_index.append(_idx)
|
| 856 |
+
|
| 857 |
+
num_frames = [
|
| 858 |
+
_.value['features'].shape[0] if isinstance(_, CacheFeatures) else _.shape[0]
|
| 859 |
+
for _ in videos
|
| 860 |
+
]
|
| 861 |
+
|
| 862 |
+
features = self.parent.encode_video(videos, mm_info=mm_info, num_frames=num_frames)
|
| 863 |
+
features = torch.split(features, num_frames)
|
| 864 |
+
|
| 865 |
+
process_features = partial(
|
| 866 |
+
self._process_features,
|
| 867 |
+
start_token_embeds=self.embed_tokens(self.start_tokens),
|
| 868 |
+
end_token_embeds=self.embed_tokens(self.end_tokens),
|
| 869 |
+
sep_token_embeds=self.embed_tokens(self.sep_tokens),
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
if self.embed_time:
|
| 874 |
+
bs = len(mm_info["video_info"])
|
| 875 |
+
vid_idx = 0
|
| 876 |
+
device = features[0].device
|
| 877 |
+
|
| 878 |
+
if self.time_embed_type == "learned_embed":
|
| 879 |
+
# Learned embedding, we need to first collect all times from all videos and only do time embedding once
|
| 880 |
+
times_list = []
|
| 881 |
+
for i in range(bs):
|
| 882 |
+
_video_info = mm_info["video_info"][i]
|
| 883 |
+
if _video_info is not None:
|
| 884 |
+
for j in range(len(_video_info)):
|
| 885 |
+
_feature = features[vid_idx]
|
| 886 |
+
if _video_info[j] == "dummy":
|
| 887 |
+
times = torch.zeros(_feature.shape[0], device=device, dtype=_feature.dtype)
|
| 888 |
+
else:
|
| 889 |
+
times = _video_info[j]["video_frame_times"]
|
| 890 |
+
times = torch.tensor(times).to(device)
|
| 891 |
+
|
| 892 |
+
for pool_size in self.pool_sizes:
|
| 893 |
+
temporal_pool_size = pool_size[0]
|
| 894 |
+
if temporal_pool_size != 1:
|
| 895 |
+
if len(times) % temporal_pool_size != 0:
|
| 896 |
+
# pad
|
| 897 |
+
print(f"Warning: length of times: {len(times)} is not a multiple of temporal_pool_size: {temporal_pool_size}")
|
| 898 |
+
remainder = len(times) % temporal_pool_size
|
| 899 |
+
pad_len = temporal_pool_size - remainder
|
| 900 |
+
last_window_mean_times = times[-remainder:].mean()
|
| 901 |
+
times = torch.cat([times, torch.ones(pad_len).to(times.device) * last_window_mean_times])
|
| 902 |
+
times = pool(times, temporal_pool_size, 0)
|
| 903 |
+
|
| 904 |
+
times_list.append(times)
|
| 905 |
+
vid_idx += 1
|
| 906 |
+
|
| 907 |
+
# pad the times to the same length
|
| 908 |
+
ori_lens = [len(times) for times in times_list]
|
| 909 |
+
max_len = max(ori_lens)
|
| 910 |
+
for i in range(len(times_list)):
|
| 911 |
+
if len(times_list[i]) < max_len:
|
| 912 |
+
times_list[i] = torch.cat([times_list[i], torch.zeros(max_len - len(times_list[i])).to(times_list[i].device)])
|
| 913 |
+
times = torch.stack(times_list, dim=0)
|
| 914 |
+
time_embeds = self.time_embed(times, dtype=features[0].dtype)
|
| 915 |
+
|
| 916 |
+
# remove the padding for each embed
|
| 917 |
+
new_time_embeds = []
|
| 918 |
+
for i in range(len(times_list)):
|
| 919 |
+
new_time_embeds.append(time_embeds[i][:ori_lens[i]].unsqueeze(1).expand(-1, features[0].shape[1], -1))
|
| 920 |
+
|
| 921 |
+
# add dummy embed to the first embed
|
| 922 |
+
new_time_embeds[0] = new_time_embeds[0] + 0 * time_embeds.mean()
|
| 923 |
+
|
| 924 |
+
new_features = []
|
| 925 |
+
fea_count = len(features)
|
| 926 |
+
vid_idx = 0
|
| 927 |
+
for i in range(bs):
|
| 928 |
+
_video_info = mm_info["video_info"][i]
|
| 929 |
+
if _video_info is not None:
|
| 930 |
+
for j in range(len(_video_info)):
|
| 931 |
+
_feature = features[vid_idx]
|
| 932 |
+
if _video_info[j] == "dummy":
|
| 933 |
+
times = torch.zeros(_feature.shape[0], device=device, dtype=_feature.dtype)
|
| 934 |
+
else:
|
| 935 |
+
times = _video_info[j]["video_frame_times"]
|
| 936 |
+
times = torch.tensor(times).to(device)
|
| 937 |
+
if self.time_embed_type == "learned_embed":
|
| 938 |
+
_feature = process_features(_feature, time_embed=new_time_embeds[vid_idx])
|
| 939 |
+
else:
|
| 940 |
+
_feature = process_features(_feature, times=times)
|
| 941 |
+
new_features.append(_feature)
|
| 942 |
+
vid_idx += 1
|
| 943 |
+
|
| 944 |
+
assert vid_idx == fea_count, "vid_idx: {}, fea_count: {}".format(vid_idx, fea_count)
|
| 945 |
+
features = new_features
|
| 946 |
+
else:
|
| 947 |
+
features = [process_features(f) for f in features]
|
| 948 |
+
return features
|
| 949 |
+
|
| 950 |
+
def _encode_video_frames(self, video_frames: torch.Tensor) -> torch.Tensor:
|
| 951 |
+
"""Helper method to encode video frames when cached features are not available."""
|
| 952 |
+
features = self.parent.encode_images(video_frames.unsqueeze(0))
|
| 953 |
+
return features.squeeze(0)
|
| 954 |
+
|
| 955 |
+
|
mm_projector/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "outputs/model/mm_projector",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"MultimodalProjector"
|
| 5 |
+
],
|
| 6 |
+
"mm_projector_type": "mlp_downsample",
|
| 7 |
+
"model_type": "v2l_projector",
|
| 8 |
+
"torch_dtype": "bfloat16",
|
| 9 |
+
"transformers_version": "4.46.0"
|
| 10 |
+
}
|
mm_projector/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e80a60453d3c104445816f05e1926f676e89a2f99ceb950e4876e99a1e391913
|
| 3 |
+
size 124850712
|
mm_utils.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
# Note: dynamic_preprocess and find_closest_aspect_ratio are referenced from https://github.com/OpenGVLab/InternVL
|
| 18 |
+
|
| 19 |
+
import base64
|
| 20 |
+
import os
|
| 21 |
+
import tempfile
|
| 22 |
+
from io import BytesIO
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
from PIL import Image
|
| 27 |
+
from transformers import StoppingCriteria
|
| 28 |
+
|
| 29 |
+
from .constants import DEFAULT_IMAGE_TOKEN
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
|
| 33 |
+
"""Extract frames from video capture object."""
|
| 34 |
+
import cv2
|
| 35 |
+
|
| 36 |
+
if fps is None or frame_count is None:
|
| 37 |
+
# Recompute if either fps or frame_count is None
|
| 38 |
+
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
| 39 |
+
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 40 |
+
if fps == 0 or frame_count == 0:
|
| 41 |
+
print(f"Video file not found. return empty images. {video_file_name}")
|
| 42 |
+
return [
|
| 43 |
+
Image.new("RGB", (720, 720)),
|
| 44 |
+
] * num_frames, 0
|
| 45 |
+
|
| 46 |
+
duration = frame_count / fps
|
| 47 |
+
frame_interval = frame_count // num_frames
|
| 48 |
+
if frame_interval == 0 and frame_count <= 1:
|
| 49 |
+
print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
|
| 50 |
+
return [
|
| 51 |
+
Image.new("RGB", (720, 720)),
|
| 52 |
+
] * num_frames, 0
|
| 53 |
+
# print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
|
| 54 |
+
|
| 55 |
+
images = []
|
| 56 |
+
count = 0
|
| 57 |
+
success = True
|
| 58 |
+
frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
|
| 59 |
+
while success:
|
| 60 |
+
# print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval)
|
| 61 |
+
if frame_count >= num_frames:
|
| 62 |
+
success, frame = vidcap.read()
|
| 63 |
+
if count in frame_indices:
|
| 64 |
+
try:
|
| 65 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 66 |
+
im_pil = Image.fromarray(img)
|
| 67 |
+
images.append(im_pil)
|
| 68 |
+
except BaseException:
|
| 69 |
+
continue
|
| 70 |
+
if len(images) >= num_frames:
|
| 71 |
+
return images, num_frames
|
| 72 |
+
count += 1
|
| 73 |
+
else:
|
| 74 |
+
# Left padding frames if the video is not long enough
|
| 75 |
+
success, frame = vidcap.read()
|
| 76 |
+
if success:
|
| 77 |
+
try:
|
| 78 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 79 |
+
im_pil = Image.fromarray(img)
|
| 80 |
+
images.append(im_pil)
|
| 81 |
+
except BaseException:
|
| 82 |
+
continue
|
| 83 |
+
count += 1
|
| 84 |
+
else:
|
| 85 |
+
break
|
| 86 |
+
if len(images) == 0:
|
| 87 |
+
raise ValueError("Did not find enough frames in the video. return empty image.")
|
| 88 |
+
|
| 89 |
+
return images, len(images)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_frame_from_vcap_with_fps(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
|
| 93 |
+
"""
|
| 94 |
+
Extract frames from video capture with FPS consideration.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
vidcap: OpenCV video capture object
|
| 98 |
+
num_frames: Maximum number of frames the model can support
|
| 99 |
+
max_fps: Maximum FPS the model can support
|
| 100 |
+
fps: FPS of the input video
|
| 101 |
+
frame_count: Number of frames in the input video
|
| 102 |
+
video_file_name: Name of the video file for logging
|
| 103 |
+
"""
|
| 104 |
+
import random
|
| 105 |
+
import cv2
|
| 106 |
+
|
| 107 |
+
if fps is None or frame_count is None:
|
| 108 |
+
# Recompute if either fps or frame_count is None
|
| 109 |
+
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
| 110 |
+
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 111 |
+
|
| 112 |
+
if fps == 0 or frame_count == 0:
|
| 113 |
+
print(f"Video file not found. return empty images. {video_file_name}")
|
| 114 |
+
empty_video_frames = int(random.uniform(2, 8 * max_fps))
|
| 115 |
+
return [
|
| 116 |
+
Image.new("RGB", (720, 720)),
|
| 117 |
+
] * empty_video_frames, 0
|
| 118 |
+
|
| 119 |
+
duration = frame_count / fps
|
| 120 |
+
# print("duration:", duration, "frames:", frame_count, "fps:", fps, "num_frames:", num_frames, "max_fps:", max_fps)
|
| 121 |
+
# If the video is too long (longer than max_fps and num_frames can support),
|
| 122 |
+
# we will use lower fps to sample frames.
|
| 123 |
+
if duration >= num_frames / max_fps:
|
| 124 |
+
frame_interval = frame_count // num_frames
|
| 125 |
+
|
| 126 |
+
# If the video is too short, we will skip the video if there is only one frame.
|
| 127 |
+
if frame_interval == 0 and frame_count <= 1:
|
| 128 |
+
print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
|
| 129 |
+
empty_video_frames = int(random.uniform(2, 8 * max_fps))
|
| 130 |
+
return [
|
| 131 |
+
Image.new("RGB", (720, 720)),
|
| 132 |
+
] * empty_video_frames, 0
|
| 133 |
+
|
| 134 |
+
images = []
|
| 135 |
+
count = 0
|
| 136 |
+
success = True
|
| 137 |
+
frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
|
| 138 |
+
|
| 139 |
+
while success:
|
| 140 |
+
if frame_count >= num_frames:
|
| 141 |
+
if count in frame_indices:
|
| 142 |
+
success, frame = vidcap.read()
|
| 143 |
+
try:
|
| 144 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 145 |
+
im_pil = Image.fromarray(img)
|
| 146 |
+
images.append(im_pil)
|
| 147 |
+
except:
|
| 148 |
+
continue
|
| 149 |
+
if len(images) >= num_frames:
|
| 150 |
+
return images, num_frames
|
| 151 |
+
else:
|
| 152 |
+
success = vidcap.grab()
|
| 153 |
+
count += 1
|
| 154 |
+
else:
|
| 155 |
+
# Left padding frames if the video is not long enough
|
| 156 |
+
success, frame = vidcap.read()
|
| 157 |
+
if success:
|
| 158 |
+
try:
|
| 159 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 160 |
+
im_pil = Image.fromarray(img)
|
| 161 |
+
images.append(im_pil)
|
| 162 |
+
except:
|
| 163 |
+
continue
|
| 164 |
+
count += 1
|
| 165 |
+
else:
|
| 166 |
+
break
|
| 167 |
+
else:
|
| 168 |
+
frames_required = int(duration * max_fps)
|
| 169 |
+
frame_indices = np.linspace(0, frame_count - 1, frames_required, dtype=int)
|
| 170 |
+
if frames_required == 0:
|
| 171 |
+
print(f"frames_required is fewer than 2. Duration {duration}, return empty image.")
|
| 172 |
+
empty_video_frames = int(random.uniform(2, 8 * max_fps))
|
| 173 |
+
return [
|
| 174 |
+
Image.new("RGB", (720, 720)),
|
| 175 |
+
] * empty_video_frames, 0
|
| 176 |
+
elif frames_required == 1:
|
| 177 |
+
frame_indices = np.linspace(0, frame_count - 1, 2, dtype=int)
|
| 178 |
+
images = []
|
| 179 |
+
count = 0
|
| 180 |
+
looked = 0
|
| 181 |
+
success = True
|
| 182 |
+
|
| 183 |
+
while success:
|
| 184 |
+
success, frame = vidcap.read()
|
| 185 |
+
if success and (looked in frame_indices):
|
| 186 |
+
try:
|
| 187 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 188 |
+
im_pil = Image.fromarray(img)
|
| 189 |
+
images.append(im_pil)
|
| 190 |
+
except:
|
| 191 |
+
continue
|
| 192 |
+
count += 1
|
| 193 |
+
looked += 1
|
| 194 |
+
|
| 195 |
+
if len(images) == 0:
|
| 196 |
+
empty_video_frames = int(random.uniform(2, 8 * max_fps))
|
| 197 |
+
return [
|
| 198 |
+
Image.new("RGB", (720, 720)),
|
| 199 |
+
] * empty_video_frames, 0
|
| 200 |
+
else:
|
| 201 |
+
return images, len(images)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def opencv_extract_frames(vpath_or_bytesio, frames=6, max_fps=0.0, fps=None, frame_count=None):
|
| 205 |
+
"""
|
| 206 |
+
Extract frames from a video using OpenCV.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
|
| 210 |
+
frames (int): Number of frames to extract from the video.
|
| 211 |
+
fps (float): Frames per second of the video. If 0.0, the function will extract frames at equal intervals.
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
list: List of PIL Images extracted from the video.
|
| 215 |
+
|
| 216 |
+
Raises:
|
| 217 |
+
NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
|
| 218 |
+
"""
|
| 219 |
+
import cv2
|
| 220 |
+
|
| 221 |
+
if isinstance(vpath_or_bytesio, str):
|
| 222 |
+
vidcap = cv2.VideoCapture(vpath_or_bytesio)
|
| 223 |
+
if max_fps > 0.0:
|
| 224 |
+
return get_frame_from_vcap_with_fps(
|
| 225 |
+
vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
|
| 226 |
+
)
|
| 227 |
+
return get_frame_from_vcap(
|
| 228 |
+
vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
|
| 229 |
+
)
|
| 230 |
+
elif isinstance(vpath_or_bytesio, (BytesIO,)):
|
| 231 |
+
# assuming mp4
|
| 232 |
+
with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
|
| 233 |
+
temp_video.write(vpath_or_bytesio.read())
|
| 234 |
+
temp_video_name = temp_video.name
|
| 235 |
+
vidcap = cv2.VideoCapture(temp_video_name)
|
| 236 |
+
if max_fps > 0.0:
|
| 237 |
+
return get_frame_from_vcap_with_fps(
|
| 238 |
+
vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
|
| 239 |
+
)
|
| 240 |
+
return get_frame_from_vcap(
|
| 241 |
+
vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
|
| 242 |
+
)
|
| 243 |
+
else:
|
| 244 |
+
raise NotImplementedError(type(vpath_or_bytesio))
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def load_image_from_base64(image):
|
| 248 |
+
"""Load PIL Image from base64 encoded string."""
|
| 249 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def expand2square(pil_img, background_color):
|
| 253 |
+
"""
|
| 254 |
+
Expand the given PIL image to a square shape by adding padding.
|
| 255 |
+
|
| 256 |
+
Parameters:
|
| 257 |
+
- pil_img: The PIL image to be expanded.
|
| 258 |
+
- background_color: The color of the padding to be added.
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
- The expanded PIL image.
|
| 262 |
+
|
| 263 |
+
If the image is already square, it is returned as is.
|
| 264 |
+
If the image is wider than it is tall, padding is added to the top and bottom.
|
| 265 |
+
If the image is taller than it is wide, padding is added to the left and right.
|
| 266 |
+
"""
|
| 267 |
+
width, height = pil_img.size
|
| 268 |
+
if pil_img.mode == "L":
|
| 269 |
+
background_color = background_color[0]
|
| 270 |
+
if width == height:
|
| 271 |
+
return pil_img
|
| 272 |
+
elif width > height:
|
| 273 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 274 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 275 |
+
return result
|
| 276 |
+
else:
|
| 277 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 278 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 279 |
+
return result
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
| 283 |
+
"""Find the closest aspect ratio from target ratios."""
|
| 284 |
+
best_ratio_diff = float("inf")
|
| 285 |
+
best_ratio = (1, 1)
|
| 286 |
+
area = width * height
|
| 287 |
+
for ratio in target_ratios:
|
| 288 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
| 289 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
| 290 |
+
if ratio_diff < best_ratio_diff:
|
| 291 |
+
best_ratio_diff = ratio_diff
|
| 292 |
+
best_ratio = ratio
|
| 293 |
+
elif ratio_diff == best_ratio_diff:
|
| 294 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
| 295 |
+
best_ratio = ratio
|
| 296 |
+
return best_ratio
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, use_thumbnail=True):
|
| 300 |
+
"""Dynamically preprocess image into multiple tiles based on aspect ratio."""
|
| 301 |
+
orig_width, orig_height = image.size
|
| 302 |
+
aspect_ratio = orig_width / orig_height
|
| 303 |
+
|
| 304 |
+
# Calculate the existing image aspect ratio
|
| 305 |
+
target_ratios = {
|
| 306 |
+
(i, j)
|
| 307 |
+
for n in range(min_num, max_num + 1)
|
| 308 |
+
for i in range(1, n + 1)
|
| 309 |
+
for j in range(1, n + 1)
|
| 310 |
+
if i * j <= max_num and i * j >= min_num
|
| 311 |
+
}
|
| 312 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
| 313 |
+
|
| 314 |
+
# find the closest aspect ratio to the target
|
| 315 |
+
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
| 316 |
+
|
| 317 |
+
# calculate the target width and height
|
| 318 |
+
target_width = image_size * target_aspect_ratio[0]
|
| 319 |
+
target_height = image_size * target_aspect_ratio[1]
|
| 320 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
| 321 |
+
|
| 322 |
+
# resize the image
|
| 323 |
+
resized_img = image.resize((target_width, target_height))
|
| 324 |
+
processed_images = []
|
| 325 |
+
for i in range(blocks):
|
| 326 |
+
box = (
|
| 327 |
+
(i % (target_width // image_size)) * image_size,
|
| 328 |
+
(i // (target_width // image_size)) * image_size,
|
| 329 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
| 330 |
+
((i // (target_width // image_size)) + 1) * image_size,
|
| 331 |
+
)
|
| 332 |
+
# split the image
|
| 333 |
+
split_img = resized_img.crop(box)
|
| 334 |
+
processed_images.append(split_img)
|
| 335 |
+
assert len(processed_images) == blocks
|
| 336 |
+
if use_thumbnail and len(processed_images) != 1:
|
| 337 |
+
thumbnail_img = image.resize((image_size, image_size))
|
| 338 |
+
processed_images.append(thumbnail_img)
|
| 339 |
+
return processed_images
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def dynamic_s2_preprocess(image, s2_scales=[384, 768, 1152], max_num=12, image_size=384):
|
| 343 |
+
"""Dynamically preprocess image with multi-scale S2 strategy."""
|
| 344 |
+
orig_width, orig_height = image.size
|
| 345 |
+
aspect_ratio = orig_width / orig_height
|
| 346 |
+
min_num = (s2_scales[-1] // s2_scales[0]) ** 2
|
| 347 |
+
|
| 348 |
+
processed_images = []
|
| 349 |
+
|
| 350 |
+
# Add tiles for all but the last scale using fixed square ratio
|
| 351 |
+
|
| 352 |
+
for scale in s2_scales[:-1]:
|
| 353 |
+
target_width = image_size * (scale // s2_scales[0])
|
| 354 |
+
target_height = image_size * (scale // s2_scales[0])
|
| 355 |
+
blocks = (scale // s2_scales[0]) ** 2
|
| 356 |
+
|
| 357 |
+
# resize the image
|
| 358 |
+
resized_img = image.resize((target_width, target_height))
|
| 359 |
+
for i in range(blocks):
|
| 360 |
+
box = (
|
| 361 |
+
(i % (target_width // image_size)) * image_size,
|
| 362 |
+
(i // (target_width // image_size)) * image_size,
|
| 363 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
| 364 |
+
((i // (target_width // image_size)) + 1) * image_size,
|
| 365 |
+
)
|
| 366 |
+
# split the image
|
| 367 |
+
split_img = resized_img.crop(box)
|
| 368 |
+
processed_images.append(split_img)
|
| 369 |
+
|
| 370 |
+
# Add tiles for the last scale using dynamic aspect ratio
|
| 371 |
+
|
| 372 |
+
# Calculate the existing image aspect ratio
|
| 373 |
+
target_ratios = {
|
| 374 |
+
(i, j)
|
| 375 |
+
for n in range(min_num, max_num + 1)
|
| 376 |
+
for i in range(1, n + 1)
|
| 377 |
+
for j in range(1, n + 1)
|
| 378 |
+
if i * j <= max_num and i * j >= min_num
|
| 379 |
+
}
|
| 380 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
| 381 |
+
|
| 382 |
+
# find the closest aspect ratio to the target
|
| 383 |
+
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
| 384 |
+
|
| 385 |
+
# calculate the target width and height
|
| 386 |
+
target_width = image_size * target_aspect_ratio[0]
|
| 387 |
+
target_height = image_size * target_aspect_ratio[1]
|
| 388 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
| 389 |
+
|
| 390 |
+
# resize the image
|
| 391 |
+
resized_img = image.resize((target_width, target_height))
|
| 392 |
+
for i in range(blocks):
|
| 393 |
+
box = (
|
| 394 |
+
(i % (target_width // image_size)) * image_size,
|
| 395 |
+
(i // (target_width // image_size)) * image_size,
|
| 396 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
| 397 |
+
((i // (target_width // image_size)) + 1) * image_size,
|
| 398 |
+
)
|
| 399 |
+
# split the image
|
| 400 |
+
split_img = resized_img.crop(box)
|
| 401 |
+
processed_images.append(split_img)
|
| 402 |
+
|
| 403 |
+
return processed_images, (target_aspect_ratio[1], target_aspect_ratio[0])
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def dynamic_process_images_and_prompt(images, prompt, data_args, image_folder=None, max_tiles=None):
|
| 407 |
+
prompt = prompt.split(DEFAULT_IMAGE_TOKEN)
|
| 408 |
+
idx = 0
|
| 409 |
+
all_images = []
|
| 410 |
+
for img in images:
|
| 411 |
+
processed_images = process_image(img, data_args, image_folder, enable_dynamic_res=True, max_tiles=max_tiles)
|
| 412 |
+
all_images.append(processed_images)
|
| 413 |
+
prompt.insert(idx + 1, f"{DEFAULT_IMAGE_TOKEN}\n" * processed_images.shape[0])
|
| 414 |
+
idx += 2
|
| 415 |
+
prompt = "".join(prompt)
|
| 416 |
+
if all_images:
|
| 417 |
+
all_images = torch.cat(all_images)
|
| 418 |
+
else:
|
| 419 |
+
all_images = None
|
| 420 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, "")
|
| 421 |
+
return all_images, prompt
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def dynamic_s2_process_images_and_prompt(images, prompt, data_args, image_folder=None):
|
| 425 |
+
idx = 0
|
| 426 |
+
all_images = []
|
| 427 |
+
all_block_size = []
|
| 428 |
+
for img in images:
|
| 429 |
+
processed_images, block_size = process_image(img, data_args, image_folder, enable_dynamic_s2=True)
|
| 430 |
+
all_images.append(processed_images)
|
| 431 |
+
all_block_size.append(block_size)
|
| 432 |
+
idx += 2
|
| 433 |
+
if all_images:
|
| 434 |
+
all_images = torch.cat(all_images)
|
| 435 |
+
else:
|
| 436 |
+
all_images = None
|
| 437 |
+
return all_images, all_block_size
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def process_image(
|
| 441 |
+
image_file, data_args, image_folder, enable_dynamic_res=False, enable_dynamic_s2=False, max_tiles=None
|
| 442 |
+
):
|
| 443 |
+
processor = data_args.image_processor
|
| 444 |
+
if isinstance(image_file, str):
|
| 445 |
+
if image_folder is not None:
|
| 446 |
+
image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
|
| 447 |
+
else:
|
| 448 |
+
image = Image.open(image_file).convert("RGB")
|
| 449 |
+
else:
|
| 450 |
+
# image is stored in bytearray
|
| 451 |
+
image = image_file
|
| 452 |
+
image = image.convert("RGB")
|
| 453 |
+
if hasattr(data_args.image_processor, "crop_size"):
|
| 454 |
+
# CLIP vision tower
|
| 455 |
+
crop_size = data_args.image_processor.crop_size
|
| 456 |
+
else:
|
| 457 |
+
# SIGLIP vision tower
|
| 458 |
+
assert hasattr(data_args.image_processor, "size")
|
| 459 |
+
crop_size = data_args.image_processor.size
|
| 460 |
+
if "dynamic_s2" in data_args.image_aspect_ratio and enable_dynamic_s2:
|
| 461 |
+
assert crop_size["height"] == crop_size["width"]
|
| 462 |
+
images, block_size = dynamic_s2_preprocess(
|
| 463 |
+
image, s2_scales=data_args.s2_scales, max_num=data_args.max_tiles, image_size=crop_size["height"]
|
| 464 |
+
)
|
| 465 |
+
images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
|
| 466 |
+
return torch.stack(images), block_size
|
| 467 |
+
if "dynamic" in data_args.image_aspect_ratio and enable_dynamic_res:
|
| 468 |
+
assert crop_size["height"] == crop_size["width"]
|
| 469 |
+
if max_tiles is not None:
|
| 470 |
+
max_num = max_tiles
|
| 471 |
+
else:
|
| 472 |
+
max_num = data_args.max_tiles
|
| 473 |
+
images = dynamic_preprocess(image, min_num=data_args.min_tiles, max_num=max_num, image_size=crop_size["height"])
|
| 474 |
+
images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
|
| 475 |
+
return torch.stack(images)
|
| 476 |
+
|
| 477 |
+
if data_args.image_aspect_ratio == "resize":
|
| 478 |
+
image = image.resize((crop_size["width"], crop_size["height"]))
|
| 479 |
+
elif data_args.image_aspect_ratio == "pad":
|
| 480 |
+
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
|
| 481 |
+
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
| 482 |
+
else:
|
| 483 |
+
# Using default behavior of the vision encoder
|
| 484 |
+
# For CLIP, default is central crop
|
| 485 |
+
# For Radio, default is central crop
|
| 486 |
+
# For Siglip, default is resize
|
| 487 |
+
# For InternVIT, default is resize
|
| 488 |
+
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
| 489 |
+
return image
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def process_images(images, image_processor, model_cfg, enable_dynamic_res=False, max_tiles=None):
|
| 493 |
+
"""Process a batch of images using the model's image processor."""
|
| 494 |
+
model_cfg.image_processor = image_processor
|
| 495 |
+
new_images = [
|
| 496 |
+
process_image(image, model_cfg, None, enable_dynamic_res=enable_dynamic_res, max_tiles=max_tiles)
|
| 497 |
+
for image in images
|
| 498 |
+
]
|
| 499 |
+
|
| 500 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
| 501 |
+
if len(new_images[0].shape) == 4:
|
| 502 |
+
new_images = torch.cat(new_images, dim=0)
|
| 503 |
+
elif len(new_images[0].shape) == 3:
|
| 504 |
+
new_images = torch.stack(new_images, dim=0)
|
| 505 |
+
else:
|
| 506 |
+
raise ValueError(f"new_images rank does not equal to 4, rank: {len(new_images[0].shape)}")
|
| 507 |
+
else:
|
| 508 |
+
raise ValueError("The shape of images in new_images is different!")
|
| 509 |
+
return new_images
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def tokenizer_image_token(prompt, tokenizer, return_tensors=None, return_ids=True):
|
| 513 |
+
"""Tokenize prompt with media tokens."""
|
| 514 |
+
if return_ids:
|
| 515 |
+
return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
|
| 516 |
+
else:
|
| 517 |
+
return tokenizer(prompt, return_tensors=return_tensors)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def is_gemma_tokenizer(tokenizer):
|
| 521 |
+
"""Check if the tokenizer is a Gemma tokenizer."""
|
| 522 |
+
return "gemma" in tokenizer.__class__.__name__.lower()
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def get_model_name_from_path(model_path):
|
| 526 |
+
"""Extract model name from file path."""
|
| 527 |
+
model_path = model_path.strip("/")
|
| 528 |
+
model_paths = model_path.split("/")
|
| 529 |
+
if model_paths[-1].startswith("checkpoint-"):
|
| 530 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
| 531 |
+
else:
|
| 532 |
+
return model_paths[-1]
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
| 536 |
+
"""Stopping criteria based on keyword tokens."""
|
| 537 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
| 538 |
+
self.keywords = keywords
|
| 539 |
+
self.keyword_ids = []
|
| 540 |
+
self.max_keyword_len = 0
|
| 541 |
+
for keyword in keywords:
|
| 542 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
| 543 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
| 544 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
| 545 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
| 546 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
| 547 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
| 548 |
+
self.tokenizer = tokenizer
|
| 549 |
+
self.start_len = input_ids.shape[1]
|
| 550 |
+
|
| 551 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 552 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
| 553 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
| 554 |
+
for keyword_id in self.keyword_ids:
|
| 555 |
+
if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
|
| 556 |
+
return True
|
| 557 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
| 558 |
+
for keyword in self.keywords:
|
| 559 |
+
if keyword in outputs:
|
| 560 |
+
return True
|
| 561 |
+
return False
|
| 562 |
+
|
| 563 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 564 |
+
outputs = []
|
| 565 |
+
for i in range(output_ids.shape[0]):
|
| 566 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
| 567 |
+
return all(outputs)
|
model_utils_packing.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from importlib import import_module
|
| 17 |
+
from typing import Tuple
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import transformers
|
| 21 |
+
from torch import nn
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
|
| 24 |
+
__all__ = ["patch"]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _get_unpad_data(attention_mask: torch.Tensor, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 28 |
+
if hasattr(_get_unpad_data, "seqlens_in_batch"):
|
| 29 |
+
seqlens_in_batch = _get_unpad_data.seqlens_in_batch
|
| 30 |
+
else:
|
| 31 |
+
seqlens_in_batch = torch.sum(attention_mask, dim=1)
|
| 32 |
+
|
| 33 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 34 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 35 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
| 36 |
+
return indices, cu_seqlens, max_seqlen_in_batch
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def set_seqlens_in_batch(seqlens_in_batch: torch.Tensor) -> None:
|
| 40 |
+
_get_unpad_data.seqlens_in_batch = seqlens_in_batch
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def patch(model: nn.Module) -> None:
|
| 44 |
+
if transformers.__version__ < "4.43.0":
|
| 45 |
+
m = import_module(model.__module__)
|
| 46 |
+
if not hasattr(m, "_get_unpad_data"):
|
| 47 |
+
raise ValueError(f"Module {m} does not have function '_get_unpad_data' for packing")
|
| 48 |
+
m._get_unpad_data = _get_unpad_data
|
| 49 |
+
else:
|
| 50 |
+
transformers.modeling_flash_attention_utils._get_unpad_data = _get_unpad_data
|
modeling_vila.py
ADDED
|
@@ -0,0 +1,1834 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import copy
|
| 17 |
+
import json
|
| 18 |
+
import logging
|
| 19 |
+
import numpy as np
|
| 20 |
+
import os
|
| 21 |
+
import os.path
|
| 22 |
+
import os.path as osp
|
| 23 |
+
import shutil
|
| 24 |
+
import warnings
|
| 25 |
+
from abc import ABC
|
| 26 |
+
from collections import OrderedDict, defaultdict, deque
|
| 27 |
+
from copy import deepcopy
|
| 28 |
+
from itertools import chain
|
| 29 |
+
from threading import Thread
|
| 30 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
import torch.distributed as dist
|
| 34 |
+
import torch.nn as nn
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
import torchvision
|
| 37 |
+
from einops import rearrange
|
| 38 |
+
from PIL import Image
|
| 39 |
+
from transformers import (
|
| 40 |
+
AutoConfig,
|
| 41 |
+
AutoModel,
|
| 42 |
+
AutoProcessor,
|
| 43 |
+
AutoTokenizer,
|
| 44 |
+
GenerationConfig,
|
| 45 |
+
LogitsProcessor,
|
| 46 |
+
PretrainedConfig,
|
| 47 |
+
PreTrainedModel,
|
| 48 |
+
Qwen2Config,
|
| 49 |
+
Qwen2ForCausalLM,
|
| 50 |
+
Qwen2PreTrainedModel,
|
| 51 |
+
TextIteratorStreamer,
|
| 52 |
+
WhisperFeatureExtractor,
|
| 53 |
+
)
|
| 54 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 55 |
+
from transformers.modeling_utils import ContextManagers, no_init_weights
|
| 56 |
+
|
| 57 |
+
from .auto_processor import VILAProcessor
|
| 58 |
+
from .base_projector import MultimodalProjector, MultimodalProjectorConfig
|
| 59 |
+
from .sound_base_projector import SoundMultimodalProjector, SoundMultimodalProjectorConfig
|
| 60 |
+
from .speech_base_projector import SpeechMultimodalProjector, SpeechMultimodalProjectorConfig
|
| 61 |
+
|
| 62 |
+
from .builder import build_llm_and_tokenizer
|
| 63 |
+
from .configuration_vila import VILAConfig
|
| 64 |
+
from .constants import *
|
| 65 |
+
from .conversation import SeparatorStyle, default_conversation
|
| 66 |
+
from .distributed import all_gather as vila_all_gather
|
| 67 |
+
from .media import extract_media
|
| 68 |
+
from .media_encoder import BasicImageEncoder, BasicVideoEncoder, TSPVideoEncoder, BasicSoundEncoder, CacheFeatures
|
| 69 |
+
from .mm_utils import process_image, process_images
|
| 70 |
+
from .model_utils_packing import set_seqlens_in_batch
|
| 71 |
+
from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2
|
| 72 |
+
from .tokenizer_utils import tokenize_conversation
|
| 73 |
+
from .utils import get_model_config, load_tokenizer_then_handle_media_tokens_and_chat_template
|
| 74 |
+
|
| 75 |
+
from .constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS_VILA, NUM_EXTRA_TOKENS_XVILA
|
| 76 |
+
from .qwen_audio_encoder import Qwen2AudioTower
|
| 77 |
+
import whisper
|
| 78 |
+
|
| 79 |
+
from .audio_encoder import AudioTower
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def build_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
|
| 83 |
+
"""Build multimodal projector from path or configuration."""
|
| 84 |
+
if model_type_or_path is None:
|
| 85 |
+
return None
|
| 86 |
+
if config.resume_path:
|
| 87 |
+
assert os.path.exists(model_type_or_path), f"Resume mm projector path {model_type_or_path} does not exist!"
|
| 88 |
+
return MultimodalProjector.from_pretrained(model_type_or_path, config)
|
| 89 |
+
else:
|
| 90 |
+
mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
|
| 91 |
+
mm_projector = MultimodalProjector(mm_projector_cfg, config)
|
| 92 |
+
return mm_projector
|
| 93 |
+
|
| 94 |
+
def build_speech_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
|
| 95 |
+
"""Build speech multimodal projector from path or configuration."""
|
| 96 |
+
if model_type_or_path is None:
|
| 97 |
+
return None
|
| 98 |
+
if config.resume_path:
|
| 99 |
+
assert os.path.exists(model_type_or_path), f"Resume speech mm projector path {model_type_or_path} does not exist!"
|
| 100 |
+
_model = SpeechMultimodalProjector.from_pretrained(
|
| 101 |
+
model_type_or_path, config, torch_dtype=eval(config.model_dtype)
|
| 102 |
+
)
|
| 103 |
+
return _model
|
| 104 |
+
else:
|
| 105 |
+
speech_mm_projector_cfg = SpeechMultimodalProjectorConfig(model_type_or_path)
|
| 106 |
+
speech_mm_projector = SpeechMultimodalProjector(speech_mm_projector_cfg, config).to(eval(config.model_dtype))
|
| 107 |
+
return speech_mm_projector
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def build_sound_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
|
| 111 |
+
"""Build sound multimodal projector from path or configuration."""
|
| 112 |
+
if model_type_or_path is None:
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
if type(config.model_dtype) == str:
|
| 116 |
+
model_dtype = eval(config.model_dtype)
|
| 117 |
+
else:
|
| 118 |
+
model_dtype = config.model_dtype
|
| 119 |
+
if config.resume_path:
|
| 120 |
+
assert os.path.exists(model_type_or_path), f"Resume sound mm projector path {model_type_or_path} does not exist!"
|
| 121 |
+
_model = SoundMultimodalProjector.from_pretrained(
|
| 122 |
+
model_type_or_path, config, torch_dtype=model_dtype
|
| 123 |
+
)
|
| 124 |
+
return _model
|
| 125 |
+
else:
|
| 126 |
+
sound_mm_projector_cfg = SoundMultimodalProjectorConfig(model_type_or_path)
|
| 127 |
+
sound_mm_projector = SoundMultimodalProjector(sound_mm_projector_cfg, config).to(model_dtype)
|
| 128 |
+
return sound_mm_projector
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def check_dot_in_model_path(model_path: str):
|
| 132 |
+
"""Check if the model path contains a dot, which may affect model loading."""
|
| 133 |
+
if osp.isdir(model_path):
|
| 134 |
+
if "." in osp.abspath(model_path):
|
| 135 |
+
return True
|
| 136 |
+
else:
|
| 137 |
+
if "." in model_path:
|
| 138 |
+
return True
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_vila_version(model_path: str) -> str:
|
| 143 |
+
VERSIONS = ["vila1.5", "vila-u", "longvila", "nvila", "vila-m3"]
|
| 144 |
+
for version in VERSIONS:
|
| 145 |
+
if version in model_path.lower():
|
| 146 |
+
return version
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def generate_jinja_template(conv_mode: str) -> str:
|
| 151 |
+
if conv_mode == "vicuna_v1":
|
| 152 |
+
return """{% set system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " %}
|
| 153 |
+
{% set roles = ["user", "assistant"] %}
|
| 154 |
+
{% set sep = " " %}
|
| 155 |
+
|
| 156 |
+
{{ system_prompt }}
|
| 157 |
+
|
| 158 |
+
{% for message in messages %}
|
| 159 |
+
{% if message['role'] == roles[0] %}
|
| 160 |
+
{{ "USER: " }}{{ sep }}{{ message['content'] }}{{ sep }}
|
| 161 |
+
{% else %}
|
| 162 |
+
{{ "ASSISTANT: " }}{{ sep }}{{ message['content'] }}{{ sep }}
|
| 163 |
+
{% endif %}
|
| 164 |
+
{% endfor %}
|
| 165 |
+
{% if messages[-1]['role'] == 'user' %}
|
| 166 |
+
{{ "ASSISTANT:" }}
|
| 167 |
+
{% endif %}
|
| 168 |
+
"""
|
| 169 |
+
elif conv_mode == "llama_3":
|
| 170 |
+
return """{% set system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|>" %}
|
| 171 |
+
{% set roles = ["<|start_header_id|>user<|end_header_id|>\\n\\n", "<|start_header_id|>assistant<|end_header_id|>\\n\\n"]%}
|
| 172 |
+
{% set sep = "<|eot_id|>" %}
|
| 173 |
+
|
| 174 |
+
{{ system_prompt }}
|
| 175 |
+
{% for message in messages %}
|
| 176 |
+
{% if message['role'] == 'user' %}
|
| 177 |
+
{{ roles[0] }}{{ message['content'] }}{{ sep }}
|
| 178 |
+
{% else %}
|
| 179 |
+
{{ roles[1] }}{{ message['content'] }}{{ sep }}
|
| 180 |
+
{% endif %}
|
| 181 |
+
{% endfor %}
|
| 182 |
+
{% if messages[-1]['role'] == 'user' %}
|
| 183 |
+
{{ roles[1] }}
|
| 184 |
+
{% endif %}
|
| 185 |
+
"""
|
| 186 |
+
elif conv_mode == "hermes_2":
|
| 187 |
+
return """{% set system_prompt = "<|im_start|>system\nAnswer the questions." %}
|
| 188 |
+
{% set roles = ["<|im_start|>user\n", "<|im_start|>assistant\n"] %}
|
| 189 |
+
{% set sep = "<|im_end|>" %}
|
| 190 |
+
|
| 191 |
+
{{ system_prompt }}{{ sep }}
|
| 192 |
+
|
| 193 |
+
{% for message in messages %}
|
| 194 |
+
{% if message['role'] == 'user' %}
|
| 195 |
+
{{ roles[0] }}{{ message['content'] }}{{ sep }}
|
| 196 |
+
{% else %}
|
| 197 |
+
{{ roles[1] }}{{ message['content'] }}{{ sep }}
|
| 198 |
+
{% endif %}
|
| 199 |
+
{% endfor %}"""
|
| 200 |
+
else:
|
| 201 |
+
raise NotImplementedError(f"Jinja template generation is not implemented for {conv_mode}.")
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
|
| 205 |
+
"""Build vision tower from path or configuration."""
|
| 206 |
+
# Skip vision tower instantiation if path is None
|
| 207 |
+
if model_name_or_path is None:
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
vision_tower_arch = None
|
| 211 |
+
if config.resume_path and "radio" not in model_name_or_path:
|
| 212 |
+
assert os.path.exists(model_name_or_path), f"Resume vision tower path {model_name_or_path} does not exist!"
|
| 213 |
+
vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
| 214 |
+
vision_tower_arch = vision_tower_cfg.architectures[0].lower()
|
| 215 |
+
vision_tower_name = vision_tower_arch if vision_tower_arch is not None else model_name_or_path
|
| 216 |
+
|
| 217 |
+
use_s2 = getattr(config, "s2", False)
|
| 218 |
+
use_dynamic_s2 = getattr(config, "dynamic_s2", False)
|
| 219 |
+
|
| 220 |
+
if "siglip" in vision_tower_name:
|
| 221 |
+
if use_dynamic_s2:
|
| 222 |
+
vision_tower = SiglipVisionTowerDynamicS2(model_name_or_path, config)
|
| 223 |
+
elif use_s2:
|
| 224 |
+
vision_tower = SiglipVisionTowerS2(model_name_or_path, config)
|
| 225 |
+
else:
|
| 226 |
+
vision_tower = SiglipVisionTower(model_name_or_path, config)
|
| 227 |
+
else:
|
| 228 |
+
raise NotImplementedError(f"Unknown vision tower: {model_name_or_path}")
|
| 229 |
+
|
| 230 |
+
config.mm_hidden_size = (
|
| 231 |
+
vision_tower.config.hidden_size if not (use_s2 or use_dynamic_s2) else vision_tower.hidden_size
|
| 232 |
+
)
|
| 233 |
+
return vision_tower
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def build_audio_tower(model_name_or_path: str, config: PretrainedConfig, encoder_type: str) -> PreTrainedModel:
|
| 237 |
+
"""Build audio tower for sound or speech processing."""
|
| 238 |
+
assert encoder_type in ["sound", "speech"]
|
| 239 |
+
|
| 240 |
+
# Skip tower instantiation if path is None
|
| 241 |
+
if model_name_or_path is None:
|
| 242 |
+
return None
|
| 243 |
+
|
| 244 |
+
model_type = "af3"
|
| 245 |
+
|
| 246 |
+
if model_type == "af3":
|
| 247 |
+
model = Qwen2AudioTower(model_name_or_path, config)
|
| 248 |
+
output_dim = 1280
|
| 249 |
+
else:
|
| 250 |
+
raise NotImplementedError(f"Not implemented for this encoder: {model_name_or_path}")
|
| 251 |
+
|
| 252 |
+
if encoder_type == "sound":
|
| 253 |
+
config.sound_hidden_size = output_dim
|
| 254 |
+
elif encoder_type == "speech":
|
| 255 |
+
config.speech_hidden_size = output_dim
|
| 256 |
+
else:
|
| 257 |
+
raise NotImplementedError(f"Not implemented for this encoder: {model_name_or_path}")
|
| 258 |
+
|
| 259 |
+
return model
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class VILAPretrainedModel(PreTrainedModel):
|
| 263 |
+
config_class = VILAConfig
|
| 264 |
+
main_input_name = "input_embeds"
|
| 265 |
+
supports_gradient_checkpointing = True
|
| 266 |
+
_supports_flash_attn_2 = True
|
| 267 |
+
_no_split_modules = ["Qwen2DecoderLayer", "SiglipEncoderLayer"]
|
| 268 |
+
|
| 269 |
+
def __init__(self, config: VILAConfig, *args, **kwargs):
|
| 270 |
+
super().__init__(config)
|
| 271 |
+
self.config = config
|
| 272 |
+
cfgs = get_model_config(config)
|
| 273 |
+
|
| 274 |
+
if len(cfgs) == 7:
|
| 275 |
+
(
|
| 276 |
+
llm_cfg,
|
| 277 |
+
vision_tower_cfg,
|
| 278 |
+
speech_tower_cfg,
|
| 279 |
+
sound_tower_cfg,
|
| 280 |
+
mm_projector_cfg,
|
| 281 |
+
speech_mm_projector_cfg,
|
| 282 |
+
sound_mm_projector_cfg,
|
| 283 |
+
) = cfgs
|
| 284 |
+
else:
|
| 285 |
+
raise ValueError(
|
| 286 |
+
"`llm_cfg` `mm_projector_cfg` `speech_mm_projector_cfg` `sound_mm_projector_cfg` `vision_tower_cfg` `speech_tower_cfg` `sound_tower_cfg` not found in the config."
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# loading on auto by default
|
| 290 |
+
device_map = kwargs.get("device_map", "auto")
|
| 291 |
+
self.mm_projector = build_mm_projector(mm_projector_cfg, config)
|
| 292 |
+
self.vision_tower = build_vision_tower(vision_tower_cfg, config)
|
| 293 |
+
|
| 294 |
+
if speech_tower_cfg:
|
| 295 |
+
self.speech_tower = build_audio_tower(speech_tower_cfg, config, encoder_type="speech")
|
| 296 |
+
self.speech_mm_projector = build_speech_mm_projector(speech_mm_projector_cfg, config)
|
| 297 |
+
if sound_tower_cfg:
|
| 298 |
+
self.sound_tower = build_audio_tower(sound_tower_cfg, config, encoder_type="sound")
|
| 299 |
+
self.sound_mm_projector = build_sound_mm_projector(sound_mm_projector_cfg, config)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
if device_map in ["auto", "cuda"]:
|
| 303 |
+
self.mm_projector = self.mm_projector.cuda()
|
| 304 |
+
self.vision_tower = self.vision_tower.cuda()
|
| 305 |
+
self.speech_tower = self.speech_tower.cuda() if hasattr(self, "speech_tower") else None
|
| 306 |
+
self.sound_tower = self.sound_tower.cuda() if hasattr(self, "sound_tower") else None
|
| 307 |
+
self.speech_mm_projector = self.speech_mm_projector.cuda() if hasattr(self, "speech_mm_projector") else None
|
| 308 |
+
self.sound_mm_projector = self.sound_mm_projector.cuda() if hasattr(self, "sound_mm_projector") else None
|
| 309 |
+
# set device_map auto can autoamtically shard llm to different devices
|
| 310 |
+
self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
|
| 311 |
+
|
| 312 |
+
self.llm_model_embed_tokens = self.llm.model.embed_tokens
|
| 313 |
+
|
| 314 |
+
self.tokenizer.padding_side = "left"
|
| 315 |
+
|
| 316 |
+
self.vocab_size = len(self.tokenizer)
|
| 317 |
+
self.update_vocab_size = lambda: setattr(self, "vocab_size", len(self.tokenizer))
|
| 318 |
+
|
| 319 |
+
self.encoders = {}
|
| 320 |
+
for name in ["image", "video", "speech", "sound"]:
|
| 321 |
+
encoder_config = getattr(self.config, f"{name}_encoder")
|
| 322 |
+
if isinstance(encoder_config, str):
|
| 323 |
+
encoder_config = json.loads(encoder_config)
|
| 324 |
+
if encoder_config.get("embed_time", False) == "True":
|
| 325 |
+
if "trope_dim" not in encoder_config and encoder_config.get("time_embed_type", "") in ["pixel", "lang"]:
|
| 326 |
+
encoder_config["trope_dim"] = self.config.hidden_size // 2
|
| 327 |
+
print(f"Warning: trope_dim not found in config, defaulting to hidden_size // 2: {encoder_config['trope_dim']}")
|
| 328 |
+
|
| 329 |
+
encoder_config.pop('_target_')
|
| 330 |
+
if name == "video":
|
| 331 |
+
self.encoders[name] = TSPVideoEncoder(parent=self, **encoder_config)
|
| 332 |
+
elif name == "image":
|
| 333 |
+
self.encoders[name] = BasicImageEncoder(self)
|
| 334 |
+
else:
|
| 335 |
+
self.encoders[name] = BasicSoundEncoder(parent=self, **encoder_config)
|
| 336 |
+
|
| 337 |
+
self.post_config()
|
| 338 |
+
self.is_loaded = True
|
| 339 |
+
|
| 340 |
+
self.llm_only_need_embed = kwargs.get("llm_only_need_embed", False)
|
| 341 |
+
if self.llm_only_need_embed:
|
| 342 |
+
print("We only need the embed_tokens in llm.")
|
| 343 |
+
del self.llm
|
| 344 |
+
self.llm = None
|
| 345 |
+
torch.cuda.empty_cache()
|
| 346 |
+
|
| 347 |
+
assert (
|
| 348 |
+
self.llm is not None
|
| 349 |
+
or self.vision_tower is not None
|
| 350 |
+
or self.speech_tower is not None
|
| 351 |
+
or self.mm_projector is not None
|
| 352 |
+
or self.speech_mm_projector is not None
|
| 353 |
+
), "At least one of the components must be instantiated."
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
@classmethod
|
| 357 |
+
def copy_or_symlink_directory(cls, model_path, output_dir, copy=True):
|
| 358 |
+
# Create output directory if it doesn't exist
|
| 359 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 360 |
+
# Create symlinks for all files in model_path to output_dir
|
| 361 |
+
for item in os.listdir(model_path):
|
| 362 |
+
src_path = os.path.join(model_path, item)
|
| 363 |
+
dst_path = os.path.join(output_dir, item)
|
| 364 |
+
|
| 365 |
+
# Remove existing file/directory at destination if it exists
|
| 366 |
+
if os.path.exists(dst_path):
|
| 367 |
+
if os.path.islink(dst_path):
|
| 368 |
+
os.unlink(dst_path)
|
| 369 |
+
elif os.path.isdir(dst_path):
|
| 370 |
+
shutil.rmtree(dst_path)
|
| 371 |
+
else:
|
| 372 |
+
os.remove(dst_path)
|
| 373 |
+
|
| 374 |
+
# Create symlink
|
| 375 |
+
if copy:
|
| 376 |
+
if os.path.isdir(src_path):
|
| 377 |
+
shutil.copytree(src_path, dst_path)
|
| 378 |
+
else:
|
| 379 |
+
shutil.copy2(src_path, dst_path)
|
| 380 |
+
print(f"Copied {src_path} to {dst_path}")
|
| 381 |
+
else:
|
| 382 |
+
os.symlink(src_path, dst_path)
|
| 383 |
+
print(f"Created symlink from {src_path} to {dst_path}")
|
| 384 |
+
|
| 385 |
+
@classmethod
|
| 386 |
+
def copy_remote_py_files(cls, output_dir, copy=True):
|
| 387 |
+
# copy .py and README for next loading
|
| 388 |
+
current_file_path = os.path.abspath(__file__)
|
| 389 |
+
current_folder = os.path.dirname(current_file_path)
|
| 390 |
+
for file_name in os.listdir(current_folder):
|
| 391 |
+
if file_name == "INSTRUCTIONS.md":
|
| 392 |
+
src_fname = os.path.join(current_folder, file_name)
|
| 393 |
+
dst_fname = os.path.join(output_dir, "README.md")
|
| 394 |
+
if os.path.exists(dst_fname):
|
| 395 |
+
old_readme = open(dst_fname).read()
|
| 396 |
+
else:
|
| 397 |
+
old_readme = ""
|
| 398 |
+
with open(src_fname) as src, open(dst_fname, "w") as dst:
|
| 399 |
+
dst.write(src.read())
|
| 400 |
+
dst.write(old_readme)
|
| 401 |
+
print("[HF] README", src_fname, "to", dst_fname)
|
| 402 |
+
if file_name.endswith(".py") or file_name.endswith(".jinja"):
|
| 403 |
+
full_file_name = os.path.join(current_folder, file_name)
|
| 404 |
+
if os.path.isfile(full_file_name):
|
| 405 |
+
if copy:
|
| 406 |
+
shutil.copy(full_file_name, output_dir)
|
| 407 |
+
print("[HF] copying", full_file_name, "to", output_dir)
|
| 408 |
+
else:
|
| 409 |
+
# symlink to ease development
|
| 410 |
+
if os.path.exists(os.path.join(output_dir, file_name)):
|
| 411 |
+
os.remove(os.path.join(output_dir, file_name))
|
| 412 |
+
os.symlink(full_file_name, os.path.join(output_dir, file_name))
|
| 413 |
+
print("[HF] linking", full_file_name, "to", output_dir)
|
| 414 |
+
|
| 415 |
+
def save_pretrained(self, output_dir, state_dict=None, **kwargs):
|
| 416 |
+
if state_dict is None:
|
| 417 |
+
# other wise fetch from deepspeed
|
| 418 |
+
# state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
|
| 419 |
+
state_dict = self.state_dict()
|
| 420 |
+
|
| 421 |
+
if getattr(self, "tokenizer", None):
|
| 422 |
+
self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
|
| 423 |
+
|
| 424 |
+
if self.get_llm():
|
| 425 |
+
print(f"saving llm to {osp.join(output_dir, 'llm')}")
|
| 426 |
+
self.llm.config._name_or_path = osp.join(output_dir, "llm")
|
| 427 |
+
llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
|
| 428 |
+
self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
|
| 429 |
+
self.config.llm_cfg = self.llm.config
|
| 430 |
+
|
| 431 |
+
if self.get_vision_tower():
|
| 432 |
+
print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
|
| 433 |
+
self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower")
|
| 434 |
+
vision_tower_state_dict = OrderedDict(
|
| 435 |
+
{k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k}
|
| 436 |
+
)
|
| 437 |
+
self.vision_tower.vision_tower.save_pretrained(
|
| 438 |
+
os.path.join(output_dir, "vision_tower"),
|
| 439 |
+
state_dict=vision_tower_state_dict,
|
| 440 |
+
)
|
| 441 |
+
self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower"))
|
| 442 |
+
self.config.vision_tower_cfg = self.vision_tower.config
|
| 443 |
+
if hasattr(self.config.vision_tower_cfg, "auto_map"):
|
| 444 |
+
if "radio" not in self.get_vision_tower().__class__.__name__.lower():
|
| 445 |
+
delattr(self.config.vision_tower_cfg, "auto_map")
|
| 446 |
+
if self.get_speech_tower():
|
| 447 |
+
print(f"saving speech_tower to {osp.join(output_dir, 'speech_tower')}")
|
| 448 |
+
self.speech_tower.config._name_or_path = osp.join(output_dir, "speech_tower").replace(
|
| 449 |
+
"tmp-checkpoint", "checkpoint"
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
speech_tower_state_dict = OrderedDict(
|
| 453 |
+
{k.split("speech_tower.audio_tower.")[-1]: v for k, v in state_dict.items() if "speech_tower" in k}
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
self.speech_tower.audio_tower.save_pretrained(
|
| 457 |
+
os.path.join(output_dir, "speech_tower"),
|
| 458 |
+
state_dict=speech_tower_state_dict,
|
| 459 |
+
)
|
| 460 |
+
self.config.speech_tower_cfg = self.speech_tower.config
|
| 461 |
+
|
| 462 |
+
if self.get_sound_tower():
|
| 463 |
+
print(f"saving sound_tower to {osp.join(output_dir, 'sound_tower')}")
|
| 464 |
+
self.sound_tower.config._name_or_path = osp.join(output_dir, "sound_tower").replace(
|
| 465 |
+
"tmp-checkpoint", "checkpoint"
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
sound_tower_state_dict = OrderedDict(
|
| 469 |
+
{k.split("sound_tower.audio_tower.")[-1]: v for k, v in state_dict.items() if "sound_tower" in k}
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
self.sound_tower.audio_tower.save_pretrained(
|
| 473 |
+
os.path.join(output_dir, "sound_tower"),
|
| 474 |
+
state_dict=sound_tower_state_dict,
|
| 475 |
+
)
|
| 476 |
+
self.config.sound_tower_cfg = self.sound_tower.config
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
if self.get_mm_projector():
|
| 481 |
+
print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
|
| 482 |
+
self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector")
|
| 483 |
+
mm_projector_state_dict = OrderedDict(
|
| 484 |
+
{k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k}
|
| 485 |
+
)
|
| 486 |
+
self.mm_projector.save_pretrained(
|
| 487 |
+
os.path.join(output_dir, "mm_projector"),
|
| 488 |
+
state_dict=mm_projector_state_dict,
|
| 489 |
+
)
|
| 490 |
+
self.config.mm_projector_cfg = self.mm_projector.config
|
| 491 |
+
|
| 492 |
+
if self.get_speech_mm_projector():
|
| 493 |
+
print(f"saving speech_mm_projector to {osp.join(output_dir, 'speech_mm_projector')}")
|
| 494 |
+
self.speech_mm_projector.config._name_or_path = osp.join(output_dir, "speech_mm_projector").replace(
|
| 495 |
+
"tmp-checkpoint", "checkpoint"
|
| 496 |
+
)
|
| 497 |
+
speech_mm_projector_state_dict = OrderedDict(
|
| 498 |
+
{k.split("speech_mm_projector.")[-1]: v for k, v in state_dict.items() if "speech_mm_projector" in k}
|
| 499 |
+
)
|
| 500 |
+
self.speech_mm_projector.save_pretrained(
|
| 501 |
+
os.path.join(output_dir, "speech_mm_projector"),
|
| 502 |
+
state_dict=speech_mm_projector_state_dict,
|
| 503 |
+
)
|
| 504 |
+
self.config.speech_mm_projector_cfg = self.speech_mm_projector.config
|
| 505 |
+
|
| 506 |
+
if self.get_sound_mm_projector():
|
| 507 |
+
print(f"saving sound_mm_projector to {osp.join(output_dir, 'sound_mm_projector')}")
|
| 508 |
+
self.sound_mm_projector.config._name_or_path = osp.join(output_dir, "sound_mm_projector").replace(
|
| 509 |
+
"tmp-checkpoint", "checkpoint"
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
sound_mm_projector_state_dict = OrderedDict(
|
| 513 |
+
{k.split("sound_mm_projector.")[-1]: v for k, v in state_dict.items() if "sound_mm_projector" in k}
|
| 514 |
+
)
|
| 515 |
+
self.sound_mm_projector.save_pretrained(
|
| 516 |
+
os.path.join(output_dir, "sound_mm_projector"),
|
| 517 |
+
state_dict=sound_mm_projector_state_dict,
|
| 518 |
+
)
|
| 519 |
+
self.config.sound_mm_projector_cfg = self.sound_mm_projector.config
|
| 520 |
+
|
| 521 |
+
# update and save top-level config
|
| 522 |
+
self.config._name_or_path = output_dir
|
| 523 |
+
self.config.architectures = [self.__class__.__name__]
|
| 524 |
+
self.config.save_pretrained(output_dir)
|
| 525 |
+
|
| 526 |
+
# copy .py and README for next loading
|
| 527 |
+
self.copy_remote_py_files(output_dir)
|
| 528 |
+
|
| 529 |
+
@classmethod
|
| 530 |
+
def from_pretrained(
|
| 531 |
+
cls,
|
| 532 |
+
pretrained_model_name_or_path: Optional[str] = None,
|
| 533 |
+
*model_args,
|
| 534 |
+
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
| 535 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 536 |
+
ignore_mismatched_sizes: bool = False,
|
| 537 |
+
force_download: bool = False,
|
| 538 |
+
local_files_only: bool = False,
|
| 539 |
+
token: Optional[Union[str, bool]] = None,
|
| 540 |
+
revision: str = "main",
|
| 541 |
+
use_safetensors: Optional[bool] = None,
|
| 542 |
+
weights_only: bool = True,
|
| 543 |
+
**kwargs,
|
| 544 |
+
):
|
| 545 |
+
# print("DEBUG2", kwargs); input()
|
| 546 |
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
|
| 547 |
+
if kwargs.get("torch_dtype", None) is not None:
|
| 548 |
+
config.torch_dtype = kwargs.get("torch_dtype", None)
|
| 549 |
+
config.model_dtype = kwargs.get("torch_dtype", None)
|
| 550 |
+
if type(kwargs.get("torch_dtype", None)) == str:
|
| 551 |
+
kwargs["torch_dtype"] = eval(kwargs.get("torch_dtype", None))
|
| 552 |
+
else:
|
| 553 |
+
kwargs["torch_dtype"] = kwargs.get("torch_dtype", None)
|
| 554 |
+
return cls._from_config(config, **kwargs)
|
| 555 |
+
|
| 556 |
+
def init_llm(self, llm_config, config, *args, **kwargs):
|
| 557 |
+
"""Initialize language model and tokenizer."""
|
| 558 |
+
self.llm, self.tokenizer = build_llm_and_tokenizer(llm_config, config, *args, **kwargs)
|
| 559 |
+
|
| 560 |
+
self.pad_token_list = (
|
| 561 |
+
self.tokenizer.pad_token_id,
|
| 562 |
+
self.tokenizer.eos_token_id,
|
| 563 |
+
self.tokenizer.tokenize("<|endoftext|>")[0], # for Qwen
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
self.vocab_size = len(self.tokenizer)
|
| 567 |
+
self.update_vocab_size = lambda: setattr(self, "vocab_size", len(self.tokenizer))
|
| 568 |
+
# XGrammar tokenizer and grammar compiler
|
| 569 |
+
# lazy init only when specified json output during inference
|
| 570 |
+
self.grammar_compiler = None
|
| 571 |
+
# self.llm.resize_token_embeddings(len(self.tokenizer))
|
| 572 |
+
return self.llm, self.tokenizer
|
| 573 |
+
|
| 574 |
+
def post_config(self):
|
| 575 |
+
self.training = self.llm.training
|
| 576 |
+
if self.training:
|
| 577 |
+
self.train()
|
| 578 |
+
else:
|
| 579 |
+
self.eval()
|
| 580 |
+
|
| 581 |
+
# configuration
|
| 582 |
+
if getattr(self.config, "llm_cfg", None) is None:
|
| 583 |
+
self.config.llm_cfg = self.llm.config
|
| 584 |
+
if getattr(self.config, "vision_tower_cfg", None) is None:
|
| 585 |
+
self.config.vision_tower_cfg = self.vision_tower.config
|
| 586 |
+
if getattr(self.config, "mm_projector_cfg", None) is None:
|
| 587 |
+
self.config.mm_projector_cfg = self.mm_projector.config
|
| 588 |
+
if getattr(self.config, "speech_tower_cfg", None) is None and hasattr(self, "speech_tower"):
|
| 589 |
+
self.config.speech_tower_cfg = self.speech_tower.config
|
| 590 |
+
if getattr(self.config, "sound_tower_cfg", None) is None and hasattr(self, "sound_tower"):
|
| 591 |
+
self.config.sound_tower_cfg = self.sound_tower.config
|
| 592 |
+
if getattr(self.config, "speech_mm_projector_cfg", None) is None and hasattr(self, "speech_mm_projector"):
|
| 593 |
+
self.config.speech_mm_projector_cfg = self.speech_mm_projector.config
|
| 594 |
+
if getattr(self.config, "sound_mm_projector_cfg", None) is None and hasattr(self, "sound_mm_projector"):
|
| 595 |
+
self.config.sound_mm_projector_cfg = self.sound_mm_projector.config
|
| 596 |
+
|
| 597 |
+
def get_llm(self):
|
| 598 |
+
llm = getattr(self, "llm", None)
|
| 599 |
+
if type(llm) is list:
|
| 600 |
+
llm = llm[0]
|
| 601 |
+
return llm
|
| 602 |
+
|
| 603 |
+
def get_lm_head(self):
|
| 604 |
+
lm_head = getattr(self.get_llm(), "lm_head", None)
|
| 605 |
+
return lm_head
|
| 606 |
+
|
| 607 |
+
def get_vision_tower(self):
|
| 608 |
+
vision_tower = getattr(self, "vision_tower", None)
|
| 609 |
+
if type(vision_tower) is list:
|
| 610 |
+
vision_tower = vision_tower[0]
|
| 611 |
+
return vision_tower
|
| 612 |
+
|
| 613 |
+
def get_speech_tower(self):
|
| 614 |
+
speech_tower = getattr(self, "speech_tower", None)
|
| 615 |
+
if type(speech_tower) is list:
|
| 616 |
+
speech_tower = speech_tower[0]
|
| 617 |
+
return speech_tower
|
| 618 |
+
|
| 619 |
+
def get_sound_tower(self):
|
| 620 |
+
sound_tower = getattr(self, "sound_tower", None)
|
| 621 |
+
if type(sound_tower) is list:
|
| 622 |
+
sound_tower = sound_tower[0]
|
| 623 |
+
return sound_tower
|
| 624 |
+
|
| 625 |
+
def get_mm_projector(self):
|
| 626 |
+
mm_projector = getattr(self, "mm_projector", None)
|
| 627 |
+
if type(mm_projector) is list:
|
| 628 |
+
mm_projector = mm_projector[0]
|
| 629 |
+
return mm_projector
|
| 630 |
+
|
| 631 |
+
def get_sound_mm_projector(self):
|
| 632 |
+
sound_mm_projector = getattr(self, "sound_mm_projector", None)
|
| 633 |
+
if type(sound_mm_projector) is list:
|
| 634 |
+
sound_mm_projector = sound_mm_projector[0]
|
| 635 |
+
return sound_mm_projector
|
| 636 |
+
|
| 637 |
+
def get_speech_tower(self):
|
| 638 |
+
speech_tower = getattr(self, "speech_tower", None)
|
| 639 |
+
if type(speech_tower) is list:
|
| 640 |
+
speech_tower = speech_tower[0]
|
| 641 |
+
return speech_tower
|
| 642 |
+
|
| 643 |
+
def get_speech_mm_projector(self):
|
| 644 |
+
speech_mm_projector = getattr(self, "speech_mm_projector", None)
|
| 645 |
+
if type(speech_mm_projector) is list:
|
| 646 |
+
speech_mm_projector = speech_mm_projector[0]
|
| 647 |
+
return speech_mm_projector
|
| 648 |
+
|
| 649 |
+
def freezed_module_patch(self):
|
| 650 |
+
"""
|
| 651 |
+
Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules.
|
| 652 |
+
"""
|
| 653 |
+
if self.training:
|
| 654 |
+
if self.get_llm() and not getattr(self.config, "tune_language_model", False):
|
| 655 |
+
pass
|
| 656 |
+
if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False):
|
| 657 |
+
self.get_vision_tower().eval()
|
| 658 |
+
if self.get_speech_tower() and not getattr(self.config, "tune_speech_tower", False):
|
| 659 |
+
self.get_speech_tower().eval()
|
| 660 |
+
if self.get_sound_tower() and not getattr(self.config, "tune_sound_tower", False):
|
| 661 |
+
self.get_sound_tower().eval()
|
| 662 |
+
if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
|
| 663 |
+
self.get_mm_projector().eval()
|
| 664 |
+
if self.get_speech_mm_projector() and not getattr(self.config, "tune_speech_mm_projector", False):
|
| 665 |
+
self.get_speech_mm_projector().eval()
|
| 666 |
+
if self.get_sound_mm_projector() and not getattr(self.config, "tune_sound_mm_projector", False):
|
| 667 |
+
self.get_sound_mm_projector().eval()
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
class VILAForCausalLM(VILAPretrainedModel):
|
| 671 |
+
def __init__(self, config: VILAConfig, *args, **kwargs):
|
| 672 |
+
super().__init__(config, *args, **kwargs)
|
| 673 |
+
|
| 674 |
+
def merge_features_for_dynamic_s2(self, image_features, block_sizes):
|
| 675 |
+
scales = self.get_vision_tower().scales
|
| 676 |
+
resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx
|
| 677 |
+
|
| 678 |
+
image_features_each_image = []
|
| 679 |
+
new_block_sizes = []
|
| 680 |
+
block_cnt = 0
|
| 681 |
+
for block_size_each_image in block_sizes:
|
| 682 |
+
if block_size_each_image is None:
|
| 683 |
+
cur_features = image_features[block_cnt : block_cnt + 1]
|
| 684 |
+
cur_features = rearrange(cur_features, "1 (h w) c -> 1 c h w", h=int(cur_features.shape[1] ** 0.5))
|
| 685 |
+
cur_features = cur_features.repeat(1, len(scales), 1, 1)
|
| 686 |
+
image_features_each_image.append(cur_features)
|
| 687 |
+
new_block_sizes.append((1, 1))
|
| 688 |
+
block_cnt += 1
|
| 689 |
+
else:
|
| 690 |
+
cur_features_each_scale = []
|
| 691 |
+
for scale in scales[:-1]:
|
| 692 |
+
num_blocks_this_scale = (scale // scales[0]) ** 2
|
| 693 |
+
cur_features_each_scale.append(
|
| 694 |
+
self.merge_chessboard(
|
| 695 |
+
image_features[block_cnt : block_cnt + num_blocks_this_scale],
|
| 696 |
+
num_split_h=scale // scales[0],
|
| 697 |
+
num_split_w=scale // scales[0],
|
| 698 |
+
)
|
| 699 |
+
) # 1 * C * H * W
|
| 700 |
+
block_cnt += num_blocks_this_scale
|
| 701 |
+
num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
|
| 702 |
+
cur_features_each_scale.append(
|
| 703 |
+
self.merge_chessboard(
|
| 704 |
+
image_features[block_cnt : block_cnt + num_blocks_last_scale],
|
| 705 |
+
num_split_h=block_size_each_image[0],
|
| 706 |
+
num_split_w=block_size_each_image[1],
|
| 707 |
+
)
|
| 708 |
+
) # 1 * C * H * W
|
| 709 |
+
block_cnt += num_blocks_last_scale
|
| 710 |
+
|
| 711 |
+
# resize and concat features from different scales
|
| 712 |
+
output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
|
| 713 |
+
cur_features = torch.cat(
|
| 714 |
+
[
|
| 715 |
+
F.interpolate(cur_features_each_scale[i].to(torch.float32), size=output_size, mode="area").to(
|
| 716 |
+
cur_features_each_scale[i].dtype
|
| 717 |
+
)
|
| 718 |
+
for i in range(len(cur_features_each_scale))
|
| 719 |
+
],
|
| 720 |
+
dim=1,
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
image_features_each_image.append(cur_features)
|
| 724 |
+
|
| 725 |
+
if resize_output_to_scale_idx == len(scales) - 1 or resize_output_to_scale_idx == -1:
|
| 726 |
+
new_block_sizes.append(block_size_each_image)
|
| 727 |
+
else:
|
| 728 |
+
new_block_sizes.append(
|
| 729 |
+
(
|
| 730 |
+
scales[resize_output_to_scale_idx] // scales[0],
|
| 731 |
+
scales[resize_output_to_scale_idx] // scales[0],
|
| 732 |
+
)
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
assert block_cnt == len(image_features)
|
| 736 |
+
|
| 737 |
+
return image_features_each_image, new_block_sizes
|
| 738 |
+
|
| 739 |
+
@staticmethod
|
| 740 |
+
def split_chessboard(x, num_split_h, num_split_w):
|
| 741 |
+
"""
|
| 742 |
+
x: b * c * h * w
|
| 743 |
+
out: b * c * h * w
|
| 744 |
+
Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension
|
| 745 |
+
"""
|
| 746 |
+
B, C, H, W = x.shape
|
| 747 |
+
assert H % num_split_h == 0 and W % num_split_w == 0
|
| 748 |
+
h, w = H // num_split_h, W // num_split_w
|
| 749 |
+
x_split = torch.cat(
|
| 750 |
+
[x[:, :, i * h : (i + 1) * h, j * w : (j + 1) * w] for i in range(num_split_h) for j in range(num_split_w)],
|
| 751 |
+
dim=0,
|
| 752 |
+
)
|
| 753 |
+
return x_split
|
| 754 |
+
|
| 755 |
+
@staticmethod
|
| 756 |
+
def merge_chessboard(x, num_split_h, num_split_w):
|
| 757 |
+
"""
|
| 758 |
+
x: b * n * c or b * h * w * c
|
| 759 |
+
out: b * c * h * w
|
| 760 |
+
Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square.
|
| 761 |
+
"""
|
| 762 |
+
B = x.shape[0]
|
| 763 |
+
if x.dim() == 3:
|
| 764 |
+
N = x.shape[1]
|
| 765 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=int(N**0.5), w=int(N**0.5))
|
| 766 |
+
|
| 767 |
+
assert B % (num_split_h * num_split_w) == 0
|
| 768 |
+
b = B // (num_split_h * num_split_w)
|
| 769 |
+
|
| 770 |
+
x_merge = torch.cat(
|
| 771 |
+
[
|
| 772 |
+
torch.cat(
|
| 773 |
+
[x[(i * num_split_w + j) * b : (i * num_split_w + j + 1) * b] for j in range(num_split_w)], dim=-1
|
| 774 |
+
)
|
| 775 |
+
for i in range(num_split_h)
|
| 776 |
+
],
|
| 777 |
+
dim=-2,
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
return x_merge
|
| 781 |
+
|
| 782 |
+
def encode_video(self, inp, block_sizes: Optional[Optional[Tuple[int, ...]]] = None, mm_info: Optional[dict] = None, num_frames: Optional[List[int]] = None):
|
| 783 |
+
bs = len(inp)
|
| 784 |
+
cache_feas = []
|
| 785 |
+
cache_feas_index = []
|
| 786 |
+
inp_block_sizes = block_sizes
|
| 787 |
+
|
| 788 |
+
# handle cache features
|
| 789 |
+
for _idx in range(len(inp)):
|
| 790 |
+
if type(inp[_idx]) == CacheFeatures:
|
| 791 |
+
cache_feas.append(inp[_idx])
|
| 792 |
+
cache_feas_index.append(_idx)
|
| 793 |
+
raw_images = [_ for _ in inp if type(_) != CacheFeatures]
|
| 794 |
+
|
| 795 |
+
raw_videos_num_frames = [_.shape[0] for _ in raw_images]
|
| 796 |
+
if len(raw_images) > 0:
|
| 797 |
+
images = torch.cat(raw_images, dim=0)
|
| 798 |
+
else:
|
| 799 |
+
images = []
|
| 800 |
+
|
| 801 |
+
if block_sizes is None:
|
| 802 |
+
block_sizes = [None] * len(images)
|
| 803 |
+
|
| 804 |
+
def _load_video_features(image_features, cache_feas, cache_feas_index, raw_videos_num_frames):
|
| 805 |
+
# load cache features
|
| 806 |
+
if len(cache_feas) > 0:
|
| 807 |
+
if len(image_features) > 0:
|
| 808 |
+
image_features = torch.split(image_features, raw_videos_num_frames)
|
| 809 |
+
new_image_features = []
|
| 810 |
+
cache_feas_idx = 0
|
| 811 |
+
raw_fea_idx = 0
|
| 812 |
+
for _idx in range(bs):
|
| 813 |
+
if _idx in cache_feas_index:
|
| 814 |
+
new_image_features.append(cache_feas[cache_feas_idx].value['features'].to(self.device, self.dtype))
|
| 815 |
+
cache_feas_idx += 1
|
| 816 |
+
else:
|
| 817 |
+
new_image_features.append(image_features[raw_fea_idx])
|
| 818 |
+
raw_fea_idx += 1
|
| 819 |
+
|
| 820 |
+
assert len(new_image_features) == bs
|
| 821 |
+
image_features = new_image_features
|
| 822 |
+
image_features = torch.cat(image_features, dim=0)
|
| 823 |
+
return image_features
|
| 824 |
+
|
| 825 |
+
if getattr(self.config, "dynamic_s2", False):
|
| 826 |
+
|
| 827 |
+
if len(images) > 0:
|
| 828 |
+
image_features = self.get_vision_tower()(images)
|
| 829 |
+
|
| 830 |
+
image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
|
| 831 |
+
|
| 832 |
+
image_features = [
|
| 833 |
+
self.split_chessboard(x, block_size[0], block_size[1])
|
| 834 |
+
for x, block_size in zip(image_features, new_block_sizes)
|
| 835 |
+
] # list of B * C * H * W tensors
|
| 836 |
+
image_features = torch.cat(
|
| 837 |
+
[rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0
|
| 838 |
+
) # B * N * C
|
| 839 |
+
else:
|
| 840 |
+
image_features = []
|
| 841 |
+
|
| 842 |
+
# load cache features
|
| 843 |
+
image_features = _load_video_features(image_features, cache_feas, cache_feas_index, raw_videos_num_frames)
|
| 844 |
+
|
| 845 |
+
# if hasattr(self.config, "save_data") and self.config.save_data and num_frames is not None: # video
|
| 846 |
+
# _save_video_features(image_features, mm_info, inp)
|
| 847 |
+
|
| 848 |
+
if inp_block_sizes is None:
|
| 849 |
+
new_block_sizes = [(1, 1)] * len(image_features)
|
| 850 |
+
else:
|
| 851 |
+
raise ValueError(f"inp_block_sizes is not None: {inp_block_sizes}")
|
| 852 |
+
image_features = image_features.to(self.device, self.dtype)
|
| 853 |
+
image_features = self.get_mm_projector()(image_features)
|
| 854 |
+
image_features = list(
|
| 855 |
+
image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0)
|
| 856 |
+
)
|
| 857 |
+
image_features = [
|
| 858 |
+
self.merge_chessboard(x, block_size[0], block_size[1])
|
| 859 |
+
for x, block_size in zip(image_features, new_block_sizes)
|
| 860 |
+
] # list of 1 * C * H * W tensors
|
| 861 |
+
image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] # list of N * C tensors
|
| 862 |
+
if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]):
|
| 863 |
+
image_features = torch.stack(image_features, dim=0)
|
| 864 |
+
else:
|
| 865 |
+
if len(images) > 0:
|
| 866 |
+
image_features = self.get_vision_tower()(images)
|
| 867 |
+
else:
|
| 868 |
+
image_features = []
|
| 869 |
+
|
| 870 |
+
# load cache features
|
| 871 |
+
image_features = _load_video_features(image_features, cache_feas, cache_feas_index, raw_videos_num_frames)
|
| 872 |
+
|
| 873 |
+
image_features = self.get_mm_projector()(image_features)
|
| 874 |
+
return image_features
|
| 875 |
+
|
| 876 |
+
def encode_images(self, images, block_sizes: Optional[Optional[Tuple[int, ...]]] = None, mm_info: Optional[dict] = None, num_frames: Optional[List[int]] = None):
|
| 877 |
+
if block_sizes is None:
|
| 878 |
+
block_sizes = [None] * len(images)
|
| 879 |
+
|
| 880 |
+
if getattr(self.config, "dynamic_s2", False):
|
| 881 |
+
image_features = self.get_vision_tower()(images)
|
| 882 |
+
|
| 883 |
+
image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
|
| 884 |
+
|
| 885 |
+
image_features = [
|
| 886 |
+
self.split_chessboard(x, block_size[0], block_size[1])
|
| 887 |
+
for x, block_size in zip(image_features, new_block_sizes)
|
| 888 |
+
] # list of B * C * H * W tensors
|
| 889 |
+
image_features = torch.cat(
|
| 890 |
+
[rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0
|
| 891 |
+
) # B * N * C
|
| 892 |
+
|
| 893 |
+
image_features = self.get_mm_projector()(image_features)
|
| 894 |
+
image_features = list(
|
| 895 |
+
image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0)
|
| 896 |
+
)
|
| 897 |
+
image_features = [
|
| 898 |
+
self.merge_chessboard(x, block_size[0], block_size[1])
|
| 899 |
+
for x, block_size in zip(image_features, new_block_sizes)
|
| 900 |
+
] # list of 1 * C * H * W tensors
|
| 901 |
+
image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] # list of N * C tensors
|
| 902 |
+
if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]):
|
| 903 |
+
image_features = torch.stack(image_features, dim=0)
|
| 904 |
+
else:
|
| 905 |
+
image_features = self.get_vision_tower()(images)
|
| 906 |
+
|
| 907 |
+
image_features = self.get_mm_projector()(image_features)
|
| 908 |
+
return image_features
|
| 909 |
+
|
| 910 |
+
def encode_sound(self, sounds, mm_info: Optional[dict] = None):
|
| 911 |
+
|
| 912 |
+
audio_features, audio_output_lengths = self.get_sound_tower()(sounds)
|
| 913 |
+
|
| 914 |
+
use_fea_downsample = False
|
| 915 |
+
if getattr(self.config, "sound_mm_projector", "") != "":
|
| 916 |
+
if "mlp_downsample" in getattr(self.config, "sound_mm_projector", ""):
|
| 917 |
+
use_fea_downsample = True
|
| 918 |
+
else:
|
| 919 |
+
sound_mm_projector_cfg = getattr(self.config, "sound_mm_projector_cfg", None)
|
| 920 |
+
if sound_mm_projector_cfg is not None:
|
| 921 |
+
if type(sound_mm_projector_cfg) == dict:
|
| 922 |
+
if "mlp_downsample" in sound_mm_projector_cfg["sound_mm_projector_type"]:
|
| 923 |
+
use_fea_downsample = True
|
| 924 |
+
elif type(sound_mm_projector_cfg) == SoundMultimodalProjectorConfig:
|
| 925 |
+
if "mlp_downsample" in sound_mm_projector_cfg.sound_mm_projector_type:
|
| 926 |
+
use_fea_downsample = True
|
| 927 |
+
|
| 928 |
+
if not use_fea_downsample:
|
| 929 |
+
audio_features = self.get_sound_mm_projector()(audio_features)
|
| 930 |
+
|
| 931 |
+
if audio_output_lengths is not None:
|
| 932 |
+
# split the batch
|
| 933 |
+
new_audio_features = []
|
| 934 |
+
start = 0
|
| 935 |
+
for length in audio_output_lengths:
|
| 936 |
+
new_audio_features.append(audio_features[start : start + length])
|
| 937 |
+
start += length
|
| 938 |
+
audio_features = new_audio_features
|
| 939 |
+
|
| 940 |
+
if use_fea_downsample:
|
| 941 |
+
audio_features = torch.stack(audio_features, dim=0)
|
| 942 |
+
audio_features = self.get_sound_mm_projector()(audio_features)
|
| 943 |
+
|
| 944 |
+
return audio_features
|
| 945 |
+
|
| 946 |
+
def train(self, mode: bool = True):
|
| 947 |
+
super().train(mode)
|
| 948 |
+
return self
|
| 949 |
+
|
| 950 |
+
def _embed(
|
| 951 |
+
self,
|
| 952 |
+
input_ids: torch.Tensor,
|
| 953 |
+
media: Dict[str, List[torch.Tensor]],
|
| 954 |
+
media_config: Dict[str, Dict[str, Any]],
|
| 955 |
+
labels: Optional[torch.Tensor],
|
| 956 |
+
attention_mask: Optional[torch.Tensor],
|
| 957 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 958 |
+
media = copy.deepcopy(media)
|
| 959 |
+
media_config = copy.deepcopy(media_config)
|
| 960 |
+
|
| 961 |
+
labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX)
|
| 962 |
+
attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool)
|
| 963 |
+
|
| 964 |
+
PROCESS_GROUP_MANAGER = None
|
| 965 |
+
if PROCESS_GROUP_MANAGER is not None:
|
| 966 |
+
for name in media:
|
| 967 |
+
self.encoders[name].end_tokens = None
|
| 968 |
+
|
| 969 |
+
# Extract text and media embeddings
|
| 970 |
+
text_embeds = self.llm_model_embed_tokens(input_ids)
|
| 971 |
+
|
| 972 |
+
mm_info = {}
|
| 973 |
+
if "video_info" in media:
|
| 974 |
+
video_info = media["video_info"]
|
| 975 |
+
del media["video_info"]
|
| 976 |
+
mm_info['video_info'] = video_info
|
| 977 |
+
else:
|
| 978 |
+
video_info = None
|
| 979 |
+
|
| 980 |
+
if "audio_info" in media:
|
| 981 |
+
audio_info = media["audio_info"]
|
| 982 |
+
del media["audio_info"]
|
| 983 |
+
mm_info['audio_info'] = audio_info
|
| 984 |
+
else:
|
| 985 |
+
audio_info = None
|
| 986 |
+
|
| 987 |
+
if media is not None:
|
| 988 |
+
media_embeds = self.__embed_media_tokens(media, media_config, mm_info)
|
| 989 |
+
else:
|
| 990 |
+
# no media was provided, so we just return an empty dict
|
| 991 |
+
media_embeds = {}
|
| 992 |
+
|
| 993 |
+
if PROCESS_GROUP_MANAGER is not None:
|
| 994 |
+
media_embeds_video = []
|
| 995 |
+
for i, images in enumerate(media_embeds["video"]):
|
| 996 |
+
num_video_frame = media["video"][i].shape[0]
|
| 997 |
+
media_embeds_video += torch.unbind(images.reshape(num_video_frame, -1, images.shape[-1]))
|
| 998 |
+
media_embeds["video"] = deque(media_embeds_video)
|
| 999 |
+
|
| 1000 |
+
# This is a workaround to make sure the dummy embeddings are consumed
|
| 1001 |
+
while media_embeds.get("dummy"):
|
| 1002 |
+
dummy_embed = media_embeds["dummy"].popleft()
|
| 1003 |
+
text_embeds += torch.sum(dummy_embed) * 0
|
| 1004 |
+
|
| 1005 |
+
# Based on segment_aud_indices_list and segment_vis_indices_list, get interleaved vis-aud embeddings for video
|
| 1006 |
+
video_sound_embeds_idx = 0
|
| 1007 |
+
sep_embed = self.encoders["video"].embed_tokens("\n")
|
| 1008 |
+
text_embeds = text_embeds.to(self.dtype)
|
| 1009 |
+
sep_embed = sep_embed.to(text_embeds.dtype)
|
| 1010 |
+
|
| 1011 |
+
if video_info is not None and self.config.load_audio_in_video and self.config.interleaved_vis_aud_in_video:
|
| 1012 |
+
assert self.encoders["video"].end_tokens is None, "end_tokens must be None for interleaved vis-aud in video"
|
| 1013 |
+
new_video_embeds = deque()
|
| 1014 |
+
video_embeds_idx = 0
|
| 1015 |
+
for k in range(len(video_info)):
|
| 1016 |
+
if video_info[k] is None:
|
| 1017 |
+
continue
|
| 1018 |
+
for i in range(len(video_info[k])):
|
| 1019 |
+
has_audio = video_info[k][i]["has_audio"]
|
| 1020 |
+
if not has_audio:
|
| 1021 |
+
new_video_embeds.append(media_embeds["video"][video_embeds_idx])
|
| 1022 |
+
video_embeds_idx += 1
|
| 1023 |
+
continue
|
| 1024 |
+
|
| 1025 |
+
# Check bounds for sound embeddings
|
| 1026 |
+
if video_sound_embeds_idx >= len(media_embeds["sound"]):
|
| 1027 |
+
raise ValueError(f"Sound embeddings index {video_sound_embeds_idx} out of bounds for video_info[{k}][{i}]")
|
| 1028 |
+
|
| 1029 |
+
segment_aud_indices_list = video_info[k][i]["segment_aud_indices_list"]
|
| 1030 |
+
segment_vis_indices_list = video_info[k][i]["segment_vis_indices_list"]
|
| 1031 |
+
|
| 1032 |
+
vis_fea_len_per_frame = media_embeds["video"][video_embeds_idx].shape[0] / video_info[k][i]["expected_frame_count"]
|
| 1033 |
+
aud_fea_len_per_stft_frame = media_embeds["sound"][video_sound_embeds_idx].shape[0] / audio_info[k][i]["new_audio_n_stft_frames"]
|
| 1034 |
+
vis_end = 0
|
| 1035 |
+
aud_end = 0
|
| 1036 |
+
_new_video_embed = []
|
| 1037 |
+
for j in range(len(segment_vis_indices_list)):
|
| 1038 |
+
_vis_aud_fea = []
|
| 1039 |
+
if len(segment_vis_indices_list[j]) > 0:
|
| 1040 |
+
_new_frames = [int(np.ceil((_frame+1) * vis_fea_len_per_frame)) for _frame in segment_vis_indices_list[j]]
|
| 1041 |
+
_vis_fea_end = _new_frames[-1]
|
| 1042 |
+
# Ensure we don't exceed the available features
|
| 1043 |
+
_vis_fea_end = min(_vis_fea_end, media_embeds["video"][video_embeds_idx].shape[0])
|
| 1044 |
+
if j == len(segment_vis_indices_list) - 1 and i == len(video_info) - 1 and k == len(video_info[i]) - 1 and not _vis_fea_end == media_embeds["video"][video_embeds_idx].shape[0]:
|
| 1045 |
+
print(f"Warning: The number of last interleaved video features does not match the video feature length. Expected: {media_embeds['video'][video_embeds_idx].shape[0]}, Got: {_vis_fea_end}")
|
| 1046 |
+
_vis_fea_end = media_embeds["video"][video_embeds_idx].shape[0]
|
| 1047 |
+
_vis_fea = media_embeds["video"][video_embeds_idx][vis_end:_vis_fea_end]
|
| 1048 |
+
vis_end = _vis_fea_end
|
| 1049 |
+
_vis_aud_fea.append(_vis_fea)
|
| 1050 |
+
_vis_aud_fea.append(sep_embed)
|
| 1051 |
+
if len(segment_aud_indices_list[j]) > 0:
|
| 1052 |
+
_new_audio_indices = [int(np.ceil(_fea * aud_fea_len_per_stft_frame)) for _fea in segment_aud_indices_list[j]]
|
| 1053 |
+
_aud_fea_end = _new_audio_indices[-1]
|
| 1054 |
+
# Ensure we don't exceed the available features
|
| 1055 |
+
_aud_fea_end = min(_aud_fea_end, media_embeds["sound"][video_sound_embeds_idx].shape[0])
|
| 1056 |
+
_aud_fea = media_embeds["sound"][video_sound_embeds_idx][aud_end:_aud_fea_end]
|
| 1057 |
+
_vis_aud_fea.append(_aud_fea)
|
| 1058 |
+
aud_end = _aud_fea_end
|
| 1059 |
+
_vis_aud_fea.append(sep_embed)
|
| 1060 |
+
_new_video_embed.append(torch.cat(_vis_aud_fea, dim=0))
|
| 1061 |
+
video_sound_embeds_idx += 1
|
| 1062 |
+
new_video_embeds.append(torch.cat(_new_video_embed, dim=0))
|
| 1063 |
+
video_embeds_idx += 1
|
| 1064 |
+
|
| 1065 |
+
assert len(new_video_embeds) == len(media_embeds["video"]), "The number of new video embeddings does not match the number of original video embeddings."
|
| 1066 |
+
media_embeds["video"] = new_video_embeds
|
| 1067 |
+
# Remove padding
|
| 1068 |
+
batch_size = labels.shape[0]
|
| 1069 |
+
text_embeds = [text_embeds[k][attention_mask[k]] for k in range(batch_size)]
|
| 1070 |
+
labels = [labels[k][attention_mask[k]] for k in range(batch_size)]
|
| 1071 |
+
# Build inverse mapping from token ID to media name
|
| 1072 |
+
media_tokens = {}
|
| 1073 |
+
for name, token_id in self.tokenizer.media_token_ids.items():
|
| 1074 |
+
media_tokens[token_id] = name
|
| 1075 |
+
|
| 1076 |
+
# Fuse text and media embeddings
|
| 1077 |
+
inputs_m, labels_m = [], []
|
| 1078 |
+
sound_embeds_idx = 0
|
| 1079 |
+
for k in range(batch_size):
|
| 1080 |
+
inputs_mk, labels_mk = [], []
|
| 1081 |
+
pos = 0
|
| 1082 |
+
while pos < len(labels[k]):
|
| 1083 |
+
if input_ids[k][pos].item() in media_tokens:
|
| 1084 |
+
name = media_tokens[input_ids[k][pos].item()] if PROCESS_GROUP_MANAGER is None else "video"
|
| 1085 |
+
if input_ids[k][pos].item() == self.tokenizer.media_token_ids["sound"]:
|
| 1086 |
+
if self.config.interleaved_vis_aud_in_video:
|
| 1087 |
+
if sound_embeds_idx < video_sound_embeds_idx:
|
| 1088 |
+
media_embeds[name].popleft()
|
| 1089 |
+
sound_embeds_idx += 1
|
| 1090 |
+
pos += 1
|
| 1091 |
+
continue
|
| 1092 |
+
sound_embeds_idx += 1
|
| 1093 |
+
|
| 1094 |
+
end = pos + 1
|
| 1095 |
+
input = media_embeds[name].popleft()
|
| 1096 |
+
label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
|
| 1097 |
+
else:
|
| 1098 |
+
end = pos
|
| 1099 |
+
while end < len(labels[k]) and input_ids[k][end].item() not in media_tokens:
|
| 1100 |
+
end += 1
|
| 1101 |
+
input = text_embeds[k][pos:end]
|
| 1102 |
+
label = labels[k][pos:end]
|
| 1103 |
+
|
| 1104 |
+
inputs_mk.append(input)
|
| 1105 |
+
labels_mk.append(label)
|
| 1106 |
+
pos = end
|
| 1107 |
+
inputs_m.append(torch.cat(inputs_mk, dim=0))
|
| 1108 |
+
labels_m.append(torch.cat(labels_mk, dim=0))
|
| 1109 |
+
inputs, labels = inputs_m, labels_m
|
| 1110 |
+
|
| 1111 |
+
inputs[0] += sep_embed.mean() * 0 # dummy embedding
|
| 1112 |
+
# Check if all media embeddings are consumed
|
| 1113 |
+
|
| 1114 |
+
for name in media_embeds:
|
| 1115 |
+
if media_embeds[name]:
|
| 1116 |
+
raise ValueError(f"Not all {name} embeddings are consumed! Still {len(media_embeds[name])} left.")
|
| 1117 |
+
|
| 1118 |
+
# Truncate sequences to `model_max_length` as media embeddings are inserted
|
| 1119 |
+
inputs, labels = self.__truncate_sequence(inputs, labels)
|
| 1120 |
+
|
| 1121 |
+
# Pad sequences to the longest one in the batch
|
| 1122 |
+
return self.__batchify_sequence(inputs, labels)
|
| 1123 |
+
|
| 1124 |
+
def __embed_media_tokens(
|
| 1125 |
+
self,
|
| 1126 |
+
media: Dict[str, List[torch.Tensor]],
|
| 1127 |
+
media_config: Dict[str, Dict[str, Any]],
|
| 1128 |
+
mm_info,
|
| 1129 |
+
) -> Dict[str, List[torch.Tensor]]:
|
| 1130 |
+
embeds = defaultdict(deque)
|
| 1131 |
+
|
| 1132 |
+
if self.config.unified_audio_encoder:
|
| 1133 |
+
assert len(media["speech"]) == 0
|
| 1134 |
+
|
| 1135 |
+
for name in media:
|
| 1136 |
+
_encoder = self.encoders[name]
|
| 1137 |
+
if name in ["speech", "sound"] and self.config.unified_audio_encoder:
|
| 1138 |
+
_encoder = self.encoders["sound"]
|
| 1139 |
+
|
| 1140 |
+
if self.training:
|
| 1141 |
+
# Gather metainfo of media objects from all ranks
|
| 1142 |
+
if name in ["speech", "sound"]:
|
| 1143 |
+
|
| 1144 |
+
info = []
|
| 1145 |
+
if type(media.get(name, {})) is dict:
|
| 1146 |
+
for _dict in media.get(name, {}):
|
| 1147 |
+
info.append({k: {"shape": v.shape, "dtype": v.dtype} for k, v in _dict.items()})
|
| 1148 |
+
elif type(media.get(name, {})) is list:
|
| 1149 |
+
info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
|
| 1150 |
+
else:
|
| 1151 |
+
raise ValueError(f"Unsupported media type: {type(media.get(name, {}))}")
|
| 1152 |
+
|
| 1153 |
+
infos_list = vila_all_gather(info)
|
| 1154 |
+
infos = list(chain(*infos_list))
|
| 1155 |
+
|
| 1156 |
+
# The entire batch does not contain any media objects of this type.
|
| 1157 |
+
if not infos:
|
| 1158 |
+
continue
|
| 1159 |
+
|
| 1160 |
+
# for audio encoding, we have to ensure the batch size is the same for all ranks. If not, we need to pad the batch with dummy tensors to the max batch size
|
| 1161 |
+
max_batch_size = max(len(_info) for _info in infos_list)
|
| 1162 |
+
missing_batch_size = max_batch_size - len(info)
|
| 1163 |
+
|
| 1164 |
+
_media = media.get(name, [])
|
| 1165 |
+
|
| 1166 |
+
_medias = list(chain(vila_all_gather(_media)))
|
| 1167 |
+
if missing_batch_size > 0:
|
| 1168 |
+
for i in range(missing_batch_size):
|
| 1169 |
+
# use one of the media tensors to create a dummy tensor
|
| 1170 |
+
if type(media.get(name, {})) is dict:
|
| 1171 |
+
_dummy = {k: v.clone().to(device=self.device) for k, v in _medias[0].items()}
|
| 1172 |
+
elif type(media.get(name, {})) is list:
|
| 1173 |
+
if type(_medias[0]) is torch.Tensor:
|
| 1174 |
+
_dummy = _medias[0].clone().to(device=self.device)
|
| 1175 |
+
elif type(_medias[0]) is np.ndarray:
|
| 1176 |
+
_dummy = _medias[0].copy()
|
| 1177 |
+
else:
|
| 1178 |
+
raise ValueError(f"Unsupported media type: {type(_medias[0])}")
|
| 1179 |
+
else:
|
| 1180 |
+
raise ValueError(f"Unsupported media type: {type(media.get(name, {}))}")
|
| 1181 |
+
_media.append(_dummy)
|
| 1182 |
+
mm_info["audio_info"].append(["dummy"])
|
| 1183 |
+
|
| 1184 |
+
# we need to also align the length of all audio samples in the batch size
|
| 1185 |
+
cur_batch_max_audio_samples = max(len(_audio) for _audio in _medias)
|
| 1186 |
+
cur_batch_max_audio_samples = int(np.ceil(cur_batch_max_audio_samples / (self.config.audio_sampling_rate * 30)) * (self.config.audio_sampling_rate * 30)) # should be multiple of 30 seconds
|
| 1187 |
+
cur_batch_max_audio_samples = min(cur_batch_max_audio_samples, self.config.audio_chunk_length * self.config.audio_sampling_rate)
|
| 1188 |
+
cur_batch_max_audio_duration = cur_batch_max_audio_samples // self.config.audio_sampling_rate
|
| 1189 |
+
|
| 1190 |
+
|
| 1191 |
+
whisper_feature_extractor = WhisperFeatureExtractor.from_pretrained(
|
| 1192 |
+
self.config._name_or_path, chunk_length=cur_batch_max_audio_duration, sampling_rate=self.config.audio_sampling_rate, hop_length=self.config.audio_hop_length
|
| 1193 |
+
)
|
| 1194 |
+
|
| 1195 |
+
# use WhisperFeatureExtractor in transformers to load
|
| 1196 |
+
new_media = []
|
| 1197 |
+
|
| 1198 |
+
aud_idx = 0
|
| 1199 |
+
for _batch_idx in range(len(mm_info["audio_info"])):
|
| 1200 |
+
_audio_info = mm_info["audio_info"][_batch_idx]
|
| 1201 |
+
if _audio_info is not None:
|
| 1202 |
+
for _mm_idx in range(len(_audio_info)):
|
| 1203 |
+
_audio = _media[aud_idx]
|
| 1204 |
+
if type(_audio) is torch.Tensor:
|
| 1205 |
+
device = _audio.device
|
| 1206 |
+
dtype = _audio.dtype
|
| 1207 |
+
_audio = _audio.cpu().float()
|
| 1208 |
+
else:
|
| 1209 |
+
# logger.warning(f"The audio type is not a tensor, which is unexpected. Using the device and dtype of the model. media: {media}, mm_info: {mm_info}")
|
| 1210 |
+
device = self.device
|
| 1211 |
+
dtype = self.dtype
|
| 1212 |
+
_audio = whisper.pad_or_trim(_audio, length=cur_batch_max_audio_samples)
|
| 1213 |
+
aud_idx += 1
|
| 1214 |
+
stft_features = whisper_feature_extractor(
|
| 1215 |
+
_audio,
|
| 1216 |
+
sampling_rate=self.config.audio_sampling_rate,
|
| 1217 |
+
return_attention_mask=True,
|
| 1218 |
+
padding="max_length",
|
| 1219 |
+
return_tensors="pt",
|
| 1220 |
+
).to(device, dtype)
|
| 1221 |
+
new_media.append(stft_features)
|
| 1222 |
+
if _audio_info[_mm_idx] != "dummy":
|
| 1223 |
+
_audio_info[_mm_idx]["new_audio_chunk_length"] = cur_batch_max_audio_duration
|
| 1224 |
+
_audio_info[_mm_idx]["new_audio_n_samples"] = cur_batch_max_audio_samples
|
| 1225 |
+
_audio_info[_mm_idx]["audio_end_sample_sec"] = _audio_info[_mm_idx]["audio_start_sec"] + cur_batch_max_audio_duration
|
| 1226 |
+
_audio_info[_mm_idx]["new_audio_n_stft_frames"] = stft_features["input_features"].shape[-1]
|
| 1227 |
+
|
| 1228 |
+
assert aud_idx == len(_media), "The number of audio info does not match the number of audio samples."
|
| 1229 |
+
_media = new_media
|
| 1230 |
+
|
| 1231 |
+
_fea = _encoder(_media, media_config[name], mm_info)
|
| 1232 |
+
# [751, 1536]
|
| 1233 |
+
# consume dummy features later
|
| 1234 |
+
_dummy_fea = _fea[len(info) :]
|
| 1235 |
+
embeds["dummy"].extend(_dummy_fea)
|
| 1236 |
+
|
| 1237 |
+
# remove the dummy features
|
| 1238 |
+
_real_fea = _fea[: len(info)]
|
| 1239 |
+
if len(info) > 0:
|
| 1240 |
+
embeds[name] = deque(_real_fea)
|
| 1241 |
+
|
| 1242 |
+
else:
|
| 1243 |
+
# Gather metainfo of media objects from all ranks
|
| 1244 |
+
info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
|
| 1245 |
+
infos = list(chain(vila_all_gather(info)))
|
| 1246 |
+
|
| 1247 |
+
# The entire batch does not contain any media objects of this type.
|
| 1248 |
+
if not infos:
|
| 1249 |
+
continue
|
| 1250 |
+
|
| 1251 |
+
# Create a dummy tensor to ensure the encoder is called, otherwise the training will hang.
|
| 1252 |
+
if media.get(name) is None or len(media[name]) == 0:
|
| 1253 |
+
dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
|
| 1254 |
+
embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
|
| 1255 |
+
continue
|
| 1256 |
+
embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
|
| 1257 |
+
|
| 1258 |
+
else:
|
| 1259 |
+
if name == "sound":
|
| 1260 |
+
all_audio_chunk_lengths = []
|
| 1261 |
+
for _sample_idx in range(len(media[name])):
|
| 1262 |
+
for _mm_idx in range(len(mm_info["audio_info"][_sample_idx])):
|
| 1263 |
+
_new_audio_chunk_length = mm_info["audio_info"][_sample_idx][_mm_idx]["new_audio_chunk_length"]
|
| 1264 |
+
all_audio_chunk_lengths.append(_new_audio_chunk_length)
|
| 1265 |
+
cur_batch_max_audio_duration = max(all_audio_chunk_lengths)
|
| 1266 |
+
cur_batch_max_audio_samples = cur_batch_max_audio_duration * self.config.audio_sampling_rate
|
| 1267 |
+
# for qwen omni audio
|
| 1268 |
+
# cur_batch_max_audio_samples = 960000
|
| 1269 |
+
|
| 1270 |
+
whisper_feature_extractor = WhisperFeatureExtractor.from_pretrained(
|
| 1271 |
+
self.config._name_or_path, chunk_length=cur_batch_max_audio_duration, sampling_rate=self.config.audio_sampling_rate, hop_length=self.config.audio_hop_length
|
| 1272 |
+
)
|
| 1273 |
+
|
| 1274 |
+
new_media = []
|
| 1275 |
+
_idx = 0
|
| 1276 |
+
assert len(all_audio_chunk_lengths) == len(media[name]), "The number of audio chunk lengths does not match the number of audio samples."
|
| 1277 |
+
|
| 1278 |
+
_media = media.get(name, [])
|
| 1279 |
+
aud_idx = 0
|
| 1280 |
+
for _batch_idx in range(len(mm_info["audio_info"])):
|
| 1281 |
+
_audio_info = mm_info["audio_info"][_batch_idx]
|
| 1282 |
+
if _audio_info is not None:
|
| 1283 |
+
for _mm_idx in range(len(_audio_info)):
|
| 1284 |
+
_audio = _media[aud_idx]
|
| 1285 |
+
if type(_audio) is torch.Tensor:
|
| 1286 |
+
device = _audio.device
|
| 1287 |
+
dtype = _audio.dtype
|
| 1288 |
+
_audio = _audio.cpu().float()
|
| 1289 |
+
else:
|
| 1290 |
+
device = self.device
|
| 1291 |
+
dtype = self.dtype
|
| 1292 |
+
_audio = whisper.pad_or_trim(_audio, length=cur_batch_max_audio_samples)
|
| 1293 |
+
aud_idx += 1
|
| 1294 |
+
stft_features = whisper_feature_extractor(
|
| 1295 |
+
_audio,
|
| 1296 |
+
sampling_rate=self.config.audio_sampling_rate,
|
| 1297 |
+
return_attention_mask=True,
|
| 1298 |
+
padding="max_length",
|
| 1299 |
+
return_tensors="pt",
|
| 1300 |
+
).to(device, dtype)
|
| 1301 |
+
|
| 1302 |
+
new_media.append(stft_features)
|
| 1303 |
+
if _audio_info[_mm_idx] != "dummy":
|
| 1304 |
+
_audio_info[_mm_idx]["new_audio_chunk_length"] = cur_batch_max_audio_duration
|
| 1305 |
+
_audio_info[_mm_idx]["new_audio_n_samples"] = cur_batch_max_audio_samples
|
| 1306 |
+
_audio_info[_mm_idx]["audio_end_sample_sec"] = _audio_info[_mm_idx]["audio_start_sec"] + cur_batch_max_audio_duration
|
| 1307 |
+
_audio_info[_mm_idx]["new_audio_n_stft_frames"] = stft_features["input_features"].shape[-1]
|
| 1308 |
+
media[name] = new_media
|
| 1309 |
+
|
| 1310 |
+
if len(media[name]) > 0:
|
| 1311 |
+
embeds[name] = deque(_encoder(media[name], media_config[name], mm_info))
|
| 1312 |
+
return embeds
|
| 1313 |
+
|
| 1314 |
+
def __truncate_sequence(
|
| 1315 |
+
self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
|
| 1316 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 1317 |
+
if self.training and any(len(input) > self.tokenizer.model_max_length for input in inputs):
|
| 1318 |
+
warnings.warn(f"Truncating sequences to `model_max_length` ({self.tokenizer.model_max_length}).")
|
| 1319 |
+
inputs = [input[: self.tokenizer.model_max_length] for input in inputs]
|
| 1320 |
+
labels = [label[: self.tokenizer.model_max_length] for label in labels]
|
| 1321 |
+
return inputs, labels
|
| 1322 |
+
|
| 1323 |
+
def __batchify_sequence(
|
| 1324 |
+
self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
|
| 1325 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 1326 |
+
batch_size = len(inputs)
|
| 1327 |
+
device = inputs[0].device
|
| 1328 |
+
hidden_size = inputs[0].shape[1]
|
| 1329 |
+
max_length = max(inputs[k].shape[0] for k in range(batch_size))
|
| 1330 |
+
attention_mask = torch.ones((batch_size, max_length), dtype=torch.bool, device=device)
|
| 1331 |
+
|
| 1332 |
+
inputs_p, labels_p = [], []
|
| 1333 |
+
for k in range(batch_size):
|
| 1334 |
+
size_pk = max_length - inputs[k].shape[0]
|
| 1335 |
+
inputs_pk = torch.zeros((size_pk, hidden_size), dtype=inputs[k].dtype, device=device)
|
| 1336 |
+
labels_pk = torch.full((size_pk,), IGNORE_INDEX, dtype=labels[k].dtype, device=device)
|
| 1337 |
+
if self.tokenizer.padding_side == "right":
|
| 1338 |
+
attention_mask[k, inputs[k].shape[0] :] = False
|
| 1339 |
+
inputs_pk = torch.cat([inputs[k], inputs_pk], dim=0)
|
| 1340 |
+
labels_pk = torch.cat([labels[k], labels_pk], dim=0)
|
| 1341 |
+
else:
|
| 1342 |
+
labels[k] = labels[k].to(device)
|
| 1343 |
+
attention_mask[k, : -inputs[k].shape[0]] = False
|
| 1344 |
+
inputs_pk = torch.cat([inputs_pk, inputs[k]], dim=0)
|
| 1345 |
+
labels_pk = torch.cat([labels_pk, labels[k]], dim=0)
|
| 1346 |
+
inputs_p.append(inputs_pk)
|
| 1347 |
+
labels_p.append(labels_pk)
|
| 1348 |
+
|
| 1349 |
+
inputs = torch.stack(inputs_p, dim=0)
|
| 1350 |
+
labels = torch.stack(labels_p, dim=0)
|
| 1351 |
+
return inputs, labels, attention_mask
|
| 1352 |
+
|
| 1353 |
+
def repack_multimodal_data(self, inputs_embeds, attention_mask, position_ids, labels):
|
| 1354 |
+
# Handle sequence parallelism
|
| 1355 |
+
PROCESS_GROUP_MANAGER = None
|
| 1356 |
+
|
| 1357 |
+
# We do re-sharding instead of packing here to ensure the sequence length is the same across all ranks.
|
| 1358 |
+
if PROCESS_GROUP_MANAGER is not None:
|
| 1359 |
+
sp_degree = PROCESS_GROUP_MANAGER.sp_degree
|
| 1360 |
+
sp_rank = PROCESS_GROUP_MANAGER.sp_rank
|
| 1361 |
+
sp_group = PROCESS_GROUP_MANAGER.sp_pg
|
| 1362 |
+
ring_degree = PROCESS_GROUP_MANAGER.ring_degree
|
| 1363 |
+
ring_rank = PROCESS_GROUP_MANAGER.ring_rank
|
| 1364 |
+
ring_type = PROCESS_GROUP_MANAGER.ring_type
|
| 1365 |
+
ulysses_degree = PROCESS_GROUP_MANAGER.ulysses_degree
|
| 1366 |
+
ulysses_rank = PROCESS_GROUP_MANAGER.ulysses_rank
|
| 1367 |
+
|
| 1368 |
+
bs, shard_seqlen = position_ids.shape
|
| 1369 |
+
sp_seq_len = [torch.zeros(1, dtype=torch.int64, device=position_ids.device) for _ in range(sp_degree)]
|
| 1370 |
+
dist.all_gather(sp_seq_len, torch.tensor(shard_seqlen, device=position_ids.device), group=sp_group)
|
| 1371 |
+
sp_seq_len_cat = torch.cat(sp_seq_len, dim=0)
|
| 1372 |
+
|
| 1373 |
+
if sp_rank == 0:
|
| 1374 |
+
original_start_id = 0
|
| 1375 |
+
else:
|
| 1376 |
+
original_start_id = torch.sum(sp_seq_len_cat[:sp_rank]).item()
|
| 1377 |
+
original_end_id = torch.sum(sp_seq_len_cat[: sp_rank + 1]).item()
|
| 1378 |
+
|
| 1379 |
+
# Gather attention_mask, position_ids, labels and input_embeds
|
| 1380 |
+
all_inputs_embeds = torch.zeros(
|
| 1381 |
+
bs,
|
| 1382 |
+
torch.sum(sp_seq_len_cat),
|
| 1383 |
+
inputs_embeds.shape[-1],
|
| 1384 |
+
dtype=inputs_embeds.dtype,
|
| 1385 |
+
device=inputs_embeds.device,
|
| 1386 |
+
).contiguous()
|
| 1387 |
+
all_inputs_embeds[:, original_start_id:original_end_id, :] += inputs_embeds
|
| 1388 |
+
dist.barrier(group=sp_group)
|
| 1389 |
+
dist.all_reduce(all_inputs_embeds, group=sp_group)
|
| 1390 |
+
dist.barrier(group=sp_group)
|
| 1391 |
+
|
| 1392 |
+
attention_mask_list = [
|
| 1393 |
+
torch.zeros((bs, sp_seq_len[i]), dtype=attention_mask.dtype, device=attention_mask.device)
|
| 1394 |
+
for i in range(sp_degree)
|
| 1395 |
+
]
|
| 1396 |
+
position_ids_list = [
|
| 1397 |
+
torch.zeros((bs, sp_seq_len[i]), dtype=position_ids.dtype, device=position_ids.device)
|
| 1398 |
+
for i in range(sp_degree)
|
| 1399 |
+
]
|
| 1400 |
+
labels_list = [
|
| 1401 |
+
torch.zeros((bs, sp_seq_len[i]), dtype=labels.dtype, device=labels.device) for i in range(sp_degree)
|
| 1402 |
+
]
|
| 1403 |
+
|
| 1404 |
+
dist.all_gather(attention_mask_list, attention_mask, group=sp_group)
|
| 1405 |
+
dist.all_gather(position_ids_list, position_ids, group=sp_group)
|
| 1406 |
+
dist.all_gather(labels_list, labels, group=sp_group)
|
| 1407 |
+
|
| 1408 |
+
effective_seqlen_list = [attention_mask_list[i].sum(dim=-1) for i in range(sp_degree)]
|
| 1409 |
+
effective_seqlen = torch.stack(effective_seqlen_list, dim=-1)
|
| 1410 |
+
effective_seqlen_batch_list = torch.unbind(effective_seqlen, dim=0)
|
| 1411 |
+
|
| 1412 |
+
global_attention_mask_list = []
|
| 1413 |
+
global_position_ids_list = []
|
| 1414 |
+
global_labels_list = []
|
| 1415 |
+
global_inputs_embeds_list = []
|
| 1416 |
+
for i in range(bs):
|
| 1417 |
+
global_attention_mask_batch_list = []
|
| 1418 |
+
global_position_ids_batch_list = []
|
| 1419 |
+
global_labels_batch_list = []
|
| 1420 |
+
global_inputs_embeds_batch_list = []
|
| 1421 |
+
for j in range(sp_degree):
|
| 1422 |
+
eff_len = effective_seqlen_batch_list[i][j]
|
| 1423 |
+
prev_len = torch.sum(sp_seq_len_cat[:j]).item() if j > 0 else 0
|
| 1424 |
+
|
| 1425 |
+
global_attention_mask_batch_list.append(attention_mask_list[j][i, :eff_len])
|
| 1426 |
+
global_position_ids_batch_list.append(position_ids_list[j][i, :eff_len])
|
| 1427 |
+
global_labels_batch_list.append(labels_list[j][i, :eff_len])
|
| 1428 |
+
global_inputs_embeds_batch_list.append(all_inputs_embeds[i, prev_len : prev_len + eff_len, :])
|
| 1429 |
+
global_attention_mask_list.append(torch.cat(global_attention_mask_batch_list, dim=0))
|
| 1430 |
+
global_position_ids_list.append(torch.cat(global_position_ids_batch_list, dim=0))
|
| 1431 |
+
global_labels_list.append(torch.cat(global_labels_batch_list, dim=0))
|
| 1432 |
+
global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0))
|
| 1433 |
+
|
| 1434 |
+
global_attention_mask = torch.nn.utils.rnn.pad_sequence(
|
| 1435 |
+
global_attention_mask_list, batch_first=True, padding_value=False
|
| 1436 |
+
)
|
| 1437 |
+
global_position_ids = torch.nn.utils.rnn.pad_sequence(
|
| 1438 |
+
global_position_ids_list, batch_first=True, padding_value=-1
|
| 1439 |
+
)
|
| 1440 |
+
global_labels = torch.nn.utils.rnn.pad_sequence(
|
| 1441 |
+
global_labels_list, batch_first=True, padding_value=IGNORE_INDEX
|
| 1442 |
+
)
|
| 1443 |
+
global_inputs_embeds = torch.nn.utils.rnn.pad_sequence(
|
| 1444 |
+
global_inputs_embeds_list, batch_first=True, padding_value=0
|
| 1445 |
+
)
|
| 1446 |
+
|
| 1447 |
+
# Re-shard the inputs
|
| 1448 |
+
if ring_degree > 1:
|
| 1449 |
+
total_effective_seqlen = torch.sum(effective_seqlen, dim=1)
|
| 1450 |
+
new_seqlen_per_rank = total_effective_seqlen // sp_degree
|
| 1451 |
+
assert torch.all(
|
| 1452 |
+
total_effective_seqlen % sp_degree == 0
|
| 1453 |
+
), "total_effective_seqlen must be divisible by sp_degree"
|
| 1454 |
+
|
| 1455 |
+
max_new_seqlen = torch.max(new_seqlen_per_rank).item()
|
| 1456 |
+
|
| 1457 |
+
new_attention_mask = torch.zeros(
|
| 1458 |
+
(bs, max_new_seqlen), dtype=global_attention_mask.dtype, device=global_attention_mask.device
|
| 1459 |
+
)
|
| 1460 |
+
new_position_ids = torch.zeros(
|
| 1461 |
+
(bs, max_new_seqlen), dtype=global_position_ids.dtype, device=global_position_ids.device
|
| 1462 |
+
)
|
| 1463 |
+
new_labels = torch.full(
|
| 1464 |
+
(bs, max_new_seqlen), IGNORE_INDEX, dtype=global_labels.dtype, device=global_labels.device
|
| 1465 |
+
)
|
| 1466 |
+
new_inputs_embeds = torch.zeros(
|
| 1467 |
+
(bs, max_new_seqlen, global_inputs_embeds.shape[-1]),
|
| 1468 |
+
dtype=global_inputs_embeds.dtype,
|
| 1469 |
+
device=global_inputs_embeds.device,
|
| 1470 |
+
)
|
| 1471 |
+
|
| 1472 |
+
if ring_type == "ring_varlen":
|
| 1473 |
+
for i in range(bs):
|
| 1474 |
+
start_idx = new_seqlen_per_rank[i] * sp_rank
|
| 1475 |
+
end_idx = start_idx + new_seqlen_per_rank[i]
|
| 1476 |
+
new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx]
|
| 1477 |
+
new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx]
|
| 1478 |
+
new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx]
|
| 1479 |
+
new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[
|
| 1480 |
+
i, start_idx:end_idx, :
|
| 1481 |
+
]
|
| 1482 |
+
elif ring_type == "zigzag_ring_varlen":
|
| 1483 |
+
chunk_size = total_effective_seqlen // (2 * sp_degree)
|
| 1484 |
+
for i in range(bs):
|
| 1485 |
+
# Zigzag pattern indices
|
| 1486 |
+
if sp_degree == ring_degree:
|
| 1487 |
+
forward_rank_idx = sp_rank
|
| 1488 |
+
backward_rank_idx = 2 * sp_degree - sp_rank - 1
|
| 1489 |
+
else:
|
| 1490 |
+
ulysses_offset = ulysses_rank * ring_degree * 2
|
| 1491 |
+
forward_rank_idx = ring_rank + ulysses_offset
|
| 1492 |
+
backward_rank_idx = sp_degree - ring_rank - 1 + ulysses_offset
|
| 1493 |
+
|
| 1494 |
+
# Calculate start and end indices for the forward and backward zigzag
|
| 1495 |
+
start_idx_fwd = forward_rank_idx * chunk_size[i]
|
| 1496 |
+
end_idx_fwd = start_idx_fwd + chunk_size[i]
|
| 1497 |
+
|
| 1498 |
+
start_idx_bwd = backward_rank_idx * chunk_size[i]
|
| 1499 |
+
end_idx_bwd = start_idx_bwd + chunk_size[i]
|
| 1500 |
+
|
| 1501 |
+
# Fill new tensors with zigzag data
|
| 1502 |
+
new_attention_mask[i, : chunk_size[i]] = global_attention_mask[i, start_idx_fwd:end_idx_fwd]
|
| 1503 |
+
new_attention_mask[i, chunk_size[i] : 2 * chunk_size[i]] = global_attention_mask[
|
| 1504 |
+
i, start_idx_bwd:end_idx_bwd
|
| 1505 |
+
]
|
| 1506 |
+
|
| 1507 |
+
new_position_ids[i, : chunk_size[i]] = global_position_ids[i, start_idx_fwd:end_idx_fwd]
|
| 1508 |
+
new_position_ids[i, chunk_size[i] : 2 * chunk_size[i]] = global_position_ids[
|
| 1509 |
+
i, start_idx_bwd:end_idx_bwd
|
| 1510 |
+
]
|
| 1511 |
+
|
| 1512 |
+
new_labels[i, : chunk_size[i]] = global_labels[i, start_idx_fwd:end_idx_fwd]
|
| 1513 |
+
new_labels[i, chunk_size[i] : 2 * chunk_size[i]] = global_labels[i, start_idx_bwd:end_idx_bwd]
|
| 1514 |
+
|
| 1515 |
+
new_inputs_embeds[i, : chunk_size[i], :] = global_inputs_embeds[i, start_idx_fwd:end_idx_fwd, :]
|
| 1516 |
+
new_inputs_embeds[i, chunk_size[i] : 2 * chunk_size[i], :] = global_inputs_embeds[
|
| 1517 |
+
i, start_idx_bwd:end_idx_bwd, :
|
| 1518 |
+
]
|
| 1519 |
+
else:
|
| 1520 |
+
raise ValueError(f"Invalid ring_type: {ring_type}")
|
| 1521 |
+
else:
|
| 1522 |
+
global_seq_len = global_attention_mask.shape[-1]
|
| 1523 |
+
seq_len_sharded = global_seq_len // sp_degree
|
| 1524 |
+
start_idx_reshard = seq_len_sharded * sp_rank
|
| 1525 |
+
end_idx_reshard = start_idx_reshard + seq_len_sharded if sp_rank < sp_degree - 1 else global_seq_len
|
| 1526 |
+
|
| 1527 |
+
new_attention_mask = torch.narrow(
|
| 1528 |
+
global_attention_mask, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
|
| 1529 |
+
)
|
| 1530 |
+
new_position_ids = torch.narrow(
|
| 1531 |
+
global_position_ids, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
|
| 1532 |
+
)
|
| 1533 |
+
new_labels = torch.narrow(global_labels, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)
|
| 1534 |
+
new_inputs_embeds = torch.narrow(
|
| 1535 |
+
global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
|
| 1536 |
+
)
|
| 1537 |
+
|
| 1538 |
+
return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels
|
| 1539 |
+
|
| 1540 |
+
device = inputs_embeds.device
|
| 1541 |
+
batch_size = inputs_embeds.shape[0]
|
| 1542 |
+
seqlens = [attention_mask[k].sum().item() for k in range(batch_size)]
|
| 1543 |
+
|
| 1544 |
+
# Pack all sequences together
|
| 1545 |
+
inputs_embeds_p = [inputs_embeds[k][attention_mask[k]] for k in range(batch_size)]
|
| 1546 |
+
attention_mask_p = [torch.ones(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
|
| 1547 |
+
position_ids_p = [torch.arange(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
|
| 1548 |
+
labels_p = [labels[k][attention_mask[k]] for k in range(batch_size)]
|
| 1549 |
+
|
| 1550 |
+
# Add one dummy token at the end of the packed sequence to ensure that `_get_unpacked_data` will be called
|
| 1551 |
+
inputs_embeds_p.append(torch.zeros(1, inputs_embeds.shape[-1], dtype=inputs_embeds.dtype, device=device))
|
| 1552 |
+
attention_mask_p.append(torch.tensor([0], dtype=torch.int, device=device))
|
| 1553 |
+
position_ids_p.append(torch.tensor([0], dtype=torch.int, device=device))
|
| 1554 |
+
labels_p.append(torch.tensor([IGNORE_INDEX], dtype=torch.int, device=device))
|
| 1555 |
+
|
| 1556 |
+
# Mask the first token of each sequence to avoid contamination
|
| 1557 |
+
for label in labels_p:
|
| 1558 |
+
label[0] = IGNORE_INDEX
|
| 1559 |
+
|
| 1560 |
+
# Batch the data
|
| 1561 |
+
inputs_embeds_p = torch.cat(inputs_embeds_p, dim=0).unsqueeze(0)
|
| 1562 |
+
attention_mask_p = torch.cat(attention_mask_p, dim=0).unsqueeze(0)
|
| 1563 |
+
position_ids_p = torch.cat(position_ids_p, dim=0).unsqueeze(0)
|
| 1564 |
+
labels_p = torch.cat(labels_p, dim=0).unsqueeze(0)
|
| 1565 |
+
|
| 1566 |
+
if hasattr(
|
| 1567 |
+
self, "pad_to_multiple_of"
|
| 1568 |
+
): # related to quantization, please refer to ModelArguments for more information.
|
| 1569 |
+
assert len(labels_p.shape) == 2
|
| 1570 |
+
batch_size, max_length, cur_length = labels_p.shape[0], labels_p.shape[1], labels_p.shape[1]
|
| 1571 |
+
hidden_size = inputs_embeds_p.shape[-1]
|
| 1572 |
+
|
| 1573 |
+
if max_length % self.pad_to_multiple_of != 0:
|
| 1574 |
+
max_length = ((max_length // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of
|
| 1575 |
+
difference = max_length - cur_length
|
| 1576 |
+
|
| 1577 |
+
inputs_embeds_p = torch.cat(
|
| 1578 |
+
(
|
| 1579 |
+
inputs_embeds_p,
|
| 1580 |
+
torch.full((batch_size, difference, hidden_size), self.llm.pad_token_id).to(inputs_embeds_p),
|
| 1581 |
+
),
|
| 1582 |
+
dim=1,
|
| 1583 |
+
)
|
| 1584 |
+
labels_p = torch.cat((labels_p, torch.full((batch_size, difference), IGNORE_INDEX).to(labels_p)), dim=1)
|
| 1585 |
+
attention_mask_p = torch.cat(
|
| 1586 |
+
(
|
| 1587 |
+
attention_mask_p,
|
| 1588 |
+
torch.zeros((batch_size, difference), dtype=torch.bool).to(attention_mask_p),
|
| 1589 |
+
),
|
| 1590 |
+
dim=1,
|
| 1591 |
+
)
|
| 1592 |
+
position_ids_p = torch.cat(
|
| 1593 |
+
(position_ids_p, torch.full((batch_size, difference), -1).to(position_ids_p)), dim=1
|
| 1594 |
+
)
|
| 1595 |
+
|
| 1596 |
+
return inputs_embeds_p, attention_mask_p, position_ids_p, labels_p
|
| 1597 |
+
|
| 1598 |
+
def forward(
|
| 1599 |
+
self,
|
| 1600 |
+
input_ids: torch.LongTensor = None,
|
| 1601 |
+
media: Optional[Dict[str, List[torch.Tensor]]] = None,
|
| 1602 |
+
images: Optional[torch.FloatTensor] = None,
|
| 1603 |
+
media_config: Optional[List] = None,
|
| 1604 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1605 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1606 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1607 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1608 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1609 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1610 |
+
packing: bool = True,
|
| 1611 |
+
force_packing: bool = False,
|
| 1612 |
+
seqlens_in_batch: Optional[torch.LongTensor] = None,
|
| 1613 |
+
dpo_forward: bool = False,
|
| 1614 |
+
**kwargs,
|
| 1615 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1616 |
+
self.freezed_module_patch()
|
| 1617 |
+
|
| 1618 |
+
if images is not None:
|
| 1619 |
+
if media is not None:
|
| 1620 |
+
raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
|
| 1621 |
+
print("The 'images' argument is deprecated. Please use 'media' instead.")
|
| 1622 |
+
media = {"image": images}
|
| 1623 |
+
|
| 1624 |
+
if media_config is None:
|
| 1625 |
+
media_config = defaultdict(dict)
|
| 1626 |
+
|
| 1627 |
+
if inputs_embeds is None:
|
| 1628 |
+
inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask)
|
| 1629 |
+
|
| 1630 |
+
if force_packing or (packing and self.training and not dpo_forward):
|
| 1631 |
+
if seqlens_in_batch is None:
|
| 1632 |
+
seqlens_in_batch = torch.sum(attention_mask, dim=1)
|
| 1633 |
+
set_seqlens_in_batch(seqlens_in_batch)
|
| 1634 |
+
|
| 1635 |
+
(inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data(
|
| 1636 |
+
inputs_embeds, attention_mask, position_ids, labels
|
| 1637 |
+
)
|
| 1638 |
+
|
| 1639 |
+
outputs = self.llm(
|
| 1640 |
+
inputs_embeds=inputs_embeds,
|
| 1641 |
+
attention_mask=attention_mask,
|
| 1642 |
+
position_ids=position_ids,
|
| 1643 |
+
past_key_values=past_key_values,
|
| 1644 |
+
labels=labels,
|
| 1645 |
+
**kwargs,
|
| 1646 |
+
)
|
| 1647 |
+
|
| 1648 |
+
if self.training and getattr(self.config, "time_token_ids", []):
|
| 1649 |
+
outputs.loss = soft_cross_entropy(
|
| 1650 |
+
outputs.logits,
|
| 1651 |
+
labels,
|
| 1652 |
+
soft_tokens=self.config.time_token_ids,
|
| 1653 |
+
std=self.config.soft_ce_std,
|
| 1654 |
+
)
|
| 1655 |
+
|
| 1656 |
+
if dpo_forward:
|
| 1657 |
+
return outputs.logits, labels
|
| 1658 |
+
|
| 1659 |
+
return outputs
|
| 1660 |
+
|
| 1661 |
+
@torch.inference_mode()
|
| 1662 |
+
def generate(
|
| 1663 |
+
self,
|
| 1664 |
+
input_ids: Optional[torch.FloatTensor] = None,
|
| 1665 |
+
media: Optional[Dict[str, List[torch.Tensor]]] = None,
|
| 1666 |
+
media_config: Dict[str, Dict[str, Any]] = None,
|
| 1667 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1668 |
+
return_output_ids_only: bool = True,
|
| 1669 |
+
**generation_kwargs,
|
| 1670 |
+
) -> torch.LongTensor:
|
| 1671 |
+
"""
|
| 1672 |
+
input_tokens: <image> describe the image
|
| 1673 |
+
media: [Tensor(1, 3, 384, 384), ]
|
| 1674 |
+
----------->
|
| 1675 |
+
input_tokens: 36000 001 002 003 004
|
| 1676 |
+
input_emds: <media emd> 001 002 003 004
|
| 1677 |
+
"""
|
| 1678 |
+
inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask)
|
| 1679 |
+
|
| 1680 |
+
output_ids = self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
|
| 1681 |
+
|
| 1682 |
+
if return_output_ids_only:
|
| 1683 |
+
return_value = output_ids
|
| 1684 |
+
else:
|
| 1685 |
+
# by default, return the input_ids and output_ids concatenated to keep consistency with the community VLMs like qwen
|
| 1686 |
+
generation_config = generation_kwargs.get("generation_config", None)
|
| 1687 |
+
if generation_config is not None:
|
| 1688 |
+
num_generations = generation_config.num_return_sequences
|
| 1689 |
+
repeat_input_ids = input_ids.repeat_interleave(num_generations, dim=0)
|
| 1690 |
+
return_value = torch.cat([repeat_input_ids, output_ids], dim=-1)
|
| 1691 |
+
else:
|
| 1692 |
+
return_value = torch.cat([input_ids, output_ids], dim=-1)
|
| 1693 |
+
|
| 1694 |
+
return return_value
|
| 1695 |
+
|
| 1696 |
+
@torch.inference_mode()
|
| 1697 |
+
def generate_content(
|
| 1698 |
+
self,
|
| 1699 |
+
prompt: Union[str, List],
|
| 1700 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 1701 |
+
response_format=None,
|
| 1702 |
+
) -> str:
|
| 1703 |
+
conversation = [{"from": "human", "value": prompt}]
|
| 1704 |
+
|
| 1705 |
+
# Convert response format to logits processor
|
| 1706 |
+
xgr_logits_processor = None
|
| 1707 |
+
|
| 1708 |
+
# Extract media from the conversation
|
| 1709 |
+
|
| 1710 |
+
media = extract_media(conversation, self.config)
|
| 1711 |
+
|
| 1712 |
+
# Process media
|
| 1713 |
+
media_config = defaultdict(dict)
|
| 1714 |
+
for name in media:
|
| 1715 |
+
if name == "image":
|
| 1716 |
+
if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
|
| 1717 |
+
self.config.image_processor = self.vision_tower.image_processor
|
| 1718 |
+
if self.config.image_aspect_ratio == "dynamic":
|
| 1719 |
+
images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
|
| 1720 |
+
conversation[0]["value"] = conversation[0]["value"].replace(
|
| 1721 |
+
DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
|
| 1722 |
+
)
|
| 1723 |
+
else:
|
| 1724 |
+
if type(self.config.s2_scales) is str:
|
| 1725 |
+
self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
|
| 1726 |
+
images, block_sizes = process_image(
|
| 1727 |
+
media["image"][0], self.config, None, enable_dynamic_s2=True
|
| 1728 |
+
)
|
| 1729 |
+
images = images.half()
|
| 1730 |
+
media_config[name]["block_sizes"] = [block_sizes]
|
| 1731 |
+
else:
|
| 1732 |
+
images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
|
| 1733 |
+
media[name] = [image for image in images]
|
| 1734 |
+
elif name == "video":
|
| 1735 |
+
if self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1:
|
| 1736 |
+
media[name] = [
|
| 1737 |
+
process_images(
|
| 1738 |
+
images,
|
| 1739 |
+
self.vision_tower.image_processor,
|
| 1740 |
+
self.config,
|
| 1741 |
+
enable_dynamic_res=True,
|
| 1742 |
+
max_tiles=self.config.video_max_tiles,
|
| 1743 |
+
).half()
|
| 1744 |
+
for images in media[name]
|
| 1745 |
+
]
|
| 1746 |
+
elif self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1:
|
| 1747 |
+
self.config.image_processor = self.vision_tower.image_processor
|
| 1748 |
+
if type(self.config.s2_scales) is str:
|
| 1749 |
+
self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
|
| 1750 |
+
media[name] = [
|
| 1751 |
+
torch.cat(
|
| 1752 |
+
[
|
| 1753 |
+
process_image(
|
| 1754 |
+
image,
|
| 1755 |
+
self.config,
|
| 1756 |
+
None,
|
| 1757 |
+
enable_dynamic_s2=True,
|
| 1758 |
+
max_tiles=self.config.video_max_tiles,
|
| 1759 |
+
)[0].half()
|
| 1760 |
+
for image in images
|
| 1761 |
+
]
|
| 1762 |
+
)
|
| 1763 |
+
for images in media[name]
|
| 1764 |
+
]
|
| 1765 |
+
else:
|
| 1766 |
+
media[name] = [
|
| 1767 |
+
process_images(images, self.vision_tower.image_processor, self.config)
|
| 1768 |
+
for images in media[name]
|
| 1769 |
+
]
|
| 1770 |
+
elif name == "speech":
|
| 1771 |
+
speeches = media["speech"]
|
| 1772 |
+
media[name] = [speech for speech in speeches]
|
| 1773 |
+
elif name == "sound":
|
| 1774 |
+
# sounds = process_sounds(media["sound"]).half()
|
| 1775 |
+
sounds = media["sound"]
|
| 1776 |
+
# media[name] = [{k: v.half() for sound in sounds for k, v in sound.items()]
|
| 1777 |
+
for sound in sounds:
|
| 1778 |
+
if type(sound) is dict:
|
| 1779 |
+
for k, v in sound.items():
|
| 1780 |
+
sound[k] = v.half()
|
| 1781 |
+
media[name] = [sound for sound in sounds]
|
| 1782 |
+
elif name == "video_info":
|
| 1783 |
+
media[name] = [media["video_info"]]
|
| 1784 |
+
elif name == "audio_info":
|
| 1785 |
+
media[name] = [media["audio_info"]]
|
| 1786 |
+
else:
|
| 1787 |
+
raise ValueError(f"Unsupported media type: {name}")
|
| 1788 |
+
|
| 1789 |
+
# Tokenize the conversation
|
| 1790 |
+
input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).unsqueeze(0).cuda()
|
| 1791 |
+
|
| 1792 |
+
# Set up the generation config
|
| 1793 |
+
generation_config = generation_config or self.default_generation_config
|
| 1794 |
+
|
| 1795 |
+
# Generate the response
|
| 1796 |
+
try:
|
| 1797 |
+
output_ids = self.generate(
|
| 1798 |
+
input_ids=input_ids,
|
| 1799 |
+
media=media,
|
| 1800 |
+
media_config=media_config,
|
| 1801 |
+
generation_config=generation_config,
|
| 1802 |
+
logits_processor=xgr_logits_processor, # structured generation
|
| 1803 |
+
)
|
| 1804 |
+
except ValueError:
|
| 1805 |
+
if not generation_config.do_sample:
|
| 1806 |
+
raise
|
| 1807 |
+
logging.warning("Generation failed with sampling, retrying with greedy decoding.")
|
| 1808 |
+
generation_config.do_sample = False
|
| 1809 |
+
output_ids = self.generate(
|
| 1810 |
+
input_ids=input_ids,
|
| 1811 |
+
media=media,
|
| 1812 |
+
media_config=media_config,
|
| 1813 |
+
generation_config=generation_config,
|
| 1814 |
+
logits_processor=xgr_logits_processor,
|
| 1815 |
+
)
|
| 1816 |
+
|
| 1817 |
+
# Decode the response
|
| 1818 |
+
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
|
| 1819 |
+
return response
|
| 1820 |
+
|
| 1821 |
+
@property
|
| 1822 |
+
def default_generation_config(self) -> GenerationConfig:
|
| 1823 |
+
generation_config = copy.deepcopy(self.generation_config or GenerationConfig())
|
| 1824 |
+
if self.tokenizer.eos_token_id is None:
|
| 1825 |
+
raise ValueError("Tokenizer must have an EOS token")
|
| 1826 |
+
if generation_config.max_length == GenerationConfig().max_length:
|
| 1827 |
+
generation_config.max_length = self.tokenizer.model_max_length
|
| 1828 |
+
if generation_config.pad_token_id is None:
|
| 1829 |
+
generation_config.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
|
| 1830 |
+
if generation_config.bos_token_id is None:
|
| 1831 |
+
generation_config.bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
| 1832 |
+
if generation_config.eos_token_id is None:
|
| 1833 |
+
generation_config.eos_token_id = self.tokenizer.eos_token_id
|
| 1834 |
+
return generation_config
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"chunk_length": 30,
|
| 3 |
+
"feature_extractor_type": "WhisperFeatureExtractor",
|
| 4 |
+
"feature_size": 128,
|
| 5 |
+
"hop_length": 160,
|
| 6 |
+
"n_fft": 400,
|
| 7 |
+
"n_samples": 480000,
|
| 8 |
+
"nb_max_frames": 3000,
|
| 9 |
+
"padding_side": "right",
|
| 10 |
+
"padding_value": 0.0,
|
| 11 |
+
"return_attention_mask": true,
|
| 12 |
+
"sampling_rate": 16000
|
| 13 |
+
}
|
pyproject.toml
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "omnivinci"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "omnivinci"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
classifiers = [
|
| 12 |
+
"Programming Language :: Python :: 3",
|
| 13 |
+
"License :: OSI Approved :: Apache Software License",
|
| 14 |
+
]
|
| 15 |
+
dependencies = [
|
| 16 |
+
"torch==2.3.0", "torchvision==0.18.0",
|
| 17 |
+
"transformers==4.46.0", "tokenizers>=0.15.2", "sentencepiece==0.1.99", "shortuuid",
|
| 18 |
+
"accelerate==0.34.2", "peft>=0.9.0", "bitsandbytes==0.43.2",
|
| 19 |
+
"pydantic<2,>=1", "markdown2[all]", "numpy==1.26.4", "scikit-learn==1.2.2",
|
| 20 |
+
"gradio==3.35.2", "gradio_client==0.2.9",
|
| 21 |
+
"requests", "httpx", "uvicorn", "fastapi", "fire", "seaborn", "ring_flash_attn==0.1.1",
|
| 22 |
+
"einops==0.6.1", "einops-exts==0.0.4", "timm==0.9.12",
|
| 23 |
+
"openpyxl==3.1.2", "pytorchvideo==0.1.5", "decord==0.6.0",
|
| 24 |
+
"datasets==2.16.1", "openai==1.8.0", "webdataset==0.2.86",
|
| 25 |
+
"nltk==3.3", "pywsd==1.2.4", "opencv-python-headless==4.8.0.76",
|
| 26 |
+
"s2wrapper@git+https://github.com/bfshi/scaling_on_scales",
|
| 27 |
+
"tyro", "pytest", "pre-commit", "loguru", "hydra-core", "xgrammar"
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
[project.scripts]
|
| 31 |
+
|
| 32 |
+
[project.optional-dependencies]
|
| 33 |
+
train = ["deepspeed==0.9.5", "ninja", "wandb"]
|
| 34 |
+
eval = ["word2number", "Levenshtein", "nltk", "pywsd"]
|
| 35 |
+
|
| 36 |
+
[project.urls]
|
| 37 |
+
"Homepage" = "https://github.com/NVlabs/OmniVinci"
|
| 38 |
+
"Bug Tracker" = "https://github.com/NVlabs/OmniVinci"
|
| 39 |
+
|
| 40 |
+
[tool.triton]
|
| 41 |
+
triton = {version = "3.0.0.post20240610003544", file = "https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/07c94329-d4c3-4ad4-9e6b-f904a60032ec/pypi/download/triton-nightly/3.post20240610003544/triton_nightly-3.0.0.post20240610003544-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", sha256 = "ac2c36a49bf9c2bb780909b38096fb718f17efd78b88a1ca1d649f6d063cdc2c"}
|
| 42 |
+
|
| 43 |
+
[tool.black]
|
| 44 |
+
line-length = 120
|
| 45 |
+
|
| 46 |
+
[tool.isort]
|
| 47 |
+
profile = "black"
|
| 48 |
+
multi_line_output = 3
|
| 49 |
+
include_trailing_comma = true
|
| 50 |
+
force_grid_wrap = 0
|
| 51 |
+
use_parentheses = true
|
| 52 |
+
ensure_newline_before_comments = true
|
| 53 |
+
line_length = 120
|
| 54 |
+
|
| 55 |
+
[tool.setuptools.packages.find]
|
| 56 |
+
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
|
| 57 |
+
|
| 58 |
+
[tool.wheel]
|
| 59 |
+
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
|
qwen2.jinja
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{% if messages[0]['role'] != 'system' %}
|
| 2 |
+
{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}
|
| 3 |
+
{% endif %}
|
| 4 |
+
|
| 5 |
+
{% for message in messages if message['content'] is not none %}
|
| 6 |
+
{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}
|
| 7 |
+
{% endfor %}
|
| 8 |
+
|
| 9 |
+
{% if add_generation_prompt %}
|
| 10 |
+
{{ '<|im_start|>assistant\n' }}
|
| 11 |
+
{% endif %}
|
qwen_audio_encoder.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from transformers import PretrainedConfig, Qwen2AudioEncoder, Qwen2AudioForConditionalGeneration
|
| 18 |
+
|
| 19 |
+
from .audio_encoder import AudioTower
|
| 20 |
+
|
| 21 |
+
class Qwen2AudioTower(AudioTower):
|
| 22 |
+
def __init__(self, model_name_or_path: str, config: PretrainedConfig):
|
| 23 |
+
super().__init__(model_name_or_path, config)
|
| 24 |
+
self.audio_tower = Qwen2AudioEncoder.from_pretrained(model_name_or_path, attn_implementation="flash_attention_2")
|
| 25 |
+
self.is_loaded = True
|
| 26 |
+
self.audio_chunk_unit_duration = 30
|
| 27 |
+
self.audio_chunk_unit_length = 3000
|
| 28 |
+
|
| 29 |
+
def forward(self, sounds):
|
| 30 |
+
if type(sounds) is list:
|
| 31 |
+
sound_features = []
|
| 32 |
+
audio_output_lengths = []
|
| 33 |
+
for sound in sounds:
|
| 34 |
+
if hasattr(sound, "input_features") or (type(sound) is dict and "input_features" in sound):
|
| 35 |
+
sound = sound["input_features"]
|
| 36 |
+
|
| 37 |
+
sound_feature = self.forward_audio_tower_batch(sound)
|
| 38 |
+
sound_feature = sound_feature.to(sound.dtype)
|
| 39 |
+
sound_features.append(sound_feature)
|
| 40 |
+
audio_output_lengths.append(sound_feature.shape[1])
|
| 41 |
+
if len(sound_features) > 0:
|
| 42 |
+
sound_features = torch.cat(sound_features, dim=1).squeeze(0)
|
| 43 |
+
else:
|
| 44 |
+
raise NotImplementedError("Not implemented for this encoder")
|
| 45 |
+
|
| 46 |
+
return sound_features, audio_output_lengths
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def forward_audio_tower_batch(self, inp):
|
| 50 |
+
"""
|
| 51 |
+
Process long audio input by splitting into fixed-size chunks (30 seconds),
|
| 52 |
+
padding if needed, batching them together, and processing through the audio tower.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
inp: Tensor of shape (batch_size, n_mels, seq_len)
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Tensor of shape (batch_size, num_chunks * chunk_seq_len, hidden_size)
|
| 59 |
+
"""
|
| 60 |
+
batch_size, n_mels, seq_len = inp.shape
|
| 61 |
+
chunk_length = self.audio_chunk_unit_length
|
| 62 |
+
num_chunks = (seq_len + chunk_length - 1) // chunk_length # Ceiling division
|
| 63 |
+
|
| 64 |
+
padded_chunks = []
|
| 65 |
+
|
| 66 |
+
for i in range(num_chunks):
|
| 67 |
+
start_idx = i * chunk_length
|
| 68 |
+
end_idx = min(start_idx + chunk_length, seq_len)
|
| 69 |
+
|
| 70 |
+
# Extract and pad chunk if necessary
|
| 71 |
+
chunk = inp[:, :, start_idx:end_idx]
|
| 72 |
+
if chunk.shape[2] < chunk_length:
|
| 73 |
+
pad_len = chunk_length - chunk.shape[2]
|
| 74 |
+
chunk = torch.nn.functional.pad(chunk, (0, pad_len), mode='constant', value=0)
|
| 75 |
+
|
| 76 |
+
padded_chunks.append(chunk)
|
| 77 |
+
|
| 78 |
+
# Stack chunks along batch dimension
|
| 79 |
+
all_chunks = torch.cat(padded_chunks, dim=0).reshape(batch_size * num_chunks, n_mels, chunk_length)
|
| 80 |
+
|
| 81 |
+
# Forward pass through the audio tower
|
| 82 |
+
chunk_outputs = self.audio_tower(all_chunks)
|
| 83 |
+
hidden_states = chunk_outputs.last_hidden_state
|
| 84 |
+
|
| 85 |
+
# Reshape back to (batch_size, num_chunks * seq_len', hidden_size)
|
| 86 |
+
_, chunk_seq_len, hidden_size = hidden_states.shape
|
| 87 |
+
hidden_states = hidden_states.reshape(batch_size, num_chunks * chunk_seq_len, hidden_size)
|
| 88 |
+
|
| 89 |
+
return hidden_states
|
siglip_encoder.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from accelerate.hooks import add_hook_to_module
|
| 21 |
+
from einops import rearrange
|
| 22 |
+
from s2wrapper import forward as multiscale_forward
|
| 23 |
+
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, SiglipImageProcessor
|
| 24 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
| 25 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 26 |
+
from transformers.models.siglip import SiglipVisionModel
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class VisionTower(nn.Module):
|
| 30 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
self.is_loaded = False
|
| 34 |
+
|
| 35 |
+
self.vision_tower_name = vision_tower
|
| 36 |
+
self.select_layer = getattr(args, "mm_vision_select_layer", -2)
|
| 37 |
+
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
|
| 38 |
+
|
| 39 |
+
self.cfg_only = None
|
| 40 |
+
|
| 41 |
+
def feature_select(self, image_forward_outs):
|
| 42 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
| 43 |
+
if self.select_feature == "patch":
|
| 44 |
+
image_features = image_features[:, 1:]
|
| 45 |
+
elif self.select_feature == "cls_patch":
|
| 46 |
+
image_features = image_features
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
| 49 |
+
return image_features
|
| 50 |
+
|
| 51 |
+
def _maybe_resize_pos_embeds(
|
| 52 |
+
self,
|
| 53 |
+
model: PreTrainedModel,
|
| 54 |
+
image_processor: BaseImageProcessor,
|
| 55 |
+
resolution: int = -1,
|
| 56 |
+
interpolate_mode: str = "linear",
|
| 57 |
+
):
|
| 58 |
+
if resolution in [model.config.image_size, -1]:
|
| 59 |
+
return
|
| 60 |
+
print(
|
| 61 |
+
f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..."
|
| 62 |
+
)
|
| 63 |
+
embeddings = model.vision_model.embeddings
|
| 64 |
+
patch_size = embeddings.patch_size
|
| 65 |
+
num_new_tokens = int((resolution // patch_size) ** 2)
|
| 66 |
+
|
| 67 |
+
old_embeddings = embeddings.position_embedding
|
| 68 |
+
match interpolate_mode:
|
| 69 |
+
case "linear":
|
| 70 |
+
# Step 1: Calculate the corresponding patch ID (pid) in the current resolution (M patches) based on the target resolution (N patches). Formula: pid = pid / N * M
|
| 71 |
+
# Step 2: Obtain new embeddings by interpolating between the embeddings of the two nearest calculated patch IDs. Formula: new_embeds = (pid - floor(pid)) * embeds[ceil(pid)] + (ceil(pid) - pid) * embeds[floor(pid)]
|
| 72 |
+
import torch
|
| 73 |
+
import torch.nn as nn
|
| 74 |
+
|
| 75 |
+
if is_deepspeed_zero3_enabled():
|
| 76 |
+
try:
|
| 77 |
+
import deepspeed
|
| 78 |
+
except ImportError:
|
| 79 |
+
raise ImportError("DeepSpeed is not installed. Please install it with `pip install deepspeed`.")
|
| 80 |
+
with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
|
| 81 |
+
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
| 82 |
+
else:
|
| 83 |
+
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
| 84 |
+
new_embeddings = nn.Embedding(
|
| 85 |
+
num_new_tokens,
|
| 86 |
+
old_embedding_dim,
|
| 87 |
+
dtype=old_embeddings.weight.dtype,
|
| 88 |
+
device=old_embeddings.weight.device,
|
| 89 |
+
)
|
| 90 |
+
mapped_indices = (
|
| 91 |
+
torch.arange(num_new_tokens).to(old_embeddings.weight.device)
|
| 92 |
+
/ (num_new_tokens - 1)
|
| 93 |
+
* (old_num_tokens - 1)
|
| 94 |
+
)
|
| 95 |
+
floor_indices = torch.clamp(mapped_indices.floor().long(), min=0, max=old_num_tokens - 1)
|
| 96 |
+
ceil_indices = torch.clamp(mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1)
|
| 97 |
+
if is_deepspeed_zero3_enabled():
|
| 98 |
+
params = [old_embeddings.weight, new_embeddings.weight]
|
| 99 |
+
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
| 100 |
+
interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
|
| 101 |
+
ceil_indices, :
|
| 102 |
+
] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
|
| 103 |
+
else:
|
| 104 |
+
interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
|
| 105 |
+
ceil_indices, :
|
| 106 |
+
] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
|
| 107 |
+
new_embeddings.weight.data = interpolated_embeds
|
| 108 |
+
case _:
|
| 109 |
+
raise NotImplementedError
|
| 110 |
+
|
| 111 |
+
if hasattr(old_embeddings, "_hf_hook"):
|
| 112 |
+
hook = old_embeddings._hf_hook
|
| 113 |
+
add_hook_to_module(new_embeddings, hook)
|
| 114 |
+
new_embeddings.requires_grad_(old_embeddings.weight.requires_grad)
|
| 115 |
+
|
| 116 |
+
# Update vision encoder's configurations
|
| 117 |
+
model.config.image_size = resolution
|
| 118 |
+
if hasattr(image_processor, "crop_size"):
|
| 119 |
+
# CLIP vision tower
|
| 120 |
+
image_processor.crop_size = resolution
|
| 121 |
+
else:
|
| 122 |
+
# SIGLIP vision tower
|
| 123 |
+
assert hasattr(image_processor, "size")
|
| 124 |
+
image_processor.size = {"height": resolution, "width": resolution}
|
| 125 |
+
|
| 126 |
+
embeddings.position_embedding = new_embeddings
|
| 127 |
+
embeddings.image_size = resolution
|
| 128 |
+
embeddings.num_patches = embeddings.num_positions = num_new_tokens
|
| 129 |
+
embeddings.position_ids = (
|
| 130 |
+
torch.arange(embeddings.num_positions).expand((1, -1)).to(old_embeddings.weight.device)
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def forward(self, images):
|
| 134 |
+
if type(images) is list:
|
| 135 |
+
image_features = []
|
| 136 |
+
for image in images:
|
| 137 |
+
image_forward_out = self.vision_tower(
|
| 138 |
+
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
|
| 139 |
+
output_hidden_states=True,
|
| 140 |
+
)
|
| 141 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
| 142 |
+
image_features.append(image_feature)
|
| 143 |
+
else:
|
| 144 |
+
image_forward_outs = self.vision_tower(
|
| 145 |
+
images.to(device=self.device, dtype=self.dtype),
|
| 146 |
+
output_hidden_states=True,
|
| 147 |
+
)
|
| 148 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
| 149 |
+
|
| 150 |
+
return image_features
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def dummy_feature(self):
|
| 154 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 155 |
+
|
| 156 |
+
@property
|
| 157 |
+
def dtype(self):
|
| 158 |
+
return self.vision_tower.dtype
|
| 159 |
+
|
| 160 |
+
@property
|
| 161 |
+
def device(self):
|
| 162 |
+
return self.vision_tower.device
|
| 163 |
+
|
| 164 |
+
@property
|
| 165 |
+
def config(self):
|
| 166 |
+
if self.is_loaded:
|
| 167 |
+
return self.vision_tower.config
|
| 168 |
+
else:
|
| 169 |
+
return self.cfg_only
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def hidden_size(self):
|
| 173 |
+
return self.config.hidden_size
|
| 174 |
+
|
| 175 |
+
@property
|
| 176 |
+
def num_patches(self):
|
| 177 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class VisionTowerS2(VisionTower):
|
| 181 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
| 182 |
+
super().__init__(vision_tower, args, delay_load)
|
| 183 |
+
|
| 184 |
+
self.scales = list(map(int, args.s2_scales.split(",")))
|
| 185 |
+
self.scales.sort()
|
| 186 |
+
self.max_split_size = args.s2_max_split_size
|
| 187 |
+
self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
|
| 188 |
+
|
| 189 |
+
def forward_feature(self, images):
|
| 190 |
+
image_forward_outs = self.vision_tower(
|
| 191 |
+
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
|
| 192 |
+
)
|
| 193 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
| 194 |
+
return image_features
|
| 195 |
+
|
| 196 |
+
def forward(self, images):
|
| 197 |
+
if type(images) is list:
|
| 198 |
+
image_feature = []
|
| 199 |
+
for image in images:
|
| 200 |
+
image_feature = multiscale_forward(
|
| 201 |
+
self.forward_feature,
|
| 202 |
+
image.unsqueeze(0),
|
| 203 |
+
img_sizes=self.scales,
|
| 204 |
+
max_split_size=self.max_split_size,
|
| 205 |
+
resize_output_to_idx=self.resize_output_to_scale_idx,
|
| 206 |
+
)
|
| 207 |
+
image_features.append(image_feature)
|
| 208 |
+
else:
|
| 209 |
+
image_features = multiscale_forward(
|
| 210 |
+
self.forward_feature,
|
| 211 |
+
images,
|
| 212 |
+
img_sizes=self.scales,
|
| 213 |
+
max_split_size=self.max_split_size,
|
| 214 |
+
resize_output_to_idx=self.resize_output_to_scale_idx,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
return image_features
|
| 218 |
+
|
| 219 |
+
@property
|
| 220 |
+
def hidden_size(self):
|
| 221 |
+
return self.config.hidden_size * len(self.scales)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class VisionTowerDynamicS2(VisionTower):
|
| 225 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
| 226 |
+
super().__init__(vision_tower, args, delay_load)
|
| 227 |
+
|
| 228 |
+
self.scales = list(map(int, args.s2_scales.split(",")))
|
| 229 |
+
self.scales.sort()
|
| 230 |
+
self.max_split_size = args.s2_max_split_size
|
| 231 |
+
self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
|
| 232 |
+
|
| 233 |
+
def forward_feature(self, images):
|
| 234 |
+
image_forward_outs = self.vision_tower(
|
| 235 |
+
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
|
| 236 |
+
)
|
| 237 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
| 238 |
+
return image_features
|
| 239 |
+
|
| 240 |
+
def forward(self, images):
|
| 241 |
+
assert type(images) is not list
|
| 242 |
+
image_features = self.forward_feature(images)
|
| 243 |
+
|
| 244 |
+
return image_features
|
| 245 |
+
|
| 246 |
+
@property
|
| 247 |
+
def hidden_size(self):
|
| 248 |
+
return self.config.hidden_size * len(self.scales)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class SiglipVisionTower(VisionTower):
|
| 252 |
+
def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
|
| 253 |
+
super().__init__(model_name_or_path, config)
|
| 254 |
+
self.vision_tower = SiglipVisionModel.from_pretrained(
|
| 255 |
+
model_name_or_path,
|
| 256 |
+
attn_implementation=config._attn_implementation,
|
| 257 |
+
torch_dtype=eval(config.model_dtype),
|
| 258 |
+
)
|
| 259 |
+
self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
|
| 260 |
+
self.is_loaded = True
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class SiglipVisionTowerS2(VisionTowerS2):
|
| 264 |
+
def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
|
| 265 |
+
super().__init__(model_name_or_path, config)
|
| 266 |
+
self.vision_tower = SiglipVisionModel.from_pretrained(
|
| 267 |
+
model_name_or_path,
|
| 268 |
+
attn_implementation=config._attn_implementation,
|
| 269 |
+
torch_dtype=eval(config.model_dtype),
|
| 270 |
+
)
|
| 271 |
+
self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
|
| 272 |
+
# Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
|
| 273 |
+
self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[-1]
|
| 274 |
+
self.is_loaded = True
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class SiglipVisionTowerDynamicS2(VisionTowerDynamicS2):
|
| 278 |
+
def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
|
| 279 |
+
super().__init__(model_name_or_path, config)
|
| 280 |
+
if type(config.model_dtype) == str:
|
| 281 |
+
model_dtype = eval(config.model_dtype)
|
| 282 |
+
else:
|
| 283 |
+
model_dtype = config.model_dtype
|
| 284 |
+
|
| 285 |
+
self.vision_tower = SiglipVisionModel.from_pretrained(
|
| 286 |
+
model_name_or_path,
|
| 287 |
+
attn_implementation="flash_attention_2",
|
| 288 |
+
torch_dtype=model_dtype,
|
| 289 |
+
)
|
| 290 |
+
self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
|
| 291 |
+
# Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
|
| 292 |
+
self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[0]
|
| 293 |
+
self.is_loaded = True
|
sound_base_projector.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
| 21 |
+
|
| 22 |
+
class SoundMultimodalProjectorConfig(PretrainedConfig):
|
| 23 |
+
"""Configuration for sound multimodal projector."""
|
| 24 |
+
|
| 25 |
+
model_type = "sound_mm_projector"
|
| 26 |
+
|
| 27 |
+
def __init__(self, sound_mm_projector_type: str = None, **kwargs):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.sound_mm_projector_type = sound_mm_projector_type
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AudioDownSampleBlock(nn.Module):
|
| 33 |
+
"""Downsample audio features using 1D convolution."""
|
| 34 |
+
def __init__(self, embed_dim):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.conv1 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
x = rearrange(x, "b t c -> b c t")
|
| 40 |
+
x = self.conv1(x)
|
| 41 |
+
x = rearrange(x, "b c t -> b t c")
|
| 42 |
+
return x
|
| 43 |
+
|
| 44 |
+
class AudioDownSamplePoolBlock(nn.Module):
|
| 45 |
+
"""Downsample audio features using average pooling."""
|
| 46 |
+
|
| 47 |
+
def __init__(self, embed_dim):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.pool = nn.AvgPool1d(kernel_size=2)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
x = rearrange(x, "b t c -> b c t")
|
| 53 |
+
x = self.pool(x)
|
| 54 |
+
x = rearrange(x, "b c t -> b t c")
|
| 55 |
+
return x
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class AudioDownSampleMaxPoolBlock(nn.Module):
|
| 59 |
+
"""Downsample audio features using max pooling."""
|
| 60 |
+
|
| 61 |
+
def __init__(self, embed_dim):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.pool = nn.MaxPool1d(kernel_size=2)
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
x = rearrange(x, "b t c -> b c t")
|
| 67 |
+
x = self.pool(x)
|
| 68 |
+
x = rearrange(x, "b c t -> b t c")
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class SoundMultimodalProjector(PreTrainedModel):
|
| 73 |
+
"""Sound multimodal projector for mapping audio features to LLM space."""
|
| 74 |
+
config_class = SoundMultimodalProjectorConfig
|
| 75 |
+
|
| 76 |
+
def __init__(self, sound_mm_projector_cfg: SoundMultimodalProjectorConfig, config: PretrainedConfig):
|
| 77 |
+
super().__init__(sound_mm_projector_cfg)
|
| 78 |
+
if hasattr(config, "sound_mm_projector"):
|
| 79 |
+
sound_mm_projector_type = config.sound_mm_projector
|
| 80 |
+
else:
|
| 81 |
+
sound_mm_projector_type = sound_mm_projector_cfg.sound_mm_projector_type
|
| 82 |
+
self.sound_mm_projector_type = sound_mm_projector_type
|
| 83 |
+
self.config.sound_mm_projector_type = sound_mm_projector_type
|
| 84 |
+
|
| 85 |
+
if hasattr(config, "sound_mm_projector_cfg") and type(config.sound_mm_projector_cfg) == dict:
|
| 86 |
+
config.sound_mm_projector_cfg["sound_mm_projector_type"] = sound_mm_projector_type
|
| 87 |
+
|
| 88 |
+
if sound_mm_projector_type == "mlp":
|
| 89 |
+
self.layers = nn.Sequential(
|
| 90 |
+
nn.Linear(config.sound_hidden_size, config.hidden_size),
|
| 91 |
+
nn.GELU(),
|
| 92 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
| 93 |
+
)
|
| 94 |
+
elif sound_mm_projector_type == "mlp_downsample":
|
| 95 |
+
self.downsample_block = AudioDownSampleBlock(config.sound_hidden_size)
|
| 96 |
+
self.layers = nn.Sequential(
|
| 97 |
+
nn.Linear(config.sound_hidden_size, config.hidden_size),
|
| 98 |
+
nn.GELU(),
|
| 99 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
| 100 |
+
)
|
| 101 |
+
elif sound_mm_projector_type == "mlp_downsample_pool":
|
| 102 |
+
self.downsample_block = AudioDownSamplePoolBlock(config.sound_hidden_size)
|
| 103 |
+
self.layers = nn.Sequential(
|
| 104 |
+
nn.Linear(config.sound_hidden_size, config.hidden_size),
|
| 105 |
+
nn.GELU(),
|
| 106 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
| 107 |
+
)
|
| 108 |
+
elif sound_mm_projector_type == "mlp_downsample_pool_max":
|
| 109 |
+
self.downsample_block = AudioDownSampleMaxPoolBlock(config.sound_hidden_size)
|
| 110 |
+
self.layers = nn.Sequential(
|
| 111 |
+
nn.Linear(config.sound_hidden_size, config.hidden_size),
|
| 112 |
+
nn.GELU(),
|
| 113 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
raise ValueError(f"Unknown projector type: {sound_mm_projector_type}")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def forward(self, x, *args, **kwargs):
|
| 120 |
+
if self.sound_mm_projector_type in ["mlp_downsample", "mlp_downsample_pool", "mlp_downsample_pool_max"]:
|
| 121 |
+
x = self.downsample_block(x)
|
| 122 |
+
return self.layers(x)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
AutoConfig.register("sound_mm_projector", SoundMultimodalProjectorConfig)
|
| 126 |
+
AutoModel.register(SoundMultimodalProjectorConfig, SoundMultimodalProjector)
|
sound_mm_projector/config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "/lustre/fs12/portfolios/llmservice/projects/llmservice_fm_vision/users/hanrongy/project/vila/VILA-Internal/../exp_log/nvomni-8b-video-0d1-trope128_omniT_ras_n16_bs2048_ga8_mstep-1_j20250718/outputs/model/sound_mm_projector",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"SoundMultimodalProjector"
|
| 5 |
+
],
|
| 6 |
+
"model_type": "sound_mm_projector",
|
| 7 |
+
"sound_mm_projector_type": "mlp",
|
| 8 |
+
"torch_dtype": "bfloat16",
|
| 9 |
+
"transformers_version": "4.46.0"
|
| 10 |
+
}
|
sound_mm_projector/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bb57ebfdeb51af4a1c0de931fd43e6a4b93277552ad02ad01b1d9ba720bcb9a4
|
| 3 |
+
size 34879856
|
sound_tower/config.json
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "outputs/model/sound_tower",
|
| 3 |
+
"activation_dropout": 0.0,
|
| 4 |
+
"activation_function": "gelu",
|
| 5 |
+
"architectures": [
|
| 6 |
+
"Qwen2AudioEncoder"
|
| 7 |
+
],
|
| 8 |
+
"attention_dropout": 0.0,
|
| 9 |
+
"audio_config": {
|
| 10 |
+
"activation_function": "gelu",
|
| 11 |
+
"d_model": 1280,
|
| 12 |
+
"encoder_attention_heads": 20,
|
| 13 |
+
"encoder_ffn_dim": 5120,
|
| 14 |
+
"encoder_layers": 32,
|
| 15 |
+
"max_source_positions": 1500,
|
| 16 |
+
"model_type": "qwen2_audio_encoder",
|
| 17 |
+
"num_mel_bins": 128,
|
| 18 |
+
"scale_embedding": false
|
| 19 |
+
},
|
| 20 |
+
"audio_token_index": 151646,
|
| 21 |
+
"d_model": 1280,
|
| 22 |
+
"dropout": 0.0,
|
| 23 |
+
"encoder_attention_heads": 20,
|
| 24 |
+
"encoder_ffn_dim": 5120,
|
| 25 |
+
"encoder_layerdrop": 0.0,
|
| 26 |
+
"encoder_layers": 32,
|
| 27 |
+
"ignore_index": -100,
|
| 28 |
+
"init_std": 0.02,
|
| 29 |
+
"max_source_positions": 1500,
|
| 30 |
+
"model_type": "qwen2_audio_encoder",
|
| 31 |
+
"num_hidden_layers": 32,
|
| 32 |
+
"num_mel_bins": 128,
|
| 33 |
+
"scale_embedding": false,
|
| 34 |
+
"text_config": {
|
| 35 |
+
"bos_token_id": 151643,
|
| 36 |
+
"eos_token_id": 151645,
|
| 37 |
+
"intermediate_size": 11008,
|
| 38 |
+
"max_position_embeddings": 8192,
|
| 39 |
+
"model_type": "qwen2",
|
| 40 |
+
"rms_norm_eps": 1e-05,
|
| 41 |
+
"rope_theta": 10000,
|
| 42 |
+
"sliding_window": 32768,
|
| 43 |
+
"torch_dtype": "bfloat16",
|
| 44 |
+
"use_mrope": false,
|
| 45 |
+
"vocab_size": 156032
|
| 46 |
+
},
|
| 47 |
+
"torch_dtype": "bfloat16",
|
| 48 |
+
"transformers_version": "4.46.0",
|
| 49 |
+
"vocab_size": 156032
|
| 50 |
+
}
|
sound_tower/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:01fcbfa8ac3d3bc4c5ab97c439dfecfea2a9c2e061031280efed292fc37b4a44
|
| 3 |
+
size 1273988176
|