Update app.py
Browse files
app.py
CHANGED
|
@@ -1,82 +1,177 @@
|
|
| 1 |
-
from fastapi import FastAPI
|
| 2 |
-
from
|
| 3 |
-
from
|
|
|
|
| 4 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
app = FastAPI(
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
print(f"Loading model on {device}...")
|
| 12 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
| 13 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 14 |
-
MODEL_NAME,
|
| 15 |
-
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 16 |
-
device_map="auto",
|
| 17 |
-
trust_remote_code=True
|
| 18 |
)
|
| 19 |
-
print("Model loaded!")
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
max_new_tokens=req.max_tokens,
|
| 41 |
-
temperature=req.temperature,
|
| 42 |
-
do_sample=True
|
| 43 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
|
|
|
| 47 |
|
| 48 |
-
return
|
| 49 |
|
| 50 |
-
@app.post("/
|
| 51 |
-
def
|
| 52 |
-
|
| 53 |
-
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
return {
|
| 66 |
-
"
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
| 70 |
}
|
| 71 |
|
| 72 |
@app.get("/health")
|
| 73 |
-
def
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
| 77 |
|
| 78 |
if __name__ == "__main__":
|
| 79 |
-
|
| 80 |
-
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|
| 81 |
-
|
| 82 |
-
|
|
|
|
| 1 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
| 5 |
import torch
|
| 6 |
+
from qwen_vl_utils import process_vision_info
|
| 7 |
+
import fitz # PyMuPDF
|
| 8 |
+
import io
|
| 9 |
+
import os
|
| 10 |
+
from typing import List, Dict
|
| 11 |
+
import tempfile
|
| 12 |
+
import uvicorn
|
| 13 |
|
| 14 |
+
app = FastAPI(
|
| 15 |
+
title="PDF OCR API",
|
| 16 |
+
description="Convert PDF to images and extract text using Qari OCR model",
|
| 17 |
+
version="1.0.0"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
)
|
|
|
|
| 19 |
|
| 20 |
+
# Initialize model and processor
|
| 21 |
+
model_name = "NAMAA-Space/Qari-OCR-0.2.2.1-VL-2B-Instruct"
|
| 22 |
+
model = None
|
| 23 |
+
processor = None
|
| 24 |
+
max_tokens = 2000
|
| 25 |
|
| 26 |
+
@app.on_event("startup")
|
| 27 |
+
async def load_model():
|
| 28 |
+
"""Load model on startup"""
|
| 29 |
+
global model, processor
|
| 30 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 31 |
+
model_name,
|
| 32 |
+
torch_dtype="auto",
|
| 33 |
+
device_map="auto"
|
| 34 |
+
)
|
| 35 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
| 36 |
+
print("Model loaded successfully!")
|
| 37 |
+
|
| 38 |
+
def pdf_to_images(pdf_bytes: bytes) -> List[Image.Image]:
|
| 39 |
+
"""Convert PDF pages to PIL Images"""
|
| 40 |
+
images = []
|
| 41 |
+
pdf_document = fitz.open(stream=pdf_bytes, filetype="pdf")
|
| 42 |
+
|
| 43 |
+
for page_num in range(len(pdf_document)):
|
| 44 |
+
page = pdf_document[page_num]
|
| 45 |
+
# Render page to image at 300 DPI for better quality
|
| 46 |
+
pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72))
|
| 47 |
+
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
| 48 |
+
images.append(img)
|
| 49 |
+
|
| 50 |
+
pdf_document.close()
|
| 51 |
+
return images
|
| 52 |
|
| 53 |
+
def process_image_ocr(image: Image.Image, temp_path: str) -> str:
|
| 54 |
+
"""Process a single image with OCR"""
|
| 55 |
+
prompt = "Below is the image of one page of a document, as well as some raw textual content that was previously extracted for it. Just return the plain text representation of this document as if you were reading it naturally. Do not hallucinate."
|
| 56 |
+
|
| 57 |
+
# Save image temporarily
|
| 58 |
+
image.save(temp_path)
|
| 59 |
+
|
| 60 |
+
messages = [
|
| 61 |
+
{
|
| 62 |
+
"role": "user",
|
| 63 |
+
"content": [
|
| 64 |
+
{"type": "image", "image": f"file://{temp_path}"},
|
| 65 |
+
{"type": "text", "text": prompt},
|
| 66 |
+
],
|
| 67 |
+
}
|
| 68 |
+
]
|
| 69 |
|
| 70 |
+
text = processor.apply_chat_template(
|
| 71 |
+
messages, tokenize=False, add_generation_prompt=True
|
|
|
|
|
|
|
|
|
|
| 72 |
)
|
| 73 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 74 |
+
inputs = processor(
|
| 75 |
+
text=[text],
|
| 76 |
+
images=image_inputs,
|
| 77 |
+
videos=video_inputs,
|
| 78 |
+
padding=True,
|
| 79 |
+
return_tensors="pt",
|
| 80 |
+
)
|
| 81 |
+
inputs = inputs.to(model.device)
|
| 82 |
+
|
| 83 |
+
generated_ids = model.generate(**inputs, max_new_tokens=max_tokens)
|
| 84 |
+
generated_ids_trimmed = [
|
| 85 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 86 |
+
]
|
| 87 |
+
output_text = processor.batch_decode(
|
| 88 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 89 |
+
)[0]
|
| 90 |
|
| 91 |
+
# Clean up temp file
|
| 92 |
+
if os.path.exists(temp_path):
|
| 93 |
+
os.remove(temp_path)
|
| 94 |
|
| 95 |
+
return output_text
|
| 96 |
|
| 97 |
+
@app.post("/ocr", response_model=Dict)
|
| 98 |
+
async def extract_text_from_pdf(file: UploadFile = File(...)):
|
| 99 |
+
"""
|
| 100 |
+
Extract text from PDF file using OCR
|
| 101 |
|
| 102 |
+
- **file**: PDF file (max 5MB)
|
| 103 |
+
|
| 104 |
+
Returns JSON with extracted text for each page
|
| 105 |
+
"""
|
| 106 |
+
# Validate file type
|
| 107 |
+
if not file.filename.endswith('.pdf'):
|
| 108 |
+
raise HTTPException(status_code=400, detail="Only PDF files are allowed")
|
| 109 |
+
|
| 110 |
+
# Read file content
|
| 111 |
+
pdf_bytes = await file.read()
|
| 112 |
|
| 113 |
+
# Check file size (5MB limit)
|
| 114 |
+
file_size = len(pdf_bytes)
|
| 115 |
+
if file_size > 5 * 1024 * 1024: # 5MB in bytes
|
| 116 |
+
raise HTTPException(
|
| 117 |
+
status_code=400,
|
| 118 |
+
detail=f"File size ({file_size / (1024*1024):.2f}MB) exceeds 5MB limit"
|
| 119 |
+
)
|
| 120 |
|
| 121 |
+
try:
|
| 122 |
+
# Convert PDF to images
|
| 123 |
+
images = pdf_to_images(pdf_bytes)
|
| 124 |
+
|
| 125 |
+
if not images:
|
| 126 |
+
raise HTTPException(status_code=400, detail="No pages found in PDF")
|
| 127 |
+
|
| 128 |
+
# Process each page
|
| 129 |
+
results = []
|
| 130 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
| 131 |
+
for i, image in enumerate(images):
|
| 132 |
+
temp_path = os.path.join(temp_dir, f"page_{i}.png")
|
| 133 |
+
try:
|
| 134 |
+
ocr_text = process_image_ocr(image, temp_path)
|
| 135 |
+
results.append({
|
| 136 |
+
"page": i + 1,
|
| 137 |
+
"text": ocr_text
|
| 138 |
+
})
|
| 139 |
+
except Exception as e:
|
| 140 |
+
results.append({
|
| 141 |
+
"page": i + 1,
|
| 142 |
+
"error": str(e)
|
| 143 |
+
})
|
| 144 |
+
|
| 145 |
+
return {
|
| 146 |
+
"success": True,
|
| 147 |
+
"filename": file.filename,
|
| 148 |
+
"total_pages": len(images),
|
| 149 |
+
"file_size_mb": round(file_size / (1024*1024), 2),
|
| 150 |
+
"results": results
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
except Exception as e:
|
| 154 |
+
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
|
| 155 |
+
|
| 156 |
+
@app.get("/")
|
| 157 |
+
async def root():
|
| 158 |
+
"""API status endpoint"""
|
| 159 |
return {
|
| 160 |
+
"message": "PDF OCR API is running",
|
| 161 |
+
"model": model_name,
|
| 162 |
+
"max_file_size": "5MB",
|
| 163 |
+
"endpoints": {
|
| 164 |
+
"POST /ocr": "Extract text from PDF"
|
| 165 |
+
}
|
| 166 |
}
|
| 167 |
|
| 168 |
@app.get("/health")
|
| 169 |
+
async def health_check():
|
| 170 |
+
"""Health check endpoint"""
|
| 171 |
+
return {
|
| 172 |
+
"status": "healthy",
|
| 173 |
+
"model_loaded": model is not None
|
| 174 |
+
}
|
| 175 |
|
| 176 |
if __name__ == "__main__":
|
| 177 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
|
|
|
|
|