Spaces:
Paused
Paused
| ###################################### version 5 ############################################ | |
| # import os | |
| # from typing import Iterable, Dict, Any, Optional, List | |
| # from threading import Thread # no longer needed but harmless if left | |
| # import time # no longer needed but harmless if left | |
| # import gradio as gr | |
| # import spaces | |
| # import torch | |
| # from PIL import Image | |
| # import pandas as pd | |
| # from transformers import ( | |
| # Qwen3VLForConditionalGeneration, | |
| # AutoModelForCausalLM, | |
| # AutoProcessor, | |
| # ) | |
| # from gradio.themes import Soft | |
| # from gradio.themes.utils import colors, fonts, sizes | |
| # # ============================================================ | |
| # # Character Error Rate (CER) | |
| # # ============================================================ | |
| # def levenshtein(a: str, b: str) -> int: | |
| # """Levenshtein distance to calculate CER.""" | |
| # a, b = a.lower(), b.lower() | |
| # if a == b: | |
| # return 0 | |
| # if not a: | |
| # return len(b) | |
| # if not b: | |
| # return len(a) | |
| # dp = list(range(len(b) + 1)) | |
| # for i, ca in enumerate(a, 1): | |
| # prev = dp[0] | |
| # dp[0] = i | |
| # for j, cb in enumerate(b, 1): | |
| # cur = dp[j] | |
| # cost = 0 if ca == cb else 1 | |
| # dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost) | |
| # prev = cur | |
| # return dp[-1] | |
| # def character_error_rate(pred: str, target: str) -> float: | |
| # """Calculate the Character Error Rate (CER) in percent.""" | |
| # target = target or "" | |
| # distance = levenshtein(pred, target) | |
| # return (distance / len(target)) * 100 if len(target) > 0 else 0.0 | |
| # # ============================================================ | |
| # # Private repo: dynamic import + Excel download | |
| # # ============================================================ | |
| # import importlib.util | |
| # from huggingface_hub import hf_hub_download | |
| # REPO_ID = "IFMedTech/Medibot_OCR_model" # private backend repo | |
| # PY_MODULES: Dict[str, str] = { | |
| # "clinical_NER.py": "ClinicalNER", | |
| # "tf_idf_phonetic.py": "TfidfPhoneticMatcher", | |
| # "symspell_matcher.py": "SymSpellMatcher", | |
| # "rapid_fuzz_matcher.py": "RapidFuzzMatcher", | |
| # } | |
| # HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN") # must be set in Space secrets | |
| # def _dynamic_import(module_path: str, class_name: str): | |
| # spec = importlib.util.spec_from_file_location(class_name, module_path) | |
| # module = importlib.util.module_from_spec(spec) | |
| # spec.loader.exec_module(module) # type: ignore | |
| # return getattr(module, class_name) | |
| # priv_classes: Dict[str, Any] = {} | |
| # drug_xlsx_path: Optional[str] = None | |
| # BACKEND_INIT_ERROR: Optional[str] = None | |
| # print("[Private] HF_TOKEN present?:", HF_TOKEN is not None) | |
| # if HF_TOKEN is None: | |
| # BACKEND_INIT_ERROR = "HUGGINGFACE_TOKEN env var is not set in this Space." | |
| # print("[Private] WARNING:", BACKEND_INIT_ERROR) | |
| # else: | |
| # print(f"[Private] Using repo: {REPO_ID}") | |
| # # 1) Load python modules (best-effort) | |
| # for fname, cls_name in PY_MODULES.items(): | |
| # try: | |
| # print(f"[Private] Downloading module file: {fname}") | |
| # path = hf_hub_download( | |
| # repo_id=REPO_ID, | |
| # filename=fname, | |
| # token=HF_TOKEN, | |
| # repo_type="model", | |
| # ) | |
| # priv_classes[cls_name] = _dynamic_import(path, cls_name) | |
| # print(f"[Private] Loaded class {cls_name} from {fname}") | |
| # except Exception as e: | |
| # msg = f"Failed to load {fname}: {e}" | |
| # print("[Private]", msg) | |
| # BACKEND_INIT_ERROR = (BACKEND_INIT_ERROR or "") + f" | {msg}" | |
| # # 2) Load Excel dictionary | |
| # try: | |
| # print("[Private] Downloading Excel file: Medibot_Drugs_Cleaned_Updated.xlsx") | |
| # drug_xlsx_path = hf_hub_download( | |
| # repo_id=REPO_ID, | |
| # filename="Medibot_Drugs_Cleaned_Updated.xlsx", | |
| # token=HF_TOKEN, | |
| # repo_type="model", | |
| # ) | |
| # print(f"[Private] Downloaded Excel at: {drug_xlsx_path}") | |
| # df_debug = pd.read_excel(drug_xlsx_path, nrows=3) | |
| # print( | |
| # f"[Private] Excel loaded successfully. " | |
| # f"Shape={df_debug.shape}, cols={list(df_debug.columns)}" | |
| # ) | |
| # except Exception as e: | |
| # msg = f"ERROR loading Excel: {e}" | |
| # print("[Private]", msg) | |
| # BACKEND_INIT_ERROR = (BACKEND_INIT_ERROR or "") + f" | {msg}" | |
| # drug_xlsx_path = None | |
| # # ============================================================ | |
| # # THEME | |
| # # ============================================================ | |
| # colors.steel_blue = colors.Color( | |
| # name="steel_blue", | |
| # c50="#EBF3F8", | |
| # c100="#D3E5F0", | |
| # c200="#A8CCE1", | |
| # c300="#7DB3D2", | |
| # c400="#529AC3", | |
| # c500="#4682B4", | |
| # c600="#3E72A0", | |
| # c700="#36638C", | |
| # c800="#2E5378", | |
| # c900="#264364", | |
| # c950="#1E3450", | |
| # ) | |
| # class SteelBlueTheme(Soft): | |
| # def __init__( | |
| # self, | |
| # *, | |
| # primary_hue: colors.Color | str = colors.gray, | |
| # secondary_hue: colors.Color | str = colors.steel_blue, | |
| # neutral_hue: colors.Color | str = colors.slate, | |
| # text_size: sizes.Size | str = sizes.text_lg, | |
| # font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| # fonts.GoogleFont("Outfit"), | |
| # "Arial", | |
| # "sans-serif", | |
| # ), | |
| # font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| # fonts.GoogleFont("IBM Plex Mono"), | |
| # "ui-monospace", | |
| # "monospace", | |
| # ), | |
| # ): | |
| # super().__init__( | |
| # primary_hue=primary_hue, | |
| # secondary_hue=secondary_hue, | |
| # neutral_hue=neutral_hue, | |
| # text_size=text_size, | |
| # font=font, | |
| # font_mono=font_mono, | |
| # ) | |
| # super().set( | |
| # background_fill_primary="*primary_50", | |
| # background_fill_primary_dark="*primary_900", | |
| # body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", | |
| # body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", | |
| # button_primary_text_color="white", | |
| # button_primary_text_color_hover="white", | |
| # button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| # button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| # button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)", | |
| # button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)", | |
| # button_secondary_text_color="black", | |
| # button_secondary_text_color_hover="white", | |
| # button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)", | |
| # button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)", | |
| # button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)", | |
| # button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)", | |
| # slider_color="*secondary_500", | |
| # slider_color_dark="*secondary_600", | |
| # block_title_text_weight="600", | |
| # block_border_width="3px", | |
| # block_shadow="*shadow_drop_lg", | |
| # button_primary_shadow="*shadow_drop_lg", | |
| # button_large_padding="11px", | |
| # color_accent_soft="*primary_100", | |
| # block_label_background_fill="*primary_200", | |
| # ) | |
| # steel_blue_theme = SteelBlueTheme() | |
| # css = """ | |
| # #main-title h1 { font-size: 2.3em !important; } | |
| # #output-title h2 { font-size: 2.1em !important; } | |
| # """ | |
| # # ============================================================ | |
| # # RUNTIME / DEVICE | |
| # # ============================================================ | |
| # os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0") | |
| # print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES")) | |
| # print("torch.__version__ =", torch.__version__) | |
| # print("torch.version.cuda =", torch.version.cuda) | |
| # print("cuda available =", torch.cuda.is_available()) | |
| # print("cuda device count =", torch.cuda.device_count()) | |
| # if torch.cuda.is_available(): | |
| # print("using device =", torch.cuda.get_device_name(0)) | |
| # use_cuda = torch.cuda.is_available() | |
| # device = torch.device("cuda:0" if use_cuda else "cpu") | |
| # if use_cuda: | |
| # torch.backends.cudnn.benchmark = True | |
| # DTYPE_FP16 = torch.float16 if use_cuda else torch.float32 | |
| # DTYPE_BF16 = torch.bfloat16 if use_cuda else torch.float32 | |
| # # ============================================================ | |
| # # OCR MODELS: Chandra-OCR + Dots.OCR | |
| # # ============================================================ | |
| # MODEL_ID_V = "datalab-to/chandra" | |
| # processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True) | |
| # model_v = Qwen3VLForConditionalGeneration.from_pretrained( | |
| # MODEL_ID_V, trust_remote_code=True, torch_dtype=DTYPE_FP16 | |
| # ).to(device).eval() | |
| # MODEL_PATH_D = "prithivMLmods/Dots.OCR-Latest-BF16" | |
| # processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True) | |
| # attn_impl = "sdpa" | |
| # try: | |
| # import flash_attn # noqa: F401 | |
| # if use_cuda: | |
| # attn_impl = "flash_attention_2" | |
| # except Exception: | |
| # attn_impl = "sdpa" | |
| # model_d = AutoModelForCausalLM.from_pretrained( | |
| # MODEL_PATH_D, | |
| # attn_implementation=attn_impl, | |
| # torch_dtype=DTYPE_BF16, | |
| # device_map="auto" if use_cuda else None, | |
| # trust_remote_code=True, | |
| # ).eval() | |
| # if not use_cuda: | |
| # model_d.to(device) | |
| # # ============================================================ | |
| # # GENERATION (no raw output UI; one markdown return) | |
| # # ============================================================ | |
| # MAX_MAX_NEW_TOKENS = 4096 | |
| # DEFAULT_MAX_NEW_TOKENS = 2048 | |
| # @spaces.GPU | |
| # def generate_image( | |
| # model_name: str, | |
| # text: str, | |
| # image: Image.Image, | |
| # max_new_tokens: int, | |
| # temperature: float, | |
| # top_p: float, | |
| # top_k: int, | |
| # repetition_penalty: float, | |
| # spell_algo: str, | |
| # ) -> str: | |
| # """ | |
| # Returns a single Markdown string: | |
| # - Medications (extracted) | |
| # - Spell-check suggestions | |
| # No raw OCR text is returned to the UI. | |
| # """ | |
| # try: | |
| # if image is None: | |
| # return "Please upload an image." | |
| # # Choose processor/model | |
| # if model_name == "Chandra-OCR": | |
| # processor, model = processor_v, model_v | |
| # elif model_name == "Dots.OCR": | |
| # processor, model = processor_d, model_d | |
| # else: | |
| # return "Invalid model selected." | |
| # # Build prompt | |
| # messages = [ | |
| # { | |
| # "role": "user", | |
| # "content": [ | |
| # {"type": "image"}, | |
| # {"type": "text", "text": text}, | |
| # ], | |
| # } | |
| # ] | |
| # prompt_full = processor.apply_chat_template( | |
| # messages, tokenize=False, add_generation_prompt=True | |
| # ) | |
| # # Preprocess | |
| # inputs = processor( | |
| # text=[prompt_full], images=[image], return_tensors="pt", padding=True | |
| # ) | |
| # inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()} | |
| # # Generate (no streaming) | |
| # gen_kwargs = dict( | |
| # **inputs, | |
| # max_new_tokens=max_new_tokens, | |
| # do_sample=True, | |
| # temperature=temperature, | |
| # top_p=top_p, | |
| # top_k=top_k, | |
| # repetition_penalty=repetition_penalty, | |
| # ) | |
| # outputs = model.generate(**gen_kwargs) | |
| # tokenizer = getattr(processor, "tokenizer", None) or processor | |
| # generated = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
| # final_ocr_text = generated.strip() | |
| # # -------------------------------------------------------- | |
| # # 2) Medications extraction | |
| # # -------------------------------------------------------- | |
| # meds: List[str] = [] | |
| # if model_name == "Dots.OCR": | |
| # try: | |
| # if "ClinicalNER" in priv_classes and HF_TOKEN is not None: | |
| # ClinicalNER = priv_classes["ClinicalNER"] | |
| # ner = ClinicalNER(token=HF_TOKEN) | |
| # ner_output = ner(final_ocr_text) or [] | |
| # meds = [ | |
| # m.strip() | |
| # for m in ner_output | |
| # if isinstance(m, str) and m.strip() | |
| # ] | |
| # print("[NER] (Dots.OCR) ClinicalNER meds:", meds) | |
| # else: | |
| # print("[NER] ClinicalNER unavailable or missing HF token; skipping.") | |
| # except Exception as e: | |
| # print(f"[NER] Error running ClinicalNER: {e}") | |
| # if not meds: | |
| # meds = [ | |
| # line.strip() | |
| # for line in final_ocr_text.splitlines() | |
| # if line.strip() | |
| # ] | |
| # print("[NER] (Dots.OCR) Fallback to lines, count:", len(meds)) | |
| # else: # Chandra-OCR | |
| # meds = [ | |
| # line.strip() | |
| # for line in final_ocr_text.splitlines() | |
| # if line.strip() | |
| # ] | |
| # print("[NER] (Chandra-OCR) Line-based meds only, count:", len(meds)) | |
| # print("[DEBUG] meds count:", len(meds)) | |
| # print("[DEBUG] drug_xlsx_path in generate_image:", drug_xlsx_path) | |
| # # -------------------------------------------------------- | |
| # # 3) Markdown: Medications only (no Raw OCR section) | |
| # # -------------------------------------------------------- | |
| # md = "### Medications (extracted)\n" | |
| # if meds: | |
| # for m in meds: | |
| # md += f"- {m}\n" | |
| # else: | |
| # md += "- None detected\n" | |
| # # -------------------------------------------------------- | |
| # # 4) Spell-check (med list) with CER | |
| # # -------------------------------------------------------- | |
| # spell_section = "\n---\n### Spell-check suggestions (" + spell_algo + ")\n" | |
| # corr: Dict[str, List] = {} | |
| # if BACKEND_INIT_ERROR: | |
| # spell_section += f"- [DEBUG] Backend init error: {BACKEND_INIT_ERROR}\n" | |
| # try: | |
| # if meds and drug_xlsx_path: | |
| # try: | |
| # df_dbg = pd.read_excel(drug_xlsx_path) | |
| # print( | |
| # f"[Spell DEBUG] Excel read OK: path={drug_xlsx_path}, " | |
| # f"shape={df_dbg.shape}, cols={list(df_dbg.columns)}" | |
| # ) | |
| # spell_section += ( | |
| # f"- [DEBUG] Excel read OK; shape={df_dbg.shape}, " | |
| # f"cols={list(df_dbg.columns)}\n" | |
| # ) | |
| # except Exception as e: | |
| # print(f"[Spell DEBUG] ERROR reading Excel in generate_image: {e}") | |
| # spell_section += f"- [DEBUG] Excel read error: {e}\n" | |
| # if ( | |
| # spell_algo == "TF-IDF + Phonetic" | |
| # and "TfidfPhoneticMatcher" in priv_classes | |
| # ): | |
| # print("[Spell DEBUG] Using TfidfPhoneticMatcher") | |
| # Cls = priv_classes["TfidfPhoneticMatcher"] | |
| # checker = Cls( | |
| # xlsx_path=drug_xlsx_path, | |
| # column="Combined_Drugs", | |
| # ngram_size=3, | |
| # phonetic_weight=0.4, | |
| # ) | |
| # corr = checker.match_list(meds, top_k=5, tfidf_threshold=0.15) | |
| # elif spell_algo == "SymSpell" and "SymSpellMatcher" in priv_classes: | |
| # print("[Spell DEBUG] Using SymSpellMatcher") | |
| # Cls = priv_classes["SymSpellMatcher"] | |
| # checker = Cls( | |
| # xlsx_path=drug_xlsx_path, | |
| # column="Combined_Drugs", | |
| # max_edit=2, | |
| # prefix_len=7, | |
| # ) | |
| # corr = checker.match_list(meds, top_k=5, min_score=0.4) | |
| # elif spell_algo == "RapidFuzz" and "RapidFuzzMatcher" in priv_classes: | |
| # print("[Spell DEBUG] Using RapidFuzzMatcher") | |
| # Cls = priv_classes["RapidFuzzMatcher"] | |
| # checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs") | |
| # corr = checker.match_list(meds, top_k=5, threshold=70.0) | |
| # else: | |
| # spell_section += ( | |
| # "- Spell-check backend unavailable " | |
| # "(no matcher class for selected algorithm).\n" | |
| # ) | |
| # else: | |
| # if not meds: | |
| # spell_section += "- No medications extracted (empty med list).\n" | |
| # if not drug_xlsx_path: | |
| # spell_section += ( | |
| # "- Drug Excel dictionary path missing " | |
| # "(drug_xlsx_path is None).\n" | |
| # ) | |
| # except Exception as e: | |
| # print(f"[Spell DEBUG] Spell-check error: {e}") | |
| # spell_section += f"- Spell-check error: {e}\n" | |
| # if corr: | |
| # for raw in meds: | |
| # suggestions = corr.get(raw, []) | |
| # if suggestions: | |
| # spell_section += f"- **{raw}**\n" | |
| # for cand, score in suggestions: | |
| # cer = character_error_rate(cand, raw) | |
| # spell_section += ( | |
| # f" - {cand} (score={score:.3f}, CER={cer:.3f}%)\n" | |
| # ) | |
| # else: | |
| # spell_section += f"- **{raw}**\n - (no suggestions)\n" | |
| # final_md = md + spell_section | |
| # return final_md | |
| # except Exception as e: | |
| # # Catch-all so the GPU worker does not crash | |
| # print(f"[ERROR] generate_image crashed: {e}") | |
| # import traceback | |
| # traceback.print_exc() | |
| # return f"Error while processing: {e}" | |
| # # ============================================================ | |
| # # UI | |
| # # ============================================================ | |
| # image_examples = [ | |
| # ["examples/test_1.jpeg"], | |
| # ["examples/test_4.jpeg"], | |
| # ["examples/test_5.jpeg"], | |
| # ] | |
| # with gr.Blocks(css=css, theme=steel_blue_theme) as demo: | |
| # gr.Markdown( | |
| # "# **Handwritten Doctor's Prescription Reading**", elem_id="main-title" | |
| # ) | |
| # with gr.Row(): | |
| # with gr.Column(scale=2): | |
| # image_upload = gr.Image( | |
| # type="pil", label="Upload Image", height=290 | |
| # ) | |
| # image_submit = gr.Button("Submit", variant="primary") | |
| # gr.Examples( | |
| # examples=image_examples, | |
| # inputs=[image_upload], | |
| # label="Example Images", | |
| # ) | |
| # spell_choice = gr.Radio( | |
| # choices=["TF-IDF + Phonetic", "SymSpell", "RapidFuzz"], | |
| # label="Select Spell-check Approach", | |
| # value="TF-IDF + Phonetic", | |
| # ) | |
| # with gr.Accordion("Advanced options", open=False): | |
| # max_new_tokens = gr.Slider( | |
| # label="Max new tokens", | |
| # minimum=1, | |
| # maximum=MAX_MAX_NEW_TOKENS, | |
| # step=1, | |
| # value=DEFAULT_MAX_NEW_TOKENS, | |
| # ) | |
| # temperature = gr.Slider( | |
| # label="Temperature", | |
| # minimum=0.1, | |
| # maximum=4.0, | |
| # step=0.1, | |
| # value=0.7, | |
| # ) | |
| # top_p = gr.Slider( | |
| # label="Top-p (nucleus sampling)", | |
| # minimum=0.05, | |
| # maximum=1.0, | |
| # step=0.05, | |
| # value=0.9, | |
| # ) | |
| # top_k = gr.Slider( | |
| # label="Top-k", | |
| # minimum=1, | |
| # maximum=1000, | |
| # step=1, | |
| # value=50, | |
| # ) | |
| # repetition_penalty = gr.Slider( | |
| # label="Repetition penalty", | |
| # minimum=1.0, | |
| # maximum=2.0, | |
| # step=0.05, | |
| # value=1.1, | |
| # ) | |
| # with gr.Column(scale=3): | |
| # gr.Markdown("## Output", elem_id="output-title") | |
| # with gr.Accordion("(Result.md)", open=True): | |
| # markdown_output = gr.Markdown(label="(Result.Md)") | |
| # model_choice = gr.Radio( | |
| # choices=["Chandra-OCR", "Dots.OCR"], | |
| # label="Select OCR Model", | |
| # value="Chandra-OCR", | |
| # ) | |
| # # Hard-coded query text (passed into the 'text' parameter) | |
| # query_state = gr.State( | |
| # "Extract medicine or drugs names along with dosage amount or quantity" | |
| # ) | |
| # image_submit.click( | |
| # fn=generate_image, | |
| # inputs=[ | |
| # model_choice, | |
| # query_state, | |
| # image_upload, | |
| # max_new_tokens, | |
| # temperature, | |
| # top_p, | |
| # top_k, | |
| # repetition_penalty, | |
| # spell_choice, | |
| # ], | |
| # outputs=[markdown_output], | |
| # ) | |
| # if __name__ == "__main__": | |
| # demo.queue(max_size=50).launch( | |
| # mcp_server=True, ssr_mode=False, show_error=True | |
| # ) | |
| ###################################### version 4 ######################################### | |
| # import os | |
| # import time | |
| # from threading import Thread | |
| # from typing import Iterable, Dict, Any, Optional, List | |
| # import gradio as gr | |
| # import spaces | |
| # import torch | |
| # from PIL import Image | |
| # import pandas as pd # Excel read + debug | |
| # from transformers import ( | |
| # Qwen3VLForConditionalGeneration, | |
| # AutoModelForCausalLM, | |
| # AutoProcessor, | |
| # TextIteratorStreamer, | |
| # ) | |
| # from gradio.themes import Soft | |
| # from gradio.themes.utils import colors, fonts, sizes | |
| # # ============================================================ | |
| # # Character Error Rate (CER) | |
| # # ============================================================ | |
| # def levenshtein(a: str, b: str) -> int: | |
| # """Levenshtein distance to calculate CER.""" | |
| # a, b = a.lower(), b.lower() | |
| # if a == b: | |
| # return 0 | |
| # if not a: | |
| # return len(b) | |
| # if not b: | |
| # return len(a) | |
| # dp = list(range(len(b) + 1)) | |
| # for i, ca in enumerate(a, 1): | |
| # prev = dp[0] | |
| # dp[0] = i | |
| # for j, cb in enumerate(b, 1): | |
| # cur = dp[j] | |
| # cost = 0 if ca == cb else 1 | |
| # dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost) | |
| # prev = cur | |
| # return dp[-1] | |
| # def character_error_rate(pred: str, target: str) -> float: | |
| # """Calculate the Character Error Rate (CER) in percent.""" | |
| # target = target or "" | |
| # distance = levenshtein(pred, target) | |
| # return (distance / len(target)) * 100 if len(target) > 0 else 0.0 | |
| # # ============================================================ | |
| # # Private repo: dynamic import + Excel download | |
| # # ============================================================ | |
| # import importlib.util | |
| # from huggingface_hub import hf_hub_download | |
| # REPO_ID = "IFMedTech/Medibot_OCR_model" # private backend repo | |
| # # Filenames in the repo → class names they define | |
| # PY_MODULES: Dict[str, str] = { | |
| # "clinical_NER.py": "ClinicalNER", | |
| # "tf_idf_phonetic.py": "TfidfPhoneticMatcher", | |
| # "symspell_matcher.py": "SymSpellMatcher", | |
| # "rapid_fuzz_matcher.py": "RapidFuzzMatcher", | |
| # } | |
| # HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN") # must be set in Space secrets | |
| # def _dynamic_import(module_path: str, class_name: str): | |
| # spec = importlib.util.spec_from_file_location(class_name, module_path) | |
| # module = importlib.util.module_from_spec(spec) | |
| # spec.loader.exec_module(module) # type: ignore | |
| # return getattr(module, class_name) | |
| # priv_classes: Dict[str, Any] = {} | |
| # drug_xlsx_path: Optional[str] = None | |
| # BACKEND_INIT_ERROR: Optional[str] = None | |
| # print("[Private] HF_TOKEN present?:", HF_TOKEN is not None) | |
| # if HF_TOKEN is None: | |
| # BACKEND_INIT_ERROR = "HUGGINGFACE_TOKEN env var is not set in this Space." | |
| # print("[Private] WARNING:", BACKEND_INIT_ERROR) | |
| # else: | |
| # print(f"[Private] Using repo: {REPO_ID}") | |
| # # 1) Load python modules (best-effort: failure of one file will not block others) | |
| # for fname, cls_name in PY_MODULES.items(): | |
| # try: | |
| # print(f"[Private] Downloading module file: {fname}") | |
| # path = hf_hub_download( | |
| # repo_id=REPO_ID, | |
| # filename=fname, | |
| # token=HF_TOKEN, | |
| # repo_type="model", | |
| # ) | |
| # priv_classes[cls_name] = _dynamic_import(path, cls_name) | |
| # print(f"[Private] Loaded class {cls_name} from {fname}") | |
| # except Exception as e: | |
| # msg = f"Failed to load {fname}: {e}" | |
| # print("[Private]", msg) | |
| # BACKEND_INIT_ERROR = (BACKEND_INIT_ERROR or "") + f" | {msg}" | |
| # # 2) Load Excel dictionary | |
| # try: | |
| # print("[Private] Downloading Excel file: Medibot_Drugs_Cleaned_Updated.xlsx") | |
| # drug_xlsx_path = hf_hub_download( | |
| # repo_id=REPO_ID, | |
| # filename="Medibot_Drugs_Cleaned_Updated.xlsx", | |
| # token=HF_TOKEN, | |
| # repo_type="model", | |
| # ) | |
| # print(f"[Private] Downloaded Excel at: {drug_xlsx_path}") | |
| # # Debug: verify read | |
| # df_debug = pd.read_excel(drug_xlsx_path, nrows=3) | |
| # print( | |
| # f"[Private] Excel loaded successfully. " | |
| # f"Shape={df_debug.shape}, cols={list(df_debug.columns)}" | |
| # ) | |
| # except Exception as e: | |
| # msg = f"ERROR loading Excel: {e}" | |
| # print("[Private]", msg) | |
| # BACKEND_INIT_ERROR = (BACKEND_INIT_ERROR or "") + f" | {msg}" | |
| # drug_xlsx_path = None | |
| # # ============================================================ | |
| # # THEME | |
| # # ============================================================ | |
| # colors.steel_blue = colors.Color( | |
| # name="steel_blue", | |
| # c50="#EBF3F8", | |
| # c100="#D3E5F0", | |
| # c200="#A8CCE1", | |
| # c300="#7DB3D2", | |
| # c400="#529AC3", | |
| # c500="#4682B4", | |
| # c600="#3E72A0", | |
| # c700="#36638C", | |
| # c800="#2E5378", | |
| # c900="#264364", | |
| # c950="#1E3450", | |
| # ) | |
| # class SteelBlueTheme(Soft): | |
| # def __init__( | |
| # self, | |
| # *, | |
| # primary_hue: colors.Color | str = colors.gray, | |
| # secondary_hue: colors.Color | str = colors.steel_blue, | |
| # neutral_hue: colors.Color | str = colors.slate, | |
| # text_size: sizes.Size | str = sizes.text_lg, | |
| # font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| # fonts.GoogleFont("Outfit"), | |
| # "Arial", | |
| # "sans-serif", | |
| # ), | |
| # font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| # fonts.GoogleFont("IBM Plex Mono"), | |
| # "ui-monospace", | |
| # "monospace", | |
| # ), | |
| # ): | |
| # super().__init__( | |
| # primary_hue=primary_hue, | |
| # secondary_hue=secondary_hue, | |
| # neutral_hue=neutral_hue, | |
| # text_size=text_size, | |
| # font=font, | |
| # font_mono=font_mono, | |
| # ) | |
| # super().set( | |
| # background_fill_primary="*primary_50", | |
| # background_fill_primary_dark="*primary_900", | |
| # body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", | |
| # body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", | |
| # button_primary_text_color="white", | |
| # button_primary_text_color_hover="white", | |
| # button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| # button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| # button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)", | |
| # button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)", | |
| # button_secondary_text_color="black", | |
| # button_secondary_text_color_hover="white", | |
| # button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)", | |
| # button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)", | |
| # button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)", | |
| # button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)", | |
| # slider_color="*secondary_500", | |
| # slider_color_dark="*secondary_600", | |
| # block_title_text_weight="600", | |
| # block_border_width="3px", | |
| # block_shadow="*shadow_drop_lg", | |
| # button_primary_shadow="*shadow_drop_lg", | |
| # button_large_padding="11px", | |
| # color_accent_soft="*primary_100", | |
| # block_label_background_fill="*primary_200", | |
| # ) | |
| # steel_blue_theme = SteelBlueTheme() | |
| # css = """ | |
| # #main-title h1 { font-size: 2.3em !important; } | |
| # #output-title h2 { font-size: 2.1em !important; } | |
| # """ | |
| # # ============================================================ | |
| # # RUNTIME / DEVICE | |
| # # ============================================================ | |
| # os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0") | |
| # print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES")) | |
| # print("torch.__version__ =", torch.__version__) | |
| # print("torch.version.cuda =", torch.version.cuda) | |
| # print("cuda available =", torch.cuda.is_available()) | |
| # print("cuda device count =", torch.cuda.device_count()) | |
| # if torch.cuda.is_available(): | |
| # print("using device =", torch.cuda.get_device_name(0)) | |
| # use_cuda = torch.cuda.is_available() | |
| # device = torch.device("cuda:0" if use_cuda else "cpu") | |
| # if use_cuda: | |
| # torch.backends.cudnn.benchmark = True | |
| # DTYPE_FP16 = torch.float16 if use_cuda else torch.float32 | |
| # DTYPE_BF16 = torch.bfloat16 if use_cuda else torch.float32 | |
| # # ============================================================ | |
| # # OCR MODELS: Chandra-OCR + Dots.OCR | |
| # # ============================================================ | |
| # # 1) Chandra-OCR (Qwen3VL) | |
| # MODEL_ID_V = "datalab-to/chandra" | |
| # processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True) | |
| # model_v = Qwen3VLForConditionalGeneration.from_pretrained( | |
| # MODEL_ID_V, trust_remote_code=True, torch_dtype=DTYPE_FP16 | |
| # ).to(device).eval() | |
| # # 2) Dots.OCR (flash_attn2 if available, else SDPA) | |
| # MODEL_PATH_D = "prithivMLmods/Dots.OCR-Latest-BF16" | |
| # processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True) | |
| # attn_impl = "sdpa" | |
| # try: | |
| # import flash_attn # noqa: F401 | |
| # if use_cuda: | |
| # attn_impl = "flash_attention_2" | |
| # except Exception: | |
| # attn_impl = "sdpa" | |
| # model_d = AutoModelForCausalLM.from_pretrained( | |
| # MODEL_PATH_D, | |
| # attn_implementation=attn_impl, | |
| # torch_dtype=DTYPE_BF16, | |
| # device_map="auto" if use_cuda else None, | |
| # trust_remote_code=True, | |
| # ).eval() | |
| # if not use_cuda: | |
| # model_d.to(device) | |
| # # ============================================================ | |
| # # GENERATION (OCR → Med extraction → Spell-check + CER) | |
| # # ClinicalNER is used ONLY for Dots.OCR. | |
| # # ============================================================ | |
| # MAX_MAX_NEW_TOKENS = 4096 | |
| # DEFAULT_MAX_NEW_TOKENS = 2048 | |
| # @spaces.GPU # you can add duration=... if you hit timeouts | |
| # def generate_image( | |
| # model_name: str, | |
| # text: str, | |
| # image: Image.Image, | |
| # max_new_tokens: int, | |
| # temperature: float, | |
| # top_p: float, | |
| # top_k: int, | |
| # repetition_penalty: float, | |
| # spell_algo: str, | |
| # ): | |
| # """ | |
| # 1) Stream OCR tokens to Raw output. | |
| # 2) For Dots.OCR: run ClinicalNER → meds list (with fallback to line-based). | |
| # For Chandra-OCR: DO NOT call ClinicalNER; meds from OCR lines only. | |
| # 3) Apply selected spell-check algorithm on meds using Excel dictionary. | |
| # 4) Compute CER for each suggestion and display in markdown. | |
| # """ | |
| # if image is None: | |
| # yield "Please upload an image.", "Please upload an image." | |
| # return | |
| # # Choose processor/model | |
| # if model_name == "Chandra-OCR": | |
| # processor, model = processor_v, model_v | |
| # elif model_name == "Dots.OCR": | |
| # processor, model = processor_d, model_d | |
| # else: | |
| # yield "Invalid model selected.", "Invalid model selected." | |
| # return | |
| # # Prompt (text is provided via gr.State) | |
| # messages = [ | |
| # { | |
| # "role": "user", | |
| # "content": [ | |
| # {"type": "image"}, | |
| # {"type": "text", "text": text}, | |
| # ], | |
| # } | |
| # ] | |
| # prompt_full = processor.apply_chat_template( | |
| # messages, tokenize=False, add_generation_prompt=True | |
| # ) | |
| # # Preprocess | |
| # inputs = processor( | |
| # text=[prompt_full], images=[image], return_tensors="pt", padding=True | |
| # ) | |
| # inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()} | |
| # # Streamer | |
| # tokenizer = getattr(processor, "tokenizer", None) or processor | |
| # streamer = TextIteratorStreamer( | |
| # tokenizer, skip_prompt=True, skip_special_tokens=True | |
| # ) | |
| # gen_kwargs = dict( | |
| # **inputs, | |
| # streamer=streamer, | |
| # max_new_tokens=max_new_tokens, | |
| # do_sample=True, | |
| # temperature=temperature, | |
| # top_p=top_p, | |
| # top_k=top_k, | |
| # repetition_penalty=repetition_penalty, | |
| # ) | |
| # # Start generation in background thread | |
| # thread = Thread(target=model.generate, kwargs=gen_kwargs) | |
| # thread.start() | |
| # # 1) Live OCR streaming : show raw text while generating | |
| # buffer = "" | |
| # for new_text in streamer: | |
| # buffer += new_text.replace("<|im_end|>", "") | |
| # time.sleep(0.01) | |
| # md_stream = "Processing..." | |
| # yield buffer,md_stream #buffer # two outputs: raw + md (same during stream) | |
| # final_ocr_text = buffer.strip() | |
| # # -------------------------------------------------------- | |
| # # 2) Medications extraction | |
| # # -------------------------------------------------------- | |
| # meds: List[str] = [] | |
| # if model_name == "Dots.OCR": | |
| # # ClinicalNER ONLY for Dots.OCR | |
| # try: | |
| # if "ClinicalNER" in priv_classes and HF_TOKEN is not None: | |
| # ClinicalNER = priv_classes["ClinicalNER"] | |
| # ner = ClinicalNER(token=HF_TOKEN) | |
| # ner_output = ner(final_ocr_text) or [] | |
| # meds = [ | |
| # m.strip() | |
| # for m in ner_output | |
| # if isinstance(m, str) and m.strip() | |
| # ] | |
| # print("[NER] (Dots.OCR) ClinicalNER meds:", meds) | |
| # else: | |
| # print("[NER] ClinicalNER unavailable or missing HF token; skipping.") | |
| # except Exception as e: | |
| # print(f"[NER] Error running ClinicalNER: {e}") | |
| # # Fallback if ClinicalNER returns nothing | |
| # if not meds: | |
| # meds = [ | |
| # line.strip() | |
| # for line in final_ocr_text.splitlines() | |
| # if line.strip() | |
| # ] | |
| # print("[NER] (Dots.OCR) Fallback to lines, count:", len(meds)) | |
| # elif model_name == "Chandra-OCR": | |
| # # NO ClinicalNER for Chandra; just use text lines | |
| # meds = [ | |
| # line.strip() | |
| # for line in final_ocr_text.splitlines() | |
| # if line.strip() | |
| # ] | |
| # print("[NER] (Chandra-OCR) Line-based meds only, count:", len(meds)) | |
| # print("[DEBUG] meds count:", len(meds)) | |
| # print("[DEBUG] drug_xlsx_path in generate_image:", drug_xlsx_path) | |
| # # -------------------------------------------------------- | |
| # # 3) Build Markdown base: OCR text + med list | |
| # # -------------------------------------------------------- | |
| # # md = "### Raw OCR Output\n" | |
| # # md += "```\n" + (final_ocr_text or "(empty)") + "\n```\n" | |
| # md="" | |
| # md += "\n---\n### Medications (extracted)\n" | |
| # if meds: | |
| # for m in meds: | |
| # md += f"- {m}\n" | |
| # else: | |
| # md += "- None detected\n" | |
| # # -------------------------------------------------------- | |
| # # 4) Spell-check (med list) with CER | |
| # # -------------------------------------------------------- | |
| # spell_section = "\n---\n### Spell-check suggestions (" + spell_algo + ")\n" | |
| # corr: Dict[str, List] = {} | |
| # if BACKEND_INIT_ERROR: | |
| # spell_section += f"- [DEBUG] Backend init error: {BACKEND_INIT_ERROR}\n" | |
| # try: | |
| # if meds and drug_xlsx_path: | |
| # # Optional Excel debug read | |
| # try: | |
| # df_dbg = pd.read_excel(drug_xlsx_path) | |
| # print( | |
| # f"[Spell DEBUG] Excel read OK: path={drug_xlsx_path}, " | |
| # f"shape={df_dbg.shape}, cols={list(df_dbg.columns)}" | |
| # ) | |
| # spell_section += ( | |
| # f"- [DEBUG] Excel read OK; shape={df_dbg.shape}, " | |
| # f"cols={list(df_dbg.columns)}\n" | |
| # ) | |
| # except Exception as e: | |
| # print(f"[Spell DEBUG] ERROR reading Excel in generate_image: {e}") | |
| # spell_section += f"- [DEBUG] Excel read error: {e}\n" | |
| # # Pick matcher based on spell_algo | |
| # if ( | |
| # spell_algo == "TF-IDF + Phonetic" | |
| # and "TfidfPhoneticMatcher" in priv_classes | |
| # ): | |
| # print("[Spell DEBUG] Using TfidfPhoneticMatcher") | |
| # Cls = priv_classes["TfidfPhoneticMatcher"] | |
| # checker = Cls( | |
| # xlsx_path=drug_xlsx_path, | |
| # column="Combined_Drugs", | |
| # ngram_size=3, | |
| # phonetic_weight=0.4, | |
| # ) | |
| # corr = checker.match_list(meds, top_k=5, tfidf_threshold=0.15) | |
| # elif spell_algo == "SymSpell" and "SymSpellMatcher" in priv_classes: | |
| # print("[Spell DEBUG] Using SymSpellMatcher") | |
| # Cls = priv_classes["SymSpellMatcher"] | |
| # checker = Cls( | |
| # xlsx_path=drug_xlsx_path, | |
| # column="Combined_Drugs", | |
| # max_edit=2, | |
| # prefix_len=7, | |
| # ) | |
| # corr = checker.match_list(meds, top_k=5, min_score=0.4) | |
| # elif spell_algo == "RapidFuzz" and "RapidFuzzMatcher" in priv_classes: | |
| # print("[Spell DEBUG] Using RapidFuzzMatcher") | |
| # Cls = priv_classes["RapidFuzzMatcher"] | |
| # checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs") | |
| # corr = checker.match_list(meds, top_k=5, threshold=70.0) | |
| # else: | |
| # spell_section += ( | |
| # "- Spell-check backend unavailable " | |
| # "(no matcher class for selected algorithm).\n" | |
| # ) | |
| # else: | |
| # if not meds: | |
| # spell_section += "- No medications extracted (empty med list).\n" | |
| # if not drug_xlsx_path: | |
| # spell_section += ( | |
| # "- Drug Excel dictionary path missing " | |
| # "(drug_xlsx_path is None).\n" | |
| # ) | |
| # except Exception as e: | |
| # print(f"[Spell DEBUG] Spell-check error: {e}") | |
| # spell_section += f"- Spell-check error: {e}\n" | |
| # # Format suggestions (top-5 per med, with scores + CER) | |
| # if corr: | |
| # for raw in meds: | |
| # suggestions = corr.get(raw, []) | |
| # if suggestions: | |
| # spell_section += f"- **{raw}**\n" | |
| # for cand, score in suggestions: | |
| # cer = character_error_rate(cand, raw) | |
| # spell_section += ( | |
| # f" - {cand} (score={score:.3f}, CER={cer:.3f}%)\n" | |
| # ) | |
| # else: | |
| # spell_section += f"- **{raw}**\n - (no suggestions)\n" | |
| # final_md = md + spell_section | |
| # # Final yield: raw OCR text + full markdown | |
| # # yield final_ocr_text, final_md | |
| # yield final_md | |
| # # ============================================================ | |
| # # UI | |
| # # ============================================================ | |
| # image_examples = [ | |
| # ["examples/test_1.jpeg"], | |
| # ["examples/test_4.jpeg"], | |
| # ["examples/test_5.jpeg"], | |
| # ] | |
| # with gr.Blocks(css=css, theme=steel_blue_theme) as demo: | |
| # gr.Markdown( | |
| # "# **Handwritten Doctor's Prescription Reading**", elem_id="main-title" | |
| # ) | |
| # with gr.Row(): | |
| # with gr.Column(scale=2): | |
| # image_upload = gr.Image( | |
| # type="pil", label="Upload Image", height=290 | |
| # ) | |
| # image_submit = gr.Button("Submit", variant="primary") | |
| # gr.Examples( | |
| # examples=image_examples, | |
| # inputs=[image_upload], | |
| # label="Example Images", | |
| # ) | |
| # spell_choice = gr.Radio( | |
| # choices=["TF-IDF + Phonetic", "SymSpell", "RapidFuzz"], | |
| # label="Select Spell-check Approach", | |
| # value="TF-IDF + Phonetic", | |
| # ) | |
| # with gr.Accordion("Advanced options", open=False): | |
| # max_new_tokens = gr.Slider( | |
| # label="Max new tokens", | |
| # minimum=1, | |
| # maximum=MAX_MAX_NEW_TOKENS, | |
| # step=1, | |
| # value=DEFAULT_MAX_NEW_TOKENS, | |
| # ) | |
| # temperature = gr.Slider( | |
| # label="Temperature", | |
| # minimum=0.1, | |
| # maximum=4.0, | |
| # step=0.1, | |
| # value=0.7, | |
| # ) | |
| # top_p = gr.Slider( | |
| # label="Top-p (nucleus sampling)", | |
| # minimum=0.05, | |
| # maximum=1.0, | |
| # step=0.05, | |
| # value=0.9, | |
| # ) | |
| # top_k = gr.Slider( | |
| # label="Top-k", | |
| # minimum=1, | |
| # maximum=1000, | |
| # step=1, | |
| # value=50, | |
| # ) | |
| # repetition_penalty = gr.Slider( | |
| # label="Repetition penalty", | |
| # minimum=1.0, | |
| # maximum=2.0, | |
| # step=0.05, | |
| # value=1.1, | |
| # ) | |
| # with gr.Column(scale=3): | |
| # gr.Markdown("## Output", elem_id="output-title") | |
| # # output = gr.Textbox( | |
| # # label="Raw Output Stream", | |
| # # interactive=False, | |
| # # lines=11, | |
| # # show_copy_button=True, | |
| # # ) | |
| # with gr.Accordion("(Result.md)", open=False): | |
| # markdown_output = gr.Markdown(label="(Result.Md)") | |
| # model_choice = gr.Radio( | |
| # choices=["Chandra-OCR", "Dots.OCR"], | |
| # label="Select OCR Model", | |
| # value="Chandra-OCR", | |
| # ) | |
| # # Hard-coded query text (passed into the 'text' parameter) | |
| # query_state = gr.State( | |
| # "Extract medicine or drugs names along with dosage amount or quantity" | |
| # ) | |
| # image_submit.click( | |
| # fn=generate_image, | |
| # inputs=[ | |
| # model_choice, | |
| # query_state, | |
| # image_upload, | |
| # max_new_tokens, | |
| # temperature, | |
| # top_p, | |
| # top_k, | |
| # repetition_penalty, | |
| # spell_choice, | |
| # ], | |
| # outputs=[markdown_output], | |
| # ) | |
| # if __name__ == "__main__": | |
| # demo.queue(max_size=50).launch( | |
| # mcp_server=True, ssr_mode=False, show_error=True | |
| # ) | |
| ################################### version 3 ######################################## | |
| import os | |
| import time | |
| from threading import Thread | |
| from typing import Iterable, Dict, Any, Optional, List | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| import pandas as pd # for reading Excel and debugging | |
| from transformers import ( | |
| Qwen3VLForConditionalGeneration, | |
| AutoModelForCausalLM, | |
| AutoProcessor, | |
| TextIteratorStreamer, | |
| ) | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| # ----------------------------- | |
| # Character Error Rate (CER) | |
| # ----------------------------- | |
| def levenshtein(a: str, b: str) -> int: | |
| """Levenshtein distance to calculate CER.""" | |
| a, b = a.lower(), b.lower() | |
| if a == b: | |
| return 0 | |
| if not a: | |
| return len(b) | |
| if not b: | |
| return len(a) | |
| dp = list(range(len(b) + 1)) | |
| for i, ca in enumerate(a, 1): | |
| prev = dp[0] | |
| dp[0] = i | |
| for j, cb in enumerate(b, 1): | |
| cur = dp[j] | |
| cost = 0 if ca == cb else 1 | |
| dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost) | |
| prev = cur | |
| return dp[-1] | |
| def character_error_rate(pred: str, target: str) -> float: | |
| """Calculate the Character Error Rate (CER) in percent.""" | |
| target = target or "" | |
| distance = levenshtein(pred, target) | |
| return (distance / len(target)) * 100 if len(target) > 0 else 0.0 | |
| # ----------------------------- | |
| # Private repo: dynamic import | |
| # ----------------------------- | |
| import importlib.util | |
| from huggingface_hub import hf_hub_download | |
| REPO_ID = "IFMedTech/Medibot_OCR_model" # private backend repo | |
| # Map filenames to exported class names | |
| PY_MODULES = { | |
| "ner.py": "ClinicalNER", # NER is only applied for Dots.OCR output | |
| "tfidf_phonetic.py": "TfidfPhoneticMatcher", | |
| "symspell_matcher.py": "SymSpellMatcher", | |
| "rapidfuzz_matcher.py": "RapidFuzzMatcher", | |
| # 'Medibot_Drugs_Cleaned_Updated.xlsx' is data, not a module | |
| } | |
| HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN") | |
| def _dynamic_import(module_path: str, class_name: str): | |
| spec = importlib.util.spec_from_file_location(class_name, module_path) | |
| module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(module) # type: ignore | |
| return getattr(module, class_name) | |
| # Load private classes and Excel dictionary (once at import time) | |
| priv_classes: Dict[str, Any] = {} | |
| drug_xlsx_path: Optional[str] = None | |
| try: | |
| if HF_TOKEN is None: | |
| print("[Private] WARNING: HUGGINGFACE_TOKEN not set; NER/Spell-check will be unavailable.") | |
| else: | |
| for fname, cls in PY_MODULES.items(): | |
| path = hf_hub_download(repo_id=REPO_ID, filename=fname, token=HF_TOKEN) | |
| if cls: | |
| priv_classes[cls] = _dynamic_import(path, cls) | |
| print(f"[Private] Loaded class: {cls} from {fname}") | |
| drug_xlsx_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename="Medibot_Drugs_Cleaned_Updated.xlsx", | |
| token=HF_TOKEN, | |
| ) | |
| print(f"[Private] Downloaded Excel at: {drug_xlsx_path}") | |
| # DEBUG: read Excel once and print its shape | |
| try: | |
| df_debug = pd.read_excel(drug_xlsx_path) | |
| print(f"[Private] Excel loaded successfully. Shape: {df_debug.shape}") | |
| except Exception as e: | |
| print(f"[Private] ERROR reading Excel for debug: {e}") | |
| except Exception as e: | |
| print(f"[Private] ERROR loading private backend: {e}") | |
| priv_classes = {} | |
| drug_xlsx_path = None | |
| # ---------------------------- | |
| # THEME | |
| # ---------------------------- | |
| colors.steel_blue = colors.Color( | |
| name="steel_blue", | |
| c50="#EBF3F8", | |
| c100="#D3E5F0", | |
| c200="#A8CCE1", | |
| c300="#7DB3D2", | |
| c400="#529AC3", | |
| c500="#4682B4", | |
| c600="#3E72A0", | |
| c700="#36638C", | |
| c800="#2E5378", | |
| c900="#264364", | |
| c950="#1E3450", | |
| ) | |
| class SteelBlueTheme(Soft): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue: colors.Color | str = colors.gray, | |
| secondary_hue: colors.Color | str = colors.steel_blue, | |
| neutral_hue: colors.Color | str = colors.slate, | |
| text_size: sizes.Size | str = sizes.text_lg, | |
| font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("Outfit"), | |
| "Arial", | |
| "sans-serif", | |
| ), | |
| font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("IBM Plex Mono"), | |
| "ui-monospace", | |
| "monospace", | |
| ), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, | |
| secondary_hue=secondary_hue, | |
| neutral_hue=neutral_hue, | |
| text_size=text_size, | |
| font=font, | |
| font_mono=font_mono, | |
| ) | |
| super().set( | |
| background_fill_primary="*primary_50", | |
| background_fill_primary_dark="*primary_900", | |
| body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", | |
| body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", | |
| button_primary_text_color="white", | |
| button_primary_text_color_hover="white", | |
| button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)", | |
| button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)", | |
| button_secondary_text_color="black", | |
| button_secondary_text_color_hover="white", | |
| button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)", | |
| button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)", | |
| button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)", | |
| button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)", | |
| slider_color="*secondary_500", | |
| slider_color_dark="*secondary_600", | |
| block_title_text_weight="600", | |
| block_border_width="3px", | |
| block_shadow="*shadow_drop_lg", | |
| button_primary_shadow="*shadow_drop_lg", | |
| button_large_padding="11px", | |
| color_accent_soft="*primary_100", | |
| block_label_background_fill="*primary_200", | |
| ) | |
| steel_blue_theme = SteelBlueTheme() | |
| css = """ | |
| #main-title h1 { font-size: 2.3em !important; } | |
| #output-title h2 { font-size: 2.1em !important; } | |
| """ | |
| # ---------------------------- | |
| # RUNTIME / DEVICE | |
| # ---------------------------- | |
| os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0") | |
| print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES")) | |
| print("torch.__version__ =", torch.__version__) | |
| print("torch.version.cuda =", torch.version.cuda) | |
| print("cuda available =", torch.cuda.is_available()) | |
| print("cuda device count =", torch.cuda.device_count()) | |
| if torch.cuda.is_available(): | |
| print("using device =", torch.cuda.get_device_name(0)) | |
| use_cuda = torch.cuda.is_available() | |
| device = torch.device("cuda:0" if use_cuda else "cpu") | |
| if use_cuda: | |
| torch.backends.cudnn.benchmark = True | |
| DTYPE_FP16 = torch.float16 if use_cuda else torch.float32 | |
| DTYPE_BF16 = torch.bfloat16 if use_cuda else torch.float32 | |
| # ---------------------------- | |
| # OCR MODELS: Chandra-OCR + Dots.OCR | |
| # ---------------------------- | |
| # 1) Chandra-OCR (Qwen3VL) | |
| MODEL_ID_V = "datalab-to/chandra" | |
| processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True) | |
| model_v = Qwen3VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_V, trust_remote_code=True, torch_dtype=DTYPE_FP16 | |
| ).to(device).eval() | |
| # 2) Dots.OCR (flash_attn2 if available, else SDPA) | |
| MODEL_PATH_D = "prithivMLmods/Dots.OCR-Latest-BF16" | |
| processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True) | |
| attn_impl = "sdpa" | |
| try: | |
| import flash_attn # noqa: F401 | |
| if use_cuda: | |
| attn_impl = "flash_attention_2" | |
| except Exception: | |
| attn_impl = "sdpa" | |
| model_d = AutoModelForCausalLM.from_pretrained( | |
| MODEL_PATH_D, | |
| attn_implementation=attn_impl, | |
| torch_dtype=DTYPE_BF16, | |
| device_map="auto" if use_cuda else None, | |
| trust_remote_code=True, | |
| ).eval() | |
| if not use_cuda: | |
| model_d.to(device) | |
| # ---------------------------- | |
| # GENERATION (OCR → NER (Dots only) → Spell-check + CER) | |
| # ---------------------------- | |
| MAX_MAX_NEW_TOKENS = 4096 | |
| DEFAULT_MAX_NEW_TOKENS = 2048 | |
| # you can add duration=... if needed, e.g. @spaces.GPU(duration=240) | |
| def generate_image( | |
| model_name: str, | |
| text: str, | |
| image: Image.Image, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int, | |
| repetition_penalty: float, | |
| spell_algo: str, | |
| ): | |
| """ | |
| 1) Stream OCR tokens to Raw output (unchanged). | |
| 2) If model_name == 'Dots.OCR', run ClinicalNER → list[str] meds. | |
| For Chandra-OCR, skip NER. | |
| 3) Apply selected spell-check (TF-IDF+Phonetic / SymSpell / RapidFuzz) | |
| using Excel dict, and compute CER for each suggestion. | |
| 4) Markdown shows OCR text, NER list (if any), and spell-check top-5 | |
| suggestions with scores and CER. | |
| """ | |
| if image is None: | |
| # Two outputs: raw textbox + markdown | |
| yield "Please upload an image.", "Please upload an image." | |
| return | |
| if model_name == "Chandra-OCR": | |
| processor, model = processor_v, model_v | |
| elif model_name == "Dots.OCR": | |
| processor, model = processor_d, model_d | |
| else: | |
| yield "Invalid model selected.", "Invalid model selected." | |
| return | |
| # Build prompt from text parameter (kept via gr.State) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": text}, | |
| ], | |
| } | |
| ] | |
| prompt_full = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| # Preprocess | |
| inputs = processor( | |
| text=[prompt_full], images=[image], return_tensors="pt", padding=True | |
| ) | |
| inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()} | |
| # Streamer | |
| tokenizer = getattr(processor, "tokenizer", None) or processor | |
| streamer = TextIteratorStreamer( | |
| tokenizer, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| gen_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| # Start generation in background thread | |
| thread = Thread(target=model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| # 1) Live OCR streaming to Raw (and mirror to Markdown during stream) | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text.replace("<|im_end|>", "") | |
| time.sleep(0.01) | |
| # During streaming, just show the raw text in both components | |
| yield buffer, buffer | |
| # Final raw text | |
| final_ocr_text = buffer.strip() | |
| # ------------------------- | |
| # 2) Clinical NER (Dots.OCR only) | |
| # ------------------------- | |
| meds: List[str] = [] | |
| if model_name == "Dots.OCR": | |
| try: | |
| if "ClinicalNER" in priv_classes and HF_TOKEN is not None: | |
| ClinicalNER = priv_classes["ClinicalNER"] | |
| ner = ClinicalNER(token=HF_TOKEN) # model_id can be passed if needed | |
| ner_output = ner(final_ocr_text) or [] | |
| # Expecting list[str]; be robust: | |
| meds = [m.strip() for m in ner_output if isinstance(m, str) and m.strip()] | |
| print("[NER] Extracted meds (from ClinicalNER):", meds) | |
| else: | |
| print("[NER] ClinicalNER not available or no HF token.") | |
| except Exception as e: | |
| print(f"[NER] Error running ClinicalNER: {e}") | |
| # Fallback: if no meds found (or Chandra-OCR), derive meds from OCR lines | |
| if not meds: | |
| meds = [line.strip() for line in final_ocr_text.splitlines() if line.strip()] | |
| print("[NER] Using line-based meds fallback, count:", len(meds)) | |
| print("[DEBUG] meds count:", len(meds)) | |
| print("[DEBUG] drug_xlsx_path in generate_image:", drug_xlsx_path) | |
| # ------------------------- | |
| # Build Markdown: OCR text + NER section | |
| # ------------------------- | |
| md = "### Raw OCR Output\n" | |
| md += "```\n" + (final_ocr_text or "(empty)") + "\n```\n" | |
| md += "\n---\n### Clinical NER (Medications)\n" | |
| if meds: | |
| for m in meds: | |
| md += f"- {m}\n" | |
| else: | |
| md += "- None detected\n" | |
| # ------------------------- | |
| # 3) Spell-check (med list) with CER | |
| # ------------------------- | |
| spell_section = "\n---\n### Spell-check suggestions (" + spell_algo + ")\n" | |
| corr: Dict[str, List] = {} | |
| try: | |
| if meds and drug_xlsx_path: | |
| try: | |
| df_debug = pd.read_excel(drug_xlsx_path) | |
| print(f"[Private] Excel loaded successfully. Shape: {df_debug.shape}") | |
| except Exception as e: | |
| print(f"[Private] ERROR reading Excel for debug: {e}") | |
| if ( | |
| spell_algo == "TF-IDF + Phonetic" | |
| and "TfidfPhoneticMatcher" in priv_classes | |
| ): | |
| Cls = priv_classes["TfidfPhoneticMatcher"] | |
| checker = Cls( | |
| xlsx_path=drug_xlsx_path, | |
| column="Combined_Drugs", | |
| ngram_size=3, | |
| phonetic_weight=0.4, | |
| ) | |
| corr = checker.match_list(meds, top_k=5, tfidf_threshold=0.15) | |
| elif spell_algo == "SymSpell" and "SymSpellMatcher" in priv_classes: | |
| Cls = priv_classes["SymSpellMatcher"] | |
| checker = Cls( | |
| xlsx_path=drug_xlsx_path, | |
| column="Combined_Drugs", | |
| max_edit=2, | |
| prefix_len=7, | |
| ) | |
| corr = checker.match_list(meds, top_k=5, min_score=0.4) | |
| elif spell_algo == "RapidFuzz" and "RapidFuzzMatcher" in priv_classes: | |
| Cls = priv_classes["RapidFuzzMatcher"] | |
| checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs") | |
| corr = checker.match_list(meds, top_k=5, threshold=70.0) | |
| else: | |
| spell_section += "- Spell-check backend unavailable (no matcher class).\n" | |
| else: | |
| if not meds: | |
| spell_section += "- No medications extracted (empty med list).\n" | |
| if not drug_xlsx_path: | |
| spell_section += "- Drug Excel dictionary path missing (drug_xlsx_path is None).\n" | |
| except Exception as e: | |
| spell_section += f"- Spell-check error: {e}\n" | |
| # Format suggestions (top-5 per med, with scores + CER) | |
| if corr: | |
| for raw in meds: | |
| suggestions = corr.get(raw, []) | |
| if suggestions: | |
| spell_section += f"- **{raw}**\n" | |
| for cand, score in suggestions: | |
| cer = character_error_rate(cand, raw) | |
| spell_section += ( | |
| f" - {cand} (score={score:.3f}, CER={cer:.3f}%)\n" | |
| ) | |
| else: | |
| spell_section += f"- **{raw}**\n - (no suggestions)\n" | |
| final_md = md + spell_section | |
| # 4) Final yield: raw unchanged; Markdown with NER + spell-check + CER | |
| yield final_ocr_text, final_md | |
| # ---------------------------- | |
| # UI | |
| # ---------------------------- | |
| # IMPORTANT: examples must match the number of inputs (here: only image) | |
| image_examples = [ | |
| ["examples/3.jpg"], | |
| ["examples/1.jpg"], | |
| ["examples/2.jpg"], | |
| ] | |
| with gr.Blocks(css=css, theme=steel_blue_theme) as demo: | |
| gr.Markdown( | |
| "# **Handwritten Doctor's Prescription Reading**", elem_id="main-title" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_upload = gr.Image( | |
| type="pil", label="Upload Image", height=290 | |
| ) | |
| image_submit = gr.Button("Submit", variant="primary") | |
| gr.Examples( | |
| examples=image_examples, | |
| inputs=[image_upload], | |
| label="Example Images", | |
| ) | |
| # Spell-check selection | |
| spell_choice = gr.Radio( | |
| choices=["TF-IDF + Phonetic", "SymSpell", "RapidFuzz"], | |
| label="Select Spell-check Approach", | |
| value="TF-IDF + Phonetic", | |
| ) | |
| with gr.Accordion("Advanced options", open=False): | |
| max_new_tokens = gr.Slider( | |
| label="Max new tokens", | |
| minimum=1, | |
| maximum=MAX_MAX_NEW_TOKENS, | |
| step=1, | |
| value=DEFAULT_MAX_NEW_TOKENS, | |
| ) | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| minimum=0.1, | |
| maximum=4.0, | |
| step=0.1, | |
| value=0.7, | |
| ) | |
| top_p = gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| minimum=0.05, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.9, | |
| ) | |
| top_k = gr.Slider( | |
| label="Top-k", | |
| minimum=1, | |
| maximum=1000, | |
| step=1, | |
| value=50, | |
| ) | |
| repetition_penalty = gr.Slider( | |
| label="Repetition penalty", | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| value=1.1, | |
| ) | |
| with gr.Column(scale=3): | |
| gr.Markdown("## Output", elem_id="output-title") | |
| output = gr.Textbox( | |
| label="Raw Output Stream", | |
| interactive=False, | |
| lines=11, | |
| show_copy_button=True, | |
| ) | |
| with gr.Accordion("(Result.md)", open=False): | |
| markdown_output = gr.Markdown(label="(Result.Md)") | |
| model_choice = gr.Radio( | |
| choices=["Chandra-OCR", "Dots.OCR"], | |
| label="Select OCR Model", | |
| value="Chandra-OCR", | |
| ) | |
| # Hard-coded instruction text, passed as gr.State to match the 'text' parameter | |
| query_state = gr.State( | |
| "Extract medicine or drugs names along with dosage amount or quantity" | |
| ) | |
| image_submit.click( | |
| fn=generate_image, | |
| inputs=[ | |
| model_choice, | |
| query_state, | |
| image_upload, | |
| max_new_tokens, | |
| temperature, | |
| top_p, | |
| top_k, | |
| repetition_penalty, | |
| spell_choice, | |
| ], | |
| outputs=[output, markdown_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=50).launch( | |
| mcp_server=True, ssr_mode=False, show_error=True | |
| ) | |
| ######################################### version 2 ######################################################################### | |
| # import os | |
| # import time | |
| # from threading import Thread | |
| # from typing import Iterable, Dict, Any, Optional, List | |
| # import gradio as gr | |
| # import spaces | |
| # import torch | |
| # from PIL import Image | |
| # from transformers import ( | |
| # Qwen3VLForConditionalGeneration, | |
| # AutoModelForCausalLM, | |
| # AutoProcessor, | |
| # TextIteratorStreamer, | |
| # ) | |
| # from gradio.themes import Soft | |
| # from gradio.themes.utils import colors, fonts, sizes | |
| # # ----------------------------- | |
| # # Character Error Rate (CER) | |
| # # ----------------------------- | |
| # def levenshtein(a: str, b: str) -> int: | |
| # """Levenshtein distance to calculate CER.""" | |
| # a, b = a.lower(), b.lower() | |
| # if a == b: | |
| # return 0 | |
| # if not a: | |
| # return len(b) | |
| # if not b: | |
| # return len(a) | |
| # dp = list(range(len(b) + 1)) | |
| # for i, ca in enumerate(a, 1): | |
| # prev = dp[0] | |
| # dp[0] = i | |
| # for j, cb in enumerate(b, 1): | |
| # cur = dp[j] | |
| # cost = 0 if ca == cb else 1 | |
| # dp[j] = min(dp[j] + 1, dp[j - 1] + 1, prev + cost) | |
| # prev = cur | |
| # return dp[-1] | |
| # def character_error_rate(pred: str, target: str) -> float: | |
| # """Calculate the Character Error Rate (CER) in percent.""" | |
| # target = target or "" | |
| # distance = levenshtein(pred, target) | |
| # return (distance / len(target)) * 100 if len(target) > 0 else 0.0 | |
| # # ----------------------------- | |
| # # Private repo: dynamic import | |
| # # ----------------------------- | |
| # import importlib.util | |
| # from huggingface_hub import hf_hub_download | |
| # REPO_ID = "IFMedTech/Medibot_OCR_model" # private backend repo | |
| # # Map filenames to exported class names | |
| # PY_MODULES = { | |
| # "ner.py": "ClinicalNER", # NER is only applied for Dots.OCR output | |
| # "tfidf_phonetic.py": "TfidfPhoneticMatcher", | |
| # "symspell_matcher.py": "SymSpellMatcher", | |
| # "rapidfuzz_matcher.py": "RapidFuzzMatcher", | |
| # # 'Medibot_Drugs_Cleaned_Updated.xlsx' is data, not a module | |
| # } | |
| # HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN") | |
| # def _dynamic_import(module_path: str, class_name: str): | |
| # spec = importlib.util.spec_from_file_location(class_name, module_path) | |
| # module = importlib.util.module_from_spec(spec) | |
| # spec.loader.exec_module(module) # type: ignore | |
| # return getattr(module, class_name) | |
| # # Load private classes and Excel dictionary (once at import time) | |
| # priv_classes: Dict[str, Any] = {} | |
| # drug_xlsx_path: Optional[str] = None | |
| # try: | |
| # if HF_TOKEN is None: | |
| # print("[Private] WARNING: HUGGINGFACE_TOKEN not set; NER/Spell-check will be unavailable.") | |
| # else: | |
| # for fname, cls in PY_MODULES.items(): | |
| # path = hf_hub_download(repo_id=REPO_ID, filename=fname, token=HF_TOKEN) | |
| # if cls: | |
| # priv_classes[cls] = _dynamic_import(path, cls) | |
| # print(f"[Private] Loaded class: {cls} from {fname}") | |
| # drug_xlsx_path = hf_hub_download( | |
| # repo_id=REPO_ID, | |
| # filename="Medibot_Drugs_Cleaned_Updated.xlsx", | |
| # token=HF_TOKEN, | |
| # ) | |
| # print(f"[Private] Downloaded Excel at: {drug_xlsx_path}") | |
| # except Exception as e: | |
| # print(f"[Private] ERROR loading private backend: {e}") | |
| # priv_classes = {} | |
| # drug_xlsx_path = None | |
| # # ---------------------------- | |
| # # THEME | |
| # # ---------------------------- | |
| # colors.steel_blue = colors.Color( | |
| # name="steel_blue", | |
| # c50="#EBF3F8", | |
| # c100="#D3E5F0", | |
| # c200="#A8CCE1", | |
| # c300="#7DB3D2", | |
| # c400="#529AC3", | |
| # c500="#4682B4", | |
| # c600="#3E72A0", | |
| # c700="#36638C", | |
| # c800="#2E5378", | |
| # c900="#264364", | |
| # c950="#1E3450", | |
| # ) | |
| # class SteelBlueTheme(Soft): | |
| # def __init__( | |
| # self, | |
| # *, | |
| # primary_hue: colors.Color | str = colors.gray, | |
| # secondary_hue: colors.Color | str = colors.steel_blue, | |
| # neutral_hue: colors.Color | str = colors.slate, | |
| # text_size: sizes.Size | str = sizes.text_lg, | |
| # font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| # fonts.GoogleFont("Outfit"), | |
| # "Arial", | |
| # "sans-serif", | |
| # ), | |
| # font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| # fonts.GoogleFont("IBM Plex Mono"), | |
| # "ui-monospace", | |
| # "monospace", | |
| # ), | |
| # ): | |
| # super().__init__( | |
| # primary_hue=primary_hue, | |
| # secondary_hue=secondary_hue, | |
| # neutral_hue=neutral_hue, | |
| # text_size=text_size, | |
| # font=font, | |
| # font_mono=font_mono, | |
| # ) | |
| # super().set( | |
| # background_fill_primary="*primary_50", | |
| # background_fill_primary_dark="*primary_900", | |
| # body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", | |
| # body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", | |
| # button_primary_text_color="white", | |
| # button_primary_text_color_hover="white", | |
| # button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| # button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| # button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)", | |
| # button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)", | |
| # button_secondary_text_color="black", | |
| # button_secondary_text_color_hover="white", | |
| # button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)", | |
| # button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)", | |
| # button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)", | |
| # button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)", | |
| # slider_color="*secondary_500", | |
| # slider_color_dark="*secondary_600", | |
| # block_title_text_weight="600", | |
| # block_border_width="3px", | |
| # block_shadow="*shadow_drop_lg", | |
| # button_primary_shadow="*shadow_drop_lg", | |
| # button_large_padding="11px", | |
| # color_accent_soft="*primary_100", | |
| # block_label_background_fill="*primary_200", | |
| # ) | |
| # steel_blue_theme = SteelBlueTheme() | |
| # css = """ | |
| # #main-title h1 { font-size: 2.3em !important; } | |
| # #output-title h2 { font-size: 2.1em !important; } | |
| # """ | |
| # # ---------------------------- | |
| # # RUNTIME / DEVICE | |
| # # ---------------------------- | |
| # os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0") | |
| # print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES")) | |
| # print("torch.__version__ =", torch.__version__) | |
| # print("torch.version.cuda =", torch.version.cuda) | |
| # print("cuda available =", torch.cuda.is_available()) | |
| # print("cuda device count =", torch.cuda.device_count()) | |
| # if torch.cuda.is_available(): | |
| # print("using device =", torch.cuda.get_device_name(0)) | |
| # use_cuda = torch.cuda.is_available() | |
| # device = torch.device("cuda:0" if use_cuda else "cpu") | |
| # if use_cuda: | |
| # torch.backends.cudnn.benchmark = True | |
| # DTYPE_FP16 = torch.float16 if use_cuda else torch.float32 | |
| # DTYPE_BF16 = torch.bfloat16 if use_cuda else torch.float32 | |
| # # ---------------------------- | |
| # # OCR MODELS: Chandra-OCR + Dots.OCR | |
| # # ---------------------------- | |
| # # 1) Chandra-OCR (Qwen3VL) | |
| # MODEL_ID_V = "datalab-to/chandra" | |
| # processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True) | |
| # model_v = Qwen3VLForConditionalGeneration.from_pretrained( | |
| # MODEL_ID_V, trust_remote_code=True, torch_dtype=DTYPE_FP16 | |
| # ).to(device).eval() | |
| # # 2) Dots.OCR (flash_attn2 if available, else SDPA) | |
| # MODEL_PATH_D = "prithivMLmods/Dots.OCR-Latest-BF16" | |
| # processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True) | |
| # attn_impl = "sdpa" | |
| # try: | |
| # import flash_attn # noqa: F401 | |
| # if use_cuda: | |
| # attn_impl = "flash_attention_2" | |
| # except Exception: | |
| # attn_impl = "sdpa" | |
| # model_d = AutoModelForCausalLM.from_pretrained( | |
| # MODEL_PATH_D, | |
| # attn_implementation=attn_impl, | |
| # torch_dtype=DTYPE_BF16, | |
| # device_map="auto" if use_cuda else None, | |
| # trust_remote_code=True, | |
| # ).eval() | |
| # if not use_cuda: | |
| # model_d.to(device) | |
| # # ---------------------------- | |
| # # GENERATION (OCR → NER (Dots only) → Spell-check + CER) | |
| # # ---------------------------- | |
| # MAX_MAX_NEW_TOKENS = 4096 | |
| # DEFAULT_MAX_NEW_TOKENS = 2048 | |
| # @spaces.GPU # you can add duration=... if needed, e.g. @spaces.GPU(duration=240) | |
| # def generate_image( | |
| # model_name: str, | |
| # text: str, | |
| # image: Image.Image, | |
| # max_new_tokens: int, | |
| # temperature: float, | |
| # top_p: float, | |
| # top_k: int, | |
| # repetition_penalty: float, | |
| # spell_algo: str, | |
| # ): | |
| # """ | |
| # 1) Stream OCR tokens to Raw output (unchanged). | |
| # 2) If model_name == 'Dots.OCR', run ClinicalNER → list[str] meds. | |
| # For Chandra-OCR, skip NER. | |
| # 3) Apply selected spell-check (TF-IDF+Phonetic / SymSpell / RapidFuzz) | |
| # using Excel dict, and compute CER for each suggestion. | |
| # 4) Markdown shows OCR text, NER list (if any), and spell-check top-5 | |
| # suggestions with scores and CER. | |
| # """ | |
| # if image is None: | |
| # # Two outputs: raw textbox + markdown | |
| # yield "Please upload an image.", "Please upload an image." | |
| # return | |
| # if model_name == "Chandra-OCR": | |
| # processor, model = processor_v, model_v | |
| # elif model_name == "Dots.OCR": | |
| # processor, model = processor_d, model_d | |
| # else: | |
| # yield "Invalid model selected.", "Invalid model selected." | |
| # return | |
| # # Build prompt from text parameter (kept via gr.State) | |
| # messages = [ | |
| # { | |
| # "role": "user", | |
| # "content": [ | |
| # {"type": "image"}, | |
| # {"type": "text", "text": text}, | |
| # ], | |
| # } | |
| # ] | |
| # prompt_full = processor.apply_chat_template( | |
| # messages, tokenize=False, add_generation_prompt=True | |
| # ) | |
| # # Preprocess | |
| # inputs = processor( | |
| # text=[prompt_full], images=[image], return_tensors="pt", padding=True | |
| # ) | |
| # inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()} | |
| # # Streamer | |
| # tokenizer = getattr(processor, "tokenizer", None) or processor | |
| # streamer = TextIteratorStreamer( | |
| # tokenizer, skip_prompt=True, skip_special_tokens=True | |
| # ) | |
| # gen_kwargs = dict( | |
| # **inputs, | |
| # streamer=streamer, | |
| # max_new_tokens=max_new_tokens, | |
| # do_sample=True, | |
| # temperature=temperature, | |
| # top_p=top_p, | |
| # top_k=top_k, | |
| # repetition_penalty=repetition_penalty, | |
| # ) | |
| # # Start generation in background thread | |
| # thread = Thread(target=model.generate, kwargs=gen_kwargs) | |
| # thread.start() | |
| # # 1) Live OCR streaming to Raw (and mirror to Markdown during stream) | |
| # buffer = "" | |
| # for new_text in streamer: | |
| # buffer += new_text.replace("<|im_end|>", "") | |
| # time.sleep(0.01) | |
| # # During streaming, just show the raw text in both components | |
| # yield buffer, buffer | |
| # # Final raw text | |
| # final_ocr_text = buffer.strip() | |
| # # ------------------------- | |
| # # 2) Clinical NER (Dots.OCR only) | |
| # # ------------------------- | |
| # meds: List[str] = [] | |
| # if model_name == "Dots.OCR": | |
| # try: | |
| # if "ClinicalNER" in priv_classes and HF_TOKEN is not None: | |
| # ClinicalNER = priv_classes["ClinicalNER"] | |
| # ner = ClinicalNER(token=HF_TOKEN) # model_id can be passed if needed | |
| # ner_output = ner(final_ocr_text) or [] | |
| # # Expecting list[str]; be robust: | |
| # meds = [m.strip() for m in ner_output if isinstance(m, str) and m.strip()] | |
| # print("[NER] Extracted meds:", meds) | |
| # else: | |
| # print("[NER] ClinicalNER not available or no HF token.") | |
| # except Exception as e: | |
| # print(f"[NER] Error running ClinicalNER: {e}") | |
| # # Fallback: if no meds found (or Chandra-OCR), derive meds from OCR lines | |
| # if not meds: | |
| # meds = [line.strip() for line in final_ocr_text.splitlines() if line.strip()] | |
| # print("[NER] Using line-based meds fallback, count:", len(meds)) | |
| # # ------------------------- | |
| # # Build Markdown: OCR text + NER section | |
| # # ------------------------- | |
| # md = "### Raw OCR Output\n" | |
| # md += "```\n" + (final_ocr_text or "(empty)") + "\n```\n" | |
| # md += "\n---\n### Clinical NER (Medications)\n" | |
| # if meds: | |
| # for m in meds: | |
| # md += f"- {m}\n" | |
| # else: | |
| # md += "- None detected\n" | |
| # # ------------------------- | |
| # # 3) Spell-check (med list) with CER | |
| # # ------------------------- | |
| # spell_section = "\n---\n### Spell-check suggestions (" + spell_algo + ")\n" | |
| # corr: Dict[str, List] = {} | |
| # try: | |
| # if meds and drug_xlsx_path: | |
| # if ( | |
| # spell_algo == "TF-IDF + Phonetic" | |
| # and "TfidfPhoneticMatcher" in priv_classes | |
| # ): | |
| # Cls = priv_classes["TfidfPhoneticMatcher"] | |
| # checker = Cls( | |
| # xlsx_path=drug_xlsx_path, | |
| # column="Combined_Drugs", | |
| # ngram_size=3, | |
| # phonetic_weight=0.4, | |
| # ) | |
| # corr = checker.match_list(meds, top_k=5, tfidf_threshold=0.15) | |
| # elif spell_algo == "SymSpell" and "SymSpellMatcher" in priv_classes: | |
| # Cls = priv_classes["SymSpellMatcher"] | |
| # checker = Cls( | |
| # xlsx_path=drug_xlsx_path, | |
| # column="Combined_Drugs", | |
| # max_edit=2, | |
| # prefix_len=7, | |
| # ) | |
| # corr = checker.match_list(meds, top_k=5, min_score=0.4) | |
| # elif ( | |
| # spell_algo == "RapidFuzz" and "RapidFuzzMatcher" in priv_classes | |
| # ): | |
| # Cls = priv_classes["RapidFuzzMatcher"] | |
| # checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs") | |
| # corr = checker.match_list(meds, top_k=5, threshold=70.0) | |
| # else: | |
| # spell_section += "- Spell-check backend unavailable.\n" | |
| # else: | |
| # spell_section += "- No NER/med list or Excel dictionary missing.\n" | |
| # except Exception as e: | |
| # spell_section += f"- Spell-check error: {e}\n" | |
| # # Format suggestions (top-5 per med, with scores + CER) | |
| # if corr: | |
| # for raw in meds: | |
| # suggestions = corr.get(raw, []) | |
| # if suggestions: | |
| # spell_section += f"- **{raw}**\n" | |
| # for cand, score in suggestions: | |
| # cer = character_error_rate(cand, raw) | |
| # spell_section += ( | |
| # f" - {cand} " | |
| # f"(score={score:.3f}, CER={cer:.3f}%)\n" | |
| # ) | |
| # else: | |
| # spell_section += f"- **{raw}**\n - (no suggestions)\n" | |
| # final_md = md + spell_section | |
| # # 4) Final yield: raw unchanged; Markdown with NER + spell-check + CER | |
| # yield final_ocr_text, final_md | |
| # # ---------------------------- | |
| # # UI | |
| # # ---------------------------- | |
| # # IMPORTANT: examples must match the number of inputs (here: only image) | |
| # image_examples = [ | |
| # ["examples/3.jpg"], | |
| # ["examples/1.jpg"], | |
| # ["examples/2.jpg"], | |
| # ] | |
| # with gr.Blocks(css=css, theme=steel_blue_theme) as demo: | |
| # gr.Markdown( | |
| # "# **Handwritten Doctor's Prescription Reading**", elem_id="main-title" | |
| # ) | |
| # with gr.Row(): | |
| # with gr.Column(scale=2): | |
| # image_upload = gr.Image( | |
| # type="pil", label="Upload Image", height=290 | |
| # ) | |
| # image_submit = gr.Button("Submit", variant="primary") | |
| # gr.Examples( | |
| # examples=image_examples, | |
| # inputs=[image_upload], | |
| # label="Example Images", | |
| # ) | |
| # # Spell-check selection | |
| # spell_choice = gr.Radio( | |
| # choices=["TF-IDF + Phonetic", "SymSpell", "RapidFuzz"], | |
| # label="Select Spell-check Approach", | |
| # value="TF-IDF + Phonetic", | |
| # ) | |
| # with gr.Accordion("Advanced options", open=False): | |
| # max_new_tokens = gr.Slider( | |
| # label="Max new tokens", | |
| # minimum=1, | |
| # maximum=MAX_MAX_NEW_TOKENS, | |
| # step=1, | |
| # value=DEFAULT_MAX_NEW_TOKENS, | |
| # ) | |
| # temperature = gr.Slider( | |
| # label="Temperature", | |
| # minimum=0.1, | |
| # maximum=4.0, | |
| # step=0.1, | |
| # value=0.7, | |
| # ) | |
| # top_p = gr.Slider( | |
| # label="Top-p (nucleus sampling)", | |
| # minimum=0.05, | |
| # maximum=1.0, | |
| # step=0.05, | |
| # value=0.9, | |
| # ) | |
| # top_k = gr.Slider( | |
| # label="Top-k", | |
| # minimum=1, | |
| # maximum=1000, | |
| # step=1, | |
| # value=50, | |
| # ) | |
| # repetition_penalty = gr.Slider( | |
| # label="Repetition penalty", | |
| # minimum=1.0, | |
| # maximum=2.0, | |
| # step=0.05, | |
| # value=1.1, | |
| # ) | |
| # with gr.Column(scale=3): | |
| # gr.Markdown("## Output", elem_id="output-title") | |
| # output = gr.Textbox( | |
| # label="Raw Output Stream", | |
| # interactive=False, | |
| # lines=11, | |
| # show_copy_button=True, | |
| # ) | |
| # with gr.Accordion("(Result.md)", open=False): | |
| # markdown_output = gr.Markdown(label="(Result.Md)") | |
| # model_choice = gr.Radio( | |
| # choices=["Chandra-OCR", "Dots.OCR"], | |
| # label="Select OCR Model", | |
| # value="Chandra-OCR", | |
| # ) | |
| # # Hard-coded instruction text, passed as gr.State to match the 'text' parameter | |
| # query_state = gr.State( | |
| # "Extract medicine or drugs names along with dosage amount or quantity" | |
| # ) | |
| # image_submit.click( | |
| # fn=generate_image, | |
| # inputs=[ | |
| # model_choice, | |
| # query_state, | |
| # image_upload, | |
| # max_new_tokens, | |
| # temperature, | |
| # top_p, | |
| # top_k, | |
| # repetition_penalty, | |
| # spell_choice, | |
| # ], | |
| # outputs=[output, markdown_output], | |
| # ) | |
| # if __name__ == "__main__": | |
| # demo.queue(max_size=50).launch( | |
| # mcp_server=True, ssr_mode=False, show_error=True | |
| # ) | |
| ##################################### version 1 ####################################################### | |
| # import os | |
| # import time | |
| # from threading import Thread | |
| # from typing import Iterable, Dict, Any, Optional, List | |
| # import gradio as gr | |
| # import spaces | |
| # import torch | |
| # from PIL import Image | |
| # from transformers import ( | |
| # Qwen3VLForConditionalGeneration, | |
| # AutoModelForCausalLM, | |
| # AutoProcessor, | |
| # TextIteratorStreamer, | |
| # ) | |
| # from gradio.themes import Soft | |
| # from gradio.themes.utils import colors, fonts, sizes | |
| # # ----------------------------- | |
| # # Private repo: dynamic import | |
| # # ----------------------------- | |
| # import importlib.util | |
| # from huggingface_hub import hf_hub_download | |
| # REPO_ID = "IFMedTech/Medibot_OCR_model" # private backend repo | |
| # # Map filenames to exported class names | |
| # PY_MODULES = { | |
| # "ner.py": "ClinicalNER", | |
| # "tfidf_phonetic.py": "TfidfPhoneticMatcher", | |
| # "symspell_matcher.py": "SymSpellMatcher", | |
| # "rapidfuzz_matcher.py": "RapidFuzzMatcher", | |
| # # 'drug_dictionary.xlsx' is data, not a module | |
| # } | |
| # HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN") | |
| # def _dynamic_import(module_path: str, class_name: str): | |
| # spec = importlib.util.spec_from_file_location(class_name, module_path) | |
| # module = importlib.util.module_from_spec(spec) | |
| # spec.loader.exec_module(module) # type: ignore | |
| # return getattr(module, class_name) | |
| # # Load private classes and Excel dictionary | |
| # priv_classes: Dict[str, Any] = {} | |
| # drug_xlsx_path: Optional[str] = None | |
| # try: | |
| # if HF_TOKEN is None: | |
| # print("[Private] WARNING: HUGGINGFACE_TOKEN not set; NER/Spell-check will be unavailable.") | |
| # else: | |
| # for fname, cls in PY_MODULES.items(): | |
| # path = hf_hub_download(repo_id=REPO_ID, filename=fname, token=HF_TOKEN) | |
| # if cls: | |
| # priv_classes[cls] = _dynamic_import(path, cls) | |
| # print(f"[Private] Loaded class: {cls} from {fname}") | |
| # drug_xlsx_path = hf_hub_download(repo_id=REPO_ID, filename="Medibot_Drugs_Cleaned_Updated.xlsx", token=HF_TOKEN) | |
| # print(f"[Private] Downloaded Excel at: {drug_xlsx_path}") | |
| # except Exception as e: | |
| # print(f"[Private] ERROR loading private backend: {e}") | |
| # priv_classes = {} | |
| # drug_xlsx_path = None | |
| # # ---------------------------- | |
| # # THEME | |
| # # ---------------------------- | |
| # colors.steel_blue = colors.Color( | |
| # name="steel_blue", | |
| # c50="#EBF3F8", | |
| # c100="#D3E5F0", | |
| # c200="#A8CCE1", | |
| # c300="#7DB3D2", | |
| # c400="#529AC3", | |
| # c500="#4682B4", | |
| # c600="#3E72A0", | |
| # c700="#36638C", | |
| # c800="#2E5378", | |
| # c900="#264364", | |
| # c950="#1E3450", | |
| # ) | |
| # class SteelBlueTheme(Soft): | |
| # def __init__( | |
| # self, | |
| # *, | |
| # primary_hue: colors.Color | str = colors.gray, | |
| # secondary_hue: colors.Color | str = colors.steel_blue, | |
| # neutral_hue: colors.Color | str = colors.slate, | |
| # text_size: sizes.Size | str = sizes.text_lg, | |
| # font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| # fonts.GoogleFont("Outfit"), "Arial", "sans-serif", | |
| # ), | |
| # font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| # fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", | |
| # ), | |
| # ): | |
| # super().__init__( | |
| # primary_hue=primary_hue, | |
| # secondary_hue=secondary_hue, | |
| # neutral_hue=neutral_hue, | |
| # text_size=text_size, | |
| # font=font, | |
| # font_mono=font_mono, | |
| # ) | |
| # super().set( | |
| # background_fill_primary="*primary_50", | |
| # background_fill_primary_dark="*primary_900", | |
| # body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", | |
| # body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", | |
| # button_primary_text_color="white", | |
| # button_primary_text_color_hover="white", | |
| # button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| # button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| # button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)", | |
| # button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)", | |
| # button_secondary_text_color="black", | |
| # button_secondary_text_color_hover="white", | |
| # button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)", | |
| # button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)", | |
| # button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)", | |
| # button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)", | |
| # slider_color="*secondary_500", | |
| # slider_color_dark="*secondary_600", | |
| # block_title_text_weight="600", | |
| # block_border_width="3px", | |
| # block_shadow="*shadow_drop_lg", | |
| # button_primary_shadow="*shadow_drop_lg", | |
| # button_large_padding="11px", | |
| # color_accent_soft="*primary_100", | |
| # block_label_background_fill="*primary_200", | |
| # ) | |
| # steel_blue_theme = SteelBlueTheme() | |
| # css = """ | |
| # #main-title h1 { font-size: 2.3em !important; } | |
| # #output-title h2 { font-size: 2.1em !important; } | |
| # """ | |
| # # ---------------------------- | |
| # # RUNTIME / DEVICE | |
| # # ---------------------------- | |
| # os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0") | |
| # print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES")) | |
| # print("torch.__version__ =", torch.__version__) | |
| # print("torch.version.cuda =", torch.version.cuda) | |
| # print("cuda available =", torch.cuda.is_available()) | |
| # print("cuda device count =", torch.cuda.device_count()) | |
| # if torch.cuda.is_available(): | |
| # print("using device =", torch.cuda.get_device_name(0)) | |
| # use_cuda = torch.cuda.is_available() | |
| # device = torch.device("cuda:0" if use_cuda else "cpu") | |
| # if use_cuda: | |
| # torch.backends.cudnn.benchmark = True | |
| # DTYPE_FP16 = torch.float16 if use_cuda else torch.float32 | |
| # DTYPE_BF16 = torch.bfloat16 if use_cuda else torch.float32 | |
| # # ---------------------------- | |
| # # OCR MODELS: Chandra-OCR + Dots.OCR | |
| # # ---------------------------- | |
| # # 1) Chandra-OCR (Qwen3VL) | |
| # MODEL_ID_V = "datalab-to/chandra" | |
| # processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True) | |
| # model_v = Qwen3VLForConditionalGeneration.from_pretrained( | |
| # MODEL_ID_V, trust_remote_code=True, torch_dtype=DTYPE_FP16 | |
| # ).to(device).eval() | |
| # # 2) Dots.OCR (flash_attn2 if available, else SDPA) | |
| # MODEL_PATH_D = "prithivMLmods/Dots.OCR-Latest-BF16" | |
| # processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True) | |
| # attn_impl = "sdpa" | |
| # try: | |
| # import flash_attn # noqa: F401 | |
| # if use_cuda: | |
| # attn_impl = "flash_attention_2" | |
| # except Exception: | |
| # attn_impl = "sdpa" | |
| # model_d = AutoModelForCausalLM.from_pretrained( | |
| # MODEL_PATH_D, | |
| # attn_implementation=attn_impl, | |
| # torch_dtype=DTYPE_BF16, | |
| # device_map="auto" if use_cuda else None, | |
| # trust_remote_code=True | |
| # ).eval() | |
| # if not use_cuda: | |
| # model_d.to(device) | |
| # # ---------------------------- | |
| # # GENERATION (OCR → NER → Spell-check) | |
| # # ---------------------------- | |
| # MAX_MAX_NEW_TOKENS = 4096 | |
| # DEFAULT_MAX_NEW_TOKENS = 2048 | |
| # @spaces.GPU | |
| # def generate_image(model_name: str, | |
| # text: str, | |
| # image: Image.Image, | |
| # max_new_tokens: int, | |
| # temperature: float, | |
| # top_p: float, | |
| # top_k: int, | |
| # repetition_penalty: float, | |
| # spell_algo: str): | |
| # """ | |
| # 1) Stream OCR tokens to Raw output (unchanged). | |
| # 2) After stream completes, run ClinicalNER on final raw text → list[str] meds. | |
| # 3) Apply selected spell-check (TF-IDF+Phonetic / SymSpell / RapidFuzz) using Excel dict. | |
| # 4) Markdown shows OCR + NER list + spell-check top-5 suggestions with scores. | |
| # """ | |
| # if image is None: | |
| # yield "Please upload an image.", "Please upload an image." | |
| # return | |
| # if model_name == "Chandra-OCR": | |
| # processor, model = processor_v, model_v | |
| # elif model_name == "Dots.OCR": | |
| # processor, model = processor_d, model_d | |
| # else: | |
| # yield "Invalid model selected.", "Invalid model selected." | |
| # return | |
| # # Build prompt | |
| # messages = [{ | |
| # "role": "user", | |
| # "content": [ | |
| # {"type": "image"}, | |
| # {"type": "text", "text": text}, | |
| # ] | |
| # }] | |
| # prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| # # Preprocess | |
| # inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True) | |
| # inputs = {k: (v.to(device) if hasattr(v, "to") else v) for k, v in inputs.items()} | |
| # # Streamer | |
| # tokenizer = getattr(processor, "tokenizer", None) or processor | |
| # streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| # gen_kwargs = dict( | |
| # **inputs, | |
| # streamer=streamer, | |
| # max_new_tokens=max_new_tokens, | |
| # do_sample=True, | |
| # temperature=temperature, | |
| # top_p=top_p, | |
| # top_k=top_k, | |
| # repetition_penalty=repetition_penalty, | |
| # ) | |
| # # Start generation | |
| # thread = Thread(target=model.generate, kwargs=gen_kwargs) | |
| # thread.start() | |
| # # 1) Live OCR streaming to Raw (and mirror to Markdown during stream) | |
| # buffer = "" | |
| # for new_text in streamer: | |
| # buffer += new_text.replace("<|im_end|>", "") | |
| # time.sleep(0.01) | |
| # yield buffer, buffer | |
| # # Final raw text for downstream processing | |
| # final_ocr_text = buffer | |
| # # 2) Clinical NER (from private repo) | |
| # # meds: List[str] = [] | |
| # # try: | |
| # # if "ClinicalNER" in priv_classes: | |
| # # ClinicalNER = priv_classes["ClinicalNER"] | |
| # # ner = ClinicalNER(token=HF_TOKEN) # pass model_id=... if using your own model | |
| # # meds = ner(final_ocr_text) or [] | |
| # # else: | |
| # # print("[NER] ClinicalNER not available.") | |
| # # except Exception as e: | |
| # # print(f"[NER] Error running ClinicalNER: {e}") | |
| # raw_ocr_text = buffer.strip() | |
| # meds = [line.strip() for line in raw_ocr_text.split('\n') if line.strip()] | |
| # # Build Markdown with OCR + NER section | |
| # md = final_ocr_text | |
| # md += "\n\n---\n### Clinical NER (Medications)\n" | |
| # if meds: | |
| # for m in meds: | |
| # md += f"- {m}\n" | |
| # else: | |
| # md += "- None detected\n" | |
| # # 3) Spell-check on NER output using selected approach + Excel dict | |
| # spell_section = "\n---\n### Spell-check suggestions (" + spell_algo + ")\n" | |
| # corr: Dict[str, List] = {} | |
| # try: | |
| # if meds and drug_xlsx_path: | |
| # if spell_algo == "TF-IDF + Phonetic" and "TfidfPhoneticMatcher" in priv_classes: | |
| # Cls = priv_classes["TfidfPhoneticMatcher"] | |
| # checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs", ngram_size=3, phonetic_weight=0.4) | |
| # corr = checker.match_list(meds, top_k=5, tfidf_threshold=0.15) | |
| # elif spell_algo == "SymSpell" and "SymSpellMatcher" in priv_classes: | |
| # Cls = priv_classes["SymSpellMatcher"] | |
| # checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs", max_edit=2, prefix_len=7) | |
| # corr = checker.match_list(meds, top_k=5, min_score=0.4) | |
| # elif spell_algo == "RapidFuzz" and "RapidFuzzMatcher" in priv_classes: | |
| # Cls = priv_classes["RapidFuzzMatcher"] | |
| # checker = Cls(xlsx_path=drug_xlsx_path, column="Combined_Drugs") | |
| # corr = checker.match_list(meds, top_k=5, threshold=70.0) | |
| # else: | |
| # spell_section += "- Spell-check backend unavailable.\n" | |
| # else: | |
| # spell_section += "- No NER output or Excel dictionary missing.\n" | |
| # except Exception as e: | |
| # spell_section += f"- Spell-check error: {e}\n" | |
| # # Format suggestions (top-5 with scores) | |
| # if corr: | |
| # for raw in meds: | |
| # suggestions = corr.get(raw, []) | |
| # if suggestions: | |
| # spell_section += f"- **{raw}**\n" | |
| # for cand, score in suggestions: | |
| # spell_section += f" - {cand} (score={score:.3f})\n" | |
| # else: | |
| # spell_section += f"- **{raw}**\n - (no suggestions)\n" | |
| # final_md = md + spell_section | |
| # # 4) Final yield: raw unchanged; Markdown with NER + spell-check | |
| # yield final_ocr_text, final_md | |
| # # ---------------------------- | |
| # # UI | |
| # # ---------------------------- | |
| # image_examples = [ | |
| # ["OCR the content perfectly.", "examples/3.jpg"], | |
| # ["Perform OCR on the image.", "examples/1.jpg"], | |
| # ["Extract the contents. [page].", "examples/2.jpg"], | |
| # ] | |
| # with gr.Blocks(css=css, theme=steel_blue_theme) as demo: | |
| # gr.Markdown("# **Handwritten Doctor's Prescription Reading**", elem_id="main-title") | |
| # with gr.Row(): | |
| # with gr.Column(scale=2): | |
| # #image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...") | |
| # image_upload = gr.Image(type="pil", label="Upload Image", height=290) | |
| # image_submit = gr.Button("Submit", variant="primary") | |
| # gr.Examples(examples=image_examples, inputs=[image_upload]) | |
| # # Spell-check selection | |
| # spell_choice = gr.Radio( | |
| # choices=["TF-IDF + Phonetic", "SymSpell", "RapidFuzz"], | |
| # label="Select Spell-check Approach", | |
| # value="TF-IDF + Phonetic" | |
| # ) | |
| # with gr.Accordion("Advanced options", open=False): | |
| # max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS) | |
| # temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.7) | |
| # top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9) | |
| # top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50) | |
| # repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1) | |
| # with gr.Column(scale=3): | |
| # gr.Markdown("## Output", elem_id="output-title") | |
| # output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=11, show_copy_button=True) | |
| # with gr.Accordion("(Result.md)", open=False): | |
| # markdown_output = gr.Markdown(label="(Result.Md)") | |
| # model_choice = gr.Radio( | |
| # choices=["Chandra-OCR", "Dots.OCR"], | |
| # label="Select OCR Model", | |
| # value="Chandra-OCR" | |
| # ) | |
| # image_submit.click( | |
| # fn=generate_image, | |
| # inputs=[model_choice,gr.State("Extract medicine or drugs names along with dosage amount or quantity") , image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, spell_choice], | |
| # outputs=[output, markdown_output] | |
| # ) | |
| # if __name__ == "__main__": | |
| # demo.queue(max_size=50).launch(mcp_server=True, ssr_mode=False, show_error=True) | |