tianfengping.tfp
check prompt path1
c747137
raw
history blame
12.6 kB
#
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Callable, List, Generator
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
from cosyvoice_rodis.utils.common import IGNORE_ID
from cosyvoice_rodis.transformer.label_smoothing_loss import LabelSmoothingLoss
from cosyvoice_rodis.utils.common import th_accuracy
from cosyvoice_rodis.utils.losses import OrthogonalityLoss
class TransformerLM(torch.nn.Module):
def __init__(
self,
text_encoder_input_size: int,
llm_input_size: int,
llm_output_size: int,
text_token_size: int,
speech_token_size: int,
text_encoder: torch.nn.Module,
llm: torch.nn.Module,
sampling: Callable,
length_normalized_loss: bool = True,
lsm_weight: float = 0.0,
spk_embed_dim: int = 192,
orth_loss: bool = False,
cross_orth_loss: bool = False,
emotion_embedding: bool = False,
):
super().__init__()
self.llm_input_size = llm_input_size
self.speech_token_size = speech_token_size
# 1. build text token inputs related modules
self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
self.text_encoder = text_encoder
self.text_encoder_affine_layer = nn.Linear(
self.text_encoder.output_size(),
llm_input_size
)
# 2. build speech token language model related modules
self.orth_loss = orth_loss
self.cross_orth_loss = cross_orth_loss
self.emotion_embedding = emotion_embedding
self.sos_eos = 0
self.task_id = 1
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
self.llm = llm
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
self.criterion_ce = LabelSmoothingLoss(
size=speech_token_size + 1,
padding_idx=IGNORE_ID,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
# 3. [Optional] build speech token related modules
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size) #192-1024
# 4. sampling method
self.sampling = sampling
if orth_loss:
self.speaker_projector = nn.Linear(spk_embed_dim, spk_embed_dim)
self.emotion_projector = nn.Linear(spk_embed_dim, spk_embed_dim)
def encode(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
):
encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
encoder_out = self.text_encoder_affine_layer(encoder_out)
return encoder_out, encoder_out_lens
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
for i in range(len(text_token))]
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
return lm_input, lm_input_len
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
"""
Args:
batch: input dic
device: CPU or GPU
Returns:
loss and accurate
"""
text_token = batch['text_token'].to(device)
text_token_len = batch['text_token_len'].to(device)
speech_token = batch['speech_token'].to(device)
speech_token_len = batch['speech_token_len'].to(device)
embedding = batch['embedding'].to(device)
# 2. process emotion_embedding
if self.emotion_embedding:
emotion_embedding = batch['emotion_embedding'].to(device)
else:
emotion_embedding = None
# 3. cross loss
if self.orth_loss and self.emotion_embedding:
embedding = self.speaker_projector(embedding)
emotion_embedding = self.emotion_projector(emotion_embedding)
embedding += emotion_embedding
if self.cross_orth_loss:
orth_loss = 0.0
contrastive_loss = 0.0
batch_size = embedding.size(0)
for i in range(batch_size):
for j in range(i + 1, batch_size):
contrastive_loss=torch.abs(torch.dot(embedding[i], emotion_embedding[j]))
orth_loss +=contrastive_loss
if batch_size == 1:
orth_loss = 0
else:
orth_loss /= (batch_size * (batch_size - 1)) / 2
else:
orth_loss = OrthogonalityLoss(embedding, emotion_embedding)
else:
orth_loss = torch.tensor(0.0).to(device)
lm_target = [
torch.tensor(
[IGNORE_ID] * (2 + text_token_len[i]) +
speech_token[i, :speech_token_len[i]].tolist() +
[self.speech_token_size] # EOS token
)
for i in range(text_token.size(0))
]
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
text_token = self.text_embedding(text_token) #[B,T,512] 221,31,1024
text_token, text_token_len = self.encode(text_token, text_token_len)
embedding = F.normalize(embedding, dim=1)
if embedding.dtype != self.spk_embed_affine_layer.weight.dtype:
embedding = embedding.to(self.spk_embed_affine_layer.weight.dtype)
embedding = self.spk_embed_affine_layer(embedding)
embedding = embedding.unsqueeze(1)
if self.emotion_embedding and emotion_embedding is not None:
emotion_embedding = F.normalize(emotion_embedding, dim=1)
if emotion_embedding.dtype != self.spk_embed_affine_layer.weight.dtype:
emotion_embedding = emotion_embedding.to(self.spk_embed_affine_layer.weight.dtype)
emotion_embedding = self.spk_embed_affine_layer(emotion_embedding)
emotion_embedding = emotion_embedding.unsqueeze(1)
embedding += emotion_embedding
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
speech_token = self.speech_embedding(speech_token)
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
task_id_emb, speech_token, speech_token_len)
# 6. run lm forward
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
logits = self.llm_decoder(lm_output)
loss = self.criterion_ce(logits, lm_target)
acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
if self.orth_loss and self.emotion_embedding:
loss += orth_loss
return {'loss': loss, 'acc': acc,"ce_loss":self.criterion_ce(logits, lm_target),"orth_loss":orth_loss,"contrastive_loss":contrastive_loss}
def sampling_ids(
self,
weighted_scores: torch.Tensor,
decoded_tokens: List,
sampling: int,
ignore_eos: bool = True,
):
while True:
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
if (not ignore_eos) or (self.speech_token_size not in top_ids):
break
return top_ids
@torch.inference_mode()
def inference(
self,
text: torch.Tensor,
text_len: torch.Tensor,
prompt_text: torch.Tensor,
prompt_text_len: torch.Tensor,
prompt_speech_token: torch.Tensor,
prompt_speech_token_len: torch.Tensor,
embedding: torch.Tensor,
emotion_embedding: Optional[torch.Tensor] = None,
sampling: int = 25,
max_token_text_ratio: float = 20,
min_token_text_ratio: float = 2,
) -> Generator[torch.Tensor, None, None]:
device = text.device
text = torch.concat([prompt_text, text], dim=1)
text_len += prompt_text_len
text = self.text_embedding(text)
# 1. encode text
text, text_len = self.encode(text, text_len)
# 2. encode embedding
if embedding.shape[0] != 0:
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
embedding = embedding.unsqueeze(dim=1)
else:
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
# 3. handle emotion embedding
if self.emotion_embedding and emotion_embedding is not None:
emotion_embedding = F.normalize(emotion_embedding.unsqueeze(0).to(torch.float32), dim=1)
if emotion_embedding.dtype != self.spk_embed_affine_layer.weight.dtype:
emotion_embedding = emotion_embedding.to(self.spk_embed_affine_layer.weight.dtype)
emotion_embedding = self.spk_embed_affine_layer(emotion_embedding)
emotion_embedding = emotion_embedding.unsqueeze(dim=1) # * 1.5
embedding += emotion_embedding
# 4. concat llm_input
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
if prompt_speech_token_len != 0:
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
else:
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
# 5. cal min/max_length
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
# 6. step by step decode
out_tokens = []
offset = 0
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
for i in range(max_len):
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(
lm_input, offset=offset, required_cache_size=-1,
att_cache=att_cache, cnn_cache=cnn_cache,
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device).to(torch.bool)))
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
# force continue decode first token
if i == 0:
logp[:, self.speech_token_size] = -float('inf')
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
if top_ids == self.speech_token_size:
break
# in stream mode, yield token one by one
yield top_ids
out_tokens.append(top_ids)
offset += lm_input.size(1)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)