QJerry commited on
Commit
95e2d44
·
verified ·
1 Parent(s): eef332d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +418 -0
app.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import json
5
+ import random
6
+ import sys
7
+ import logging
8
+ import warnings
9
+ import re
10
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
11
+ from transformers import AutoModel, AutoTokenizer
12
+ from dataclasses import dataclass
13
+
14
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
15
+
16
+ from diffusers import ZImagePipeline
17
+ from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
18
+ from pe import prompt_template
19
+
20
+
21
+ # ==================== Environment Variables ================================
22
+ MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo")
23
+ ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true"
24
+ ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true"
25
+ ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "_flash_3")
26
+ DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
27
+ # ===========================================================================
28
+
29
+
30
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
31
+ warnings.filterwarnings("ignore")
32
+ logging.getLogger("transformers").setLevel(logging.ERROR)
33
+
34
+ RESOLUTION_SET = [
35
+ "1024x1024 ( 1:1 )",
36
+ "1152x896 ( 9:7 )",
37
+ "896x1152 ( 7:9 )",
38
+ "1152x864 ( 4:3 )",
39
+ "864x1152 ( 3:4 )",
40
+ "1248x832 ( 3:2 )",
41
+ "832x1248 ( 2:3 )",
42
+ "1280x720 (16:9 )",
43
+ "720x1280 (9:16 )",
44
+ "1344x576 (21:9 )",
45
+ "576x1344 (9:21 )",
46
+ ]
47
+
48
+ RES_CHOICES = {
49
+ "1024": [
50
+ "1024x1024 ( 1:1 )",
51
+ "1152x896 ( 9:7 )",
52
+ "896x1152 ( 7:9 )",
53
+ "1152x864 ( 4:3 )",
54
+ "864x1152 ( 3:4 )",
55
+ "1248x832 ( 3:2 )",
56
+ "832x1248 ( 2:3 )",
57
+ "1280x720 ( 16:9 )",
58
+ "720x1280 ( 9:16 )",
59
+ "1344x576 ( 21:9 )",
60
+ "576x1344 ( 9:21 )",
61
+ ],
62
+ }
63
+
64
+ def get_resolution(resolution):
65
+ match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
66
+ if match:
67
+ return int(match.group(1)), int(match.group(2))
68
+ return 1024, 1024
69
+
70
+ def load_models(model_path, enable_compile=False, attention_backend="native"):
71
+ print(f"Loading models from {model_path}...")
72
+ if not os.path.exists(model_path):
73
+ raise FileNotFoundError(f"Model directory not found: {model_path}")
74
+
75
+ vae = AutoencoderKL.from_pretrained(
76
+ os.path.join(model_path, "vae"),
77
+ torch_dtype=torch.bfloat16,
78
+ device_map="cuda"
79
+ )
80
+
81
+ text_encoder = AutoModel.from_pretrained(
82
+ os.path.join(model_path, "text_encoder"),
83
+ torch_dtype=torch.bfloat16,
84
+ device_map="cuda",
85
+ ).eval()
86
+
87
+ tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
88
+ tokenizer.padding_side = "left"
89
+
90
+ if enable_compile:
91
+ print("Enabling torch.compile optimizations...")
92
+ torch._inductor.config.conv_1x1_as_mm = True
93
+ torch._inductor.config.coordinate_descent_tuning = True
94
+ torch._inductor.config.epilogue_fusion = False
95
+ torch._inductor.config.coordinate_descent_check_all_directions = True
96
+ torch._inductor.config.max_autotune_gemm = True
97
+ torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN"
98
+ torch._inductor.config.triton.cudagraphs = False
99
+
100
+ pipe = ZImagePipeline(
101
+ scheduler=None,
102
+ vae=vae,
103
+ text_encoder=text_encoder,
104
+ tokenizer=tokenizer,
105
+ transformer=None
106
+ )
107
+
108
+ if enable_compile:
109
+ pipe.vae.disable_tiling()
110
+
111
+ transformer = ZImageTransformer2DModel.from_pretrained(
112
+ os.path.join(model_path, "transformer")
113
+ ).to("cuda", torch.bfloat16)
114
+
115
+ pipe.transformer = transformer
116
+ pipe.transformer.set_attention_backend(attention_backend)
117
+
118
+ if enable_compile:
119
+ print("Compiling transformer...")
120
+ pipe.transformer = torch.compile(
121
+ pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False
122
+ )
123
+
124
+ pipe.to("cuda", torch.bfloat16)
125
+
126
+ return pipe
127
+
128
+ def generate_image(
129
+ pipe,
130
+ prompt,
131
+ resolution="1024x1024",
132
+ seed=-1,
133
+ guidance_scale=5.0,
134
+ num_inference_steps=50,
135
+ shift=3.0,
136
+ max_sequence_length=512,
137
+ ):
138
+ height, width = get_resolution(resolution)
139
+
140
+ if seed == -1:
141
+ seed = torch.randint(0, 1000000, (1,)).item()
142
+ print(f"Using seed: {seed}")
143
+
144
+ generator = torch.Generator("cuda").manual_seed(seed)
145
+
146
+ scheduler = FlowMatchEulerDiscreteScheduler(
147
+ num_train_timesteps=1000,
148
+ shift=shift
149
+ )
150
+ pipe.scheduler = scheduler
151
+
152
+ image = pipe(
153
+ prompt=prompt,
154
+ height=height,
155
+ width=width,
156
+ guidance_scale=guidance_scale,
157
+ num_inference_steps=num_inference_steps,
158
+ generator=generator,
159
+ max_sequence_length=max_sequence_length,
160
+ ).images[0]
161
+
162
+ return image
163
+
164
+ def warmup_model(pipe, resolutions):
165
+ print("Starting warmup phase...")
166
+
167
+ dummy_prompt = "warmup"
168
+
169
+ for res_str in resolutions:
170
+ print(f"Warming up for resolution: {res_str}")
171
+ try:
172
+ for i in range(3):
173
+ generate_image(
174
+ pipe,
175
+ prompt=dummy_prompt,
176
+ resolution=res_str,
177
+ num_inference_steps=9,
178
+ guidance_scale=0.0,
179
+ seed=42 + i
180
+ )
181
+ except Exception as e:
182
+ print(f"Warmup failed for {res_str}: {e}")
183
+
184
+ print("Warmup completed.")
185
+
186
+ # ==================== Prompt Expander ====================
187
+ @dataclass
188
+ class PromptOutput:
189
+ status: bool
190
+ prompt: str
191
+ seed: int
192
+ system_prompt: str
193
+ message: str
194
+
195
+ class PromptExpander:
196
+ def __init__(self, backend="api", **kwargs):
197
+ self.backend = backend
198
+
199
+ def decide_system_prompt(self, template_name=None):
200
+ return prompt_template
201
+
202
+ class APIPromptExpander(PromptExpander):
203
+ def __init__(self, api_config=None, **kwargs):
204
+ super().__init__(backend="api", **kwargs)
205
+ self.api_config = api_config or {}
206
+ self.client = self._init_api_client()
207
+
208
+ def _init_api_client(self):
209
+ try:
210
+ from openai import OpenAI
211
+ api_key = self.api_config.get("api_key") or DASHSCOPE_API_KEY
212
+ base_url = self.api_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
213
+
214
+ if not api_key:
215
+ print("Warning: DASHSCOPE_API_KEY not found.")
216
+ return None
217
+
218
+ return OpenAI(api_key=api_key, base_url=base_url)
219
+ except ImportError:
220
+ print("Please install openai: pip install openai")
221
+ return None
222
+ except Exception as e:
223
+ print(f"Failed to initialize API client: {e}")
224
+ return None
225
+
226
+ def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs):
227
+ return self.extend(prompt, system_prompt, seed, **kwargs)
228
+
229
+ def extend(self, prompt, system_prompt=None, seed=-1, **kwargs):
230
+ if self.client is None:
231
+ return PromptOutput(False, "", seed, system_prompt, "API client not initialized")
232
+
233
+ if system_prompt is None:
234
+ system_prompt = self.decide_system_prompt()
235
+
236
+ if "{prompt}" in system_prompt:
237
+ system_prompt = system_prompt.format(prompt=prompt)
238
+ prompt = " "
239
+
240
+ try:
241
+ model = self.api_config.get("model", "qwen3-max-preview")
242
+ response = self.client.chat.completions.create(
243
+ model=model,
244
+ messages=[
245
+ {"role": "system", "content": system_prompt},
246
+ {"role": "user", "content": prompt}
247
+ ],
248
+ temperature=0.7,
249
+ top_p=0.8,
250
+ )
251
+
252
+ content = response.choices[0].message.content
253
+ json_start = content.find("```json")
254
+ if json_start != -1:
255
+ json_end = content.find("```", json_start + 7)
256
+ try:
257
+ json_str = content[json_start + 7 : json_end].strip()
258
+ data = json.loads(json_str)
259
+ expanded_prompt = data.get("revised_prompt", content)
260
+ except:
261
+ expanded_prompt = content
262
+ else:
263
+ expanded_prompt = content
264
+
265
+ return PromptOutput(
266
+ status=True,
267
+ prompt=expanded_prompt,
268
+ seed=seed,
269
+ system_prompt=system_prompt,
270
+ message=content
271
+ )
272
+ except Exception as e:
273
+ return PromptOutput(False, "", seed, system_prompt, str(e))
274
+
275
+ def create_prompt_expander(backend="api", **kwargs):
276
+ if backend == "api":
277
+ return APIPromptExpander(**kwargs)
278
+ raise ValueError("Only 'api' backend is supported.")
279
+
280
+ pipe = None
281
+ prompt_expander = None
282
+
283
+ def init_app():
284
+ global pipe, prompt_expander
285
+
286
+ try:
287
+ pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND)
288
+ print(f"Model loaded. Compile: {ENABLE_COMPILE}, Backend: {ATTENTION_BACKEND}")
289
+
290
+ if ENABLE_WARMUP:
291
+ all_resolutions = []
292
+ for cat in RES_CHOICES.values():
293
+ all_resolutions.extend(cat)
294
+ warmup_model(pipe, all_resolutions)
295
+
296
+ except Exception as e:
297
+ print(f"Error loading model: {e}")
298
+ pipe = None
299
+
300
+ try:
301
+ prompt_expander = create_prompt_expander(backend="api", api_config={"model": "qwen3-max-preview"})
302
+ print("Prompt expander initialized.")
303
+ except Exception as e:
304
+ print(f"Error initializing prompt expander: {e}")
305
+ prompt_expander = None
306
+
307
+ def prompt_enhance(prompt, enable_enhance):
308
+ if not enable_enhance or not prompt_expander:
309
+ return prompt, "Enhancement disabled or not available."
310
+
311
+ if not prompt.strip():
312
+ return "", "Please enter a prompt."
313
+
314
+ try:
315
+ result = prompt_expander(prompt)
316
+ if result.status:
317
+ return result.prompt, result.message
318
+ else:
319
+ return prompt, f"Enhancement failed: {result.message}"
320
+ except Exception as e:
321
+ return prompt, f"Error: {str(e)}"
322
+
323
+ def generate(prompt, resolution, seed, steps, shift, enhance):
324
+ if pipe is None:
325
+ raise gr.Error("Model not loaded.")
326
+
327
+ final_prompt = prompt
328
+
329
+ if enhance:
330
+ final_prompt, _ = prompt_enhance(prompt, True)
331
+ print(f"Enhanced prompt: {final_prompt}")
332
+
333
+ if seed == -1:
334
+ seed = random.randint(0, 1000000)
335
+
336
+ try:
337
+ resolution_str = resolution.split(" ")[0]
338
+ except:
339
+ resolution_str = "1024x1024"
340
+
341
+ image = generate_image(
342
+ pipe=pipe,
343
+ prompt=final_prompt,
344
+ resolution=resolution_str,
345
+ seed=seed,
346
+ guidance_scale=0.0,
347
+ num_inference_steps=steps,
348
+ shift=shift
349
+ )
350
+
351
+ return image, final_prompt, str(seed)
352
+
353
+ # ==================== Gradio Interface ====================
354
+ init_app()
355
+
356
+ with gr.Blocks(title="Z-Image Demo") as demo:
357
+ gr.Markdown("# Z-Image Generation Demo")
358
+
359
+ with gr.Row():
360
+ with gr.Column():
361
+ prompt_input = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here...")
362
+ with gr.Row():
363
+ enable_enhance = gr.Checkbox(label="Enhance Prompt (DashScope)", value=True)
364
+ enhance_btn = gr.Button("Enhance Only")
365
+
366
+ with gr.Row():
367
+ choices = [int(k) for k in RES_CHOICES.keys()]
368
+ res_cat = gr.Dropdown(value=1024, choices=choices, label="Resolution Category")
369
+
370
+ initial_res_choices = RES_CHOICES["1024"]
371
+ resolution = gr.Dropdown(
372
+ value=initial_res_choices[0],
373
+ choices=initial_res_choices,
374
+ label="Resolution"
375
+ )
376
+ seed = gr.Number(label="Seed", value=-1, precision=0, info="-1 for random")
377
+
378
+ with gr.Row():
379
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=9, step=1)
380
+ shift = gr.Slider(label="Shift", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
381
+
382
+ generate_btn = gr.Button("Generate", variant="primary")
383
+
384
+ with gr.Column():
385
+ output_image = gr.Image(label="Generated Image", format="png")
386
+ final_prompt_output = gr.Textbox(label="Final Prompt Used", lines=3, interactive=False)
387
+ used_seed = gr.Textbox(label="Seed Used", interactive=False)
388
+
389
+ def update_res_choices(_res_cat):
390
+ if str(_res_cat) in RES_CHOICES:
391
+ res_choices = RES_CHOICES[str(_res_cat)]
392
+ else:
393
+ res_choices = RES_CHOICES["1024"]
394
+ return gr.update(value=res_choices[0], choices=res_choices)
395
+
396
+ res_cat.change(update_res_choices, inputs=res_cat, outputs=resolution)
397
+
398
+ enhance_btn.click(
399
+ prompt_enhance,
400
+ inputs=[prompt_input, enable_enhance],
401
+ outputs=[prompt_input, final_prompt_output]
402
+ )
403
+
404
+ generate_btn.click(
405
+ generate,
406
+ inputs=[
407
+ prompt_input,
408
+ resolution,
409
+ seed,
410
+ steps,
411
+ shift,
412
+ enable_enhance
413
+ ],
414
+ outputs=[output_image, final_prompt_output, used_seed]
415
+ )
416
+
417
+ if __name__ == "__main__":
418
+ demo.launch()