Emeritus-21 commited on
Commit
de5e2ab
Β·
verified Β·
1 Parent(s): 5e2af9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -70
app.py CHANGED
@@ -3,13 +3,10 @@ from threading import Thread
3
  import gradio as gr
4
  import spaces
5
  from PIL import Image
 
 
6
  import torch
7
- from transformers import AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration
8
- from reportlab.platypus import SimpleDocTemplate, Paragraph
9
- from reportlab.lib.styles import getSampleStyleSheet
10
- from docx import Document
11
- from gtts import gTTS
12
- from jiwer import cer
13
 
14
  # ---------------- Models ----------------
15
  MODEL_PATHS = {
@@ -35,76 +32,38 @@ for name, (repo_id, cls) in MODEL_PATHS.items():
35
  except Exception as e:
36
  print(f"⚠️ Failed to load {name}: {e}")
37
 
38
- # ---------------- Helpers ----------------
39
- def _build_inputs(processor, tokenizer, image: Image.Image, prompt: str):
40
- messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
41
- if tokenizer and hasattr(tokenizer, "apply_chat_template"):
42
- chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
43
- return processor(text=[chat_prompt], images=[image], return_tensors="pt")
44
- return processor(text=[prompt], images=[image], return_tensors="pt")
45
 
46
- def _decode_text(model, processor, tokenizer, output_ids, prompt: str):
47
- try:
48
- decoded_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
49
- prompt_start = decoded_text.find(prompt)
50
- if prompt_start != -1:
51
- decoded_text = decoded_text[prompt_start + len(prompt):].strip()
52
- else:
53
- decoded_text = decoded_text.strip()
54
- return decoded_text
55
- except Exception:
56
- try:
57
- decoded_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
58
- prompt_start = decoded_text.find(prompt)
59
- if prompt_start != -1:
60
- decoded_text = decoded_text[prompt_start + len(prompt):].strip()
61
- return decoded_text
62
- except Exception:
63
- return str(output_ids).strip()
64
-
65
- # πŸš€ Updated prompt with underline tagging instructions
66
- def _default_prompt(query: str | None) -> str:
67
- if query and query.strip():
68
- return query.strip()
69
- return (
70
- "You are a professional Handwritten OCR system.\n"
71
- "TASK: Read the handwritten image and transcribe the text EXACTLY as written.\n"
72
- "- Preserve original structure and line breaks.\n"
73
- "- Keep spacing, bullet points, numbering, and indentation.\n"
74
- "- Render tables as Markdown tables if present.\n"
75
- "- Detect and mark UNDERLINED text with <u>...</u> tags.\n"
76
- "- If text is double-underlined, wrap twice: <u><u>...</u></u>.\n"
77
- "- Do NOT autocorrect spelling or grammar.\n"
78
- "- Do NOT merge lines.\n"
79
- "Return RAW transcription only."
80
- )
81
-
82
- # ---------------- OCR Function ----------------
83
  @spaces.GPU
84
- def ocr_image(image: Image.Image, model_choice: str, query: str = None,
85
- max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT,
86
- temperature: float = 0.1, top_p: float = 1.0, top_k: int = 0, repetition_penalty: float = 1.0,
87
- progress=gr.Progress(track_tqdm=True)):
88
- if image is None: return "Please upload or capture an image."
89
  if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
90
- processor, model, tokenizer = _loaded_processors[model_choice], _loaded_models[model_choice], getattr(_loaded_processors[model_choice], "tokenizer", None)
91
- prompt = _default_prompt(query)
92
- batch = _build_inputs(processor, tokenizer, image, prompt).to(device)
 
93
  with torch.inference_mode():
94
- output_ids = model.generate(**batch, max_new_tokens=max_new_tokens, do_sample=False,
95
- temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
96
- return _decode_text(model, processor, tokenizer, output_ids, prompt).replace("<|im_end|>", "").strip()
 
 
 
 
97
 
98
- # ---------------- Export Helpers ----------------
99
- def _safe_text(text: str) -> str: return (text or "").strip()
100
- def save_as_pdf(text): ...
101
- def save_as_word(text): ...
102
- def save_as_audio(text): ...
103
- def calculate_cer_score(gt, pred): ...
104
 
105
- # ---------------- Gradio Interface ----------------
106
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
107
- gr.Markdown("## ✍🏾 Wilson OCR (Prompt underline mode)")
108
  model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="Select OCR Model")
109
 
110
  with gr.Tab("πŸ–Ό Image Inference"):
@@ -113,7 +72,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
113
  extract_btn = gr.Button("πŸ“€ Extract RAW Text", variant="primary")
114
  raw_output = gr.Textbox(label="πŸ“œ RAW Structured Output", lines=18, show_copy_button=True)
115
 
116
- extract_btn.click(fn=ocr_image, inputs=[image_input, model_choice, query_input], outputs=[raw_output])
117
 
118
  if __name__ == "__main__":
119
  demo.queue().launch(share=True)
 
3
  import gradio as gr
4
  import spaces
5
  from PIL import Image
6
+ import numpy as np
7
+ import cv2
8
  import torch
9
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
 
 
 
 
 
10
 
11
  # ---------------- Models ----------------
12
  MODEL_PATHS = {
 
32
  except Exception as e:
33
  print(f"⚠️ Failed to load {name}: {e}")
34
 
35
+ # ---------------- Underline Detection ----------------
36
+ def detect_underlines(image: Image.Image):
37
+ cv_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
38
+ _, thresh = cv2.threshold(cv_img, 150, 255, cv2.THRESH_BINARY_INV)
39
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (30, 1))
40
+ detected_lines = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=2)
41
+ return detected_lines
42
 
43
+ # ---------------- OCR + Underline ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  @spaces.GPU
45
+ def ocr_with_underlines(image: Image.Image, model_choice: str, query: str = None,
46
+ max_new_tokens: int = MAX_NEW_TOKENS_DEFAULT):
47
+ if image is None: return "Please upload an image."
 
 
48
  if model_choice not in _loaded_models: return f"Invalid model: {model_choice}"
49
+ processor, model = _loaded_processors[model_choice], _loaded_models[model_choice]
50
+
51
+ # Run OCR
52
+ inputs = processor(images=image, text="Transcribe handwriting.", return_tensors="pt").to(device)
53
  with torch.inference_mode():
54
+ output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
55
+ raw_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
56
+
57
+ # Run CV underline detection
58
+ underline_mask = detect_underlines(image)
59
+ if np.sum(underline_mask) > 5000:
60
+ raw_text = f"<u>{raw_text}</u>"
61
 
62
+ return raw_text.strip()
 
 
 
 
 
63
 
64
+ # ---------------- Gradio UI ----------------
65
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
66
+ gr.Markdown("## ✍🏾 Wilson OCR (OpenCV underline mode)")
67
  model_choice = gr.Radio(choices=list(MODEL_PATHS.keys()), value=list(MODEL_PATHS.keys())[0], label="Select OCR Model")
68
 
69
  with gr.Tab("πŸ–Ό Image Inference"):
 
72
  extract_btn = gr.Button("πŸ“€ Extract RAW Text", variant="primary")
73
  raw_output = gr.Textbox(label="πŸ“œ RAW Structured Output", lines=18, show_copy_button=True)
74
 
75
+ extract_btn.click(fn=ocr_with_underlines, inputs=[image_input, model_choice, query_input], outputs=[raw_output])
76
 
77
  if __name__ == "__main__":
78
  demo.queue().launch(share=True)