qwen3 / app.py
saeid1999's picture
Update app.py
817321f verified
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from peft import PeftModel
import torch
from qwen_vl_utils import process_vision_info
import io
import os
app = FastAPI(title="Qwen OCR API")
# ---------------------------------------------------------------------
# CONFIGURATION
# ---------------------------------------------------------------------
MODEL_NAME = "NAMAA-Space/Qari-OCR-0.2.2.1-VL-2B-Instruct" # your fine-tuned model
BASE_MODEL = "Qwen/Qwen2.5-VL-2B-Instruct" # base model for LoRA fallback
# detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# ---------------------------------------------------------------------
# MODEL LOADING
# ---------------------------------------------------------------------
print("🚀 Loading model...")
try:
# Try loading as a full model
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
device_map=device,
trust_remote_code=True
)
print(f"✅ Loaded main model: {MODEL_NAME}")
except Exception as e:
print(f"⚠️ Direct load failed: {e}")
print("➡️ Trying as LoRA/PEFT adapter...")
base = Qwen2VLForConditionalGeneration.from_pretrained(
BASE_MODEL,
torch_dtype=dtype,
device_map=device,
trust_remote_code=True
)
model = PeftModel.from_pretrained(base, MODEL_NAME)
print(f"✅ Loaded LoRA adapter on base model: {BASE_MODEL}")
# load processor
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
model.to(device)
model.eval()
# ---------------------------------------------------------------------
# OCR ENDPOINT
# ---------------------------------------------------------------------
@app.post("/ocr")
async def ocr(file: UploadFile = File(...)):
try:
# Load image
image_bytes = await file.read()
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Prompt
prompt = (
"Below is an image of one page of a document. "
"Return its natural text representation accurately, without hallucination."
)
# Format message
messages = [{
"role": "user",
"content": [
{"type": "image", "image": img},
{"type": "text", "text": prompt}
]
}]
# Prepare model inputs
text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text_prompt],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
).to(device)
# Generate output
with torch.no_grad():
gen_ids = model.generate(**inputs, max_new_tokens=2000)
# Clean up output
trimmed_ids = [o[len(i):] for i, o in zip(inputs.input_ids, gen_ids)]
result = processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return JSONResponse({"text": result})
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
# ---------------------------------------------------------------------
# MAIN ENTRY (for local debugging)
# ---------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))