Spaces:
Running
Running
| import os | |
| import sys | |
| import time | |
| import json | |
| import gc | |
| import random | |
| import torch | |
| import gradio as gr | |
| import requests | |
| from threading import Lock, Event, Thread | |
| from contextlib import contextmanager | |
| from urllib.parse import urlparse | |
| from huggingface_hub import hf_hub_download, hf_hub_url | |
| from huggingface_hub.utils import RepositoryNotFoundError | |
| # ===== LOGGING ===== | |
| LOG_BUFFER = [] | |
| LOG_LOCK = Lock() | |
| def log(msg): | |
| with LOG_LOCK: | |
| t = time.strftime("%H:%M:%S") | |
| entry = f"{t} | {msg}" | |
| LOG_BUFFER.append(entry) | |
| if len(LOG_BUFFER) > 500: | |
| LOG_BUFFER.pop(0) | |
| return "\n".join(LOG_BUFFER) | |
| # ===== ENV SETUP ===== | |
| os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" | |
| os.environ["TRANSFORMERS_CACHE"] = "./hf_cache" | |
| os.environ["HF_DATASETS_CACHE"] = "./hf_cache" | |
| os.makedirs("./hf_cache", exist_ok=True) | |
| torch.set_grad_enabled(False) | |
| torch.set_num_threads(min(8, os.cpu_count() or 1)) | |
| torch.set_float32_matmul_precision("medium") | |
| DEVICE = "cpu" | |
| DTYPE = torch.float32 | |
| try: | |
| from diffusers import ZImagePipeline, GGUFQuantizationConfig, ZImageTransformer2DModel | |
| log("Loaded diffusers modules") | |
| except ImportError as e: | |
| log(f"Import error: {e}") | |
| sys.exit(1) | |
| # ===== DOWNLOAD CONTEXT ===== | |
| interrupt_event = Event() | |
| pipe_cache = {} | |
| download_lock = Lock() | |
| # ===== MODEL LIST ===== | |
| MODEL_SPECS = { | |
| "Turbo Full": "Tongyi-MAI/Z-Image-Turbo", | |
| "Turbo Q2_K GGUF": "unsloth/Z-Image-Turbo-GGUF" | |
| } | |
| # ===== DOWNLOAD HELPERS ===== | |
| def list_repo_files(repo_id): | |
| """ | |
| Returns a list of (filename, size) tuples by doing a dry run | |
| (no actual data downloaded). | |
| """ | |
| try: | |
| infos = hf_hub_download(repo_id, dry_run=True) | |
| return [(info.rfilename, info.size_in_bytes) for info in infos] | |
| except Exception as e: | |
| log(f"List failed: {e}") | |
| return [] | |
| def download_file_chunked(repo_id, filename, target_dir, progress_updater): | |
| """ | |
| Download a single file by streaming signed URL chunks. | |
| Supports resume by checking existing file size. | |
| """ | |
| local_path = os.path.join(target_dir, filename) | |
| tmp_path = local_path + ".part" | |
| os.makedirs(os.path.dirname(local_path), exist_ok=True) | |
| already = 0 | |
| if os.path.exists(tmp_path): | |
| already = os.path.getsize(tmp_path) | |
| # Get a fresh signed URL from HF for that file | |
| try: | |
| url = hf_hub_url(repo_id, filename) | |
| except RepositoryNotFoundError: | |
| # fallback to normal | |
| url = hf_hub_download(repo_id, filename=filename) | |
| headers = {} | |
| if already > 0: | |
| headers["Range"] = f"bytes={already}-" | |
| with requests.get(url, headers=headers, stream=True, timeout=10) as r: | |
| total = int(r.headers.get("Content-Length", 0)) + already | |
| with open(tmp_path, "ab") as f: | |
| downloaded = already | |
| for chunk in r.iter_content(chunk_size=1024*256): | |
| if interrupt_event.is_set(): | |
| return False | |
| if not chunk: | |
| continue | |
| f.write(chunk) | |
| downloaded += len(chunk) | |
| progress_updater(downloaded / total) | |
| os.rename(tmp_path, local_path) | |
| return True | |
| def parallel_download_repo(repo_id, progress: gr.Progress): | |
| """ | |
| Download all files in the repo in parallel with per-file progress. | |
| """ | |
| base_dir = os.path.join("./hf_cache", repo_id.replace("/", "_")) | |
| files = list_repo_files(repo_id) | |
| if not files: | |
| progress(1.0, desc="No files to download") | |
| return | |
| total_bytes = sum(sz for _, sz in files) | |
| downloaded_bytes = 0 | |
| def file_thread(filename, size): | |
| nonlocal downloaded_bytes | |
| success = download_file_chunked( | |
| repo_id, filename, base_dir, | |
| lambda frac: progress((downloaded_bytes + frac * size) / total_bytes, | |
| desc=f"{filename} {frac*100:.1f}%") | |
| ) | |
| if success: | |
| with download_lock: | |
| downloaded_bytes += size | |
| threads = [] | |
| for fname, size in files: | |
| if interrupt_event.is_set(): | |
| break | |
| # skip if fully cached already | |
| local_full = os.path.join(base_dir, fname) | |
| if os.path.exists(local_full) and os.path.getsize(local_full) == size: | |
| downloaded_bytes += size | |
| continue | |
| t = Thread(target=file_thread, args=(fname, size)) | |
| t.start() | |
| threads.append(t) | |
| for t in threads: | |
| t.join() | |
| # ===== PIPELINE LOADER ===== | |
| def load_pipeline(model_key): | |
| """ | |
| Load the HF pipeline, using quantized GGUF if selected. | |
| """ | |
| if model_key in pipe_cache: | |
| return pipe_cache[model_key] | |
| repo = MODEL_SPECS[model_key] | |
| repo_cache = os.path.join("./hf_cache", repo.replace("/", "_")) | |
| # ensure cache | |
| if not os.path.isdir(repo_cache) or not os.listdir(repo_cache): | |
| raise gr.Error("Model not downloaded; press Preload first") | |
| # load model | |
| if "GGUF" in model_key: | |
| # pick .gguf file | |
| files = [f for f in os.listdir(repo_cache) if f.endswith(".gguf")] | |
| if not files: | |
| raise gr.Error("Quantized file not found") | |
| gguf = os.path.join(repo_cache, files[0]) | |
| transformer = ZImageTransformer2DModel.from_single_file( | |
| gguf, | |
| quantization_config=GGUFQuantizationConfig(compute_dtype=DTYPE), | |
| torch_dtype=DTYPE | |
| ) | |
| pipe = ZImagePipeline.from_pretrained( | |
| "Tongyi-MAI/Z-Image-Turbo", | |
| transformer=transformer, | |
| torch_dtype=DTYPE, | |
| cache_dir="./hf_cache" | |
| ) | |
| else: | |
| pipe = ZImagePipeline.from_pretrained(repo_cache, torch_dtype=DTYPE, local_files_only=True) | |
| pipe.to(DEVICE) | |
| pipe.vae.eval() | |
| pipe.text_encoder.eval() | |
| pipe.transformer.eval() | |
| pipe_cache[model_key] = pipe | |
| return pipe | |
| def managed_memory(): | |
| try: | |
| yield | |
| finally: | |
| gc.collect() | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # ===== GENERATION ===== | |
| def generate(prompt, quality_mode, seed, model_key): | |
| if not prompt.strip(): | |
| raise gr.Error("Prompt cannot be empty") | |
| PRESETS = { | |
| "ultra_fast": (1,256), | |
| "fast": (1,256), | |
| "balanced": (2,256), | |
| "quality": (4,384), | |
| "ultra_quality": (4,512), | |
| } | |
| steps, size = PRESETS.get(quality_mode, (1,256)) | |
| width = height = size | |
| seed = int(seed) if seed>=0 else random.randint(0,2**31-1) | |
| log(f"Gen: {prompt[:30]} | {quality_mode} | {model_key} | seed={seed}") | |
| with managed_memory(): | |
| pipe = load_pipeline(model_key) | |
| gen = torch.Generator("cpu").manual_seed(seed) | |
| previews=[] | |
| start = time.time() | |
| def cb(ppl, step, timestep, cbk): | |
| if interrupt_event.is_set(): | |
| ppl._interrupt=True | |
| if step % 2 == 0: | |
| try: | |
| previews.append(ppl.image_from_latents(cbk["latents"])) | |
| except: | |
| pass | |
| return cbk | |
| result = pipe( | |
| prompt=prompt, | |
| negative_prompt=None, | |
| width=width, | |
| height=height, | |
| num_inference_steps=steps, | |
| guidance_scale=0.0, | |
| generator=gen, | |
| callback_on_step_end=cb, | |
| callback_on_step_end_tensor_inputs=["latents"], | |
| output_type="pil" | |
| ) | |
| final = result.images[0] | |
| previews.append(final) | |
| log(f"Generated in {time.time()-start:.1f}s") | |
| return final, seed, previews | |
| # ===== GRADIO UI ===== | |
| with gr.Blocks(title="Z‑Image Turbo CPU Downloader + UI A‑Progress") as demo: | |
| gr.Markdown("## True parallel download UI + chunked progress") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt", lines=4) | |
| quality = gr.Radio(["ultra_fast","fast","balanced","quality","ultra_quality"], value="fast") | |
| seed = gr.Number(value=-1, precision=0, label="Seed") | |
| model_select = gr.Dropdown(list(MODEL_SPECS.keys()), value=list(MODEL_SPECS.keys())[0], label="Model") | |
| preload = gr.Button("PRELOAD MODELS") | |
| gen_btn = gr.Button("GENERATE") | |
| stop_btn = gr.Button("STOP") | |
| with gr.Column(): | |
| out_image = gr.Image(label="Final") | |
| used_seed = gr.Number(label="Seed Used") | |
| preview = gr.Gallery(label="Preview Frames") | |
| logs = gr.Textbox(label="Logs", lines=25) | |
| def do_preload(progress=gr.Progress()): | |
| interrupt_event.clear() | |
| for key, repo in MODEL_SPECS.items(): | |
| parallel_download_repo(repo, progress) | |
| return log("📦 Preload finished") | |
| def do_gen(prompt, quality, seed, model_key): | |
| interrupt_event.clear() | |
| img, used, previews = generate(prompt, quality, seed, model_key) | |
| return img, used, previews, log("🧠 Generation done") | |
| def do_stop(): | |
| interrupt_event.set() | |
| return log("🔴 Interrupt set") | |
| preload.click(do_preload, outputs=logs) | |
| gen_btn.click(do_gen, inputs=[prompt,quality,seed,model_select], | |
| outputs=[out_image,used_seed,preview,logs]) | |
| stop_btn.click(do_stop, outputs=logs) | |
| demo.queue() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |