mckell commited on
Commit
b66efbf
·
verified ·
1 Parent(s): 2d32c0f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -9
app.py CHANGED
@@ -59,13 +59,25 @@ CHECKPOINT_FILENAMES = {
59
  }
60
 
61
 
62
- def download_data(output_dir: Path) -> None:
63
- """Download data from HuggingFace Hub."""
64
- from huggingface_hub import snapshot_download
65
 
66
- print(f"Downloading data from {DATA_REPO_ID}...")
67
- print(f"Output directory: {output_dir.absolute()}")
 
 
 
 
 
 
 
 
 
 
 
68
 
 
69
  snapshot_download(
70
  repo_id=DATA_REPO_ID,
71
  repo_type="dataset",
@@ -75,10 +87,15 @@ def download_data(output_dir: Path) -> None:
75
  print(f"Data downloaded to {output_dir}")
76
 
77
 
78
- def download_checkpoint(output_dir: Path, model: str) -> None:
79
- """Download model checkpoint."""
80
- import urllib.request
 
 
 
81
 
 
 
82
  if model not in CHECKPOINT_URLS:
83
  print(f"Unknown model: {model}")
84
  return
@@ -93,8 +110,18 @@ def download_checkpoint(output_dir: Path, model: str) -> None:
93
  print(f"Checkpoint exists: {filepath}")
94
  return
95
 
 
 
 
 
 
 
 
 
 
 
96
  url = CHECKPOINT_URLS[model]
97
- print(f"Downloading {model} checkpoint (~1GB)...")
98
  print(f" URL: {url}")
99
 
100
  try:
 
59
  }
60
 
61
 
62
+ def download_data_r2(output_dir: Path) -> bool:
63
+ """Download data from Cloudflare R2. Returns True on success."""
64
+ from diffviews.data.r2_cache import R2DataStore
65
 
66
+ store = R2DataStore()
67
+ if not store.enabled:
68
+ return False
69
+
70
+ print(f"Downloading data from R2...")
71
+ for model in ["dmd2", "edm"]:
72
+ store.download_model_data(model, output_dir)
73
+ return True
74
+
75
+
76
+ def download_data_hf(output_dir: Path) -> None:
77
+ """Fallback: download data from HuggingFace Hub."""
78
+ from huggingface_hub import snapshot_download
79
 
80
+ print(f"Downloading data from {DATA_REPO_ID} (HF fallback)...")
81
  snapshot_download(
82
  repo_id=DATA_REPO_ID,
83
  repo_type="dataset",
 
87
  print(f"Data downloaded to {output_dir}")
88
 
89
 
90
+ def download_data(output_dir: Path) -> None:
91
+ """Download data: R2 first, HF fallback."""
92
+ print(f"Output directory: {output_dir.absolute()}")
93
+ if not download_data_r2(output_dir):
94
+ download_data_hf(output_dir)
95
+
96
 
97
+ def download_checkpoint(output_dir: Path, model: str) -> None:
98
+ """Download model checkpoint: R2 first, URL fallback."""
99
  if model not in CHECKPOINT_URLS:
100
  print(f"Unknown model: {model}")
101
  return
 
110
  print(f"Checkpoint exists: {filepath}")
111
  return
112
 
113
+ # Try R2 first
114
+ from diffviews.data.r2_cache import R2DataStore
115
+ store = R2DataStore()
116
+ r2_key = f"data/{model}/checkpoints/{filename}"
117
+ if store.enabled and store.download_file(r2_key, filepath):
118
+ print(f"Checkpoint downloaded from R2: {filepath} ({filepath.stat().st_size / 1e6:.1f} MB)")
119
+ return
120
+
121
+ # Fallback to direct URL
122
+ import urllib.request
123
  url = CHECKPOINT_URLS[model]
124
+ print(f"Downloading {model} checkpoint from URL (~1GB)...")
125
  print(f" URL: {url}")
126
 
127
  try: