import gradio as gr import numpy as np import random, json, spaces, torch from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from videox_fun.pipeline import ZImageControlPipeline from videox_fun.models import ZImageControlTransformer2DModel from transformers import AutoTokenizer, Qwen3ForCausalLM from diffusers import AutoencoderKL from utils.image_utils import get_image_latent, rescale_image from utils.prompt_utils import polish_prompt # from controlnet_aux import HEDdetector, MLSDdetector, OpenposeDetector, CannyDetector, MidasDetector from controlnet_aux.processor import Processor # MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo" MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1280 # git clone https://huggingface.co/Tongyi-MAI/Z-Image-Turbo MODEL_LOCAL = "models/Z-Image-Turbo/" # curl -L -o Z-Image-Turbo-Fun-Controlnet-Union.safetensors https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/Z-Image-Turbo-Fun-Controlnet-Union.safetensors TRANSFORMER_LOCAL = "models/Z-Image-Turbo-Fun-Controlnet-Union.safetensors" weight_dtype = torch.bfloat16 # load transformer transformer = ZImageControlTransformer2DModel.from_pretrained( MODEL_LOCAL, subfolder="transformer", # low_cpu_mem_usage=False, # torch_dtype=torch.bfloat16, transformer_additional_kwargs={ "control_layers_places": [0, 5, 10, 15, 20, 25], "control_in_dim": 16 }, ).to("cuda", torch.bfloat16) if TRANSFORMER_LOCAL is not None: print(f"From checkpoint: {TRANSFORMER_LOCAL}") from safetensors.torch import load_file, safe_open state_dict = load_file(TRANSFORMER_LOCAL) state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict m, u = transformer.load_state_dict(state_dict, strict=False) print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") # load ZImageControlPipeline vae = AutoencoderKL.from_pretrained( MODEL_LOCAL, subfolder="vae" ).to(weight_dtype) tokenizer = AutoTokenizer.from_pretrained( MODEL_LOCAL, subfolder="tokenizer" ) text_encoder = Qwen3ForCausalLM.from_pretrained( MODEL_LOCAL, subfolder="text_encoder", torch_dtype=weight_dtype, low_cpu_mem_usage=False, ) # scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3) scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( MODEL_LOCAL, subfolder="scheduler" ) pipe = ZImageControlPipeline( vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler, ) pipe.to("cuda", torch.bfloat16) # pipe.transformer = transformer # pipe.to("cuda") # ======== AoTI compilation + FA3 ======== pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"] spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3") def prepare(prompt): polished_prompt = polish_prompt(prompt) return polished_prompt @spaces.GPU def inference( prompt, input_image, image_scale=1.0, control_mode='Canny', control_context_scale = 0.75, seed=42, randomize_seed=True, guidance_scale=1.5, num_inference_steps=8, progress=gr.Progress(track_tqdm=True), ): # process image print("DEBUG: process image") if input_image is None: print("Error: input_image is empty.") return None # input_image, width, height = scale_image(input_image, image_scale) # control_mode='HED' processor_id = 'canny' if control_mode == 'HED': processor_id = 'softedge_hed' if control_mode =='Midas': processor_id = 'depth_midas' if control_mode =='MLSD': processor_id = 'mlsd' if control_mode =='Pose': processor_id = 'openpose_full' print(f"DEBUG: processor_id={processor_id}") processor = Processor(processor_id) # Width must be divisible by 16 control_image, width, height = rescale_image(input_image, image_scale, 16) control_image = control_image.resize((1024, 1024)) print("DEBUG: processor running") control_image = processor(control_image, to_pil=True) control_image = control_image.resize((width, height)) print("DEBUG: control_image_torch") control_image_torch = get_image_latent(control_image, sample_size=[height, width])[:, :, 0] # generation if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) image = pipe( prompt=prompt, height=height, width=width, generator=generator, guidance_scale=guidance_scale, control_image=control_image_torch, num_inference_steps=num_inference_steps, control_context_scale=control_context_scale, ).images[0] return image, seed, control_image def read_file(path: str) -> str: with open(path, 'r', encoding='utf-8') as f: content = f.read() return content css = """ #col-container { margin: 0 auto; max-width: 960px; } """ with open('static/data.json', 'r') as file: data = json.load(file) examples = data['examples'] with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): with gr.Column(): gr.HTML(read_file("static/header.html")) with gr.Row(): with gr.Column(): input_image = gr.Image( height=290, sources=['upload', 'clipboard'], image_mode='RGB', # elem_id="image_upload", type="pil", label="Upload") prompt = gr.Textbox( label="Prompt", show_label=False, lines=2, placeholder="Enter your prompt", container=False, ) control_mode = gr.Radio( choices=["HED", "Canny", "Midas", "MLSD", "Pose"], value="HED", label="Control Mode" ) run_button = gr.Button("Generate", variant="primary") with gr.Column(): output_image = gr.Image(label="Generated image", show_label=False) polished_prompt = gr.Textbox(label="Polished prompt", interactive=False) with gr.Accordion("Preprocessor output", open=False): control_image = gr.Image(label="Control image", show_label=False) with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=False) with gr.Row(): image_scale = gr.Slider( label="Image scale", minimum=0.5, maximum=2.0, step=0.1, value=1.0, ) control_context_scale = gr.Slider( label="Control context scale", minimum=0.0, maximum=1.0, step=0.1, value=0.75, ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=2.5, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=30, step=1, value=8, ) gr.Examples(examples=examples, inputs=[input_image, prompt]) gr.HTML(read_file("static/footer.html")) run_button.click( fn=prepare, inputs=prompt, outputs=[polished_prompt] # outputs=gr.State(), # Pass to the next function, not to UI at this step ).then( fn=inference, inputs=[ polished_prompt, input_image, image_scale, control_mode, control_context_scale, seed, randomize_seed, guidance_scale, num_inference_steps, ], outputs=[output_image, seed, control_image], ) if __name__ == "__main__": demo.launch(mcp_server=True)