|
|
from typing import Any |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class ContentEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embed_dim: int, |
|
|
text_encoder: nn.Module = None, |
|
|
llm_encoder: nn.Module = None, |
|
|
video_encoder: nn.Module = None, |
|
|
midi_encoder: nn.Module = None, |
|
|
phoneme_encoder: nn.Module = None, |
|
|
pitch_encoder: nn.Module = None, |
|
|
audio_encoder: nn.Module = None |
|
|
): |
|
|
super().__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.text_encoder = text_encoder |
|
|
self.midi_encoder = midi_encoder |
|
|
self.phoneme_encoder = phoneme_encoder |
|
|
self.pitch_encoder = pitch_encoder |
|
|
self.audio_encoder = audio_encoder |
|
|
self.video_encoder = video_encoder |
|
|
|
|
|
def encode_content( |
|
|
self, batch_content: list[Any], batch_task: list[str], |
|
|
device: str | torch.device |
|
|
): |
|
|
batch_content_output = [] |
|
|
batch_content_mask = [] |
|
|
batch_la_content_output = [] |
|
|
batch_la_content_output_mask = [] |
|
|
zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device) |
|
|
|
|
|
for i,(content, task) in enumerate(zip(batch_content, batch_task)): |
|
|
if task == "audio_editing": |
|
|
raw_waveform = torch.as_tensor(content["audio"]).float() |
|
|
waveform_with_batch_dim = raw_waveform.unsqueeze(0).to(device) |
|
|
waveform_lengths = torch.as_tensor([raw_waveform.shape[0]]) |
|
|
|
|
|
|
|
|
content_output_dict = self.text_encoder( |
|
|
[content["caption"]], waveform_with_batch_dim |
|
|
) |
|
|
audio_dict = { |
|
|
"waveform": waveform_with_batch_dim, |
|
|
"waveform_lengths": waveform_lengths |
|
|
} |
|
|
audio_output_dict = self.audio_encoder(**audio_dict) |
|
|
la_content_output_dict = { |
|
|
"output": audio_output_dict["output"], |
|
|
"mask": audio_output_dict["mask"] |
|
|
} |
|
|
|
|
|
batch_content_output.append(content_output_dict["output"][0]) |
|
|
batch_content_mask.append(content_output_dict["mask"][0]) |
|
|
batch_la_content_output.append(la_content_output_dict["output"][0]) |
|
|
batch_la_content_output_mask.append( |
|
|
la_content_output_dict.get("mask", zero_la_content)[0] |
|
|
) |
|
|
|
|
|
batch_content_output = nn.utils.rnn.pad_sequence( |
|
|
batch_content_output, batch_first=True, padding_value=0 |
|
|
) |
|
|
batch_content_mask = nn.utils.rnn.pad_sequence( |
|
|
batch_content_mask, batch_first=True, padding_value=False |
|
|
) |
|
|
batch_la_content_output = nn.utils.rnn.pad_sequence( |
|
|
batch_la_content_output, batch_first=True, padding_value=0 |
|
|
) |
|
|
|
|
|
batch_la_content_output_mask = nn.utils.rnn.pad_sequence( |
|
|
batch_la_content_output_mask, batch_first=True, padding_value=False |
|
|
) |
|
|
return { |
|
|
"content": batch_content_output , |
|
|
"content_mask": batch_content_mask, |
|
|
"length_aligned_content": batch_la_content_output, |
|
|
"time_aligned_content_mask": batch_la_content_output_mask |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class BatchedContentEncoder(ContentEncoder): |
|
|
def encode_content( |
|
|
self, batch_content: list[dict], batch_task: list[str], |
|
|
device: str | torch.device |
|
|
): |
|
|
assert all(task == "audio_editing" for task in batch_task), \ |
|
|
"BatchedContentEncoder now are only support audio_editing" |
|
|
|
|
|
zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device) |
|
|
|
|
|
captions = [] |
|
|
waveforms = [] |
|
|
waveform_lengths = [] |
|
|
for content in batch_content: |
|
|
raw_waveform = torch.as_tensor(content["audio"]).float().to(device) |
|
|
captions.append(content["caption"]) |
|
|
waveforms.append(raw_waveform) |
|
|
waveform_lengths.append(raw_waveform.shape[0]) |
|
|
|
|
|
content_output_dict = self.text_encoder( |
|
|
captions, waveforms |
|
|
) |
|
|
|
|
|
batch_la_content_output = [] |
|
|
batch_la_content_output_mask = [] |
|
|
for i in range(len(batch_content)): |
|
|
audio_dict = { |
|
|
"waveform": waveforms[i].unsqueeze(0), |
|
|
"waveform_lengths": torch.as_tensor([waveform_lengths[i]], device=device) |
|
|
} |
|
|
audio_output_dict = self.audio_encoder(**audio_dict) |
|
|
batch_la_content_output.append(audio_output_dict["output"][0]) |
|
|
batch_la_content_output_mask.append(audio_output_dict["mask"][0]) |
|
|
|
|
|
|
|
|
batch_la_content_output = nn.utils.rnn.pad_sequence( |
|
|
batch_la_content_output, batch_first=True, padding_value=0 |
|
|
) |
|
|
batch_la_content_output_mask = nn.utils.rnn.pad_sequence( |
|
|
batch_la_content_output_mask, batch_first=True, padding_value=False |
|
|
) |
|
|
|
|
|
return { |
|
|
"content": content_output_dict["output"], |
|
|
"content_mask": content_output_dict["mask"], |
|
|
"length_aligned_content": batch_la_content_output, |
|
|
"time_aligned_content_mask": batch_la_content_output_mask |
|
|
} |
|
|
|