PRISM2.0 / backend /model_handler.py
devranx's picture
Fix deprecation warning
cc1ce63
raw
history blame
5.75 kB
import os
import torch
import easyocr
import numpy as np
import gc
from transformers import AutoTokenizer, AutoModel, AutoProcessor, AutoModelForZeroShotImageClassification
import torch.nn.functional as F
from backend.utils import build_transform
class ModelHandler:
def __init__(self):
try:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}", flush=True)
self.transform = build_transform()
self.load_models()
except Exception as e:
print(f"CRITICAL ERROR in ModelHandler.__init__: {e}", flush=True)
import traceback
traceback.print_exc()
def load_models(self):
# MODEL 1: InternVL
try:
# Check if local path exists, otherwise use HF Hub ID
local_path = os.path.join("Models", "InternVL2_5-1B-MPO")
if os.path.exists(local_path):
internvl_model_path = local_path
print(f"Loading InternVL from local path: {internvl_model_path}", flush=True)
else:
internvl_model_path = "OpenGVLab/InternVL2_5-1B-MPO" # HF Hub ID
print(f"Local model not found. Downloading InternVL from HF Hub: {internvl_model_path}", flush=True)
self.model_int = AutoModel.from_pretrained(
internvl_model_path,
dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).eval()
for module in self.model_int.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0
self.tokenizer_int = AutoTokenizer.from_pretrained(internvl_model_path, trust_remote_code=True)
print("\nInternVL model and tokenizer loaded successfully.", flush=True)
except Exception as e:
print(f"\nError loading InternVL model or tokenizer: {e}", flush=True)
import traceback
traceback.print_exc()
self.model_int = None
self.tokenizer_int = None
# MODEL 2: EasyOCR
try:
# EasyOCR automatically handles downloading if not present
self.reader = easyocr.Reader(['en', 'hi'], gpu=False)
print("\nEasyOCR reader initialized successfully.")
except Exception as e:
print(f"\nError initializing EasyOCR reader: {e}")
self.reader = None
# MODEL 3: CLIP
try:
local_path = os.path.join("Models", "clip-vit-base-patch32")
if os.path.exists(local_path):
clip_model_path = local_path
print(f"Loading CLIP from local path: {clip_model_path}")
else:
clip_model_path = "openai/clip-vit-base-patch32" # HF Hub ID
print(f"Local model not found. Downloading CLIP from HF Hub: {clip_model_path}")
self.processor_clip = AutoProcessor.from_pretrained(clip_model_path)
self.model_clip = AutoModelForZeroShotImageClassification.from_pretrained(clip_model_path).to(self.device)
print("\nCLIP model and processor loaded successfully.")
except Exception as e:
print(f"\nError loading CLIP model or processor: {e}")
self.model_clip = None
self.processor_clip = None
def easyocr_ocr(self, image):
if not self.reader:
return ""
image_np = np.array(image)
results = self.reader.readtext(image_np, detail=1)
del image_np
gc.collect()
if not results:
return ""
sorted_results = sorted(results, key=lambda x: (x[0][0][1], x[0][0][0]))
ordered_text = " ".join([res[1] for res in sorted_results]).strip()
return ordered_text
def intern(self, image, prompt, max_tokens):
if not self.model_int or not self.tokenizer_int:
return ""
pixel_values = self.transform(image).unsqueeze(0).to(self.device).to(torch.bfloat16)
with torch.no_grad():
response, _ = self.model_int.chat(
self.tokenizer_int,
pixel_values,
prompt,
generation_config={
"max_new_tokens": max_tokens,
"do_sample": False,
"num_beams": 1,
"temperature": 1.0,
"top_p": 1.0,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"pad_token_id": self.tokenizer_int.pad_token_id
},
history=None,
return_history=True
)
del pixel_values
gc.collect()
return response
def clip(self, image, labels):
if not self.model_clip or not self.processor_clip:
return None
processed = self.processor_clip(
text=labels,
images=image,
padding=True,
return_tensors="pt"
).to(self.device)
del image, labels
gc.collect()
return processed
def get_clip_probs(self, image, labels):
inputs = self.clip(image, labels)
if inputs is None:
return None
with torch.no_grad():
outputs = self.model_clip(**inputs)
logits_per_image = outputs.logits_per_image
probs = F.softmax(logits_per_image, dim=1)
del inputs, outputs, logits_per_image
gc.collect()
return probs
# Create a global instance to be used by modules
model_handler = ModelHandler()