Spaces:
Running
on
Zero
Running
on
Zero
| # app.py | |
| # ============================================================ | |
| # IMPORTANT: imports order matters for Hugging Face Spaces | |
| # ============================================================ | |
| import os | |
| import gc | |
| import random | |
| import warnings | |
| import logging | |
| import inspect | |
| # ---- Spaces GPU decorator (must be imported early) ---------- | |
| try: | |
| import spaces # noqa: F401 | |
| SPACES_AVAILABLE = True | |
| except Exception: | |
| SPACES_AVAILABLE = False | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from huggingface_hub import login | |
| # ============================================================ | |
| # Try importing Z-Image pipelines (requires diffusers>=0.36.0) | |
| # ============================================================ | |
| ZIMAGE_AVAILABLE = True | |
| ZIMAGE_IMPORT_ERROR = None | |
| try: | |
| from diffusers import ( | |
| ZImagePipeline, | |
| ZImageImg2ImgPipeline, | |
| FlowMatchEulerDiscreteScheduler, | |
| ) | |
| except Exception as e: | |
| ZIMAGE_AVAILABLE = False | |
| ZIMAGE_IMPORT_ERROR = repr(e) | |
| # ============================================================ | |
| # Config | |
| # ============================================================ | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "telcom/dee-z-image").strip() | |
| ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3").strip() | |
| ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "false").lower() == "true" | |
| HF_TOKEN = os.getenv("HF_TOKEN", "").strip() | |
| if HF_TOKEN: | |
| login(token=HF_TOKEN) | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| warnings.filterwarnings("ignore") | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| # ============================================================ | |
| # Device & dtype | |
| # ============================================================ | |
| cuda_available = torch.cuda.is_available() | |
| device = torch.device("cuda" if cuda_available else "cpu") | |
| if cuda_available and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported(): | |
| dtype = torch.bfloat16 | |
| elif cuda_available: | |
| dtype = torch.float16 | |
| else: | |
| dtype = torch.float32 | |
| MAX_IMAGE_SIZE = 1536 if cuda_available else 768 | |
| fallback_msg = "" | |
| if not cuda_available: | |
| fallback_msg = "GPU unavailable. Running in CPU fallback mode (slow)." | |
| # ============================================================ | |
| # Load pipelines | |
| # ============================================================ | |
| pipe_txt2img = None | |
| pipe_img2img = None | |
| model_loaded = False | |
| load_error = None | |
| def _set_attention_backend_best_effort(p): | |
| try: | |
| if hasattr(p, "transformer") and hasattr(p.transformer, "set_attention_backend"): | |
| p.transformer.set_attention_backend(ATTENTION_BACKEND) | |
| except Exception: | |
| pass | |
| def _compile_best_effort(p): | |
| if not (ENABLE_COMPILE and device.type == "cuda"): | |
| return | |
| try: | |
| if hasattr(p, "transformer"): | |
| p.transformer = torch.compile( | |
| p.transformer, | |
| mode="max-autotune-no-cudagraphs", | |
| fullgraph=False, | |
| ) | |
| except Exception: | |
| pass | |
| if ZIMAGE_AVAILABLE: | |
| try: | |
| fp_kwargs = { | |
| "torch_dtype": dtype, | |
| "use_safetensors": True, | |
| } | |
| if HF_TOKEN: | |
| fp_kwargs["token"] = HF_TOKEN | |
| pipe_txt2img = ZImagePipeline.from_pretrained(MODEL_PATH, **fp_kwargs).to(device) | |
| _set_attention_backend_best_effort(pipe_txt2img) | |
| _compile_best_effort(pipe_txt2img) | |
| try: | |
| pipe_txt2img.set_progress_bar_config(disable=True) | |
| except Exception: | |
| pass | |
| # Share weights/components with img2img pipeline | |
| pipe_img2img = ZImageImg2ImgPipeline(**pipe_txt2img.components).to(device) | |
| _set_attention_backend_best_effort(pipe_img2img) | |
| try: | |
| pipe_img2img.set_progress_bar_config(disable=True) | |
| except Exception: | |
| pass | |
| model_loaded = True | |
| except Exception as e: | |
| load_error = repr(e) | |
| model_loaded = False | |
| else: | |
| load_error = ( | |
| "Z-Image pipelines not available in your diffusers install.\n\n" | |
| f"Import error:\n{ZIMAGE_IMPORT_ERROR}\n\n" | |
| "Fix: set requirements.txt to diffusers==0.36.0 (or install Diffusers from source)." | |
| ) | |
| model_loaded = False | |
| # ============================================================ | |
| # Helpers | |
| # ============================================================ | |
| def make_error_image(w: int, h: int) -> Image.Image: | |
| return Image.new("RGB", (int(w), int(h)), (18, 18, 22)) | |
| def prep_init_image(img: Image.Image, width: int, height: int) -> Image.Image: | |
| if img is None: | |
| return None | |
| if not isinstance(img, Image.Image): | |
| return None | |
| img = img.convert("RGB") | |
| if img.size != (width, height): | |
| img = img.resize((width, height), Image.LANCZOS) | |
| return img | |
| def _call_pipeline(pipe, kwargs: dict): | |
| """ | |
| Robust call: only pass kwargs the pipeline actually accepts. | |
| This avoids crashes if a particular build does not support negative_prompt, etc. | |
| """ | |
| try: | |
| sig = inspect.signature(pipe.__call__) | |
| allowed = set(sig.parameters.keys()) | |
| filtered = {k: v for k, v in kwargs.items() if k in allowed and v is not None} | |
| return pipe(**filtered) | |
| except Exception: | |
| # Fallback: try raw kwargs (some pipelines use **kwargs internally) | |
| return pipe(**{k: v for k, v in kwargs.items() if v is not None}) | |
| # ============================================================ | |
| # Inference | |
| # ============================================================ | |
| def _infer_impl( | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| shift, | |
| max_sequence_length, | |
| init_image, | |
| strength, | |
| ): | |
| width = int(width) | |
| height = int(height) | |
| seed = int(seed) | |
| if not model_loaded: | |
| return make_error_image(width, height), f"Model load failed: {load_error}" | |
| prompt = (prompt or "").strip() | |
| if not prompt: | |
| return make_error_image(width, height), "Error: prompt is empty." | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| status = f"Seed: {seed}" | |
| if fallback_msg: | |
| status += f" | {fallback_msg}" | |
| gs = float(guidance_scale) | |
| steps = int(num_inference_steps) | |
| msl = int(max_sequence_length) | |
| st = float(strength) | |
| neg = (negative_prompt or "").strip() | |
| if not neg: | |
| neg = None | |
| init_image = prep_init_image(init_image, width, height) | |
| # Update scheduler (shift) per run | |
| try: | |
| scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=float(shift)) | |
| pipe_txt2img.scheduler = scheduler | |
| pipe_img2img.scheduler = scheduler | |
| except Exception: | |
| pass | |
| try: | |
| base_kwargs = dict( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| guidance_scale=gs, | |
| num_inference_steps=steps, | |
| generator=generator, | |
| max_sequence_length=msl, | |
| ) | |
| # only passed if supported by the pipeline | |
| if neg is not None: | |
| base_kwargs["negative_prompt"] = neg | |
| with torch.inference_mode(): | |
| if device.type == "cuda": | |
| with torch.autocast("cuda", dtype=dtype): | |
| if init_image is not None: | |
| out = _call_pipeline( | |
| pipe_img2img, | |
| {**base_kwargs, "image": init_image, "strength": st}, | |
| ) | |
| else: | |
| out = _call_pipeline(pipe_txt2img, base_kwargs) | |
| else: | |
| if init_image is not None: | |
| out = _call_pipeline( | |
| pipe_img2img, | |
| {**base_kwargs, "image": init_image, "strength": st}, | |
| ) | |
| else: | |
| out = _call_pipeline(pipe_txt2img, base_kwargs) | |
| img = out.images[0] | |
| return img, status | |
| except Exception as e: | |
| return make_error_image(width, height), f"Error: {type(e).__name__}: {e}" | |
| finally: | |
| gc.collect() | |
| if device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| if SPACES_AVAILABLE: | |
| def infer(*args, **kwargs): | |
| return _infer_impl(*args, **kwargs) | |
| else: | |
| def infer(*args, **kwargs): | |
| return _infer_impl(*args, **kwargs) | |
| # ============================================================ | |
| # UI (simple black style like your SDXL example) | |
| # ============================================================ | |
| CSS = """ | |
| body { | |
| background: #000; | |
| color: #fff; | |
| } | |
| """ | |
| with gr.Blocks(title="Z-Image txt2img + img2img") as demo: | |
| gr.HTML(f"<style>{CSS}</style>") | |
| if fallback_msg: | |
| gr.Markdown(f"**{fallback_msg}**") | |
| if not model_loaded: | |
| gr.Markdown(f"⚠️ Model failed to load:\n\n{load_error}") | |
| gr.Markdown("## Z-Image Generator (txt2img + img2img)") | |
| prompt = gr.Textbox(label="Prompt", lines=2) | |
| init_image = gr.Image(label="Initial image (optional)", type="pil") | |
| run_button = gr.Button("Generate") | |
| result = gr.Image(label="Result") | |
| status = gr.Markdown("") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| negative_prompt = gr.Textbox(label="Negative prompt (optional)") | |
| seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed") | |
| randomize_seed = gr.Checkbox(value=True, label="Randomize seed") | |
| width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=1024, label="Width") | |
| height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=1024, label="Height") | |
| guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=0.0, label="Guidance scale") | |
| num_inference_steps = gr.Slider(1, 100, step=1, value=8, label="Steps") | |
| shift = gr.Slider(1.0, 10.0, step=0.1, value=3.0, label="Time shift") | |
| max_sequence_length = gr.Slider(64, 512, step=64, value=512, label="Max sequence length") | |
| strength = gr.Slider(0.0, 1.0, step=0.05, value=0.6, label="Image strength (img2img)") | |
| run_button.click( | |
| fn=infer, | |
| inputs=[ | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| shift, | |
| max_sequence_length, | |
| init_image, | |
| strength, | |
| ], | |
| outputs=[result, status], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(ssr_mode=False) |