Spaces:
Running
Running
Commit
·
6d1fbad
1
Parent(s):
3e9dbbe
Update path
Browse files- Nested/trainers/BaseTrainer.py +8 -1
- app.py +2 -0
Nested/trainers/BaseTrainer.py
CHANGED
|
@@ -3,6 +3,7 @@ import torch
|
|
| 3 |
import logging
|
| 4 |
import natsort
|
| 5 |
import glob
|
|
|
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
|
@@ -114,5 +115,11 @@ class BaseTrainer:
|
|
| 114 |
logger.info("Loading checkpoint %s", checkpoint_path)
|
| 115 |
|
| 116 |
device = None if torch.cuda.is_available() else torch.device('cpu')
|
| 117 |
-
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
self.model.load_state_dict(checkpoint["model"])
|
|
|
|
| 3 |
import logging
|
| 4 |
import natsort
|
| 5 |
import glob
|
| 6 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 7 |
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
|
|
|
|
| 115 |
logger.info("Loading checkpoint %s", checkpoint_path)
|
| 116 |
|
| 117 |
device = None if torch.cuda.is_available() else torch.device('cpu')
|
| 118 |
+
# checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 119 |
+
repo_path = snapshot_download(repo_id="SinaLab/Nested")
|
| 120 |
+
|
| 121 |
+
model_file = os.path.join(repo_path, "checkpoints", "checkpoint_2.pt")
|
| 122 |
+
|
| 123 |
+
checkpoint = torch.load(model_file, map_location=device, weights_only=False)
|
| 124 |
+
|
| 125 |
self.model.load_state_dict(checkpoint["model"])
|
app.py
CHANGED
|
@@ -28,6 +28,8 @@ encoder = AutoModel.from_pretrained(pretrained_path).eval()
|
|
| 28 |
# filename="checkpoints/checkpoint_2.pt"
|
| 29 |
# )
|
| 30 |
|
|
|
|
|
|
|
| 31 |
checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
|
| 32 |
print("checkpoint_path : ", checkpoint_path)
|
| 33 |
|
|
|
|
| 28 |
# filename="checkpoints/checkpoint_2.pt"
|
| 29 |
# )
|
| 30 |
|
| 31 |
+
|
| 32 |
+
|
| 33 |
checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
|
| 34 |
print("checkpoint_path : ", checkpoint_path)
|
| 35 |
|