File size: 14,468 Bytes
0b04352 857d986 0b04352 cc18481 0b04352 58e5bac 0b04352 58e5bac ae34bd9 58e5bac 2d32c0f 58e5bac 9a3535b 0b04352 b66efbf 0b04352 b66efbf 0b04352 b66efbf 0b04352 b66efbf 0b04352 b66efbf 0b04352 b66efbf 0b04352 b66efbf 0b04352 58e5bac 73553fd 58e5bac 73553fd 58e5bac 73553fd 58e5bac 73553fd 58e5bac 73553fd 58e5bac 73553fd 58e5bac 73553fd 0b04352 eb2867a 0ec7bae eb2867a 0ec7bae eb2867a 0b04352 eb2867a 0b04352 eb2867a 0b04352 cc18481 73553fd cc18481 73553fd cc18481 73553fd 0b04352 9a3535b 207930a 9a3535b 207930a 9a3535b 207930a 9a3535b 207930a 9a3535b 207930a 0b04352 207930a 0b04352 207930a 857d986 0b04352 9a3535b f208a6d 207930a 0b04352 f208a6d 0b04352 207930a 0b04352 207930a 0b04352 207930a 857d986 0b04352 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 | """
HuggingFace Spaces entry point for diffviews.
This file is the main entry point for HF Spaces deployment.
It downloads required data and checkpoints on startup, then launches the Gradio app.
Requirements:
Python 3.10+
Gradio 6.0+
Environment variables:
DIFFVIEWS_DATA_DIR: Override data directory (default: data)
DIFFVIEWS_CHECKPOINT: Which checkpoint to download (dmd2, edm, all, none; default: all)
DIFFVIEWS_DEVICE: Override device (cuda, mps, cpu; auto-detected if not set)
"""
import os
import subprocess
import sys
from pathlib import Path
# Install diffviews from git to bypass pip cache issues
_REPO_URL = "https://github.com/mckellcarter/diffviews.git"
_REPO_BRANCH = os.environ.get("DIFFVIEWS_BRANCH", "diffviews-gradio6-HFz-CFr2")
_REPO_DIR = "/tmp/diffviews"
# Remove stale pip-installed version so our clone takes priority
subprocess.run(["pip", "uninstall", "-y", "diffviews"], capture_output=True)
# Purge any cached diffviews modules
for mod in list(sys.modules):
if mod == "diffviews" or mod.startswith("diffviews."):
del sys.modules[mod]
if not os.path.exists(_REPO_DIR):
print(f"Cloning diffviews from {_REPO_BRANCH}...")
subprocess.run(
["git", "clone", "--depth=1", "-b", _REPO_BRANCH, _REPO_URL, _REPO_DIR],
check=True,
)
sys.path.insert(0, _REPO_DIR)
import spaces
# Data source configuration
DATA_REPO_ID = "mckell/diffviews_demo_data"
CHECKPOINT_URLS = {
"dmd2": (
"https://huggingface.co/mckell/diffviews-dmd2-checkpoint/"
"resolve/main/dmd2-imagenet-64-10step.pkl"
),
"edm": (
"https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/"
"edm-imagenet-64x64-cond-adm.pkl"
),
}
CHECKPOINT_FILENAMES = {
"dmd2": "dmd2-imagenet-64-10step.pkl",
"edm": "edm-imagenet-64x64-cond-adm.pkl",
}
def download_data_r2(output_dir: Path) -> bool:
"""Download data from Cloudflare R2. Returns True on success."""
from diffviews.data.r2_cache import R2DataStore
store = R2DataStore()
if not store.enabled:
return False
print(f"Downloading data from R2...")
for model in ["dmd2", "edm"]:
store.download_model_data(model, output_dir)
return True
def download_data_hf(output_dir: Path) -> None:
"""Fallback: download data from HuggingFace Hub."""
from huggingface_hub import snapshot_download
print(f"Downloading data from {DATA_REPO_ID} (HF fallback)...")
snapshot_download(
repo_id=DATA_REPO_ID,
repo_type="dataset",
local_dir=output_dir,
revision="main",
)
print(f"Data downloaded to {output_dir}")
def download_data(output_dir: Path) -> None:
"""Download data: R2 first, HF fallback."""
print(f"Output directory: {output_dir.absolute()}")
if not download_data_r2(output_dir):
download_data_hf(output_dir)
def download_checkpoint(output_dir: Path, model: str) -> None:
"""Download model checkpoint: R2 first, URL fallback."""
if model not in CHECKPOINT_URLS:
print(f"Unknown model: {model}")
return
ckpt_dir = output_dir / model / "checkpoints"
ckpt_dir.mkdir(parents=True, exist_ok=True)
filename = CHECKPOINT_FILENAMES[model]
filepath = ckpt_dir / filename
if filepath.exists():
print(f"Checkpoint exists: {filepath}")
return
# Try R2 first
from diffviews.data.r2_cache import R2DataStore
store = R2DataStore()
r2_key = f"data/{model}/checkpoints/{filename}"
if store.enabled and store.download_file(r2_key, filepath):
print(f"Checkpoint downloaded from R2: {filepath} ({filepath.stat().st_size / 1e6:.1f} MB)")
return
# Fallback to direct URL
import urllib.request
url = CHECKPOINT_URLS[model]
print(f"Downloading {model} checkpoint from URL (~1GB)...")
print(f" URL: {url}")
try:
urllib.request.urlretrieve(url, filepath)
print(f" Done ({filepath.stat().st_size / 1e6:.1f} MB)")
except Exception as e:
print(f" Error downloading checkpoint: {e}")
print(" Generation will be disabled without checkpoint")
def get_pca_components() -> int | None:
"""Read PCA pre-reduction setting from env. None = disabled."""
val = os.environ.get("DIFFVIEWS_PCA_COMPONENTS", "50")
if val.lower() in ("0", "none", "off", ""):
return None
return int(val)
def regenerate_umap(data_dir: Path, model: str) -> bool:
"""Regenerate UMAP pickle for a model to ensure numba compatibility.
This recomputes UMAP from activations and saves new pickle file.
Required when running on different environment than original pickle was created.
"""
from diffviews.processing.umap import (
load_dataset_activations,
compute_umap,
save_embeddings,
)
import json
model_dir = data_dir / model
activation_dir = model_dir / "activations" / "imagenet_real"
metadata_path = model_dir / "metadata" / "imagenet_real" / "dataset_info.json"
embeddings_dir = model_dir / "embeddings"
# Check if activations exist
if not activation_dir.exists() or not metadata_path.exists():
print(f" Skipping UMAP regeneration for {model}: missing activations")
return False
# Find existing embeddings CSV to get parameters
csv_files = list(embeddings_dir.glob("*.csv"))
if not csv_files:
print(f" Skipping UMAP regeneration for {model}: no embeddings CSV")
return False
csv_path = csv_files[0]
json_path = csv_path.with_suffix(".json")
pkl_path = csv_path.with_suffix(".pkl")
# Load UMAP params from existing JSON
umap_params = {"n_neighbors": 15, "min_dist": 0.1, "layers": ["encoder_bottleneck", "midblock"]}
if json_path.exists():
with open(json_path, "r") as f:
umap_params = json.load(f)
pca_components = get_pca_components()
print(f" Regenerating UMAP for {model}...")
print(f" Params: n_neighbors={umap_params.get('n_neighbors', 15)}, min_dist={umap_params.get('min_dist', 0.1)}, pca={pca_components}")
try:
# Load activations
activations, metadata_df = load_dataset_activations(activation_dir, metadata_path)
print(f" Loaded {activations.shape[0]} activations ({activations.shape[1]} dims)")
# Compute UMAP (with optional PCA pre-reduction)
embeddings, reducer, scaler, pca_reducer = compute_umap(
activations,
n_neighbors=umap_params.get("n_neighbors", 15),
min_dist=umap_params.get("min_dist", 0.1),
normalize=True,
pca_components=pca_components,
)
# Save (overwrites existing pickle with compatible version)
save_embeddings(embeddings, metadata_df, csv_path, umap_params, reducer, scaler, pca_reducer)
print(f" UMAP pickle regenerated: {pkl_path}")
return True
except Exception as e:
print(f" Error regenerating UMAP: {e}")
return False
def check_umap_compatibility(data_dir: Path, model: str) -> bool:
"""Check if UMAP pickle is compatible with current numba environment."""
embeddings_dir = data_dir / model / "embeddings"
pkl_files = list(embeddings_dir.glob("*.pkl"))
if not pkl_files:
return True # No pickle to check
pkl_path = pkl_files[0]
try:
import pickle
with open(pkl_path, "rb") as f:
umap_data = pickle.load(f)
reducer = umap_data.get("reducer")
if reducer is None:
return True
# Try a dummy transform to check numba compatibility
import numpy as np
dummy = np.random.randn(1, 100).astype(np.float32)
# This will fail if numba JIT is incompatible
scaler = umap_data.get("scaler")
if scaler:
dummy_scaled = scaler.transform(dummy)
else:
dummy_scaled = dummy
# The actual transform - this triggers numba JIT
_ = reducer.transform(dummy_scaled)
return True
except Exception as e:
print(f" UMAP compatibility check failed for {model}: {e}")
return False
def ensure_data_ready(data_dir: Path, checkpoints: list) -> bool:
"""Ensure data and checkpoints are downloaded."""
print(f"Checking for existing data in {data_dir.absolute()}...")
# Check which models have data (config + embeddings + images)
models_with_data = []
for model in ["dmd2", "edm"]:
config_path = data_dir / model / "config.json"
embeddings_dir = data_dir / model / "embeddings"
images_dir = data_dir / model / "images" / "imagenet_real"
if not config_path.exists():
continue
if not embeddings_dir.exists():
continue
csv_files = list(embeddings_dir.glob("*.csv"))
png_files = list(images_dir.glob("sample_*.png")) if images_dir.exists() else []
if csv_files and png_files:
models_with_data.append(model)
print(f" Found {model}: {len(csv_files)} csv, {len(png_files)} images")
if not models_with_data:
print("Data not found, downloading...")
download_data(data_dir)
else:
print(f"Data already present: {models_with_data}")
# Download checkpoints only if not present
for model in checkpoints:
download_checkpoint(data_dir, model)
# Regenerate UMAP for all models to ensure numba compatibility
# This is fast enough to do on every startup and avoids compatibility issues
print("\nRegenerating UMAP pickles for numba compatibility...")
for model in ["dmd2", "edm"]:
model_dir = data_dir / model
if not model_dir.exists():
print(f" {model}: model dir not found, skipping")
continue
embeddings_dir = model_dir / "embeddings"
if not embeddings_dir.exists() or not list(embeddings_dir.glob("*.csv")):
print(f" {model}: no embeddings found, skipping")
continue
print(f" {model}: regenerating UMAP...")
regenerate_umap(data_dir, model)
return True
def get_device() -> str:
"""Auto-detect best available device."""
override = os.environ.get("DIFFVIEWS_DEVICE")
if override:
return override
import torch
if torch.cuda.is_available():
return "cuda"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"
@spaces.GPU(duration=120)
def generate_on_gpu(
model_name, all_neighbors, class_label,
n_steps, m_steps, s_max, s_min, guidance, noise_mode,
extract_layers, can_project
):
"""Run masked generation on GPU. Must live in app_file for ZeroGPU detection."""
from diffviews.visualization.app import _app_visualizer as visualizer
from diffviews.core.masking import ActivationMasker
from diffviews.core.generator import generate_with_mask_multistep
with visualizer._generation_lock:
adapter = visualizer.load_adapter(model_name)
if adapter is None:
return None
activation_dict = visualizer.prepare_activation_dict(model_name, all_neighbors)
if activation_dict is None:
return None
masker = ActivationMasker(adapter)
for layer_name, activation in activation_dict.items():
masker.set_mask(layer_name, activation)
masker.register_hooks(list(activation_dict.keys()))
try:
result = generate_with_mask_multistep(
adapter,
masker,
class_label=class_label,
num_steps=int(n_steps),
mask_steps=int(m_steps),
sigma_max=float(s_max),
sigma_min=float(s_min),
guidance_scale=float(guidance),
noise_mode=(noise_mode or "stochastic noise").replace(" noise", ""),
num_samples=1,
device=visualizer.device,
extract_layers=extract_layers if can_project else None,
return_trajectory=can_project,
return_intermediates=True,
return_noised_inputs=True,
)
finally:
masker.remove_hooks()
return result
@spaces.GPU(duration=180)
def extract_layer_on_gpu(model_name, layer_name, batch_size=32):
"""Extract layer activations on GPU. Must live in app_file for ZeroGPU detection."""
from diffviews.visualization.app import _app_visualizer as visualizer
return visualizer.extract_layer_activations(model_name, layer_name, batch_size)
def _setup():
"""Initialize data, visualizer, and Gradio app."""
data_dir = Path(os.environ.get("DIFFVIEWS_DATA_DIR", "data"))
checkpoint_config = os.environ.get("DIFFVIEWS_CHECKPOINT", "all")
device = get_device()
if checkpoint_config == "all":
checkpoints = list(CHECKPOINT_URLS.keys())
elif checkpoint_config == "none":
checkpoints = []
else:
checkpoints = [c.strip() for c in checkpoint_config.split(",") if c.strip()]
print("=" * 50)
print("DiffViews - Diffusion Activation Visualizer")
print("=" * 50)
print(f"Data directory: {data_dir.absolute()}")
print(f"Device: {device}")
print(f"Checkpoints: {checkpoints}")
print("=" * 50)
ensure_data_ready(data_dir, checkpoints)
import diffviews.visualization.app as viz_mod
from diffviews.visualization.app import (
GradioVisualizer,
create_gradio_app,
)
# Inject ZeroGPU-decorated functions into visualization module
# so Gradio callbacks use the versions codefind can detect
viz_mod._generate_on_gpu = generate_on_gpu
viz_mod._extract_layer_on_gpu = extract_layer_on_gpu
print("\nInitializing visualizer...")
visualizer = GradioVisualizer(data_dir=data_dir, device=device)
print("Creating Gradio app...")
app = create_gradio_app(visualizer)
app.queue(max_size=20)
return app
# Module-level setup so Gradio hot-reload (which imports but doesn't call main)
# still initializes everything and finds the app as `demo`.
demo = _setup()
if __name__ == "__main__":
import gradio as gr
from diffviews.visualization.app import CUSTOM_CSS, PLOTLY_HANDLER_JS
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
theme=gr.themes.Soft(),
css=CUSTOM_CSS,
js=PLOTLY_HANDLER_JS,
)
|