Zitc / app.py
programmersd's picture
Update app.py
d183607 verified
import os
import sys
import time
import json
import gc
import random
import torch
import gradio as gr
import requests
from threading import Lock, Event, Thread
from contextlib import contextmanager
from urllib.parse import urlparse
from huggingface_hub import hf_hub_download, hf_hub_url
from huggingface_hub.utils import RepositoryNotFoundError
# ===== LOGGING =====
LOG_BUFFER = []
LOG_LOCK = Lock()
def log(msg):
with LOG_LOCK:
t = time.strftime("%H:%M:%S")
entry = f"{t} | {msg}"
LOG_BUFFER.append(entry)
if len(LOG_BUFFER) > 500:
LOG_BUFFER.pop(0)
return "\n".join(LOG_BUFFER)
# ===== ENV SETUP =====
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
os.environ["TRANSFORMERS_CACHE"] = "./hf_cache"
os.environ["HF_DATASETS_CACHE"] = "./hf_cache"
os.makedirs("./hf_cache", exist_ok=True)
torch.set_grad_enabled(False)
torch.set_num_threads(min(8, os.cpu_count() or 1))
torch.set_float32_matmul_precision("medium")
DEVICE = "cpu"
DTYPE = torch.float32
try:
from diffusers import ZImagePipeline, GGUFQuantizationConfig, ZImageTransformer2DModel
log("Loaded diffusers modules")
except ImportError as e:
log(f"Import error: {e}")
sys.exit(1)
# ===== DOWNLOAD CONTEXT =====
interrupt_event = Event()
pipe_cache = {}
download_lock = Lock()
# ===== MODEL LIST =====
MODEL_SPECS = {
"Turbo Full": "Tongyi-MAI/Z-Image-Turbo",
"Turbo Q2_K GGUF": "unsloth/Z-Image-Turbo-GGUF"
}
# ===== DOWNLOAD HELPERS =====
def list_repo_files(repo_id):
"""
Returns a list of (filename, size) tuples by doing a dry run
(no actual data downloaded).
"""
try:
infos = hf_hub_download(repo_id, dry_run=True)
return [(info.rfilename, info.size_in_bytes) for info in infos]
except Exception as e:
log(f"List failed: {e}")
return []
def download_file_chunked(repo_id, filename, target_dir, progress_updater):
"""
Download a single file by streaming signed URL chunks.
Supports resume by checking existing file size.
"""
local_path = os.path.join(target_dir, filename)
tmp_path = local_path + ".part"
os.makedirs(os.path.dirname(local_path), exist_ok=True)
already = 0
if os.path.exists(tmp_path):
already = os.path.getsize(tmp_path)
# Get a fresh signed URL from HF for that file
try:
url = hf_hub_url(repo_id, filename)
except RepositoryNotFoundError:
# fallback to normal
url = hf_hub_download(repo_id, filename=filename)
headers = {}
if already > 0:
headers["Range"] = f"bytes={already}-"
with requests.get(url, headers=headers, stream=True, timeout=10) as r:
total = int(r.headers.get("Content-Length", 0)) + already
with open(tmp_path, "ab") as f:
downloaded = already
for chunk in r.iter_content(chunk_size=1024*256):
if interrupt_event.is_set():
return False
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
progress_updater(downloaded / total)
os.rename(tmp_path, local_path)
return True
def parallel_download_repo(repo_id, progress: gr.Progress):
"""
Download all files in the repo in parallel with per-file progress.
"""
base_dir = os.path.join("./hf_cache", repo_id.replace("/", "_"))
files = list_repo_files(repo_id)
if not files:
progress(1.0, desc="No files to download")
return
total_bytes = sum(sz for _, sz in files)
downloaded_bytes = 0
def file_thread(filename, size):
nonlocal downloaded_bytes
success = download_file_chunked(
repo_id, filename, base_dir,
lambda frac: progress((downloaded_bytes + frac * size) / total_bytes,
desc=f"{filename} {frac*100:.1f}%")
)
if success:
with download_lock:
downloaded_bytes += size
threads = []
for fname, size in files:
if interrupt_event.is_set():
break
# skip if fully cached already
local_full = os.path.join(base_dir, fname)
if os.path.exists(local_full) and os.path.getsize(local_full) == size:
downloaded_bytes += size
continue
t = Thread(target=file_thread, args=(fname, size))
t.start()
threads.append(t)
for t in threads:
t.join()
# ===== PIPELINE LOADER =====
def load_pipeline(model_key):
"""
Load the HF pipeline, using quantized GGUF if selected.
"""
if model_key in pipe_cache:
return pipe_cache[model_key]
repo = MODEL_SPECS[model_key]
repo_cache = os.path.join("./hf_cache", repo.replace("/", "_"))
# ensure cache
if not os.path.isdir(repo_cache) or not os.listdir(repo_cache):
raise gr.Error("Model not downloaded; press Preload first")
# load model
if "GGUF" in model_key:
# pick .gguf file
files = [f for f in os.listdir(repo_cache) if f.endswith(".gguf")]
if not files:
raise gr.Error("Quantized file not found")
gguf = os.path.join(repo_cache, files[0])
transformer = ZImageTransformer2DModel.from_single_file(
gguf,
quantization_config=GGUFQuantizationConfig(compute_dtype=DTYPE),
torch_dtype=DTYPE
)
pipe = ZImagePipeline.from_pretrained(
"Tongyi-MAI/Z-Image-Turbo",
transformer=transformer,
torch_dtype=DTYPE,
cache_dir="./hf_cache"
)
else:
pipe = ZImagePipeline.from_pretrained(repo_cache, torch_dtype=DTYPE, local_files_only=True)
pipe.to(DEVICE)
pipe.vae.eval()
pipe.text_encoder.eval()
pipe.transformer.eval()
pipe_cache[model_key] = pipe
return pipe
@contextmanager
def managed_memory():
try:
yield
finally:
gc.collect()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# ===== GENERATION =====
@torch.inference_mode()
@torch.no_grad()
def generate(prompt, quality_mode, seed, model_key):
if not prompt.strip():
raise gr.Error("Prompt cannot be empty")
PRESETS = {
"ultra_fast": (1,256),
"fast": (1,256),
"balanced": (2,256),
"quality": (4,384),
"ultra_quality": (4,512),
}
steps, size = PRESETS.get(quality_mode, (1,256))
width = height = size
seed = int(seed) if seed>=0 else random.randint(0,2**31-1)
log(f"Gen: {prompt[:30]} | {quality_mode} | {model_key} | seed={seed}")
with managed_memory():
pipe = load_pipeline(model_key)
gen = torch.Generator("cpu").manual_seed(seed)
previews=[]
start = time.time()
def cb(ppl, step, timestep, cbk):
if interrupt_event.is_set():
ppl._interrupt=True
if step % 2 == 0:
try:
previews.append(ppl.image_from_latents(cbk["latents"]))
except:
pass
return cbk
result = pipe(
prompt=prompt,
negative_prompt=None,
width=width,
height=height,
num_inference_steps=steps,
guidance_scale=0.0,
generator=gen,
callback_on_step_end=cb,
callback_on_step_end_tensor_inputs=["latents"],
output_type="pil"
)
final = result.images[0]
previews.append(final)
log(f"Generated in {time.time()-start:.1f}s")
return final, seed, previews
# ===== GRADIO UI =====
with gr.Blocks(title="Z‑Image Turbo CPU Downloader + UI A‑Progress") as demo:
gr.Markdown("## True parallel download UI + chunked progress")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", lines=4)
quality = gr.Radio(["ultra_fast","fast","balanced","quality","ultra_quality"], value="fast")
seed = gr.Number(value=-1, precision=0, label="Seed")
model_select = gr.Dropdown(list(MODEL_SPECS.keys()), value=list(MODEL_SPECS.keys())[0], label="Model")
preload = gr.Button("PRELOAD MODELS")
gen_btn = gr.Button("GENERATE")
stop_btn = gr.Button("STOP")
with gr.Column():
out_image = gr.Image(label="Final")
used_seed = gr.Number(label="Seed Used")
preview = gr.Gallery(label="Preview Frames")
logs = gr.Textbox(label="Logs", lines=25)
def do_preload(progress=gr.Progress()):
interrupt_event.clear()
for key, repo in MODEL_SPECS.items():
parallel_download_repo(repo, progress)
return log("📦 Preload finished")
def do_gen(prompt, quality, seed, model_key):
interrupt_event.clear()
img, used, previews = generate(prompt, quality, seed, model_key)
return img, used, previews, log("🧠 Generation done")
def do_stop():
interrupt_event.set()
return log("🔴 Interrupt set")
preload.click(do_preload, outputs=logs)
gen_btn.click(do_gen, inputs=[prompt,quality,seed,model_select],
outputs=[out_image,used_seed,preview,logs])
stop_btn.click(do_stop, outputs=logs)
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860)