Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,13 +3,10 @@ from threading import Thread
|
|
| 3 |
import gradio as gr
|
| 4 |
import spaces
|
| 5 |
from PIL import Image
|
|
|
|
|
|
|
| 6 |
import torch
|
| 7 |
-
from transformers import AutoProcessor,
|
| 8 |
-
from reportlab.platypus import SimpleDocTemplate, Paragraph
|
| 9 |
-
from reportlab.lib.styles import getSampleStyleSheet
|
| 10 |
-
from docx import Document
|
| 11 |
-
from gtts import gTTS
|
| 12 |
-
from jiwer import cer
|
| 13 |
|
| 14 |
# ---------------- Models ----------------
|
| 15 |
MODEL_PATHS = {
|
|
@@ -35,76 +32,38 @@ for name, (repo_id, cls) in MODEL_PATHS.items():
|
|
| 35 |
except Exception as e:
|
| 36 |
print(f"β οΈ Failed to load {name}: {e}")
|
| 37 |
|
| 38 |
-
# ----------------
|
| 39 |
-
def
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
return
|
| 45 |
|
| 46 |
-
|
| 47 |
-
try:
|
| 48 |
-
decoded_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
|
| 49 |
-
prompt_start = decoded_text.find(prompt)
|
| 50 |
-
if prompt_start != -1:
|
| 51 |
-
decoded_text = decoded_text[prompt_start + len(prompt):].strip()
|
| 52 |
-
else:
|
| 53 |
-
decoded_text = decoded_text.strip()
|
| 54 |
-
return decoded_text
|
| 55 |
-
except Exception:
|
| 56 |
-
try:
|
| 57 |
-
decoded_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
| 58 |
-
prompt_start = decoded_text.find(prompt)
|
| 59 |
-
if prompt_start != -1:
|
| 60 |
-
decoded_text = decoded_text[prompt_start + len(prompt):].strip()
|
| 61 |
-
return decoded_text
|
| 62 |
-
except Exception:
|
| 63 |
-
return str(output_ids).strip()
|
| 64 |
-
|
| 65 |
-
# π Updated prompt with underline tagging instructions
|
| 66 |
-
def _default_prompt(query: str | None) -> str:
|
| 67 |
-
if query and query.strip():
|
| 68 |
-
return query.strip()
|
| 69 |
-
return (
|
| 70 |
-
"You are a professional Handwritten OCR system.\n"
|
| 71 |
-
"TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
|
| 72 |
-
"- Preserve original structure and line breaks.\n"
|
| 73 |
-
"- Keep spacing, bullet points, numbering, and indentation.\n"
|
| 74 |
-
"- Render tables as Markdown tables if present.\n"
|
| 75 |
-
"- Detect and mark UNDERLINED text with <u>...</u> tags.\n"
|
| 76 |
-
"- If text is double-underlined, wrap twice: <u><u>...</u></u>.\n"
|
| 77 |
-
"- Do NOT autocorrect spelling or grammar.\n"
|
| 78 |
-
"- Do NOT merge lines.\n"
|
| 79 |
-
"Return RAW transcription only."
|
| 80 |
-
)
|
| 81 |
-
|
| 82 |
-
# ---------------- OCR Function ----------------
|
| 83 |
@spaces.GPU
|
| 84 |
-
def
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
progress=gr.Progress(track_tqdm=True)):
|
| 88 |
-
if image is None: return "Please upload or capture an image."
|
| 89 |
if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
|
| 90 |
-
processor, model
|
| 91 |
-
|
| 92 |
-
|
|
|
|
| 93 |
with torch.inference_mode():
|
| 94 |
-
output_ids = model.generate(**
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
|
| 99 |
-
def _safe_text(text: str) -> str: return (text or "").strip()
|
| 100 |
-
def save_as_pdf(text): ...
|
| 101 |
-
def save_as_word(text): ...
|
| 102 |
-
def save_as_audio(text): ...
|
| 103 |
-
def calculate_cer_score(gt, pred): ...
|
| 104 |
|
| 105 |
-
# ---------------- Gradio
|
| 106 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 107 |
-
gr.Markdown("## βπΎ Wilson OCR (
|
| 108 |
model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="Select OCR Model")
|
| 109 |
|
| 110 |
with gr.Tab("πΌ Image Inference"):
|
|
@@ -113,7 +72,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 113 |
extract_btn = gr.Button("π€ Extract RAW Text", variant="primary")
|
| 114 |
raw_output = gr.Textbox(label="π RAW Structured Output", lines=18, show_copy_button=True)
|
| 115 |
|
| 116 |
-
extract_btn.click(fn=
|
| 117 |
|
| 118 |
if __name__ == "__main__":
|
| 119 |
demo.queue().launch(share=True)
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
import spaces
|
| 5 |
from PIL import Image
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
import torch
|
| 9 |
+
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# ---------------- Models ----------------
|
| 12 |
MODEL_PATHS = {
|
|
|
|
| 32 |
except Exception as e:
|
| 33 |
print(f"β οΈ Failed to load {name}: {e}")
|
| 34 |
|
| 35 |
+
# ---------------- Underline Detection ----------------
|
| 36 |
+
def detect_underlines(image: Image.Image):
|
| 37 |
+
cv_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
| 38 |
+
_, thresh = cv2.threshold(cv_img, 150, 255, cv2.THRESH_BINARY_INV)
|
| 39 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (30, 1))
|
| 40 |
+
detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=2)
|
| 41 |
+
return detected_lines
|
| 42 |
|
| 43 |
+
# ---------------- OCR + Underline ----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
@spaces.GPU
|
| 45 |
+
def ocr_with_underlines(image: Image.Image, model_choice: str, query: str = None,
|
| 46 |
+
max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT):
|
| 47 |
+
if image is None: return "Please upload an image."
|
|
|
|
|
|
|
| 48 |
if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
|
| 49 |
+
processor, model = _loaded_processors[model_choice], _loaded_models[model_choice]
|
| 50 |
+
|
| 51 |
+
# Run OCR
|
| 52 |
+
inputs = processor(images=image, text="Transcribe handwriting.", return_tensors="pt").to(device)
|
| 53 |
with torch.inference_mode():
|
| 54 |
+
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
|
| 55 |
+
raw_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
|
| 56 |
+
|
| 57 |
+
# Run CV underline detection
|
| 58 |
+
underline_mask = detect_underlines(image)
|
| 59 |
+
if np.sum(underline_mask) > 5000:
|
| 60 |
+
raw_text = f"<u>{raw_text}</u>"
|
| 61 |
|
| 62 |
+
return raw_text.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
+
# ---------------- Gradio UI ----------------
|
| 65 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 66 |
+
gr.Markdown("## βπΎ Wilson OCR (OpenCV underline mode)")
|
| 67 |
model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="Select OCR Model")
|
| 68 |
|
| 69 |
with gr.Tab("πΌ Image Inference"):
|
|
|
|
| 72 |
extract_btn = gr.Button("π€ Extract RAW Text", variant="primary")
|
| 73 |
raw_output = gr.Textbox(label="π RAW Structured Output", lines=18, show_copy_button=True)
|
| 74 |
|
| 75 |
+
extract_btn.click(fn=ocr_with_underlines, inputs=[image_input, model_choice, query_input], outputs=[raw_output])
|
| 76 |
|
| 77 |
if __name__ == "__main__":
|
| 78 |
demo.queue().launch(share=True)
|