File size: 5,556 Bytes
0b04352
 
 
 
 
 
857d986
 
 
 
0b04352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb2867a
 
0ec7bae
eb2867a
 
 
 
0ec7bae
 
 
 
 
 
 
 
 
 
 
 
 
eb2867a
 
0b04352
 
 
eb2867a
0b04352
eb2867a
0b04352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
857d986
 
 
 
 
 
 
0b04352
f208a6d
0b04352
 
 
 
 
f208a6d
0b04352
 
 
 
 
 
 
 
857d986
 
 
0b04352
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""
HuggingFace Spaces entry point for diffviews.

This file is the main entry point for HF Spaces deployment.
It downloads required data and checkpoints on startup, then launches the Gradio app.

Requirements:
    Python 3.10+
    Gradio 6.0+

Environment variables:
    DIFFVIEWS_DATA_DIR: Override data directory (default: data)
    DIFFVIEWS_CHECKPOINT: Which checkpoint to download (dmd2, edm, all, none; default: dmd2)
    DIFFVIEWS_DEVICE: Override device (cuda, mps, cpu; auto-detected if not set)
"""

import os
from pathlib import Path

# Data source configuration
DATA_REPO_ID = "mckell/diffviews_demo_data"
CHECKPOINT_URLS = {
    "dmd2": (
        "https://huggingface.co/mckell/diffviews-dmd2-checkpoint/"
        "resolve/main/dmd2-imagenet-64-10step.pkl"
    ),
    "edm": (
        "https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/"
        "edm-imagenet-64x64-cond-adm.pkl"
    ),
}
CHECKPOINT_FILENAMES = {
    "dmd2": "dmd2-imagenet-64-10step.pkl",
    "edm": "edm-imagenet-64x64-cond-adm.pkl",
}


def download_data(output_dir: Path) -> None:
    """Download data from HuggingFace Hub."""
    from huggingface_hub import snapshot_download

    print(f"Downloading data from {DATA_REPO_ID}...")
    print(f"Output directory: {output_dir.absolute()}")

    snapshot_download(
        repo_id=DATA_REPO_ID,
        repo_type="dataset",
        local_dir=output_dir,
        revision="main",
    )
    print(f"Data downloaded to {output_dir}")


def download_checkpoint(output_dir: Path, model: str) -> None:
    """Download model checkpoint."""
    import urllib.request

    if model not in CHECKPOINT_URLS:
        print(f"Unknown model: {model}")
        return

    ckpt_dir = output_dir / model / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    filename = CHECKPOINT_FILENAMES[model]
    filepath = ckpt_dir / filename

    if filepath.exists():
        print(f"Checkpoint exists: {filepath}")
        return

    url = CHECKPOINT_URLS[model]
    print(f"Downloading {model} checkpoint (~1GB)...")
    print(f"  URL: {url}")

    try:
        urllib.request.urlretrieve(url, filepath)
        print(f"  Done ({filepath.stat().st_size / 1e6:.1f} MB)")
    except Exception as e:
        print(f"  Error downloading checkpoint: {e}")
        print("  Generation will be disabled without checkpoint")


def ensure_data_ready(data_dir: Path, checkpoints: list) -> bool:
    """Ensure data and checkpoints are downloaded."""
    print(f"Checking for existing data in {data_dir.absolute()}...")

    # Check which models have data (config + embeddings + images)
    models_with_data = []
    for model in ["dmd2", "edm"]:
        config_path = data_dir / model / "config.json"
        embeddings_dir = data_dir / model / "embeddings"
        images_dir = data_dir / model / "images" / "imagenet_real"

        if not config_path.exists():
            continue
        if not embeddings_dir.exists():
            continue

        csv_files = list(embeddings_dir.glob("*.csv"))
        png_files = list(images_dir.glob("sample_*.png")) if images_dir.exists() else []

        if csv_files and png_files:
            models_with_data.append(model)
            print(f"  Found {model}: {len(csv_files)} csv, {len(png_files)} images")

    if not models_with_data:
        print("Data not found, downloading...")
        download_data(data_dir)
    else:
        print(f"Data already present: {models_with_data}")

    # Download checkpoints only if not present
    for model in checkpoints:
        download_checkpoint(data_dir, model)

    return True


def get_device() -> str:
    """Auto-detect best available device."""
    override = os.environ.get("DIFFVIEWS_DEVICE")
    if override:
        return override

    import torch

    if torch.cuda.is_available():
        return "cuda"
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def main():
    """Main entry point for HF Spaces."""
    # Configuration from environment
    data_dir = Path(os.environ.get("DIFFVIEWS_DATA_DIR", "data"))
    checkpoint_config = os.environ.get("DIFFVIEWS_CHECKPOINT", "dmd2")
    device = get_device()

    # Parse checkpoint config
    if checkpoint_config == "all":
        checkpoints = list(CHECKPOINT_URLS.keys())
    elif checkpoint_config == "none":
        checkpoints = []
    else:
        checkpoints = [c.strip() for c in checkpoint_config.split(",") if c.strip()]

    print("=" * 50)
    print("DiffViews - Diffusion Activation Visualizer")
    print("=" * 50)
    print(f"Data directory: {data_dir.absolute()}")
    print(f"Device: {device}")
    print(f"Checkpoints: {checkpoints}")
    print("=" * 50)

    # Ensure data is ready
    ensure_data_ready(data_dir, checkpoints)

    # Import and launch visualizer
    import gradio as gr
    from diffviews.visualization.app import (
        GradioVisualizer,
        create_gradio_app,
        CUSTOM_CSS,
        PLOTLY_HANDLER_JS,
    )

    print("\nInitializing visualizer...")
    visualizer = GradioVisualizer(
        data_dir=data_dir,
        device=device,
    )

    print("Creating Gradio app...")
    app = create_gradio_app(visualizer)

    print("Launching...")
    # HF Spaces expects server on 0.0.0.0:7860
    app.queue(max_size=20).launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,  # Spaces handles public URL
        theme=gr.themes.Soft(),
        css=CUSTOM_CSS,
        js=PLOTLY_HANDLER_JS,
    )


if __name__ == "__main__":
    main()