Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| from fpdf import FPDF | |
| from datetime import datetime | |
| import gradio as gr | |
| from model import get_model | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL_PATH = "ecg_multiclass.pth" | |
| IMG_SIZE = (224, 224) | |
| DESCRIPTIONS = { | |
| "Normal": "This ECG indicates a normal heart rhythm with no abnormalities detected.", | |
| "Myocardial_Infarction": "This ECG suggests a myocardial infarction (heart attack). Immediate medical attention is recommended.", | |
| "Abnormal_heartbeat": "This ECG shows an abnormal heartbeat pattern indicating possible arrhythmia or heart irregularities." | |
| } | |
| # Load model checkpoint | |
| checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) | |
| classes = checkpoint["classes"] | |
| num_classes = len(classes) | |
| model = get_model(num_classes=num_classes,pretrained=False) | |
| model.load_state_dict(checkpoint["model_state"]) | |
| model.eval() | |
| transform = transforms.Compose([ | |
| transforms.Grayscale(), | |
| transforms.Resize(IMG_SIZE), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]) | |
| ]) | |
| def generate_report(image): | |
| if image is None: | |
| return None, "Please upload an image." | |
| # Convert | |
| img = Image.fromarray(image).convert("L") | |
| input_tensor = transform(img).unsqueeze(0).to(DEVICE) | |
| # Predict | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| probabilities = torch.softmax(output, dim=1)[0] | |
| pred_index = torch.argmax(probabilities).item() | |
| pred_class = classes[pred_index] | |
| # Prepare probability texts | |
| prob_text = "" | |
| for i, cls in enumerate(classes): | |
| prob_text += f"{cls}: {probabilities[i] * 100:.2f}%\n" | |
| # Generate PDF | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| pdf_path = f"ecg_report_{timestamp}.pdf" | |
| pdf = FPDF() | |
| pdf.add_page() | |
| pdf.set_font("Arial", "B", 16) | |
| pdf.cell(0, 10, "ECG Prediction Report", ln=True, align="C") | |
| pdf.ln(10) | |
| pdf.set_font("Arial", "", 12) | |
| pdf.multi_cell(0, 10, f"Predicted Class: {pred_class}\n") | |
| pdf.multi_cell(0, 8, f"Probabilities:\n{prob_text}") | |
| pdf.multi_cell(0, 8, f"Description:\n{DESCRIPTIONS.get(pred_class, 'No description.')}") | |
| pdf.ln(10) | |
| # Save uploaded image temporarily for embedding | |
| img_path = f"temp_img_{timestamp}.png" | |
| img.save(img_path) | |
| pdf.image(img_path, x=30, w=150) | |
| pdf.output(pdf_path) | |
| return pdf_path, f"Prediction complete: {pred_class}" | |
| # Gradio UI | |
| interface = gr.Interface( | |
| fn=generate_report, | |
| inputs=gr.Image(type="numpy", label="Upload ECG Image"), | |
| outputs=[ | |
| gr.File(label="Download PDF Report"), | |
| gr.Textbox(label="Status") | |
| ], | |
| title="ECG Classification & PDF Report Generator", | |
| description="Upload an ECG image to get an AI-generated PDF report." | |
| ) | |
| interface.launch() | |