import os import sys import time import json import gc import random import torch import gradio as gr import requests from threading import Lock, Event, Thread from contextlib import contextmanager from urllib.parse import urlparse from huggingface_hub import hf_hub_download, hf_hub_url from huggingface_hub.utils import RepositoryNotFoundError # ===== LOGGING ===== LOG_BUFFER = [] LOG_LOCK = Lock() def log(msg): with LOG_LOCK: t = time.strftime("%H:%M:%S") entry = f"{t} | {msg}" LOG_BUFFER.append(entry) if len(LOG_BUFFER) > 500: LOG_BUFFER.pop(0) return "\n".join(LOG_BUFFER) # ===== ENV SETUP ===== os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" os.environ["TRANSFORMERS_CACHE"] = "./hf_cache" os.environ["HF_DATASETS_CACHE"] = "./hf_cache" os.makedirs("./hf_cache", exist_ok=True) torch.set_grad_enabled(False) torch.set_num_threads(min(8, os.cpu_count() or 1)) torch.set_float32_matmul_precision("medium") DEVICE = "cpu" DTYPE = torch.float32 try: from diffusers import ZImagePipeline, GGUFQuantizationConfig, ZImageTransformer2DModel log("Loaded diffusers modules") except ImportError as e: log(f"Import error: {e}") sys.exit(1) # ===== DOWNLOAD CONTEXT ===== interrupt_event = Event() pipe_cache = {} download_lock = Lock() # ===== MODEL LIST ===== MODEL_SPECS = { "Turbo Full": "Tongyi-MAI/Z-Image-Turbo", "Turbo Q2_K GGUF": "unsloth/Z-Image-Turbo-GGUF" } # ===== DOWNLOAD HELPERS ===== def list_repo_files(repo_id): """ Returns a list of (filename, size) tuples by doing a dry run (no actual data downloaded). """ try: infos = hf_hub_download(repo_id, dry_run=True) return [(info.rfilename, info.size_in_bytes) for info in infos] except Exception as e: log(f"List failed: {e}") return [] def download_file_chunked(repo_id, filename, target_dir, progress_updater): """ Download a single file by streaming signed URL chunks. Supports resume by checking existing file size. """ local_path = os.path.join(target_dir, filename) tmp_path = local_path + ".part" os.makedirs(os.path.dirname(local_path), exist_ok=True) already = 0 if os.path.exists(tmp_path): already = os.path.getsize(tmp_path) # Get a fresh signed URL from HF for that file try: url = hf_hub_url(repo_id, filename) except RepositoryNotFoundError: # fallback to normal url = hf_hub_download(repo_id, filename=filename) headers = {} if already > 0: headers["Range"] = f"bytes={already}-" with requests.get(url, headers=headers, stream=True, timeout=10) as r: total = int(r.headers.get("Content-Length", 0)) + already with open(tmp_path, "ab") as f: downloaded = already for chunk in r.iter_content(chunk_size=1024*256): if interrupt_event.is_set(): return False if not chunk: continue f.write(chunk) downloaded += len(chunk) progress_updater(downloaded / total) os.rename(tmp_path, local_path) return True def parallel_download_repo(repo_id, progress: gr.Progress): """ Download all files in the repo in parallel with per-file progress. """ base_dir = os.path.join("./hf_cache", repo_id.replace("/", "_")) files = list_repo_files(repo_id) if not files: progress(1.0, desc="No files to download") return total_bytes = sum(sz for _, sz in files) downloaded_bytes = 0 def file_thread(filename, size): nonlocal downloaded_bytes success = download_file_chunked( repo_id, filename, base_dir, lambda frac: progress((downloaded_bytes + frac * size) / total_bytes, desc=f"{filename} {frac*100:.1f}%") ) if success: with download_lock: downloaded_bytes += size threads = [] for fname, size in files: if interrupt_event.is_set(): break # skip if fully cached already local_full = os.path.join(base_dir, fname) if os.path.exists(local_full) and os.path.getsize(local_full) == size: downloaded_bytes += size continue t = Thread(target=file_thread, args=(fname, size)) t.start() threads.append(t) for t in threads: t.join() # ===== PIPELINE LOADER ===== def load_pipeline(model_key): """ Load the HF pipeline, using quantized GGUF if selected. """ if model_key in pipe_cache: return pipe_cache[model_key] repo = MODEL_SPECS[model_key] repo_cache = os.path.join("./hf_cache", repo.replace("/", "_")) # ensure cache if not os.path.isdir(repo_cache) or not os.listdir(repo_cache): raise gr.Error("Model not downloaded; press Preload first") # load model if "GGUF" in model_key: # pick .gguf file files = [f for f in os.listdir(repo_cache) if f.endswith(".gguf")] if not files: raise gr.Error("Quantized file not found") gguf = os.path.join(repo_cache, files[0]) transformer = ZImageTransformer2DModel.from_single_file( gguf, quantization_config=GGUFQuantizationConfig(compute_dtype=DTYPE), torch_dtype=DTYPE ) pipe = ZImagePipeline.from_pretrained( "Tongyi-MAI/Z-Image-Turbo", transformer=transformer, torch_dtype=DTYPE, cache_dir="./hf_cache" ) else: pipe = ZImagePipeline.from_pretrained(repo_cache, torch_dtype=DTYPE, local_files_only=True) pipe.to(DEVICE) pipe.vae.eval() pipe.text_encoder.eval() pipe.transformer.eval() pipe_cache[model_key] = pipe return pipe @contextmanager def managed_memory(): try: yield finally: gc.collect() gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # ===== GENERATION ===== @torch.inference_mode() @torch.no_grad() def generate(prompt, quality_mode, seed, model_key): if not prompt.strip(): raise gr.Error("Prompt cannot be empty") PRESETS = { "ultra_fast": (1,256), "fast": (1,256), "balanced": (2,256), "quality": (4,384), "ultra_quality": (4,512), } steps, size = PRESETS.get(quality_mode, (1,256)) width = height = size seed = int(seed) if seed>=0 else random.randint(0,2**31-1) log(f"Gen: {prompt[:30]} | {quality_mode} | {model_key} | seed={seed}") with managed_memory(): pipe = load_pipeline(model_key) gen = torch.Generator("cpu").manual_seed(seed) previews=[] start = time.time() def cb(ppl, step, timestep, cbk): if interrupt_event.is_set(): ppl._interrupt=True if step % 2 == 0: try: previews.append(ppl.image_from_latents(cbk["latents"])) except: pass return cbk result = pipe( prompt=prompt, negative_prompt=None, width=width, height=height, num_inference_steps=steps, guidance_scale=0.0, generator=gen, callback_on_step_end=cb, callback_on_step_end_tensor_inputs=["latents"], output_type="pil" ) final = result.images[0] previews.append(final) log(f"Generated in {time.time()-start:.1f}s") return final, seed, previews # ===== GRADIO UI ===== with gr.Blocks(title="Z‑Image Turbo CPU Downloader + UI A‑Progress") as demo: gr.Markdown("## True parallel download UI + chunked progress") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", lines=4) quality = gr.Radio(["ultra_fast","fast","balanced","quality","ultra_quality"], value="fast") seed = gr.Number(value=-1, precision=0, label="Seed") model_select = gr.Dropdown(list(MODEL_SPECS.keys()), value=list(MODEL_SPECS.keys())[0], label="Model") preload = gr.Button("PRELOAD MODELS") gen_btn = gr.Button("GENERATE") stop_btn = gr.Button("STOP") with gr.Column(): out_image = gr.Image(label="Final") used_seed = gr.Number(label="Seed Used") preview = gr.Gallery(label="Preview Frames") logs = gr.Textbox(label="Logs", lines=25) def do_preload(progress=gr.Progress()): interrupt_event.clear() for key, repo in MODEL_SPECS.items(): parallel_download_repo(repo, progress) return log("📦 Preload finished") def do_gen(prompt, quality, seed, model_key): interrupt_event.clear() img, used, previews = generate(prompt, quality, seed, model_key) return img, used, previews, log("🧠 Generation done") def do_stop(): interrupt_event.set() return log("🔴 Interrupt set") preload.click(do_preload, outputs=logs) gen_btn.click(do_gen, inputs=[prompt,quality,seed,model_select], outputs=[out_image,used_seed,preview,logs]) stop_btn.click(do_stop, outputs=logs) demo.queue() demo.launch(server_name="0.0.0.0", server_port=7860)