juancho2112's picture
Update app.py
29fe0a1 verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torch import nn
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
# --- 1. CONFIGURACIÓN ---
MODEL_PATH = "modelo_mejorado.pth"
LABELS = [
"fondo",
"wheat leaf rust",
"wheat powdery mildew",
"wheat septoria blotch",
"wheat stem rust",
"wheat stripe rust"
]
PALETA_COLORES = [
[0, 0, 0], # Fondo (Transparente)
[220, 38, 38], # Rust (Rojo intenso)
[22, 163, 74], # Mildew (Verde hoja)
[37, 99, 235], # Septoria (Azul fuerte)
[234, 179, 8], # Stem Rust (Amarillo oro)
[219, 39, 119] # Stripe Rust (Rosa fuerte)
]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Usando dispositivo: {device}")
# --- 2. CARGAR MODELO (Lógica igual a la anterior) ---
checkpoint_name = "nvidia/segformer-b4-finetuned-ade-512-512"
try:
model_inference = SegformerForSemanticSegmentation.from_pretrained(
checkpoint_name,
num_labels=len(LABELS),
id2label={i: label for i, label in enumerate(LABELS)},
label2id={label: i for i, label in enumerate(LABELS)},
ignore_mismatched_sizes=True
)
state_dict = torch.load(MODEL_PATH, map_location=device)
model_inference.load_state_dict(state_dict)
model_inference.to(device)
model_inference.eval()
image_processor = SegformerImageProcessor.from_pretrained(checkpoint_name)
image_processor.do_resize = False
image_processor.do_rescale = True
print("Modelo cargado correctamente.")
except Exception as e:
print(f"Error cargando modelo: {e}")
# --- 3. FUNCIÓN DE PREDICCIÓN ---
def predecir_enfermedad(image):
if image is None: return None, "⚠️ Por favor sube una imagen primero."
original_size = image.size
img_resized = image.resize((512, 512))
inputs = image_processor(images=img_resized, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model_inference(**inputs)
logits = outputs.logits
logits_upsampled = nn.functional.interpolate(logits, size=(512, 512), mode="bilinear", align_corners=False)
pred_mask = logits_upsampled.argmax(dim=1).squeeze().cpu().numpy()
color_mask = np.zeros((512, 512, 3), dtype=np.uint8)
classes_found = []
unique_classes = np.unique(pred_mask)
for class_id in unique_classes:
if class_id == 0: continue
classes_found.append(LABELS[class_id])
color_mask[pred_mask == class_id] = PALETA_COLORES[class_id]
mask_pil = Image.fromarray(color_mask).resize(original_size, resample=Image.NEAREST)
final_image = Image.blend(image.convert("RGB"), mask_pil.convert("RGB"), alpha=0.45)
if len(classes_found) > 0:
# Formato Markdown para que se vea bonito el texto
diagnosis = "### 🦠 Enfermedades Detectadas:\n" + "\n".join(f"- **{c.capitalize()}**" for c in classes_found)
else:
diagnosis = "### ✅ Planta Sana\nNo se detectaron enfermedades (Solo fondo)."
return final_image, diagnosis
# --- 4. DISEÑO DE LA INTERFAZ (Blocks) ---
# CSS personalizado para darle estilo
custom_css = """
.container { max-width: 1100px; margin: auto; padding-top: 20px; }
h1 { text-align: center; color: #2E7D32; font-family: 'Helvetica', sans-serif; font-weight: bold; }
.description { text-align: center; font-size: 1.1em; color: #555; margin-bottom: 20px; }
.footer { text-align: center; margin-top: 30px; font-size: 0.8em; color: #888; }
"""
# Tema visual (Colores tierra/verde)
theme = gr.themes.Soft(
primary_hue="emerald",
neutral_hue="stone",
).set(
button_primary_background_fill="#2E7D32",
button_primary_background_fill_hover="#1B5E20",
button_primary_text_color="white",
)
with gr.Blocks(theme=theme, css=custom_css, title="Wheat Disease AI") as demo:
# Encabezado
with gr.Column(elem_classes=["container"]):
gr.Markdown("# 🌾 Detector Inteligente de Trigo")
gr.Markdown("Sube una fotografía de la hoja de trigo para segmentar y diagnosticar enfermedades automáticamente.", elem_classes=["description"])
# Fila principal: Entrada a la izq, Salida a la derecha
with gr.Row():
# Columna Izquierda (Entrada)
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="📸 Foto de la Hoja", height=400)
# Botón grande y verde
btn_predict = gr.Button("🔍 Analizar Enfermedad", variant="primary", scale=0)
# Ejemplos (Opcional, si subes fotos de ejemplo a la carpeta)
# gr.Examples(["ejemplo1.jpg"], inputs=input_image)
# Columna Derecha (Salida)
with gr.Column(scale=1):
output_image = gr.Image(type="pil", label="🧪 Resultado (Máscara)", height=400)
output_text = gr.Markdown(label="Diagnóstico") # Usamos Markdown para texto enriquecido
# Pie de página
gr.Markdown("---")
gr.Markdown("Modelo **SegFormer B4** | Entrenado en PyTorch | Implementación para Tesis", elem_classes=["footer"])
# Lógica del botón
btn_predict.click(
fn=predecir_enfermedad,
inputs=input_image,
outputs=[output_image, output_text]
)
if __name__ == "__main__":
demo.launch()