Safetensors export and infer.
Browse files- infer.py +18 -5
- requirements.txt +1 -0
- scripts/strip_checkpoint.py +11 -4
infer.py
CHANGED
|
@@ -3,14 +3,13 @@ import os
|
|
| 3 |
import pprint
|
| 4 |
import time
|
| 5 |
from typing import List, Tuple, Optional, Dict, Any
|
| 6 |
-
import yaml
|
| 7 |
-
|
| 8 |
import numpy as np
|
| 9 |
import cv2
|
| 10 |
import torch
|
| 11 |
import torch.nn.functional as F
|
| 12 |
from torch.amp import autocast
|
| 13 |
from tqdm import tqdm
|
|
|
|
| 14 |
|
| 15 |
from src.wireseghr.model import WireSegHR
|
| 16 |
from pathlib import Path
|
|
@@ -256,7 +255,7 @@ def main():
|
|
| 256 |
"--ckpt",
|
| 257 |
type=str,
|
| 258 |
default="",
|
| 259 |
-
help="Optional checkpoint (.pt
|
| 260 |
)
|
| 261 |
parser.add_argument(
|
| 262 |
"--save_prob", action="store_true", help="Also save probability .npy"
|
|
@@ -356,8 +355,22 @@ def main():
|
|
| 356 |
if ckpt_path:
|
| 357 |
assert Path(ckpt_path).is_file(), f"Checkpoint not found: {ckpt_path}"
|
| 358 |
print(f"[WireSegHR][infer] Loading checkpoint: {ckpt_path}")
|
| 359 |
-
|
| 360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
model.eval()
|
| 362 |
|
| 363 |
# Benchmark mode
|
|
|
|
| 3 |
import pprint
|
| 4 |
import time
|
| 5 |
from typing import List, Tuple, Optional, Dict, Any
|
|
|
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import cv2
|
| 8 |
import torch
|
| 9 |
import torch.nn.functional as F
|
| 10 |
from torch.amp import autocast
|
| 11 |
from tqdm import tqdm
|
| 12 |
+
from safetensors.torch import load_file as safe_load_file
|
| 13 |
|
| 14 |
from src.wireseghr.model import WireSegHR
|
| 15 |
from pathlib import Path
|
|
|
|
| 255 |
"--ckpt",
|
| 256 |
type=str,
|
| 257 |
default="",
|
| 258 |
+
help="Optional checkpoint (.pt with {'model': state_dict} or .safetensors with pure state_dict)",
|
| 259 |
)
|
| 260 |
parser.add_argument(
|
| 261 |
"--save_prob", action="store_true", help="Also save probability .npy"
|
|
|
|
| 355 |
if ckpt_path:
|
| 356 |
assert Path(ckpt_path).is_file(), f"Checkpoint not found: {ckpt_path}"
|
| 357 |
print(f"[WireSegHR][infer] Loading checkpoint: {ckpt_path}")
|
| 358 |
+
suffix = Path(ckpt_path).suffix.lower()
|
| 359 |
+
if suffix == ".safetensors":
|
| 360 |
+
# Safetensors exports contain a pure state_dict
|
| 361 |
+
state_dict = safe_load_file(ckpt_path)
|
| 362 |
+
model.load_state_dict(state_dict)
|
| 363 |
+
else:
|
| 364 |
+
print(
|
| 365 |
+
"[WireSegHR][infer][WARN] Loading a PyTorch checkpoint. Prefer .safetensors for inference-only weights."
|
| 366 |
+
)
|
| 367 |
+
# PyTorch .pt/.pth checkpoints expected to have {'model': state_dict}
|
| 368 |
+
state = torch.load(ckpt_path, map_location=device)
|
| 369 |
+
assert "model" in state, (
|
| 370 |
+
"Expected a dict with key 'model' for PyTorch checkpoint. "
|
| 371 |
+
"Use scripts/strip_checkpoint.py or provide a .safetensors file."
|
| 372 |
+
)
|
| 373 |
+
model.load_state_dict(state["model"])
|
| 374 |
model.eval()
|
| 375 |
|
| 376 |
# Benchmark mode
|
requirements.txt
CHANGED
|
@@ -8,3 +8,4 @@ PyYAML>=6.0.1
|
|
| 8 |
tqdm>=4.65.0
|
| 9 |
gdown>=5.1.0
|
| 10 |
pydrive2
|
|
|
|
|
|
| 8 |
tqdm>=4.65.0
|
| 9 |
gdown>=5.1.0
|
| 10 |
pydrive2
|
| 11 |
+
safetensors
|
scripts/strip_checkpoint.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
import argparse
|
| 4 |
from pathlib import Path
|
| 5 |
import torch
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def main():
|
|
@@ -10,7 +11,8 @@ def main():
|
|
| 10 |
description="Strip training checkpoint to inference-only weights (FP32)."
|
| 11 |
)
|
| 12 |
parser.add_argument("--in", dest="inp", type=str, required=True, help="Path to training checkpoint .pt")
|
| 13 |
-
parser.add_argument("--out", dest="out", type=str, required=True, help="Path to save weights-only .pt")
|
|
|
|
| 14 |
args = parser.parse_args()
|
| 15 |
|
| 16 |
in_path = Path(args.inp)
|
|
@@ -38,9 +40,14 @@ def main():
|
|
| 38 |
#in the future, can cast to bfloat if necessary.
|
| 39 |
# state_dict = {k: (v.float() if torch.is_floating_point(v) else v) for k, v in state_dict.items()}
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
if __name__ == "__main__":
|
|
|
|
| 3 |
import argparse
|
| 4 |
from pathlib import Path
|
| 5 |
import torch
|
| 6 |
+
from safetensors.torch import save_file as safetensors_save_file
|
| 7 |
|
| 8 |
|
| 9 |
def main():
|
|
|
|
| 11 |
description="Strip training checkpoint to inference-only weights (FP32)."
|
| 12 |
)
|
| 13 |
parser.add_argument("--in", dest="inp", type=str, required=True, help="Path to training checkpoint .pt")
|
| 14 |
+
parser.add_argument("--out", dest="out", type=str, required=True, help="Path to save weights-only .pt or .safetensors")
|
| 15 |
+
# Output format is inferred from --out extension
|
| 16 |
args = parser.parse_args()
|
| 17 |
|
| 18 |
in_path = Path(args.inp)
|
|
|
|
| 40 |
#in the future, can cast to bfloat if necessary.
|
| 41 |
# state_dict = {k: (v.float() if torch.is_floating_point(v) else v) for k, v in state_dict.items()}
|
| 42 |
|
| 43 |
+
suffix = out_path.suffix.lower()
|
| 44 |
+
if suffix == ".safetensors":
|
| 45 |
+
safetensors_save_file(state_dict, str(out_path))
|
| 46 |
+
print(f"[strip_checkpoint] Saved safetensors (pure state_dict) to: {out_path}")
|
| 47 |
+
else:
|
| 48 |
+
to_save = {"model": state_dict}
|
| 49 |
+
torch.save(to_save, str(out_path))
|
| 50 |
+
print(f"[strip_checkpoint] Saved dict with only 'model' to: {out_path}")
|
| 51 |
|
| 52 |
|
| 53 |
if __name__ == "__main__":
|