Upload folder using huggingface_hub
Browse files- .gradio/certificate.pem +31 -0
- app.py +133 -0
- configs/cleanmel_offline.yaml +67 -0
- configs/vocos_offline.yaml +44 -0
- model/__pycache__/cleanmel.cpython-310.pyc +0 -0
- model/__pycache__/stft.cpython-310.pyc +0 -0
- model/cleanmel.py +401 -0
- model/stft.py +154 -0
- model/vocos/__init__.py +1 -0
- model/vocos/__pycache__/__init__.cpython-310.pyc +0 -0
- model/vocos/__pycache__/__init__.cpython-312.pyc +0 -0
- model/vocos/__pycache__/__init__.cpython-39.pyc +0 -0
- model/vocos/__pycache__/dataset.cpython-310.pyc +0 -0
- model/vocos/__pycache__/discriminators.cpython-310.pyc +0 -0
- model/vocos/__pycache__/discriminators.cpython-39.pyc +0 -0
- model/vocos/__pycache__/experiment.cpython-310.pyc +0 -0
- model/vocos/__pycache__/experiment.cpython-312.pyc +0 -0
- model/vocos/__pycache__/experiment.cpython-39.pyc +0 -0
- model/vocos/__pycache__/feature_extractors.cpython-310.pyc +0 -0
- model/vocos/__pycache__/feature_extractors.cpython-39.pyc +0 -0
- model/vocos/__pycache__/heads.cpython-310.pyc +0 -0
- model/vocos/__pycache__/helpers.cpython-310.pyc +0 -0
- model/vocos/__pycache__/loss.cpython-310.pyc +0 -0
- model/vocos/__pycache__/models.cpython-310.pyc +0 -0
- model/vocos/__pycache__/modules.cpython-310.pyc +0 -0
- model/vocos/__pycache__/modules.cpython-39.pyc +0 -0
- model/vocos/__pycache__/pretrained.cpython-310.pyc +0 -0
- model/vocos/__pycache__/spectral_ops.cpython-310.pyc +0 -0
- model/vocos/dataset.py +93 -0
- model/vocos/discriminators.py +211 -0
- model/vocos/experiment.py +398 -0
- model/vocos/feature_extractors.py +170 -0
- model/vocos/heads.py +164 -0
- model/vocos/helpers.py +71 -0
- model/vocos/loss.py +114 -0
- model/vocos/models.py +118 -0
- model/vocos/modules.py +213 -0
- model/vocos/pretrained.py +162 -0
- model/vocos/spectral_ops.py +192 -0
.gradio/certificate.pem
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-----BEGIN CERTIFICATE-----
|
| 2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
| 3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
| 4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
| 5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
| 6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
| 7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
| 8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
| 9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
| 10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
| 11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
| 12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
| 13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
| 14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
| 15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
| 16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
| 17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
| 18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
| 19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
| 20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
| 21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
| 22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
| 23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
| 24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
| 25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
| 26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
| 27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
| 28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
| 29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
| 30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
| 31 |
+
-----END CERTIFICATE-----
|
app.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import spaces
|
| 3 |
+
import tempfile
|
| 4 |
+
import soundfile as sf
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import librosa as lb
|
| 7 |
+
import yaml
|
| 8 |
+
import numpy as np
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from model.cleanmel import CleanMel
|
| 11 |
+
from model.vocos.pretrained import Vocos
|
| 12 |
+
from model.stft import InputSTFT, TargetMel
|
| 13 |
+
|
| 14 |
+
DEVICE = torch.device("cuda:5")
|
| 15 |
+
|
| 16 |
+
def read_audio(file_path):
|
| 17 |
+
audio, sample_rate = sf.read(file_path)
|
| 18 |
+
if audio.ndim > 1:
|
| 19 |
+
audio = audio[:, 0]
|
| 20 |
+
if sample_rate != 16000:
|
| 21 |
+
audio = lb.resample(audio, orig_sr=sample_rate, target_sr=16000)
|
| 22 |
+
sample_rate = 16000
|
| 23 |
+
|
| 24 |
+
return torch.tensor(audio).float().squeeze().unsqueeze(0)
|
| 25 |
+
|
| 26 |
+
def stft(audio):
|
| 27 |
+
transform = InputSTFT(
|
| 28 |
+
n_fft=512,
|
| 29 |
+
n_win=512,
|
| 30 |
+
n_hop=128,
|
| 31 |
+
normalize=False,
|
| 32 |
+
center=True,
|
| 33 |
+
onesided=True,
|
| 34 |
+
online=False
|
| 35 |
+
).to(DEVICE)
|
| 36 |
+
return transform(audio)
|
| 37 |
+
|
| 38 |
+
def mel_transform(audio, X_norm):
|
| 39 |
+
transform = TargetMel(
|
| 40 |
+
sample_rate=16000,
|
| 41 |
+
n_fft=512,
|
| 42 |
+
n_win=512,
|
| 43 |
+
n_hop=128,
|
| 44 |
+
n_mels=80,
|
| 45 |
+
f_min=0,
|
| 46 |
+
f_max=8000,
|
| 47 |
+
power=2,
|
| 48 |
+
center=True,
|
| 49 |
+
normalize=False,
|
| 50 |
+
onesided=True,
|
| 51 |
+
mel_norm="slaney",
|
| 52 |
+
mel_scale="slaney",
|
| 53 |
+
librosa_mel=True,
|
| 54 |
+
online=False
|
| 55 |
+
).to(DEVICE)
|
| 56 |
+
return transform(audio, X_norm)
|
| 57 |
+
|
| 58 |
+
def load_cleanmel(model_name):
|
| 59 |
+
model_config = f"./configs/cleanmel_offline.yaml"
|
| 60 |
+
model_config = yaml.safe_load(open(model_config, "r"))["model"]["arch"]["init_args"]
|
| 61 |
+
cleanmel = CleanMel(**model_config)
|
| 62 |
+
cleanmel.load_state_dict(torch.load(f"./ckpts/CleanMel/{model_name}.ckpt"))
|
| 63 |
+
return cleanmel.eval()
|
| 64 |
+
|
| 65 |
+
def load_vocos(model_name):
|
| 66 |
+
vocos = Vocos.from_hparams(config_path="./configs/vocos_offline.yaml")
|
| 67 |
+
vocos = Vocos.from_pretrained(None, model_path=f"./ckpts/Vocos/{model_name}.pt", model=vocos)
|
| 68 |
+
return vocos.eval()
|
| 69 |
+
|
| 70 |
+
def get_mrm_pred(Y_hat, x, X_norm):
|
| 71 |
+
X_noisy = mel_transform(x, X_norm)
|
| 72 |
+
Y_hat = Y_hat.squeeze()
|
| 73 |
+
Y_hat = torch.square(Y_hat * (torch.sqrt(X_noisy) + 1e-10))
|
| 74 |
+
return Y_hat
|
| 75 |
+
|
| 76 |
+
def safe_log(x):
|
| 77 |
+
return torch.log(torch.clip(x, min=1e-5))
|
| 78 |
+
|
| 79 |
+
@spaces.GPU
|
| 80 |
+
@torch.inference_mode()
|
| 81 |
+
def enhance_cleanmel_L_mask(audio_path):
|
| 82 |
+
model = load_cleanmel("offline_CleanMel_L_mask").to(DEVICE)
|
| 83 |
+
vocos = load_vocos("vocos_offline").to(DEVICE)
|
| 84 |
+
x = read_audio(audio_path).to(DEVICE)
|
| 85 |
+
X, X_norm = stft(x)
|
| 86 |
+
Y_hat = model(X)
|
| 87 |
+
MRM_hat = torch.sigmoid(Y_hat)
|
| 88 |
+
Y_hat = get_mrm_pred(MRM_hat, x, X_norm)
|
| 89 |
+
logMel_hat = safe_log(Y_hat)
|
| 90 |
+
y_hat = vocos(logMel_hat, X_norm)
|
| 91 |
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
|
| 92 |
+
sf.write(tmp_file.name, y_hat.squeeze().cpu().numpy(), 16000)
|
| 93 |
+
with tempfile.NamedTemporaryFile(suffix='.npy', delete=False) as tmp_logmel_np_file:
|
| 94 |
+
np.save(tmp_logmel_np_file.name, logMel_hat.squeeze().cpu().numpy())
|
| 95 |
+
logMel_img = logMel_hat.squeeze().cpu().numpy()[::-1, :]
|
| 96 |
+
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_logmel_img:
|
| 97 |
+
# give a plt figure size according to the logMel shape
|
| 98 |
+
plt.figure(figsize=(logMel_img.shape[1] / 100, logMel_img.shape[0] / 50))
|
| 99 |
+
plt.clf()
|
| 100 |
+
plt.imshow(logMel_img, vmin=-11, cmap="jet")
|
| 101 |
+
plt.tight_layout()
|
| 102 |
+
plt.ylabel("Mel bands")
|
| 103 |
+
plt.xlabel("Time (second)")
|
| 104 |
+
plt.yticks([0, 80], [80, 0])
|
| 105 |
+
dur = x.shape[-1] / 16000
|
| 106 |
+
xticks = [int(x) for x in np.linspace(0, logMel_img.shape[-1], 11)]
|
| 107 |
+
xticks_str = ["{:.1f}".format(x) for x in np.linspace(0, dur, 11)]
|
| 108 |
+
plt.xticks(xticks, xticks_str)
|
| 109 |
+
plt.savefig(tmp_logmel_img.name)
|
| 110 |
+
|
| 111 |
+
return tmp_file.name, tmp_logmel_img.name, tmp_logmel_np_file.name
|
| 112 |
+
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
demo = gr.Blocks()
|
| 115 |
+
with gr.Blocks(title="CleanMel Demo") as demo:
|
| 116 |
+
gr.Markdown("## CleanMel Demo")
|
| 117 |
+
gr.Markdown("This demo showcases the CleanMel model for speech enhancement.")
|
| 118 |
+
|
| 119 |
+
with gr.Row():
|
| 120 |
+
audio_input = gr.Audio(label="Input Audio", type="filepath", sources="upload")
|
| 121 |
+
enhance_button = gr.Button("Enhance Audio")
|
| 122 |
+
|
| 123 |
+
output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
|
| 124 |
+
output_mel = gr.Image(label="Output LogMel Spectrogram", type="filepath", visible=True)
|
| 125 |
+
output_np = gr.File(label="Enhanced LogMel Spec. (.npy)", type="filepath")
|
| 126 |
+
|
| 127 |
+
enhance_button.click(
|
| 128 |
+
enhance_cleanmel_L_mask,
|
| 129 |
+
inputs=audio_input,
|
| 130 |
+
outputs=[output_audio, output_mel, output_np]
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
demo.launch(debug=False, share=True)
|
configs/cleanmel_offline.yaml
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed_everything: 2
|
| 2 |
+
|
| 3 |
+
trainer:
|
| 4 |
+
gradient_clip_val: 10
|
| 5 |
+
gradient_clip_algorithm: norm
|
| 6 |
+
devices: null
|
| 7 |
+
accelerator: gpu
|
| 8 |
+
strategy: ddp_find_unused_parameters_false
|
| 9 |
+
sync_batchnorm: false
|
| 10 |
+
precision: 32
|
| 11 |
+
num_sanity_val_steps: 3
|
| 12 |
+
deterministic: true
|
| 13 |
+
max_epochs: 100
|
| 14 |
+
log_every_n_steps: 40
|
| 15 |
+
|
| 16 |
+
model:
|
| 17 |
+
arch:
|
| 18 |
+
class_path: model.arch.cleanmel.CleanMel
|
| 19 |
+
init_args:
|
| 20 |
+
dim_input: 2
|
| 21 |
+
dim_output: 1
|
| 22 |
+
n_layers: 16
|
| 23 |
+
dim_hidden: 144
|
| 24 |
+
layer_linear_freq: 1
|
| 25 |
+
f_kernel_size: 5
|
| 26 |
+
f_conv_groups: 8
|
| 27 |
+
n_freqs: 257
|
| 28 |
+
n_mels: 80
|
| 29 |
+
mamba_state: 16
|
| 30 |
+
mamba_conv_kernel: 4
|
| 31 |
+
online: false
|
| 32 |
+
sr: 16000
|
| 33 |
+
n_fft: 512
|
| 34 |
+
input_stft:
|
| 35 |
+
class_path: model.io.stft.InputSTFT
|
| 36 |
+
init_args:
|
| 37 |
+
n_fft: 512
|
| 38 |
+
n_win: 512
|
| 39 |
+
n_hop: 128
|
| 40 |
+
center: true
|
| 41 |
+
normalize: false
|
| 42 |
+
onesided: true
|
| 43 |
+
online: false
|
| 44 |
+
target_stft:
|
| 45 |
+
class_path: model.io.stft.TargetMel
|
| 46 |
+
init_args:
|
| 47 |
+
sample_rate: 16000
|
| 48 |
+
n_fft: 512
|
| 49 |
+
n_win: 512
|
| 50 |
+
n_hop: 128
|
| 51 |
+
n_mels: 80
|
| 52 |
+
f_min: 0
|
| 53 |
+
f_max: 8000
|
| 54 |
+
power: 2
|
| 55 |
+
center: true
|
| 56 |
+
normalize: false
|
| 57 |
+
onesided: true
|
| 58 |
+
mel_norm: "slaney"
|
| 59 |
+
mel_scale: "slaney"
|
| 60 |
+
librosa_mel: true
|
| 61 |
+
online: false
|
| 62 |
+
|
| 63 |
+
optimizer: [AdamW, { lr: 0.001, weight_decay: 0.001}]
|
| 64 |
+
lr_scheduler: [ExponentialLR, { gamma: 0.99 }]
|
| 65 |
+
exp_name: exp
|
| 66 |
+
metrics: [DNSMOS]
|
| 67 |
+
log_eps: 1e-5
|
configs/vocos_offline.yaml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
feature_extractor:
|
| 2 |
+
class_path: model.vocos.feature_extractors.MelSpectrogramFeatures
|
| 3 |
+
init_args:
|
| 4 |
+
sample_rate: 16000
|
| 5 |
+
n_fft: 512
|
| 6 |
+
n_win: 512
|
| 7 |
+
n_hop: 128
|
| 8 |
+
n_mels: 80
|
| 9 |
+
f_min: 0
|
| 10 |
+
f_max: 8000
|
| 11 |
+
power: 2
|
| 12 |
+
center: true
|
| 13 |
+
normalize: false
|
| 14 |
+
onesided: true
|
| 15 |
+
mel_norm: slaney
|
| 16 |
+
mel_scale: slaney
|
| 17 |
+
librosa_mel: true
|
| 18 |
+
clip_val: 0.00001
|
| 19 |
+
backbone:
|
| 20 |
+
class_path: model.vocos.models.VocosBackbone
|
| 21 |
+
init_args:
|
| 22 |
+
input_channels: 80
|
| 23 |
+
dim: 512
|
| 24 |
+
intermediate_dim: 1536
|
| 25 |
+
num_layers: 8
|
| 26 |
+
layer_scale_init_value: null
|
| 27 |
+
adanorm_num_embeddings: null
|
| 28 |
+
head:
|
| 29 |
+
class_path: model.vocos.heads.ISTFTHead
|
| 30 |
+
init_args:
|
| 31 |
+
dim: 512
|
| 32 |
+
n_fft: 512
|
| 33 |
+
hop_length: 128
|
| 34 |
+
padding: center
|
| 35 |
+
sample_rate: 16000
|
| 36 |
+
initial_learning_rate: 0.0005
|
| 37 |
+
num_warmup_steps: 0
|
| 38 |
+
mel_loss_coeff: 45.0
|
| 39 |
+
mrd_loss_coeff: 0.1
|
| 40 |
+
pretrain_mel_steps: 0
|
| 41 |
+
decay_mel_coeff: false
|
| 42 |
+
evaluate_utmos: true
|
| 43 |
+
evaluate_pesq: true
|
| 44 |
+
evaluate_periodicty: true
|
model/__pycache__/cleanmel.cpython-310.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
model/__pycache__/stft.cpython-310.pyc
ADDED
|
Binary file (4.27 kB). View file
|
|
|
model/cleanmel.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import pytorch_lightning
|
| 8 |
+
import librosa
|
| 9 |
+
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
from torch.nn import Parameter, init
|
| 12 |
+
from torch.nn.common_types import _size_1_t
|
| 13 |
+
|
| 14 |
+
from mamba_ssm import Mamba
|
| 15 |
+
from mamba_ssm.utils.generation import InferenceParams
|
| 16 |
+
|
| 17 |
+
class LinearGroup(nn.Module):
|
| 18 |
+
|
| 19 |
+
def __init__(self, in_features: int, out_features: int, num_groups: int, bias: bool = True) -> None:
|
| 20 |
+
super(LinearGroup, self).__init__()
|
| 21 |
+
self.in_features = in_features
|
| 22 |
+
self.out_features = out_features
|
| 23 |
+
self.num_groups = num_groups
|
| 24 |
+
self.weight = Parameter(torch.empty((num_groups, out_features, in_features)))
|
| 25 |
+
if bias:
|
| 26 |
+
self.bias = Parameter(torch.empty(num_groups, out_features))
|
| 27 |
+
else:
|
| 28 |
+
self.register_parameter('bias', None)
|
| 29 |
+
self.reset_parameters()
|
| 30 |
+
|
| 31 |
+
def reset_parameters(self) -> None:
|
| 32 |
+
# same as linear
|
| 33 |
+
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 34 |
+
if self.bias is not None:
|
| 35 |
+
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
| 36 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| 37 |
+
init.uniform_(self.bias, -bound, bound)
|
| 38 |
+
|
| 39 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 40 |
+
"""shape [..., group, feature]"""
|
| 41 |
+
x = torch.einsum("...gh,gkh->...gk", x, self.weight)
|
| 42 |
+
if self.bias is not None:
|
| 43 |
+
x = x + self.bias
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
def extra_repr(self) -> str:
|
| 47 |
+
return f"{self.in_features}, {self.out_features}, num_groups={self.num_groups}, bias={True if self.bias is not None else False}"
|
| 48 |
+
|
| 49 |
+
class LayerNorm(nn.LayerNorm):
|
| 50 |
+
|
| 51 |
+
def __init__(self, seq_last: bool, **kwargs) -> None:
|
| 52 |
+
"""
|
| 53 |
+
Arg s:
|
| 54 |
+
seq_last (bool): whether the sequence dim is the last dim
|
| 55 |
+
"""
|
| 56 |
+
super().__init__(**kwargs)
|
| 57 |
+
self.seq_last = seq_last
|
| 58 |
+
|
| 59 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 60 |
+
if self.seq_last:
|
| 61 |
+
input = input.transpose(-1, 1) # [B, H, Seq] -> [B, Seq, H], or [B,H,w,h] -> [B,h,w,H]
|
| 62 |
+
o = super().forward(input)
|
| 63 |
+
if self.seq_last:
|
| 64 |
+
o = o.transpose(-1, 1)
|
| 65 |
+
return o
|
| 66 |
+
|
| 67 |
+
class CausalConv1d(nn.Conv1d):
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
in_channels: int,
|
| 72 |
+
out_channels: int,
|
| 73 |
+
kernel_size: _size_1_t,
|
| 74 |
+
stride: _size_1_t = 1,
|
| 75 |
+
padding: _size_1_t | str = 0,
|
| 76 |
+
dilation: _size_1_t = 1,
|
| 77 |
+
groups: int = 1,
|
| 78 |
+
bias: bool = True,
|
| 79 |
+
padding_mode: str = 'zeros',
|
| 80 |
+
device=None,
|
| 81 |
+
dtype=None,
|
| 82 |
+
look_ahead: int = 0,
|
| 83 |
+
) -> None:
|
| 84 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
|
| 85 |
+
self.look_ahead = look_ahead
|
| 86 |
+
assert look_ahead <= self.kernel_size[0] - 1, (look_ahead, self.kernel_size)
|
| 87 |
+
|
| 88 |
+
def forward(self, x: Tensor, state: Dict[int, Any] = None) -> Tensor:
|
| 89 |
+
# x [B,H,T]
|
| 90 |
+
B, H, T = x.shape
|
| 91 |
+
if state is None or id(self) not in state:
|
| 92 |
+
x = F.pad(x, pad=(self.kernel_size[0] - 1 - self.look_ahead, self.look_ahead))
|
| 93 |
+
else:
|
| 94 |
+
x = torch.concat([state[id(self)], x], dim=-1)
|
| 95 |
+
if state is not None:
|
| 96 |
+
state[id(self)] = x[..., -self.kernel_size + 1:]
|
| 97 |
+
x = super().forward(x)
|
| 98 |
+
return x
|
| 99 |
+
|
| 100 |
+
class CleanMelLayer(nn.Module):
|
| 101 |
+
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
dim_hidden: int,
|
| 105 |
+
dim_squeeze: int,
|
| 106 |
+
n_freqs: int,
|
| 107 |
+
dropout: Tuple[float, float, float] = (0, 0, 0),
|
| 108 |
+
f_kernel_size: int = 5,
|
| 109 |
+
f_conv_groups: int = 8,
|
| 110 |
+
padding: str = 'zeros',
|
| 111 |
+
full: nn.Module = None,
|
| 112 |
+
mamba_state: int = None,
|
| 113 |
+
mamba_conv_kernel: int = None,
|
| 114 |
+
online: bool = False,
|
| 115 |
+
) -> None:
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.online = online
|
| 118 |
+
# cross-band block
|
| 119 |
+
# frequency-convolutional module
|
| 120 |
+
self.fconv1 = nn.ModuleList([
|
| 121 |
+
LayerNorm(seq_last=True, normalized_shape=dim_hidden),
|
| 122 |
+
nn.Conv1d(in_channels=dim_hidden, out_channels=dim_hidden, kernel_size=f_kernel_size, groups=f_conv_groups, padding='same', padding_mode=padding),
|
| 123 |
+
nn.PReLU(dim_hidden),
|
| 124 |
+
])
|
| 125 |
+
# full-band linear module
|
| 126 |
+
self.norm_full = LayerNorm(seq_last=False, normalized_shape=dim_hidden)
|
| 127 |
+
self.full_share = False if full == None else True
|
| 128 |
+
self.squeeze = nn.Sequential(nn.Conv1d(in_channels=dim_hidden, out_channels=dim_squeeze, kernel_size=1), nn.SiLU())
|
| 129 |
+
self.dropout_full = nn.Dropout2d(dropout[2]) if dropout[2] > 0 else None
|
| 130 |
+
self.full = LinearGroup(n_freqs, n_freqs, num_groups=dim_squeeze) if full == None else full
|
| 131 |
+
self.unsqueeze = nn.Sequential(nn.Conv1d(in_channels=dim_squeeze, out_channels=dim_hidden, kernel_size=1), nn.SiLU())
|
| 132 |
+
# frequency-convolutional module
|
| 133 |
+
self.fconv2 = nn.ModuleList([
|
| 134 |
+
LayerNorm(seq_last=True, normalized_shape=dim_hidden),
|
| 135 |
+
nn.Conv1d(in_channels=dim_hidden, out_channels=dim_hidden, kernel_size=f_kernel_size, groups=f_conv_groups, padding='same', padding_mode=padding),
|
| 136 |
+
nn.PReLU(dim_hidden),
|
| 137 |
+
])
|
| 138 |
+
|
| 139 |
+
# narrow-band block
|
| 140 |
+
self.norm_mamba = LayerNorm(seq_last=False, normalized_shape=dim_hidden)
|
| 141 |
+
if online:
|
| 142 |
+
self.mamba = Mamba(d_model=dim_hidden, d_state=mamba_state, d_conv=mamba_conv_kernel, layer_idx=0)
|
| 143 |
+
else:
|
| 144 |
+
self.mamba = nn.ModuleList([
|
| 145 |
+
Mamba(d_model=dim_hidden, d_state=mamba_state, d_conv=mamba_conv_kernel, layer_idx=0),
|
| 146 |
+
Mamba(d_model=dim_hidden, d_state=mamba_state, d_conv=mamba_conv_kernel, layer_idx=1),
|
| 147 |
+
])
|
| 148 |
+
|
| 149 |
+
self.dropout_mamba = nn.Dropout(dropout[0])
|
| 150 |
+
|
| 151 |
+
def forward(self, x: Tensor, inference: bool = False) -> Tensor:
|
| 152 |
+
x = x + self._fconv(self.fconv1, x)
|
| 153 |
+
x = x + self._full(x)
|
| 154 |
+
x = x + self._fconv(self.fconv2, x)
|
| 155 |
+
if self.online:
|
| 156 |
+
x = x + self._mamba(x, self.mamba, self.norm_mamba, self.dropout_mamba, inference)
|
| 157 |
+
else:
|
| 158 |
+
x_fw = x + self._mamba(x, self.mamba[0], self.norm_mamba, self.dropout_mamba, inference)
|
| 159 |
+
x_bw = x.flip(dims=[2]) + self._mamba(x.flip(dims=[2]), self.mamba[1], self.norm_mamba, self.dropout_mamba, inference)
|
| 160 |
+
x = (x_fw + x_bw.flip(dims=[2])) / 2
|
| 161 |
+
return x
|
| 162 |
+
|
| 163 |
+
def _mamba(self, x: Tensor, mamba: Mamba, norm: nn.Module, dropout: nn.Module, inference: bool = False):
|
| 164 |
+
B, F, T, H = x.shape
|
| 165 |
+
x = norm(x)
|
| 166 |
+
x = x.reshape(B * F, T, H)
|
| 167 |
+
if inference:
|
| 168 |
+
inference_params = InferenceParams(T, B * F)
|
| 169 |
+
xs = []
|
| 170 |
+
for i in range(T):
|
| 171 |
+
inference_params.seqlen_offset = i
|
| 172 |
+
xi = mamba.forward(x[:, [i], :], inference_params)
|
| 173 |
+
xs.append(xi)
|
| 174 |
+
x = torch.concat(xs, dim=1)
|
| 175 |
+
else:
|
| 176 |
+
x = mamba.forward(x)
|
| 177 |
+
x = x.reshape(B, F, T, H)
|
| 178 |
+
return dropout(x)
|
| 179 |
+
|
| 180 |
+
def _fconv(self, ml: nn.ModuleList, x: Tensor) -> Tensor:
|
| 181 |
+
B, F, T, H = x.shape
|
| 182 |
+
x = x.permute(0, 2, 3, 1) # [B,T,H,F]
|
| 183 |
+
x = x.reshape(B * T, H, F)
|
| 184 |
+
for m in ml:
|
| 185 |
+
x = m(x)
|
| 186 |
+
x = x.reshape(B, T, H, F)
|
| 187 |
+
x = x.permute(0, 3, 1, 2) # [B,F,T,H]
|
| 188 |
+
return x
|
| 189 |
+
|
| 190 |
+
def _full(self, x: Tensor) -> Tensor:
|
| 191 |
+
B, F, T, H = x.shape
|
| 192 |
+
x = self.norm_full(x)
|
| 193 |
+
x = x.permute(0, 2, 3, 1) # [B,T,H,F]
|
| 194 |
+
x = x.reshape(B * T, H, F)
|
| 195 |
+
x = self.squeeze(x) # [B*T,H',F]
|
| 196 |
+
if self.dropout_full:
|
| 197 |
+
x = x.reshape(B, T, -1, F)
|
| 198 |
+
x = x.transpose(1, 3) # [B,F,H',T]
|
| 199 |
+
x = self.dropout_full(x) # dropout some frequencies in one utterance
|
| 200 |
+
x = x.transpose(1, 3) # [B,T,H',F]
|
| 201 |
+
x = x.reshape(B * T, -1, F)
|
| 202 |
+
x = self.full(x) # [B*T,H',F]
|
| 203 |
+
x = self.unsqueeze(x) # [B*T,H,F]
|
| 204 |
+
x = x.reshape(B, T, H, F)
|
| 205 |
+
x = x.permute(0, 3, 1, 2) # [B,F,T,H]
|
| 206 |
+
return x
|
| 207 |
+
|
| 208 |
+
def extra_repr(self) -> str:
|
| 209 |
+
return f"full_share={self.full_share}"
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class CleanMel(nn.Module):
|
| 213 |
+
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
dim_input: int, # the input dim for each time-frequency point
|
| 217 |
+
dim_output: int, # the output dim for each time-frequency point
|
| 218 |
+
n_layers: int,
|
| 219 |
+
n_freqs: int,
|
| 220 |
+
n_mels: int = 80,
|
| 221 |
+
layer_linear_freq: int = 1,
|
| 222 |
+
encoder_kernel_size: int = 5,
|
| 223 |
+
dim_hidden: int = 192,
|
| 224 |
+
dropout: Tuple[float, float, float] = (0, 0, 0),
|
| 225 |
+
f_kernel_size: int = 5,
|
| 226 |
+
f_conv_groups: int = 8,
|
| 227 |
+
padding: str = 'zeros',
|
| 228 |
+
mamba_state: int = 16,
|
| 229 |
+
mamba_conv_kernel: int = 4,
|
| 230 |
+
online: bool = True,
|
| 231 |
+
sr: int = 16000,
|
| 232 |
+
n_fft: int = 512,
|
| 233 |
+
):
|
| 234 |
+
super().__init__()
|
| 235 |
+
self.layer_linear_freq = layer_linear_freq
|
| 236 |
+
self.online = online
|
| 237 |
+
# encoder
|
| 238 |
+
self.encoder = CausalConv1d(in_channels=dim_input, out_channels=dim_hidden, kernel_size=encoder_kernel_size, look_ahead=0)
|
| 239 |
+
# cleanmel layers
|
| 240 |
+
full = None
|
| 241 |
+
layers = []
|
| 242 |
+
for l in range(n_layers):
|
| 243 |
+
layer = CleanMelLayer(
|
| 244 |
+
dim_hidden=dim_hidden,
|
| 245 |
+
dim_squeeze=8 if l < layer_linear_freq else dim_hidden,
|
| 246 |
+
n_freqs=n_freqs if l < layer_linear_freq else n_mels,
|
| 247 |
+
dropout=dropout,
|
| 248 |
+
f_kernel_size=f_kernel_size,
|
| 249 |
+
f_conv_groups=f_conv_groups,
|
| 250 |
+
padding=padding,
|
| 251 |
+
full=full if l > layer_linear_freq else None,
|
| 252 |
+
online=online,
|
| 253 |
+
mamba_conv_kernel=mamba_conv_kernel,
|
| 254 |
+
mamba_state=mamba_state,
|
| 255 |
+
)
|
| 256 |
+
if hasattr(layer, 'full'):
|
| 257 |
+
full = layer.full
|
| 258 |
+
layers.append(layer)
|
| 259 |
+
self.layers = nn.ModuleList(layers)
|
| 260 |
+
# Mel filterbank
|
| 261 |
+
linear2mel = librosa.filters.mel(**{"sr": sr, "n_fft": n_fft, "n_mels": n_mels})
|
| 262 |
+
self.register_buffer("linear2mel", torch.nn.Parameter(torch.tensor(linear2mel.T, dtype=torch.float32)))
|
| 263 |
+
# decoder
|
| 264 |
+
self.decoder = nn.Linear(in_features=dim_hidden, out_features=dim_output)
|
| 265 |
+
|
| 266 |
+
def forward(self, x: Tensor, inference: bool = False) -> Tensor:
|
| 267 |
+
# x: [Batch, Freq, Time, Feature]
|
| 268 |
+
B, F, T, H0 = x.shape
|
| 269 |
+
x = self.encoder(x.reshape(B * F, T, H0).permute(0, 2, 1)).permute(0, 2, 1)
|
| 270 |
+
|
| 271 |
+
H = x.shape[2]
|
| 272 |
+
x = x.reshape(B, F, T, H)
|
| 273 |
+
# First Cross-Narrow band block in Linear Frequency
|
| 274 |
+
for i in range(self.layer_linear_freq):
|
| 275 |
+
m = self.layers[i]
|
| 276 |
+
x = m(x, inference).contiguous()
|
| 277 |
+
|
| 278 |
+
# Mel-filterbank
|
| 279 |
+
x = torch.einsum("bfth,fm->bmth", x, self.linear2mel)
|
| 280 |
+
|
| 281 |
+
for i in range(self.layer_linear_freq, len(self.layers)):
|
| 282 |
+
m = self.layers[i]
|
| 283 |
+
x = m(x, inference).contiguous()
|
| 284 |
+
|
| 285 |
+
y = self.decoder(x).squeeze(-1)
|
| 286 |
+
return y.contiguous()
|
| 287 |
+
|
| 288 |
+
if __name__ == '__main__':
|
| 289 |
+
# a quick demo here for the CleanMel model
|
| 290 |
+
# input: wavs
|
| 291 |
+
# output: enhanced log-mel spectrogram
|
| 292 |
+
pytorch_lightning.seed_everything(1234)
|
| 293 |
+
import soundfile as sf
|
| 294 |
+
import matplotlib.pyplot as plt
|
| 295 |
+
import numpy as np
|
| 296 |
+
from model.io.stft import InputSTFT
|
| 297 |
+
from model.io.stft import TargetMel
|
| 298 |
+
from torch.utils.flop_counter import FlopCounterMode
|
| 299 |
+
|
| 300 |
+
online=False
|
| 301 |
+
# Define input STFT and target Mel
|
| 302 |
+
stft = InputSTFT(
|
| 303 |
+
n_fft=512,
|
| 304 |
+
n_win=512,
|
| 305 |
+
n_hop=128,
|
| 306 |
+
center=True,
|
| 307 |
+
normalize=False,
|
| 308 |
+
onesided=True,
|
| 309 |
+
online=online).to("cuda")
|
| 310 |
+
|
| 311 |
+
target_mel = TargetMel(
|
| 312 |
+
sample_rate=16000,
|
| 313 |
+
n_fft=512,
|
| 314 |
+
n_win=512,
|
| 315 |
+
n_hop=128,
|
| 316 |
+
n_mels=80,
|
| 317 |
+
f_min=0,
|
| 318 |
+
f_max=8000,
|
| 319 |
+
power=2,
|
| 320 |
+
center=True,
|
| 321 |
+
normalize=False,
|
| 322 |
+
onesided=True,
|
| 323 |
+
mel_norm="slaney",
|
| 324 |
+
mel_scale="slaney",
|
| 325 |
+
librosa_mel=True,
|
| 326 |
+
online=online).to("cuda")
|
| 327 |
+
|
| 328 |
+
def customize_soxnorm(wav, gain=-3, factor=None):
|
| 329 |
+
wav = np.clip(wav, a_max=1, a_min=-1)
|
| 330 |
+
if factor is None:
|
| 331 |
+
linear_gain = 10 ** (gain / 20)
|
| 332 |
+
factor = linear_gain / np.abs(wav).max()
|
| 333 |
+
wav = wav * factor
|
| 334 |
+
return wav, factor
|
| 335 |
+
else:
|
| 336 |
+
wav = wav * factor
|
| 337 |
+
return wav, None
|
| 338 |
+
|
| 339 |
+
# Noisy file path
|
| 340 |
+
wav = "./src/demos/noisy_CHIME-real_F05_442C020S_STR_REAL.wav"
|
| 341 |
+
wavname = wav.split("/")[-1].split(".")[0]
|
| 342 |
+
|
| 343 |
+
print(f"Processing {wav}")
|
| 344 |
+
noisy, fs = sf.read(wav)
|
| 345 |
+
dur = len(noisy) / fs
|
| 346 |
+
noisy, factor = customize_soxnorm(noisy, gain=-3)
|
| 347 |
+
noisy = torch.tensor(noisy).unsqueeze(0).float().to("cuda")
|
| 348 |
+
# vocos norm
|
| 349 |
+
x = stft(noisy)
|
| 350 |
+
# Load the model
|
| 351 |
+
hidden=96
|
| 352 |
+
depth=8
|
| 353 |
+
model = CleanMel(
|
| 354 |
+
dim_input=2,
|
| 355 |
+
dim_output=1,
|
| 356 |
+
n_layers=depth,
|
| 357 |
+
dim_hidden=hidden,
|
| 358 |
+
layer_linear_freq=1,
|
| 359 |
+
f_kernel_size=5,
|
| 360 |
+
f_conv_groups=8,
|
| 361 |
+
n_freqs=257,
|
| 362 |
+
mamba_state=16,
|
| 363 |
+
mamba_conv_kernel=4,
|
| 364 |
+
online=online,
|
| 365 |
+
sr=16000,
|
| 366 |
+
n_fft=512
|
| 367 |
+
).to("cuda")
|
| 368 |
+
|
| 369 |
+
# Load the pretrained model
|
| 370 |
+
state_dict = torch.load("./pretrained/CleanMel_S_L1.ckpt")
|
| 371 |
+
model.load_state_dict(state_dict)
|
| 372 |
+
|
| 373 |
+
model.eval()
|
| 374 |
+
with FlopCounterMode(model, display=False) as fcm:
|
| 375 |
+
y_hat = model(x, inference=False)
|
| 376 |
+
flops_forward_eval = fcm.get_total_flops()
|
| 377 |
+
params_eval = sum(param.numel() for param in model.parameters())
|
| 378 |
+
print(f"flops_forward={flops_forward_eval/1e9 / dur:.2f}G")
|
| 379 |
+
print(f"params={params_eval/1e6:.2f} M")
|
| 380 |
+
|
| 381 |
+
# y_hat is the enhanced log-mel spectrogram
|
| 382 |
+
y_hat = y_hat[0].cpu().detach().numpy()
|
| 383 |
+
|
| 384 |
+
# sanity check
|
| 385 |
+
if wavname == "noisy_CHIME-real_F05_442C020S_STR_REAL":
|
| 386 |
+
assert np.allclose(y_hat, np.load("./src/inference/check_CHIME-real_F05_442C020S_STR_REAL.npy"), atol=1e-5)
|
| 387 |
+
|
| 388 |
+
# plot the enhanced mel spectrogram
|
| 389 |
+
noisy_mel = target_mel(noisy)
|
| 390 |
+
noisy_mel = torch.log(noisy_mel.clamp(min=1e-5))[0].cpu().detach().numpy()
|
| 391 |
+
vmax = math.log(1e2)
|
| 392 |
+
vmin = math.log(1e-5)
|
| 393 |
+
plt.figure(figsize=(8, 4))
|
| 394 |
+
plt.subplot(2, 1, 1)
|
| 395 |
+
plt.imshow(noisy_mel, aspect='auto', origin='lower', cmap='jet', vmax=vmax, vmin=vmin)
|
| 396 |
+
plt.colorbar()
|
| 397 |
+
plt.subplot(2, 1, 2)
|
| 398 |
+
plt.imshow(y_hat, aspect='auto', origin='lower', cmap='jet', vmax=vmax, vmin=vmin)
|
| 399 |
+
plt.colorbar()
|
| 400 |
+
plt.tight_layout()
|
| 401 |
+
plt.savefig(f"./src/inference/{wavname}.png")
|
model/stft.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import librosa
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import random
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from torchaudio.transforms import Spectrogram
|
| 8 |
+
from torchaudio.transforms import Spectrogram, MelScale
|
| 9 |
+
|
| 10 |
+
def soxnorm(wav: torch.Tensor, gain, factor=None):
|
| 11 |
+
"""sox norm, used in Vocos codes;
|
| 12 |
+
"""
|
| 13 |
+
wav = torch.clip(wav, max=1, min=-1).float()
|
| 14 |
+
if factor is None:
|
| 15 |
+
linear_gain = 10 ** (gain / 20)
|
| 16 |
+
factor = linear_gain / torch.abs(wav).max().item()
|
| 17 |
+
wav = wav * factor
|
| 18 |
+
else:
|
| 19 |
+
# for clean speech, normed by the noisy factor
|
| 20 |
+
wav = wav * factor
|
| 21 |
+
assert torch.all(wav.abs() <= 1), f"out wavform is not in [-1, 1], {wav.abs().max()}"
|
| 22 |
+
return wav, factor
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class InputSTFT(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
The STFT of the input signal of CleanMel (STFT coefficients);
|
| 28 |
+
In online mode, the recursive normalization is used.
|
| 29 |
+
"""
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
n_fft: int,
|
| 33 |
+
n_win: int,
|
| 34 |
+
n_hop: int,
|
| 35 |
+
center: bool,
|
| 36 |
+
normalize: bool,
|
| 37 |
+
onesided: bool,
|
| 38 |
+
online: bool = False):
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
self.online = online
|
| 42 |
+
self.stft=Spectrogram(
|
| 43 |
+
n_fft=n_fft,
|
| 44 |
+
win_length=n_win,
|
| 45 |
+
hop_length=n_hop,
|
| 46 |
+
normalized=normalize,
|
| 47 |
+
center=center,
|
| 48 |
+
onesided=onesided,
|
| 49 |
+
power=None
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
if self.online:
|
| 54 |
+
# recursive normalization
|
| 55 |
+
x = self.stft(x)
|
| 56 |
+
x_mag = x.abs()
|
| 57 |
+
x_norm = recursive_normalization(x_mag)
|
| 58 |
+
x = x / x_norm.clamp(min=1e-8)
|
| 59 |
+
x = torch.view_as_real(x)
|
| 60 |
+
else:
|
| 61 |
+
# vocos dBFS normalization
|
| 62 |
+
x, x_norm = soxnorm(x, random.randint(-6, -1) if self.training else -3)
|
| 63 |
+
x = self.stft(x)
|
| 64 |
+
x = torch.view_as_real(x)
|
| 65 |
+
return x, x_norm
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class LibrosaMelScale(nn.Module):
|
| 69 |
+
r"""Pytorch implementation of librosa mel scale to align with common ESPNet ASR models;
|
| 70 |
+
You might need to define .
|
| 71 |
+
"""
|
| 72 |
+
def __init__(self, n_mels, sample_rate, f_min, f_max, n_stft, norm=None, mel_scale="slaney"):
|
| 73 |
+
super(LibrosaMelScale, self).__init__()
|
| 74 |
+
|
| 75 |
+
_mel_options = dict(
|
| 76 |
+
sr=sample_rate,
|
| 77 |
+
n_fft=(n_stft - 1) * 2,
|
| 78 |
+
n_mels=n_mels,
|
| 79 |
+
fmin=f_min,
|
| 80 |
+
fmax=f_max if f_max is not None else float(sample_rate // 2),
|
| 81 |
+
htk=mel_scale=="htk",
|
| 82 |
+
norm=norm
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
fb = torch.from_numpy(librosa.filters.mel(**_mel_options).T).float()
|
| 86 |
+
self.register_buffer("fb", fb)
|
| 87 |
+
|
| 88 |
+
def forward(self, specgram):
|
| 89 |
+
mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2)
|
| 90 |
+
return mel_specgram
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class TargetMel(nn.Module):
|
| 94 |
+
"""
|
| 95 |
+
This class generates the enhancement TARGET mel spectrogram;
|
| 96 |
+
"""
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
sample_rate: int,
|
| 100 |
+
n_fft: int,
|
| 101 |
+
n_win: int,
|
| 102 |
+
n_hop: int,
|
| 103 |
+
n_mels: int,
|
| 104 |
+
f_min: int,
|
| 105 |
+
f_max: int,
|
| 106 |
+
power: int,
|
| 107 |
+
center: bool,
|
| 108 |
+
normalize: bool,
|
| 109 |
+
onesided: bool,
|
| 110 |
+
mel_norm: str | None,
|
| 111 |
+
mel_scale: str,
|
| 112 |
+
librosa_mel: bool = True,
|
| 113 |
+
online: bool = False,
|
| 114 |
+
):
|
| 115 |
+
super().__init__()
|
| 116 |
+
# This implementation vs torchaudio.transforms.MelSpectrogram: Add librosa melscale
|
| 117 |
+
# librosa melscale is numerically different from the torchaudio melscale (x_diff > 1e-5)
|
| 118 |
+
|
| 119 |
+
self.sample_rate = sample_rate
|
| 120 |
+
self.n_fft = n_fft
|
| 121 |
+
self.online = online
|
| 122 |
+
self.stft = Spectrogram(
|
| 123 |
+
n_fft=n_fft,
|
| 124 |
+
win_length=n_win,
|
| 125 |
+
hop_length=n_hop,
|
| 126 |
+
power=None if online else power,
|
| 127 |
+
normalized=normalize,
|
| 128 |
+
center=center,
|
| 129 |
+
onesided=onesided,
|
| 130 |
+
)
|
| 131 |
+
mel_method = LibrosaMelScale if librosa_mel else MelScale
|
| 132 |
+
self.mel_scale = mel_method(
|
| 133 |
+
n_mels=n_mels,
|
| 134 |
+
sample_rate=sample_rate,
|
| 135 |
+
f_min=f_min,
|
| 136 |
+
f_max=f_max,
|
| 137 |
+
n_stft=n_fft // 2 + 1,
|
| 138 |
+
norm=mel_norm,
|
| 139 |
+
mel_scale=mel_scale,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def forward(self, x: Tensor, x_norm=None):
|
| 143 |
+
if self.online:
|
| 144 |
+
# apply recursive normalization to target waveform
|
| 145 |
+
spectrogram = self.stft(x)
|
| 146 |
+
spectrogram = spectrogram / (x_norm + 1e-8)
|
| 147 |
+
spectrogram = spectrogram.abs().pow(2) # to power spectrogram
|
| 148 |
+
else:
|
| 149 |
+
# apply vocos dBFS normalization to target waveform
|
| 150 |
+
x, _ = soxnorm(x, None, x_norm)
|
| 151 |
+
spectrogram = self.stft(x)
|
| 152 |
+
# mel spectrogram
|
| 153 |
+
mel_specgram = self.mel_scale(spectrogram)
|
| 154 |
+
return mel_specgram
|
model/vocos/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
model/vocos/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|
model/vocos/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (177 Bytes). View file
|
|
|
model/vocos/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|
model/vocos/__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (3.91 kB). View file
|
|
|
model/vocos/__pycache__/discriminators.cpython-310.pyc
ADDED
|
Binary file (8 kB). View file
|
|
|
model/vocos/__pycache__/discriminators.cpython-39.pyc
ADDED
|
Binary file (7.98 kB). View file
|
|
|
model/vocos/__pycache__/experiment.cpython-310.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
model/vocos/__pycache__/experiment.cpython-312.pyc
ADDED
|
Binary file (22.9 kB). View file
|
|
|
model/vocos/__pycache__/experiment.cpython-39.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
model/vocos/__pycache__/feature_extractors.cpython-310.pyc
ADDED
|
Binary file (6.03 kB). View file
|
|
|
model/vocos/__pycache__/feature_extractors.cpython-39.pyc
ADDED
|
Binary file (5.92 kB). View file
|
|
|
model/vocos/__pycache__/heads.cpython-310.pyc
ADDED
|
Binary file (6.83 kB). View file
|
|
|
model/vocos/__pycache__/helpers.cpython-310.pyc
ADDED
|
Binary file (2.75 kB). View file
|
|
|
model/vocos/__pycache__/loss.cpython-310.pyc
ADDED
|
Binary file (4.82 kB). View file
|
|
|
model/vocos/__pycache__/models.cpython-310.pyc
ADDED
|
Binary file (5.02 kB). View file
|
|
|
model/vocos/__pycache__/modules.cpython-310.pyc
ADDED
|
Binary file (6.68 kB). View file
|
|
|
model/vocos/__pycache__/modules.cpython-39.pyc
ADDED
|
Binary file (6.63 kB). View file
|
|
|
model/vocos/__pycache__/pretrained.cpython-310.pyc
ADDED
|
Binary file (7.41 kB). View file
|
|
|
model/vocos/__pycache__/spectral_ops.cpython-310.pyc
ADDED
|
Binary file (6.86 kB). View file
|
|
|
model/vocos/dataset.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from pytorch_lightning.utilities.types import EVAL_DATALOADERS
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio
|
| 7 |
+
import warnings
|
| 8 |
+
from pytorch_lightning import LightningDataModule
|
| 9 |
+
from torch.utils.data import Dataset, DataLoader
|
| 10 |
+
|
| 11 |
+
torch.set_num_threads(1)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class DataConfig:
|
| 16 |
+
filelist_path: str
|
| 17 |
+
sampling_rate: int
|
| 18 |
+
num_samples: int
|
| 19 |
+
batch_size: int
|
| 20 |
+
num_workers: int
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class VocosDataModule(LightningDataModule):
|
| 24 |
+
def __init__(self, train_params: DataConfig, val_params: DataConfig):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.train_config = train_params
|
| 27 |
+
self.val_config = val_params
|
| 28 |
+
|
| 29 |
+
def _get_dataloder(self, cfg: DataConfig, train: bool):
|
| 30 |
+
dataset = VocosDataset(cfg, train=train)
|
| 31 |
+
dataloader = DataLoader(
|
| 32 |
+
dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True,
|
| 33 |
+
)
|
| 34 |
+
return dataloader
|
| 35 |
+
|
| 36 |
+
def train_dataloader(self) -> DataLoader:
|
| 37 |
+
return self._get_dataloder(self.train_config, train=True)
|
| 38 |
+
|
| 39 |
+
def val_dataloader(self) -> DataLoader:
|
| 40 |
+
return self._get_dataloder(self.val_config, train=False)
|
| 41 |
+
|
| 42 |
+
def test_dataloader(self) -> DataLoader:
|
| 43 |
+
return self.val_dataloader()
|
| 44 |
+
|
| 45 |
+
class VocosDataset(Dataset):
|
| 46 |
+
def __init__(self, cfg: DataConfig, train: bool):
|
| 47 |
+
with open(cfg.filelist_path) as f:
|
| 48 |
+
self.filelist = f.read().splitlines()
|
| 49 |
+
self.sampling_rate = cfg.sampling_rate
|
| 50 |
+
self.num_samples = cfg.num_samples
|
| 51 |
+
self.train = train
|
| 52 |
+
|
| 53 |
+
def __len__(self) -> int:
|
| 54 |
+
return len(self.filelist)
|
| 55 |
+
|
| 56 |
+
def customize_soxnorm(self, wav, gain=-3, factor=None):
|
| 57 |
+
wav = np.clip(wav, a_max=1, a_min=-1)
|
| 58 |
+
if factor is None:
|
| 59 |
+
linear_gain = 10 ** (gain / 20)
|
| 60 |
+
wav = wav / np.abs(wav).max() * linear_gain
|
| 61 |
+
return wav, linear_gain / np.abs(wav).max()
|
| 62 |
+
else:
|
| 63 |
+
wav = wav * factor
|
| 64 |
+
return wav, None
|
| 65 |
+
|
| 66 |
+
def __getitem__(self, index: int) -> torch.Tensor:
|
| 67 |
+
audio_path = self.filelist[index]
|
| 68 |
+
try:
|
| 69 |
+
y, sr = torchaudio.load(audio_path)
|
| 70 |
+
except:
|
| 71 |
+
warnings.warn(f"Error loading {audio_path}")
|
| 72 |
+
return self.__getitem__(np.random.randint(len(self.filelist)))
|
| 73 |
+
if y.size(-1) == 0:
|
| 74 |
+
return self.__getitem__(np.random.randint(len(self.filelist)))
|
| 75 |
+
if y.size(0) > 1:
|
| 76 |
+
# mix to mono
|
| 77 |
+
y = y.mean(dim=0, keepdim=True)
|
| 78 |
+
gain = np.random.uniform(-1, -6) if self.train else -3
|
| 79 |
+
y, _ = self.customize_soxnorm(y, gain)
|
| 80 |
+
if sr != self.sampling_rate:
|
| 81 |
+
y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate)
|
| 82 |
+
if y.size(-1) < self.num_samples:
|
| 83 |
+
pad_length = self.num_samples - y.size(-1)
|
| 84 |
+
padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
|
| 85 |
+
y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
|
| 86 |
+
elif self.train:
|
| 87 |
+
start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
|
| 88 |
+
y = y[:, start : start + self.num_samples]
|
| 89 |
+
else:
|
| 90 |
+
# During validation, take always the first segment for determinism
|
| 91 |
+
y = y[:, : self.num_samples]
|
| 92 |
+
|
| 93 |
+
return y[0]
|
model/vocos/discriminators.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import Conv2d
|
| 7 |
+
from torch.nn.utils import weight_norm
|
| 8 |
+
from torchaudio.transforms import Spectrogram
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MultiPeriodDiscriminator(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan.
|
| 14 |
+
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
periods (tuple[int]): Tuple of periods for each discriminator.
|
| 18 |
+
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
| 19 |
+
Defaults to None.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, periods: Tuple[int, ...] = (2, 3, 5, 7, 11), num_embeddings: Optional[int] = None):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods])
|
| 25 |
+
|
| 26 |
+
def forward(
|
| 27 |
+
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None
|
| 28 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
| 29 |
+
y_d_rs = []
|
| 30 |
+
y_d_gs = []
|
| 31 |
+
fmap_rs = []
|
| 32 |
+
fmap_gs = []
|
| 33 |
+
for d in self.discriminators:
|
| 34 |
+
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
| 35 |
+
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
| 36 |
+
y_d_rs.append(y_d_r)
|
| 37 |
+
fmap_rs.append(fmap_r)
|
| 38 |
+
y_d_gs.append(y_d_g)
|
| 39 |
+
fmap_gs.append(fmap_g)
|
| 40 |
+
|
| 41 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class DiscriminatorP(nn.Module):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
period: int,
|
| 48 |
+
in_channels: int = 1,
|
| 49 |
+
kernel_size: int = 5,
|
| 50 |
+
stride: int = 3,
|
| 51 |
+
lrelu_slope: float = 0.1,
|
| 52 |
+
num_embeddings: Optional[int] = None,
|
| 53 |
+
):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.period = period
|
| 56 |
+
self.convs = nn.ModuleList(
|
| 57 |
+
[
|
| 58 |
+
weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
|
| 59 |
+
weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
|
| 60 |
+
weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
|
| 61 |
+
weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
|
| 62 |
+
weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))),
|
| 63 |
+
]
|
| 64 |
+
)
|
| 65 |
+
if num_embeddings is not None:
|
| 66 |
+
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=1024)
|
| 67 |
+
torch.nn.init.zeros_(self.emb.weight)
|
| 68 |
+
|
| 69 |
+
self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 70 |
+
self.lrelu_slope = lrelu_slope
|
| 71 |
+
|
| 72 |
+
def forward(
|
| 73 |
+
self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
|
| 74 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 75 |
+
x = x.unsqueeze(1)
|
| 76 |
+
fmap = []
|
| 77 |
+
# 1d to 2d
|
| 78 |
+
b, c, t = x.shape
|
| 79 |
+
if t % self.period != 0: # pad first
|
| 80 |
+
n_pad = self.period - (t % self.period)
|
| 81 |
+
x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
|
| 82 |
+
t = t + n_pad
|
| 83 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 84 |
+
|
| 85 |
+
for i, l in enumerate(self.convs):
|
| 86 |
+
x = l(x)
|
| 87 |
+
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
|
| 88 |
+
if i > 0:
|
| 89 |
+
fmap.append(x)
|
| 90 |
+
if cond_embedding_id is not None:
|
| 91 |
+
emb = self.emb(cond_embedding_id)
|
| 92 |
+
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
| 93 |
+
else:
|
| 94 |
+
h = 0
|
| 95 |
+
x = self.conv_post(x)
|
| 96 |
+
fmap.append(x)
|
| 97 |
+
x += h
|
| 98 |
+
x = torch.flatten(x, 1, -1)
|
| 99 |
+
|
| 100 |
+
return x, fmap
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class MultiResolutionDiscriminator(nn.Module):
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
|
| 107 |
+
num_embeddings: Optional[int] = None,
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
|
| 111 |
+
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
|
| 115 |
+
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
| 116 |
+
Defaults to None.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.discriminators = nn.ModuleList(
|
| 121 |
+
[DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
def forward(
|
| 125 |
+
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
| 126 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
| 127 |
+
y_d_rs = []
|
| 128 |
+
y_d_gs = []
|
| 129 |
+
fmap_rs = []
|
| 130 |
+
fmap_gs = []
|
| 131 |
+
|
| 132 |
+
for d in self.discriminators:
|
| 133 |
+
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
| 134 |
+
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
| 135 |
+
y_d_rs.append(y_d_r)
|
| 136 |
+
fmap_rs.append(fmap_r)
|
| 137 |
+
y_d_gs.append(y_d_g)
|
| 138 |
+
fmap_gs.append(fmap_g)
|
| 139 |
+
|
| 140 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class DiscriminatorR(nn.Module):
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
window_length: int,
|
| 147 |
+
num_embeddings: Optional[int] = None,
|
| 148 |
+
channels: int = 32,
|
| 149 |
+
hop_factor: float = 0.25,
|
| 150 |
+
bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
|
| 151 |
+
):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.window_length = window_length
|
| 154 |
+
self.hop_factor = hop_factor
|
| 155 |
+
self.spec_fn = Spectrogram(
|
| 156 |
+
n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
|
| 157 |
+
)
|
| 158 |
+
n_fft = window_length // 2 + 1
|
| 159 |
+
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
| 160 |
+
self.bands = bands
|
| 161 |
+
convs = lambda: nn.ModuleList(
|
| 162 |
+
[
|
| 163 |
+
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
| 164 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
| 165 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
| 166 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
| 167 |
+
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
| 168 |
+
]
|
| 169 |
+
)
|
| 170 |
+
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
| 171 |
+
|
| 172 |
+
if num_embeddings is not None:
|
| 173 |
+
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
|
| 174 |
+
torch.nn.init.zeros_(self.emb.weight)
|
| 175 |
+
|
| 176 |
+
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
| 177 |
+
|
| 178 |
+
def spectrogram(self, x):
|
| 179 |
+
# Remove DC offset
|
| 180 |
+
x = x - x.mean(dim=-1, keepdims=True)
|
| 181 |
+
# Peak normalize the volume of input audio
|
| 182 |
+
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
| 183 |
+
x = self.spec_fn(x)
|
| 184 |
+
x = torch.view_as_real(x)
|
| 185 |
+
x = rearrange(x, "b f t c -> b c t f")
|
| 186 |
+
# Split into bands
|
| 187 |
+
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
| 188 |
+
return x_bands
|
| 189 |
+
|
| 190 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
|
| 191 |
+
x_bands = self.spectrogram(x)
|
| 192 |
+
fmap = []
|
| 193 |
+
x = []
|
| 194 |
+
for band, stack in zip(x_bands, self.band_convs):
|
| 195 |
+
for i, layer in enumerate(stack):
|
| 196 |
+
band = layer(band)
|
| 197 |
+
band = torch.nn.functional.leaky_relu(band, 0.1)
|
| 198 |
+
if i > 0:
|
| 199 |
+
fmap.append(band)
|
| 200 |
+
x.append(band)
|
| 201 |
+
x = torch.cat(x, dim=-1)
|
| 202 |
+
if cond_embedding_id is not None:
|
| 203 |
+
emb = self.emb(cond_embedding_id)
|
| 204 |
+
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
| 205 |
+
else:
|
| 206 |
+
h = 0
|
| 207 |
+
x = self.conv_post(x)
|
| 208 |
+
fmap.append(x)
|
| 209 |
+
x += h
|
| 210 |
+
|
| 211 |
+
return x, fmap
|
model/vocos/experiment.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio
|
| 7 |
+
import transformers
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from model.vocos.offline.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
|
| 11 |
+
from model.vocos.offline.feature_extractors import FeatureExtractor
|
| 12 |
+
from model.vocos.offline.heads import FourierHead
|
| 13 |
+
from model.vocos.offline.helpers import plot_spectrogram_to_numpy
|
| 14 |
+
from model.vocos.offline.loss import DiscriminatorLoss, GeneratorLoss, FeatureMatchingLoss, MelSpecReconstructionLoss
|
| 15 |
+
# from models.vocos.offline.models import Backbone
|
| 16 |
+
from model.vocos.offline.modules import safe_log
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class VocosExp(pl.LightningModule):
|
| 20 |
+
# noinspection PyUnusedLocal
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
feature_extractor: FeatureExtractor,
|
| 24 |
+
backbone: nn.Module,
|
| 25 |
+
head: nn.Module,
|
| 26 |
+
sample_rate: int,
|
| 27 |
+
initial_learning_rate: float,
|
| 28 |
+
num_warmup_steps: int = 0,
|
| 29 |
+
mel_loss_coeff: float = 45,
|
| 30 |
+
mrd_loss_coeff: float = 1.0,
|
| 31 |
+
pretrain_mel_steps: int = 0,
|
| 32 |
+
decay_mel_coeff: bool = False,
|
| 33 |
+
evaluate_utmos: bool = False,
|
| 34 |
+
evaluate_pesq: bool = False,
|
| 35 |
+
evaluate_periodicty: bool = False,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Args:
|
| 39 |
+
feature_extractor (FeatureExtractor): An instance of FeatureExtractor to extract features from audio signals.
|
| 40 |
+
backbone (Backbone): An instance of Backbone model.
|
| 41 |
+
head (FourierHead): An instance of Fourier head to generate spectral coefficients and reconstruct a waveform.
|
| 42 |
+
sample_rate (int): Sampling rate of the audio signals.
|
| 43 |
+
initial_learning_rate (float): Initial learning rate for the optimizer.
|
| 44 |
+
num_warmup_steps (int): Number of steps for the warmup phase of learning rate scheduler. Default is 0.
|
| 45 |
+
mel_loss_coeff (float, optional): Coefficient for Mel-spectrogram loss in the loss function. Default is 45.
|
| 46 |
+
mrd_loss_coeff (float, optional): Coefficient for Multi Resolution Discriminator loss. Default is 1.0.
|
| 47 |
+
pretrain_mel_steps (int, optional): Number of steps to pre-train the model without the GAN objective. Default is 0.
|
| 48 |
+
decay_mel_coeff (bool, optional): If True, the Mel-spectrogram loss coefficient is decayed during training. Default is False.
|
| 49 |
+
evaluate_utmos (bool, optional): If True, UTMOS scores are computed for each validation run.
|
| 50 |
+
evaluate_pesq (bool, optional): If True, PESQ scores are computed for each validation run.
|
| 51 |
+
evaluate_periodicty (bool, optional): If True, periodicity scores are computed for each validation run.
|
| 52 |
+
"""
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.save_hyperparameters(ignore=["feature_extractor", "backbone", "head"])
|
| 55 |
+
self.feature_extractor = feature_extractor
|
| 56 |
+
self.backbone = backbone
|
| 57 |
+
self.head = head
|
| 58 |
+
self.sample_rate = sample_rate
|
| 59 |
+
self.initial_learning_rate = initial_learning_rate
|
| 60 |
+
self.num_warmup_steps = num_warmup_steps
|
| 61 |
+
self.mel_loss_coeff = mel_loss_coeff
|
| 62 |
+
self.mrd_loss_coeff = mrd_loss_coeff
|
| 63 |
+
self.pretrain_mel_steps = pretrain_mel_steps
|
| 64 |
+
self.decay_mel_coeff = decay_mel_coeff
|
| 65 |
+
self.evaluate_utmos = evaluate_utmos
|
| 66 |
+
self.evaluate_pesq = evaluate_pesq
|
| 67 |
+
self.evaluate_periodicty = evaluate_periodicty
|
| 68 |
+
|
| 69 |
+
self.multiperioddisc = MultiPeriodDiscriminator()
|
| 70 |
+
self.multiresddisc = MultiResolutionDiscriminator()
|
| 71 |
+
|
| 72 |
+
self.disc_loss = DiscriminatorLoss()
|
| 73 |
+
self.gen_loss = GeneratorLoss()
|
| 74 |
+
self.feat_matching_loss = FeatureMatchingLoss()
|
| 75 |
+
self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate)
|
| 76 |
+
|
| 77 |
+
self.train_discriminator = False
|
| 78 |
+
self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff
|
| 79 |
+
self.temp_cache=None
|
| 80 |
+
self.temp_grad=None
|
| 81 |
+
|
| 82 |
+
def configure_optimizers(self):
|
| 83 |
+
disc_params = [
|
| 84 |
+
{"params": self.multiperioddisc.parameters()},
|
| 85 |
+
{"params": self.multiresddisc.parameters()},
|
| 86 |
+
]
|
| 87 |
+
gen_params = [
|
| 88 |
+
{"params": self.feature_extractor.parameters()},
|
| 89 |
+
{"params": self.backbone.parameters()},
|
| 90 |
+
{"params": self.head.parameters()},
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
opt_disc = torch.optim.AdamW(disc_params, lr=self.initial_learning_rate, betas=(0.8, 0.9))
|
| 94 |
+
opt_gen = torch.optim.AdamW(gen_params, lr=self.initial_learning_rate, betas=(0.8, 0.9))
|
| 95 |
+
|
| 96 |
+
max_steps = self.trainer.max_steps // 2 # Max steps per optimizer
|
| 97 |
+
scheduler_disc = transformers.get_cosine_schedule_with_warmup(
|
| 98 |
+
opt_disc, num_warmup_steps=self.num_warmup_steps, num_training_steps=max_steps,
|
| 99 |
+
)
|
| 100 |
+
scheduler_gen = transformers.get_cosine_schedule_with_warmup(
|
| 101 |
+
opt_gen, num_warmup_steps=self.num_warmup_steps, num_training_steps=max_steps,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return (
|
| 105 |
+
[opt_disc, opt_gen],
|
| 106 |
+
[{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}],
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def forward(self, audio_input, **kwargs):
|
| 110 |
+
features = self.feature_extractor(audio_input, **kwargs)
|
| 111 |
+
x = self.backbone(features, **kwargs)
|
| 112 |
+
audio_output = self.head(x)
|
| 113 |
+
return audio_output
|
| 114 |
+
|
| 115 |
+
def training_step(self, batch, batch_idx, optimizer_idx, **kwargs):
|
| 116 |
+
audio_input = batch
|
| 117 |
+
# train discriminator
|
| 118 |
+
if optimizer_idx == 0 and self.train_discriminator:
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
audio_hat = self(audio_input, **kwargs)
|
| 121 |
+
real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
|
| 122 |
+
real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
|
| 123 |
+
loss_mp, loss_mp_real, _ = self.disc_loss(
|
| 124 |
+
disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp
|
| 125 |
+
)
|
| 126 |
+
loss_mrd, loss_mrd_real, _ = self.disc_loss(
|
| 127 |
+
disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd
|
| 128 |
+
)
|
| 129 |
+
loss_mp /= len(loss_mp_real)
|
| 130 |
+
loss_mrd /= len(loss_mrd_real)
|
| 131 |
+
loss = loss_mp + self.mrd_loss_coeff * loss_mrd
|
| 132 |
+
|
| 133 |
+
self.log("discriminator/total", loss, prog_bar=True)
|
| 134 |
+
self.log("discriminator/multi_period_loss", loss_mp)
|
| 135 |
+
self.log("discriminator/multi_res_loss", loss_mrd)
|
| 136 |
+
return loss
|
| 137 |
+
|
| 138 |
+
# train generator
|
| 139 |
+
if optimizer_idx == 1:
|
| 140 |
+
audio_hat = self(audio_input, **kwargs)
|
| 141 |
+
if self.train_discriminator:
|
| 142 |
+
_, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc(
|
| 143 |
+
y=audio_input, y_hat=audio_hat, **kwargs,
|
| 144 |
+
)
|
| 145 |
+
_, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc(
|
| 146 |
+
y=audio_input, y_hat=audio_hat, **kwargs,
|
| 147 |
+
)
|
| 148 |
+
loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp)
|
| 149 |
+
loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd)
|
| 150 |
+
loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp)
|
| 151 |
+
loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd)
|
| 152 |
+
loss_fm_mp = self.feat_matching_loss(fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp) / len(fmap_rs_mp)
|
| 153 |
+
loss_fm_mrd = self.feat_matching_loss(fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd) / len(fmap_rs_mrd)
|
| 154 |
+
|
| 155 |
+
self.log("generator/multi_period_loss", loss_gen_mp)
|
| 156 |
+
self.log("generator/multi_res_loss", loss_gen_mrd)
|
| 157 |
+
self.log("generator/feature_matching_mp", loss_fm_mp)
|
| 158 |
+
self.log("generator/feature_matching_mrd", loss_fm_mrd)
|
| 159 |
+
else:
|
| 160 |
+
loss_gen_mp = loss_gen_mrd = loss_fm_mp = loss_fm_mrd = 0
|
| 161 |
+
|
| 162 |
+
mel_loss = self.melspec_loss(audio_hat, audio_input)
|
| 163 |
+
loss = (
|
| 164 |
+
loss_gen_mp
|
| 165 |
+
+ self.mrd_loss_coeff * loss_gen_mrd
|
| 166 |
+
+ loss_fm_mp
|
| 167 |
+
+ self.mrd_loss_coeff * loss_fm_mrd
|
| 168 |
+
+ self.mel_loss_coeff * mel_loss
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
self.log("generator/total_loss", loss, prog_bar=True)
|
| 172 |
+
self.log("mel_loss_coeff", self.mel_loss_coeff)
|
| 173 |
+
self.log("generator/mel_loss", mel_loss)
|
| 174 |
+
|
| 175 |
+
if self.global_step % 1000 == 0 and self.global_rank == 0:
|
| 176 |
+
self.logger.experiment.add_audio(
|
| 177 |
+
"train/audio_in", audio_input[0].data.cpu(), self.global_step, self.sample_rate
|
| 178 |
+
)
|
| 179 |
+
self.logger.experiment.add_audio(
|
| 180 |
+
"train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.sample_rate
|
| 181 |
+
)
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
mel = safe_log(self.melspec_loss.mel_spec(audio_input[0]))
|
| 184 |
+
mel_hat = safe_log(self.melspec_loss.mel_spec(audio_hat[0]))
|
| 185 |
+
self.logger.experiment.add_image(
|
| 186 |
+
"train/mel_target",
|
| 187 |
+
plot_spectrogram_to_numpy(mel.data.cpu().numpy()),
|
| 188 |
+
self.global_step,
|
| 189 |
+
dataformats="HWC",
|
| 190 |
+
)
|
| 191 |
+
self.logger.experiment.add_image(
|
| 192 |
+
"train/mel_pred",
|
| 193 |
+
plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
|
| 194 |
+
self.global_step,
|
| 195 |
+
dataformats="HWC",
|
| 196 |
+
)
|
| 197 |
+
return loss
|
| 198 |
+
|
| 199 |
+
def on_validation_epoch_start(self):
|
| 200 |
+
if self.evaluate_utmos:
|
| 201 |
+
from model.vocos.metrics.UTMOS import UTMOSScore
|
| 202 |
+
# if not hasattr(self, "utmos_model"):
|
| 203 |
+
self.utmos_model = UTMOSScore(device=self.device)
|
| 204 |
+
|
| 205 |
+
def validation_step(self, batch, batch_idx, **kwargs):
|
| 206 |
+
audio_input = batch
|
| 207 |
+
audio_hat = self(audio_input, **kwargs)
|
| 208 |
+
|
| 209 |
+
audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.sample_rate, new_freq=16000)
|
| 210 |
+
audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.sample_rate, new_freq=16000)
|
| 211 |
+
|
| 212 |
+
if self.evaluate_periodicty:
|
| 213 |
+
from model.vocos.metrics.periodicity import calculate_periodicity_metrics
|
| 214 |
+
|
| 215 |
+
periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(audio_16_khz, audio_hat_16khz)
|
| 216 |
+
else:
|
| 217 |
+
periodicity_loss = pitch_loss = f1_score = 0
|
| 218 |
+
|
| 219 |
+
if self.evaluate_utmos:
|
| 220 |
+
utmos_score = self.utmos_model.score(audio_hat_16khz.unsqueeze(1)).mean()
|
| 221 |
+
else:
|
| 222 |
+
utmos_score = torch.zeros(1, device=self.device)
|
| 223 |
+
|
| 224 |
+
if self.evaluate_pesq:
|
| 225 |
+
from pesq import pesq
|
| 226 |
+
|
| 227 |
+
pesq_score = 0
|
| 228 |
+
for ref, deg in zip(audio_16_khz.cpu().numpy(), audio_hat_16khz.cpu().numpy()):
|
| 229 |
+
pesq_score += pesq(16000, ref, deg, "wb", on_error=1)
|
| 230 |
+
pesq_score /= len(audio_16_khz)
|
| 231 |
+
pesq_score = torch.tensor(pesq_score)
|
| 232 |
+
else:
|
| 233 |
+
pesq_score = torch.zeros(1, device=self.device)
|
| 234 |
+
|
| 235 |
+
mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1))
|
| 236 |
+
total_loss = mel_loss + (5 - utmos_score) + (5 - pesq_score)
|
| 237 |
+
|
| 238 |
+
return {
|
| 239 |
+
"val_loss": total_loss,
|
| 240 |
+
"mel_loss": mel_loss,
|
| 241 |
+
"utmos_score": utmos_score,
|
| 242 |
+
"pesq_score": pesq_score,
|
| 243 |
+
"periodicity_loss": periodicity_loss,
|
| 244 |
+
"pitch_loss": pitch_loss,
|
| 245 |
+
"f1_score": f1_score,
|
| 246 |
+
"audio_input": audio_input[0],
|
| 247 |
+
"audio_pred": audio_hat[0],
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
def validation_epoch_end(self, outputs):
|
| 251 |
+
if self.global_rank == 0:
|
| 252 |
+
for i, output in enumerate(outputs):
|
| 253 |
+
*_, audio_in, audio_pred = output.values()
|
| 254 |
+
self.logger.experiment.add_audio(
|
| 255 |
+
f"val_in_{i}", audio_in.data.cpu().numpy(), self.global_step, self.sample_rate
|
| 256 |
+
)
|
| 257 |
+
self.logger.experiment.add_audio(
|
| 258 |
+
f"val_pred_{i}", audio_pred.data.cpu().numpy(), self.global_step, self.sample_rate
|
| 259 |
+
)
|
| 260 |
+
mel_target = safe_log(self.melspec_loss.mel_spec(audio_in))
|
| 261 |
+
mel_hat = safe_log(self.melspec_loss.mel_spec(audio_pred))
|
| 262 |
+
self.logger.experiment.add_image(
|
| 263 |
+
f"val_mel_target_{i}",
|
| 264 |
+
plot_spectrogram_to_numpy(mel_target.data.cpu().numpy()),
|
| 265 |
+
self.global_step,
|
| 266 |
+
dataformats="HWC",
|
| 267 |
+
)
|
| 268 |
+
self.logger.experiment.add_image(
|
| 269 |
+
f"val_mel_hat_{i}",
|
| 270 |
+
plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
|
| 271 |
+
self.global_step,
|
| 272 |
+
dataformats="HWC",
|
| 273 |
+
)
|
| 274 |
+
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
|
| 275 |
+
mel_loss = torch.stack([x["mel_loss"] for x in outputs]).mean()
|
| 276 |
+
utmos_score = torch.stack([x["utmos_score"] for x in outputs]).mean()
|
| 277 |
+
pesq_score = torch.stack([x["pesq_score"] for x in outputs]).mean()
|
| 278 |
+
periodicity_loss = np.array([x["periodicity_loss"] for x in outputs]).mean()
|
| 279 |
+
pitch_loss = np.array([x["pitch_loss"] for x in outputs]).mean()
|
| 280 |
+
f1_score = np.array([x["f1_score"] for x in outputs]).mean()
|
| 281 |
+
|
| 282 |
+
self.log("val_loss", avg_loss, sync_dist=True)
|
| 283 |
+
self.log("val/mel_loss", mel_loss, sync_dist=True)
|
| 284 |
+
self.log("val/utmos_score", utmos_score, sync_dist=True)
|
| 285 |
+
self.log("val/pesq_score", pesq_score, sync_dist=True)
|
| 286 |
+
self.log("val/periodicity_loss", periodicity_loss, sync_dist=True)
|
| 287 |
+
self.log("val/pitch_loss", pitch_loss, sync_dist=True)
|
| 288 |
+
self.log("val/f1_score", f1_score, sync_dist=True)
|
| 289 |
+
|
| 290 |
+
return {
|
| 291 |
+
"avg_loss": avg_loss,
|
| 292 |
+
"mel_loss": mel_loss,
|
| 293 |
+
"utmos_score": utmos_score,
|
| 294 |
+
"pesq_score": pesq_score,
|
| 295 |
+
"periodicity_loss": periodicity_loss,
|
| 296 |
+
"pitch_loss": pitch_loss,
|
| 297 |
+
"f1_score": f1_score,
|
| 298 |
+
}
|
| 299 |
+
def on_test_epoch_start(self):
|
| 300 |
+
self.on_validation_epoch_start()
|
| 301 |
+
|
| 302 |
+
def test_step(self, *args, **kwargs):
|
| 303 |
+
return self.validation_step(*args, **kwargs)
|
| 304 |
+
|
| 305 |
+
def test_epoch_end(self, outputs):
|
| 306 |
+
results = self.validation_epoch_end(outputs)
|
| 307 |
+
print(results)
|
| 308 |
+
@property
|
| 309 |
+
def global_step(self):
|
| 310 |
+
"""
|
| 311 |
+
Override global_step so that it returns the total number of batches processed
|
| 312 |
+
"""
|
| 313 |
+
return self.trainer.fit_loop.epoch_loop.total_batch_idx
|
| 314 |
+
|
| 315 |
+
def on_train_batch_start(self, *args):
|
| 316 |
+
if self.global_step >= self.pretrain_mel_steps:
|
| 317 |
+
self.train_discriminator = True
|
| 318 |
+
else:
|
| 319 |
+
self.train_discriminator = False
|
| 320 |
+
|
| 321 |
+
def on_train_batch_end(self, *args):
|
| 322 |
+
def mel_loss_coeff_decay(current_step, num_cycles=0.5):
|
| 323 |
+
max_steps = self.trainer.max_steps // 2
|
| 324 |
+
if current_step < self.num_warmup_steps:
|
| 325 |
+
return 1.0
|
| 326 |
+
progress = float(current_step - self.num_warmup_steps) / float(
|
| 327 |
+
max(1, max_steps - self.num_warmup_steps)
|
| 328 |
+
)
|
| 329 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
| 330 |
+
|
| 331 |
+
if self.decay_mel_coeff:
|
| 332 |
+
self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(self.global_step + 1)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class VocosEncodecExp(VocosExp):
|
| 336 |
+
"""
|
| 337 |
+
VocosEncodecExp is a subclass of VocosExp that overrides the parent experiment to function as a conditional GAN.
|
| 338 |
+
It manages an additional `bandwidth_id` attribute, which denotes a learnable embedding corresponding to
|
| 339 |
+
a specific bandwidth value of EnCodec. During training, a random bandwidth_id is generated for each step,
|
| 340 |
+
while during validation, a fixed bandwidth_id is used.
|
| 341 |
+
"""
|
| 342 |
+
|
| 343 |
+
def __init__(
|
| 344 |
+
self,
|
| 345 |
+
feature_extractor: FeatureExtractor,
|
| 346 |
+
backbone: pl.LightningModule,
|
| 347 |
+
head: pl.LightningModule,
|
| 348 |
+
sample_rate: int,
|
| 349 |
+
initial_learning_rate: float,
|
| 350 |
+
num_warmup_steps: int,
|
| 351 |
+
mel_loss_coeff: float = 45,
|
| 352 |
+
mrd_loss_coeff: float = 1.0,
|
| 353 |
+
pretrain_mel_steps: int = 0,
|
| 354 |
+
decay_mel_coeff: bool = False,
|
| 355 |
+
evaluate_utmos: bool = False,
|
| 356 |
+
evaluate_pesq: bool = False,
|
| 357 |
+
evaluate_periodicty: bool = False,
|
| 358 |
+
):
|
| 359 |
+
super().__init__(
|
| 360 |
+
feature_extractor,
|
| 361 |
+
backbone,
|
| 362 |
+
head,
|
| 363 |
+
sample_rate,
|
| 364 |
+
initial_learning_rate,
|
| 365 |
+
num_warmup_steps,
|
| 366 |
+
mel_loss_coeff,
|
| 367 |
+
mrd_loss_coeff,
|
| 368 |
+
pretrain_mel_steps,
|
| 369 |
+
decay_mel_coeff,
|
| 370 |
+
evaluate_utmos,
|
| 371 |
+
evaluate_pesq,
|
| 372 |
+
evaluate_periodicty,
|
| 373 |
+
)
|
| 374 |
+
# Override with conditional discriminators
|
| 375 |
+
self.multiperioddisc = MultiPeriodDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
|
| 376 |
+
self.multiresddisc = MultiResolutionDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
|
| 377 |
+
|
| 378 |
+
def training_step(self, *args):
|
| 379 |
+
bandwidth_id = torch.randint(low=0, high=len(self.feature_extractor.bandwidths), size=(1,), device=self.device,)
|
| 380 |
+
output = super().training_step(*args, bandwidth_id=bandwidth_id)
|
| 381 |
+
return output
|
| 382 |
+
|
| 383 |
+
def validation_step(self, *args):
|
| 384 |
+
bandwidth_id = torch.tensor([0], device=self.device)
|
| 385 |
+
output = super().validation_step(*args, bandwidth_id=bandwidth_id)
|
| 386 |
+
return output
|
| 387 |
+
|
| 388 |
+
def validation_epoch_end(self, outputs):
|
| 389 |
+
if self.global_rank == 0:
|
| 390 |
+
*_, audio_in, _ = outputs[0].values()
|
| 391 |
+
# Resynthesis with encodec for reference
|
| 392 |
+
self.feature_extractor.encodec.set_target_bandwidth(self.feature_extractor.bandwidths[0])
|
| 393 |
+
encodec_audio = self.feature_extractor.encodec(audio_in[None, None, :])
|
| 394 |
+
self.logger.experiment.add_audio(
|
| 395 |
+
"encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.sample_rate,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
super().validation_epoch_end(outputs)
|
model/vocos/feature_extractors.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import librosa
|
| 5 |
+
from encodec import EncodecModel
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from torchaudio.transforms import Spectrogram, MelScale
|
| 10 |
+
from model.vocos.modules import safe_log
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FeatureExtractor(nn.Module):
|
| 14 |
+
"""Base class for feature extractors."""
|
| 15 |
+
|
| 16 |
+
def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 17 |
+
"""
|
| 18 |
+
Extract features from the given audio.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
audio (Tensor): Input audio waveform.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Tensor: Extracted features of shape (B, C, L), where B is the batch size,
|
| 25 |
+
C denotes output features, and L is the sequence length.
|
| 26 |
+
"""
|
| 27 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class LibrosaMelScale(nn.Module):
|
| 31 |
+
r"""This MelScale has a create_fb_matrix function that can be used to create a filterbank matrix.
|
| 32 |
+
same as previous torchaudio version
|
| 33 |
+
"""
|
| 34 |
+
__constants__ = ["n_mels", "sample_rate", "f_min", "f_max"]
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
n_mels: int = 128,
|
| 39 |
+
sample_rate: int = 16000,
|
| 40 |
+
f_min: float = 0.0,
|
| 41 |
+
f_max: Optional[float] = None,
|
| 42 |
+
n_stft: int = 201,
|
| 43 |
+
norm: Optional[str] = None,
|
| 44 |
+
mel_scale: str = "htk",
|
| 45 |
+
) -> None:
|
| 46 |
+
super(LibrosaMelScale, self).__init__()
|
| 47 |
+
self.n_mels = n_mels
|
| 48 |
+
self.sample_rate = sample_rate
|
| 49 |
+
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
|
| 50 |
+
self.f_min = f_min
|
| 51 |
+
self.norm = norm
|
| 52 |
+
self.mel_scale = mel_scale
|
| 53 |
+
|
| 54 |
+
if f_min > self.f_max:
|
| 55 |
+
raise ValueError("Require f_min: {} <= f_max: {}".format(f_min, self.f_max))
|
| 56 |
+
_mel_options = dict(
|
| 57 |
+
sr=sample_rate,
|
| 58 |
+
n_fft=(n_stft - 1) * 2,
|
| 59 |
+
n_mels=n_mels,
|
| 60 |
+
fmin=f_min,
|
| 61 |
+
fmax=f_max,
|
| 62 |
+
htk=mel_scale=="htk",
|
| 63 |
+
norm=norm
|
| 64 |
+
)
|
| 65 |
+
fb = torch.from_numpy(librosa.filters.mel(**_mel_options).T).float()
|
| 66 |
+
self.register_buffer("fb", fb)
|
| 67 |
+
|
| 68 |
+
def forward(self, specgram):
|
| 69 |
+
mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2)
|
| 70 |
+
return mel_specgram
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class MelSpectrogramFeatures(FeatureExtractor):
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
sample_rate: int,
|
| 77 |
+
n_fft: int,
|
| 78 |
+
n_win: int,
|
| 79 |
+
n_hop: int,
|
| 80 |
+
n_mels: int,
|
| 81 |
+
f_min: int,
|
| 82 |
+
f_max: int,
|
| 83 |
+
power: int,
|
| 84 |
+
center: bool,
|
| 85 |
+
normalize: bool,
|
| 86 |
+
onesided: bool,
|
| 87 |
+
mel_norm: str | None,
|
| 88 |
+
mel_scale: str,
|
| 89 |
+
librosa_mel: bool = True,
|
| 90 |
+
clip_val: float = 1e-7
|
| 91 |
+
):
|
| 92 |
+
super().__init__()
|
| 93 |
+
# This implementation vs torchaudio.transforms.MelSpectrogram: Add librosa melscale
|
| 94 |
+
# librosa melscale is numerically different from the torchaudio melscale (x_diff > 1e-5)
|
| 95 |
+
self.n_fft = n_fft
|
| 96 |
+
self.spectrogram = Spectrogram(
|
| 97 |
+
n_fft=n_fft,
|
| 98 |
+
win_length=n_win,
|
| 99 |
+
hop_length=n_hop,
|
| 100 |
+
power=power,
|
| 101 |
+
normalized=normalize,
|
| 102 |
+
center=center,
|
| 103 |
+
onesided=onesided,
|
| 104 |
+
)
|
| 105 |
+
mel_method = LibrosaMelScale if librosa_mel else MelScale
|
| 106 |
+
self.mel_scale = mel_method(
|
| 107 |
+
n_mels=n_mels,
|
| 108 |
+
sample_rate=sample_rate,
|
| 109 |
+
f_min=f_min,
|
| 110 |
+
f_max=f_max,
|
| 111 |
+
n_stft=n_fft // 2 + 1,
|
| 112 |
+
norm=mel_norm,
|
| 113 |
+
mel_scale=mel_scale,
|
| 114 |
+
)
|
| 115 |
+
self.clip_val = clip_val
|
| 116 |
+
|
| 117 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 118 |
+
# Compute Spectrogram
|
| 119 |
+
specgram = self.spectrogram(x)
|
| 120 |
+
mel_specgram = self.mel_scale(specgram)
|
| 121 |
+
return safe_log(mel_specgram, self.clip_val)
|
| 122 |
+
|
| 123 |
+
class EncodecFeatures(FeatureExtractor):
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
encodec_model: str = "encodec_24khz",
|
| 127 |
+
bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0],
|
| 128 |
+
train_codebooks: bool = False,
|
| 129 |
+
):
|
| 130 |
+
super().__init__()
|
| 131 |
+
if encodec_model == "encodec_24khz":
|
| 132 |
+
encodec = EncodecModel.encodec_model_24khz
|
| 133 |
+
elif encodec_model == "encodec_48khz":
|
| 134 |
+
encodec = EncodecModel.encodec_model_48khz
|
| 135 |
+
else:
|
| 136 |
+
raise ValueError(
|
| 137 |
+
f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz' and 'encodec_48khz'."
|
| 138 |
+
)
|
| 139 |
+
self.encodec = encodec(pretrained=True)
|
| 140 |
+
for param in self.encodec.parameters():
|
| 141 |
+
param.requires_grad = False
|
| 142 |
+
self.num_q = self.encodec.quantizer.get_num_quantizers_for_bandwidth(
|
| 143 |
+
self.encodec.frame_rate, bandwidth=max(bandwidths)
|
| 144 |
+
)
|
| 145 |
+
codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0)
|
| 146 |
+
self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks)
|
| 147 |
+
self.bandwidths = bandwidths
|
| 148 |
+
|
| 149 |
+
@torch.no_grad()
|
| 150 |
+
def get_encodec_codes(self, audio):
|
| 151 |
+
audio = audio.unsqueeze(1)
|
| 152 |
+
emb = self.encodec.encoder(audio)
|
| 153 |
+
codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth)
|
| 154 |
+
return codes
|
| 155 |
+
|
| 156 |
+
def forward(self, audio: torch.Tensor, **kwargs):
|
| 157 |
+
bandwidth_id = kwargs.get("bandwidth_id")
|
| 158 |
+
if bandwidth_id is None:
|
| 159 |
+
raise ValueError("The 'bandwidth_id' argument is required")
|
| 160 |
+
self.encodec.eval() # Force eval mode as Pytorch Lightning automatically sets child modules to training mode
|
| 161 |
+
self.encodec.set_target_bandwidth(self.bandwidths[bandwidth_id])
|
| 162 |
+
codes = self.get_encodec_codes(audio)
|
| 163 |
+
# Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights`
|
| 164 |
+
# with offsets given by the number of bins, and finally summed in a vectorized operation.
|
| 165 |
+
offsets = torch.arange(
|
| 166 |
+
0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device
|
| 167 |
+
)
|
| 168 |
+
embeddings_idxs = codes + offsets.view(-1, 1, 1)
|
| 169 |
+
features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0)
|
| 170 |
+
return features.transpose(1, 2)
|
model/vocos/heads.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
|
| 6 |
+
|
| 7 |
+
from model.vocos.spectral_ops import IMDCT, ISTFT
|
| 8 |
+
from model.vocos.modules import symexp
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class FourierHead(nn.Module):
|
| 12 |
+
"""Base class for inverse fourier modules."""
|
| 13 |
+
|
| 14 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
"""
|
| 16 |
+
Args:
|
| 17 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 18 |
+
L is the sequence length, and H denotes the model dimension.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 22 |
+
"""
|
| 23 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ISTFTHead(FourierHead):
|
| 27 |
+
"""
|
| 28 |
+
ISTFT Head module for predicting STFT complex coefficients.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
dim (int): Hidden dimension of the model.
|
| 32 |
+
n_fft (int): Size of Fourier transform.
|
| 33 |
+
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
| 34 |
+
the resolution of the input features.
|
| 35 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
| 39 |
+
super().__init__()
|
| 40 |
+
out_dim = n_fft + 2
|
| 41 |
+
self.out = torch.nn.Linear(dim, out_dim)
|
| 42 |
+
self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
|
| 43 |
+
|
| 44 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
"""
|
| 46 |
+
Forward pass of the ISTFTHead module.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 50 |
+
L is the sequence length, and H denotes the model dimension.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 54 |
+
"""
|
| 55 |
+
x = self.out(x).transpose(1, 2)
|
| 56 |
+
mag, p = x.chunk(2, dim=1)
|
| 57 |
+
mag = torch.exp(mag)
|
| 58 |
+
mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
|
| 59 |
+
# wrapping happens here. These two lines produce real and imaginary value
|
| 60 |
+
x = torch.cos(p)
|
| 61 |
+
y = torch.sin(p)
|
| 62 |
+
# recalculating phase here does not produce anything new
|
| 63 |
+
# only costs time
|
| 64 |
+
# phase = torch.atan2(y, x)
|
| 65 |
+
# S = mag * torch.exp(phase * 1j)
|
| 66 |
+
# better directly produce the complex value
|
| 67 |
+
S = mag * (x + 1j * y)
|
| 68 |
+
audio = self.istft(S)
|
| 69 |
+
return audio
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class IMDCTSymExpHead(FourierHead):
|
| 73 |
+
"""
|
| 74 |
+
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
dim (int): Hidden dimension of the model.
|
| 78 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
| 79 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 80 |
+
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
|
| 81 |
+
based on perceptual scaling. Defaults to None.
|
| 82 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
dim: int,
|
| 88 |
+
mdct_frame_len: int,
|
| 89 |
+
padding: str = "same",
|
| 90 |
+
sample_rate: Optional[int] = None,
|
| 91 |
+
clip_audio: bool = False,
|
| 92 |
+
):
|
| 93 |
+
super().__init__()
|
| 94 |
+
out_dim = mdct_frame_len // 2
|
| 95 |
+
self.out = nn.Linear(dim, out_dim)
|
| 96 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
| 97 |
+
self.clip_audio = clip_audio
|
| 98 |
+
|
| 99 |
+
if sample_rate is not None:
|
| 100 |
+
# optionally init the last layer following mel-scale
|
| 101 |
+
m_max = _hz_to_mel(sample_rate // 2)
|
| 102 |
+
m_pts = torch.linspace(0, m_max, out_dim)
|
| 103 |
+
f_pts = _mel_to_hz(m_pts)
|
| 104 |
+
scale = 1 - (f_pts / f_pts.max())
|
| 105 |
+
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
self.out.weight.mul_(scale.view(-1, 1))
|
| 108 |
+
|
| 109 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 110 |
+
"""
|
| 111 |
+
Forward pass of the IMDCTSymExpHead module.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 115 |
+
L is the sequence length, and H denotes the model dimension.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 119 |
+
"""
|
| 120 |
+
x = self.out(x)
|
| 121 |
+
x = symexp(x)
|
| 122 |
+
x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes
|
| 123 |
+
audio = self.imdct(x)
|
| 124 |
+
if self.clip_audio:
|
| 125 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
| 126 |
+
|
| 127 |
+
return audio
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class IMDCTCosHead(FourierHead):
|
| 131 |
+
"""
|
| 132 |
+
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
dim (int): Hidden dimension of the model.
|
| 136 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
| 137 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 138 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.clip_audio = clip_audio
|
| 144 |
+
self.out = nn.Linear(dim, mdct_frame_len)
|
| 145 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
| 146 |
+
|
| 147 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 148 |
+
"""
|
| 149 |
+
Forward pass of the IMDCTCosHead module.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 153 |
+
L is the sequence length, and H denotes the model dimension.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 157 |
+
"""
|
| 158 |
+
x = self.out(x)
|
| 159 |
+
m, p = x.chunk(2, dim=2)
|
| 160 |
+
m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes
|
| 161 |
+
audio = self.imdct(m * torch.cos(p))
|
| 162 |
+
if self.clip_audio:
|
| 163 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
| 164 |
+
return audio
|
model/vocos/helpers.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from matplotlib import pyplot as plt
|
| 5 |
+
from pytorch_lightning import Callback
|
| 6 |
+
|
| 7 |
+
matplotlib.use("Agg")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray:
|
| 11 |
+
"""
|
| 12 |
+
Save a matplotlib figure to a numpy array.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
fig (Figure): Matplotlib figure object.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
ndarray: Numpy array representing the figure.
|
| 19 |
+
"""
|
| 20 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
| 21 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
| 22 |
+
return data
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray:
|
| 26 |
+
"""
|
| 27 |
+
Plot a spectrogram and convert it to a numpy array.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
spectrogram (ndarray): Spectrogram data.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
ndarray: Numpy array representing the plotted spectrogram.
|
| 34 |
+
"""
|
| 35 |
+
spectrogram = spectrogram.astype(np.float32)
|
| 36 |
+
fig, ax = plt.subplots(figsize=(12, 3))
|
| 37 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
| 38 |
+
plt.colorbar(im, ax=ax)
|
| 39 |
+
plt.xlabel("Frames")
|
| 40 |
+
plt.ylabel("Channels")
|
| 41 |
+
plt.tight_layout()
|
| 42 |
+
|
| 43 |
+
fig.canvas.draw()
|
| 44 |
+
data = save_figure_to_numpy(fig)
|
| 45 |
+
plt.close()
|
| 46 |
+
return data
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class GradNormCallback(Callback):
|
| 50 |
+
"""
|
| 51 |
+
Callback to log the gradient norm.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def on_after_backward(self, trainer, model):
|
| 55 |
+
model.log("grad_norm", gradient_norm(model))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor:
|
| 59 |
+
"""
|
| 60 |
+
Compute the gradient norm.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
model (Module): PyTorch model.
|
| 64 |
+
norm_type (float, optional): Type of the norm. Defaults to 2.0.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Tensor: Gradient norm.
|
| 68 |
+
"""
|
| 69 |
+
grads = [p.grad for p in model.parameters() if p.grad is not None]
|
| 70 |
+
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type)
|
| 71 |
+
return total_norm
|
model/vocos/loss.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from model.vocos.offline.modules import safe_log
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MelSpecReconstructionLoss(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100,
|
| 17 |
+
):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.mel_spec = torchaudio.transforms.MelSpectrogram(
|
| 20 |
+
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def forward(self, y_hat, y) -> torch.Tensor:
|
| 24 |
+
"""
|
| 25 |
+
Args:
|
| 26 |
+
y_hat (Tensor): Predicted audio waveform.
|
| 27 |
+
y (Tensor): Ground truth audio waveform.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Tensor: L1 loss between the mel-scaled magnitude spectrograms.
|
| 31 |
+
"""
|
| 32 |
+
mel_hat = safe_log(self.mel_spec(y_hat))
|
| 33 |
+
mel = safe_log(self.mel_spec(y))
|
| 34 |
+
|
| 35 |
+
loss = torch.nn.functional.l1_loss(mel, mel_hat)
|
| 36 |
+
|
| 37 |
+
return loss
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class GeneratorLoss(nn.Module):
|
| 41 |
+
"""
|
| 42 |
+
Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
disc_outputs (List[Tensor]): List of discriminator outputs.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
|
| 52 |
+
the sub-discriminators
|
| 53 |
+
"""
|
| 54 |
+
loss = torch.zeros(1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype)
|
| 55 |
+
gen_losses = []
|
| 56 |
+
for dg in disc_outputs:
|
| 57 |
+
l = torch.mean(torch.clamp(1 - dg, min=0))
|
| 58 |
+
gen_losses.append(l)
|
| 59 |
+
loss += l
|
| 60 |
+
|
| 61 |
+
return loss, gen_losses
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class DiscriminatorLoss(nn.Module):
|
| 65 |
+
"""
|
| 66 |
+
Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def forward(
|
| 70 |
+
self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
|
| 71 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
| 72 |
+
"""
|
| 73 |
+
Args:
|
| 74 |
+
disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
|
| 75 |
+
disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
|
| 79 |
+
the sub-discriminators for real outputs, and a list of
|
| 80 |
+
loss values for generated outputs.
|
| 81 |
+
"""
|
| 82 |
+
loss = torch.zeros(1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype)
|
| 83 |
+
r_losses = []
|
| 84 |
+
g_losses = []
|
| 85 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
| 86 |
+
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
|
| 87 |
+
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
|
| 88 |
+
loss += r_loss + g_loss
|
| 89 |
+
r_losses.append(r_loss)
|
| 90 |
+
g_losses.append(g_loss)
|
| 91 |
+
|
| 92 |
+
return loss, r_losses, g_losses
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class FeatureMatchingLoss(nn.Module):
|
| 96 |
+
"""
|
| 97 |
+
Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
|
| 101 |
+
"""
|
| 102 |
+
Args:
|
| 103 |
+
fmap_r (List[List[Tensor]]): List of feature maps from real samples.
|
| 104 |
+
fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Tensor: The calculated feature matching loss.
|
| 108 |
+
"""
|
| 109 |
+
loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype)
|
| 110 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
| 111 |
+
for rl, gl in zip(dr, dg):
|
| 112 |
+
loss += torch.mean(torch.abs(rl - gl))
|
| 113 |
+
|
| 114 |
+
return loss
|
model/vocos/models.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.utils import weight_norm
|
| 6 |
+
|
| 7 |
+
from model.vocos.modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Backbone(nn.Module):
|
| 11 |
+
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
| 12 |
+
|
| 13 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 14 |
+
"""
|
| 15 |
+
Args:
|
| 16 |
+
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
| 17 |
+
C denotes output features, and L is the sequence length.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
| 21 |
+
and H denotes the model dimension.
|
| 22 |
+
"""
|
| 23 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class VocosBackbone(Backbone):
|
| 27 |
+
"""
|
| 28 |
+
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
input_channels (int): Number of input features channels.
|
| 32 |
+
dim (int): Hidden dimension of the model.
|
| 33 |
+
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
| 34 |
+
num_layers (int): Number of ConvNeXtBlock layers.
|
| 35 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
| 36 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
| 37 |
+
None means non-conditional model. Defaults to None.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
input_channels: int,
|
| 43 |
+
dim: int,
|
| 44 |
+
intermediate_dim: int,
|
| 45 |
+
num_layers: int,
|
| 46 |
+
layer_scale_init_value: Optional[float] = None,
|
| 47 |
+
adanorm_num_embeddings: Optional[int] = None,
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.input_channels = input_channels
|
| 51 |
+
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
|
| 52 |
+
self.adanorm = adanorm_num_embeddings is not None
|
| 53 |
+
if adanorm_num_embeddings:
|
| 54 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
| 55 |
+
else:
|
| 56 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 57 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
| 58 |
+
self.convnext = nn.ModuleList(
|
| 59 |
+
[
|
| 60 |
+
ConvNeXtBlock(
|
| 61 |
+
dim=dim,
|
| 62 |
+
intermediate_dim=intermediate_dim,
|
| 63 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 64 |
+
adanorm_num_embeddings=adanorm_num_embeddings,
|
| 65 |
+
)
|
| 66 |
+
for _ in range(num_layers)
|
| 67 |
+
]
|
| 68 |
+
)
|
| 69 |
+
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
| 70 |
+
self.apply(self._init_weights)
|
| 71 |
+
|
| 72 |
+
def _init_weights(self, m):
|
| 73 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
| 74 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 75 |
+
nn.init.constant_(m.bias, 0)
|
| 76 |
+
|
| 77 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 78 |
+
bandwidth_id = kwargs.get('bandwidth_id', None)
|
| 79 |
+
x = self.embed(x)
|
| 80 |
+
if self.adanorm:
|
| 81 |
+
assert bandwidth_id is not None
|
| 82 |
+
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
|
| 83 |
+
else:
|
| 84 |
+
x = self.norm(x.transpose(1, 2))
|
| 85 |
+
x = x.transpose(1, 2)
|
| 86 |
+
for conv_block in self.convnext:
|
| 87 |
+
x = conv_block(x, cond_embedding_id=bandwidth_id)
|
| 88 |
+
x = self.final_layer_norm(x.transpose(1, 2))
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class VocosResNetBackbone(Backbone):
|
| 93 |
+
"""
|
| 94 |
+
Vocos backbone module built with ResBlocks.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
input_channels (int): Number of input features channels.
|
| 98 |
+
dim (int): Hidden dimension of the model.
|
| 99 |
+
num_blocks (int): Number of ResBlock1 blocks.
|
| 100 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self, input_channels, dim, num_blocks, layer_scale_init_value=None,
|
| 105 |
+
):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.input_channels = input_channels
|
| 108 |
+
self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1))
|
| 109 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
|
| 110 |
+
self.resnet = nn.Sequential(
|
| 111 |
+
*[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)]
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 115 |
+
x = self.embed(x)
|
| 116 |
+
x = self.resnet(x)
|
| 117 |
+
x = x.transpose(1, 2)
|
| 118 |
+
return x
|
model/vocos/modules.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ConvNeXtBlock(nn.Module):
|
| 9 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
dim (int): Number of input channels.
|
| 13 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
| 14 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
| 15 |
+
Defaults to None.
|
| 16 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
| 17 |
+
None means non-conditional LayerNorm. Defaults to None.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
dim: int,
|
| 23 |
+
intermediate_dim: int,
|
| 24 |
+
layer_scale_init_value: float,
|
| 25 |
+
adanorm_num_embeddings: Optional[int] = None,
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
| 29 |
+
self.adanorm = adanorm_num_embeddings is not None
|
| 30 |
+
if adanorm_num_embeddings:
|
| 31 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
| 32 |
+
else:
|
| 33 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 34 |
+
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
| 35 |
+
self.act = nn.GELU()
|
| 36 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 37 |
+
self.gamma = (
|
| 38 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
| 39 |
+
if layer_scale_init_value > 0
|
| 40 |
+
else None
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 44 |
+
residual = x
|
| 45 |
+
x = self.dwconv(x)
|
| 46 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
| 47 |
+
if self.adanorm:
|
| 48 |
+
assert cond_embedding_id is not None
|
| 49 |
+
x = self.norm(x, cond_embedding_id)
|
| 50 |
+
else:
|
| 51 |
+
x = self.norm(x)
|
| 52 |
+
x = self.pwconv1(x)
|
| 53 |
+
x = self.act(x)
|
| 54 |
+
x = self.pwconv2(x)
|
| 55 |
+
if self.gamma is not None:
|
| 56 |
+
x = self.gamma * x
|
| 57 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
| 58 |
+
|
| 59 |
+
x = residual + x
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class AdaLayerNorm(nn.Module):
|
| 64 |
+
"""
|
| 65 |
+
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
num_embeddings (int): Number of embeddings.
|
| 69 |
+
embedding_dim (int): Dimension of the embeddings.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.eps = eps
|
| 75 |
+
self.dim = embedding_dim
|
| 76 |
+
self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
| 77 |
+
self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
| 78 |
+
torch.nn.init.ones_(self.scale.weight)
|
| 79 |
+
torch.nn.init.zeros_(self.shift.weight)
|
| 80 |
+
|
| 81 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
|
| 82 |
+
scale = self.scale(cond_embedding_id)
|
| 83 |
+
shift = self.shift(cond_embedding_id)
|
| 84 |
+
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
| 85 |
+
x = x * scale + shift
|
| 86 |
+
return x
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class ResBlock1(nn.Module):
|
| 90 |
+
"""
|
| 91 |
+
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
| 92 |
+
but without upsampling layers.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
dim (int): Number of input channels.
|
| 96 |
+
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
| 97 |
+
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
| 98 |
+
Defaults to (1, 3, 5).
|
| 99 |
+
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
| 100 |
+
Defaults to 0.1.
|
| 101 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
| 102 |
+
Defaults to None.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
dim: int,
|
| 108 |
+
kernel_size: int = 3,
|
| 109 |
+
dilation: Tuple[int, int, int] = (1, 3, 5),
|
| 110 |
+
lrelu_slope: float = 0.1,
|
| 111 |
+
layer_scale_init_value: Optional[float] = None,
|
| 112 |
+
):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.lrelu_slope = lrelu_slope
|
| 115 |
+
self.convs1 = nn.ModuleList(
|
| 116 |
+
[
|
| 117 |
+
weight_norm(
|
| 118 |
+
nn.Conv1d(
|
| 119 |
+
dim,
|
| 120 |
+
dim,
|
| 121 |
+
kernel_size,
|
| 122 |
+
1,
|
| 123 |
+
dilation=dilation[0],
|
| 124 |
+
padding=self.get_padding(kernel_size, dilation[0]),
|
| 125 |
+
)
|
| 126 |
+
),
|
| 127 |
+
weight_norm(
|
| 128 |
+
nn.Conv1d(
|
| 129 |
+
dim,
|
| 130 |
+
dim,
|
| 131 |
+
kernel_size,
|
| 132 |
+
1,
|
| 133 |
+
dilation=dilation[1],
|
| 134 |
+
padding=self.get_padding(kernel_size, dilation[1]),
|
| 135 |
+
)
|
| 136 |
+
),
|
| 137 |
+
weight_norm(
|
| 138 |
+
nn.Conv1d(
|
| 139 |
+
dim,
|
| 140 |
+
dim,
|
| 141 |
+
kernel_size,
|
| 142 |
+
1,
|
| 143 |
+
dilation=dilation[2],
|
| 144 |
+
padding=self.get_padding(kernel_size, dilation[2]),
|
| 145 |
+
)
|
| 146 |
+
),
|
| 147 |
+
]
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.convs2 = nn.ModuleList(
|
| 151 |
+
[
|
| 152 |
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
| 153 |
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
| 154 |
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
| 155 |
+
]
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
self.gamma = nn.ParameterList(
|
| 159 |
+
[
|
| 160 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
| 161 |
+
if layer_scale_init_value is not None
|
| 162 |
+
else None,
|
| 163 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
| 164 |
+
if layer_scale_init_value is not None
|
| 165 |
+
else None,
|
| 166 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
| 167 |
+
if layer_scale_init_value is not None
|
| 168 |
+
else None,
|
| 169 |
+
]
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 173 |
+
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
| 174 |
+
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
| 175 |
+
xt = c1(xt)
|
| 176 |
+
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
| 177 |
+
xt = c2(xt)
|
| 178 |
+
if gamma is not None:
|
| 179 |
+
xt = gamma * xt
|
| 180 |
+
x = xt + x
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
def remove_weight_norm(self):
|
| 184 |
+
for l in self.convs1:
|
| 185 |
+
remove_weight_norm(l)
|
| 186 |
+
for l in self.convs2:
|
| 187 |
+
remove_weight_norm(l)
|
| 188 |
+
|
| 189 |
+
@staticmethod
|
| 190 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
| 191 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
| 195 |
+
"""
|
| 196 |
+
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
x (Tensor): Input tensor.
|
| 200 |
+
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
Tensor: Element-wise logarithm of the input tensor with clipping applied.
|
| 204 |
+
"""
|
| 205 |
+
return torch.log(torch.clip(x, min=clip_val))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def symlog(x: torch.Tensor) -> torch.Tensor:
|
| 209 |
+
return torch.sign(x) * torch.log1p(x.abs())
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def symexp(x: torch.Tensor) -> torch.Tensor:
|
| 213 |
+
return torch.sign(x) * (torch.exp(x.abs()) - 1)
|
model/vocos/pretrained.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, Tuple, Union, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import yaml
|
| 7 |
+
from huggingface_hub import hf_hub_download
|
| 8 |
+
from torch import nn
|
| 9 |
+
from model.vocos.feature_extractors import FeatureExtractor, EncodecFeatures
|
| 10 |
+
from model.vocos.heads import FourierHead
|
| 11 |
+
from model.vocos.models import Backbone
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
|
| 15 |
+
"""Instantiates a class with the given args and init.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
args: Positional arguments required for instantiation.
|
| 19 |
+
init: Dict of the form {"class_path":...,"init_args":...}.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
The instantiated class object.
|
| 23 |
+
"""
|
| 24 |
+
kwargs = init.get("init_args", {})
|
| 25 |
+
if not isinstance(args, tuple):
|
| 26 |
+
args = (args,)
|
| 27 |
+
class_module, class_name = init["class_path"].rsplit(".", 1)
|
| 28 |
+
module = __import__(class_module, fromlist=[class_name])
|
| 29 |
+
args_class = getattr(module, class_name)
|
| 30 |
+
return args_class(*args, **kwargs)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Vocos(nn.Module):
|
| 34 |
+
"""
|
| 35 |
+
The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
|
| 36 |
+
This class is primarily designed for inference, with support for loading from pretrained
|
| 37 |
+
model checkpoints. It consists of three main components: a feature extractor,
|
| 38 |
+
a backbone, and a head.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self, feature_extractor: nn.Module, backbone: Backbone, head: FourierHead,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.feature_extractor = feature_extractor
|
| 46 |
+
self.backbone = backbone
|
| 47 |
+
self.head = head
|
| 48 |
+
|
| 49 |
+
@classmethod
|
| 50 |
+
def from_hparams(cls, config_path: str) -> "Vocos":
|
| 51 |
+
"""
|
| 52 |
+
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
|
| 53 |
+
"""
|
| 54 |
+
with open(config_path, "r") as f:
|
| 55 |
+
config = yaml.safe_load(f)
|
| 56 |
+
feature_extractor = instantiate_class(args=(), init=config["feature_extractor"])
|
| 57 |
+
backbone = instantiate_class(args=(), init=config["backbone"])
|
| 58 |
+
head = instantiate_class(args=(), init=config["head"])
|
| 59 |
+
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
|
| 60 |
+
return model
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def from_pretrained(self, config_path: str, model_path: str, model: nn.Module=None) -> "Vocos":
|
| 64 |
+
"""
|
| 65 |
+
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
|
| 66 |
+
"""
|
| 67 |
+
if model is None:
|
| 68 |
+
model = self.from_hparams(config_path)
|
| 69 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 70 |
+
prefixes = ("backbone", "feature_extractor", "head")
|
| 71 |
+
state_dict = {
|
| 72 |
+
key: value
|
| 73 |
+
for key, value in state_dict.items()
|
| 74 |
+
if any(key.startswith(prefix) for prefix in prefixes)
|
| 75 |
+
}
|
| 76 |
+
if isinstance(model.feature_extractor, EncodecFeatures):
|
| 77 |
+
encodec_parameters = {
|
| 78 |
+
"feature_extractor.encodec." + key: value
|
| 79 |
+
for key, value in model.feature_extractor.encodec.state_dict().items()
|
| 80 |
+
}
|
| 81 |
+
state_dict.update(encodec_parameters)
|
| 82 |
+
model.load_state_dict(state_dict)
|
| 83 |
+
model.eval()
|
| 84 |
+
return model
|
| 85 |
+
|
| 86 |
+
@torch.inference_mode()
|
| 87 |
+
def forward(self, features_input: torch.Tensor, X_norm, **kwargs: Any) -> torch.Tensor:
|
| 88 |
+
"""
|
| 89 |
+
Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
|
| 90 |
+
which is then passed through the backbone and the head to reconstruct the audio output.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
|
| 94 |
+
where B is the batch size and L is the waveform length.
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
|
| 99 |
+
"""
|
| 100 |
+
audio_output = self.decode(features_input, **kwargs)
|
| 101 |
+
return audio_output / X_norm
|
| 102 |
+
|
| 103 |
+
@torch.inference_mode()
|
| 104 |
+
def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
| 105 |
+
"""
|
| 106 |
+
Method to decode audio waveform from already calculated features. The features input is passed through
|
| 107 |
+
the backbone and the head to reconstruct the audio output.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
|
| 111 |
+
C denotes the feature dimension, and L is the sequence length.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
|
| 115 |
+
"""
|
| 116 |
+
x = self.backbone(features_input, **kwargs)
|
| 117 |
+
audio_output = self.head(x)
|
| 118 |
+
return audio_output
|
| 119 |
+
|
| 120 |
+
@torch.inference_mode()
|
| 121 |
+
def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
"""
|
| 123 |
+
Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
|
| 124 |
+
codebook weights.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
|
| 128 |
+
where K is the number of codebooks, B is the batch size and L is the sequence length.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
|
| 132 |
+
and L is the sequence length.
|
| 133 |
+
"""
|
| 134 |
+
assert isinstance(
|
| 135 |
+
self.feature_extractor, EncodecFeatures
|
| 136 |
+
), "Feature extractor should be an instance of EncodecFeatures"
|
| 137 |
+
|
| 138 |
+
if codes.dim() == 2:
|
| 139 |
+
codes = codes.unsqueeze(1)
|
| 140 |
+
|
| 141 |
+
n_bins = self.feature_extractor.encodec.quantizer.bins
|
| 142 |
+
offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device)
|
| 143 |
+
embeddings_idxs = codes + offsets.view(-1, 1, 1)
|
| 144 |
+
features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
|
| 145 |
+
features = features.transpose(1, 2)
|
| 146 |
+
|
| 147 |
+
return features
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
+
model = Vocos.from_pretrained(
|
| 152 |
+
"/nvmework3/shaonian/MelSpatialNet/MelSpatialNet/models/vocos/pretrained/pretrained_rec_normed.yaml",
|
| 153 |
+
"/nvmework3/shaonian/MelSpatialNet/MelSpatialNet/models/vocos/pretrained/vocos_hop128_clip1e-5_rts.ckpt").to("meta")
|
| 154 |
+
x = torch.randn(1, 80, 501)
|
| 155 |
+
x = x.to('meta')
|
| 156 |
+
from torch.utils.flop_counter import FlopCounterMode # requires torch>=2.1.0
|
| 157 |
+
with FlopCounterMode(model, display=False) as fcm:
|
| 158 |
+
y = model.decode(x)
|
| 159 |
+
flops_forward_eval = fcm.get_total_flops()
|
| 160 |
+
|
| 161 |
+
params_eval = sum(param.numel() for param in model.parameters())
|
| 162 |
+
print(f"flops_forward={flops_forward_eval/4e9:.2f}G, params={params_eval/1e6:.2f} M")
|
model/vocos/spectral_ops.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import scipy
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn, view_as_real, view_as_complex
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ISTFT(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
| 10 |
+
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
| 11 |
+
See issue: https://github.com/pytorch/pytorch/issues/62323
|
| 12 |
+
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
| 13 |
+
The NOLA constraint is met as we trim padded samples anyway.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
n_fft (int): Size of Fourier transform.
|
| 17 |
+
hop_length (int): The distance between neighboring sliding window frames.
|
| 18 |
+
win_length (int): The size of window frame and STFT filter.
|
| 19 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
|
| 23 |
+
super().__init__()
|
| 24 |
+
if padding not in ["center", "same"]:
|
| 25 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 26 |
+
self.padding = padding
|
| 27 |
+
self.n_fft = n_fft
|
| 28 |
+
self.hop_length = hop_length
|
| 29 |
+
self.win_length = win_length
|
| 30 |
+
window = torch.hann_window(win_length)
|
| 31 |
+
self.register_buffer("window", window)
|
| 32 |
+
|
| 33 |
+
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
"""
|
| 35 |
+
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
| 39 |
+
N is the number of frequency bins, and T is the number of time frames.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
| 43 |
+
"""
|
| 44 |
+
if self.padding == "center":
|
| 45 |
+
# Fallback to pytorch native implementation
|
| 46 |
+
return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
|
| 47 |
+
elif self.padding == "same":
|
| 48 |
+
pad = (self.win_length - self.hop_length) // 2
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 51 |
+
|
| 52 |
+
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
| 53 |
+
B, N, T = spec.shape
|
| 54 |
+
|
| 55 |
+
# Inverse FFT
|
| 56 |
+
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
| 57 |
+
ifft = ifft * self.window[None, :, None]
|
| 58 |
+
|
| 59 |
+
# Overlap and Add
|
| 60 |
+
output_size = (T - 1) * self.hop_length + self.win_length
|
| 61 |
+
y = torch.nn.functional.fold(
|
| 62 |
+
ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
|
| 63 |
+
)[:, 0, 0, pad:-pad]
|
| 64 |
+
|
| 65 |
+
# Window envelope
|
| 66 |
+
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
| 67 |
+
window_envelope = torch.nn.functional.fold(
|
| 68 |
+
window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
|
| 69 |
+
).squeeze()[pad:-pad]
|
| 70 |
+
|
| 71 |
+
# Normalize
|
| 72 |
+
assert (window_envelope > 1e-11).all()
|
| 73 |
+
y = y / window_envelope
|
| 74 |
+
|
| 75 |
+
return y
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class MDCT(nn.Module):
|
| 79 |
+
"""
|
| 80 |
+
Modified Discrete Cosine Transform (MDCT) module.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
frame_len (int): Length of the MDCT frame.
|
| 84 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
| 88 |
+
super().__init__()
|
| 89 |
+
if padding not in ["center", "same"]:
|
| 90 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 91 |
+
self.padding = padding
|
| 92 |
+
self.frame_len = frame_len
|
| 93 |
+
N = frame_len // 2
|
| 94 |
+
n0 = (N + 1) / 2
|
| 95 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
| 96 |
+
self.register_buffer("window", window)
|
| 97 |
+
|
| 98 |
+
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
|
| 99 |
+
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
|
| 100 |
+
# view_as_real: NCCL Backend does not support ComplexFloat data type
|
| 101 |
+
# https://github.com/pytorch/pytorch/issues/71613
|
| 102 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
| 103 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
| 104 |
+
|
| 105 |
+
def forward(self, audio: torch.Tensor) -> torch.Tensor:
|
| 106 |
+
"""
|
| 107 |
+
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
|
| 111 |
+
and T is the length of the audio.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
|
| 115 |
+
and N is the number of frequency bins.
|
| 116 |
+
"""
|
| 117 |
+
if self.padding == "center":
|
| 118 |
+
audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2))
|
| 119 |
+
elif self.padding == "same":
|
| 120 |
+
# hop_length is 1/2 frame_len
|
| 121 |
+
audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4))
|
| 122 |
+
else:
|
| 123 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 124 |
+
|
| 125 |
+
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
|
| 126 |
+
N = self.frame_len // 2
|
| 127 |
+
x = x * self.window.expand(x.shape)
|
| 128 |
+
X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N]
|
| 129 |
+
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
|
| 130 |
+
return torch.real(res) * np.sqrt(2)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class IMDCT(nn.Module):
|
| 134 |
+
"""
|
| 135 |
+
Inverse Modified Discrete Cosine Transform (IMDCT) module.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
frame_len (int): Length of the MDCT frame.
|
| 139 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
| 143 |
+
super().__init__()
|
| 144 |
+
if padding not in ["center", "same"]:
|
| 145 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 146 |
+
self.padding = padding
|
| 147 |
+
self.frame_len = frame_len
|
| 148 |
+
N = frame_len // 2
|
| 149 |
+
n0 = (N + 1) / 2
|
| 150 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
| 151 |
+
self.register_buffer("window", window)
|
| 152 |
+
|
| 153 |
+
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
|
| 154 |
+
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
|
| 155 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
| 156 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
| 157 |
+
|
| 158 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
| 159 |
+
"""
|
| 160 |
+
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
|
| 164 |
+
L is the number of frames, and N is the number of frequency bins.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
|
| 168 |
+
"""
|
| 169 |
+
B, L, N = X.shape
|
| 170 |
+
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
|
| 171 |
+
Y[..., :N] = X
|
| 172 |
+
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
|
| 173 |
+
y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1)
|
| 174 |
+
y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2)
|
| 175 |
+
result = y * self.window.expand(y.shape)
|
| 176 |
+
output_size = (1, (L + 1) * N)
|
| 177 |
+
audio = torch.nn.functional.fold(
|
| 178 |
+
result.transpose(1, 2),
|
| 179 |
+
output_size=output_size,
|
| 180 |
+
kernel_size=(1, self.frame_len),
|
| 181 |
+
stride=(1, self.frame_len // 2),
|
| 182 |
+
)[:, 0, 0, :]
|
| 183 |
+
|
| 184 |
+
if self.padding == "center":
|
| 185 |
+
pad = self.frame_len // 2
|
| 186 |
+
elif self.padding == "same":
|
| 187 |
+
pad = self.frame_len // 4
|
| 188 |
+
else:
|
| 189 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 190 |
+
|
| 191 |
+
audio = audio[:, pad:-pad]
|
| 192 |
+
return audio
|