ZIT-Controlnet / app.py
Alexander Bagus
22
ed56a3e
raw
history blame
8.66 kB
import gradio as gr
import numpy as np
import random, json, spaces, torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from videox_fun.pipeline import ZImageControlPipeline
from videox_fun.models import ZImageControlTransformer2DModel
from transformers import AutoTokenizer, Qwen3ForCausalLM
from diffusers import AutoencoderKL
from utils.image_utils import get_image_latent, rescale_image
from utils.prompt_utils import polish_prompt
# from controlnet_aux import HEDdetector, MLSDdetector, OpenposeDetector, CannyDetector, MidasDetector
from controlnet_aux.processor import Processor
# MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1280
# git clone https://huggingface.co/Tongyi-MAI/Z-Image-Turbo
MODEL_LOCAL = "models/Z-Image-Turbo/"
# curl -L -o Z-Image-Turbo-Fun-Controlnet-Union.safetensors https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/Z-Image-Turbo-Fun-Controlnet-Union.safetensors
TRANSFORMER_LOCAL = "models/Z-Image-Turbo-Fun-Controlnet-Union.safetensors"
weight_dtype = torch.bfloat16
# load transformer
transformer = ZImageControlTransformer2DModel.from_pretrained(
MODEL_LOCAL,
subfolder="transformer",
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
transformer_additional_kwargs={
"control_layers_places": [0, 5, 10, 15, 20, 25],
"control_in_dim": 16
},
).to(torch.bfloat16)
if TRANSFORMER_LOCAL is not None:
print(f"From checkpoint: {TRANSFORMER_LOCAL}")
from safetensors.torch import load_file, safe_open
state_dict = load_file(TRANSFORMER_LOCAL)
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
m, u = transformer.load_state_dict(state_dict, strict=False)
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
# load ZImageControlPipeline
vae = AutoencoderKL.from_pretrained(
MODEL_LOCAL,
subfolder="vae"
).to(weight_dtype)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_LOCAL, subfolder="tokenizer"
)
text_encoder = Qwen3ForCausalLM.from_pretrained(
MODEL_LOCAL, subfolder="text_encoder", torch_dtype=weight_dtype,
low_cpu_mem_usage=False,
)
# scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3)
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
MODEL_LOCAL,
subfolder="scheduler"
)
pipe = ZImageControlPipeline(
vae=vae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
scheduler=scheduler,
)
pipe.transformer = transformer
pipe.to("cuda")
# ======== AoTI compilation + FA3 ========
pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
spaces.aoti_blocks_load(pipe.transformer.layers,
"zerogpu-aoti/Z-Image", variant="fa3")
def prepare(prompt):
polished_prompt = polish_prompt(prompt)
return polished_prompt
@spaces.GPU
def inference(
prompt,
input_image,
image_scale=1.0,
control_mode='Canny',
control_context_scale = 0.75,
seed=42,
randomize_seed=True,
guidance_scale=1.5,
num_inference_steps=8,
progress=gr.Progress(track_tqdm=True),
):
# process image
print("DEBUG: process image")
if input_image is None:
print("Error: input_image is empty.")
return None
# input_image, width, height = scale_image(input_image, image_scale)
# control_mode='HED'
processor_id = 'canny'
if control_mode == 'HED':
processor_id = 'softedge_hed'
if control_mode =='Midas':
processor_id = 'depth_midas'
if control_mode =='MLSD':
processor_id = 'mlsd'
if control_mode =='Pose':
processor_id = 'openpose_full'
print(f"DEBUG: processor_id={processor_id}")
processor = Processor(processor_id)
# Width must be divisible by 16
control_image, width, height = rescale_image(input_image, image_scale, 16)
control_image = control_image.resize((1024, 1024))
print("DEBUG: processor running")
control_image = processor(control_image, to_pil=True)
control_image = control_image.resize((width, height))
print("DEBUG: control_image_torch")
control_image_torch = get_image_latent(control_image, sample_size=[height, width])[:, :, 0]
# generation
if randomize_seed: seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
height=height,
width=width,
generator=generator,
guidance_scale=guidance_scale,
control_image=control_image_torch,
num_inference_steps=num_inference_steps,
control_context_scale=control_context_scale,
).images[0]
return image, seed, control_image
def read_file(path: str) -> str:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
return content
css = """
#col-container {
margin: 0 auto;
max-width: 960px;
}
"""
with open('static/data.json', 'r') as file:
data = json.load(file)
examples = data['examples']
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
with gr.Column():
gr.HTML(read_file("static/header.html"))
with gr.Row():
with gr.Column():
input_image = gr.Image(
height=290, sources=['upload', 'clipboard'],
image_mode='RGB',
# elem_id="image_upload",
type="pil", label="Upload")
prompt = gr.Textbox(
label="Prompt",
show_label=False,
lines=2,
placeholder="Enter your prompt",
container=False,
)
control_mode = gr.Radio(
choices=["HED", "Canny", "Midas", "MLSD", "Pose"],
value="HED",
label="Control Mode"
)
run_button = gr.Button("Generate", variant="primary")
with gr.Column():
output_image = gr.Image(label="Generated image", show_label=False)
polished_prompt = gr.Textbox(label="Polished prompt", interactive=False)
with gr.Accordion("Preprocessor output", open=False):
control_image = gr.Image(label="Control image", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
with gr.Row():
image_scale = gr.Slider(
label="Image scale",
minimum=0.5,
maximum=2.0,
step=0.1,
value=1.0,
)
control_context_scale = gr.Slider(
label="Control context scale",
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.75,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=2.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=30,
step=1,
value=8,
)
gr.Examples(examples=examples, inputs=[input_image, prompt])
gr.HTML(read_file("static/footer.html"))
run_button.click(
fn=prepare,
inputs=prompt,
outputs=[polished_prompt]
# outputs=gr.State(), # Pass to the next function, not to UI at this step
).then(
fn=inference,
inputs=[
polished_prompt,
input_image,
image_scale,
control_mode,
control_context_scale,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
],
outputs=[output_image, seed, control_image],
)
if __name__ == "__main__":
demo.launch(mcp_server=True)