Update models.py
Browse files
models.py
CHANGED
|
@@ -722,7 +722,7 @@ def load_F0_models(path):
|
|
| 722 |
# load F0 model
|
| 723 |
|
| 724 |
F0_model = JDCNet(num_class=1, seq_len=192)
|
| 725 |
-
params = torch.load(path, map_location="cpu")["net"]
|
| 726 |
F0_model.load_state_dict(params)
|
| 727 |
_ = F0_model.train()
|
| 728 |
|
|
@@ -739,7 +739,7 @@ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
|
|
| 739 |
|
| 740 |
def _load_model(model_config, model_path):
|
| 741 |
model = ASRCNN(**model_config)
|
| 742 |
-
params = torch.load(model_path, map_location="cpu")["model"]
|
| 743 |
model.load_state_dict(params)
|
| 744 |
return model
|
| 745 |
|
|
@@ -862,7 +862,7 @@ def build_model(args, text_aligner, pitch_extractor, bert):
|
|
| 862 |
|
| 863 |
|
| 864 |
def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]):
|
| 865 |
-
state = torch.load(path, map_location="cpu")
|
| 866 |
params = state["net"]
|
| 867 |
for key in model:
|
| 868 |
if key in params and key not in ignore_modules:
|
|
|
|
| 722 |
# load F0 model
|
| 723 |
|
| 724 |
F0_model = JDCNet(num_class=1, seq_len=192)
|
| 725 |
+
params = torch.load(path, map_location="cpu", weights_only=False)["net"]
|
| 726 |
F0_model.load_state_dict(params)
|
| 727 |
_ = F0_model.train()
|
| 728 |
|
|
|
|
| 739 |
|
| 740 |
def _load_model(model_config, model_path):
|
| 741 |
model = ASRCNN(**model_config)
|
| 742 |
+
params = torch.load(model_path, map_location="cpu", weights_only=False)["model"]
|
| 743 |
model.load_state_dict(params)
|
| 744 |
return model
|
| 745 |
|
|
|
|
| 862 |
|
| 863 |
|
| 864 |
def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]):
|
| 865 |
+
state = torch.load(path, map_location="cpu", weights_only=False)
|
| 866 |
params = state["net"]
|
| 867 |
for key in model:
|
| 868 |
if key in params and key not in ignore_modules:
|