Use get_input_embeddings() Instead of Accessing .embed_tokens Directly

#8
1_Pooling/config.json DELETED
@@ -1,10 +0,0 @@
1
- {
2
- "word_embedding_dimension": 3584,
3
- "pooling_mode_cls_token": false,
4
- "pooling_mode_mean_tokens": false,
5
- "pooling_mode_max_tokens": false,
6
- "pooling_mode_mean_sqrt_len_tokens": false,
7
- "pooling_mode_weightedmean_tokens": false,
8
- "pooling_mode_lasttoken": true,
9
- "include_prompt": true
10
- }
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -3691,110 +3691,61 @@ The `GME` models support three types of input: **text**, **image**, and **image-
3691
  |[`gme-Qwen2-VL-2B`](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct) | 2.21B | 32768 | 1536 | 65.27 | 68.41 | 64.45 |
3692
  |[`gme-Qwen2-VL-7B`](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-7B-Instruct) | 8.29B | 32768 | 3584 | 67.48 | 71.36 | 67.44 |
3693
 
3694
-
3695
-
3696
  ## Usage
3697
-
3698
-
3699
- **Transformers**
3700
-
3701
- The remote code has some issues with `transformers>=4.52.0`, please downgrade or use `sentence_transformers`
3702
 
3703
  ```python
3704
- from transformers import AutoModel
3705
- from transformers.utils.versions import require_version
3706
-
3707
-
3708
- require_version(
3709
- "transformers<4.52.0",
3710
- "The remote code has some issues with transformers>=4.52.0, please downgrade: pip install transformers==4.51.3"
3711
- )
3712
 
 
3713
 
3714
- t2i_prompt = 'Find an image that matches the given text.'
3715
  texts = [
3716
- "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023.",
3717
- "Alibaba office.",
3718
  ]
3719
  images = [
3720
- 'https://upload.wikimedia.org/wikipedia/commons/e/e9/Tesla_Cybertruck_damaged_window.jpg',
3721
- 'https://upload.wikimedia.org/wikipedia/commons/e/e0/TaobaoCity_Alibaba_Xixi_Park.jpg',
3722
  ]
3723
 
3724
-
3725
- gme = AutoModel.from_pretrained(
3726
- "Alibaba-NLP/gme-Qwen2-VL-7B-Instruct",
3727
- torch_dtype="float16", device_map='cuda', trust_remote_code=True
3728
- )
3729
-
3730
-
3731
  # Single-modal embedding
3732
  e_text = gme.get_text_embeddings(texts=texts)
3733
  e_image = gme.get_image_embeddings(images=images)
3734
- print('Single-modal', (e_text @ e_image.T).tolist())
3735
- ## Single-modal [[0.279296875, 0.0002658367156982422], [0.06427001953125, 0.304443359375]]
3736
 
3737
  # How to set embedding instruction
3738
- e_query = gme.get_text_embeddings(texts=texts, instruction=t2i_prompt)
3739
  # If is_query=False, we always use the default instruction.
3740
  e_corpus = gme.get_image_embeddings(images=images, is_query=False)
3741
- print('Single-modal with instruction', (e_query @ e_corpus.T).tolist())
3742
- ## Single-modal with instruction [[0.32861328125, 0.026336669921875], [0.09466552734375, 0.3134765625]]
3743
 
3744
  # Fused-modal embedding
3745
  e_fused = gme.get_fused_embeddings(texts=texts, images=images)
3746
- print('Fused-modal', (e_fused @ e_fused.T).tolist())
3747
- ## Fused-modal [[1.0, 0.0308685302734375], [0.0308685302734375, 1.0]]
3748
- ```
3749
 
 
3750
 
3751
- **sentence_transformers**
3752
-
3753
- The `encode` function accept `str` or `dict` with key(s) in `{'text', 'image', 'prompt'}`.
3754
-
3755
- **Do not pass `prompt` as the argument to `encode`**, pass as the input as a `dict` with a `prompt` key.
3756
 
3757
  ```python
3758
- from sentence_transformers import SentenceTransformer
3759
 
 
3760
 
3761
- t2i_prompt = 'Find an image that matches the given text.'
3762
- texts = [
3763
- "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023.",
3764
- "Alibaba office.",
3765
- ]
3766
- images = [
3767
- 'https://upload.wikimedia.org/wikipedia/commons/e/e9/Tesla_Cybertruck_damaged_window.jpg',
3768
- 'https://upload.wikimedia.org/wikipedia/commons/e/e0/TaobaoCity_Alibaba_Xixi_Park.jpg',
3769
- ]
3770
-
3771
-
3772
- gme_st = SentenceTransformer("Alibaba-NLP/gme-Qwen2-VL-7B-Instruct")
3773
-
3774
- # Single-modal embedding
3775
- e_text = gme_st.encode(texts, convert_to_tensor=True)
3776
- e_image = gme_st.encode([dict(image=i) for i in images], convert_to_tensor=True)
3777
- print('Single-modal', (e_text @ e_image.T).tolist())
3778
- ## Single-modal [[0.27880859375, 0.0005745887756347656], [0.06500244140625, 0.306640625]]
3779
-
3780
- # How to set embedding instruction
3781
- e_query = gme_st.encode([dict(text=t, prompt=t2i_prompt) for t in texts], convert_to_tensor=True)
3782
- # If no prompt, we always use the default instruction.
3783
- e_corpus = gme_st.encode([dict(image=i) for i in images], convert_to_tensor=True)
3784
- print('Single-modal with instruction', (e_query @ e_corpus.T).tolist())
3785
- ## Single-modal with instruction [[0.328369140625, 0.0269927978515625], [0.09521484375, 0.316162109375]]
3786
-
3787
- # Fused-modal embedding
3788
- e_fused = gme_st.encode([dict(text=t, image=i) for t, i in zip(texts, images)], convert_to_tensor=True)
3789
- print('Fused-modal', (e_fused @ e_fused.T).tolist())
3790
- ## Fused-modal [[0.99951171875, 0.0311737060546875], [0.0311737060546875, 1.0009765625]]
3791
  ```
3792
 
3793
-
 
3794
 
3795
  ## Evaluation
3796
 
3797
- We validated the performance on our universal multimodal retrieval benchmark (**UMRB**, see [Release UMRB](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-7B-Instruct/discussions/2)) among others.
3798
 
3799
  | | | Single-modal | | Cross-modal | | | Fused-modal | | | | Avg. |
3800
  |--------------------|------|:------------:|:---------:|:-----------:|:-----------:|:---------:|:-----------:|:----------:|:----------:|:-----------:|:----------:|
 
3691
  |[`gme-Qwen2-VL-2B`](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct) | 2.21B | 32768 | 1536 | 65.27 | 68.41 | 64.45 |
3692
  |[`gme-Qwen2-VL-7B`](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-7B-Instruct) | 8.29B | 32768 | 3584 | 67.48 | 71.36 | 67.44 |
3693
 
 
 
3694
  ## Usage
3695
+ **Use with custom code**
 
 
 
 
3696
 
3697
  ```python
3698
+ # You can find the script gme_inference.py in https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct/blob/main/gme_inference.py
3699
+ from gme_inference import GmeQwen2VL
 
 
 
 
 
 
3700
 
3701
+ model = GmeQwen2VL('Alibaba-NLP/gme-Qwen2-VL-7B-Instruct')
3702
 
 
3703
  texts = [
3704
+ "What kind of car is this?",
3705
+ "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023."
3706
  ]
3707
  images = [
3708
+ 'https://en.wikipedia.org/wiki/File:Tesla_Cybertruck_damaged_window.jpg',
3709
+ 'https://en.wikipedia.org/wiki/File:2024_Tesla_Cybertruck_Foundation_Series,_front_left_(Greenwich).jpg',
3710
  ]
3711
 
 
 
 
 
 
 
 
3712
  # Single-modal embedding
3713
  e_text = gme.get_text_embeddings(texts=texts)
3714
  e_image = gme.get_image_embeddings(images=images)
3715
+ print((e_text * e_image).sum(-1))
3716
+ ## tensor([0.1702, 0.5278], dtype=torch.float16)
3717
 
3718
  # How to set embedding instruction
3719
+ e_query = gme.get_text_embeddings(texts=texts, instruction='Find an image that matches the given text.')
3720
  # If is_query=False, we always use the default instruction.
3721
  e_corpus = gme.get_image_embeddings(images=images, is_query=False)
3722
+ print((e_query * e_corpus).sum(-1))
3723
+ ## tensor([0.2000, 0.5752], dtype=torch.float16)
3724
 
3725
  # Fused-modal embedding
3726
  e_fused = gme.get_fused_embeddings(texts=texts, images=images)
3727
+ print((e_fused[0] * e_fused[1]).sum())
3728
+ ## tensor(0.6826, dtype=torch.float16)
 
3729
 
3730
+ ```
3731
 
3732
+ <!-- <details>
3733
+ <summary>With transformers</summary>
 
 
 
3734
 
3735
  ```python
3736
+ # Requires transformers>=4.46.2
3737
 
3738
+ TODO
3739
 
3740
+ # [[0.3016996383666992, 0.7503870129585266, 0.3203084468841553]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3741
  ```
3742
 
3743
+ </details>
3744
+ -->
3745
 
3746
  ## Evaluation
3747
 
3748
+ We validated the performance on our universal multimodal retrieval benchmark (**UMRB**) among others.
3749
 
3750
  | | | Single-modal | | Cross-modal | | | Fused-modal | | | | Avg. |
3751
  |--------------------|------|:------------:|:---------:|:-----------:|:-----------:|:---------:|:-----------:|:----------:|:----------:|:-----------:|:----------:|
config.json CHANGED
@@ -1,13 +1,8 @@
1
  {
2
- "_name_or_path": "Alibaba-NLP/gme-Qwen2-VL-7B-Instruct",
3
  "architectures": [
4
- "Qwen2VLForConditionalGeneration",
5
- "GmeQwen2VL"
6
  ],
7
- "auto_map": {
8
- "AutoConfig": "modeling_gme_qwen2vl.GmeQwen2VLConfig",
9
- "AutoModel": "modeling_gme_qwen2vl.GmeQwen2VL"
10
- },
11
  "attention_dropout": 0.0,
12
  "bos_token_id": 151643,
13
  "eos_token_id": 151645,
 
1
  {
2
+ "_name_or_path": "gme-Qwen2-VL-7B-Instruct",
3
  "architectures": [
4
+ "Qwen2VLForConditionalGeneration"
 
5
  ],
 
 
 
 
6
  "attention_dropout": 0.0,
7
  "bos_token_id": 151643,
8
  "eos_token_id": 151645,
config_sentence_transformers.json DELETED
@@ -1,7 +0,0 @@
1
- {
2
- "prompts": {
3
- "query": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
4
- },
5
- "default_prompt_name": null,
6
- "similarity_fn_name": null
7
- }
 
 
 
 
 
 
 
 
custom_st.py DELETED
@@ -1,221 +0,0 @@
1
- from io import BytesIO
2
- from typing import Any, Dict, Optional, List
3
- import torch
4
- from PIL import Image
5
- from sentence_transformers.models import Transformer as BaseTransformer
6
- from transformers import AutoModelForVision2Seq, AutoProcessor
7
-
8
-
9
- class MultiModalTransformer(BaseTransformer):
10
- def __init__(
11
- self,
12
- model_name_or_path: str,
13
- cache_dir: Optional[str] = None,
14
- tokenizer_args: Optional[Dict[str, Any]] = None,
15
- min_image_tokens: int = 256,
16
- max_image_tokens: int = 1280,
17
- max_length: int = 1800,
18
- **kwargs,
19
- ):
20
- super().__init__(model_name_or_path, **kwargs)
21
- if tokenizer_args is None:
22
- tokenizer_args = {}
23
- tokenizer_args.pop("trust_remote_code", None)
24
-
25
- # Initialize processor
26
- min_pixels = min_image_tokens * 28 * 28
27
- max_pixels = max_image_tokens * 28 * 28
28
- self.processor = AutoProcessor.from_pretrained(
29
- model_name_or_path, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
30
- )
31
- self.processor.tokenizer.padding_side = 'right'
32
- self.sep = ' '
33
- self.max_length = max_length
34
- self.normalize = True
35
-
36
- def _load_model(
37
- self,
38
- model_name_or_path: str,
39
- config,
40
- cache_dir: str,
41
- backend: str,
42
- is_peft_model: bool,
43
- **model_args,
44
- ) -> None:
45
- model_args.pop("trust_remote_code", None)
46
- self.auto_model = AutoModelForVision2Seq.from_pretrained(
47
- model_name_or_path, torch_dtype=torch.float16, **model_args
48
- )
49
-
50
- def forward(
51
- self, features: Dict[str, torch.Tensor], **kwargs
52
- ) -> Dict[str, torch.Tensor]:
53
- if features.get("inputs_embeds", None) is None:
54
- features["inputs_embeds"] = self.auto_model.base_model.get_input_embeddings()(features["input_ids"])
55
- if features.get("pixel_values", None) is not None:
56
- features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
57
- image_embeds = self.auto_model.visual(
58
- features["pixel_values"], grid_thw=features["image_grid_thw"]
59
- )
60
- image_mask = features["input_ids"] == self.auto_model.config.image_token_id
61
- features["inputs_embeds"][image_mask] = image_embeds
62
- # features.pop("pixel_values")
63
- # features.pop("image_grid_thw")
64
- # features.pop("input_ids")
65
- inputs = {k: v for k, v in features.items() if k in 'position_ids,attention_mask,inputs_embeds'}
66
- outputs = self.auto_model.model(
67
- **inputs,
68
- return_dict=True,
69
- output_hidden_states=True,
70
- # **kwargs
71
- )
72
- # pooling_mask = features["attention_mask"] if features.get("pooling_mask", None) is None else features["pooling_mask"]
73
- # left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
74
- # if left_padding:
75
- # embeddings = outputs.last_hidden_state
76
- # else:
77
- # sequence_lengths = pooling_mask.sum(dim=1) - 1
78
- # embeddings = outputs.last_hidden_state[torch.arange(
79
- # outputs.last_hidden_state.shape[0], device=outputs.last_hidden_state.device
80
- # ), sequence_lengths]
81
- features.update({"token_embeddings": outputs.last_hidden_state})
82
- return features
83
-
84
- def tokenize(self, texts: List[List[Dict[str, Any]]] | List[str]) -> Dict[str, torch.Tensor]:
85
- default_instruction = 'You are a helpful assistant.'
86
-
87
- all_texts, all_images = list(), list()
88
- for item in texts:
89
- if isinstance(item, str):
90
- txt, img, inst = item, None, default_instruction
91
- elif isinstance(item, dict):
92
- txt = item.get('text', None)
93
- img = item.get('image', None)
94
- inst = item.get('prompt', default_instruction)
95
- else:
96
- raise RuntimeError(f'Input format not supported! {item=}')
97
-
98
- input_str = ''
99
- if img is None:
100
- all_images = None # All examples in the same batch are consistent
101
- # or will have ValueError: Could not make a flat list of images from xxxx
102
- else:
103
- input_str += '<|vision_start|><|image_pad|><|vision_end|>'
104
- img = fetch_image(img)
105
- all_images.append(img)
106
- if txt is not None:
107
- input_str += txt
108
- msg = f'<|im_start|>system\n{inst}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
109
- all_texts.append(msg)
110
-
111
- inputs = self.processor(
112
- text=all_texts,
113
- images=all_images,
114
- padding="longest",
115
- truncation=True,
116
- max_length=self.max_seq_length,
117
- return_tensors='pt'
118
- )
119
- return inputs
120
-
121
-
122
- ### Copied from qwen_vl_utils.vision_process.py
123
- import base64
124
- from io import BytesIO
125
- import requests
126
-
127
- IMAGE_FACTOR = 28
128
- MIN_PIXELS = 4 * 28 * 28
129
- MAX_PIXELS = 16384 * 28 * 28
130
- MAX_RATIO = 200
131
-
132
-
133
- def round_by_factor(number: int, factor: int) -> int:
134
- """Returns the closest integer to 'number' that is divisible by 'factor'."""
135
- return round(number / factor) * factor
136
-
137
-
138
- def ceil_by_factor(number: int, factor: int) -> int:
139
- """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
140
- return math.ceil(number / factor) * factor
141
-
142
-
143
- def floor_by_factor(number: int, factor: int) -> int:
144
- """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
145
- return math.floor(number / factor) * factor
146
-
147
-
148
- def smart_resize(
149
- height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
150
- ) -> tuple[int, int]:
151
- """
152
- Rescales the image so that the following conditions are met:
153
-
154
- 1. Both dimensions (height and width) are divisible by 'factor'.
155
-
156
- 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
157
-
158
- 3. The aspect ratio of the image is maintained as closely as possible.
159
- """
160
- h_bar = max(factor, round_by_factor(height, factor))
161
- w_bar = max(factor, round_by_factor(width, factor))
162
- if h_bar * w_bar > max_pixels:
163
- beta = math.sqrt((height * width) / max_pixels)
164
- h_bar = floor_by_factor(height / beta, factor)
165
- w_bar = floor_by_factor(width / beta, factor)
166
- elif h_bar * w_bar < min_pixels:
167
- beta = math.sqrt(min_pixels / (height * width))
168
- h_bar = ceil_by_factor(height * beta, factor)
169
- w_bar = ceil_by_factor(width * beta, factor)
170
-
171
- if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO:
172
- logging.warning(
173
- f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}"
174
- )
175
- if h_bar > w_bar:
176
- h_bar = w_bar * MAX_RATIO
177
- else:
178
- w_bar = h_bar * MAX_RATIO
179
- return h_bar, w_bar
180
-
181
-
182
- def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
183
- image_obj = None
184
- if isinstance(image, Image.Image):
185
- image_obj = image
186
- elif image.startswith("http://") or image.startswith("https://"):
187
- image_obj = Image.open(requests.get(image, stream=True).raw)
188
- elif image.startswith("file://"):
189
- image_obj = Image.open(image[7:])
190
- elif image.startswith("data:image"):
191
- if "base64," in image:
192
- _, base64_data = image.split("base64,", 1)
193
- data = base64.b64decode(base64_data)
194
- image_obj = Image.open(BytesIO(data))
195
- else:
196
- image_obj = Image.open(image)
197
- if image_obj is None:
198
- raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
199
- image = image_obj.convert("RGB")
200
- ## resize
201
- # if "resized_height" in ele and "resized_width" in ele:
202
- # resized_height, resized_width = smart_resize(
203
- # ele["resized_height"],
204
- # ele["resized_width"],
205
- # factor=size_factor,
206
- # )
207
- # else:
208
- width, height = image.size
209
- # min_pixels = ele.get("min_pixels", MIN_PIXELS)
210
- # max_pixels = ele.get("max_pixels", MAX_PIXELS)
211
- resized_height, resized_width = smart_resize(
212
- height,
213
- width,
214
- factor=size_factor,
215
- min_pixels=MIN_PIXELS,
216
- max_pixels=MAX_PIXELS,
217
- )
218
- image = image.resize((resized_width, resized_height))
219
-
220
- return image
221
- ###
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_gme_qwen2vl.py DELETED
@@ -1,337 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import base64
4
- import logging
5
- import math
6
- import os
7
- from io import BytesIO
8
- from typing import Any, Dict, List, Optional, Union
9
-
10
- import requests
11
- import torch
12
- from PIL import Image
13
- from torch.utils.data import DataLoader
14
- from tqdm.autonotebook import tqdm
15
- from transformers import AutoProcessor, PreTrainedModel
16
- from transformers.models.qwen2_vl.modeling_qwen2_vl import (
17
- Qwen2VisionTransformerPretrainedModel,
18
- Qwen2VLConfig,
19
- Qwen2VLForConditionalGeneration,
20
- Qwen2VLModel,
21
- )
22
- from transformers.utils.versions import require_version
23
-
24
-
25
- require_version(
26
- "transformers<4.52.0",
27
- "This code has some issues with transformers>=4.52.0, please downgrade: pip install transformers==4.51.3"
28
- )
29
-
30
-
31
- class GmeQwen2VLConfig(Qwen2VLConfig):
32
- # model_type = ''
33
-
34
- def __init__(
35
- self,
36
- min_image_tokens: int = 256,
37
- max_image_tokens: int = 1280,
38
- max_length: int = 1800,
39
- **kwargs: Any,
40
- ) -> None:
41
- super().__init__(**kwargs)
42
- self.min_image_tokens = min_image_tokens
43
- self.max_image_tokens = max_image_tokens
44
- self.max_length = max_length
45
-
46
-
47
- class GmeQwen2VL(PreTrainedModel):
48
- config_class = GmeQwen2VLConfig
49
- base_model_prefix = "model"
50
- supports_gradient_checkpointing = True
51
- _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
52
- # _skip_keys_device_placement = "past_key_values"
53
- _supports_flash_attn_2 = True
54
- _supports_sdpa = True
55
- # _supports_cache_class = True
56
- _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
57
- # _tied_weights_keys = ["lm_head.weight"]
58
-
59
- def __init__(self, config: GmeQwen2VLConfig, **kwargs: Any) -> None:
60
- super().__init__(config)
61
- self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
62
- self.model = Qwen2VLModel(config)
63
- self.vocab_size = config.vocab_size
64
- # self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
65
- self.rope_deltas = None # cache rope_deltas here
66
-
67
- min_pixels: int = config.min_image_tokens * 28 * 28
68
- max_pixels: int = config.max_image_tokens * 28 * 28
69
- self.processor = AutoProcessor.from_pretrained(
70
- config._name_or_path, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
71
- )
72
- self.max_length: int = config.max_length
73
- self.normalize: bool = True
74
- self.processor.tokenizer.padding_side = "right"
75
- self.default_instruction: str = "You are a helpful assistant."
76
- self.sep: str = " "
77
-
78
- # Initialize weights and apply final processing
79
- self.post_init()
80
-
81
- def forward(
82
- self,
83
- input_ids: Optional[torch.LongTensor] = None,
84
- attention_mask: Optional[torch.Tensor] = None,
85
- position_ids: Optional[torch.LongTensor] = None,
86
- past_key_values: Optional[List[torch.FloatTensor]] = None,
87
- inputs_embeds: Optional[torch.FloatTensor] = None,
88
- pixel_values: Optional[torch.Tensor] = None,
89
- # pixel_values_videos: Optional[torch.FloatTensor] = None,
90
- image_grid_thw: Optional[torch.LongTensor] = None,
91
- # video_grid_thw: Optional[torch.LongTensor] = None,
92
- pooling_mask: Optional[torch.LongTensor] = None,
93
- **kwargs
94
- ) -> torch.Tensor:
95
- if inputs_embeds is None:
96
- inputs_embeds = self.model.get_input_embeddings()(input_ids)
97
- if pixel_values is not None:
98
- pixel_values = pixel_values.type(self.visual.get_dtype())
99
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
100
- image_mask = input_ids == self.config.image_token_id
101
- inputs_embeds[image_mask] = image_embeds
102
- # if pixel_values_videos is not None:
103
- # pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
104
- # video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
105
- # video_mask = input_ids == self.config.video_token_id
106
- # inputs_embeds[video_mask] = video_embeds
107
- if attention_mask is not None:
108
- attention_mask = attention_mask.to(inputs_embeds.device)
109
-
110
- outputs = self.model(
111
- input_ids=None,
112
- position_ids=position_ids,
113
- attention_mask=attention_mask,
114
- past_key_values=past_key_values,
115
- inputs_embeds=inputs_embeds,
116
- )
117
-
118
- pooling_mask = attention_mask if pooling_mask is None else pooling_mask
119
- left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
120
- if left_padding:
121
- embeddings = outputs.last_hidden_state[:, -1]
122
- else:
123
- sequence_lengths = pooling_mask.sum(dim=1) - 1
124
- batch_size = outputs.last_hidden_state.shape[0]
125
- embeddings = outputs.last_hidden_state[torch.arange(
126
- batch_size, device=outputs.last_hidden_state.device
127
- ), sequence_lengths]
128
- if self.normalize:
129
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
130
- return embeddings.contiguous()
131
-
132
- def embed(self, texts: list[str], images: list[Image.Image], is_query=True, instruction=None, **kwargs):
133
- self.eval()
134
- # Inputs must be batched
135
- input_texts, input_images = list(), list()
136
- for t, i in zip(texts, images):
137
- if not is_query or instruction is None:
138
- instruction = self.default_instruction
139
- input_str = ''
140
- if i is None:
141
- input_images = None # All examples in the same batch are consistent
142
- else:
143
- input_str += '<|vision_start|><|image_pad|><|vision_end|>'
144
- i = fetch_image(i)
145
- input_images.append(i)
146
- if t is not None:
147
- input_str += t
148
- msg = f'<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
149
- input_texts.append(msg)
150
-
151
- inputs = self.processor(
152
- text=input_texts,
153
- images=input_images,
154
- padding=True,
155
- truncation=True,
156
- max_length=self.max_length,
157
- return_tensors='pt'
158
- )
159
- inputs = {k: v.to(self.device) for k, v in inputs.items()} # TODO
160
- with torch.inference_mode():
161
- embeddings = self.forward(**inputs)
162
- return embeddings
163
-
164
- def encode(self, sentences: list[str], *, prompt_name=None, **kwargs):
165
- return self.get_fused_embeddings(texts=sentences, prompt_name=prompt_name, **kwargs)
166
-
167
- def encode_queries(self, queries: List[str], **kwargs):
168
- embeddings = self.encode(queries, **kwargs)
169
- return embeddings
170
-
171
- def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs):
172
- if type(corpus) is dict:
173
- sentences = [
174
- (corpus["title"][i] + self.sep + corpus["text"][i]).strip()
175
- if "title" in corpus
176
- else corpus["text"][i].strip()
177
- for i in range(len(corpus["text"]))
178
- ]
179
- else:
180
- sentences = [
181
- (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
182
- for doc in corpus
183
- ]
184
- embeddings = self.encode(sentences, is_query=False, **kwargs)
185
- return embeddings
186
-
187
- def get_image_embeddings(self, images: list[Image.Image] | DataLoader, **kwargs):
188
- return self.get_fused_embeddings(images=images, **kwargs)
189
-
190
- def get_text_embeddings(self, texts: list[str], **kwargs):
191
- return self.get_fused_embeddings(texts=texts, **kwargs)
192
-
193
- def get_fused_embeddings(self, texts: list[str] = None, images: list[Image.Image] | DataLoader = None, **kwargs):
194
- if isinstance(images, DataLoader):
195
- image_loader = images
196
- batch_size = image_loader.batch_size
197
- image_loader.dataset.transform = None
198
- else:
199
- batch_size = kwargs.pop('batch_size', 32)
200
- if images is None:
201
- image_loader = None
202
- else:
203
- image_loader = DataLoader(
204
- images,
205
- batch_size=batch_size,
206
- shuffle=False,
207
- collate_fn=custom_collate_fn,
208
- num_workers=min(math.floor(os.cpu_count() / 2), 8),
209
- )
210
-
211
- if texts is None:
212
- assert image_loader is not None
213
- n_batch = len(image_loader)
214
- else:
215
- n_batch = len(texts) // batch_size + int(len(texts) % batch_size > 0)
216
- image_loader = image_loader or [None] * n_batch
217
-
218
- all_embeddings = list()
219
- none_batch = [None] * batch_size
220
- show_progress_bar = kwargs.pop('show_progress_bar', False)
221
- pbar = tqdm(total=n_batch, disable=not show_progress_bar, mininterval=1, miniters=10, desc='encode')
222
- for n, img_batch in zip(range(0, n_batch * batch_size, batch_size), image_loader):
223
- text_batch = none_batch if texts is None else texts[n: n+batch_size]
224
- img_batch = none_batch if img_batch is None else img_batch
225
- embeddings = self.embed(texts=text_batch, images=img_batch, **kwargs)
226
- pbar.update(1)
227
- all_embeddings.append(embeddings.cpu())
228
- pbar.close()
229
- all_embeddings = torch.cat(all_embeddings, dim=0)
230
- return all_embeddings
231
-
232
-
233
- def custom_collate_fn(batch):
234
- return batch
235
-
236
-
237
- ### Copied from qwen_vl_utils.vision_process.py
238
- import base64
239
- from io import BytesIO
240
- import requests
241
-
242
- IMAGE_FACTOR = 28
243
- MIN_PIXELS = 4 * 28 * 28
244
- MAX_PIXELS = 16384 * 28 * 28
245
- MAX_RATIO = 200
246
-
247
-
248
- def round_by_factor(number: int, factor: int) -> int:
249
- """Returns the closest integer to 'number' that is divisible by 'factor'."""
250
- return round(number / factor) * factor
251
-
252
-
253
- def ceil_by_factor(number: int, factor: int) -> int:
254
- """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
255
- return math.ceil(number / factor) * factor
256
-
257
-
258
- def floor_by_factor(number: int, factor: int) -> int:
259
- """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
260
- return math.floor(number / factor) * factor
261
-
262
-
263
- def smart_resize(
264
- height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
265
- ) -> tuple[int, int]:
266
- """
267
- Rescales the image so that the following conditions are met:
268
-
269
- 1. Both dimensions (height and width) are divisible by 'factor'.
270
-
271
- 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
272
-
273
- 3. The aspect ratio of the image is maintained as closely as possible.
274
- """
275
- h_bar = max(factor, round_by_factor(height, factor))
276
- w_bar = max(factor, round_by_factor(width, factor))
277
- if h_bar * w_bar > max_pixels:
278
- beta = math.sqrt((height * width) / max_pixels)
279
- h_bar = floor_by_factor(height / beta, factor)
280
- w_bar = floor_by_factor(width / beta, factor)
281
- elif h_bar * w_bar < min_pixels:
282
- beta = math.sqrt(min_pixels / (height * width))
283
- h_bar = ceil_by_factor(height * beta, factor)
284
- w_bar = ceil_by_factor(width * beta, factor)
285
-
286
- if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO:
287
- logging.warning(
288
- f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}"
289
- )
290
- if h_bar > w_bar:
291
- h_bar = w_bar * MAX_RATIO
292
- else:
293
- w_bar = h_bar * MAX_RATIO
294
- return h_bar, w_bar
295
-
296
-
297
- def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
298
- image_obj = None
299
- if isinstance(image, Image.Image):
300
- image_obj = image
301
- elif image.startswith("http://") or image.startswith("https://"):
302
- image_obj = Image.open(requests.get(image, stream=True).raw)
303
- elif image.startswith("file://"):
304
- image_obj = Image.open(image[7:])
305
- elif image.startswith("data:image"):
306
- if "base64," in image:
307
- _, base64_data = image.split("base64,", 1)
308
- data = base64.b64decode(base64_data)
309
- image_obj = Image.open(BytesIO(data))
310
- else:
311
- image_obj = Image.open(image)
312
- if image_obj is None:
313
- raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
314
- image = image_obj.convert("RGB")
315
- ## resize
316
- # if "resized_height" in ele and "resized_width" in ele:
317
- # resized_height, resized_width = smart_resize(
318
- # ele["resized_height"],
319
- # ele["resized_width"],
320
- # factor=size_factor,
321
- # )
322
- # else:
323
- width, height = image.size
324
- # min_pixels = ele.get("min_pixels", MIN_PIXELS)
325
- # max_pixels = ele.get("max_pixels", MAX_PIXELS)
326
- resized_height, resized_width = smart_resize(
327
- height,
328
- width,
329
- factor=size_factor,
330
- min_pixels=MIN_PIXELS,
331
- max_pixels=MAX_PIXELS,
332
- )
333
- image = image.resize((resized_width, resized_height))
334
-
335
- return image
336
- ###
337
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules.json DELETED
@@ -1,20 +0,0 @@
1
- [
2
- {
3
- "idx": 0,
4
- "name": "0",
5
- "path": "",
6
- "type": "custom_st.MultiModalTransformer"
7
- },
8
- {
9
- "idx": 1,
10
- "name": "1",
11
- "path": "1_Pooling",
12
- "type": "sentence_transformers.models.Pooling"
13
- },
14
- {
15
- "idx": 2,
16
- "name": "2",
17
- "path": "2_Normalize",
18
- "type": "sentence_transformers.models.Normalize"
19
- }
20
- ]