Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import torch | |
| import gc | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| import logging | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| torch.cuda.empty_cache() | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF']= 'max_split_size_mb:1024' | |
| gc.collect() | |
| class SMOLVLM2: | |
| def __init__(self,model_name = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" , memory_efficient=True): | |
| self.half = True | |
| self.processor = AutoProcessor.from_pretrained(model_name) | |
| if self.support_flash_attension(device_id=0): | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| _attn_implementation="flash_attention_2" | |
| ).to(device) | |
| else: | |
| self.model = AutoModelForImageTextToText.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| ).to(device) | |
| logging.info("Model loaded") | |
| self.print_gpu_memory() | |
| def print_gpu_memory(): | |
| logging.info(f"Allocated memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") | |
| logging.info(f"Cached memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB") | |
| ## check for flash attension | |
| def support_flash_attension(device_id): | |
| """ Check if GPU supports FalshAttension""" | |
| support = False | |
| major, minor = torch.cuda.get_device_capability(device_id) | |
| if major<8: | |
| print("GPU does not support Flash Attension") | |
| else: | |
| support = True | |
| return support | |
| def run_inference_on_image(self,image_path,query): | |
| messages = [ | |
| { | |
| "role":"user", | |
| "content":[ | |
| {"type":"image","path":image_path}, | |
| {"type":"text","text":query} | |
| ] | |
| } | |
| ] | |
| inputs = self.processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt = True, | |
| tokenize = True, | |
| return_dict = True, | |
| return_tensors = 'pt' | |
| ) | |
| if self.half: | |
| inputs.to(torch.half).to(device) | |
| else: | |
| inputs.to(device) | |
| generated_ids = self.model.generate(**inputs,do_sample = False , max_new_tokens = 1024) | |
| generated_texts = self.processor.batch_decode(generated_ids,skip_special_tokens=True) | |
| del inputs | |
| torch.cuda.empty_cache() | |
| return generated_texts[0].split('\n')[-1] | |