File size: 5,747 Bytes
d790e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc1ce63
d790e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import torch
import easyocr
import numpy as np
import gc
from transformers import AutoTokenizer, AutoModel, AutoProcessor, AutoModelForZeroShotImageClassification
import torch.nn.functional as F
from backend.utils import build_transform

class ModelHandler:
    def __init__(self):
        try:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            print(f"Using device: {self.device}", flush=True)
            self.transform = build_transform()
            self.load_models()
        except Exception as e:
            print(f"CRITICAL ERROR in ModelHandler.__init__: {e}", flush=True)
            import traceback
            traceback.print_exc()

    def load_models(self):
        # MODEL 1: InternVL
        try:
            # Check if local path exists, otherwise use HF Hub ID
            local_path = os.path.join("Models", "InternVL2_5-1B-MPO")
            if os.path.exists(local_path):
                internvl_model_path = local_path
                print(f"Loading InternVL from local path: {internvl_model_path}", flush=True)
            else:
                internvl_model_path = "OpenGVLab/InternVL2_5-1B-MPO" # HF Hub ID
                print(f"Local model not found. Downloading InternVL from HF Hub: {internvl_model_path}", flush=True)

            self.model_int = AutoModel.from_pretrained(
                internvl_model_path,
                dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
                trust_remote_code=True
            ).eval()

            for module in self.model_int.modules():
                if isinstance(module, torch.nn.Dropout):
                    module.p = 0

            self.tokenizer_int = AutoTokenizer.from_pretrained(internvl_model_path, trust_remote_code=True)
            print("\nInternVL model and tokenizer loaded successfully.", flush=True)
        except Exception as e:
            print(f"\nError loading InternVL model or tokenizer: {e}", flush=True)
            import traceback
            traceback.print_exc()
            self.model_int = None
            self.tokenizer_int = None

        # MODEL 2: EasyOCR
        try:
            # EasyOCR automatically handles downloading if not present
            self.reader = easyocr.Reader(['en', 'hi'], gpu=False)
            print("\nEasyOCR reader initialized successfully.")
        except Exception as e:
            print(f"\nError initializing EasyOCR reader: {e}")
            self.reader = None

        # MODEL 3: CLIP
        try:
            local_path = os.path.join("Models", "clip-vit-base-patch32")
            if os.path.exists(local_path):
                clip_model_path = local_path
                print(f"Loading CLIP from local path: {clip_model_path}")
            else:
                clip_model_path = "openai/clip-vit-base-patch32" # HF Hub ID
                print(f"Local model not found. Downloading CLIP from HF Hub: {clip_model_path}")

            self.processor_clip = AutoProcessor.from_pretrained(clip_model_path)
            self.model_clip = AutoModelForZeroShotImageClassification.from_pretrained(clip_model_path).to(self.device)
            print("\nCLIP model and processor loaded successfully.")
        except Exception as e:
            print(f"\nError loading CLIP model or processor: {e}")
            self.model_clip = None
            self.processor_clip = None

    def easyocr_ocr(self, image):
        if not self.reader:
            return ""
        image_np = np.array(image)
        results = self.reader.readtext(image_np, detail=1)
        
        del image_np
        gc.collect()

        if not results:
            return ""
        
        sorted_results = sorted(results, key=lambda x: (x[0][0][1], x[0][0][0]))
        ordered_text = " ".join([res[1] for res in sorted_results]).strip()
        return ordered_text

    def intern(self, image, prompt, max_tokens):
        if not self.model_int or not self.tokenizer_int:
            return ""
            
        pixel_values = self.transform(image).unsqueeze(0).to(self.device).to(torch.bfloat16)
        with torch.no_grad():
            response, _ = self.model_int.chat(
                self.tokenizer_int,
                pixel_values,
                prompt,
                generation_config={
                    "max_new_tokens": max_tokens,
                    "do_sample": False,
                    "num_beams": 1,
                    "temperature": 1.0,
                    "top_p": 1.0,
                    "repetition_penalty": 1.0,
                    "length_penalty": 1.0,
                    "pad_token_id": self.tokenizer_int.pad_token_id
                },
                history=None,
                return_history=True
            )
        
        del pixel_values
        gc.collect()
        return response

    def clip(self, image, labels):
        if not self.model_clip or not self.processor_clip:
            return None

        processed = self.processor_clip(
            text=labels,
            images=image,
            padding=True,
            return_tensors="pt"
        ).to(self.device)

        del image, labels
        gc.collect()
        return processed

    def get_clip_probs(self, image, labels):
        inputs = self.clip(image, labels)
        if inputs is None:
            return None
            
        with torch.no_grad():
            outputs = self.model_clip(**inputs)
        
        logits_per_image = outputs.logits_per_image
        probs = F.softmax(logits_per_image, dim=1)
        
        del inputs, outputs, logits_per_image
        gc.collect()
        
        return probs

# Create a global instance to be used by modules
model_handler = ModelHandler()