|
|
from fastapi import FastAPI, File, UploadFile |
|
|
from fastapi.responses import JSONResponse |
|
|
from PIL import Image |
|
|
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor |
|
|
from peft import PeftModel |
|
|
import torch |
|
|
from qwen_vl_utils import process_vision_info |
|
|
import io |
|
|
import os |
|
|
|
|
|
app = FastAPI(title="Qwen OCR API") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = "NAMAA-Space/Qari-OCR-0.2.2.1-VL-2B-Instruct" |
|
|
BASE_MODEL = "Qwen/Qwen2.5-VL-2B-Instruct" |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("🚀 Loading model...") |
|
|
|
|
|
try: |
|
|
|
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
|
MODEL_NAME, |
|
|
torch_dtype=dtype, |
|
|
device_map=device, |
|
|
trust_remote_code=True |
|
|
) |
|
|
print(f"✅ Loaded main model: {MODEL_NAME}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Direct load failed: {e}") |
|
|
print("➡️ Trying as LoRA/PEFT adapter...") |
|
|
base = Qwen2VLForConditionalGeneration.from_pretrained( |
|
|
BASE_MODEL, |
|
|
torch_dtype=dtype, |
|
|
device_map=device, |
|
|
trust_remote_code=True |
|
|
) |
|
|
model = PeftModel.from_pretrained(base, MODEL_NAME) |
|
|
print(f"✅ Loaded LoRA adapter on base model: {BASE_MODEL}") |
|
|
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/ocr") |
|
|
async def ocr(file: UploadFile = File(...)): |
|
|
try: |
|
|
|
|
|
image_bytes = await file.read() |
|
|
img = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
|
|
|
prompt = ( |
|
|
"Below is an image of one page of a document. " |
|
|
"Return its natural text representation accurately, without hallucination." |
|
|
) |
|
|
|
|
|
|
|
|
messages = [{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": img}, |
|
|
{"type": "text", "text": prompt} |
|
|
] |
|
|
}] |
|
|
|
|
|
|
|
|
text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
inputs = processor( |
|
|
text=[text_prompt], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt" |
|
|
).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
gen_ids = model.generate(**inputs, max_new_tokens=2000) |
|
|
|
|
|
|
|
|
trimmed_ids = [o[len(i):] for i, o in zip(inputs.input_ids, gen_ids)] |
|
|
result = processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
|
|
|
|
return JSONResponse({"text": result}) |
|
|
|
|
|
except Exception as e: |
|
|
return JSONResponse({"error": str(e)}, status_code=500) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run("app:app", host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) |
|
|
|