File size: 5,647 Bytes
c2623c6
 
 
 
 
 
 
 
 
 
 
 
 
 
d8bc12c
 
c2623c6
 
 
 
 
 
 
476db41
c2623c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5398f0
476db41
 
d8bc12c
 
 
 
 
 
 
 
c2623c6
d8bc12c
c2623c6
 
d8bc12c
c2623c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8bc12c
 
c2623c6
 
 
 
d8bc12c
 
c2623c6
 
 
476db41
 
 
 
c2623c6
 
 
 
 
 
 
 
 
 
 
 
 
 
d8bc12c
 
c2623c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476db41
b5398f0
78f791e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import torch
import gc
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class AIImageGeneratorNSFW:
    def __init__(self):
        self.pipeline = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_id = "segmind/Segmind-DE-XL"
        self.lora_id = "urn:air:sdxl:lora:civitai:141300@341068"
        self.is_model_loaded = False
        logger.info(f"Inicializando en dispositivo: {self.device}")

    def load_model(self):
        if self.is_model_loaded:
            return True
        try:
            logger.info("Cargando modelo base NSFW con LoRA y optimización...")
            torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
            
            tokenizer_1 = CLIPTokenizer.from_pretrained(self.model_id, subfolder="tokenizer", use_fast=False)
            tokenizer_2 = CLIPTokenizer.from_pretrained(self.model_id, subfolder="tokenizer_2", use_fast=False)
            
            text_encoder_1 = CLIPTextModel.from_pretrained(self.model_id, subfolder="text_encoder", torch_dtype=torch_dtype, low_cpu_mem_usage=True)
            text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(self.model_id, subfolder="text_encoder_2", torch_dtype=torch_dtype, low_cpu_mem_usage=True)
            
            self.pipeline = StableDiffusionXLPipeline.from_pretrained(
                self.model_id,
                tokenizer=[tokenizer_1, tokenizer_2],
                text_encoder=[text_encoder_1, text_encoder_2],
                torch_dtype=torch_dtype,
                scheduler=EulerDiscreteScheduler.from_pretrained(self.model_id, subfolder="scheduler"),
                safety_checker=None,
                use_safetensors=True,
                variant="fp16" if self.device == "cuda" else None
            )
            
            # Carga LoRA con método actual
            self.pipeline.load_lora(self.lora_id, weight=1.0) 
            
            if self.device == "cuda":
                self.pipeline.enable_model_cpu_offload()
                self.pipeline.enable_vae_slicing()
                self.pipeline.unet = torch.compile(self.pipeline.unet, mode='reduce-overhead')
                self.pipeline.text_encoder_1 = torch.compile(text_encoder_1, mode='reduce-overhead')
                self.pipeline.text_encoder_2 = torch.compile(text_encoder_2, mode='reduce-overhead')
            
            self.is_model_loaded = True
            logger.info("Modelo NSFW con LoRA cargado y optimizado correctamente.")
            return True
        except Exception as e:
            logger.error(f"Error cargando modelo NSFW con LoRA: {e}")
            return False

    def generate_image(self, prompt, width=1024, height=576, steps=35, guidance_scale=12.0):
        if not self.is_model_loaded and not self.load_model():
            return None
        try:
            with torch.inference_mode():
                generator = torch.Generator(self.device).manual_seed(torch.randint(0, 2**32, (1,)).item())
                result = self.pipeline(
                    prompt=prompt,
                    width=(width // 8) * 8,
                    height=(height // 8) * 8,
                    num_inference_steps=steps,
                    guidance_scale=guidance_scale,
                    generator=generator,
                    output_type="pil"
                )
                gc.collect()
                if self.device == "cuda":
                    torch.cuda.empty_cache()
                return result.images[0]
        except Exception as e:
            logger.error(f"Error generando imagen NSFW: {e}")
            gc.collect()
            if self.device == "cuda":
                torch.cuda.empty_cache()
            return None

def initialize_generator_nsfw():
    global generator_nsfw
    if 'generator_nsfw' not in globals():
        globals()['generator_nsfw'] = AIImageGeneratorNSFW()
    return globals()['generator_nsfw']

def generate_image_nsfw(prompt, width, height, steps, guidance_scale):
    gen = initialize_generator_nsfw()
    if not prompt.strip():
        return None
    return gen.generate_image(
        prompt=prompt,
        width=int(width),
        height=int(height),
        steps=int(steps),
        guidance_scale=float(guidance_scale)
    )

def create_nsfw_interface():
    with gr.Blocks(title="Generador de Imágenes NSFW con IA - Stable Diffusion XL") as iface:
        gr.Markdown("# 🎨 Generador NSFW basado en Stable Diffusion XL\n_Uso responsable y solo para adultos_")

        prompt = gr.Textbox(label="Prompt para la imagen NSFW", placeholder="Describe el contenido explícito...", lines=3)
        width = gr.Slider(512, 1536, value=1024, step=8, label="Ancho (pixeles)")
        height = gr.Slider(512, 1536, value=576, step=8, label="Alto (pixeles)")
        steps = gr.Slider(10, 50, value=35, step=1, label="Pasos de inferencia")
        guidance_scale = gr.Slider(1.0, 20.0, value=12.0, step=0.1, label="Escala de guía")

        btn_generate = gr.Button("Generar Imagen NSFW")
        img_output = gr.Image(label="Imagen generada")

        btn_generate.click(
            fn=generate_image_nsfw,
            inputs=[prompt, width, height, steps, guidance_scale],
            outputs=img_output
        )
    return iface

# Declarar interfaz global para Hugging Face Spaces
nsfw_app = create_nsfw_interface()

if __name__ == "__main__":
    nsfw_app.launch()