Spaces:
Runtime error
Runtime error
File size: 5,465 Bytes
29fe0a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torch import nn
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
# --- 1. CONFIGURACIÓN ---
MODEL_PATH = "modelo_mejorado.pth"
LABELS = [
"fondo",
"wheat leaf rust",
"wheat powdery mildew",
"wheat septoria blotch",
"wheat stem rust",
"wheat stripe rust"
]
PALETA_COLORES = [
[0, 0, 0], # Fondo (Transparente)
[220, 38, 38], # Rust (Rojo intenso)
[22, 163, 74], # Mildew (Verde hoja)
[37, 99, 235], # Septoria (Azul fuerte)
[234, 179, 8], # Stem Rust (Amarillo oro)
[219, 39, 119] # Stripe Rust (Rosa fuerte)
]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Usando dispositivo: {device}")
# --- 2. CARGAR MODELO (Lógica igual a la anterior) ---
checkpoint_name = "nvidia/segformer-b4-finetuned-ade-512-512"
try:
model_inference = SegformerForSemanticSegmentation.from_pretrained(
checkpoint_name,
num_labels=len(LABELS),
id2label={i: label for i, label in enumerate(LABELS)},
label2id={label: i for i, label in enumerate(LABELS)},
ignore_mismatched_sizes=True
)
state_dict = torch.load(MODEL_PATH, map_location=device)
model_inference.load_state_dict(state_dict)
model_inference.to(device)
model_inference.eval()
image_processor = SegformerImageProcessor.from_pretrained(checkpoint_name)
image_processor.do_resize = False
image_processor.do_rescale = True
print("Modelo cargado correctamente.")
except Exception as e:
print(f"Error cargando modelo: {e}")
# --- 3. FUNCIÓN DE PREDICCIÓN ---
def predecir_enfermedad(image):
if image is None: return None, "⚠️ Por favor sube una imagen primero."
original_size = image.size
img_resized = image.resize((512, 512))
inputs = image_processor(images=img_resized, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model_inference(**inputs)
logits = outputs.logits
logits_upsampled = nn.functional.interpolate(logits, size=(512, 512), mode="bilinear", align_corners=False)
pred_mask = logits_upsampled.argmax(dim=1).squeeze().cpu().numpy()
color_mask = np.zeros((512, 512, 3), dtype=np.uint8)
classes_found = []
unique_classes = np.unique(pred_mask)
for class_id in unique_classes:
if class_id == 0: continue
classes_found.append(LABELS[class_id])
color_mask[pred_mask == class_id] = PALETA_COLORES[class_id]
mask_pil = Image.fromarray(color_mask).resize(original_size, resample=Image.NEAREST)
final_image = Image.blend(image.convert("RGB"), mask_pil.convert("RGB"), alpha=0.45)
if len(classes_found) > 0:
# Formato Markdown para que se vea bonito el texto
diagnosis = "### 🦠 Enfermedades Detectadas:\n" + "\n".join(f"- **{c.capitalize()}**" for c in classes_found)
else:
diagnosis = "### ✅ Planta Sana\nNo se detectaron enfermedades (Solo fondo)."
return final_image, diagnosis
# --- 4. DISEÑO DE LA INTERFAZ (Blocks) ---
# CSS personalizado para darle estilo
custom_css = """
.container { max-width: 1100px; margin: auto; padding-top: 20px; }
h1 { text-align: center; color: #2E7D32; font-family: 'Helvetica', sans-serif; font-weight: bold; }
.description { text-align: center; font-size: 1.1em; color: #555; margin-bottom: 20px; }
.footer { text-align: center; margin-top: 30px; font-size: 0.8em; color: #888; }
"""
# Tema visual (Colores tierra/verde)
theme = gr.themes.Soft(
primary_hue="emerald",
neutral_hue="stone",
).set(
button_primary_background_fill="#2E7D32",
button_primary_background_fill_hover="#1B5E20",
button_primary_text_color="white",
)
with gr.Blocks(theme=theme, css=custom_css, title="Wheat Disease AI") as demo:
# Encabezado
with gr.Column(elem_classes=["container"]):
gr.Markdown("# 🌾 Detector Inteligente de Trigo")
gr.Markdown("Sube una fotografía de la hoja de trigo para segmentar y diagnosticar enfermedades automáticamente.", elem_classes=["description"])
# Fila principal: Entrada a la izq, Salida a la derecha
with gr.Row():
# Columna Izquierda (Entrada)
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="📸 Foto de la Hoja", height=400)
# Botón grande y verde
btn_predict = gr.Button("🔍 Analizar Enfermedad", variant="primary", scale=0)
# Ejemplos (Opcional, si subes fotos de ejemplo a la carpeta)
# gr.Examples(["ejemplo1.jpg"], inputs=input_image)
# Columna Derecha (Salida)
with gr.Column(scale=1):
output_image = gr.Image(type="pil", label="🧪 Resultado (Máscara)", height=400)
output_text = gr.Markdown(label="Diagnóstico") # Usamos Markdown para texto enriquecido
# Pie de página
gr.Markdown("---")
gr.Markdown("Modelo **SegFormer B4** | Entrenado en PyTorch | Implementación para Tesis", elem_classes=["footer"])
# Lógica del botón
btn_predict.click(
fn=predecir_enfermedad,
inputs=input_image,
outputs=[output_image, output_text]
)
if __name__ == "__main__":
demo.launch() |