Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| from dataclasses import dataclass | |
| import json | |
| import logging | |
| import os | |
| import random | |
| import re | |
| import sys | |
| import warnings | |
| from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| from diffusers import ZImagePipeline | |
| from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel | |
| from pe import prompt_template | |
| # ==================== Environment Variables ================================== | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "Tongyi-MAI/Z-Image-Turbo") | |
| ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "true").lower() == "true" | |
| ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "true").lower() == "true" | |
| ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3") | |
| DASHSCOPE_API_KEY = os.environ.get("DASHSCOPE_API_KEY") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # ============================================================================= | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| warnings.filterwarnings("ignore") | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| RES_CHOICES = { | |
| "1024": [ | |
| "1024x1024 ( 1:1 )", | |
| "1152x896 ( 9:7 )", | |
| "896x1152 ( 7:9 )", | |
| "1152x864 ( 4:3 )", | |
| "864x1152 ( 3:4 )", | |
| "1248x832 ( 3:2 )", | |
| "832x1248 ( 2:3 )", | |
| "1280x720 ( 16:9 )", | |
| "720x1280 ( 9:16 )", | |
| "1344x576 ( 21:9 )", | |
| "576x1344 ( 9:21 )", | |
| ], | |
| "1280": [ | |
| "1280x1280 ( 1:1 )", | |
| "1440x1120 ( 9:7 )", | |
| "1120x1440 ( 7:9 )", | |
| "1472x1104 ( 4:3 )", | |
| "1104x1472 ( 3:4 )", | |
| "1536x1024 ( 3:2 )", | |
| "1024x1536 ( 2:3 )", | |
| "1600x896 ( 16:9 )", | |
| "896x1600 ( 9:16 )", # not 900 coz divided by 16 needed | |
| "1680x720 ( 21:9 )", | |
| "720x1680 ( 9:21 )", | |
| ], | |
| } | |
| RESOLUTION_SET = [] | |
| for resolutions in RES_CHOICES.values(): | |
| RESOLUTION_SET.extend(resolutions) | |
| EXAMPLE_PROMPTS = [ | |
| ["ํ ๋จ์ฑ๊ณผ ๊ทธ์ ํธ๋ค์ด ์ด์ธ๋ฆฌ๋ ์์์ ์ ๊ณ ์ค๋ด ์กฐ๋ช ์๋ ๊ด๊ฐ๋ค์ด ์๋ ๋ฐฐ๊ฒฝ์์ ๊ฐ ์ผ์ ์ฐธ๊ฐํ๊ณ ์๋ ๋ชจ์ต."], | |
| [ | |
| "๋ถ์๊ธฐ ์๋ ์ด๋์ด ํค์ ์ธ๋ฌผ ์ฌ์ง, ์ฐ์ํ ์ค๊ตญ ์ฌ์ฑ์ด ์ด๋์ด ๋ฐฉ์ ์๋ค. ๊ฐํ ๋น์ด ์ ํฐ๋ฅผ ํต๊ณผํด ๊ทธ๋ ์ ์ผ๊ต์ ๋ฒ๊ฐ ๋ชจ์์ ์ ๋ช ํ ๋น๊ณผ ๊ทธ๋ฆผ์๋ฅผ ํฌ์ฌํ๋ฉฐ ํ์ชฝ ๋๋ง์ ์ ํํ ๋น์ถ๋ค. ๋์ ๋๋น, ๋ช ์ ๊ฒฝ๊ณ๊ฐ ์ ๋ช ํ๋ฉฐ, ์ ๋น๋ก์ด ๋๋, ๋ผ์ด์นด ์นด๋ฉ๋ผ ์์กฐ." | |
| ], | |
| [ | |
| "๋ฐ๊ฒ ์กฐ๋ช ๋ ์๋ฆฌ๋ฒ ์ดํฐ ์์์ ๊ธด ๊ฒ์ ๋จธ๋ฆฌ๋ฅผ ํ ์ ์ ๋์์์ ์ฌ์ฑ์ด ๊ฑฐ์ธ์ ํฅํด ์ ์นด๋ฅผ ์ฐ๋ ์ค๊ฐ ๊ฑฐ๋ฆฌ ์ค๋งํธํฐ ์ ์นด ์ฌ์ง. ๊ทธ๋ ๋ ํฐ์ ๊ฝ๋ฌด๋ฌ๊ฐ ์๋ ๊ฒ์์ ์คํ์๋ ํฌ๋กญํ๊ณผ ์ด๋์ด ์ฒญ๋ฐ์ง๋ฅผ ์ ๊ณ ์๋ค. ๋จธ๋ฆฌ๋ฅผ ์ฝ๊ฐ ๊ธฐ์ธ์ด๊ณ ์ ์ ์ ๋พฐ์กฑํ๊ฒ ๋ด๋ฐ์ด ํค์คํ๋ ๋ฏํ ํฌ์ฆ๋ก ๋งค์ฐ ๊ท์ฝ๊ณ ์ฅ๋์ค๋ฌ์ด ๋ชจ์ต์ด๋ค. ์ค๋ฅธ์์ ์ง์ ํ์ ์ค๋งํธํฐ์ ๋ค๊ณ ์ผ๊ตด ์ผ๋ถ๋ฅผ ๊ฐ๋ฆฌ๊ณ ์์ผ๋ฉฐ, ํ๋ฉด ์นด๋ฉ๋ผ ๋ ์ฆ๊ฐ ๊ฑฐ์ธ์ ํฅํ๊ณ ์๋ค." | |
| ], | |
| [ | |
| "๋นจ๊ฐ ํํธ๋ฅผ ์ ์ ์ ์ ์ค๊ตญ ์ฌ์ฑ, ์ ๊ตํ ์์. ์๋ฒฝํ ๋ฉ์ดํฌ์ , ๋ถ์ ๊ฝ๋ฌด๋ฌ ์ด๋ง ์ฅ์. ์ ๊ตํ ๋์ ์ชฝ์ง ๋จธ๋ฆฌ, ๊ธ๋น ๋ดํฉ ๋จธ๋ฆฌ ์ฅ์, ๋ถ์ ๊ฝ, ๊ตฌ์ฌ. ์ฌ์ธ๊ณผ ๋๋ฌด, ์๊ฐ ๊ทธ๋ ค์ง ๋ฅ๊ทผ ์ ์ด์ ๋ถ์ฑ๋ฅผ ๋ค๊ณ ์๋ค. ๋ค์จ ๋ฒ๊ฐ ๋ชจ์ ๋จํ (โก๏ธ), ๋ฐ์ ๋ ธ๋์ ๋น, ํผ์น ์ผ์ชฝ ์๋ฐ๋ฅ ์์. ๋ถ๋๋ฝ๊ฒ ์กฐ๋ช ๋ ์ผ์ธ ๋ฐค ๋ฐฐ๊ฒฝ, ์ค๋ฃจ์ฃ์ ๋ค์ธต ํ(์์ ๋์ํ), ํ๋ฆฟํ ์ปฌ๋ฌ ๋จผ ๋ถ๋น๋ค." | |
| ], | |
| [ | |
| '''๊ณ ์ํ๊ณ ์ฅ์ํ ์ค๊ตญ ํ๊ฒฝ์ ๋ฌ์ฌํ ์ธ๋ก ํ์์ ๋์งํธ ์ผ๋ฌ์คํธ๋ ์ด์ ์ผ๋ก, ์ ํต์ ์ธ ์ฐ์ํ ์คํ์ผ์ ํ๋์ ์ด๊ณ ๊น๋ํ ๋ฏธํ์ผ๋ก ์ฌํด์ํ๋ค. ์ฅ๋ฉด์ ์ค์ ๊ณ๊ณก์ ๋๋ฌ์ผ ๋ค์ํ ํ๋์๊ณผ ์ฒญ๋ก์ ์์์ ์ฐ๋ ์์ ๊ฐํ๋ฅธ ์ ๋ฒฝ์ด ์ง๋ฐฐํ๋ค. ๋ฉ๋ฆฌ ์ฐ๋ค์ด ์ธต์ธต์ด ์ฐํ ํ๋์๊ณผ ํฐ์ ์๊ฐ ์์ผ๋ก ์ฌ๋ผ์ง๋ฉฐ ๊ฐํ ๋๊ธฐ ์๊ทผ๊ฐ๊ณผ ๊น์ด๋ฅผ ๋ง๋ค์ด๋ธ๋ค. ๊ณ ์ํ ์ฒญ๋ก์ ๊ฐ์ด ๊ตฌ์ฑ์ ์ค์์ ๊ฐ๋ก์ง๋ฌ ํ๋ฅด๋ฉฐ, ์์ ์ ํต ์ค๊ตญ ๋ฐฐ, ์๋ง๋ ์ผํ์ด ๋ฌผ ์๋ฅผ ํญํดํ๊ณ ์๋ค. ๋ฐฐ๋ ๋ฐ์ ๋ ธ๋์ ์ฒ๋ง๊ณผ ๋ถ์ ์ ์ฒด๋ฅผ ๊ฐ์ง๊ณ ์์ผ๋ฉฐ ๋ค์ ๋ถ๋๋ฌ์ด ๋ฌผ๊ฒฐ์ ๋จ๊ธด๋ค. ์ฌ๋ฌ ๋ช ์ ํฌ๋ฏธํ ์ธ๋ฌผ๋ค์ ํ์ฐ๊ณ ์๋ค. ๋ น์ ๋๋ฌด์ ์ผ๋ถ ๋งจ๊ฐ์ง ๋๋ฌด๋ฅผ ํฌํจํ ๋๋ฌธ๋๋ฌธํ ์์์ด ๋ฐ์ ์ ๋ฐ๊ณผ ๋ด์ฐ๋ฆฌ์ ๋ถ์ด ์๋ค. ์ ์ฒด ์กฐ๋ช ์ ๋ถ๋๋ฝ๊ณ ํ์ฐ๋์ด ์ ์ฒด ์ฅ๋ฉด์ ํ์จํ ๋น์ ๋๋ฆฌ์ด๋ค. ์ด๋ฏธ์ง ์ค์์ ํ ์คํธ๊ฐ ๊ฒน์ณ์ ธ ์๋ค. ํ ์คํธ ๋ธ๋ก ์๋จ์๋ ์์ํ๋ ๋ฌธ์๊ฐ ํฌํจ๋ ์๊ณ ๋นจ๊ฐ์์ ์ํ ๋์ฅ ๊ฐ์ ๋ก๊ณ ๊ฐ ์๋ค. ๊ทธ ์๋ ์์ ๊ฒ์์ ์ฐ์ธ๋ฆฌํ ๊ธ๊ผด๋ก 'Zao-Xiang * East Beauty & West Fashion * Z-Image'๋ผ๋ ๋จ์ด๊ฐ ์๋ค. ๊ทธ ๋ฐ๋ก ์๋ ๋ ํฐ ์ฐ์ํ ๊ฒ์์ ์ธ๋ฆฌํ ๊ธ๊ผด๋ก 'SHOW & SHARE CREATIVITY WITH THE WORLD'๋ผ๋ ๋จ์ด๊ฐ ์๋ค. ๊ทธ ์ค์๋ "SHOW & SHARE", "CREATIVITY", "WITH THE WORLD"๊ฐ ์๋ค.''' | |
| ], | |
| [ | |
| """๊ฐ์์ ์ํ ใํ์์ ๋งใ(The Taste of Memory)์ ์ํ ํฌ์คํฐ. ์ฅ๋ฉด์ ์๋ฐํ 19์ธ๊ธฐ ์คํ์ผ ์ฃผ๋ฐฉ์ ์ค์ ๋์ด ์๋ค. ํ๋ฉด ์ค์์ ์ ๊ฐ์ ๋จธ๋ฆฌ์ ์์ ์ฝง์์ผ์ ๊ฐ์ง ์ค๋ ๋จ์ฑ(๋ฐฐ์ฐ ์์ ํํ ๋ฆฌ๊ฑด ์ฐ๊ธฐ)์ด ๋๋ฌด ํ ์ด๋ธ ๋ค์ ์ ์์ผ๋ฉฐ, ํฐ์ ์ ์ธ , ๊ฒ์์ ์กฐ๋ผ, ๋ฒ ์ด์ง์ ์์น๋ง๋ฅผ ์ ๊ณ ์๊ณ ํ ์ฌ์ฑ์ ๋ฐ๋ผ๋ณด๋ฉฐ ์์ ํฐ ๋ฉ์ด๋ฆฌ์ ์๊ณ ๊ธฐ๋ฅผ ๋ค๊ณ ์์ผ๋ฉฐ ์๋์๋ ๋๋ฌด ๋๋ง๊ฐ ์๋ค. ๊ทธ์ ์ค๋ฅธ์ชฝ์๋ ๋์ ์ชฝ์ง ๋จธ๋ฆฌ๋ฅผ ํ ๊ฒ์ ๋จธ๋ฆฌ ์ฌ์ฑ(๋ฐฐ์ฐ ์๋ฆฌ๋ ๋ฐด์ค ์ฐ๊ธฐ)์ด ํ ์ด๋ธ์ ๊ธฐ๋์ด ๊ทธ์๊ฒ ๋ถ๋๋ฝ๊ฒ ๋ฏธ์์ง๊ณ ์๋ค. ๊ทธ๋ ๋ ์ฐํ ์ ์ ์ธ ์ ์๋จ์ ํฐ์, ํ๋จ์ ํ๋์์ธ ๊ธด ์น๋ง๋ฅผ ์ ๊ณ ์๋ค. ํ ์ด๋ธ ์์๋ ๋ค์ง ํ์ ์๋ฐฐ์ถ ์ฑ๊ฐ ์๋ ๋๋ง ์ธ์๋ ํฐ์ ๋์๊ธฐ ์ ์, ์ ์ ํ ํ๋ธ๊ฐ ์๊ณ , ์ผ์ชฝ ๋๋ฌด ์์ ์์๋ ์ง์ ์ ํฌ๋ ํ ์ก์ด๊ฐ ๋์ฌ ์๋ค. ๋ฐฐ๊ฒฝ์ ๊ฑฐ์น ๊ฒ ํ๋ฐฑ์์ผ๋ก ๋ฏธ์ฅ๋ ๋ฒฝ์ด๋ฉฐ ํ๊ฒฝํ ํ ์ ์ด ๊ฑธ๋ ค ์๋ค. ๊ฐ์ฅ ์ค๋ฅธ์ชฝ ์์ ๋ ์์๋ ๋ณต๊ณ ํ ์ค์ผ ๋จํ๊ฐ ๋์ฌ ์๋ค. ํฌ์คํฐ์๋ ๋ง์ ํ ์คํธ ์ ๋ณด๊ฐ ์๋ค. ์ผ์ชฝ ์๋จ์๋ ํฐ์ ์ฐ์ธ๋ฆฌํ ๊ธ๊ผด๋ก "ARTISAN FILMS PRESENTS"๊ฐ ์๊ณ ๊ทธ ์๋์ "ELEANOR VANCE"์ "ACADEMY AWARDยฎ WINNER"๊ฐ ์๋ค. ์ค๋ฅธ์ชฝ ์๋จ์๋ "ARTHUR PENHALIGON"๊ณผ "GOLDEN GLOBEยฎ AWARD WINNER"๊ฐ ์ฐ์ฌ ์๋ค. ์๋จ ์ค์์๋ ์ ๋์ค ์ํ์ ์๊ณ๊ด ๋ก๊ณ ๊ฐ ์๊ณ ์๋์ "SUNDANCE FILM FESTIVAL GRAND JURY PRIZE 2024"๊ฐ ์ฐ์ฌ ์๋ค. ์ฃผ์ ์ ๋ชฉ "THE TASTE OF MEMORY"๋ ํฐ์์ ํฐ ์ธ๋ฆฌํ ๊ธ๊ผด๋ก ํ๋จ์ ๋์ ๋๊ฒ ํ์๋์ด ์๋ค. ์ ๋ชฉ ์๋์๋ "A FILM BY Tongyi Interaction Lab"์ด ๋ช ์๋์ด ์๋ค. ํ๋จ ์์ญ์๋ ํฐ์ ์์ ๊ธ์จ๋ก "SCREENPLAY BY ANNA REID", "CULINARY DIRECTION BY JAMES CARTER" ๋ฐ Artisan Films, Riverstone Pictures, Heritage Media ๋ฑ ์๋ง์ ์ ์์ฌ ๋ก๊ณ ๋ฅผ ํฌํจํ ์ ์ฒด ์ถ์ฐ์ง ๋ฐ ์ ์์ง ๋ช ๋จ์ด ๋์ด๋์ด ์๋ค. ์ ์ฒด์ ์ธ ์คํ์ผ์ ์ฌ์ค์ฃผ์๋ก ๋ฐ๋ปํ๊ณ ๋ถ๋๋ฌ์ด ์กฐ๋ช ๋ฐฉ์์ ์ฑํํ์ฌ ์น๋ฐํ ๋ถ์๊ธฐ๋ฅผ ์กฐ์ฑํ๋ค. ์์กฐ๋ ๊ฐ์, ๋ฒ ์ด์ง, ๋ถ๋๋ฌ์ด ๋ น์ ๋ฑ ๋์ง์ ํค์ด ์ฃผ๋ฅผ ์ด๋ฃฌ๋ค. ๋ ๋ฐฐ์ฐ์ ๋ชธ์ ๋ชจ๋ ํ๋ฆฌ์์ ์๋ ค ์๋ค.""" | |
| ], | |
| [ | |
| """์ ์ฌ๊ฐํ ๊ตฌ๋์ ํด๋ก์ฆ์ ์ฌ์ง์ผ๋ก, ๊ฑฐ๋ํ๊ณ ์ ๋ช ํ ๋ น์ ์๋ฌผ ์์ด ์ฃผ์ ์ด๋ฉฐ ํ ์คํธ๊ฐ ๊ฒน์ณ์ ธ ํฌ์คํฐ๋ ์ก์ง ํ์ง ๊ฐ์ ์ธ๊ด์ ๊ฐ์ถ๊ณ ์๋ค. ์ฃผ์ ํผ์ฌ์ฒด๋ ์ผ์ชฝ ํ๋จ์์ ์ค๋ฅธ์ชฝ ์๋จ์ผ๋ก ๋๊ฐ์ ์ผ๋ก ๊ตฌ๋ถ๋ฌ์ ธ ํ๋ ์์ ๊ฐ๋ก์ง๋ฅด๋ ๋๊ป๊ณ ์์ค ๊ฐ์ ์ง๊ฐ์ ์์ด๋ค. ํ๋ฉด์ด ๋งค์ฐ ๋ฐ์ฌ์ ์ด์ด์ ๋ฐ์ ์ง์ฌ๊ด์์ ํฌ์ฐฉํ์ฌ ๋๋๋ฌ์ง ํ์ด๋ผ์ดํธ๋ฅผ ํ์ฑํ๊ณ ๋ฐ์ ๋ฉด ์๋ ํํํ ๋ฏธ์ธ ์๋งฅ์ด ๋๋ฌ๋๋ค. ๋ฐฐ๊ฒฝ์ ๋ค๋ฅธ ์ง์ ๋ น์ ์๋ค๋ก ๊ตฌ์ฑ๋์ด ์์ผ๋ฉฐ ์ฝ๊ฐ ์ด์ ์ด ํ๋ ค์ ธ ์์ ํผ์ฌ๊ณ ์ฌ๋ ํจ๊ณผ๋ฅผ ๋ง๋ค์ด ์ ๊ฒฝ์ ์ฃผ์ ์์ ๊ฐ์กฐํ๋ค. ์ ์ฒด์ ์ธ ์คํ์ผ์ ์ฌ์ค์ ์ธ ์ฌ์ง์ผ๋ก ๋ฐ์ ์๊ณผ ์ด๋์ด ๊ทธ๋ฆผ์ ๋ฐฐ๊ฒฝ ์ฌ์ด์ ๋์ ๋๋น๋ฅผ ํ์ฑํ๋ค. ์ด๋ฏธ์ง์๋ ์ฌ๋ฌ ๋ ๋๋ง๋ ํ ์คํธ๊ฐ ์๋ค. ์ผ์ชฝ ์๋จ์๋ ํฐ์ ์ธ๋ฆฌํ ๊ธ๊ผด๋ก "PIXEL-PEEPERS GUILD Presents"๋ผ๋ ํ ์คํธ๊ฐ ์๋ค. ์ค๋ฅธ์ชฝ ์๋จ์๋ ํฐ์ ์ธ๋ฆฌํ ๊ธ๊ผด๋ก "[Instant Noodle] ๆณก้ข่ฐๆๅ "๋ผ๋ ํ ์คํธ๊ฐ ์๋ค. ์ผ์ชฝ์๋ ์์ง์ผ๋ก ๋ฐฐ์ด๋ ์ ๋ชฉ "Render Distance: Max"๊ฐ ํฐ์ ์ธ๋ฆฌํ ๊ธ๊ผด๋ก ๋์ด ์๋ค. ์ผ์ชฝ ํ๋จ์๋ ๋ค์ฏ ๊ฐ์ ๊ฑฐ๋ํ ํฐ์ ์ก์ฒด ํ์ "ๆพๅกๅจ...็็ง"๊ฐ ์๋ค. ์ค๋ฅธ์ชฝ ํ๋จ์๋ ์์ ํฐ์ ์ธ๋ฆฌํ ๊ธ๊ผด๋ก "Leica Glowโข Unobtanium X-1"์ด ์๊ณ , ๊ทธ ๋ฐ๋ก ์์๋ ํฐ์ ์ก์ฒด๋ก ์ฐ์ธ ์ด๋ฆ "่กๅ "๊ฐ ์๋ค. ์๋ณ๋ ํต์ฌ ๊ฐ์ฒด์๋ ๋ธ๋๋ ํฝ์ ํผํผ์ค ๊ธธ๋, ์ ํ ๋ผ์ธ ์ธ์คํดํธ ๋๋ค ์กฐ๋ฏธ๋ฃ ํจํค์ง, ์นด๋ฉ๋ผ ๋ชจ๋ธ Unobtaniumโข X-1 ๋ฐ ์ฌ์ง๊ฐ ์ด๋ฆ Zao-Xiang์ด ํฌํจ๋๋ค.""" | |
| ], | |
| ] | |
| def get_resolution(resolution): | |
| match = re.search(r"(\d+)\s*[รx]\s*(\d+)", resolution) | |
| if match: | |
| return int(match.group(1)), int(match.group(2)) | |
| return 1024, 1024 | |
| def load_models(model_path, enable_compile=False, attention_backend="native"): | |
| print(f"Loading models from {model_path}...") | |
| use_auth_token = HF_TOKEN if HF_TOKEN else True | |
| if not os.path.exists(model_path): | |
| vae = AutoencoderKL.from_pretrained( | |
| f"{model_path}", | |
| subfolder="vae", | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| use_auth_token=use_auth_token, | |
| ) | |
| text_encoder = AutoModel.from_pretrained( | |
| f"{model_path}", | |
| subfolder="text_encoder", | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| use_auth_token=use_auth_token, | |
| ).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(f"{model_path}", subfolder="tokenizer", use_auth_token=use_auth_token) | |
| else: | |
| vae = AutoencoderKL.from_pretrained( | |
| os.path.join(model_path, "vae"), torch_dtype=torch.bfloat16, device_map="cuda" | |
| ) | |
| text_encoder = AutoModel.from_pretrained( | |
| os.path.join(model_path, "text_encoder"), | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", | |
| ).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer")) | |
| tokenizer.padding_side = "left" | |
| if enable_compile: | |
| print("Enabling torch.compile optimizations...") | |
| torch._inductor.config.conv_1x1_as_mm = True | |
| torch._inductor.config.coordinate_descent_tuning = True | |
| torch._inductor.config.epilogue_fusion = False | |
| torch._inductor.config.coordinate_descent_check_all_directions = True | |
| torch._inductor.config.max_autotune_gemm = True | |
| torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN" | |
| torch._inductor.config.triton.cudagraphs = False | |
| pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None) | |
| if enable_compile: | |
| pipe.vae.disable_tiling() | |
| if not os.path.exists(model_path): | |
| transformer = ZImageTransformer2DModel.from_pretrained( | |
| f"{model_path}", subfolder="transformer", use_auth_token=use_auth_token | |
| ).to("cuda", torch.bfloat16) | |
| else: | |
| transformer = ZImageTransformer2DModel.from_pretrained(os.path.join(model_path, "transformer")).to( | |
| "cuda", torch.bfloat16 | |
| ) | |
| pipe.transformer = transformer | |
| pipe.transformer.set_attention_backend(attention_backend) | |
| if enable_compile: | |
| print("Compiling transformer...") | |
| pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False) | |
| pipe.to("cuda", torch.bfloat16) | |
| return pipe | |
| def generate_image( | |
| pipe, | |
| prompt, | |
| resolution="1024x1024", | |
| seed=42, | |
| guidance_scale=5.0, | |
| num_inference_steps=50, | |
| shift=3.0, | |
| max_sequence_length=512, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| width, height = get_resolution(resolution) | |
| generator = torch.Generator("cuda").manual_seed(seed) | |
| scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift) | |
| pipe.scheduler = scheduler | |
| image = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| max_sequence_length=max_sequence_length, | |
| ).images[0] | |
| return image | |
| def warmup_model(pipe, resolutions): | |
| print("Starting warmup phase...") | |
| dummy_prompt = "warmup" | |
| for res_str in resolutions: | |
| print(f"Warming up for resolution: {res_str}") | |
| try: | |
| for i in range(3): | |
| generate_image( | |
| pipe, | |
| prompt=dummy_prompt, | |
| resolution=res_str, | |
| num_inference_steps=9, | |
| guidance_scale=0.0, | |
| seed=42 + i, | |
| ) | |
| except Exception as e: | |
| print(f"Warmup failed for {res_str}: {e}") | |
| print("Warmup completed.") | |
| # ==================== Prompt Expander ==================== | |
| class PromptOutput: | |
| status: bool | |
| prompt: str | |
| seed: int | |
| system_prompt: str | |
| message: str | |
| class PromptExpander: | |
| def __init__(self, backend="api", **kwargs): | |
| self.backend = backend | |
| def decide_system_prompt(self, template_name=None): | |
| return prompt_template | |
| class APIPromptExpander(PromptExpander): | |
| def __init__(self, api_config=None, **kwargs): | |
| super().__init__(backend="api", **kwargs) | |
| self.api_config = api_config or {} | |
| self.client = self._init_api_client() | |
| def _init_api_client(self): | |
| try: | |
| from openai import OpenAI | |
| api_key = self.api_config.get("api_key") or DASHSCOPE_API_KEY | |
| base_url = self.api_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1") | |
| if not api_key: | |
| print("Warning: DASHSCOPE_API_KEY not found.") | |
| return None | |
| return OpenAI(api_key=api_key, base_url=base_url) | |
| except ImportError: | |
| print("Please install openai: pip install openai") | |
| return None | |
| except Exception as e: | |
| print(f"Failed to initialize API client: {e}") | |
| return None | |
| def __call__(self, prompt, system_prompt=None, seed=-1, **kwargs): | |
| return self.extend(prompt, system_prompt, seed, **kwargs) | |
| def extend(self, prompt, system_prompt=None, seed=-1, **kwargs): | |
| if self.client is None: | |
| return PromptOutput(False, "", seed, system_prompt, "API client not initialized") | |
| if system_prompt is None: | |
| system_prompt = self.decide_system_prompt() | |
| if "{prompt}" in system_prompt: | |
| system_prompt = system_prompt.format(prompt=prompt) | |
| prompt = " " | |
| try: | |
| model = self.api_config.get("model", "qwen3-max-preview") | |
| response = self.client.chat.completions.create( | |
| model=model, | |
| messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}], | |
| temperature=0.7, | |
| top_p=0.8, | |
| ) | |
| content = response.choices[0].message.content | |
| json_start = content.find("```json") | |
| if json_start != -1: | |
| json_end = content.find("```", json_start + 7) | |
| try: | |
| json_str = content[json_start + 7 : json_end].strip() | |
| data = json.loads(json_str) | |
| expanded_prompt = data.get("revised_prompt", content) | |
| except: | |
| expanded_prompt = content | |
| else: | |
| expanded_prompt = content | |
| return PromptOutput( | |
| status=True, prompt=expanded_prompt, seed=seed, system_prompt=system_prompt, message=content | |
| ) | |
| except Exception as e: | |
| return PromptOutput(False, "", seed, system_prompt, str(e)) | |
| def create_prompt_expander(backend="api", **kwargs): | |
| if backend == "api": | |
| return APIPromptExpander(**kwargs) | |
| raise ValueError("Only 'api' backend is supported.") | |
| pipe = None | |
| prompt_expander = None | |
| def init_app(): | |
| global pipe, prompt_expander | |
| try: | |
| pipe = load_models(MODEL_PATH, enable_compile=ENABLE_COMPILE, attention_backend=ATTENTION_BACKEND) | |
| print(f"Model loaded. Compile: {ENABLE_COMPILE}, Backend: {ATTENTION_BACKEND}") | |
| if ENABLE_WARMUP: | |
| all_resolutions = [] | |
| for cat in RES_CHOICES.values(): | |
| all_resolutions.extend(cat) | |
| warmup_model(pipe, all_resolutions) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| pipe = None | |
| try: | |
| prompt_expander = create_prompt_expander(backend="api", api_config={"model": "qwen3-max-preview"}) | |
| print("Prompt expander initialized.") | |
| except Exception as e: | |
| print(f"Error initializing prompt expander: {e}") | |
| prompt_expander = None | |
| def prompt_enhance(prompt, enable_enhance): | |
| if not enable_enhance or not prompt_expander: | |
| return prompt, "ํ๋กฌํํธ ํฅ์์ด ๋นํ์ฑํ๋์๊ฑฐ๋ ์ฌ์ฉํ ์ ์์ต๋๋ค." | |
| if not prompt.strip(): | |
| return "", "ํ๋กฌํํธ๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์." | |
| try: | |
| result = prompt_expander(prompt) | |
| if result.status: | |
| return result.prompt, result.message | |
| else: | |
| return prompt, f"ํฅ์ ์คํจ: {result.message}" | |
| except Exception as e: | |
| return prompt, f"์ค๋ฅ: {str(e)}" | |
| def generate( | |
| prompt, resolution="1024x1024 ( 1:1 )", seed=42, steps=9, shift=3.0, enhance=False, random_seed=True, gallery_images=None, progress=gr.Progress(track_tqdm=True) | |
| ): | |
| """ | |
| Generate an image using the Z-Image model based on the provided prompt and settings. | |
| This function is triggered when the user clicks the "Generate" button. It processes | |
| the input prompt (optionally enhancing it), configures generation parameters, and | |
| produces an image using the Z-Image diffusion transformer pipeline. | |
| Args: | |
| prompt (str): Text prompt describing the desired image content | |
| resolution (str): Output resolution in format "WIDTHxHEIGHT ( RATIO )" (e.g., "1024x1024 ( 1:1 )") | |
| valid options, 1024 category: | |
| - "1024x1024 ( 1:1 )" | |
| - "1152x896 ( 9:7 )" | |
| - "896x1152 ( 7:9 )" | |
| - "1152x864 ( 4:3 )" | |
| - "864x1152 ( 3:4 )" | |
| - "1248x832 ( 3:2 )" | |
| - "832x1248 ( 2:3 )" | |
| - "1280x720 ( 16:9 )" | |
| - "720x1280 ( 9:16 )" | |
| - "1344x576 ( 21:9 )" | |
| - "576x1344 ( 9:21 )" | |
| 1280 category: | |
| - "1280x1280 ( 1:1 )" | |
| - "1440x1120 ( 9:7 )" | |
| - "1120x1440 ( 7:9 )" | |
| - "1472x1104 ( 4:3 )" | |
| - "1104x1472 ( 3:4 )" | |
| - "1536x1024 ( 3:2 )" | |
| - "1024x1536 ( 2:3 )" | |
| - "1600x896 ( 16:9 )" | |
| - "896x1600 ( 9:16 )" | |
| - "1680x720 ( 21:9 )" | |
| - "720x1680 ( 9:21 )" | |
| seed (int): Seed for reproducible generation | |
| steps (int): Number of inference steps for the diffusion process | |
| shift (float): Time shift parameter for the flow matching scheduler | |
| enhance (bool): This was Whether to enhance the prompt (DISABLED! Do not use) | |
| random_seed (bool): Whether to generate a new random seed, if True will ignore the seed input | |
| gallery_images (list): List of previously generated images to append to (only needed for the Gradio UI) | |
| progress (gr.Progress): Gradio progress tracker for displaying generation progress (only needed for the Gradio UI) | |
| Returns: | |
| tuple: (gallery_images, seed_str, seed_int) | |
| - gallery_images: Updated list of generated images including the new image | |
| - seed_str: String representation of the seed used for generation | |
| - seed_int: Integer representation of the seed used for generation | |
| """ | |
| if pipe is None: | |
| raise gr.Error("๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค.") | |
| final_prompt = prompt | |
| if enhance: | |
| final_prompt, _ = prompt_enhance(prompt, True) | |
| print(f"Enhanced prompt: {final_prompt}") | |
| if random_seed: | |
| new_seed = random.randint(1, 1000000) | |
| else: | |
| new_seed = seed if seed != -1 else random.randint(1, 1000000) | |
| try: | |
| resolution_str = resolution.split(" ")[0] | |
| except: | |
| resolution_str = "1024x1024" | |
| image = generate_image( | |
| pipe=pipe, | |
| prompt=final_prompt, | |
| resolution=resolution_str, | |
| seed=new_seed, | |
| guidance_scale=0.0, | |
| num_inference_steps=int(steps + 1), | |
| shift=shift, | |
| ) | |
| if gallery_images is None: | |
| gallery_images = [] | |
| gallery_images.append(image) | |
| return gallery_images, str(new_seed), int(new_seed) | |
| init_app() | |
| # ==================== AoTI (Ahead of Time Inductor compilation) ==================== | |
| pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"] | |
| spaces.aoti_blocks_load(pipe.transformer.layers, "zerogpu-aoti/Z-Image", variant="fa3") | |
| with gr.Blocks(title="Z-Image ๋ฐ๋ชจ") as demo: | |
| gr.Markdown( | |
| """<div align="center"> | |
| # Z-Image ์ด๋ฏธ์ง ์์ฑ ๋ฐ๋ชจ | |
| [](https://github.com/Tongyi-MAI/Z-Image) | |
| *๋จ์ผ ์คํธ๋ฆผ ๋ํจ์ ํธ๋์คํฌ๋จธ๋ฅผ ์ฌ์ฉํ ํจ์จ์ ์ธ ์ด๋ฏธ์ง ์์ฑ ๊ธฐ๋ฐ ๋ชจ๋ธ* | |
| </div>""" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox(label="ํ๋กฌํํธ", lines=3, placeholder="ํ๋กฌํํธ๋ฅผ ์ ๋ ฅํ์ธ์...") | |
| # PE components (Temporarily disabled) | |
| # with gr.Row(): | |
| # enable_enhance = gr.Checkbox(label="ํ๋กฌํํธ ํฅ์ (DashScope)", value=False) | |
| # enhance_btn = gr.Button("ํฅ์๋ง ์คํ") | |
| with gr.Row(): | |
| choices = [int(k) for k in RES_CHOICES.keys()] | |
| res_cat = gr.Dropdown(value=1024, choices=choices, label="ํด์๋ ์นดํ ๊ณ ๋ฆฌ") | |
| initial_res_choices = RES_CHOICES["1024"] | |
| resolution = gr.Dropdown(value=initial_res_choices[0], choices=RESOLUTION_SET, label="๋๋น x ๋์ด (๋น์จ)") | |
| with gr.Row(): | |
| seed = gr.Number(label="์๋", value=42, precision=0) | |
| random_seed = gr.Checkbox(label="๋๋ค ์๋", value=True) | |
| with gr.Row(): | |
| steps = gr.Slider(label="์คํ ์", minimum=1, maximum=100, value=8, step=1, interactive=False) | |
| shift = gr.Slider(label="์๊ฐ ์ด๋", minimum=1.0, maximum=10.0, value=3.0, step=0.1) | |
| generate_btn = gr.Button("์์ฑ", variant="primary") | |
| # Example prompts | |
| gr.Markdown("### ๐ ์์ ํ๋กฌํํธ") | |
| gr.Examples(examples=EXAMPLE_PROMPTS, inputs=prompt_input, label=None) | |
| with gr.Column(scale=1): | |
| output_gallery = gr.Gallery( | |
| label="์์ฑ๋ ์ด๋ฏธ์ง", columns=2, rows=2, height=600, object_fit="contain", format="png", interactive=False | |
| ) | |
| used_seed = gr.Textbox(label="์ฌ์ฉ๋ ์๋", interactive=False) | |
| def update_res_choices(_res_cat): | |
| if str(_res_cat) in RES_CHOICES: | |
| res_choices = RES_CHOICES[str(_res_cat)] | |
| else: | |
| res_choices = RES_CHOICES["1024"] | |
| return gr.update(value=res_choices[0], choices=res_choices) | |
| res_cat.change(update_res_choices, inputs=res_cat, outputs=resolution, api_visibility="private") | |
| # PE enhancement button (Temporarily disabled) | |
| # enhance_btn.click( | |
| # prompt_enhance, | |
| # inputs=[prompt_input, enable_enhance], | |
| # outputs=[prompt_input, final_prompt_output] | |
| # ) | |
| # Dummy enable_enhance variable set to False | |
| enable_enhance = gr.State(value=False) | |
| generate_btn.click( | |
| generate, | |
| inputs=[prompt_input, resolution, seed, steps, shift, enable_enhance, random_seed, output_gallery], | |
| outputs=[output_gallery, used_seed, seed], | |
| api_visibility="public", | |
| ) | |
| css=''' | |
| .fillable{max-width: 1230px !important} | |
| ''' | |
| if __name__ == "__main__": | |
| demo.launch(css=css, mcp_server=True) |