programmersd commited on
Commit
29958ad
·
verified ·
1 Parent(s): 966d466

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -122
app.py CHANGED
@@ -5,60 +5,93 @@ import time
5
  import random
6
  import torch
7
  import gradio as gr
8
- from threading import Lock
9
  from contextlib import contextmanager
 
 
 
10
 
11
- # --- LOGGING FOR UI ---
12
  LOG_BUFFER = []
13
  LOG_LOCK = Lock()
14
 
15
- def log(message):
16
- print(message)
17
  with LOG_LOCK:
18
- LOG_BUFFER.append(f"{time.strftime('%H:%M:%S')} | {message}")
 
19
  if len(LOG_BUFFER) > 500:
20
  LOG_BUFFER.pop(0)
 
21
  return "\n".join(LOG_BUFFER)
22
 
23
- _initial_logs = log("🚀 Initializing Ultimate Z-Image Turbo CPU Edition...")
24
 
25
- # CPU THREAD OPTIMIZATION
26
  CPU_THREADS = min(8, os.cpu_count() or 1)
27
- os.environ["OMP_NUM_THREADS"] = str(CPU_THREADS)
28
- os.environ["MKL_NUM_THREADS"] = str(CPU_THREADS)
29
- os.environ["OPENBLAS_NUM_THREADS"] = str(CPU_THREADS)
30
- os.environ["VECLIB_MAXIMUM_THREADS"] = str(CPU_THREADS)
31
- os.environ["NUMEXPR_NUM_THREADS"] = str(CPU_THREADS)
32
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
33
  os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
34
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
 
 
35
  os.environ["TRANSFORMERS_CACHE"] = "./hf_cache"
36
  os.environ["HF_DATASETS_CACHE"] = "./hf_cache"
37
 
38
- torch.set_num_threads(CPU_THREADS)
39
  torch.set_grad_enabled(False)
 
40
  torch.backends.mkldnn.enabled = True
41
- torch.backends.mkldnn.deterministic = False
42
- torch.set_flush_denormal(True)
43
  torch.set_float32_matmul_precision("medium")
44
 
45
  DEVICE = "cpu"
46
  DTYPE = torch.float32
47
- CACHE_DIR = "./hf_cache"
48
- os.makedirs(CACHE_DIR, exist_ok=True)
49
-
50
- log(f"⚡ CPU Threads: {CPU_THREADS}, Device: {DEVICE}, DType: {DTYPE}")
51
 
52
  try:
53
  from diffusers import ZImagePipeline
54
- log("📦 diffusers imported successfully")
55
  except ImportError as e:
56
- log(f"Import Error: {e}")
57
  sys.exit(1)
58
 
59
- pipe = None
60
- _pipe_lock = Lock()
61
- _generation_lock = Lock()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  @contextmanager
64
  def managed_memory():
@@ -70,77 +103,68 @@ def managed_memory():
70
  if torch.cuda.is_available():
71
  torch.cuda.empty_cache()
72
 
73
- def load_pipeline():
74
- global pipe
75
- with _pipe_lock:
76
- if pipe is not None:
77
- return pipe
78
-
79
- log("📦 Loading Z-Image Turbo pipeline...")
80
- start_load = time.time()
81
-
82
- pipe = ZImagePipeline.from_pretrained(
83
- "Tongyi-MAI/Z-Image-Turbo",
84
- torch_dtype=DTYPE,
85
- cache_dir=CACHE_DIR,
86
- low_cpu_mem_usage=True
87
- )
88
-
89
- pipe = pipe.to(DEVICE)
90
  pipe.vae.eval()
91
  pipe.text_encoder.eval()
92
  pipe.transformer.eval()
93
-
94
  try:
95
- pipe.transformer = torch.compile(
96
- pipe.transformer,
97
- mode="reduce-overhead",
98
- fullgraph=False,
99
- dynamic=False
100
- )
101
- log("✅ Transformer compiled successfully!")
102
- except Exception as compile_error:
103
- log(f"⚠️ torch.compile() failed: {compile_error}")
104
-
105
- load_time = time.time() - start_load
106
- log(f"✅ Pipeline loaded in {load_time:.2f}s")
107
  return pipe
108
 
 
 
109
  @torch.inference_mode()
110
  @torch.no_grad()
111
- def generate(prompt, quality_mode, seed, progress=gr.Progress()):
112
  if not prompt.strip():
113
- raise gr.Error("🎯 Prompt cannot be empty!")
114
-
115
- quality_settings = {
116
- "ultra_fast": {"steps": 1, "width": 256, "height": 256},
117
- "fast": {"steps": 1, "width": 256, "height": 256},
118
- "balanced": {"steps": 2, "width": 256, "height": 256},
119
- "quality": {"steps": 4, "width": 384, "height": 384},
120
- "ultra_quality": {"steps": 4, "width": 512, "height": 512}
121
  }
122
- settings = quality_settings.get(quality_mode, quality_settings["fast"])
123
- steps, width, height = settings["steps"], settings["width"], settings["height"]
124
-
125
- seed = int(seed) if seed >= 0 else random.randint(0, 2**31 - 1)
126
- log(f"🎨 Generating: '{prompt[:50]}...' | Mode: {quality_mode} | {width}x{height} | Seed: {seed}")
127
-
128
- with managed_memory():
129
- with _generation_lock:
130
- pipe = load_pipeline()
131
- generator = torch.Generator("cpu").manual_seed(seed)
132
- start_time = time.time()
133
-
134
- def diffusers_progress_callback(pipeline, step_index, timestep, callback_kwargs):
135
- elapsed = time.time() - start_time
136
- avg = elapsed / (step_index + 1) if step_index >= 0 else 0
137
- remaining = avg * (steps - step_index - 1)
138
- progress(
139
- (step_index + 1) / steps,
140
- desc=f"Step {step_index+1}/{steps} | ETA {remaining:.1f}s"
141
- )
142
- return callback_kwargs
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  result = pipe(
145
  prompt=prompt,
146
  negative_prompt=None,
@@ -149,61 +173,57 @@ def generate(prompt, quality_mode, seed, progress=gr.Progress()):
149
  num_inference_steps=steps,
150
  guidance_scale=0.0,
151
  generator=generator,
152
- callback_on_step_end=diffusers_progress_callback,
153
  callback_on_step_end_tensor_inputs=["latents"],
154
  output_type="pil"
155
  )
 
 
 
 
 
156
 
157
- image = result.images[0]
158
- elapsed = time.time() - start_time
159
- log(f"✅ Generated in {elapsed:.2f}s | Seed: {seed}")
160
 
161
- del result
162
- gc.collect()
163
 
164
- return image, seed
165
 
166
- with gr.Blocks(title="🚀 Z-Image Turbo Pro Max + Live Logs") as demo:
167
- gr.Markdown("## GPU‑FREE CPU TurboLive Logs Below")
168
 
169
  with gr.Row():
170
  with gr.Column():
171
  prompt = gr.Textbox(label="Prompt", lines=4)
172
  quality_mode = gr.Radio(
173
- choices=[
174
- ("Ultra Fast", "ultra_fast"),
175
- ("Fast", "fast"),
176
- ("Balanced", "balanced"),
177
- ("Quality", "quality"),
178
- ("Ultra Quality", "ultra_quality")
179
- ],
180
  value="fast",
181
  label="Quality Mode"
182
  )
183
- seed = gr.Number(value=-1, precision=0, label="Seed")
184
- generate_btn = gr.Button("GENERATE")
 
 
 
185
  with gr.Column():
186
- output_image = gr.Image(label="Output")
187
- used_seed = gr.Number(label="Seed Used", interactive=False)
188
- log_output = gr.Textbox(
189
- label="Live System Log",
190
- lines=15,
191
- interactive=False
192
- )
193
 
194
- def wrapped_generate(prompt, quality_mode, seed):
195
- image, used_seed = generate(prompt, quality_mode, seed)
196
- logs = log("🧠 Latest status: Finished generation.")
197
- return image, used_seed, logs
198
 
199
- generate_btn.click(
200
- wrapped_generate,
201
- inputs=[prompt, quality_mode, seed],
202
- outputs=[output_image, used_seed, log_output],
203
- concurrency_limit=1
204
- )
205
 
206
- demo.queue(max_size=3)
 
207
 
208
- if __name__ == "__main__":
209
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
5
  import random
6
  import torch
7
  import gradio as gr
8
+ from threading import Lock, Event
9
  from contextlib import contextmanager
10
+ from huggingface_hub import snapshot_download, LocalEntryNotFoundError
11
+
12
+ # ----------- LOGGING -----------
13
 
 
14
  LOG_BUFFER = []
15
  LOG_LOCK = Lock()
16
 
17
+ def log(msg):
 
18
  with LOG_LOCK:
19
+ timestamp = time.strftime('%H:%M:%S')
20
+ LOG_BUFFER.append(f"{timestamp} | {msg}")
21
  if len(LOG_BUFFER) > 500:
22
  LOG_BUFFER.pop(0)
23
+ print(msg)
24
  return "\n".join(LOG_BUFFER)
25
 
26
+ # ----------- ENV CONFIG -----------
27
 
 
28
  CPU_THREADS = min(8, os.cpu_count() or 1)
29
+ for var in ["OMP_NUM_THREADS","MKL_NUM_THREADS","OPENBLAS_NUM_THREADS","VECLIB_MAXIMUM_THREADS","NUMEXPR_NUM_THREADS"]:
30
+ os.environ[var] = str(CPU_THREADS)
31
+
 
 
32
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
33
  os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
34
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
35
+ os.environ["HF_HUB_OFFLINE"] = "1"
36
+ os.environ["TRANSFORMERS_OFFLINE"] = "1"
37
  os.environ["TRANSFORMERS_CACHE"] = "./hf_cache"
38
  os.environ["HF_DATASETS_CACHE"] = "./hf_cache"
39
 
 
40
  torch.set_grad_enabled(False)
41
+ torch.set_num_threads(CPU_THREADS)
42
  torch.backends.mkldnn.enabled = True
 
 
43
  torch.set_float32_matmul_precision("medium")
44
 
45
  DEVICE = "cpu"
46
  DTYPE = torch.float32
47
+ os.makedirs("./hf_cache", exist_ok=True)
 
 
 
48
 
49
  try:
50
  from diffusers import ZImagePipeline
51
+ log("Imported diffusers successfully.")
52
  except ImportError as e:
53
+ log(f"Import diffusers failed: {e}")
54
  sys.exit(1)
55
 
56
+ pipe_cache = {}
57
+ pipe_lock = Lock()
58
+ generation_lock = Lock()
59
+ interrupt_event = Event()
60
+
61
+ # ----------- SNAPSHOT WITH RETRY -----------
62
+
63
+ MODEL_SPECS = {
64
+ "Z-Image Turbo": "Tongyi-MAI/Z-Image-Turbo",
65
+ # Optionally add quantized variants here
66
+ # "Z-Image Turbo GGUF": "unsloth/Z-Image-Turbo-GGUF",
67
+ }
68
+
69
+ def download_snapshot_with_retry(repo_id, local_path, retries=3):
70
+ attempt = 1
71
+ while attempt <= retries:
72
+ log(f"Snapshot attempt {attempt}/{retries} for {repo_id}...")
73
+ try:
74
+ # snapshot_download respects HF cache and will skip downloads if cached
75
+ path = snapshot_download(repo_id=repo_id, local_dir=local_path, local_dir_use_symlinks=False)
76
+ log(f"Snapshot fully downloaded: {path}")
77
+ return path
78
+ except Exception as e:
79
+ log(f"⚠️ snapshot_download failed: {e}")
80
+ attempt += 1
81
+ time.sleep(2)
82
+ raise RuntimeError(f"Failed to download snapshot of {repo_id} after {retries} attempts")
83
+
84
+ # Ensure snapshot is present
85
+ for model_name, repo_id in MODEL_SPECS.items():
86
+ local_dir = os.path.join("./hf_cache", f"{model_name}_snapshot")
87
+ if not os.path.isdir(local_dir) or not os.listdir(local_dir):
88
+ log(f"📥 No snapshot for {model_name}, starting download...")
89
+ try:
90
+ download_snapshot_with_retry(repo_id, local_dir, retries=3)
91
+ except RuntimeError as err:
92
+ log(f"❌ Snapshot download error: {err}")
93
+
94
+ # ----------- PIPELINE LOADING -----------
95
 
96
  @contextmanager
97
  def managed_memory():
 
103
  if torch.cuda.is_available():
104
  torch.cuda.empty_cache()
105
 
106
+ def load_pipeline(model_name):
107
+ if model_name in pipe_cache:
108
+ return pipe_cache[model_name]
109
+ with pipe_lock:
110
+ log(f"Loading {model_name} pipeline.")
111
+ repo_dir = os.path.join("./hf_cache", f"{model_name}_snapshot")
112
+ try:
113
+ pipe = ZImagePipeline.from_pretrained(repo_dir, torch_dtype=DTYPE, local_files_only=True, low_cpu_mem_usage=True)
114
+ except LocalEntryNotFoundError:
115
+ log(f"Incomplete local snapshot for {model_name}, retrying online load.")
116
+ pipe = ZImagePipeline.from_pretrained(MODEL_SPECS[model_name], torch_dtype=DTYPE, cache_dir="./hf_cache", low_cpu_mem_usage=True)
117
+ pipe.to(DEVICE)
 
 
 
 
 
118
  pipe.vae.eval()
119
  pipe.text_encoder.eval()
120
  pipe.transformer.eval()
 
121
  try:
122
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")
123
+ log("Transformer compiled.")
124
+ except Exception as e:
125
+ log(f"Transformer compile skipped: {e}")
126
+ pipe_cache[model_name] = pipe
 
 
 
 
 
 
 
127
  return pipe
128
 
129
+ # ----------- GENERATION LOGIC -----------
130
+
131
  @torch.inference_mode()
132
  @torch.no_grad()
133
+ def generate(prompt, quality_mode, seed, model_name):
134
  if not prompt.strip():
135
+ raise gr.Error("Prompt cannot be empty!")
136
+
137
+ PRESETS = {
138
+ "ultra_fast": (1, 256),
139
+ "fast": (1, 256),
140
+ "balanced": (2, 256),
141
+ "quality": (4, 384),
142
+ "ultra_quality": (4, 512),
143
  }
144
+ steps, size = PRESETS.get(quality_mode, (1, 256))
145
+ width = height = size
146
+
147
+ seed = int(seed) if seed >= 0 else random.randint(0, (2**31)-1)
148
+ log(f"Generating: '{prompt[:40]}...' | {quality_mode} | {width}x{height} | seed={seed}")
149
+
150
+ with managed_memory(), generation_lock:
151
+ pipe = load_pipeline(model_name)
152
+ generator = torch.Generator("cpu").manual_seed(seed)
153
+ start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ preview_images = []
156
+
157
+ def progress_cb(pipeline, step_idx, timestep, cbk):
158
+ if interrupt_event.is_set():
159
+ raise KeyboardInterrupt("Generation interrupted")
160
+ if step_idx % 2 == 0: # preview every 2 steps
161
+ try:
162
+ preview_images.append(pipeline.image_from_latents(pipeline.latents))
163
+ except Exception:
164
+ pass
165
+ return cbk
166
+
167
+ try:
168
  result = pipe(
169
  prompt=prompt,
170
  negative_prompt=None,
 
173
  num_inference_steps=steps,
174
  guidance_scale=0.0,
175
  generator=generator,
176
+ callback_on_step_end=progress_cb,
177
  callback_on_step_end_tensor_inputs=["latents"],
178
  output_type="pil"
179
  )
180
+ final_image = result.images[0]
181
+ log(f"Done in {time.time()-start_time:.1f}s")
182
+ except KeyboardInterrupt:
183
+ log("⚠️ Generation interrupted.")
184
+ return None, seed, preview_images
185
 
186
+ del result
187
+ gc.collect()
 
188
 
189
+ preview_images.append(final_image)
190
+ return final_image, seed, preview_images
191
 
192
+ # ----------- GRADIO UI -----------
193
 
194
+ with gr.Blocks(title="🤩✨ ZImage Turbo CPU Ultimate + Retry + Preview + Interrupt") as demo:
195
+ gr.Markdown("## Full feature CPU image generator true snapshot retry + preview frames")
196
 
197
  with gr.Row():
198
  with gr.Column():
199
  prompt = gr.Textbox(label="Prompt", lines=4)
200
  quality_mode = gr.Radio(
201
+ choices=["ultra_fast","fast","balanced","quality","ultra_quality"],
 
 
 
 
 
 
202
  value="fast",
203
  label="Quality Mode"
204
  )
205
+ seed = gr.Number(value=-1, precision=0, label="Seed (-1=random)")
206
+ model_choice = gr.Dropdown(list(MODEL_SPECS.keys()), value=list(MODEL_SPECS.keys())[0], label="Select model")
207
+ gen_btn = gr.Button("GENERATE")
208
+ interrupt_btn = gr.Button("STOP")
209
+
210
  with gr.Column():
211
+ out_img = gr.Image(label="Final Image")
212
+ out_seed = gr.Number(label="Seed Used", interactive=False)
213
+ preview_gallery = gr.Gallery(label="Preview frames")
214
+ log_output = gr.Textbox(label="Live System Log", lines=15, interactive=False)
 
 
 
215
 
216
+ def on_generate(prompt, quality_mode, seed, model_choice):
217
+ interrupt_event.clear()
218
+ final_img, used_seed, previews = generate(prompt, quality_mode, seed, model_choice)
219
+ return final_img, used_seed, previews, log("Generation done.")
220
 
221
+ def on_interrupt():
222
+ interrupt_event.set()
223
+ return log("📌 Interrupt requested")
 
 
 
224
 
225
+ gen_btn.click(on_generate, inputs=[prompt, quality_mode, seed, model_choice], outputs=[out_img, out_seed, preview_gallery, log_output])
226
+ interrupt_btn.click(on_interrupt, inputs=None, outputs=log_output)
227
 
228
+ demo.queue()
229
+ demo.launch(server_name="0.0.0.0", server_port=7860)