rahul7star commited on
Commit
1b8645f
Β·
verified Β·
1 Parent(s): 1727d4f

Create app1.py

Browse files
Files changed (1) hide show
  1. app1.py +197 -0
app1.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spaces
3
+ import gradio as gr
4
+ from diffusers import DiffusionPipeline
5
+ import diffusers
6
+ import io
7
+ import logging
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+
10
+ # ------------------------
11
+ # GLOBAL LOG BUFFER
12
+ # ------------------------
13
+ log_buffer = io.StringIO()
14
+
15
+ def log(msg):
16
+ print(msg)
17
+ log_buffer.write(msg + "\n")
18
+
19
+ # Enable diffusers debug logs
20
+ diffusers.utils.logging.set_verbosity_info()
21
+
22
+ log("Loading Z-Image-Turbo pipeline...")
23
+
24
+ # ------------------------
25
+ # Load FP8 text encoder + tokenizer
26
+ # ------------------------
27
+ log("Loading FP8 Qwen3-4B tokenizer + text encoder...")
28
+ fp8_tokenizer = AutoTokenizer.from_pretrained(
29
+ "jiangchengchengNLP/qwen3-4b-fp8-scaled"
30
+ )
31
+ fp8_text_encoder = AutoModelForCausalLM.from_pretrained(
32
+ "jiangchengchengNLP/qwen3-4b-fp8-scaled",
33
+ device_map="auto",
34
+ torch_dtype=torch.bfloat16, # can replace with torch.float8_e4m3fn if PyTorch nightly supports
35
+ )
36
+
37
+ # ------------------------
38
+ # Load Z-Image-Turbo
39
+ # ------------------------
40
+ pipe = DiffusionPipeline.from_pretrained(
41
+ "Tongyi-MAI/Z-Image-Turbo",
42
+ torch_dtype=torch.bfloat16,
43
+ low_cpu_mem_usage=False,
44
+ attn_implementation="kernels-community/vllm-flash-attn3",
45
+ )
46
+
47
+ # Inject FP8 tokenizer + text encoder
48
+ pipe.tokenizer = fp8_tokenizer
49
+ pipe.text_encoder = fp8_text_encoder
50
+ pipe.to("cuda")
51
+
52
+ # ------------------------
53
+ # Pipeline debug info
54
+ # ------------------------
55
+ def pipeline_debug_info(pipe):
56
+ info = []
57
+ info.append("=== PIPELINE DEBUG INFO ===")
58
+
59
+ try:
60
+ tr = pipe.transformer.config
61
+ info.append(f"Transformer Class: {pipe.transformer.__class__.__name__}")
62
+ # Z-Image-Turbo keys
63
+ info.append(f"Hidden dim: {tr.get('hidden_dim')}")
64
+ info.append(f"Attention heads: {tr.get('num_heads')}")
65
+ info.append(f"Depth (layers): {tr.get('depth')}")
66
+ info.append(f"Patch size: {tr.get('patch_size')}")
67
+ info.append(f"MLP ratio: {tr.get('mlp_ratio')}")
68
+ info.append(f"Attention backend: {tr.get('attn_implementation')}")
69
+ except Exception as e:
70
+ info.append(f"Transformer diagnostics failed: {e}")
71
+
72
+ # VAE info
73
+ try:
74
+ vae = pipe.vae.config
75
+ info.append(f"VAE latent channels: {vae.latent_channels}")
76
+ info.append(f"VAE scaling factor: {vae.scaling_factor}")
77
+ except Exception as e:
78
+ info.append(f"VAE diagnostics failed: {e}")
79
+
80
+ return "\n".join(info)
81
+
82
+
83
+ def latent_shape_info(h, w, pipe):
84
+ try:
85
+ c = pipe.vae.config.latent_channels
86
+ s = pipe.vae.config.scaling_factor
87
+ h_lat = int(h * s)
88
+ w_lat = int(w * s)
89
+ return f"Latent shape β†’ ({c}, {h_lat}, {w_lat})"
90
+ except Exception as e:
91
+ return f"Latent shape calc failed: {e}"
92
+
93
+
94
+ # ------------------------
95
+ # IMAGE GENERATOR
96
+ # ------------------------
97
+ @spaces.GPU
98
+ def generate_image(prompt, height, width, num_inference_steps, seed, randomize_seed, num_images):
99
+ log_buffer.truncate(0)
100
+ log_buffer.seek(0)
101
+
102
+ log("=== NEW GENERATION REQUEST ===")
103
+ log(f"Prompt: {prompt}")
104
+ log(f"Height: {height}, Width: {width}")
105
+ log(f"Inference Steps: {num_inference_steps}")
106
+ log(f"Num Images: {num_images}")
107
+
108
+ if randomize_seed:
109
+ seed = torch.randint(0, 2**32 - 1, (1,)).item()
110
+ log(f"Randomized Seed β†’ {seed}")
111
+ else:
112
+ log(f"Seed: {seed}")
113
+
114
+ num_images = min(max(1, int(num_images)), 3)
115
+
116
+ # Debug pipeline info
117
+ log(pipeline_debug_info(pipe))
118
+
119
+ generator = torch.Generator("cuda").manual_seed(int(seed))
120
+
121
+ log("Running pipeline forward()...")
122
+ result = pipe(
123
+ prompt=prompt,
124
+ height=int(height),
125
+ width=int(width),
126
+ num_inference_steps=int(num_inference_steps),
127
+ guidance_scale=0.0,
128
+ generator=generator,
129
+ max_sequence_length=1024,
130
+ num_images_per_prompt=num_images,
131
+ output_type="pil",
132
+ )
133
+
134
+ # Tensor diagnostics (shapes only)
135
+ try:
136
+ log(f"VAE latent channels: {pipe.vae.config.latent_channels}")
137
+ log(f"VAE scaling factor: {pipe.vae.config.scaling_factor}")
138
+ log(latent_shape_info(height, width, pipe))
139
+ except Exception as e:
140
+ log(f"Latent diagnostics error: {e}")
141
+
142
+ log("Pipeline finished.")
143
+ log("Returning images...")
144
+
145
+ return result.images, seed, log_buffer.getvalue()
146
+
147
+
148
+ # ------------------------
149
+ # GRADIO UI
150
+ # ------------------------
151
+ examples = [
152
+ ["Young Chinese woman in red Hanfu, intricate embroidery..."],
153
+ ["A majestic dragon soaring through clouds at sunset..."],
154
+ ["Cozy coffee shop interior, warm lighting, rain on windows..."],
155
+ ["Astronaut riding a horse on Mars, cinematic lighting..."],
156
+ ["Portrait of a wise old wizard..."],
157
+ ]
158
+
159
+ with gr.Blocks(title="Z-Image-Turbo Multi Image Demo") as demo:
160
+ gr.Markdown("# 🎨 Z-Image-Turbo β€” Multi Image ")
161
+
162
+ with gr.Row():
163
+ with gr.Column(scale=1):
164
+ prompt = gr.Textbox(label="Prompt", lines=4)
165
+
166
+ with gr.Row():
167
+ height = gr.Slider(512, 2048, 1024, step=64, label="Height")
168
+ width = gr.Slider(512, 2048, 1024, step=64, label="Width")
169
+
170
+ num_images = gr.Slider(1, 3, 2, step=1, label="Number of Images")
171
+
172
+ num_inference_steps = gr.Slider(
173
+ 1, 20, 9, step=1, label="Inference Steps",
174
+ info="9 steps = 8 DiT forward passes",
175
+ )
176
+
177
+ with gr.Row():
178
+ seed = gr.Number(label="Seed", value=42, precision=0)
179
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
180
+
181
+ generate_btn = gr.Button("πŸš€ Generate", variant="primary")
182
+
183
+ with gr.Column(scale=1):
184
+ output_images = gr.Gallery(label="Generated Images")
185
+ used_seed = gr.Number(label="Seed Used", interactive=False)
186
+ debug_log = gr.Textbox(label="Debug Log Output", lines=25, interactive=False)
187
+
188
+ gr.Examples(examples=examples, inputs=[prompt], cache_examples=False)
189
+
190
+ generate_btn.click(
191
+ fn=generate_image,
192
+ inputs=[prompt, height, width, num_inference_steps, seed, randomize_seed, num_images],
193
+ outputs=[output_images, used_seed, debug_log],
194
+ )
195
+
196
+ if __name__ == "__main__":
197
+ demo.launch()