fix: checkpoint path print
Browse files
domino-training-test.ipynb
CHANGED
|
@@ -1052,14 +1052,13 @@
|
|
| 1052 |
" surf_factors = None # If not available, set to None\n",
|
| 1053 |
" \n",
|
| 1054 |
" # Load the best model checkpoint\n",
|
| 1055 |
-
"
|
| 1056 |
-
"
|
| 1057 |
-
"
|
| 1058 |
-
" key=lambda p: p.stat().st_mtime,\n",
|
| 1059 |
-
" )\n",
|
| 1060 |
" )\n",
|
|
|
|
|
|
|
| 1061 |
" model.load_state_dict(best_checkpoint) # Load the model state\n",
|
| 1062 |
-
" print(f\"Model loaded: {best_checkpoint}\")\n",
|
| 1063 |
" \n",
|
| 1064 |
" # Set the path to save predictions\n",
|
| 1065 |
" pred_save_path = SAVE_PATH\n",
|
|
|
|
| 1052 |
" surf_factors = None # If not available, set to None\n",
|
| 1053 |
" \n",
|
| 1054 |
" # Load the best model checkpoint\n",
|
| 1055 |
+
" checkpoint_path = max(\n",
|
| 1056 |
+
" (CHECKPOINT_DIR / \"best_model\").glob(\"DoMINO.0.*.pt\"),\n",
|
| 1057 |
+
" key=lambda p: p.stat().st_mtime,\n",
|
|
|
|
|
|
|
| 1058 |
" )\n",
|
| 1059 |
+
" print(f\"Loading checkpoint: {checkpoint_path.name}\") # Print only the checkpoint name\n",
|
| 1060 |
+
" best_checkpoint = torch.load(checkpoint_path)\n",
|
| 1061 |
" model.load_state_dict(best_checkpoint) # Load the model state\n",
|
|
|
|
| 1062 |
" \n",
|
| 1063 |
" # Set the path to save predictions\n",
|
| 1064 |
" pred_save_path = SAVE_PATH\n",
|