Upload DAC
Browse files- config.json +16 -0
- model.py +212 -212
- model.safetensors +3 -0
config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"DAC"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "model.DACConfig",
|
| 7 |
+
"AutoModel": "model.DAC"
|
| 8 |
+
},
|
| 9 |
+
"decoding_chunk_rate": 0.1,
|
| 10 |
+
"decoding_overlap_rate": 0.1,
|
| 11 |
+
"encoding_chunk_size_in_sec": 1,
|
| 12 |
+
"model_type": "dac",
|
| 13 |
+
"model_type_by_sampling_freq": "16khz",
|
| 14 |
+
"torch_dtype": "float32",
|
| 15 |
+
"transformers_version": "4.44.0"
|
| 16 |
+
}
|
model.py
CHANGED
|
@@ -1,212 +1,212 @@
|
|
| 1 |
-
from typing import Union
|
| 2 |
-
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
import torchaudio
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import torchaudio.transforms as transforms
|
| 8 |
-
from transformers import PretrainedConfig, PreTrainedModel
|
| 9 |
-
|
| 10 |
-
import dac
|
| 11 |
-
from audiotools import AudioSignal
|
| 12 |
-
|
| 13 |
-
from utils import freeze
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class DACConfig(PretrainedConfig):
|
| 17 |
-
model_type = 'dac'
|
| 18 |
-
|
| 19 |
-
def __init__(self,
|
| 20 |
-
model_type_by_sampling_freq:str='44khz',
|
| 21 |
-
encoding_chunk_size_in_sec:int=1,
|
| 22 |
-
decoding_chunk_rate:float=0.1,
|
| 23 |
-
decoding_overlap_rate:float=0.1,
|
| 24 |
-
**kwargs):
|
| 25 |
-
super().__init__(**kwargs)
|
| 26 |
-
"""
|
| 27 |
-
Initializes the model object.
|
| 28 |
-
Args:
|
| 29 |
-
model_type_by_sampling_freq (str, optional): The model type based on the sampling frequency. Defaults to '44khz'. Choose among ['44khz', '24khz', '16khz']
|
| 30 |
-
encoding_chunk_size_in_sec (int, optional): The size of the encoding chunk in seconds. Defaults to 1.
|
| 31 |
-
decoding_chunk_rate (float, optional): The decoding chunk rate. Must be between 0 and 1. Defaults to 0.1.
|
| 32 |
-
decoding_overlap_rate (float, optional): The decoding overlap rate. Must be between 0 and 1. Defaults to 0.1.
|
| 33 |
-
**kwargs: Additional keyword arguments.
|
| 34 |
-
Raises:
|
| 35 |
-
AssertionError: If the model_type_by_sampling_freq is not one of ['44khz', '24khz', '16khz'].
|
| 36 |
-
AssertionError: If the decoding_chunk_rate is not between 0 and 1.
|
| 37 |
-
AssertionError: If the decoding_overlap_rate is not between 0 and 1.
|
| 38 |
-
"""
|
| 39 |
-
self.model_type_by_sampling_freq = model_type_by_sampling_freq
|
| 40 |
-
self.encoding_chunk_size_in_sec = encoding_chunk_size_in_sec
|
| 41 |
-
self.decoding_chunk_rate = decoding_chunk_rate
|
| 42 |
-
self.decoding_overlap_rate = decoding_overlap_rate
|
| 43 |
-
|
| 44 |
-
assert model_type_by_sampling_freq.lower() in ['44khz', '24khz', '16khz']
|
| 45 |
-
assert decoding_chunk_rate > 0 and decoding_chunk_rate <= 1.0, '`decoding_chunk_rate` must be bewteen 0 and 1.'
|
| 46 |
-
assert decoding_overlap_rate >= 0 and decoding_overlap_rate < 1.0, '`decoding_overlap_rate` must be bewteen 0 and 1.'
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
class DAC(PreTrainedModel):
|
| 51 |
-
config_class = DACConfig
|
| 52 |
-
|
| 53 |
-
def __init__(self, config):
|
| 54 |
-
super().__init__(config)
|
| 55 |
-
|
| 56 |
-
self.model_type_by_sampling_freq = config.model_type_by_sampling_freq.lower()
|
| 57 |
-
self.model_type_by_sampling_freq_int = {'44khz':44100, '24khz':24000, '16khz':16000}[self.model_type_by_sampling_freq]
|
| 58 |
-
self.encoding_chunk_size_in_sec = config.encoding_chunk_size_in_sec
|
| 59 |
-
self.decoding_chunk_rate = config.decoding_chunk_rate
|
| 60 |
-
self.decoding_overlap_rate = config.decoding_overlap_rate
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
dac_path = dac.utils.download(model_type=self.model_type_by_sampling_freq)
|
| 64 |
-
self.dac = dac.DAC.load(dac_path)
|
| 65 |
-
self.dac.eval()
|
| 66 |
-
freeze(self.dac)
|
| 67 |
-
|
| 68 |
-
self.downsampling_rate = int(np.prod(self.dac.encoder_rates)) # 512
|
| 69 |
-
|
| 70 |
-
def load_audio(self, filename:str):
|
| 71 |
-
waveform, sample_rate = torchaudio.load(filename) # waveform: (n_channels, length); sample_rate: const.
|
| 72 |
-
return waveform, sample_rate
|
| 73 |
-
|
| 74 |
-
def resample_audio(self, waveform:torch.FloatTensor, orig_sr:int, target_sr:int):
|
| 75 |
-
"""
|
| 76 |
-
- sr: sampling rate
|
| 77 |
-
- waveform: (n_channels, length)
|
| 78 |
-
"""
|
| 79 |
-
if orig_sr == target_sr:
|
| 80 |
-
return waveform
|
| 81 |
-
|
| 82 |
-
converter = transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)
|
| 83 |
-
waveform = converter(waveform) # (n_channels, new_length)
|
| 84 |
-
return waveform # (n_channels, new_length)
|
| 85 |
-
|
| 86 |
-
def to_mono_channel(self, waveform:torch.FloatTensor):
|
| 87 |
-
"""
|
| 88 |
-
- waveform: (n_channels, length)
|
| 89 |
-
"""
|
| 90 |
-
n_channels = waveform.shape[0]
|
| 91 |
-
if n_channels > 1:
|
| 92 |
-
waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, length)
|
| 93 |
-
return waveform # (1, length)
|
| 94 |
-
|
| 95 |
-
@torch.no_grad()
|
| 96 |
-
def encode(self, audio_fname:str):
|
| 97 |
-
self.eval()
|
| 98 |
-
|
| 99 |
-
waveform, sr = self.load_audio(audio_fname)
|
| 100 |
-
waveform = self.resample_audio(waveform, orig_sr=sr, target_sr=self.model_type_by_sampling_freq_int)
|
| 101 |
-
sr = self.model_type_by_sampling_freq_int
|
| 102 |
-
waveform = self.to_mono_channel(waveform) # DAC accepts a mono channel only.
|
| 103 |
-
|
| 104 |
-
zq, s = self._chunk_encoding(waveform, sr)
|
| 105 |
-
return zq, s
|
| 106 |
-
|
| 107 |
-
def _chunk_encoding(self, waveform:torch.FloatTensor, sr:int):
|
| 108 |
-
# TODO: can I make it parallel?
|
| 109 |
-
"""
|
| 110 |
-
waveform: (c l)
|
| 111 |
-
"""
|
| 112 |
-
x = waveform # brief varname
|
| 113 |
-
x = x.unsqueeze(1) # (b 1 l); add a null batch dim
|
| 114 |
-
chunk_size = int(self.encoding_chunk_size_in_sec * sr)
|
| 115 |
-
|
| 116 |
-
# adjust `chunk_size` to prevent any padding in `dac.preprocess`, which causes a gap between the mini-batches in the resulting music.
|
| 117 |
-
remainer = chunk_size % self.dac.hop_length
|
| 118 |
-
chunk_size = chunk_size-remainer
|
| 119 |
-
|
| 120 |
-
# process
|
| 121 |
-
zq_list, s_list = [], []
|
| 122 |
-
audio_length = x.shape[-1]
|
| 123 |
-
for start in range(0, audio_length, chunk_size):
|
| 124 |
-
end = start + chunk_size
|
| 125 |
-
chunk = x[:, :, start:end]
|
| 126 |
-
chunk = self.dac.preprocess(chunk, sr)
|
| 127 |
-
zq, s, _, _, _ = self.dac.encode(chunk.to(self.device))
|
| 128 |
-
zq = zq.cpu()
|
| 129 |
-
s = s.cpu()
|
| 130 |
-
"""
|
| 131 |
-
"zq" : Tensor[B x D x T]
|
| 132 |
-
Quantized continuous representation of input
|
| 133 |
-
= summation of all the residual quantized vectors across every rvq level
|
| 134 |
-
= E(x) = z = \sum_n^N{zq_n} where N is the number of codebooks
|
| 135 |
-
"s" : Tensor[B x N x T]
|
| 136 |
-
Codebook indices for each codebook
|
| 137 |
-
(quantized discrete representation of input)
|
| 138 |
-
*first element in the N dimension = first RVQ level
|
| 139 |
-
"""
|
| 140 |
-
zq_list.append(zq)
|
| 141 |
-
s_list.append(s)
|
| 142 |
-
torch.cuda.empty_cache()
|
| 143 |
-
|
| 144 |
-
zq = torch.cat(zq_list, dim=2).float() # (1, d, length)
|
| 145 |
-
s = torch.cat(s_list, dim=2).long() # (1, n_rvq, length)
|
| 146 |
-
|
| 147 |
-
return zq, s
|
| 148 |
-
|
| 149 |
-
@torch.no_grad()
|
| 150 |
-
def decode(self, *, zq:Union[torch.FloatTensor,None]=None, s:Union[torch.IntTensor,None]=None):
|
| 151 |
-
"""
|
| 152 |
-
zq: (b, d, length)
|
| 153 |
-
"""
|
| 154 |
-
if isinstance(zq,type(None)) and isinstance(s,type(None)):
|
| 155 |
-
assert False, 'one of them must be valid.'
|
| 156 |
-
self.eval()
|
| 157 |
-
|
| 158 |
-
if not isinstance(zq,type(None)):
|
| 159 |
-
waveform = self._chunk_decoding(zq) # (b, 1, length); output always has a mono-channel.
|
| 160 |
-
if not isinstance(s,type(None)):
|
| 161 |
-
zq = self.code_to_zq(s)
|
| 162 |
-
waveform = self._chunk_decoding(zq) # (b, 1, length); output always has a mono-channel.
|
| 163 |
-
|
| 164 |
-
return waveform
|
| 165 |
-
|
| 166 |
-
def _chunk_decoding(self, zq:torch.FloatTensor):
|
| 167 |
-
"""
|
| 168 |
-
zq: (b, d, length)
|
| 169 |
-
"""
|
| 170 |
-
length = zq.shape[-1]
|
| 171 |
-
chunk_size = round(int(self.decoding_chunk_rate * length))
|
| 172 |
-
overlap_size = round(self.decoding_overlap_rate * chunk_size) # overlap size in terms of token length
|
| 173 |
-
overlap_size_in_data_space = round(overlap_size * self.downsampling_rate)
|
| 174 |
-
waveform_concat = None
|
| 175 |
-
for start in range(0, length, chunk_size-overlap_size):
|
| 176 |
-
end = start + chunk_size
|
| 177 |
-
chunk = zq[:,:, start:end] # (b, d, chunk_size)
|
| 178 |
-
waveform = self.dac.decode(chunk.to(self.device)) # (b, 1, chunk_size*self.downsampling_rate)
|
| 179 |
-
waveform = waveform.cpu()
|
| 180 |
-
|
| 181 |
-
if isinstance(waveform_concat, type(None)):
|
| 182 |
-
waveform_concat = waveform.clone()
|
| 183 |
-
else:
|
| 184 |
-
if self.decoding_overlap_rate != 0.:
|
| 185 |
-
prev_x = waveform_concat[:,:,:-overlap_size_in_data_space]
|
| 186 |
-
rest_of_new_x = waveform[:,:,overlap_size_in_data_space:]
|
| 187 |
-
overlap_x_from_prev_x = waveform_concat[:,:,-overlap_size_in_data_space:] # (b, 1, overlap_size_in_data_space)
|
| 188 |
-
overlap_x_from_new_x = waveform[:,:,:overlap_size_in_data_space] # (b, 1, overlap_size_in_data_space)
|
| 189 |
-
overlap = (overlap_x_from_prev_x + overlap_x_from_new_x) / 2 # take mean; maybe there's a better strategy but it seems to work fine.
|
| 190 |
-
waveform_concat = torch.cat((prev_x, overlap, rest_of_new_x), dim=-1) # (b, 1, ..)
|
| 191 |
-
else:
|
| 192 |
-
prev_x = waveform_concat
|
| 193 |
-
rest_of_new_x = waveform
|
| 194 |
-
waveform_concat = torch.cat((prev_x, rest_of_new_x), dim=-1) # (b, 1, ..)
|
| 195 |
-
return waveform_concat # (b, 1, length)
|
| 196 |
-
|
| 197 |
-
def code_to_zq(self, s:torch.IntTensor):
|
| 198 |
-
"""
|
| 199 |
-
s: (b, n_rvq, length)
|
| 200 |
-
"""
|
| 201 |
-
zq, _, _ = self.dac.quantizer.from_codes(s.to(self.device)) # zq: (b, d, length)
|
| 202 |
-
zq = zq.cpu()
|
| 203 |
-
return zq
|
| 204 |
-
|
| 205 |
-
def save_tensor(self, tensor:torch.Tensor, fname:str) -> None:
|
| 206 |
-
torch.save(tensor.cpu(), fname)
|
| 207 |
-
|
| 208 |
-
def load_tensor(self, fname:str):
|
| 209 |
-
return torch.load(fname)
|
| 210 |
-
|
| 211 |
-
def waveform_to_audiofile(self, waveform:torch.FloatTensor, fname:str) -> None:
|
| 212 |
-
AudioSignal(waveform, sample_rate=self.model_type_by_sampling_freq_int).write(fname)
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torchaudio.transforms as transforms
|
| 8 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
| 9 |
+
|
| 10 |
+
import dac
|
| 11 |
+
from audiotools import AudioSignal
|
| 12 |
+
|
| 13 |
+
from utils import freeze
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DACConfig(PretrainedConfig):
|
| 17 |
+
model_type = 'dac'
|
| 18 |
+
|
| 19 |
+
def __init__(self,
|
| 20 |
+
model_type_by_sampling_freq:str='44khz',
|
| 21 |
+
encoding_chunk_size_in_sec:int=1,
|
| 22 |
+
decoding_chunk_rate:float=0.1,
|
| 23 |
+
decoding_overlap_rate:float=0.1,
|
| 24 |
+
**kwargs):
|
| 25 |
+
super().__init__(**kwargs)
|
| 26 |
+
"""
|
| 27 |
+
Initializes the model object.
|
| 28 |
+
Args:
|
| 29 |
+
model_type_by_sampling_freq (str, optional): The model type based on the sampling frequency. Defaults to '44khz'. Choose among ['44khz', '24khz', '16khz']
|
| 30 |
+
encoding_chunk_size_in_sec (int, optional): The size of the encoding chunk in seconds. Defaults to 1.
|
| 31 |
+
decoding_chunk_rate (float, optional): The decoding chunk rate. Must be between 0 and 1. Defaults to 0.1.
|
| 32 |
+
decoding_overlap_rate (float, optional): The decoding overlap rate. Must be between 0 and 1. Defaults to 0.1.
|
| 33 |
+
**kwargs: Additional keyword arguments.
|
| 34 |
+
Raises:
|
| 35 |
+
AssertionError: If the model_type_by_sampling_freq is not one of ['44khz', '24khz', '16khz'].
|
| 36 |
+
AssertionError: If the decoding_chunk_rate is not between 0 and 1.
|
| 37 |
+
AssertionError: If the decoding_overlap_rate is not between 0 and 1.
|
| 38 |
+
"""
|
| 39 |
+
self.model_type_by_sampling_freq = model_type_by_sampling_freq
|
| 40 |
+
self.encoding_chunk_size_in_sec = encoding_chunk_size_in_sec
|
| 41 |
+
self.decoding_chunk_rate = decoding_chunk_rate
|
| 42 |
+
self.decoding_overlap_rate = decoding_overlap_rate
|
| 43 |
+
|
| 44 |
+
assert model_type_by_sampling_freq.lower() in ['44khz', '24khz', '16khz']
|
| 45 |
+
assert decoding_chunk_rate > 0 and decoding_chunk_rate <= 1.0, '`decoding_chunk_rate` must be bewteen 0 and 1.'
|
| 46 |
+
assert decoding_overlap_rate >= 0 and decoding_overlap_rate < 1.0, '`decoding_overlap_rate` must be bewteen 0 and 1.'
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class DAC(PreTrainedModel):
|
| 51 |
+
config_class = DACConfig
|
| 52 |
+
|
| 53 |
+
def __init__(self, config):
|
| 54 |
+
super().__init__(config)
|
| 55 |
+
|
| 56 |
+
self.model_type_by_sampling_freq = config.model_type_by_sampling_freq.lower()
|
| 57 |
+
self.model_type_by_sampling_freq_int = {'44khz':44100, '24khz':24000, '16khz':16000}[self.model_type_by_sampling_freq]
|
| 58 |
+
self.encoding_chunk_size_in_sec = config.encoding_chunk_size_in_sec
|
| 59 |
+
self.decoding_chunk_rate = config.decoding_chunk_rate
|
| 60 |
+
self.decoding_overlap_rate = config.decoding_overlap_rate
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
dac_path = dac.utils.download(model_type=self.model_type_by_sampling_freq)
|
| 64 |
+
self.dac = dac.DAC.load(dac_path)
|
| 65 |
+
self.dac.eval()
|
| 66 |
+
freeze(self.dac)
|
| 67 |
+
|
| 68 |
+
self.downsampling_rate = int(np.prod(self.dac.encoder_rates)) # 512
|
| 69 |
+
|
| 70 |
+
def load_audio(self, filename:str):
|
| 71 |
+
waveform, sample_rate = torchaudio.load(filename) # waveform: (n_channels, length); sample_rate: const.
|
| 72 |
+
return waveform, sample_rate
|
| 73 |
+
|
| 74 |
+
def resample_audio(self, waveform:torch.FloatTensor, orig_sr:int, target_sr:int):
|
| 75 |
+
"""
|
| 76 |
+
- sr: sampling rate
|
| 77 |
+
- waveform: (n_channels, length)
|
| 78 |
+
"""
|
| 79 |
+
if orig_sr == target_sr:
|
| 80 |
+
return waveform
|
| 81 |
+
|
| 82 |
+
converter = transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)
|
| 83 |
+
waveform = converter(waveform) # (n_channels, new_length)
|
| 84 |
+
return waveform # (n_channels, new_length)
|
| 85 |
+
|
| 86 |
+
def to_mono_channel(self, waveform:torch.FloatTensor):
|
| 87 |
+
"""
|
| 88 |
+
- waveform: (n_channels, length)
|
| 89 |
+
"""
|
| 90 |
+
n_channels = waveform.shape[0]
|
| 91 |
+
if n_channels > 1:
|
| 92 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True) # (1, length)
|
| 93 |
+
return waveform # (1, length)
|
| 94 |
+
|
| 95 |
+
@torch.no_grad()
|
| 96 |
+
def encode(self, audio_fname:str):
|
| 97 |
+
self.eval()
|
| 98 |
+
|
| 99 |
+
waveform, sr = self.load_audio(audio_fname)
|
| 100 |
+
waveform = self.resample_audio(waveform, orig_sr=sr, target_sr=self.model_type_by_sampling_freq_int)
|
| 101 |
+
sr = self.model_type_by_sampling_freq_int
|
| 102 |
+
waveform = self.to_mono_channel(waveform) # DAC accepts a mono channel only.
|
| 103 |
+
|
| 104 |
+
zq, s = self._chunk_encoding(waveform, sr)
|
| 105 |
+
return zq, s
|
| 106 |
+
|
| 107 |
+
def _chunk_encoding(self, waveform:torch.FloatTensor, sr:int):
|
| 108 |
+
# TODO: can I make it parallel?
|
| 109 |
+
"""
|
| 110 |
+
waveform: (c l)
|
| 111 |
+
"""
|
| 112 |
+
x = waveform # brief varname
|
| 113 |
+
x = x.unsqueeze(1) # (b 1 l); add a null batch dim
|
| 114 |
+
chunk_size = int(self.encoding_chunk_size_in_sec * sr)
|
| 115 |
+
|
| 116 |
+
# adjust `chunk_size` to prevent any padding in `dac.preprocess`, which causes a gap between the mini-batches in the resulting music.
|
| 117 |
+
remainer = chunk_size % self.dac.hop_length
|
| 118 |
+
chunk_size = chunk_size-remainer
|
| 119 |
+
|
| 120 |
+
# process
|
| 121 |
+
zq_list, s_list = [], []
|
| 122 |
+
audio_length = x.shape[-1]
|
| 123 |
+
for start in range(0, audio_length, chunk_size):
|
| 124 |
+
end = start + chunk_size
|
| 125 |
+
chunk = x[:, :, start:end]
|
| 126 |
+
chunk = self.dac.preprocess(chunk, sr)
|
| 127 |
+
zq, s, _, _, _ = self.dac.encode(chunk.to(self.device))
|
| 128 |
+
zq = zq.cpu()
|
| 129 |
+
s = s.cpu()
|
| 130 |
+
"""
|
| 131 |
+
"zq" : Tensor[B x D x T]
|
| 132 |
+
Quantized continuous representation of input
|
| 133 |
+
= summation of all the residual quantized vectors across every rvq level
|
| 134 |
+
= E(x) = z = \sum_n^N{zq_n} where N is the number of codebooks
|
| 135 |
+
"s" : Tensor[B x N x T]
|
| 136 |
+
Codebook indices for each codebook
|
| 137 |
+
(quantized discrete representation of input)
|
| 138 |
+
*first element in the N dimension = first RVQ level
|
| 139 |
+
"""
|
| 140 |
+
zq_list.append(zq)
|
| 141 |
+
s_list.append(s)
|
| 142 |
+
torch.cuda.empty_cache()
|
| 143 |
+
|
| 144 |
+
zq = torch.cat(zq_list, dim=2).float() # (1, d, length)
|
| 145 |
+
s = torch.cat(s_list, dim=2).long() # (1, n_rvq, length)
|
| 146 |
+
|
| 147 |
+
return zq, s
|
| 148 |
+
|
| 149 |
+
@torch.no_grad()
|
| 150 |
+
def decode(self, *, zq:Union[torch.FloatTensor,None]=None, s:Union[torch.IntTensor,None]=None):
|
| 151 |
+
"""
|
| 152 |
+
zq: (b, d, length)
|
| 153 |
+
"""
|
| 154 |
+
if isinstance(zq,type(None)) and isinstance(s,type(None)):
|
| 155 |
+
assert False, 'one of them must be valid.'
|
| 156 |
+
self.eval()
|
| 157 |
+
|
| 158 |
+
if not isinstance(zq,type(None)):
|
| 159 |
+
waveform = self._chunk_decoding(zq) # (b, 1, length); output always has a mono-channel.
|
| 160 |
+
if not isinstance(s,type(None)):
|
| 161 |
+
zq = self.code_to_zq(s)
|
| 162 |
+
waveform = self._chunk_decoding(zq) # (b, 1, length); output always has a mono-channel.
|
| 163 |
+
|
| 164 |
+
return waveform
|
| 165 |
+
|
| 166 |
+
def _chunk_decoding(self, zq:torch.FloatTensor):
|
| 167 |
+
"""
|
| 168 |
+
zq: (b, d, length)
|
| 169 |
+
"""
|
| 170 |
+
length = zq.shape[-1]
|
| 171 |
+
chunk_size = round(int(self.decoding_chunk_rate * length))
|
| 172 |
+
overlap_size = round(self.decoding_overlap_rate * chunk_size) # overlap size in terms of token length
|
| 173 |
+
overlap_size_in_data_space = round(overlap_size * self.downsampling_rate)
|
| 174 |
+
waveform_concat = None
|
| 175 |
+
for start in range(0, length, chunk_size-overlap_size):
|
| 176 |
+
end = start + chunk_size
|
| 177 |
+
chunk = zq[:,:, start:end] # (b, d, chunk_size)
|
| 178 |
+
waveform = self.dac.decode(chunk.to(self.device)) # (b, 1, chunk_size*self.downsampling_rate)
|
| 179 |
+
waveform = waveform.cpu()
|
| 180 |
+
|
| 181 |
+
if isinstance(waveform_concat, type(None)):
|
| 182 |
+
waveform_concat = waveform.clone()
|
| 183 |
+
else:
|
| 184 |
+
if self.decoding_overlap_rate != 0.:
|
| 185 |
+
prev_x = waveform_concat[:,:,:-overlap_size_in_data_space]
|
| 186 |
+
rest_of_new_x = waveform[:,:,overlap_size_in_data_space:]
|
| 187 |
+
overlap_x_from_prev_x = waveform_concat[:,:,-overlap_size_in_data_space:] # (b, 1, overlap_size_in_data_space)
|
| 188 |
+
overlap_x_from_new_x = waveform[:,:,:overlap_size_in_data_space] # (b, 1, overlap_size_in_data_space)
|
| 189 |
+
overlap = (overlap_x_from_prev_x + overlap_x_from_new_x) / 2 # take mean; maybe there's a better strategy but it seems to work fine.
|
| 190 |
+
waveform_concat = torch.cat((prev_x, overlap, rest_of_new_x), dim=-1) # (b, 1, ..)
|
| 191 |
+
else:
|
| 192 |
+
prev_x = waveform_concat
|
| 193 |
+
rest_of_new_x = waveform
|
| 194 |
+
waveform_concat = torch.cat((prev_x, rest_of_new_x), dim=-1) # (b, 1, ..)
|
| 195 |
+
return waveform_concat # (b, 1, length)
|
| 196 |
+
|
| 197 |
+
def code_to_zq(self, s:torch.IntTensor):
|
| 198 |
+
"""
|
| 199 |
+
s: (b, n_rvq, length)
|
| 200 |
+
"""
|
| 201 |
+
zq, _, _ = self.dac.quantizer.from_codes(s.to(self.device)) # zq: (b, d, length)
|
| 202 |
+
zq = zq.cpu()
|
| 203 |
+
return zq
|
| 204 |
+
|
| 205 |
+
def save_tensor(self, tensor:torch.Tensor, fname:str) -> None:
|
| 206 |
+
torch.save(tensor.cpu(), fname)
|
| 207 |
+
|
| 208 |
+
def load_tensor(self, fname:str):
|
| 209 |
+
return torch.load(fname)
|
| 210 |
+
|
| 211 |
+
def waveform_to_audiofile(self, waveform:torch.FloatTensor, fname:str) -> None:
|
| 212 |
+
AudioSignal(waveform, sample_rate=self.model_type_by_sampling_freq_int).write(fname)
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d4eedd71256d763a5e9806e32e96bb33d7daff6dc10acbaab5403e4057a45771
|
| 3 |
+
size 296740304
|