import gradio as gr import torch import numpy as np from PIL import Image from torch import nn from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor # --- 1. CONFIGURACIÓN --- MODEL_PATH = "modelo_mejorado.pth" LABELS = [ "fondo", "wheat leaf rust", "wheat powdery mildew", "wheat septoria blotch", "wheat stem rust", "wheat stripe rust" ] PALETA_COLORES = [ [0, 0, 0], # Fondo (Transparente) [220, 38, 38], # Rust (Rojo intenso) [22, 163, 74], # Mildew (Verde hoja) [37, 99, 235], # Septoria (Azul fuerte) [234, 179, 8], # Stem Rust (Amarillo oro) [219, 39, 119] # Stripe Rust (Rosa fuerte) ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Usando dispositivo: {device}") # --- 2. CARGAR MODELO (Lógica igual a la anterior) --- checkpoint_name = "nvidia/segformer-b4-finetuned-ade-512-512" try: model_inference = SegformerForSemanticSegmentation.from_pretrained( checkpoint_name, num_labels=len(LABELS), id2label={i: label for i, label in enumerate(LABELS)}, label2id={label: i for i, label in enumerate(LABELS)}, ignore_mismatched_sizes=True ) state_dict = torch.load(MODEL_PATH, map_location=device) model_inference.load_state_dict(state_dict) model_inference.to(device) model_inference.eval() image_processor = SegformerImageProcessor.from_pretrained(checkpoint_name) image_processor.do_resize = False image_processor.do_rescale = True print("Modelo cargado correctamente.") except Exception as e: print(f"Error cargando modelo: {e}") # --- 3. FUNCIÓN DE PREDICCIÓN --- def predecir_enfermedad(image): if image is None: return None, "⚠️ Por favor sube una imagen primero." original_size = image.size img_resized = image.resize((512, 512)) inputs = image_processor(images=img_resized, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model_inference(**inputs) logits = outputs.logits logits_upsampled = nn.functional.interpolate(logits, size=(512, 512), mode="bilinear", align_corners=False) pred_mask = logits_upsampled.argmax(dim=1).squeeze().cpu().numpy() color_mask = np.zeros((512, 512, 3), dtype=np.uint8) classes_found = [] unique_classes = np.unique(pred_mask) for class_id in unique_classes: if class_id == 0: continue classes_found.append(LABELS[class_id]) color_mask[pred_mask == class_id] = PALETA_COLORES[class_id] mask_pil = Image.fromarray(color_mask).resize(original_size, resample=Image.NEAREST) final_image = Image.blend(image.convert("RGB"), mask_pil.convert("RGB"), alpha=0.45) if len(classes_found) > 0: # Formato Markdown para que se vea bonito el texto diagnosis = "### 🦠 Enfermedades Detectadas:\n" + "\n".join(f"- **{c.capitalize()}**" for c in classes_found) else: diagnosis = "### ✅ Planta Sana\nNo se detectaron enfermedades (Solo fondo)." return final_image, diagnosis # --- 4. DISEÑO DE LA INTERFAZ (Blocks) --- # CSS personalizado para darle estilo custom_css = """ .container { max-width: 1100px; margin: auto; padding-top: 20px; } h1 { text-align: center; color: #2E7D32; font-family: 'Helvetica', sans-serif; font-weight: bold; } .description { text-align: center; font-size: 1.1em; color: #555; margin-bottom: 20px; } .footer { text-align: center; margin-top: 30px; font-size: 0.8em; color: #888; } """ # Tema visual (Colores tierra/verde) theme = gr.themes.Soft( primary_hue="emerald", neutral_hue="stone", ).set( button_primary_background_fill="#2E7D32", button_primary_background_fill_hover="#1B5E20", button_primary_text_color="white", ) with gr.Blocks(theme=theme, css=custom_css, title="Wheat Disease AI") as demo: # Encabezado with gr.Column(elem_classes=["container"]): gr.Markdown("# 🌾 Detector Inteligente de Trigo") gr.Markdown("Sube una fotografía de la hoja de trigo para segmentar y diagnosticar enfermedades automáticamente.", elem_classes=["description"]) # Fila principal: Entrada a la izq, Salida a la derecha with gr.Row(): # Columna Izquierda (Entrada) with gr.Column(scale=1): input_image = gr.Image(type="pil", label="📸 Foto de la Hoja", height=400) # Botón grande y verde btn_predict = gr.Button("🔍 Analizar Enfermedad", variant="primary", scale=0) # Ejemplos (Opcional, si subes fotos de ejemplo a la carpeta) # gr.Examples(["ejemplo1.jpg"], inputs=input_image) # Columna Derecha (Salida) with gr.Column(scale=1): output_image = gr.Image(type="pil", label="🧪 Resultado (Máscara)", height=400) output_text = gr.Markdown(label="Diagnóstico") # Usamos Markdown para texto enriquecido # Pie de página gr.Markdown("---") gr.Markdown("Modelo **SegFormer B4** | Entrenado en PyTorch | Implementación para Tesis", elem_classes=["footer"]) # Lógica del botón btn_predict.click( fn=predecir_enfermedad, inputs=input_image, outputs=[output_image, output_text] ) if __name__ == "__main__": demo.launch()