Model update
Browse files- README1.md +114 -0
- infer.py +499 -0
- infer_utils.py +247 -0
- modeling_jvlm.py +1 -1
- test_jvlm.py +504 -0
README1.md
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# JVLM - Jina Vision Language Model
|
| 2 |
+
|
| 3 |
+
Minimal inference script for JVLM with streaming output and batch processing.
|
| 4 |
+
|
| 5 |
+
## Quick Start
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
python infer.py -i test_image.jpg -p "Describe the image"
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
## Requirements
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
uv sync
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
Optional extras:
|
| 18 |
+
```bash
|
| 19 |
+
uv sync --extra accelerate # recommended for automatic device selection
|
| 20 |
+
uv sync --extra tensorflow # only needed for tensorflow resize methods
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
## Usage
|
| 24 |
+
|
| 25 |
+
### CLI
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
# Single image (streaming)
|
| 29 |
+
python infer.py -i photo.jpg -p "What's in this image?"
|
| 30 |
+
|
| 31 |
+
# Remote image URL
|
| 32 |
+
python infer.py -i https://example.com/image.jpg -p "Describe this"
|
| 33 |
+
|
| 34 |
+
# Multiple images (local and remote)
|
| 35 |
+
python infer.py -i img1.jpg -i https://example.com/img2.jpg -i img3.jpg -p "Compare these images"
|
| 36 |
+
|
| 37 |
+
# Glob pattern support (quote patterns to prevent shell expansion)
|
| 38 |
+
python infer.py -i "*.jpg" -p "Describe"
|
| 39 |
+
python infer.py -i "photos/*.png" -i "images/*.jpg" -p "What do you see?"
|
| 40 |
+
|
| 41 |
+
# Non-streaming
|
| 42 |
+
python infer.py -i photo.jpg -p "What's in this image?" --no-stream
|
| 43 |
+
|
| 44 |
+
# Custom model
|
| 45 |
+
python infer.py -m /path/to/model -i image.png -p "Describe the scene"
|
| 46 |
+
|
| 47 |
+
# Custom max tokens
|
| 48 |
+
python infer.py -i photo.jpg -p "Explain in detail" --max-tokens 2048
|
| 49 |
+
|
| 50 |
+
# Prompt position control
|
| 51 |
+
python infer.py -i photo.jpg -p "Describe" --prompt-first
|
| 52 |
+
|
| 53 |
+
# Map mode: apply one prompt to multiple images
|
| 54 |
+
python infer.py --map -i "*.jpg" -p "What is this?"
|
| 55 |
+
|
| 56 |
+
# Map mode: apply multiple prompts to one image
|
| 57 |
+
python infer.py --map -i photo.jpg -p "What breed?" -p "What color?" -p "Happy or sad?"
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
**Options:**
|
| 61 |
+
- `-i, --image`: image path, URL, or glob pattern (can specify multiple times, default: test_image.jpg)
|
| 62 |
+
- `-p, --prompt`: text prompt (can specify multiple times with --map, default: "Describe the image for me in 100 words")
|
| 63 |
+
- `-m, --model`: model path (default: ".")
|
| 64 |
+
- `--max-tokens`: maximum output tokens (default: 1024)
|
| 65 |
+
- `--no-stream`: disable streaming (default: stream token-by-token)
|
| 66 |
+
- `--no-image-labels`: disable ordinal labels for multi-image inputs (default: enabled)
|
| 67 |
+
- `--prompt-first`: place prompt before images instead of after (may affect output quality)
|
| 68 |
+
- `--map`: map mode - apply single prompt to multiple images OR multiple prompts to single image
|
| 69 |
+
|
| 70 |
+
### Python
|
| 71 |
+
|
| 72 |
+
```python
|
| 73 |
+
from PIL import Image
|
| 74 |
+
import torch
|
| 75 |
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
| 76 |
+
|
| 77 |
+
processor = AutoProcessor.from_pretrained(".", trust_remote_code=True, use_fast=False)
|
| 78 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 79 |
+
".", trust_remote_code=True, dtype=torch.bfloat16,
|
| 80 |
+
device_map="auto"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
device = next(model.parameters()).device
|
| 84 |
+
|
| 85 |
+
image = Image.open("test_image.jpg")
|
| 86 |
+
inputs = [{
|
| 87 |
+
'role': 'user',
|
| 88 |
+
'content': [{'type': 'image', 'image': image}, {'type': 'text', 'text': "Describe this"}]
|
| 89 |
+
}]
|
| 90 |
+
|
| 91 |
+
messages, images = processor.apply_chat_template(inputs, add_generation_prompt=True)
|
| 92 |
+
processed_inputs = processor(messages=messages, images=images)
|
| 93 |
+
batched_inputs = processor.collate([processed_inputs], max_sequence_length=4096)
|
| 94 |
+
|
| 95 |
+
# Move to device
|
| 96 |
+
device_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v
|
| 97 |
+
for k, v in batched_inputs.items()}
|
| 98 |
+
|
| 99 |
+
# Streaming generation
|
| 100 |
+
with torch.no_grad(), torch.autocast(device.type, dtype=torch.bfloat16):
|
| 101 |
+
for token_id in model.stream_generate(
|
| 102 |
+
input_ids=device_inputs['input_ids'],
|
| 103 |
+
images=device_inputs['images'],
|
| 104 |
+
image_masks=device_inputs['image_masks'],
|
| 105 |
+
image_input_idx=device_inputs['image_input_idx'],
|
| 106 |
+
max_new_tokens=256,
|
| 107 |
+
):
|
| 108 |
+
text = processor.tokenizer.decode([token_id], skip_special_tokens=True)
|
| 109 |
+
print(text, end='', flush=True)
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
## Notes
|
| 113 |
+
|
| 114 |
+
Streaming uses `stream_generate()` method with KV cache in `modeling_jvlm.py`.
|
infer.py
ADDED
|
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import glob
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from time import perf_counter
|
| 7 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 8 |
+
from urllib.parse import urlparse
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from transformers import (
|
| 12 |
+
AutoModelForCausalLM, AutoProcessor, GenerationConfig, TextStreamer
|
| 13 |
+
)
|
| 14 |
+
from transformers.utils import is_flash_attn_2_available
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 18 |
+
TEST_IMAGE = './assets/the_persistence_of_memory.jpg'
|
| 19 |
+
|
| 20 |
+
log = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Timer:
|
| 24 |
+
def __enter__(self):
|
| 25 |
+
self.start = perf_counter()
|
| 26 |
+
self.readout = None
|
| 27 |
+
return self
|
| 28 |
+
|
| 29 |
+
def __exit__(self, *_, **__):
|
| 30 |
+
self.time = perf_counter() - self.start
|
| 31 |
+
self.readout = f'{self.time:.3f}'
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _resolve_device_dtype_and_attn() -> Tuple[torch.device, torch.dtype, str]:
|
| 35 |
+
if torch.cuda.is_available():
|
| 36 |
+
device = torch.device('cuda')
|
| 37 |
+
if is_flash_attn_2_available():
|
| 38 |
+
dtype = torch.bfloat16
|
| 39 |
+
attn_implementation = 'flash_attention_2'
|
| 40 |
+
else:
|
| 41 |
+
dtype = torch.float16
|
| 42 |
+
attn_implementation = 'sdpa'
|
| 43 |
+
else:
|
| 44 |
+
if torch.backends.mps.is_available():
|
| 45 |
+
device = torch.device('mps')
|
| 46 |
+
else:
|
| 47 |
+
device = torch.device('cpu')
|
| 48 |
+
dtype = torch.float32
|
| 49 |
+
attn_implementation = 'sdpa'
|
| 50 |
+
|
| 51 |
+
return device, dtype, attn_implementation
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _build_conversations(
|
| 55 |
+
images: Optional[List[str]],
|
| 56 |
+
prompts: Optional[List[str]],
|
| 57 |
+
batched: bool = False,
|
| 58 |
+
prompt_first: bool = False,
|
| 59 |
+
image_labels: bool = False,
|
| 60 |
+
):
|
| 61 |
+
def _is_url(_path: str) -> bool:
|
| 62 |
+
try:
|
| 63 |
+
result = urlparse(_path)
|
| 64 |
+
return result.scheme in ('http', 'https')
|
| 65 |
+
except:
|
| 66 |
+
return False
|
| 67 |
+
|
| 68 |
+
images = images or []
|
| 69 |
+
expanded_image_paths = []
|
| 70 |
+
for path in images:
|
| 71 |
+
if _is_url(path):
|
| 72 |
+
expanded_image_paths.append(path)
|
| 73 |
+
elif any(char in path for char in ['*', '?', '[', ']']):
|
| 74 |
+
matched_files = glob.glob(path)
|
| 75 |
+
if matched_files:
|
| 76 |
+
expanded_image_paths.extend(sorted(matched_files))
|
| 77 |
+
else:
|
| 78 |
+
log.warning(f'No files matched pattern "{path}"')
|
| 79 |
+
else:
|
| 80 |
+
expanded_image_paths.append(path)
|
| 81 |
+
images = expanded_image_paths or [TEST_IMAGE]
|
| 82 |
+
n_images = len(images)
|
| 83 |
+
|
| 84 |
+
if prompts is None:
|
| 85 |
+
prompts = (
|
| 86 |
+
['Describe the image in 100 words'] if n_images == 1 or batched else
|
| 87 |
+
['Describe the images in 100 words']
|
| 88 |
+
)
|
| 89 |
+
n_prompts = len(prompts)
|
| 90 |
+
|
| 91 |
+
if n_images == 1 and n_prompts == 1:
|
| 92 |
+
examples = [([images[0]], prompts[0])]
|
| 93 |
+
elif batched:
|
| 94 |
+
if n_images > 1 and n_prompts == 1:
|
| 95 |
+
prompt = prompts[0]
|
| 96 |
+
log.info(f'Batch mode: Applying 1 prompt to {n_images} images')
|
| 97 |
+
examples = [([image], prompt) for image in images]
|
| 98 |
+
elif n_images == 1 and n_prompts > 1:
|
| 99 |
+
image = images[0]
|
| 100 |
+
log.info(f'\nBatch mode: Applying {n_prompts} prompts to 1 image')
|
| 101 |
+
examples = [([image], prompt) for prompt in prompts]
|
| 102 |
+
elif n_images > 1 and n_images == n_prompts:
|
| 103 |
+
log.info(f'\nBatch mode: Applying {n_prompts} prompts to {n_images} images')
|
| 104 |
+
examples = [([image], prompt) for image, prompt in zip(images, prompts)]
|
| 105 |
+
else:
|
| 106 |
+
log.error(
|
| 107 |
+
'Batch mode requires either (multiple images + 1 prompt) or '
|
| 108 |
+
'(1 image + multiple prompts) or (multiple images + multiple prompts) '
|
| 109 |
+
'with equal number of images and prompts. Got '
|
| 110 |
+
f'{n_images} images and {n_prompts} prompts'
|
| 111 |
+
)
|
| 112 |
+
sys.exit(1)
|
| 113 |
+
else:
|
| 114 |
+
if n_prompts > 1:
|
| 115 |
+
log.error(
|
| 116 |
+
'Non-batch mode requires 1+ images and 1 prompt. Got '
|
| 117 |
+
f'{n_images} images and {n_prompts} prompts'
|
| 118 |
+
)
|
| 119 |
+
sys.exit(1)
|
| 120 |
+
examples = [(images, prompts[0])]
|
| 121 |
+
|
| 122 |
+
conversations = []
|
| 123 |
+
allimages = []
|
| 124 |
+
allprompts = []
|
| 125 |
+
ordinals = [
|
| 126 |
+
'first', 'second', 'third', 'fourth', 'fifth',
|
| 127 |
+
'sixth', 'seventh', 'eighth', 'ninth', 'tenth',
|
| 128 |
+
]
|
| 129 |
+
for images, prompt in examples:
|
| 130 |
+
content = []
|
| 131 |
+
allimages.append(images)
|
| 132 |
+
allprompts.append(prompt)
|
| 133 |
+
if prompt_first:
|
| 134 |
+
content.append({'type': 'text', 'text': prompt})
|
| 135 |
+
if len(images) > 1 and image_labels:
|
| 136 |
+
for idx, img in enumerate(images):
|
| 137 |
+
ordinal = ordinals[idx] if idx < len(ordinals) else f'{idx+1}th'
|
| 138 |
+
image = images[idx]
|
| 139 |
+
descriptor = f'url: {image}'
|
| 140 |
+
if os.path.isfile(image):
|
| 141 |
+
descriptor = f'filename: {os.path.basename(image)}'
|
| 142 |
+
content.append({
|
| 143 |
+
'type': 'text',
|
| 144 |
+
'text': f'(this is the {ordinal} image, {descriptor})',
|
| 145 |
+
})
|
| 146 |
+
content.append({'type': 'image', 'image': img})
|
| 147 |
+
else:
|
| 148 |
+
content.extend([{'type': 'image', 'image': image} for image in images])
|
| 149 |
+
if not prompt_first:
|
| 150 |
+
content.append({'type': 'text', 'text': prompt})
|
| 151 |
+
conversations.append({'role': 'user', 'content': content})
|
| 152 |
+
|
| 153 |
+
return conversations, allimages, allprompts
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _token_usage_report(
|
| 157 |
+
inputs: Dict[str, Any],
|
| 158 |
+
images: List[Any],
|
| 159 |
+
max_sequence_length: int,
|
| 160 |
+
special_image_token_ids: Dict[str, int],
|
| 161 |
+
):
|
| 162 |
+
"""Report token usage statistics in tree format."""
|
| 163 |
+
n_images = len(images)
|
| 164 |
+
input_ids = inputs['input_ids']
|
| 165 |
+
attention_mask = inputs['attention_mask']
|
| 166 |
+
|
| 167 |
+
# Total tokens in sequence (non-padding)
|
| 168 |
+
total_tokens = attention_mask.sum().item()
|
| 169 |
+
|
| 170 |
+
# Count ALL image-related tokens directly from input_ids
|
| 171 |
+
image_patch_id = special_image_token_ids['image_patch_token_id']
|
| 172 |
+
image_start_id = special_image_token_ids['image_start_token_id']
|
| 173 |
+
image_end_id = special_image_token_ids['image_end_token_id']
|
| 174 |
+
image_col_id = special_image_token_ids['image_col_token_id']
|
| 175 |
+
|
| 176 |
+
num_patch = (input_ids[0] == image_patch_id).sum().item()
|
| 177 |
+
num_start = (input_ids[0] == image_start_id).sum().item()
|
| 178 |
+
num_end = (input_ids[0] == image_end_id).sum().item()
|
| 179 |
+
num_col = (input_ids[0] == image_col_id).sum().item()
|
| 180 |
+
|
| 181 |
+
# Total image tokens = all image-related special tokens
|
| 182 |
+
total_image_tokens = num_patch + num_start + num_end + num_col
|
| 183 |
+
|
| 184 |
+
# Pure text tokens (excluding all image-related tokens)
|
| 185 |
+
text_token_count = total_tokens - total_image_tokens
|
| 186 |
+
|
| 187 |
+
report = [
|
| 188 |
+
f'Input Context Window Layout (max: {max_sequence_length} tokens):',
|
| 189 |
+
f'├── Total: {total_tokens} tokens '
|
| 190 |
+
f'({((total_tokens / max_sequence_length) * 100):.1f}%)',
|
| 191 |
+
]
|
| 192 |
+
# Count tokens per image by finding img_start and img_end boundaries
|
| 193 |
+
# Each image is delimited by img_start and img_end tokens
|
| 194 |
+
tokens_per_image_list = []
|
| 195 |
+
|
| 196 |
+
# Find all img_start and img_end positions in input_ids
|
| 197 |
+
start_positions = (input_ids[0] == image_start_id).nonzero(
|
| 198 |
+
as_tuple=True
|
| 199 |
+
)[0].tolist()
|
| 200 |
+
end_positions = (input_ids[0] == image_end_id).nonzero(as_tuple=True)[0].tolist()
|
| 201 |
+
|
| 202 |
+
if len(start_positions) > 0 and len(end_positions) > 0:
|
| 203 |
+
# Each image typically has 2 start and 2 end tokens
|
| 204 |
+
# Determine actual number of images in context
|
| 205 |
+
n_starts_per_image = 2 # typical case
|
| 206 |
+
n_images_in_context = len(start_positions) // n_starts_per_image
|
| 207 |
+
|
| 208 |
+
# Warn if not all images fit in context
|
| 209 |
+
if n_images_in_context < n_images:
|
| 210 |
+
log.warning(
|
| 211 |
+
f'Only {n_images_in_context}/{n_images} images fit in context window'
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
for idx in range(n_images):
|
| 215 |
+
if idx < n_images_in_context:
|
| 216 |
+
# Get the start and end indices for this image
|
| 217 |
+
start_idx_begin = idx * n_starts_per_image
|
| 218 |
+
end_idx_end = (idx + 1) * n_starts_per_image
|
| 219 |
+
if (
|
| 220 |
+
start_idx_begin < len(start_positions) and
|
| 221 |
+
end_idx_end <= len(end_positions)
|
| 222 |
+
):
|
| 223 |
+
# First start position and last end position define the image span
|
| 224 |
+
first_start = start_positions[start_idx_begin]
|
| 225 |
+
last_end = end_positions[end_idx_end - 1]
|
| 226 |
+
# Count tokens from first start to last end (inclusive)
|
| 227 |
+
num_tokens = last_end - first_start + 1
|
| 228 |
+
tokens_per_image_list.append(num_tokens)
|
| 229 |
+
else:
|
| 230 |
+
tokens_per_image_list.append(0)
|
| 231 |
+
else:
|
| 232 |
+
# Image didn't fit in context
|
| 233 |
+
tokens_per_image_list.append(0)
|
| 234 |
+
else:
|
| 235 |
+
# Fallback to uniform division if we can't find boundaries
|
| 236 |
+
tokens_per_image = total_image_tokens // n_images if n_images > 0 else 0
|
| 237 |
+
tokens_per_image_list = [tokens_per_image] * n_images
|
| 238 |
+
|
| 239 |
+
for idx in range(n_images):
|
| 240 |
+
img = images[idx]
|
| 241 |
+
n_tokens = tokens_per_image_list[idx] if idx < len(tokens_per_image_list) else 0
|
| 242 |
+
pct = (n_tokens / max_sequence_length * 100)
|
| 243 |
+
report.append(
|
| 244 |
+
f'├── Image {idx + 1}: {img.width}x{img.height} → {n_tokens} '
|
| 245 |
+
f'tokens ({pct:.1f}%)'
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
text_pct = (text_token_count / max_sequence_length * 100)
|
| 249 |
+
report.append(f'└── Text: {text_token_count} tokens ({text_pct:.1f}%)')
|
| 250 |
+
|
| 251 |
+
return '\n'.join(report)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def test_jvlm():
|
| 255 |
+
parser = argparse.ArgumentParser(
|
| 256 |
+
description='jina-vlm-v1 vision-language model inference.'
|
| 257 |
+
)
|
| 258 |
+
parser.add_argument(
|
| 259 |
+
'-m',
|
| 260 |
+
'--model',
|
| 261 |
+
default='.',
|
| 262 |
+
help=(
|
| 263 |
+
'Model path. Set this to "jinaai/jina-vlm-v1" if you are running this '
|
| 264 |
+
'script outside this repo.'
|
| 265 |
+
)
|
| 266 |
+
)
|
| 267 |
+
parser.add_argument(
|
| 268 |
+
'-i',
|
| 269 |
+
'--image',
|
| 270 |
+
action='append',
|
| 271 |
+
help='Image path or glob pattern (can specify multiple times, e.g., "*.jpg").'
|
| 272 |
+
)
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
'-p',
|
| 275 |
+
'--prompt',
|
| 276 |
+
action='append',
|
| 277 |
+
help='Text prompt (can specify multiple times with --map).',
|
| 278 |
+
)
|
| 279 |
+
parser.add_argument(
|
| 280 |
+
'--max-crops',
|
| 281 |
+
type=int,
|
| 282 |
+
default=12,
|
| 283 |
+
help='Maximum crops (default: 12).',
|
| 284 |
+
)
|
| 285 |
+
parser.add_argument(
|
| 286 |
+
'--max-tokens',
|
| 287 |
+
type=int,
|
| 288 |
+
default=1024,
|
| 289 |
+
help='Maximum output tokens (default: 1024).',
|
| 290 |
+
)
|
| 291 |
+
parser.add_argument(
|
| 292 |
+
'--max-pixels',
|
| 293 |
+
type=int,
|
| 294 |
+
default=None,
|
| 295 |
+
help=(
|
| 296 |
+
'Max pixels per image, bigger images are resized and the aspect ratio is '
|
| 297 |
+
'preserved (default: None).'
|
| 298 |
+
),
|
| 299 |
+
)
|
| 300 |
+
parser.add_argument(
|
| 301 |
+
'--no-stream',
|
| 302 |
+
action='store_true',
|
| 303 |
+
help='Disable streaming (default: stream token-by-token)',
|
| 304 |
+
)
|
| 305 |
+
parser.add_argument(
|
| 306 |
+
'--image-labels',
|
| 307 |
+
action='store_true',
|
| 308 |
+
help=(
|
| 309 |
+
'Enable ordinal text labels after each image '
|
| 310 |
+
'(default: no image labels for multi-image)'
|
| 311 |
+
),
|
| 312 |
+
)
|
| 313 |
+
parser.add_argument(
|
| 314 |
+
'--prompt-first',
|
| 315 |
+
action='store_true',
|
| 316 |
+
help=(
|
| 317 |
+
'Place prompt before images instead of after (default: prompt after images)'
|
| 318 |
+
),
|
| 319 |
+
)
|
| 320 |
+
parser.add_argument(
|
| 321 |
+
'--batched',
|
| 322 |
+
action='store_true',
|
| 323 |
+
help=(
|
| 324 |
+
'Batch mode: apply single prompt to multiple images (or single image to '
|
| 325 |
+
'multiple prompts) with KV cache reuse.'
|
| 326 |
+
),
|
| 327 |
+
)
|
| 328 |
+
args = parser.parse_args()
|
| 329 |
+
|
| 330 |
+
print('Welcome to the jinaai/jina-vlm-v1 playground ✨')
|
| 331 |
+
print('Use this script to test our model!')
|
| 332 |
+
print('- Jina AI')
|
| 333 |
+
print()
|
| 334 |
+
|
| 335 |
+
print('Model path: ', args.model)
|
| 336 |
+
print('Loading the processor ...')
|
| 337 |
+
processor = AutoProcessor.from_pretrained(
|
| 338 |
+
args.model, trust_remote_code=True, use_fast=False,
|
| 339 |
+
)
|
| 340 |
+
print('Done ✅')
|
| 341 |
+
print()
|
| 342 |
+
|
| 343 |
+
print('Specifying device, dtype and attention implementation ...')
|
| 344 |
+
device, dtype, attn_implementation = _resolve_device_dtype_and_attn()
|
| 345 |
+
print(f'Using attention implementation: {attn_implementation}')
|
| 346 |
+
print(f'Using device: {device}')
|
| 347 |
+
print(f'Using dtype: {dtype}')
|
| 348 |
+
print()
|
| 349 |
+
|
| 350 |
+
print('Loading the model ...')
|
| 351 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 352 |
+
args.model,
|
| 353 |
+
trust_remote_code=True,
|
| 354 |
+
dtype=dtype,
|
| 355 |
+
low_cpu_mem_usage=True,
|
| 356 |
+
device_map=device.type,
|
| 357 |
+
attn_implementation=attn_implementation,
|
| 358 |
+
)
|
| 359 |
+
max_sequence_length = getattr(model.config, 'max_sequence_length', 40960)
|
| 360 |
+
n_params = sum(p.numel() for p in model.parameters())
|
| 361 |
+
print(f'Max sequence length: {max_sequence_length}')
|
| 362 |
+
print(f'Number of parameters: {n_params}')
|
| 363 |
+
print('Done ✅')
|
| 364 |
+
print()
|
| 365 |
+
|
| 366 |
+
print('Let\'s create some conversations ...')
|
| 367 |
+
conversations, images, prompts = _build_conversations(
|
| 368 |
+
args.image,
|
| 369 |
+
args.prompt,
|
| 370 |
+
batched=args.batched,
|
| 371 |
+
prompt_first=args.prompt_first,
|
| 372 |
+
image_labels=args.image_labels
|
| 373 |
+
)
|
| 374 |
+
n_conversations = len(conversations)
|
| 375 |
+
print(f'Built {n_conversations} conversations 🚀')
|
| 376 |
+
print()
|
| 377 |
+
|
| 378 |
+
print('Transforming conversations to numbers ...')
|
| 379 |
+
timer = Timer()
|
| 380 |
+
with timer:
|
| 381 |
+
texts = processor.apply_chat_template(conversations, add_generation_prompt=True)
|
| 382 |
+
inputs = processor(
|
| 383 |
+
text=texts,
|
| 384 |
+
images=images,
|
| 385 |
+
padding='longest',
|
| 386 |
+
max_length=max_sequence_length,
|
| 387 |
+
max_crops=args.max_crops,
|
| 388 |
+
max_pixels=args.max_pixels,
|
| 389 |
+
do_resize=True if args.max_pixels is not None else False,
|
| 390 |
+
return_tensors='pt',
|
| 391 |
+
)
|
| 392 |
+
device_inputs = {}
|
| 393 |
+
for k, v in inputs.items():
|
| 394 |
+
if k == 'labels':
|
| 395 |
+
continue
|
| 396 |
+
if isinstance(v, torch.Tensor):
|
| 397 |
+
if v.is_floating_point():
|
| 398 |
+
device_inputs[k] = v.to(device, dtype=dtype, non_blocking=True)
|
| 399 |
+
else:
|
| 400 |
+
device_inputs[k] = v.to(device, non_blocking=True)
|
| 401 |
+
else:
|
| 402 |
+
device_inputs[k] = v
|
| 403 |
+
|
| 404 |
+
processing_time = timer.readout
|
| 405 |
+
special_image_token_ids = {
|
| 406 |
+
'image_patch_token_id': processor.image_processor.image_patch_token_id,
|
| 407 |
+
'image_start_token_id': processor.image_processor.image_start_token_id,
|
| 408 |
+
'image_end_token_id': processor.image_processor.image_end_token_id,
|
| 409 |
+
'image_col_token_id': processor.image_processor.image_col_token_id,
|
| 410 |
+
}
|
| 411 |
+
token_usage_reports = []
|
| 412 |
+
for idx in range(n_conversations):
|
| 413 |
+
ith_inputs = {k: v[idx] for k, v in inputs.items()}
|
| 414 |
+
token_usage_report = _token_usage_report(
|
| 415 |
+
ith_inputs,
|
| 416 |
+
images[idx],
|
| 417 |
+
max_sequence_length=max_sequence_length,
|
| 418 |
+
special_image_token_ids=special_image_token_ids,
|
| 419 |
+
)
|
| 420 |
+
token_usage_reports.append(token_usage_report)
|
| 421 |
+
print(f'Processed {n_conversations} conversations in {processing_time}s')
|
| 422 |
+
print('All done 🪄')
|
| 423 |
+
print()
|
| 424 |
+
|
| 425 |
+
print('Running inference ...')
|
| 426 |
+
generated_tokens = 0
|
| 427 |
+
input_prompts = inputs['input_ids']
|
| 428 |
+
|
| 429 |
+
if args.no_stream:
|
| 430 |
+
print('Non-streaming mode')
|
| 431 |
+
print('Inference will run in a batch')
|
| 432 |
+
print()
|
| 433 |
+
|
| 434 |
+
with (
|
| 435 |
+
timer,
|
| 436 |
+
torch.no_grad(),
|
| 437 |
+
torch.autocast(device.type, enabled=(device.type != 'mps'), dtype=dtype),
|
| 438 |
+
):
|
| 439 |
+
output = model.generate(
|
| 440 |
+
**device_inputs,
|
| 441 |
+
generation_config=GenerationConfig(
|
| 442 |
+
max_new_tokens=args.max_tokens, do_sample=False,
|
| 443 |
+
),
|
| 444 |
+
)
|
| 445 |
+
generation_time, generation_readout = timer.time, timer.readout
|
| 446 |
+
|
| 447 |
+
for idx in range(n_conversations):
|
| 448 |
+
out = output.sequences[idx][len(input_prompts[idx].tolist()):]
|
| 449 |
+
generated_tokens += len(out)
|
| 450 |
+
response = processor.tokenizer.decode(out, skip_special_tokens=True)
|
| 451 |
+
print(f'Conversation {idx + 1}/{n_conversations}')
|
| 452 |
+
print(f'├── 🖼️Images: {images[idx]}')
|
| 453 |
+
print(f'├── 📜Prompt: {prompts[idx]}')
|
| 454 |
+
print(f'├── 💬Chat: {texts[idx]}')
|
| 455 |
+
print(f'└── 🧠Response: {response}')
|
| 456 |
+
print('Token usage report:')
|
| 457 |
+
print(token_usage_reports[idx])
|
| 458 |
+
print()
|
| 459 |
+
else:
|
| 460 |
+
print('Streaming mode')
|
| 461 |
+
print('Inference will run sequentially')
|
| 462 |
+
print()
|
| 463 |
+
|
| 464 |
+
streamer = TextStreamer(processor.tokenizer)
|
| 465 |
+
for idx in range(n_conversations):
|
| 466 |
+
print(f'Conversation {idx + 1}/{n_conversations}')
|
| 467 |
+
print(f'├── 🖼️Images: {images[idx]}')
|
| 468 |
+
print(f'├── 📜Prompt: {prompts[idx]}')
|
| 469 |
+
print(f'├── 💬Chat: {texts[idx]}')
|
| 470 |
+
print(f'└── 🧠Response: ')
|
| 471 |
+
ith_inputs = {k: v[idx].unsqueeze(0) for k, v in device_inputs.items()}
|
| 472 |
+
with (
|
| 473 |
+
timer,
|
| 474 |
+
torch.no_grad(),
|
| 475 |
+
torch.autocast(device.type, enabled=(device.type != 'mps'), dtype=dtype)
|
| 476 |
+
):
|
| 477 |
+
output = model.generate(
|
| 478 |
+
**ith_inputs,
|
| 479 |
+
streamer=streamer,
|
| 480 |
+
generation_config=GenerationConfig(
|
| 481 |
+
max_new_tokens=args.max_tokens, do_sample=False,
|
| 482 |
+
),
|
| 483 |
+
)
|
| 484 |
+
out = output.sequences[0][len(input_prompts[idx].tolist()):]
|
| 485 |
+
generated_tokens += len(out)
|
| 486 |
+
print('Token usage report:')
|
| 487 |
+
print(token_usage_reports[idx])
|
| 488 |
+
print()
|
| 489 |
+
|
| 490 |
+
generation_time, generation_readout = timer.time, timer.readout
|
| 491 |
+
|
| 492 |
+
res_per_sec = n_conversations / generation_time if generation_time > 0 else 0
|
| 493 |
+
tok_per_sec = generated_tokens / generation_time if generation_time > 0 else 0
|
| 494 |
+
print('Done ✅')
|
| 495 |
+
print(f'Generated {n_conversations} responses in {generation_readout}s')
|
| 496 |
+
print(f'{res_per_sec:.2f} res/s {tok_per_sec:.2f} tok/s')
|
| 497 |
+
|
| 498 |
+
if __name__ == '__main__':
|
| 499 |
+
test_jvlm()
|
infer_utils.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
import urllib.request
|
| 4 |
+
from urllib.parse import urlparse
|
| 5 |
+
import time
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def is_url(path):
|
| 10 |
+
"""Check if a path is a URL"""
|
| 11 |
+
try:
|
| 12 |
+
result = urlparse(path)
|
| 13 |
+
return result.scheme in ("http", "https")
|
| 14 |
+
except:
|
| 15 |
+
return False
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def download_image(url):
|
| 19 |
+
"""Download image from URL to temporary file"""
|
| 20 |
+
try:
|
| 21 |
+
# Create temp file with proper extension
|
| 22 |
+
parsed = urlparse(url)
|
| 23 |
+
ext = os.path.splitext(parsed.path)[1] or ".jpg"
|
| 24 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=ext)
|
| 25 |
+
temp_path = temp_file.name
|
| 26 |
+
temp_file.close()
|
| 27 |
+
|
| 28 |
+
# Download image
|
| 29 |
+
urllib.request.urlretrieve(url, temp_path)
|
| 30 |
+
print(f"Downloaded image from: {url}")
|
| 31 |
+
return temp_path
|
| 32 |
+
except Exception as e:
|
| 33 |
+
raise RuntimeError(f"Failed to download image from {url}: {e}")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def print_token_stats(batched_inputs, images_list, model, processor):
|
| 37 |
+
"""Print token usage statistics in tree format.
|
| 38 |
+
|
| 39 |
+
Comment out the call to this function if you don't want to see the stats.
|
| 40 |
+
"""
|
| 41 |
+
input_ids = batched_inputs['input_ids']
|
| 42 |
+
max_ctx_len = model.config.max_sequence_length
|
| 43 |
+
image_input_idx = batched_inputs.get('image_input_idx')
|
| 44 |
+
|
| 45 |
+
# Total tokens in sequence (non-padding)
|
| 46 |
+
valid_mask = input_ids[0] != -1
|
| 47 |
+
total_tokens = valid_mask.sum().item()
|
| 48 |
+
|
| 49 |
+
# Count ALL image-related tokens directly from input_ids
|
| 50 |
+
image_patch_id = processor.image_preprocessor.image_patch_token_id
|
| 51 |
+
image_start_id = processor.image_preprocessor.image_start_token_id
|
| 52 |
+
image_end_id = processor.image_preprocessor.image_end_token_id
|
| 53 |
+
image_col_id = processor.image_preprocessor.image_col_token_id
|
| 54 |
+
|
| 55 |
+
num_patch = (input_ids[0] == image_patch_id).sum().item()
|
| 56 |
+
num_start = (input_ids[0] == image_start_id).sum().item()
|
| 57 |
+
num_end = (input_ids[0] == image_end_id).sum().item()
|
| 58 |
+
num_col = (input_ids[0] == image_col_id).sum().item()
|
| 59 |
+
|
| 60 |
+
# Total image tokens = all image-related special tokens
|
| 61 |
+
total_image_tokens = num_patch + num_start + num_end + num_col
|
| 62 |
+
|
| 63 |
+
# Pure text tokens (excluding all image-related tokens)
|
| 64 |
+
text_token_count = total_tokens - total_image_tokens
|
| 65 |
+
|
| 66 |
+
print("Input Context Window Layout (max: {} tokens):".format(max_ctx_len))
|
| 67 |
+
print("└── Total: {} tokens ({:.1f}%)".format(
|
| 68 |
+
total_tokens, (total_tokens / max_ctx_len) * 100))
|
| 69 |
+
|
| 70 |
+
# Count tokens per image by finding img_start and img_end boundaries
|
| 71 |
+
# Each image is delimited by img_start and img_end tokens
|
| 72 |
+
tokens_per_image_list = []
|
| 73 |
+
|
| 74 |
+
# Find all img_start and img_end positions in input_ids
|
| 75 |
+
start_positions = (input_ids[0] == image_start_id).nonzero(as_tuple=True)[0].tolist()
|
| 76 |
+
end_positions = (input_ids[0] == image_end_id).nonzero(as_tuple=True)[0].tolist()
|
| 77 |
+
|
| 78 |
+
if len(start_positions) > 0 and len(end_positions) > 0:
|
| 79 |
+
# Each image typically has 2 start and 2 end tokens
|
| 80 |
+
# Determine actual number of images in context
|
| 81 |
+
num_starts_per_image = 2 # typical case
|
| 82 |
+
num_images_in_context = len(start_positions) // num_starts_per_image
|
| 83 |
+
|
| 84 |
+
# Warn if not all images fit in context
|
| 85 |
+
if num_images_in_context < len(images_list):
|
| 86 |
+
print(f"Warning: Only {num_images_in_context}/{len(images_list)} images fit in context window")
|
| 87 |
+
|
| 88 |
+
for img_idx in range(len(images_list)):
|
| 89 |
+
if img_idx < num_images_in_context:
|
| 90 |
+
# Get the start and end indices for this image
|
| 91 |
+
start_idx_begin = img_idx * num_starts_per_image
|
| 92 |
+
start_idx_end = (img_idx + 1) * num_starts_per_image
|
| 93 |
+
end_idx_begin = img_idx * num_starts_per_image
|
| 94 |
+
end_idx_end = (img_idx + 1) * num_starts_per_image
|
| 95 |
+
|
| 96 |
+
if start_idx_begin < len(start_positions) and end_idx_end <= len(end_positions):
|
| 97 |
+
# First start position and last end position define the image span
|
| 98 |
+
first_start = start_positions[start_idx_begin]
|
| 99 |
+
last_end = end_positions[end_idx_end - 1]
|
| 100 |
+
# Count tokens from first start to last end (inclusive)
|
| 101 |
+
num_tokens = last_end - first_start + 1
|
| 102 |
+
tokens_per_image_list.append(num_tokens)
|
| 103 |
+
else:
|
| 104 |
+
tokens_per_image_list.append(0)
|
| 105 |
+
else:
|
| 106 |
+
# Image didn't fit in context
|
| 107 |
+
tokens_per_image_list.append(0)
|
| 108 |
+
else:
|
| 109 |
+
# Fallback to uniform division if we can't find boundaries
|
| 110 |
+
tokens_per_image = total_image_tokens // len(images_list) if len(images_list) > 0 else 0
|
| 111 |
+
tokens_per_image_list = [tokens_per_image] * len(images_list)
|
| 112 |
+
|
| 113 |
+
for img_idx in range(len(images_list)):
|
| 114 |
+
img = images_list[img_idx]
|
| 115 |
+
num_tokens = tokens_per_image_list[img_idx] if img_idx < len(tokens_per_image_list) else 0
|
| 116 |
+
pct = (num_tokens / max_ctx_len * 100)
|
| 117 |
+
if img_idx < len(images_list) - 1:
|
| 118 |
+
print(" ├��─ Image {}: {}x{} → {} tokens ({:.1f}%)".format(
|
| 119 |
+
img_idx + 1, img.width, img.height, num_tokens, pct))
|
| 120 |
+
else:
|
| 121 |
+
print(" ├── Image {}: {}x{} → {} tokens ({:.1f}%)".format(
|
| 122 |
+
img_idx + 1, img.width, img.height, num_tokens, pct))
|
| 123 |
+
|
| 124 |
+
# Show text last
|
| 125 |
+
text_pct = (text_token_count / max_ctx_len * 100)
|
| 126 |
+
print(" └── Text: {} tokens ({:.1f}%)".format(text_token_count, text_pct))
|
| 127 |
+
print()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def is_chinese_char(cp):
|
| 131 |
+
"""Check if character is CJK"""
|
| 132 |
+
if (
|
| 133 |
+
(cp >= 0x4E00 and cp <= 0x9FFF)
|
| 134 |
+
or (cp >= 0x3400 and cp <= 0x4DBF)
|
| 135 |
+
or (cp >= 0x20000 and cp <= 0x2A6DF)
|
| 136 |
+
or (cp >= 0x2A700 and cp <= 0x2B73F)
|
| 137 |
+
or (cp >= 0x2B740 and cp <= 0x2B81F)
|
| 138 |
+
or (cp >= 0x2B820 and cp <= 0x2CEAF)
|
| 139 |
+
or (cp >= 0xF900 and cp <= 0xFAFF)
|
| 140 |
+
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
| 141 |
+
):
|
| 142 |
+
return True
|
| 143 |
+
return False
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def build_content(images, prompt, prompt_first):
|
| 147 |
+
"""Build content list with proper ordering."""
|
| 148 |
+
content = []
|
| 149 |
+
if prompt_first:
|
| 150 |
+
content.append({"type": "text", "text": prompt})
|
| 151 |
+
content.extend([{"type": "image", "image": img} for img in images])
|
| 152 |
+
else:
|
| 153 |
+
content.extend([{"type": "image", "image": img} for img in images])
|
| 154 |
+
content.append({"type": "text", "text": prompt})
|
| 155 |
+
return content
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def generate_single(model, processor, content, args, device, prefer_mps, max_crops=12):
|
| 159 |
+
"""
|
| 160 |
+
Generate output for a single image-prompt pair.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
text: Generated text
|
| 164 |
+
elapsed_time: Time taken
|
| 165 |
+
num_tokens: Number of tokens generated
|
| 166 |
+
"""
|
| 167 |
+
inputs = [{"role": "user", "content": content}]
|
| 168 |
+
messages, images = processor.apply_chat_template(inputs, add_generation_prompt=True)
|
| 169 |
+
processed_inputs = processor(messages=messages, images=images)
|
| 170 |
+
|
| 171 |
+
# Use model's max sequence length from config
|
| 172 |
+
max_seq_len = getattr(model.config, 'max_sequence_length', 40960)
|
| 173 |
+
batched_inputs = processor.collate(
|
| 174 |
+
[processed_inputs], max_sequence_length=max_seq_len, max_crops=max_crops * len(images)
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
device_inputs = {}
|
| 178 |
+
for k, v in batched_inputs.items():
|
| 179 |
+
if isinstance(v, torch.Tensor):
|
| 180 |
+
if prefer_mps and v.is_floating_point():
|
| 181 |
+
device_inputs[k] = v.to(device, dtype=torch.float32, non_blocking=True)
|
| 182 |
+
else:
|
| 183 |
+
device_inputs[k] = v.to(device, non_blocking=True)
|
| 184 |
+
else:
|
| 185 |
+
device_inputs[k] = v
|
| 186 |
+
|
| 187 |
+
with (
|
| 188 |
+
torch.no_grad(),
|
| 189 |
+
torch.autocast(device, enabled=(device != "mps"), dtype=torch.bfloat16),
|
| 190 |
+
):
|
| 191 |
+
if args.no_stream:
|
| 192 |
+
start_time = time.time()
|
| 193 |
+
outputs = model.generate(
|
| 194 |
+
input_ids=device_inputs["input_ids"],
|
| 195 |
+
images=device_inputs.get("images"),
|
| 196 |
+
image_masks=device_inputs.get("image_masks"),
|
| 197 |
+
image_input_idx=device_inputs["image_input_idx"],
|
| 198 |
+
max_new_tokens=args.max_tokens,
|
| 199 |
+
)
|
| 200 |
+
elapsed_time = time.time() - start_time
|
| 201 |
+
text = processor.tokenizer.decode(
|
| 202 |
+
outputs.token_ids[0, 0], skip_special_tokens=True
|
| 203 |
+
)
|
| 204 |
+
num_tokens = len(outputs.token_ids[0, 0])
|
| 205 |
+
|
| 206 |
+
print(text)
|
| 207 |
+
else:
|
| 208 |
+
# Streaming mode
|
| 209 |
+
token_cache = []
|
| 210 |
+
print_len = 0
|
| 211 |
+
token_count = 0
|
| 212 |
+
start_time = time.time()
|
| 213 |
+
for token_id in model.stream_generate(
|
| 214 |
+
input_ids=device_inputs["input_ids"],
|
| 215 |
+
position_ids=device_inputs.get("position_ids"),
|
| 216 |
+
images=device_inputs.get("images"),
|
| 217 |
+
image_masks=device_inputs.get("image_masks"),
|
| 218 |
+
image_input_idx=device_inputs["image_input_idx"],
|
| 219 |
+
max_new_tokens=args.max_tokens,
|
| 220 |
+
):
|
| 221 |
+
token_cache.append(token_id)
|
| 222 |
+
token_count += 1
|
| 223 |
+
text = processor.tokenizer.decode(token_cache, skip_special_tokens=True)
|
| 224 |
+
|
| 225 |
+
if text.endswith("\n"):
|
| 226 |
+
printable_text = text[print_len:]
|
| 227 |
+
token_cache = []
|
| 228 |
+
print_len = 0
|
| 229 |
+
elif len(text) > 0 and is_chinese_char(ord(text[-1])):
|
| 230 |
+
printable_text = text[print_len:]
|
| 231 |
+
print_len += len(printable_text)
|
| 232 |
+
else:
|
| 233 |
+
printable_text = text[print_len : text.rfind(" ") + 1]
|
| 234 |
+
print_len += len(printable_text)
|
| 235 |
+
|
| 236 |
+
print(printable_text, end="", flush=True)
|
| 237 |
+
|
| 238 |
+
elapsed_time = time.time() - start_time
|
| 239 |
+
if print_len < len(text):
|
| 240 |
+
print(text[print_len:], end="", flush=True)
|
| 241 |
+
print()
|
| 242 |
+
num_tokens = token_count
|
| 243 |
+
|
| 244 |
+
tok_per_sec = num_tokens / elapsed_time if elapsed_time > 0 else 0
|
| 245 |
+
print(f"{tok_per_sec:.2f} tok/s")
|
| 246 |
+
|
| 247 |
+
return text, elapsed_time, num_tokens
|
modeling_jvlm.py
CHANGED
|
@@ -492,7 +492,7 @@ class JinaVLM(JinaPreTrainedModel):
|
|
| 492 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 493 |
) -> BaseModelOutputWithPast:
|
| 494 |
image_features = None
|
| 495 |
-
if images is not None:
|
| 496 |
image_out = self.vision_model(images, image_masks)
|
| 497 |
image_features = image_out.last_hidden_state
|
| 498 |
return self.language_model(
|
|
|
|
| 492 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 493 |
) -> BaseModelOutputWithPast:
|
| 494 |
image_features = None
|
| 495 |
+
if images is not None and images.shape[1] > 0:
|
| 496 |
image_out = self.vision_model(images, image_masks)
|
| 497 |
image_features = image_out.last_hidden_state
|
| 498 |
return self.language_model(
|
test_jvlm.py
ADDED
|
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import glob
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from time import perf_counter
|
| 7 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 8 |
+
from urllib.parse import urlparse
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from transformers import (
|
| 12 |
+
AutoModelForCausalLM, AutoProcessor, GenerationConfig, TextStreamer
|
| 13 |
+
)
|
| 14 |
+
from transformers.utils import is_flash_attn_2_available
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 18 |
+
TEST_IMAGE = './assets/the_persistence_of_memory.jpg'
|
| 19 |
+
|
| 20 |
+
log = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Timer:
|
| 24 |
+
def __enter__(self):
|
| 25 |
+
self.start = perf_counter()
|
| 26 |
+
self.readout = None
|
| 27 |
+
return self
|
| 28 |
+
|
| 29 |
+
def __exit__(self, *_, **__):
|
| 30 |
+
self.time = perf_counter() - self.start
|
| 31 |
+
self.readout = f'{self.time:.3f}'
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _resolve_device_dtype_and_attn() -> Tuple[torch.device, torch.dtype, str]:
|
| 35 |
+
if torch.cuda.is_available():
|
| 36 |
+
device = torch.device('cuda')
|
| 37 |
+
if is_flash_attn_2_available():
|
| 38 |
+
dtype = torch.bfloat16
|
| 39 |
+
attn_implementation = 'flash_attention_2'
|
| 40 |
+
else:
|
| 41 |
+
dtype = torch.float16
|
| 42 |
+
attn_implementation = 'sdpa'
|
| 43 |
+
else:
|
| 44 |
+
if torch.backends.mps.is_available():
|
| 45 |
+
device = torch.device('mps')
|
| 46 |
+
else:
|
| 47 |
+
device = torch.device('cpu')
|
| 48 |
+
dtype = torch.float32
|
| 49 |
+
attn_implementation = 'sdpa'
|
| 50 |
+
|
| 51 |
+
return device, dtype, attn_implementation
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _build_conversations(
|
| 55 |
+
images: Optional[List[str]],
|
| 56 |
+
prompts: Optional[List[str]],
|
| 57 |
+
batched: bool = False,
|
| 58 |
+
prompt_first: bool = False,
|
| 59 |
+
image_labels: bool = False,
|
| 60 |
+
):
|
| 61 |
+
def _is_url(_path: str) -> bool:
|
| 62 |
+
try:
|
| 63 |
+
result = urlparse(_path)
|
| 64 |
+
return result.scheme in ('http', 'https')
|
| 65 |
+
except:
|
| 66 |
+
return False
|
| 67 |
+
|
| 68 |
+
images = images or []
|
| 69 |
+
expanded_image_paths = []
|
| 70 |
+
for path in images:
|
| 71 |
+
if _is_url(path):
|
| 72 |
+
expanded_image_paths.append(path)
|
| 73 |
+
elif any(char in path for char in ['*', '?', '[', ']']):
|
| 74 |
+
matched_files = glob.glob(path)
|
| 75 |
+
if matched_files:
|
| 76 |
+
expanded_image_paths.extend(sorted(matched_files))
|
| 77 |
+
else:
|
| 78 |
+
log.warning(f'No files matched pattern "{path}"')
|
| 79 |
+
else:
|
| 80 |
+
expanded_image_paths.append(path)
|
| 81 |
+
images = expanded_image_paths
|
| 82 |
+
n_images = len(images)
|
| 83 |
+
if prompts is None:
|
| 84 |
+
if len(images) == 0:
|
| 85 |
+
images = [TEST_IMAGE]
|
| 86 |
+
n_images = len(images)
|
| 87 |
+
prompts = (
|
| 88 |
+
['Describe the image in 100 words'] if n_images == 1 or batched else
|
| 89 |
+
['Describe the images in 100 words']
|
| 90 |
+
)
|
| 91 |
+
n_prompts = len(prompts)
|
| 92 |
+
|
| 93 |
+
if n_images == 0:
|
| 94 |
+
examples = [([], prompt) for prompt in prompts]
|
| 95 |
+
elif n_images == 1 and n_prompts == 1:
|
| 96 |
+
examples = [([images[0]], prompts[0])]
|
| 97 |
+
elif batched:
|
| 98 |
+
if n_images > 1 and n_prompts == 1:
|
| 99 |
+
prompt = prompts[0]
|
| 100 |
+
log.info(f'Batch mode: Applying 1 prompt to {n_images} images')
|
| 101 |
+
examples = [([image], prompt) for image in images]
|
| 102 |
+
elif n_images == 1 and n_prompts > 1:
|
| 103 |
+
image = images[0]
|
| 104 |
+
log.info(f'\nBatch mode: Applying {n_prompts} prompts to 1 image')
|
| 105 |
+
examples = [([image], prompt) for prompt in prompts]
|
| 106 |
+
elif n_images > 1 and n_images == n_prompts:
|
| 107 |
+
log.info(f'\nBatch mode: Applying {n_prompts} prompts to {n_images} images')
|
| 108 |
+
examples = [([image], prompt) for image, prompt in zip(images, prompts)]
|
| 109 |
+
else:
|
| 110 |
+
log.error(
|
| 111 |
+
'Batch mode requires either (multiple images + 1 prompt) or '
|
| 112 |
+
'(1 image + multiple prompts) or (multiple images + multiple prompts) '
|
| 113 |
+
'with equal number of images and prompts. Got '
|
| 114 |
+
f'{n_images} images and {n_prompts} prompts'
|
| 115 |
+
)
|
| 116 |
+
sys.exit(1)
|
| 117 |
+
else:
|
| 118 |
+
if n_prompts > 1:
|
| 119 |
+
log.error(
|
| 120 |
+
'Non-batch mode requires 1+ images and 1 prompt. Got '
|
| 121 |
+
f'{n_images} images and {n_prompts} prompts'
|
| 122 |
+
)
|
| 123 |
+
sys.exit(1)
|
| 124 |
+
examples = [(images, prompts[0])]
|
| 125 |
+
|
| 126 |
+
conversations = []
|
| 127 |
+
allimages = []
|
| 128 |
+
allprompts = []
|
| 129 |
+
ordinals = [
|
| 130 |
+
'first', 'second', 'third', 'fourth', 'fifth',
|
| 131 |
+
'sixth', 'seventh', 'eighth', 'ninth', 'tenth',
|
| 132 |
+
]
|
| 133 |
+
for images, prompt in examples:
|
| 134 |
+
content = []
|
| 135 |
+
allimages.append(images)
|
| 136 |
+
allprompts.append(prompt)
|
| 137 |
+
if prompt_first:
|
| 138 |
+
content.append({'type': 'text', 'text': prompt})
|
| 139 |
+
if len(images) > 1 and image_labels:
|
| 140 |
+
for idx, img in enumerate(images):
|
| 141 |
+
ordinal = ordinals[idx] if idx < len(ordinals) else f'{idx+1}th'
|
| 142 |
+
image = images[idx]
|
| 143 |
+
descriptor = f'url: {image}'
|
| 144 |
+
if os.path.isfile(image):
|
| 145 |
+
descriptor = f'filename: {os.path.basename(image)}'
|
| 146 |
+
content.append({
|
| 147 |
+
'type': 'text',
|
| 148 |
+
'text': f'(this is the {ordinal} image, {descriptor})',
|
| 149 |
+
})
|
| 150 |
+
content.append({'type': 'image', 'image': img})
|
| 151 |
+
else:
|
| 152 |
+
content.extend([{'type': 'image', 'image': image} for image in images])
|
| 153 |
+
if not prompt_first:
|
| 154 |
+
content.append({'type': 'text', 'text': prompt})
|
| 155 |
+
conversations.append([{'role': 'user', 'content': content}])
|
| 156 |
+
|
| 157 |
+
return conversations, allimages, allprompts
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _token_usage_report(
|
| 161 |
+
inputs: Dict[str, Any],
|
| 162 |
+
n_images: int,
|
| 163 |
+
max_sequence_length: int,
|
| 164 |
+
special_image_token_ids: Dict[str, int],
|
| 165 |
+
):
|
| 166 |
+
"""Report token usage statistics in tree format."""
|
| 167 |
+
input_ids = inputs['input_ids']
|
| 168 |
+
attention_mask = inputs['attention_mask']
|
| 169 |
+
|
| 170 |
+
# Total tokens in sequence (non-padding)
|
| 171 |
+
total_tokens = attention_mask.sum().item()
|
| 172 |
+
|
| 173 |
+
# Count ALL image-related tokens directly from input_ids
|
| 174 |
+
image_patch_id = special_image_token_ids['image_patch_token_id']
|
| 175 |
+
image_start_id = special_image_token_ids['image_start_token_id']
|
| 176 |
+
image_end_id = special_image_token_ids['image_end_token_id']
|
| 177 |
+
image_column_token_id = special_image_token_ids['image_column_token_id']
|
| 178 |
+
|
| 179 |
+
num_patch = (input_ids == image_patch_id).sum().item()
|
| 180 |
+
num_start = (input_ids == image_start_id).sum().item()
|
| 181 |
+
num_end = (input_ids == image_end_id).sum().item()
|
| 182 |
+
num_col = (input_ids == image_column_token_id).sum().item()
|
| 183 |
+
|
| 184 |
+
# Total image tokens = all image-related special tokens
|
| 185 |
+
total_image_tokens = num_patch + num_start + num_end + num_col
|
| 186 |
+
|
| 187 |
+
# Pure text tokens (excluding all image-related tokens)
|
| 188 |
+
text_token_count = total_tokens - total_image_tokens
|
| 189 |
+
|
| 190 |
+
report = [
|
| 191 |
+
f'Input Context Window Layout (max: {max_sequence_length} tokens):',
|
| 192 |
+
f'├── Total: {total_tokens} tokens '
|
| 193 |
+
f'({((total_tokens / max_sequence_length) * 100):.1f}%)',
|
| 194 |
+
]
|
| 195 |
+
# Count tokens per image by finding img_start and img_end boundaries
|
| 196 |
+
# Each image is delimited by img_start and img_end tokens
|
| 197 |
+
tokens_per_image_list = []
|
| 198 |
+
|
| 199 |
+
# Find all img_start and img_end positions in input_ids
|
| 200 |
+
start_positions = (input_ids == image_start_id).nonzero(
|
| 201 |
+
as_tuple=True
|
| 202 |
+
)[0].tolist()
|
| 203 |
+
end_positions = (input_ids == image_end_id).nonzero(as_tuple=True)[0].tolist()
|
| 204 |
+
|
| 205 |
+
if len(start_positions) > 0 and len(end_positions) > 0:
|
| 206 |
+
# Each image typically has 2 start and 2 end tokens
|
| 207 |
+
# Determine actual number of images in context
|
| 208 |
+
n_starts_per_image = 2 # typical case
|
| 209 |
+
n_images_in_context = len(start_positions) // n_starts_per_image
|
| 210 |
+
|
| 211 |
+
# Warn if not all images fit in context
|
| 212 |
+
if n_images_in_context < n_images:
|
| 213 |
+
log.warning(
|
| 214 |
+
f'Only {n_images_in_context}/{n_images} images fit in context window'
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
for idx in range(n_images):
|
| 218 |
+
if idx < n_images_in_context:
|
| 219 |
+
# Get the start and end indices for this image
|
| 220 |
+
start_idx_begin = idx * n_starts_per_image
|
| 221 |
+
end_idx_end = (idx + 1) * n_starts_per_image
|
| 222 |
+
if (
|
| 223 |
+
start_idx_begin < len(start_positions) and
|
| 224 |
+
end_idx_end <= len(end_positions)
|
| 225 |
+
):
|
| 226 |
+
# First start position and last end position define the image span
|
| 227 |
+
first_start = start_positions[start_idx_begin]
|
| 228 |
+
last_end = end_positions[end_idx_end - 1]
|
| 229 |
+
# Count tokens from first start to last end (inclusive)
|
| 230 |
+
num_tokens = last_end - first_start + 1
|
| 231 |
+
tokens_per_image_list.append(num_tokens)
|
| 232 |
+
else:
|
| 233 |
+
tokens_per_image_list.append(0)
|
| 234 |
+
else:
|
| 235 |
+
# Image didn't fit in context
|
| 236 |
+
tokens_per_image_list.append(0)
|
| 237 |
+
else:
|
| 238 |
+
# Fallback to uniform division if we can't find boundaries
|
| 239 |
+
tokens_per_image = total_image_tokens // n_images if n_images > 0 else 0
|
| 240 |
+
tokens_per_image_list = [tokens_per_image] * n_images
|
| 241 |
+
|
| 242 |
+
for idx in range(n_images):
|
| 243 |
+
n_tokens = tokens_per_image_list[idx] if idx < len(tokens_per_image_list) else 0
|
| 244 |
+
pct = (n_tokens / max_sequence_length * 100)
|
| 245 |
+
report.append(f'├── Image {idx + 1} → {n_tokens} tokens ({pct:.1f}%)')
|
| 246 |
+
|
| 247 |
+
text_pct = (text_token_count / max_sequence_length * 100)
|
| 248 |
+
report.append(f'└── Text: {text_token_count} tokens ({text_pct:.1f}%)')
|
| 249 |
+
|
| 250 |
+
return '\n'.join(report)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def test_jvlm():
|
| 254 |
+
parser = argparse.ArgumentParser(
|
| 255 |
+
description='jina-vlm-v1 vision-language model inference.'
|
| 256 |
+
)
|
| 257 |
+
parser.add_argument(
|
| 258 |
+
'-m',
|
| 259 |
+
'--model',
|
| 260 |
+
default='.',
|
| 261 |
+
help=(
|
| 262 |
+
'Model path. Set this to "jinaai/jina-vlm-v1" if you are running this '
|
| 263 |
+
'script outside this repo.'
|
| 264 |
+
)
|
| 265 |
+
)
|
| 266 |
+
parser.add_argument(
|
| 267 |
+
'-i',
|
| 268 |
+
'--image',
|
| 269 |
+
action='append',
|
| 270 |
+
help='Image path or glob pattern (can specify multiple times, e.g., "*.jpg").'
|
| 271 |
+
)
|
| 272 |
+
parser.add_argument(
|
| 273 |
+
'-p',
|
| 274 |
+
'--prompt',
|
| 275 |
+
action='append',
|
| 276 |
+
help='Text prompt (can specify multiple times with --map).',
|
| 277 |
+
)
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
'--max-crops',
|
| 280 |
+
type=int,
|
| 281 |
+
default=12,
|
| 282 |
+
help='Maximum crops (default: 12).',
|
| 283 |
+
)
|
| 284 |
+
parser.add_argument(
|
| 285 |
+
'--max-tokens',
|
| 286 |
+
type=int,
|
| 287 |
+
default=1024,
|
| 288 |
+
help='Maximum output tokens (default: 1024).',
|
| 289 |
+
)
|
| 290 |
+
parser.add_argument(
|
| 291 |
+
'--max-pixels',
|
| 292 |
+
type=int,
|
| 293 |
+
default=None,
|
| 294 |
+
help=(
|
| 295 |
+
'Max pixels per image, bigger images are resized and the aspect ratio is '
|
| 296 |
+
'preserved (default: None).'
|
| 297 |
+
),
|
| 298 |
+
)
|
| 299 |
+
parser.add_argument(
|
| 300 |
+
'--no-stream',
|
| 301 |
+
action='store_true',
|
| 302 |
+
help='Disable streaming (default: stream token-by-token)',
|
| 303 |
+
)
|
| 304 |
+
parser.add_argument(
|
| 305 |
+
'--image-labels',
|
| 306 |
+
action='store_true',
|
| 307 |
+
help=(
|
| 308 |
+
'Enable ordinal text labels after each image '
|
| 309 |
+
'(default: no image labels for multi-image)'
|
| 310 |
+
),
|
| 311 |
+
)
|
| 312 |
+
parser.add_argument(
|
| 313 |
+
'--prompt-first',
|
| 314 |
+
action='store_true',
|
| 315 |
+
help=(
|
| 316 |
+
'Place prompt before images instead of after (default: prompt after images)'
|
| 317 |
+
),
|
| 318 |
+
)
|
| 319 |
+
parser.add_argument(
|
| 320 |
+
'--batched',
|
| 321 |
+
action='store_true',
|
| 322 |
+
help=(
|
| 323 |
+
'Batch mode: apply single prompt to multiple images (or single image to '
|
| 324 |
+
'multiple prompts) with KV cache reuse.'
|
| 325 |
+
),
|
| 326 |
+
)
|
| 327 |
+
args = parser.parse_args()
|
| 328 |
+
|
| 329 |
+
print()
|
| 330 |
+
print('Welcome to the jinaai/jina-vlm-v1 playground ✨')
|
| 331 |
+
print('Use this script to test our model!')
|
| 332 |
+
print('- Jina AI')
|
| 333 |
+
print()
|
| 334 |
+
print('--- Loading the model ...')
|
| 335 |
+
print('Specifying device, dtype and attention implementation ...')
|
| 336 |
+
device, dtype, attn_implementation = _resolve_device_dtype_and_attn()
|
| 337 |
+
print(f'Using attention implementation: {attn_implementation}')
|
| 338 |
+
print(f'Using device: {device}')
|
| 339 |
+
print(f'Using dtype: {dtype}')
|
| 340 |
+
print('Model path: ', args.model)
|
| 341 |
+
processor = AutoProcessor.from_pretrained(
|
| 342 |
+
args.model, trust_remote_code=True, use_fast=False,
|
| 343 |
+
)
|
| 344 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 345 |
+
args.model,
|
| 346 |
+
trust_remote_code=True,
|
| 347 |
+
dtype=dtype,
|
| 348 |
+
low_cpu_mem_usage=True,
|
| 349 |
+
device_map=device.type,
|
| 350 |
+
attn_implementation=attn_implementation,
|
| 351 |
+
)
|
| 352 |
+
max_sequence_length = getattr(model.config, 'max_sequence_length', 40960)
|
| 353 |
+
n_params = sum(p.numel() for p in model.parameters())
|
| 354 |
+
print(f'Max sequence length: {max_sequence_length}')
|
| 355 |
+
print(f'Number of parameters: {n_params}')
|
| 356 |
+
print('Done ✅')
|
| 357 |
+
print()
|
| 358 |
+
|
| 359 |
+
print('--- Let\'s create some conversations ...')
|
| 360 |
+
conversations, images, prompts = _build_conversations(
|
| 361 |
+
args.image,
|
| 362 |
+
args.prompt,
|
| 363 |
+
batched=args.batched,
|
| 364 |
+
prompt_first=args.prompt_first,
|
| 365 |
+
image_labels=args.image_labels
|
| 366 |
+
)
|
| 367 |
+
n_conversations = len(conversations)
|
| 368 |
+
print(f'Built {n_conversations} conversations 🚀')
|
| 369 |
+
print()
|
| 370 |
+
|
| 371 |
+
print('--- Transforming conversations to numbers ...')
|
| 372 |
+
timer = Timer()
|
| 373 |
+
with timer:
|
| 374 |
+
texts = processor.apply_chat_template(conversations, add_generation_prompt=True)
|
| 375 |
+
print(texts)
|
| 376 |
+
print(images)
|
| 377 |
+
inputs = processor(
|
| 378 |
+
text=texts,
|
| 379 |
+
images=images,
|
| 380 |
+
padding='longest',
|
| 381 |
+
max_length=max_sequence_length,
|
| 382 |
+
max_crops=args.max_crops,
|
| 383 |
+
max_pixels=args.max_pixels,
|
| 384 |
+
do_resize=True if args.max_pixels is not None else False,
|
| 385 |
+
return_tensors='pt',
|
| 386 |
+
)
|
| 387 |
+
print(inputs['images'])
|
| 388 |
+
print(inputs['image_input_idx'])
|
| 389 |
+
texts = texts if isinstance(texts, list) else [texts]
|
| 390 |
+
device_inputs = {}
|
| 391 |
+
for k, v in inputs.items():
|
| 392 |
+
if k == 'labels':
|
| 393 |
+
continue
|
| 394 |
+
if isinstance(v, torch.Tensor):
|
| 395 |
+
if v.is_floating_point():
|
| 396 |
+
device_inputs[k] = v.to(device, dtype=dtype, non_blocking=True)
|
| 397 |
+
else:
|
| 398 |
+
device_inputs[k] = v.to(device, non_blocking=True)
|
| 399 |
+
else:
|
| 400 |
+
device_inputs[k] = v
|
| 401 |
+
|
| 402 |
+
processing_time = timer.readout
|
| 403 |
+
special_image_token_ids = {
|
| 404 |
+
'image_patch_token_id': processor.image_patch_token_id,
|
| 405 |
+
'image_start_token_id': processor.image_start_token_id,
|
| 406 |
+
'image_end_token_id': processor.image_end_token_id,
|
| 407 |
+
'image_column_token_id': processor.image_column_token_id,
|
| 408 |
+
}
|
| 409 |
+
token_usage_reports = []
|
| 410 |
+
for idx in range(n_conversations):
|
| 411 |
+
ith_inputs = {k: v[idx] for k, v in inputs.items()}
|
| 412 |
+
token_usage_report = _token_usage_report(
|
| 413 |
+
ith_inputs,
|
| 414 |
+
len(images[idx]),
|
| 415 |
+
max_sequence_length=max_sequence_length,
|
| 416 |
+
special_image_token_ids=special_image_token_ids,
|
| 417 |
+
)
|
| 418 |
+
token_usage_reports.append(token_usage_report)
|
| 419 |
+
print(f'Processed {n_conversations} conversations in {processing_time}s')
|
| 420 |
+
print('All done 🪄')
|
| 421 |
+
print()
|
| 422 |
+
|
| 423 |
+
print('--- Running inference ...')
|
| 424 |
+
generated_tokens = 0
|
| 425 |
+
input_prompts = inputs['input_ids']
|
| 426 |
+
|
| 427 |
+
if args.no_stream:
|
| 428 |
+
print('Non-streaming mode')
|
| 429 |
+
print('Inference will run in a batch')
|
| 430 |
+
print()
|
| 431 |
+
|
| 432 |
+
with (
|
| 433 |
+
timer,
|
| 434 |
+
torch.no_grad(),
|
| 435 |
+
torch.autocast(device.type, enabled=(device.type != 'mps'), dtype=dtype),
|
| 436 |
+
):
|
| 437 |
+
output = model.generate(
|
| 438 |
+
**device_inputs,
|
| 439 |
+
generation_config=GenerationConfig(
|
| 440 |
+
max_new_tokens=args.max_tokens, do_sample=False,
|
| 441 |
+
),
|
| 442 |
+
return_dict_in_generate=True,
|
| 443 |
+
use_model_defaults=True,
|
| 444 |
+
)
|
| 445 |
+
generation_time, generation_readout = timer.time, timer.readout
|
| 446 |
+
|
| 447 |
+
for idx in range(n_conversations):
|
| 448 |
+
out = output.sequences[idx][len(input_prompts[idx].tolist()):]
|
| 449 |
+
generated_tokens += len(out)
|
| 450 |
+
response = processor.tokenizer.decode(out, skip_special_tokens=True)
|
| 451 |
+
print(f'* Conversation {idx + 1}/{n_conversations}')
|
| 452 |
+
print(f'├── 🖼️Images: {images[idx]}')
|
| 453 |
+
print(f'├── 📜Prompt: {prompts[idx]}')
|
| 454 |
+
print(f'├── 💬Chat: {texts[idx]}')
|
| 455 |
+
print(f'└── 🧠Response:{response}')
|
| 456 |
+
print('Token usage report:')
|
| 457 |
+
print(token_usage_reports[idx])
|
| 458 |
+
print()
|
| 459 |
+
else:
|
| 460 |
+
print('Streaming mode')
|
| 461 |
+
print('Inference will run sequentially')
|
| 462 |
+
print()
|
| 463 |
+
|
| 464 |
+
streamer = TextStreamer(
|
| 465 |
+
processor.tokenizer, skip_prompt=True, skip_special_tokens=True
|
| 466 |
+
)
|
| 467 |
+
for idx in range(n_conversations):
|
| 468 |
+
print(f'* Conversation {idx + 1}/{n_conversations}')
|
| 469 |
+
print(f'├── 🖼️Images: {images[idx]}')
|
| 470 |
+
print(f'├── 📜Prompt: {prompts[idx]}')
|
| 471 |
+
print(f'├── 💬Chat: {texts[idx]}')
|
| 472 |
+
print(f'└── 🧠Response:', end='')
|
| 473 |
+
ith_inputs = {k: v[idx].unsqueeze(0) for k, v in device_inputs.items()}
|
| 474 |
+
with (
|
| 475 |
+
timer,
|
| 476 |
+
torch.no_grad(),
|
| 477 |
+
torch.autocast(device.type, enabled=(device.type != 'mps'), dtype=dtype)
|
| 478 |
+
):
|
| 479 |
+
output = model.generate(
|
| 480 |
+
**ith_inputs,
|
| 481 |
+
streamer=streamer,
|
| 482 |
+
generation_config=GenerationConfig(
|
| 483 |
+
max_new_tokens=args.max_tokens, do_sample=False,
|
| 484 |
+
),
|
| 485 |
+
return_dict_in_generate=True,
|
| 486 |
+
use_model_defaults=True,
|
| 487 |
+
)
|
| 488 |
+
out = output.sequences[0][len(input_prompts[idx].tolist()):]
|
| 489 |
+
generated_tokens += len(out)
|
| 490 |
+
print('Token usage report:')
|
| 491 |
+
print(token_usage_reports[idx])
|
| 492 |
+
print()
|
| 493 |
+
|
| 494 |
+
generation_time, generation_readout = timer.time, timer.readout
|
| 495 |
+
|
| 496 |
+
res_per_sec = n_conversations / generation_time if generation_time > 0 else 0
|
| 497 |
+
tok_per_sec = generated_tokens / generation_time if generation_time > 0 else 0
|
| 498 |
+
print(f'Generated {n_conversations} responses in {generation_readout}s')
|
| 499 |
+
print(f'{res_per_sec:.2f} res/s {tok_per_sec:.2f} tok/s')
|
| 500 |
+
print('Done ✅')
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
if __name__ == '__main__':
|
| 504 |
+
test_jvlm()
|