medibotOCR / app.py
IFMedTechdemo's picture
Update app.py
834c4ff verified
###################################### version 5 ############################################
# import os
# from typing import Iterable, Dict, Any, Optional, List
# from threading import Thread # no longer needed but harmless if left
# import time # no longer needed but harmless if left
# import gradio as gr
# import spaces
# import torch
# from PIL import Image
# import pandas as pd
# from transformers import (
# Qwen3VLForConditionalGeneration,
# AutoModelForCausalLM,
# AutoProcessor,
# )
# from gradio.themes import Soft
# from gradio.themes.utils import colors, fonts, sizes
# # ============================================================
# # Character Error Rate (CER)
# # ============================================================
# def levenshtein(a: str, b: str) -> int:
# """Levenshtein distance to calculate CER."""
# a, b = a.lower(), b.lower()
# if a == b:
# return 0
# if not a:
# return len(b)
# if not b:
# return len(a)
# dp = list(range(len(b) + 1))
# for i, ca in enumerate(a, 1):
# prev = dp[0]
# dp[0] = i
# for j, cb in enumerate(b, 1):
# cur = dp[j]
# cost = 0 if ca == cb else 1
# dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost)
# prev = cur
# return dp[-1]
# def character_error_rate(pred: str, target: str) -> float:
# """Calculate the Character Error Rate (CER) in percent."""
# target = target or ""
# distance = levenshtein(pred, target)
# return (distance / len(target)) * 100 if len(target) > 0 else 0.0
# # ============================================================
# # Private repo: dynamic import + Excel download
# # ============================================================
# import importlib.util
# from huggingface_hub import hf_hub_download
# REPO_ID = "IFMedTech/Medibot_OCR_model" # private backend repo
# PY_MODULES: Dict[str, str] = {
# "clinical_NER.py": "ClinicalNER",
# "tf_idf_phonetic.py": "TfidfPhoneticMatcher",
# "symspell_matcher.py": "SymSpellMatcher",
# "rapid_fuzz_matcher.py": "RapidFuzzMatcher",
# }
# HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN") # must be set in Space secrets
# def _dynamic_import(module_path: str, class_name: str):
# spec = importlib.util.spec_from_file_location(class_name, module_path)
# module = importlib.util.module_from_spec(spec)
# spec.loader.exec_module(module) # type: ignore
# return getattr(module, class_name)
# priv_classes: Dict[str, Any] = {}
# drug_xlsx_path: Optional[str] = None
# BACKEND_INIT_ERROR: Optional[str] = None
# print("[Private] HF_TOKEN present?:", HF_TOKEN is not None)
# if HF_TOKEN is None:
# BACKEND_INIT_ERROR = "HUGGINGFACE_TOKEN env var is not set in this Space."
# print("[Private] WARNING:", BACKEND_INIT_ERROR)
# else:
# print(f"[Private] Using repo: {REPO_ID}")
# # 1) Load python modules (best-effort)
# for fname, cls_name in PY_MODULES.items():
# try:
# print(f"[Private] Downloading module file: {fname}")
# path = hf_hub_download(
# repo_id=REPO_ID,
# filename=fname,
# token=HF_TOKEN,
# repo_type="model",
# )
# priv_classes[cls_name] = _dynamic_import(path, cls_name)
# print(f"[Private] Loaded class {cls_name} from {fname}")
# except Exception as e:
# msg = f"Failed to load {fname}: {e}"
# print("[Private]", msg)
# BACKEND_INIT_ERROR = (BACKEND_INIT_ERROR or "") + f" | {msg}"
# # 2) Load Excel dictionary
# try:
# print("[Private] Downloading Excel file: Medibot_Drugs_Cleaned_Updated.xlsx")
# drug_xlsx_path = hf_hub_download(
# repo_id=REPO_ID,
# filename="Medibot_Drugs_Cleaned_Updated.xlsx",
# token=HF_TOKEN,
# repo_type="model",
# )
# print(f"[Private] Downloaded Excel at: {drug_xlsx_path}")
# df_debug = pd.read_excel(drug_xlsx_path, nrows=3)
# print(
# f"[Private] Excel loaded successfully. "
# f"Shape={df_debug.shape}, cols={list(df_debug.columns)}"
# )
# except Exception as e:
# msg = f"ERROR loading Excel: {e}"
# print("[Private]", msg)
# BACKEND_INIT_ERROR = (BACKEND_INIT_ERROR or "") + f" | {msg}"
# drug_xlsx_path = None
# # ============================================================
# # THEME
# # ============================================================
# colors.steel_blue = colors.Color(
# name="steel_blue",
# c50="#EBF3F8",
# c100="#D3E5F0",
# c200="#A8CCE1",
# c300="#7DB3D2",
# c400="#529AC3",
# c500="#4682B4",
# c600="#3E72A0",
# c700="#36638C",
# c800="#2E5378",
# c900="#264364",
# c950="#1E3450",
# )
# class SteelBlueTheme(Soft):
# def __init__(
# self,
# *,
# primary_hue: colors.Color | str = colors.gray,
# secondary_hue: colors.Color | str = colors.steel_blue,
# neutral_hue: colors.Color | str = colors.slate,
# text_size: sizes.Size | str = sizes.text_lg,
# font: fonts.Font | str | Iterable[fonts.Font | str] = (
# fonts.GoogleFont("Outfit"),
# "Arial",
# "sans-serif",
# ),
# font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
# fonts.GoogleFont("IBM Plex Mono"),
# "ui-monospace",
# "monospace",
# ),
# ):
# super().__init__(
# primary_hue=primary_hue,
# secondary_hue=secondary_hue,
# neutral_hue=neutral_hue,
# text_size=text_size,
# font=font,
# font_mono=font_mono,
# )
# super().set(
# background_fill_primary="*primary_50",
# background_fill_primary_dark="*primary_900",
# body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
# body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
# button_primary_text_color="white",
# button_primary_text_color_hover="white",
# button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
# button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
# button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
# button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
# button_secondary_text_color="black",
# button_secondary_text_color_hover="white",
# button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
# button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
# button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
# button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
# slider_color="*secondary_500",
# slider_color_dark="*secondary_600",
# block_title_text_weight="600",
# block_border_width="3px",
# block_shadow="*shadow_drop_lg",
# button_primary_shadow="*shadow_drop_lg",
# button_large_padding="11px",
# color_accent_soft="*primary_100",
# block_label_background_fill="*primary_200",
# )
# steel_blue_theme = SteelBlueTheme()
# css = """
# #main-title h1 { font-size: 2.3em !important; }
# #output-title h2 { font-size: 2.1em !important; }
# """
# # ============================================================
# # RUNTIME / DEVICE
# # ============================================================
# os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
# print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES"))
# print("torch.__version__ =", torch.__version__)
# print("torch.version.cuda =", torch.version.cuda)
# print("cuda available =", torch.cuda.is_available())
# print("cuda device count =", torch.cuda.device_count())
# if torch.cuda.is_available():
# print("using device =", torch.cuda.get_device_name(0))
# use_cuda = torch.cuda.is_available()
# device = torch.device("cuda:0" if use_cuda else "cpu")
# if use_cuda:
# torch.backends.cudnn.benchmark = True
# DTYPE_FP16 = torch.float16 if use_cuda else torch.float32
# DTYPE_BF16 = torch.bfloat16 if use_cuda else torch.float32
# # ============================================================
# # OCR MODELS: Chandra-OCR + Dots.OCR
# # ============================================================
# MODEL_ID_V = "datalab-to/chandra"
# processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
# model_v = Qwen3VLForConditionalGeneration.from_pretrained(
# MODEL_ID_V, trust_remote_code=True, torch_dtype=DTYPE_FP16
# ).to(device).eval()
# MODEL_PATH_D = "prithivMLmods/Dots.OCR-Latest-BF16"
# processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
# attn_impl = "sdpa"
# try:
# import flash_attn # noqa: F401
# if use_cuda:
# attn_impl = "flash_attention_2"
# except Exception:
# attn_impl = "sdpa"
# model_d = AutoModelForCausalLM.from_pretrained(
# MODEL_PATH_D,
# attn_implementation=attn_impl,
# torch_dtype=DTYPE_BF16,
# device_map="auto" if use_cuda else None,
# trust_remote_code=True,
# ).eval()
# if not use_cuda:
# model_d.to(device)
# # ============================================================
# # GENERATION (no raw output UI; one markdown return)
# # ============================================================
# MAX_MAX_NEW_TOKENS = 4096
# DEFAULT_MAX_NEW_TOKENS = 2048
# @spaces.GPU
# def generate_image(
# model_name: str,
# text: str,
# image: Image.Image,
# max_new_tokens: int,
# temperature: float,
# top_p: float,
# top_k: int,
# repetition_penalty: float,
# spell_algo: str,
# ) -> str:
# """
# Returns a single Markdown string:
# - Medications (extracted)
# - Spell-check suggestions
# No raw OCR text is returned to the UI.
# """
# try:
# if image is None:
# return "Please upload an image."
# # Choose processor/model
# if model_name == "Chandra-OCR":
# processor, model = processor_v, model_v
# elif model_name == "Dots.OCR":
# processor, model = processor_d, model_d
# else:
# return "Invalid model selected."
# # Build prompt
# messages = [
# {
# "role": "user",
# "content": [
# {"type": "image"},
# {"type": "text", "text": text},
# ],
# }
# ]
# prompt_full = processor.apply_chat_template(
# messages, tokenize=False, add_generation_prompt=True
# )
# # Preprocess
# inputs = processor(
# text=[prompt_full], images=[image], return_tensors="pt", padding=True
# )
# inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
# # Generate (no streaming)
# gen_kwargs = dict(
# **inputs,
# max_new_tokens=max_new_tokens,
# do_sample=True,
# temperature=temperature,
# top_p=top_p,
# top_k=top_k,
# repetition_penalty=repetition_penalty,
# )
# outputs = model.generate(**gen_kwargs)
# tokenizer = getattr(processor, "tokenizer", None) or processor
# generated = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# final_ocr_text = generated.strip()
# # --------------------------------------------------------
# # 2) Medications extraction
# # --------------------------------------------------------
# meds: List[str] = []
# if model_name == "Dots.OCR":
# try:
# if "ClinicalNER" in priv_classes and HF_TOKEN is not None:
# ClinicalNER = priv_classes["ClinicalNER"]
# ner = ClinicalNER(token=HF_TOKEN)
# ner_output = ner(final_ocr_text) or []
# meds = [
# m.strip()
# for m in ner_output
# if isinstance(m, str) and m.strip()
# ]
# print("[NER] (Dots.OCR) ClinicalNER meds:", meds)
# else:
# print("[NER] ClinicalNER unavailable or missing HF token; skipping.")
# except Exception as e:
# print(f"[NER] Error running ClinicalNER: {e}")
# if not meds:
# meds = [
# line.strip()
# for line in final_ocr_text.splitlines()
# if line.strip()
# ]
# print("[NER] (Dots.OCR) Fallback to lines, count:", len(meds))
# else: # Chandra-OCR
# meds = [
# line.strip()
# for line in final_ocr_text.splitlines()
# if line.strip()
# ]
# print("[NER] (Chandra-OCR) Line-based meds only, count:", len(meds))
# print("[DEBUG] meds count:", len(meds))
# print("[DEBUG] drug_xlsx_path in generate_image:", drug_xlsx_path)
# # --------------------------------------------------------
# # 3) Markdown: Medications only (no Raw OCR section)
# # --------------------------------------------------------
# md = "### Medications (extracted)\n"
# if meds:
# for m in meds:
# md += f"- {m}\n"
# else:
# md += "- None detected\n"
# # --------------------------------------------------------
# # 4) Spell-check (med list) with CER
# # --------------------------------------------------------
# spell_section = "\n---\n### Spell-check suggestions (" + spell_algo + ")\n"
# corr: Dict[str, List] = {}
# if BACKEND_INIT_ERROR:
# spell_section += f"- [DEBUG] Backend init error: {BACKEND_INIT_ERROR}\n"
# try:
# if meds and drug_xlsx_path:
# try:
# df_dbg = pd.read_excel(drug_xlsx_path)
# print(
# f"[Spell DEBUG] Excel read OK: path={drug_xlsx_path}, "
# f"shape={df_dbg.shape}, cols={list(df_dbg.columns)}"
# )
# spell_section += (
# f"- [DEBUG] Excel read OK; shape={df_dbg.shape}, "
# f"cols={list(df_dbg.columns)}\n"
# )
# except Exception as e:
# print(f"[Spell DEBUG] ERROR reading Excel in generate_image: {e}")
# spell_section += f"- [DEBUG] Excel read error: {e}\n"
# if (
# spell_algo == "TF-IDF + Phonetic"
# and "TfidfPhoneticMatcher" in priv_classes
# ):
# print("[Spell DEBUG] Using TfidfPhoneticMatcher")
# Cls = priv_classes["TfidfPhoneticMatcher"]
# checker = Cls(
# xlsx_path=drug_xlsx_path,
# column="Combined_Drugs",
# ngram_size=3,
# phonetic_weight=0.4,
# )
# corr = checker.match_list(meds, top_k=5, tfidf_threshold=0.15)
# elif spell_algo == "SymSpell" and "SymSpellMatcher" in priv_classes:
# print("[Spell DEBUG] Using SymSpellMatcher")
# Cls = priv_classes["SymSpellMatcher"]
# checker = Cls(
# xlsx_path=drug_xlsx_path,
# column="Combined_Drugs",
# max_edit=2,
# prefix_len=7,
# )
# corr = checker.match_list(meds, top_k=5, min_score=0.4)
# elif spell_algo == "RapidFuzz" and "RapidFuzzMatcher" in priv_classes:
# print("[Spell DEBUG] Using RapidFuzzMatcher")
# Cls = priv_classes["RapidFuzzMatcher"]
# checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs")
# corr = checker.match_list(meds, top_k=5, threshold=70.0)
# else:
# spell_section += (
# "- Spell-check backend unavailable "
# "(no matcher class for selected algorithm).\n"
# )
# else:
# if not meds:
# spell_section += "- No medications extracted (empty med list).\n"
# if not drug_xlsx_path:
# spell_section += (
# "- Drug Excel dictionary path missing "
# "(drug_xlsx_path is None).\n"
# )
# except Exception as e:
# print(f"[Spell DEBUG] Spell-check error: {e}")
# spell_section += f"- Spell-check error: {e}\n"
# if corr:
# for raw in meds:
# suggestions = corr.get(raw, [])
# if suggestions:
# spell_section += f"- **{raw}**\n"
# for cand, score in suggestions:
# cer = character_error_rate(cand, raw)
# spell_section += (
# f" - {cand} (score={score:.3f}, CER={cer:.3f}%)\n"
# )
# else:
# spell_section += f"- **{raw}**\n - (no suggestions)\n"
# final_md = md + spell_section
# return final_md
# except Exception as e:
# # Catch-all so the GPU worker does not crash
# print(f"[ERROR] generate_image crashed: {e}")
# import traceback
# traceback.print_exc()
# return f"Error while processing: {e}"
# # ============================================================
# # UI
# # ============================================================
# image_examples = [
# ["examples/test_1.jpeg"],
# ["examples/test_4.jpeg"],
# ["examples/test_5.jpeg"],
# ]
# with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
# gr.Markdown(
# "# **Handwritten Doctor's Prescription Reading**", elem_id="main-title"
# )
# with gr.Row():
# with gr.Column(scale=2):
# image_upload = gr.Image(
# type="pil", label="Upload Image", height=290
# )
# image_submit = gr.Button("Submit", variant="primary")
# gr.Examples(
# examples=image_examples,
# inputs=[image_upload],
# label="Example Images",
# )
# spell_choice = gr.Radio(
# choices=["TF-IDF + Phonetic", "SymSpell", "RapidFuzz"],
# label="Select Spell-check Approach",
# value="TF-IDF + Phonetic",
# )
# with gr.Accordion("Advanced options", open=False):
# max_new_tokens = gr.Slider(
# label="Max new tokens",
# minimum=1,
# maximum=MAX_MAX_NEW_TOKENS,
# step=1,
# value=DEFAULT_MAX_NEW_TOKENS,
# )
# temperature = gr.Slider(
# label="Temperature",
# minimum=0.1,
# maximum=4.0,
# step=0.1,
# value=0.7,
# )
# top_p = gr.Slider(
# label="Top-p (nucleus sampling)",
# minimum=0.05,
# maximum=1.0,
# step=0.05,
# value=0.9,
# )
# top_k = gr.Slider(
# label="Top-k",
# minimum=1,
# maximum=1000,
# step=1,
# value=50,
# )
# repetition_penalty = gr.Slider(
# label="Repetition penalty",
# minimum=1.0,
# maximum=2.0,
# step=0.05,
# value=1.1,
# )
# with gr.Column(scale=3):
# gr.Markdown("## Output", elem_id="output-title")
# with gr.Accordion("(Result.md)", open=True):
# markdown_output = gr.Markdown(label="(Result.Md)")
# model_choice = gr.Radio(
# choices=["Chandra-OCR", "Dots.OCR"],
# label="Select OCR Model",
# value="Chandra-OCR",
# )
# # Hard-coded query text (passed into the 'text' parameter)
# query_state = gr.State(
# "Extract medicine or drugs names along with dosage amount or quantity"
# )
# image_submit.click(
# fn=generate_image,
# inputs=[
# model_choice,
# query_state,
# image_upload,
# max_new_tokens,
# temperature,
# top_p,
# top_k,
# repetition_penalty,
# spell_choice,
# ],
# outputs=[markdown_output],
# )
# if __name__ == "__main__":
# demo.queue(max_size=50).launch(
# mcp_server=True, ssr_mode=False, show_error=True
# )
###################################### version 4 #########################################
# import os
# import time
# from threading import Thread
# from typing import Iterable, Dict, Any, Optional, List
# import gradio as gr
# import spaces
# import torch
# from PIL import Image
# import pandas as pd # Excel read + debug
# from transformers import (
# Qwen3VLForConditionalGeneration,
# AutoModelForCausalLM,
# AutoProcessor,
# TextIteratorStreamer,
# )
# from gradio.themes import Soft
# from gradio.themes.utils import colors, fonts, sizes
# # ============================================================
# # Character Error Rate (CER)
# # ============================================================
# def levenshtein(a: str, b: str) -> int:
# """Levenshtein distance to calculate CER."""
# a, b = a.lower(), b.lower()
# if a == b:
# return 0
# if not a:
# return len(b)
# if not b:
# return len(a)
# dp = list(range(len(b) + 1))
# for i, ca in enumerate(a, 1):
# prev = dp[0]
# dp[0] = i
# for j, cb in enumerate(b, 1):
# cur = dp[j]
# cost = 0 if ca == cb else 1
# dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost)
# prev = cur
# return dp[-1]
# def character_error_rate(pred: str, target: str) -> float:
# """Calculate the Character Error Rate (CER) in percent."""
# target = target or ""
# distance = levenshtein(pred, target)
# return (distance / len(target)) * 100 if len(target) > 0 else 0.0
# # ============================================================
# # Private repo: dynamic import + Excel download
# # ============================================================
# import importlib.util
# from huggingface_hub import hf_hub_download
# REPO_ID = "IFMedTech/Medibot_OCR_model" # private backend repo
# # Filenames in the repo → class names they define
# PY_MODULES: Dict[str, str] = {
# "clinical_NER.py": "ClinicalNER",
# "tf_idf_phonetic.py": "TfidfPhoneticMatcher",
# "symspell_matcher.py": "SymSpellMatcher",
# "rapid_fuzz_matcher.py": "RapidFuzzMatcher",
# }
# HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN") # must be set in Space secrets
# def _dynamic_import(module_path: str, class_name: str):
# spec = importlib.util.spec_from_file_location(class_name, module_path)
# module = importlib.util.module_from_spec(spec)
# spec.loader.exec_module(module) # type: ignore
# return getattr(module, class_name)
# priv_classes: Dict[str, Any] = {}
# drug_xlsx_path: Optional[str] = None
# BACKEND_INIT_ERROR: Optional[str] = None
# print("[Private] HF_TOKEN present?:", HF_TOKEN is not None)
# if HF_TOKEN is None:
# BACKEND_INIT_ERROR = "HUGGINGFACE_TOKEN env var is not set in this Space."
# print("[Private] WARNING:", BACKEND_INIT_ERROR)
# else:
# print(f"[Private] Using repo: {REPO_ID}")
# # 1) Load python modules (best-effort: failure of one file will not block others)
# for fname, cls_name in PY_MODULES.items():
# try:
# print(f"[Private] Downloading module file: {fname}")
# path = hf_hub_download(
# repo_id=REPO_ID,
# filename=fname,
# token=HF_TOKEN,
# repo_type="model",
# )
# priv_classes[cls_name] = _dynamic_import(path, cls_name)
# print(f"[Private] Loaded class {cls_name} from {fname}")
# except Exception as e:
# msg = f"Failed to load {fname}: {e}"
# print("[Private]", msg)
# BACKEND_INIT_ERROR = (BACKEND_INIT_ERROR or "") + f" | {msg}"
# # 2) Load Excel dictionary
# try:
# print("[Private] Downloading Excel file: Medibot_Drugs_Cleaned_Updated.xlsx")
# drug_xlsx_path = hf_hub_download(
# repo_id=REPO_ID,
# filename="Medibot_Drugs_Cleaned_Updated.xlsx",
# token=HF_TOKEN,
# repo_type="model",
# )
# print(f"[Private] Downloaded Excel at: {drug_xlsx_path}")
# # Debug: verify read
# df_debug = pd.read_excel(drug_xlsx_path, nrows=3)
# print(
# f"[Private] Excel loaded successfully. "
# f"Shape={df_debug.shape}, cols={list(df_debug.columns)}"
# )
# except Exception as e:
# msg = f"ERROR loading Excel: {e}"
# print("[Private]", msg)
# BACKEND_INIT_ERROR = (BACKEND_INIT_ERROR or "") + f" | {msg}"
# drug_xlsx_path = None
# # ============================================================
# # THEME
# # ============================================================
# colors.steel_blue = colors.Color(
# name="steel_blue",
# c50="#EBF3F8",
# c100="#D3E5F0",
# c200="#A8CCE1",
# c300="#7DB3D2",
# c400="#529AC3",
# c500="#4682B4",
# c600="#3E72A0",
# c700="#36638C",
# c800="#2E5378",
# c900="#264364",
# c950="#1E3450",
# )
# class SteelBlueTheme(Soft):
# def __init__(
# self,
# *,
# primary_hue: colors.Color | str = colors.gray,
# secondary_hue: colors.Color | str = colors.steel_blue,
# neutral_hue: colors.Color | str = colors.slate,
# text_size: sizes.Size | str = sizes.text_lg,
# font: fonts.Font | str | Iterable[fonts.Font | str] = (
# fonts.GoogleFont("Outfit"),
# "Arial",
# "sans-serif",
# ),
# font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
# fonts.GoogleFont("IBM Plex Mono"),
# "ui-monospace",
# "monospace",
# ),
# ):
# super().__init__(
# primary_hue=primary_hue,
# secondary_hue=secondary_hue,
# neutral_hue=neutral_hue,
# text_size=text_size,
# font=font,
# font_mono=font_mono,
# )
# super().set(
# background_fill_primary="*primary_50",
# background_fill_primary_dark="*primary_900",
# body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
# body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
# button_primary_text_color="white",
# button_primary_text_color_hover="white",
# button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
# button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
# button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
# button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
# button_secondary_text_color="black",
# button_secondary_text_color_hover="white",
# button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
# button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
# button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
# button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
# slider_color="*secondary_500",
# slider_color_dark="*secondary_600",
# block_title_text_weight="600",
# block_border_width="3px",
# block_shadow="*shadow_drop_lg",
# button_primary_shadow="*shadow_drop_lg",
# button_large_padding="11px",
# color_accent_soft="*primary_100",
# block_label_background_fill="*primary_200",
# )
# steel_blue_theme = SteelBlueTheme()
# css = """
# #main-title h1 { font-size: 2.3em !important; }
# #output-title h2 { font-size: 2.1em !important; }
# """
# # ============================================================
# # RUNTIME / DEVICE
# # ============================================================
# os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
# print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES"))
# print("torch.__version__ =", torch.__version__)
# print("torch.version.cuda =", torch.version.cuda)
# print("cuda available =", torch.cuda.is_available())
# print("cuda device count =", torch.cuda.device_count())
# if torch.cuda.is_available():
# print("using device =", torch.cuda.get_device_name(0))
# use_cuda = torch.cuda.is_available()
# device = torch.device("cuda:0" if use_cuda else "cpu")
# if use_cuda:
# torch.backends.cudnn.benchmark = True
# DTYPE_FP16 = torch.float16 if use_cuda else torch.float32
# DTYPE_BF16 = torch.bfloat16 if use_cuda else torch.float32
# # ============================================================
# # OCR MODELS: Chandra-OCR + Dots.OCR
# # ============================================================
# # 1) Chandra-OCR (Qwen3VL)
# MODEL_ID_V = "datalab-to/chandra"
# processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
# model_v = Qwen3VLForConditionalGeneration.from_pretrained(
# MODEL_ID_V, trust_remote_code=True, torch_dtype=DTYPE_FP16
# ).to(device).eval()
# # 2) Dots.OCR (flash_attn2 if available, else SDPA)
# MODEL_PATH_D = "prithivMLmods/Dots.OCR-Latest-BF16"
# processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
# attn_impl = "sdpa"
# try:
# import flash_attn # noqa: F401
# if use_cuda:
# attn_impl = "flash_attention_2"
# except Exception:
# attn_impl = "sdpa"
# model_d = AutoModelForCausalLM.from_pretrained(
# MODEL_PATH_D,
# attn_implementation=attn_impl,
# torch_dtype=DTYPE_BF16,
# device_map="auto" if use_cuda else None,
# trust_remote_code=True,
# ).eval()
# if not use_cuda:
# model_d.to(device)
# # ============================================================
# # GENERATION (OCR → Med extraction → Spell-check + CER)
# # ClinicalNER is used ONLY for Dots.OCR.
# # ============================================================
# MAX_MAX_NEW_TOKENS = 4096
# DEFAULT_MAX_NEW_TOKENS = 2048
# @spaces.GPU # you can add duration=... if you hit timeouts
# def generate_image(
# model_name: str,
# text: str,
# image: Image.Image,
# max_new_tokens: int,
# temperature: float,
# top_p: float,
# top_k: int,
# repetition_penalty: float,
# spell_algo: str,
# ):
# """
# 1) Stream OCR tokens to Raw output.
# 2) For Dots.OCR: run ClinicalNER → meds list (with fallback to line-based).
# For Chandra-OCR: DO NOT call ClinicalNER; meds from OCR lines only.
# 3) Apply selected spell-check algorithm on meds using Excel dictionary.
# 4) Compute CER for each suggestion and display in markdown.
# """
# if image is None:
# yield "Please upload an image.", "Please upload an image."
# return
# # Choose processor/model
# if model_name == "Chandra-OCR":
# processor, model = processor_v, model_v
# elif model_name == "Dots.OCR":
# processor, model = processor_d, model_d
# else:
# yield "Invalid model selected.", "Invalid model selected."
# return
# # Prompt (text is provided via gr.State)
# messages = [
# {
# "role": "user",
# "content": [
# {"type": "image"},
# {"type": "text", "text": text},
# ],
# }
# ]
# prompt_full = processor.apply_chat_template(
# messages, tokenize=False, add_generation_prompt=True
# )
# # Preprocess
# inputs = processor(
# text=[prompt_full], images=[image], return_tensors="pt", padding=True
# )
# inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
# # Streamer
# tokenizer = getattr(processor, "tokenizer", None) or processor
# streamer = TextIteratorStreamer(
# tokenizer, skip_prompt=True, skip_special_tokens=True
# )
# gen_kwargs = dict(
# **inputs,
# streamer=streamer,
# max_new_tokens=max_new_tokens,
# do_sample=True,
# temperature=temperature,
# top_p=top_p,
# top_k=top_k,
# repetition_penalty=repetition_penalty,
# )
# # Start generation in background thread
# thread = Thread(target=model.generate, kwargs=gen_kwargs)
# thread.start()
# # 1) Live OCR streaming : show raw text while generating
# buffer = ""
# for new_text in streamer:
# buffer += new_text.replace("<|im_end|>", "")
# time.sleep(0.01)
# md_stream = "Processing..."
# yield buffer,md_stream #buffer # two outputs: raw + md (same during stream)
# final_ocr_text = buffer.strip()
# # --------------------------------------------------------
# # 2) Medications extraction
# # --------------------------------------------------------
# meds: List[str] = []
# if model_name == "Dots.OCR":
# # ClinicalNER ONLY for Dots.OCR
# try:
# if "ClinicalNER" in priv_classes and HF_TOKEN is not None:
# ClinicalNER = priv_classes["ClinicalNER"]
# ner = ClinicalNER(token=HF_TOKEN)
# ner_output = ner(final_ocr_text) or []
# meds = [
# m.strip()
# for m in ner_output
# if isinstance(m, str) and m.strip()
# ]
# print("[NER] (Dots.OCR) ClinicalNER meds:", meds)
# else:
# print("[NER] ClinicalNER unavailable or missing HF token; skipping.")
# except Exception as e:
# print(f"[NER] Error running ClinicalNER: {e}")
# # Fallback if ClinicalNER returns nothing
# if not meds:
# meds = [
# line.strip()
# for line in final_ocr_text.splitlines()
# if line.strip()
# ]
# print("[NER] (Dots.OCR) Fallback to lines, count:", len(meds))
# elif model_name == "Chandra-OCR":
# # NO ClinicalNER for Chandra; just use text lines
# meds = [
# line.strip()
# for line in final_ocr_text.splitlines()
# if line.strip()
# ]
# print("[NER] (Chandra-OCR) Line-based meds only, count:", len(meds))
# print("[DEBUG] meds count:", len(meds))
# print("[DEBUG] drug_xlsx_path in generate_image:", drug_xlsx_path)
# # --------------------------------------------------------
# # 3) Build Markdown base: OCR text + med list
# # --------------------------------------------------------
# # md = "### Raw OCR Output\n"
# # md += "```\n" + (final_ocr_text or "(empty)") + "\n```\n"
# md=""
# md += "\n---\n### Medications (extracted)\n"
# if meds:
# for m in meds:
# md += f"- {m}\n"
# else:
# md += "- None detected\n"
# # --------------------------------------------------------
# # 4) Spell-check (med list) with CER
# # --------------------------------------------------------
# spell_section = "\n---\n### Spell-check suggestions (" + spell_algo + ")\n"
# corr: Dict[str, List] = {}
# if BACKEND_INIT_ERROR:
# spell_section += f"- [DEBUG] Backend init error: {BACKEND_INIT_ERROR}\n"
# try:
# if meds and drug_xlsx_path:
# # Optional Excel debug read
# try:
# df_dbg = pd.read_excel(drug_xlsx_path)
# print(
# f"[Spell DEBUG] Excel read OK: path={drug_xlsx_path}, "
# f"shape={df_dbg.shape}, cols={list(df_dbg.columns)}"
# )
# spell_section += (
# f"- [DEBUG] Excel read OK; shape={df_dbg.shape}, "
# f"cols={list(df_dbg.columns)}\n"
# )
# except Exception as e:
# print(f"[Spell DEBUG] ERROR reading Excel in generate_image: {e}")
# spell_section += f"- [DEBUG] Excel read error: {e}\n"
# # Pick matcher based on spell_algo
# if (
# spell_algo == "TF-IDF + Phonetic"
# and "TfidfPhoneticMatcher" in priv_classes
# ):
# print("[Spell DEBUG] Using TfidfPhoneticMatcher")
# Cls = priv_classes["TfidfPhoneticMatcher"]
# checker = Cls(
# xlsx_path=drug_xlsx_path,
# column="Combined_Drugs",
# ngram_size=3,
# phonetic_weight=0.4,
# )
# corr = checker.match_list(meds, top_k=5, tfidf_threshold=0.15)
# elif spell_algo == "SymSpell" and "SymSpellMatcher" in priv_classes:
# print("[Spell DEBUG] Using SymSpellMatcher")
# Cls = priv_classes["SymSpellMatcher"]
# checker = Cls(
# xlsx_path=drug_xlsx_path,
# column="Combined_Drugs",
# max_edit=2,
# prefix_len=7,
# )
# corr = checker.match_list(meds, top_k=5, min_score=0.4)
# elif spell_algo == "RapidFuzz" and "RapidFuzzMatcher" in priv_classes:
# print("[Spell DEBUG] Using RapidFuzzMatcher")
# Cls = priv_classes["RapidFuzzMatcher"]
# checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs")
# corr = checker.match_list(meds, top_k=5, threshold=70.0)
# else:
# spell_section += (
# "- Spell-check backend unavailable "
# "(no matcher class for selected algorithm).\n"
# )
# else:
# if not meds:
# spell_section += "- No medications extracted (empty med list).\n"
# if not drug_xlsx_path:
# spell_section += (
# "- Drug Excel dictionary path missing "
# "(drug_xlsx_path is None).\n"
# )
# except Exception as e:
# print(f"[Spell DEBUG] Spell-check error: {e}")
# spell_section += f"- Spell-check error: {e}\n"
# # Format suggestions (top-5 per med, with scores + CER)
# if corr:
# for raw in meds:
# suggestions = corr.get(raw, [])
# if suggestions:
# spell_section += f"- **{raw}**\n"
# for cand, score in suggestions:
# cer = character_error_rate(cand, raw)
# spell_section += (
# f" - {cand} (score={score:.3f}, CER={cer:.3f}%)\n"
# )
# else:
# spell_section += f"- **{raw}**\n - (no suggestions)\n"
# final_md = md + spell_section
# # Final yield: raw OCR text + full markdown
# # yield final_ocr_text, final_md
# yield final_md
# # ============================================================
# # UI
# # ============================================================
# image_examples = [
# ["examples/test_1.jpeg"],
# ["examples/test_4.jpeg"],
# ["examples/test_5.jpeg"],
# ]
# with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
# gr.Markdown(
# "# **Handwritten Doctor's Prescription Reading**", elem_id="main-title"
# )
# with gr.Row():
# with gr.Column(scale=2):
# image_upload = gr.Image(
# type="pil", label="Upload Image", height=290
# )
# image_submit = gr.Button("Submit", variant="primary")
# gr.Examples(
# examples=image_examples,
# inputs=[image_upload],
# label="Example Images",
# )
# spell_choice = gr.Radio(
# choices=["TF-IDF + Phonetic", "SymSpell", "RapidFuzz"],
# label="Select Spell-check Approach",
# value="TF-IDF + Phonetic",
# )
# with gr.Accordion("Advanced options", open=False):
# max_new_tokens = gr.Slider(
# label="Max new tokens",
# minimum=1,
# maximum=MAX_MAX_NEW_TOKENS,
# step=1,
# value=DEFAULT_MAX_NEW_TOKENS,
# )
# temperature = gr.Slider(
# label="Temperature",
# minimum=0.1,
# maximum=4.0,
# step=0.1,
# value=0.7,
# )
# top_p = gr.Slider(
# label="Top-p (nucleus sampling)",
# minimum=0.05,
# maximum=1.0,
# step=0.05,
# value=0.9,
# )
# top_k = gr.Slider(
# label="Top-k",
# minimum=1,
# maximum=1000,
# step=1,
# value=50,
# )
# repetition_penalty = gr.Slider(
# label="Repetition penalty",
# minimum=1.0,
# maximum=2.0,
# step=0.05,
# value=1.1,
# )
# with gr.Column(scale=3):
# gr.Markdown("## Output", elem_id="output-title")
# # output = gr.Textbox(
# # label="Raw Output Stream",
# # interactive=False,
# # lines=11,
# # show_copy_button=True,
# # )
# with gr.Accordion("(Result.md)", open=False):
# markdown_output = gr.Markdown(label="(Result.Md)")
# model_choice = gr.Radio(
# choices=["Chandra-OCR", "Dots.OCR"],
# label="Select OCR Model",
# value="Chandra-OCR",
# )
# # Hard-coded query text (passed into the 'text' parameter)
# query_state = gr.State(
# "Extract medicine or drugs names along with dosage amount or quantity"
# )
# image_submit.click(
# fn=generate_image,
# inputs=[
# model_choice,
# query_state,
# image_upload,
# max_new_tokens,
# temperature,
# top_p,
# top_k,
# repetition_penalty,
# spell_choice,
# ],
# outputs=[markdown_output],
# )
# if __name__ == "__main__":
# demo.queue(max_size=50).launch(
# mcp_server=True, ssr_mode=False, show_error=True
# )
################################### version 3 ########################################
import os
import time
from threading import Thread
from typing import Iterable, Dict, Any, Optional, List
import gradio as gr
import spaces
import torch
from PIL import Image
import pandas as pd # for reading Excel and debugging
from transformers import (
Qwen3VLForConditionalGeneration,
AutoModelForCausalLM,
AutoProcessor,
TextIteratorStreamer,
)
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
# -----------------------------
# Character Error Rate (CER)
# -----------------------------
def levenshtein(a: str, b: str) -> int:
"""Levenshtein distance to calculate CER."""
a, b = a.lower(), b.lower()
if a == b:
return 0
if not a:
return len(b)
if not b:
return len(a)
dp = list(range(len(b) + 1))
for i, ca in enumerate(a, 1):
prev = dp[0]
dp[0] = i
for j, cb in enumerate(b, 1):
cur = dp[j]
cost = 0 if ca == cb else 1
dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost)
prev = cur
return dp[-1]
def character_error_rate(pred: str, target: str) -> float:
"""Calculate the Character Error Rate (CER) in percent."""
target = target or ""
distance = levenshtein(pred, target)
return (distance / len(target)) * 100 if len(target) > 0 else 0.0
# -----------------------------
# Private repo: dynamic import
# -----------------------------
import importlib.util
from huggingface_hub import hf_hub_download
REPO_ID = "IFMedTech/Medibot_OCR_model" # private backend repo
# Map filenames to exported class names
PY_MODULES = {
"ner.py": "ClinicalNER", # NER is only applied for Dots.OCR output
"tfidf_phonetic.py": "TfidfPhoneticMatcher",
"symspell_matcher.py": "SymSpellMatcher",
"rapidfuzz_matcher.py": "RapidFuzzMatcher",
# 'Medibot_Drugs_Cleaned_Updated.xlsx' is data, not a module
}
HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
def _dynamic_import(module_path: str, class_name: str):
spec = importlib.util.spec_from_file_location(class_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return getattr(module, class_name)
# Load private classes and Excel dictionary (once at import time)
priv_classes: Dict[str, Any] = {}
drug_xlsx_path: Optional[str] = None
try:
if HF_TOKEN is None:
print("[Private] WARNING: HUGGINGFACE_TOKEN not set; NER/Spell-check will be unavailable.")
else:
for fname, cls in PY_MODULES.items():
path = hf_hub_download(repo_id=REPO_ID, filename=fname, token=HF_TOKEN)
if cls:
priv_classes[cls] = _dynamic_import(path, cls)
print(f"[Private] Loaded class: {cls} from {fname}")
drug_xlsx_path = hf_hub_download(
repo_id=REPO_ID,
filename="Medibot_Drugs_Cleaned_Updated.xlsx",
token=HF_TOKEN,
)
print(f"[Private] Downloaded Excel at: {drug_xlsx_path}")
# DEBUG: read Excel once and print its shape
try:
df_debug = pd.read_excel(drug_xlsx_path)
print(f"[Private] Excel loaded successfully. Shape: {df_debug.shape}")
except Exception as e:
print(f"[Private] ERROR reading Excel for debug: {e}")
except Exception as e:
print(f"[Private] ERROR loading private backend: {e}")
priv_classes = {}
drug_xlsx_path = None
# ----------------------------
# THEME
# ----------------------------
colors.steel_blue = colors.Color(
name="steel_blue",
c50="#EBF3F8",
c100="#D3E5F0",
c200="#A8CCE1",
c300="#7DB3D2",
c400="#529AC3",
c500="#4682B4",
c600="#3E72A0",
c700="#36638C",
c800="#2E5378",
c900="#264364",
c950="#1E3450",
)
class SteelBlueTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.steel_blue,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"),
"Arial",
"sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"),
"ui-monospace",
"monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
background_fill_primary_dark="*primary_900",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
button_secondary_text_color="black",
button_secondary_text_color_hover="white",
button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
slider_color="*secondary_500",
slider_color_dark="*secondary_600",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
button_large_padding="11px",
color_accent_soft="*primary_100",
block_label_background_fill="*primary_200",
)
steel_blue_theme = SteelBlueTheme()
css = """
#main-title h1 { font-size: 2.3em !important; }
#output-title h2 { font-size: 2.1em !important; }
"""
# ----------------------------
# RUNTIME / DEVICE
# ----------------------------
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES"))
print("torch.__version__ =", torch.__version__)
print("torch.version.cuda =", torch.version.cuda)
print("cuda available =", torch.cuda.is_available())
print("cuda device count =", torch.cuda.device_count())
if torch.cuda.is_available():
print("using device =", torch.cuda.get_device_name(0))
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
if use_cuda:
torch.backends.cudnn.benchmark = True
DTYPE_FP16 = torch.float16 if use_cuda else torch.float32
DTYPE_BF16 = torch.bfloat16 if use_cuda else torch.float32
# ----------------------------
# OCR MODELS: Chandra-OCR + Dots.OCR
# ----------------------------
# 1) Chandra-OCR (Qwen3VL)
MODEL_ID_V = "datalab-to/chandra"
processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
model_v = Qwen3VLForConditionalGeneration.from_pretrained(
MODEL_ID_V, trust_remote_code=True, torch_dtype=DTYPE_FP16
).to(device).eval()
# 2) Dots.OCR (flash_attn2 if available, else SDPA)
MODEL_PATH_D = "prithivMLmods/Dots.OCR-Latest-BF16"
processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
attn_impl = "sdpa"
try:
import flash_attn # noqa: F401
if use_cuda:
attn_impl = "flash_attention_2"
except Exception:
attn_impl = "sdpa"
model_d = AutoModelForCausalLM.from_pretrained(
MODEL_PATH_D,
attn_implementation=attn_impl,
torch_dtype=DTYPE_BF16,
device_map="auto" if use_cuda else None,
trust_remote_code=True,
).eval()
if not use_cuda:
model_d.to(device)
# ----------------------------
# GENERATION (OCR → NER (Dots only) → Spell-check + CER)
# ----------------------------
MAX_MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 2048
@spaces.GPU # you can add duration=... if needed, e.g. @spaces.GPU(duration=240)
def generate_image(
model_name: str,
text: str,
image: Image.Image,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: float,
spell_algo: str,
):
"""
1) Stream OCR tokens to Raw output (unchanged).
2) If model_name == 'Dots.OCR', run ClinicalNER → list[str] meds.
For Chandra-OCR, skip NER.
3) Apply selected spell-check (TF-IDF+Phonetic / SymSpell / RapidFuzz)
using Excel dict, and compute CER for each suggestion.
4) Markdown shows OCR text, NER list (if any), and spell-check top-5
suggestions with scores and CER.
"""
if image is None:
# Two outputs: raw textbox + markdown
yield "Please upload an image.", "Please upload an image."
return
if model_name == "Chandra-OCR":
processor, model = processor_v, model_v
elif model_name == "Dots.OCR":
processor, model = processor_d, model_d
else:
yield "Invalid model selected.", "Invalid model selected."
return
# Build prompt from text parameter (kept via gr.State)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": text},
],
}
]
prompt_full = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Preprocess
inputs = processor(
text=[prompt_full], images=[image], return_tensors="pt", padding=True
)
inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
# Streamer
tokenizer = getattr(processor, "tokenizer", None) or processor
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
gen_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
)
# Start generation in background thread
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
# 1) Live OCR streaming to Raw (and mirror to Markdown during stream)
buffer = ""
for new_text in streamer:
buffer += new_text.replace("<|im_end|>", "")
time.sleep(0.01)
# During streaming, just show the raw text in both components
yield buffer, buffer
# Final raw text
final_ocr_text = buffer.strip()
# -------------------------
# 2) Clinical NER (Dots.OCR only)
# -------------------------
meds: List[str] = []
if model_name == "Dots.OCR":
try:
if "ClinicalNER" in priv_classes and HF_TOKEN is not None:
ClinicalNER = priv_classes["ClinicalNER"]
ner = ClinicalNER(token=HF_TOKEN) # model_id can be passed if needed
ner_output = ner(final_ocr_text) or []
# Expecting list[str]; be robust:
meds = [m.strip() for m in ner_output if isinstance(m, str) and m.strip()]
print("[NER] Extracted meds (from ClinicalNER):", meds)
else:
print("[NER] ClinicalNER not available or no HF token.")
except Exception as e:
print(f"[NER] Error running ClinicalNER: {e}")
# Fallback: if no meds found (or Chandra-OCR), derive meds from OCR lines
if not meds:
meds = [line.strip() for line in final_ocr_text.splitlines() if line.strip()]
print("[NER] Using line-based meds fallback, count:", len(meds))
print("[DEBUG] meds count:", len(meds))
print("[DEBUG] drug_xlsx_path in generate_image:", drug_xlsx_path)
# -------------------------
# Build Markdown: OCR text + NER section
# -------------------------
md = "### Raw OCR Output\n"
md += "```\n" + (final_ocr_text or "(empty)") + "\n```\n"
md += "\n---\n### Clinical NER (Medications)\n"
if meds:
for m in meds:
md += f"- {m}\n"
else:
md += "- None detected\n"
# -------------------------
# 3) Spell-check (med list) with CER
# -------------------------
spell_section = "\n---\n### Spell-check suggestions (" + spell_algo + ")\n"
corr: Dict[str, List] = {}
try:
if meds and drug_xlsx_path:
try:
df_debug = pd.read_excel(drug_xlsx_path)
print(f"[Private] Excel loaded successfully. Shape: {df_debug.shape}")
except Exception as e:
print(f"[Private] ERROR reading Excel for debug: {e}")
if (
spell_algo == "TF-IDF + Phonetic"
and "TfidfPhoneticMatcher" in priv_classes
):
Cls = priv_classes["TfidfPhoneticMatcher"]
checker = Cls(
xlsx_path=drug_xlsx_path,
column="Combined_Drugs",
ngram_size=3,
phonetic_weight=0.4,
)
corr = checker.match_list(meds, top_k=5, tfidf_threshold=0.15)
elif spell_algo == "SymSpell" and "SymSpellMatcher" in priv_classes:
Cls = priv_classes["SymSpellMatcher"]
checker = Cls(
xlsx_path=drug_xlsx_path,
column="Combined_Drugs",
max_edit=2,
prefix_len=7,
)
corr = checker.match_list(meds, top_k=5, min_score=0.4)
elif spell_algo == "RapidFuzz" and "RapidFuzzMatcher" in priv_classes:
Cls = priv_classes["RapidFuzzMatcher"]
checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs")
corr = checker.match_list(meds, top_k=5, threshold=70.0)
else:
spell_section += "- Spell-check backend unavailable (no matcher class).\n"
else:
if not meds:
spell_section += "- No medications extracted (empty med list).\n"
if not drug_xlsx_path:
spell_section += "- Drug Excel dictionary path missing (drug_xlsx_path is None).\n"
except Exception as e:
spell_section += f"- Spell-check error: {e}\n"
# Format suggestions (top-5 per med, with scores + CER)
if corr:
for raw in meds:
suggestions = corr.get(raw, [])
if suggestions:
spell_section += f"- **{raw}**\n"
for cand, score in suggestions:
cer = character_error_rate(cand, raw)
spell_section += (
f" - {cand} (score={score:.3f}, CER={cer:.3f}%)\n"
)
else:
spell_section += f"- **{raw}**\n - (no suggestions)\n"
final_md = md + spell_section
# 4) Final yield: raw unchanged; Markdown with NER + spell-check + CER
yield final_ocr_text, final_md
# ----------------------------
# UI
# ----------------------------
# IMPORTANT: examples must match the number of inputs (here: only image)
image_examples = [
["examples/3.jpg"],
["examples/1.jpg"],
["examples/2.jpg"],
]
with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
gr.Markdown(
"# **Handwritten Doctor's Prescription Reading**", elem_id="main-title"
)
with gr.Row():
with gr.Column(scale=2):
image_upload = gr.Image(
type="pil", label="Upload Image", height=290
)
image_submit = gr.Button("Submit", variant="primary")
gr.Examples(
examples=image_examples,
inputs=[image_upload],
label="Example Images",
)
# Spell-check selection
spell_choice = gr.Radio(
choices=["TF-IDF + Phonetic", "SymSpell", "RapidFuzz"],
label="Select Spell-check Approach",
value="TF-IDF + Phonetic",
)
with gr.Accordion("Advanced options", open=False):
max_new_tokens = gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.7,
)
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
)
top_k = gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
)
repetition_penalty = gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.1,
)
with gr.Column(scale=3):
gr.Markdown("## Output", elem_id="output-title")
output = gr.Textbox(
label="Raw Output Stream",
interactive=False,
lines=11,
show_copy_button=True,
)
with gr.Accordion("(Result.md)", open=False):
markdown_output = gr.Markdown(label="(Result.Md)")
model_choice = gr.Radio(
choices=["Chandra-OCR", "Dots.OCR"],
label="Select OCR Model",
value="Chandra-OCR",
)
# Hard-coded instruction text, passed as gr.State to match the 'text' parameter
query_state = gr.State(
"Extract medicine or drugs names along with dosage amount or quantity"
)
image_submit.click(
fn=generate_image,
inputs=[
model_choice,
query_state,
image_upload,
max_new_tokens,
temperature,
top_p,
top_k,
repetition_penalty,
spell_choice,
],
outputs=[output, markdown_output],
)
if __name__ == "__main__":
demo.queue(max_size=50).launch(
mcp_server=True, ssr_mode=False, show_error=True
)
######################################### version 2 #########################################################################
# import os
# import time
# from threading import Thread
# from typing import Iterable, Dict, Any, Optional, List
# import gradio as gr
# import spaces
# import torch
# from PIL import Image
# from transformers import (
# Qwen3VLForConditionalGeneration,
# AutoModelForCausalLM,
# AutoProcessor,
# TextIteratorStreamer,
# )
# from gradio.themes import Soft
# from gradio.themes.utils import colors, fonts, sizes
# # -----------------------------
# # Character Error Rate (CER)
# # -----------------------------
# def levenshtein(a: str, b: str) -> int:
# """Levenshtein distance to calculate CER."""
# a, b = a.lower(), b.lower()
# if a == b:
# return 0
# if not a:
# return len(b)
# if not b:
# return len(a)
# dp = list(range(len(b) + 1))
# for i, ca in enumerate(a, 1):
# prev = dp[0]
# dp[0] = i
# for j, cb in enumerate(b, 1):
# cur = dp[j]
# cost = 0 if ca == cb else 1
# dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost)
# prev = cur
# return dp[-1]
# def character_error_rate(pred: str, target: str) -> float:
# """Calculate the Character Error Rate (CER) in percent."""
# target = target or ""
# distance = levenshtein(pred, target)
# return (distance / len(target)) * 100 if len(target) > 0 else 0.0
# # -----------------------------
# # Private repo: dynamic import
# # -----------------------------
# import importlib.util
# from huggingface_hub import hf_hub_download
# REPO_ID = "IFMedTech/Medibot_OCR_model" # private backend repo
# # Map filenames to exported class names
# PY_MODULES = {
# "ner.py": "ClinicalNER", # NER is only applied for Dots.OCR output
# "tfidf_phonetic.py": "TfidfPhoneticMatcher",
# "symspell_matcher.py": "SymSpellMatcher",
# "rapidfuzz_matcher.py": "RapidFuzzMatcher",
# # 'Medibot_Drugs_Cleaned_Updated.xlsx' is data, not a module
# }
# HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
# def _dynamic_import(module_path: str, class_name: str):
# spec = importlib.util.spec_from_file_location(class_name, module_path)
# module = importlib.util.module_from_spec(spec)
# spec.loader.exec_module(module) # type: ignore
# return getattr(module, class_name)
# # Load private classes and Excel dictionary (once at import time)
# priv_classes: Dict[str, Any] = {}
# drug_xlsx_path: Optional[str] = None
# try:
# if HF_TOKEN is None:
# print("[Private] WARNING: HUGGINGFACE_TOKEN not set; NER/Spell-check will be unavailable.")
# else:
# for fname, cls in PY_MODULES.items():
# path = hf_hub_download(repo_id=REPO_ID, filename=fname, token=HF_TOKEN)
# if cls:
# priv_classes[cls] = _dynamic_import(path, cls)
# print(f"[Private] Loaded class: {cls} from {fname}")
# drug_xlsx_path = hf_hub_download(
# repo_id=REPO_ID,
# filename="Medibot_Drugs_Cleaned_Updated.xlsx",
# token=HF_TOKEN,
# )
# print(f"[Private] Downloaded Excel at: {drug_xlsx_path}")
# except Exception as e:
# print(f"[Private] ERROR loading private backend: {e}")
# priv_classes = {}
# drug_xlsx_path = None
# # ----------------------------
# # THEME
# # ----------------------------
# colors.steel_blue = colors.Color(
# name="steel_blue",
# c50="#EBF3F8",
# c100="#D3E5F0",
# c200="#A8CCE1",
# c300="#7DB3D2",
# c400="#529AC3",
# c500="#4682B4",
# c600="#3E72A0",
# c700="#36638C",
# c800="#2E5378",
# c900="#264364",
# c950="#1E3450",
# )
# class SteelBlueTheme(Soft):
# def __init__(
# self,
# *,
# primary_hue: colors.Color | str = colors.gray,
# secondary_hue: colors.Color | str = colors.steel_blue,
# neutral_hue: colors.Color | str = colors.slate,
# text_size: sizes.Size | str = sizes.text_lg,
# font: fonts.Font | str | Iterable[fonts.Font | str] = (
# fonts.GoogleFont("Outfit"),
# "Arial",
# "sans-serif",
# ),
# font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
# fonts.GoogleFont("IBM Plex Mono"),
# "ui-monospace",
# "monospace",
# ),
# ):
# super().__init__(
# primary_hue=primary_hue,
# secondary_hue=secondary_hue,
# neutral_hue=neutral_hue,
# text_size=text_size,
# font=font,
# font_mono=font_mono,
# )
# super().set(
# background_fill_primary="*primary_50",
# background_fill_primary_dark="*primary_900",
# body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
# body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
# button_primary_text_color="white",
# button_primary_text_color_hover="white",
# button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
# button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
# button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
# button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
# button_secondary_text_color="black",
# button_secondary_text_color_hover="white",
# button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
# button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
# button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
# button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
# slider_color="*secondary_500",
# slider_color_dark="*secondary_600",
# block_title_text_weight="600",
# block_border_width="3px",
# block_shadow="*shadow_drop_lg",
# button_primary_shadow="*shadow_drop_lg",
# button_large_padding="11px",
# color_accent_soft="*primary_100",
# block_label_background_fill="*primary_200",
# )
# steel_blue_theme = SteelBlueTheme()
# css = """
# #main-title h1 { font-size: 2.3em !important; }
# #output-title h2 { font-size: 2.1em !important; }
# """
# # ----------------------------
# # RUNTIME / DEVICE
# # ----------------------------
# os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
# print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES"))
# print("torch.__version__ =", torch.__version__)
# print("torch.version.cuda =", torch.version.cuda)
# print("cuda available =", torch.cuda.is_available())
# print("cuda device count =", torch.cuda.device_count())
# if torch.cuda.is_available():
# print("using device =", torch.cuda.get_device_name(0))
# use_cuda = torch.cuda.is_available()
# device = torch.device("cuda:0" if use_cuda else "cpu")
# if use_cuda:
# torch.backends.cudnn.benchmark = True
# DTYPE_FP16 = torch.float16 if use_cuda else torch.float32
# DTYPE_BF16 = torch.bfloat16 if use_cuda else torch.float32
# # ----------------------------
# # OCR MODELS: Chandra-OCR + Dots.OCR
# # ----------------------------
# # 1) Chandra-OCR (Qwen3VL)
# MODEL_ID_V = "datalab-to/chandra"
# processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
# model_v = Qwen3VLForConditionalGeneration.from_pretrained(
# MODEL_ID_V, trust_remote_code=True, torch_dtype=DTYPE_FP16
# ).to(device).eval()
# # 2) Dots.OCR (flash_attn2 if available, else SDPA)
# MODEL_PATH_D = "prithivMLmods/Dots.OCR-Latest-BF16"
# processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
# attn_impl = "sdpa"
# try:
# import flash_attn # noqa: F401
# if use_cuda:
# attn_impl = "flash_attention_2"
# except Exception:
# attn_impl = "sdpa"
# model_d = AutoModelForCausalLM.from_pretrained(
# MODEL_PATH_D,
# attn_implementation=attn_impl,
# torch_dtype=DTYPE_BF16,
# device_map="auto" if use_cuda else None,
# trust_remote_code=True,
# ).eval()
# if not use_cuda:
# model_d.to(device)
# # ----------------------------
# # GENERATION (OCR → NER (Dots only) → Spell-check + CER)
# # ----------------------------
# MAX_MAX_NEW_TOKENS = 4096
# DEFAULT_MAX_NEW_TOKENS = 2048
# @spaces.GPU # you can add duration=... if needed, e.g. @spaces.GPU(duration=240)
# def generate_image(
# model_name: str,
# text: str,
# image: Image.Image,
# max_new_tokens: int,
# temperature: float,
# top_p: float,
# top_k: int,
# repetition_penalty: float,
# spell_algo: str,
# ):
# """
# 1) Stream OCR tokens to Raw output (unchanged).
# 2) If model_name == 'Dots.OCR', run ClinicalNER → list[str] meds.
# For Chandra-OCR, skip NER.
# 3) Apply selected spell-check (TF-IDF+Phonetic / SymSpell / RapidFuzz)
# using Excel dict, and compute CER for each suggestion.
# 4) Markdown shows OCR text, NER list (if any), and spell-check top-5
# suggestions with scores and CER.
# """
# if image is None:
# # Two outputs: raw textbox + markdown
# yield "Please upload an image.", "Please upload an image."
# return
# if model_name == "Chandra-OCR":
# processor, model = processor_v, model_v
# elif model_name == "Dots.OCR":
# processor, model = processor_d, model_d
# else:
# yield "Invalid model selected.", "Invalid model selected."
# return
# # Build prompt from text parameter (kept via gr.State)
# messages = [
# {
# "role": "user",
# "content": [
# {"type": "image"},
# {"type": "text", "text": text},
# ],
# }
# ]
# prompt_full = processor.apply_chat_template(
# messages, tokenize=False, add_generation_prompt=True
# )
# # Preprocess
# inputs = processor(
# text=[prompt_full], images=[image], return_tensors="pt", padding=True
# )
# inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
# # Streamer
# tokenizer = getattr(processor, "tokenizer", None) or processor
# streamer = TextIteratorStreamer(
# tokenizer, skip_prompt=True, skip_special_tokens=True
# )
# gen_kwargs = dict(
# **inputs,
# streamer=streamer,
# max_new_tokens=max_new_tokens,
# do_sample=True,
# temperature=temperature,
# top_p=top_p,
# top_k=top_k,
# repetition_penalty=repetition_penalty,
# )
# # Start generation in background thread
# thread = Thread(target=model.generate, kwargs=gen_kwargs)
# thread.start()
# # 1) Live OCR streaming to Raw (and mirror to Markdown during stream)
# buffer = ""
# for new_text in streamer:
# buffer += new_text.replace("<|im_end|>", "")
# time.sleep(0.01)
# # During streaming, just show the raw text in both components
# yield buffer, buffer
# # Final raw text
# final_ocr_text = buffer.strip()
# # -------------------------
# # 2) Clinical NER (Dots.OCR only)
# # -------------------------
# meds: List[str] = []
# if model_name == "Dots.OCR":
# try:
# if "ClinicalNER" in priv_classes and HF_TOKEN is not None:
# ClinicalNER = priv_classes["ClinicalNER"]
# ner = ClinicalNER(token=HF_TOKEN) # model_id can be passed if needed
# ner_output = ner(final_ocr_text) or []
# # Expecting list[str]; be robust:
# meds = [m.strip() for m in ner_output if isinstance(m, str) and m.strip()]
# print("[NER] Extracted meds:", meds)
# else:
# print("[NER] ClinicalNER not available or no HF token.")
# except Exception as e:
# print(f"[NER] Error running ClinicalNER: {e}")
# # Fallback: if no meds found (or Chandra-OCR), derive meds from OCR lines
# if not meds:
# meds = [line.strip() for line in final_ocr_text.splitlines() if line.strip()]
# print("[NER] Using line-based meds fallback, count:", len(meds))
# # -------------------------
# # Build Markdown: OCR text + NER section
# # -------------------------
# md = "### Raw OCR Output\n"
# md += "```\n" + (final_ocr_text or "(empty)") + "\n```\n"
# md += "\n---\n### Clinical NER (Medications)\n"
# if meds:
# for m in meds:
# md += f"- {m}\n"
# else:
# md += "- None detected\n"
# # -------------------------
# # 3) Spell-check (med list) with CER
# # -------------------------
# spell_section = "\n---\n### Spell-check suggestions (" + spell_algo + ")\n"
# corr: Dict[str, List] = {}
# try:
# if meds and drug_xlsx_path:
# if (
# spell_algo == "TF-IDF + Phonetic"
# and "TfidfPhoneticMatcher" in priv_classes
# ):
# Cls = priv_classes["TfidfPhoneticMatcher"]
# checker = Cls(
# xlsx_path=drug_xlsx_path,
# column="Combined_Drugs",
# ngram_size=3,
# phonetic_weight=0.4,
# )
# corr = checker.match_list(meds, top_k=5, tfidf_threshold=0.15)
# elif spell_algo == "SymSpell" and "SymSpellMatcher" in priv_classes:
# Cls = priv_classes["SymSpellMatcher"]
# checker = Cls(
# xlsx_path=drug_xlsx_path,
# column="Combined_Drugs",
# max_edit=2,
# prefix_len=7,
# )
# corr = checker.match_list(meds, top_k=5, min_score=0.4)
# elif (
# spell_algo == "RapidFuzz" and "RapidFuzzMatcher" in priv_classes
# ):
# Cls = priv_classes["RapidFuzzMatcher"]
# checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs")
# corr = checker.match_list(meds, top_k=5, threshold=70.0)
# else:
# spell_section += "- Spell-check backend unavailable.\n"
# else:
# spell_section += "- No NER/med list or Excel dictionary missing.\n"
# except Exception as e:
# spell_section += f"- Spell-check error: {e}\n"
# # Format suggestions (top-5 per med, with scores + CER)
# if corr:
# for raw in meds:
# suggestions = corr.get(raw, [])
# if suggestions:
# spell_section += f"- **{raw}**\n"
# for cand, score in suggestions:
# cer = character_error_rate(cand, raw)
# spell_section += (
# f" - {cand} "
# f"(score={score:.3f}, CER={cer:.3f}%)\n"
# )
# else:
# spell_section += f"- **{raw}**\n - (no suggestions)\n"
# final_md = md + spell_section
# # 4) Final yield: raw unchanged; Markdown with NER + spell-check + CER
# yield final_ocr_text, final_md
# # ----------------------------
# # UI
# # ----------------------------
# # IMPORTANT: examples must match the number of inputs (here: only image)
# image_examples = [
# ["examples/3.jpg"],
# ["examples/1.jpg"],
# ["examples/2.jpg"],
# ]
# with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
# gr.Markdown(
# "# **Handwritten Doctor's Prescription Reading**", elem_id="main-title"
# )
# with gr.Row():
# with gr.Column(scale=2):
# image_upload = gr.Image(
# type="pil", label="Upload Image", height=290
# )
# image_submit = gr.Button("Submit", variant="primary")
# gr.Examples(
# examples=image_examples,
# inputs=[image_upload],
# label="Example Images",
# )
# # Spell-check selection
# spell_choice = gr.Radio(
# choices=["TF-IDF + Phonetic", "SymSpell", "RapidFuzz"],
# label="Select Spell-check Approach",
# value="TF-IDF + Phonetic",
# )
# with gr.Accordion("Advanced options", open=False):
# max_new_tokens = gr.Slider(
# label="Max new tokens",
# minimum=1,
# maximum=MAX_MAX_NEW_TOKENS,
# step=1,
# value=DEFAULT_MAX_NEW_TOKENS,
# )
# temperature = gr.Slider(
# label="Temperature",
# minimum=0.1,
# maximum=4.0,
# step=0.1,
# value=0.7,
# )
# top_p = gr.Slider(
# label="Top-p (nucleus sampling)",
# minimum=0.05,
# maximum=1.0,
# step=0.05,
# value=0.9,
# )
# top_k = gr.Slider(
# label="Top-k",
# minimum=1,
# maximum=1000,
# step=1,
# value=50,
# )
# repetition_penalty = gr.Slider(
# label="Repetition penalty",
# minimum=1.0,
# maximum=2.0,
# step=0.05,
# value=1.1,
# )
# with gr.Column(scale=3):
# gr.Markdown("## Output", elem_id="output-title")
# output = gr.Textbox(
# label="Raw Output Stream",
# interactive=False,
# lines=11,
# show_copy_button=True,
# )
# with gr.Accordion("(Result.md)", open=False):
# markdown_output = gr.Markdown(label="(Result.Md)")
# model_choice = gr.Radio(
# choices=["Chandra-OCR", "Dots.OCR"],
# label="Select OCR Model",
# value="Chandra-OCR",
# )
# # Hard-coded instruction text, passed as gr.State to match the 'text' parameter
# query_state = gr.State(
# "Extract medicine or drugs names along with dosage amount or quantity"
# )
# image_submit.click(
# fn=generate_image,
# inputs=[
# model_choice,
# query_state,
# image_upload,
# max_new_tokens,
# temperature,
# top_p,
# top_k,
# repetition_penalty,
# spell_choice,
# ],
# outputs=[output, markdown_output],
# )
# if __name__ == "__main__":
# demo.queue(max_size=50).launch(
# mcp_server=True, ssr_mode=False, show_error=True
# )
##################################### version 1 #######################################################
# import os
# import time
# from threading import Thread
# from typing import Iterable, Dict, Any, Optional, List
# import gradio as gr
# import spaces
# import torch
# from PIL import Image
# from transformers import (
# Qwen3VLForConditionalGeneration,
# AutoModelForCausalLM,
# AutoProcessor,
# TextIteratorStreamer,
# )
# from gradio.themes import Soft
# from gradio.themes.utils import colors, fonts, sizes
# # -----------------------------
# # Private repo: dynamic import
# # -----------------------------
# import importlib.util
# from huggingface_hub import hf_hub_download
# REPO_ID = "IFMedTech/Medibot_OCR_model" # private backend repo
# # Map filenames to exported class names
# PY_MODULES = {
# "ner.py": "ClinicalNER",
# "tfidf_phonetic.py": "TfidfPhoneticMatcher",
# "symspell_matcher.py": "SymSpellMatcher",
# "rapidfuzz_matcher.py": "RapidFuzzMatcher",
# # 'drug_dictionary.xlsx' is data, not a module
# }
# HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN")
# def _dynamic_import(module_path: str, class_name: str):
# spec = importlib.util.spec_from_file_location(class_name, module_path)
# module = importlib.util.module_from_spec(spec)
# spec.loader.exec_module(module) # type: ignore
# return getattr(module, class_name)
# # Load private classes and Excel dictionary
# priv_classes: Dict[str, Any] = {}
# drug_xlsx_path: Optional[str] = None
# try:
# if HF_TOKEN is None:
# print("[Private] WARNING: HUGGINGFACE_TOKEN not set; NER/Spell-check will be unavailable.")
# else:
# for fname, cls in PY_MODULES.items():
# path = hf_hub_download(repo_id=REPO_ID, filename=fname, token=HF_TOKEN)
# if cls:
# priv_classes[cls] = _dynamic_import(path, cls)
# print(f"[Private] Loaded class: {cls} from {fname}")
# drug_xlsx_path = hf_hub_download(repo_id=REPO_ID, filename="Medibot_Drugs_Cleaned_Updated.xlsx", token=HF_TOKEN)
# print(f"[Private] Downloaded Excel at: {drug_xlsx_path}")
# except Exception as e:
# print(f"[Private] ERROR loading private backend: {e}")
# priv_classes = {}
# drug_xlsx_path = None
# # ----------------------------
# # THEME
# # ----------------------------
# colors.steel_blue = colors.Color(
# name="steel_blue",
# c50="#EBF3F8",
# c100="#D3E5F0",
# c200="#A8CCE1",
# c300="#7DB3D2",
# c400="#529AC3",
# c500="#4682B4",
# c600="#3E72A0",
# c700="#36638C",
# c800="#2E5378",
# c900="#264364",
# c950="#1E3450",
# )
# class SteelBlueTheme(Soft):
# def __init__(
# self,
# *,
# primary_hue: colors.Color | str = colors.gray,
# secondary_hue: colors.Color | str = colors.steel_blue,
# neutral_hue: colors.Color | str = colors.slate,
# text_size: sizes.Size | str = sizes.text_lg,
# font: fonts.Font | str | Iterable[fonts.Font | str] = (
# fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
# ),
# font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
# fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
# ),
# ):
# super().__init__(
# primary_hue=primary_hue,
# secondary_hue=secondary_hue,
# neutral_hue=neutral_hue,
# text_size=text_size,
# font=font,
# font_mono=font_mono,
# )
# super().set(
# background_fill_primary="*primary_50",
# background_fill_primary_dark="*primary_900",
# body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
# body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
# button_primary_text_color="white",
# button_primary_text_color_hover="white",
# button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
# button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
# button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
# button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
# button_secondary_text_color="black",
# button_secondary_text_color_hover="white",
# button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
# button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
# button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
# button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
# slider_color="*secondary_500",
# slider_color_dark="*secondary_600",
# block_title_text_weight="600",
# block_border_width="3px",
# block_shadow="*shadow_drop_lg",
# button_primary_shadow="*shadow_drop_lg",
# button_large_padding="11px",
# color_accent_soft="*primary_100",
# block_label_background_fill="*primary_200",
# )
# steel_blue_theme = SteelBlueTheme()
# css = """
# #main-title h1 { font-size: 2.3em !important; }
# #output-title h2 { font-size: 2.1em !important; }
# """
# # ----------------------------
# # RUNTIME / DEVICE
# # ----------------------------
# os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
# print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES"))
# print("torch.__version__ =", torch.__version__)
# print("torch.version.cuda =", torch.version.cuda)
# print("cuda available =", torch.cuda.is_available())
# print("cuda device count =", torch.cuda.device_count())
# if torch.cuda.is_available():
# print("using device =", torch.cuda.get_device_name(0))
# use_cuda = torch.cuda.is_available()
# device = torch.device("cuda:0" if use_cuda else "cpu")
# if use_cuda:
# torch.backends.cudnn.benchmark = True
# DTYPE_FP16 = torch.float16 if use_cuda else torch.float32
# DTYPE_BF16 = torch.bfloat16 if use_cuda else torch.float32
# # ----------------------------
# # OCR MODELS: Chandra-OCR + Dots.OCR
# # ----------------------------
# # 1) Chandra-OCR (Qwen3VL)
# MODEL_ID_V = "datalab-to/chandra"
# processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
# model_v = Qwen3VLForConditionalGeneration.from_pretrained(
# MODEL_ID_V, trust_remote_code=True, torch_dtype=DTYPE_FP16
# ).to(device).eval()
# # 2) Dots.OCR (flash_attn2 if available, else SDPA)
# MODEL_PATH_D = "prithivMLmods/Dots.OCR-Latest-BF16"
# processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
# attn_impl = "sdpa"
# try:
# import flash_attn # noqa: F401
# if use_cuda:
# attn_impl = "flash_attention_2"
# except Exception:
# attn_impl = "sdpa"
# model_d = AutoModelForCausalLM.from_pretrained(
# MODEL_PATH_D,
# attn_implementation=attn_impl,
# torch_dtype=DTYPE_BF16,
# device_map="auto" if use_cuda else None,
# trust_remote_code=True
# ).eval()
# if not use_cuda:
# model_d.to(device)
# # ----------------------------
# # GENERATION (OCR → NER → Spell-check)
# # ----------------------------
# MAX_MAX_NEW_TOKENS = 4096
# DEFAULT_MAX_NEW_TOKENS = 2048
# @spaces.GPU
# def generate_image(model_name: str,
# text: str,
# image: Image.Image,
# max_new_tokens: int,
# temperature: float,
# top_p: float,
# top_k: int,
# repetition_penalty: float,
# spell_algo: str):
# """
# 1) Stream OCR tokens to Raw output (unchanged).
# 2) After stream completes, run ClinicalNER on final raw text → list[str] meds.
# 3) Apply selected spell-check (TF-IDF+Phonetic / SymSpell / RapidFuzz) using Excel dict.
# 4) Markdown shows OCR + NER list + spell-check top-5 suggestions with scores.
# """
# if image is None:
# yield "Please upload an image.", "Please upload an image."
# return
# if model_name == "Chandra-OCR":
# processor, model = processor_v, model_v
# elif model_name == "Dots.OCR":
# processor, model = processor_d, model_d
# else:
# yield "Invalid model selected.", "Invalid model selected."
# return
# # Build prompt
# messages = [{
# "role": "user",
# "content": [
# {"type": "image"},
# {"type": "text", "text": text},
# ]
# }]
# prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# # Preprocess
# inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True)
# inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()}
# # Streamer
# tokenizer = getattr(processor, "tokenizer", None) or processor
# streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# gen_kwargs = dict(
# **inputs,
# streamer=streamer,
# max_new_tokens=max_new_tokens,
# do_sample=True,
# temperature=temperature,
# top_p=top_p,
# top_k=top_k,
# repetition_penalty=repetition_penalty,
# )
# # Start generation
# thread = Thread(target=model.generate, kwargs=gen_kwargs)
# thread.start()
# # 1) Live OCR streaming to Raw (and mirror to Markdown during stream)
# buffer = ""
# for new_text in streamer:
# buffer += new_text.replace("<|im_end|>", "")
# time.sleep(0.01)
# yield buffer, buffer
# # Final raw text for downstream processing
# final_ocr_text = buffer
# # 2) Clinical NER (from private repo)
# # meds: List[str] = []
# # try:
# # if "ClinicalNER" in priv_classes:
# # ClinicalNER = priv_classes["ClinicalNER"]
# # ner = ClinicalNER(token=HF_TOKEN) # pass model_id=... if using your own model
# # meds = ner(final_ocr_text) or []
# # else:
# # print("[NER] ClinicalNER not available.")
# # except Exception as e:
# # print(f"[NER] Error running ClinicalNER: {e}")
# raw_ocr_text = buffer.strip()
# meds = [line.strip() for line in raw_ocr_text.split('\n') if line.strip()]
# # Build Markdown with OCR + NER section
# md = final_ocr_text
# md += "\n\n---\n### Clinical NER (Medications)\n"
# if meds:
# for m in meds:
# md += f"- {m}\n"
# else:
# md += "- None detected\n"
# # 3) Spell-check on NER output using selected approach + Excel dict
# spell_section = "\n---\n### Spell-check suggestions (" + spell_algo + ")\n"
# corr: Dict[str, List] = {}
# try:
# if meds and drug_xlsx_path:
# if spell_algo == "TF-IDF + Phonetic" and "TfidfPhoneticMatcher" in priv_classes:
# Cls = priv_classes["TfidfPhoneticMatcher"]
# checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs", ngram_size=3, phonetic_weight=0.4)
# corr = checker.match_list(meds, top_k=5, tfidf_threshold=0.15)
# elif spell_algo == "SymSpell" and "SymSpellMatcher" in priv_classes:
# Cls = priv_classes["SymSpellMatcher"]
# checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs", max_edit=2, prefix_len=7)
# corr = checker.match_list(meds, top_k=5, min_score=0.4)
# elif spell_algo == "RapidFuzz" and "RapidFuzzMatcher" in priv_classes:
# Cls = priv_classes["RapidFuzzMatcher"]
# checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs")
# corr = checker.match_list(meds, top_k=5, threshold=70.0)
# else:
# spell_section += "- Spell-check backend unavailable.\n"
# else:
# spell_section += "- No NER output or Excel dictionary missing.\n"
# except Exception as e:
# spell_section += f"- Spell-check error: {e}\n"
# # Format suggestions (top-5 with scores)
# if corr:
# for raw in meds:
# suggestions = corr.get(raw, [])
# if suggestions:
# spell_section += f"- **{raw}**\n"
# for cand, score in suggestions:
# spell_section += f" - {cand} (score={score:.3f})\n"
# else:
# spell_section += f"- **{raw}**\n - (no suggestions)\n"
# final_md = md + spell_section
# # 4) Final yield: raw unchanged; Markdown with NER + spell-check
# yield final_ocr_text, final_md
# # ----------------------------
# # UI
# # ----------------------------
# image_examples = [
# ["OCR the content perfectly.", "examples/3.jpg"],
# ["Perform OCR on the image.", "examples/1.jpg"],
# ["Extract the contents. [page].", "examples/2.jpg"],
# ]
# with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
# gr.Markdown("# **Handwritten Doctor's Prescription Reading**", elem_id="main-title")
# with gr.Row():
# with gr.Column(scale=2):
# #image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
# image_upload = gr.Image(type="pil", label="Upload Image", height=290)
# image_submit = gr.Button("Submit", variant="primary")
# gr.Examples(examples=image_examples, inputs=[image_upload])
# # Spell-check selection
# spell_choice = gr.Radio(
# choices=["TF-IDF + Phonetic", "SymSpell", "RapidFuzz"],
# label="Select Spell-check Approach",
# value="TF-IDF + Phonetic"
# )
# with gr.Accordion("Advanced options", open=False):
# max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
# temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.7)
# top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
# top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
# repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)
# with gr.Column(scale=3):
# gr.Markdown("## Output", elem_id="output-title")
# output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True)
# with gr.Accordion("(Result.md)", open=False):
# markdown_output = gr.Markdown(label="(Result.Md)")
# model_choice = gr.Radio(
# choices=["Chandra-OCR", "Dots.OCR"],
# label="Select OCR Model",
# value="Chandra-OCR"
# )
# image_submit.click(
# fn=generate_image,
# inputs=[model_choice,gr.State("Extract medicine or drugs names along with dosage amount or quantity") , image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, spell_choice],
# outputs=[output, markdown_output]
# )
# if __name__ == "__main__":
# demo.queue(max_size=50).launch(mcp_server=True, ssr_mode=False, show_error=True)