File size: 5,556 Bytes
0b04352 857d986 0b04352 eb2867a 0ec7bae eb2867a 0ec7bae eb2867a 0b04352 eb2867a 0b04352 eb2867a 0b04352 857d986 0b04352 f208a6d 0b04352 f208a6d 0b04352 857d986 0b04352 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | """
HuggingFace Spaces entry point for diffviews.
This file is the main entry point for HF Spaces deployment.
It downloads required data and checkpoints on startup, then launches the Gradio app.
Requirements:
Python 3.10+
Gradio 6.0+
Environment variables:
DIFFVIEWS_DATA_DIR: Override data directory (default: data)
DIFFVIEWS_CHECKPOINT: Which checkpoint to download (dmd2, edm, all, none; default: dmd2)
DIFFVIEWS_DEVICE: Override device (cuda, mps, cpu; auto-detected if not set)
"""
import os
from pathlib import Path
# Data source configuration
DATA_REPO_ID = "mckell/diffviews_demo_data"
CHECKPOINT_URLS = {
"dmd2": (
"https://huggingface.co/mckell/diffviews-dmd2-checkpoint/"
"resolve/main/dmd2-imagenet-64-10step.pkl"
),
"edm": (
"https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/"
"edm-imagenet-64x64-cond-adm.pkl"
),
}
CHECKPOINT_FILENAMES = {
"dmd2": "dmd2-imagenet-64-10step.pkl",
"edm": "edm-imagenet-64x64-cond-adm.pkl",
}
def download_data(output_dir: Path) -> None:
"""Download data from HuggingFace Hub."""
from huggingface_hub import snapshot_download
print(f"Downloading data from {DATA_REPO_ID}...")
print(f"Output directory: {output_dir.absolute()}")
snapshot_download(
repo_id=DATA_REPO_ID,
repo_type="dataset",
local_dir=output_dir,
revision="main",
)
print(f"Data downloaded to {output_dir}")
def download_checkpoint(output_dir: Path, model: str) -> None:
"""Download model checkpoint."""
import urllib.request
if model not in CHECKPOINT_URLS:
print(f"Unknown model: {model}")
return
ckpt_dir = output_dir / model / "checkpoints"
ckpt_dir.mkdir(parents=True, exist_ok=True)
filename = CHECKPOINT_FILENAMES[model]
filepath = ckpt_dir / filename
if filepath.exists():
print(f"Checkpoint exists: {filepath}")
return
url = CHECKPOINT_URLS[model]
print(f"Downloading {model} checkpoint (~1GB)...")
print(f" URL: {url}")
try:
urllib.request.urlretrieve(url, filepath)
print(f" Done ({filepath.stat().st_size / 1e6:.1f} MB)")
except Exception as e:
print(f" Error downloading checkpoint: {e}")
print(" Generation will be disabled without checkpoint")
def ensure_data_ready(data_dir: Path, checkpoints: list) -> bool:
"""Ensure data and checkpoints are downloaded."""
print(f"Checking for existing data in {data_dir.absolute()}...")
# Check which models have data (config + embeddings + images)
models_with_data = []
for model in ["dmd2", "edm"]:
config_path = data_dir / model / "config.json"
embeddings_dir = data_dir / model / "embeddings"
images_dir = data_dir / model / "images" / "imagenet_real"
if not config_path.exists():
continue
if not embeddings_dir.exists():
continue
csv_files = list(embeddings_dir.glob("*.csv"))
png_files = list(images_dir.glob("sample_*.png")) if images_dir.exists() else []
if csv_files and png_files:
models_with_data.append(model)
print(f" Found {model}: {len(csv_files)} csv, {len(png_files)} images")
if not models_with_data:
print("Data not found, downloading...")
download_data(data_dir)
else:
print(f"Data already present: {models_with_data}")
# Download checkpoints only if not present
for model in checkpoints:
download_checkpoint(data_dir, model)
return True
def get_device() -> str:
"""Auto-detect best available device."""
override = os.environ.get("DIFFVIEWS_DEVICE")
if override:
return override
import torch
if torch.cuda.is_available():
return "cuda"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"
def main():
"""Main entry point for HF Spaces."""
# Configuration from environment
data_dir = Path(os.environ.get("DIFFVIEWS_DATA_DIR", "data"))
checkpoint_config = os.environ.get("DIFFVIEWS_CHECKPOINT", "dmd2")
device = get_device()
# Parse checkpoint config
if checkpoint_config == "all":
checkpoints = list(CHECKPOINT_URLS.keys())
elif checkpoint_config == "none":
checkpoints = []
else:
checkpoints = [c.strip() for c in checkpoint_config.split(",") if c.strip()]
print("=" * 50)
print("DiffViews - Diffusion Activation Visualizer")
print("=" * 50)
print(f"Data directory: {data_dir.absolute()}")
print(f"Device: {device}")
print(f"Checkpoints: {checkpoints}")
print("=" * 50)
# Ensure data is ready
ensure_data_ready(data_dir, checkpoints)
# Import and launch visualizer
import gradio as gr
from diffviews.visualization.app import (
GradioVisualizer,
create_gradio_app,
CUSTOM_CSS,
PLOTLY_HANDLER_JS,
)
print("\nInitializing visualizer...")
visualizer = GradioVisualizer(
data_dir=data_dir,
device=device,
)
print("Creating Gradio app...")
app = create_gradio_app(visualizer)
print("Launching...")
# HF Spaces expects server on 0.0.0.0:7860
app.queue(max_size=20).launch(
server_name="0.0.0.0",
server_port=7860,
share=False, # Spaces handles public URL
theme=gr.themes.Soft(),
css=CUSTOM_CSS,
js=PLOTLY_HANDLER_JS,
)
if __name__ == "__main__":
main()
|