| | """ |
| | HuggingFace Spaces entry point for diffviews. |
| | |
| | This file is the main entry point for HF Spaces deployment. |
| | It downloads required data and checkpoints on startup, then launches the Gradio app. |
| | |
| | Requirements: |
| | Python 3.10+ |
| | Gradio 6.0+ |
| | |
| | Environment variables: |
| | DIFFVIEWS_DATA_DIR: Override data directory (default: data) |
| | DIFFVIEWS_CHECKPOINT: Which checkpoint to download (dmd2, edm, all, none; default: dmd2) |
| | DIFFVIEWS_DEVICE: Override device (cuda, mps, cpu; auto-detected if not set) |
| | """ |
| |
|
| | import os |
| | from pathlib import Path |
| |
|
| | |
| | DATA_REPO_ID = "mckell/diffviews_demo_data" |
| | CHECKPOINT_URLS = { |
| | "dmd2": ( |
| | "https://huggingface.co/mckell/diffviews-dmd2-checkpoint/" |
| | "resolve/main/dmd2-imagenet-64-10step.pkl" |
| | ), |
| | "edm": ( |
| | "https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/" |
| | "edm-imagenet-64x64-cond-adm.pkl" |
| | ), |
| | } |
| | CHECKPOINT_FILENAMES = { |
| | "dmd2": "dmd2-imagenet-64-10step.pkl", |
| | "edm": "edm-imagenet-64x64-cond-adm.pkl", |
| | } |
| |
|
| |
|
| | def download_data(output_dir: Path) -> None: |
| | """Download data from HuggingFace Hub.""" |
| | from huggingface_hub import snapshot_download |
| |
|
| | print(f"Downloading data from {DATA_REPO_ID}...") |
| | print(f"Output directory: {output_dir.absolute()}") |
| |
|
| | snapshot_download( |
| | repo_id=DATA_REPO_ID, |
| | repo_type="dataset", |
| | local_dir=output_dir, |
| | revision="main", |
| | ) |
| | print(f"Data downloaded to {output_dir}") |
| |
|
| |
|
| | def download_checkpoint(output_dir: Path, model: str) -> None: |
| | """Download model checkpoint.""" |
| | import urllib.request |
| |
|
| | if model not in CHECKPOINT_URLS: |
| | print(f"Unknown model: {model}") |
| | return |
| |
|
| | ckpt_dir = output_dir / model / "checkpoints" |
| | ckpt_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | filename = CHECKPOINT_FILENAMES[model] |
| | filepath = ckpt_dir / filename |
| |
|
| | if filepath.exists(): |
| | print(f"Checkpoint exists: {filepath}") |
| | return |
| |
|
| | url = CHECKPOINT_URLS[model] |
| | print(f"Downloading {model} checkpoint (~1GB)...") |
| | print(f" URL: {url}") |
| |
|
| | try: |
| | urllib.request.urlretrieve(url, filepath) |
| | print(f" Done ({filepath.stat().st_size / 1e6:.1f} MB)") |
| | except Exception as e: |
| | print(f" Error downloading checkpoint: {e}") |
| | print(" Generation will be disabled without checkpoint") |
| |
|
| |
|
| | def ensure_data_ready(data_dir: Path, checkpoints: list) -> bool: |
| | """Ensure data and checkpoints are downloaded.""" |
| | print(f"Checking for existing data in {data_dir.absolute()}...") |
| |
|
| | |
| | models_with_data = [] |
| | for model in ["dmd2", "edm"]: |
| | config_path = data_dir / model / "config.json" |
| | embeddings_dir = data_dir / model / "embeddings" |
| | images_dir = data_dir / model / "images" / "imagenet_real" |
| |
|
| | if not config_path.exists(): |
| | continue |
| | if not embeddings_dir.exists(): |
| | continue |
| |
|
| | csv_files = list(embeddings_dir.glob("*.csv")) |
| | png_files = list(images_dir.glob("sample_*.png")) if images_dir.exists() else [] |
| |
|
| | if csv_files and png_files: |
| | models_with_data.append(model) |
| | print(f" Found {model}: {len(csv_files)} csv, {len(png_files)} images") |
| |
|
| | if not models_with_data: |
| | print("Data not found, downloading...") |
| | download_data(data_dir) |
| | else: |
| | print(f"Data already present: {models_with_data}") |
| |
|
| | |
| | for model in checkpoints: |
| | download_checkpoint(data_dir, model) |
| |
|
| | return True |
| |
|
| |
|
| | def get_device() -> str: |
| | """Auto-detect best available device.""" |
| | override = os.environ.get("DIFFVIEWS_DEVICE") |
| | if override: |
| | return override |
| |
|
| | import torch |
| |
|
| | if torch.cuda.is_available(): |
| | return "cuda" |
| | if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| | return "mps" |
| | return "cpu" |
| |
|
| |
|
| | def main(): |
| | """Main entry point for HF Spaces.""" |
| | |
| | data_dir = Path(os.environ.get("DIFFVIEWS_DATA_DIR", "data")) |
| | checkpoint_config = os.environ.get("DIFFVIEWS_CHECKPOINT", "dmd2") |
| | device = get_device() |
| |
|
| | |
| | if checkpoint_config == "all": |
| | checkpoints = list(CHECKPOINT_URLS.keys()) |
| | elif checkpoint_config == "none": |
| | checkpoints = [] |
| | else: |
| | checkpoints = [c.strip() for c in checkpoint_config.split(",") if c.strip()] |
| |
|
| | print("=" * 50) |
| | print("DiffViews - Diffusion Activation Visualizer") |
| | print("=" * 50) |
| | print(f"Data directory: {data_dir.absolute()}") |
| | print(f"Device: {device}") |
| | print(f"Checkpoints: {checkpoints}") |
| | print("=" * 50) |
| |
|
| | |
| | ensure_data_ready(data_dir, checkpoints) |
| |
|
| | |
| | import gradio as gr |
| | from diffviews.visualization.app import ( |
| | GradioVisualizer, |
| | create_gradio_app, |
| | CUSTOM_CSS, |
| | PLOTLY_HANDLER_JS, |
| | ) |
| |
|
| | print("\nInitializing visualizer...") |
| | visualizer = GradioVisualizer( |
| | data_dir=data_dir, |
| | device=device, |
| | ) |
| |
|
| | print("Creating Gradio app...") |
| | app = create_gradio_app(visualizer) |
| |
|
| | print("Launching...") |
| | |
| | app.queue(max_size=20).launch( |
| | server_name="0.0.0.0", |
| | server_port=7860, |
| | share=False, |
| | theme=gr.themes.Soft(), |
| | css=CUSTOM_CSS, |
| | js=PLOTLY_HANDLER_JS, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|