""" HuggingFace Spaces entry point for diffviews. This file is the main entry point for HF Spaces deployment. It downloads required data and checkpoints on startup, then launches the Gradio app. Requirements: Python 3.10+ Gradio 6.0+ Environment variables: DIFFVIEWS_DATA_DIR: Override data directory (default: data) DIFFVIEWS_CHECKPOINT: Which checkpoint to download (dmd2, edm, all, none; default: all) DIFFVIEWS_DEVICE: Override device (cuda, mps, cpu; auto-detected if not set) """ import os import subprocess import sys from pathlib import Path # Install diffviews from git to bypass pip cache issues _REPO_URL = "https://github.com/mckellcarter/diffviews.git" _REPO_BRANCH = os.environ.get("DIFFVIEWS_BRANCH", "diffviews-gradio6-HFz-CFr2") _REPO_DIR = "/tmp/diffviews" # Remove stale pip-installed version so our clone takes priority subprocess.run(["pip", "uninstall", "-y", "diffviews"], capture_output=True) # Purge any cached diffviews modules for mod in list(sys.modules): if mod == "diffviews" or mod.startswith("diffviews."): del sys.modules[mod] if not os.path.exists(_REPO_DIR): print(f"Cloning diffviews from {_REPO_BRANCH}...") subprocess.run( ["git", "clone", "--depth=1", "-b", _REPO_BRANCH, _REPO_URL, _REPO_DIR], check=True, ) sys.path.insert(0, _REPO_DIR) import spaces # Data source configuration DATA_REPO_ID = "mckell/diffviews_demo_data" CHECKPOINT_URLS = { "dmd2": ( "https://huggingface.co/mckell/diffviews-dmd2-checkpoint/" "resolve/main/dmd2-imagenet-64-10step.pkl" ), "edm": ( "https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/" "edm-imagenet-64x64-cond-adm.pkl" ), } CHECKPOINT_FILENAMES = { "dmd2": "dmd2-imagenet-64-10step.pkl", "edm": "edm-imagenet-64x64-cond-adm.pkl", } def download_data_r2(output_dir: Path) -> bool: """Download data from Cloudflare R2. Returns True on success.""" from diffviews.data.r2_cache import R2DataStore store = R2DataStore() if not store.enabled: return False print(f"Downloading data from R2...") for model in ["dmd2", "edm"]: store.download_model_data(model, output_dir) return True def download_data_hf(output_dir: Path) -> None: """Fallback: download data from HuggingFace Hub.""" from huggingface_hub import snapshot_download print(f"Downloading data from {DATA_REPO_ID} (HF fallback)...") snapshot_download( repo_id=DATA_REPO_ID, repo_type="dataset", local_dir=output_dir, revision="main", ) print(f"Data downloaded to {output_dir}") def download_data(output_dir: Path) -> None: """Download data: R2 first, HF fallback.""" print(f"Output directory: {output_dir.absolute()}") if not download_data_r2(output_dir): download_data_hf(output_dir) def download_checkpoint(output_dir: Path, model: str) -> None: """Download model checkpoint: R2 first, URL fallback.""" if model not in CHECKPOINT_URLS: print(f"Unknown model: {model}") return ckpt_dir = output_dir / model / "checkpoints" ckpt_dir.mkdir(parents=True, exist_ok=True) filename = CHECKPOINT_FILENAMES[model] filepath = ckpt_dir / filename if filepath.exists(): print(f"Checkpoint exists: {filepath}") return # Try R2 first from diffviews.data.r2_cache import R2DataStore store = R2DataStore() r2_key = f"data/{model}/checkpoints/{filename}" if store.enabled and store.download_file(r2_key, filepath): print(f"Checkpoint downloaded from R2: {filepath} ({filepath.stat().st_size / 1e6:.1f} MB)") return # Fallback to direct URL import urllib.request url = CHECKPOINT_URLS[model] print(f"Downloading {model} checkpoint from URL (~1GB)...") print(f" URL: {url}") try: urllib.request.urlretrieve(url, filepath) print(f" Done ({filepath.stat().st_size / 1e6:.1f} MB)") except Exception as e: print(f" Error downloading checkpoint: {e}") print(" Generation will be disabled without checkpoint") def get_pca_components() -> int | None: """Read PCA pre-reduction setting from env. None = disabled.""" val = os.environ.get("DIFFVIEWS_PCA_COMPONENTS", "50") if val.lower() in ("0", "none", "off", ""): return None return int(val) def regenerate_umap(data_dir: Path, model: str) -> bool: """Regenerate UMAP pickle for a model to ensure numba compatibility. This recomputes UMAP from activations and saves new pickle file. Required when running on different environment than original pickle was created. """ from diffviews.processing.umap import ( load_dataset_activations, compute_umap, save_embeddings, ) import json model_dir = data_dir / model activation_dir = model_dir / "activations" / "imagenet_real" metadata_path = model_dir / "metadata" / "imagenet_real" / "dataset_info.json" embeddings_dir = model_dir / "embeddings" # Check if activations exist if not activation_dir.exists() or not metadata_path.exists(): print(f" Skipping UMAP regeneration for {model}: missing activations") return False # Find existing embeddings CSV to get parameters csv_files = list(embeddings_dir.glob("*.csv")) if not csv_files: print(f" Skipping UMAP regeneration for {model}: no embeddings CSV") return False csv_path = csv_files[0] json_path = csv_path.with_suffix(".json") pkl_path = csv_path.with_suffix(".pkl") # Load UMAP params from existing JSON umap_params = {"n_neighbors": 15, "min_dist": 0.1, "layers": ["encoder_bottleneck", "midblock"]} if json_path.exists(): with open(json_path, "r") as f: umap_params = json.load(f) pca_components = get_pca_components() print(f" Regenerating UMAP for {model}...") print(f" Params: n_neighbors={umap_params.get('n_neighbors', 15)}, min_dist={umap_params.get('min_dist', 0.1)}, pca={pca_components}") try: # Load activations activations, metadata_df = load_dataset_activations(activation_dir, metadata_path) print(f" Loaded {activations.shape[0]} activations ({activations.shape[1]} dims)") # Compute UMAP (with optional PCA pre-reduction) embeddings, reducer, scaler, pca_reducer = compute_umap( activations, n_neighbors=umap_params.get("n_neighbors", 15), min_dist=umap_params.get("min_dist", 0.1), normalize=True, pca_components=pca_components, ) # Save (overwrites existing pickle with compatible version) save_embeddings(embeddings, metadata_df, csv_path, umap_params, reducer, scaler, pca_reducer) print(f" UMAP pickle regenerated: {pkl_path}") return True except Exception as e: print(f" Error regenerating UMAP: {e}") return False def check_umap_compatibility(data_dir: Path, model: str) -> bool: """Check if UMAP pickle is compatible with current numba environment.""" embeddings_dir = data_dir / model / "embeddings" pkl_files = list(embeddings_dir.glob("*.pkl")) if not pkl_files: return True # No pickle to check pkl_path = pkl_files[0] try: import pickle with open(pkl_path, "rb") as f: umap_data = pickle.load(f) reducer = umap_data.get("reducer") if reducer is None: return True # Try a dummy transform to check numba compatibility import numpy as np dummy = np.random.randn(1, 100).astype(np.float32) # This will fail if numba JIT is incompatible scaler = umap_data.get("scaler") if scaler: dummy_scaled = scaler.transform(dummy) else: dummy_scaled = dummy # The actual transform - this triggers numba JIT _ = reducer.transform(dummy_scaled) return True except Exception as e: print(f" UMAP compatibility check failed for {model}: {e}") return False def ensure_data_ready(data_dir: Path, checkpoints: list) -> bool: """Ensure data and checkpoints are downloaded.""" print(f"Checking for existing data in {data_dir.absolute()}...") # Check which models have data (config + embeddings + images) models_with_data = [] for model in ["dmd2", "edm"]: config_path = data_dir / model / "config.json" embeddings_dir = data_dir / model / "embeddings" images_dir = data_dir / model / "images" / "imagenet_real" if not config_path.exists(): continue if not embeddings_dir.exists(): continue csv_files = list(embeddings_dir.glob("*.csv")) png_files = list(images_dir.glob("sample_*.png")) if images_dir.exists() else [] if csv_files and png_files: models_with_data.append(model) print(f" Found {model}: {len(csv_files)} csv, {len(png_files)} images") if not models_with_data: print("Data not found, downloading...") download_data(data_dir) else: print(f"Data already present: {models_with_data}") # Download checkpoints only if not present for model in checkpoints: download_checkpoint(data_dir, model) # Regenerate UMAP for all models to ensure numba compatibility # This is fast enough to do on every startup and avoids compatibility issues print("\nRegenerating UMAP pickles for numba compatibility...") for model in ["dmd2", "edm"]: model_dir = data_dir / model if not model_dir.exists(): print(f" {model}: model dir not found, skipping") continue embeddings_dir = model_dir / "embeddings" if not embeddings_dir.exists() or not list(embeddings_dir.glob("*.csv")): print(f" {model}: no embeddings found, skipping") continue print(f" {model}: regenerating UMAP...") regenerate_umap(data_dir, model) return True def get_device() -> str: """Auto-detect best available device.""" override = os.environ.get("DIFFVIEWS_DEVICE") if override: return override import torch if torch.cuda.is_available(): return "cuda" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "mps" return "cpu" @spaces.GPU(duration=120) def generate_on_gpu( model_name, all_neighbors, class_label, n_steps, m_steps, s_max, s_min, guidance, noise_mode, extract_layers, can_project ): """Run masked generation on GPU. Must live in app_file for ZeroGPU detection.""" from diffviews.visualization.app import _app_visualizer as visualizer from diffviews.core.masking import ActivationMasker from diffviews.core.generator import generate_with_mask_multistep with visualizer._generation_lock: adapter = visualizer.load_adapter(model_name) if adapter is None: return None activation_dict = visualizer.prepare_activation_dict(model_name, all_neighbors) if activation_dict is None: return None masker = ActivationMasker(adapter) for layer_name, activation in activation_dict.items(): masker.set_mask(layer_name, activation) masker.register_hooks(list(activation_dict.keys())) try: result = generate_with_mask_multistep( adapter, masker, class_label=class_label, num_steps=int(n_steps), mask_steps=int(m_steps), sigma_max=float(s_max), sigma_min=float(s_min), guidance_scale=float(guidance), noise_mode=(noise_mode or "stochastic noise").replace(" noise", ""), num_samples=1, device=visualizer.device, extract_layers=extract_layers if can_project else None, return_trajectory=can_project, return_intermediates=True, return_noised_inputs=True, ) finally: masker.remove_hooks() return result @spaces.GPU(duration=180) def extract_layer_on_gpu(model_name, layer_name, batch_size=32): """Extract layer activations on GPU. Must live in app_file for ZeroGPU detection.""" from diffviews.visualization.app import _app_visualizer as visualizer return visualizer.extract_layer_activations(model_name, layer_name, batch_size) def _setup(): """Initialize data, visualizer, and Gradio app.""" data_dir = Path(os.environ.get("DIFFVIEWS_DATA_DIR", "data")) checkpoint_config = os.environ.get("DIFFVIEWS_CHECKPOINT", "all") device = get_device() if checkpoint_config == "all": checkpoints = list(CHECKPOINT_URLS.keys()) elif checkpoint_config == "none": checkpoints = [] else: checkpoints = [c.strip() for c in checkpoint_config.split(",") if c.strip()] print("=" * 50) print("DiffViews - Diffusion Activation Visualizer") print("=" * 50) print(f"Data directory: {data_dir.absolute()}") print(f"Device: {device}") print(f"Checkpoints: {checkpoints}") print("=" * 50) ensure_data_ready(data_dir, checkpoints) import diffviews.visualization.app as viz_mod from diffviews.visualization.app import ( GradioVisualizer, create_gradio_app, ) # Inject ZeroGPU-decorated functions into visualization module # so Gradio callbacks use the versions codefind can detect viz_mod._generate_on_gpu = generate_on_gpu viz_mod._extract_layer_on_gpu = extract_layer_on_gpu print("\nInitializing visualizer...") visualizer = GradioVisualizer(data_dir=data_dir, device=device) print("Creating Gradio app...") app = create_gradio_app(visualizer) app.queue(max_size=20) return app # Module-level setup so Gradio hot-reload (which imports but doesn't call main) # still initializes everything and finds the app as `demo`. demo = _setup() if __name__ == "__main__": import gradio as gr from diffviews.visualization.app import CUSTOM_CSS, PLOTLY_HANDLER_JS demo.launch( server_name="0.0.0.0", server_port=7860, share=False, theme=gr.themes.Soft(), css=CUSTOM_CSS, js=PLOTLY_HANDLER_JS, )