fix strip checkpoint saving
Browse files
scripts/strip_checkpoint.py
CHANGED
|
@@ -35,11 +35,12 @@ def main():
|
|
| 35 |
)
|
| 36 |
state_dict = ckpt
|
| 37 |
|
| 38 |
-
#
|
| 39 |
-
state_dict = {k: (v.float() if torch.is_floating_point(v) else v) for k, v in state_dict.items()}
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
if __name__ == "__main__":
|
|
|
|
| 35 |
)
|
| 36 |
state_dict = ckpt
|
| 37 |
|
| 38 |
+
#in the future, can cast to bfloat if necessary.
|
| 39 |
+
# state_dict = {k: (v.float() if torch.is_floating_point(v) else v) for k, v in state_dict.items()}
|
| 40 |
|
| 41 |
+
to_save = {"model": state_dict}
|
| 42 |
+
torch.save(to_save, str(out_path))
|
| 43 |
+
print(f"[strip_checkpoint] Saved dict with only 'model' to: {out_path}")
|
| 44 |
|
| 45 |
|
| 46 |
if __name__ == "__main__":
|