| | from transformers import BaseImageProcessor, ImageProcessingMixin |
| | from transformers.processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs |
| | import math |
| | from typing import Iterable, Optional, Tuple, List, TypedDict, Literal, Union, overload |
| |
|
| | from PIL import Image |
| | import torch |
| | import numpy as np |
| | import torchvision |
| | from torch import nn |
| | from torch.nn import functional as F, LayerNorm |
| | from torchvision.transforms.functional import InterpolationMode |
| | from transformers.activations import ACT2FN |
| | from torchvision import transforms |
| | from torchvision.transforms.functional import InterpolationMode |
| | from transformers.feature_extraction_utils import BatchFeature, TensorType |
| | from transformers.image_utils import ImageInput |
| | from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack |
| | from math import ceil |
| | from itertools import product |
| |
|
| |
|
| |
|
| | MAX_IMAGE_SIZE: int = 3024 |
| |
|
| | class Step3VLImagePixelInputs(TypedDict): |
| | type: Literal["pixel_values"] |
| | pixel_values: torch.Tensor |
| | patch_pixel_values: Optional[torch.Tensor] |
| | num_patches: list[int] |
| |
|
| |
|
| | class Step3VLImageEmbeddingInputs(TypedDict): |
| | type: Literal["image_embeds"] |
| | image_embeds: torch.Tensor |
| |
|
| |
|
| | ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None] |
| |
|
| |
|
| | class GPUToTensor(torch.nn.Module): |
| |
|
| | def forward(self, raw_image: Union[np.ndarray, |
| | Image.Image]) -> torch.Tensor: |
| | if isinstance(raw_image, Image.Image): |
| | return transforms.ToTensor()(raw_image) |
| | if raw_image.ndim == 2: |
| | raw_image = raw_image[:, :, None].repeat(3, -1) |
| | if torch.cuda.is_available(): |
| | device = torch.device("cuda") |
| | else: |
| | device = torch.device("cpu") |
| | image_tensor = torch.from_numpy(raw_image).to(device) |
| | image_tensor = torch.permute(image_tensor, (2, 0, 1)).contiguous() |
| | if image_tensor.dtype == torch.uint8: |
| | image_tensor = image_tensor.to(torch.float32).div(255) |
| | return image_tensor |
| |
|
| | class Step3VisionProcessor(BaseImageProcessor): |
| |
|
| | def __init__(self, size, interpolation_mode="bicubic", patch_size=None): |
| | mean = [0.48145466, 0.4578275, 0.40821073] |
| | std = [0.26862954, 0.26130258, 0.27577711] |
| | patch_size = patch_size if patch_size is not None else size |
| |
|
| | self.transform = transforms.Compose([ |
| | GPUToTensor(), |
| | transforms.Normalize(mean, std), |
| | transforms.Resize( |
| | (size, size), |
| | interpolation=InterpolationMode.BICUBIC if interpolation_mode |
| | == "bicubic" else InterpolationMode.BILINEAR, |
| | antialias=True), |
| | ]) |
| |
|
| | self.patch_transform = transforms.Compose([ |
| | GPUToTensor(), |
| | transforms.Normalize(mean, std), |
| | transforms.Resize( |
| | (patch_size, patch_size), |
| | interpolation=InterpolationMode.BICUBIC if interpolation_mode |
| | == "bicubic" else InterpolationMode.BILINEAR, |
| | antialias=True), |
| | ]) if patch_size is not None else None |
| |
|
| | def __call__(self, image, is_patch=False): |
| | if is_patch: |
| | return {"pixel_values": self.patch_transform(image).unsqueeze(0)} |
| | else: |
| | return {"pixel_values": self.transform(image).unsqueeze(0)} |
| |
|
| | class ImagePatcher: |
| | def determine_window_size(self, long: int, short: int) -> int: |
| | if long <= 728: |
| | return short if long / short > 1.5 else 0 |
| | return min(short, 504) if long / short > 4 else 504 |
| | def slide_window( |
| | self, |
| | width: int, |
| | height: int, |
| | sizes: list[tuple[int, int]], |
| | steps: list[tuple[int, int]], |
| | img_rate_thr: float = 0.6, |
| | ) -> tuple[list[tuple[int, int, int, int]], tuple[int, int]]: |
| | assert 1 >= img_rate_thr >= 0, "The `in_rate_thr` should lie in 0~1" |
| | windows = [] |
| | |
| | for size, step in zip(sizes, steps): |
| | size_w, size_h = size |
| | step_w, step_h = step |
| |
|
| | x_num = 1 if width <= size_w else ceil((width - size_w) / step_w + |
| | 1) |
| | x_start = [step_w * i for i in range(x_num)] |
| | if len(x_start) > 1 and x_start[-1] + size_w > width: |
| | x_start[-1] = width - size_w |
| |
|
| | y_num = 1 if height <= size_h else ceil((height - size_h) / |
| | step_h + 1) |
| | y_start = [step_h * i for i in range(y_num)] |
| | if len(y_start) > 1 and y_start[-1] + size_h > height: |
| | y_start[-1] = height - size_h |
| |
|
| | start = np.array(list(product(y_start, x_start)), dtype=int) |
| | start[:, [0, 1]] = start[:, [1, 0]] |
| | windows.append(np.concatenate([start, start + size], axis=1)) |
| | windows = np.concatenate(windows, axis=0) |
| |
|
| | return [(int(box[0]), int(box[1]), int(box[2] - box[0]), |
| | int(box[3] - box[1])) for box in windows], (x_num, y_num) |
| |
|
| | def square_pad(self, img: Image.Image) -> Image.Image: |
| | w, h = img.size |
| | if w == h: |
| | return img |
| | size = max(w, h) |
| | padded = Image.new(img.mode, (size, size), 0) |
| | padded.paste(img, (0, 0)) |
| | return padded |
| |
|
| | def get_image_size_for_padding(self, img_width: int, |
| | img_height: int) -> tuple[int, int]: |
| | ratio = img_width / img_height |
| | if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4): |
| | new_size = max(img_height, img_width) |
| | return new_size, new_size |
| | return img_width, img_height |
| |
|
| | def get_image_size_for_preprocess(self, img_width: int, |
| | img_height: int) -> tuple[int, int]: |
| |
|
| | if max(img_height, img_width) > MAX_IMAGE_SIZE: |
| | scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width) |
| | img_width = int(img_width * scale_factor) |
| | img_height = int(img_height * scale_factor) |
| | return img_width, img_height |
| |
|
| | def get_image_size_for_crop(self, img_width: int, img_height: int, |
| | window_size: int): |
| | w_ratio = img_width / window_size |
| | h_ratio = img_height / window_size |
| |
|
| | if w_ratio < 1: |
| | width_new = img_width |
| | else: |
| | decimal_w = w_ratio - img_width // window_size |
| | w_ratio = int(w_ratio) + 1 if decimal_w > 0.2 else int(w_ratio) |
| | width_new = window_size * w_ratio |
| | if h_ratio < 1: |
| | height_new = img_height |
| | else: |
| | decimal_h = h_ratio - img_height // window_size |
| | h_ratio = int(h_ratio) + 1 if decimal_h > 0.2 else int(h_ratio) |
| | height_new = window_size * h_ratio |
| | return int(width_new), int(height_new) |
| |
|
| | def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int): |
| | target = img.crop((j, i, j + tw, i + th)) |
| | return target |
| |
|
| | def get_num_patches(self, img_width: int, |
| | img_height: int) -> tuple[int, int]: |
| | img_width, img_height = self.get_image_size_for_padding( |
| | img_width, img_height) |
| | img_width, img_height = self.get_image_size_for_preprocess( |
| | img_width, img_height) |
| | window_size = self.determine_window_size(max(img_height, img_width), |
| | min(img_height, img_width)) |
| | if window_size == 0: |
| | return 0, 0 |
| | else: |
| | img_width, img_height = self.get_image_size_for_crop( |
| | img_width, img_height, window_size) |
| | center_list, (x_num, y_num) = self.slide_window( |
| | img_width, img_height, [(window_size, window_size)], |
| | [(window_size, window_size)]) |
| | full_rows = (len(center_list) - 1) // x_num + 1 |
| | if len(center_list) > 0 and len(center_list) % x_num == 0: |
| | full_rows -= 1 |
| | return len(center_list), full_rows |
| |
|
| | def __call__( |
| | self, img: Image.Image |
| | ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]: |
| | img_width, img_height = img.size |
| | new_img_width, new_img_height = self.get_image_size_for_padding( |
| | img_width, img_height) |
| | if new_img_width != img_width or new_img_height != img_height: |
| | img = self.square_pad(img) |
| | img_width, img_height = img.size |
| |
|
| | new_img_width, new_img_height = self.get_image_size_for_preprocess( |
| | img_width, img_height) |
| | img = img.resize((new_img_width, new_img_height), |
| | Image.Resampling.BILINEAR) |
| | window_size = self.determine_window_size( |
| | max(new_img_height, new_img_width), |
| | min(new_img_height, new_img_width)) |
| |
|
| | if window_size == 0: |
| | return img, [], None |
| | else: |
| | new_img_width, new_img_height = self.get_image_size_for_crop( |
| | new_img_width, new_img_height, window_size) |
| | if (new_img_width, new_img_height) != (img_width, img_height): |
| | img_for_crop = img.resize((new_img_width, new_img_height), |
| | Image.Resampling.BILINEAR) |
| | else: |
| | img_for_crop = img |
| |
|
| | patches = [] |
| | newlines = [] |
| | center_list, (x_num, y_num) = self.slide_window( |
| | new_img_width, new_img_height, [(window_size, window_size)], |
| | [(window_size, window_size)]) |
| | for patch_id, center_lf_point in enumerate(center_list): |
| | x, y, patch_w, patch_h = center_lf_point |
| | big_patch = self.patch_crop(img_for_crop, y, x, patch_h, |
| | patch_w) |
| | patches.append(big_patch) |
| | if (patch_id + 1) % x_num == 0: |
| | newlines.append(patch_id) |
| |
|
| | if newlines and newlines[-1] == len(patches) - 1: |
| | newlines.pop() |
| |
|
| | return img, patches, [i in newlines for i in range(len(patches))] if len(patches) > 0 else None |
| |
|
| |
|
| |
|
| |
|
| | class Step3VLProcessor(ProcessorMixin): |
| | |
| | |
| | attributes = ["tokenizer"] |
| | tokenizer_class = "AutoTokenizer" |
| |
|
| | def __init__( |
| | self, |
| | tokenizer=None, |
| | chat_template=None, |
| | **kwargs |
| | ) -> None: |
| | self.image_size = 728 |
| | self.patch_size = 504 |
| |
|
| | self.image_preprocessor = Step3VisionProcessor(self.image_size, |
| | "bilinear", |
| | self.patch_size) |
| |
|
| | self.num_image_feature_size = 169 |
| | self.num_patch_feature_size = 81 |
| | self.image_token = "<im_patch>" |
| | self.image_feature_placeholder = (self.image_token * |
| | self.num_image_feature_size) |
| | self.patch_feature_placeholder = (self.image_token * |
| | self.num_patch_feature_size) |
| | super().__init__(tokenizer=tokenizer, chat_template=chat_template, **kwargs) |
| | self.patcher = ImagePatcher() |
| | |
| | @property |
| | def image_token_id(self) -> int: |
| | return self.tokenizer.get_vocab()[self.image_token] |
| |
|
| | def get_num_image_tokens(self, img_width: int, img_height: int) -> int: |
| | num_patches, num_newlines = self.patcher.get_num_patches( |
| | img_width, img_height) |
| |
|
| | return num_patches * ( |
| | self.num_patch_feature_size + |
| | 2) + self.num_image_feature_size + 2 + num_newlines |
| |
|
| | def _split_images(self, |
| | images: list[Image.Image]) -> list[ImageWithPatches]: |
| | result = [] |
| | for img in images: |
| | result.append(self.patcher(img)) |
| | return result |
| |
|
| | def _convert_images_to_pixel_values( |
| | self, |
| | images: list[Image.Image], |
| | is_patch: bool = False, |
| | ) -> list[torch.Tensor]: |
| | return [ |
| | self.image_preprocessor(img, is_patch=is_patch)["pixel_values"] |
| | for img in images |
| | ] |
| |
|
| | def _get_patch_repl( |
| | self, |
| | num_patches: int, |
| | patch_newline_mask: list[bool] | None, |
| | ) -> tuple[str, list[int]]: |
| | text = "" |
| | token_ids = [] |
| | for i in range(num_patches): |
| | assert len(patch_newline_mask) == num_patches |
| | text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>" |
| | token_ids.extend( |
| | [self.tokenizer.convert_tokens_to_ids("<patch_start>")] + |
| | [self.image_token_id] * self.num_patch_feature_size + |
| | [self.tokenizer.convert_tokens_to_ids("<patch_end>")]) |
| | if patch_newline_mask and patch_newline_mask[i]: |
| | text += "<patch_newline>" |
| | token_ids.append( |
| | self.tokenizer.convert_tokens_to_ids("<patch_newline>")) |
| | return text, token_ids |
| |
|
| | def _get_image_repl( |
| | self, |
| | num_images: int, |
| | ) -> tuple[str, list[int]]: |
| | text = f"<im_start>{self.image_feature_placeholder}<im_end>" |
| | token_ids = [ |
| | self.tokenizer.convert_tokens_to_ids("<im_start>") |
| | ] + [self.image_token_id] * self.num_image_feature_size + [ |
| | self.tokenizer.convert_tokens_to_ids("<im_end>") |
| | ] |
| | return text * num_images, token_ids * num_images |
| |
|
| | def _get_image_repl_features( |
| | self, |
| | num_images: int, |
| | num_patches: int, |
| | patch_new_line_idx: Optional[list[bool]], |
| | ) -> tuple[str, list[int]]: |
| | if num_patches > 0: |
| | patch_repl, patch_repl_ids = self._get_patch_repl( |
| | num_patches, patch_new_line_idx) |
| | else: |
| | patch_repl = "" |
| | patch_repl_ids = [] |
| | image_repl, image_repl_ids = self._get_image_repl(num_images) |
| | return patch_repl + image_repl, patch_repl_ids + image_repl_ids |
| |
|
| | def replace_placeholder(self, text: str, placeholder: str, |
| | repls: list[str]) -> str: |
| | parts = text.split(placeholder) |
| |
|
| | if len(parts) - 1 != len(repls): |
| | raise ValueError( |
| | "The number of placeholders does not match the number of replacements." |
| | ) |
| |
|
| | result = [parts[0]] |
| | for i, repl in enumerate(repls): |
| | result.append(repl) |
| | result.append(parts[i + 1]) |
| |
|
| | return "".join(result) |
| |
|
| | def __call__( |
| | self, |
| | text: Optional[Union[str, list[str]]] = None, |
| | images: ImageInput | None = None, |
| | return_tensors: Optional[Union[str, TensorType]] = None, |
| | **kwargs, |
| | ) -> BatchFeature: |
| |
|
| | if images is not None: |
| | images = self.image_preprocessor.fetch_images(images) |
| | if text is None: |
| | text = [] |
| | if not isinstance(text, list): |
| | text = [text] |
| | if images is None: |
| | images = [] |
| | elif not isinstance(images, list): |
| | images = [images] |
| | elif isinstance(images[0], list): |
| | images = images[0] |
| |
|
| | if len(images) == 0: |
| | image_inputs = {} |
| | text_inputs = self.tokenizer(text) |
| | else: |
| | splitted_images_data = self._split_images(images) |
| | pixel_values_lst = [] |
| | patch_pixel_values_lst = [] |
| | patch_newline_mask_lst = [] |
| | image_repl_str_lst = [] |
| | image_repl_ids_lst = [] |
| | num_patches = [] |
| | for raw_img, img_patches, patch_newline_mask in splitted_images_data: |
| | pixel_values_lst.extend( |
| | self._convert_images_to_pixel_values([raw_img])) |
| |
|
| | if len(img_patches) > 0: |
| | patch_pixel_values_lst.extend( |
| | self._convert_images_to_pixel_values(img_patches, |
| | is_patch=True)) |
| | num_patches.append(len(img_patches)) |
| |
|
| | image_repl_str, image_repl_ids = self._get_image_repl_features( |
| | 1, len(img_patches), patch_newline_mask) |
| | image_repl_str_lst.append(image_repl_str) |
| | image_repl_ids_lst.extend(image_repl_ids) |
| |
|
| | if patch_newline_mask is not None: |
| | patch_newline_mask_lst.extend(patch_newline_mask) |
| |
|
| | image_inputs = { |
| | "pixel_values": torch.cat(pixel_values_lst), |
| | "num_patches": num_patches, |
| | } |
| | if patch_pixel_values_lst: |
| | image_inputs["patch_pixel_values"] = torch.cat( |
| | patch_pixel_values_lst) |
| | if patch_newline_mask_lst: |
| | image_inputs["patch_newline_mask"] = torch.tensor( |
| | patch_newline_mask_lst, dtype=torch.bool) |
| |
|
| | text = [ |
| | self.replace_placeholder(t, self.image_token, |
| | image_repl_str_lst) for t in text |
| | ] |
| | text_inputs = self.tokenizer(text) |
| |
|
| | return BatchFeature( |
| | { |
| | **text_inputs, |
| | **image_inputs, |
| | }, |
| | tensor_type=return_tensors, |
| | ) |
| | |
| | |
| | def batch_decode(self, *args, **kwargs): |
| | """ |
| | This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please |
| | refer to the docstring of this method for more information. |
| | """ |
| | return self.tokenizer.batch_decode(*args, **kwargs) |
| |
|
| | |
| | def decode(self, *args, **kwargs): |
| | """ |
| | This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to |
| | the docstring of this method for more information. |
| | """ |
| | return self.tokenizer.decode(*args, **kwargs) |
| | |
| | __all__ = ["Step3VLProcessor"] |
| |
|