File size: 3,774 Bytes
3520feb 6ebad2b 817321f 31433e6 6ebad2b 817321f 3bdb5e3 817321f 31433e6 817321f 31433e6 817321f 3bdb5e3 3520feb 817321f 254976d 817321f 254976d 817321f 254976d 817321f 254976d 817321f 254976d 817321f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from peft import PeftModel
import torch
from qwen_vl_utils import process_vision_info
import io
import os
app = FastAPI(title="Qwen OCR API")
# ---------------------------------------------------------------------
# CONFIGURATION
# ---------------------------------------------------------------------
MODEL_NAME = "NAMAA-Space/Qari-OCR-0.2.2.1-VL-2B-Instruct" # your fine-tuned model
BASE_MODEL = "Qwen/Qwen2.5-VL-2B-Instruct" # base model for LoRA fallback
# detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# ---------------------------------------------------------------------
# MODEL LOADING
# ---------------------------------------------------------------------
print("🚀 Loading model...")
try:
# Try loading as a full model
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
device_map=device,
trust_remote_code=True
)
print(f"✅ Loaded main model: {MODEL_NAME}")
except Exception as e:
print(f"⚠️ Direct load failed: {e}")
print("➡️ Trying as LoRA/PEFT adapter...")
base = Qwen2VLForConditionalGeneration.from_pretrained(
BASE_MODEL,
torch_dtype=dtype,
device_map=device,
trust_remote_code=True
)
model = PeftModel.from_pretrained(base, MODEL_NAME)
print(f"✅ Loaded LoRA adapter on base model: {BASE_MODEL}")
# load processor
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
model.to(device)
model.eval()
# ---------------------------------------------------------------------
# OCR ENDPOINT
# ---------------------------------------------------------------------
@app.post("/ocr")
async def ocr(file: UploadFile = File(...)):
try:
# Load image
image_bytes = await file.read()
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Prompt
prompt = (
"Below is an image of one page of a document. "
"Return its natural text representation accurately, without hallucination."
)
# Format message
messages = [{
"role": "user",
"content": [
{"type": "image", "image": img},
{"type": "text", "text": prompt}
]
}]
# Prepare model inputs
text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text_prompt],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
).to(device)
# Generate output
with torch.no_grad():
gen_ids = model.generate(**inputs, max_new_tokens=2000)
# Clean up output
trimmed_ids = [o[len(i):] for i, o in zip(inputs.input_ids, gen_ids)]
result = processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return JSONResponse({"text": result})
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
# ---------------------------------------------------------------------
# MAIN ENTRY (for local debugging)
# ---------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
|