Vibeforge / app.py
rickveloper's picture
Update app.py
1ed16aa verified
import os, random, re, torch
from typing import List, Tuple
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
# =========================
# SPEED PRESET
# =========================
# Use SD Turbo (1.5) – optimized for very few steps on CPU
DEFAULT_MODEL_ID = "stabilityai/sd-turbo"
MODEL_ID = os.getenv("MODEL_ID", DEFAULT_MODEL_ID)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
# Short NSFW guardrail (block, no blur)
NSFW_TERMS = [
r"\bnsfw\b", r"\bnude\b", r"\bnudity\b", r"\bsex\b", r"\bexplicit\b", r"\bporn\b",
r"\bboobs\b", r"\bbutt\b", r"\bass\b", r"\bnaked\b", r"\btits\b",
r"\b18\+\b", r"\berotic\b", r"\bfetish\b"
]
NSFW_REGEX = re.compile("|".join(NSFW_TERMS), flags=re.IGNORECASE)
def _blocked_tile(reason: str, w=384, h=384) -> Image.Image:
img = Image.new("RGB", (w, h), (18, 20, 26))
d = ImageDraw.Draw(img)
text = f"BLOCKED\n{reason}"
try:
font = ImageFont.truetype("DejaVuSans-Bold.ttf", 26)
except:
font = ImageFont.load_default()
box = d.multiline_textbbox((0,0), text, font=font, align="center")
tw, th = box[2]-box[0], box[3]-box[1]
d.multiline_text(((w-tw)//2, (h-th)//2), text, font=font, fill=(255,255,255), align="center")
return img
def _is_nsfw(s: str) -> bool:
return bool(NSFW_REGEX.search(s or ""))
# -------------------------
# Load pipeline (fast path)
# -------------------------
torch.set_grad_enabled(False)
pipe = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
safety_checker=None # let model config handle; we block explicitly on prompts
)
# Turbo still benefits from DPMSolver for CPU
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
if DEVICE == "cuda":
pipe = pipe.to("cuda")
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
else:
pipe = pipe.to("cpu")
# -------------------------
# Generate fn (kept lean)
# -------------------------
def generate(
prompt: str,
negative_prompt: str,
steps: int,
guidance: float,
width: int,
height: int,
seed: int,
batch_size: int
) -> Tuple[List[Image.Image], str]:
if not prompt.strip():
return [], "Add a prompt first."
# block obvious NSFW prompts
if _is_nsfw(prompt) or _is_nsfw(negative_prompt or ""):
return [_blocked_tile("NSFW prompt detected", width, height)], "Blocked: NSFW prompt."
# SD-Turbo is designed for tiny step counts + low/zero CFG
# guard rails on parameters
steps = max(1, min(int(steps), 12))
guidance = max(0.0, min(float(guidance), 2.0))
# Seed
if seed < 0:
seed = random.randint(0, 2**31 - 1)
generator = torch.Generator(device=DEVICE).manual_seed(seed)
out = pipe(
prompt=prompt,
negative_prompt=(negative_prompt or None),
num_inference_steps=steps,
guidance_scale=guidance,
width=width,
height=height,
num_images_per_prompt=batch_size,
generator=generator
)
imgs = out.images
# Some sd-turbo configs may not return nsfw flags; we already block on prompt
msg = f"Model: {MODEL_ID} • Seed: {seed} • Steps: {steps} • CFG: {guidance}{width}x{height} • Batch: {batch_size}"
return imgs, msg
# -------------------------
# UI (defaults tuned for CPU)
# -------------------------
with gr.Blocks(title="VibeForge — Fast (CPU-friendly) Image Gen") as demo:
gr.Markdown(
"""
# VibeForge ⚒️
**Fast, clean image generation (CPU-friendly).**
Uses **SD-Turbo** tuned for low steps. NSFW inputs are blocked.
"""
)
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(
label="Prompt",
placeholder="a neon-lit lighthouse on a stormy cliff at night, cinematic, volumetric fog, high contrast"
)
negative = gr.Textbox(label="Negative Prompt", placeholder="low quality, watermark, overexposed")
with gr.Row():
steps = gr.Slider(1, 12, value=4, step=1, label="Steps (SD-Turbo sweet spot: 2-6)")
guidance = gr.Slider(0.0, 2.0, value=0.5, step=0.1, label="CFG (SD-Turbo likes low)")
with gr.Row():
width = gr.Dropdown(choices=[384, 448, 512], value=384, label="Width")
height = gr.Dropdown(choices=[384, 448, 512], value=384, label="Height")
with gr.Row():
seed = gr.Number(value=-1, label="Seed (-1 = random)", precision=0)
batch = gr.Slider(1, 2, value=1, step=1, label="Batch (keep small on CPU)")
go = gr.Button("Generate", variant="primary")
with gr.Column(scale=5):
gallery = gr.Gallery(label="Output", columns=2, height=448)
info = gr.Markdown()
go.click(
fn=generate,
inputs=[prompt, negative, steps, guidance, width, height, seed, batch],
outputs=[gallery, info]
)
if __name__ == "__main__":
demo.launch()