ZIT-Controlnet / app.py
Alexander Bagus
22
89aa5ab
raw
history blame
7 kB
import gradio as gr
import numpy as np
import random
import json
import spaces
import torch
from diffusers import DiffusionPipeline
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from videox_fun.pipeline import ZImageControlPipeline
from videox_fun.models import ZImageControlTransformer2DModel
from transformers import AutoTokenizer, Qwen3ForCausalLM
from diffusers import AutoencoderKL
from image_utils import get_image_latent, scale_image
# from videox_fun.utils.utils import get_image_latent
# MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1280
# git clone https://huggingface.co/Tongyi-MAI/Z-Image-Turbo
MODEL_LOCAL = "models/Z-Image-Turbo/"
# curl -L -o Z-Image-Turbo-Fun-Controlnet-Union.safetensors https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/Z-Image-Turbo-Fun-Controlnet-Union.safetensors
TRANSFORMER_LOCAL = "models/Z-Image-Turbo-Fun-Controlnet-Union.safetensors"
weight_dtype = torch.bfloat16
# load transformer
transformer = ZImageControlTransformer2DModel.from_pretrained(
MODEL_LOCAL,
subfolder="transformer",
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
transformer_additional_kwargs={
"control_layers_places": [0, 5, 10, 15, 20, 25],
"control_in_dim": 16
},
).to(torch.bfloat16)
if TRANSFORMER_LOCAL is not None:
print(f"From checkpoint: {TRANSFORMER_LOCAL}")
if TRANSFORMER_LOCAL.endswith("safetensors"):
from safetensors.torch import load_file, safe_open
state_dict = load_file(TRANSFORMER_LOCAL)
else:
state_dict = torch.load(TRANSFORMER_LOCAL, map_location="cpu")
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
m, u = transformer.load_state_dict(state_dict, strict=False)
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
# load ZImageControlPipeline
vae = AutoencoderKL.from_pretrained(
MODEL_LOCAL,
subfolder="vae"
).to(weight_dtype)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_LOCAL, subfolder="tokenizer"
)
text_encoder = Qwen3ForCausalLM.from_pretrained(
MODEL_LOCAL, subfolder="text_encoder", torch_dtype=weight_dtype,
low_cpu_mem_usage=True,
)
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3)
pipe = ZImageControlPipeline(
vae=vae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
scheduler=scheduler,
)
pipe.transformer = transformer
pipe.to("cuda")
# ======== AoTI compilation + FA3 ========
pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
spaces.aoti_blocks_load(pipe.transformer.layers,
"zerogpu-aoti/Z-Image", variant="fa3")
@spaces.GPU
def inference(
prompt,
input_image,
image_scale=1.0,
control_context_scale = 0.75,
seed=42,
randomize_seed=True,
guidance_scale=1.5,
num_inference_steps=8,
progress=gr.Progress(track_tqdm=True),
):
# process image
if input_image is None:
print("Error: input_image is empty.")
return None
input_image, width, height = scale_image(input_image, image_scale)
control_image = get_image_latent(input_image, sample_size=[height, width])[:, :, 0]
# generation
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
height=height,
width=width,
generator=generator,
guidance_scale=guidance_scale,
control_image=control_image,
num_inference_steps=num_inference_steps,
control_context_scale=control_context_scale,
).images[0]
return image, seed
def read_file(path: str) -> str:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
return content
css = """
#col-container {
margin: 0 auto;
max-width: 960px;
}
"""
with open('static/data.json', 'r') as file:
data = json.load(file)
examples = data['examples']
with gr.Blocks() as demo:
with gr.Column(elem_id="col-container"):
with gr.Column():
gr.HTML(read_file("static/header.html"))
with gr.Row(equal_height=True):
with gr.Column():
input_image = gr.Image(
height=290, sources=['upload', 'clipboard'],
image_mode='RGB',
# elem_id="image_upload",
type="pil", label="Upload")
prompt = gr.Textbox(
label="Prompt",
show_label=False,
lines=2,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", variant="primary")
with gr.Column():
output_image = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
image_scale = gr.Slider(
label="Image scale",
minimum=0.5,
maximum=2.0,
step=0.1,
value=1.0,
)
control_context_scale = gr.Slider(
label="Control context scale",
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.75,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=2.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=30,
step=1,
value=8,
)
gr.Examples(examples=examples, inputs=[input_image, prompt])
gr.HTML(read_file("static/footer.html"))
gr.on(
triggers=[run_button.click, prompt.submit],
fn=inference,
inputs=[
prompt,
input_image,
image_scale,
control_context_scale,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
],
outputs=[output_image, seed],
)
if __name__ == "__main__":
demo.launch(mcp_server=True)