Spaces:
Build error
Build error
| from ttts.vqvae.xtts_dvae import DiscreteVAE | |
| from ttts.diffusion.model import DiffusionTts | |
| from ttts.gpt.model import UnifiedVoice | |
| from ttts.classifier.model import AudioMiniEncoderWithClassifierHead | |
| from omegaconf import OmegaConf | |
| from ttts.diffusion.aa_model import AA_diffusion | |
| import json | |
| import torch | |
| import os | |
| def load_model(model_name, model_path, config_path, device): | |
| config_path = os.path.expanduser(config_path) | |
| model_path = os.path.expanduser(model_path) | |
| if config_path.endswith('.json'): | |
| config = json.load(open(config_path)) | |
| else: | |
| config = OmegaConf.load(config_path) | |
| if model_name=='vqvae': | |
| model = DiscreteVAE(**config['vqvae']) | |
| sd = torch.load(model_path, map_location=device)['model'] | |
| model.load_state_dict(sd, strict=True) | |
| model = model.to(device) | |
| elif model_name=='gpt': | |
| model = UnifiedVoice(**config['gpt']) | |
| gpt = torch.load(model_path, map_location=device)['model'] | |
| model.load_state_dict(gpt, strict=True) | |
| model = model.to(device) | |
| elif model_name=='diffusion': | |
| # model = DiffusionTts(**config['diffusion']) | |
| model = AA_diffusion(config) | |
| diffusion = torch.load(model_path, map_location=device)['model'] | |
| model.load_state_dict(diffusion, strict=True) | |
| model = model.to(device) | |
| elif model_name == 'classifier': | |
| model = AudioMiniEncoderWithClassifierHead(**config['classifier']) | |
| classifier = torch.load(model_path, map_location=device)['model'] | |
| model.load_state_dict(classifier, strict=True) | |
| model = model.to(device) | |
| # elif model_name=='clvp': | |
| return model.eval() |