Spaces:
Sleeping
Sleeping
| # Main script for KBVQA: Knowledge-Based Visual Question Answering Module | |
| # This module is the central component for implementing the designed model architecture for the Knowledge-Based Visual | |
| # Question Answering (KB-VQA) project. It integrates various sub-modules, including image captioning, object detection, | |
| # and a fine-tuned language model, to provide a comprehensive solution for answering questions based on visual input. | |
| # --- Description --- | |
| # **KBVQA class**: | |
| # The KBVQA class encapsulates the functionality needed to perform visual question answering using a combination of | |
| # multimodal models. | |
| # The class handles the following tasks: | |
| # - Loading and managing a fine-tuned language model (LLaMA-2) for question answering. | |
| # - Integrating an image captioning model to generate descriptive captions for input images. | |
| # - Utilizing an object detection model to identify and describe objects within the images. | |
| # - Formatting and generating prompts for the language model based on the image captions and detected objects. | |
| # - Providing methods to analyze images and generate answers to user-provided questions. | |
| # **prepare_kbvqa_model function**: | |
| # - The prepare_kbvqa_model function orchestrates the loading and initialization of the KBVQA class, ensuring it is | |
| # ready for inference. | |
| # ---Instructions--- | |
| # **Model Preparation**: | |
| # Use the prepare_kbvqa_model function to prepare and initialize the KBVQA system, ensuring all required models are | |
| # loaded and ready for use. | |
| # **Image Processing and Question Answering**: | |
| # Use the get_caption method to generate captions for input images. | |
| # Use the detect_objects method to identify and describe objects in the images. | |
| # Use the generate_answer method to answer questions based on the image captions and detected objects. | |
| # This module forms the backbone of the KB-VQA project, integrating advanced models to provide an end-to-end solution | |
| # for visual question answering tasks. | |
| # Ensure all dependencies are installed and the required configuration file is in place before running this script. | |
| # The configurations for the KBVQA class are defined in the 'my_model/config/kbvqa_config.py' file. | |
| # ---------- Please run this module to utilize the full KB-VQA functionality ----------# | |
| # ---------- Please ensure this is run on a GPU ----------# | |
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from typing import Tuple, Optional | |
| from my_model.utilities.gen_utilities import free_gpu_resources | |
| from my_model.captioner.image_captioning import ImageCaptioningModel | |
| from my_model.detector.object_detection import ObjectDetector | |
| import my_model.config.kbvqa_config as config | |
| class KBVQA: | |
| """ | |
| The KBVQA class encapsulates the functionality for the Knowledge-Based Visual Question Answering (KBVQA) model. | |
| It integrates various components such as an image captioning model, object detection model, and a fine-tuned | |
| language model (LLAMA2) on OK-VQA dataset for generating answers to visual questions. | |
| Attributes: | |
| kbvqa_model_name (str): Name of the fine-tuned language model used for KBVQA. | |
| quantization (str): The quantization setting for the model (e.g., '4bit', '8bit'). | |
| max_context_window (int): The maximum number of tokens allowed in the model's context window. | |
| add_eos_token (bool): Flag to indicate whether to add an end-of-sentence token to the tokenizer. | |
| trust_remote (bool): Flag to indicate whether to trust remote code when using the tokenizer. | |
| use_fast (bool): Flag to indicate whether to use the fast version of the tokenizer. | |
| low_cpu_mem_usage (bool): Flag to optimize model loading for low CPU memory usage. | |
| kbvqa_tokenizer (Optional[AutoTokenizer]): The tokenizer for the KBVQA model. | |
| captioner (Optional[ImageCaptioningModel]): The model used for generating image captions. | |
| detector (Optional[ObjectDetector]): The object detection model. | |
| detection_model (Optional[str]): The name of the object detection model. | |
| detection_confidence (Optional[float]): The confidence threshold for object detection. | |
| kbvqa_model (Optional[AutoModelForCausalLM]): The fine-tuned language model for KBVQA. | |
| bnb_config (BitsAndBytesConfig): Configuration for BitsAndBytes optimized model. | |
| access_token (str): Access token for Hugging Face API. | |
| current_prompt_length (int): Prompt length. | |
| Methods: | |
| create_bnb_config: Creates a BitsAndBytes configuration based on the quantization setting. | |
| load_caption_model: Loads the image captioning model. | |
| get_caption: Generates a caption for a given image. | |
| load_detector: Loads the object detection model. | |
| detect_objects: Detects objects in a given image. | |
| load_fine_tuned_model: Loads the fine-tuned KBVQA model along with its tokenizer. | |
| all_models_loaded: Checks if all the required models are loaded. | |
| force_reload_model: Forces a reload of all models, freeing up GPU resources. | |
| format_prompt: Formats the prompt for the KBVQA model. | |
| generate_answer: Generates an answer to a given question using the KBVQA model. | |
| """ | |
| def __init__(self) -> None: | |
| """ | |
| Initializes the KBVQA instance with configuration parameters. | |
| """ | |
| if st.session_state["method"] == "7b-Fine-Tuned Model": | |
| self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME_7b | |
| elif st.session_state["method"] == "13b-Fine-Tuned Model": | |
| self.kbvqa_model_name: str = config.KBVQA_MODEL_NAME_13b | |
| self.quantization: str = config.QUANTIZATION | |
| self.max_context_window: int = config.MAX_CONTEXT_WINDOW # set to 4,000 tokens | |
| self.add_eos_token: bool = config.ADD_EOS_TOKEN | |
| self.trust_remote: bool = config.TRUST_REMOTE | |
| self.use_fast: bool = config.USE_FAST | |
| self.low_cpu_mem_usage: bool = config.LOW_CPU_MEM_USAGE | |
| self.kbvqa_tokenizer: Optional[AutoTokenizer] = None | |
| self.captioner: Optional[ImageCaptioningModel] = None | |
| self.detector: Optional[ObjectDetector] = None | |
| self.detection_model: Optional[str] = None | |
| self.detection_confidence: Optional[float] = None | |
| self.kbvqa_model: Optional[AutoModelForCausalLM] = None | |
| self.bnb_config: BitsAndBytesConfig = self.create_bnb_config() | |
| self.access_token: str = config.HUGGINGFACE_TOKEN | |
| self.current_prompt_length = None | |
| def create_bnb_config(self) -> BitsAndBytesConfig: | |
| """ | |
| Creates a BitsAndBytes configuration based on the quantization setting. | |
| Returns: | |
| BitsAndBytesConfig: Configuration for BitsAndBytes optimized model. | |
| """ | |
| if self.quantization == '4bit': | |
| return BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| elif self.quantization == '8bit': | |
| return BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| bnb_8bit_use_double_quant=True, | |
| bnb_8bit_quant_type="nf4", | |
| bnb_8bit_compute_dtype=torch.bfloat16 | |
| ) | |
| def load_caption_model(self) -> None: | |
| """ | |
| Loads the image captioning model into the KBVQA instance. | |
| Returns: | |
| None | |
| """ | |
| self.captioner = ImageCaptioningModel() | |
| self.captioner.load_model() | |
| free_gpu_resources() | |
| def get_caption(self, img: Image.Image) -> str: | |
| """ | |
| Generates a caption for a given image using the image captioning model. | |
| Args: | |
| img (PIL.Image.Image): The image for which to generate a caption. | |
| Returns: | |
| str: The generated caption for the image. | |
| """ | |
| caption = self.captioner.generate_caption(img) | |
| free_gpu_resources() | |
| return caption | |
| def load_detector(self, model: str) -> None: | |
| """ | |
| Loads the object detection model. | |
| Args: | |
| model (str): The name of the object detection model to load. | |
| Returns: | |
| None | |
| """ | |
| self.detector = ObjectDetector() | |
| self.detector.load_model(model) | |
| free_gpu_resources() | |
| def detect_objects(self, img: Image.Image) -> Tuple[Image.Image, str]: | |
| """ | |
| Detects objects in a given image using the loaded object detection model. | |
| Args: | |
| img (PIL.Image.Image): The image in which to detect objects. | |
| Returns: | |
| tuple: A tuple containing the image with detected objects drawn and a string representation of detected objects. | |
| """ | |
| image = self.detector.process_image(img) | |
| free_gpu_resources() | |
| detected_objects_string, detected_objects_list = self.detector.detect_objects(image, threshold=st.session_state[ | |
| 'confidence_level']) | |
| free_gpu_resources() | |
| image_with_boxes = self.detector.draw_boxes(img, detected_objects_list) | |
| free_gpu_resources() | |
| return image_with_boxes, detected_objects_string | |
| def load_fine_tuned_model(self) -> None: | |
| """ | |
| Loads the fine-tuned KBVQA model along with its tokenizer. | |
| Returns: | |
| None | |
| """ | |
| self.kbvqa_model = AutoModelForCausalLM.from_pretrained(self.kbvqa_model_name, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| quantization_config=self.bnb_config, | |
| token=self.access_token) | |
| free_gpu_resources() | |
| self.kbvqa_tokenizer = AutoTokenizer.from_pretrained(self.kbvqa_model_name, | |
| use_fast=self.use_fast, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=self.trust_remote, | |
| add_eos_token=self.add_eos_token, | |
| token=self.access_token) | |
| free_gpu_resources() | |
| def all_models_loaded(self) -> bool: | |
| """ | |
| Checks if all the required models (KBVQA, captioner, detector) are loaded. | |
| Returns: | |
| bool: True if all models are loaded, False otherwise. | |
| """ | |
| return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None | |
| def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None, | |
| caption: str = None, objects: Optional[str] = None) -> str: | |
| """ | |
| Formats the prompt for the KBVQA model based on the provided parameters. | |
| This implements the Prompt Engineering Module of the Overall KB-VQA Archetecture. | |
| Args: | |
| current_query (str): The current question to be answered. | |
| history (str, optional): The history of previous interactions. | |
| sys_prompt (str, optional): The system prompt or instructions for the model. | |
| caption (str, optional): The caption of the image. | |
| objects (str, optional): The detected objects in the image. | |
| Returns: | |
| str: The formatted prompt for the KBVQA model. | |
| """ | |
| # These are the special tokens designed for the model to be fine-tuned on. | |
| B_CAP = '[CAP]' | |
| E_CAP = '[/CAP]' | |
| B_QES = '[QES]' | |
| E_QES = '[/QES]' | |
| B_OBJ = '[OBJ]' | |
| E_OBJ = '[/OBJ]' | |
| # These are the default special tokens of LLaMA-2 Chat Model. | |
| B_SENT = '<s>' | |
| E_SENT = '</s>' | |
| B_INST = '[INST]' | |
| E_INST = '[/INST]' | |
| B_SYS = '<<SYS>>\n' | |
| E_SYS = '\n<</SYS>>\n\n' | |
| current_query = current_query.strip() | |
| if sys_prompt is None: | |
| sys_prompt = config.SYSTEM_PROMPT.strip() | |
| # History can be used to facilitate multi turn chat, not used for the Run Inference tool within the demo app. | |
| if history is None: | |
| if objects is None: | |
| p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_QES}{current_query}{E_QES}{E_INST}""" | |
| else: | |
| p = f"""{B_SENT}{B_INST} {B_SYS}{sys_prompt}{E_SYS}{B_CAP}{caption}{E_CAP}{B_OBJ}{objects}{E_OBJ}{B_QES}taking into consideration the objects with high certainty, {current_query}{E_QES}{E_INST}""" | |
| else: | |
| p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}""" | |
| return p | |
| def trim_objects(detected_objects_str: str) -> str: | |
| """ | |
| Trim the last object from the detected objects string. | |
| This is implemented to ensure that the prompt length is within the context window, threshold set to 4,000 tokens. | |
| Args: | |
| detected_objects_str (str): String containing detected objects. | |
| Returns: | |
| str: The string with the last object removed. | |
| """ | |
| objects = detected_objects_str.strip().split("\n") | |
| if len(objects) >= 1: | |
| return "\n".join(objects[:-1]) | |
| return "" | |
| def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str: | |
| """ | |
| Generates an answer to a given question using the KBVQA model. | |
| Args: | |
| question (str): The question to be answered. | |
| caption (str): The caption of the image related to the question. | |
| detected_objects_str (str): The string representation of detected objects in the image. | |
| Returns: | |
| str: The generated answer to the question. | |
| """ | |
| free_gpu_resources() | |
| prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str) | |
| num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt)) | |
| self.current_prompt_length = num_tokens | |
| trim = False # flag used to check if prompt trim is required or no. | |
| # max_context_window is set to 4,000 tokens, refer to the config file. | |
| if self.current_prompt_length > self.max_context_window: | |
| trim = True | |
| st.warning( | |
| f"Prompt length is {self.current_prompt_length} which is larger than the maximum context window of LLaMA-2," | |
| f" objects detected with low confidence will be removed one at a time until the prompt length is within the" | |
| f" maximum context window ...") | |
| # an object is trimmed from the bottom of the list until the overall prompt length is within the context window. | |
| while self.current_prompt_length > self.max_context_window: | |
| detected_objects_str = self.trim_objects(detected_objects_str) | |
| prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str) | |
| self.current_prompt_length = len(self.kbvqa_tokenizer.tokenize(prompt)) | |
| if detected_objects_str == "": | |
| break # Break if no objects are left | |
| if trim: | |
| st.warning(f"New prompt length is: {self.current_prompt_length}") | |
| trim = False | |
| model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda') | |
| free_gpu_resources() | |
| input_ids = model_inputs["input_ids"] | |
| output_ids = self.kbvqa_model.generate(input_ids) | |
| free_gpu_resources() | |
| index = input_ids.shape[1] # needed to avoid printing the input prompt | |
| history = self.kbvqa_tokenizer.decode(output_ids[0], skip_special_tokens=False) | |
| output_text = self.kbvqa_tokenizer.decode(output_ids[0][index:], skip_special_tokens=True) | |
| return output_text.capitalize() | |
| def prepare_kbvqa_model(only_reload_detection_model: bool = False, force_reload: bool = False) -> KBVQA: | |
| """ | |
| Prepares the KBVQA model for use, including loading necessary sub-models. | |
| This serves as the main function for loading and reloading the KB-VQA model. | |
| Args: | |
| only_reload_detection_model (bool): If True, only the object detection model is reloaded. | |
| force_reload (bool): If True, forces the reload of all models. | |
| Returns: | |
| KBVQA: An instance of the KBVQA model ready for inference. | |
| """ | |
| if force_reload: | |
| free_gpu_resources() | |
| loading_message = 'Reloading model.. this should take no more than 2 or 3 minutes!' | |
| try: | |
| del st.session_state['kbvqa'] | |
| free_gpu_resources() | |
| free_gpu_resources() | |
| except: | |
| free_gpu_resources() | |
| free_gpu_resources() | |
| pass | |
| free_gpu_resources() | |
| else: | |
| loading_message = 'Looading model.. this should take no more than 2 or 3 minutes!' | |
| free_gpu_resources() | |
| kbvqa = KBVQA() | |
| kbvqa.detection_model = st.session_state.detection_model | |
| # Progress bar for model loading | |
| with st.spinner(loading_message): | |
| if not only_reload_detection_model: | |
| progress_bar = st.progress(0) | |
| kbvqa.load_detector(kbvqa.detection_model) | |
| progress_bar.progress(33) | |
| kbvqa.load_caption_model() | |
| free_gpu_resources() | |
| progress_bar.progress(75) | |
| st.text('Almost there :)') | |
| kbvqa.load_fine_tuned_model() | |
| free_gpu_resources() | |
| progress_bar.progress(100) | |
| else: | |
| free_gpu_resources() | |
| progress_bar = st.progress(0) | |
| kbvqa.load_detector(kbvqa.detection_model) | |
| progress_bar.progress(100) | |
| if kbvqa.all_models_loaded: | |
| st.success('Model loaded successfully and ready for inferecne!') | |
| kbvqa.kbvqa_model.eval() | |
| free_gpu_resources() | |
| return kbvqa | |
| if __name__ == "__main__": | |
| pass | |
| #### Example on how to use the module #### | |
| # Prepare the KBVQA model | |
| # kbvqa = prepare_kbvqa_model() | |
| # Load an image | |
| # image = Image.open('path_to_image.jpg') | |
| # Generate a caption for the image | |
| # caption = kbvqa.get_caption(image) | |
| # Detect objects in the image | |
| # image_with_boxes, detected_objects_str = kbvqa.detect_objects(image) | |
| # Generate an answer to a question about the image | |
| # question = "What is the object in the image?" | |
| # answer = kbvqa.generate_answer(question, caption, detected_objects_str) | |
| # print(f"Answer: {answer}") | |