mckell commited on
Commit
0b04352
·
verified ·
1 Parent(s): b8b99da

Initial seed upload

Browse files
Files changed (3) hide show
  1. SPACES_README.md +30 -0
  2. app.py +164 -0
  3. requirements.txt +23 -0
SPACES_README.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: DiffViews
3
+ emoji: 🔬
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # DiffViews - Diffusion Activation Visualizer
14
+
15
+ Interactive visualization of diffusion model activations projected to 2D via UMAP.
16
+
17
+ ## Features
18
+ - Explore activation space of diffusion models
19
+ - Select points and find nearest neighbors
20
+ - Generate images from averaged neighbor activations
21
+ - Visualize denoising trajectories
22
+
23
+ ## Usage
24
+ 1. Hover over points to preview samples
25
+ 2. Click to select a point
26
+ 3. Click nearby points or use "Suggest KNN" to add neighbors
27
+ 4. Click "Generate from Neighbors" to create new images
28
+
29
+ ## Note
30
+ First launch downloads ~2.5GB of data and checkpoints. Generation on CPU takes ~30-60s per image.
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Spaces entry point for diffviews.
3
+
4
+ This file is the main entry point for HF Spaces deployment.
5
+ It downloads required data and checkpoints on startup, then launches the Gradio app.
6
+
7
+ Environment variables:
8
+ DIFFVIEWS_DATA_DIR: Override data directory (default: data)
9
+ DIFFVIEWS_CHECKPOINT: Which checkpoint to download (dmd2, edm, all, none; default: dmd2)
10
+ DIFFVIEWS_DEVICE: Override device (cuda, mps, cpu; auto-detected if not set)
11
+ """
12
+
13
+ import os
14
+ from pathlib import Path
15
+
16
+ # Data source configuration
17
+ DATA_REPO_ID = "mckell/diffviews_demo_data"
18
+ CHECKPOINT_URLS = {
19
+ "dmd2": (
20
+ "https://huggingface.co/mckell/diffviews-dmd2-checkpoint/"
21
+ "resolve/main/dmd2-imagenet-64-10step.pkl"
22
+ ),
23
+ "edm": (
24
+ "https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/"
25
+ "edm-imagenet-64x64-cond-adm.pkl"
26
+ ),
27
+ }
28
+ CHECKPOINT_FILENAMES = {
29
+ "dmd2": "dmd2-imagenet-64-10step.pkl",
30
+ "edm": "edm-imagenet-64x64-cond-adm.pkl",
31
+ }
32
+
33
+
34
+ def download_data(output_dir: Path) -> None:
35
+ """Download data from HuggingFace Hub."""
36
+ from huggingface_hub import snapshot_download
37
+
38
+ print(f"Downloading data from {DATA_REPO_ID}...")
39
+ print(f"Output directory: {output_dir.absolute()}")
40
+
41
+ snapshot_download(
42
+ repo_id=DATA_REPO_ID,
43
+ repo_type="dataset",
44
+ local_dir=output_dir,
45
+ revision="main",
46
+ )
47
+ print(f"Data downloaded to {output_dir}")
48
+
49
+
50
+ def download_checkpoint(output_dir: Path, model: str) -> None:
51
+ """Download model checkpoint."""
52
+ import urllib.request
53
+
54
+ if model not in CHECKPOINT_URLS:
55
+ print(f"Unknown model: {model}")
56
+ return
57
+
58
+ ckpt_dir = output_dir / model / "checkpoints"
59
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
60
+
61
+ filename = CHECKPOINT_FILENAMES[model]
62
+ filepath = ckpt_dir / filename
63
+
64
+ if filepath.exists():
65
+ print(f"Checkpoint exists: {filepath}")
66
+ return
67
+
68
+ url = CHECKPOINT_URLS[model]
69
+ print(f"Downloading {model} checkpoint (~1GB)...")
70
+ print(f" URL: {url}")
71
+
72
+ try:
73
+ urllib.request.urlretrieve(url, filepath)
74
+ print(f" Done ({filepath.stat().st_size / 1e6:.1f} MB)")
75
+ except Exception as e:
76
+ print(f" Error downloading checkpoint: {e}")
77
+ print(" Generation will be disabled without checkpoint")
78
+
79
+
80
+ def ensure_data_ready(data_dir: Path, checkpoints: list) -> bool:
81
+ """Ensure data and checkpoints are downloaded."""
82
+ # Check if data exists (look for config files)
83
+ has_data = any(
84
+ (data_dir / model / "config.json").exists()
85
+ for model in ["dmd2", "edm"]
86
+ )
87
+
88
+ if not has_data:
89
+ print("Data not found, downloading...")
90
+ download_data(data_dir)
91
+ else:
92
+ print(f"Data found in {data_dir}")
93
+
94
+ # Download checkpoints
95
+ for model in checkpoints:
96
+ download_checkpoint(data_dir, model)
97
+
98
+ return True
99
+
100
+
101
+ def get_device() -> str:
102
+ """Auto-detect best available device."""
103
+ override = os.environ.get("DIFFVIEWS_DEVICE")
104
+ if override:
105
+ return override
106
+
107
+ import torch
108
+
109
+ if torch.cuda.is_available():
110
+ return "cuda"
111
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
112
+ return "mps"
113
+ return "cpu"
114
+
115
+
116
+ def main():
117
+ """Main entry point for HF Spaces."""
118
+ # Configuration from environment
119
+ data_dir = Path(os.environ.get("DIFFVIEWS_DATA_DIR", "data"))
120
+ checkpoint_config = os.environ.get("DIFFVIEWS_CHECKPOINT", "dmd2")
121
+ device = get_device()
122
+
123
+ # Parse checkpoint config
124
+ if checkpoint_config == "all":
125
+ checkpoints = list(CHECKPOINT_URLS.keys())
126
+ elif checkpoint_config == "none":
127
+ checkpoints = []
128
+ else:
129
+ checkpoints = [c.strip() for c in checkpoint_config.split(",") if c.strip()]
130
+
131
+ print("=" * 50)
132
+ print("DiffViews - Diffusion Activation Visualizer")
133
+ print("=" * 50)
134
+ print(f"Data directory: {data_dir.absolute()}")
135
+ print(f"Device: {device}")
136
+ print(f"Checkpoints: {checkpoints}")
137
+ print("=" * 50)
138
+
139
+ # Ensure data is ready
140
+ ensure_data_ready(data_dir, checkpoints)
141
+
142
+ # Import and launch visualizer
143
+ from diffviews.visualization.app import GradioVisualizer, create_gradio_app
144
+
145
+ print("\nInitializing visualizer...")
146
+ visualizer = GradioVisualizer(
147
+ data_dir=data_dir,
148
+ device=device,
149
+ )
150
+
151
+ print("Creating Gradio app...")
152
+ app = create_gradio_app(visualizer)
153
+
154
+ print("Launching...")
155
+ # HF Spaces expects server on 0.0.0.0:7860
156
+ app.queue(max_size=20).launch(
157
+ server_name="0.0.0.0",
158
+ server_port=7860,
159
+ share=False, # Spaces handles public URL
160
+ )
161
+
162
+
163
+ if __name__ == "__main__":
164
+ main()
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiffViews - HuggingFace Spaces Requirements
2
+ # Install diffviews package from GitHub
3
+ git+https://github.com/mckellcarter/diffviews.git
4
+
5
+ # Core dependencies
6
+ torch>=2.0.0
7
+ numpy>=1.21.0
8
+ pandas>=1.5.0
9
+ pillow>=9.0.0
10
+ scikit-learn>=1.0.0
11
+ umap-learn>=0.5.0
12
+ tqdm>=4.60.0
13
+
14
+ # Visualization
15
+ gradio>=4.0.0
16
+ plotly>=5.18.0
17
+ matplotlib>=3.5.0
18
+
19
+ # HuggingFace Hub for data download
20
+ huggingface_hub>=0.19.0
21
+
22
+ # Optional but useful
23
+ scipy>=1.7.0