Upload app.py
Browse files
app.py
CHANGED
|
@@ -59,13 +59,25 @@ CHECKPOINT_FILENAMES = {
|
|
| 59 |
}
|
| 60 |
|
| 61 |
|
| 62 |
-
def
|
| 63 |
-
"""Download data from
|
| 64 |
-
from
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
|
|
|
| 69 |
snapshot_download(
|
| 70 |
repo_id=DATA_REPO_ID,
|
| 71 |
repo_type="dataset",
|
|
@@ -75,10 +87,15 @@ def download_data(output_dir: Path) -> None:
|
|
| 75 |
print(f"Data downloaded to {output_dir}")
|
| 76 |
|
| 77 |
|
| 78 |
-
def
|
| 79 |
-
"""Download
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
| 81 |
|
|
|
|
|
|
|
| 82 |
if model not in CHECKPOINT_URLS:
|
| 83 |
print(f"Unknown model: {model}")
|
| 84 |
return
|
|
@@ -93,8 +110,18 @@ def download_checkpoint(output_dir: Path, model: str) -> None:
|
|
| 93 |
print(f"Checkpoint exists: {filepath}")
|
| 94 |
return
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
url = CHECKPOINT_URLS[model]
|
| 97 |
-
print(f"Downloading {model} checkpoint (~1GB)...")
|
| 98 |
print(f" URL: {url}")
|
| 99 |
|
| 100 |
try:
|
|
|
|
| 59 |
}
|
| 60 |
|
| 61 |
|
| 62 |
+
def download_data_r2(output_dir: Path) -> bool:
|
| 63 |
+
"""Download data from Cloudflare R2. Returns True on success."""
|
| 64 |
+
from diffviews.data.r2_cache import R2DataStore
|
| 65 |
|
| 66 |
+
store = R2DataStore()
|
| 67 |
+
if not store.enabled:
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
print(f"Downloading data from R2...")
|
| 71 |
+
for model in ["dmd2", "edm"]:
|
| 72 |
+
store.download_model_data(model, output_dir)
|
| 73 |
+
return True
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def download_data_hf(output_dir: Path) -> None:
|
| 77 |
+
"""Fallback: download data from HuggingFace Hub."""
|
| 78 |
+
from huggingface_hub import snapshot_download
|
| 79 |
|
| 80 |
+
print(f"Downloading data from {DATA_REPO_ID} (HF fallback)...")
|
| 81 |
snapshot_download(
|
| 82 |
repo_id=DATA_REPO_ID,
|
| 83 |
repo_type="dataset",
|
|
|
|
| 87 |
print(f"Data downloaded to {output_dir}")
|
| 88 |
|
| 89 |
|
| 90 |
+
def download_data(output_dir: Path) -> None:
|
| 91 |
+
"""Download data: R2 first, HF fallback."""
|
| 92 |
+
print(f"Output directory: {output_dir.absolute()}")
|
| 93 |
+
if not download_data_r2(output_dir):
|
| 94 |
+
download_data_hf(output_dir)
|
| 95 |
+
|
| 96 |
|
| 97 |
+
def download_checkpoint(output_dir: Path, model: str) -> None:
|
| 98 |
+
"""Download model checkpoint: R2 first, URL fallback."""
|
| 99 |
if model not in CHECKPOINT_URLS:
|
| 100 |
print(f"Unknown model: {model}")
|
| 101 |
return
|
|
|
|
| 110 |
print(f"Checkpoint exists: {filepath}")
|
| 111 |
return
|
| 112 |
|
| 113 |
+
# Try R2 first
|
| 114 |
+
from diffviews.data.r2_cache import R2DataStore
|
| 115 |
+
store = R2DataStore()
|
| 116 |
+
r2_key = f"data/{model}/checkpoints/{filename}"
|
| 117 |
+
if store.enabled and store.download_file(r2_key, filepath):
|
| 118 |
+
print(f"Checkpoint downloaded from R2: {filepath} ({filepath.stat().st_size / 1e6:.1f} MB)")
|
| 119 |
+
return
|
| 120 |
+
|
| 121 |
+
# Fallback to direct URL
|
| 122 |
+
import urllib.request
|
| 123 |
url = CHECKPOINT_URLS[model]
|
| 124 |
+
print(f"Downloading {model} checkpoint from URL (~1GB)...")
|
| 125 |
print(f" URL: {url}")
|
| 126 |
|
| 127 |
try:
|