LLDDWW commited on
Commit
76e08e0
ยท
1 Parent(s): c990a0c

Rewrite Gradio app with proper parsing

Browse files
Files changed (1) hide show
  1. app.py +136 -91
app.py CHANGED
@@ -1,113 +1,158 @@
1
- # app.py (HF Space)
2
-
3
-
4
- # dose_per_intake: e.g., "1ํšŒ 1์ •", "1์ •", "5 mL"
5
- m_dose = re.search(r"(1ํšŒ\s*)?(\d+)\s*([๊ฐ€-ํžฃa-zA-Z]+|mL|ml|mg)", t)
6
- dose_per_intake = None
7
- if m_dose:
8
- dose_per_intake = f"{m_dose.group(2)} {m_dose.group(3)}"
9
-
10
-
11
- # drug name (heuristic): token before mg/mL or first uppercase-like word
12
- m_drug = re.search(r"([๊ฐ€-ํžฃA-Za-z]+)\s*(\d+\s*(mg|mL|ml))", t)
13
- drug_name = m_drug.group(1) if m_drug else None
14
-
15
-
16
- return {
17
- "drug_name": drug_name,
18
- "dose_per_intake": dose_per_intake,
19
- "times_per_day": times_per_day,
20
- "time_slots": time_slots or None,
21
- }
22
-
23
-
24
-
25
-
26
- def ocr_and_parse(img) -> Dict[str, Any]:
27
- """Run OCR then parse fields with basic validation."""
28
- # OCR output is a list of dicts with 'generated_text'
29
- raw = ocr(img)[0]["generated_text"]
30
- fields = parse_fields(raw)
31
-
32
-
33
- # basic validation messages
34
- warn = []
35
- if not fields["drug_name"]:
36
- warn.append("์•ฝ ์ด๋ฆ„ ์ธ์‹์ด ๋ถˆํ™•์‹คํ•ฉ๋‹ˆ๋‹ค.")
37
- if not fields["times_per_day"]:
38
- warn.append("1์ผ ํšŸ์ˆ˜๋ฅผ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค (์˜ˆ: 1์ผ 3ํšŒ).")
39
-
40
-
41
- return {"raw_text": raw, "fields": fields, "warnings": warn}
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
  def render_card(fields: Dict[str, Any]) -> Image.Image:
47
- """Draw a simple schedule card (PNG) from parsed fields."""
48
- W, H = 720, 400
49
- img = Image.new("RGB", (W, H), "white")
50
- d = ImageDraw.Draw(img)
51
- title = "์˜ค๋Š˜ ๋ณต์šฉ ์ผ์ •"
52
- d.rectangle((0, 0, W, 60), fill=(230, 240, 255))
53
- d.text((24, 18), title, fill=(0, 0, 0))
54
 
 
 
 
55
 
56
- y = 90
57
- def line(label, value):
58
- nonlocal y
59
- d.text((24, y), f"{label}", fill=(60, 60, 60))
60
- d.text((180, y), f": {value if value else '-'}", fill=(0, 0, 0))
61
- y += 34
62
 
 
 
 
 
 
 
63
 
64
- line("์•ฝ ์ด๋ฆ„", fields.get("drug_name"))
65
- line("1ํšŒ ์šฉ๋Ÿ‰", fields.get("dose_per_intake"))
66
- line("1์ผ ํšŸ์ˆ˜", fields.get("times_per_day"))
67
- slots = ", ".join(fields.get("time_slots") or [])
68
- line("์‹œ๊ฐ„๋Œ€", slots if slots else None)
69
 
 
 
70
 
71
- d.text((24, H-60), "โ€ป ์˜๋ฃŒ์ง„ ์ฒ˜๋ฐฉ ์šฐ์„ , ๋ณธ ์•ฑ์€ ์•ˆ๋‚ด์šฉ์ž…๋‹ˆ๋‹ค.", fill=(120, 120, 120))
72
- return img
 
73
 
74
 
 
 
 
 
 
 
 
 
 
75
 
76
 
77
- def run_pipeline(image):
78
- if image is None:
79
- return "์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”.", None, None
80
- out = ocr_and_parse(image)
81
- card = render_card(out["fields"])
82
- csv_row = to_csv_row(out)
83
- return json.dumps(out, ensure_ascii=False, indent=2), card, csv_row
84
 
85
-
86
-
87
-
88
- def to_csv_row(out: Dict[str, Any]) -> str:
89
- f = out["fields"]
90
- row = [
91
- f.get("drug_name") or "",
92
- f.get("dose_per_intake") or "",
93
- str(f.get("times_per_day") or ""),
94
- ";".join(f.get("time_slots") or []),
95
- ]
96
- return ",".join(row)
97
 
98
 
99
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
100
- gr.Markdown("# MedCard-KR ยท ์•ฝ๋ด‰ํˆฌ OCR โ†’ ๋ณต์šฉ ์ผ์ • ์นด๋“œ")
101
- with gr.Row():
102
- with gr.Column():
103
- img_in = gr.Image(type="pil", label="์•ฝ ๋ด‰ํˆฌ/๋ผ๋ฒจ ์‚ฌ์ง„")
104
- btn = gr.Button("์ธ์‹ & ์นด๋“œ ์ƒ์„ฑ", variant="primary")
105
- csv = gr.Textbox(label="CSV(์•ฝ๋ช…,1ํšŒ์šฉ๋Ÿ‰,1์ผํšŸ์ˆ˜,์‹œ๊ฐ„๋Œ€)")
106
- with gr.Column():
107
- json_out = gr.Code(label="์ธ์‹ ๊ฒฐ๊ณผ(JSON)")
108
- card = gr.Image(type="pil", label="์ผ์ • ์นด๋“œ(๋ฏธ๋ฆฌ๋ณด๊ธฐ)")
109
- btn.click(run_pipeline, inputs=img_in, outputs=[json_out, card, csv])
110
 
111
 
112
  if __name__ == "__main__":
113
- demo.queue().launch()
 
1
+ import json
2
+ import re
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import gradio as gr
6
+ from PIL import Image, ImageDraw
7
+ from transformers import pipeline
8
+
9
+ # --- OCR pipeline ---------------------------------------------------------
10
+ # We use a light-weight printed-text OCR model that works well for receipts/labels.
11
+ ocr = pipeline("image-to-text", model="microsoft/trocr-base-printed")
12
+
13
+ # Korean keywords describing time slots on prescription labels.
14
+ TIME_KEYWORDS = [
15
+ "์•„์นจ",
16
+ "์ ์‹ฌ",
17
+ "์ €๋…",
18
+ "์ทจ์นจ",
19
+ "์ž๊ธฐ",
20
+ "์‹์ „",
21
+ "์‹ํ›„",
22
+ "์‹๊ฐ„",
23
+ "๊ธฐ์ƒ",
24
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
+ def _extract_time_slots(text: str) -> List[str]:
28
+ slots = []
29
+ for kw in TIME_KEYWORDS:
30
+ if kw in text:
31
+ slots.append(kw)
32
+ # Also capture explicit times like 08:00 ํ˜น์€ 8์‹œ
33
+ for match in re.findall(r"(\d{1,2}[:์‹œ]\d{0,2})", text):
34
+ norm = match.replace("์‹œ", ":")
35
+ if norm.endswith(":"):
36
+ norm += "00"
37
+ if norm not in slots:
38
+ slots.append(norm)
39
+ return slots
40
+
41
+
42
+ def parse_fields(raw: str) -> Dict[str, Any]:
43
+ """Extract drug name and dosage information from OCR text."""
44
+ text = raw.replace("\n", " ")
45
+ text = re.sub(r"\s+", " ", text)
46
+
47
+ # 1) ์•ฝ ์ด๋ฆ„: ๋‹จ์–ด + ์šฉ๋Ÿ‰ ํŒจํ„ด ์ฃผ๋ณ€์—์„œ ์ฐพ๊ธฐ
48
+ drug_name: Optional[str] = None
49
+ drug_match = re.search(r"([๊ฐ€-ํžฃA-Za-z]+)\s*(\d+)\s*(mg|mL|ML|์ •)", text)
50
+ if drug_match:
51
+ drug_name = drug_match.group(1)
52
+ else:
53
+ fallback = re.search(r"([๊ฐ€-ํžฃA-Za-z]{2,})", text)
54
+ drug_name = fallback.group(1) if fallback else None
55
+
56
+ # 2) 1ํšŒ ์šฉ๋Ÿ‰: "1ํšŒ 1์ •", "1์ •", "5 mL" ๋“ฑ
57
+ dose_per_intake: Optional[str] = None
58
+ dose_match = re.search(r"(1ํšŒ\s*)?(\d+[\./]?\d*)\s*([๊ฐ€-ํžฃA-Za-z]+|mL|ml|mg|์ •)", text)
59
+ if dose_match:
60
+ dose_per_intake = f"{dose_match.group(2)} {dose_match.group(3)}".strip()
61
+
62
+ # 3) 1์ผ ๋ณต์šฉ ํšŸ์ˆ˜: "1์ผ 3ํšŒ", "ํ•˜๋ฃจ 2ํšŒ"
63
+ times_per_day: Optional[int] = None
64
+ times_match = re.search(r"(?:1์ผ|ํ•˜๋ฃจ)\s*(\d+)\s*ํšŒ", text)
65
+ if times_match:
66
+ times_per_day = int(times_match.group(1))
67
+
68
+ # 4) ์‹œ๊ฐ„๋Œ€ ํ‚ค์›Œ๋“œ/์‹œ๊ฐ ์ถ”์ถœ
69
+ time_slots = _extract_time_slots(text)
70
+
71
+ return {
72
+ "drug_name": drug_name,
73
+ "dose_per_intake": dose_per_intake,
74
+ "times_per_day": times_per_day,
75
+ "time_slots": time_slots or None,
76
+ }
77
+
78
+
79
+ def ocr_and_parse(image: Image.Image) -> Dict[str, Any]:
80
+ raw_text = ocr(image)[0]["generated_text"]
81
+ fields = parse_fields(raw_text)
82
+
83
+ warnings: List[str] = []
84
+ if not fields["drug_name"]:
85
+ warnings.append("์•ฝ ์ด๋ฆ„ ์ธ์‹์ด ๋ถˆํ™•์‹คํ•ฉ๋‹ˆ๋‹ค.")
86
+ if not fields["times_per_day"]:
87
+ warnings.append("1์ผ ํšŸ์ˆ˜๋ฅผ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค (์˜ˆ: 1์ผ 3ํšŒ).")
88
+
89
+ return {"raw_text": raw_text, "fields": fields, "warnings": warnings}
90
 
91
 
92
  def render_card(fields: Dict[str, Any]) -> Image.Image:
93
+ width, height = 720, 400
94
+ img = Image.new("RGB", (width, height), "white")
95
+ draw = ImageDraw.Draw(img)
 
 
 
 
96
 
97
+ header_text = "์˜ค๋Š˜ ๋ณต์šฉ ์ผ์ •"
98
+ draw.rectangle((0, 0, width, 60), fill=(230, 240, 255))
99
+ draw.text((24, 18), header_text, fill=(0, 0, 0))
100
 
101
+ y = 90
 
 
 
 
 
102
 
103
+ def add_line(label: str, value: Optional[str]):
104
+ nonlocal y
105
+ draw.text((24, y), label, fill=(60, 60, 60))
106
+ display = value if value else "-"
107
+ draw.text((180, y), f": {display}", fill=(0, 0, 0))
108
+ y += 34
109
 
110
+ add_line("์•ฝ ์ด๋ฆ„", fields.get("drug_name"))
111
+ add_line("1ํšŒ ์šฉ๋Ÿ‰", fields.get("dose_per_intake"))
112
+ add_line("1์ผ ํšŸ์ˆ˜", str(fields.get("times_per_day") or ""))
 
 
113
 
114
+ slots = fields.get("time_slots") or []
115
+ add_line("์‹œ๊ฐ„๋Œ€", ", ".join(slots) if slots else None)
116
 
117
+ footer = "โ€ป ์˜๋ฃŒ์ง„ ์ฒ˜๋ฐฉ์ด ์šฐ์„ ์ด๋ฉฐ, ๋ณธ ์•ฑ์€ ์ฐธ๊ณ ์šฉ์ž…๋‹ˆ๋‹ค."
118
+ draw.text((24, height - 60), footer, fill=(120, 120, 120))
119
+ return img
120
 
121
 
122
+ def to_csv_row(output: Dict[str, Any]) -> str:
123
+ fields = output["fields"]
124
+ row = [
125
+ fields.get("drug_name") or "",
126
+ fields.get("dose_per_intake") or "",
127
+ str(fields.get("times_per_day") or ""),
128
+ ";".join(fields.get("time_slots") or []),
129
+ ]
130
+ return ",".join(row)
131
 
132
 
133
+ def run_pipeline(image: Optional[Image.Image]):
134
+ if image is None:
135
+ return "์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”.", None, None
 
 
 
 
136
 
137
+ output = ocr_and_parse(image)
138
+ card = render_card(output["fields"])
139
+ csv_row = to_csv_row(output)
140
+ json_text = json.dumps(output, ensure_ascii=False, indent=2)
141
+ return json_text, card, csv_row
 
 
 
 
 
 
 
142
 
143
 
144
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
145
+ gr.Markdown("# MedCard-KR ยท ์•ฝ๋ด‰ํˆฌ OCR โ†’ ๋ณต์šฉ ์ผ์ • ์นด๋“œ")
146
+ with gr.Row():
147
+ with gr.Column():
148
+ img_in = gr.Image(type="pil", label="์•ฝ ๋ด‰ํˆฌ/๋ผ๋ฒจ ์‚ฌ์ง„")
149
+ btn = gr.Button("์ธ์‹ & ์นด๋“œ ์ƒ์„ฑ", variant="primary")
150
+ csv_box = gr.Textbox(label="CSV(์•ฝ๋ช…,1ํšŒ์šฉ๋Ÿ‰,1์ผํšŸ์ˆ˜,์‹œ๊ฐ„๋Œ€)")
151
+ with gr.Column():
152
+ json_out = gr.Code(label="์ธ์‹ ๊ฒฐ๊ณผ(JSON)")
153
+ card_out = gr.Image(type="pil", label="์ผ์ • ์นด๋“œ(๋ฏธ๋ฆฌ๋ณด๊ธฐ)")
154
+ btn.click(run_pipeline, inputs=img_in, outputs=[json_out, card_out, csv_box])
155
 
156
 
157
  if __name__ == "__main__":
158
+ demo.queue().launch()