Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import torch | |
| import PIL | |
| from PIL import Image | |
| from typing import Optional, Union, List | |
| from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration | |
| import bitsandbytes | |
| import accelerate | |
| from my_model.config import captioning_config as config | |
| from my_model.utilities.gen_utilities import free_gpu_resources | |
| class ImageCaptioningModel: | |
| """ | |
| A class to handle image captioning using InstructBlip model. | |
| Attributes: | |
| model_type (str): Type of the model to use. | |
| processor (InstructBlipProcessor or None): The processor for handling image input. | |
| model (InstructBlipForConditionalGeneration or None): The loaded model. | |
| prompt (str): Prompt for the model. | |
| max_image_size (int): Maximum size for the input image. | |
| min_length (int): Minimum length of the generated caption. | |
| max_new_tokens (int): Maximum number of new tokens to generate. | |
| model_path (str): Path to the pre-trained model. | |
| device_map (str): Device map for model loading. | |
| torch_dtype (torch.dtype): Data type for torch tensors. | |
| load_in_8bit (bool): Whether to load the model in 8-bit precision. | |
| load_in_4bit (bool): Whether to load the model in 4-bit precision. | |
| low_cpu_mem_usage (bool): Whether to optimize for low CPU memory usage. | |
| skip_special_tokens (bool): Whether to skip special tokens in the generated captions. | |
| """ | |
| def __init__(self) -> None: | |
| """ | |
| Initializes the ImageCaptioningModel class with configuration settings. | |
| """ | |
| self.model_type = config.MODEL_TYPE | |
| self.processor = None | |
| self.model = None | |
| self.prompt = config.PROMPT | |
| self.max_image_size = config.MAX_IMAGE_SIZE | |
| self.min_length = config.MIN_LENGTH | |
| self.max_new_tokens = config.MAX_NEW_TOKENS | |
| self.model_path = config.MODEL_PATH | |
| self.device_map = config.DEVICE_MAP | |
| self.torch_dtype = config.TORCH_DTYPE | |
| self.load_in_8bit = config.LOAD_IN_8BIT | |
| self.load_in_4bit = config.LOAD_IN_4BIT | |
| self.low_cpu_mem_usage = config.LOW_CPU_MEM_USAGE | |
| self.skip_secial_tokens = config.SKIP_SPECIAL_TOKENS | |
| def load_model(self) -> None: | |
| """ | |
| Loads the InstructBlip model and processor based on the specified configuration. | |
| """ | |
| if self.load_in_4bit and self.load_in_8bit: # Ensure only one of 4-bit or 8-bit precision is used. | |
| self.load_in_4bit = False | |
| if self.model_type == 'i_blip': | |
| self.processor = InstructBlipProcessor.from_pretrained(self.model_path, | |
| load_in_8bit=self.load_in_8bit, | |
| load_in_4bit=self.load_in_4bit, | |
| torch_dtype=self.torch_dtype, | |
| device_map=self.device_map | |
| ) | |
| free_gpu_resources() | |
| self.model = InstructBlipForConditionalGeneration.from_pretrained(self.model_path, | |
| load_in_8bit=self.load_in_8bit, | |
| load_in_4bit=self.load_in_4bit, | |
| torch_dtype=self.torch_dtype, | |
| low_cpu_mem_usage=self.low_cpu_mem_usage, | |
| device_map=self.device_map | |
| ) | |
| free_gpu_resources() | |
| def resize_image(self, image: Image.Image, max_image_size: Optional[int] = None) -> Image.Image: | |
| """ | |
| Resizes the image to fit within the specified maximum size while maintaining aspect ratio. | |
| Args: | |
| image (Image.Image): The input image to resize. | |
| max_image_size (Optional[int]): The maximum size for the resized image. Defaults to None. | |
| Returns: | |
| Image.Image: The resized image. | |
| """ | |
| if max_image_size is None: | |
| max_image_size = int(os.getenv("MAX_IMAGE_SIZE", "1024")) | |
| h, w = image.size | |
| scale = max_image_size / max(h, w) | |
| if scale < 1: | |
| new_w = int(w * scale) | |
| new_h = int(h * scale) | |
| image = image.resize((new_w, new_h), resample=PIL.Image.Resampling.LANCZOS) | |
| return image | |
| def generate_caption(self, image_path: Union[str, io.IOBase, Image.Image]) -> str: | |
| """ | |
| Generates a caption for the given image. | |
| Args: | |
| image_path (Union[str, io.IOBase, Image.Image]): The path to the image, file-like object, or PIL Image. | |
| Returns: | |
| str: The generated caption for the image. | |
| """ | |
| free_gpu_resources() | |
| free_gpu_resources() | |
| if isinstance(image_path, str) or isinstance(image_path, io.IOBase): | |
| # If it's a file path or file-like object, open it as a PIL Image | |
| image = Image.open(image_path) | |
| elif isinstance(image_path, Image.Image): | |
| image = image_path | |
| image = self.resize_image(image) | |
| inputs = self.processor(image, self.prompt, return_tensors="pt").to("cuda", self.torch_dtype) | |
| outputs = self.model.generate(**inputs, min_length=self.min_length, max_new_tokens=self.max_new_tokens) | |
| caption = self.processor.decode(outputs[0], skip_special_tokens=self.skip_secial_tokens).strip() | |
| free_gpu_resources() | |
| free_gpu_resources() | |
| return caption | |
| def generate_captions_for_multiple_images(self, image_paths: List[Union[str, io.IOBase, Image.Image]]) -> List[str]: | |
| """ | |
| Generates captions for multiple images. | |
| Args: | |
| image_paths (List[Union[str, io.IOBase, Image.Image]]): A list of paths to images, file-like objects, or PIL Images. | |
| Returns: | |
| List[str]: A list of captions for the provided images. | |
| """ | |
| return [self.generate_caption(image_path) for image_path in image_paths] | |
| def get_caption(img: Union[str, io.IOBase, Image.Image]) -> str: | |
| """ | |
| Loads the captioning model and generates a caption for a single image. | |
| Args: | |
| img (Union[str, io.IOBase, Image.Image]): The path to the image, file-like object, or PIL Image. | |
| Returns: | |
| str: The generated caption for the image. | |
| """ | |
| captioner = ImageCaptioningModel() | |
| free_gpu_resources() | |
| captioner.load_model() | |
| free_gpu_resources() | |
| caption = captioner.generate_caption(img) | |
| free_gpu_resources() | |
| return caption |