IFMedTechdemo commited on
Commit
a779cac
·
verified ·
1 Parent(s): 40cc5c9

Update app.py

Browse files

total change in code version 2.
The generate_image signature now exactly matches the inputs=[...] order (including the gr.State(...) for text).

The function always yields two values, and the UI defines outputs=[output, markdown_output], so Gradio will not error.

ClinicalNER is only invoked when model_name == "Dots.OCR"; otherwise, medications are derived from line-splitting as a fallback.

Spell-check suggests up to 5 matches per med (depending on your private matcher implementation), and each line includes both score and CER.

Files changed (1) hide show
  1. app.py +223 -97
app.py CHANGED
@@ -1,10 +1,7 @@
1
- ###################################### version 2 ########################################################
2
-
3
  import os
4
  import time
5
  from threading import Thread
6
  from typing import Iterable, Dict, Any, Optional, List
7
- import pandas as pd # For reading Excel file
8
 
9
  import gradio as gr
10
  import spaces
@@ -21,14 +18,11 @@ from transformers import (
21
  from gradio.themes import Soft
22
  from gradio.themes.utils import colors, fonts, sizes
23
 
24
- MAX_MAX_NEW_TOKENS = 4096
25
- DEFAULT_MAX_NEW_TOKENS = 2048
26
-
27
-
28
  # -----------------------------
29
- # Character Error Rate (CER) Calculation
30
  # -----------------------------
31
 
 
32
  def levenshtein(a: str, b: str) -> int:
33
  """Levenshtein distance to calculate CER."""
34
  a, b = a.lower(), b.lower()
@@ -45,14 +39,17 @@ def levenshtein(a: str, b: str) -> int:
45
  for j, cb in enumerate(b, 1):
46
  cur = dp[j]
47
  cost = 0 if ca == cb else 1
48
- dp[j] = min(dp[j] + 1, dp[j-1] + 1, prev + cost)
49
  prev = cur
50
  return dp[-1]
51
 
 
52
  def character_error_rate(pred: str, target: str) -> float:
53
- """Calculate the Character Error Rate (CER)."""
 
54
  distance = levenshtein(pred, target)
55
- return (distance / len(target)) * 100 if len(target) > 0 else 0
 
56
 
57
  # -----------------------------
58
  # Private repo: dynamic import
@@ -64,14 +61,15 @@ REPO_ID = "IFMedTech/Medibot_OCR_model" # private backend repo
64
 
65
  # Map filenames to exported class names
66
  PY_MODULES = {
67
- "ner.py": "ClinicalNER", # NER is only applied for Dots.OCR output
68
  "tfidf_phonetic.py": "TfidfPhoneticMatcher",
69
  "symspell_matcher.py": "SymSpellMatcher",
70
  "rapidfuzz_matcher.py": "RapidFuzzMatcher",
71
- # 'drug_dictionary.xlsx' is data, not a module
72
  }
73
 
74
- HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
 
75
 
76
  def _dynamic_import(module_path: str, class_name: str):
77
  spec = importlib.util.spec_from_file_location(class_name, module_path)
@@ -79,7 +77,8 @@ def _dynamic_import(module_path: str, class_name: str):
79
  spec.loader.exec_module(module) # type: ignore
80
  return getattr(module, class_name)
81
 
82
- # Load private classes and Excel dictionary
 
83
  priv_classes: Dict[str, Any] = {}
84
  drug_xlsx_path: Optional[str] = None
85
  try:
@@ -91,7 +90,11 @@ try:
91
  if cls:
92
  priv_classes[cls] = _dynamic_import(path, cls)
93
  print(f"[Private] Loaded class: {cls} from {fname}")
94
- drug_xlsx_path = hf_hub_download(repo_id=REPO_ID, filename="Medibot_Drugs_Cleaned_Updated.xlsx", token=HF_TOKEN)
 
 
 
 
95
  print(f"[Private] Downloaded Excel at: {drug_xlsx_path}")
96
  except Exception as e:
97
  print(f"[Private] ERROR loading private backend: {e}")
@@ -116,6 +119,7 @@ colors.steel_blue = colors.Color(
116
  c950="#1E3450",
117
  )
118
 
 
119
  class SteelBlueTheme(Soft):
120
  def __init__(
121
  self,
@@ -125,10 +129,14 @@ class SteelBlueTheme(Soft):
125
  neutral_hue: colors.Color | str = colors.slate,
126
  text_size: sizes.Size | str = sizes.text_lg,
127
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
128
- fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
 
 
129
  ),
130
  font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
131
- fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
 
 
132
  ),
133
  ):
134
  super().__init__(
@@ -167,6 +175,7 @@ class SteelBlueTheme(Soft):
167
  block_label_background_fill="*primary_200",
168
  )
169
 
 
170
  steel_blue_theme = SteelBlueTheme()
171
 
172
  css = """
@@ -177,11 +186,7 @@ css = """
177
  # ----------------------------
178
  # RUNTIME / DEVICE
179
  # ----------------------------
180
-
181
- # Ensure CUDA_VISIBLE_DEVICES is set correctly to use GPU
182
  os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
183
-
184
- # Check if CUDA is available and print relevant information
185
  print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES"))
186
  print("torch.__version__ =", torch.__version__)
187
  print("torch.version.cuda =", torch.version.cuda)
@@ -214,6 +219,7 @@ processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True
214
  attn_impl = "sdpa"
215
  try:
216
  import flash_attn # noqa: F401
 
217
  if use_cuda:
218
  attn_impl = "flash_attention_2"
219
  except Exception:
@@ -224,31 +230,41 @@ model_d = AutoModelForCausalLM.from_pretrained(
224
  attn_implementation=attn_impl,
225
  torch_dtype=DTYPE_BF16,
226
  device_map="auto" if use_cuda else None,
227
- trust_remote_code=True
228
  ).eval()
229
  if not use_cuda:
230
  model_d.to(device)
231
 
232
  # ----------------------------
233
- # GENERATION (OCR → Spell-check)
234
  # ----------------------------
 
 
235
 
236
- @spaces.GPU
237
- def generate_image(model_name: str,
238
- text: str,
239
- image: Image.Image,
240
- max_new_tokens: int,
241
- temperature: float,
242
- top_p: float,
243
- top_k: int,
244
- repetition_penalty: float,
245
- spell_algo: str):
 
 
 
246
  """
247
- 1) Stream OCR tokens to Raw output.
248
- 2) Directly apply spell-check algorithms (TF-IDF+Phonetic, SymSpell, or RapidFuzz).
249
- 3) Only apply Clinical NER to Dots.OCR output, then apply spell-check on the result.
 
 
 
 
250
  """
251
  if image is None:
 
252
  yield "Please upload an image.", "Please upload an image."
253
  return
254
 
@@ -260,23 +276,31 @@ def generate_image(model_name: str,
260
  yield "Invalid model selected.", "Invalid model selected."
261
  return
262
 
263
- # Build prompt
264
- messages = [{
265
- "role": "user",
266
- "content": [
267
- {"type": "image"},
268
- {"type": "text", "text": text},
269
- ]
270
- }]
271
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
 
272
 
273
  # Preprocess
274
- inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True)
 
 
275
  inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
276
 
277
  # Streamer
278
  tokenizer = getattr(processor, "tokenizer", None) or processor
279
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
280
 
281
  gen_kwargs = dict(
282
  **inputs,
@@ -289,132 +313,234 @@ def generate_image(model_name: str,
289
  repetition_penalty=repetition_penalty,
290
  )
291
 
292
- # Start generation
293
  thread = Thread(target=model.generate, kwargs=gen_kwargs)
294
  thread.start()
295
 
296
- # 1) Live OCR streaming to Raw
297
  buffer = ""
298
  for new_text in streamer:
299
  buffer += new_text.replace("<|im_end|>", "")
300
  time.sleep(0.01)
 
301
  yield buffer, buffer
302
 
303
- # Final raw OCR output (buffer)
304
  final_ocr_text = buffer.strip()
305
 
306
- # 2) Apply Clinical NER ONLY for Dots.OCR output
307
- meds = []
 
 
308
  if model_name == "Dots.OCR":
309
  try:
310
- if "ClinicalNER" in priv_classes:
311
  ClinicalNER = priv_classes["ClinicalNER"]
312
- ner = ClinicalNER(token=HF_TOKEN) # pass model_id=... if using your own model
313
- meds = ner(final_ocr_text) or []
314
- print("Extracted meds:", meds) # Print extracted meds
 
 
315
  else:
316
- print("[NER] ClinicalNER not available.")
317
  except Exception as e:
318
  print(f"[NER] Error running ClinicalNER: {e}")
319
 
320
- # 3) Apply selected spell-check algorithm (directly on raw OCR output or NER output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  spell_section = "\n---\n### Spell-check suggestions (" + spell_algo + ")\n"
322
  corr: Dict[str, List] = {}
323
 
324
  try:
325
- if final_ocr_text and drug_xlsx_path:
326
- # Print meds and the number of rows in the drug_xlsx_path
327
- print("Meds:", meds)
328
- print("Rows in drug_xlsx_path:", len(pd.read_excel(drug_xlsx_path)))
329
-
330
- if spell_algo == "TF-IDF + Phonetic" and "TfidfPhoneticMatcher" in priv_classes:
331
  Cls = priv_classes["TfidfPhoneticMatcher"]
332
- checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs", ngram_size=3, phonetic_weight=0.4)
333
- corr = checker.match_list([final_ocr_text], top_k=5, tfidf_threshold=0.15)
 
 
 
 
 
334
 
335
  elif spell_algo == "SymSpell" and "SymSpellMatcher" in priv_classes:
336
  Cls = priv_classes["SymSpellMatcher"]
337
- checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs", max_edit=2, prefix_len=7)
338
- corr = checker.match_list([final_ocr_text], top_k=5, min_score=0.4)
339
-
340
- elif spell_algo == "RapidFuzz" and "RapidFuzzMatcher" in priv_classes:
 
 
 
 
 
 
 
341
  Cls = priv_classes["RapidFuzzMatcher"]
342
  checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs")
343
- corr = checker.match_list([final_ocr_text], top_k=5, threshold=70.0)
344
  else:
345
  spell_section += "- Spell-check backend unavailable.\n"
346
  else:
347
- spell_section += "- No OCR output or Excel dictionary missing.\n"
348
  except Exception as e:
349
  spell_section += f"- Spell-check error: {e}\n"
350
 
351
- # Format spell-check suggestions (top-5 with CER)
352
  if corr:
353
- for raw in [final_ocr_text]:
354
  suggestions = corr.get(raw, [])
355
  if suggestions:
356
  spell_section += f"- **{raw}**\n"
357
  for cand, score in suggestions:
358
- cer = character_error_rate(cand, raw) # Calculate CER
359
- spell_section += f" - {cand} (score={score:.3f}, CER={cer:.3f}%)\n"
 
 
 
360
  else:
361
  spell_section += f"- **{raw}**\n - (no suggestions)\n"
362
 
363
- final_md = spell_section # Only spell-check suggestions
364
 
365
- # 4) Final yield: raw unchanged; Markdown with spell-check
366
  yield final_ocr_text, final_md
367
 
 
368
  # ----------------------------
369
  # UI
370
  # ----------------------------
371
-
372
  image_examples = [
373
- ["OCR the content perfectly.", "examples/3.jpg"],
374
- ["Perform OCR on the image.", "examples/1.jpg"],
375
- ["Extract the contents. [page].", "examples/2.jpg"],
376
  ]
377
 
378
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
379
- gr.Markdown("# **Handwritten Doctor's Prescription Reading**", elem_id="main-title")
 
 
380
  with gr.Row():
381
  with gr.Column(scale=2):
382
- image_upload = gr.Image(type="pil", label="Upload Image", height=290)
 
 
383
  image_submit = gr.Button("Submit", variant="primary")
384
- gr.Examples(examples=image_examples, inputs=[image_upload])
 
 
 
 
385
 
386
  # Spell-check selection
387
  spell_choice = gr.Radio(
388
  choices=["TF-IDF + Phonetic", "SymSpell", "RapidFuzz"],
389
  label="Select Spell-check Approach",
390
- value="TF-IDF + Phonetic"
391
  )
392
 
393
  with gr.Accordion("Advanced options", open=False):
394
- max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
395
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.7)
396
- top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
397
- top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
398
- repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
  with gr.Column(scale=3):
401
  gr.Markdown("## Output", elem_id="output-title")
402
- output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
 
 
 
 
 
 
 
403
 
404
  model_choice = gr.Radio(
405
  choices=["Chandra-OCR", "Dots.OCR"],
406
  label="Select OCR Model",
407
- value="Chandra-OCR"
408
  )
409
 
 
 
 
 
 
410
  image_submit.click(
411
  fn=generate_image,
412
- inputs=[model_choice, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, spell_choice],
413
- outputs=[output]
 
 
 
 
 
 
 
 
 
 
414
  )
415
 
416
  if __name__ == "__main__":
417
- demo.queue(max_size=50).launch(mcp_server=True, ssr_mode=False, show_error=True)
 
 
418
 
419
 
420
 
 
 
 
1
  import os
2
  import time
3
  from threading import Thread
4
  from typing import Iterable, Dict, Any, Optional, List
 
5
 
6
  import gradio as gr
7
  import spaces
 
18
  from gradio.themes import Soft
19
  from gradio.themes.utils import colors, fonts, sizes
20
 
 
 
 
 
21
  # -----------------------------
22
+ # Character Error Rate (CER)
23
  # -----------------------------
24
 
25
+
26
  def levenshtein(a: str, b: str) -> int:
27
  """Levenshtein distance to calculate CER."""
28
  a, b = a.lower(), b.lower()
 
39
  for j, cb in enumerate(b, 1):
40
  cur = dp[j]
41
  cost = 0 if ca == cb else 1
42
+ dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost)
43
  prev = cur
44
  return dp[-1]
45
 
46
+
47
  def character_error_rate(pred: str, target: str) -> float:
48
+ """Calculate the Character Error Rate (CER) in percent."""
49
+ target = target or ""
50
  distance = levenshtein(pred, target)
51
+ return (distance / len(target)) * 100 if len(target) > 0 else 0.0
52
+
53
 
54
  # -----------------------------
55
  # Private repo: dynamic import
 
61
 
62
  # Map filenames to exported class names
63
  PY_MODULES = {
64
+ "ner.py": "ClinicalNER", # NER is only applied for Dots.OCR output
65
  "tfidf_phonetic.py": "TfidfPhoneticMatcher",
66
  "symspell_matcher.py": "SymSpellMatcher",
67
  "rapidfuzz_matcher.py": "RapidFuzzMatcher",
68
+ # 'Medibot_Drugs_Cleaned_Updated.xlsx' is data, not a module
69
  }
70
 
71
+ HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
72
+
73
 
74
  def _dynamic_import(module_path: str, class_name: str):
75
  spec = importlib.util.spec_from_file_location(class_name, module_path)
 
77
  spec.loader.exec_module(module) # type: ignore
78
  return getattr(module, class_name)
79
 
80
+
81
+ # Load private classes and Excel dictionary (once at import time)
82
  priv_classes: Dict[str, Any] = {}
83
  drug_xlsx_path: Optional[str] = None
84
  try:
 
90
  if cls:
91
  priv_classes[cls] = _dynamic_import(path, cls)
92
  print(f"[Private] Loaded class: {cls} from {fname}")
93
+ drug_xlsx_path = hf_hub_download(
94
+ repo_id=REPO_ID,
95
+ filename="Medibot_Drugs_Cleaned_Updated.xlsx",
96
+ token=HF_TOKEN,
97
+ )
98
  print(f"[Private] Downloaded Excel at: {drug_xlsx_path}")
99
  except Exception as e:
100
  print(f"[Private] ERROR loading private backend: {e}")
 
119
  c950="#1E3450",
120
  )
121
 
122
+
123
  class SteelBlueTheme(Soft):
124
  def __init__(
125
  self,
 
129
  neutral_hue: colors.Color | str = colors.slate,
130
  text_size: sizes.Size | str = sizes.text_lg,
131
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
132
+ fonts.GoogleFont("Outfit"),
133
+ "Arial",
134
+ "sans-serif",
135
  ),
136
  font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
137
+ fonts.GoogleFont("IBM Plex Mono"),
138
+ "ui-monospace",
139
+ "monospace",
140
  ),
141
  ):
142
  super().__init__(
 
175
  block_label_background_fill="*primary_200",
176
  )
177
 
178
+
179
  steel_blue_theme = SteelBlueTheme()
180
 
181
  css = """
 
186
  # ----------------------------
187
  # RUNTIME / DEVICE
188
  # ----------------------------
 
 
189
  os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
 
 
190
  print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES"))
191
  print("torch.__version__ =", torch.__version__)
192
  print("torch.version.cuda =", torch.version.cuda)
 
219
  attn_impl = "sdpa"
220
  try:
221
  import flash_attn # noqa: F401
222
+
223
  if use_cuda:
224
  attn_impl = "flash_attention_2"
225
  except Exception:
 
230
  attn_implementation=attn_impl,
231
  torch_dtype=DTYPE_BF16,
232
  device_map="auto" if use_cuda else None,
233
+ trust_remote_code=True,
234
  ).eval()
235
  if not use_cuda:
236
  model_d.to(device)
237
 
238
  # ----------------------------
239
+ # GENERATION (OCR → NER (Dots only) → Spell-check + CER)
240
  # ----------------------------
241
+ MAX_MAX_NEW_TOKENS = 4096
242
+ DEFAULT_MAX_NEW_TOKENS = 2048
243
 
244
+
245
+ @spaces.GPU # you can add duration=... if needed, e.g. @spaces.GPU(duration=240)
246
+ def generate_image(
247
+ model_name: str,
248
+ text: str,
249
+ image: Image.Image,
250
+ max_new_tokens: int,
251
+ temperature: float,
252
+ top_p: float,
253
+ top_k: int,
254
+ repetition_penalty: float,
255
+ spell_algo: str,
256
+ ):
257
  """
258
+ 1) Stream OCR tokens to Raw output (unchanged).
259
+ 2) If model_name == 'Dots.OCR', run ClinicalNER → list[str] meds.
260
+ For Chandra-OCR, skip NER.
261
+ 3) Apply selected spell-check (TF-IDF+Phonetic / SymSpell / RapidFuzz)
262
+ using Excel dict, and compute CER for each suggestion.
263
+ 4) Markdown shows OCR text, NER list (if any), and spell-check top-5
264
+ suggestions with scores and CER.
265
  """
266
  if image is None:
267
+ # Two outputs: raw textbox + markdown
268
  yield "Please upload an image.", "Please upload an image."
269
  return
270
 
 
276
  yield "Invalid model selected.", "Invalid model selected."
277
  return
278
 
279
+ # Build prompt from text parameter (kept via gr.State)
280
+ messages = [
281
+ {
282
+ "role": "user",
283
+ "content": [
284
+ {"type": "image"},
285
+ {"type": "text", "text": text},
286
+ ],
287
+ }
288
+ ]
289
+ prompt_full = processor.apply_chat_template(
290
+ messages, tokenize=False, add_generation_prompt=True
291
+ )
292
 
293
  # Preprocess
294
+ inputs = processor(
295
+ text=[prompt_full], images=[image], return_tensors="pt", padding=True
296
+ )
297
  inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
298
 
299
  # Streamer
300
  tokenizer = getattr(processor, "tokenizer", None) or processor
301
+ streamer = TextIteratorStreamer(
302
+ tokenizer, skip_prompt=True, skip_special_tokens=True
303
+ )
304
 
305
  gen_kwargs = dict(
306
  **inputs,
 
313
  repetition_penalty=repetition_penalty,
314
  )
315
 
316
+ # Start generation in background thread
317
  thread = Thread(target=model.generate, kwargs=gen_kwargs)
318
  thread.start()
319
 
320
+ # 1) Live OCR streaming to Raw (and mirror to Markdown during stream)
321
  buffer = ""
322
  for new_text in streamer:
323
  buffer += new_text.replace("<|im_end|>", "")
324
  time.sleep(0.01)
325
+ # During streaming, just show the raw text in both components
326
  yield buffer, buffer
327
 
328
+ # Final raw text
329
  final_ocr_text = buffer.strip()
330
 
331
+ # -------------------------
332
+ # 2) Clinical NER (Dots.OCR only)
333
+ # -------------------------
334
+ meds: List[str] = []
335
  if model_name == "Dots.OCR":
336
  try:
337
+ if "ClinicalNER" in priv_classes and HF_TOKEN is not None:
338
  ClinicalNER = priv_classes["ClinicalNER"]
339
+ ner = ClinicalNER(token=HF_TOKEN) # model_id can be passed if needed
340
+ ner_output = ner(final_ocr_text) or []
341
+ # Expecting list[str]; be robust:
342
+ meds = [m.strip() for m in ner_output if isinstance(m, str) and m.strip()]
343
+ print("[NER] Extracted meds:", meds)
344
  else:
345
+ print("[NER] ClinicalNER not available or no HF token.")
346
  except Exception as e:
347
  print(f"[NER] Error running ClinicalNER: {e}")
348
 
349
+ # Fallback: if no meds found (or Chandra-OCR), derive meds from OCR lines
350
+ if not meds:
351
+ meds = [line.strip() for line in final_ocr_text.splitlines() if line.strip()]
352
+ print("[NER] Using line-based meds fallback, count:", len(meds))
353
+
354
+ # -------------------------
355
+ # Build Markdown: OCR text + NER section
356
+ # -------------------------
357
+ md = "### Raw OCR Output\n"
358
+ md += "```\n" + (final_ocr_text or "(empty)") + "\n```\n"
359
+
360
+ md += "\n---\n### Clinical NER (Medications)\n"
361
+ if meds:
362
+ for m in meds:
363
+ md += f"- {m}\n"
364
+ else:
365
+ md += "- None detected\n"
366
+
367
+ # -------------------------
368
+ # 3) Spell-check (med list) with CER
369
+ # -------------------------
370
  spell_section = "\n---\n### Spell-check suggestions (" + spell_algo + ")\n"
371
  corr: Dict[str, List] = {}
372
 
373
  try:
374
+ if meds and drug_xlsx_path:
375
+ if (
376
+ spell_algo == "TF-IDF + Phonetic"
377
+ and "TfidfPhoneticMatcher" in priv_classes
378
+ ):
 
379
  Cls = priv_classes["TfidfPhoneticMatcher"]
380
+ checker = Cls(
381
+ xlsx_path=drug_xlsx_path,
382
+ column="Combined_Drugs",
383
+ ngram_size=3,
384
+ phonetic_weight=0.4,
385
+ )
386
+ corr = checker.match_list(meds, top_k=5, tfidf_threshold=0.15)
387
 
388
  elif spell_algo == "SymSpell" and "SymSpellMatcher" in priv_classes:
389
  Cls = priv_classes["SymSpellMatcher"]
390
+ checker = Cls(
391
+ xlsx_path=drug_xlsx_path,
392
+ column="Combined_Drugs",
393
+ max_edit=2,
394
+ prefix_len=7,
395
+ )
396
+ corr = checker.match_list(meds, top_k=5, min_score=0.4)
397
+
398
+ elif (
399
+ spell_algo == "RapidFuzz" and "RapidFuzzMatcher" in priv_classes
400
+ ):
401
  Cls = priv_classes["RapidFuzzMatcher"]
402
  checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs")
403
+ corr = checker.match_list(meds, top_k=5, threshold=70.0)
404
  else:
405
  spell_section += "- Spell-check backend unavailable.\n"
406
  else:
407
+ spell_section += "- No NER/med list or Excel dictionary missing.\n"
408
  except Exception as e:
409
  spell_section += f"- Spell-check error: {e}\n"
410
 
411
+ # Format suggestions (top-5 per med, with scores + CER)
412
  if corr:
413
+ for raw in meds:
414
  suggestions = corr.get(raw, [])
415
  if suggestions:
416
  spell_section += f"- **{raw}**\n"
417
  for cand, score in suggestions:
418
+ cer = character_error_rate(cand, raw)
419
+ spell_section += (
420
+ f" - {cand} "
421
+ f"(score={score:.3f}, CER={cer:.3f}%)\n"
422
+ )
423
  else:
424
  spell_section += f"- **{raw}**\n - (no suggestions)\n"
425
 
426
+ final_md = md + spell_section
427
 
428
+ # 4) Final yield: raw unchanged; Markdown with NER + spell-check + CER
429
  yield final_ocr_text, final_md
430
 
431
+
432
  # ----------------------------
433
  # UI
434
  # ----------------------------
435
+ # IMPORTANT: examples must match the number of inputs (here: only image)
436
  image_examples = [
437
+ ["examples/3.jpg"],
438
+ ["examples/1.jpg"],
439
+ ["examples/2.jpg"],
440
  ]
441
 
442
  with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
443
+ gr.Markdown(
444
+ "# **Handwritten Doctor's Prescription Reading**", elem_id="main-title"
445
+ )
446
  with gr.Row():
447
  with gr.Column(scale=2):
448
+ image_upload = gr.Image(
449
+ type="pil", label="Upload Image", height=290
450
+ )
451
  image_submit = gr.Button("Submit", variant="primary")
452
+ gr.Examples(
453
+ examples=image_examples,
454
+ inputs=[image_upload],
455
+ label="Example Images",
456
+ )
457
 
458
  # Spell-check selection
459
  spell_choice = gr.Radio(
460
  choices=["TF-IDF + Phonetic", "SymSpell", "RapidFuzz"],
461
  label="Select Spell-check Approach",
462
+ value="TF-IDF + Phonetic",
463
  )
464
 
465
  with gr.Accordion("Advanced options", open=False):
466
+ max_new_tokens = gr.Slider(
467
+ label="Max new tokens",
468
+ minimum=1,
469
+ maximum=MAX_MAX_NEW_TOKENS,
470
+ step=1,
471
+ value=DEFAULT_MAX_NEW_TOKENS,
472
+ )
473
+ temperature = gr.Slider(
474
+ label="Temperature",
475
+ minimum=0.1,
476
+ maximum=4.0,
477
+ step=0.1,
478
+ value=0.7,
479
+ )
480
+ top_p = gr.Slider(
481
+ label="Top-p (nucleus sampling)",
482
+ minimum=0.05,
483
+ maximum=1.0,
484
+ step=0.05,
485
+ value=0.9,
486
+ )
487
+ top_k = gr.Slider(
488
+ label="Top-k",
489
+ minimum=1,
490
+ maximum=1000,
491
+ step=1,
492
+ value=50,
493
+ )
494
+ repetition_penalty = gr.Slider(
495
+ label="Repetition penalty",
496
+ minimum=1.0,
497
+ maximum=2.0,
498
+ step=0.05,
499
+ value=1.1,
500
+ )
501
 
502
  with gr.Column(scale=3):
503
  gr.Markdown("## Output", elem_id="output-title")
504
+ output = gr.Textbox(
505
+ label="Raw Output Stream",
506
+ interactive=False,
507
+ lines=11,
508
+ show_copy_button=True,
509
+ )
510
+ with gr.Accordion("(Result.md)", open=False):
511
+ markdown_output = gr.Markdown(label="(Result.Md)")
512
 
513
  model_choice = gr.Radio(
514
  choices=["Chandra-OCR", "Dots.OCR"],
515
  label="Select OCR Model",
516
+ value="Chandra-OCR",
517
  )
518
 
519
+ # Hard-coded instruction text, passed as gr.State to match the 'text' parameter
520
+ query_state = gr.State(
521
+ "Extract medicine or drugs names along with dosage amount or quantity"
522
+ )
523
+
524
  image_submit.click(
525
  fn=generate_image,
526
+ inputs=[
527
+ model_choice,
528
+ query_state,
529
+ image_upload,
530
+ max_new_tokens,
531
+ temperature,
532
+ top_p,
533
+ top_k,
534
+ repetition_penalty,
535
+ spell_choice,
536
+ ],
537
+ outputs=[output, markdown_output],
538
  )
539
 
540
  if __name__ == "__main__":
541
+ demo.queue(max_size=50).launch(
542
+ mcp_server=True, ssr_mode=False, show_error=True
543
+ )
544
 
545
 
546