Sayed223's picture
Update app.py
a97991a verified
import torch
from torchvision import transforms
from PIL import Image
from fpdf import FPDF
from datetime import datetime
import gradio as gr
from model import get_model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "ecg_multiclass.pth"
IMG_SIZE = (224, 224)
DESCRIPTIONS = {
"Normal": "This ECG indicates a normal heart rhythm with no abnormalities detected.",
"Myocardial_Infarction": "This ECG suggests a myocardial infarction (heart attack). Immediate medical attention is recommended.",
"Abnormal_heartbeat": "This ECG shows an abnormal heartbeat pattern indicating possible arrhythmia or heart irregularities."
}
# Load model checkpoint
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
classes = checkpoint["classes"]
num_classes = len(classes)
model = get_model(num_classes=num_classes,pretrained=False)
model.load_state_dict(checkpoint["model_state"])
model.eval()
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
def generate_report(image):
if image is None:
return None, "Please upload an image."
# Convert
img = Image.fromarray(image).convert("L")
input_tensor = transform(img).unsqueeze(0).to(DEVICE)
# Predict
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.softmax(output, dim=1)[0]
pred_index = torch.argmax(probabilities).item()
pred_class = classes[pred_index]
# Prepare probability texts
prob_text = ""
for i, cls in enumerate(classes):
prob_text += f"{cls}: {probabilities[i] * 100:.2f}%\n"
# Generate PDF
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
pdf_path = f"ecg_report_{timestamp}.pdf"
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", "B", 16)
pdf.cell(0, 10, "ECG Prediction Report", ln=True, align="C")
pdf.ln(10)
pdf.set_font("Arial", "", 12)
pdf.multi_cell(0, 10, f"Predicted Class: {pred_class}\n")
pdf.multi_cell(0, 8, f"Probabilities:\n{prob_text}")
pdf.multi_cell(0, 8, f"Description:\n{DESCRIPTIONS.get(pred_class, 'No description.')}")
pdf.ln(10)
# Save uploaded image temporarily for embedding
img_path = f"temp_img_{timestamp}.png"
img.save(img_path)
pdf.image(img_path, x=30, w=150)
pdf.output(pdf_path)
return pdf_path, f"Prediction complete: {pred_class}"
# Gradio UI
interface = gr.Interface(
fn=generate_report,
inputs=gr.Image(type="numpy", label="Upload ECG Image"),
outputs=[
gr.File(label="Download PDF Report"),
gr.Textbox(label="Status")
],
title="ECG Classification & PDF Report Generator",
description="Upload an ECG image to get an AI-generated PDF report."
)
interface.launch()