import torch from torchvision import transforms from PIL import Image from fpdf import FPDF from datetime import datetime import gradio as gr from model import get_model DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_PATH = "ecg_multiclass.pth" IMG_SIZE = (224, 224) DESCRIPTIONS = { "Normal": "This ECG indicates a normal heart rhythm with no abnormalities detected.", "Myocardial_Infarction": "This ECG suggests a myocardial infarction (heart attack). Immediate medical attention is recommended.", "Abnormal_heartbeat": "This ECG shows an abnormal heartbeat pattern indicating possible arrhythmia or heart irregularities." } # Load model checkpoint checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) classes = checkpoint["classes"] num_classes = len(classes) model = get_model(num_classes=num_classes,pretrained=False) model.load_state_dict(checkpoint["model_state"]) model.eval() transform = transforms.Compose([ transforms.Grayscale(), transforms.Resize(IMG_SIZE), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) def generate_report(image): if image is None: return None, "Please upload an image." # Convert img = Image.fromarray(image).convert("L") input_tensor = transform(img).unsqueeze(0).to(DEVICE) # Predict with torch.no_grad(): output = model(input_tensor) probabilities = torch.softmax(output, dim=1)[0] pred_index = torch.argmax(probabilities).item() pred_class = classes[pred_index] # Prepare probability texts prob_text = "" for i, cls in enumerate(classes): prob_text += f"{cls}: {probabilities[i] * 100:.2f}%\n" # Generate PDF timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") pdf_path = f"ecg_report_{timestamp}.pdf" pdf = FPDF() pdf.add_page() pdf.set_font("Arial", "B", 16) pdf.cell(0, 10, "ECG Prediction Report", ln=True, align="C") pdf.ln(10) pdf.set_font("Arial", "", 12) pdf.multi_cell(0, 10, f"Predicted Class: {pred_class}\n") pdf.multi_cell(0, 8, f"Probabilities:\n{prob_text}") pdf.multi_cell(0, 8, f"Description:\n{DESCRIPTIONS.get(pred_class, 'No description.')}") pdf.ln(10) # Save uploaded image temporarily for embedding img_path = f"temp_img_{timestamp}.png" img.save(img_path) pdf.image(img_path, x=30, w=150) pdf.output(pdf_path) return pdf_path, f"Prediction complete: {pred_class}" # Gradio UI interface = gr.Interface( fn=generate_report, inputs=gr.Image(type="numpy", label="Upload ECG Image"), outputs=[ gr.File(label="Download PDF Report"), gr.Textbox(label="Status") ], title="ECG Classification & PDF Report Generator", description="Upload an ECG image to get an AI-generated PDF report." ) interface.launch()