File size: 5,465 Bytes
29fe0a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torch import nn
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor

# --- 1. CONFIGURACIÓN ---
MODEL_PATH = "modelo_mejorado.pth"

LABELS = [
    "fondo",
    "wheat leaf rust",
    "wheat powdery mildew",
    "wheat septoria blotch",
    "wheat stem rust",
    "wheat stripe rust"
]

PALETA_COLORES = [
    [0, 0, 0],       # Fondo (Transparente)
    [220, 38, 38],   # Rust (Rojo intenso)
    [22, 163, 74],   # Mildew (Verde hoja)
    [37, 99, 235],   # Septoria (Azul fuerte)
    [234, 179, 8],   # Stem Rust (Amarillo oro)
    [219, 39, 119]   # Stripe Rust (Rosa fuerte)
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Usando dispositivo: {device}")

# --- 2. CARGAR MODELO (Lógica igual a la anterior) ---
checkpoint_name = "nvidia/segformer-b4-finetuned-ade-512-512"

try:
    model_inference = SegformerForSemanticSegmentation.from_pretrained(
        checkpoint_name,
        num_labels=len(LABELS),
        id2label={i: label for i, label in enumerate(LABELS)},
        label2id={label: i for i, label in enumerate(LABELS)},
        ignore_mismatched_sizes=True
    )
    state_dict = torch.load(MODEL_PATH, map_location=device)
    model_inference.load_state_dict(state_dict)
    model_inference.to(device)
    model_inference.eval()

    image_processor = SegformerImageProcessor.from_pretrained(checkpoint_name)
    image_processor.do_resize = False 
    image_processor.do_rescale = True 
    print("Modelo cargado correctamente.")
except Exception as e:
    print(f"Error cargando modelo: {e}")

# --- 3. FUNCIÓN DE PREDICCIÓN ---
def predecir_enfermedad(image):
    if image is None: return None, "⚠️ Por favor sube una imagen primero."
    
    original_size = image.size 
    img_resized = image.resize((512, 512))
    
    inputs = image_processor(images=img_resized, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model_inference(**inputs)
        logits = outputs.logits
        logits_upsampled = nn.functional.interpolate(logits, size=(512, 512), mode="bilinear", align_corners=False)
        pred_mask = logits_upsampled.argmax(dim=1).squeeze().cpu().numpy()

    color_mask = np.zeros((512, 512, 3), dtype=np.uint8)
    classes_found = []
    unique_classes = np.unique(pred_mask)
    
    for class_id in unique_classes:
        if class_id == 0: continue
        classes_found.append(LABELS[class_id])
        color_mask[pred_mask == class_id] = PALETA_COLORES[class_id]

    mask_pil = Image.fromarray(color_mask).resize(original_size, resample=Image.NEAREST)
    final_image = Image.blend(image.convert("RGB"), mask_pil.convert("RGB"), alpha=0.45)
    
    if len(classes_found) > 0:
        # Formato Markdown para que se vea bonito el texto
        diagnosis = "### 🦠 Enfermedades Detectadas:\n" + "\n".join(f"- **{c.capitalize()}**" for c in classes_found)
    else:
        diagnosis = "### ✅ Planta Sana\nNo se detectaron enfermedades (Solo fondo)."

    return final_image, diagnosis

# --- 4. DISEÑO DE LA INTERFAZ (Blocks) ---

# CSS personalizado para darle estilo
custom_css = """
.container { max-width: 1100px; margin: auto; padding-top: 20px; }
h1 { text-align: center; color: #2E7D32; font-family: 'Helvetica', sans-serif; font-weight: bold; }
.description { text-align: center; font-size: 1.1em; color: #555; margin-bottom: 20px; }
.footer { text-align: center; margin-top: 30px; font-size: 0.8em; color: #888; }
"""

# Tema visual (Colores tierra/verde)
theme = gr.themes.Soft(
    primary_hue="emerald",
    neutral_hue="stone",
).set(
    button_primary_background_fill="#2E7D32",
    button_primary_background_fill_hover="#1B5E20",
    button_primary_text_color="white",
)

with gr.Blocks(theme=theme, css=custom_css, title="Wheat Disease AI") as demo:
    
    # Encabezado
    with gr.Column(elem_classes=["container"]):
        gr.Markdown("# 🌾 Detector Inteligente de Trigo")
        gr.Markdown("Sube una fotografía de la hoja de trigo para segmentar y diagnosticar enfermedades automáticamente.", elem_classes=["description"])
        
        # Fila principal: Entrada a la izq, Salida a la derecha
        with gr.Row():
            # Columna Izquierda (Entrada)
            with gr.Column(scale=1):
                input_image = gr.Image(type="pil", label="📸 Foto de la Hoja", height=400)
                # Botón grande y verde
                btn_predict = gr.Button("🔍 Analizar Enfermedad", variant="primary", scale=0)
                
                # Ejemplos (Opcional, si subes fotos de ejemplo a la carpeta)
                # gr.Examples(["ejemplo1.jpg"], inputs=input_image)

            # Columna Derecha (Salida)
            with gr.Column(scale=1):
                output_image = gr.Image(type="pil", label="🧪 Resultado (Máscara)", height=400)
                output_text = gr.Markdown(label="Diagnóstico") # Usamos Markdown para texto enriquecido
        
        # Pie de página
        gr.Markdown("---")
        gr.Markdown("Modelo **SegFormer B4** | Entrenado en PyTorch | Implementación para Tesis", elem_classes=["footer"])

    # Lógica del botón
    btn_predict.click(
        fn=predecir_enfermedad, 
        inputs=input_image, 
        outputs=[output_image, output_text]
    )

if __name__ == "__main__":
    demo.launch()