Spaces:
Runtime error
Runtime error
| import re | |
| import argparse | |
| from string import punctuation | |
| import torch | |
| import yaml | |
| import numpy as np | |
| from torch.utils.data import DataLoader | |
| from g2p_en import G2p | |
| from pypinyin import pinyin, Style | |
| from utils.model import get_model, get_vocoder | |
| from utils.tools import to_device, synth_samples, get_roberta_emotion_embeddings | |
| from dataset import TextDataset | |
| from text import text_to_sequence | |
| from transformers import RobertaTokenizerFast, AutoModel, AutoModelForSequenceClassification | |
| ro_model = "roberta_pretrained" | |
| roberta_tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base') | |
| roberta_model = AutoModelForSequenceClassification.from_pretrained(ro_model) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def read_lexicon(lex_path): | |
| lexicon = {} | |
| with open(lex_path) as f: | |
| for line in f: | |
| temp = re.split(r"\s+", line.strip("\n")) | |
| word = temp[0] | |
| phones = temp[1:] | |
| if word.lower() not in lexicon: | |
| lexicon[word.lower()] = phones | |
| return lexicon | |
| def preprocess_english(text, preprocess_config): | |
| text = text.rstrip(punctuation) | |
| lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"]) | |
| g2p = G2p() | |
| phones = [] | |
| words = re.split(r"([,;.\-\?\!\s+])", text) | |
| for w in words: | |
| if w.lower() in lexicon: | |
| phones += lexicon[w.lower()] | |
| else: | |
| phones += list(filter(lambda p: p != " ", g2p(w))) | |
| phones = "{" + "}{".join(phones) + "}" | |
| phones = re.sub(r"\{[^\w\s]?\}", "{sp}", phones) | |
| phones = phones.replace("}{", " ") | |
| print("Raw Text Sequence: {}".format(text)) | |
| print("Phoneme Sequence: {}".format(phones)) | |
| sequence = np.array( | |
| text_to_sequence( | |
| phones, preprocess_config["preprocessing"]["text"]["text_cleaners"] | |
| ) | |
| ) | |
| return np.array(sequence) | |
| def preprocess_mandarin(text, preprocess_config): | |
| lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"]) | |
| phones = [] | |
| pinyins = [ | |
| p[0] | |
| for p in pinyin( | |
| text, style=Style.TONE3, strict=False, neutral_tone_with_five=True | |
| ) | |
| ] | |
| for p in pinyins: | |
| if p in lexicon: | |
| phones += lexicon[p] | |
| else: | |
| phones.append("sp") | |
| phones = "{" + " ".join(phones) + "}" | |
| print("Raw Text Sequence: {}".format(text)) | |
| print("Phoneme Sequence: {}".format(phones)) | |
| sequence = np.array( | |
| text_to_sequence( | |
| phones, preprocess_config["preprocessing"]["text"]["text_cleaners"] | |
| ) | |
| ) | |
| return np.array(sequence) | |
| def synthesize(model, step, configs, vocoder, batchs, control_values): | |
| preprocess_config, model_config, train_config = configs | |
| pitch_control, energy_control, duration_control = control_values | |
| for batch in batchs: | |
| batch = to_device(batch, device) | |
| with torch.no_grad(): | |
| # Forward | |
| output = model( | |
| *(batch[2:]), | |
| p_control=pitch_control, | |
| e_control=energy_control, | |
| d_control=duration_control | |
| ) | |
| synth_samples( | |
| batch, | |
| output, | |
| vocoder, | |
| model_config, | |
| preprocess_config, | |
| train_config["path"]["result_path"], | |
| ) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--restore_step", type=int, required=True) | |
| parser.add_argument( | |
| "--mode", | |
| type=str, | |
| choices=["batch", "single"], | |
| required=True, | |
| help="Synthesize a whole dataset or a single sentence", | |
| ) | |
| parser.add_argument( | |
| "--source", | |
| type=str, | |
| default=None, | |
| help="path to a source file with format like train.txt and val.txt, for batch mode only", | |
| ) | |
| parser.add_argument( | |
| "--text", | |
| type=str, | |
| default=None, | |
| help="raw text to synthesize, for single-sentence mode only", | |
| ) | |
| parser.add_argument( | |
| "--speaker_id", | |
| type=int, | |
| default=0, | |
| help="speaker ID for multi-speaker synthesis, for single-sentence mode only", | |
| ) | |
| parser.add_argument( | |
| "--emotion_id", | |
| type=int, | |
| default=0, | |
| help="emotion ID for multi-emotion synthesis, for single-sentence mode only", | |
| ) | |
| parser.add_argument( | |
| "--bert_embed", | |
| type=int, | |
| default=0, | |
| help="Use bert embedings to control sentiment", | |
| ) | |
| parser.add_argument( | |
| "-p", | |
| "--preprocess_config", | |
| type=str, | |
| required=True, | |
| help="path to preprocess.yaml", | |
| ) | |
| parser.add_argument( | |
| "-m", "--model_config", type=str, required=True, help="path to model.yaml" | |
| ) | |
| parser.add_argument( | |
| "-t", "--train_config", type=str, required=True, help="path to train.yaml" | |
| ) | |
| parser.add_argument( | |
| "--pitch_control", | |
| type=float, | |
| default=1.0, | |
| help="control the pitch of the whole utterance, larger value for higher pitch", | |
| ) | |
| parser.add_argument( | |
| "--energy_control", | |
| type=float, | |
| default=1.0, | |
| help="control the energy of the whole utterance, larger value for larger volume", | |
| ) | |
| parser.add_argument( | |
| "--duration_control", | |
| type=float, | |
| default=1.0, | |
| help="control the speed of the whole utterance, larger value for slower speaking rate", | |
| ) | |
| args = parser.parse_args() | |
| # Check source texts | |
| if args.mode == "batch": | |
| assert args.source is not None and args.text is None | |
| if args.mode == "single": | |
| assert args.source is None and args.text is not None | |
| # Read Config | |
| preprocess_config = yaml.load( | |
| open(args.preprocess_config, "r"), Loader=yaml.FullLoader | |
| ) | |
| model_config = yaml.load( | |
| open(args.model_config, "r"), Loader=yaml.FullLoader) | |
| train_config = yaml.load( | |
| open(args.train_config, "r"), Loader=yaml.FullLoader) | |
| configs = (preprocess_config, model_config, train_config) | |
| # Get model | |
| model = get_model(args, configs, device, train=False) | |
| # Load vocoder | |
| vocoder = get_vocoder(model_config, device) | |
| # Preprocess texts | |
| if args.mode == "batch": | |
| # Get dataset | |
| dataset = TextDataset(args.source, preprocess_config) | |
| batchs = DataLoader( | |
| dataset, | |
| batch_size=8, | |
| collate_fn=dataset.collate_fn, | |
| ) | |
| if args.mode == "single": | |
| if np.array([args.bert_embed]) == 0: | |
| emotions = np.array([args.emotion_id]) | |
| # print(f'FS2 emotions: {emotions}') | |
| else: | |
| emotions = get_roberta_emotion_embeddings( | |
| roberta_tokenizer, roberta_model, args.text) | |
| emotions = torch.argmax(emotions, dim=1).cpu().numpy() | |
| # print(f'RoBERTa emotions {emotions}') | |
| ids = raw_texts = [args.text[:100]] | |
| speakers = np.array([args.speaker_id]) | |
| if preprocess_config["preprocessing"]["text"]["language"] == "en": | |
| texts = np.array( | |
| [preprocess_english(args.text, preprocess_config)]) | |
| elif preprocess_config["preprocessing"]["text"]["language"] == "zh": | |
| texts = np.array( | |
| [preprocess_mandarin(args.text, preprocess_config)]) | |
| text_lens = np.array([len(texts[0])]) | |
| batchs = [(ids, raw_texts, speakers, texts, | |
| text_lens, max(text_lens), emotions)] | |
| control_values = args.pitch_control, args.energy_control, args.duration_control | |
| synthesize(model, args.restore_step, configs, | |
| vocoder, batchs, control_values) | |