File size: 6,583 Bytes
d790e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad6d315
 
 
 
 
 
 
 
d790e98
 
 
 
 
ad6d315
 
 
 
 
d790e98
 
 
 
 
 
 
 
 
 
 
 
abc2a72
d790e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad6d315
 
 
 
 
d790e98
 
 
 
 
 
 
 
ad6d315
 
 
 
 
d790e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad6d315
d790e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad6d315
d790e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad6d315
d790e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad6d315
 
d790e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import torch
import easyocr
import numpy as np
import gc
from transformers import AutoTokenizer, AutoModel, AutoProcessor, AutoModelForZeroShotImageClassification
import torch.nn.functional as F
from backend.utils import build_transform

class ModelHandler:
    def __init__(self):
        try:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            print(f"Using device: {self.device}", flush=True)
            self.transform = build_transform()
            
            # Initialize model placeholders
            self.model_int = None
            self.tokenizer_int = None
            self.reader = None
            self.model_clip = None
            self.processor_clip = None
            
        except Exception as e:
            print(f"CRITICAL ERROR in ModelHandler.__init__: {e}", flush=True)
            import traceback
            traceback.print_exc()

    def load_internvl(self):
        if self.model_int is not None and self.tokenizer_int is not None:
            return

        print("Loading InternVL model...", flush=True)
        try:
            # Check if local path exists, otherwise use HF Hub ID
            local_path = os.path.join("Models", "InternVL2_5-1B-MPO")
            if os.path.exists(local_path):
                internvl_model_path = local_path
                print(f"Loading InternVL from local path: {internvl_model_path}", flush=True)
            else:
                internvl_model_path = "OpenGVLab/InternVL2_5-1B-MPO" # HF Hub ID
                print(f"Local model not found. Downloading InternVL from HF Hub: {internvl_model_path}", flush=True)

            self.model_int = AutoModel.from_pretrained(
                internvl_model_path,
                torch_dtype=torch.bfloat16,
                low_cpu_mem_usage=True,
                trust_remote_code=True
            ).eval()

            for module in self.model_int.modules():
                if isinstance(module, torch.nn.Dropout):
                    module.p = 0

            self.tokenizer_int = AutoTokenizer.from_pretrained(internvl_model_path, trust_remote_code=True)
            print("\nInternVL model and tokenizer loaded successfully.", flush=True)
        except Exception as e:
            print(f"\nError loading InternVL model or tokenizer: {e}", flush=True)
            import traceback
            traceback.print_exc()
            self.model_int = None
            self.tokenizer_int = None

    def load_easyocr(self):
        if self.reader is not None:
            return

        print("Loading EasyOCR model...", flush=True)
        try:
            # EasyOCR automatically handles downloading if not present
            self.reader = easyocr.Reader(['en', 'hi'], gpu=False)
            print("\nEasyOCR reader initialized successfully.")
        except Exception as e:
            print(f"\nError initializing EasyOCR reader: {e}")
            self.reader = None

    def load_clip(self):
        if self.model_clip is not None and self.processor_clip is not None:
            return

        print("Loading CLIP model...", flush=True)
        try:
            local_path = os.path.join("Models", "clip-vit-base-patch32")
            if os.path.exists(local_path):
                clip_model_path = local_path
                print(f"Loading CLIP from local path: {clip_model_path}")
            else:
                clip_model_path = "openai/clip-vit-base-patch32" # HF Hub ID
                print(f"Local model not found. Downloading CLIP from HF Hub: {clip_model_path}")

            self.processor_clip = AutoProcessor.from_pretrained(clip_model_path)
            self.model_clip = AutoModelForZeroShotImageClassification.from_pretrained(clip_model_path).to(self.device)
            print("\nCLIP model and processor loaded successfully.")
        except Exception as e:
            print(f"\nError loading CLIP model or processor: {e}")
            self.model_clip = None
            self.processor_clip = None

    def easyocr_ocr(self, image):
        self.load_easyocr()
        if not self.reader:
            return ""
        image_np = np.array(image)
        results = self.reader.readtext(image_np, detail=1)
        
        del image_np
        gc.collect()

        if not results:
            return ""
        
        sorted_results = sorted(results, key=lambda x: (x[0][0][1], x[0][0][0]))
        ordered_text = " ".join([res[1] for res in sorted_results]).strip()
        return ordered_text

    def intern(self, image, prompt, max_tokens):
        self.load_internvl()
        if not self.model_int or not self.tokenizer_int:
            return ""
            
        pixel_values = self.transform(image).unsqueeze(0).to(self.device).to(torch.bfloat16)
        with torch.no_grad():
            response, _ = self.model_int.chat(
                self.tokenizer_int,
                pixel_values,
                prompt,
                generation_config={
                    "max_new_tokens": max_tokens,
                    "do_sample": False,
                    "num_beams": 1,
                    "temperature": 1.0,
                    "top_p": 1.0,
                    "repetition_penalty": 1.0,
                    "length_penalty": 1.0,
                    "pad_token_id": self.tokenizer_int.pad_token_id
                },
                history=None,
                return_history=True
            )
        
        del pixel_values
        gc.collect()
        return response

    def clip(self, image, labels):
        self.load_clip()
        if not self.model_clip or not self.processor_clip:
            return None

        processed = self.processor_clip(
            text=labels,
            images=image,
            padding=True,
            return_tensors="pt"
        ).to(self.device)

        del image, labels
        gc.collect()
        return processed

    def get_clip_probs(self, image, labels):
        # clip() calls load_clip(), so we don't strictly need it here, but good for safety if clip() implementation changes
        self.load_clip() 
        inputs = self.clip(image, labels)
        if inputs is None:
            return None
            
        with torch.no_grad():
            outputs = self.model_clip(**inputs)
        
        logits_per_image = outputs.logits_per_image
        probs = F.softmax(logits_per_image, dim=1)
        
        del inputs, outputs, logits_per_image
        gc.collect()
        
        return probs

# Create a global instance to be used by modules
model_handler = ModelHandler()