Upload app.py
Browse files
app.py
CHANGED
|
@@ -10,7 +10,7 @@ Requirements:
|
|
| 10 |
|
| 11 |
Environment variables:
|
| 12 |
DIFFVIEWS_DATA_DIR: Override data directory (default: data)
|
| 13 |
-
DIFFVIEWS_CHECKPOINT: Which checkpoint to download (dmd2, edm, all, none; default:
|
| 14 |
DIFFVIEWS_DEVICE: Override device (cuda, mps, cpu; auto-detected if not set)
|
| 15 |
"""
|
| 16 |
|
|
@@ -218,18 +218,22 @@ def ensure_data_ready(data_dir: Path, checkpoints: list) -> bool:
|
|
| 218 |
for model in checkpoints:
|
| 219 |
download_checkpoint(data_dir, model)
|
| 220 |
|
| 221 |
-
#
|
| 222 |
-
|
|
|
|
| 223 |
for model in ["dmd2", "edm"]:
|
| 224 |
model_dir = data_dir / model
|
| 225 |
if not model_dir.exists():
|
|
|
|
| 226 |
continue
|
| 227 |
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
| 233 |
|
| 234 |
return True
|
| 235 |
|
|
@@ -253,7 +257,7 @@ def main():
|
|
| 253 |
"""Main entry point for HF Spaces."""
|
| 254 |
# Configuration from environment
|
| 255 |
data_dir = Path(os.environ.get("DIFFVIEWS_DATA_DIR", "data"))
|
| 256 |
-
checkpoint_config = os.environ.get("DIFFVIEWS_CHECKPOINT", "
|
| 257 |
device = get_device()
|
| 258 |
|
| 259 |
# Parse checkpoint config
|
|
|
|
| 10 |
|
| 11 |
Environment variables:
|
| 12 |
DIFFVIEWS_DATA_DIR: Override data directory (default: data)
|
| 13 |
+
DIFFVIEWS_CHECKPOINT: Which checkpoint to download (dmd2, edm, all, none; default: all)
|
| 14 |
DIFFVIEWS_DEVICE: Override device (cuda, mps, cpu; auto-detected if not set)
|
| 15 |
"""
|
| 16 |
|
|
|
|
| 218 |
for model in checkpoints:
|
| 219 |
download_checkpoint(data_dir, model)
|
| 220 |
|
| 221 |
+
# Regenerate UMAP for all models to ensure numba compatibility
|
| 222 |
+
# This is fast enough to do on every startup and avoids compatibility issues
|
| 223 |
+
print("\nRegenerating UMAP pickles for numba compatibility...")
|
| 224 |
for model in ["dmd2", "edm"]:
|
| 225 |
model_dir = data_dir / model
|
| 226 |
if not model_dir.exists():
|
| 227 |
+
print(f" {model}: model dir not found, skipping")
|
| 228 |
continue
|
| 229 |
|
| 230 |
+
embeddings_dir = model_dir / "embeddings"
|
| 231 |
+
if not embeddings_dir.exists() or not list(embeddings_dir.glob("*.csv")):
|
| 232 |
+
print(f" {model}: no embeddings found, skipping")
|
| 233 |
+
continue
|
| 234 |
+
|
| 235 |
+
print(f" {model}: regenerating UMAP...")
|
| 236 |
+
regenerate_umap(data_dir, model)
|
| 237 |
|
| 238 |
return True
|
| 239 |
|
|
|
|
| 257 |
"""Main entry point for HF Spaces."""
|
| 258 |
# Configuration from environment
|
| 259 |
data_dir = Path(os.environ.get("DIFFVIEWS_DATA_DIR", "data"))
|
| 260 |
+
checkpoint_config = os.environ.get("DIFFVIEWS_CHECKPOINT", "all") # Download all by default
|
| 261 |
device = get_device()
|
| 262 |
|
| 263 |
# Parse checkpoint config
|