sergio-sanz-rodriguez commited on
Commit
6cfaed9
·
1 Parent(s): a9b94d9

fixed bug model.py

Browse files
Files changed (1) hide show
  1. model.py +1 -2
model.py CHANGED
@@ -6,7 +6,6 @@ from torchvision.transforms import v2
6
  def load_model(model: torch.nn.Module,
7
  model_weights_dir: str,
8
  model_weights_name: str):
9
- #hidden_units: int):
10
 
11
  """Loads a PyTorch model from a target directory.
12
 
@@ -34,7 +33,7 @@ def load_model(model: torch.nn.Module,
34
  # Load the model
35
  print(f"[INFO] Loading model from: {model_path}")
36
 
37
- model.load_state_dict(torch.load(model_path, weights_only=True))
38
 
39
  return model
40
 
 
6
  def load_model(model: torch.nn.Module,
7
  model_weights_dir: str,
8
  model_weights_name: str):
 
9
 
10
  """Loads a PyTorch model from a target directory.
11
 
 
33
  # Load the model
34
  print(f"[INFO] Loading model from: {model_path}")
35
 
36
+ model.load_state_dict(torch.load(model_path, weights_only=True, map_location=torch.device('cpu')))
37
 
38
  return model
39