File size: 3,774 Bytes
3520feb
6ebad2b
 
 
817321f
31433e6
6ebad2b
 
817321f
3bdb5e3
817321f
31433e6
817321f
 
 
 
 
31433e6
817321f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bdb5e3
3520feb
817321f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254976d
817321f
 
 
254976d
817321f
254976d
817321f
 
254976d
817321f
 
 
254976d
 
817321f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from peft import PeftModel
import torch
from qwen_vl_utils import process_vision_info
import io
import os

app = FastAPI(title="Qwen OCR API")

# ---------------------------------------------------------------------
# CONFIGURATION
# ---------------------------------------------------------------------
MODEL_NAME = "NAMAA-Space/Qari-OCR-0.2.2.1-VL-2B-Instruct"   # your fine-tuned model
BASE_MODEL = "Qwen/Qwen2.5-VL-2B-Instruct"                   # base model for LoRA fallback

# detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

# ---------------------------------------------------------------------
# MODEL LOADING
# ---------------------------------------------------------------------
print("🚀 Loading model...")

try:
    # Try loading as a full model
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        MODEL_NAME,
        torch_dtype=dtype,
        device_map=device,
        trust_remote_code=True
    )
    print(f"✅ Loaded main model: {MODEL_NAME}")

except Exception as e:
    print(f"⚠️ Direct load failed: {e}")
    print("➡️ Trying as LoRA/PEFT adapter...")
    base = Qwen2VLForConditionalGeneration.from_pretrained(
        BASE_MODEL,
        torch_dtype=dtype,
        device_map=device,
        trust_remote_code=True
    )
    model = PeftModel.from_pretrained(base, MODEL_NAME)
    print(f"✅ Loaded LoRA adapter on base model: {BASE_MODEL}")

# load processor
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
model.to(device)
model.eval()

# ---------------------------------------------------------------------
# OCR ENDPOINT
# ---------------------------------------------------------------------
@app.post("/ocr")
async def ocr(file: UploadFile = File(...)):
    try:
        # Load image
        image_bytes = await file.read()
        img = Image.open(io.BytesIO(image_bytes)).convert("RGB")

        # Prompt
        prompt = (
            "Below is an image of one page of a document. "
            "Return its natural text representation accurately, without hallucination."
        )

        # Format message
        messages = [{
            "role": "user",
            "content": [
                {"type": "image", "image": img},
                {"type": "text", "text": prompt}
            ]
        }]

        # Prepare model inputs
        text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = processor(
            text=[text_prompt],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt"
        ).to(device)

        # Generate output
        with torch.no_grad():
            gen_ids = model.generate(**inputs, max_new_tokens=2000)

        # Clean up output
        trimmed_ids = [o[len(i):] for i, o in zip(inputs.input_ids, gen_ids)]
        result = processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

        return JSONResponse({"text": result})

    except Exception as e:
        return JSONResponse({"error": str(e)}, status_code=500)

# ---------------------------------------------------------------------
# MAIN ENTRY (for local debugging)
# ---------------------------------------------------------------------
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))