telcom's picture
Update app.py
a344ac6 verified
# ============================================================
# Hugging Face Spaces GPU app
# IMPORTANT:
# - spaces MUST be imported first
# - @spaces.GPU MUST be used directly
# ============================================================
import spaces # MUST be first
import os
import random
import gc
import gradio as gr
import numpy as np
from PIL import Image
import torch
from diffusers import (
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
EulerAncestralDiscreteScheduler,
)
from transformers import CLIPTokenizer, CLIPTextModel
from huggingface_hub import login
# ============================================================
# Config
# ============================================================
MODEL_ID = "telcom/dee-unlearning-tiny-sd"
REVISION = "main"
HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
if HF_TOKEN:
login(token=HF_TOKEN)
device = torch.device("cuda")
dtype = torch.float16
IMAGE_SIZE = 512
MAX_SEED = np.iinfo(np.int32).max
# ============================================================
# Load model (once at startup)
# ============================================================
pipe_txt2img = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
revision=REVISION,
torch_dtype=dtype,
safety_checker=None,
).to(device)
# πŸ”‘ Force tokenizer + text encoder
pipe_txt2img.tokenizer = CLIPTokenizer.from_pretrained(
MODEL_ID, subfolder="tokenizer"
)
pipe_txt2img.text_encoder = CLIPTextModel.from_pretrained(
MODEL_ID,
subfolder="text_encoder",
torch_dtype=dtype,
).to(device)
# Scheduler
pipe_txt2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe_txt2img.scheduler.config
)
# Memory optimisations
pipe_txt2img.enable_attention_slicing()
pipe_txt2img.enable_vae_slicing()
try:
pipe_txt2img.enable_xformers_memory_efficient_attention()
except Exception:
pass
pipe_txt2img.set_progress_bar_config(disable=True)
# Img2Img pipeline
pipe_img2img = StableDiffusionImg2ImgPipeline(
**pipe_txt2img.components
).to(device)
pipe_img2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe_img2img.scheduler.config
)
# ============================================================
# GPU INFERENCE FUNCTION (Spaces requires this)
# ============================================================
@spaces.GPU
def infer(
prompt,
negative_prompt,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
init_image,
strength,
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
try:
with torch.inference_mode():
if init_image is not None:
image = pipe_img2img(
prompt=prompt,
negative_prompt=negative_prompt,
image=init_image,
strength=float(strength),
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_inference_steps),
generator=generator,
).images[0]
else:
image = pipe_txt2img(
prompt=prompt,
negative_prompt=negative_prompt,
width=IMAGE_SIZE,
height=IMAGE_SIZE,
guidance_scale=float(guidance_scale),
num_inference_steps=int(num_inference_steps),
generator=generator,
).images[0]
return image, f"Seed: {seed}"
finally:
gc.collect()
torch.cuda.empty_cache()
# ============================================================
# UI
# ============================================================
with gr.Blocks(title="Stable Diffusion (512Γ—512)") as demo:
gr.Markdown("## Stable Diffusion Generator (GPU, 512Γ—512)")
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want to generate",
lines=2,
)
init_image = gr.Image(
label="Initial image (optional, enables img2img)",
type="pil",
)
run_button = gr.Button("Generate")
result = gr.Image(label="Result")
status = gr.Markdown("")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(
label="Negative prompt",
value="nsfw, (low quality, worst quality:1.2), watermark, signature, ugly, deformed",
)
seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
randomize_seed = gr.Checkbox(True, label="Randomize seed")
guidance_scale = gr.Slider(1, 20, step=0.5, value=7.5, label="Guidance scale")
num_inference_steps = gr.Slider(1, 40, step=1, value=30, label="Steps")
strength = gr.Slider(0.0, 1.0, step=0.05, value=0.7, label="Image strength (img2img)")
run_button.click(
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
init_image,
strength,
],
outputs=[result, status],
)
demo.queue().launch(server_name="0.0.0.0", server_port=7860)