File size: 5,747 Bytes
d790e98 cc1ce63 d790e98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import os
import torch
import easyocr
import numpy as np
import gc
from transformers import AutoTokenizer, AutoModel, AutoProcessor, AutoModelForZeroShotImageClassification
import torch.nn.functional as F
from backend.utils import build_transform
class ModelHandler:
def __init__(self):
try:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}", flush=True)
self.transform = build_transform()
self.load_models()
except Exception as e:
print(f"CRITICAL ERROR in ModelHandler.__init__: {e}", flush=True)
import traceback
traceback.print_exc()
def load_models(self):
# MODEL 1: InternVL
try:
# Check if local path exists, otherwise use HF Hub ID
local_path = os.path.join("Models", "InternVL2_5-1B-MPO")
if os.path.exists(local_path):
internvl_model_path = local_path
print(f"Loading InternVL from local path: {internvl_model_path}", flush=True)
else:
internvl_model_path = "OpenGVLab/InternVL2_5-1B-MPO" # HF Hub ID
print(f"Local model not found. Downloading InternVL from HF Hub: {internvl_model_path}", flush=True)
self.model_int = AutoModel.from_pretrained(
internvl_model_path,
dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).eval()
for module in self.model_int.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0
self.tokenizer_int = AutoTokenizer.from_pretrained(internvl_model_path, trust_remote_code=True)
print("\nInternVL model and tokenizer loaded successfully.", flush=True)
except Exception as e:
print(f"\nError loading InternVL model or tokenizer: {e}", flush=True)
import traceback
traceback.print_exc()
self.model_int = None
self.tokenizer_int = None
# MODEL 2: EasyOCR
try:
# EasyOCR automatically handles downloading if not present
self.reader = easyocr.Reader(['en', 'hi'], gpu=False)
print("\nEasyOCR reader initialized successfully.")
except Exception as e:
print(f"\nError initializing EasyOCR reader: {e}")
self.reader = None
# MODEL 3: CLIP
try:
local_path = os.path.join("Models", "clip-vit-base-patch32")
if os.path.exists(local_path):
clip_model_path = local_path
print(f"Loading CLIP from local path: {clip_model_path}")
else:
clip_model_path = "openai/clip-vit-base-patch32" # HF Hub ID
print(f"Local model not found. Downloading CLIP from HF Hub: {clip_model_path}")
self.processor_clip = AutoProcessor.from_pretrained(clip_model_path)
self.model_clip = AutoModelForZeroShotImageClassification.from_pretrained(clip_model_path).to(self.device)
print("\nCLIP model and processor loaded successfully.")
except Exception as e:
print(f"\nError loading CLIP model or processor: {e}")
self.model_clip = None
self.processor_clip = None
def easyocr_ocr(self, image):
if not self.reader:
return ""
image_np = np.array(image)
results = self.reader.readtext(image_np, detail=1)
del image_np
gc.collect()
if not results:
return ""
sorted_results = sorted(results, key=lambda x: (x[0][0][1], x[0][0][0]))
ordered_text = " ".join([res[1] for res in sorted_results]).strip()
return ordered_text
def intern(self, image, prompt, max_tokens):
if not self.model_int or not self.tokenizer_int:
return ""
pixel_values = self.transform(image).unsqueeze(0).to(self.device).to(torch.bfloat16)
with torch.no_grad():
response, _ = self.model_int.chat(
self.tokenizer_int,
pixel_values,
prompt,
generation_config={
"max_new_tokens": max_tokens,
"do_sample": False,
"num_beams": 1,
"temperature": 1.0,
"top_p": 1.0,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"pad_token_id": self.tokenizer_int.pad_token_id
},
history=None,
return_history=True
)
del pixel_values
gc.collect()
return response
def clip(self, image, labels):
if not self.model_clip or not self.processor_clip:
return None
processed = self.processor_clip(
text=labels,
images=image,
padding=True,
return_tensors="pt"
).to(self.device)
del image, labels
gc.collect()
return processed
def get_clip_probs(self, image, labels):
inputs = self.clip(image, labels)
if inputs is None:
return None
with torch.no_grad():
outputs = self.model_clip(**inputs)
logits_per_image = outputs.logits_per_image
probs = F.softmax(logits_per_image, dim=1)
del inputs, outputs, logits_per_image
gc.collect()
return probs
# Create a global instance to be used by modules
model_handler = ModelHandler()
|