Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from torch import nn | |
| from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor | |
| # --- 1. CONFIGURACIÓN --- | |
| MODEL_PATH = "modelo_mejorado.pth" | |
| LABELS = [ | |
| "fondo", | |
| "wheat leaf rust", | |
| "wheat powdery mildew", | |
| "wheat septoria blotch", | |
| "wheat stem rust", | |
| "wheat stripe rust" | |
| ] | |
| PALETA_COLORES = [ | |
| [0, 0, 0], # Fondo (Transparente) | |
| [220, 38, 38], # Rust (Rojo intenso) | |
| [22, 163, 74], # Mildew (Verde hoja) | |
| [37, 99, 235], # Septoria (Azul fuerte) | |
| [234, 179, 8], # Stem Rust (Amarillo oro) | |
| [219, 39, 119] # Stripe Rust (Rosa fuerte) | |
| ] | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Usando dispositivo: {device}") | |
| # --- 2. CARGAR MODELO (Lógica igual a la anterior) --- | |
| checkpoint_name = "nvidia/segformer-b4-finetuned-ade-512-512" | |
| try: | |
| model_inference = SegformerForSemanticSegmentation.from_pretrained( | |
| checkpoint_name, | |
| num_labels=len(LABELS), | |
| id2label={i: label for i, label in enumerate(LABELS)}, | |
| label2id={label: i for i, label in enumerate(LABELS)}, | |
| ignore_mismatched_sizes=True | |
| ) | |
| state_dict = torch.load(MODEL_PATH, map_location=device) | |
| model_inference.load_state_dict(state_dict) | |
| model_inference.to(device) | |
| model_inference.eval() | |
| image_processor = SegformerImageProcessor.from_pretrained(checkpoint_name) | |
| image_processor.do_resize = False | |
| image_processor.do_rescale = True | |
| print("Modelo cargado correctamente.") | |
| except Exception as e: | |
| print(f"Error cargando modelo: {e}") | |
| # --- 3. FUNCIÓN DE PREDICCIÓN --- | |
| def predecir_enfermedad(image): | |
| if image is None: return None, "⚠️ Por favor sube una imagen primero." | |
| original_size = image.size | |
| img_resized = image.resize((512, 512)) | |
| inputs = image_processor(images=img_resized, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model_inference(**inputs) | |
| logits = outputs.logits | |
| logits_upsampled = nn.functional.interpolate(logits, size=(512, 512), mode="bilinear", align_corners=False) | |
| pred_mask = logits_upsampled.argmax(dim=1).squeeze().cpu().numpy() | |
| color_mask = np.zeros((512, 512, 3), dtype=np.uint8) | |
| classes_found = [] | |
| unique_classes = np.unique(pred_mask) | |
| for class_id in unique_classes: | |
| if class_id == 0: continue | |
| classes_found.append(LABELS[class_id]) | |
| color_mask[pred_mask == class_id] = PALETA_COLORES[class_id] | |
| mask_pil = Image.fromarray(color_mask).resize(original_size, resample=Image.NEAREST) | |
| final_image = Image.blend(image.convert("RGB"), mask_pil.convert("RGB"), alpha=0.45) | |
| if len(classes_found) > 0: | |
| # Formato Markdown para que se vea bonito el texto | |
| diagnosis = "### 🦠 Enfermedades Detectadas:\n" + "\n".join(f"- **{c.capitalize()}**" for c in classes_found) | |
| else: | |
| diagnosis = "### ✅ Planta Sana\nNo se detectaron enfermedades (Solo fondo)." | |
| return final_image, diagnosis | |
| # --- 4. DISEÑO DE LA INTERFAZ (Blocks) --- | |
| # CSS personalizado para darle estilo | |
| custom_css = """ | |
| .container { max-width: 1100px; margin: auto; padding-top: 20px; } | |
| h1 { text-align: center; color: #2E7D32; font-family: 'Helvetica', sans-serif; font-weight: bold; } | |
| .description { text-align: center; font-size: 1.1em; color: #555; margin-bottom: 20px; } | |
| .footer { text-align: center; margin-top: 30px; font-size: 0.8em; color: #888; } | |
| """ | |
| # Tema visual (Colores tierra/verde) | |
| theme = gr.themes.Soft( | |
| primary_hue="emerald", | |
| neutral_hue="stone", | |
| ).set( | |
| button_primary_background_fill="#2E7D32", | |
| button_primary_background_fill_hover="#1B5E20", | |
| button_primary_text_color="white", | |
| ) | |
| with gr.Blocks(theme=theme, css=custom_css, title="Wheat Disease AI") as demo: | |
| # Encabezado | |
| with gr.Column(elem_classes=["container"]): | |
| gr.Markdown("# 🌾 Detector Inteligente de Trigo") | |
| gr.Markdown("Sube una fotografía de la hoja de trigo para segmentar y diagnosticar enfermedades automáticamente.", elem_classes=["description"]) | |
| # Fila principal: Entrada a la izq, Salida a la derecha | |
| with gr.Row(): | |
| # Columna Izquierda (Entrada) | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="📸 Foto de la Hoja", height=400) | |
| # Botón grande y verde | |
| btn_predict = gr.Button("🔍 Analizar Enfermedad", variant="primary", scale=0) | |
| # Ejemplos (Opcional, si subes fotos de ejemplo a la carpeta) | |
| # gr.Examples(["ejemplo1.jpg"], inputs=input_image) | |
| # Columna Derecha (Salida) | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(type="pil", label="🧪 Resultado (Máscara)", height=400) | |
| output_text = gr.Markdown(label="Diagnóstico") # Usamos Markdown para texto enriquecido | |
| # Pie de página | |
| gr.Markdown("---") | |
| gr.Markdown("Modelo **SegFormer B4** | Entrenado en PyTorch | Implementación para Tesis", elem_classes=["footer"]) | |
| # Lógica del botón | |
| btn_predict.click( | |
| fn=predecir_enfermedad, | |
| inputs=input_image, | |
| outputs=[output_image, output_text] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |