Sayed223 commited on
Commit
cf13548
·
verified ·
1 Parent(s): 13417b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -94
app.py CHANGED
@@ -1,94 +1,94 @@
1
- import torch
2
- from torchvision import transforms
3
- from PIL import Image
4
- from fpdf import FPDF
5
- from datetime import datetime
6
- import gradio as gr
7
- from model import get_model
8
-
9
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
- MODEL_PATH = "ecg_multiclass.pth"
11
- IMG_SIZE = (224, 224)
12
-
13
- DESCRIPTIONS = {
14
- "Normal": "This ECG indicates a normal heart rhythm with no abnormalities detected.",
15
- "Myocardial_Infarction": "This ECG suggests a myocardial infarction (heart attack). Immediate medical attention is recommended.",
16
- "Abnormal_heartbeat": "This ECG shows an abnormal heartbeat pattern indicating possible arrhythmia or heart irregularities."
17
- }
18
-
19
- # Load model checkpoint
20
- checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
21
- classes = checkpoint["classes"]
22
- num_classes = len(classes)
23
-
24
- model = get_model(num_classes=num_classes, weights=None).to(DEVICE)
25
- model.load_state_dict(checkpoint["model_state"])
26
- model.eval()
27
-
28
- transform = transforms.Compose([
29
- transforms.Grayscale(),
30
- transforms.Resize(IMG_SIZE),
31
- transforms.ToTensor(),
32
- transforms.Normalize([0.5], [0.5])
33
- ])
34
-
35
-
36
- def generate_report(image):
37
- if image is None:
38
- return None, "Please upload an image."
39
-
40
- # Convert
41
- img = Image.fromarray(image).convert("L")
42
- input_tensor = transform(img).unsqueeze(0).to(DEVICE)
43
-
44
- # Predict
45
- with torch.no_grad():
46
- output = model(input_tensor)
47
- probabilities = torch.softmax(output, dim=1)[0]
48
- pred_index = torch.argmax(probabilities).item()
49
- pred_class = classes[pred_index]
50
-
51
- # Prepare probability text
52
- prob_text = ""
53
- for i, cls in enumerate(classes):
54
- prob_text += f"{cls}: {probabilities[i] * 100:.2f}%\n"
55
-
56
- # Generate PDF
57
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
58
- pdf_path = f"ecg_report_{timestamp}.pdf"
59
-
60
- pdf = FPDF()
61
- pdf.add_page()
62
- pdf.set_font("Arial", "B", 16)
63
- pdf.cell(0, 10, "ECG Prediction Report", ln=True, align="C")
64
- pdf.ln(10)
65
-
66
- pdf.set_font("Arial", "", 12)
67
- pdf.multi_cell(0, 10, f"Predicted Class: {pred_class}\n")
68
- pdf.multi_cell(0, 8, f"Probabilities:\n{prob_text}")
69
- pdf.multi_cell(0, 8, f"Description:\n{DESCRIPTIONS.get(pred_class, 'No description.')}")
70
- pdf.ln(10)
71
-
72
- # Save uploaded image temporarily for embedding
73
- img_path = f"temp_img_{timestamp}.png"
74
- img.save(img_path)
75
-
76
- pdf.image(img_path, x=30, w=150)
77
- pdf.output(pdf_path)
78
-
79
- return pdf_path, f"Prediction complete: {pred_class}"
80
-
81
-
82
- # Gradio UI
83
- interface = gr.Interface(
84
- fn=generate_report,
85
- inputs=gr.Image(type="numpy", label="Upload ECG Image"),
86
- outputs=[
87
- gr.File(label="Download PDF Report"),
88
- gr.Textbox(label="Status")
89
- ],
90
- title="ECG Classification & PDF Report Generator",
91
- description="Upload an ECG image to get an AI-generated PDF report."
92
- )
93
-
94
- interface.launch()
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ from fpdf import FPDF
5
+ from datetime import datetime
6
+ import gradio as gr
7
+ from model import get_model
8
+
9
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ MODEL_PATH = "ecg_multiclass.pth"
11
+ IMG_SIZE = (224, 224)
12
+
13
+ DESCRIPTIONS = {
14
+ "Normal": "This ECG indicates a normal heart rhythm with no abnormalities detected.",
15
+ "Myocardial_Infarction": "This ECG suggests a myocardial infarction (heart attack). Immediate medical attention is recommended.",
16
+ "Abnormal_heartbeat": "This ECG shows an abnormal heartbeat pattern indicating possible arrhythmia or heart irregularities."
17
+ }
18
+
19
+ # Load model checkpoint
20
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
21
+ classes = checkpoint["classes"]
22
+ num_classes = len(classes)
23
+
24
+ model = get_model(num_classes=num_classes,pretrained=False)
25
+ model.load_state_dict(checkpoint["model_state"])
26
+ model.eval()
27
+
28
+ transform = transforms.Compose([
29
+ transforms.Grayscale(),
30
+ transforms.Resize(IMG_SIZE),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize([0.5], [0.5])
33
+ ])
34
+
35
+
36
+ def generate_report(image):
37
+ if image is None:
38
+ return None, "Please upload an image."
39
+
40
+ # Convert
41
+ img = Image.fromarray(image).convert("L")
42
+ input_tensor = transform(img).unsqueeze(0).to(DEVICE)
43
+
44
+ # Predict
45
+ with torch.no_grad():
46
+ output = model(input_tensor)
47
+ probabilities = torch.softmax(output, dim=1)[0]
48
+ pred_index = torch.argmax(probabilities).item()
49
+ pred_class = classes[pred_index]
50
+
51
+ # Prepare probability text
52
+ prob_text = ""
53
+ for i, cls in enumerate(classes):
54
+ prob_text += f"{cls}: {probabilities[i] * 100:.2f}%\n"
55
+
56
+ # Generate PDF
57
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
58
+ pdf_path = f"ecg_report_{timestamp}.pdf"
59
+
60
+ pdf = FPDF()
61
+ pdf.add_page()
62
+ pdf.set_font("Arial", "B", 16)
63
+ pdf.cell(0, 10, "ECG Prediction Report", ln=True, align="C")
64
+ pdf.ln(10)
65
+
66
+ pdf.set_font("Arial", "", 12)
67
+ pdf.multi_cell(0, 10, f"Predicted Class: {pred_class}\n")
68
+ pdf.multi_cell(0, 8, f"Probabilities:\n{prob_text}")
69
+ pdf.multi_cell(0, 8, f"Description:\n{DESCRIPTIONS.get(pred_class, 'No description.')}")
70
+ pdf.ln(10)
71
+
72
+ # Save uploaded image temporarily for embedding
73
+ img_path = f"temp_img_{timestamp}.png"
74
+ img.save(img_path)
75
+
76
+ pdf.image(img_path, x=30, w=150)
77
+ pdf.output(pdf_path)
78
+
79
+ return pdf_path, f"Prediction complete: {pred_class}"
80
+
81
+
82
+ # Gradio UI
83
+ interface = gr.Interface(
84
+ fn=generate_report,
85
+ inputs=gr.Image(type="numpy", label="Upload ECG Image"),
86
+ outputs=[
87
+ gr.File(label="Download PDF Report"),
88
+ gr.Textbox(label="Status")
89
+ ],
90
+ title="ECG Classification & PDF Report Generator",
91
+ description="Upload an ECG image to get an AI-generated PDF report."
92
+ )
93
+
94
+ interface.launch()