PRISM / model_handler.py
devranx's picture
Upload 20 files
03e275e verified
import os
import torch
import easyocr
import numpy as np
import gc
from transformers import AutoTokenizer, AutoModel, AutoProcessor, AutoModelForZeroShotImageClassification
import torch.nn.functional as F
from utils import build_transform
class ModelHandler:
def __init__(self):
self.device = torch.device("cpu") # Change to "cuda" if GPU available
self.transform = build_transform()
self.load_models()
def load_models(self):
# MODEL 1: InternVL
try:
# Check if local path exists, otherwise use HF Hub ID
local_path = os.path.join("Models", "InternVL2_5-1B-MPO")
if os.path.exists(local_path):
internvl_model_path = local_path
print(f"Loading InternVL from local path: {internvl_model_path}")
else:
internvl_model_path = "OpenGVLab/InternVL2_5-1B-MPO" # HF Hub ID
print(f"Local model not found. Downloading InternVL from HF Hub: {internvl_model_path}")
self.model_int = AutoModel.from_pretrained(
internvl_model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).eval()
for module in self.model_int.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0
self.tokenizer_int = AutoTokenizer.from_pretrained(internvl_model_path, trust_remote_code=True)
print("\nInternVL model and tokenizer loaded successfully.")
except Exception as e:
print(f"\nError loading InternVL model or tokenizer: {e}")
self.model_int = None
self.tokenizer_int = None
# MODEL 2: EasyOCR
try:
# EasyOCR automatically handles downloading if not present
self.reader = easyocr.Reader(['en', 'hi'], gpu=False)
print("\nEasyOCR reader initialized successfully.")
except Exception as e:
print(f"\nError initializing EasyOCR reader: {e}")
self.reader = None
# MODEL 3: CLIP
try:
local_path = os.path.join("Models", "clip-vit-base-patch32")
if os.path.exists(local_path):
clip_model_path = local_path
print(f"Loading CLIP from local path: {clip_model_path}")
else:
clip_model_path = "openai/clip-vit-base-patch32" # HF Hub ID
print(f"Local model not found. Downloading CLIP from HF Hub: {clip_model_path}")
self.processor_clip = AutoProcessor.from_pretrained(clip_model_path)
self.model_clip = AutoModelForZeroShotImageClassification.from_pretrained(clip_model_path).to(self.device)
print("\nCLIP model and processor loaded successfully.")
except Exception as e:
print(f"\nError loading CLIP model or processor: {e}")
self.model_clip = None
self.processor_clip = None
def easyocr_ocr(self, image):
if not self.reader:
return ""
image_np = np.array(image)
results = self.reader.readtext(image_np, detail=1)
del image_np
gc.collect()
if not results:
return ""
sorted_results = sorted(results, key=lambda x: (x[0][0][1], x[0][0][0]))
ordered_text = " ".join([res[1] for res in sorted_results]).strip()
return ordered_text
def intern(self, image, prompt, max_tokens):
if not self.model_int or not self.tokenizer_int:
return ""
pixel_values = self.transform(image).unsqueeze(0).to(self.device).to(torch.bfloat16)
with torch.no_grad():
response, _ = self.model_int.chat(
self.tokenizer_int,
pixel_values,
prompt,
generation_config={
"max_new_tokens": max_tokens,
"do_sample": False,
"num_beams": 1,
"temperature": 1.0,
"top_p": 1.0,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"pad_token_id": self.tokenizer_int.pad_token_id
},
history=None,
return_history=True
)
del pixel_values
gc.collect()
return response
def clip(self, image, labels):
if not self.model_clip or not self.processor_clip:
return None
processed = self.processor_clip(
text=labels,
images=image,
padding=True,
return_tensors="pt"
).to(self.device)
del image, labels
gc.collect()
return processed
def get_clip_probs(self, image, labels):
inputs = self.clip(image, labels)
if inputs is None:
return None
with torch.no_grad():
outputs = self.model_clip(**inputs)
logits_per_image = outputs.logits_per_image
probs = F.softmax(logits_per_image, dim=1)
del inputs, outputs, logits_per_image
gc.collect()
return probs
# Create a global instance to be used by modules
model_handler = ModelHandler()