Spaces:
Paused
Paused
| from typing import List | |
| import torch | |
| from PIL import Image | |
| from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline | |
| from modules import shared | |
| from modules.text_generation import encode | |
| from huggingface_hub import hf_hub_download | |
| from .minigpt4.processor import Blip2ImageEvalProcessor | |
| from .minigpt4.mini_gpt4 import MiniGPT4 | |
| class MiniGPT4_Pipeline(AbstractMultimodalPipeline): | |
| def __init__(self, params: dict) -> None: | |
| super().__init__() | |
| self.image_processor = Blip2ImageEvalProcessor() | |
| def placeholder_token_id() -> int: | |
| return 1 | |
| def image_start() -> str: | |
| return "<Img>" | |
| def image_end() -> str: | |
| return "</Img>" | |
| def num_image_embeds() -> int: | |
| return 32 | |
| def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor: | |
| return shared.model.model.embed_tokens(input_ids).to(shared.model.device, dtype=shared.model.dtype) | |
| def placeholder_embeddings() -> torch.Tensor: | |
| placeholders = encode("<ImgContent>", add_bos_token=False, add_special_tokens=False)[0]#torch.ones(MiniGPT4_Pipeline.num_image_embeds()) * MiniGPT4_Pipeline.placeholder_token_id() | |
| return MiniGPT4_Pipeline.embed_tokens(placeholders.to(shared.model.device, dtype=torch.int64)).to(dtype=shared.model.dtype) | |
| def embed_images(self, images: List[Image.Image]) -> torch.Tensor: | |
| im = torch.stack([self.image_processor(image) for image in images]) | |
| image_emb = self.vision_tower.encode_img(im) | |
| return image_emb.to(shared.model.device, dtype=shared.model.dtype) | |
| class MiniGPT4_13b_Pipeline(MiniGPT4_Pipeline): | |
| def __init__(self, params: dict) -> None: | |
| super().__init__(params) | |
| ckpt_path = hf_hub_download("Vision-CAIR/MiniGPT-4", "pretrained_minigpt4.pth") | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| self.vision_tower = MiniGPT4(llama_hidden_size=5120, | |
| vision_dtype=self._get_dtype("vision_bits", params), | |
| vision_device=self._get_device("vision_device", params), | |
| projector_device=self._get_device("projector_device", params), | |
| projector_dtype=self._get_dtype("projector_bits", params)) | |
| self.vision_tower.load_state_dict(ckpt['model'], strict=False) | |
| def name() -> str: | |
| return "minigpt4-13b" | |
| class MiniGPT4_7b_Pipeline(MiniGPT4_Pipeline): | |
| def __init__(self, params: dict) -> None: | |
| super().__init__(params) | |
| ckpt_path = hf_hub_download("ckpt/minigpt4-7B", "prerained_minigpt4_7b.pth") | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| self.vision_tower = MiniGPT4(llama_hidden_size=4096, | |
| vision_dtype=self._get_dtype("vision_bits", params), | |
| vision_device=self._get_device("vision_device", params), | |
| projector_device=self._get_device("projector_device", params), | |
| projector_dtype=self._get_dtype("projector_bits", params)) | |
| self.vision_tower.load_state_dict(ckpt['model'], strict=False) | |
| def name() -> str: | |
| return "minigpt4-7b" | |