Sayed223 commited on
Commit
72647b2
·
verified ·
1 Parent(s): 16e3031

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +8 -12
  2. app.py +94 -0
  3. ecg_multiclass.pth +3 -0
  4. model.py +17 -0
  5. requirements.txt +5 -0
README.md CHANGED
@@ -1,12 +1,8 @@
1
- ---
2
- title: Ecg Ai Classifier
3
- emoji: 🐨
4
- colorFrom: indigo
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # ECG Classification AI
2
+
3
+ Upload an ECG image and receive:
4
+ - Predicted class
5
+ - Probability scores
6
+ - Automatically generated PDF report
7
+
8
+ Model: Custom CNN (3-class)
 
 
 
 
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ from fpdf import FPDF
5
+ from datetime import datetime
6
+ import gradio as gr
7
+ from model import get_model
8
+
9
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ MODEL_PATH = "ecg_multiclass.pth"
11
+ IMG_SIZE = (224, 224)
12
+
13
+ DESCRIPTIONS = {
14
+ "Normal": "This ECG indicates a normal heart rhythm with no abnormalities detected.",
15
+ "Myocardial_Infarction": "This ECG suggests a myocardial infarction (heart attack). Immediate medical attention is recommended.",
16
+ "Abnormal_heartbeat": "This ECG shows an abnormal heartbeat pattern indicating possible arrhythmia or heart irregularities."
17
+ }
18
+
19
+ # Load model checkpoint
20
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
21
+ classes = checkpoint["classes"]
22
+ num_classes = len(classes)
23
+
24
+ model = get_model(num_classes=num_classes, weights=None).to(DEVICE)
25
+ model.load_state_dict(checkpoint["model_state"])
26
+ model.eval()
27
+
28
+ transform = transforms.Compose([
29
+ transforms.Grayscale(),
30
+ transforms.Resize(IMG_SIZE),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize([0.5], [0.5])
33
+ ])
34
+
35
+
36
+ def generate_report(image):
37
+ if image is None:
38
+ return None, "Please upload an image."
39
+
40
+ # Convert
41
+ img = Image.fromarray(image).convert("L")
42
+ input_tensor = transform(img).unsqueeze(0).to(DEVICE)
43
+
44
+ # Predict
45
+ with torch.no_grad():
46
+ output = model(input_tensor)
47
+ probabilities = torch.softmax(output, dim=1)[0]
48
+ pred_index = torch.argmax(probabilities).item()
49
+ pred_class = classes[pred_index]
50
+
51
+ # Prepare probability text
52
+ prob_text = ""
53
+ for i, cls in enumerate(classes):
54
+ prob_text += f"{cls}: {probabilities[i] * 100:.2f}%\n"
55
+
56
+ # Generate PDF
57
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
58
+ pdf_path = f"ecg_report_{timestamp}.pdf"
59
+
60
+ pdf = FPDF()
61
+ pdf.add_page()
62
+ pdf.set_font("Arial", "B", 16)
63
+ pdf.cell(0, 10, "ECG Prediction Report", ln=True, align="C")
64
+ pdf.ln(10)
65
+
66
+ pdf.set_font("Arial", "", 12)
67
+ pdf.multi_cell(0, 10, f"Predicted Class: {pred_class}\n")
68
+ pdf.multi_cell(0, 8, f"Probabilities:\n{prob_text}")
69
+ pdf.multi_cell(0, 8, f"Description:\n{DESCRIPTIONS.get(pred_class, 'No description.')}")
70
+ pdf.ln(10)
71
+
72
+ # Save uploaded image temporarily for embedding
73
+ img_path = f"temp_img_{timestamp}.png"
74
+ img.save(img_path)
75
+
76
+ pdf.image(img_path, x=30, w=150)
77
+ pdf.output(pdf_path)
78
+
79
+ return pdf_path, f"Prediction complete: {pred_class}"
80
+
81
+
82
+ # Gradio UI
83
+ interface = gr.Interface(
84
+ fn=generate_report,
85
+ inputs=gr.Image(type="numpy", label="Upload ECG Image"),
86
+ outputs=[
87
+ gr.File(label="Download PDF Report"),
88
+ gr.Textbox(label="Status")
89
+ ],
90
+ title="ECG Classification & PDF Report Generator",
91
+ description="Upload an ECG image to get an AI-generated PDF report."
92
+ )
93
+
94
+ interface.launch()
ecg_multiclass.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9bf90aa76d576a335b5c6c3c3f49c374d6d5ba6775c5db80678e63612a868739
3
+ size 44765643
model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+ import torch.nn as nn
3
+ from torchvision.models import resnet18
4
+
5
+ def get_model(num_classes, pretrained=True):
6
+ """
7
+ Returns a CNN model adapted for grayscale ECG images
8
+ """
9
+ model = resnet18(pretrained=pretrained)
10
+
11
+ # Change first layer to accept 1-channel input (grayscale)
12
+ model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
13
+
14
+ # Change the output layer for our number of classes
15
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
16
+
17
+ return model
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pillow
4
+ fpdf
5
+ gradio