leoye commited on
Commit
fd01e7c
·
0 Parent(s):

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ asset/omni_benchmarks.png filter=lfs diff=lfs merge=lfs -text
37
+ asset/omni_benchmarks2.png filter=lfs diff=lfs merge=lfs -text
38
+ llm/tokenizer.json filter=lfs diff=lfs merge=lfs -text
39
+ asset/performance.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # <span style="background: linear-gradient(45deg, #667eea 0%, #764ba2 25%, #f093fb 50%, #f5576c 75%, #4facfe 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; font-weight: bold; font-size: 1.1em;">**OmniVinci: Enhancing Architecture and Data for Omni-Modal Understanding LLM**</span> <br />
2
+
3
+ [![Paper](https://img.shields.io/badge/ArXiv-Paper-brown)](https://arxiv.org/)
4
+ [![Code](https://img.shields.io/badge/GitHub-Link-blue)](https://github.com/NVlabs)
5
+ [![Model](https://img.shields.io/badge/HuggingFace-Model-yellow)](https://huggingface.co/nvidia/omnivinci)
6
+
7
+
8
+ ## Introduction
9
+ OmniVinci is an NVIDIA research project focused on exploring omni-modal LLMs that can not only see and read but also listen, speak, and reason.
10
+
11
+ We are among the best omni-modality understanding models. Check out our performance on some of the most popular omni-modality, audio, and vision benchmarks:
12
+ <p align="center">
13
+ <img src="./asset/performance.png" width="80%"/>
14
+ <p>
15
+
16
+
17
+ ## Quickstart
18
+
19
+ Below, we provide simple examples to show how to use our model with Transformers.
20
+
21
+ ### Environment Setup
22
+
23
+ 1. Download and navigate to the HuggingFace repository:
24
+ ```
25
+ huggingface-cli download nvidia/omnivinci --local-dir ./omnivinci --local-dir-use-symlinks False
26
+ cd ./omnivinci
27
+ ```
28
+
29
+ 2. Install Python environment (based on NVILA codebase):
30
+ ```
31
+ bash ./environment_setup.sh omnivinci
32
+ ```
33
+
34
+ ### 🤗 Transformers Usage
35
+
36
+ #### Video (with Audio) Inference Example
37
+ ```python
38
+ from transformers import AutoProcessor, AutoModel, AutoConfig,AutoModelForCausalLM
39
+ import torch
40
+ import os
41
+
42
+ # default: Load the model on the available device(s)
43
+ model_path = "./"
44
+ video_path = "xxx.mp4"
45
+ generation_kwargs = {"max_new_tokens": 1024, "max_length": 99999999}
46
+ load_audio_in_video = True
47
+ num_video_frames = 128
48
+ audio_length = "max_3600"
49
+
50
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
51
+
52
+ model = AutoModel.from_pretrained(model_path,
53
+ trust_remote_code=True,
54
+ torch_dtype="torch.float16",
55
+ device_map="auto")
56
+
57
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
58
+ generation_config = model.default_generation_config
59
+ generation_config.update(**generation_kwargs)
60
+
61
+ model.config.load_audio_in_video = load_audio_in_video
62
+ processor.config.load_audio_in_video = load_audio_in_video
63
+ if num_video_frames > 0:
64
+ model.config.num_video_frames = num_video_frames
65
+ processor.config.num_video_frames = num_video_frames
66
+ if audio_length != -1:
67
+ model.config.audio_chunk_length = audio_length
68
+ processor.config.audio_chunk_length = audio_length
69
+
70
+
71
+ conversation = [{
72
+ "role": "user",
73
+ "content": [
74
+ {"type": "video", "video":video_path},
75
+ {"type": "text", "text": "Assess the video, followed by a detailed description of its video and audio contents."}
76
+ ]
77
+ }]
78
+ text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
79
+
80
+ inputs = processor([text])
81
+
82
+ output_ids = model.generate(
83
+ input_ids=inputs.input_ids,
84
+ media=getattr(inputs, 'media', None),
85
+ media_config=getattr(inputs, 'media_config', None),
86
+ generation_config=generation_config,
87
+ )
88
+ print(processor.tokenizer.batch_decode(output_ids, skip_special_tokens=True))
89
+ ```
90
+
91
+ - **For audio and image inference examples, please refer to `example_mini_audio.py` and `example_mini_image.py`.**
92
+
93
+
94
+ ## License / Terms of Use
95
+ The model is released under the [NVIDIA OneWay Noncommercial License](asset/NVIDIA_OneWay_Noncommercial_License.docx).
asset/NVIDIA_OneWay_Noncommercial_License.docx ADDED
Binary file (20.6 kB). View file
 
asset/omni_benchmarks.png ADDED

Git LFS Details

  • SHA256: 582f1f1c454a3c775162ed469ceed6a76aeca1f3e4d57e5c7710ae0eb1310dfa
  • Pointer size: 131 Bytes
  • Size of remote file: 684 kB
asset/omni_benchmarks2.png ADDED

Git LFS Details

  • SHA256: a7d759ea879119e1b894dd16d353eb24409d9d1c7d206a3b79eee3e093cce28d
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB
asset/performance.png ADDED

Git LFS Details

  • SHA256: 33be284f6fcff5627ebb9e3597944ac5fca9f8003a78a465ece7be7eefa3d0c1
  • Pointer size: 131 Bytes
  • Size of remote file: 233 kB
audio_encoder.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+
22
+ class AudioTower(nn.Module):
23
+ def __init__(self, audio_tower, args, delay_load=False):
24
+ super().__init__()
25
+
26
+ self.is_loaded = False
27
+
28
+ self.audio_tower_name = audio_tower
29
+ self.cfg_only = None
30
+
31
+ def forward(self, sounds):
32
+ if type(sounds) is list:
33
+ sound_features = []
34
+ audio_output_lengths = []
35
+ for sound in sounds:
36
+ if hasattr(sound, "input_features"):
37
+ sound = sound["input_features"]
38
+ sound_feature = self.audio_tower(sound)
39
+ sound_feature = sound_feature.last_hidden_state
40
+ sound_feature = sound_feature.to(sound.dtype)
41
+ sound_features.append(sound_feature)
42
+ audio_output_lengths.append(sound_feature.shape[1])
43
+ sound_features = torch.cat(sound_features, dim=1).squeeze(0)
44
+ else:
45
+ raise NotImplementedError("Not implemented for this encoder")
46
+
47
+ return sound_features, audio_output_lengths
48
+
49
+ @property
50
+ def dummy_feature(self):
51
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
52
+
53
+ @property
54
+ def dtype(self):
55
+ return self.audio_tower.dtype
56
+
57
+ @property
58
+ def config(self):
59
+ if self.is_loaded:
60
+ return self.audio_tower.config
61
+ else:
62
+ return self.cfg_only
63
+
64
+ @property
65
+ def device(self):
66
+ return self.audio_tower.device
67
+
68
+ @property
69
+ def hidden_size(self):
70
+ return self.config.hidden_size
auto_processor.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import copy
17
+ import os
18
+ import os.path as osp
19
+ import warnings
20
+ from collections import defaultdict
21
+ from io import BytesIO
22
+ from typing import List, Optional, Union
23
+
24
+ import PIL.Image
25
+ import requests
26
+ import torch
27
+ from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoProcessor, AutoTokenizer
28
+ from transformers.feature_extraction_utils import BatchFeature
29
+ from transformers.image_utils import ImageInput
30
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
31
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
32
+ from transformers.utils import logging
33
+
34
+ from .constants import DEFAULT_IMAGE_TOKEN, MEDIA_TOKENS
35
+ from .media import Image, Video, extract_media, Sound
36
+ from .mm_utils import process_image, process_images
37
+ from .tokenizer_utils import tokenize_conversation
38
+
39
+
40
+ def to_rgb(pil_image: PIL.Image.Image) -> PIL.Image.Image:
41
+ """Convert PIL image to RGB format."""
42
+ if pil_image.mode == "RGBA":
43
+ white_background = PIL.Image.new("RGB", pil_image.size, (255, 255, 255))
44
+ white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
45
+ return white_background
46
+ else:
47
+ return pil_image.convert("RGB")
48
+
49
+
50
+ def fetch_image(ele: dict[str, str | PIL.Image.Image], size_factor=None) -> PIL.Image.Image:
51
+ """Fetch and load image from various sources (local path, URL, base64, PIL.Image)."""
52
+ if "image" in ele:
53
+ image = ele["image"]
54
+ else:
55
+ image = ele["image_url"]
56
+ image_obj = None
57
+ if isinstance(image, PIL.Image.Image):
58
+ image_obj = image
59
+ elif image.startswith("http://") or image.startswith("https://"):
60
+ response = requests.get(image, stream=True)
61
+ image_obj = PIL.Image.open(BytesIO(response.content))
62
+ elif image.startswith("file://"):
63
+ image_obj = PIL.Image.open(image[7:])
64
+ elif image.startswith("data:image"):
65
+ if "base64," in image:
66
+ _, base64_data = image.split("base64,", 1)
67
+ data = base64.b64decode(base64_data)
68
+ image_obj = PIL.Image.open(BytesIO(data))
69
+ else:
70
+ image_obj = PIL.Image.open(image)
71
+ if image_obj is None:
72
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
73
+ image = to_rgb(image_obj)
74
+
75
+ return image
76
+
77
+
78
+ def fetch_image_url_or_fpath(url_or_fpath):
79
+ """Fetch image from URL or local file path, returns local file path."""
80
+ if url_or_fpath.startswith("http") or url_or_fpath.startswith("https"):
81
+ import tempfile
82
+
83
+ import requests
84
+
85
+ # Download the image to a temporary file
86
+ temp_dir = tempfile.mkdtemp()
87
+ temp_file = os.path.join(temp_dir, os.path.basename(url_or_fpath))
88
+
89
+ response = requests.get(url_or_fpath, stream=True)
90
+ response.raise_for_status()
91
+
92
+ with open(temp_file, "wb") as f:
93
+ for chunk in response.iter_content(chunk_size=8192):
94
+ f.write(chunk)
95
+
96
+ return temp_file
97
+ elif url_or_fpath.startswith("file://"):
98
+ fpath = url_or_fpath.replace("file://", "")
99
+ assert osp.exists(fpath), f"File {fpath} does not exist"
100
+ return fpath
101
+ elif osp.exists(url_or_fpath):
102
+ assert osp.isfile(url_or_fpath), f"File {url_or_fpath} does not exist"
103
+ return url_or_fpath
104
+ else:
105
+ raise ValueError(f"Unsupported image path: {url_or_fpath}")
106
+
107
+
108
+ def pad_fn(input_ids_list: List[torch.Tensor], padding_value=0, target_len=None, padding_side="left") -> torch.Tensor:
109
+ # tensor shape is (batch_size, seq_len)
110
+ max_len = max([ids.shape[1] for ids in input_ids_list])
111
+ if target_len is not None:
112
+ assert target_len >= max_len, "target_len must be greater than or equal to max_len"
113
+ max_len = target_len
114
+
115
+ new_input_ids_list = []
116
+ for i, input_ids in enumerate(input_ids_list):
117
+ pad_tensor = torch.ones_like(input_ids) * padding_value
118
+ curr_len = input_ids.shape[1]
119
+ pad_tensor = pad_tensor[:, : max_len - curr_len]
120
+ if padding_side == "right":
121
+ input_ids = torch.cat((input_ids, pad_tensor), dim=1)
122
+ else:
123
+ input_ids = torch.cat((pad_tensor, input_ids), dim=1)
124
+ new_input_ids_list.append(input_ids)
125
+ return torch.cat(new_input_ids_list, dim=0)
126
+
127
+
128
+ def extract_value_from_conv(chat):
129
+ value = []
130
+ if isinstance(chat["content"], str):
131
+ value.append(chat["content"])
132
+ return value
133
+
134
+ # otherwise, it's a list of content
135
+ for content in chat["content"]:
136
+ if content["type"] == "image":
137
+ if "path" in content:
138
+ # VILA style, can be either filepath or http url
139
+ value.append(Image(fetch_image_url_or_fpath(content["path"])))
140
+ elif "image" in content:
141
+ # Qwen style
142
+ value.append(Image(fetch_image_url_or_fpath(content["image"])))
143
+ elif "image_pil" in content:
144
+ # Qwen style
145
+ assert isinstance(content["image_pil"], PIL.Image.Image), f"Type of image_pil must be PIL.Image.Image"
146
+ value.append(content["image_pil"])
147
+ else:
148
+ raise ValueError(f"Type = `image` , but no `path` or `image` in {chat['content']}")
149
+ elif content["type"] == "video":
150
+ if "video" in content:
151
+ # Qwen style
152
+ value.append(Video(fetch_image_url_or_fpath(content["video"])))
153
+ else:
154
+ raise ValueError(f"Type = `video` , but no `video` in {chat['content']}")
155
+ elif content["type"] == "text":
156
+ value.append(content["text"])
157
+ elif content["type"] == "audio":
158
+ value.append(Sound(fetch_image_url_or_fpath(content["audio"])))
159
+ elif content["type"] == "sound":
160
+ value.append(Sound(fetch_image_url_or_fpath(content["sound"])))
161
+ elif content["type"] == "speech":
162
+ value.append(Sound(fetch_image_url_or_fpath(content["speech"])))
163
+ else:
164
+ raise ValueError(f"Unsupported content type: {content['type']}")
165
+ return value
166
+
167
+
168
+ class VILAProcessorKwargs(ProcessingKwargs, total=False):
169
+ _defaults = {
170
+ "text_kwargs": {
171
+ "padding": False,
172
+ },
173
+ }
174
+
175
+
176
+ class VILAProcessor(ProcessorMixin):
177
+ attributes = []
178
+ valid_kwargs = []
179
+
180
+ def __init__(
181
+ self, image_processor=None, tokenizer=None, chat_template=None, config=None, padding_side="left", **kwargs
182
+ ):
183
+ self.image_token = MEDIA_TOKENS["image"]
184
+ self.video_token = MEDIA_TOKENS["video"]
185
+ self.speech_token = MEDIA_TOKENS["speech"]
186
+ self.sound_token = MEDIA_TOKENS["sound"]
187
+ self.config = config
188
+ self.image_processor = image_processor
189
+ self.tokenizer = tokenizer
190
+ self.padding_side = padding_side
191
+
192
+ # Use <|endoftext|> token as padding token for Qwen models
193
+ self.pad_token_id = self.tokenizer("<|endoftext|>").input_ids[0]
194
+ self.eos_token_id = self.tokenizer.eos_token_id
195
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
196
+
197
+ @staticmethod
198
+ def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
199
+ """
200
+ Extract vision information from conversations.
201
+ Reference: qwen_vl_utils
202
+ """
203
+ vision_infos = []
204
+ if isinstance(conversations[0], dict):
205
+ conversations = [conversations]
206
+ for conversation in conversations:
207
+ for message in conversation:
208
+ if isinstance(message["content"], list):
209
+ for ele in message["content"]:
210
+ if (
211
+ "image" in ele
212
+ or "image_url" in ele
213
+ or "video" in ele
214
+ or ele["type"] in ("image", "image_url", "video")
215
+ ):
216
+ vision_infos.append(ele)
217
+ return vision_infos
218
+
219
+ @staticmethod
220
+ def process_vision_info(
221
+ conversations: list[dict] | list[list[dict]],
222
+ return_video_kwargs: bool = False,
223
+ ) -> tuple[list[PIL.Image.Image] | None, list[torch.Tensor | list[PIL.Image.Image]] | None, Optional[dict]]:
224
+ """
225
+ Process vision information from conversations.
226
+ Reference: qwen_vl_utils
227
+
228
+ Note: NVILA does not depend on this function, but maintains the same interface.
229
+ """
230
+ vision_infos = extract_vision_info(conversations)
231
+ # Read images or videos
232
+ image_inputs = []
233
+ video_inputs = []
234
+ video_sample_fps_list = []
235
+ for vision_info in vision_infos:
236
+ if "image" in vision_info or "image_url" in vision_info:
237
+ image_inputs.append(fetch_image(vision_info))
238
+ elif "video" in vision_info:
239
+ video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
240
+ video_sample_fps_list.append(video_sample_fps)
241
+ video_inputs.append(video_input)
242
+ else:
243
+ raise ValueError("image, image_url or video should in content.")
244
+ if len(image_inputs) == 0:
245
+ image_inputs = None
246
+ if len(video_inputs) == 0:
247
+ video_inputs = None
248
+ if return_video_kwargs:
249
+ return image_inputs, video_inputs, {"fps": video_sample_fps_list}
250
+ return image_inputs, video_inputs
251
+
252
+ @staticmethod
253
+ def move_data_to_device(cls, prompt_inputs):
254
+ def _move_data_to_device(item):
255
+ # wrap function grpo trainer _prepare_input
256
+ kwargs = {"device": cls.args.device}
257
+ if cls.is_deepspeed_enabled and (torch.is_floating_point(item) or torch.is_complex(item)):
258
+ kwargs.update({"dtype": cls.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
259
+ return item.to(**kwargs)
260
+
261
+ prompt_inputs.input_ids = _move_data_to_device(prompt_inputs.input_ids)
262
+ prompt_inputs.attention_mask = _move_data_to_device(prompt_inputs.attention_mask)
263
+ if "image" in prompt_inputs.media:
264
+ prompt_inputs.media["image"] = [_move_data_to_device(img) for img in prompt_inputs.media["image"]]
265
+ return prompt_inputs
266
+
267
+ @classmethod
268
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
269
+ padding_side = kwargs.get("padding_side", "left")
270
+ if os.path.isdir(pretrained_model_name_or_path):
271
+ pretrained_model_name_or_path = pretrained_model_name_or_path
272
+ else:
273
+ print(f"pretrained_model_name_or_path {pretrained_model_name_or_path} is not a directory, downloading")
274
+ from huggingface_hub import snapshot_download
275
+
276
+ pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path)
277
+
278
+ image_processor = AutoImageProcessor.from_pretrained(
279
+ osp.join(pretrained_model_name_or_path, "vision_tower"), trust_remote_code=True
280
+ )
281
+ tokenizer = AutoTokenizer.from_pretrained(
282
+ osp.join(pretrained_model_name_or_path, "llm"), trust_remote_code=True
283
+ )
284
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
285
+
286
+ return cls(image_processor=image_processor, tokenizer=tokenizer, config=config, padding_side=padding_side)
287
+
288
+ def __repr__(self):
289
+ return f"VILAProcessor(image_processor=SigLip, tokenizer={self.tokenizer}, config={self.config})"
290
+
291
+ def __call__(
292
+ self,
293
+ conversation=None,
294
+ **kwargs: Unpack[VILAProcessorKwargs],
295
+ ) -> BatchFeature:
296
+ """
297
+ The `conv` will be look like
298
+ [
299
+ {
300
+ 'from': 'human',
301
+ 'value': [
302
+ <transformers_modules.NVILA-Lite-2B-hf-preview.media.Image object at 0x154e68e4c460>,
303
+ 'What are the common elements in these pictures?'
304
+ ]
305
+ }
306
+ ]
307
+ and `conversation` will be a list of such `conv`s
308
+ """
309
+ if kwargs.get("text", None) is not None:
310
+ conversation = kwargs.get("text")
311
+ assert conversation is not None, "`conversation` or `text` is required"
312
+ padding_side = kwargs.get("padding_side", self.padding_side)
313
+
314
+ input_ids_list = []
315
+ attention_mask = []
316
+ media = defaultdict(list)
317
+ media_config = defaultdict(dict)
318
+ for conv in conversation:
319
+ feat = self.__single_call__(conv, **kwargs)
320
+ input_ids_list.append(feat.input_ids)
321
+ attention_mask.append(feat.attention_mask)
322
+ for name in feat.media:
323
+ media[name] += feat.media[name]
324
+ for name in feat.media_config:
325
+ media_config[name].update(feat.media_config[name])
326
+
327
+ # pad the input_ids to batchfy
328
+ input_ids = pad_fn(
329
+ input_ids_list,
330
+ padding_value=self.pad_token_id,
331
+ padding_side=padding_side,
332
+ )
333
+ # Ignore the pad token in the attention mask
334
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
335
+ attention_mask[input_ids == self.pad_token_id] = False
336
+ input_texts = self.tokenizer.batch_decode(input_ids)
337
+ bdata = BatchFeature(
338
+ data={
339
+ # "input_texts": input_texts,
340
+ "input_ids": input_ids,
341
+ "attention_mask": attention_mask,
342
+ "media": media,
343
+ "media_config": media_config,
344
+ }
345
+ )
346
+ return bdata
347
+
348
+ def __single_call__(
349
+ self,
350
+ conversation,
351
+ images: ImageInput = None,
352
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
353
+ videos = None,
354
+ **kwargs: Unpack[VILAProcessorKwargs],
355
+ ) -> BatchFeature:
356
+ conversation = copy.deepcopy(conversation)
357
+ media = extract_media(conversation, self.config)
358
+ # Process media
359
+ media_config = defaultdict(dict)
360
+ for name in media:
361
+ if name == "image":
362
+ if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
363
+ self.config.image_processor = self.image_processor
364
+ if self.config.image_aspect_ratio == "dynamic":
365
+ images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
366
+ # Note: This assumes images appear at the first conversation position
367
+ conversation[0]["value"] = conversation[0]["value"].replace(
368
+ DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
369
+ )
370
+ else:
371
+ if type(self.config.s2_scales) is str:
372
+ self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
373
+ images, block_sizes = process_image(
374
+ media["image"][0], self.config, None, enable_dynamic_s2=True
375
+ )
376
+ images = images.half()
377
+ media_config[name]["block_sizes"] = [block_sizes]
378
+ else:
379
+ images = process_images(media["image"], self.image_processor, self.config).half()
380
+ media[name] = [image for image in images]
381
+ elif name == "video":
382
+ media[name] = [
383
+ process_images(images, self.image_processor, self.config).half() for images in media[name]
384
+ ]
385
+ elif name == "speech":
386
+ speeches = media["speech"]
387
+ media[name] = [speech for speech in speeches]
388
+ elif name == "sound":
389
+ sounds = media["sound"]
390
+ for sound in sounds:
391
+ if type(sound) is dict:
392
+ for k, v in sound.items():
393
+ sound[k] = v.half()
394
+ media[name] = [sound for sound in sounds]
395
+ elif name == "video_info":
396
+ media[name] = [media["video_info"]]
397
+ elif name == "audio_info":
398
+ media[name] = [media["audio_info"]]
399
+ else:
400
+ raise ValueError(f"Unsupported media type: {name}")
401
+
402
+ inputs = tokenize_conversation(
403
+ conversation,
404
+ self.tokenizer,
405
+ mm_use_bos_eos_tokens=self.config.mm_use_bos_eos_tokens,
406
+ unified_audio_encoder=self.config.unified_audio_encoder,
407
+ add_generation_prompt=True,
408
+ )
409
+
410
+ input_ids = inputs.unsqueeze(0)
411
+
412
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
413
+ return BatchFeature(
414
+ data={
415
+ "input_ids": input_ids,
416
+ "attention_mask": attention_mask,
417
+ "media": media,
418
+ "media_config": media_config,
419
+ }
420
+ )
421
+
422
+ def batch_decode(self, *args, **kwargs):
423
+ """
424
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
425
+ refer to the docstring of this method for more information.
426
+ """
427
+ return self.tokenizer.batch_decode(*args, **kwargs)
428
+
429
+ def decode(self, *args, **kwargs):
430
+ """
431
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
432
+ the docstring of this method for more information.
433
+ """
434
+ return self.tokenizer.decode(*args, **kwargs)
435
+
436
+ def post_process_image_text_to_text(self, generated_outputs):
437
+ """
438
+ Post-process the output of the model to decode the text.
439
+
440
+ Args:
441
+ generated_outputs (`torch.Tensor` or `np.ndarray`):
442
+ The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
443
+ or `(sequence_length,)`.
444
+
445
+ Returns:
446
+ `List[str]`: The decoded text.
447
+ """
448
+ return self.tokenizer.batch_decode(
449
+ generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
450
+ )
451
+
452
+ @property
453
+ def model_input_names(self):
454
+ tokenizer_input_names = self.tokenizer.model_input_names
455
+ image_processor_input_names = self.image_processor.model_input_names
456
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
457
+
458
+ def convert_gpt_conv_to_vila_conv(self, conversation):
459
+ vila_conv = []
460
+ for chat in conversation:
461
+ vila_chat = {"from": "", "value": []}
462
+ if chat["role"] in ("user", "system"):
463
+ # user allows to input image and text
464
+ vila_chat["from"] = "human" if chat["role"] == "user" else "system"
465
+ vila_chat["value"] = extract_value_from_conv(chat)
466
+ elif chat["role"] == "assistant":
467
+ vila_chat["from"] = "gpt"
468
+ vila_chat["value"] = extract_value_from_conv(chat)
469
+ else:
470
+ raise ValueError(f"Unsupported role: {chat['role']} in chat {chat}")
471
+ vila_conv.append(vila_chat)
472
+
473
+ return vila_conv
474
+
475
+ def apply_chat_template(self, conversation, add_generation_prompt=True, **kwargs):
476
+ return self.convert_gpt_conv_to_vila_conv(conversation)
base_projector.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import re
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
22
+
23
+
24
+ class IdentityMap(nn.Module):
25
+ """Identity mapping that returns input unchanged."""
26
+ def __init__(self):
27
+ super().__init__()
28
+
29
+ def forward(self, x, *args, **kwargs):
30
+ return x
31
+
32
+ @property
33
+ def config(self):
34
+ return {"mm_projector_type": "identity"}
35
+
36
+
37
+ class SimpleResBlock(nn.Module):
38
+ """Simple residual block with layer normalization."""
39
+
40
+ def __init__(self, channels):
41
+ super().__init__()
42
+ self.pre_norm = nn.LayerNorm(channels)
43
+ self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
44
+
45
+ def forward(self, x):
46
+ x = self.pre_norm(x)
47
+ return x + self.proj(x)
48
+
49
+
50
+ class DownSampleBlock(nn.Module):
51
+ """Downsample 2D feature maps by rearranging into 2x2 blocks."""
52
+ def forward(self, x):
53
+ vit_embeds = x
54
+ h = w = int(vit_embeds.shape[1] ** 0.5)
55
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
56
+ vit_embeds = self.flat_square(vit_embeds)
57
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
58
+ return vit_embeds
59
+
60
+ def flat_square(self, x):
61
+ n, w, h, c = x.size()
62
+ if w % 2 == 1:
63
+ x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
64
+ n, w, h, c = x.size()
65
+ if h % 2 == 1:
66
+ x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
67
+ n, w, h, c = x.size()
68
+ x = x.contiguous()
69
+ x = x.view(n, w, int(h / 2), int(c * 2))
70
+ x = x.permute(0, 2, 1, 3).contiguous()
71
+ x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
72
+ x = x.permute(0, 2, 1, 3).contiguous()
73
+ return x
74
+
75
+
76
+ class DownSample2x2BlockFix(nn.Module):
77
+ """Downsample 2D feature maps by rearranging into 2x2 blocks (fixed version)."""
78
+
79
+ def forward(self, x):
80
+ vit_embeds = x
81
+ h = w = int(vit_embeds.shape[1] ** 0.5)
82
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
83
+ vit_embeds = flat_square_2x2(vit_embeds)
84
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
85
+ return vit_embeds
86
+
87
+
88
+ def flat_square_2x2(x):
89
+ """Rearrange feature map into 2x2 blocks."""
90
+ n, w, h, c = x.size()
91
+ if w % 2 == 1:
92
+ x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
93
+ n, w, h, c = x.size()
94
+ x = x.contiguous()
95
+ if h % 2 == 1:
96
+ x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
97
+ n, w, h, c = x.size()
98
+ x = x.view(n, w, int(h / 2), int(c * 2))
99
+ x = x.permute(0, 2, 1, 3).contiguous()
100
+ x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
101
+ x = x.permute(0, 2, 1, 3).contiguous()
102
+ return x
103
+
104
+
105
+ class DownSample3x3BlockFix(nn.Module):
106
+ """Downsample 2D feature maps by rearranging into 3x3 blocks (fixed version)."""
107
+
108
+ def forward(self, x):
109
+ vit_embeds = x
110
+ h = w = int(vit_embeds.shape[1] ** 0.5)
111
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
112
+ vit_embeds = flat_square_3x3(vit_embeds)
113
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
114
+ return vit_embeds
115
+
116
+
117
+ def flat_square_3x3(x):
118
+ """Rearrange feature map into 3x3 blocks."""
119
+ n, w, h, c = x.size()
120
+ if w % 3 != 0:
121
+ x = torch.concat([x, torch.zeros((n, 3 - (w % 3), h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
122
+ n, w, h, c = x.size()
123
+ x = x.contiguous()
124
+ if h % 3 != 0:
125
+ x = torch.concat([x, torch.zeros((n, w, 3 - (h % 3), c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
126
+ n, w, h, c = x.size()
127
+ x = x.view(n, w, int(h / 3), int(c * 3))
128
+ x = x.permute(0, 2, 1, 3).contiguous()
129
+ x = x.view(n, int(h / 3), int(w / 3), int(c * 9))
130
+ x = x.permute(0, 2, 1, 3).contiguous()
131
+ return x
132
+
133
+
134
+ class MultimodalProjectorConfig(PretrainedConfig):
135
+ """Configuration for vision-to-language projector."""
136
+
137
+ model_type = "v2l_projector"
138
+
139
+ def __init__(self, mm_projector_type: str = None, **kwargs):
140
+ super().__init__()
141
+ self.mm_projector_type = mm_projector_type
142
+
143
+
144
+ class MultimodalProjector(PreTrainedModel):
145
+ """Multimodal projector for mapping vision features to LLM space."""
146
+ config_class = MultimodalProjectorConfig
147
+
148
+ def __init__(self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig):
149
+ super().__init__(mm_projector_cfg)
150
+ mm_projector_type = mm_projector_cfg.mm_projector_type
151
+ self.downsample_rate = 1
152
+ if mm_projector_type == "identity":
153
+ self.layers = IdentityMap()
154
+ elif mm_projector_type == "linear":
155
+ self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size)
156
+ elif mm_projector_type == "mlp_downsample":
157
+ self.layers = nn.Sequential(
158
+ DownSampleBlock(),
159
+ nn.LayerNorm(config.mm_hidden_size * 4),
160
+ nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
161
+ nn.GELU(),
162
+ nn.Linear(config.hidden_size, config.hidden_size),
163
+ )
164
+ self.downsample_rate = 2
165
+ elif mm_projector_type == "mlp_downsample_2x2_fix":
166
+ self.layers = nn.Sequential(
167
+ DownSample2x2BlockFix(),
168
+ nn.LayerNorm(config.mm_hidden_size * 4),
169
+ nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
170
+ nn.GELU(),
171
+ nn.Linear(config.hidden_size, config.hidden_size),
172
+ )
173
+ self.downsample_rate = 2
174
+ elif mm_projector_type == "mlp_downsample_3x3_fix":
175
+ self.layers = nn.Sequential(
176
+ DownSample3x3BlockFix(),
177
+ nn.LayerNorm(config.mm_hidden_size * 9),
178
+ nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 3),
179
+ nn.GELU(),
180
+ nn.LayerNorm(config.mm_hidden_size * 3),
181
+ nn.Linear(config.mm_hidden_size * 3, config.hidden_size),
182
+ nn.GELU(),
183
+ nn.Linear(config.hidden_size, config.hidden_size),
184
+ )
185
+ self.downsample_rate = 3
186
+ elif mm_projector_type == "mlp_downsample_3x3_s2":
187
+ self.layers = nn.Sequential(
188
+ DownSample3x3BlockFix(),
189
+ nn.LayerNorm(config.mm_hidden_size * 9),
190
+ nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 3),
191
+ nn.GELU(),
192
+ nn.LayerNorm(config.mm_hidden_size * 3),
193
+ nn.Linear(config.mm_hidden_size * 3, config.mm_hidden_size),
194
+ nn.GELU(),
195
+ nn.LayerNorm(config.mm_hidden_size),
196
+ nn.Linear(config.mm_hidden_size, config.mm_hidden_size // 3),
197
+ nn.GELU(),
198
+ nn.LayerNorm(config.mm_hidden_size // 3),
199
+ nn.Linear(config.mm_hidden_size // 3, config.hidden_size),
200
+ nn.GELU(),
201
+ nn.Linear(config.hidden_size, config.hidden_size),
202
+ )
203
+ elif mm_projector_type == "mlp_downsample_3x3_s2_new":
204
+ self.layers = nn.Sequential(
205
+ DownSample3x3BlockFix(),
206
+ nn.LayerNorm(config.mm_hidden_size * 9),
207
+ nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 4),
208
+ nn.GELU(),
209
+ nn.LayerNorm(config.mm_hidden_size * 4),
210
+ nn.Linear(config.mm_hidden_size * 4, config.mm_hidden_size * 2),
211
+ nn.GELU(),
212
+ nn.LayerNorm(config.mm_hidden_size * 2),
213
+ nn.Linear(config.mm_hidden_size * 2, config.mm_hidden_size),
214
+ nn.GELU(),
215
+ nn.LayerNorm(config.mm_hidden_size),
216
+ nn.Linear(config.mm_hidden_size, config.mm_hidden_size // 3),
217
+ nn.GELU(),
218
+ nn.LayerNorm(config.mm_hidden_size // 3),
219
+ nn.Linear(config.mm_hidden_size // 3, config.hidden_size),
220
+ nn.GELU(),
221
+ nn.Linear(config.hidden_size, config.hidden_size),
222
+ )
223
+ else:
224
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type)
225
+ if mlp_gelu_match:
226
+ mlp_depth = int(mlp_gelu_match.group(1))
227
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
228
+ for _ in range(1, mlp_depth):
229
+ modules.append(nn.GELU())
230
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
231
+ self.layers = nn.Sequential(*modules)
232
+ else:
233
+ raise ValueError(f"Unknown projector type: {mm_projector_type}")
234
+
235
+ def forward(self, x, *args, **kwargs):
236
+ return self.layers(x)
builder.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import math
18
+ import os
19
+ import os.path as osp
20
+ import warnings
21
+ from dataclasses import asdict
22
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
23
+
24
+ import torch
25
+ import transformers
26
+ from huggingface_hub import file_exists, repo_exists
27
+ from huggingface_hub.utils import HFValidationError
28
+ from transformers import (
29
+ AutoConfig,
30
+ AutoModelForCausalLM,
31
+ AutoTokenizer,
32
+ PretrainedConfig,
33
+ PreTrainedModel,
34
+ PreTrainedTokenizer,
35
+ )
36
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
37
+
38
+ from .constants import MEDIA_TOKENS, SENTINEL_TOKEN
39
+ from .conversation import SeparatorStyle, default_conversation
40
+
41
+ DUMMY_CONVERSATION = [
42
+ {"from": "human", "value": "question"},
43
+ {"from": "gpt", "value": "answer"},
44
+ ] * 10
45
+
46
+
47
+ def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
48
+ """Tokenize a prompt and return input IDs."""
49
+ return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
50
+
51
+
52
+ def has_tokenizer(repo_id_or_path: str) -> bool:
53
+ """Check if a tokenizer exists at the given path or repository."""
54
+ # Check if the tokenizer is in a local directory
55
+ if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
56
+ return True
57
+
58
+ # Check if the tokenizer is in a Hugging Face Hub repo
59
+ try:
60
+ return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json")
61
+ except HFValidationError:
62
+ return False
63
+
64
+
65
+ def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
66
+ """Add sentinel token to tokenizer if not already present."""
67
+ if not hasattr(tokenizer, "sentinel_token"):
68
+ tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
69
+ tokenizer.sentinel_token = SENTINEL_TOKEN
70
+ tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
71
+
72
+
73
+ def tokenize_conversation_legacy(
74
+ messages: Sequence[Dict[str, str]],
75
+ tokenizer: transformers.PreTrainedTokenizer,
76
+ add_generation_prompt: bool = False,
77
+ overrides: Optional[Dict[str, str]] = None,
78
+ no_system_prompt: bool = False,
79
+ ) -> torch.Tensor:
80
+ """Tokenize conversation using legacy format."""
81
+ conv = default_conversation.copy()
82
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
83
+
84
+ if no_system_prompt:
85
+ conv.system = ""
86
+
87
+ # Skip the first message if it is not from human
88
+ if messages[0]["from"] != "human":
89
+ messages = messages[1:]
90
+
91
+ # Add a generation prompt if needed
92
+ if add_generation_prompt:
93
+ messages.append({"from": "gpt", "value": None})
94
+
95
+ conv.messages = []
96
+ for turn, message in enumerate(messages):
97
+ role = roles[message["from"]]
98
+ assert role == conv.roles[turn % 2]
99
+ if overrides is not None and message["from"] in overrides:
100
+ conv.append_message(role, overrides[message["from"]])
101
+ else:
102
+ conv.append_message(role, message["value"])
103
+
104
+ return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
105
+
106
+
107
+ def tokenize_conversation(
108
+ messages: Sequence[Dict[str, str]],
109
+ tokenizer: transformers.PreTrainedTokenizer,
110
+ add_generation_prompt: bool = False,
111
+ overrides: Optional[Dict[str, str]] = None,
112
+ no_system_prompt: bool = False,
113
+ ) -> torch.Tensor:
114
+ """Tokenize conversation using modern chat template format."""
115
+ # Normalize the conversation before tokenization
116
+ for message in messages:
117
+ message["value"] = message["value"].strip()
118
+
119
+ if default_conversation.sep_style != SeparatorStyle.AUTO:
120
+ return tokenize_conversation_legacy(
121
+ messages,
122
+ tokenizer,
123
+ add_generation_prompt=add_generation_prompt,
124
+ overrides=overrides,
125
+ no_system_prompt=no_system_prompt,
126
+ )
127
+
128
+ conversation = []
129
+ for m in messages:
130
+ message = {}
131
+ if m["from"] == "human":
132
+ message["role"] = "user"
133
+ elif m["from"] == "gpt":
134
+ message["role"] = "assistant"
135
+ else:
136
+ raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
137
+
138
+ message["content"] = m["value"]
139
+ if overrides is not None and m["from"] in overrides:
140
+ message["content"] = overrides[m["from"]]
141
+ conversation.append(message)
142
+
143
+ if no_system_prompt:
144
+ conversation = [{"role": "system", "content": ""}] + conversation
145
+
146
+ text = tokenizer.apply_chat_template(
147
+ conversation,
148
+ add_generation_prompt=add_generation_prompt,
149
+ tokenize=False,
150
+ )
151
+ return tokenizer_image_token(text, tokenizer, return_tensors="pt")
152
+
153
+
154
+ def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
155
+ """Infer stop tokens from tokenizer by analyzing dummy conversation."""
156
+ _maybe_add_sentinel_token(tokenizer)
157
+ template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
158
+
159
+ stop_tokens = {tokenizer.eos_token}
160
+ for k in range(template.size(0) - 1):
161
+ if template[k] == tokenizer.sentinel_token_id:
162
+ stop_token = tokenizer.decode(template[k + 1])
163
+ stop_tokens.add(stop_token)
164
+ return list(stop_tokens)
165
+
166
+
167
+ def context_length_extension(config):
168
+ """Extend context length using RoPE scaling if needed."""
169
+ orig_ctx_len = getattr(config, "max_position_embeddings", None)
170
+ model_max_length = getattr(config, "model_max_length", None)
171
+ if orig_ctx_len and model_max_length > orig_ctx_len:
172
+ print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
173
+ scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
174
+ config.rope_scaling = {"type": "linear", "factor": scaling_factor}
175
+ return config
176
+
177
+
178
+ def build_llm_and_tokenizer(
179
+ model_name_or_path: str,
180
+ config: PretrainedConfig,
181
+ attn_implementation=None,
182
+ model_max_length=None,
183
+ *args,
184
+ **kwargs,
185
+ ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
186
+ """Build language model and tokenizer from pretrained checkpoint."""
187
+ llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
188
+ llm_cfg._attn_implementation = attn_implementation
189
+ llm_cfg.model_max_length = model_max_length
190
+ if model_max_length is not None:
191
+ context_length_extension(llm_cfg)
192
+
193
+ # Quantization related
194
+ quantization_restore_from_checkpoint = False
195
+
196
+ if type(config.model_dtype) == str:
197
+ model_dtype = eval(config.model_dtype)
198
+ else:
199
+ model_dtype = config.model_dtype
200
+
201
+ if quantization_restore_from_checkpoint:
202
+ fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
203
+
204
+ llm = AutoModelForCausalLM.from_pretrained(
205
+ fp8_model_name_or_path, config=llm_cfg, torch_dtype=model_dtype, *args, **kwargs
206
+ )
207
+ else:
208
+ if is_deepspeed_zero3_enabled():
209
+ kwargs.pop("device_map")
210
+ llm = AutoModelForCausalLM.from_pretrained(
211
+ model_name_or_path, config=llm_cfg, torch_dtype=model_dtype, *args, **kwargs
212
+ )
213
+ print(f"Loaded model from {model_name_or_path} with dtype {model_dtype}")
214
+
215
+ # Locate the tokenizer.
216
+ llm_path = model_name_or_path
217
+ if not has_tokenizer(llm_path):
218
+ llm_path = osp.join(llm_path, "llm")
219
+ if not has_tokenizer(llm_path):
220
+ raise ValueError(f"Cannot find tokenizer in {llm_path}.")
221
+
222
+ tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False)
223
+ if model_max_length is not None:
224
+ tokenizer.model_max_length = model_max_length
225
+
226
+ # Load chat template if specified.
227
+ if getattr(config, "chat_template", None) is not None:
228
+ print(f"Using chat template: {config.chat_template}")
229
+ fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
230
+ if not os.path.exists(fpath):
231
+ fpath = os.path.join(os.path.dirname(model_name_or_path), f"{config.chat_template}.jinja")
232
+ with open(fpath) as fd:
233
+ chat_template = fd.read()
234
+ tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
235
+
236
+ # Set stop tokens for the tokenizer
237
+ tokenizer.stop_tokens = infer_stop_tokens(tokenizer)
238
+ tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens)
239
+
240
+ # Add media tokens to the tokenizer
241
+ tokenizer.media_tokens = MEDIA_TOKENS
242
+ tokenizer.media_token_ids = {}
243
+ for name, token in MEDIA_TOKENS.items():
244
+ if config.speech_tower_cfg is None and name == "speech":
245
+ continue
246
+ if config.sound_tower_cfg is None and name == "sound":
247
+ continue
248
+ tokenizer.add_tokens([token], special_tokens=True)
249
+ tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token)
250
+ tokenizer.media_tokens[name] = token
251
+
252
+ config.hidden_size = llm.config.hidden_size
253
+ return llm, tokenizer
config.json ADDED
The diff for this file is too large to render. See raw diff
 
configuration_vila.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import json
18
+ import math
19
+ import os
20
+ import os.path as osp
21
+ from copy import deepcopy
22
+ from threading import Thread
23
+ from typing import List, Optional
24
+
25
+ import torch
26
+ import torchvision
27
+ from PIL import Image
28
+ from transformers import (
29
+ AutoProcessor,
30
+ PretrainedConfig,
31
+ PreTrainedModel,
32
+ Qwen2Config,
33
+ Qwen2ForCausalLM,
34
+ Qwen2PreTrainedModel,
35
+ TextIteratorStreamer,
36
+ )
37
+
38
+
39
+ class VILAConfig(PretrainedConfig):
40
+ model_type = "vila"
41
+ keys_to_ignore_at_inference = ["past_key_values"]
42
+
43
+ def __init__(
44
+ self,
45
+ llm_cfg=None,
46
+ vision_tower_cfg=None,
47
+ mm_projector_cfg=None,
48
+ speech_tower_cfg=None,
49
+ sound_tower_cfg=None,
50
+ speech_mm_projector_cfg=None,
51
+ sound_mm_projector_cfg=None,
52
+ architectures=None,
53
+ resume_path=None,
54
+ hidden_size=None,
55
+ mm_hidden_size=None,
56
+ image_aspect_ratio=None,
57
+ num_video_frames=None,
58
+ fps=None,
59
+ mm_vision_select_layer=None,
60
+ mm_vision_select_feature=None,
61
+ mm_use_im_start_end=False,
62
+ mm_use_im_patch_token=False,
63
+ mm_projector_lr=None,
64
+ vision_tower_lr=None,
65
+ vision_resolution=None,
66
+ interpolate_mode=None,
67
+ s2=None,
68
+ dynamic_s2=None,
69
+ s2_scales=None,
70
+ s2_max_split_size=None,
71
+ s2_resize_output_to_scale_idx=0,
72
+ min_tiles: Optional[int] = 1,
73
+ max_tiles: Optional[int] = 12,
74
+ num_time_tokens=None,
75
+ time_token_format=None,
76
+ image_encoder: str = '{"_target_": "llava.model.encoders.BasicImageEncoder"}',
77
+ video_encoder: str = '{"_target_": "llava.model.encoders.TSPVideoEncoder"}',
78
+ sound_encoder: str = '{"_target_": "llava.model.encoders.BasicSoundEncoder"}',
79
+ speech_encoder: str = '{"_target_": "llava.model.encoders.BasicSpeechEncoder"}',
80
+ **kwargs,
81
+ ):
82
+ super().__init__(**kwargs)
83
+
84
+ self.architectures = architectures
85
+ self.llm_cfg = llm_cfg
86
+ self.vision_tower_cfg = vision_tower_cfg
87
+ self.mm_projector_cfg = mm_projector_cfg
88
+ self.speech_tower_cfg = speech_tower_cfg
89
+ self.sound_tower_cfg = sound_tower_cfg
90
+ self.speech_mm_projector_cfg = speech_mm_projector_cfg
91
+ self.sound_mm_projector_cfg = sound_mm_projector_cfg
92
+ self.resume_path = resume_path
93
+
94
+ self.hidden_size = hidden_size
95
+ self.mm_hidden_size = mm_hidden_size
96
+ self.image_aspect_ratio = image_aspect_ratio
97
+ self.num_video_frames = num_video_frames
98
+ self.fps = fps
99
+ self.mm_vision_select_layer = mm_vision_select_layer
100
+ self.mm_vision_select_feature = mm_vision_select_feature
101
+ self.mm_use_im_start_end = mm_use_im_start_end
102
+ self.mm_use_im_patch_token = mm_use_im_patch_token
103
+ self.mm_projector_lr = mm_projector_lr
104
+ self.vision_tower_lr = vision_tower_lr
105
+ self.vision_resolution = vision_resolution
106
+ self.interpolate_mode = interpolate_mode
107
+ self.s2 = s2
108
+ self.dynamic_s2 = dynamic_s2
109
+ self.s2_scales = s2_scales
110
+ self.s2_max_split_size = s2_max_split_size
111
+ self.s2_resize_output_to_scale_idx = s2_resize_output_to_scale_idx
112
+ self.min_tiles = min_tiles
113
+ self.max_tiles = max_tiles
114
+ self.num_time_tokens = num_time_tokens
115
+ self.time_token_format = time_token_format
116
+
117
+ self.image_encoder = image_encoder
118
+ self.video_encoder = video_encoder
119
+ self.sound_encoder = sound_encoder
120
+ self.speech_encoder = speech_encoder
121
+ self.audio_sampling_rate = 16000
122
+ self.audio_chunk_length = 120
123
+ self.interleaved_vis_aud_in_video = True
124
+ self.interleaved_video_segment_duration = 30
125
+ self.audio_hop_length = 60
126
+
127
+ super().__init__(**kwargs)
constants.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
18
+
19
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
20
+ WORKER_HEART_BEAT_INTERVAL = 15
21
+
22
+ LOGDIR = "."
23
+
24
+ # Model Constants
25
+ IGNORE_INDEX = -100
26
+ DEFAULT_IMAGE_TOKEN = "<image>"
27
+ DEFAULT_SOUND_TOKEN = "<sound>"
28
+ DEFAULT_SPEECH_TOKEN = "<speech>"
29
+ SENTINEL_TOKEN = "<vila/sentinel>"
30
+ DEFAULT_IM_START_TOKEN = "<im_start>"
31
+ DEFAULT_IM_END_TOKEN = "<im_end>"
32
+
33
+
34
+ SENTINEL_TOKEN = "<vila/sentinel>"
35
+
36
+ MEDIA_TOKENS = {
37
+ "image": "<image>",
38
+ "video": "<vila/video>",
39
+ "speech": "<speech>",
40
+ "sound": "<sound>",
41
+ }
42
+
43
+ # Token IDs for different model variants:
44
+ """
45
+ vila:
46
+ 151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
47
+ 151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
48
+ 151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
49
+ 151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
50
+ 151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
51
+ 151648: AddedToken("<vila/sentinel>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
52
+ 151649: AddedToken("<image>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
53
+ 151650: AddedToken("<vila/video>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
54
+
55
+ xvila:
56
+ 151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
57
+ 151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
58
+ 151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
59
+ 151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
60
+ 151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
61
+ 151648: AddedToken("<vila/sentinel>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
62
+ 151649: AddedToken("<image>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
63
+ 151650: AddedToken("<vila/video>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
64
+ 151651: AddedToken("<speech>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
65
+ 151652: AddedToken("<sound>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
66
+ 151653: AddedToken("<|image_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
67
+ 151654: AddedToken("<|image_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
68
+ 151655: AddedToken("<|video_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
69
+ 151656: AddedToken("<|video_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
70
+ 151657: AddedToken("<|speech_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
71
+ 151658: AddedToken("<|speech_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
72
+ 151659: AddedToken("<|sound_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
73
+ 151660: AddedToken("<|sound_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
74
+ """
75
+ MM_BOS_EOS_TOKENS = {
76
+ "image": ["<|image_bos|>", "<|image_eos|>"],
77
+ "video": ["<|video_bos|>", "<|video_eos|>"],
78
+ "speech": ["<|speech_bos|>", "<|speech_eos|>"],
79
+ "sound": ["<|sound_bos|>", "<|sound_eos|>"],
80
+ }
81
+
82
+ NUM_EXTRA_TOKENS_VILA = 8
83
+ NUM_EXTRA_TOKENS_XVILA = 10
conversation.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
17
+
18
+ import dataclasses
19
+ from enum import Enum, auto
20
+ from typing import List
21
+
22
+
23
+ class SeparatorStyle(Enum):
24
+ """Different separator style."""
25
+
26
+ AUTO = auto()
27
+ TWO = auto()
28
+ MPT = auto()
29
+ PLAIN = auto()
30
+ LLAMA_3 = auto()
31
+
32
+
33
+ @dataclasses.dataclass
34
+ class Conversation:
35
+ """A class that keeps all conversation history."""
36
+
37
+ system: str
38
+ roles: List[str]
39
+ messages: List[List[str]]
40
+ sep_style: SeparatorStyle = SeparatorStyle.AUTO
41
+ sep: str = "###"
42
+ sep2: str = None
43
+ version: str = "Unknown"
44
+
45
+ def get_prompt(self):
46
+ messages = self.messages
47
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
48
+ messages = self.messages.copy()
49
+ init_role, init_msg = messages[0].copy()
50
+ init_msg = init_msg[0].replace("<image>", "").strip()
51
+ messages[0] = (init_role, "<image>\n" + init_msg)
52
+
53
+ if self.sep_style == SeparatorStyle.TWO:
54
+ seps = [self.sep, self.sep2]
55
+ ret = self.system + seps[0]
56
+ for i, (role, message) in enumerate(messages):
57
+ if message:
58
+ if type(message) is tuple:
59
+ message, _, _ = message
60
+ ret += role + ": " + message + seps[i % 2]
61
+ else:
62
+ ret += role + ":"
63
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
64
+ ret = self.system + self.sep
65
+ for rid, (role, message) in enumerate(messages):
66
+ if message:
67
+ if type(message) is tuple:
68
+ message = message[0]
69
+ sep = self.sep if rid < len(messages) - 1 else self.sep2
70
+ ret += role + message + sep
71
+ else:
72
+ ret += role
73
+ elif self.sep_style == SeparatorStyle.MPT:
74
+ ret = self.system + self.sep
75
+ for role, message in messages:
76
+ if message:
77
+ if type(message) is tuple:
78
+ message, _, _ = message
79
+ ret += role + message + self.sep
80
+ else:
81
+ ret += role
82
+ elif self.sep_style == SeparatorStyle.PLAIN:
83
+ seps = [self.sep, self.sep2]
84
+ ret = self.system
85
+ for i, (role, message) in enumerate(messages):
86
+ if message:
87
+ if type(message) is tuple:
88
+ message, _, _ = message
89
+ ret += message + seps[i % 2]
90
+ else:
91
+ ret += ""
92
+ else:
93
+ raise ValueError(f"Invalid style: {self.sep_style}")
94
+
95
+ return ret
96
+
97
+ def append_message(self, role, message):
98
+ self.messages.append([role, message])
99
+
100
+ def copy(self):
101
+ return Conversation(
102
+ system=self.system,
103
+ roles=self.roles,
104
+ messages=[[x, y] for x, y in self.messages],
105
+ sep_style=self.sep_style,
106
+ sep=self.sep,
107
+ sep2=self.sep2,
108
+ version=self.version,
109
+ )
110
+
111
+
112
+ conv_auto = Conversation(
113
+ system="",
114
+ roles=("", ""),
115
+ messages=(),
116
+ sep_style=SeparatorStyle.AUTO,
117
+ sep="\n",
118
+ )
119
+
120
+ conv_vicuna_v1 = Conversation(
121
+ system="A chat between a curious user and an artificial intelligence assistant. "
122
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
123
+ roles=("USER", "ASSISTANT"),
124
+ version="v1",
125
+ messages=(),
126
+ sep_style=SeparatorStyle.TWO,
127
+ sep=" ",
128
+ sep2="</s>",
129
+ )
130
+
131
+ conv_llava_plain = Conversation(
132
+ system="",
133
+ roles=("", ""),
134
+ messages=(),
135
+ sep_style=SeparatorStyle.PLAIN,
136
+ sep="\n",
137
+ )
138
+
139
+ hermes_2 = Conversation(
140
+ system="<|im_start|>system\nAnswer the questions.",
141
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
142
+ sep_style=SeparatorStyle.MPT,
143
+ sep="<|im_end|>",
144
+ messages=(),
145
+ version="hermes-2",
146
+ )
147
+
148
+ llama_3_chat = Conversation(
149
+ system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
150
+ "You are able to understand the visual content that the user provides, "
151
+ "and assist the user with a variety of tasks using natural language.",
152
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
153
+ version="llama_v3",
154
+ messages=(),
155
+ sep_style=SeparatorStyle.LLAMA_3,
156
+ sep="<|eot_id|>",
157
+ sep2="<|end_of_text|>",
158
+ )
159
+
160
+
161
+ default_conversation = conv_auto
162
+ conv_templates = {
163
+ "auto": conv_auto,
164
+ "hermes-2": hermes_2,
165
+ "llama_3": llama_3_chat,
166
+ "v1": conv_vicuna_v1,
167
+ "vicuna_v1": conv_vicuna_v1,
168
+ "plain": conv_llava_plain,
169
+ }
170
+
171
+
172
+ CONVERSATION_MODE_MAPPING = {
173
+ "vila1.5-3b": "vicuna_v1",
174
+ "vila1.5-8b": "llama_3",
175
+ "vila1.5-13b": "vicuna_v1",
176
+ "vila1.5-40b": "hermes-2",
177
+ "llama-3": "llama_3",
178
+ "llama3": "llama_3",
179
+ }
180
+
181
+
182
+ def auto_set_conversation_mode(model_name_or_path: str) -> str:
183
+ """Automatically set conversation mode based on model name/path."""
184
+ global default_conversation
185
+ for k, v in CONVERSATION_MODE_MAPPING.items():
186
+ if k in model_name_or_path.lower():
187
+ print(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.")
188
+ default_conversation = conv_templates[v]
189
+ return
distributed.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import os
18
+ import warnings
19
+ from typing import Any, List, Optional
20
+
21
+ from torch import distributed as dist
22
+
23
+ __all__ = [
24
+ "init",
25
+ "is_initialized",
26
+ "size",
27
+ "rank",
28
+ "local_size",
29
+ "local_rank",
30
+ "is_main",
31
+ "barrier",
32
+ "gather",
33
+ "all_gather",
34
+ ]
35
+
36
+
37
+ def init() -> None:
38
+ if "RANK" not in os.environ:
39
+ warnings.warn("Environment variable `RANK` is not set. Skipping distributed initialization.")
40
+ return
41
+ dist.init_process_group(backend="nccl", init_method="env://")
42
+
43
+
44
+ def is_initialized() -> bool:
45
+ return dist.is_initialized()
46
+
47
+
48
+ def size() -> int:
49
+ return int(os.environ.get("WORLD_SIZE", 1))
50
+
51
+
52
+ def rank() -> int:
53
+ return int(os.environ.get("RANK", 0))
54
+
55
+
56
+ def local_size() -> int:
57
+ return int(os.environ.get("LOCAL_WORLD_SIZE", 1))
58
+
59
+
60
+ def local_rank() -> int:
61
+ return int(os.environ.get("LOCAL_RANK", 0))
62
+
63
+
64
+ def is_main() -> bool:
65
+ return rank() == 0
66
+
67
+
68
+ def barrier() -> None:
69
+ dist.barrier()
70
+
71
+
72
+ def gather(obj: Any, dst: int = 0) -> Optional[List[Any]]:
73
+ if not is_initialized():
74
+ return [obj]
75
+ if is_main():
76
+ objs = [None for _ in range(size())]
77
+ dist.gather_object(obj, objs, dst=dst)
78
+ return objs
79
+ else:
80
+ dist.gather_object(obj, dst=dst)
81
+ return None
82
+
83
+
84
+ def all_gather(obj: Any) -> List[Any]:
85
+ if not is_initialized():
86
+ return [obj]
87
+ objs = [None for _ in range(size())]
88
+ dist.all_gather_object(objs, obj)
89
+ return objs
environment_setup.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -e
3
+
4
+ CONDA_ENV=${1:-""}
5
+ if [ -n "$CONDA_ENV" ]; then
6
+ # This is required to activate conda environment
7
+ eval "$(conda shell.bash hook)"
8
+
9
+ conda create -n $CONDA_ENV python=3.10.14 -y
10
+ conda activate $CONDA_ENV
11
+ # This is optional if you prefer to use built-in nvcc
12
+ conda install -c nvidia cuda-toolkit=12.2 -y
13
+ else
14
+ echo "Skipping conda environment creation. Make sure you have the correct environment activated."
15
+ fi
16
+
17
+ # Using uv to speedup installations
18
+ pip install uv
19
+ alias uvp="uv pip"
20
+
21
+ echo "[INFO] Using python $(which python)"
22
+ echo "[INFO] Using pip $(which pip)"
23
+ echo "[INFO] Using uv $(which uv)"
24
+
25
+ # This is required to enable PEP 660 support
26
+ uv pip install --upgrade pip setuptools
27
+
28
+ # Install FlashAttention2
29
+ uv pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.8/flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
30
+
31
+ # Install VILA
32
+ uv pip install -e ".[train,eval]"
33
+
34
+ # numpy introduce a lot dependencies issues, separate from pyproject.yaml
35
+ pip install numpy==1.26.4
36
+
37
+ # audio
38
+ uv pip install soundfile librosa openai-whisper ftfy
39
+ conda install -c conda-forge ffmpeg
40
+ uv pip install jiwer
41
+
42
+ # Downgrade protobuf to 3.20 for backward compatibility
43
+ uv pip install protobuf==3.20.*
44
+
45
+ # Replace transformers and deepspeed files
46
+ site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])')
47
+ cp -rv ./transformers/modeling_utils.py $site_pkg_path/transformers/modeling_utils.py # for using qwen 2.5 omni checkpoint
48
+
49
+ # for benchmark adoption
50
+ uv pip install faiss-gpu-cu12
51
+
52
+ # Quantization requires the newest triton version, and introduce dependency issue
53
+ uv pip install triton==3.1.0 # we don't need this version if we do not use FP8LinearQwen2Config, QLlavaLlamaConfig, etc. It is not compatible with mamba-ssm.
54
+
55
+ uv pip install kaldiio
56
+
57
+ # for rotary embedding
58
+ uv pip install beartype
59
+
60
+ uv pip install pydantic==1.10.22
example_infer.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ from transformers import AutoProcessor, AutoModel, AutoConfig, GenerationConfig
18
+ import torch
19
+ import os
20
+ import time
21
+ from pathlib import Path
22
+ from typing import List, Dict, Any, Optional, Union
23
+ import logging
24
+ import sys
25
+ os.environ["HF_HUB_OFFLINE"] = "1" # Use local cache for models
26
+
27
+ # Set up logging
28
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
29
+ logger = logging.getLogger(__name__)
30
+
31
+ def add_to_sys_path_direct(model_path):
32
+ """Add model path directly to sys.path"""
33
+ if model_path not in sys.path:
34
+ sys.path.insert(0, model_path) # Insert at beginning for priority
35
+ print(f"✓ Added to sys.path: {model_path}")
36
+ else:
37
+ print(f"Already in sys.path: {model_path}")
38
+
39
+ class NVOmniVideoInference:
40
+ """A class to handle NVOmni video model inference with improved error handling and flexibility."""
41
+
42
+ def __init__(self, model_path: str, torch_dtype="torch.float16", device_map="auto"):
43
+ """
44
+ Initialize the NVOmni model for video inference.
45
+
46
+ Args:
47
+ model_path (str): Path to the model directory
48
+ torch_dtype: PyTorch data type for model weights
49
+ device_map (str): Device mapping strategy for model loading
50
+ """
51
+ self.model_path = model_path
52
+ self.torch_dtype = torch_dtype
53
+ self.device_map = device_map
54
+ self.model = None
55
+ self.processor = None
56
+ self.config = None
57
+ self.device = None
58
+
59
+ self.load_model()
60
+
61
+ def validate_paths(self, model_path: str, video_path: str = None) -> bool:
62
+ """Validate that required paths exist."""
63
+ if not Path(model_path).exists():
64
+ logger.error(f"Model path does not exist: {model_path}")
65
+ return False
66
+
67
+ if video_path and not Path(video_path).exists():
68
+ logger.error(f"Video path does not exist: {video_path}")
69
+ return False
70
+
71
+ return True
72
+
73
+ def load_model(self) -> bool:
74
+ """Load the model, processor, and config with error handling."""
75
+ if not self.validate_paths(self.model_path):
76
+ return False
77
+
78
+ if True:
79
+ logger.info("Loading model configuration...")
80
+ self.config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True)
81
+
82
+ logger.info("Loading model...")
83
+ start_time = time.time()
84
+ self.model = AutoModel.from_pretrained(
85
+ self.model_path,
86
+ trust_remote_code=True,
87
+ torch_dtype=self.torch_dtype,
88
+ device_map=self.device_map,
89
+ low_cpu_mem_usage=True # More memory efficient loading
90
+ )#.to(eval(self.torch_dtype))
91
+ load_time = time.time() - start_time
92
+ logger.info(f"Model loaded in {load_time:.2f} seconds")
93
+
94
+ logger.info("Loading processor...")
95
+ self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
96
+
97
+ # Set device for single-device setups
98
+ if hasattr(self.model, 'device'):
99
+ self.device = self.model.device
100
+ else:
101
+ self.device = next(self.model.parameters()).device if self.model.parameters() else torch.device('cpu')
102
+
103
+ logger.info(f"Model successfully loaded on device: {self.device}")
104
+ self._print_model_info()
105
+ return True
106
+
107
+ def _print_model_info(self):
108
+ """Print useful information about the loaded model."""
109
+ logger.info("=" * 50)
110
+ logger.info("MODEL INFORMATION")
111
+ logger.info("=" * 50)
112
+
113
+ if self.config:
114
+ logger.info(f"Model type: {getattr(self.config, 'model_type', 'Unknown')}")
115
+ logger.info(f"Hidden size: {getattr(self.config, 'hidden_size', 'Unknown')}")
116
+
117
+ if self.model and torch.cuda.is_available():
118
+ logger.info(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
119
+ logger.info(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
120
+
121
+ def create_conversation(self, video_path: str, text_prompt: str) -> List[Dict[str, Any]]:
122
+ """
123
+ Create a conversation format for the model.
124
+
125
+ Args:
126
+ video_path (str): Path to the video file
127
+ text_prompt (str): Text prompt for the model
128
+
129
+ Returns:
130
+ List[Dict]: Conversation in the expected format
131
+ """
132
+ return [{
133
+ "role": "user",
134
+ "content": [
135
+ {"type": "video", "video": video_path},
136
+ {"type": "text", "text": text_prompt}
137
+ ]
138
+ }]
139
+
140
+ @torch.inference_mode()
141
+ def generate_response(
142
+ self,
143
+ video_path: str,
144
+ text_prompt: str,
145
+ max_new_tokens: int = 256,
146
+ temperature: float = None,
147
+ top_p: float = None,
148
+ do_sample: bool = None,
149
+ num_video_frames: int = -1,
150
+ load_audio_in_video: bool = True,
151
+ audio_length: Union[int, str] = "max_3600",
152
+ ) -> Optional[str]:
153
+ """
154
+ Generate a response from the model given a video and text prompt.
155
+
156
+ Args:
157
+ video_path (str): Path to the video file
158
+ text_prompt (str): Text prompt for the model
159
+ max_new_tokens (int): Maximum number of new tokens to generate
160
+ temperature (float): Sampling temperature
161
+ top_p (float): Top-p sampling parameter
162
+ do_sample (bool): Whether to use sampling
163
+ custom_generation_config (GenerationConfig): Custom generation configuration
164
+
165
+ Returns:
166
+ Optional[str]: Generated response or None if failed
167
+ """
168
+ if not self.model or not self.processor:
169
+ logger.error("Model or processor not loaded. Please initialize the model first.")
170
+ return None
171
+
172
+ if not self.validate_paths(self.model_path, video_path):
173
+ return None
174
+
175
+ # try:
176
+ if True:
177
+
178
+ logger.info(f"Processing video: {video_path}")
179
+ logger.info(f"Text prompt: {text_prompt}")
180
+
181
+ # Create conversation
182
+ conversation = self.create_conversation(video_path, text_prompt)
183
+
184
+ # Apply chat template
185
+ text = self.processor.apply_chat_template(
186
+ conversation,
187
+ tokenize=False,
188
+ add_generation_prompt=True
189
+ )
190
+ logger.info(f"Chat template applied")
191
+
192
+ # set model params
193
+ self.model.config.load_audio_in_video = load_audio_in_video
194
+ self.processor.config.load_audio_in_video = load_audio_in_video
195
+ if num_video_frames > 0:
196
+ self.model.config.num_video_frames = num_video_frames
197
+ self.processor.config.num_video_frames = num_video_frames
198
+ if audio_length != -1:
199
+ self.model.config.audio_chunk_length = audio_length
200
+ self.processor.config.audio_chunk_length = audio_length
201
+ logger.info(f"Model config - load_audio_in_video: {self.model.config.load_audio_in_video}, num_video_frames: {self.model.config.num_video_frames}, audio_chunk_length: {self.model.config.audio_chunk_length}")
202
+
203
+ # Process inputs
204
+ start_time = time.time()
205
+ inputs = self.processor([text])
206
+
207
+ # Move inputs to the correct device if needed
208
+ if hasattr(inputs, 'input_ids') and inputs.input_ids is not None:
209
+ inputs.input_ids = inputs.input_ids.to(self.device)
210
+
211
+ processing_time = time.time() - start_time
212
+ logger.info(f"Input processing completed in {processing_time:.2f} seconds")
213
+
214
+ logger.info("Generating response...")
215
+ start_time = time.time()
216
+
217
+ generation_kwargs = {"max_new_tokens": max_new_tokens, "max_length": 99999999}
218
+ if top_p is not None:
219
+ generation_kwargs["top_p"] = top_p
220
+ if do_sample is not None:
221
+ generation_kwargs["do_sample"] = do_sample
222
+ if temperature is not None:
223
+ generation_kwargs["temperature"] = temperature
224
+
225
+ generation_config = self.model.default_generation_config
226
+ generation_config.update(**generation_kwargs)
227
+
228
+ logger.info(f"Generation config: {generation_config.to_dict()}")
229
+
230
+
231
+ with torch.no_grad():
232
+ output_ids = self.model.generate(
233
+ input_ids=inputs.input_ids,
234
+ media=getattr(inputs, 'media', None),
235
+ media_config=getattr(inputs, 'media_config', None),
236
+ generation_config=generation_config,
237
+ )
238
+
239
+ generation_time = time.time() - start_time
240
+ logger.info(f"Generation completed in {generation_time:.2f} seconds")
241
+
242
+ # Decode response
243
+ response = self.processor.tokenizer.batch_decode(
244
+ output_ids,
245
+ skip_special_tokens=True
246
+ )[0]
247
+
248
+ return response
249
+
250
+ def batch_generate(
251
+ self,
252
+ video_text_pairs: List[tuple],
253
+ **generation_kwargs
254
+ ) -> List[Optional[str]]:
255
+ """
256
+ Generate responses for multiple video-text pairs.
257
+
258
+ Args:
259
+ video_text_pairs (List[tuple]): List of (video_path, text_prompt) tuples
260
+ **generation_kwargs: Arguments passed to generate_response
261
+
262
+ Returns:
263
+ List[Optional[str]]: List of generated responses
264
+ """
265
+ responses = []
266
+ for i, (video_path, text_prompt) in enumerate(video_text_pairs):
267
+ logger.info(f"Processing batch item {i+1}/{len(video_text_pairs)}")
268
+ response = self.generate_response(video_path, text_prompt, **generation_kwargs)
269
+ responses.append(response)
270
+
271
+ # Clear cache between generations to manage memory
272
+ if torch.cuda.is_available():
273
+ torch.cuda.empty_cache()
274
+
275
+ return responses
276
+
277
+ def main():
278
+ """Main function demonstrating usage of the NVOmni model."""
279
+
280
+ # Configuration
281
+ MODEL_PATH = "./"
282
+ VIDEO_PATH = "xxx.mp4"
283
+ TEXT_PROMPT = "Assess the video, followed by a detailed description of it's video and audio contents."
284
+
285
+ num_video_frames=128
286
+ audio_length="max_3600"
287
+ load_audio_in_video=True
288
+
289
+ add_to_sys_path_direct(MODEL_PATH)
290
+
291
+ # Initialize the inference class
292
+ logger.info("Initializing NVOmni Video Inference...")
293
+ inferencer = NVOmniVideoInference(MODEL_PATH, torch_dtype="torch.float16")
294
+
295
+ if inferencer.model is None:
296
+ logger.error("Failed to initialize model. Exiting.")
297
+ return
298
+
299
+ # Generate response
300
+ logger.info("Starting inference...")
301
+ response = inferencer.generate_response(
302
+ video_path=VIDEO_PATH,
303
+ text_prompt=TEXT_PROMPT,
304
+ num_video_frames=num_video_frames,
305
+ load_audio_in_video=load_audio_in_video,
306
+ audio_length=audio_length,
307
+ max_new_tokens=1024,
308
+ )
309
+
310
+ if response:
311
+ print("\n" + "="*60)
312
+ print("GENERATED RESPONSE")
313
+ print("="*60)
314
+ print(response)
315
+ print("="*60)
316
+ else:
317
+ logger.error("Failed to generate response")
318
+
319
+ # Example of batch processing
320
+ if False:
321
+ logger.info("\nExample: Batch processing")
322
+ batch_pairs = [
323
+ (VIDEO_PATH, "What is happening in this video?"),
324
+ (VIDEO_PATH, "Describe the audio content of this video."),
325
+ ]
326
+
327
+ batch_responses = inferencer.batch_generate(batch_pairs, max_new_tokens=128)
328
+
329
+ for i, (pair, response) in enumerate(zip(batch_pairs, batch_responses)):
330
+ print(f"\n--- Batch Response {i+1} ---")
331
+ print(f"Prompt: {pair[1]}")
332
+ print(f"Response: {response}")
333
+
334
+ if __name__ == "__main__":
335
+ main()
example_mini_audio.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Example script for audio transcription using the model.
18
+
19
+ This script demonstrates how to:
20
+ 1. Load the model and processor
21
+ 2. Configure audio processing parameters
22
+ 3. Process audio input
23
+ 4. Generate transcription output
24
+
25
+ Usage:
26
+ python example_mini_audio.py --model_path <path_to_model> --audio_path <path_to_audio>
27
+ """
28
+
29
+ from transformers import AutoProcessor, AutoModel, AutoConfig, AutoModelForCausalLM
30
+ import torch
31
+ import os
32
+ import argparse
33
+
34
+ # Configuration
35
+ parser = argparse.ArgumentParser(description="Audio transcription example")
36
+ parser.add_argument("--model_path", type=str, default="./", help="Path to the model")
37
+ parser.add_argument("--audio_path", type=str, required=True, help="Path to the audio file")
38
+ parser.add_argument("--max_new_tokens", type=int, default=1024, help="Maximum number of tokens to generate")
39
+ parser.add_argument("--num_video_frames", type=int, default=128, help="Number of video frames to process")
40
+ parser.add_argument("--audio_length", type=str, default="max_3600", help="Maximum audio length")
41
+
42
+ args = parser.parse_args()
43
+
44
+ model_path = args.model_path
45
+ audio_path = args.audio_path
46
+ generation_kwargs = {"max_new_tokens": args.max_new_tokens, "max_length": 99999999}
47
+ load_audio_in_video = True
48
+ num_video_frames = args.num_video_frames
49
+ audio_length = args.audio_length
50
+
51
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
52
+
53
+ model = AutoModel.from_pretrained(model_path,
54
+ trust_remote_code=True,
55
+ torch_dtype="torch.float16",
56
+ device_map="auto")
57
+
58
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
59
+ generation_config = model.default_generation_config
60
+ generation_config.update(**generation_kwargs)
61
+
62
+ model.config.load_audio_in_video = load_audio_in_video
63
+ processor.config.load_audio_in_video = load_audio_in_video
64
+ if num_video_frames > 0:
65
+ model.config.num_video_frames = num_video_frames
66
+ processor.config.num_video_frames = num_video_frames
67
+ if audio_length != -1:
68
+ model.config.audio_chunk_length = audio_length
69
+ processor.config.audio_chunk_length = audio_length
70
+
71
+
72
+ conversation = [{
73
+ "role": "user",
74
+ "content": [
75
+ {"type": "audio", "audio": audio_path},
76
+ {"type": "text", "text": "Transcribe the whole speech."}
77
+ ]
78
+ }]
79
+ text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
80
+
81
+ inputs = processor([text])
82
+
83
+ output_ids = model.generate(
84
+ input_ids=inputs.input_ids,
85
+ media=getattr(inputs, 'media', None),
86
+ media_config=getattr(inputs, 'media_config', None),
87
+ generation_config=generation_config,
88
+ )
89
+ print(processor.tokenizer.batch_decode(output_ids, skip_special_tokens=True))
example_mini_image.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Example script for image understanding using the model.
18
+
19
+ This script demonstrates how to:
20
+ 1. Load the model and processor
21
+ 2. Process image input
22
+ 3. Generate description output
23
+
24
+ Usage:
25
+ python example_mini_image.py --model_path <path_to_model> --image_path <path_to_image>
26
+ """
27
+
28
+ from transformers import AutoProcessor, AutoModel, AutoConfig, AutoModelForCausalLM
29
+ import torch
30
+ import os
31
+ import argparse
32
+
33
+ # Configuration
34
+ parser = argparse.ArgumentParser(description="Image understanding example")
35
+ parser.add_argument("--model_path", type=str, default="./", help="Path to the model")
36
+ parser.add_argument("--image_path", type=str, required=True, help="Path to the image file")
37
+ parser.add_argument("--max_new_tokens", type=int, default=1024, help="Maximum number of tokens to generate")
38
+ parser.add_argument("--prompt", type=str, default="Describe the image in detail.", help="Text prompt for the model")
39
+
40
+ args = parser.parse_args()
41
+
42
+ model_path = args.model_path
43
+ image_path = args.image_path
44
+ generation_kwargs = {"max_new_tokens": args.max_new_tokens, "max_length": 99999999}
45
+
46
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
47
+
48
+ model = AutoModel.from_pretrained(
49
+ model_path,
50
+ trust_remote_code=True,
51
+ torch_dtype=torch.float16,
52
+ device_map="auto"
53
+ )
54
+
55
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
56
+ generation_config = model.default_generation_config
57
+ generation_config.update(**generation_kwargs)
58
+
59
+ conversation = [{
60
+ "role": "user",
61
+ "content": [
62
+ {"type": "image", "image": image_path},
63
+ {"type": "text", "text": args.prompt}
64
+ ]
65
+ }]
66
+ text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
67
+
68
+ inputs = processor([text])
69
+
70
+ output_ids = model.generate(
71
+ input_ids=inputs.input_ids,
72
+ media=getattr(inputs, 'media', None),
73
+ media_config=getattr(inputs, 'media_config', None),
74
+ generation_config=generation_config,
75
+ )
76
+ print(processor.tokenizer.batch_decode(output_ids, skip_special_tokens=True))
example_mini_video.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Example script for video understanding using the model.
18
+
19
+ This script demonstrates how to:
20
+ 1. Load the model and processor
21
+ 2. Configure video and audio processing parameters
22
+ 3. Process video input with optional audio
23
+ 4. Generate description output
24
+
25
+ Usage:
26
+ python example_mini_video.py --model_path <path_to_model> --video_path <path_to_video>
27
+ """
28
+
29
+ from transformers import AutoProcessor, AutoModel, AutoConfig, AutoModelForCausalLM
30
+ import torch
31
+ import os
32
+ import argparse
33
+
34
+ # Configuration
35
+ parser = argparse.ArgumentParser(description="Video understanding example")
36
+ parser.add_argument("--model_path", type=str, default="./", help="Path to the model")
37
+ parser.add_argument("--video_path", type=str, required=True, help="Path to the video file")
38
+ parser.add_argument("--max_new_tokens", type=int, default=1024, help="Maximum number of tokens to generate")
39
+ parser.add_argument("--num_video_frames", type=int, default=128, help="Number of video frames to process")
40
+ parser.add_argument("--audio_length", type=str, default="max_3600", help="Maximum audio length")
41
+ parser.add_argument("--prompt", type=str, default="What are they talking about in detail?", help="Text prompt for the model")
42
+ parser.add_argument("--load_audio", action="store_true", default=True, help="Load audio from video")
43
+
44
+ args = parser.parse_args()
45
+
46
+ model_path = args.model_path
47
+ video_path = args.video_path
48
+ generation_kwargs = {"max_new_tokens": args.max_new_tokens, "max_length": 99999999}
49
+ load_audio_in_video = args.load_audio
50
+ num_video_frames = args.num_video_frames
51
+ audio_length = args.audio_length
52
+ text_prompt = args.prompt
53
+
54
+ assert os.path.exists(video_path), f"Video path {video_path} does not exist."
55
+
56
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
57
+
58
+ model = AutoModel.from_pretrained(
59
+ model_path,
60
+ trust_remote_code=True,
61
+ torch_dtype=torch.float16,
62
+ device_map="auto"
63
+ )
64
+
65
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
66
+ generation_config = model.default_generation_config
67
+ generation_config.update(**generation_kwargs)
68
+
69
+ model.config.load_audio_in_video = load_audio_in_video
70
+ processor.config.load_audio_in_video = load_audio_in_video
71
+ if num_video_frames > 0:
72
+ model.config.num_video_frames = num_video_frames
73
+ processor.config.num_video_frames = num_video_frames
74
+ if audio_length != -1:
75
+ model.config.audio_chunk_length = audio_length
76
+ processor.config.audio_chunk_length = audio_length
77
+
78
+ def forward_inference(video_path, text_prompt):
79
+ """Run inference on video with text prompt."""
80
+ print(f"Text prompt: {text_prompt}")
81
+ print(f"Video path: {video_path}")
82
+ conversation = [{
83
+ "role": "user",
84
+ "content": [
85
+ {"type": "video", "video": video_path},
86
+ {"type": "text", "text": text_prompt}
87
+ ]
88
+ }]
89
+ text = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
90
+
91
+ inputs = processor([text])
92
+
93
+ output_ids = model.generate(
94
+ input_ids=inputs.input_ids,
95
+ media=getattr(inputs, 'media', None),
96
+ media_config=getattr(inputs, 'media_config', None),
97
+ generation_config=generation_config,
98
+ )
99
+ print(processor.tokenizer.batch_decode(output_ids, skip_special_tokens=True))
100
+
101
+ forward_inference(video_path, text_prompt)
llm/added_tokens.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<image>": 151649,
3
+ "<sound>": 151652,
4
+ "<speech>": 151651,
5
+ "<vila/sentinel>": 151648,
6
+ "<vila/video>": 151650,
7
+ "<|endoftext|>": 151643,
8
+ "<|im_end|>": 151645,
9
+ "<|im_start|>": 151644,
10
+ "<|image_bos|>": 151653,
11
+ "<|image_eos|>": 151654,
12
+ "<|sound_bos|>": 151659,
13
+ "<|sound_eos|>": 151660,
14
+ "<|speech_bos|>": 151657,
15
+ "<|speech_eos|>": 151658,
16
+ "<|video_bos|>": 151655,
17
+ "<|video_eos|>": 151656,
18
+ "[BOS]": 151646,
19
+ "[PAD]": 151647
20
+ }
llm/config.json ADDED
The diff for this file is too large to render. See raw diff
 
llm/generation_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "repetition_penalty": 1.05,
10
+ "temperature": 0.7,
11
+ "top_k": 20,
12
+ "top_p": 0.8,
13
+ "transformers_version": "4.46.0"
14
+ }
llm/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
llm/model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94f591013a08df71d152c96e5bc415bedc434bf30514fb03bd39c8d49e7161cd
3
+ size 4874772072
llm/model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f8d257b5d45c7d51d3abd0b25a524ae8234cca1d6536976dfa833aa9bb06ffe
3
+ size 4932751008
llm/model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:872102a5b7897af896ea518db185fda955786e8f45d70a49c48bbb8a4a45d305
3
+ size 4330865200
llm/model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b82d845caeacc0f77859bece772e05aaf828b51b25946d4ab58801845130b30
3
+ size 1087106176
llm/model.safetensors.index.json ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 15225455616
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00004-of-00004.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00004.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
13
+ "model.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
14
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
15
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
16
+ "model.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
17
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
18
+ "model.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
19
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
20
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
21
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
22
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
23
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
24
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
25
+ "model.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
26
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
27
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
28
+ "model.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
29
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
30
+ "model.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
31
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
32
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
33
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
34
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
35
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
36
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
37
+ "model.layers.10.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
38
+ "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
39
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
40
+ "model.layers.10.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
41
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
42
+ "model.layers.10.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
43
+ "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
44
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
45
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
46
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
47
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
48
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
49
+ "model.layers.11.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
50
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
51
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
52
+ "model.layers.11.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
53
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
54
+ "model.layers.11.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
55
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
56
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
57
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
58
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
59
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
60
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
61
+ "model.layers.12.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
62
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
63
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
64
+ "model.layers.12.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
65
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
66
+ "model.layers.12.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
67
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
68
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
69
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
70
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
71
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
72
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
73
+ "model.layers.13.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
74
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
75
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
76
+ "model.layers.13.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
77
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
78
+ "model.layers.13.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
79
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
80
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
81
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
82
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
83
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
84
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
85
+ "model.layers.14.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
86
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
87
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
88
+ "model.layers.14.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
89
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
90
+ "model.layers.14.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
91
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
92
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
93
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
94
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
95
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
96
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
97
+ "model.layers.15.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
98
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
99
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
100
+ "model.layers.15.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
101
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
102
+ "model.layers.15.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
103
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
104
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
105
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
106
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
107
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
108
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
109
+ "model.layers.16.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
110
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
111
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
112
+ "model.layers.16.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
113
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
114
+ "model.layers.16.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
115
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
116
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
117
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
118
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
119
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
120
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
121
+ "model.layers.17.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
122
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
123
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
124
+ "model.layers.17.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
125
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
126
+ "model.layers.17.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
127
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
128
+ "model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
129
+ "model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
130
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
131
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
132
+ "model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
133
+ "model.layers.18.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
134
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
135
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
136
+ "model.layers.18.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
137
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
138
+ "model.layers.18.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
139
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
140
+ "model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
141
+ "model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
142
+ "model.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
143
+ "model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
144
+ "model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
145
+ "model.layers.19.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
146
+ "model.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
147
+ "model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
148
+ "model.layers.19.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
149
+ "model.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
150
+ "model.layers.19.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
151
+ "model.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
152
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
153
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
154
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
155
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
156
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
157
+ "model.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
158
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
159
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
160
+ "model.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
161
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
162
+ "model.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
163
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
164
+ "model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
165
+ "model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
166
+ "model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
167
+ "model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
168
+ "model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
169
+ "model.layers.20.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
170
+ "model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
171
+ "model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
172
+ "model.layers.20.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
173
+ "model.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
174
+ "model.layers.20.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
175
+ "model.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
176
+ "model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
177
+ "model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
178
+ "model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
179
+ "model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
180
+ "model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
181
+ "model.layers.21.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
182
+ "model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
183
+ "model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
184
+ "model.layers.21.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
185
+ "model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
186
+ "model.layers.21.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
187
+ "model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
188
+ "model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
189
+ "model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
190
+ "model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
191
+ "model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
192
+ "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
193
+ "model.layers.22.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
194
+ "model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
195
+ "model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
196
+ "model.layers.22.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
197
+ "model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
198
+ "model.layers.22.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
199
+ "model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
200
+ "model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
201
+ "model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
202
+ "model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
203
+ "model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
204
+ "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
205
+ "model.layers.23.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
206
+ "model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
207
+ "model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
208
+ "model.layers.23.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
209
+ "model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
210
+ "model.layers.23.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
211
+ "model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
212
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
213
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
214
+ "model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
215
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
216
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
217
+ "model.layers.24.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
218
+ "model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
219
+ "model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
220
+ "model.layers.24.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
221
+ "model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
222
+ "model.layers.24.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
223
+ "model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
224
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
225
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
226
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
227
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
228
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
229
+ "model.layers.25.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
230
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
231
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
232
+ "model.layers.25.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
233
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
234
+ "model.layers.25.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
235
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
236
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
237
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
238
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
239
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
240
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
241
+ "model.layers.26.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
242
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
243
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
244
+ "model.layers.26.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
245
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
246
+ "model.layers.26.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
247
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
248
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
249
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
250
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
251
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
252
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
253
+ "model.layers.27.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
254
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
255
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
256
+ "model.layers.27.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
257
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
258
+ "model.layers.27.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
259
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
260
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
261
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
262
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
263
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
264
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
265
+ "model.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
266
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
267
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
268
+ "model.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
269
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
270
+ "model.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
271
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
272
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
273
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
274
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
275
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
276
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
277
+ "model.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
278
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
279
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
280
+ "model.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
281
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
282
+ "model.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
283
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
284
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
285
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
286
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
287
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
288
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
289
+ "model.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
290
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
291
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
292
+ "model.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
293
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
294
+ "model.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
295
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
296
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
297
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
298
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
299
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
300
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
301
+ "model.layers.6.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
302
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
303
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
304
+ "model.layers.6.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
305
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
306
+ "model.layers.6.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
307
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
308
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
309
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
310
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
311
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
312
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
313
+ "model.layers.7.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
314
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
315
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
316
+ "model.layers.7.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
317
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
318
+ "model.layers.7.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
319
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
320
+ "model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
321
+ "model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
322
+ "model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
323
+ "model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
324
+ "model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
325
+ "model.layers.8.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
326
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
327
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
328
+ "model.layers.8.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
329
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
330
+ "model.layers.8.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
331
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
332
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
333
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
334
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
335
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
336
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
337
+ "model.layers.9.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
338
+ "model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
339
+ "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
340
+ "model.layers.9.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
341
+ "model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
342
+ "model.layers.9.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
343
+ "model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
344
+ "model.norm.weight": "model-00003-of-00004.safetensors"
345
+ }
346
+ }
llm/special_tokens_map.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "<|sound_bos|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ {
11
+ "content": "<|sound_eos|>",
12
+ "lstrip": false,
13
+ "normalized": false,
14
+ "rstrip": false,
15
+ "single_word": false
16
+ }
17
+ ],
18
+ "bos_token": {
19
+ "content": "[BOS]",
20
+ "lstrip": false,
21
+ "normalized": false,
22
+ "rstrip": false,
23
+ "single_word": false
24
+ },
25
+ "eos_token": {
26
+ "content": "<|im_end|>",
27
+ "lstrip": false,
28
+ "normalized": false,
29
+ "rstrip": false,
30
+ "single_word": false
31
+ },
32
+ "pad_token": {
33
+ "content": "[PAD]",
34
+ "lstrip": false,
35
+ "normalized": false,
36
+ "rstrip": false,
37
+ "single_word": false
38
+ }
39
+ }
llm/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:491635783283196cfd9ab5d019617234a246b35a58da4761afd6ad77380f43c8
3
+ size 11415920
llm/tokenizer_config.json ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "151643": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "151644": {
13
+ "content": "<|im_start|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "151645": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "151646": {
29
+ "content": "[BOS]",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "151647": {
37
+ "content": "[PAD]",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "151648": {
45
+ "content": "<vila/sentinel>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "151649": {
53
+ "content": "<image>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "151650": {
61
+ "content": "<vila/video>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "151651": {
69
+ "content": "<speech>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "151652": {
77
+ "content": "<sound>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "151653": {
85
+ "content": "<|image_bos|>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "151654": {
93
+ "content": "<|image_eos|>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "151655": {
101
+ "content": "<|video_bos|>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "151656": {
109
+ "content": "<|video_eos|>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "151657": {
117
+ "content": "<|speech_bos|>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "151658": {
125
+ "content": "<|speech_eos|>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "151659": {
133
+ "content": "<|sound_bos|>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "151660": {
141
+ "content": "<|sound_eos|>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ }
148
+ },
149
+ "additional_special_tokens": [
150
+ "<|sound_bos|>",
151
+ "<|sound_eos|>"
152
+ ],
153
+ "bos_token": "[BOS]",
154
+ "chat_template": "{% if messages[0]['role'] != 'system' %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{% for message in messages if message['content'] is not none %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
155
+ "clean_up_tokenization_spaces": false,
156
+ "eos_token": "<|im_end|>",
157
+ "errors": "replace",
158
+ "legacy": false,
159
+ "model_max_length": 14000,
160
+ "pad_token": "[PAD]",
161
+ "padding_side": "right",
162
+ "split_special_tokens": false,
163
+ "tokenizer_class": "Qwen2Tokenizer",
164
+ "unk_token": null
165
+ }
llm/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
media.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import glob
17
+ import time
18
+ import random
19
+ import os
20
+ import tempfile
21
+ from collections import defaultdict
22
+ from io import BytesIO
23
+ from typing import Any, Dict, List, Optional, Union
24
+ import io
25
+ import cv2
26
+ import kaldiio
27
+ import librosa
28
+ import soundfile as sf
29
+ import torch
30
+ import numpy as np
31
+ import PIL
32
+ import PIL.Image
33
+ import requests
34
+ import tarfile
35
+ import whisper
36
+ import decord
37
+ from decord import AudioReader, cpu
38
+
39
+ from transformers import PretrainedConfig
40
+
41
+ MEDIA_TOKENS = {
42
+ "image": "<image>",
43
+ "video": "<vila/video>",
44
+ "speech": "<speech>",
45
+ "sound": "<sound>",
46
+ }
47
+
48
+
49
+ class Media:
50
+ """Base class for media objects."""
51
+ pass
52
+
53
+
54
+ class File(Media):
55
+ """File-based media object."""
56
+ def __init__(self, path: str) -> None:
57
+ self.path = path
58
+
59
+
60
+ class Image(File):
61
+ """Image media object."""
62
+ pass
63
+
64
+
65
+ class Video(File):
66
+ """Video media object."""
67
+ pass
68
+
69
+
70
+ class Speech(File):
71
+ """Speech audio media object."""
72
+ def __init__(self, path, extension: str = None) -> None:
73
+ self.path = path
74
+ self.extension = extension
75
+
76
+
77
+ class Sound(File):
78
+ """Sound/music audio media object."""
79
+ def __init__(self, path, extension: str = None) -> None:
80
+ self.path = path
81
+ self.extension = extension
82
+
83
+
84
+ def make_list(obj: Any) -> List:
85
+ """Convert object to list if not already a list."""
86
+ return obj if isinstance(obj, list) else [obj]
87
+
88
+
89
+ def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image:
90
+ """Extract PIL Image from Image object or return PIL Image as-is."""
91
+ if isinstance(image, Image):
92
+ if image.path.startswith("http://") or image.path.startswith("https://"):
93
+ image = PIL.Image.open(requests.get(image.path, stream=True).raw)
94
+ else:
95
+ image = PIL.Image.open(image.path)
96
+ return image
97
+
98
+
99
+ def _load_video_bytesio(
100
+ video_bytesio: BytesIO, *, num_frames: int, config: PretrainedConfig, load_aud: bool = False
101
+ ) -> List[PIL.Image.Image]:
102
+ """Load video from BytesIO object by writing to temporary file."""
103
+ with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
104
+ temp_video.write(video_bytesio.read())
105
+ temp_video_name = temp_video.name
106
+ return _load_video(temp_video_name, num_frames=num_frames, load_aud=load_aud, config=config)
107
+
108
+ def get_overlap(inp1, inp2):
109
+ """
110
+ Calculates the overlapping time frame between a video clip and an audio segment.
111
+
112
+ Args:
113
+ inp1 (list): [start_sec, end_sec]
114
+ inp2 (list): [start_sec, end_sec]
115
+
116
+ Returns:
117
+ tuple or None: (overlap_start, overlap_end) if overlap exists, else None.
118
+ """
119
+ # Calculate the maximum start time and minimum end time
120
+ overlap_start = max(inp1[0], inp2[0])
121
+ overlap_end = min(inp1[1], inp2[1])
122
+
123
+ # Check if there is an actual overlap
124
+ if overlap_start < overlap_end:
125
+ return (overlap_start, overlap_end)
126
+ else:
127
+ return None
128
+
129
+
130
+ def _load_video(
131
+ video_path: str, *, num_frames: int, config: PretrainedConfig, load_aud: bool = False
132
+ ) -> List[PIL.Image.Image]:
133
+ # Load video frames from a directory
134
+ if os.path.isdir(video_path):
135
+ frame_paths = sorted(glob.glob(os.path.join(video_path, "*")))
136
+ indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int)
137
+ return [PIL.Image.open(frame_paths[index]) for index in indices]
138
+
139
+ # Load video frames from a video file
140
+ vidcap = cv2.VideoCapture(video_path)
141
+
142
+ # Load audio if available and needed
143
+ audio_info = None
144
+ if load_aud:
145
+ try:
146
+ aud_feature, audio_info = _load_speech(video_path, config)
147
+ except Exception as e:
148
+ aud_feature = None
149
+ else:
150
+ aud_feature = None
151
+
152
+ # Find the last frame as frame count might not be accurate
153
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
154
+ while frame_count > 0:
155
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
156
+ if vidcap.grab():
157
+ break
158
+ frame_count -= 1
159
+ else:
160
+ raise ValueError(f"Video '{video_path}' has no frames.")
161
+
162
+ # Extract frames uniformly
163
+ indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
164
+
165
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
166
+ video_duration = frame_count / fps
167
+
168
+ # When load_audio_in_video and interleaved_vis_aud_in_video is True, we need to load frames for each video segment
169
+ if config.load_audio_in_video and config.interleaved_vis_aud_in_video and aud_feature is not None:
170
+ segment_duration = config.interleaved_video_segment_duration
171
+ if segment_duration == -1:
172
+ raise ValueError("video_segment_duration is not set")
173
+
174
+ segment_vis_indices_list = []
175
+ segment_aud_indices_list = []
176
+ segment_counts = np.ceil(video_duration / segment_duration).astype(int)
177
+
178
+ if type(aud_feature) == dict:
179
+ aud_feas = aud_feature["input_features"]
180
+ else:
181
+ aud_feas = aud_feature
182
+ audio_start_sec = audio_info['audio_start_sec']
183
+ audio_end_sec = audio_info['audio_end_sample_sec']
184
+
185
+ stft_frames_per_second = config.audio_sampling_rate // config.audio_hop_length
186
+
187
+ _idx = 0
188
+ aud_sample_start_idx = 0
189
+ for i in range(segment_counts):
190
+ end_frame = min((i+1) * segment_duration * fps, frame_count)
191
+
192
+ _indices = []
193
+ while _idx < len(indices) and indices[_idx] < end_frame and _idx < len(indices):
194
+ _indices.append(indices[_idx])
195
+ _idx += 1
196
+ segment_vis_indices_list.append(_indices)
197
+ clip_start_sec = i * segment_duration
198
+ clip_end_sec = min(clip_start_sec + segment_duration, video_duration)
199
+
200
+ # get the audio indices for the current clip
201
+ overlap = get_overlap([clip_start_sec, clip_end_sec], [audio_start_sec, audio_end_sec])
202
+ if overlap is not None:
203
+ aud_sample_end_idx = round((overlap[1] - audio_start_sec) * stft_frames_per_second)
204
+ segment_aud_indices_list.append([aud_sample_start_idx, aud_sample_end_idx])
205
+ aud_sample_start_idx = aud_sample_end_idx
206
+ else:
207
+ segment_aud_indices_list.append([])
208
+ frames = {}
209
+ frame_times = {}
210
+ for index in indices:
211
+ if index in frames:
212
+ continue
213
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
214
+ success, frame = vidcap.read()
215
+ if not success:
216
+ print(f"Failed to read frame {index} from video '{video_path}'. Skipped.")
217
+ continue
218
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
219
+ frames[index] = PIL.Image.fromarray(frame)
220
+ frame_times[index] = index / fps
221
+
222
+ output_frames = [frames[index] for index in indices if index in frames]
223
+ output_frame_times = [frame_times[index] for index in indices if index in frame_times]
224
+
225
+ video_info = {}
226
+ if config.load_audio_in_video and config.interleaved_vis_aud_in_video and aud_feature is not None:
227
+ new_segment_vis_indices_list = []
228
+ processed_frame_index = 0
229
+ for i, segment_indices in enumerate(segment_vis_indices_list):
230
+ new_segment_vis_indices_list.append([])
231
+ for index in segment_indices:
232
+ if index in frames:
233
+ new_segment_vis_indices_list[-1].append(processed_frame_index)
234
+ processed_frame_index += 1
235
+ segment_vis_indices_list = new_segment_vis_indices_list
236
+
237
+ video_info["segment_vis_indices_list"] = segment_vis_indices_list
238
+ video_info["segment_aud_indices_list"] = segment_aud_indices_list
239
+ video_info['expected_frame_count'] = len(indices)
240
+ video_info['video_path'] = video_path
241
+ if audio_info is not None:
242
+ audio_info['video_path'] = video_path
243
+ video_info['has_audio'] = aud_feature is not None
244
+ video_info['video_duration'] = video_duration
245
+ video_info['audio_info'] = audio_info
246
+
247
+ # calculate the time of each frame
248
+ video_info['video_frame_times'] = output_frame_times
249
+
250
+ return output_frames, aud_feature, video_info
251
+
252
+
253
+ def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]:
254
+ num_frames = config.num_video_frames
255
+ aud_fea = None
256
+
257
+ if getattr(config, "fps") != 0:
258
+ print("Extracting frames from video with specified FPS is not supported yet. Ignored.")
259
+
260
+ if isinstance(video.path, BytesIO):
261
+ frames, aud_fea, video_info = _load_video_bytesio(
262
+ video.path, num_frames=num_frames, config=config, load_aud=config.load_audio_in_video
263
+ )
264
+ else:
265
+ frames, aud_fea, video_info = _load_video(
266
+ video.path, num_frames=num_frames, config=config, load_aud=config.load_audio_in_video
267
+ )
268
+
269
+ if config.load_audio_in_video:
270
+ return frames, aud_fea, video_info
271
+ else:
272
+ return frames, video_info
273
+
274
+
275
+ def soundFile_read_audio(audio_file, offset=None, duration=None, dtype='float32'):
276
+ if dtype not in ['int32', 'float32']:
277
+ print("audio dtype must be int32 or float32. Default to float32")
278
+ dtype = 'float32'
279
+ # return read audio and its sample rate
280
+ if isinstance(audio_file, bytes):
281
+ audio_file = io.BytesIO(audio_file)
282
+ with sf.SoundFile(audio_file, 'r') as f:
283
+ sample_rate = f.samplerate
284
+ if offset is not None and offset > 0:
285
+ f.seek(int(offset * sample_rate))
286
+ if duration is not None and duration > 0:
287
+ samples = f.read(int(duration * sample_rate), dtype=dtype)
288
+ else:
289
+ samples = f.read(dtype=dtype)
290
+ return samples, sample_rate
291
+
292
+ def load_audio_from_tar(tar_file, audio_file):
293
+ with tarfile.open(tar_file, 'r') as tar:
294
+ audio_member = tar.getmember(audio_file)
295
+ audio_file = tar.extractfile(audio_member)
296
+ return librosa.load(audio_file)
297
+
298
+ def _load_audio_file(audio_path: str, config: PretrainedConfig):
299
+ # Load video frames from a directory
300
+ if audio_path is None:
301
+ return None
302
+
303
+ dirname = os.path.dirname(audio_path)
304
+ filename = os.path.basename(audio_path)
305
+
306
+ if dirname.endswith(".tar"):
307
+ speech, sample_rate = load_audio_from_tar(dirname, filename)
308
+ else:
309
+ sample_rate = config.audio_sampling_rate
310
+ speech = whisper.load_audio(audio_path, sr=sample_rate)
311
+
312
+ return speech, sample_rate
313
+
314
+
315
+ def _load_audio(audio: Union[str, dict], config: PretrainedConfig):
316
+ if isinstance(audio, str):
317
+ return _load_audio_file(audio, config)
318
+ elif isinstance(audio, dict):
319
+ audio_sample = audio['sample']
320
+ if isinstance(audio_sample, (bytes, io.BytesIO)):
321
+ offset = audio.get('offset', None)
322
+ duration = audio.get('duration', None)
323
+ dtype = audio.get('dtype', 'float32')
324
+ return soundFile_read_audio(
325
+ audio_sample, offset=offset, duration=duration, dtype=dtype
326
+ )
327
+ elif isinstance(audio_sample, np.ndarray):
328
+ return audio_sample, audio.get('sample_rate')
329
+ else:
330
+ raise ValueError(f"Expect the loaded audio to be a processed numpy array or raw bytes. Got {type(audio_sample)}")
331
+ else:
332
+ raise ValueError(f"Expect input to be a path string or dict. Got {type(audio)}")
333
+
334
+ def _whisper_process(audio, sample_rate, audio_chunk_length, max_chunks_per_file):
335
+ outputs = []
336
+ num_audio_chunks = 0
337
+
338
+ chunk_length = audio_chunk_length * sample_rate
339
+ for i in range(0, len(audio), chunk_length):
340
+ chunk = audio[i : i + chunk_length]
341
+ chunk = whisper.pad_or_trim(chunk)
342
+ if chunk.dtype != np.float32:
343
+ chunk = chunk.astype(np.float32)
344
+ mel = whisper.log_mel_spectrogram(chunk, n_mels=128)
345
+ num_audio_chunks+=1
346
+ outputs.append(mel)
347
+ if num_audio_chunks == max_chunks_per_file:
348
+ break
349
+
350
+ frames = torch.stack(outputs, dim=0)
351
+ return frames.numpy().tolist()
352
+
353
+ def _load_speech(speech, config: PretrainedConfig):
354
+ if type(speech) == str:
355
+ speech_path = speech
356
+ else:
357
+ speech_path = speech.path
358
+
359
+ # Load video frames from a directory
360
+ if speech_path is None:
361
+ return None
362
+ speech_outputs = []
363
+
364
+ if config.audio_chunk_length and not (type(config.audio_chunk_length) == str and "max" in config.audio_chunk_length):
365
+ try:
366
+ config.audio_chunk_length = int(config.audio_chunk_length)
367
+ except Exception as e:
368
+ print(f"Error setting audio_chunk_length: {e}")
369
+ raise e
370
+
371
+ audio_n_samples_limit = config.audio_chunk_length * config.audio_sampling_rate
372
+
373
+ def load_wav(speech_path):
374
+ speech, sr = librosa.load(speech_path, sr=config.audio_sampling_rate)
375
+ cur_max_length = speech.shape[0]
376
+ ori_audio_duration = cur_max_length / sr
377
+ return speech, ori_audio_duration
378
+
379
+ def get_audio(speech, audio_n_samples):
380
+
381
+ if type(speech) == decord.audio_reader.AudioReader:
382
+ ori_n_samples = speech.shape[1]
383
+ else:
384
+ ori_n_samples = speech.shape[0]
385
+
386
+ # random audio smaple
387
+ audio_start_sample_id = 0
388
+ audio_end_sample_id = ori_n_samples
389
+
390
+
391
+ load_max_audio = type(config.audio_chunk_length) == str and "max" in config.audio_chunk_length
392
+ if hasattr(config, 'random_audio_sample') and not load_max_audio:
393
+ if ori_n_samples > audio_n_samples:
394
+ audio_start_sample_id = random.randint(0, ori_n_samples - audio_n_samples)
395
+ audio_end_sample_id = audio_start_sample_id + audio_n_samples
396
+ else:
397
+ if load_max_audio:
398
+ if "_" in config.audio_chunk_length:
399
+ max_audio_chunk_length = int(config.audio_chunk_length.split("_")[1])
400
+ max_audio_n_samples = max_audio_chunk_length * config.audio_sampling_rate
401
+ audio_n_samples = min(ori_n_samples, max_audio_n_samples)
402
+ audio_end_sample_id = audio_n_samples
403
+ else:
404
+ audio_n_samples = ori_n_samples
405
+ audio_end_sample_id = audio_n_samples
406
+ else:
407
+ audio_end_sample_id = min(audio_n_samples, ori_n_samples)
408
+
409
+ if type(speech) == decord.audio_reader.AudioReader:
410
+ speech = speech[audio_start_sample_id:audio_end_sample_id].asnumpy()[0]
411
+ else:
412
+ speech = speech[audio_start_sample_id:audio_end_sample_id]
413
+
414
+
415
+ return speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id
416
+
417
+ if isinstance(speech_path, dict):
418
+ if "offset" in speech_path:
419
+ speech, ori_sample_rate = _load_audio(speech_path, config)
420
+
421
+ else:
422
+ speech = speech_path["sample"]
423
+ ori_sample_rate = speech_path["sample_rate"]
424
+
425
+ # resample the speech based on current sample rate
426
+ speech = librosa.resample(speech, orig_sr=ori_sample_rate, target_sr=config.audio_sampling_rate)
427
+ # variable audio sequence lengths
428
+ ori_audio_duration = speech.shape[0] / config.audio_sampling_rate
429
+ speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit)
430
+
431
+ elif isinstance(speech_path, BytesIO):
432
+ if speech.extension == ".wav":
433
+ # speech, sr = librosa.load(speech_path, sr=config.audio_sampling_rate)
434
+ # ori_audio_duration = speech.shape[0] / sr
435
+ speech, ori_audio_duration = load_wav(speech_path)
436
+ speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit)
437
+ else:
438
+ raise ValueError(f"Unsupported audio extension: {speech.extension}")
439
+
440
+ elif ".mat" in speech_path or ".ark" in speech_path:
441
+ rate, speech = kaldiio.load_mat(speech_path)
442
+ speech = librosa.resample(speech, orig_sr=rate, target_sr=config.audio_sampling_rate)
443
+ speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit)
444
+ ori_audio_duration = speech.shape[0] / config.audio_sampling_rate
445
+ elif ".mp4" in speech_path:
446
+ # Load audio from video file
447
+ ar = AudioReader(speech_path, ctx=cpu(0), sample_rate=config.audio_sampling_rate, mono=True)
448
+ cur_max_length = ar.shape[1]
449
+ ori_audio_duration = cur_max_length / config.audio_sampling_rate
450
+ speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(ar, audio_n_samples_limit)
451
+ else:
452
+ assert os.path.exists(speech_path), f"File {speech_path} does not exist"
453
+ speech, ori_audio_duration = load_wav(speech_path)
454
+ speech, audio_n_samples, audio_start_sample_id, audio_end_sample_id = get_audio(speech, audio_n_samples_limit)
455
+
456
+ # convert to float
457
+ speech = speech.astype(np.float32)
458
+ audio_n_samples = int(np.ceil(speech.shape[0] / (config.audio_sampling_rate * 30)) * (config.audio_sampling_rate * 30))
459
+
460
+ speech = whisper.pad_or_trim(speech, length=audio_n_samples) # we don't pad or trim here, instead, we pad based on the max length of all audio samples in the batch size later
461
+
462
+ new_audio_chunk_length = int(audio_n_samples // config.audio_sampling_rate)
463
+ audio_start_sec = audio_start_sample_id / config.audio_sampling_rate
464
+ audio_end_sample_sec = audio_end_sample_id / config.audio_sampling_rate
465
+
466
+ audio_info = {}
467
+ audio_info['new_audio_chunk_length'] = new_audio_chunk_length
468
+ audio_info['new_audio_n_samples'] = audio_n_samples
469
+ audio_info['ori_audio_duration'] = ori_audio_duration
470
+ audio_info['audio_start_sec'] = audio_start_sec
471
+ audio_info['audio_end_sample_sec'] = audio_end_sample_sec
472
+
473
+ return speech, audio_info
474
+
475
+ def _extract_speech(speech: Speech, config: PretrainedConfig):
476
+ frames, audio_info = _load_speech(speech, config)
477
+ return frames, audio_info
478
+
479
+ _extract_sound = _extract_speech
480
+ def extract_media(
481
+ messages: List[Dict[str, Any]],
482
+ config: Optional[PretrainedConfig] = None,
483
+ draft: bool = False,
484
+ ) -> Dict[str, List[Any]]:
485
+ media = defaultdict(list)
486
+
487
+ if not hasattr(config, "load_audio_in_video"):
488
+ print(f"Warning: load_audio_in_video not in config, set to False")
489
+ config.load_audio_in_video = False
490
+
491
+ for message in messages:
492
+ text = ""
493
+ for part in make_list(message["value"]):
494
+ if isinstance(part, str):
495
+ for token in MEDIA_TOKENS.values():
496
+ if token in part:
497
+ print(f"Media token '{token}' found in text: '{part}'. Removed.")
498
+ part = part.replace(token, "").strip()
499
+ text += part
500
+ elif isinstance(part, (Image, PIL.Image.Image)):
501
+ if draft:
502
+ media["image"].append(part)
503
+ else:
504
+ media["image"].append(_extract_image(part))
505
+ text += MEDIA_TOKENS["image"]
506
+ elif isinstance(part, Video):
507
+ if draft:
508
+ media["video"].append(part)
509
+ else:
510
+ if config.load_audio_in_video:
511
+ output, aud_fea, video_info = _extract_video(part, config)
512
+ media["video"].append(output)
513
+ media["video_info"].append(video_info)
514
+ if aud_fea is not None:
515
+ media["sound"].append(aud_fea)
516
+ media["audio_info"].append(video_info['audio_info'])
517
+ text += MEDIA_TOKENS["sound"]
518
+ else:
519
+ output, video_info = _extract_video(part, config)
520
+ media["video"].append(output)
521
+ media["video_info"].append(video_info)
522
+ text += MEDIA_TOKENS["video"]
523
+ elif isinstance(part, Speech):
524
+ if draft:
525
+ if config.unified_audio_encoder:
526
+ media["sound"].append(part)
527
+ text += MEDIA_TOKENS["sound"]
528
+ else:
529
+ media["speech"].append(part)
530
+ text += MEDIA_TOKENS["speech"]
531
+ else:
532
+ output, audio_info = _extract_speech(part, config)
533
+ if output is not None:
534
+ if config.unified_audio_encoder:
535
+ media["sound"].append(output)
536
+ text += MEDIA_TOKENS["sound"]
537
+ else:
538
+ media["speech"].append(output)
539
+ text += MEDIA_TOKENS["speech"]
540
+ media["audio_info"].append(audio_info)
541
+ elif isinstance(part, Sound):
542
+ if draft:
543
+ media["sound"].append(part)
544
+ text += MEDIA_TOKENS["sound"]
545
+ else:
546
+ output, audio_info = _extract_sound(part, config)
547
+ if output is not None:
548
+ media["sound"].append(output)
549
+ media["audio_info"].append(audio_info)
550
+ text += MEDIA_TOKENS["sound"]
551
+ else:
552
+ print(f"part: {part}")
553
+ raise ValueError(f"Unsupported prompt part type: {type(part)}")
554
+ message["value"] = text
555
+ return media
media_encoder.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+ from typing import Any, Dict, List, Optional, Tuple
18
+
19
+ import torch
20
+ from torch import nn
21
+ from torch.nn import Module, ModuleList
22
+ import numpy as np
23
+ from einops import rearrange, repeat
24
+ from torch.cuda.amp import autocast
25
+ from torch import nn, einsum, broadcast_tensors, Tensor
26
+ from beartype import beartype
27
+ from beartype.typing import Literal, Union, Optional
28
+ from math import pi, log
29
+ import math
30
+
31
+
32
+ class CacheFeatures(object):
33
+ def __init__(self, value, type):
34
+ self.value = value
35
+ self.type = type
36
+ def my_to(self, device, dtype):
37
+ self.value['features'] = self.value['features'].to(device, dtype) if 'features' in self.value and self.value['features'] is not None else None
38
+ return self
39
+ def __call__(self):
40
+ return self.value
41
+
42
+ def exists(val):
43
+ return val is not None
44
+
45
+ def default(val, d):
46
+ return val if exists(val) else d
47
+
48
+ # broadcat, as tortoise-tts was using it
49
+
50
+ def broadcat(tensors, dim = -1):
51
+ broadcasted_tensors = broadcast_tensors(*tensors)
52
+
53
+ def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor:
54
+ # return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1)
55
+ # Reshape x to group elements along the specified dimension into chunks of 'size', then average over those chunks.
56
+
57
+ # Check if the dimension is divisible by the pool size, if not pad with mean values
58
+ if x.shape[dim] % size != 0:
59
+ print(f"Warning: dimension {dim} with size {x.shape[dim]} is not divisible by pool size {size}, padding with mean values")
60
+ remainder = x.shape[dim] % size
61
+ pad_len = size - remainder
62
+
63
+ # Get the mean of the last few elements along the dimension to be pooled
64
+ last_elements = x.narrow(dim, x.shape[dim] - remainder, remainder)
65
+ mean_value = last_elements.mean()
66
+
67
+ # Create padding tensor with the same shape as x except for the dimension being pooled
68
+ pad_shape = list(x.shape)
69
+ pad_shape[dim] = pad_len
70
+ padding = torch.ones(pad_shape, device=x.device, dtype=x.dtype) * mean_value
71
+
72
+ # Concatenate the original tensor with the padding along the specified dimension
73
+ x = torch.cat([x, padding], dim=dim)
74
+
75
+ shape_before = x.shape[:dim]
76
+ shape_after = x.shape[dim + 1 :]
77
+ new_shape = shape_before + (-1, size) + shape_after
78
+ x_reshaped = x.view(new_shape)
79
+ return x_reshaped.mean(dim + 1)
80
+
81
+ def rotate_half(x):
82
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
83
+ x1, x2 = x.unbind(dim = -1)
84
+ x = torch.stack((-x2, x1), dim = -1)
85
+ return rearrange(x, '... d r -> ... (d r)')
86
+
87
+
88
+ def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2):
89
+ with torch.amp.autocast(device_type='cuda', enabled=False):
90
+ ori_dtype = t.dtype
91
+ embed_dtype = torch.float64
92
+ t = t.to(embed_dtype)
93
+ if t.ndim == 3:
94
+ seq_len = t.shape[seq_dim]
95
+ freqs = freqs[-seq_len:].to(t)
96
+
97
+ rot_dim = freqs.shape[-1]
98
+ end_index = start_index + rot_dim
99
+
100
+ assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
101
+
102
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
103
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
104
+ return torch.cat((t_left, t, t_right), dim = -1).to(ori_dtype)
105
+
106
+ class MaxTimeContinuousTimeRotaryEmbedding(nn.Module):
107
+ def __init__(self, dim, max_time, period_mode="shortest", device=None):
108
+ super().__init__()
109
+ assert dim % 2 == 0, "RoPE embedding dimension must be even"
110
+
111
+ # Set max period = max_time
112
+ if period_mode == "shortest": # shortest period is max_time
113
+ base = 5
114
+ inv_freq = 2 * math.pi / (max_time * (base ** (torch.arange(0, dim // 2).float() / (dim // 2))))
115
+ elif period_mode == "longest": # longest period is max_time ** ((dim // 2) / (dim // 2 - 1))
116
+ theta = max_time ** ((dim // 2) / (dim // 2 - 1))
117
+ inv_freq = 2 * math.pi / ((theta ** (torch.arange(0, dim // 2).float() / (dim // 2))))
118
+ else:
119
+ raise ValueError(f"Invalid period mode: {period_mode}")
120
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
121
+
122
+ def forward(self, time_values: torch.Tensor):
123
+ """
124
+ time_values: [batch_size, seq_len], in seconds (or any continuous unit)
125
+ Returns:
126
+ cos, sin: [batch_size, seq_len, dim]
127
+ """
128
+ batch_size, seq_len = time_values.shape
129
+ time_values_exp = time_values[:, None, :] # [batch, 1, seq_len]
130
+ freqs = (self.inv_freq[None, :, None] @ time_values_exp).transpose(1, 2) # [batch, seq_len, dim//2]
131
+ # emb = torch.cat([freqs, freqs], dim=-1) # [batch, seq_len, dim]
132
+ # return emb.cos(), emb.sin()
133
+ return freqs
134
+
135
+ def get_axial_freqs(self, *dims):
136
+ Colon = slice(None)
137
+ all_freqs = []
138
+
139
+ for ind, dim in enumerate(dims):
140
+ pos = torch.arange(dim, device = self.device)
141
+
142
+ freqs = self.forward(pos, seq_len = dim)
143
+
144
+ all_axis = [None] * len(dims)
145
+ all_axis[ind] = Colon
146
+
147
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
148
+ all_freqs.append(freqs[new_axis_slice])
149
+
150
+ all_freqs = broadcast_tensors(*all_freqs)
151
+ return torch.cat(all_freqs, dim = -1)
152
+
153
+
154
+
155
+
156
+ class RotaryEmbedding(Module):
157
+ @beartype
158
+ def __init__(
159
+ self,
160
+ dim,
161
+ custom_freqs: Optional[Tensor] = None,
162
+ freqs_for: Union[Literal['lang', 'pixel', 'constant']] = 'lang',
163
+ theta = 10000,
164
+ max_freq = 10,
165
+ num_freqs = 1,
166
+ learned_freq = False,
167
+ use_xpos = False,
168
+ xpos_scale_base = 512,
169
+ interpolate_factor = 1.,
170
+ theta_rescale_factor = 1.,
171
+ seq_before_head_dim = False,
172
+ cache_if_possible = True,
173
+ max_time = None
174
+ ):
175
+ super().__init__()
176
+
177
+ self.dim = dim
178
+ self.freqs_for = freqs_for
179
+ self.max_freq = max_freq
180
+ self.num_freqs = num_freqs
181
+ self.learned_freq = learned_freq
182
+ self.use_xpos = use_xpos
183
+ self.xpos_scale_base = xpos_scale_base
184
+ self.interpolate_factor = interpolate_factor
185
+ self.theta_rescale_factor = theta_rescale_factor
186
+ self.cache_if_possible = cache_if_possible
187
+ self.max_time = max_time
188
+
189
+ self.tmp_store('cached_freqs', None)
190
+ self.tmp_store('cached_scales', None)
191
+
192
+ # Adjust theta to avoid angle wrapping after large times
193
+ if exists(max_time) and freqs_for == 'lang':
194
+ # Make sure highest frequency completes 1 full rotation over max time
195
+ # theta = base of exponent: higher theta → lower frequency range
196
+ # max_time * (1/theta^(0)) = 2pi => theta = max_time / (2pi)
197
+ theta = max_time / (2 * pi)
198
+
199
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
200
+
201
+ self.theta = theta
202
+
203
+ if exists(custom_freqs):
204
+ freqs = custom_freqs
205
+ elif freqs_for == 'lang':
206
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
207
+ elif freqs_for == 'pixel':
208
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
209
+ elif freqs_for == 'constant':
210
+ freqs = torch.ones(num_freqs).float()
211
+
212
+ self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
213
+
214
+ self.learned_freq = learned_freq
215
+
216
+ # dummy for device
217
+
218
+ self.tmp_store('dummy', torch.tensor(0))
219
+
220
+ # default sequence dimension
221
+
222
+ self.seq_before_head_dim = seq_before_head_dim
223
+ self.default_seq_dim = -3 if seq_before_head_dim else -2
224
+
225
+ # interpolation factors
226
+
227
+ assert interpolate_factor >= 1.
228
+ self.interpolate_factor = interpolate_factor
229
+
230
+ # xpos
231
+ if not use_xpos:
232
+ self.tmp_store('scale', None)
233
+ return
234
+
235
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
236
+ self.scale_base = xpos_scale_base
237
+ self.tmp_store('scale', scale)
238
+
239
+ # add apply_rotary_emb as static method
240
+
241
+ self.apply_rotary_emb = staticmethod(apply_rotary_emb)
242
+
243
+ @property
244
+ def device(self):
245
+ return self.dummy.device
246
+
247
+ def tmp_store(self, key, value):
248
+ self.register_buffer(key, value, persistent = False)
249
+
250
+ def get_seq_pos(self, seq_len, device, dtype, offset = 0):
251
+ return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor
252
+
253
+ def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0):
254
+ seq_dim = default(seq_dim, self.default_seq_dim)
255
+
256
+ assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
257
+
258
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
259
+
260
+ freqs = self.forward(self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset), seq_len = seq_len, offset = offset)
261
+
262
+ if seq_dim == -3:
263
+ freqs = rearrange(freqs, 'n d -> n 1 d')
264
+
265
+ return apply_rotary_emb(freqs, t, seq_dim = seq_dim)
266
+
267
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0):
268
+ seq_dim = default(seq_dim, self.default_seq_dim)
269
+
270
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
271
+ assert q_len <= k_len
272
+
273
+ rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, offset = k_len - q_len + offset)
274
+ rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim, offset = offset)
275
+
276
+ rotated_q = rotated_q.type(q.dtype)
277
+ rotated_k = rotated_k.type(k.dtype)
278
+
279
+ return rotated_q, rotated_k
280
+
281
+ def rotate_queries_and_keys(self, q, k, seq_dim = None):
282
+ seq_dim = default(seq_dim, self.default_seq_dim)
283
+
284
+ assert self.use_xpos
285
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
286
+
287
+ seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
288
+
289
+ freqs = self.forward(seq, seq_len = seq_len)
290
+ scale = self.get_scale(seq, seq_len = seq_len).to(dtype)
291
+
292
+ if seq_dim == -3:
293
+ freqs = rearrange(freqs, 'n d -> n 1 d')
294
+ scale = rearrange(scale, 'n d -> n 1 d')
295
+
296
+ rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim)
297
+ rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim)
298
+
299
+ rotated_q = rotated_q.type(q.dtype)
300
+ rotated_k = rotated_k.type(k.dtype)
301
+
302
+ return rotated_q, rotated_k
303
+
304
+ @beartype
305
+ def get_scale(
306
+ self,
307
+ t: Tensor,
308
+ seq_len: Optional[int] = None,
309
+ offset = 0
310
+ ):
311
+ assert self.use_xpos
312
+
313
+ should_cache = (
314
+ self.cache_if_possible and
315
+ exists(seq_len)
316
+ )
317
+
318
+ if (
319
+ should_cache and \
320
+ exists(self.cached_scales) and \
321
+ (seq_len + offset) <= self.cached_scales.shape[0]
322
+ ):
323
+ return self.cached_scales[offset:(offset + seq_len)]
324
+
325
+ scale = 1.
326
+ if self.use_xpos:
327
+ power = (t - len(t) // 2) / self.scale_base
328
+ scale = self.scale ** rearrange(power, 'n -> n 1')
329
+ scale = torch.cat((scale, scale), dim = -1)
330
+
331
+ if should_cache:
332
+ self.tmp_store('cached_scales', scale)
333
+
334
+ return scale
335
+
336
+ def get_axial_freqs(self, *dims):
337
+ Colon = slice(None)
338
+ all_freqs = []
339
+
340
+ for ind, dim in enumerate(dims):
341
+ if self.freqs_for == 'pixel':
342
+ pos = torch.linspace(-1, 1, steps = dim, device = self.device)
343
+ else:
344
+ pos = torch.arange(dim, device = self.device)
345
+
346
+ freqs = self.forward(pos, seq_len = dim)
347
+
348
+ all_axis = [None] * len(dims)
349
+ all_axis[ind] = Colon
350
+
351
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
352
+ all_freqs.append(freqs[new_axis_slice])
353
+
354
+ all_freqs = broadcast_tensors(*all_freqs)
355
+ return torch.cat(all_freqs, dim = -1)
356
+
357
+ def forward(
358
+ self,
359
+ t: Tensor,
360
+ seq_len = None,
361
+ offset = 0
362
+ ):
363
+ should_cache = (
364
+ self.cache_if_possible and \
365
+ not self.learned_freq and \
366
+ exists(seq_len) and \
367
+ self.freqs_for != 'pixel'
368
+ )
369
+
370
+ if (
371
+ should_cache and \
372
+ exists(self.cached_freqs) and \
373
+ (offset + seq_len) <= self.cached_freqs.shape[0]
374
+ ):
375
+ return self.cached_freqs[offset:(offset + seq_len)].detach()
376
+
377
+ freqs = self.freqs
378
+
379
+ # Scale time to keep t * freq <= 2pi
380
+ if hasattr(self, 'max_time') and self.max_time is not None:
381
+ t = t / self.max_time * (2 * pi)
382
+
383
+ freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
384
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
385
+
386
+ if should_cache:
387
+ self.tmp_store('cached_freqs', freqs.detach())
388
+
389
+ return freqs
390
+
391
+ class BaseEncoder(nn.Module):
392
+ def __init__(self, parent: nn.Module) -> None:
393
+ super().__init__()
394
+ self._parent = [parent]
395
+
396
+ @property
397
+ def parent(self) -> nn.Module:
398
+ return self._parent[0]
399
+
400
+
401
+ class BasicImageEncoder(BaseEncoder):
402
+ def __init__(
403
+ self,
404
+ parent: torch.nn.Module,
405
+ start_tokens: Optional[str] = None,
406
+ end_tokens: Optional[str] = "\n",
407
+ ) -> None:
408
+ super().__init__(parent)
409
+ end_tokens = None if end_tokens == "None" else end_tokens
410
+ self.start_tokens = start_tokens
411
+ self.end_tokens = end_tokens
412
+
413
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
414
+ if tokens is None:
415
+ return None
416
+ token_ids = self.parent.tokenizer(tokens).input_ids
417
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
418
+ return self.parent.llm_model_embed_tokens(token_ids)
419
+
420
+ def _process_features(
421
+ self,
422
+ features: torch.Tensor,
423
+ start_token_embeds: Optional[torch.Tensor],
424
+ end_token_embeds: Optional[torch.Tensor],
425
+ ) -> torch.Tensor:
426
+ if start_token_embeds is not None:
427
+ features = torch.cat([start_token_embeds, features], dim=0)
428
+ if end_token_embeds is not None:
429
+ features = torch.cat([features, end_token_embeds], dim=0)
430
+ return features
431
+
432
+ def forward(self, images: List[torch.Tensor], config: Dict[str, Any], mm_info: dict) -> List[torch.Tensor]:
433
+ images = torch.stack(images, dim=0)
434
+ features = self.parent.encode_images(images, block_sizes=config.get("block_sizes"))
435
+ process_features = partial(
436
+ self._process_features,
437
+ start_token_embeds=self.embed_tokens(self.start_tokens),
438
+ end_token_embeds=self.embed_tokens(self.end_tokens),
439
+ )
440
+ return [process_features(f) for f in features]
441
+
442
+
443
+ class BasicVideoEncoder(BaseEncoder):
444
+ def __init__(
445
+ self,
446
+ parent: torch.nn.Module,
447
+ start_tokens: Optional[str] = None,
448
+ end_tokens: Optional[str] = "\n",
449
+ ) -> None:
450
+ super().__init__(parent)
451
+ end_tokens = None if end_tokens == "None" else end_tokens
452
+ self.start_tokens = start_tokens
453
+ self.end_tokens = end_tokens
454
+
455
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
456
+ if tokens is None:
457
+ return None
458
+ token_ids = self.parent.tokenizer(tokens).input_ids
459
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
460
+ return self.parent.llm_model_embed_tokens(token_ids)
461
+
462
+ def _process_features(
463
+ self,
464
+ features: torch.Tensor,
465
+ start_token_embeds: Optional[torch.Tensor],
466
+ end_token_embeds: Optional[torch.Tensor],
467
+ ) -> torch.Tensor:
468
+ if start_token_embeds is not None:
469
+ start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0)
470
+ features = torch.cat([start_embeds, features], dim=1)
471
+ if end_token_embeds is not None:
472
+ end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0)
473
+ features = torch.cat([features, end_embeds], dim=1)
474
+ return features.flatten(0, 1)
475
+
476
+ def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
477
+ num_frames = [video.shape[0] for video in videos]
478
+ images = torch.cat(videos, dim=0)
479
+ features = self.parent.encode_images(images)
480
+ features = torch.split(features, num_frames)
481
+ process_features = partial(
482
+ self._process_features,
483
+ start_token_embeds=self.embed_tokens(self.start_tokens),
484
+ end_token_embeds=self.embed_tokens(self.end_tokens),
485
+ )
486
+ return [process_features(f) for f in features]
487
+
488
+
489
+
490
+
491
+ class BasicSoundEncoder(BaseEncoder):
492
+ def __init__(
493
+ self,
494
+ parent: torch.nn.Module,
495
+ start_tokens: Optional[str] = None,
496
+ end_tokens: Optional[str] = "\n",
497
+ embed_time = "True",
498
+ trope_theta = 50000,
499
+ trope_dim = 128,
500
+ max_time = None,
501
+ time_embed_type = "pixel",
502
+ period_fix = False,
503
+ ) -> None:
504
+ super().__init__(parent)
505
+ end_tokens = None if end_tokens == "None" else end_tokens
506
+ if embed_time == "True":
507
+ embed_time = True
508
+ elif embed_time == "False":
509
+ embed_time = False
510
+ self.start_tokens = start_tokens
511
+ self.end_tokens = end_tokens
512
+
513
+ if embed_time == "False" or embed_time == False:
514
+ self.embed_time = False
515
+ else:
516
+ self.embed_time = True
517
+ self.time_embed_type = time_embed_type
518
+
519
+ period_mode = None
520
+ if type(period_fix) == str:
521
+ if period_fix == "shortest":
522
+ period_fix = "MTCT"
523
+ period_mode = "shortest"
524
+ elif period_fix == "longest":
525
+ period_fix = "MTCT"
526
+ period_mode = "longest"
527
+
528
+ self.period_fix = period_fix
529
+ self.max_time = max_time
530
+
531
+ if period_fix == "MTCT":
532
+ if period_mode is None:
533
+ self.pos_emb = MaxTimeContinuousTimeRotaryEmbedding(
534
+ dim = trope_dim,
535
+ max_time = max_time,
536
+ )
537
+ else:
538
+ self.pos_emb = MaxTimeContinuousTimeRotaryEmbedding(
539
+ dim = trope_dim,
540
+ max_time = max_time,
541
+ period_mode = period_mode,
542
+ )
543
+
544
+ elif time_embed_type in ["pixel", "lang"]:
545
+ if trope_dim is None and max_time is None:
546
+ raise ValueError("trope_dim or max_time is required when embed_time is True")
547
+ self.pos_emb = RotaryEmbedding(
548
+ dim = trope_dim,
549
+ freqs_for = time_embed_type,
550
+ max_freq = 256,
551
+ max_time = max_time,
552
+ )
553
+ elif time_embed_type == "learned_embed":
554
+ self.time_embed = parent.sound_mm_projector.time_embed
555
+ else:
556
+ raise ValueError(f"Invalid time_embed_type: {time_embed_type}")
557
+
558
+ def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
559
+ if tokens is None:
560
+ return None
561
+ token_ids = self.parent.tokenizer(tokens).input_ids
562
+ token_ids = torch.tensor(token_ids, device=self.parent.device)
563
+ # return self.parent.llm.model.embed_tokens(token_ids)
564
+ return self.parent.llm_model_embed_tokens(token_ids)
565
+
566
+ def _process_features(
567
+ self,
568
+ features: torch.Tensor,
569
+ start_token_embeds: Optional[torch.Tensor],
570
+ end_token_embeds: Optional[torch.Tensor],
571
+ times: Optional[torch.Tensor] = None,
572
+ time_embed: Optional[torch.Tensor] = None,
573
+ ) -> torch.Tensor:
574
+
575
+ features = features.to(self.parent.device)
576
+ device = features.device
577
+ dtype = features.dtype
578
+
579
+ if self.embed_time:
580
+ device = features.device
581
+ dtype = features.dtype
582
+
583
+ # Handle different embedding types
584
+ if self.time_embed_type in ["pixel", "lang"]:
585
+ times = times.unsqueeze(0)
586
+ new_times = times
587
+ pos_emb = self.pos_emb.to(device)
588
+ if self.period_fix == "True":
589
+ if self.max_time is not None:
590
+ angle = new_times.to(device) / self.max_time * 2 * np.pi
591
+ else:
592
+ angle = new_times.to(device)
593
+ elif self.period_fix == "MTCT":
594
+ freqs = self.pos_emb(new_times.float())
595
+ freqs = freqs.squeeze(0)
596
+ features = apply_rotary_emb(freqs, features)
597
+ else:
598
+ angle = (-new_times * 2 * np.pi).to(device)
599
+
600
+ if not self.period_fix == "MTCT":
601
+ freqs = pos_emb.get_axial_freqs(new_times.shape[0], features.shape[-2]).to(device)
602
+ angle_expanded = angle.unsqueeze(2)
603
+ angle_expanded = angle_expanded.expand(new_times.shape[0], features.shape[-2], freqs.shape[-1])
604
+ freqs = freqs * angle_expanded
605
+ freqs = freqs.squeeze(0)
606
+ # ori_dtype = features.dtype
607
+ # embed_dtype = torch.float32
608
+ # features = features.to(embed_dtype)
609
+ features = apply_rotary_emb(freqs, features)
610
+ # features = features.to(ori_dtype)
611
+ elif self.time_embed_type == "learned_embed": # Learned embedding
612
+ # Add time embeddings to features
613
+ features = features + time_embed
614
+ else:
615
+ raise ValueError(f"Invalid time_embed_type: {self.time_embed_type}")
616
+
617
+ if start_token_embeds is not None:
618
+ features = torch.cat([start_token_embeds, features], dim=0)
619
+ if end_token_embeds is not None:
620
+ features = torch.cat([features, end_token_embeds], dim=0)
621
+ return features
622
+
623
+ def forward(self, sounds: List[torch.Tensor], config: Dict[str, Any], mm_info: dict) -> List[torch.Tensor]:
624
+ # sounds = torch.stack(sounds, dim=0)
625
+ features = self.parent.encode_sound(sounds, mm_info=mm_info)
626
+ process_features = partial(
627
+ self._process_features,
628
+ start_token_embeds=self.embed_tokens(self.start_tokens),
629
+ end_token_embeds=self.embed_tokens(self.end_tokens),
630
+ )
631
+
632
+
633
+ if self.embed_time:
634
+ new_features = []
635
+ device = features[0].device
636
+ fea_count = len(features)
637
+ aud_idx = 0
638
+ bs = len(mm_info["audio_info"])
639
+
640
+ if self.time_embed_type == "learned_embed": # Learned embedding, we need to first collect all times and only do time embedding once
641
+ times_list = []
642
+ for i in range(bs):
643
+ _audio_info = mm_info["audio_info"][i]
644
+ if _audio_info is not None:
645
+ for j in range(len(_audio_info)):
646
+ _feature = features[aud_idx]
647
+ if _audio_info[j] == "dummy":
648
+ times = torch.zeros(_feature.shape[0], device=device, dtype=_feature.dtype)
649
+ else:
650
+ audio_chunk_length = _audio_info[j]["new_audio_chunk_length"]
651
+ sec_per_embed = audio_chunk_length / _feature.shape[0]
652
+ audio_start_sec = _audio_info[j]["audio_start_sec"]
653
+ times = [audio_start_sec + i * sec_per_embed + sec_per_embed / 2 for i in range(_feature.shape[0])]
654
+ times = torch.tensor(times).to(device)
655
+ times_list.append(times)
656
+ aud_idx += 1
657
+
658
+ times = torch.stack(times_list, dim=0)
659
+ time_embeds = self.time_embed(times, dtype=features[0].dtype)
660
+
661
+ aud_idx = 0
662
+ for i in range(bs):
663
+ _audio_info = mm_info["audio_info"][i]
664
+ if _audio_info is not None:
665
+ for j in range(len(_audio_info)):
666
+ try:
667
+ _feature = features[aud_idx]
668
+ except Exception as e:
669
+ print(f"Error: {e}. Length of features: {len(features)}. Length of _audio_info: {len(_audio_info)}. Length of _feature: {_feature.shape[0]}")
670
+ raise e
671
+ if _audio_info[j] == "dummy":
672
+ times = torch.zeros(_feature.shape[0], device=device, dtype=_feature.dtype)
673
+ else:
674
+ audio_chunk_length = _audio_info[j]["new_audio_chunk_length"]
675
+ sec_per_embed = audio_chunk_length / _feature.shape[0]
676
+ audio_start_sec = _audio_info[j]["audio_start_sec"]
677
+ times = [audio_start_sec + i * sec_per_embed + sec_per_embed / 2 for i in range(_feature.shape[0])]
678
+ times = torch.tensor(times).to(device)
679
+ if self.time_embed_type == "learned_embed":
680
+ _feature = process_features(_feature, time_embed=time_embeds[aud_idx])
681
+ else:
682
+ _feature = process_features(_feature, times=times)
683
+ new_features.append(_feature)
684
+ aud_idx += 1
685
+
686
+ assert aud_idx == fea_count , "aud_idx: {}, fea_count: {}".format(aud_idx, fea_count)
687
+ features = new_features
688
+ else:
689
+ features = [process_features(f) for f in features]
690
+ return features
691
+
692
+ # return [process_features(f) for f in feature
693
+
694
+ class TSPVideoEncoder(BasicVideoEncoder):
695
+ def __init__(
696
+ self,
697
+ parent: torch.nn.Module,
698
+ pool_sizes: List[Tuple[int, int, int]],
699
+ start_tokens: Optional[str] = None,
700
+ end_tokens: Optional[str] = "\n",
701
+ sep_tokens: Optional[str] = None,
702
+ embed_time: str = "False",
703
+ trope_theta = 50000,
704
+ trope_dim = 128,
705
+ max_time = None,
706
+ time_embed_type = "pixel",
707
+ period_fix = False,
708
+ ) -> None:
709
+ super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens)
710
+ self.pool_sizes = pool_sizes
711
+ self.sep_tokens = sep_tokens
712
+
713
+ if embed_time == "False":
714
+ self.embed_time = False
715
+ else:
716
+ self.embed_time = True
717
+ self.time_embed_type = time_embed_type
718
+
719
+ period_mode = None
720
+ if type(period_fix) == str:
721
+ if period_fix == "shortest":
722
+ period_fix = "MTCT"
723
+ period_mode = "shortest"
724
+ elif period_fix == "longest":
725
+ period_fix = "MTCT"
726
+ period_mode = "longest"
727
+
728
+ self.period_fix = period_fix
729
+ self.max_time = max_time
730
+
731
+ if period_fix == "MTCT":
732
+ if period_mode is None:
733
+ self.pos_emb = MaxTimeContinuousTimeRotaryEmbedding(
734
+ dim = trope_dim,
735
+ max_time = max_time,
736
+ )
737
+ else:
738
+ self.pos_emb = MaxTimeContinuousTimeRotaryEmbedding(
739
+ dim = trope_dim,
740
+ max_time = max_time,
741
+ period_mode = period_mode,
742
+ )
743
+
744
+ elif time_embed_type in ["pixel", "lang"]:
745
+ if trope_dim is None and max_time is None:
746
+ raise ValueError("trope_dim or max_time is required when embed_time is True")
747
+
748
+ if time_embed_type == "lang":
749
+ self.pos_emb = RotaryEmbedding(
750
+ dim = trope_dim,
751
+ freqs_for = 'lang',
752
+ theta = trope_theta,
753
+ max_time = max_time,
754
+ )
755
+ elif time_embed_type == "pixel":
756
+ self.pos_emb = RotaryEmbedding(
757
+ dim = trope_dim,
758
+ freqs_for = time_embed_type,
759
+ max_freq = 256
760
+ )
761
+ elif time_embed_type == "learned_embed":
762
+ self.time_embed = parent.mm_projector.time_embed
763
+ else:
764
+ raise ValueError(f"Invalid time_embed_type: {time_embed_type}")
765
+
766
+ def _process_features(
767
+ self,
768
+ inputs: torch.Tensor,
769
+ start_token_embeds: Optional[torch.Tensor],
770
+ end_token_embeds: Optional[torch.Tensor],
771
+ sep_token_embeds: Optional[torch.Tensor],
772
+ times: Optional[torch.Tensor] = None,
773
+ time_embed: Optional[torch.Tensor] = None,
774
+ ) -> torch.Tensor:
775
+ nt, ns = inputs.shape[:2]
776
+ nl = int(ns**0.5)
777
+ outputs = []
778
+ for pool_size in self.pool_sizes:
779
+ features = inputs.view(nt, nl, nl, -1)
780
+ for dim, p in enumerate(pool_size):
781
+ try:
782
+ features = pool(features, p, dim=dim)
783
+ except Exception as e:
784
+ print(f"Error: Pooling failed: {e}")
785
+ print(f"inputs.shape: {inputs.shape}, features.shape: {features.shape}, pool_size: {p}, dim: {dim}")
786
+ raise e
787
+ features = features.flatten(1, 2)
788
+
789
+ if self.embed_time:
790
+ device = features.device
791
+ dtype = features.dtype
792
+ if self.time_embed_type in ["pixel", "lang"]:
793
+ # consider the pooling in self.pool_sizes
794
+ temporal_pool_size = pool_size[0]
795
+ if temporal_pool_size != 1:
796
+ if len(times) % temporal_pool_size != 0:
797
+ # pad
798
+ print(f"Warning: length of times: {len(times)} is not a multiple of temporal_pool_size: {temporal_pool_size}")
799
+ remainder = len(times) % temporal_pool_size
800
+ pad_len = temporal_pool_size - remainder
801
+ last_window_mean_times = times[-remainder:].mean()
802
+ times = torch.cat([times, torch.ones(pad_len).to(times.device) * last_window_mean_times])
803
+ new_times = pool(times, temporal_pool_size, 0)
804
+ else:
805
+ new_times = times
806
+
807
+ pos_emb = self.pos_emb.to(device)
808
+ if self.period_fix == "True":
809
+ if self.max_time is not None:
810
+ angle = new_times.to(device) / self.max_time * 2 * np.pi
811
+ else:
812
+ angle = new_times.to(device)
813
+ elif self.period_fix == "MTCT":
814
+ if new_times.ndim == 1:
815
+ new_times = new_times.unsqueeze(0)
816
+ freqs = self.pos_emb(new_times.float())
817
+ freqs = freqs.squeeze(0)
818
+ freqs = freqs.unsqueeze(1)
819
+ features = apply_rotary_emb(freqs, features, seq_dim=0)
820
+ else:
821
+ angle = (-new_times * 2 * np.pi).to(device)
822
+
823
+ if not self.period_fix == "MTCT":
824
+ freqs = pos_emb.get_axial_freqs(new_times.shape[0], features.shape[-2]).to(device)
825
+ angle_expanded = angle.unsqueeze(1).unsqueeze(2)
826
+ angle_expanded = angle_expanded.expand(new_times.shape[0], features.shape[-2], freqs.shape[-1])
827
+ freqs = freqs * angle_expanded
828
+ # ori_dtype = features.dtype
829
+ # embed_dtype = torch.float32
830
+ # features = features.to(embed_dtype)
831
+ features = apply_rotary_emb(freqs, features)
832
+ # features = features.to(ori_dtype)
833
+ elif self.time_embed_type == "learned_embed": # Learned embedding
834
+ # Add time embeddings to features
835
+ features = features + time_embed
836
+ else:
837
+ raise ValueError(f"Invalid time_embed_type: {self.time_embed_type}")
838
+
839
+ features = super()._process_features(
840
+ features,
841
+ start_token_embeds=start_token_embeds,
842
+ end_token_embeds=end_token_embeds,
843
+ )
844
+ if sep_token_embeds is not None:
845
+ features = torch.cat([features, sep_token_embeds], dim=0)
846
+ outputs.append(features)
847
+ return torch.cat(outputs, dim=0)
848
+
849
+ def forward(self, videos: List[torch.Tensor], config: Dict[str, Any], mm_info: dict) -> List[torch.Tensor]:
850
+ cache_feas = []
851
+ cache_feas_index = []
852
+ for _idx in range(len(videos)):
853
+ if type(videos[_idx]) == CacheFeatures:
854
+ cache_feas.append(videos[_idx])
855
+ cache_feas_index.append(_idx)
856
+
857
+ num_frames = [
858
+ _.value['features'].shape[0] if isinstance(_, CacheFeatures) else _.shape[0]
859
+ for _ in videos
860
+ ]
861
+
862
+ features = self.parent.encode_video(videos, mm_info=mm_info, num_frames=num_frames)
863
+ features = torch.split(features, num_frames)
864
+
865
+ process_features = partial(
866
+ self._process_features,
867
+ start_token_embeds=self.embed_tokens(self.start_tokens),
868
+ end_token_embeds=self.embed_tokens(self.end_tokens),
869
+ sep_token_embeds=self.embed_tokens(self.sep_tokens),
870
+ )
871
+
872
+
873
+ if self.embed_time:
874
+ bs = len(mm_info["video_info"])
875
+ vid_idx = 0
876
+ device = features[0].device
877
+
878
+ if self.time_embed_type == "learned_embed":
879
+ # Learned embedding, we need to first collect all times from all videos and only do time embedding once
880
+ times_list = []
881
+ for i in range(bs):
882
+ _video_info = mm_info["video_info"][i]
883
+ if _video_info is not None:
884
+ for j in range(len(_video_info)):
885
+ _feature = features[vid_idx]
886
+ if _video_info[j] == "dummy":
887
+ times = torch.zeros(_feature.shape[0], device=device, dtype=_feature.dtype)
888
+ else:
889
+ times = _video_info[j]["video_frame_times"]
890
+ times = torch.tensor(times).to(device)
891
+
892
+ for pool_size in self.pool_sizes:
893
+ temporal_pool_size = pool_size[0]
894
+ if temporal_pool_size != 1:
895
+ if len(times) % temporal_pool_size != 0:
896
+ # pad
897
+ print(f"Warning: length of times: {len(times)} is not a multiple of temporal_pool_size: {temporal_pool_size}")
898
+ remainder = len(times) % temporal_pool_size
899
+ pad_len = temporal_pool_size - remainder
900
+ last_window_mean_times = times[-remainder:].mean()
901
+ times = torch.cat([times, torch.ones(pad_len).to(times.device) * last_window_mean_times])
902
+ times = pool(times, temporal_pool_size, 0)
903
+
904
+ times_list.append(times)
905
+ vid_idx += 1
906
+
907
+ # pad the times to the same length
908
+ ori_lens = [len(times) for times in times_list]
909
+ max_len = max(ori_lens)
910
+ for i in range(len(times_list)):
911
+ if len(times_list[i]) < max_len:
912
+ times_list[i] = torch.cat([times_list[i], torch.zeros(max_len - len(times_list[i])).to(times_list[i].device)])
913
+ times = torch.stack(times_list, dim=0)
914
+ time_embeds = self.time_embed(times, dtype=features[0].dtype)
915
+
916
+ # remove the padding for each embed
917
+ new_time_embeds = []
918
+ for i in range(len(times_list)):
919
+ new_time_embeds.append(time_embeds[i][:ori_lens[i]].unsqueeze(1).expand(-1, features[0].shape[1], -1))
920
+
921
+ # add dummy embed to the first embed
922
+ new_time_embeds[0] = new_time_embeds[0] + 0 * time_embeds.mean()
923
+
924
+ new_features = []
925
+ fea_count = len(features)
926
+ vid_idx = 0
927
+ for i in range(bs):
928
+ _video_info = mm_info["video_info"][i]
929
+ if _video_info is not None:
930
+ for j in range(len(_video_info)):
931
+ _feature = features[vid_idx]
932
+ if _video_info[j] == "dummy":
933
+ times = torch.zeros(_feature.shape[0], device=device, dtype=_feature.dtype)
934
+ else:
935
+ times = _video_info[j]["video_frame_times"]
936
+ times = torch.tensor(times).to(device)
937
+ if self.time_embed_type == "learned_embed":
938
+ _feature = process_features(_feature, time_embed=new_time_embeds[vid_idx])
939
+ else:
940
+ _feature = process_features(_feature, times=times)
941
+ new_features.append(_feature)
942
+ vid_idx += 1
943
+
944
+ assert vid_idx == fea_count, "vid_idx: {}, fea_count: {}".format(vid_idx, fea_count)
945
+ features = new_features
946
+ else:
947
+ features = [process_features(f) for f in features]
948
+ return features
949
+
950
+ def _encode_video_frames(self, video_frames: torch.Tensor) -> torch.Tensor:
951
+ """Helper method to encode video frames when cached features are not available."""
952
+ features = self.parent.encode_images(video_frames.unsqueeze(0))
953
+ return features.squeeze(0)
954
+
955
+
mm_projector/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "outputs/model/mm_projector",
3
+ "architectures": [
4
+ "MultimodalProjector"
5
+ ],
6
+ "mm_projector_type": "mlp_downsample",
7
+ "model_type": "v2l_projector",
8
+ "torch_dtype": "bfloat16",
9
+ "transformers_version": "4.46.0"
10
+ }
mm_projector/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e80a60453d3c104445816f05e1926f676e89a2f99ceb950e4876e99a1e391913
3
+ size 124850712
mm_utils.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # Note: dynamic_preprocess and find_closest_aspect_ratio are referenced from https://github.com/OpenGVLab/InternVL
18
+
19
+ import base64
20
+ import os
21
+ import tempfile
22
+ from io import BytesIO
23
+
24
+ import numpy as np
25
+ import torch
26
+ from PIL import Image
27
+ from transformers import StoppingCriteria
28
+
29
+ from .constants import DEFAULT_IMAGE_TOKEN
30
+
31
+
32
+ def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
33
+ """Extract frames from video capture object."""
34
+ import cv2
35
+
36
+ if fps is None or frame_count is None:
37
+ # Recompute if either fps or frame_count is None
38
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
39
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
40
+ if fps == 0 or frame_count == 0:
41
+ print(f"Video file not found. return empty images. {video_file_name}")
42
+ return [
43
+ Image.new("RGB", (720, 720)),
44
+ ] * num_frames, 0
45
+
46
+ duration = frame_count / fps
47
+ frame_interval = frame_count // num_frames
48
+ if frame_interval == 0 and frame_count <= 1:
49
+ print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
50
+ return [
51
+ Image.new("RGB", (720, 720)),
52
+ ] * num_frames, 0
53
+ # print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
54
+
55
+ images = []
56
+ count = 0
57
+ success = True
58
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
59
+ while success:
60
+ # print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval)
61
+ if frame_count >= num_frames:
62
+ success, frame = vidcap.read()
63
+ if count in frame_indices:
64
+ try:
65
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
66
+ im_pil = Image.fromarray(img)
67
+ images.append(im_pil)
68
+ except BaseException:
69
+ continue
70
+ if len(images) >= num_frames:
71
+ return images, num_frames
72
+ count += 1
73
+ else:
74
+ # Left padding frames if the video is not long enough
75
+ success, frame = vidcap.read()
76
+ if success:
77
+ try:
78
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
79
+ im_pil = Image.fromarray(img)
80
+ images.append(im_pil)
81
+ except BaseException:
82
+ continue
83
+ count += 1
84
+ else:
85
+ break
86
+ if len(images) == 0:
87
+ raise ValueError("Did not find enough frames in the video. return empty image.")
88
+
89
+ return images, len(images)
90
+
91
+
92
+ def get_frame_from_vcap_with_fps(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
93
+ """
94
+ Extract frames from video capture with FPS consideration.
95
+
96
+ Args:
97
+ vidcap: OpenCV video capture object
98
+ num_frames: Maximum number of frames the model can support
99
+ max_fps: Maximum FPS the model can support
100
+ fps: FPS of the input video
101
+ frame_count: Number of frames in the input video
102
+ video_file_name: Name of the video file for logging
103
+ """
104
+ import random
105
+ import cv2
106
+
107
+ if fps is None or frame_count is None:
108
+ # Recompute if either fps or frame_count is None
109
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
110
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
111
+
112
+ if fps == 0 or frame_count == 0:
113
+ print(f"Video file not found. return empty images. {video_file_name}")
114
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
115
+ return [
116
+ Image.new("RGB", (720, 720)),
117
+ ] * empty_video_frames, 0
118
+
119
+ duration = frame_count / fps
120
+ # print("duration:", duration, "frames:", frame_count, "fps:", fps, "num_frames:", num_frames, "max_fps:", max_fps)
121
+ # If the video is too long (longer than max_fps and num_frames can support),
122
+ # we will use lower fps to sample frames.
123
+ if duration >= num_frames / max_fps:
124
+ frame_interval = frame_count // num_frames
125
+
126
+ # If the video is too short, we will skip the video if there is only one frame.
127
+ if frame_interval == 0 and frame_count <= 1:
128
+ print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
129
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
130
+ return [
131
+ Image.new("RGB", (720, 720)),
132
+ ] * empty_video_frames, 0
133
+
134
+ images = []
135
+ count = 0
136
+ success = True
137
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
138
+
139
+ while success:
140
+ if frame_count >= num_frames:
141
+ if count in frame_indices:
142
+ success, frame = vidcap.read()
143
+ try:
144
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
145
+ im_pil = Image.fromarray(img)
146
+ images.append(im_pil)
147
+ except:
148
+ continue
149
+ if len(images) >= num_frames:
150
+ return images, num_frames
151
+ else:
152
+ success = vidcap.grab()
153
+ count += 1
154
+ else:
155
+ # Left padding frames if the video is not long enough
156
+ success, frame = vidcap.read()
157
+ if success:
158
+ try:
159
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
160
+ im_pil = Image.fromarray(img)
161
+ images.append(im_pil)
162
+ except:
163
+ continue
164
+ count += 1
165
+ else:
166
+ break
167
+ else:
168
+ frames_required = int(duration * max_fps)
169
+ frame_indices = np.linspace(0, frame_count - 1, frames_required, dtype=int)
170
+ if frames_required == 0:
171
+ print(f"frames_required is fewer than 2. Duration {duration}, return empty image.")
172
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
173
+ return [
174
+ Image.new("RGB", (720, 720)),
175
+ ] * empty_video_frames, 0
176
+ elif frames_required == 1:
177
+ frame_indices = np.linspace(0, frame_count - 1, 2, dtype=int)
178
+ images = []
179
+ count = 0
180
+ looked = 0
181
+ success = True
182
+
183
+ while success:
184
+ success, frame = vidcap.read()
185
+ if success and (looked in frame_indices):
186
+ try:
187
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
188
+ im_pil = Image.fromarray(img)
189
+ images.append(im_pil)
190
+ except:
191
+ continue
192
+ count += 1
193
+ looked += 1
194
+
195
+ if len(images) == 0:
196
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
197
+ return [
198
+ Image.new("RGB", (720, 720)),
199
+ ] * empty_video_frames, 0
200
+ else:
201
+ return images, len(images)
202
+
203
+
204
+ def opencv_extract_frames(vpath_or_bytesio, frames=6, max_fps=0.0, fps=None, frame_count=None):
205
+ """
206
+ Extract frames from a video using OpenCV.
207
+
208
+ Args:
209
+ vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
210
+ frames (int): Number of frames to extract from the video.
211
+ fps (float): Frames per second of the video. If 0.0, the function will extract frames at equal intervals.
212
+
213
+ Returns:
214
+ list: List of PIL Images extracted from the video.
215
+
216
+ Raises:
217
+ NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
218
+ """
219
+ import cv2
220
+
221
+ if isinstance(vpath_or_bytesio, str):
222
+ vidcap = cv2.VideoCapture(vpath_or_bytesio)
223
+ if max_fps > 0.0:
224
+ return get_frame_from_vcap_with_fps(
225
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
226
+ )
227
+ return get_frame_from_vcap(
228
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
229
+ )
230
+ elif isinstance(vpath_or_bytesio, (BytesIO,)):
231
+ # assuming mp4
232
+ with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
233
+ temp_video.write(vpath_or_bytesio.read())
234
+ temp_video_name = temp_video.name
235
+ vidcap = cv2.VideoCapture(temp_video_name)
236
+ if max_fps > 0.0:
237
+ return get_frame_from_vcap_with_fps(
238
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
239
+ )
240
+ return get_frame_from_vcap(
241
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
242
+ )
243
+ else:
244
+ raise NotImplementedError(type(vpath_or_bytesio))
245
+
246
+
247
+ def load_image_from_base64(image):
248
+ """Load PIL Image from base64 encoded string."""
249
+ return Image.open(BytesIO(base64.b64decode(image)))
250
+
251
+
252
+ def expand2square(pil_img, background_color):
253
+ """
254
+ Expand the given PIL image to a square shape by adding padding.
255
+
256
+ Parameters:
257
+ - pil_img: The PIL image to be expanded.
258
+ - background_color: The color of the padding to be added.
259
+
260
+ Returns:
261
+ - The expanded PIL image.
262
+
263
+ If the image is already square, it is returned as is.
264
+ If the image is wider than it is tall, padding is added to the top and bottom.
265
+ If the image is taller than it is wide, padding is added to the left and right.
266
+ """
267
+ width, height = pil_img.size
268
+ if pil_img.mode == "L":
269
+ background_color = background_color[0]
270
+ if width == height:
271
+ return pil_img
272
+ elif width > height:
273
+ result = Image.new(pil_img.mode, (width, width), background_color)
274
+ result.paste(pil_img, (0, (width - height) // 2))
275
+ return result
276
+ else:
277
+ result = Image.new(pil_img.mode, (height, height), background_color)
278
+ result.paste(pil_img, ((height - width) // 2, 0))
279
+ return result
280
+
281
+
282
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
283
+ """Find the closest aspect ratio from target ratios."""
284
+ best_ratio_diff = float("inf")
285
+ best_ratio = (1, 1)
286
+ area = width * height
287
+ for ratio in target_ratios:
288
+ target_aspect_ratio = ratio[0] / ratio[1]
289
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
290
+ if ratio_diff < best_ratio_diff:
291
+ best_ratio_diff = ratio_diff
292
+ best_ratio = ratio
293
+ elif ratio_diff == best_ratio_diff:
294
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
295
+ best_ratio = ratio
296
+ return best_ratio
297
+
298
+
299
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, use_thumbnail=True):
300
+ """Dynamically preprocess image into multiple tiles based on aspect ratio."""
301
+ orig_width, orig_height = image.size
302
+ aspect_ratio = orig_width / orig_height
303
+
304
+ # Calculate the existing image aspect ratio
305
+ target_ratios = {
306
+ (i, j)
307
+ for n in range(min_num, max_num + 1)
308
+ for i in range(1, n + 1)
309
+ for j in range(1, n + 1)
310
+ if i * j <= max_num and i * j >= min_num
311
+ }
312
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
313
+
314
+ # find the closest aspect ratio to the target
315
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
316
+
317
+ # calculate the target width and height
318
+ target_width = image_size * target_aspect_ratio[0]
319
+ target_height = image_size * target_aspect_ratio[1]
320
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
321
+
322
+ # resize the image
323
+ resized_img = image.resize((target_width, target_height))
324
+ processed_images = []
325
+ for i in range(blocks):
326
+ box = (
327
+ (i % (target_width // image_size)) * image_size,
328
+ (i // (target_width // image_size)) * image_size,
329
+ ((i % (target_width // image_size)) + 1) * image_size,
330
+ ((i // (target_width // image_size)) + 1) * image_size,
331
+ )
332
+ # split the image
333
+ split_img = resized_img.crop(box)
334
+ processed_images.append(split_img)
335
+ assert len(processed_images) == blocks
336
+ if use_thumbnail and len(processed_images) != 1:
337
+ thumbnail_img = image.resize((image_size, image_size))
338
+ processed_images.append(thumbnail_img)
339
+ return processed_images
340
+
341
+
342
+ def dynamic_s2_preprocess(image, s2_scales=[384, 768, 1152], max_num=12, image_size=384):
343
+ """Dynamically preprocess image with multi-scale S2 strategy."""
344
+ orig_width, orig_height = image.size
345
+ aspect_ratio = orig_width / orig_height
346
+ min_num = (s2_scales[-1] // s2_scales[0]) ** 2
347
+
348
+ processed_images = []
349
+
350
+ # Add tiles for all but the last scale using fixed square ratio
351
+
352
+ for scale in s2_scales[:-1]:
353
+ target_width = image_size * (scale // s2_scales[0])
354
+ target_height = image_size * (scale // s2_scales[0])
355
+ blocks = (scale // s2_scales[0]) ** 2
356
+
357
+ # resize the image
358
+ resized_img = image.resize((target_width, target_height))
359
+ for i in range(blocks):
360
+ box = (
361
+ (i % (target_width // image_size)) * image_size,
362
+ (i // (target_width // image_size)) * image_size,
363
+ ((i % (target_width // image_size)) + 1) * image_size,
364
+ ((i // (target_width // image_size)) + 1) * image_size,
365
+ )
366
+ # split the image
367
+ split_img = resized_img.crop(box)
368
+ processed_images.append(split_img)
369
+
370
+ # Add tiles for the last scale using dynamic aspect ratio
371
+
372
+ # Calculate the existing image aspect ratio
373
+ target_ratios = {
374
+ (i, j)
375
+ for n in range(min_num, max_num + 1)
376
+ for i in range(1, n + 1)
377
+ for j in range(1, n + 1)
378
+ if i * j <= max_num and i * j >= min_num
379
+ }
380
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
381
+
382
+ # find the closest aspect ratio to the target
383
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
384
+
385
+ # calculate the target width and height
386
+ target_width = image_size * target_aspect_ratio[0]
387
+ target_height = image_size * target_aspect_ratio[1]
388
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
389
+
390
+ # resize the image
391
+ resized_img = image.resize((target_width, target_height))
392
+ for i in range(blocks):
393
+ box = (
394
+ (i % (target_width // image_size)) * image_size,
395
+ (i // (target_width // image_size)) * image_size,
396
+ ((i % (target_width // image_size)) + 1) * image_size,
397
+ ((i // (target_width // image_size)) + 1) * image_size,
398
+ )
399
+ # split the image
400
+ split_img = resized_img.crop(box)
401
+ processed_images.append(split_img)
402
+
403
+ return processed_images, (target_aspect_ratio[1], target_aspect_ratio[0])
404
+
405
+
406
+ def dynamic_process_images_and_prompt(images, prompt, data_args, image_folder=None, max_tiles=None):
407
+ prompt = prompt.split(DEFAULT_IMAGE_TOKEN)
408
+ idx = 0
409
+ all_images = []
410
+ for img in images:
411
+ processed_images = process_image(img, data_args, image_folder, enable_dynamic_res=True, max_tiles=max_tiles)
412
+ all_images.append(processed_images)
413
+ prompt.insert(idx + 1, f"{DEFAULT_IMAGE_TOKEN}\n" * processed_images.shape[0])
414
+ idx += 2
415
+ prompt = "".join(prompt)
416
+ if all_images:
417
+ all_images = torch.cat(all_images)
418
+ else:
419
+ all_images = None
420
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, "")
421
+ return all_images, prompt
422
+
423
+
424
+ def dynamic_s2_process_images_and_prompt(images, prompt, data_args, image_folder=None):
425
+ idx = 0
426
+ all_images = []
427
+ all_block_size = []
428
+ for img in images:
429
+ processed_images, block_size = process_image(img, data_args, image_folder, enable_dynamic_s2=True)
430
+ all_images.append(processed_images)
431
+ all_block_size.append(block_size)
432
+ idx += 2
433
+ if all_images:
434
+ all_images = torch.cat(all_images)
435
+ else:
436
+ all_images = None
437
+ return all_images, all_block_size
438
+
439
+
440
+ def process_image(
441
+ image_file, data_args, image_folder, enable_dynamic_res=False, enable_dynamic_s2=False, max_tiles=None
442
+ ):
443
+ processor = data_args.image_processor
444
+ if isinstance(image_file, str):
445
+ if image_folder is not None:
446
+ image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
447
+ else:
448
+ image = Image.open(image_file).convert("RGB")
449
+ else:
450
+ # image is stored in bytearray
451
+ image = image_file
452
+ image = image.convert("RGB")
453
+ if hasattr(data_args.image_processor, "crop_size"):
454
+ # CLIP vision tower
455
+ crop_size = data_args.image_processor.crop_size
456
+ else:
457
+ # SIGLIP vision tower
458
+ assert hasattr(data_args.image_processor, "size")
459
+ crop_size = data_args.image_processor.size
460
+ if "dynamic_s2" in data_args.image_aspect_ratio and enable_dynamic_s2:
461
+ assert crop_size["height"] == crop_size["width"]
462
+ images, block_size = dynamic_s2_preprocess(
463
+ image, s2_scales=data_args.s2_scales, max_num=data_args.max_tiles, image_size=crop_size["height"]
464
+ )
465
+ images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
466
+ return torch.stack(images), block_size
467
+ if "dynamic" in data_args.image_aspect_ratio and enable_dynamic_res:
468
+ assert crop_size["height"] == crop_size["width"]
469
+ if max_tiles is not None:
470
+ max_num = max_tiles
471
+ else:
472
+ max_num = data_args.max_tiles
473
+ images = dynamic_preprocess(image, min_num=data_args.min_tiles, max_num=max_num, image_size=crop_size["height"])
474
+ images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
475
+ return torch.stack(images)
476
+
477
+ if data_args.image_aspect_ratio == "resize":
478
+ image = image.resize((crop_size["width"], crop_size["height"]))
479
+ elif data_args.image_aspect_ratio == "pad":
480
+ image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
481
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
482
+ else:
483
+ # Using default behavior of the vision encoder
484
+ # For CLIP, default is central crop
485
+ # For Radio, default is central crop
486
+ # For Siglip, default is resize
487
+ # For InternVIT, default is resize
488
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
489
+ return image
490
+
491
+
492
+ def process_images(images, image_processor, model_cfg, enable_dynamic_res=False, max_tiles=None):
493
+ """Process a batch of images using the model's image processor."""
494
+ model_cfg.image_processor = image_processor
495
+ new_images = [
496
+ process_image(image, model_cfg, None, enable_dynamic_res=enable_dynamic_res, max_tiles=max_tiles)
497
+ for image in images
498
+ ]
499
+
500
+ if all(x.shape == new_images[0].shape for x in new_images):
501
+ if len(new_images[0].shape) == 4:
502
+ new_images = torch.cat(new_images, dim=0)
503
+ elif len(new_images[0].shape) == 3:
504
+ new_images = torch.stack(new_images, dim=0)
505
+ else:
506
+ raise ValueError(f"new_images rank does not equal to 4, rank: {len(new_images[0].shape)}")
507
+ else:
508
+ raise ValueError("The shape of images in new_images is different!")
509
+ return new_images
510
+
511
+
512
+ def tokenizer_image_token(prompt, tokenizer, return_tensors=None, return_ids=True):
513
+ """Tokenize prompt with media tokens."""
514
+ if return_ids:
515
+ return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
516
+ else:
517
+ return tokenizer(prompt, return_tensors=return_tensors)
518
+
519
+
520
+ def is_gemma_tokenizer(tokenizer):
521
+ """Check if the tokenizer is a Gemma tokenizer."""
522
+ return "gemma" in tokenizer.__class__.__name__.lower()
523
+
524
+
525
+ def get_model_name_from_path(model_path):
526
+ """Extract model name from file path."""
527
+ model_path = model_path.strip("/")
528
+ model_paths = model_path.split("/")
529
+ if model_paths[-1].startswith("checkpoint-"):
530
+ return model_paths[-2] + "_" + model_paths[-1]
531
+ else:
532
+ return model_paths[-1]
533
+
534
+
535
+ class KeywordsStoppingCriteria(StoppingCriteria):
536
+ """Stopping criteria based on keyword tokens."""
537
+ def __init__(self, keywords, tokenizer, input_ids):
538
+ self.keywords = keywords
539
+ self.keyword_ids = []
540
+ self.max_keyword_len = 0
541
+ for keyword in keywords:
542
+ cur_keyword_ids = tokenizer(keyword).input_ids
543
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
544
+ cur_keyword_ids = cur_keyword_ids[1:]
545
+ if len(cur_keyword_ids) > self.max_keyword_len:
546
+ self.max_keyword_len = len(cur_keyword_ids)
547
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
548
+ self.tokenizer = tokenizer
549
+ self.start_len = input_ids.shape[1]
550
+
551
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
552
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
553
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
554
+ for keyword_id in self.keyword_ids:
555
+ if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
556
+ return True
557
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
558
+ for keyword in self.keywords:
559
+ if keyword in outputs:
560
+ return True
561
+ return False
562
+
563
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
564
+ outputs = []
565
+ for i in range(output_ids.shape[0]):
566
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
567
+ return all(outputs)
model_utils_packing.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from importlib import import_module
17
+ from typing import Tuple
18
+
19
+ import torch
20
+ import transformers
21
+ from torch import nn
22
+ from torch.nn import functional as F
23
+
24
+ __all__ = ["patch"]
25
+
26
+
27
+ def _get_unpad_data(attention_mask: torch.Tensor, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, int]:
28
+ if hasattr(_get_unpad_data, "seqlens_in_batch"):
29
+ seqlens_in_batch = _get_unpad_data.seqlens_in_batch
30
+ else:
31
+ seqlens_in_batch = torch.sum(attention_mask, dim=1)
32
+
33
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
34
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
35
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
36
+ return indices, cu_seqlens, max_seqlen_in_batch
37
+
38
+
39
+ def set_seqlens_in_batch(seqlens_in_batch: torch.Tensor) -> None:
40
+ _get_unpad_data.seqlens_in_batch = seqlens_in_batch
41
+
42
+
43
+ def patch(model: nn.Module) -> None:
44
+ if transformers.__version__ < "4.43.0":
45
+ m = import_module(model.__module__)
46
+ if not hasattr(m, "_get_unpad_data"):
47
+ raise ValueError(f"Module {m} does not have function '_get_unpad_data' for packing")
48
+ m._get_unpad_data = _get_unpad_data
49
+ else:
50
+ transformers.modeling_flash_attention_utils._get_unpad_data = _get_unpad_data
modeling_vila.py ADDED
@@ -0,0 +1,1834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import copy
17
+ import json
18
+ import logging
19
+ import numpy as np
20
+ import os
21
+ import os.path
22
+ import os.path as osp
23
+ import shutil
24
+ import warnings
25
+ from abc import ABC
26
+ from collections import OrderedDict, defaultdict, deque
27
+ from copy import deepcopy
28
+ from itertools import chain
29
+ from threading import Thread
30
+ from typing import Any, Dict, List, Optional, Tuple, Union
31
+
32
+ import torch
33
+ import torch.distributed as dist
34
+ import torch.nn as nn
35
+ import torch.nn.functional as F
36
+ import torchvision
37
+ from einops import rearrange
38
+ from PIL import Image
39
+ from transformers import (
40
+ AutoConfig,
41
+ AutoModel,
42
+ AutoProcessor,
43
+ AutoTokenizer,
44
+ GenerationConfig,
45
+ LogitsProcessor,
46
+ PretrainedConfig,
47
+ PreTrainedModel,
48
+ Qwen2Config,
49
+ Qwen2ForCausalLM,
50
+ Qwen2PreTrainedModel,
51
+ TextIteratorStreamer,
52
+ WhisperFeatureExtractor,
53
+ )
54
+ from transformers.modeling_outputs import CausalLMOutputWithPast
55
+ from transformers.modeling_utils import ContextManagers, no_init_weights
56
+
57
+ from .auto_processor import VILAProcessor
58
+ from .base_projector import MultimodalProjector, MultimodalProjectorConfig
59
+ from .sound_base_projector import SoundMultimodalProjector, SoundMultimodalProjectorConfig
60
+ from .speech_base_projector import SpeechMultimodalProjector, SpeechMultimodalProjectorConfig
61
+
62
+ from .builder import build_llm_and_tokenizer
63
+ from .configuration_vila import VILAConfig
64
+ from .constants import *
65
+ from .conversation import SeparatorStyle, default_conversation
66
+ from .distributed import all_gather as vila_all_gather
67
+ from .media import extract_media
68
+ from .media_encoder import BasicImageEncoder, BasicVideoEncoder, TSPVideoEncoder, BasicSoundEncoder, CacheFeatures
69
+ from .mm_utils import process_image, process_images
70
+ from .model_utils_packing import set_seqlens_in_batch
71
+ from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2
72
+ from .tokenizer_utils import tokenize_conversation
73
+ from .utils import get_model_config, load_tokenizer_then_handle_media_tokens_and_chat_template
74
+
75
+ from .constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS_VILA, NUM_EXTRA_TOKENS_XVILA
76
+ from .qwen_audio_encoder import Qwen2AudioTower
77
+ import whisper
78
+
79
+ from .audio_encoder import AudioTower
80
+
81
+
82
+ def build_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
83
+ """Build multimodal projector from path or configuration."""
84
+ if model_type_or_path is None:
85
+ return None
86
+ if config.resume_path:
87
+ assert os.path.exists(model_type_or_path), f"Resume mm projector path {model_type_or_path} does not exist!"
88
+ return MultimodalProjector.from_pretrained(model_type_or_path, config)
89
+ else:
90
+ mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
91
+ mm_projector = MultimodalProjector(mm_projector_cfg, config)
92
+ return mm_projector
93
+
94
+ def build_speech_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
95
+ """Build speech multimodal projector from path or configuration."""
96
+ if model_type_or_path is None:
97
+ return None
98
+ if config.resume_path:
99
+ assert os.path.exists(model_type_or_path), f"Resume speech mm projector path {model_type_or_path} does not exist!"
100
+ _model = SpeechMultimodalProjector.from_pretrained(
101
+ model_type_or_path, config, torch_dtype=eval(config.model_dtype)
102
+ )
103
+ return _model
104
+ else:
105
+ speech_mm_projector_cfg = SpeechMultimodalProjectorConfig(model_type_or_path)
106
+ speech_mm_projector = SpeechMultimodalProjector(speech_mm_projector_cfg, config).to(eval(config.model_dtype))
107
+ return speech_mm_projector
108
+
109
+
110
+ def build_sound_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
111
+ """Build sound multimodal projector from path or configuration."""
112
+ if model_type_or_path is None:
113
+ return None
114
+
115
+ if type(config.model_dtype) == str:
116
+ model_dtype = eval(config.model_dtype)
117
+ else:
118
+ model_dtype = config.model_dtype
119
+ if config.resume_path:
120
+ assert os.path.exists(model_type_or_path), f"Resume sound mm projector path {model_type_or_path} does not exist!"
121
+ _model = SoundMultimodalProjector.from_pretrained(
122
+ model_type_or_path, config, torch_dtype=model_dtype
123
+ )
124
+ return _model
125
+ else:
126
+ sound_mm_projector_cfg = SoundMultimodalProjectorConfig(model_type_or_path)
127
+ sound_mm_projector = SoundMultimodalProjector(sound_mm_projector_cfg, config).to(model_dtype)
128
+ return sound_mm_projector
129
+
130
+
131
+ def check_dot_in_model_path(model_path: str):
132
+ """Check if the model path contains a dot, which may affect model loading."""
133
+ if osp.isdir(model_path):
134
+ if "." in osp.abspath(model_path):
135
+ return True
136
+ else:
137
+ if "." in model_path:
138
+ return True
139
+ return False
140
+
141
+
142
+ def get_vila_version(model_path: str) -> str:
143
+ VERSIONS = ["vila1.5", "vila-u", "longvila", "nvila", "vila-m3"]
144
+ for version in VERSIONS:
145
+ if version in model_path.lower():
146
+ return version
147
+ return None
148
+
149
+
150
+ def generate_jinja_template(conv_mode: str) -> str:
151
+ if conv_mode == "vicuna_v1":
152
+ return """{% set system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " %}
153
+ {% set roles = ["user", "assistant"] %}
154
+ {% set sep = " " %}
155
+
156
+ {{ system_prompt }}
157
+
158
+ {% for message in messages %}
159
+ {% if message['role'] == roles[0] %}
160
+ {{ "USER: " }}{{ sep }}{{ message['content'] }}{{ sep }}
161
+ {% else %}
162
+ {{ "ASSISTANT: " }}{{ sep }}{{ message['content'] }}{{ sep }}
163
+ {% endif %}
164
+ {% endfor %}
165
+ {% if messages[-1]['role'] == 'user' %}
166
+ {{ "ASSISTANT:" }}
167
+ {% endif %}
168
+ """
169
+ elif conv_mode == "llama_3":
170
+ return """{% set system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|>" %}
171
+ {% set roles = ["<|start_header_id|>user<|end_header_id|>\\n\\n", "<|start_header_id|>assistant<|end_header_id|>\\n\\n"]%}
172
+ {% set sep = "<|eot_id|>" %}
173
+
174
+ {{ system_prompt }}
175
+ {% for message in messages %}
176
+ {% if message['role'] == 'user' %}
177
+ {{ roles[0] }}{{ message['content'] }}{{ sep }}
178
+ {% else %}
179
+ {{ roles[1] }}{{ message['content'] }}{{ sep }}
180
+ {% endif %}
181
+ {% endfor %}
182
+ {% if messages[-1]['role'] == 'user' %}
183
+ {{ roles[1] }}
184
+ {% endif %}
185
+ """
186
+ elif conv_mode == "hermes_2":
187
+ return """{% set system_prompt = "<|im_start|>system\nAnswer the questions." %}
188
+ {% set roles = ["<|im_start|>user\n", "<|im_start|>assistant\n"] %}
189
+ {% set sep = "<|im_end|>" %}
190
+
191
+ {{ system_prompt }}{{ sep }}
192
+
193
+ {% for message in messages %}
194
+ {% if message['role'] == 'user' %}
195
+ {{ roles[0] }}{{ message['content'] }}{{ sep }}
196
+ {% else %}
197
+ {{ roles[1] }}{{ message['content'] }}{{ sep }}
198
+ {% endif %}
199
+ {% endfor %}"""
200
+ else:
201
+ raise NotImplementedError(f"Jinja template generation is not implemented for {conv_mode}.")
202
+
203
+
204
+ def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
205
+ """Build vision tower from path or configuration."""
206
+ # Skip vision tower instantiation if path is None
207
+ if model_name_or_path is None:
208
+ return None
209
+
210
+ vision_tower_arch = None
211
+ if config.resume_path and "radio" not in model_name_or_path:
212
+ assert os.path.exists(model_name_or_path), f"Resume vision tower path {model_name_or_path} does not exist!"
213
+ vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
214
+ vision_tower_arch = vision_tower_cfg.architectures[0].lower()
215
+ vision_tower_name = vision_tower_arch if vision_tower_arch is not None else model_name_or_path
216
+
217
+ use_s2 = getattr(config, "s2", False)
218
+ use_dynamic_s2 = getattr(config, "dynamic_s2", False)
219
+
220
+ if "siglip" in vision_tower_name:
221
+ if use_dynamic_s2:
222
+ vision_tower = SiglipVisionTowerDynamicS2(model_name_or_path, config)
223
+ elif use_s2:
224
+ vision_tower = SiglipVisionTowerS2(model_name_or_path, config)
225
+ else:
226
+ vision_tower = SiglipVisionTower(model_name_or_path, config)
227
+ else:
228
+ raise NotImplementedError(f"Unknown vision tower: {model_name_or_path}")
229
+
230
+ config.mm_hidden_size = (
231
+ vision_tower.config.hidden_size if not (use_s2 or use_dynamic_s2) else vision_tower.hidden_size
232
+ )
233
+ return vision_tower
234
+
235
+
236
+ def build_audio_tower(model_name_or_path: str, config: PretrainedConfig, encoder_type: str) -> PreTrainedModel:
237
+ """Build audio tower for sound or speech processing."""
238
+ assert encoder_type in ["sound", "speech"]
239
+
240
+ # Skip tower instantiation if path is None
241
+ if model_name_or_path is None:
242
+ return None
243
+
244
+ model_type = "af3"
245
+
246
+ if model_type == "af3":
247
+ model = Qwen2AudioTower(model_name_or_path, config)
248
+ output_dim = 1280
249
+ else:
250
+ raise NotImplementedError(f"Not implemented for this encoder: {model_name_or_path}")
251
+
252
+ if encoder_type == "sound":
253
+ config.sound_hidden_size = output_dim
254
+ elif encoder_type == "speech":
255
+ config.speech_hidden_size = output_dim
256
+ else:
257
+ raise NotImplementedError(f"Not implemented for this encoder: {model_name_or_path}")
258
+
259
+ return model
260
+
261
+
262
+ class VILAPretrainedModel(PreTrainedModel):
263
+ config_class = VILAConfig
264
+ main_input_name = "input_embeds"
265
+ supports_gradient_checkpointing = True
266
+ _supports_flash_attn_2 = True
267
+ _no_split_modules = ["Qwen2DecoderLayer", "SiglipEncoderLayer"]
268
+
269
+ def __init__(self, config: VILAConfig, *args, **kwargs):
270
+ super().__init__(config)
271
+ self.config = config
272
+ cfgs = get_model_config(config)
273
+
274
+ if len(cfgs) == 7:
275
+ (
276
+ llm_cfg,
277
+ vision_tower_cfg,
278
+ speech_tower_cfg,
279
+ sound_tower_cfg,
280
+ mm_projector_cfg,
281
+ speech_mm_projector_cfg,
282
+ sound_mm_projector_cfg,
283
+ ) = cfgs
284
+ else:
285
+ raise ValueError(
286
+ "`llm_cfg` `mm_projector_cfg` `speech_mm_projector_cfg` `sound_mm_projector_cfg` `vision_tower_cfg` `speech_tower_cfg` `sound_tower_cfg` not found in the config."
287
+ )
288
+
289
+ # loading on auto by default
290
+ device_map = kwargs.get("device_map", "auto")
291
+ self.mm_projector = build_mm_projector(mm_projector_cfg, config)
292
+ self.vision_tower = build_vision_tower(vision_tower_cfg, config)
293
+
294
+ if speech_tower_cfg:
295
+ self.speech_tower = build_audio_tower(speech_tower_cfg, config, encoder_type="speech")
296
+ self.speech_mm_projector = build_speech_mm_projector(speech_mm_projector_cfg, config)
297
+ if sound_tower_cfg:
298
+ self.sound_tower = build_audio_tower(sound_tower_cfg, config, encoder_type="sound")
299
+ self.sound_mm_projector = build_sound_mm_projector(sound_mm_projector_cfg, config)
300
+
301
+
302
+ if device_map in ["auto", "cuda"]:
303
+ self.mm_projector = self.mm_projector.cuda()
304
+ self.vision_tower = self.vision_tower.cuda()
305
+ self.speech_tower = self.speech_tower.cuda() if hasattr(self, "speech_tower") else None
306
+ self.sound_tower = self.sound_tower.cuda() if hasattr(self, "sound_tower") else None
307
+ self.speech_mm_projector = self.speech_mm_projector.cuda() if hasattr(self, "speech_mm_projector") else None
308
+ self.sound_mm_projector = self.sound_mm_projector.cuda() if hasattr(self, "sound_mm_projector") else None
309
+ # set device_map auto can autoamtically shard llm to different devices
310
+ self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
311
+
312
+ self.llm_model_embed_tokens = self.llm.model.embed_tokens
313
+
314
+ self.tokenizer.padding_side = "left"
315
+
316
+ self.vocab_size = len(self.tokenizer)
317
+ self.update_vocab_size = lambda: setattr(self, "vocab_size", len(self.tokenizer))
318
+
319
+ self.encoders = {}
320
+ for name in ["image", "video", "speech", "sound"]:
321
+ encoder_config = getattr(self.config, f"{name}_encoder")
322
+ if isinstance(encoder_config, str):
323
+ encoder_config = json.loads(encoder_config)
324
+ if encoder_config.get("embed_time", False) == "True":
325
+ if "trope_dim" not in encoder_config and encoder_config.get("time_embed_type", "") in ["pixel", "lang"]:
326
+ encoder_config["trope_dim"] = self.config.hidden_size // 2
327
+ print(f"Warning: trope_dim not found in config, defaulting to hidden_size // 2: {encoder_config['trope_dim']}")
328
+
329
+ encoder_config.pop('_target_')
330
+ if name == "video":
331
+ self.encoders[name] = TSPVideoEncoder(parent=self, **encoder_config)
332
+ elif name == "image":
333
+ self.encoders[name] = BasicImageEncoder(self)
334
+ else:
335
+ self.encoders[name] = BasicSoundEncoder(parent=self, **encoder_config)
336
+
337
+ self.post_config()
338
+ self.is_loaded = True
339
+
340
+ self.llm_only_need_embed = kwargs.get("llm_only_need_embed", False)
341
+ if self.llm_only_need_embed:
342
+ print("We only need the embed_tokens in llm.")
343
+ del self.llm
344
+ self.llm = None
345
+ torch.cuda.empty_cache()
346
+
347
+ assert (
348
+ self.llm is not None
349
+ or self.vision_tower is not None
350
+ or self.speech_tower is not None
351
+ or self.mm_projector is not None
352
+ or self.speech_mm_projector is not None
353
+ ), "At least one of the components must be instantiated."
354
+
355
+
356
+ @classmethod
357
+ def copy_or_symlink_directory(cls, model_path, output_dir, copy=True):
358
+ # Create output directory if it doesn't exist
359
+ os.makedirs(output_dir, exist_ok=True)
360
+ # Create symlinks for all files in model_path to output_dir
361
+ for item in os.listdir(model_path):
362
+ src_path = os.path.join(model_path, item)
363
+ dst_path = os.path.join(output_dir, item)
364
+
365
+ # Remove existing file/directory at destination if it exists
366
+ if os.path.exists(dst_path):
367
+ if os.path.islink(dst_path):
368
+ os.unlink(dst_path)
369
+ elif os.path.isdir(dst_path):
370
+ shutil.rmtree(dst_path)
371
+ else:
372
+ os.remove(dst_path)
373
+
374
+ # Create symlink
375
+ if copy:
376
+ if os.path.isdir(src_path):
377
+ shutil.copytree(src_path, dst_path)
378
+ else:
379
+ shutil.copy2(src_path, dst_path)
380
+ print(f"Copied {src_path} to {dst_path}")
381
+ else:
382
+ os.symlink(src_path, dst_path)
383
+ print(f"Created symlink from {src_path} to {dst_path}")
384
+
385
+ @classmethod
386
+ def copy_remote_py_files(cls, output_dir, copy=True):
387
+ # copy .py and README for next loading
388
+ current_file_path = os.path.abspath(__file__)
389
+ current_folder = os.path.dirname(current_file_path)
390
+ for file_name in os.listdir(current_folder):
391
+ if file_name == "INSTRUCTIONS.md":
392
+ src_fname = os.path.join(current_folder, file_name)
393
+ dst_fname = os.path.join(output_dir, "README.md")
394
+ if os.path.exists(dst_fname):
395
+ old_readme = open(dst_fname).read()
396
+ else:
397
+ old_readme = ""
398
+ with open(src_fname) as src, open(dst_fname, "w") as dst:
399
+ dst.write(src.read())
400
+ dst.write(old_readme)
401
+ print("[HF] README", src_fname, "to", dst_fname)
402
+ if file_name.endswith(".py") or file_name.endswith(".jinja"):
403
+ full_file_name = os.path.join(current_folder, file_name)
404
+ if os.path.isfile(full_file_name):
405
+ if copy:
406
+ shutil.copy(full_file_name, output_dir)
407
+ print("[HF] copying", full_file_name, "to", output_dir)
408
+ else:
409
+ # symlink to ease development
410
+ if os.path.exists(os.path.join(output_dir, file_name)):
411
+ os.remove(os.path.join(output_dir, file_name))
412
+ os.symlink(full_file_name, os.path.join(output_dir, file_name))
413
+ print("[HF] linking", full_file_name, "to", output_dir)
414
+
415
+ def save_pretrained(self, output_dir, state_dict=None, **kwargs):
416
+ if state_dict is None:
417
+ # other wise fetch from deepspeed
418
+ # state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
419
+ state_dict = self.state_dict()
420
+
421
+ if getattr(self, "tokenizer", None):
422
+ self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
423
+
424
+ if self.get_llm():
425
+ print(f"saving llm to {osp.join(output_dir, 'llm')}")
426
+ self.llm.config._name_or_path = osp.join(output_dir, "llm")
427
+ llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
428
+ self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
429
+ self.config.llm_cfg = self.llm.config
430
+
431
+ if self.get_vision_tower():
432
+ print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
433
+ self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower")
434
+ vision_tower_state_dict = OrderedDict(
435
+ {k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k}
436
+ )
437
+ self.vision_tower.vision_tower.save_pretrained(
438
+ os.path.join(output_dir, "vision_tower"),
439
+ state_dict=vision_tower_state_dict,
440
+ )
441
+ self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower"))
442
+ self.config.vision_tower_cfg = self.vision_tower.config
443
+ if hasattr(self.config.vision_tower_cfg, "auto_map"):
444
+ if "radio" not in self.get_vision_tower().__class__.__name__.lower():
445
+ delattr(self.config.vision_tower_cfg, "auto_map")
446
+ if self.get_speech_tower():
447
+ print(f"saving speech_tower to {osp.join(output_dir, 'speech_tower')}")
448
+ self.speech_tower.config._name_or_path = osp.join(output_dir, "speech_tower").replace(
449
+ "tmp-checkpoint", "checkpoint"
450
+ )
451
+
452
+ speech_tower_state_dict = OrderedDict(
453
+ {k.split("speech_tower.audio_tower.")[-1]: v for k, v in state_dict.items() if "speech_tower" in k}
454
+ )
455
+
456
+ self.speech_tower.audio_tower.save_pretrained(
457
+ os.path.join(output_dir, "speech_tower"),
458
+ state_dict=speech_tower_state_dict,
459
+ )
460
+ self.config.speech_tower_cfg = self.speech_tower.config
461
+
462
+ if self.get_sound_tower():
463
+ print(f"saving sound_tower to {osp.join(output_dir, 'sound_tower')}")
464
+ self.sound_tower.config._name_or_path = osp.join(output_dir, "sound_tower").replace(
465
+ "tmp-checkpoint", "checkpoint"
466
+ )
467
+
468
+ sound_tower_state_dict = OrderedDict(
469
+ {k.split("sound_tower.audio_tower.")[-1]: v for k, v in state_dict.items() if "sound_tower" in k}
470
+ )
471
+
472
+ self.sound_tower.audio_tower.save_pretrained(
473
+ os.path.join(output_dir, "sound_tower"),
474
+ state_dict=sound_tower_state_dict,
475
+ )
476
+ self.config.sound_tower_cfg = self.sound_tower.config
477
+
478
+
479
+
480
+ if self.get_mm_projector():
481
+ print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
482
+ self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector")
483
+ mm_projector_state_dict = OrderedDict(
484
+ {k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k}
485
+ )
486
+ self.mm_projector.save_pretrained(
487
+ os.path.join(output_dir, "mm_projector"),
488
+ state_dict=mm_projector_state_dict,
489
+ )
490
+ self.config.mm_projector_cfg = self.mm_projector.config
491
+
492
+ if self.get_speech_mm_projector():
493
+ print(f"saving speech_mm_projector to {osp.join(output_dir, 'speech_mm_projector')}")
494
+ self.speech_mm_projector.config._name_or_path = osp.join(output_dir, "speech_mm_projector").replace(
495
+ "tmp-checkpoint", "checkpoint"
496
+ )
497
+ speech_mm_projector_state_dict = OrderedDict(
498
+ {k.split("speech_mm_projector.")[-1]: v for k, v in state_dict.items() if "speech_mm_projector" in k}
499
+ )
500
+ self.speech_mm_projector.save_pretrained(
501
+ os.path.join(output_dir, "speech_mm_projector"),
502
+ state_dict=speech_mm_projector_state_dict,
503
+ )
504
+ self.config.speech_mm_projector_cfg = self.speech_mm_projector.config
505
+
506
+ if self.get_sound_mm_projector():
507
+ print(f"saving sound_mm_projector to {osp.join(output_dir, 'sound_mm_projector')}")
508
+ self.sound_mm_projector.config._name_or_path = osp.join(output_dir, "sound_mm_projector").replace(
509
+ "tmp-checkpoint", "checkpoint"
510
+ )
511
+
512
+ sound_mm_projector_state_dict = OrderedDict(
513
+ {k.split("sound_mm_projector.")[-1]: v for k, v in state_dict.items() if "sound_mm_projector" in k}
514
+ )
515
+ self.sound_mm_projector.save_pretrained(
516
+ os.path.join(output_dir, "sound_mm_projector"),
517
+ state_dict=sound_mm_projector_state_dict,
518
+ )
519
+ self.config.sound_mm_projector_cfg = self.sound_mm_projector.config
520
+
521
+ # update and save top-level config
522
+ self.config._name_or_path = output_dir
523
+ self.config.architectures = [self.__class__.__name__]
524
+ self.config.save_pretrained(output_dir)
525
+
526
+ # copy .py and README for next loading
527
+ self.copy_remote_py_files(output_dir)
528
+
529
+ @classmethod
530
+ def from_pretrained(
531
+ cls,
532
+ pretrained_model_name_or_path: Optional[str] = None,
533
+ *model_args,
534
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
535
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
536
+ ignore_mismatched_sizes: bool = False,
537
+ force_download: bool = False,
538
+ local_files_only: bool = False,
539
+ token: Optional[Union[str, bool]] = None,
540
+ revision: str = "main",
541
+ use_safetensors: Optional[bool] = None,
542
+ weights_only: bool = True,
543
+ **kwargs,
544
+ ):
545
+ # print("DEBUG2", kwargs); input()
546
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
547
+ if kwargs.get("torch_dtype", None) is not None:
548
+ config.torch_dtype = kwargs.get("torch_dtype", None)
549
+ config.model_dtype = kwargs.get("torch_dtype", None)
550
+ if type(kwargs.get("torch_dtype", None)) == str:
551
+ kwargs["torch_dtype"] = eval(kwargs.get("torch_dtype", None))
552
+ else:
553
+ kwargs["torch_dtype"] = kwargs.get("torch_dtype", None)
554
+ return cls._from_config(config, **kwargs)
555
+
556
+ def init_llm(self, llm_config, config, *args, **kwargs):
557
+ """Initialize language model and tokenizer."""
558
+ self.llm, self.tokenizer = build_llm_and_tokenizer(llm_config, config, *args, **kwargs)
559
+
560
+ self.pad_token_list = (
561
+ self.tokenizer.pad_token_id,
562
+ self.tokenizer.eos_token_id,
563
+ self.tokenizer.tokenize("<|endoftext|>")[0], # for Qwen
564
+ )
565
+
566
+ self.vocab_size = len(self.tokenizer)
567
+ self.update_vocab_size = lambda: setattr(self, "vocab_size", len(self.tokenizer))
568
+ # XGrammar tokenizer and grammar compiler
569
+ # lazy init only when specified json output during inference
570
+ self.grammar_compiler = None
571
+ # self.llm.resize_token_embeddings(len(self.tokenizer))
572
+ return self.llm, self.tokenizer
573
+
574
+ def post_config(self):
575
+ self.training = self.llm.training
576
+ if self.training:
577
+ self.train()
578
+ else:
579
+ self.eval()
580
+
581
+ # configuration
582
+ if getattr(self.config, "llm_cfg", None) is None:
583
+ self.config.llm_cfg = self.llm.config
584
+ if getattr(self.config, "vision_tower_cfg", None) is None:
585
+ self.config.vision_tower_cfg = self.vision_tower.config
586
+ if getattr(self.config, "mm_projector_cfg", None) is None:
587
+ self.config.mm_projector_cfg = self.mm_projector.config
588
+ if getattr(self.config, "speech_tower_cfg", None) is None and hasattr(self, "speech_tower"):
589
+ self.config.speech_tower_cfg = self.speech_tower.config
590
+ if getattr(self.config, "sound_tower_cfg", None) is None and hasattr(self, "sound_tower"):
591
+ self.config.sound_tower_cfg = self.sound_tower.config
592
+ if getattr(self.config, "speech_mm_projector_cfg", None) is None and hasattr(self, "speech_mm_projector"):
593
+ self.config.speech_mm_projector_cfg = self.speech_mm_projector.config
594
+ if getattr(self.config, "sound_mm_projector_cfg", None) is None and hasattr(self, "sound_mm_projector"):
595
+ self.config.sound_mm_projector_cfg = self.sound_mm_projector.config
596
+
597
+ def get_llm(self):
598
+ llm = getattr(self, "llm", None)
599
+ if type(llm) is list:
600
+ llm = llm[0]
601
+ return llm
602
+
603
+ def get_lm_head(self):
604
+ lm_head = getattr(self.get_llm(), "lm_head", None)
605
+ return lm_head
606
+
607
+ def get_vision_tower(self):
608
+ vision_tower = getattr(self, "vision_tower", None)
609
+ if type(vision_tower) is list:
610
+ vision_tower = vision_tower[0]
611
+ return vision_tower
612
+
613
+ def get_speech_tower(self):
614
+ speech_tower = getattr(self, "speech_tower", None)
615
+ if type(speech_tower) is list:
616
+ speech_tower = speech_tower[0]
617
+ return speech_tower
618
+
619
+ def get_sound_tower(self):
620
+ sound_tower = getattr(self, "sound_tower", None)
621
+ if type(sound_tower) is list:
622
+ sound_tower = sound_tower[0]
623
+ return sound_tower
624
+
625
+ def get_mm_projector(self):
626
+ mm_projector = getattr(self, "mm_projector", None)
627
+ if type(mm_projector) is list:
628
+ mm_projector = mm_projector[0]
629
+ return mm_projector
630
+
631
+ def get_sound_mm_projector(self):
632
+ sound_mm_projector = getattr(self, "sound_mm_projector", None)
633
+ if type(sound_mm_projector) is list:
634
+ sound_mm_projector = sound_mm_projector[0]
635
+ return sound_mm_projector
636
+
637
+ def get_speech_tower(self):
638
+ speech_tower = getattr(self, "speech_tower", None)
639
+ if type(speech_tower) is list:
640
+ speech_tower = speech_tower[0]
641
+ return speech_tower
642
+
643
+ def get_speech_mm_projector(self):
644
+ speech_mm_projector = getattr(self, "speech_mm_projector", None)
645
+ if type(speech_mm_projector) is list:
646
+ speech_mm_projector = speech_mm_projector[0]
647
+ return speech_mm_projector
648
+
649
+ def freezed_module_patch(self):
650
+ """
651
+ Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules.
652
+ """
653
+ if self.training:
654
+ if self.get_llm() and not getattr(self.config, "tune_language_model", False):
655
+ pass
656
+ if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False):
657
+ self.get_vision_tower().eval()
658
+ if self.get_speech_tower() and not getattr(self.config, "tune_speech_tower", False):
659
+ self.get_speech_tower().eval()
660
+ if self.get_sound_tower() and not getattr(self.config, "tune_sound_tower", False):
661
+ self.get_sound_tower().eval()
662
+ if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
663
+ self.get_mm_projector().eval()
664
+ if self.get_speech_mm_projector() and not getattr(self.config, "tune_speech_mm_projector", False):
665
+ self.get_speech_mm_projector().eval()
666
+ if self.get_sound_mm_projector() and not getattr(self.config, "tune_sound_mm_projector", False):
667
+ self.get_sound_mm_projector().eval()
668
+
669
+
670
+ class VILAForCausalLM(VILAPretrainedModel):
671
+ def __init__(self, config: VILAConfig, *args, **kwargs):
672
+ super().__init__(config, *args, **kwargs)
673
+
674
+ def merge_features_for_dynamic_s2(self, image_features, block_sizes):
675
+ scales = self.get_vision_tower().scales
676
+ resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx
677
+
678
+ image_features_each_image = []
679
+ new_block_sizes = []
680
+ block_cnt = 0
681
+ for block_size_each_image in block_sizes:
682
+ if block_size_each_image is None:
683
+ cur_features = image_features[block_cnt : block_cnt + 1]
684
+ cur_features = rearrange(cur_features, "1 (h w) c -> 1 c h w", h=int(cur_features.shape[1] ** 0.5))
685
+ cur_features = cur_features.repeat(1, len(scales), 1, 1)
686
+ image_features_each_image.append(cur_features)
687
+ new_block_sizes.append((1, 1))
688
+ block_cnt += 1
689
+ else:
690
+ cur_features_each_scale = []
691
+ for scale in scales[:-1]:
692
+ num_blocks_this_scale = (scale // scales[0]) ** 2
693
+ cur_features_each_scale.append(
694
+ self.merge_chessboard(
695
+ image_features[block_cnt : block_cnt + num_blocks_this_scale],
696
+ num_split_h=scale // scales[0],
697
+ num_split_w=scale // scales[0],
698
+ )
699
+ ) # 1 * C * H * W
700
+ block_cnt += num_blocks_this_scale
701
+ num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
702
+ cur_features_each_scale.append(
703
+ self.merge_chessboard(
704
+ image_features[block_cnt : block_cnt + num_blocks_last_scale],
705
+ num_split_h=block_size_each_image[0],
706
+ num_split_w=block_size_each_image[1],
707
+ )
708
+ ) # 1 * C * H * W
709
+ block_cnt += num_blocks_last_scale
710
+
711
+ # resize and concat features from different scales
712
+ output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
713
+ cur_features = torch.cat(
714
+ [
715
+ F.interpolate(cur_features_each_scale[i].to(torch.float32), size=output_size, mode="area").to(
716
+ cur_features_each_scale[i].dtype
717
+ )
718
+ for i in range(len(cur_features_each_scale))
719
+ ],
720
+ dim=1,
721
+ )
722
+
723
+ image_features_each_image.append(cur_features)
724
+
725
+ if resize_output_to_scale_idx == len(scales) - 1 or resize_output_to_scale_idx == -1:
726
+ new_block_sizes.append(block_size_each_image)
727
+ else:
728
+ new_block_sizes.append(
729
+ (
730
+ scales[resize_output_to_scale_idx] // scales[0],
731
+ scales[resize_output_to_scale_idx] // scales[0],
732
+ )
733
+ )
734
+
735
+ assert block_cnt == len(image_features)
736
+
737
+ return image_features_each_image, new_block_sizes
738
+
739
+ @staticmethod
740
+ def split_chessboard(x, num_split_h, num_split_w):
741
+ """
742
+ x: b * c * h * w
743
+ out: b * c * h * w
744
+ Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension
745
+ """
746
+ B, C, H, W = x.shape
747
+ assert H % num_split_h == 0 and W % num_split_w == 0
748
+ h, w = H // num_split_h, W // num_split_w
749
+ x_split = torch.cat(
750
+ [x[:, :, i * h : (i + 1) * h, j * w : (j + 1) * w] for i in range(num_split_h) for j in range(num_split_w)],
751
+ dim=0,
752
+ )
753
+ return x_split
754
+
755
+ @staticmethod
756
+ def merge_chessboard(x, num_split_h, num_split_w):
757
+ """
758
+ x: b * n * c or b * h * w * c
759
+ out: b * c * h * w
760
+ Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square.
761
+ """
762
+ B = x.shape[0]
763
+ if x.dim() == 3:
764
+ N = x.shape[1]
765
+ x = rearrange(x, "b (h w) c -> b c h w", h=int(N**0.5), w=int(N**0.5))
766
+
767
+ assert B % (num_split_h * num_split_w) == 0
768
+ b = B // (num_split_h * num_split_w)
769
+
770
+ x_merge = torch.cat(
771
+ [
772
+ torch.cat(
773
+ [x[(i * num_split_w + j) * b : (i * num_split_w + j + 1) * b] for j in range(num_split_w)], dim=-1
774
+ )
775
+ for i in range(num_split_h)
776
+ ],
777
+ dim=-2,
778
+ )
779
+
780
+ return x_merge
781
+
782
+ def encode_video(self, inp, block_sizes: Optional[Optional[Tuple[int, ...]]] = None, mm_info: Optional[dict] = None, num_frames: Optional[List[int]] = None):
783
+ bs = len(inp)
784
+ cache_feas = []
785
+ cache_feas_index = []
786
+ inp_block_sizes = block_sizes
787
+
788
+ # handle cache features
789
+ for _idx in range(len(inp)):
790
+ if type(inp[_idx]) == CacheFeatures:
791
+ cache_feas.append(inp[_idx])
792
+ cache_feas_index.append(_idx)
793
+ raw_images = [_ for _ in inp if type(_) != CacheFeatures]
794
+
795
+ raw_videos_num_frames = [_.shape[0] for _ in raw_images]
796
+ if len(raw_images) > 0:
797
+ images = torch.cat(raw_images, dim=0)
798
+ else:
799
+ images = []
800
+
801
+ if block_sizes is None:
802
+ block_sizes = [None] * len(images)
803
+
804
+ def _load_video_features(image_features, cache_feas, cache_feas_index, raw_videos_num_frames):
805
+ # load cache features
806
+ if len(cache_feas) > 0:
807
+ if len(image_features) > 0:
808
+ image_features = torch.split(image_features, raw_videos_num_frames)
809
+ new_image_features = []
810
+ cache_feas_idx = 0
811
+ raw_fea_idx = 0
812
+ for _idx in range(bs):
813
+ if _idx in cache_feas_index:
814
+ new_image_features.append(cache_feas[cache_feas_idx].value['features'].to(self.device, self.dtype))
815
+ cache_feas_idx += 1
816
+ else:
817
+ new_image_features.append(image_features[raw_fea_idx])
818
+ raw_fea_idx += 1
819
+
820
+ assert len(new_image_features) == bs
821
+ image_features = new_image_features
822
+ image_features = torch.cat(image_features, dim=0)
823
+ return image_features
824
+
825
+ if getattr(self.config, "dynamic_s2", False):
826
+
827
+ if len(images) > 0:
828
+ image_features = self.get_vision_tower()(images)
829
+
830
+ image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
831
+
832
+ image_features = [
833
+ self.split_chessboard(x, block_size[0], block_size[1])
834
+ for x, block_size in zip(image_features, new_block_sizes)
835
+ ] # list of B * C * H * W tensors
836
+ image_features = torch.cat(
837
+ [rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0
838
+ ) # B * N * C
839
+ else:
840
+ image_features = []
841
+
842
+ # load cache features
843
+ image_features = _load_video_features(image_features, cache_feas, cache_feas_index, raw_videos_num_frames)
844
+
845
+ # if hasattr(self.config, "save_data") and self.config.save_data and num_frames is not None: # video
846
+ # _save_video_features(image_features, mm_info, inp)
847
+
848
+ if inp_block_sizes is None:
849
+ new_block_sizes = [(1, 1)] * len(image_features)
850
+ else:
851
+ raise ValueError(f"inp_block_sizes is not None: {inp_block_sizes}")
852
+ image_features = image_features.to(self.device, self.dtype)
853
+ image_features = self.get_mm_projector()(image_features)
854
+ image_features = list(
855
+ image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0)
856
+ )
857
+ image_features = [
858
+ self.merge_chessboard(x, block_size[0], block_size[1])
859
+ for x, block_size in zip(image_features, new_block_sizes)
860
+ ] # list of 1 * C * H * W tensors
861
+ image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] # list of N * C tensors
862
+ if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]):
863
+ image_features = torch.stack(image_features, dim=0)
864
+ else:
865
+ if len(images) > 0:
866
+ image_features = self.get_vision_tower()(images)
867
+ else:
868
+ image_features = []
869
+
870
+ # load cache features
871
+ image_features = _load_video_features(image_features, cache_feas, cache_feas_index, raw_videos_num_frames)
872
+
873
+ image_features = self.get_mm_projector()(image_features)
874
+ return image_features
875
+
876
+ def encode_images(self, images, block_sizes: Optional[Optional[Tuple[int, ...]]] = None, mm_info: Optional[dict] = None, num_frames: Optional[List[int]] = None):
877
+ if block_sizes is None:
878
+ block_sizes = [None] * len(images)
879
+
880
+ if getattr(self.config, "dynamic_s2", False):
881
+ image_features = self.get_vision_tower()(images)
882
+
883
+ image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
884
+
885
+ image_features = [
886
+ self.split_chessboard(x, block_size[0], block_size[1])
887
+ for x, block_size in zip(image_features, new_block_sizes)
888
+ ] # list of B * C * H * W tensors
889
+ image_features = torch.cat(
890
+ [rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0
891
+ ) # B * N * C
892
+
893
+ image_features = self.get_mm_projector()(image_features)
894
+ image_features = list(
895
+ image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0)
896
+ )
897
+ image_features = [
898
+ self.merge_chessboard(x, block_size[0], block_size[1])
899
+ for x, block_size in zip(image_features, new_block_sizes)
900
+ ] # list of 1 * C * H * W tensors
901
+ image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] # list of N * C tensors
902
+ if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]):
903
+ image_features = torch.stack(image_features, dim=0)
904
+ else:
905
+ image_features = self.get_vision_tower()(images)
906
+
907
+ image_features = self.get_mm_projector()(image_features)
908
+ return image_features
909
+
910
+ def encode_sound(self, sounds, mm_info: Optional[dict] = None):
911
+
912
+ audio_features, audio_output_lengths = self.get_sound_tower()(sounds)
913
+
914
+ use_fea_downsample = False
915
+ if getattr(self.config, "sound_mm_projector", "") != "":
916
+ if "mlp_downsample" in getattr(self.config, "sound_mm_projector", ""):
917
+ use_fea_downsample = True
918
+ else:
919
+ sound_mm_projector_cfg = getattr(self.config, "sound_mm_projector_cfg", None)
920
+ if sound_mm_projector_cfg is not None:
921
+ if type(sound_mm_projector_cfg) == dict:
922
+ if "mlp_downsample" in sound_mm_projector_cfg["sound_mm_projector_type"]:
923
+ use_fea_downsample = True
924
+ elif type(sound_mm_projector_cfg) == SoundMultimodalProjectorConfig:
925
+ if "mlp_downsample" in sound_mm_projector_cfg.sound_mm_projector_type:
926
+ use_fea_downsample = True
927
+
928
+ if not use_fea_downsample:
929
+ audio_features = self.get_sound_mm_projector()(audio_features)
930
+
931
+ if audio_output_lengths is not None:
932
+ # split the batch
933
+ new_audio_features = []
934
+ start = 0
935
+ for length in audio_output_lengths:
936
+ new_audio_features.append(audio_features[start : start + length])
937
+ start += length
938
+ audio_features = new_audio_features
939
+
940
+ if use_fea_downsample:
941
+ audio_features = torch.stack(audio_features, dim=0)
942
+ audio_features = self.get_sound_mm_projector()(audio_features)
943
+
944
+ return audio_features
945
+
946
+ def train(self, mode: bool = True):
947
+ super().train(mode)
948
+ return self
949
+
950
+ def _embed(
951
+ self,
952
+ input_ids: torch.Tensor,
953
+ media: Dict[str, List[torch.Tensor]],
954
+ media_config: Dict[str, Dict[str, Any]],
955
+ labels: Optional[torch.Tensor],
956
+ attention_mask: Optional[torch.Tensor],
957
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
958
+ media = copy.deepcopy(media)
959
+ media_config = copy.deepcopy(media_config)
960
+
961
+ labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX)
962
+ attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool)
963
+
964
+ PROCESS_GROUP_MANAGER = None
965
+ if PROCESS_GROUP_MANAGER is not None:
966
+ for name in media:
967
+ self.encoders[name].end_tokens = None
968
+
969
+ # Extract text and media embeddings
970
+ text_embeds = self.llm_model_embed_tokens(input_ids)
971
+
972
+ mm_info = {}
973
+ if "video_info" in media:
974
+ video_info = media["video_info"]
975
+ del media["video_info"]
976
+ mm_info['video_info'] = video_info
977
+ else:
978
+ video_info = None
979
+
980
+ if "audio_info" in media:
981
+ audio_info = media["audio_info"]
982
+ del media["audio_info"]
983
+ mm_info['audio_info'] = audio_info
984
+ else:
985
+ audio_info = None
986
+
987
+ if media is not None:
988
+ media_embeds = self.__embed_media_tokens(media, media_config, mm_info)
989
+ else:
990
+ # no media was provided, so we just return an empty dict
991
+ media_embeds = {}
992
+
993
+ if PROCESS_GROUP_MANAGER is not None:
994
+ media_embeds_video = []
995
+ for i, images in enumerate(media_embeds["video"]):
996
+ num_video_frame = media["video"][i].shape[0]
997
+ media_embeds_video += torch.unbind(images.reshape(num_video_frame, -1, images.shape[-1]))
998
+ media_embeds["video"] = deque(media_embeds_video)
999
+
1000
+ # This is a workaround to make sure the dummy embeddings are consumed
1001
+ while media_embeds.get("dummy"):
1002
+ dummy_embed = media_embeds["dummy"].popleft()
1003
+ text_embeds += torch.sum(dummy_embed) * 0
1004
+
1005
+ # Based on segment_aud_indices_list and segment_vis_indices_list, get interleaved vis-aud embeddings for video
1006
+ video_sound_embeds_idx = 0
1007
+ sep_embed = self.encoders["video"].embed_tokens("\n")
1008
+ text_embeds = text_embeds.to(self.dtype)
1009
+ sep_embed = sep_embed.to(text_embeds.dtype)
1010
+
1011
+ if video_info is not None and self.config.load_audio_in_video and self.config.interleaved_vis_aud_in_video:
1012
+ assert self.encoders["video"].end_tokens is None, "end_tokens must be None for interleaved vis-aud in video"
1013
+ new_video_embeds = deque()
1014
+ video_embeds_idx = 0
1015
+ for k in range(len(video_info)):
1016
+ if video_info[k] is None:
1017
+ continue
1018
+ for i in range(len(video_info[k])):
1019
+ has_audio = video_info[k][i]["has_audio"]
1020
+ if not has_audio:
1021
+ new_video_embeds.append(media_embeds["video"][video_embeds_idx])
1022
+ video_embeds_idx += 1
1023
+ continue
1024
+
1025
+ # Check bounds for sound embeddings
1026
+ if video_sound_embeds_idx >= len(media_embeds["sound"]):
1027
+ raise ValueError(f"Sound embeddings index {video_sound_embeds_idx} out of bounds for video_info[{k}][{i}]")
1028
+
1029
+ segment_aud_indices_list = video_info[k][i]["segment_aud_indices_list"]
1030
+ segment_vis_indices_list = video_info[k][i]["segment_vis_indices_list"]
1031
+
1032
+ vis_fea_len_per_frame = media_embeds["video"][video_embeds_idx].shape[0] / video_info[k][i]["expected_frame_count"]
1033
+ aud_fea_len_per_stft_frame = media_embeds["sound"][video_sound_embeds_idx].shape[0] / audio_info[k][i]["new_audio_n_stft_frames"]
1034
+ vis_end = 0
1035
+ aud_end = 0
1036
+ _new_video_embed = []
1037
+ for j in range(len(segment_vis_indices_list)):
1038
+ _vis_aud_fea = []
1039
+ if len(segment_vis_indices_list[j]) > 0:
1040
+ _new_frames = [int(np.ceil((_frame+1) * vis_fea_len_per_frame)) for _frame in segment_vis_indices_list[j]]
1041
+ _vis_fea_end = _new_frames[-1]
1042
+ # Ensure we don't exceed the available features
1043
+ _vis_fea_end = min(_vis_fea_end, media_embeds["video"][video_embeds_idx].shape[0])
1044
+ if j == len(segment_vis_indices_list) - 1 and i == len(video_info) - 1 and k == len(video_info[i]) - 1 and not _vis_fea_end == media_embeds["video"][video_embeds_idx].shape[0]:
1045
+ print(f"Warning: The number of last interleaved video features does not match the video feature length. Expected: {media_embeds['video'][video_embeds_idx].shape[0]}, Got: {_vis_fea_end}")
1046
+ _vis_fea_end = media_embeds["video"][video_embeds_idx].shape[0]
1047
+ _vis_fea = media_embeds["video"][video_embeds_idx][vis_end:_vis_fea_end]
1048
+ vis_end = _vis_fea_end
1049
+ _vis_aud_fea.append(_vis_fea)
1050
+ _vis_aud_fea.append(sep_embed)
1051
+ if len(segment_aud_indices_list[j]) > 0:
1052
+ _new_audio_indices = [int(np.ceil(_fea * aud_fea_len_per_stft_frame)) for _fea in segment_aud_indices_list[j]]
1053
+ _aud_fea_end = _new_audio_indices[-1]
1054
+ # Ensure we don't exceed the available features
1055
+ _aud_fea_end = min(_aud_fea_end, media_embeds["sound"][video_sound_embeds_idx].shape[0])
1056
+ _aud_fea = media_embeds["sound"][video_sound_embeds_idx][aud_end:_aud_fea_end]
1057
+ _vis_aud_fea.append(_aud_fea)
1058
+ aud_end = _aud_fea_end
1059
+ _vis_aud_fea.append(sep_embed)
1060
+ _new_video_embed.append(torch.cat(_vis_aud_fea, dim=0))
1061
+ video_sound_embeds_idx += 1
1062
+ new_video_embeds.append(torch.cat(_new_video_embed, dim=0))
1063
+ video_embeds_idx += 1
1064
+
1065
+ assert len(new_video_embeds) == len(media_embeds["video"]), "The number of new video embeddings does not match the number of original video embeddings."
1066
+ media_embeds["video"] = new_video_embeds
1067
+ # Remove padding
1068
+ batch_size = labels.shape[0]
1069
+ text_embeds = [text_embeds[k][attention_mask[k]] for k in range(batch_size)]
1070
+ labels = [labels[k][attention_mask[k]] for k in range(batch_size)]
1071
+ # Build inverse mapping from token ID to media name
1072
+ media_tokens = {}
1073
+ for name, token_id in self.tokenizer.media_token_ids.items():
1074
+ media_tokens[token_id] = name
1075
+
1076
+ # Fuse text and media embeddings
1077
+ inputs_m, labels_m = [], []
1078
+ sound_embeds_idx = 0
1079
+ for k in range(batch_size):
1080
+ inputs_mk, labels_mk = [], []
1081
+ pos = 0
1082
+ while pos < len(labels[k]):
1083
+ if input_ids[k][pos].item() in media_tokens:
1084
+ name = media_tokens[input_ids[k][pos].item()] if PROCESS_GROUP_MANAGER is None else "video"
1085
+ if input_ids[k][pos].item() == self.tokenizer.media_token_ids["sound"]:
1086
+ if self.config.interleaved_vis_aud_in_video:
1087
+ if sound_embeds_idx < video_sound_embeds_idx:
1088
+ media_embeds[name].popleft()
1089
+ sound_embeds_idx += 1
1090
+ pos += 1
1091
+ continue
1092
+ sound_embeds_idx += 1
1093
+
1094
+ end = pos + 1
1095
+ input = media_embeds[name].popleft()
1096
+ label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
1097
+ else:
1098
+ end = pos
1099
+ while end < len(labels[k]) and input_ids[k][end].item() not in media_tokens:
1100
+ end += 1
1101
+ input = text_embeds[k][pos:end]
1102
+ label = labels[k][pos:end]
1103
+
1104
+ inputs_mk.append(input)
1105
+ labels_mk.append(label)
1106
+ pos = end
1107
+ inputs_m.append(torch.cat(inputs_mk, dim=0))
1108
+ labels_m.append(torch.cat(labels_mk, dim=0))
1109
+ inputs, labels = inputs_m, labels_m
1110
+
1111
+ inputs[0] += sep_embed.mean() * 0 # dummy embedding
1112
+ # Check if all media embeddings are consumed
1113
+
1114
+ for name in media_embeds:
1115
+ if media_embeds[name]:
1116
+ raise ValueError(f"Not all {name} embeddings are consumed! Still {len(media_embeds[name])} left.")
1117
+
1118
+ # Truncate sequences to `model_max_length` as media embeddings are inserted
1119
+ inputs, labels = self.__truncate_sequence(inputs, labels)
1120
+
1121
+ # Pad sequences to the longest one in the batch
1122
+ return self.__batchify_sequence(inputs, labels)
1123
+
1124
+ def __embed_media_tokens(
1125
+ self,
1126
+ media: Dict[str, List[torch.Tensor]],
1127
+ media_config: Dict[str, Dict[str, Any]],
1128
+ mm_info,
1129
+ ) -> Dict[str, List[torch.Tensor]]:
1130
+ embeds = defaultdict(deque)
1131
+
1132
+ if self.config.unified_audio_encoder:
1133
+ assert len(media["speech"]) == 0
1134
+
1135
+ for name in media:
1136
+ _encoder = self.encoders[name]
1137
+ if name in ["speech", "sound"] and self.config.unified_audio_encoder:
1138
+ _encoder = self.encoders["sound"]
1139
+
1140
+ if self.training:
1141
+ # Gather metainfo of media objects from all ranks
1142
+ if name in ["speech", "sound"]:
1143
+
1144
+ info = []
1145
+ if type(media.get(name, {})) is dict:
1146
+ for _dict in media.get(name, {}):
1147
+ info.append({k: {"shape": v.shape, "dtype": v.dtype} for k, v in _dict.items()})
1148
+ elif type(media.get(name, {})) is list:
1149
+ info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
1150
+ else:
1151
+ raise ValueError(f"Unsupported media type: {type(media.get(name, {}))}")
1152
+
1153
+ infos_list = vila_all_gather(info)
1154
+ infos = list(chain(*infos_list))
1155
+
1156
+ # The entire batch does not contain any media objects of this type.
1157
+ if not infos:
1158
+ continue
1159
+
1160
+ # for audio encoding, we have to ensure the batch size is the same for all ranks. If not, we need to pad the batch with dummy tensors to the max batch size
1161
+ max_batch_size = max(len(_info) for _info in infos_list)
1162
+ missing_batch_size = max_batch_size - len(info)
1163
+
1164
+ _media = media.get(name, [])
1165
+
1166
+ _medias = list(chain(vila_all_gather(_media)))
1167
+ if missing_batch_size > 0:
1168
+ for i in range(missing_batch_size):
1169
+ # use one of the media tensors to create a dummy tensor
1170
+ if type(media.get(name, {})) is dict:
1171
+ _dummy = {k: v.clone().to(device=self.device) for k, v in _medias[0].items()}
1172
+ elif type(media.get(name, {})) is list:
1173
+ if type(_medias[0]) is torch.Tensor:
1174
+ _dummy = _medias[0].clone().to(device=self.device)
1175
+ elif type(_medias[0]) is np.ndarray:
1176
+ _dummy = _medias[0].copy()
1177
+ else:
1178
+ raise ValueError(f"Unsupported media type: {type(_medias[0])}")
1179
+ else:
1180
+ raise ValueError(f"Unsupported media type: {type(media.get(name, {}))}")
1181
+ _media.append(_dummy)
1182
+ mm_info["audio_info"].append(["dummy"])
1183
+
1184
+ # we need to also align the length of all audio samples in the batch size
1185
+ cur_batch_max_audio_samples = max(len(_audio) for _audio in _medias)
1186
+ cur_batch_max_audio_samples = int(np.ceil(cur_batch_max_audio_samples / (self.config.audio_sampling_rate * 30)) * (self.config.audio_sampling_rate * 30)) # should be multiple of 30 seconds
1187
+ cur_batch_max_audio_samples = min(cur_batch_max_audio_samples, self.config.audio_chunk_length * self.config.audio_sampling_rate)
1188
+ cur_batch_max_audio_duration = cur_batch_max_audio_samples // self.config.audio_sampling_rate
1189
+
1190
+
1191
+ whisper_feature_extractor = WhisperFeatureExtractor.from_pretrained(
1192
+ self.config._name_or_path, chunk_length=cur_batch_max_audio_duration, sampling_rate=self.config.audio_sampling_rate, hop_length=self.config.audio_hop_length
1193
+ )
1194
+
1195
+ # use WhisperFeatureExtractor in transformers to load
1196
+ new_media = []
1197
+
1198
+ aud_idx = 0
1199
+ for _batch_idx in range(len(mm_info["audio_info"])):
1200
+ _audio_info = mm_info["audio_info"][_batch_idx]
1201
+ if _audio_info is not None:
1202
+ for _mm_idx in range(len(_audio_info)):
1203
+ _audio = _media[aud_idx]
1204
+ if type(_audio) is torch.Tensor:
1205
+ device = _audio.device
1206
+ dtype = _audio.dtype
1207
+ _audio = _audio.cpu().float()
1208
+ else:
1209
+ # logger.warning(f"The audio type is not a tensor, which is unexpected. Using the device and dtype of the model. media: {media}, mm_info: {mm_info}")
1210
+ device = self.device
1211
+ dtype = self.dtype
1212
+ _audio = whisper.pad_or_trim(_audio, length=cur_batch_max_audio_samples)
1213
+ aud_idx += 1
1214
+ stft_features = whisper_feature_extractor(
1215
+ _audio,
1216
+ sampling_rate=self.config.audio_sampling_rate,
1217
+ return_attention_mask=True,
1218
+ padding="max_length",
1219
+ return_tensors="pt",
1220
+ ).to(device, dtype)
1221
+ new_media.append(stft_features)
1222
+ if _audio_info[_mm_idx] != "dummy":
1223
+ _audio_info[_mm_idx]["new_audio_chunk_length"] = cur_batch_max_audio_duration
1224
+ _audio_info[_mm_idx]["new_audio_n_samples"] = cur_batch_max_audio_samples
1225
+ _audio_info[_mm_idx]["audio_end_sample_sec"] = _audio_info[_mm_idx]["audio_start_sec"] + cur_batch_max_audio_duration
1226
+ _audio_info[_mm_idx]["new_audio_n_stft_frames"] = stft_features["input_features"].shape[-1]
1227
+
1228
+ assert aud_idx == len(_media), "The number of audio info does not match the number of audio samples."
1229
+ _media = new_media
1230
+
1231
+ _fea = _encoder(_media, media_config[name], mm_info)
1232
+ # [751, 1536]
1233
+ # consume dummy features later
1234
+ _dummy_fea = _fea[len(info) :]
1235
+ embeds["dummy"].extend(_dummy_fea)
1236
+
1237
+ # remove the dummy features
1238
+ _real_fea = _fea[: len(info)]
1239
+ if len(info) > 0:
1240
+ embeds[name] = deque(_real_fea)
1241
+
1242
+ else:
1243
+ # Gather metainfo of media objects from all ranks
1244
+ info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
1245
+ infos = list(chain(vila_all_gather(info)))
1246
+
1247
+ # The entire batch does not contain any media objects of this type.
1248
+ if not infos:
1249
+ continue
1250
+
1251
+ # Create a dummy tensor to ensure the encoder is called, otherwise the training will hang.
1252
+ if media.get(name) is None or len(media[name]) == 0:
1253
+ dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
1254
+ embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
1255
+ continue
1256
+ embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
1257
+
1258
+ else:
1259
+ if name == "sound":
1260
+ all_audio_chunk_lengths = []
1261
+ for _sample_idx in range(len(media[name])):
1262
+ for _mm_idx in range(len(mm_info["audio_info"][_sample_idx])):
1263
+ _new_audio_chunk_length = mm_info["audio_info"][_sample_idx][_mm_idx]["new_audio_chunk_length"]
1264
+ all_audio_chunk_lengths.append(_new_audio_chunk_length)
1265
+ cur_batch_max_audio_duration = max(all_audio_chunk_lengths)
1266
+ cur_batch_max_audio_samples = cur_batch_max_audio_duration * self.config.audio_sampling_rate
1267
+ # for qwen omni audio
1268
+ # cur_batch_max_audio_samples = 960000
1269
+
1270
+ whisper_feature_extractor = WhisperFeatureExtractor.from_pretrained(
1271
+ self.config._name_or_path, chunk_length=cur_batch_max_audio_duration, sampling_rate=self.config.audio_sampling_rate, hop_length=self.config.audio_hop_length
1272
+ )
1273
+
1274
+ new_media = []
1275
+ _idx = 0
1276
+ assert len(all_audio_chunk_lengths) == len(media[name]), "The number of audio chunk lengths does not match the number of audio samples."
1277
+
1278
+ _media = media.get(name, [])
1279
+ aud_idx = 0
1280
+ for _batch_idx in range(len(mm_info["audio_info"])):
1281
+ _audio_info = mm_info["audio_info"][_batch_idx]
1282
+ if _audio_info is not None:
1283
+ for _mm_idx in range(len(_audio_info)):
1284
+ _audio = _media[aud_idx]
1285
+ if type(_audio) is torch.Tensor:
1286
+ device = _audio.device
1287
+ dtype = _audio.dtype
1288
+ _audio = _audio.cpu().float()
1289
+ else:
1290
+ device = self.device
1291
+ dtype = self.dtype
1292
+ _audio = whisper.pad_or_trim(_audio, length=cur_batch_max_audio_samples)
1293
+ aud_idx += 1
1294
+ stft_features = whisper_feature_extractor(
1295
+ _audio,
1296
+ sampling_rate=self.config.audio_sampling_rate,
1297
+ return_attention_mask=True,
1298
+ padding="max_length",
1299
+ return_tensors="pt",
1300
+ ).to(device, dtype)
1301
+
1302
+ new_media.append(stft_features)
1303
+ if _audio_info[_mm_idx] != "dummy":
1304
+ _audio_info[_mm_idx]["new_audio_chunk_length"] = cur_batch_max_audio_duration
1305
+ _audio_info[_mm_idx]["new_audio_n_samples"] = cur_batch_max_audio_samples
1306
+ _audio_info[_mm_idx]["audio_end_sample_sec"] = _audio_info[_mm_idx]["audio_start_sec"] + cur_batch_max_audio_duration
1307
+ _audio_info[_mm_idx]["new_audio_n_stft_frames"] = stft_features["input_features"].shape[-1]
1308
+ media[name] = new_media
1309
+
1310
+ if len(media[name]) > 0:
1311
+ embeds[name] = deque(_encoder(media[name], media_config[name], mm_info))
1312
+ return embeds
1313
+
1314
+ def __truncate_sequence(
1315
+ self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
1316
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1317
+ if self.training and any(len(input) > self.tokenizer.model_max_length for input in inputs):
1318
+ warnings.warn(f"Truncating sequences to `model_max_length` ({self.tokenizer.model_max_length}).")
1319
+ inputs = [input[: self.tokenizer.model_max_length] for input in inputs]
1320
+ labels = [label[: self.tokenizer.model_max_length] for label in labels]
1321
+ return inputs, labels
1322
+
1323
+ def __batchify_sequence(
1324
+ self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
1325
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1326
+ batch_size = len(inputs)
1327
+ device = inputs[0].device
1328
+ hidden_size = inputs[0].shape[1]
1329
+ max_length = max(inputs[k].shape[0] for k in range(batch_size))
1330
+ attention_mask = torch.ones((batch_size, max_length), dtype=torch.bool, device=device)
1331
+
1332
+ inputs_p, labels_p = [], []
1333
+ for k in range(batch_size):
1334
+ size_pk = max_length - inputs[k].shape[0]
1335
+ inputs_pk = torch.zeros((size_pk, hidden_size), dtype=inputs[k].dtype, device=device)
1336
+ labels_pk = torch.full((size_pk,), IGNORE_INDEX, dtype=labels[k].dtype, device=device)
1337
+ if self.tokenizer.padding_side == "right":
1338
+ attention_mask[k, inputs[k].shape[0] :] = False
1339
+ inputs_pk = torch.cat([inputs[k], inputs_pk], dim=0)
1340
+ labels_pk = torch.cat([labels[k], labels_pk], dim=0)
1341
+ else:
1342
+ labels[k] = labels[k].to(device)
1343
+ attention_mask[k, : -inputs[k].shape[0]] = False
1344
+ inputs_pk = torch.cat([inputs_pk, inputs[k]], dim=0)
1345
+ labels_pk = torch.cat([labels_pk, labels[k]], dim=0)
1346
+ inputs_p.append(inputs_pk)
1347
+ labels_p.append(labels_pk)
1348
+
1349
+ inputs = torch.stack(inputs_p, dim=0)
1350
+ labels = torch.stack(labels_p, dim=0)
1351
+ return inputs, labels, attention_mask
1352
+
1353
+ def repack_multimodal_data(self, inputs_embeds, attention_mask, position_ids, labels):
1354
+ # Handle sequence parallelism
1355
+ PROCESS_GROUP_MANAGER = None
1356
+
1357
+ # We do re-sharding instead of packing here to ensure the sequence length is the same across all ranks.
1358
+ if PROCESS_GROUP_MANAGER is not None:
1359
+ sp_degree = PROCESS_GROUP_MANAGER.sp_degree
1360
+ sp_rank = PROCESS_GROUP_MANAGER.sp_rank
1361
+ sp_group = PROCESS_GROUP_MANAGER.sp_pg
1362
+ ring_degree = PROCESS_GROUP_MANAGER.ring_degree
1363
+ ring_rank = PROCESS_GROUP_MANAGER.ring_rank
1364
+ ring_type = PROCESS_GROUP_MANAGER.ring_type
1365
+ ulysses_degree = PROCESS_GROUP_MANAGER.ulysses_degree
1366
+ ulysses_rank = PROCESS_GROUP_MANAGER.ulysses_rank
1367
+
1368
+ bs, shard_seqlen = position_ids.shape
1369
+ sp_seq_len = [torch.zeros(1, dtype=torch.int64, device=position_ids.device) for _ in range(sp_degree)]
1370
+ dist.all_gather(sp_seq_len, torch.tensor(shard_seqlen, device=position_ids.device), group=sp_group)
1371
+ sp_seq_len_cat = torch.cat(sp_seq_len, dim=0)
1372
+
1373
+ if sp_rank == 0:
1374
+ original_start_id = 0
1375
+ else:
1376
+ original_start_id = torch.sum(sp_seq_len_cat[:sp_rank]).item()
1377
+ original_end_id = torch.sum(sp_seq_len_cat[: sp_rank + 1]).item()
1378
+
1379
+ # Gather attention_mask, position_ids, labels and input_embeds
1380
+ all_inputs_embeds = torch.zeros(
1381
+ bs,
1382
+ torch.sum(sp_seq_len_cat),
1383
+ inputs_embeds.shape[-1],
1384
+ dtype=inputs_embeds.dtype,
1385
+ device=inputs_embeds.device,
1386
+ ).contiguous()
1387
+ all_inputs_embeds[:, original_start_id:original_end_id, :] += inputs_embeds
1388
+ dist.barrier(group=sp_group)
1389
+ dist.all_reduce(all_inputs_embeds, group=sp_group)
1390
+ dist.barrier(group=sp_group)
1391
+
1392
+ attention_mask_list = [
1393
+ torch.zeros((bs, sp_seq_len[i]), dtype=attention_mask.dtype, device=attention_mask.device)
1394
+ for i in range(sp_degree)
1395
+ ]
1396
+ position_ids_list = [
1397
+ torch.zeros((bs, sp_seq_len[i]), dtype=position_ids.dtype, device=position_ids.device)
1398
+ for i in range(sp_degree)
1399
+ ]
1400
+ labels_list = [
1401
+ torch.zeros((bs, sp_seq_len[i]), dtype=labels.dtype, device=labels.device) for i in range(sp_degree)
1402
+ ]
1403
+
1404
+ dist.all_gather(attention_mask_list, attention_mask, group=sp_group)
1405
+ dist.all_gather(position_ids_list, position_ids, group=sp_group)
1406
+ dist.all_gather(labels_list, labels, group=sp_group)
1407
+
1408
+ effective_seqlen_list = [attention_mask_list[i].sum(dim=-1) for i in range(sp_degree)]
1409
+ effective_seqlen = torch.stack(effective_seqlen_list, dim=-1)
1410
+ effective_seqlen_batch_list = torch.unbind(effective_seqlen, dim=0)
1411
+
1412
+ global_attention_mask_list = []
1413
+ global_position_ids_list = []
1414
+ global_labels_list = []
1415
+ global_inputs_embeds_list = []
1416
+ for i in range(bs):
1417
+ global_attention_mask_batch_list = []
1418
+ global_position_ids_batch_list = []
1419
+ global_labels_batch_list = []
1420
+ global_inputs_embeds_batch_list = []
1421
+ for j in range(sp_degree):
1422
+ eff_len = effective_seqlen_batch_list[i][j]
1423
+ prev_len = torch.sum(sp_seq_len_cat[:j]).item() if j > 0 else 0
1424
+
1425
+ global_attention_mask_batch_list.append(attention_mask_list[j][i, :eff_len])
1426
+ global_position_ids_batch_list.append(position_ids_list[j][i, :eff_len])
1427
+ global_labels_batch_list.append(labels_list[j][i, :eff_len])
1428
+ global_inputs_embeds_batch_list.append(all_inputs_embeds[i, prev_len : prev_len + eff_len, :])
1429
+ global_attention_mask_list.append(torch.cat(global_attention_mask_batch_list, dim=0))
1430
+ global_position_ids_list.append(torch.cat(global_position_ids_batch_list, dim=0))
1431
+ global_labels_list.append(torch.cat(global_labels_batch_list, dim=0))
1432
+ global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0))
1433
+
1434
+ global_attention_mask = torch.nn.utils.rnn.pad_sequence(
1435
+ global_attention_mask_list, batch_first=True, padding_value=False
1436
+ )
1437
+ global_position_ids = torch.nn.utils.rnn.pad_sequence(
1438
+ global_position_ids_list, batch_first=True, padding_value=-1
1439
+ )
1440
+ global_labels = torch.nn.utils.rnn.pad_sequence(
1441
+ global_labels_list, batch_first=True, padding_value=IGNORE_INDEX
1442
+ )
1443
+ global_inputs_embeds = torch.nn.utils.rnn.pad_sequence(
1444
+ global_inputs_embeds_list, batch_first=True, padding_value=0
1445
+ )
1446
+
1447
+ # Re-shard the inputs
1448
+ if ring_degree > 1:
1449
+ total_effective_seqlen = torch.sum(effective_seqlen, dim=1)
1450
+ new_seqlen_per_rank = total_effective_seqlen // sp_degree
1451
+ assert torch.all(
1452
+ total_effective_seqlen % sp_degree == 0
1453
+ ), "total_effective_seqlen must be divisible by sp_degree"
1454
+
1455
+ max_new_seqlen = torch.max(new_seqlen_per_rank).item()
1456
+
1457
+ new_attention_mask = torch.zeros(
1458
+ (bs, max_new_seqlen), dtype=global_attention_mask.dtype, device=global_attention_mask.device
1459
+ )
1460
+ new_position_ids = torch.zeros(
1461
+ (bs, max_new_seqlen), dtype=global_position_ids.dtype, device=global_position_ids.device
1462
+ )
1463
+ new_labels = torch.full(
1464
+ (bs, max_new_seqlen), IGNORE_INDEX, dtype=global_labels.dtype, device=global_labels.device
1465
+ )
1466
+ new_inputs_embeds = torch.zeros(
1467
+ (bs, max_new_seqlen, global_inputs_embeds.shape[-1]),
1468
+ dtype=global_inputs_embeds.dtype,
1469
+ device=global_inputs_embeds.device,
1470
+ )
1471
+
1472
+ if ring_type == "ring_varlen":
1473
+ for i in range(bs):
1474
+ start_idx = new_seqlen_per_rank[i] * sp_rank
1475
+ end_idx = start_idx + new_seqlen_per_rank[i]
1476
+ new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx]
1477
+ new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx]
1478
+ new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx]
1479
+ new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[
1480
+ i, start_idx:end_idx, :
1481
+ ]
1482
+ elif ring_type == "zigzag_ring_varlen":
1483
+ chunk_size = total_effective_seqlen // (2 * sp_degree)
1484
+ for i in range(bs):
1485
+ # Zigzag pattern indices
1486
+ if sp_degree == ring_degree:
1487
+ forward_rank_idx = sp_rank
1488
+ backward_rank_idx = 2 * sp_degree - sp_rank - 1
1489
+ else:
1490
+ ulysses_offset = ulysses_rank * ring_degree * 2
1491
+ forward_rank_idx = ring_rank + ulysses_offset
1492
+ backward_rank_idx = sp_degree - ring_rank - 1 + ulysses_offset
1493
+
1494
+ # Calculate start and end indices for the forward and backward zigzag
1495
+ start_idx_fwd = forward_rank_idx * chunk_size[i]
1496
+ end_idx_fwd = start_idx_fwd + chunk_size[i]
1497
+
1498
+ start_idx_bwd = backward_rank_idx * chunk_size[i]
1499
+ end_idx_bwd = start_idx_bwd + chunk_size[i]
1500
+
1501
+ # Fill new tensors with zigzag data
1502
+ new_attention_mask[i, : chunk_size[i]] = global_attention_mask[i, start_idx_fwd:end_idx_fwd]
1503
+ new_attention_mask[i, chunk_size[i] : 2 * chunk_size[i]] = global_attention_mask[
1504
+ i, start_idx_bwd:end_idx_bwd
1505
+ ]
1506
+
1507
+ new_position_ids[i, : chunk_size[i]] = global_position_ids[i, start_idx_fwd:end_idx_fwd]
1508
+ new_position_ids[i, chunk_size[i] : 2 * chunk_size[i]] = global_position_ids[
1509
+ i, start_idx_bwd:end_idx_bwd
1510
+ ]
1511
+
1512
+ new_labels[i, : chunk_size[i]] = global_labels[i, start_idx_fwd:end_idx_fwd]
1513
+ new_labels[i, chunk_size[i] : 2 * chunk_size[i]] = global_labels[i, start_idx_bwd:end_idx_bwd]
1514
+
1515
+ new_inputs_embeds[i, : chunk_size[i], :] = global_inputs_embeds[i, start_idx_fwd:end_idx_fwd, :]
1516
+ new_inputs_embeds[i, chunk_size[i] : 2 * chunk_size[i], :] = global_inputs_embeds[
1517
+ i, start_idx_bwd:end_idx_bwd, :
1518
+ ]
1519
+ else:
1520
+ raise ValueError(f"Invalid ring_type: {ring_type}")
1521
+ else:
1522
+ global_seq_len = global_attention_mask.shape[-1]
1523
+ seq_len_sharded = global_seq_len // sp_degree
1524
+ start_idx_reshard = seq_len_sharded * sp_rank
1525
+ end_idx_reshard = start_idx_reshard + seq_len_sharded if sp_rank < sp_degree - 1 else global_seq_len
1526
+
1527
+ new_attention_mask = torch.narrow(
1528
+ global_attention_mask, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
1529
+ )
1530
+ new_position_ids = torch.narrow(
1531
+ global_position_ids, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
1532
+ )
1533
+ new_labels = torch.narrow(global_labels, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)
1534
+ new_inputs_embeds = torch.narrow(
1535
+ global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
1536
+ )
1537
+
1538
+ return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels
1539
+
1540
+ device = inputs_embeds.device
1541
+ batch_size = inputs_embeds.shape[0]
1542
+ seqlens = [attention_mask[k].sum().item() for k in range(batch_size)]
1543
+
1544
+ # Pack all sequences together
1545
+ inputs_embeds_p = [inputs_embeds[k][attention_mask[k]] for k in range(batch_size)]
1546
+ attention_mask_p = [torch.ones(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
1547
+ position_ids_p = [torch.arange(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
1548
+ labels_p = [labels[k][attention_mask[k]] for k in range(batch_size)]
1549
+
1550
+ # Add one dummy token at the end of the packed sequence to ensure that `_get_unpacked_data` will be called
1551
+ inputs_embeds_p.append(torch.zeros(1, inputs_embeds.shape[-1], dtype=inputs_embeds.dtype, device=device))
1552
+ attention_mask_p.append(torch.tensor([0], dtype=torch.int, device=device))
1553
+ position_ids_p.append(torch.tensor([0], dtype=torch.int, device=device))
1554
+ labels_p.append(torch.tensor([IGNORE_INDEX], dtype=torch.int, device=device))
1555
+
1556
+ # Mask the first token of each sequence to avoid contamination
1557
+ for label in labels_p:
1558
+ label[0] = IGNORE_INDEX
1559
+
1560
+ # Batch the data
1561
+ inputs_embeds_p = torch.cat(inputs_embeds_p, dim=0).unsqueeze(0)
1562
+ attention_mask_p = torch.cat(attention_mask_p, dim=0).unsqueeze(0)
1563
+ position_ids_p = torch.cat(position_ids_p, dim=0).unsqueeze(0)
1564
+ labels_p = torch.cat(labels_p, dim=0).unsqueeze(0)
1565
+
1566
+ if hasattr(
1567
+ self, "pad_to_multiple_of"
1568
+ ): # related to quantization, please refer to ModelArguments for more information.
1569
+ assert len(labels_p.shape) == 2
1570
+ batch_size, max_length, cur_length = labels_p.shape[0], labels_p.shape[1], labels_p.shape[1]
1571
+ hidden_size = inputs_embeds_p.shape[-1]
1572
+
1573
+ if max_length % self.pad_to_multiple_of != 0:
1574
+ max_length = ((max_length // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of
1575
+ difference = max_length - cur_length
1576
+
1577
+ inputs_embeds_p = torch.cat(
1578
+ (
1579
+ inputs_embeds_p,
1580
+ torch.full((batch_size, difference, hidden_size), self.llm.pad_token_id).to(inputs_embeds_p),
1581
+ ),
1582
+ dim=1,
1583
+ )
1584
+ labels_p = torch.cat((labels_p, torch.full((batch_size, difference), IGNORE_INDEX).to(labels_p)), dim=1)
1585
+ attention_mask_p = torch.cat(
1586
+ (
1587
+ attention_mask_p,
1588
+ torch.zeros((batch_size, difference), dtype=torch.bool).to(attention_mask_p),
1589
+ ),
1590
+ dim=1,
1591
+ )
1592
+ position_ids_p = torch.cat(
1593
+ (position_ids_p, torch.full((batch_size, difference), -1).to(position_ids_p)), dim=1
1594
+ )
1595
+
1596
+ return inputs_embeds_p, attention_mask_p, position_ids_p, labels_p
1597
+
1598
+ def forward(
1599
+ self,
1600
+ input_ids: torch.LongTensor = None,
1601
+ media: Optional[Dict[str, List[torch.Tensor]]] = None,
1602
+ images: Optional[torch.FloatTensor] = None,
1603
+ media_config: Optional[List] = None,
1604
+ pixel_values: Optional[torch.FloatTensor] = None,
1605
+ attention_mask: Optional[torch.Tensor] = None,
1606
+ position_ids: Optional[torch.LongTensor] = None,
1607
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1608
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1609
+ labels: Optional[torch.LongTensor] = None,
1610
+ packing: bool = True,
1611
+ force_packing: bool = False,
1612
+ seqlens_in_batch: Optional[torch.LongTensor] = None,
1613
+ dpo_forward: bool = False,
1614
+ **kwargs,
1615
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1616
+ self.freezed_module_patch()
1617
+
1618
+ if images is not None:
1619
+ if media is not None:
1620
+ raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
1621
+ print("The 'images' argument is deprecated. Please use 'media' instead.")
1622
+ media = {"image": images}
1623
+
1624
+ if media_config is None:
1625
+ media_config = defaultdict(dict)
1626
+
1627
+ if inputs_embeds is None:
1628
+ inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask)
1629
+
1630
+ if force_packing or (packing and self.training and not dpo_forward):
1631
+ if seqlens_in_batch is None:
1632
+ seqlens_in_batch = torch.sum(attention_mask, dim=1)
1633
+ set_seqlens_in_batch(seqlens_in_batch)
1634
+
1635
+ (inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data(
1636
+ inputs_embeds, attention_mask, position_ids, labels
1637
+ )
1638
+
1639
+ outputs = self.llm(
1640
+ inputs_embeds=inputs_embeds,
1641
+ attention_mask=attention_mask,
1642
+ position_ids=position_ids,
1643
+ past_key_values=past_key_values,
1644
+ labels=labels,
1645
+ **kwargs,
1646
+ )
1647
+
1648
+ if self.training and getattr(self.config, "time_token_ids", []):
1649
+ outputs.loss = soft_cross_entropy(
1650
+ outputs.logits,
1651
+ labels,
1652
+ soft_tokens=self.config.time_token_ids,
1653
+ std=self.config.soft_ce_std,
1654
+ )
1655
+
1656
+ if dpo_forward:
1657
+ return outputs.logits, labels
1658
+
1659
+ return outputs
1660
+
1661
+ @torch.inference_mode()
1662
+ def generate(
1663
+ self,
1664
+ input_ids: Optional[torch.FloatTensor] = None,
1665
+ media: Optional[Dict[str, List[torch.Tensor]]] = None,
1666
+ media_config: Dict[str, Dict[str, Any]] = None,
1667
+ attention_mask: Optional[torch.LongTensor] = None,
1668
+ return_output_ids_only: bool = True,
1669
+ **generation_kwargs,
1670
+ ) -> torch.LongTensor:
1671
+ """
1672
+ input_tokens: <image> describe the image
1673
+ media: [Tensor(1, 3, 384, 384), ]
1674
+ ----------->
1675
+ input_tokens: 36000 001 002 003 004
1676
+ input_emds: <media emd> 001 002 003 004
1677
+ """
1678
+ inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask)
1679
+
1680
+ output_ids = self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
1681
+
1682
+ if return_output_ids_only:
1683
+ return_value = output_ids
1684
+ else:
1685
+ # by default, return the input_ids and output_ids concatenated to keep consistency with the community VLMs like qwen
1686
+ generation_config = generation_kwargs.get("generation_config", None)
1687
+ if generation_config is not None:
1688
+ num_generations = generation_config.num_return_sequences
1689
+ repeat_input_ids = input_ids.repeat_interleave(num_generations, dim=0)
1690
+ return_value = torch.cat([repeat_input_ids, output_ids], dim=-1)
1691
+ else:
1692
+ return_value = torch.cat([input_ids, output_ids], dim=-1)
1693
+
1694
+ return return_value
1695
+
1696
+ @torch.inference_mode()
1697
+ def generate_content(
1698
+ self,
1699
+ prompt: Union[str, List],
1700
+ generation_config: Optional[GenerationConfig] = None,
1701
+ response_format=None,
1702
+ ) -> str:
1703
+ conversation = [{"from": "human", "value": prompt}]
1704
+
1705
+ # Convert response format to logits processor
1706
+ xgr_logits_processor = None
1707
+
1708
+ # Extract media from the conversation
1709
+
1710
+ media = extract_media(conversation, self.config)
1711
+
1712
+ # Process media
1713
+ media_config = defaultdict(dict)
1714
+ for name in media:
1715
+ if name == "image":
1716
+ if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
1717
+ self.config.image_processor = self.vision_tower.image_processor
1718
+ if self.config.image_aspect_ratio == "dynamic":
1719
+ images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
1720
+ conversation[0]["value"] = conversation[0]["value"].replace(
1721
+ DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
1722
+ )
1723
+ else:
1724
+ if type(self.config.s2_scales) is str:
1725
+ self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
1726
+ images, block_sizes = process_image(
1727
+ media["image"][0], self.config, None, enable_dynamic_s2=True
1728
+ )
1729
+ images = images.half()
1730
+ media_config[name]["block_sizes"] = [block_sizes]
1731
+ else:
1732
+ images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
1733
+ media[name] = [image for image in images]
1734
+ elif name == "video":
1735
+ if self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1:
1736
+ media[name] = [
1737
+ process_images(
1738
+ images,
1739
+ self.vision_tower.image_processor,
1740
+ self.config,
1741
+ enable_dynamic_res=True,
1742
+ max_tiles=self.config.video_max_tiles,
1743
+ ).half()
1744
+ for images in media[name]
1745
+ ]
1746
+ elif self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1:
1747
+ self.config.image_processor = self.vision_tower.image_processor
1748
+ if type(self.config.s2_scales) is str:
1749
+ self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
1750
+ media[name] = [
1751
+ torch.cat(
1752
+ [
1753
+ process_image(
1754
+ image,
1755
+ self.config,
1756
+ None,
1757
+ enable_dynamic_s2=True,
1758
+ max_tiles=self.config.video_max_tiles,
1759
+ )[0].half()
1760
+ for image in images
1761
+ ]
1762
+ )
1763
+ for images in media[name]
1764
+ ]
1765
+ else:
1766
+ media[name] = [
1767
+ process_images(images, self.vision_tower.image_processor, self.config)
1768
+ for images in media[name]
1769
+ ]
1770
+ elif name == "speech":
1771
+ speeches = media["speech"]
1772
+ media[name] = [speech for speech in speeches]
1773
+ elif name == "sound":
1774
+ # sounds = process_sounds(media["sound"]).half()
1775
+ sounds = media["sound"]
1776
+ # media[name] = [{k: v.half() for sound in sounds for k, v in sound.items()]
1777
+ for sound in sounds:
1778
+ if type(sound) is dict:
1779
+ for k, v in sound.items():
1780
+ sound[k] = v.half()
1781
+ media[name] = [sound for sound in sounds]
1782
+ elif name == "video_info":
1783
+ media[name] = [media["video_info"]]
1784
+ elif name == "audio_info":
1785
+ media[name] = [media["audio_info"]]
1786
+ else:
1787
+ raise ValueError(f"Unsupported media type: {name}")
1788
+
1789
+ # Tokenize the conversation
1790
+ input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).unsqueeze(0).cuda()
1791
+
1792
+ # Set up the generation config
1793
+ generation_config = generation_config or self.default_generation_config
1794
+
1795
+ # Generate the response
1796
+ try:
1797
+ output_ids = self.generate(
1798
+ input_ids=input_ids,
1799
+ media=media,
1800
+ media_config=media_config,
1801
+ generation_config=generation_config,
1802
+ logits_processor=xgr_logits_processor, # structured generation
1803
+ )
1804
+ except ValueError:
1805
+ if not generation_config.do_sample:
1806
+ raise
1807
+ logging.warning("Generation failed with sampling, retrying with greedy decoding.")
1808
+ generation_config.do_sample = False
1809
+ output_ids = self.generate(
1810
+ input_ids=input_ids,
1811
+ media=media,
1812
+ media_config=media_config,
1813
+ generation_config=generation_config,
1814
+ logits_processor=xgr_logits_processor,
1815
+ )
1816
+
1817
+ # Decode the response
1818
+ response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
1819
+ return response
1820
+
1821
+ @property
1822
+ def default_generation_config(self) -> GenerationConfig:
1823
+ generation_config = copy.deepcopy(self.generation_config or GenerationConfig())
1824
+ if self.tokenizer.eos_token_id is None:
1825
+ raise ValueError("Tokenizer must have an EOS token")
1826
+ if generation_config.max_length == GenerationConfig().max_length:
1827
+ generation_config.max_length = self.tokenizer.model_max_length
1828
+ if generation_config.pad_token_id is None:
1829
+ generation_config.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
1830
+ if generation_config.bos_token_id is None:
1831
+ generation_config.bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
1832
+ if generation_config.eos_token_id is None:
1833
+ generation_config.eos_token_id = self.tokenizer.eos_token_id
1834
+ return generation_config
preprocessor_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "feature_extractor_type": "WhisperFeatureExtractor",
4
+ "feature_size": 128,
5
+ "hop_length": 160,
6
+ "n_fft": 400,
7
+ "n_samples": 480000,
8
+ "nb_max_frames": 3000,
9
+ "padding_side": "right",
10
+ "padding_value": 0.0,
11
+ "return_attention_mask": true,
12
+ "sampling_rate": 16000
13
+ }
pyproject.toml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "omnivinci"
7
+ version = "1.0.0"
8
+ description = "omnivinci"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: Apache Software License",
14
+ ]
15
+ dependencies = [
16
+ "torch==2.3.0", "torchvision==0.18.0",
17
+ "transformers==4.46.0", "tokenizers>=0.15.2", "sentencepiece==0.1.99", "shortuuid",
18
+ "accelerate==0.34.2", "peft>=0.9.0", "bitsandbytes==0.43.2",
19
+ "pydantic<2,>=1", "markdown2[all]", "numpy==1.26.4", "scikit-learn==1.2.2",
20
+ "gradio==3.35.2", "gradio_client==0.2.9",
21
+ "requests", "httpx", "uvicorn", "fastapi", "fire", "seaborn", "ring_flash_attn==0.1.1",
22
+ "einops==0.6.1", "einops-exts==0.0.4", "timm==0.9.12",
23
+ "openpyxl==3.1.2", "pytorchvideo==0.1.5", "decord==0.6.0",
24
+ "datasets==2.16.1", "openai==1.8.0", "webdataset==0.2.86",
25
+ "nltk==3.3", "pywsd==1.2.4", "opencv-python-headless==4.8.0.76",
26
+ "s2wrapper@git+https://github.com/bfshi/scaling_on_scales",
27
+ "tyro", "pytest", "pre-commit", "loguru", "hydra-core", "xgrammar"
28
+ ]
29
+
30
+ [project.scripts]
31
+
32
+ [project.optional-dependencies]
33
+ train = ["deepspeed==0.9.5", "ninja", "wandb"]
34
+ eval = ["word2number", "Levenshtein", "nltk", "pywsd"]
35
+
36
+ [project.urls]
37
+ "Homepage" = "https://github.com/NVlabs/OmniVinci"
38
+ "Bug Tracker" = "https://github.com/NVlabs/OmniVinci"
39
+
40
+ [tool.triton]
41
+ triton = {version = "3.0.0.post20240610003544", file = "https://aiinfra.pkgs.visualstudio.com/2692857e-05ef-43b4-ba9c-ccf1c22c437c/_packaging/07c94329-d4c3-4ad4-9e6b-f904a60032ec/pypi/download/triton-nightly/3.post20240610003544/triton_nightly-3.0.0.post20240610003544-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", sha256 = "ac2c36a49bf9c2bb780909b38096fb718f17efd78b88a1ca1d649f6d063cdc2c"}
42
+
43
+ [tool.black]
44
+ line-length = 120
45
+
46
+ [tool.isort]
47
+ profile = "black"
48
+ multi_line_output = 3
49
+ include_trailing_comma = true
50
+ force_grid_wrap = 0
51
+ use_parentheses = true
52
+ ensure_newline_before_comments = true
53
+ line_length = 120
54
+
55
+ [tool.setuptools.packages.find]
56
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
57
+
58
+ [tool.wheel]
59
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
qwen2.jinja ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% if messages[0]['role'] != 'system' %}
2
+ {{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}
3
+ {% endif %}
4
+
5
+ {% for message in messages if message['content'] is not none %}
6
+ {{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}
7
+ {% endfor %}
8
+
9
+ {% if add_generation_prompt %}
10
+ {{ '<|im_start|>assistant\n' }}
11
+ {% endif %}
qwen_audio_encoder.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from transformers import PretrainedConfig, Qwen2AudioEncoder, Qwen2AudioForConditionalGeneration
18
+
19
+ from .audio_encoder import AudioTower
20
+
21
+ class Qwen2AudioTower(AudioTower):
22
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig):
23
+ super().__init__(model_name_or_path, config)
24
+ self.audio_tower = Qwen2AudioEncoder.from_pretrained(model_name_or_path, attn_implementation="flash_attention_2")
25
+ self.is_loaded = True
26
+ self.audio_chunk_unit_duration = 30
27
+ self.audio_chunk_unit_length = 3000
28
+
29
+ def forward(self, sounds):
30
+ if type(sounds) is list:
31
+ sound_features = []
32
+ audio_output_lengths = []
33
+ for sound in sounds:
34
+ if hasattr(sound, "input_features") or (type(sound) is dict and "input_features" in sound):
35
+ sound = sound["input_features"]
36
+
37
+ sound_feature = self.forward_audio_tower_batch(sound)
38
+ sound_feature = sound_feature.to(sound.dtype)
39
+ sound_features.append(sound_feature)
40
+ audio_output_lengths.append(sound_feature.shape[1])
41
+ if len(sound_features) > 0:
42
+ sound_features = torch.cat(sound_features, dim=1).squeeze(0)
43
+ else:
44
+ raise NotImplementedError("Not implemented for this encoder")
45
+
46
+ return sound_features, audio_output_lengths
47
+
48
+
49
+ def forward_audio_tower_batch(self, inp):
50
+ """
51
+ Process long audio input by splitting into fixed-size chunks (30 seconds),
52
+ padding if needed, batching them together, and processing through the audio tower.
53
+
54
+ Args:
55
+ inp: Tensor of shape (batch_size, n_mels, seq_len)
56
+
57
+ Returns:
58
+ Tensor of shape (batch_size, num_chunks * chunk_seq_len, hidden_size)
59
+ """
60
+ batch_size, n_mels, seq_len = inp.shape
61
+ chunk_length = self.audio_chunk_unit_length
62
+ num_chunks = (seq_len + chunk_length - 1) // chunk_length # Ceiling division
63
+
64
+ padded_chunks = []
65
+
66
+ for i in range(num_chunks):
67
+ start_idx = i * chunk_length
68
+ end_idx = min(start_idx + chunk_length, seq_len)
69
+
70
+ # Extract and pad chunk if necessary
71
+ chunk = inp[:, :, start_idx:end_idx]
72
+ if chunk.shape[2] < chunk_length:
73
+ pad_len = chunk_length - chunk.shape[2]
74
+ chunk = torch.nn.functional.pad(chunk, (0, pad_len), mode='constant', value=0)
75
+
76
+ padded_chunks.append(chunk)
77
+
78
+ # Stack chunks along batch dimension
79
+ all_chunks = torch.cat(padded_chunks, dim=0).reshape(batch_size * num_chunks, n_mels, chunk_length)
80
+
81
+ # Forward pass through the audio tower
82
+ chunk_outputs = self.audio_tower(all_chunks)
83
+ hidden_states = chunk_outputs.last_hidden_state
84
+
85
+ # Reshape back to (batch_size, num_chunks * seq_len', hidden_size)
86
+ _, chunk_seq_len, hidden_size = hidden_states.shape
87
+ hidden_states = hidden_states.reshape(batch_size, num_chunks * chunk_seq_len, hidden_size)
88
+
89
+ return hidden_states
siglip_encoder.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from accelerate.hooks import add_hook_to_module
21
+ from einops import rearrange
22
+ from s2wrapper import forward as multiscale_forward
23
+ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, SiglipImageProcessor
24
+ from transformers.image_processing_utils import BaseImageProcessor
25
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
26
+ from transformers.models.siglip import SiglipVisionModel
27
+
28
+
29
+ class VisionTower(nn.Module):
30
+ def __init__(self, vision_tower, args, delay_load=False):
31
+ super().__init__()
32
+
33
+ self.is_loaded = False
34
+
35
+ self.vision_tower_name = vision_tower
36
+ self.select_layer = getattr(args, "mm_vision_select_layer", -2)
37
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
38
+
39
+ self.cfg_only = None
40
+
41
+ def feature_select(self, image_forward_outs):
42
+ image_features = image_forward_outs.hidden_states[self.select_layer]
43
+ if self.select_feature == "patch":
44
+ image_features = image_features[:, 1:]
45
+ elif self.select_feature == "cls_patch":
46
+ image_features = image_features
47
+ else:
48
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
49
+ return image_features
50
+
51
+ def _maybe_resize_pos_embeds(
52
+ self,
53
+ model: PreTrainedModel,
54
+ image_processor: BaseImageProcessor,
55
+ resolution: int = -1,
56
+ interpolate_mode: str = "linear",
57
+ ):
58
+ if resolution in [model.config.image_size, -1]:
59
+ return
60
+ print(
61
+ f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..."
62
+ )
63
+ embeddings = model.vision_model.embeddings
64
+ patch_size = embeddings.patch_size
65
+ num_new_tokens = int((resolution // patch_size) ** 2)
66
+
67
+ old_embeddings = embeddings.position_embedding
68
+ match interpolate_mode:
69
+ case "linear":
70
+ # Step 1: Calculate the corresponding patch ID (pid) in the current resolution (M patches) based on the target resolution (N patches). Formula: pid = pid / N * M
71
+ # Step 2: Obtain new embeddings by interpolating between the embeddings of the two nearest calculated patch IDs. Formula: new_embeds = (pid - floor(pid)) * embeds[ceil(pid)] + (ceil(pid) - pid) * embeds[floor(pid)]
72
+ import torch
73
+ import torch.nn as nn
74
+
75
+ if is_deepspeed_zero3_enabled():
76
+ try:
77
+ import deepspeed
78
+ except ImportError:
79
+ raise ImportError("DeepSpeed is not installed. Please install it with `pip install deepspeed`.")
80
+ with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
81
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
82
+ else:
83
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
84
+ new_embeddings = nn.Embedding(
85
+ num_new_tokens,
86
+ old_embedding_dim,
87
+ dtype=old_embeddings.weight.dtype,
88
+ device=old_embeddings.weight.device,
89
+ )
90
+ mapped_indices = (
91
+ torch.arange(num_new_tokens).to(old_embeddings.weight.device)
92
+ / (num_new_tokens - 1)
93
+ * (old_num_tokens - 1)
94
+ )
95
+ floor_indices = torch.clamp(mapped_indices.floor().long(), min=0, max=old_num_tokens - 1)
96
+ ceil_indices = torch.clamp(mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1)
97
+ if is_deepspeed_zero3_enabled():
98
+ params = [old_embeddings.weight, new_embeddings.weight]
99
+ with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
100
+ interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
101
+ ceil_indices, :
102
+ ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
103
+ else:
104
+ interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
105
+ ceil_indices, :
106
+ ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
107
+ new_embeddings.weight.data = interpolated_embeds
108
+ case _:
109
+ raise NotImplementedError
110
+
111
+ if hasattr(old_embeddings, "_hf_hook"):
112
+ hook = old_embeddings._hf_hook
113
+ add_hook_to_module(new_embeddings, hook)
114
+ new_embeddings.requires_grad_(old_embeddings.weight.requires_grad)
115
+
116
+ # Update vision encoder's configurations
117
+ model.config.image_size = resolution
118
+ if hasattr(image_processor, "crop_size"):
119
+ # CLIP vision tower
120
+ image_processor.crop_size = resolution
121
+ else:
122
+ # SIGLIP vision tower
123
+ assert hasattr(image_processor, "size")
124
+ image_processor.size = {"height": resolution, "width": resolution}
125
+
126
+ embeddings.position_embedding = new_embeddings
127
+ embeddings.image_size = resolution
128
+ embeddings.num_patches = embeddings.num_positions = num_new_tokens
129
+ embeddings.position_ids = (
130
+ torch.arange(embeddings.num_positions).expand((1, -1)).to(old_embeddings.weight.device)
131
+ )
132
+
133
+ def forward(self, images):
134
+ if type(images) is list:
135
+ image_features = []
136
+ for image in images:
137
+ image_forward_out = self.vision_tower(
138
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
139
+ output_hidden_states=True,
140
+ )
141
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
142
+ image_features.append(image_feature)
143
+ else:
144
+ image_forward_outs = self.vision_tower(
145
+ images.to(device=self.device, dtype=self.dtype),
146
+ output_hidden_states=True,
147
+ )
148
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
149
+
150
+ return image_features
151
+
152
+ @property
153
+ def dummy_feature(self):
154
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
155
+
156
+ @property
157
+ def dtype(self):
158
+ return self.vision_tower.dtype
159
+
160
+ @property
161
+ def device(self):
162
+ return self.vision_tower.device
163
+
164
+ @property
165
+ def config(self):
166
+ if self.is_loaded:
167
+ return self.vision_tower.config
168
+ else:
169
+ return self.cfg_only
170
+
171
+ @property
172
+ def hidden_size(self):
173
+ return self.config.hidden_size
174
+
175
+ @property
176
+ def num_patches(self):
177
+ return (self.config.image_size // self.config.patch_size) ** 2
178
+
179
+
180
+ class VisionTowerS2(VisionTower):
181
+ def __init__(self, vision_tower, args, delay_load=False):
182
+ super().__init__(vision_tower, args, delay_load)
183
+
184
+ self.scales = list(map(int, args.s2_scales.split(",")))
185
+ self.scales.sort()
186
+ self.max_split_size = args.s2_max_split_size
187
+ self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
188
+
189
+ def forward_feature(self, images):
190
+ image_forward_outs = self.vision_tower(
191
+ images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
192
+ )
193
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
194
+ return image_features
195
+
196
+ def forward(self, images):
197
+ if type(images) is list:
198
+ image_feature = []
199
+ for image in images:
200
+ image_feature = multiscale_forward(
201
+ self.forward_feature,
202
+ image.unsqueeze(0),
203
+ img_sizes=self.scales,
204
+ max_split_size=self.max_split_size,
205
+ resize_output_to_idx=self.resize_output_to_scale_idx,
206
+ )
207
+ image_features.append(image_feature)
208
+ else:
209
+ image_features = multiscale_forward(
210
+ self.forward_feature,
211
+ images,
212
+ img_sizes=self.scales,
213
+ max_split_size=self.max_split_size,
214
+ resize_output_to_idx=self.resize_output_to_scale_idx,
215
+ )
216
+
217
+ return image_features
218
+
219
+ @property
220
+ def hidden_size(self):
221
+ return self.config.hidden_size * len(self.scales)
222
+
223
+
224
+ class VisionTowerDynamicS2(VisionTower):
225
+ def __init__(self, vision_tower, args, delay_load=False):
226
+ super().__init__(vision_tower, args, delay_load)
227
+
228
+ self.scales = list(map(int, args.s2_scales.split(",")))
229
+ self.scales.sort()
230
+ self.max_split_size = args.s2_max_split_size
231
+ self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
232
+
233
+ def forward_feature(self, images):
234
+ image_forward_outs = self.vision_tower(
235
+ images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
236
+ )
237
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
238
+ return image_features
239
+
240
+ def forward(self, images):
241
+ assert type(images) is not list
242
+ image_features = self.forward_feature(images)
243
+
244
+ return image_features
245
+
246
+ @property
247
+ def hidden_size(self):
248
+ return self.config.hidden_size * len(self.scales)
249
+
250
+
251
+ class SiglipVisionTower(VisionTower):
252
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
253
+ super().__init__(model_name_or_path, config)
254
+ self.vision_tower = SiglipVisionModel.from_pretrained(
255
+ model_name_or_path,
256
+ attn_implementation=config._attn_implementation,
257
+ torch_dtype=eval(config.model_dtype),
258
+ )
259
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
260
+ self.is_loaded = True
261
+
262
+
263
+ class SiglipVisionTowerS2(VisionTowerS2):
264
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
265
+ super().__init__(model_name_or_path, config)
266
+ self.vision_tower = SiglipVisionModel.from_pretrained(
267
+ model_name_or_path,
268
+ attn_implementation=config._attn_implementation,
269
+ torch_dtype=eval(config.model_dtype),
270
+ )
271
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
272
+ # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
273
+ self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[-1]
274
+ self.is_loaded = True
275
+
276
+
277
+ class SiglipVisionTowerDynamicS2(VisionTowerDynamicS2):
278
+ def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
279
+ super().__init__(model_name_or_path, config)
280
+ if type(config.model_dtype) == str:
281
+ model_dtype = eval(config.model_dtype)
282
+ else:
283
+ model_dtype = config.model_dtype
284
+
285
+ self.vision_tower = SiglipVisionModel.from_pretrained(
286
+ model_name_or_path,
287
+ attn_implementation="flash_attention_2",
288
+ torch_dtype=model_dtype,
289
+ )
290
+ self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
291
+ # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
292
+ self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[0]
293
+ self.is_loaded = True
sound_base_projector.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from einops import rearrange
20
+ from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
21
+
22
+ class SoundMultimodalProjectorConfig(PretrainedConfig):
23
+ """Configuration for sound multimodal projector."""
24
+
25
+ model_type = "sound_mm_projector"
26
+
27
+ def __init__(self, sound_mm_projector_type: str = None, **kwargs):
28
+ super().__init__()
29
+ self.sound_mm_projector_type = sound_mm_projector_type
30
+
31
+
32
+ class AudioDownSampleBlock(nn.Module):
33
+ """Downsample audio features using 1D convolution."""
34
+ def __init__(self, embed_dim):
35
+ super().__init__()
36
+ self.conv1 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
37
+
38
+ def forward(self, x):
39
+ x = rearrange(x, "b t c -> b c t")
40
+ x = self.conv1(x)
41
+ x = rearrange(x, "b c t -> b t c")
42
+ return x
43
+
44
+ class AudioDownSamplePoolBlock(nn.Module):
45
+ """Downsample audio features using average pooling."""
46
+
47
+ def __init__(self, embed_dim):
48
+ super().__init__()
49
+ self.pool = nn.AvgPool1d(kernel_size=2)
50
+
51
+ def forward(self, x):
52
+ x = rearrange(x, "b t c -> b c t")
53
+ x = self.pool(x)
54
+ x = rearrange(x, "b c t -> b t c")
55
+ return x
56
+
57
+
58
+ class AudioDownSampleMaxPoolBlock(nn.Module):
59
+ """Downsample audio features using max pooling."""
60
+
61
+ def __init__(self, embed_dim):
62
+ super().__init__()
63
+ self.pool = nn.MaxPool1d(kernel_size=2)
64
+
65
+ def forward(self, x):
66
+ x = rearrange(x, "b t c -> b c t")
67
+ x = self.pool(x)
68
+ x = rearrange(x, "b c t -> b t c")
69
+ return x
70
+
71
+
72
+ class SoundMultimodalProjector(PreTrainedModel):
73
+ """Sound multimodal projector for mapping audio features to LLM space."""
74
+ config_class = SoundMultimodalProjectorConfig
75
+
76
+ def __init__(self, sound_mm_projector_cfg: SoundMultimodalProjectorConfig, config: PretrainedConfig):
77
+ super().__init__(sound_mm_projector_cfg)
78
+ if hasattr(config, "sound_mm_projector"):
79
+ sound_mm_projector_type = config.sound_mm_projector
80
+ else:
81
+ sound_mm_projector_type = sound_mm_projector_cfg.sound_mm_projector_type
82
+ self.sound_mm_projector_type = sound_mm_projector_type
83
+ self.config.sound_mm_projector_type = sound_mm_projector_type
84
+
85
+ if hasattr(config, "sound_mm_projector_cfg") and type(config.sound_mm_projector_cfg) == dict:
86
+ config.sound_mm_projector_cfg["sound_mm_projector_type"] = sound_mm_projector_type
87
+
88
+ if sound_mm_projector_type == "mlp":
89
+ self.layers = nn.Sequential(
90
+ nn.Linear(config.sound_hidden_size, config.hidden_size),
91
+ nn.GELU(),
92
+ nn.Linear(config.hidden_size, config.hidden_size),
93
+ )
94
+ elif sound_mm_projector_type == "mlp_downsample":
95
+ self.downsample_block = AudioDownSampleBlock(config.sound_hidden_size)
96
+ self.layers = nn.Sequential(
97
+ nn.Linear(config.sound_hidden_size, config.hidden_size),
98
+ nn.GELU(),
99
+ nn.Linear(config.hidden_size, config.hidden_size),
100
+ )
101
+ elif sound_mm_projector_type == "mlp_downsample_pool":
102
+ self.downsample_block = AudioDownSamplePoolBlock(config.sound_hidden_size)
103
+ self.layers = nn.Sequential(
104
+ nn.Linear(config.sound_hidden_size, config.hidden_size),
105
+ nn.GELU(),
106
+ nn.Linear(config.hidden_size, config.hidden_size),
107
+ )
108
+ elif sound_mm_projector_type == "mlp_downsample_pool_max":
109
+ self.downsample_block = AudioDownSampleMaxPoolBlock(config.sound_hidden_size)
110
+ self.layers = nn.Sequential(
111
+ nn.Linear(config.sound_hidden_size, config.hidden_size),
112
+ nn.GELU(),
113
+ nn.Linear(config.hidden_size, config.hidden_size),
114
+ )
115
+ else:
116
+ raise ValueError(f"Unknown projector type: {sound_mm_projector_type}")
117
+
118
+
119
+ def forward(self, x, *args, **kwargs):
120
+ if self.sound_mm_projector_type in ["mlp_downsample", "mlp_downsample_pool", "mlp_downsample_pool_max"]:
121
+ x = self.downsample_block(x)
122
+ return self.layers(x)
123
+
124
+
125
+ AutoConfig.register("sound_mm_projector", SoundMultimodalProjectorConfig)
126
+ AutoModel.register(SoundMultimodalProjectorConfig, SoundMultimodalProjector)
sound_mm_projector/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/lustre/fs12/portfolios/llmservice/projects/llmservice_fm_vision/users/hanrongy/project/vila/VILA-Internal/../exp_log/nvomni-8b-video-0d1-trope128_omniT_ras_n16_bs2048_ga8_mstep-1_j20250718/outputs/model/sound_mm_projector",
3
+ "architectures": [
4
+ "SoundMultimodalProjector"
5
+ ],
6
+ "model_type": "sound_mm_projector",
7
+ "sound_mm_projector_type": "mlp",
8
+ "torch_dtype": "bfloat16",
9
+ "transformers_version": "4.46.0"
10
+ }
sound_mm_projector/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb57ebfdeb51af4a1c0de931fd43e6a4b93277552ad02ad01b1d9ba720bcb9a4
3
+ size 34879856
sound_tower/config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "outputs/model/sound_tower",
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "architectures": [
6
+ "Qwen2AudioEncoder"
7
+ ],
8
+ "attention_dropout": 0.0,
9
+ "audio_config": {
10
+ "activation_function": "gelu",
11
+ "d_model": 1280,
12
+ "encoder_attention_heads": 20,
13
+ "encoder_ffn_dim": 5120,
14
+ "encoder_layers": 32,
15
+ "max_source_positions": 1500,
16
+ "model_type": "qwen2_audio_encoder",
17
+ "num_mel_bins": 128,
18
+ "scale_embedding": false
19
+ },
20
+ "audio_token_index": 151646,
21
+ "d_model": 1280,
22
+ "dropout": 0.0,
23
+ "encoder_attention_heads": 20,
24
+ "encoder_ffn_dim": 5120,
25
+ "encoder_layerdrop": 0.0,
26
+ "encoder_layers": 32,
27
+ "ignore_index": -100,
28
+ "init_std": 0.02,
29
+ "max_source_positions": 1500,
30
+ "model_type": "qwen2_audio_encoder",
31
+ "num_hidden_layers": 32,
32
+ "num_mel_bins": 128,
33
+ "scale_embedding": false,
34
+ "text_config": {
35
+ "bos_token_id": 151643,
36
+ "eos_token_id": 151645,
37
+ "intermediate_size": 11008,
38
+ "max_position_embeddings": 8192,
39
+ "model_type": "qwen2",
40
+ "rms_norm_eps": 1e-05,
41
+ "rope_theta": 10000,
42
+ "sliding_window": 32768,
43
+ "torch_dtype": "bfloat16",
44
+ "use_mrope": false,
45
+ "vocab_size": 156032
46
+ },
47
+ "torch_dtype": "bfloat16",
48
+ "transformers_version": "4.46.0",
49
+ "vocab_size": 156032
50
+ }
sound_tower/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01fcbfa8ac3d3bc4c5ab97c439dfecfea2a9c2e061031280efed292fc37b4a44
3
+ size 1273988176