import os import torch import easyocr import numpy as np import gc from transformers import AutoTokenizer, AutoModel, AutoProcessor, AutoModelForZeroShotImageClassification import torch.nn.functional as F from backend.utils import build_transform class ModelHandler: def __init__(self): try: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}", flush=True) self.transform = build_transform() # Initialize model placeholders self.model_int = None self.tokenizer_int = None self.reader = None self.model_clip = None self.processor_clip = None except Exception as e: print(f"CRITICAL ERROR in ModelHandler.__init__: {e}", flush=True) import traceback traceback.print_exc() def load_internvl(self): if self.model_int is not None and self.tokenizer_int is not None: return print("Loading InternVL model...", flush=True) try: # Check if local path exists, otherwise use HF Hub ID local_path = os.path.join("Models", "InternVL2_5-1B-MPO") if os.path.exists(local_path): internvl_model_path = local_path print(f"Loading InternVL from local path: {internvl_model_path}", flush=True) else: internvl_model_path = "OpenGVLab/InternVL2_5-1B-MPO" # HF Hub ID print(f"Local model not found. Downloading InternVL from HF Hub: {internvl_model_path}", flush=True) self.model_int = AutoModel.from_pretrained( internvl_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True ).eval() for module in self.model_int.modules(): if isinstance(module, torch.nn.Dropout): module.p = 0 self.tokenizer_int = AutoTokenizer.from_pretrained(internvl_model_path, trust_remote_code=True) print("\nInternVL model and tokenizer loaded successfully.", flush=True) except Exception as e: print(f"\nError loading InternVL model or tokenizer: {e}", flush=True) import traceback traceback.print_exc() self.model_int = None self.tokenizer_int = None def load_easyocr(self): if self.reader is not None: return print("Loading EasyOCR model...", flush=True) try: # EasyOCR automatically handles downloading if not present self.reader = easyocr.Reader(['en', 'hi'], gpu=False) print("\nEasyOCR reader initialized successfully.") except Exception as e: print(f"\nError initializing EasyOCR reader: {e}") self.reader = None def load_clip(self): if self.model_clip is not None and self.processor_clip is not None: return print("Loading CLIP model...", flush=True) try: local_path = os.path.join("Models", "clip-vit-base-patch32") if os.path.exists(local_path): clip_model_path = local_path print(f"Loading CLIP from local path: {clip_model_path}") else: clip_model_path = "openai/clip-vit-base-patch32" # HF Hub ID print(f"Local model not found. Downloading CLIP from HF Hub: {clip_model_path}") self.processor_clip = AutoProcessor.from_pretrained(clip_model_path) self.model_clip = AutoModelForZeroShotImageClassification.from_pretrained(clip_model_path).to(self.device) print("\nCLIP model and processor loaded successfully.") except Exception as e: print(f"\nError loading CLIP model or processor: {e}") self.model_clip = None self.processor_clip = None def easyocr_ocr(self, image): self.load_easyocr() if not self.reader: return "" image_np = np.array(image) results = self.reader.readtext(image_np, detail=1) del image_np gc.collect() if not results: return "" sorted_results = sorted(results, key=lambda x: (x[0][0][1], x[0][0][0])) ordered_text = " ".join([res[1] for res in sorted_results]).strip() return ordered_text def intern(self, image, prompt, max_tokens): self.load_internvl() if not self.model_int or not self.tokenizer_int: return "" pixel_values = self.transform(image).unsqueeze(0).to(self.device).to(torch.bfloat16) with torch.no_grad(): response, _ = self.model_int.chat( self.tokenizer_int, pixel_values, prompt, generation_config={ "max_new_tokens": max_tokens, "do_sample": False, "num_beams": 1, "temperature": 1.0, "top_p": 1.0, "repetition_penalty": 1.0, "length_penalty": 1.0, "pad_token_id": self.tokenizer_int.pad_token_id }, history=None, return_history=True ) del pixel_values gc.collect() return response def clip(self, image, labels): self.load_clip() if not self.model_clip or not self.processor_clip: return None processed = self.processor_clip( text=labels, images=image, padding=True, return_tensors="pt" ).to(self.device) del image, labels gc.collect() return processed def get_clip_probs(self, image, labels): # clip() calls load_clip(), so we don't strictly need it here, but good for safety if clip() implementation changes self.load_clip() inputs = self.clip(image, labels) if inputs is None: return None with torch.no_grad(): outputs = self.model_clip(**inputs) logits_per_image = outputs.logits_per_image probs = F.softmax(logits_per_image, dim=1) del inputs, outputs, logits_per_image gc.collect() return probs # Create a global instance to be used by modules model_handler = ModelHandler()