PRISM2.0 / backend /model_handler.py
devranx's picture
Fix startup timeout with lazy model loading
ad6d315
import os
import torch
import easyocr
import numpy as np
import gc
from transformers import AutoTokenizer, AutoModel, AutoProcessor, AutoModelForZeroShotImageClassification
import torch.nn.functional as F
from backend.utils import build_transform
class ModelHandler:
def __init__(self):
try:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}", flush=True)
self.transform = build_transform()
# Initialize model placeholders
self.model_int = None
self.tokenizer_int = None
self.reader = None
self.model_clip = None
self.processor_clip = None
except Exception as e:
print(f"CRITICAL ERROR in ModelHandler.__init__: {e}", flush=True)
import traceback
traceback.print_exc()
def load_internvl(self):
if self.model_int is not None and self.tokenizer_int is not None:
return
print("Loading InternVL model...", flush=True)
try:
# Check if local path exists, otherwise use HF Hub ID
local_path = os.path.join("Models", "InternVL2_5-1B-MPO")
if os.path.exists(local_path):
internvl_model_path = local_path
print(f"Loading InternVL from local path: {internvl_model_path}", flush=True)
else:
internvl_model_path = "OpenGVLab/InternVL2_5-1B-MPO" # HF Hub ID
print(f"Local model not found. Downloading InternVL from HF Hub: {internvl_model_path}", flush=True)
self.model_int = AutoModel.from_pretrained(
internvl_model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).eval()
for module in self.model_int.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0
self.tokenizer_int = AutoTokenizer.from_pretrained(internvl_model_path, trust_remote_code=True)
print("\nInternVL model and tokenizer loaded successfully.", flush=True)
except Exception as e:
print(f"\nError loading InternVL model or tokenizer: {e}", flush=True)
import traceback
traceback.print_exc()
self.model_int = None
self.tokenizer_int = None
def load_easyocr(self):
if self.reader is not None:
return
print("Loading EasyOCR model...", flush=True)
try:
# EasyOCR automatically handles downloading if not present
self.reader = easyocr.Reader(['en', 'hi'], gpu=False)
print("\nEasyOCR reader initialized successfully.")
except Exception as e:
print(f"\nError initializing EasyOCR reader: {e}")
self.reader = None
def load_clip(self):
if self.model_clip is not None and self.processor_clip is not None:
return
print("Loading CLIP model...", flush=True)
try:
local_path = os.path.join("Models", "clip-vit-base-patch32")
if os.path.exists(local_path):
clip_model_path = local_path
print(f"Loading CLIP from local path: {clip_model_path}")
else:
clip_model_path = "openai/clip-vit-base-patch32" # HF Hub ID
print(f"Local model not found. Downloading CLIP from HF Hub: {clip_model_path}")
self.processor_clip = AutoProcessor.from_pretrained(clip_model_path)
self.model_clip = AutoModelForZeroShotImageClassification.from_pretrained(clip_model_path).to(self.device)
print("\nCLIP model and processor loaded successfully.")
except Exception as e:
print(f"\nError loading CLIP model or processor: {e}")
self.model_clip = None
self.processor_clip = None
def easyocr_ocr(self, image):
self.load_easyocr()
if not self.reader:
return ""
image_np = np.array(image)
results = self.reader.readtext(image_np, detail=1)
del image_np
gc.collect()
if not results:
return ""
sorted_results = sorted(results, key=lambda x: (x[0][0][1], x[0][0][0]))
ordered_text = " ".join([res[1] for res in sorted_results]).strip()
return ordered_text
def intern(self, image, prompt, max_tokens):
self.load_internvl()
if not self.model_int or not self.tokenizer_int:
return ""
pixel_values = self.transform(image).unsqueeze(0).to(self.device).to(torch.bfloat16)
with torch.no_grad():
response, _ = self.model_int.chat(
self.tokenizer_int,
pixel_values,
prompt,
generation_config={
"max_new_tokens": max_tokens,
"do_sample": False,
"num_beams": 1,
"temperature": 1.0,
"top_p": 1.0,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"pad_token_id": self.tokenizer_int.pad_token_id
},
history=None,
return_history=True
)
del pixel_values
gc.collect()
return response
def clip(self, image, labels):
self.load_clip()
if not self.model_clip or not self.processor_clip:
return None
processed = self.processor_clip(
text=labels,
images=image,
padding=True,
return_tensors="pt"
).to(self.device)
del image, labels
gc.collect()
return processed
def get_clip_probs(self, image, labels):
# clip() calls load_clip(), so we don't strictly need it here, but good for safety if clip() implementation changes
self.load_clip()
inputs = self.clip(image, labels)
if inputs is None:
return None
with torch.no_grad():
outputs = self.model_clip(**inputs)
logits_per_image = outputs.logits_per_image
probs = F.softmax(logits_per_image, dim=1)
del inputs, outputs, logits_per_image
gc.collect()
return probs
# Create a global instance to be used by modules
model_handler = ModelHandler()