File size: 4,764 Bytes
05c7848
 
9837a1e
05c7848
9837a1e
363f1a6
9837a1e
05c7848
363f1a6
05c7848
 
 
363f1a6
05c7848
9837a1e
 
363f1a6
9837a1e
 
05c7848
 
363f1a6
9837a1e
363f1a6
05c7848
 
9837a1e
05c7848
9837a1e
05c7848
 
9837a1e
 
363f1a6
 
 
 
 
 
 
9837a1e
 
 
 
 
363f1a6
9837a1e
05c7848
 
9837a1e
 
363f1a6
9837a1e
 
05c7848
363f1a6
05c7848
 
 
 
 
363f1a6
05c7848
9837a1e
363f1a6
 
 
 
 
 
 
 
 
 
 
 
 
9837a1e
363f1a6
9837a1e
363f1a6
9837a1e
05c7848
9837a1e
363f1a6
9837a1e
05c7848
363f1a6
 
 
 
 
 
9837a1e
363f1a6
05c7848
9837a1e
05c7848
 
 
 
9837a1e
363f1a6
9837a1e
05c7848
9837a1e
 
363f1a6
9837a1e
05c7848
363f1a6
9837a1e
 
363f1a6
 
 
 
 
 
9837a1e
 
 
 
 
 
 
 
 
363f1a6
9837a1e
363f1a6
05c7848
9837a1e
363f1a6
05c7848
 
 
9837a1e
05c7848
 
 
 
363f1a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
from PIL import Image
import tempfile
import os
import gc

# --- CONFIGURACIÓN DE MODELOS ---
MODEL_REALISTIC = "stabilityai/stable-diffusion-xl-base-1.0"
MODEL_PONY_REALISM = "john6666/pony-realism-v23-sdxl"

# Detectar dispositivo
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

# Variables globales para los pipes
pipe_realistic = None
pipe_pony = None

def load_model(model_id):
    """Carga un modelo SDXL con optimizaciones."""
    print(f"⏳ Cargando modelo: {model_id}")
    
    pipe = StableDiffusionXLPipeline.from_pretrained(
        model_id,
        torch_dtype=dtype,
        use_safetensors=True,
        low_cpu_mem_usage=True
    )

    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

    # Compilación segura solo si torch >= 2.0 y dispositivo es CUDA
    if device == "cuda" and hasattr(torch, "compile"):
        try:
            pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
        except Exception as e:
            print(f"⚠️ No se pudo compilar el modelo: {e}")

    pipe.to(device)

    print(f"✅ Modelo {model_id.split('/')[-1]} listo")
    return pipe

# Cargar modelos al inicio
print("⏳ Cargando modelos iniciales...")
pipe_realistic = load_model(MODEL_REALISTIC)
pipe_pony = load_model(MODEL_PONY_REALISM)
print("✅ Todos los modelos iniciales listos")

# --- FUNCIÓN DE GENERACIÓN ---
def generate_image(prompt, negative_prompt, model_choice, steps, guidance, width, height, seed):
    global pipe_realistic, pipe_pony

    pipe = pipe_pony if model_choice == "pony" else pipe_realistic

    generator = None
    if seed >= 0:
        generator = torch.Generator(device=pipe.device).manual_seed(seed)

    print(f"🔄 Generando imagen con '{model_choice}'...")

    try:
        result = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_inference_steps=steps,
            guidance_scale=guidance,
            width=width,
            height=height,
            generator=generator
        )

        image = result.images[0]

        # Guardar temporalmente
        with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
            image.save(temp_file.name, format="JPEG", quality=95)
            temp_filepath = temp_file.name

        return image, temp_filepath

    except Exception as e:
        print(f"❌ Error al generar imagen: {e}")
        return None, None

    finally:
        # Liberar memoria
        gc.collect()
        if device == "cuda":
            torch.cuda.empty_cache()


# --- INTERFAZ GRADIO ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="violet")) as demo:
    gr.Markdown("# 🖼️✨ Generador de Imágenes Hiperrealistas con SDXL")

    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(
                label="Prompt (Descripción positiva)",
                placeholder="Ej: photorealistic portrait in 9:16, cinematic, natural light...",
                lines=3
            )
            negative_prompt = gr.Textbox(
                label="Negative Prompt (Lo que NO quieres)",
                placeholder="Ej: blurry, text, low quality, deformed...",
                lines=2
            )

            with gr.Row():
                model_choice = gr.Radio(
                    choices=[
                        ("Realista Puro", "realistic"),
                        ("Pony Realism", "pony")
                    ],
                    label="Modelo",
                    value="realistic"
                )
                seed = gr.Number(value=-1, label="Seed (-1 = aleatorio)", precision=0)

            with gr.Accordion("Ajustes Avanzados", open=False):
                steps = gr.Slider(20, 80, value=40, step=5, label="Pasos de inferencia")
                guidance = gr.Slider(1, 15, value=7.5, step=0.5, label="Guidance scale (Creatividad)")
                with gr.Row():
                    width = gr.Slider(512, 1024, value=768, step=64, label="Ancho")
                    height = gr.Slider(512, 1024, value=768, step=64, label="Alto")

            btn = gr.Button("🎨 Generar Imagen", variant="primary")

        with gr.Column():
            output_img = gr.Image(label="Resultado", type="pil", height=512)
            download_btn = gr.File(label="Descargar JPG")

    btn.click(
        fn=generate_image,
        inputs=[prompt, negative_prompt, model_choice, steps, guidance, width, height, seed],
        outputs=[output_img, download_btn]
    )

if __name__ == "__main__":
    demo.launch()