import gradio as gr import numpy as np import random import json import spaces import torch from diffusers import DiffusionPipeline from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from videox_fun.pipeline import ZImageControlPipeline from videox_fun.models import ZImageControlTransformer2DModel from transformers import AutoTokenizer, Qwen3ForCausalLM from diffusers import AutoencoderKL from image_utils import get_image_latent, scale_image # from videox_fun.utils.utils import get_image_latent # MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo" MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1280 # git clone https://huggingface.co/Tongyi-MAI/Z-Image-Turbo MODEL_LOCAL = "models/Z-Image-Turbo/" # curl -L -o Z-Image-Turbo-Fun-Controlnet-Union.safetensors https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/Z-Image-Turbo-Fun-Controlnet-Union.safetensors TRANSFORMER_LOCAL = "models/Z-Image-Turbo-Fun-Controlnet-Union.safetensors" weight_dtype = torch.bfloat16 # load transformer transformer = ZImageControlTransformer2DModel.from_pretrained( MODEL_LOCAL, subfolder="transformer", low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, transformer_additional_kwargs={ "control_layers_places": [0, 5, 10, 15, 20, 25], "control_in_dim": 16 }, ).to(torch.bfloat16) if TRANSFORMER_LOCAL is not None: print(f"From checkpoint: {TRANSFORMER_LOCAL}") if TRANSFORMER_LOCAL.endswith("safetensors"): from safetensors.torch import load_file, safe_open state_dict = load_file(TRANSFORMER_LOCAL) else: state_dict = torch.load(TRANSFORMER_LOCAL, map_location="cpu") state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict m, u = transformer.load_state_dict(state_dict, strict=False) print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") # load ZImageControlPipeline vae = AutoencoderKL.from_pretrained( MODEL_LOCAL, subfolder="vae" ).to(weight_dtype) tokenizer = AutoTokenizer.from_pretrained( MODEL_LOCAL, subfolder="tokenizer" ) text_encoder = Qwen3ForCausalLM.from_pretrained( MODEL_LOCAL, subfolder="text_encoder", torch_dtype=weight_dtype, low_cpu_mem_usage=True, ) scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3) pipe = ZImageControlPipeline( vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler, ) pipe.transformer = transformer pipe.to("cuda") # ======== AoTI compilation + FA3 ======== pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"] spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3") @spaces.GPU def inference( prompt, input_image, image_scale=1.0, control_context_scale = 0.75, seed=42, randomize_seed=True, guidance_scale=1.5, num_inference_steps=8, progress=gr.Progress(track_tqdm=True), ): # process image if input_image is None: print("Error: input_image is empty.") return None input_image, width, height = scale_image(input_image, image_scale) control_image = get_image_latent(input_image, sample_size=[height, width])[:, :, 0] # generation if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) image = pipe( prompt=prompt, height=height, width=width, generator=generator, guidance_scale=guidance_scale, control_image=control_image, num_inference_steps=num_inference_steps, control_context_scale=control_context_scale, ).images[0] return image, seed def read_file(path: str) -> str: with open(path, 'r', encoding='utf-8') as f: content = f.read() return content css = """ #col-container { margin: 0 auto; max-width: 960px; } """ with open('static/data.json', 'r') as file: data = json.load(file) examples = data['examples'] with gr.Blocks() as demo: with gr.Column(elem_id="col-container"): with gr.Column(): gr.HTML(read_file("static/header.html")) with gr.Row(equal_height=True): with gr.Column(): input_image = gr.Image( height=290, sources=['upload', 'clipboard'], image_mode='RGB', # elem_id="image_upload", type="pil", label="Upload") prompt = gr.Textbox( label="Prompt", show_label=False, lines=2, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Run", variant="primary") with gr.Column(): output_image = gr.Image(label="Result", show_label=False) with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): image_scale = gr.Slider( label="Image scale", minimum=0.5, maximum=2.0, step=0.1, value=1.0, ) control_context_scale = gr.Slider( label="Control context scale", minimum=0.0, maximum=1.0, step=0.1, value=0.75, ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=2.5, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=30, step=1, value=8, ) gr.Examples(examples=examples, inputs=[input_image, prompt]) gr.HTML(read_file("static/footer.html")) gr.on( triggers=[run_button.click, prompt.submit], fn=inference, inputs=[ prompt, input_image, image_scale, control_context_scale, seed, randomize_seed, guidance_scale, num_inference_steps, ], outputs=[output_image, seed], ) if __name__ == "__main__": demo.launch(mcp_server=True)