carmelog commited on
Commit
3b3d375
·
1 Parent(s): d066e2c

fix: checkpoint path print

Browse files
Files changed (1) hide show
  1. domino-training-test.ipynb +5 -6
domino-training-test.ipynb CHANGED
@@ -1052,14 +1052,13 @@
1052
  " surf_factors = None # If not available, set to None\n",
1053
  " \n",
1054
  " # Load the best model checkpoint\n",
1055
- " best_checkpoint = torch.load(\n",
1056
- " max(\n",
1057
- " (CHECKPOINT_DIR / \"best_model\").glob(\"DoMINO.0.*.pt\"),\n",
1058
- " key=lambda p: p.stat().st_mtime,\n",
1059
- " )\n",
1060
  " )\n",
 
 
1061
  " model.load_state_dict(best_checkpoint) # Load the model state\n",
1062
- " print(f\"Model loaded: {best_checkpoint}\")\n",
1063
  " \n",
1064
  " # Set the path to save predictions\n",
1065
  " pred_save_path = SAVE_PATH\n",
 
1052
  " surf_factors = None # If not available, set to None\n",
1053
  " \n",
1054
  " # Load the best model checkpoint\n",
1055
+ " checkpoint_path = max(\n",
1056
+ " (CHECKPOINT_DIR / \"best_model\").glob(\"DoMINO.0.*.pt\"),\n",
1057
+ " key=lambda p: p.stat().st_mtime,\n",
 
 
1058
  " )\n",
1059
+ " print(f\"Loading checkpoint: {checkpoint_path.name}\") # Print only the checkpoint name\n",
1060
+ " best_checkpoint = torch.load(checkpoint_path)\n",
1061
  " model.load_state_dict(best_checkpoint) # Load the model state\n",
 
1062
  " \n",
1063
  " # Set the path to save predictions\n",
1064
  " pred_save_path = SAVE_PATH\n",