Spaces:
Running
Running
| import os | |
| import re | |
| from fairseq import checkpoint_utils | |
| def get_index_path_from_model(sid): | |
| sid0strip = re.sub(r'\.pth|\.onnx$', '', sid) | |
| sid0name = os.path.split(sid0strip)[-1] # Extract only the name, not the directory | |
| # Check if the sid0strip has the specific ending format _eXXX_sXXX | |
| if re.match(r'.+_e\d+_s\d+$', sid0name): | |
| base_model_name = sid0name.rsplit('_', 2)[0] | |
| else: | |
| base_model_name = sid0name | |
| return next( | |
| ( | |
| f | |
| for f in [ | |
| os.path.join(root, name) | |
| for root, _, files in os.walk(os.getenv("index_root"), topdown=False) | |
| for name in files | |
| if name.endswith(".index") and "trained" not in name | |
| ] | |
| if base_model_name in f | |
| ), | |
| "", | |
| ) | |
| def load_hubert(config): | |
| models, _, _ = checkpoint_utils.load_model_ensemble_and_task( | |
| ["assets/hubert/hubert_base.pt"], | |
| suffix="", | |
| ) | |
| hubert_model = models[0] | |
| hubert_model = hubert_model.to(config.device) | |
| if config.is_half: | |
| hubert_model = hubert_model.half() | |
| else: | |
| hubert_model = hubert_model.float() | |
| return hubert_model.eval() | |