saeid1999 commited on
Commit
3bdb5e3
·
verified ·
1 Parent(s): 2679a38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -159
app.py CHANGED
@@ -3,196 +3,218 @@ from fastapi.responses import JSONResponse
3
  from PIL import Image
4
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
5
  import torch
 
6
  from qwen_vl_utils import process_vision_info
7
- import fitz # PyMuPDF
8
  import io
9
- import os
10
- from typing import List, Dict
11
- import tempfile
12
- import uvicorn
13
- from contextlib import asynccontextmanager
14
 
15
- # Initialize model and processor
16
- model_name = "NAMAA-Space/Qari-OCR-0.2.2.1-VL-2B-Instruct"
17
  model = None
18
  processor = None
19
  max_tokens = 2000
20
 
21
- @asynccontextmanager
22
- async def lifespan(app: FastAPI):
23
- """Load model on startup"""
24
  global model, processor
25
 
26
- # Detect device
27
- device = "cuda" if torch.cuda.is_available() else "cpu"
28
- print(f"Using device: {device}")
29
 
30
- # Load model with appropriate settings for CPU or GPU
31
- if device == "cpu":
32
- model = Qwen2VLForConditionalGeneration.from_pretrained(
33
- model_name,
34
- torch_dtype=torch.float32,
35
- device_map="cpu"
36
- )
37
- else:
38
- model = Qwen2VLForConditionalGeneration.from_pretrained(
39
- model_name,
40
- torch_dtype="auto",
41
- device_map="auto"
42
- )
43
 
 
44
  processor = AutoProcessor.from_pretrained(model_name)
45
- print("Model loaded successfully!")
46
 
47
- yield
48
-
49
- # Cleanup on shutdown
50
- print("Shutting down...")
51
-
52
- app = FastAPI(
53
- title="PDF OCR API",
54
- description="Convert PDF to images and extract text using Qari OCR model",
55
- version="1.0.0",
56
- lifespan=lifespan
57
- )
58
 
59
- def pdf_to_images(pdf_bytes: bytes) -> List[Image.Image]:
60
- """Convert PDF pages to PIL Images"""
61
- images = []
62
- pdf_document = fitz.open(stream=pdf_bytes, filetype="pdf")
63
-
64
- for page_num in range(len(pdf_document)):
65
- page = pdf_document[page_num]
66
- # Render page to image at 300 DPI for better quality
67
- pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72))
68
- img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
69
- images.append(img)
70
-
71
- pdf_document.close()
72
- return images
73
 
74
- def process_image_ocr(image: Image.Image, temp_path: str) -> str:
75
- """Process a single image with OCR"""
76
- prompt = "Below is the image of one page of a document, as well as some raw textual content that was previously extracted for it. Just return the plain text representation of this document as if you were reading it naturally. Do not hallucinate."
 
 
 
 
77
 
78
- # Save image temporarily
79
- image.save(temp_path)
 
80
 
81
- messages = [
82
- {
83
- "role": "user",
84
- "content": [
85
- {"type": "image", "image": f"file://{temp_path}"},
86
- {"type": "text", "text": prompt},
87
- ],
88
- }
89
- ]
90
 
91
- text = processor.apply_chat_template(
92
- messages, tokenize=False, add_generation_prompt=True
93
- )
94
- image_inputs, video_inputs = process_vision_info(messages)
95
- inputs = processor(
96
- text=[text],
97
- images=image_inputs,
98
- videos=video_inputs,
99
- padding=True,
100
- return_tensors="pt",
101
- )
102
- inputs = inputs.to(model.device)
103
 
104
- generated_ids = model.generate(**inputs, max_new_tokens=max_tokens)
105
- generated_ids_trimmed = [
106
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
107
- ]
108
- output_text = processor.batch_decode(
109
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
110
- )[0]
111
 
112
- # Clean up temp file
113
- if os.path.exists(temp_path):
114
- os.remove(temp_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- return output_text
 
 
 
117
 
118
- @app.post("/ocr", response_model=Dict)
119
- async def extract_text_from_pdf(file: UploadFile = File(...)):
120
  """
121
- Extract text from PDF file using OCR
122
 
123
- - **file**: PDF file (max 5MB)
 
124
 
125
- Returns JSON with extracted text for each page
 
126
  """
127
- # Validate file type
128
- if not file.filename.endswith('.pdf'):
129
- raise HTTPException(status_code=400, detail="Only PDF files are allowed")
130
-
131
- # Read file content
132
- pdf_bytes = await file.read()
133
-
134
- # Check file size (5MB limit)
135
- file_size = len(pdf_bytes)
136
- if file_size > 5 * 1024 * 1024: # 5MB in bytes
137
- raise HTTPException(
138
- status_code=400,
139
- detail=f"File size ({file_size / (1024*1024):.2f}MB) exceeds 5MB limit"
140
- )
141
-
142
- try:
143
- # Convert PDF to images
144
- images = pdf_to_images(pdf_bytes)
145
 
146
- if not images:
147
- raise HTTPException(status_code=400, detail="No pages found in PDF")
148
 
149
- # Process each page
150
- results = []
151
- with tempfile.TemporaryDirectory() as temp_dir:
152
- for i, image in enumerate(images):
153
- temp_path = os.path.join(temp_dir, f"page_{i}.png")
154
- try:
155
- ocr_text = process_image_ocr(image, temp_path)
156
- results.append({
157
- "page": i + 1,
158
- "text": ocr_text
159
- })
160
- except Exception as e:
161
- results.append({
162
- "page": i + 1,
163
- "error": str(e)
164
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- return {
167
- "success": True,
168
- "filename": file.filename,
169
- "total_pages": len(images),
170
- "file_size_mb": round(file_size / (1024*1024), 2),
171
- "results": results
172
- }
173
 
174
- except Exception as e:
175
- raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
176
 
177
- @app.get("/")
178
- async def root():
179
- """API status endpoint"""
180
- return {
181
- "message": "PDF OCR API is running",
182
- "model": model_name,
183
- "max_file_size": "5MB",
184
- "endpoints": {
185
- "POST /ocr": "Extract text from PDF"
186
- }
187
- }
188
-
189
- @app.get("/health")
190
- async def health_check():
191
- """Health check endpoint"""
192
- return {
193
- "status": "healthy",
194
- "model_loaded": model is not None
195
- }
196
 
197
  if __name__ == "__main__":
198
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
3
  from PIL import Image
4
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
5
  import torch
6
+ import os
7
  from qwen_vl_utils import process_vision_info
 
8
  import io
9
+ from typing import Optional
10
+ import uuid
11
+
12
+ app = FastAPI(title="Qari OCR API", description="OCR API using Qwen2VL model")
 
13
 
14
+ # Global variables for model and processor
 
15
  model = None
16
  processor = None
17
  max_tokens = 2000
18
 
19
+ @app.on_event("startup")
20
+ async def load_model():
21
+ """Load model and processor on startup"""
22
  global model, processor
23
 
24
+ model_name = "NAMAA-Space/Qari-OCR-0.2.2.1-VL-2B-Instruct"
 
 
25
 
26
+ print("Loading model...")
27
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
28
+ model_name,
29
+ torch_dtype="auto",
30
+ device_map="auto"
31
+ )
 
 
 
 
 
 
 
32
 
33
+ print("Loading processor...")
34
  processor = AutoProcessor.from_pretrained(model_name)
 
35
 
36
+ print("Model and processor loaded successfully!")
 
 
 
 
 
 
 
 
 
 
37
 
38
+ @app.get("/")
39
+ async def root():
40
+ """Health check endpoint"""
41
+ return {
42
+ "status": "running",
43
+ "model": "NAMAA-Space/Qari-OCR-0.2.2.1-VL-2B-Instruct",
44
+ "message": "OCR API is ready"
45
+ }
 
 
 
 
 
 
46
 
47
+ @app.post("/ocr")
48
+ async def perform_ocr(
49
+ file: UploadFile = File(...),
50
+ prompt: Optional[str] = None
51
+ ):
52
+ """
53
+ Perform OCR on uploaded image
54
 
55
+ Args:
56
+ file: Image file (PNG, JPG, JPEG)
57
+ prompt: Optional custom prompt (defaults to standard OCR prompt)
58
 
59
+ Returns:
60
+ JSON with extracted text
61
+ """
62
+ if model is None or processor is None:
63
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
 
 
 
 
64
 
65
+ # Validate file type
66
+ if not file.content_type.startswith("image/"):
67
+ raise HTTPException(status_code=400, detail="File must be an image")
 
 
 
 
 
 
 
 
 
68
 
69
+ # Generate unique filename
70
+ temp_filename = f"temp_{uuid.uuid4()}.png"
 
 
 
 
 
71
 
72
+ try:
73
+ # Read and save image
74
+ contents = await file.read()
75
+ image = Image.open(io.BytesIO(contents))
76
+ image.save(temp_filename)
77
+
78
+ # Default prompt if not provided
79
+ if prompt is None:
80
+ prompt = "Below is the image of one page of a document, as well as some raw textual content that was previously extracted for it. Just return the plain text representation of this document as if you were reading it naturally. Do not hallucinate."
81
+
82
+ # Prepare messages
83
+ messages = [
84
+ {
85
+ "role": "user",
86
+ "content": [
87
+ {"type": "image", "image": f"file://{temp_filename}"},
88
+ {"type": "text", "text": prompt},
89
+ ],
90
+ }
91
+ ]
92
+
93
+ # Process inputs
94
+ text = processor.apply_chat_template(
95
+ messages, tokenize=False, add_generation_prompt=True
96
+ )
97
+ image_inputs, video_inputs = process_vision_info(messages)
98
+ inputs = processor(
99
+ text=[text],
100
+ images=image_inputs,
101
+ videos=video_inputs,
102
+ padding=True,
103
+ return_tensors="pt",
104
+ )
105
+ inputs = inputs.to("cuda")
106
+
107
+ # Generate output
108
+ generated_ids = model.generate(**inputs, max_new_tokens=max_tokens)
109
+ generated_ids_trimmed = [
110
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
111
+ ]
112
+ output_text = processor.batch_decode(
113
+ generated_ids_trimmed,
114
+ skip_special_tokens=True,
115
+ clean_up_tokenization_spaces=False
116
+ )[0]
117
+
118
+ return JSONResponse(content={
119
+ "success": True,
120
+ "text": output_text,
121
+ "filename": file.filename
122
+ })
123
+
124
+ except Exception as e:
125
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
126
 
127
+ finally:
128
+ # Clean up temporary file
129
+ if os.path.exists(temp_filename):
130
+ os.remove(temp_filename)
131
 
132
+ @app.post("/ocr-batch")
133
+ async def perform_ocr_batch(files: list[UploadFile] = File(...)):
134
  """
135
+ Perform OCR on multiple images
136
 
137
+ Args:
138
+ files: List of image files
139
 
140
+ Returns:
141
+ JSON with extracted text for each image
142
  """
143
+ if model is None or processor is None:
144
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
145
+
146
+ results = []
147
+
148
+ for file in files:
149
+ if not file.content_type.startswith("image/"):
150
+ results.append({
151
+ "filename": file.filename,
152
+ "success": False,
153
+ "error": "File must be an image"
154
+ })
155
+ continue
 
 
 
 
 
156
 
157
+ temp_filename = f"temp_{uuid.uuid4()}.png"
 
158
 
159
+ try:
160
+ contents = await file.read()
161
+ image = Image.open(io.BytesIO(contents))
162
+ image.save(temp_filename)
163
+
164
+ prompt = "Below is the image of one page of a document, as well as some raw textual content that was previously extracted for it. Just return the plain text representation of this document as if you were reading it naturally. Do not hallucinate."
165
+
166
+ messages = [
167
+ {
168
+ "role": "user",
169
+ "content": [
170
+ {"type": "image", "image": f"file://{temp_filename}"},
171
+ {"type": "text", "text": prompt},
172
+ ],
173
+ }
174
+ ]
175
+
176
+ text = processor.apply_chat_template(
177
+ messages, tokenize=False, add_generation_prompt=True
178
+ )
179
+ image_inputs, video_inputs = process_vision_info(messages)
180
+ inputs = processor(
181
+ text=[text],
182
+ images=image_inputs,
183
+ videos=video_inputs,
184
+ padding=True,
185
+ return_tensors="pt",
186
+ )
187
+ inputs = inputs.to("cuda")
188
+
189
+ generated_ids = model.generate(**inputs, max_new_tokens=max_tokens)
190
+ generated_ids_trimmed = [
191
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
192
+ ]
193
+ output_text = processor.batch_decode(
194
+ generated_ids_trimmed,
195
+ skip_special_tokens=True,
196
+ clean_up_tokenization_spaces=False
197
+ )[0]
198
+
199
+ results.append({
200
+ "filename": file.filename,
201
+ "success": True,
202
+ "text": output_text
203
+ })
204
+
205
+ except Exception as e:
206
+ results.append({
207
+ "filename": file.filename,
208
+ "success": False,
209
+ "error": str(e)
210
+ })
211
 
212
+ finally:
213
+ if os.path.exists(temp_filename):
214
+ os.remove(temp_filename)
 
 
 
 
215
 
216
+ return JSONResponse(content={"results": results})
 
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  if __name__ == "__main__":
220
  uvicorn.run(app, host="0.0.0.0", port=7860)