gmastrapas commited on
Commit
1da93af
·
verified ·
1 Parent(s): 0d648f5

Model update

Browse files
Files changed (5) hide show
  1. README1.md +114 -0
  2. infer.py +499 -0
  3. infer_utils.py +247 -0
  4. modeling_jvlm.py +1 -1
  5. test_jvlm.py +504 -0
README1.md ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # JVLM - Jina Vision Language Model
2
+
3
+ Minimal inference script for JVLM with streaming output and batch processing.
4
+
5
+ ## Quick Start
6
+
7
+ ```bash
8
+ python infer.py -i test_image.jpg -p "Describe the image"
9
+ ```
10
+
11
+ ## Requirements
12
+
13
+ ```bash
14
+ uv sync
15
+ ```
16
+
17
+ Optional extras:
18
+ ```bash
19
+ uv sync --extra accelerate # recommended for automatic device selection
20
+ uv sync --extra tensorflow # only needed for tensorflow resize methods
21
+ ```
22
+
23
+ ## Usage
24
+
25
+ ### CLI
26
+
27
+ ```bash
28
+ # Single image (streaming)
29
+ python infer.py -i photo.jpg -p "What's in this image?"
30
+
31
+ # Remote image URL
32
+ python infer.py -i https://example.com/image.jpg -p "Describe this"
33
+
34
+ # Multiple images (local and remote)
35
+ python infer.py -i img1.jpg -i https://example.com/img2.jpg -i img3.jpg -p "Compare these images"
36
+
37
+ # Glob pattern support (quote patterns to prevent shell expansion)
38
+ python infer.py -i "*.jpg" -p "Describe"
39
+ python infer.py -i "photos/*.png" -i "images/*.jpg" -p "What do you see?"
40
+
41
+ # Non-streaming
42
+ python infer.py -i photo.jpg -p "What's in this image?" --no-stream
43
+
44
+ # Custom model
45
+ python infer.py -m /path/to/model -i image.png -p "Describe the scene"
46
+
47
+ # Custom max tokens
48
+ python infer.py -i photo.jpg -p "Explain in detail" --max-tokens 2048
49
+
50
+ # Prompt position control
51
+ python infer.py -i photo.jpg -p "Describe" --prompt-first
52
+
53
+ # Map mode: apply one prompt to multiple images
54
+ python infer.py --map -i "*.jpg" -p "What is this?"
55
+
56
+ # Map mode: apply multiple prompts to one image
57
+ python infer.py --map -i photo.jpg -p "What breed?" -p "What color?" -p "Happy or sad?"
58
+ ```
59
+
60
+ **Options:**
61
+ - `-i, --image`: image path, URL, or glob pattern (can specify multiple times, default: test_image.jpg)
62
+ - `-p, --prompt`: text prompt (can specify multiple times with --map, default: "Describe the image for me in 100 words")
63
+ - `-m, --model`: model path (default: ".")
64
+ - `--max-tokens`: maximum output tokens (default: 1024)
65
+ - `--no-stream`: disable streaming (default: stream token-by-token)
66
+ - `--no-image-labels`: disable ordinal labels for multi-image inputs (default: enabled)
67
+ - `--prompt-first`: place prompt before images instead of after (may affect output quality)
68
+ - `--map`: map mode - apply single prompt to multiple images OR multiple prompts to single image
69
+
70
+ ### Python
71
+
72
+ ```python
73
+ from PIL import Image
74
+ import torch
75
+ from transformers import AutoModelForCausalLM, AutoProcessor
76
+
77
+ processor = AutoProcessor.from_pretrained(".", trust_remote_code=True, use_fast=False)
78
+ model = AutoModelForCausalLM.from_pretrained(
79
+ ".", trust_remote_code=True, dtype=torch.bfloat16,
80
+ device_map="auto"
81
+ )
82
+
83
+ device = next(model.parameters()).device
84
+
85
+ image = Image.open("test_image.jpg")
86
+ inputs = [{
87
+ 'role': 'user',
88
+ 'content': [{'type': 'image', 'image': image}, {'type': 'text', 'text': "Describe this"}]
89
+ }]
90
+
91
+ messages, images = processor.apply_chat_template(inputs, add_generation_prompt=True)
92
+ processed_inputs = processor(messages=messages, images=images)
93
+ batched_inputs = processor.collate([processed_inputs], max_sequence_length=4096)
94
+
95
+ # Move to device
96
+ device_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v
97
+ for k, v in batched_inputs.items()}
98
+
99
+ # Streaming generation
100
+ with torch.no_grad(), torch.autocast(device.type, dtype=torch.bfloat16):
101
+ for token_id in model.stream_generate(
102
+ input_ids=device_inputs['input_ids'],
103
+ images=device_inputs['images'],
104
+ image_masks=device_inputs['image_masks'],
105
+ image_input_idx=device_inputs['image_input_idx'],
106
+ max_new_tokens=256,
107
+ ):
108
+ text = processor.tokenizer.decode([token_id], skip_special_tokens=True)
109
+ print(text, end='', flush=True)
110
+ ```
111
+
112
+ ## Notes
113
+
114
+ Streaming uses `stream_generate()` method with KV cache in `modeling_jvlm.py`.
infer.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import logging
4
+ import os
5
+ import sys
6
+ from time import perf_counter
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+ from urllib.parse import urlparse
9
+
10
+ import torch
11
+ from transformers import (
12
+ AutoModelForCausalLM, AutoProcessor, GenerationConfig, TextStreamer
13
+ )
14
+ from transformers.utils import is_flash_attn_2_available
15
+
16
+
17
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
18
+ TEST_IMAGE = './assets/the_persistence_of_memory.jpg'
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+
23
+ class Timer:
24
+ def __enter__(self):
25
+ self.start = perf_counter()
26
+ self.readout = None
27
+ return self
28
+
29
+ def __exit__(self, *_, **__):
30
+ self.time = perf_counter() - self.start
31
+ self.readout = f'{self.time:.3f}'
32
+
33
+
34
+ def _resolve_device_dtype_and_attn() -> Tuple[torch.device, torch.dtype, str]:
35
+ if torch.cuda.is_available():
36
+ device = torch.device('cuda')
37
+ if is_flash_attn_2_available():
38
+ dtype = torch.bfloat16
39
+ attn_implementation = 'flash_attention_2'
40
+ else:
41
+ dtype = torch.float16
42
+ attn_implementation = 'sdpa'
43
+ else:
44
+ if torch.backends.mps.is_available():
45
+ device = torch.device('mps')
46
+ else:
47
+ device = torch.device('cpu')
48
+ dtype = torch.float32
49
+ attn_implementation = 'sdpa'
50
+
51
+ return device, dtype, attn_implementation
52
+
53
+
54
+ def _build_conversations(
55
+ images: Optional[List[str]],
56
+ prompts: Optional[List[str]],
57
+ batched: bool = False,
58
+ prompt_first: bool = False,
59
+ image_labels: bool = False,
60
+ ):
61
+ def _is_url(_path: str) -> bool:
62
+ try:
63
+ result = urlparse(_path)
64
+ return result.scheme in ('http', 'https')
65
+ except:
66
+ return False
67
+
68
+ images = images or []
69
+ expanded_image_paths = []
70
+ for path in images:
71
+ if _is_url(path):
72
+ expanded_image_paths.append(path)
73
+ elif any(char in path for char in ['*', '?', '[', ']']):
74
+ matched_files = glob.glob(path)
75
+ if matched_files:
76
+ expanded_image_paths.extend(sorted(matched_files))
77
+ else:
78
+ log.warning(f'No files matched pattern "{path}"')
79
+ else:
80
+ expanded_image_paths.append(path)
81
+ images = expanded_image_paths or [TEST_IMAGE]
82
+ n_images = len(images)
83
+
84
+ if prompts is None:
85
+ prompts = (
86
+ ['Describe the image in 100 words'] if n_images == 1 or batched else
87
+ ['Describe the images in 100 words']
88
+ )
89
+ n_prompts = len(prompts)
90
+
91
+ if n_images == 1 and n_prompts == 1:
92
+ examples = [([images[0]], prompts[0])]
93
+ elif batched:
94
+ if n_images > 1 and n_prompts == 1:
95
+ prompt = prompts[0]
96
+ log.info(f'Batch mode: Applying 1 prompt to {n_images} images')
97
+ examples = [([image], prompt) for image in images]
98
+ elif n_images == 1 and n_prompts > 1:
99
+ image = images[0]
100
+ log.info(f'\nBatch mode: Applying {n_prompts} prompts to 1 image')
101
+ examples = [([image], prompt) for prompt in prompts]
102
+ elif n_images > 1 and n_images == n_prompts:
103
+ log.info(f'\nBatch mode: Applying {n_prompts} prompts to {n_images} images')
104
+ examples = [([image], prompt) for image, prompt in zip(images, prompts)]
105
+ else:
106
+ log.error(
107
+ 'Batch mode requires either (multiple images + 1 prompt) or '
108
+ '(1 image + multiple prompts) or (multiple images + multiple prompts) '
109
+ 'with equal number of images and prompts. Got '
110
+ f'{n_images} images and {n_prompts} prompts'
111
+ )
112
+ sys.exit(1)
113
+ else:
114
+ if n_prompts > 1:
115
+ log.error(
116
+ 'Non-batch mode requires 1+ images and 1 prompt. Got '
117
+ f'{n_images} images and {n_prompts} prompts'
118
+ )
119
+ sys.exit(1)
120
+ examples = [(images, prompts[0])]
121
+
122
+ conversations = []
123
+ allimages = []
124
+ allprompts = []
125
+ ordinals = [
126
+ 'first', 'second', 'third', 'fourth', 'fifth',
127
+ 'sixth', 'seventh', 'eighth', 'ninth', 'tenth',
128
+ ]
129
+ for images, prompt in examples:
130
+ content = []
131
+ allimages.append(images)
132
+ allprompts.append(prompt)
133
+ if prompt_first:
134
+ content.append({'type': 'text', 'text': prompt})
135
+ if len(images) > 1 and image_labels:
136
+ for idx, img in enumerate(images):
137
+ ordinal = ordinals[idx] if idx < len(ordinals) else f'{idx+1}th'
138
+ image = images[idx]
139
+ descriptor = f'url: {image}'
140
+ if os.path.isfile(image):
141
+ descriptor = f'filename: {os.path.basename(image)}'
142
+ content.append({
143
+ 'type': 'text',
144
+ 'text': f'(this is the {ordinal} image, {descriptor})',
145
+ })
146
+ content.append({'type': 'image', 'image': img})
147
+ else:
148
+ content.extend([{'type': 'image', 'image': image} for image in images])
149
+ if not prompt_first:
150
+ content.append({'type': 'text', 'text': prompt})
151
+ conversations.append({'role': 'user', 'content': content})
152
+
153
+ return conversations, allimages, allprompts
154
+
155
+
156
+ def _token_usage_report(
157
+ inputs: Dict[str, Any],
158
+ images: List[Any],
159
+ max_sequence_length: int,
160
+ special_image_token_ids: Dict[str, int],
161
+ ):
162
+ """Report token usage statistics in tree format."""
163
+ n_images = len(images)
164
+ input_ids = inputs['input_ids']
165
+ attention_mask = inputs['attention_mask']
166
+
167
+ # Total tokens in sequence (non-padding)
168
+ total_tokens = attention_mask.sum().item()
169
+
170
+ # Count ALL image-related tokens directly from input_ids
171
+ image_patch_id = special_image_token_ids['image_patch_token_id']
172
+ image_start_id = special_image_token_ids['image_start_token_id']
173
+ image_end_id = special_image_token_ids['image_end_token_id']
174
+ image_col_id = special_image_token_ids['image_col_token_id']
175
+
176
+ num_patch = (input_ids[0] == image_patch_id).sum().item()
177
+ num_start = (input_ids[0] == image_start_id).sum().item()
178
+ num_end = (input_ids[0] == image_end_id).sum().item()
179
+ num_col = (input_ids[0] == image_col_id).sum().item()
180
+
181
+ # Total image tokens = all image-related special tokens
182
+ total_image_tokens = num_patch + num_start + num_end + num_col
183
+
184
+ # Pure text tokens (excluding all image-related tokens)
185
+ text_token_count = total_tokens - total_image_tokens
186
+
187
+ report = [
188
+ f'Input Context Window Layout (max: {max_sequence_length} tokens):',
189
+ f'├── Total: {total_tokens} tokens '
190
+ f'({((total_tokens / max_sequence_length) * 100):.1f}%)',
191
+ ]
192
+ # Count tokens per image by finding img_start and img_end boundaries
193
+ # Each image is delimited by img_start and img_end tokens
194
+ tokens_per_image_list = []
195
+
196
+ # Find all img_start and img_end positions in input_ids
197
+ start_positions = (input_ids[0] == image_start_id).nonzero(
198
+ as_tuple=True
199
+ )[0].tolist()
200
+ end_positions = (input_ids[0] == image_end_id).nonzero(as_tuple=True)[0].tolist()
201
+
202
+ if len(start_positions) > 0 and len(end_positions) > 0:
203
+ # Each image typically has 2 start and 2 end tokens
204
+ # Determine actual number of images in context
205
+ n_starts_per_image = 2 # typical case
206
+ n_images_in_context = len(start_positions) // n_starts_per_image
207
+
208
+ # Warn if not all images fit in context
209
+ if n_images_in_context < n_images:
210
+ log.warning(
211
+ f'Only {n_images_in_context}/{n_images} images fit in context window'
212
+ )
213
+
214
+ for idx in range(n_images):
215
+ if idx < n_images_in_context:
216
+ # Get the start and end indices for this image
217
+ start_idx_begin = idx * n_starts_per_image
218
+ end_idx_end = (idx + 1) * n_starts_per_image
219
+ if (
220
+ start_idx_begin < len(start_positions) and
221
+ end_idx_end <= len(end_positions)
222
+ ):
223
+ # First start position and last end position define the image span
224
+ first_start = start_positions[start_idx_begin]
225
+ last_end = end_positions[end_idx_end - 1]
226
+ # Count tokens from first start to last end (inclusive)
227
+ num_tokens = last_end - first_start + 1
228
+ tokens_per_image_list.append(num_tokens)
229
+ else:
230
+ tokens_per_image_list.append(0)
231
+ else:
232
+ # Image didn't fit in context
233
+ tokens_per_image_list.append(0)
234
+ else:
235
+ # Fallback to uniform division if we can't find boundaries
236
+ tokens_per_image = total_image_tokens // n_images if n_images > 0 else 0
237
+ tokens_per_image_list = [tokens_per_image] * n_images
238
+
239
+ for idx in range(n_images):
240
+ img = images[idx]
241
+ n_tokens = tokens_per_image_list[idx] if idx < len(tokens_per_image_list) else 0
242
+ pct = (n_tokens / max_sequence_length * 100)
243
+ report.append(
244
+ f'├── Image {idx + 1}: {img.width}x{img.height} → {n_tokens} '
245
+ f'tokens ({pct:.1f}%)'
246
+ )
247
+
248
+ text_pct = (text_token_count / max_sequence_length * 100)
249
+ report.append(f'└── Text: {text_token_count} tokens ({text_pct:.1f}%)')
250
+
251
+ return '\n'.join(report)
252
+
253
+
254
+ def test_jvlm():
255
+ parser = argparse.ArgumentParser(
256
+ description='jina-vlm-v1 vision-language model inference.'
257
+ )
258
+ parser.add_argument(
259
+ '-m',
260
+ '--model',
261
+ default='.',
262
+ help=(
263
+ 'Model path. Set this to "jinaai/jina-vlm-v1" if you are running this '
264
+ 'script outside this repo.'
265
+ )
266
+ )
267
+ parser.add_argument(
268
+ '-i',
269
+ '--image',
270
+ action='append',
271
+ help='Image path or glob pattern (can specify multiple times, e.g., "*.jpg").'
272
+ )
273
+ parser.add_argument(
274
+ '-p',
275
+ '--prompt',
276
+ action='append',
277
+ help='Text prompt (can specify multiple times with --map).',
278
+ )
279
+ parser.add_argument(
280
+ '--max-crops',
281
+ type=int,
282
+ default=12,
283
+ help='Maximum crops (default: 12).',
284
+ )
285
+ parser.add_argument(
286
+ '--max-tokens',
287
+ type=int,
288
+ default=1024,
289
+ help='Maximum output tokens (default: 1024).',
290
+ )
291
+ parser.add_argument(
292
+ '--max-pixels',
293
+ type=int,
294
+ default=None,
295
+ help=(
296
+ 'Max pixels per image, bigger images are resized and the aspect ratio is '
297
+ 'preserved (default: None).'
298
+ ),
299
+ )
300
+ parser.add_argument(
301
+ '--no-stream',
302
+ action='store_true',
303
+ help='Disable streaming (default: stream token-by-token)',
304
+ )
305
+ parser.add_argument(
306
+ '--image-labels',
307
+ action='store_true',
308
+ help=(
309
+ 'Enable ordinal text labels after each image '
310
+ '(default: no image labels for multi-image)'
311
+ ),
312
+ )
313
+ parser.add_argument(
314
+ '--prompt-first',
315
+ action='store_true',
316
+ help=(
317
+ 'Place prompt before images instead of after (default: prompt after images)'
318
+ ),
319
+ )
320
+ parser.add_argument(
321
+ '--batched',
322
+ action='store_true',
323
+ help=(
324
+ 'Batch mode: apply single prompt to multiple images (or single image to '
325
+ 'multiple prompts) with KV cache reuse.'
326
+ ),
327
+ )
328
+ args = parser.parse_args()
329
+
330
+ print('Welcome to the jinaai/jina-vlm-v1 playground ✨')
331
+ print('Use this script to test our model!')
332
+ print('- Jina AI')
333
+ print()
334
+
335
+ print('Model path: ', args.model)
336
+ print('Loading the processor ...')
337
+ processor = AutoProcessor.from_pretrained(
338
+ args.model, trust_remote_code=True, use_fast=False,
339
+ )
340
+ print('Done ✅')
341
+ print()
342
+
343
+ print('Specifying device, dtype and attention implementation ...')
344
+ device, dtype, attn_implementation = _resolve_device_dtype_and_attn()
345
+ print(f'Using attention implementation: {attn_implementation}')
346
+ print(f'Using device: {device}')
347
+ print(f'Using dtype: {dtype}')
348
+ print()
349
+
350
+ print('Loading the model ...')
351
+ model = AutoModelForCausalLM.from_pretrained(
352
+ args.model,
353
+ trust_remote_code=True,
354
+ dtype=dtype,
355
+ low_cpu_mem_usage=True,
356
+ device_map=device.type,
357
+ attn_implementation=attn_implementation,
358
+ )
359
+ max_sequence_length = getattr(model.config, 'max_sequence_length', 40960)
360
+ n_params = sum(p.numel() for p in model.parameters())
361
+ print(f'Max sequence length: {max_sequence_length}')
362
+ print(f'Number of parameters: {n_params}')
363
+ print('Done ✅')
364
+ print()
365
+
366
+ print('Let\'s create some conversations ...')
367
+ conversations, images, prompts = _build_conversations(
368
+ args.image,
369
+ args.prompt,
370
+ batched=args.batched,
371
+ prompt_first=args.prompt_first,
372
+ image_labels=args.image_labels
373
+ )
374
+ n_conversations = len(conversations)
375
+ print(f'Built {n_conversations} conversations 🚀')
376
+ print()
377
+
378
+ print('Transforming conversations to numbers ...')
379
+ timer = Timer()
380
+ with timer:
381
+ texts = processor.apply_chat_template(conversations, add_generation_prompt=True)
382
+ inputs = processor(
383
+ text=texts,
384
+ images=images,
385
+ padding='longest',
386
+ max_length=max_sequence_length,
387
+ max_crops=args.max_crops,
388
+ max_pixels=args.max_pixels,
389
+ do_resize=True if args.max_pixels is not None else False,
390
+ return_tensors='pt',
391
+ )
392
+ device_inputs = {}
393
+ for k, v in inputs.items():
394
+ if k == 'labels':
395
+ continue
396
+ if isinstance(v, torch.Tensor):
397
+ if v.is_floating_point():
398
+ device_inputs[k] = v.to(device, dtype=dtype, non_blocking=True)
399
+ else:
400
+ device_inputs[k] = v.to(device, non_blocking=True)
401
+ else:
402
+ device_inputs[k] = v
403
+
404
+ processing_time = timer.readout
405
+ special_image_token_ids = {
406
+ 'image_patch_token_id': processor.image_processor.image_patch_token_id,
407
+ 'image_start_token_id': processor.image_processor.image_start_token_id,
408
+ 'image_end_token_id': processor.image_processor.image_end_token_id,
409
+ 'image_col_token_id': processor.image_processor.image_col_token_id,
410
+ }
411
+ token_usage_reports = []
412
+ for idx in range(n_conversations):
413
+ ith_inputs = {k: v[idx] for k, v in inputs.items()}
414
+ token_usage_report = _token_usage_report(
415
+ ith_inputs,
416
+ images[idx],
417
+ max_sequence_length=max_sequence_length,
418
+ special_image_token_ids=special_image_token_ids,
419
+ )
420
+ token_usage_reports.append(token_usage_report)
421
+ print(f'Processed {n_conversations} conversations in {processing_time}s')
422
+ print('All done 🪄')
423
+ print()
424
+
425
+ print('Running inference ...')
426
+ generated_tokens = 0
427
+ input_prompts = inputs['input_ids']
428
+
429
+ if args.no_stream:
430
+ print('Non-streaming mode')
431
+ print('Inference will run in a batch')
432
+ print()
433
+
434
+ with (
435
+ timer,
436
+ torch.no_grad(),
437
+ torch.autocast(device.type, enabled=(device.type != 'mps'), dtype=dtype),
438
+ ):
439
+ output = model.generate(
440
+ **device_inputs,
441
+ generation_config=GenerationConfig(
442
+ max_new_tokens=args.max_tokens, do_sample=False,
443
+ ),
444
+ )
445
+ generation_time, generation_readout = timer.time, timer.readout
446
+
447
+ for idx in range(n_conversations):
448
+ out = output.sequences[idx][len(input_prompts[idx].tolist()):]
449
+ generated_tokens += len(out)
450
+ response = processor.tokenizer.decode(out, skip_special_tokens=True)
451
+ print(f'Conversation {idx + 1}/{n_conversations}')
452
+ print(f'├── 🖼️Images: {images[idx]}')
453
+ print(f'├── 📜Prompt: {prompts[idx]}')
454
+ print(f'├── 💬Chat: {texts[idx]}')
455
+ print(f'└── 🧠Response: {response}')
456
+ print('Token usage report:')
457
+ print(token_usage_reports[idx])
458
+ print()
459
+ else:
460
+ print('Streaming mode')
461
+ print('Inference will run sequentially')
462
+ print()
463
+
464
+ streamer = TextStreamer(processor.tokenizer)
465
+ for idx in range(n_conversations):
466
+ print(f'Conversation {idx + 1}/{n_conversations}')
467
+ print(f'├── 🖼️Images: {images[idx]}')
468
+ print(f'├── 📜Prompt: {prompts[idx]}')
469
+ print(f'├── 💬Chat: {texts[idx]}')
470
+ print(f'└── 🧠Response: ')
471
+ ith_inputs = {k: v[idx].unsqueeze(0) for k, v in device_inputs.items()}
472
+ with (
473
+ timer,
474
+ torch.no_grad(),
475
+ torch.autocast(device.type, enabled=(device.type != 'mps'), dtype=dtype)
476
+ ):
477
+ output = model.generate(
478
+ **ith_inputs,
479
+ streamer=streamer,
480
+ generation_config=GenerationConfig(
481
+ max_new_tokens=args.max_tokens, do_sample=False,
482
+ ),
483
+ )
484
+ out = output.sequences[0][len(input_prompts[idx].tolist()):]
485
+ generated_tokens += len(out)
486
+ print('Token usage report:')
487
+ print(token_usage_reports[idx])
488
+ print()
489
+
490
+ generation_time, generation_readout = timer.time, timer.readout
491
+
492
+ res_per_sec = n_conversations / generation_time if generation_time > 0 else 0
493
+ tok_per_sec = generated_tokens / generation_time if generation_time > 0 else 0
494
+ print('Done ✅')
495
+ print(f'Generated {n_conversations} responses in {generation_readout}s')
496
+ print(f'{res_per_sec:.2f} res/s {tok_per_sec:.2f} tok/s')
497
+
498
+ if __name__ == '__main__':
499
+ test_jvlm()
infer_utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import urllib.request
4
+ from urllib.parse import urlparse
5
+ import time
6
+ import torch
7
+
8
+
9
+ def is_url(path):
10
+ """Check if a path is a URL"""
11
+ try:
12
+ result = urlparse(path)
13
+ return result.scheme in ("http", "https")
14
+ except:
15
+ return False
16
+
17
+
18
+ def download_image(url):
19
+ """Download image from URL to temporary file"""
20
+ try:
21
+ # Create temp file with proper extension
22
+ parsed = urlparse(url)
23
+ ext = os.path.splitext(parsed.path)[1] or ".jpg"
24
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=ext)
25
+ temp_path = temp_file.name
26
+ temp_file.close()
27
+
28
+ # Download image
29
+ urllib.request.urlretrieve(url, temp_path)
30
+ print(f"Downloaded image from: {url}")
31
+ return temp_path
32
+ except Exception as e:
33
+ raise RuntimeError(f"Failed to download image from {url}: {e}")
34
+
35
+
36
+ def print_token_stats(batched_inputs, images_list, model, processor):
37
+ """Print token usage statistics in tree format.
38
+
39
+ Comment out the call to this function if you don't want to see the stats.
40
+ """
41
+ input_ids = batched_inputs['input_ids']
42
+ max_ctx_len = model.config.max_sequence_length
43
+ image_input_idx = batched_inputs.get('image_input_idx')
44
+
45
+ # Total tokens in sequence (non-padding)
46
+ valid_mask = input_ids[0] != -1
47
+ total_tokens = valid_mask.sum().item()
48
+
49
+ # Count ALL image-related tokens directly from input_ids
50
+ image_patch_id = processor.image_preprocessor.image_patch_token_id
51
+ image_start_id = processor.image_preprocessor.image_start_token_id
52
+ image_end_id = processor.image_preprocessor.image_end_token_id
53
+ image_col_id = processor.image_preprocessor.image_col_token_id
54
+
55
+ num_patch = (input_ids[0] == image_patch_id).sum().item()
56
+ num_start = (input_ids[0] == image_start_id).sum().item()
57
+ num_end = (input_ids[0] == image_end_id).sum().item()
58
+ num_col = (input_ids[0] == image_col_id).sum().item()
59
+
60
+ # Total image tokens = all image-related special tokens
61
+ total_image_tokens = num_patch + num_start + num_end + num_col
62
+
63
+ # Pure text tokens (excluding all image-related tokens)
64
+ text_token_count = total_tokens - total_image_tokens
65
+
66
+ print("Input Context Window Layout (max: {} tokens):".format(max_ctx_len))
67
+ print("└── Total: {} tokens ({:.1f}%)".format(
68
+ total_tokens, (total_tokens / max_ctx_len) * 100))
69
+
70
+ # Count tokens per image by finding img_start and img_end boundaries
71
+ # Each image is delimited by img_start and img_end tokens
72
+ tokens_per_image_list = []
73
+
74
+ # Find all img_start and img_end positions in input_ids
75
+ start_positions = (input_ids[0] == image_start_id).nonzero(as_tuple=True)[0].tolist()
76
+ end_positions = (input_ids[0] == image_end_id).nonzero(as_tuple=True)[0].tolist()
77
+
78
+ if len(start_positions) > 0 and len(end_positions) > 0:
79
+ # Each image typically has 2 start and 2 end tokens
80
+ # Determine actual number of images in context
81
+ num_starts_per_image = 2 # typical case
82
+ num_images_in_context = len(start_positions) // num_starts_per_image
83
+
84
+ # Warn if not all images fit in context
85
+ if num_images_in_context < len(images_list):
86
+ print(f"Warning: Only {num_images_in_context}/{len(images_list)} images fit in context window")
87
+
88
+ for img_idx in range(len(images_list)):
89
+ if img_idx < num_images_in_context:
90
+ # Get the start and end indices for this image
91
+ start_idx_begin = img_idx * num_starts_per_image
92
+ start_idx_end = (img_idx + 1) * num_starts_per_image
93
+ end_idx_begin = img_idx * num_starts_per_image
94
+ end_idx_end = (img_idx + 1) * num_starts_per_image
95
+
96
+ if start_idx_begin < len(start_positions) and end_idx_end <= len(end_positions):
97
+ # First start position and last end position define the image span
98
+ first_start = start_positions[start_idx_begin]
99
+ last_end = end_positions[end_idx_end - 1]
100
+ # Count tokens from first start to last end (inclusive)
101
+ num_tokens = last_end - first_start + 1
102
+ tokens_per_image_list.append(num_tokens)
103
+ else:
104
+ tokens_per_image_list.append(0)
105
+ else:
106
+ # Image didn't fit in context
107
+ tokens_per_image_list.append(0)
108
+ else:
109
+ # Fallback to uniform division if we can't find boundaries
110
+ tokens_per_image = total_image_tokens // len(images_list) if len(images_list) > 0 else 0
111
+ tokens_per_image_list = [tokens_per_image] * len(images_list)
112
+
113
+ for img_idx in range(len(images_list)):
114
+ img = images_list[img_idx]
115
+ num_tokens = tokens_per_image_list[img_idx] if img_idx < len(tokens_per_image_list) else 0
116
+ pct = (num_tokens / max_ctx_len * 100)
117
+ if img_idx < len(images_list) - 1:
118
+ print(" ├��─ Image {}: {}x{} → {} tokens ({:.1f}%)".format(
119
+ img_idx + 1, img.width, img.height, num_tokens, pct))
120
+ else:
121
+ print(" ├── Image {}: {}x{} → {} tokens ({:.1f}%)".format(
122
+ img_idx + 1, img.width, img.height, num_tokens, pct))
123
+
124
+ # Show text last
125
+ text_pct = (text_token_count / max_ctx_len * 100)
126
+ print(" └── Text: {} tokens ({:.1f}%)".format(text_token_count, text_pct))
127
+ print()
128
+
129
+
130
+ def is_chinese_char(cp):
131
+ """Check if character is CJK"""
132
+ if (
133
+ (cp >= 0x4E00 and cp <= 0x9FFF)
134
+ or (cp >= 0x3400 and cp <= 0x4DBF)
135
+ or (cp >= 0x20000 and cp <= 0x2A6DF)
136
+ or (cp >= 0x2A700 and cp <= 0x2B73F)
137
+ or (cp >= 0x2B740 and cp <= 0x2B81F)
138
+ or (cp >= 0x2B820 and cp <= 0x2CEAF)
139
+ or (cp >= 0xF900 and cp <= 0xFAFF)
140
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
141
+ ):
142
+ return True
143
+ return False
144
+
145
+
146
+ def build_content(images, prompt, prompt_first):
147
+ """Build content list with proper ordering."""
148
+ content = []
149
+ if prompt_first:
150
+ content.append({"type": "text", "text": prompt})
151
+ content.extend([{"type": "image", "image": img} for img in images])
152
+ else:
153
+ content.extend([{"type": "image", "image": img} for img in images])
154
+ content.append({"type": "text", "text": prompt})
155
+ return content
156
+
157
+
158
+ def generate_single(model, processor, content, args, device, prefer_mps, max_crops=12):
159
+ """
160
+ Generate output for a single image-prompt pair.
161
+
162
+ Returns:
163
+ text: Generated text
164
+ elapsed_time: Time taken
165
+ num_tokens: Number of tokens generated
166
+ """
167
+ inputs = [{"role": "user", "content": content}]
168
+ messages, images = processor.apply_chat_template(inputs, add_generation_prompt=True)
169
+ processed_inputs = processor(messages=messages, images=images)
170
+
171
+ # Use model's max sequence length from config
172
+ max_seq_len = getattr(model.config, 'max_sequence_length', 40960)
173
+ batched_inputs = processor.collate(
174
+ [processed_inputs], max_sequence_length=max_seq_len, max_crops=max_crops * len(images)
175
+ )
176
+
177
+ device_inputs = {}
178
+ for k, v in batched_inputs.items():
179
+ if isinstance(v, torch.Tensor):
180
+ if prefer_mps and v.is_floating_point():
181
+ device_inputs[k] = v.to(device, dtype=torch.float32, non_blocking=True)
182
+ else:
183
+ device_inputs[k] = v.to(device, non_blocking=True)
184
+ else:
185
+ device_inputs[k] = v
186
+
187
+ with (
188
+ torch.no_grad(),
189
+ torch.autocast(device, enabled=(device != "mps"), dtype=torch.bfloat16),
190
+ ):
191
+ if args.no_stream:
192
+ start_time = time.time()
193
+ outputs = model.generate(
194
+ input_ids=device_inputs["input_ids"],
195
+ images=device_inputs.get("images"),
196
+ image_masks=device_inputs.get("image_masks"),
197
+ image_input_idx=device_inputs["image_input_idx"],
198
+ max_new_tokens=args.max_tokens,
199
+ )
200
+ elapsed_time = time.time() - start_time
201
+ text = processor.tokenizer.decode(
202
+ outputs.token_ids[0, 0], skip_special_tokens=True
203
+ )
204
+ num_tokens = len(outputs.token_ids[0, 0])
205
+
206
+ print(text)
207
+ else:
208
+ # Streaming mode
209
+ token_cache = []
210
+ print_len = 0
211
+ token_count = 0
212
+ start_time = time.time()
213
+ for token_id in model.stream_generate(
214
+ input_ids=device_inputs["input_ids"],
215
+ position_ids=device_inputs.get("position_ids"),
216
+ images=device_inputs.get("images"),
217
+ image_masks=device_inputs.get("image_masks"),
218
+ image_input_idx=device_inputs["image_input_idx"],
219
+ max_new_tokens=args.max_tokens,
220
+ ):
221
+ token_cache.append(token_id)
222
+ token_count += 1
223
+ text = processor.tokenizer.decode(token_cache, skip_special_tokens=True)
224
+
225
+ if text.endswith("\n"):
226
+ printable_text = text[print_len:]
227
+ token_cache = []
228
+ print_len = 0
229
+ elif len(text) > 0 and is_chinese_char(ord(text[-1])):
230
+ printable_text = text[print_len:]
231
+ print_len += len(printable_text)
232
+ else:
233
+ printable_text = text[print_len : text.rfind(" ") + 1]
234
+ print_len += len(printable_text)
235
+
236
+ print(printable_text, end="", flush=True)
237
+
238
+ elapsed_time = time.time() - start_time
239
+ if print_len < len(text):
240
+ print(text[print_len:], end="", flush=True)
241
+ print()
242
+ num_tokens = token_count
243
+
244
+ tok_per_sec = num_tokens / elapsed_time if elapsed_time > 0 else 0
245
+ print(f"{tok_per_sec:.2f} tok/s")
246
+
247
+ return text, elapsed_time, num_tokens
modeling_jvlm.py CHANGED
@@ -492,7 +492,7 @@ class JinaVLM(JinaPreTrainedModel):
492
  **kwargs: Unpack[FlashAttentionKwargs],
493
  ) -> BaseModelOutputWithPast:
494
  image_features = None
495
- if images is not None:
496
  image_out = self.vision_model(images, image_masks)
497
  image_features = image_out.last_hidden_state
498
  return self.language_model(
 
492
  **kwargs: Unpack[FlashAttentionKwargs],
493
  ) -> BaseModelOutputWithPast:
494
  image_features = None
495
+ if images is not None and images.shape[1] > 0:
496
  image_out = self.vision_model(images, image_masks)
497
  image_features = image_out.last_hidden_state
498
  return self.language_model(
test_jvlm.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import logging
4
+ import os
5
+ import sys
6
+ from time import perf_counter
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+ from urllib.parse import urlparse
9
+
10
+ import torch
11
+ from transformers import (
12
+ AutoModelForCausalLM, AutoProcessor, GenerationConfig, TextStreamer
13
+ )
14
+ from transformers.utils import is_flash_attn_2_available
15
+
16
+
17
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
18
+ TEST_IMAGE = './assets/the_persistence_of_memory.jpg'
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+
23
+ class Timer:
24
+ def __enter__(self):
25
+ self.start = perf_counter()
26
+ self.readout = None
27
+ return self
28
+
29
+ def __exit__(self, *_, **__):
30
+ self.time = perf_counter() - self.start
31
+ self.readout = f'{self.time:.3f}'
32
+
33
+
34
+ def _resolve_device_dtype_and_attn() -> Tuple[torch.device, torch.dtype, str]:
35
+ if torch.cuda.is_available():
36
+ device = torch.device('cuda')
37
+ if is_flash_attn_2_available():
38
+ dtype = torch.bfloat16
39
+ attn_implementation = 'flash_attention_2'
40
+ else:
41
+ dtype = torch.float16
42
+ attn_implementation = 'sdpa'
43
+ else:
44
+ if torch.backends.mps.is_available():
45
+ device = torch.device('mps')
46
+ else:
47
+ device = torch.device('cpu')
48
+ dtype = torch.float32
49
+ attn_implementation = 'sdpa'
50
+
51
+ return device, dtype, attn_implementation
52
+
53
+
54
+ def _build_conversations(
55
+ images: Optional[List[str]],
56
+ prompts: Optional[List[str]],
57
+ batched: bool = False,
58
+ prompt_first: bool = False,
59
+ image_labels: bool = False,
60
+ ):
61
+ def _is_url(_path: str) -> bool:
62
+ try:
63
+ result = urlparse(_path)
64
+ return result.scheme in ('http', 'https')
65
+ except:
66
+ return False
67
+
68
+ images = images or []
69
+ expanded_image_paths = []
70
+ for path in images:
71
+ if _is_url(path):
72
+ expanded_image_paths.append(path)
73
+ elif any(char in path for char in ['*', '?', '[', ']']):
74
+ matched_files = glob.glob(path)
75
+ if matched_files:
76
+ expanded_image_paths.extend(sorted(matched_files))
77
+ else:
78
+ log.warning(f'No files matched pattern "{path}"')
79
+ else:
80
+ expanded_image_paths.append(path)
81
+ images = expanded_image_paths
82
+ n_images = len(images)
83
+ if prompts is None:
84
+ if len(images) == 0:
85
+ images = [TEST_IMAGE]
86
+ n_images = len(images)
87
+ prompts = (
88
+ ['Describe the image in 100 words'] if n_images == 1 or batched else
89
+ ['Describe the images in 100 words']
90
+ )
91
+ n_prompts = len(prompts)
92
+
93
+ if n_images == 0:
94
+ examples = [([], prompt) for prompt in prompts]
95
+ elif n_images == 1 and n_prompts == 1:
96
+ examples = [([images[0]], prompts[0])]
97
+ elif batched:
98
+ if n_images > 1 and n_prompts == 1:
99
+ prompt = prompts[0]
100
+ log.info(f'Batch mode: Applying 1 prompt to {n_images} images')
101
+ examples = [([image], prompt) for image in images]
102
+ elif n_images == 1 and n_prompts > 1:
103
+ image = images[0]
104
+ log.info(f'\nBatch mode: Applying {n_prompts} prompts to 1 image')
105
+ examples = [([image], prompt) for prompt in prompts]
106
+ elif n_images > 1 and n_images == n_prompts:
107
+ log.info(f'\nBatch mode: Applying {n_prompts} prompts to {n_images} images')
108
+ examples = [([image], prompt) for image, prompt in zip(images, prompts)]
109
+ else:
110
+ log.error(
111
+ 'Batch mode requires either (multiple images + 1 prompt) or '
112
+ '(1 image + multiple prompts) or (multiple images + multiple prompts) '
113
+ 'with equal number of images and prompts. Got '
114
+ f'{n_images} images and {n_prompts} prompts'
115
+ )
116
+ sys.exit(1)
117
+ else:
118
+ if n_prompts > 1:
119
+ log.error(
120
+ 'Non-batch mode requires 1+ images and 1 prompt. Got '
121
+ f'{n_images} images and {n_prompts} prompts'
122
+ )
123
+ sys.exit(1)
124
+ examples = [(images, prompts[0])]
125
+
126
+ conversations = []
127
+ allimages = []
128
+ allprompts = []
129
+ ordinals = [
130
+ 'first', 'second', 'third', 'fourth', 'fifth',
131
+ 'sixth', 'seventh', 'eighth', 'ninth', 'tenth',
132
+ ]
133
+ for images, prompt in examples:
134
+ content = []
135
+ allimages.append(images)
136
+ allprompts.append(prompt)
137
+ if prompt_first:
138
+ content.append({'type': 'text', 'text': prompt})
139
+ if len(images) > 1 and image_labels:
140
+ for idx, img in enumerate(images):
141
+ ordinal = ordinals[idx] if idx < len(ordinals) else f'{idx+1}th'
142
+ image = images[idx]
143
+ descriptor = f'url: {image}'
144
+ if os.path.isfile(image):
145
+ descriptor = f'filename: {os.path.basename(image)}'
146
+ content.append({
147
+ 'type': 'text',
148
+ 'text': f'(this is the {ordinal} image, {descriptor})',
149
+ })
150
+ content.append({'type': 'image', 'image': img})
151
+ else:
152
+ content.extend([{'type': 'image', 'image': image} for image in images])
153
+ if not prompt_first:
154
+ content.append({'type': 'text', 'text': prompt})
155
+ conversations.append([{'role': 'user', 'content': content}])
156
+
157
+ return conversations, allimages, allprompts
158
+
159
+
160
+ def _token_usage_report(
161
+ inputs: Dict[str, Any],
162
+ n_images: int,
163
+ max_sequence_length: int,
164
+ special_image_token_ids: Dict[str, int],
165
+ ):
166
+ """Report token usage statistics in tree format."""
167
+ input_ids = inputs['input_ids']
168
+ attention_mask = inputs['attention_mask']
169
+
170
+ # Total tokens in sequence (non-padding)
171
+ total_tokens = attention_mask.sum().item()
172
+
173
+ # Count ALL image-related tokens directly from input_ids
174
+ image_patch_id = special_image_token_ids['image_patch_token_id']
175
+ image_start_id = special_image_token_ids['image_start_token_id']
176
+ image_end_id = special_image_token_ids['image_end_token_id']
177
+ image_column_token_id = special_image_token_ids['image_column_token_id']
178
+
179
+ num_patch = (input_ids == image_patch_id).sum().item()
180
+ num_start = (input_ids == image_start_id).sum().item()
181
+ num_end = (input_ids == image_end_id).sum().item()
182
+ num_col = (input_ids == image_column_token_id).sum().item()
183
+
184
+ # Total image tokens = all image-related special tokens
185
+ total_image_tokens = num_patch + num_start + num_end + num_col
186
+
187
+ # Pure text tokens (excluding all image-related tokens)
188
+ text_token_count = total_tokens - total_image_tokens
189
+
190
+ report = [
191
+ f'Input Context Window Layout (max: {max_sequence_length} tokens):',
192
+ f'├── Total: {total_tokens} tokens '
193
+ f'({((total_tokens / max_sequence_length) * 100):.1f}%)',
194
+ ]
195
+ # Count tokens per image by finding img_start and img_end boundaries
196
+ # Each image is delimited by img_start and img_end tokens
197
+ tokens_per_image_list = []
198
+
199
+ # Find all img_start and img_end positions in input_ids
200
+ start_positions = (input_ids == image_start_id).nonzero(
201
+ as_tuple=True
202
+ )[0].tolist()
203
+ end_positions = (input_ids == image_end_id).nonzero(as_tuple=True)[0].tolist()
204
+
205
+ if len(start_positions) > 0 and len(end_positions) > 0:
206
+ # Each image typically has 2 start and 2 end tokens
207
+ # Determine actual number of images in context
208
+ n_starts_per_image = 2 # typical case
209
+ n_images_in_context = len(start_positions) // n_starts_per_image
210
+
211
+ # Warn if not all images fit in context
212
+ if n_images_in_context < n_images:
213
+ log.warning(
214
+ f'Only {n_images_in_context}/{n_images} images fit in context window'
215
+ )
216
+
217
+ for idx in range(n_images):
218
+ if idx < n_images_in_context:
219
+ # Get the start and end indices for this image
220
+ start_idx_begin = idx * n_starts_per_image
221
+ end_idx_end = (idx + 1) * n_starts_per_image
222
+ if (
223
+ start_idx_begin < len(start_positions) and
224
+ end_idx_end <= len(end_positions)
225
+ ):
226
+ # First start position and last end position define the image span
227
+ first_start = start_positions[start_idx_begin]
228
+ last_end = end_positions[end_idx_end - 1]
229
+ # Count tokens from first start to last end (inclusive)
230
+ num_tokens = last_end - first_start + 1
231
+ tokens_per_image_list.append(num_tokens)
232
+ else:
233
+ tokens_per_image_list.append(0)
234
+ else:
235
+ # Image didn't fit in context
236
+ tokens_per_image_list.append(0)
237
+ else:
238
+ # Fallback to uniform division if we can't find boundaries
239
+ tokens_per_image = total_image_tokens // n_images if n_images > 0 else 0
240
+ tokens_per_image_list = [tokens_per_image] * n_images
241
+
242
+ for idx in range(n_images):
243
+ n_tokens = tokens_per_image_list[idx] if idx < len(tokens_per_image_list) else 0
244
+ pct = (n_tokens / max_sequence_length * 100)
245
+ report.append(f'├── Image {idx + 1} → {n_tokens} tokens ({pct:.1f}%)')
246
+
247
+ text_pct = (text_token_count / max_sequence_length * 100)
248
+ report.append(f'└── Text: {text_token_count} tokens ({text_pct:.1f}%)')
249
+
250
+ return '\n'.join(report)
251
+
252
+
253
+ def test_jvlm():
254
+ parser = argparse.ArgumentParser(
255
+ description='jina-vlm-v1 vision-language model inference.'
256
+ )
257
+ parser.add_argument(
258
+ '-m',
259
+ '--model',
260
+ default='.',
261
+ help=(
262
+ 'Model path. Set this to "jinaai/jina-vlm-v1" if you are running this '
263
+ 'script outside this repo.'
264
+ )
265
+ )
266
+ parser.add_argument(
267
+ '-i',
268
+ '--image',
269
+ action='append',
270
+ help='Image path or glob pattern (can specify multiple times, e.g., "*.jpg").'
271
+ )
272
+ parser.add_argument(
273
+ '-p',
274
+ '--prompt',
275
+ action='append',
276
+ help='Text prompt (can specify multiple times with --map).',
277
+ )
278
+ parser.add_argument(
279
+ '--max-crops',
280
+ type=int,
281
+ default=12,
282
+ help='Maximum crops (default: 12).',
283
+ )
284
+ parser.add_argument(
285
+ '--max-tokens',
286
+ type=int,
287
+ default=1024,
288
+ help='Maximum output tokens (default: 1024).',
289
+ )
290
+ parser.add_argument(
291
+ '--max-pixels',
292
+ type=int,
293
+ default=None,
294
+ help=(
295
+ 'Max pixels per image, bigger images are resized and the aspect ratio is '
296
+ 'preserved (default: None).'
297
+ ),
298
+ )
299
+ parser.add_argument(
300
+ '--no-stream',
301
+ action='store_true',
302
+ help='Disable streaming (default: stream token-by-token)',
303
+ )
304
+ parser.add_argument(
305
+ '--image-labels',
306
+ action='store_true',
307
+ help=(
308
+ 'Enable ordinal text labels after each image '
309
+ '(default: no image labels for multi-image)'
310
+ ),
311
+ )
312
+ parser.add_argument(
313
+ '--prompt-first',
314
+ action='store_true',
315
+ help=(
316
+ 'Place prompt before images instead of after (default: prompt after images)'
317
+ ),
318
+ )
319
+ parser.add_argument(
320
+ '--batched',
321
+ action='store_true',
322
+ help=(
323
+ 'Batch mode: apply single prompt to multiple images (or single image to '
324
+ 'multiple prompts) with KV cache reuse.'
325
+ ),
326
+ )
327
+ args = parser.parse_args()
328
+
329
+ print()
330
+ print('Welcome to the jinaai/jina-vlm-v1 playground ✨')
331
+ print('Use this script to test our model!')
332
+ print('- Jina AI')
333
+ print()
334
+ print('--- Loading the model ...')
335
+ print('Specifying device, dtype and attention implementation ...')
336
+ device, dtype, attn_implementation = _resolve_device_dtype_and_attn()
337
+ print(f'Using attention implementation: {attn_implementation}')
338
+ print(f'Using device: {device}')
339
+ print(f'Using dtype: {dtype}')
340
+ print('Model path: ', args.model)
341
+ processor = AutoProcessor.from_pretrained(
342
+ args.model, trust_remote_code=True, use_fast=False,
343
+ )
344
+ model = AutoModelForCausalLM.from_pretrained(
345
+ args.model,
346
+ trust_remote_code=True,
347
+ dtype=dtype,
348
+ low_cpu_mem_usage=True,
349
+ device_map=device.type,
350
+ attn_implementation=attn_implementation,
351
+ )
352
+ max_sequence_length = getattr(model.config, 'max_sequence_length', 40960)
353
+ n_params = sum(p.numel() for p in model.parameters())
354
+ print(f'Max sequence length: {max_sequence_length}')
355
+ print(f'Number of parameters: {n_params}')
356
+ print('Done ✅')
357
+ print()
358
+
359
+ print('--- Let\'s create some conversations ...')
360
+ conversations, images, prompts = _build_conversations(
361
+ args.image,
362
+ args.prompt,
363
+ batched=args.batched,
364
+ prompt_first=args.prompt_first,
365
+ image_labels=args.image_labels
366
+ )
367
+ n_conversations = len(conversations)
368
+ print(f'Built {n_conversations} conversations 🚀')
369
+ print()
370
+
371
+ print('--- Transforming conversations to numbers ...')
372
+ timer = Timer()
373
+ with timer:
374
+ texts = processor.apply_chat_template(conversations, add_generation_prompt=True)
375
+ print(texts)
376
+ print(images)
377
+ inputs = processor(
378
+ text=texts,
379
+ images=images,
380
+ padding='longest',
381
+ max_length=max_sequence_length,
382
+ max_crops=args.max_crops,
383
+ max_pixels=args.max_pixels,
384
+ do_resize=True if args.max_pixels is not None else False,
385
+ return_tensors='pt',
386
+ )
387
+ print(inputs['images'])
388
+ print(inputs['image_input_idx'])
389
+ texts = texts if isinstance(texts, list) else [texts]
390
+ device_inputs = {}
391
+ for k, v in inputs.items():
392
+ if k == 'labels':
393
+ continue
394
+ if isinstance(v, torch.Tensor):
395
+ if v.is_floating_point():
396
+ device_inputs[k] = v.to(device, dtype=dtype, non_blocking=True)
397
+ else:
398
+ device_inputs[k] = v.to(device, non_blocking=True)
399
+ else:
400
+ device_inputs[k] = v
401
+
402
+ processing_time = timer.readout
403
+ special_image_token_ids = {
404
+ 'image_patch_token_id': processor.image_patch_token_id,
405
+ 'image_start_token_id': processor.image_start_token_id,
406
+ 'image_end_token_id': processor.image_end_token_id,
407
+ 'image_column_token_id': processor.image_column_token_id,
408
+ }
409
+ token_usage_reports = []
410
+ for idx in range(n_conversations):
411
+ ith_inputs = {k: v[idx] for k, v in inputs.items()}
412
+ token_usage_report = _token_usage_report(
413
+ ith_inputs,
414
+ len(images[idx]),
415
+ max_sequence_length=max_sequence_length,
416
+ special_image_token_ids=special_image_token_ids,
417
+ )
418
+ token_usage_reports.append(token_usage_report)
419
+ print(f'Processed {n_conversations} conversations in {processing_time}s')
420
+ print('All done 🪄')
421
+ print()
422
+
423
+ print('--- Running inference ...')
424
+ generated_tokens = 0
425
+ input_prompts = inputs['input_ids']
426
+
427
+ if args.no_stream:
428
+ print('Non-streaming mode')
429
+ print('Inference will run in a batch')
430
+ print()
431
+
432
+ with (
433
+ timer,
434
+ torch.no_grad(),
435
+ torch.autocast(device.type, enabled=(device.type != 'mps'), dtype=dtype),
436
+ ):
437
+ output = model.generate(
438
+ **device_inputs,
439
+ generation_config=GenerationConfig(
440
+ max_new_tokens=args.max_tokens, do_sample=False,
441
+ ),
442
+ return_dict_in_generate=True,
443
+ use_model_defaults=True,
444
+ )
445
+ generation_time, generation_readout = timer.time, timer.readout
446
+
447
+ for idx in range(n_conversations):
448
+ out = output.sequences[idx][len(input_prompts[idx].tolist()):]
449
+ generated_tokens += len(out)
450
+ response = processor.tokenizer.decode(out, skip_special_tokens=True)
451
+ print(f'* Conversation {idx + 1}/{n_conversations}')
452
+ print(f'├── 🖼️Images: {images[idx]}')
453
+ print(f'├── 📜Prompt: {prompts[idx]}')
454
+ print(f'├── 💬Chat: {texts[idx]}')
455
+ print(f'└── 🧠Response:{response}')
456
+ print('Token usage report:')
457
+ print(token_usage_reports[idx])
458
+ print()
459
+ else:
460
+ print('Streaming mode')
461
+ print('Inference will run sequentially')
462
+ print()
463
+
464
+ streamer = TextStreamer(
465
+ processor.tokenizer, skip_prompt=True, skip_special_tokens=True
466
+ )
467
+ for idx in range(n_conversations):
468
+ print(f'* Conversation {idx + 1}/{n_conversations}')
469
+ print(f'├── 🖼️Images: {images[idx]}')
470
+ print(f'├── 📜Prompt: {prompts[idx]}')
471
+ print(f'├── 💬Chat: {texts[idx]}')
472
+ print(f'└── 🧠Response:', end='')
473
+ ith_inputs = {k: v[idx].unsqueeze(0) for k, v in device_inputs.items()}
474
+ with (
475
+ timer,
476
+ torch.no_grad(),
477
+ torch.autocast(device.type, enabled=(device.type != 'mps'), dtype=dtype)
478
+ ):
479
+ output = model.generate(
480
+ **ith_inputs,
481
+ streamer=streamer,
482
+ generation_config=GenerationConfig(
483
+ max_new_tokens=args.max_tokens, do_sample=False,
484
+ ),
485
+ return_dict_in_generate=True,
486
+ use_model_defaults=True,
487
+ )
488
+ out = output.sequences[0][len(input_prompts[idx].tolist()):]
489
+ generated_tokens += len(out)
490
+ print('Token usage report:')
491
+ print(token_usage_reports[idx])
492
+ print()
493
+
494
+ generation_time, generation_readout = timer.time, timer.readout
495
+
496
+ res_per_sec = n_conversations / generation_time if generation_time > 0 else 0
497
+ tok_per_sec = generated_tokens / generation_time if generation_time > 0 else 0
498
+ print(f'Generated {n_conversations} responses in {generation_readout}s')
499
+ print(f'{res_per_sec:.2f} res/s {tok_per_sec:.2f} tok/s')
500
+ print('Done ✅')
501
+
502
+
503
+ if __name__ == '__main__':
504
+ test_jvlm()