File size: 5,299 Bytes
c14d03d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from typing import Any
import torch
import torch.nn as nn
class ContentEncoder(nn.Module):
def __init__(
self,
embed_dim: int,
text_encoder: nn.Module = None,
llm_encoder: nn.Module = None,
video_encoder: nn.Module = None,
midi_encoder: nn.Module = None,
phoneme_encoder: nn.Module = None,
pitch_encoder: nn.Module = None,
audio_encoder: nn.Module = None
):
super().__init__()
self.embed_dim = embed_dim
self.text_encoder = text_encoder
self.midi_encoder = midi_encoder
self.phoneme_encoder = phoneme_encoder
self.pitch_encoder = pitch_encoder
self.audio_encoder = audio_encoder
self.video_encoder = video_encoder
def encode_content(
self, batch_content: list[Any], batch_task: list[str],
device: str | torch.device
):
batch_content_output = []
batch_content_mask = []
batch_la_content_output = []
batch_la_content_output_mask = []
zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device)
for i,(content, task) in enumerate(zip(batch_content, batch_task)):
if task == "audio_editing":
raw_waveform = torch.as_tensor(content["audio"]).float()
waveform_with_batch_dim = raw_waveform.unsqueeze(0).to(device)
waveform_lengths = torch.as_tensor([raw_waveform.shape[0]])
# Note: text encoder actually is audiollm encoder, encode both waveform and caption
content_output_dict = self.text_encoder(
[content["caption"]], waveform_with_batch_dim
)
audio_dict = {
"waveform": waveform_with_batch_dim,
"waveform_lengths": waveform_lengths
}
audio_output_dict = self.audio_encoder(**audio_dict)
la_content_output_dict = {
"output": audio_output_dict["output"],
"mask": audio_output_dict["mask"]
}
batch_content_output.append(content_output_dict["output"][0])
batch_content_mask.append(content_output_dict["mask"][0])
batch_la_content_output.append(la_content_output_dict["output"][0])
batch_la_content_output_mask.append(
la_content_output_dict.get("mask", zero_la_content)[0]
)
batch_content_output = nn.utils.rnn.pad_sequence(
batch_content_output, batch_first=True, padding_value=0
)
batch_content_mask = nn.utils.rnn.pad_sequence(
batch_content_mask, batch_first=True, padding_value=False
)
batch_la_content_output = nn.utils.rnn.pad_sequence(
batch_la_content_output, batch_first=True, padding_value=0
)
batch_la_content_output_mask = nn.utils.rnn.pad_sequence(
batch_la_content_output_mask, batch_first=True, padding_value=False
)
return {
"content": batch_content_output ,
"content_mask": batch_content_mask,
"length_aligned_content": batch_la_content_output,
"time_aligned_content_mask": batch_la_content_output_mask
}
class BatchedContentEncoder(ContentEncoder):
def encode_content(
self, batch_content: list[dict], batch_task: list[str],
device: str | torch.device
):
assert all(task == "audio_editing" for task in batch_task), \
"BatchedContentEncoder now are only support audio_editing"
zero_la_content = torch.zeros(1, 1, self.embed_dim, device=device)
captions = []
waveforms = []
waveform_lengths = []
for content in batch_content:
raw_waveform = torch.as_tensor(content["audio"]).float().to(device)
captions.append(content["caption"])
waveforms.append(raw_waveform)
waveform_lengths.append(raw_waveform.shape[0])
content_output_dict = self.text_encoder(
captions, waveforms
)
batch_la_content_output = []
batch_la_content_output_mask = []
for i in range(len(batch_content)):
audio_dict = {
"waveform": waveforms[i].unsqueeze(0),
"waveform_lengths": torch.as_tensor([waveform_lengths[i]], device=device)
}
audio_output_dict = self.audio_encoder(**audio_dict)
batch_la_content_output.append(audio_output_dict["output"][0])
batch_la_content_output_mask.append(audio_output_dict["mask"][0])
# pad audio_encoder
batch_la_content_output = nn.utils.rnn.pad_sequence(
batch_la_content_output, batch_first=True, padding_value=0
)
batch_la_content_output_mask = nn.utils.rnn.pad_sequence(
batch_la_content_output_mask, batch_first=True, padding_value=False
)
return {
"content": content_output_dict["output"],
"content_mask": content_output_dict["mask"],
"length_aligned_content": batch_la_content_output,
"time_aligned_content_mask": batch_la_content_output_mask
}
|