Update app.py
Browse files
app.py
CHANGED
|
@@ -3,196 +3,218 @@ from fastapi.responses import JSONResponse
|
|
| 3 |
from PIL import Image
|
| 4 |
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
| 5 |
import torch
|
|
|
|
| 6 |
from qwen_vl_utils import process_vision_info
|
| 7 |
-
import fitz # PyMuPDF
|
| 8 |
import io
|
| 9 |
-
import
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
from contextlib import asynccontextmanager
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
model_name = "NAMAA-Space/Qari-OCR-0.2.2.1-VL-2B-Instruct"
|
| 17 |
model = None
|
| 18 |
processor = None
|
| 19 |
max_tokens = 2000
|
| 20 |
|
| 21 |
-
@
|
| 22 |
-
async def
|
| 23 |
-
"""Load model on startup"""
|
| 24 |
global model, processor
|
| 25 |
|
| 26 |
-
|
| 27 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
-
print(f"Using device: {device}")
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
)
|
| 37 |
-
else:
|
| 38 |
-
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 39 |
-
model_name,
|
| 40 |
-
torch_dtype="auto",
|
| 41 |
-
device_map="auto"
|
| 42 |
-
)
|
| 43 |
|
|
|
|
| 44 |
processor = AutoProcessor.from_pretrained(model_name)
|
| 45 |
-
print("Model loaded successfully!")
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# Cleanup on shutdown
|
| 50 |
-
print("Shutting down...")
|
| 51 |
-
|
| 52 |
-
app = FastAPI(
|
| 53 |
-
title="PDF OCR API",
|
| 54 |
-
description="Convert PDF to images and extract text using Qari OCR model",
|
| 55 |
-
version="1.0.0",
|
| 56 |
-
lifespan=lifespan
|
| 57 |
-
)
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72))
|
| 68 |
-
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
| 69 |
-
images.append(img)
|
| 70 |
-
|
| 71 |
-
pdf_document.close()
|
| 72 |
-
return images
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
|
|
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
{"type": "text", "text": prompt},
|
| 87 |
-
],
|
| 88 |
-
}
|
| 89 |
-
]
|
| 90 |
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
image_inputs, video_inputs = process_vision_info(messages)
|
| 95 |
-
inputs = processor(
|
| 96 |
-
text=[text],
|
| 97 |
-
images=image_inputs,
|
| 98 |
-
videos=video_inputs,
|
| 99 |
-
padding=True,
|
| 100 |
-
return_tensors="pt",
|
| 101 |
-
)
|
| 102 |
-
inputs = inputs.to(model.device)
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 107 |
-
]
|
| 108 |
-
output_text = processor.batch_decode(
|
| 109 |
-
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 110 |
-
)[0]
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
-
@app.post("/ocr"
|
| 119 |
-
async def
|
| 120 |
"""
|
| 121 |
-
|
| 122 |
|
| 123 |
-
|
|
|
|
| 124 |
|
| 125 |
-
Returns
|
|
|
|
| 126 |
"""
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
)
|
| 141 |
-
|
| 142 |
-
try:
|
| 143 |
-
# Convert PDF to images
|
| 144 |
-
images = pdf_to_images(pdf_bytes)
|
| 145 |
|
| 146 |
-
|
| 147 |
-
raise HTTPException(status_code=400, detail="No pages found in PDF")
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
"total_pages": len(images),
|
| 170 |
-
"file_size_mb": round(file_size / (1024*1024), 2),
|
| 171 |
-
"results": results
|
| 172 |
-
}
|
| 173 |
|
| 174 |
-
|
| 175 |
-
raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
|
| 176 |
|
| 177 |
-
@app.get("/")
|
| 178 |
-
async def root():
|
| 179 |
-
"""API status endpoint"""
|
| 180 |
-
return {
|
| 181 |
-
"message": "PDF OCR API is running",
|
| 182 |
-
"model": model_name,
|
| 183 |
-
"max_file_size": "5MB",
|
| 184 |
-
"endpoints": {
|
| 185 |
-
"POST /ocr": "Extract text from PDF"
|
| 186 |
-
}
|
| 187 |
-
}
|
| 188 |
-
|
| 189 |
-
@app.get("/health")
|
| 190 |
-
async def health_check():
|
| 191 |
-
"""Health check endpoint"""
|
| 192 |
-
return {
|
| 193 |
-
"status": "healthy",
|
| 194 |
-
"model_loaded": model is not None
|
| 195 |
-
}
|
| 196 |
|
| 197 |
if __name__ == "__main__":
|
| 198 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 3 |
from PIL import Image
|
| 4 |
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
| 5 |
import torch
|
| 6 |
+
import os
|
| 7 |
from qwen_vl_utils import process_vision_info
|
|
|
|
| 8 |
import io
|
| 9 |
+
from typing import Optional
|
| 10 |
+
import uuid
|
| 11 |
+
|
| 12 |
+
app = FastAPI(title="Qari OCR API", description="OCR API using Qwen2VL model")
|
|
|
|
| 13 |
|
| 14 |
+
# Global variables for model and processor
|
|
|
|
| 15 |
model = None
|
| 16 |
processor = None
|
| 17 |
max_tokens = 2000
|
| 18 |
|
| 19 |
+
@app.on_event("startup")
|
| 20 |
+
async def load_model():
|
| 21 |
+
"""Load model and processor on startup"""
|
| 22 |
global model, processor
|
| 23 |
|
| 24 |
+
model_name = "NAMAA-Space/Qari-OCR-0.2.2.1-VL-2B-Instruct"
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
print("Loading model...")
|
| 27 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 28 |
+
model_name,
|
| 29 |
+
torch_dtype="auto",
|
| 30 |
+
device_map="auto"
|
| 31 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
print("Loading processor...")
|
| 34 |
processor = AutoProcessor.from_pretrained(model_name)
|
|
|
|
| 35 |
|
| 36 |
+
print("Model and processor loaded successfully!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
@app.get("/")
|
| 39 |
+
async def root():
|
| 40 |
+
"""Health check endpoint"""
|
| 41 |
+
return {
|
| 42 |
+
"status": "running",
|
| 43 |
+
"model": "NAMAA-Space/Qari-OCR-0.2.2.1-VL-2B-Instruct",
|
| 44 |
+
"message": "OCR API is ready"
|
| 45 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
@app.post("/ocr")
|
| 48 |
+
async def perform_ocr(
|
| 49 |
+
file: UploadFile = File(...),
|
| 50 |
+
prompt: Optional[str] = None
|
| 51 |
+
):
|
| 52 |
+
"""
|
| 53 |
+
Perform OCR on uploaded image
|
| 54 |
|
| 55 |
+
Args:
|
| 56 |
+
file: Image file (PNG, JPG, JPEG)
|
| 57 |
+
prompt: Optional custom prompt (defaults to standard OCR prompt)
|
| 58 |
|
| 59 |
+
Returns:
|
| 60 |
+
JSON with extracted text
|
| 61 |
+
"""
|
| 62 |
+
if model is None or processor is None:
|
| 63 |
+
raise HTTPException(status_code=503, detail="Model not loaded yet")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
# Validate file type
|
| 66 |
+
if not file.content_type.startswith("image/"):
|
| 67 |
+
raise HTTPException(status_code=400, detail="File must be an image")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
+
# Generate unique filename
|
| 70 |
+
temp_filename = f"temp_{uuid.uuid4()}.png"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
+
try:
|
| 73 |
+
# Read and save image
|
| 74 |
+
contents = await file.read()
|
| 75 |
+
image = Image.open(io.BytesIO(contents))
|
| 76 |
+
image.save(temp_filename)
|
| 77 |
+
|
| 78 |
+
# Default prompt if not provided
|
| 79 |
+
if prompt is None:
|
| 80 |
+
prompt = "Below is the image of one page of a document, as well as some raw textual content that was previously extracted for it. Just return the plain text representation of this document as if you were reading it naturally. Do not hallucinate."
|
| 81 |
+
|
| 82 |
+
# Prepare messages
|
| 83 |
+
messages = [
|
| 84 |
+
{
|
| 85 |
+
"role": "user",
|
| 86 |
+
"content": [
|
| 87 |
+
{"type": "image", "image": f"file://{temp_filename}"},
|
| 88 |
+
{"type": "text", "text": prompt},
|
| 89 |
+
],
|
| 90 |
+
}
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
# Process inputs
|
| 94 |
+
text = processor.apply_chat_template(
|
| 95 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 96 |
+
)
|
| 97 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 98 |
+
inputs = processor(
|
| 99 |
+
text=[text],
|
| 100 |
+
images=image_inputs,
|
| 101 |
+
videos=video_inputs,
|
| 102 |
+
padding=True,
|
| 103 |
+
return_tensors="pt",
|
| 104 |
+
)
|
| 105 |
+
inputs = inputs.to("cuda")
|
| 106 |
+
|
| 107 |
+
# Generate output
|
| 108 |
+
generated_ids = model.generate(**inputs, max_new_tokens=max_tokens)
|
| 109 |
+
generated_ids_trimmed = [
|
| 110 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 111 |
+
]
|
| 112 |
+
output_text = processor.batch_decode(
|
| 113 |
+
generated_ids_trimmed,
|
| 114 |
+
skip_special_tokens=True,
|
| 115 |
+
clean_up_tokenization_spaces=False
|
| 116 |
+
)[0]
|
| 117 |
+
|
| 118 |
+
return JSONResponse(content={
|
| 119 |
+
"success": True,
|
| 120 |
+
"text": output_text,
|
| 121 |
+
"filename": file.filename
|
| 122 |
+
})
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
| 126 |
|
| 127 |
+
finally:
|
| 128 |
+
# Clean up temporary file
|
| 129 |
+
if os.path.exists(temp_filename):
|
| 130 |
+
os.remove(temp_filename)
|
| 131 |
|
| 132 |
+
@app.post("/ocr-batch")
|
| 133 |
+
async def perform_ocr_batch(files: list[UploadFile] = File(...)):
|
| 134 |
"""
|
| 135 |
+
Perform OCR on multiple images
|
| 136 |
|
| 137 |
+
Args:
|
| 138 |
+
files: List of image files
|
| 139 |
|
| 140 |
+
Returns:
|
| 141 |
+
JSON with extracted text for each image
|
| 142 |
"""
|
| 143 |
+
if model is None or processor is None:
|
| 144 |
+
raise HTTPException(status_code=503, detail="Model not loaded yet")
|
| 145 |
+
|
| 146 |
+
results = []
|
| 147 |
+
|
| 148 |
+
for file in files:
|
| 149 |
+
if not file.content_type.startswith("image/"):
|
| 150 |
+
results.append({
|
| 151 |
+
"filename": file.filename,
|
| 152 |
+
"success": False,
|
| 153 |
+
"error": "File must be an image"
|
| 154 |
+
})
|
| 155 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
+
temp_filename = f"temp_{uuid.uuid4()}.png"
|
|
|
|
| 158 |
|
| 159 |
+
try:
|
| 160 |
+
contents = await file.read()
|
| 161 |
+
image = Image.open(io.BytesIO(contents))
|
| 162 |
+
image.save(temp_filename)
|
| 163 |
+
|
| 164 |
+
prompt = "Below is the image of one page of a document, as well as some raw textual content that was previously extracted for it. Just return the plain text representation of this document as if you were reading it naturally. Do not hallucinate."
|
| 165 |
+
|
| 166 |
+
messages = [
|
| 167 |
+
{
|
| 168 |
+
"role": "user",
|
| 169 |
+
"content": [
|
| 170 |
+
{"type": "image", "image": f"file://{temp_filename}"},
|
| 171 |
+
{"type": "text", "text": prompt},
|
| 172 |
+
],
|
| 173 |
+
}
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
text = processor.apply_chat_template(
|
| 177 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 178 |
+
)
|
| 179 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
| 180 |
+
inputs = processor(
|
| 181 |
+
text=[text],
|
| 182 |
+
images=image_inputs,
|
| 183 |
+
videos=video_inputs,
|
| 184 |
+
padding=True,
|
| 185 |
+
return_tensors="pt",
|
| 186 |
+
)
|
| 187 |
+
inputs = inputs.to("cuda")
|
| 188 |
+
|
| 189 |
+
generated_ids = model.generate(**inputs, max_new_tokens=max_tokens)
|
| 190 |
+
generated_ids_trimmed = [
|
| 191 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 192 |
+
]
|
| 193 |
+
output_text = processor.batch_decode(
|
| 194 |
+
generated_ids_trimmed,
|
| 195 |
+
skip_special_tokens=True,
|
| 196 |
+
clean_up_tokenization_spaces=False
|
| 197 |
+
)[0]
|
| 198 |
+
|
| 199 |
+
results.append({
|
| 200 |
+
"filename": file.filename,
|
| 201 |
+
"success": True,
|
| 202 |
+
"text": output_text
|
| 203 |
+
})
|
| 204 |
+
|
| 205 |
+
except Exception as e:
|
| 206 |
+
results.append({
|
| 207 |
+
"filename": file.filename,
|
| 208 |
+
"success": False,
|
| 209 |
+
"error": str(e)
|
| 210 |
+
})
|
| 211 |
|
| 212 |
+
finally:
|
| 213 |
+
if os.path.exists(temp_filename):
|
| 214 |
+
os.remove(temp_filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
+
return JSONResponse(content={"results": results})
|
|
|
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
if __name__ == "__main__":
|
| 220 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|