TymaaHammouda commited on
Commit
6d1fbad
·
1 Parent(s): 3e9dbbe

Update path

Browse files
Files changed (2) hide show
  1. Nested/trainers/BaseTrainer.py +8 -1
  2. app.py +2 -0
Nested/trainers/BaseTrainer.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  import logging
4
  import natsort
5
  import glob
 
6
 
7
  logger = logging.getLogger(__name__)
8
 
@@ -114,5 +115,11 @@ class BaseTrainer:
114
  logger.info("Loading checkpoint %s", checkpoint_path)
115
 
116
  device = None if torch.cuda.is_available() else torch.device('cpu')
117
- checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
 
 
 
 
 
 
118
  self.model.load_state_dict(checkpoint["model"])
 
3
  import logging
4
  import natsort
5
  import glob
6
+ from huggingface_hub import hf_hub_download, snapshot_download
7
 
8
  logger = logging.getLogger(__name__)
9
 
 
115
  logger.info("Loading checkpoint %s", checkpoint_path)
116
 
117
  device = None if torch.cuda.is_available() else torch.device('cpu')
118
+ # checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
119
+ repo_path = snapshot_download(repo_id="SinaLab/Nested")
120
+
121
+ model_file = os.path.join(repo_path, "checkpoints", "checkpoint_2.pt")
122
+
123
+ checkpoint = torch.load(model_file, map_location=device, weights_only=False)
124
+
125
  self.model.load_state_dict(checkpoint["model"])
app.py CHANGED
@@ -28,6 +28,8 @@ encoder = AutoModel.from_pretrained(pretrained_path).eval()
28
  # filename="checkpoints/checkpoint_2.pt"
29
  # )
30
 
 
 
31
  checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
32
  print("checkpoint_path : ", checkpoint_path)
33
 
 
28
  # filename="checkpoints/checkpoint_2.pt"
29
  # )
30
 
31
+
32
+
33
  checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
34
  print("checkpoint_path : ", checkpoint_path)
35