Commit
·
6cfaed9
1
Parent(s):
a9b94d9
fixed bug model.py
Browse files
model.py
CHANGED
|
@@ -6,7 +6,6 @@ from torchvision.transforms import v2
|
|
| 6 |
def load_model(model: torch.nn.Module,
|
| 7 |
model_weights_dir: str,
|
| 8 |
model_weights_name: str):
|
| 9 |
-
#hidden_units: int):
|
| 10 |
|
| 11 |
"""Loads a PyTorch model from a target directory.
|
| 12 |
|
|
@@ -34,7 +33,7 @@ def load_model(model: torch.nn.Module,
|
|
| 34 |
# Load the model
|
| 35 |
print(f"[INFO] Loading model from: {model_path}")
|
| 36 |
|
| 37 |
-
model.load_state_dict(torch.load(model_path, weights_only=True))
|
| 38 |
|
| 39 |
return model
|
| 40 |
|
|
|
|
| 6 |
def load_model(model: torch.nn.Module,
|
| 7 |
model_weights_dir: str,
|
| 8 |
model_weights_name: str):
|
|
|
|
| 9 |
|
| 10 |
"""Loads a PyTorch model from a target directory.
|
| 11 |
|
|
|
|
| 33 |
# Load the model
|
| 34 |
print(f"[INFO] Loading model from: {model_path}")
|
| 35 |
|
| 36 |
+
model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device('cpu')))
|
| 37 |
|
| 38 |
return model
|
| 39 |
|