mckell commited on
Commit
cc18481
·
verified ·
1 Parent(s): 73553fd

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -10,7 +10,7 @@ Requirements:
10
 
11
  Environment variables:
12
  DIFFVIEWS_DATA_DIR: Override data directory (default: data)
13
- DIFFVIEWS_CHECKPOINT: Which checkpoint to download (dmd2, edm, all, none; default: dmd2)
14
  DIFFVIEWS_DEVICE: Override device (cuda, mps, cpu; auto-detected if not set)
15
  """
16
 
@@ -218,18 +218,22 @@ def ensure_data_ready(data_dir: Path, checkpoints: list) -> bool:
218
  for model in checkpoints:
219
  download_checkpoint(data_dir, model)
220
 
221
- # Check UMAP compatibility and regenerate if needed
222
- print("\nChecking UMAP compatibility...")
 
223
  for model in ["dmd2", "edm"]:
224
  model_dir = data_dir / model
225
  if not model_dir.exists():
 
226
  continue
227
 
228
- if not check_umap_compatibility(data_dir, model):
229
- print(f" {model}: UMAP incompatible, regenerating...")
230
- regenerate_umap(data_dir, model)
231
- else:
232
- print(f" {model}: UMAP compatible")
 
 
233
 
234
  return True
235
 
@@ -253,7 +257,7 @@ def main():
253
  """Main entry point for HF Spaces."""
254
  # Configuration from environment
255
  data_dir = Path(os.environ.get("DIFFVIEWS_DATA_DIR", "data"))
256
- checkpoint_config = os.environ.get("DIFFVIEWS_CHECKPOINT", "dmd2")
257
  device = get_device()
258
 
259
  # Parse checkpoint config
 
10
 
11
  Environment variables:
12
  DIFFVIEWS_DATA_DIR: Override data directory (default: data)
13
+ DIFFVIEWS_CHECKPOINT: Which checkpoint to download (dmd2, edm, all, none; default: all)
14
  DIFFVIEWS_DEVICE: Override device (cuda, mps, cpu; auto-detected if not set)
15
  """
16
 
 
218
  for model in checkpoints:
219
  download_checkpoint(data_dir, model)
220
 
221
+ # Regenerate UMAP for all models to ensure numba compatibility
222
+ # This is fast enough to do on every startup and avoids compatibility issues
223
+ print("\nRegenerating UMAP pickles for numba compatibility...")
224
  for model in ["dmd2", "edm"]:
225
  model_dir = data_dir / model
226
  if not model_dir.exists():
227
+ print(f" {model}: model dir not found, skipping")
228
  continue
229
 
230
+ embeddings_dir = model_dir / "embeddings"
231
+ if not embeddings_dir.exists() or not list(embeddings_dir.glob("*.csv")):
232
+ print(f" {model}: no embeddings found, skipping")
233
+ continue
234
+
235
+ print(f" {model}: regenerating UMAP...")
236
+ regenerate_umap(data_dir, model)
237
 
238
  return True
239
 
 
257
  """Main entry point for HF Spaces."""
258
  # Configuration from environment
259
  data_dir = Path(os.environ.get("DIFFVIEWS_DATA_DIR", "data"))
260
+ checkpoint_config = os.environ.get("DIFFVIEWS_CHECKPOINT", "all") # Download all by default
261
  device = get_device()
262
 
263
  # Parse checkpoint config