saeid1999 commited on
Commit
6ebad2b
·
verified ·
1 Parent(s): bec3387

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -62
app.py CHANGED
@@ -1,82 +1,177 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  import torch
 
 
 
 
 
 
 
5
 
6
- app = FastAPI()
7
-
8
- MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
-
11
- print(f"Loading model on {device}...")
12
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
13
- model = AutoModelForCausalLM.from_pretrained(
14
- MODEL_NAME,
15
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
16
- device_map="auto",
17
- trust_remote_code=True
18
  )
19
- print("Model loaded!")
20
 
21
- class ChatRequest(BaseModel):
22
- message: str
23
- max_tokens: int = 512
24
- temperature: float = 0.7
 
25
 
26
- class CompletionRequest(BaseModel):
27
- messages: list
28
- max_tokens: int = 512
29
- temperature: float = 0.7
30
- stream: bool = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- @app.post("/chat")
33
- def chat(req: ChatRequest):
34
- messages = [{"role": "user", "content": req.message}]
35
- text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
36
- inputs = tokenizer([text], return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- outputs = model.generate(
39
- **inputs,
40
- max_new_tokens=req.max_tokens,
41
- temperature=req.temperature,
42
- do_sample=True
43
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
46
- response = response.split("assistant\n")[-1].strip()
 
47
 
48
- return {"response": response}
49
 
50
- @app.post("/v1/chat/completions")
51
- def completions(req: CompletionRequest):
52
- text = tokenizer.apply_chat_template(req.messages, tokenize=False, add_generation_prompt=True)
53
- inputs = tokenizer([text], return_tensors="pt").to(device)
54
 
55
- outputs = model.generate(
56
- **inputs,
57
- max_new_tokens=req.max_tokens,
58
- temperature=req.temperature,
59
- do_sample=True
60
- )
 
 
 
 
61
 
62
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
63
- response = response.split("assistant\n")[-1].strip()
 
 
 
 
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  return {
66
- "choices": [{
67
- "message": {"role": "assistant", "content": response},
68
- "finish_reason": "stop"
69
- }]
 
 
70
  }
71
 
72
  @app.get("/health")
73
- def health():
74
- return {"status": "ok"}
75
-
76
-
 
 
77
 
78
  if __name__ == "__main__":
79
- import uvicorn
80
- uvicorn.run("app:app", host="0.0.0.0", port=7860)
81
-
82
-
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from PIL import Image
4
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
5
  import torch
6
+ from qwen_vl_utils import process_vision_info
7
+ import fitz # PyMuPDF
8
+ import io
9
+ import os
10
+ from typing import List, Dict
11
+ import tempfile
12
+ import uvicorn
13
 
14
+ app = FastAPI(
15
+ title="PDF OCR API",
16
+ description="Convert PDF to images and extract text using Qari OCR model",
17
+ version="1.0.0"
 
 
 
 
 
 
 
 
18
  )
 
19
 
20
+ # Initialize model and processor
21
+ model_name = "NAMAA-Space/Qari-OCR-0.2.2.1-VL-2B-Instruct"
22
+ model = None
23
+ processor = None
24
+ max_tokens = 2000
25
 
26
+ @app.on_event("startup")
27
+ async def load_model():
28
+ """Load model on startup"""
29
+ global model, processor
30
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
31
+ model_name,
32
+ torch_dtype="auto",
33
+ device_map="auto"
34
+ )
35
+ processor = AutoProcessor.from_pretrained(model_name)
36
+ print("Model loaded successfully!")
37
+
38
+ def pdf_to_images(pdf_bytes: bytes) -> List[Image.Image]:
39
+ """Convert PDF pages to PIL Images"""
40
+ images = []
41
+ pdf_document = fitz.open(stream=pdf_bytes, filetype="pdf")
42
+
43
+ for page_num in range(len(pdf_document)):
44
+ page = pdf_document[page_num]
45
+ # Render page to image at 300 DPI for better quality
46
+ pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72))
47
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
48
+ images.append(img)
49
+
50
+ pdf_document.close()
51
+ return images
52
 
53
+ def process_image_ocr(image: Image.Image, temp_path: str) -> str:
54
+ """Process a single image with OCR"""
55
+ prompt = "Below is the image of one page of a document, as well as some raw textual content that was previously extracted for it. Just return the plain text representation of this document as if you were reading it naturally. Do not hallucinate."
56
+
57
+ # Save image temporarily
58
+ image.save(temp_path)
59
+
60
+ messages = [
61
+ {
62
+ "role": "user",
63
+ "content": [
64
+ {"type": "image", "image": f"file://{temp_path}"},
65
+ {"type": "text", "text": prompt},
66
+ ],
67
+ }
68
+ ]
69
 
70
+ text = processor.apply_chat_template(
71
+ messages, tokenize=False, add_generation_prompt=True
 
 
 
72
  )
73
+ image_inputs, video_inputs = process_vision_info(messages)
74
+ inputs = processor(
75
+ text=[text],
76
+ images=image_inputs,
77
+ videos=video_inputs,
78
+ padding=True,
79
+ return_tensors="pt",
80
+ )
81
+ inputs = inputs.to(model.device)
82
+
83
+ generated_ids = model.generate(**inputs, max_new_tokens=max_tokens)
84
+ generated_ids_trimmed = [
85
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
86
+ ]
87
+ output_text = processor.batch_decode(
88
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
89
+ )[0]
90
 
91
+ # Clean up temp file
92
+ if os.path.exists(temp_path):
93
+ os.remove(temp_path)
94
 
95
+ return output_text
96
 
97
+ @app.post("/ocr", response_model=Dict)
98
+ async def extract_text_from_pdf(file: UploadFile = File(...)):
99
+ """
100
+ Extract text from PDF file using OCR
101
 
102
+ - **file**: PDF file (max 5MB)
103
+
104
+ Returns JSON with extracted text for each page
105
+ """
106
+ # Validate file type
107
+ if not file.filename.endswith('.pdf'):
108
+ raise HTTPException(status_code=400, detail="Only PDF files are allowed")
109
+
110
+ # Read file content
111
+ pdf_bytes = await file.read()
112
 
113
+ # Check file size (5MB limit)
114
+ file_size = len(pdf_bytes)
115
+ if file_size > 5 * 1024 * 1024: # 5MB in bytes
116
+ raise HTTPException(
117
+ status_code=400,
118
+ detail=f"File size ({file_size / (1024*1024):.2f}MB) exceeds 5MB limit"
119
+ )
120
 
121
+ try:
122
+ # Convert PDF to images
123
+ images = pdf_to_images(pdf_bytes)
124
+
125
+ if not images:
126
+ raise HTTPException(status_code=400, detail="No pages found in PDF")
127
+
128
+ # Process each page
129
+ results = []
130
+ with tempfile.TemporaryDirectory() as temp_dir:
131
+ for i, image in enumerate(images):
132
+ temp_path = os.path.join(temp_dir, f"page_{i}.png")
133
+ try:
134
+ ocr_text = process_image_ocr(image, temp_path)
135
+ results.append({
136
+ "page": i + 1,
137
+ "text": ocr_text
138
+ })
139
+ except Exception as e:
140
+ results.append({
141
+ "page": i + 1,
142
+ "error": str(e)
143
+ })
144
+
145
+ return {
146
+ "success": True,
147
+ "filename": file.filename,
148
+ "total_pages": len(images),
149
+ "file_size_mb": round(file_size / (1024*1024), 2),
150
+ "results": results
151
+ }
152
+
153
+ except Exception as e:
154
+ raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
155
+
156
+ @app.get("/")
157
+ async def root():
158
+ """API status endpoint"""
159
  return {
160
+ "message": "PDF OCR API is running",
161
+ "model": model_name,
162
+ "max_file_size": "5MB",
163
+ "endpoints": {
164
+ "POST /ocr": "Extract text from PDF"
165
+ }
166
  }
167
 
168
  @app.get("/health")
169
+ async def health_check():
170
+ """Health check endpoint"""
171
+ return {
172
+ "status": "healthy",
173
+ "model_loaded": model is not None
174
+ }
175
 
176
  if __name__ == "__main__":
177
+ uvicorn.run(app, host="0.0.0.0", port=7860)