File size: 4,994 Bytes
88aba71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import soundfile as sf
from xcodec2.modeling_xcodec2 import XCodec2Model
import torchaudio
class TextToSpeech:
def __init__(self, sample_audio_path, sample_audio_text):
self.sample_audio_text = sample_audio_text
# 初始化模型
llasa_3b = "HKUSTAudio/Llasa-3B"
xcodec2 = "HKUSTAudio/xcodec2"
self.tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
self.llasa_3b_model = AutoModelForCausalLM.from_pretrained(
llasa_3b,
trust_remote_code=True,
device_map="auto",
)
self.llasa_3b_model.eval()
self.xcodec_model = XCodec2Model.from_pretrained(xcodec2)
self.xcodec_model.eval().cuda()
# 处理音频
waveform, sample_rate = torchaudio.load(sample_audio_path)
if len(waveform[0]) / sample_rate > 15:
print("已将音频裁剪至前15秒。")
waveform = waveform[:, : sample_rate * 15]
# 检查音频是否为立体声
if waveform.size(0) > 1:
waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
else:
waveform_mono = waveform
self.prompt_wav = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=16000
)(waveform_mono)
# Encode the prompt wav
vq_code_prompt = self.xcodec_model.encode_code(input_waveform=self.prompt_wav)
vq_code_prompt = vq_code_prompt[0, 0, :]
self.speech_ids_prefix = self.ids_to_speech_tokens(vq_code_prompt)
self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
def ids_to_speech_tokens(self, speech_ids):
speech_tokens_str = []
for speech_id in speech_ids:
speech_tokens_str.append(f"<|s_{speech_id}|>")
return speech_tokens_str
def extract_speech_ids(self, speech_tokens_str):
speech_ids = []
for token_str in speech_tokens_str:
if token_str.startswith("<|s_") and token_str.endswith("|>"):
num_str = token_str[4:-2]
num = int(num_str)
speech_ids.append(num)
else:
print(f"Unexpected token: {token_str}")
return speech_ids
@torch.inference_mode()
def infer(self, target_text):
if len(target_text) == 0:
return None
elif len(target_text) > 300:
print("文本过长,请保持在300字符以内。")
target_text = target_text[:300]
input_text = self.sample_audio_text + " " + target_text
formatted_text = (
f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
)
chat = [
{
"role": "user",
"content": "Convert the text to speech:" + formatted_text,
},
{
"role": "assistant",
"content": "<|SPEECH_GENERATION_START|>"
+ "".join(self.speech_ids_prefix),
},
]
input_ids = self.tokenizer.apply_chat_template(
chat, tokenize=True, return_tensors="pt", continue_final_message=True
)
input_ids = input_ids.to("cuda")
outputs = self.llasa_3b_model.generate(
input_ids,
max_length=2048,
eos_token_id=self.speech_end_id,
do_sample=True,
top_p=1,
temperature=0.8,
)
generated_ids = outputs[0][input_ids.shape[1] - len(self.speech_ids_prefix): -1]
speech_tokens = self.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True
)
speech_tokens = self.extract_speech_ids(speech_tokens)
speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
gen_wav = self.xcodec_model.decode_code(speech_tokens)
gen_wav = gen_wav[:, :, self.prompt_wav.shape[1]:]
return (16000, gen_wav[0, 0, :].cpu().numpy())
if __name__ == "__main__":
# 如果遇到问题,请尝试将参考音频转换为WAV或MP3格式,将其裁剪至15秒以内,并缩短提示文本。
sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。"
sample_audio_path = os.path.join(os.path.dirname(__file__), "sample.wav")
tts = TextToSpeech(sample_audio_path, sample_audio_text)
target_text = "晚上好啊,吃了吗您"
result = tts.infer(target_text)
sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), result[1], result[0])
target_text = "我是老北京正黄旗!"
result = tts.infer(target_text)
sf.write(os.path.join(os.path.dirname(__file__), "output1.wav"), result[1], result[0])
|