|
|
import gradio as gr |
|
|
import torch |
|
|
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler |
|
|
from PIL import Image |
|
|
import tempfile |
|
|
import os |
|
|
import gc |
|
|
|
|
|
|
|
|
MODEL_REALISTIC = "stabilityai/stable-diffusion-xl-base-1.0" |
|
|
MODEL_PONY_REALISM = "john6666/pony-realism-v23-sdxl" |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
|
|
|
|
|
pipe_realistic = None |
|
|
pipe_pony = None |
|
|
|
|
|
def load_model(model_id): |
|
|
"""Carga un modelo SDXL con optimizaciones.""" |
|
|
print(f"⏳ Cargando modelo: {model_id}") |
|
|
|
|
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=dtype, |
|
|
use_safetensors=True, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
|
|
|
|
|
|
|
|
if device == "cuda" and hasattr(torch, "compile"): |
|
|
try: |
|
|
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) |
|
|
except Exception as e: |
|
|
print(f"⚠️ No se pudo compilar el modelo: {e}") |
|
|
|
|
|
pipe.to(device) |
|
|
|
|
|
print(f"✅ Modelo {model_id.split('/')[-1]} listo") |
|
|
return pipe |
|
|
|
|
|
|
|
|
print("⏳ Cargando modelos iniciales...") |
|
|
pipe_realistic = load_model(MODEL_REALISTIC) |
|
|
pipe_pony = load_model(MODEL_PONY_REALISM) |
|
|
print("✅ Todos los modelos iniciales listos") |
|
|
|
|
|
|
|
|
def generate_image(prompt, negative_prompt, model_choice, steps, guidance, width, height, seed): |
|
|
global pipe_realistic, pipe_pony |
|
|
|
|
|
pipe = pipe_pony if model_choice == "pony" else pipe_realistic |
|
|
|
|
|
generator = None |
|
|
if seed >= 0: |
|
|
generator = torch.Generator(device=pipe.device).manual_seed(seed) |
|
|
|
|
|
print(f"🔄 Generando imagen con '{model_choice}'...") |
|
|
|
|
|
try: |
|
|
result = pipe( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
num_inference_steps=steps, |
|
|
guidance_scale=guidance, |
|
|
width=width, |
|
|
height=height, |
|
|
generator=generator |
|
|
) |
|
|
|
|
|
image = result.images[0] |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file: |
|
|
image.save(temp_file.name, format="JPEG", quality=95) |
|
|
temp_filepath = temp_file.name |
|
|
|
|
|
return image, temp_filepath |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error al generar imagen: {e}") |
|
|
return None, None |
|
|
|
|
|
finally: |
|
|
|
|
|
gc.collect() |
|
|
if device == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="violet")) as demo: |
|
|
gr.Markdown("# 🖼️✨ Generador de Imágenes Hiperrealistas con SDXL") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prompt = gr.Textbox( |
|
|
label="Prompt (Descripción positiva)", |
|
|
placeholder="Ej: photorealistic portrait in 9:16, cinematic, natural light...", |
|
|
lines=3 |
|
|
) |
|
|
negative_prompt = gr.Textbox( |
|
|
label="Negative Prompt (Lo que NO quieres)", |
|
|
placeholder="Ej: blurry, text, low quality, deformed...", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
model_choice = gr.Radio( |
|
|
choices=[ |
|
|
("Realista Puro", "realistic"), |
|
|
("Pony Realism", "pony") |
|
|
], |
|
|
label="Modelo", |
|
|
value="realistic" |
|
|
) |
|
|
seed = gr.Number(value=-1, label="Seed (-1 = aleatorio)", precision=0) |
|
|
|
|
|
with gr.Accordion("Ajustes Avanzados", open=False): |
|
|
steps = gr.Slider(20, 80, value=40, step=5, label="Pasos de inferencia") |
|
|
guidance = gr.Slider(1, 15, value=7.5, step=0.5, label="Guidance scale (Creatividad)") |
|
|
with gr.Row(): |
|
|
width = gr.Slider(512, 1024, value=768, step=64, label="Ancho") |
|
|
height = gr.Slider(512, 1024, value=768, step=64, label="Alto") |
|
|
|
|
|
btn = gr.Button("🎨 Generar Imagen", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_img = gr.Image(label="Resultado", type="pil", height=512) |
|
|
download_btn = gr.File(label="Descargar JPG") |
|
|
|
|
|
btn.click( |
|
|
fn=generate_image, |
|
|
inputs=[prompt, negative_prompt, model_choice, steps, guidance, width, height, seed], |
|
|
outputs=[output_img, download_btn] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |