Spaces:
Runtime error
Runtime error
| import math | |
| import os | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import decord | |
| import natsort | |
| from vita_audio.constants import ( | |
| IMAGENET_DEFAULT_MEAN, | |
| IMAGENET_DEFAULT_STD, | |
| IMAGENET_STANDARD_MEAN, | |
| IMAGENET_STANDARD_STD, | |
| OPENAI_CLIP_MEAN, | |
| OPENAI_CLIP_STD, | |
| ) | |
| class ImageProcessor: | |
| def __init__( | |
| self, | |
| process_type, | |
| image_size=448, | |
| normalize_type="imagenet", | |
| min_patch_grid=1, | |
| max_patch_grid=6, | |
| ): | |
| self.process_type = process_type | |
| self.image_size = image_size | |
| if normalize_type == "imagenet": | |
| MEAN, STD = IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| elif normalize_type == "clip": | |
| MEAN, STD = OPENAI_CLIP_MEAN, OPENAI_CLIP_STD | |
| elif normalize_type == "siglip": | |
| MEAN, STD = IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD | |
| else: | |
| raise NotImplementedError | |
| self.mean = MEAN | |
| self.std = STD | |
| self.patch_size = image_size | |
| self.min_patch_grid = min_patch_grid | |
| self.max_patch_grid = max_patch_grid | |
| if self.process_type == "anyres": | |
| self.grid_pinpoints = [ | |
| (i, j) | |
| for i in range(min_patch_grid, max_patch_grid + 1) | |
| for j in range(min_patch_grid, max_patch_grid + 1) | |
| ] | |
| self.possible_resolutions = [ | |
| [dim * self.patch_size for dim in pair] for pair in self.grid_pinpoints | |
| ] | |
| print(f"grid_pinpoints {self.grid_pinpoints}") | |
| print(f"possible_resolutions {self.possible_resolutions}") | |
| if self.process_type == "dynamic": | |
| max_num = self.max_patch_grid | |
| min_num = self.min_patch_grid | |
| # calculate the existing image aspect ratio | |
| target_ratios = set( | |
| (i, j) | |
| for n in range(min_num, max_num + 1) | |
| for i in range(1, n + 1) | |
| for j in range(1, n + 1) | |
| if i * j <= max_num and i * j >= min_num | |
| ) | |
| self.target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) | |
| self.possible_resolutions = [ | |
| [dim * self.patch_size for dim in pair] for pair in self.target_ratios | |
| ] | |
| print(f"target_ratios {self.target_ratios}") | |
| print(f"possible_resolutions {self.possible_resolutions}") | |
| def get_frame_paths(self, frame_root, num_frames=8): | |
| os.makedirs(frame_root, exist_ok=True) | |
| self.frame_tmpl = "frame-{}-of-{}.jpg" | |
| return [ | |
| os.path.join(frame_root, self.frame_tmpl.format(i, num_frames)) | |
| for i in range(1, num_frames + 1) | |
| ] | |
| def save_video_frames(self, vid_path, max_fps=1, num_frames=8): | |
| vid = decord.VideoReader(vid_path, num_threads=1) | |
| step_size = len(vid) / (num_frames + 1) | |
| # step_size = max(1, step_size) | |
| fps = vid.get_avg_fps() | |
| step_size = max(fps / max_fps, step_size) | |
| # indices = [int(i * step_size) for i in range(1, num_frames + 1)] | |
| indices = [int(i * step_size) for i in range(0, num_frames)] | |
| indices = [i for i in indices if i < len(vid)] | |
| num_frames = len(indices) | |
| frame_paths = self.get_frame_paths(vid_path + ".saved_frames", num_frames) | |
| flag = np.all([os.path.exists(p) for p in frame_paths]) | |
| if flag: | |
| return frame_paths | |
| images = [vid[i].asnumpy() for i in indices] | |
| images = [Image.fromarray(arr) for arr in images] | |
| for im, pth in zip(images, frame_paths): | |
| # if not os.path.exists(pth): | |
| # im.save(pth) | |
| im.save(pth) | |
| # print(f"save_video_frames vid_path {vid_path} fps {fps} len(vid) {len(vid)} frame_paths {frame_paths}") | |
| return frame_paths | |
| def get_video_frames(self, vid_path, max_fps=1, num_frames=8): | |
| vid = decord.VideoReader(vid_path, num_threads=1) | |
| step_size = len(vid) / (num_frames + 1) | |
| # step_size = max(1, step_size) | |
| fps = vid.get_avg_fps() | |
| step_size = max(fps / max_fps, step_size) | |
| # indices = [int(i * step_size) for i in range(1, num_frames + 1)] | |
| indices = [int(i * step_size) for i in range(0, num_frames)] | |
| indices = [i for i in indices if i < len(vid)] | |
| images = [vid[i].asnumpy() for i in indices] | |
| images = [Image.fromarray(arr) for arr in images] | |
| # print(f"save_video_frames vid_path {vid_path} fps {fps} len(vid) {len(vid)} frame_paths {frame_paths}") | |
| return images | |
| def process_video(self, video_file_or_dir, max_num_frame=8, max_fps=1): | |
| if os.path.isdir(video_file_or_dir): | |
| all_filepath = [] | |
| for root, dirs, files in os.walk(video_file_or_dir): | |
| for filename in files: | |
| if ( | |
| filename.endswith("png") | |
| or filename.endswith("jpeg") | |
| or filename.endswith("jpg") | |
| ): | |
| filepath = os.path.join(root, filename) | |
| all_filepath.append(filepath) | |
| if len(all_filepath) == 0: | |
| return None | |
| # all_filepath.sort() | |
| all_filepath = natsort.natsorted(all_filepath) | |
| total_frame = len(all_filepath) | |
| if "ShareGPTVideo" in video_file_or_dir: | |
| fps = 2 | |
| else: | |
| fps = 1 | |
| target_frame = int(min(total_frame / fps * max_fps, max_num_frame)) | |
| index = [int(1.0 * total_frame / target_frame) * x for x in range(target_frame)] | |
| selected_filepath = [all_filepath[x] for x in index] | |
| img_or_path_list = selected_filepath | |
| # print(f"process_video {img_or_path_list}") | |
| elif os.path.isfile(video_file_or_dir): | |
| # frame_paths = self.save_video_frames( | |
| # video_file_or_dir, num_frames=max_num_frame, max_fps=max_fps | |
| # ) | |
| # img_or_path_list = frame_paths | |
| img_or_path_list = self.get_video_frames( | |
| video_file_or_dir, num_frames=max_num_frame, max_fps=max_fps | |
| ) | |
| else: | |
| # print(f"FileNotFoundError {video_file_or_dir}") | |
| raise NotImplementedError | |
| return self.process_images(img_or_path_list), img_or_path_list | |
| def process_images(self, img_or_path_list): | |
| if isinstance(img_or_path_list[0], str): | |
| images = [Image.open(x).convert("RGB") for x in img_or_path_list] | |
| elif isinstance(img_or_path_list[0], Image.Image): | |
| images = [x.convert("RGB") for x in img_or_path_list] | |
| else: | |
| images = img_or_path_list | |
| def expand2square(pil_img, background_color): | |
| width, height = pil_img.size | |
| if width == height: | |
| return pil_img | |
| elif width > height: | |
| result = Image.new(pil_img.mode, (width, width), background_color) | |
| result.paste(pil_img, (0, (width - height) // 2)) | |
| return result | |
| else: | |
| result = Image.new(pil_img.mode, (height, height), background_color) | |
| result.paste(pil_img, ((height - width) // 2, 0)) | |
| return result | |
| image_tensor = torch.ones([len(images), 3, self.image_size, self.image_size]) | |
| for i, image in enumerate(images): | |
| image = expand2square(image, tuple(int(x * 255) for x in self.mean)) | |
| image = image.resize( | |
| (self.image_size, self.image_size), resample=Image.Resampling.BICUBIC | |
| ) | |
| image = np.array(image, dtype=np.float32) | |
| image = image * 1.0 / 255.0 | |
| mean = np.array(self.mean, dtype=image.dtype) | |
| std = np.array(self.std, dtype=image.dtype) | |
| image = (image - mean) / std | |
| image = torch.tensor(image, dtype=torch.float32) | |
| image = image.permute(2, 0, 1) | |
| image_tensor[i] = image | |
| return image_tensor | |
| def process_images_with_subpatch(self, img_or_path): | |
| if self.process_type == "anyres": | |
| return self.process_anyres(img_or_path) | |
| if self.process_type == "dynamic": | |
| return self.process_dynamic(img_or_path) | |
| if isinstance(img_or_path, str): | |
| image = Image.open(img_or_path).convert("RGB") | |
| elif isinstance(img_or_path, Image.Image): | |
| image = img_or_path.convert("RGB") | |
| else: | |
| image = img_or_path | |
| return self.process_images([images]) | |
| def process_anyres(self, img_or_path): | |
| if isinstance(img_or_path, str): | |
| image = Image.open(img_or_path).convert("RGB") | |
| elif isinstance(img_or_path, Image.Image): | |
| image = img_or_path.convert("RGB") | |
| else: | |
| image = img_or_path | |
| best_resolution = select_best_resolution(image.size, self.possible_resolutions) | |
| image_padded = resize_and_pad_image(image, best_resolution) | |
| patches = divide_to_patches(image_padded, self.patch_size) | |
| if best_resolution == (self.patch_size, self.patch_size): | |
| image_patches = [image] | |
| else: | |
| image_patches = [image] + patches | |
| image_patches = self.process_images(image_patches) | |
| # print(f"image {image.size} best_resolution {best_resolution} image_padded {image_padded.size} patches {len(patches)} image_patches {image_patches.size()}") | |
| return image_patches, best_resolution | |
| def process_dynamic(self, img_or_path): | |
| if isinstance(img_or_path, str): | |
| image = Image.open(img_or_path).convert("RGB") | |
| elif isinstance(img_or_path, Image.Image): | |
| image = img_or_path.convert("RGB") | |
| else: | |
| image = img_or_path | |
| image_patches, best_resolution = dynamic_preprocess( | |
| image, | |
| min_num=self.min_patch_grid, | |
| max_num=self.max_patch_grid, | |
| image_size=self.patch_size, | |
| use_thumbnail=True, | |
| ) | |
| image_patches = self.process_images(image_patches) | |
| # print(f"image {image.size} best_resolution {best_resolution} image_padded {image_padded.size} patches {len(patches)} image_patches {image_patches.size()}") | |
| return image_patches, best_resolution | |
| def select_best_resolution(original_size, possible_resolutions): | |
| """ | |
| Selects the best resolution from a list of possible resolutions based on the original size. | |
| Args: | |
| original_size (tuple): The original size of the image in the format (width, height). | |
| possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. | |
| Returns: | |
| tuple: The best fit resolution in the format (width, height). | |
| """ | |
| original_width, original_height = original_size | |
| best_fit = None | |
| max_effective_resolution = 0 | |
| min_wasted_resolution = float("inf") | |
| for width, height in possible_resolutions: | |
| # Calculate the downscaled size to keep the aspect ratio | |
| scale = min(width / original_width, height / original_height) | |
| downscaled_width, downscaled_height = int(original_width * scale), int( | |
| original_height * scale | |
| ) | |
| # Calculate effective and wasted resolutions | |
| effective_resolution = min( | |
| downscaled_width * downscaled_height, original_width * original_height | |
| ) | |
| wasted_resolution = (width * height) - effective_resolution | |
| if effective_resolution > max_effective_resolution or ( | |
| effective_resolution == max_effective_resolution | |
| and wasted_resolution < min_wasted_resolution | |
| ): | |
| max_effective_resolution = effective_resolution | |
| min_wasted_resolution = wasted_resolution | |
| best_fit = (width, height) | |
| return best_fit | |
| def resize_and_pad_image(image, target_resolution): | |
| """ | |
| Resize and pad an image to a target resolution while maintaining aspect ratio. | |
| Args: | |
| image (PIL.Image.Image): The input image. | |
| target_resolution (tuple): The target resolution (width, height) of the image. | |
| Returns: | |
| PIL.Image.Image: The resized and padded image. | |
| """ | |
| original_width, original_height = image.size | |
| target_width, target_height = target_resolution | |
| # Determine which dimension (width or height) to fill | |
| scale_w = target_width / original_width | |
| scale_h = target_height / original_height | |
| if scale_w < scale_h: | |
| # Width will be filled completely | |
| new_width = target_width | |
| new_height = min(math.ceil(original_height * scale_w), target_height) | |
| else: | |
| # Height will be filled completely | |
| new_height = target_height | |
| new_width = min(math.ceil(original_width * scale_h), target_width) | |
| # Resize the image | |
| resized_image = image.resize((new_width, new_height)) | |
| # Create a new image with the target size and paste the resized image onto it | |
| new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0)) | |
| paste_x = (target_width - new_width) // 2 | |
| paste_y = (target_height - new_height) // 2 | |
| new_image.paste(resized_image, (paste_x, paste_y)) | |
| return new_image | |
| def divide_to_patches(image, patch_size): | |
| """ | |
| Divides an image into patches of a specified size. | |
| Args: | |
| image (PIL.Image.Image): The input image. | |
| patch_size (int): The size of each patch. | |
| Returns: | |
| list: A list of PIL.Image.Image objects representing the patches. | |
| """ | |
| patches = [] | |
| width, height = image.size | |
| for i in range(0, height, patch_size): | |
| for j in range(0, width, patch_size): | |
| box = (j, i, j + patch_size, i + patch_size) | |
| patch = image.crop(box) | |
| patches.append(patch) | |
| return patches | |
| def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): | |
| best_ratio_diff = float("inf") | |
| best_ratio = (1, 1) | |
| area = width * height | |
| for ratio in target_ratios: | |
| target_aspect_ratio = ratio[0] / ratio[1] | |
| ratio_diff = abs(aspect_ratio - target_aspect_ratio) | |
| if ratio_diff < best_ratio_diff: | |
| best_ratio_diff = ratio_diff | |
| best_ratio = ratio | |
| elif ratio_diff == best_ratio_diff: | |
| if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: | |
| best_ratio = ratio | |
| # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') | |
| return best_ratio | |
| def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): | |
| orig_width, orig_height = image.size | |
| aspect_ratio = orig_width / orig_height | |
| # calculate the existing image aspect ratio | |
| target_ratios = set( | |
| (i, j) | |
| for n in range(min_num, max_num + 1) | |
| for i in range(1, n + 1) | |
| for j in range(1, n + 1) | |
| if i * j <= max_num and i * j >= min_num | |
| ) | |
| target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) | |
| # find the closest aspect ratio to the target | |
| target_aspect_ratio = find_closest_aspect_ratio( | |
| aspect_ratio, target_ratios, orig_width, orig_height, image_size | |
| ) | |
| # calculate the target width and height | |
| target_width = image_size * target_aspect_ratio[0] | |
| target_height = image_size * target_aspect_ratio[1] | |
| blocks = target_aspect_ratio[0] * target_aspect_ratio[1] | |
| # resize the image | |
| resized_img = image.resize((target_width, target_height)) | |
| processed_images = [] | |
| for i in range(blocks): | |
| box = ( | |
| (i % (target_width // image_size)) * image_size, | |
| (i // (target_width // image_size)) * image_size, | |
| ((i % (target_width // image_size)) + 1) * image_size, | |
| ((i // (target_width // image_size)) + 1) * image_size, | |
| ) | |
| # split the image | |
| split_img = resized_img.crop(box) | |
| processed_images.append(split_img) | |
| assert len(processed_images) == blocks | |
| if use_thumbnail and len(processed_images) != 1: | |
| thumbnail_img = image.resize((image_size, image_size)) | |
| # processed_images.append(thumbnail_img) | |
| processed_images = [ | |
| thumbnail_img, | |
| ] + processed_images | |
| return processed_images, (target_width, target_height) | |