File size: 2,955 Bytes
72647b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from torchvision import transforms
from PIL import Image
from fpdf import FPDF
from datetime import datetime
import gradio as gr
from model import get_model

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "ecg_multiclass.pth"
IMG_SIZE = (224, 224)

DESCRIPTIONS = {
    "Normal": "This ECG indicates a normal heart rhythm with no abnormalities detected.",
    "Myocardial_Infarction": "This ECG suggests a myocardial infarction (heart attack). Immediate medical attention is recommended.",
    "Abnormal_heartbeat": "This ECG shows an abnormal heartbeat pattern indicating possible arrhythmia or heart irregularities."
}

# Load model checkpoint
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
classes = checkpoint["classes"]
num_classes = len(classes)

model = get_model(num_classes=num_classes, weights=None).to(DEVICE)
model.load_state_dict(checkpoint["model_state"])
model.eval()

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])


def generate_report(image):
    if image is None:
        return None, "Please upload an image."

    # Convert
    img = Image.fromarray(image).convert("L")
    input_tensor = transform(img).unsqueeze(0).to(DEVICE)

    # Predict
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.softmax(output, dim=1)[0]
        pred_index = torch.argmax(probabilities).item()
        pred_class = classes[pred_index]

    # Prepare probability text
    prob_text = ""
    for i, cls in enumerate(classes):
        prob_text += f"{cls}: {probabilities[i] * 100:.2f}%\n"

    # Generate PDF
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    pdf_path = f"ecg_report_{timestamp}.pdf"

    pdf = FPDF()
    pdf.add_page()
    pdf.set_font("Arial", "B", 16)
    pdf.cell(0, 10, "ECG Prediction Report", ln=True, align="C")
    pdf.ln(10)

    pdf.set_font("Arial", "", 12)
    pdf.multi_cell(0, 10, f"Predicted Class: {pred_class}\n")
    pdf.multi_cell(0, 8, f"Probabilities:\n{prob_text}")
    pdf.multi_cell(0, 8, f"Description:\n{DESCRIPTIONS.get(pred_class, 'No description.')}")
    pdf.ln(10)

    # Save uploaded image temporarily for embedding
    img_path = f"temp_img_{timestamp}.png"
    img.save(img_path)

    pdf.image(img_path, x=30, w=150)
    pdf.output(pdf_path)

    return pdf_path, f"Prediction complete: {pred_class}"


# Gradio UI
interface = gr.Interface(
    fn=generate_report,
    inputs=gr.Image(type="numpy", label="Upload ECG Image"),
    outputs=[
        gr.File(label="Download PDF Report"),
        gr.Textbox(label="Status")
    ],
    title="ECG Classification & PDF Report Generator",
    description="Upload an ECG image to get an AI-generated PDF report."
)

interface.launch()