Spaces:
Running
Running
Commit
·
d0cd3b0
1
Parent(s):
3ccdc25
test1
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +131 -0
- assets/translations.py +171 -0
- model_list.py +0 -0
- multi_inference.py +303 -0
- requirements.txt +50 -0
- separator/audio_writer.py +85 -0
- separator/ensemble.py +192 -0
- separator/models/bandit/core/__init__.py +744 -0
- separator/models/bandit/core/data/__init__.py +2 -0
- separator/models/bandit/core/data/_types.py +18 -0
- separator/models/bandit/core/data/augmentation.py +107 -0
- separator/models/bandit/core/data/augmented.py +35 -0
- separator/models/bandit/core/data/base.py +69 -0
- separator/models/bandit/core/data/dnr/__init__.py +0 -0
- separator/models/bandit/core/data/dnr/datamodule.py +74 -0
- separator/models/bandit/core/data/dnr/dataset.py +392 -0
- separator/models/bandit/core/data/dnr/preprocess.py +54 -0
- separator/models/bandit/core/data/musdb/__init__.py +0 -0
- separator/models/bandit/core/data/musdb/datamodule.py +77 -0
- separator/models/bandit/core/data/musdb/dataset.py +280 -0
- separator/models/bandit/core/data/musdb/preprocess.py +238 -0
- separator/models/bandit/core/data/musdb/validation.yaml +15 -0
- separator/models/bandit/core/loss/__init__.py +2 -0
- separator/models/bandit/core/loss/_complex.py +34 -0
- separator/models/bandit/core/loss/_multistem.py +45 -0
- separator/models/bandit/core/loss/_timefreq.py +113 -0
- separator/models/bandit/core/loss/snr.py +146 -0
- separator/models/bandit/core/metrics/__init__.py +9 -0
- separator/models/bandit/core/metrics/_squim.py +383 -0
- separator/models/bandit/core/metrics/snr.py +150 -0
- separator/models/bandit/core/model/__init__.py +3 -0
- separator/models/bandit/core/model/_spectral.py +58 -0
- separator/models/bandit/core/model/bsrnn/__init__.py +23 -0
- separator/models/bandit/core/model/bsrnn/bandsplit.py +139 -0
- separator/models/bandit/core/model/bsrnn/core.py +661 -0
- separator/models/bandit/core/model/bsrnn/maskestim.py +347 -0
- separator/models/bandit/core/model/bsrnn/tfmodel.py +317 -0
- separator/models/bandit/core/model/bsrnn/utils.py +583 -0
- separator/models/bandit/core/model/bsrnn/wrapper.py +882 -0
- separator/models/bandit/core/utils/__init__.py +0 -0
- separator/models/bandit/core/utils/audio.py +463 -0
- separator/models/bandit/model_from_config.py +31 -0
- separator/models/bandit_v2/bandit.py +367 -0
- separator/models/bandit_v2/bandsplit.py +130 -0
- separator/models/bandit_v2/film.py +25 -0
- separator/models/bandit_v2/maskestim.py +281 -0
- separator/models/bandit_v2/tfmodel.py +145 -0
- separator/models/bandit_v2/utils.py +523 -0
- separator/models/bs_roformer/__init__.py +2 -0
- separator/models/bs_roformer/__pycache__/__init__.cpython-310.pyc +0 -0
app.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import time
|
| 4 |
+
import shutil
|
| 5 |
+
import argparse
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
import gradio as gr
|
| 8 |
+
os.system("pip install https://github.com/noblebarkrr/mvsepless/blob/bd611441e48e918650e6860738894673b3a1a5f1/fixed/audio_separator-0.32.0-py3-none-any.whl")
|
| 9 |
+
from multi_inference import MVSEPLESS, OUTPUT_FORMATS
|
| 10 |
+
from assets.translations import TRANSLATIONS, TRANSLATIONS_STEMS
|
| 11 |
+
|
| 12 |
+
OUTPUT_DIR = os.path.join(os.getcwd(), "output")
|
| 13 |
+
plugins_dir = os.path.join(os.getcwd(), "plugins")
|
| 14 |
+
os.makedirs(plugins_dir, exist_ok=True)
|
| 15 |
+
|
| 16 |
+
CURRENT_LANG = "ru"
|
| 17 |
+
|
| 18 |
+
def t(key, **kwargs):
|
| 19 |
+
"""Функция для получения перевода с подстановкой значений"""
|
| 20 |
+
lang = CURRENT_LANG
|
| 21 |
+
translation = TRANSLATIONS.get(lang, {}).get(key, key)
|
| 22 |
+
return translation.format(**kwargs) if kwargs else translation
|
| 23 |
+
|
| 24 |
+
def t_stem(key, **kwargs):
|
| 25 |
+
"""Функция для получения перевода с подстановкой значений"""
|
| 26 |
+
lang = CURRENT_LANG
|
| 27 |
+
translation = TRANSLATIONS_STEMS.get(lang, {}).get(key, key)
|
| 28 |
+
return translation.format(**kwargs) if kwargs else translation
|
| 29 |
+
|
| 30 |
+
def gen_out_dir():
|
| 31 |
+
return os.path.join(OUTPUT_DIR, datetime.now().strftime("%Y%m%d_%H%M%S"))
|
| 32 |
+
|
| 33 |
+
mvsepless = MVSEPLESS()
|
| 34 |
+
|
| 35 |
+
def sep_wrapper(a, b, c, d, e, f, g, h):
|
| 36 |
+
results = mvsepless.separator(input_file=a, output_dir=gen_out_dir(), model_type=b, model_name=c, ext_inst=d, vr_aggr=e, output_format=f, output_bitrate=f'{g}k', call_method="cli", selected_stems=h)
|
| 37 |
+
stems = []
|
| 38 |
+
if results:
|
| 39 |
+
for i, (stem, output_file) in enumerate(results[:20]):
|
| 40 |
+
stems.append(gr.update(
|
| 41 |
+
visible=True,
|
| 42 |
+
label=t_stem(stem),
|
| 43 |
+
value=output_file
|
| 44 |
+
))
|
| 45 |
+
|
| 46 |
+
while len(stems) < 20:
|
| 47 |
+
stems.append(gr.update(visible=False, label=None, value=None))
|
| 48 |
+
|
| 49 |
+
return tuple(stems)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
theme = gr.themes.Default(
|
| 53 |
+
primary_hue="violet",
|
| 54 |
+
secondary_hue="cyan",
|
| 55 |
+
neutral_hue="blue",
|
| 56 |
+
spacing_size="sm",
|
| 57 |
+
font=[gr.themes.GoogleFont("Tektur"), 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
| 58 |
+
).set(
|
| 59 |
+
body_text_color='*neutral_950',
|
| 60 |
+
body_text_color_subdued='*neutral_500',
|
| 61 |
+
background_fill_primary='*neutral_200',
|
| 62 |
+
background_fill_primary_dark='*neutral_800',
|
| 63 |
+
border_color_accent='*primary_950',
|
| 64 |
+
border_color_accent_dark='*neutral_700',
|
| 65 |
+
border_color_accent_subdued='*primary_500',
|
| 66 |
+
border_color_primary='*primary_800',
|
| 67 |
+
border_color_primary_dark='*neutral_400',
|
| 68 |
+
color_accent_soft='*primary_100',
|
| 69 |
+
color_accent_soft_dark='*neutral_800',
|
| 70 |
+
link_text_color='*secondary_700',
|
| 71 |
+
link_text_color_active='*secondary_700',
|
| 72 |
+
link_text_color_hover='*secondary_800',
|
| 73 |
+
link_text_color_visited='*secondary_600',
|
| 74 |
+
link_text_color_visited_dark='*secondary_700',
|
| 75 |
+
block_background_fill='*background_fill_secondary',
|
| 76 |
+
block_background_fill_dark='*neutral_950',
|
| 77 |
+
block_label_background_fill='*secondary_400',
|
| 78 |
+
block_label_text_color='*neutral_800',
|
| 79 |
+
panel_background_fill='*background_fill_primary',
|
| 80 |
+
checkbox_background_color='*background_fill_secondary',
|
| 81 |
+
checkbox_label_background_fill_dark='*neutral_900',
|
| 82 |
+
input_background_fill_dark='*neutral_900',
|
| 83 |
+
input_background_fill_focus='*neutral_100',
|
| 84 |
+
input_background_fill_focus_dark='*neutral_950',
|
| 85 |
+
button_small_radius='*radius_sm',
|
| 86 |
+
button_secondary_background_fill='*neutral_400',
|
| 87 |
+
button_secondary_background_fill_dark='*neutral_500',
|
| 88 |
+
button_secondary_background_fill_hover_dark='*neutral_950'
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def create_app():
|
| 93 |
+
with gr.Row():
|
| 94 |
+
with gr.Column():
|
| 95 |
+
input_audio = gr.Audio(label=t("select_file"), interactive=True, type="filepath")
|
| 96 |
+
input_audio_path = gr.Textbox(label=t("audio_path"), info=t("audio_path_info"), interactive=True)
|
| 97 |
+
with gr.Column():
|
| 98 |
+
with gr.Row():
|
| 99 |
+
model_type = gr.Dropdown(label=t("model_type"), choices=mvsepless.get_mt(), value=mvsepless.get_mt()[0], interactive=True, filterable=False)
|
| 100 |
+
model_name = gr.Dropdown(label=t("model_name"), choices=mvsepless.get_mn(mvsepless.get_mt()[0]), value=mvsepless.get_mn(mvsepless.get_mt()[0])[0], interactive=True, filterable=False)
|
| 101 |
+
target_instrument = gr.Textbox(label=t("target_instrument"), value=mvsepless.get_tgt_inst(mvsepless.get_mt()[0], mvsepless.get_mn(mvsepless.get_mt()[0])[0]), interactive=False)
|
| 102 |
+
vr_aggr = gr.Slider(0, 100, step=1, label=t("vr_aggressiveness"), visible=False, value=5, interactive=True)
|
| 103 |
+
extract_instrumental = gr.Checkbox(label=t("extract_instrumental"), value=True, interactive=True)
|
| 104 |
+
stems_list = gr.CheckboxGroup(label=t("stems_list"), info=t("stems_info", target_instrument="vocals"), choices=mvsepless.get_stems(mvsepless.get_mt()[0], mvsepless.get_mn(mvsepless.get_mt()[0])[0]), value=None, interactive=False)
|
| 105 |
+
with gr.Row():
|
| 106 |
+
output_format, output_bitrate = gr.Dropdown(label=t("output_format"), choices=OUTPUT_FORMATS, value="mp3", interactive=True, filterable=False), gr.Slider(32, 320, step=1, label=t("bitrate"), value=320, interactive=True)
|
| 107 |
+
separate_btn = gr.Button(t("separate_btn"), variant="primary", interactive=True)
|
| 108 |
+
download_via_zip_btn = gr.DownloadButton(label="Download via zip", visible=False, interactive=True)
|
| 109 |
+
output_stems = []
|
| 110 |
+
for _ in range(10):
|
| 111 |
+
with gr.Row():
|
| 112 |
+
audio1 = gr.Audio(visible=False, interactive=False, type="filepath", show_download_button=True)
|
| 113 |
+
audio2 = gr.Audio(visible=False, interactive=False, type="filepath", show_download_button=True)
|
| 114 |
+
output_stems.extend([audio1, audio2])
|
| 115 |
+
|
| 116 |
+
input_audio.upload(fn=(lambda x: gr.update(value=x)), inputs=input_audio, outputs=input_audio_path)
|
| 117 |
+
model_type.change(fn=(lambda x: gr.update(choices=mvsepless.get_mn(x), value=mvsepless.get_mn(x)[0])), inputs=model_type, outputs=model_name).then(fn=(lambda x: (gr.update(visible=False if x in ["vr", "mdx"] else True), gr.update(visible=True if x == "vr" else False))), inputs=model_type, outputs=[extract_instrumental, vr_aggr])
|
| 118 |
+
model_name.change(fn=(lambda x, y: gr.update(choices=mvsepless.get_stems(x, y), value=None)), inputs=[model_type, model_name], outputs=stems_list).then(fn=(lambda x, y: (gr.update(interactive=True if mvsepless.get_tgt_inst(x, y) == None else None, info=t("stems_info", target_instrument=mvsepless.get_tgt_inst(x, y)) if mvsepless.get_tgt_inst(x, y) is not None else t("stems_info2")), gr.update(value=mvsepless.get_tgt_inst(x, y)), gr.update(value=True if mvsepless.get_tgt_inst(x, y) is not None else False))), inputs=[model_type, model_name], outputs=[stems_list, target_instrument, extract_instrumental])
|
| 119 |
+
separate_btn.click(fn=sep_wrapper, inputs=[input_audio_path, model_type, model_name, extract_instrumental, vr_aggr, output_format, output_bitrate, stems_list], outputs=output_stems, show_progress_on=input_audio)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
CURRENT_LANG = ru
|
| 123 |
+
css = """
|
| 124 |
+
.fixed-height { height: 160px !important; min-height: 160px !important; }
|
| 125 |
+
.fixed-height2 { height: 250px !important; min-height: 250px !important; }
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
with gr.Blocks(theme=theme, css=css) as app:
|
| 129 |
+
create_app()
|
| 130 |
+
|
| 131 |
+
app.launch(allowed_paths=["/"], server_port=7860, share=False)
|
assets/translations.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TRANSLATIONS = {
|
| 2 |
+
"ru": {
|
| 3 |
+
"app_title": "MVSEPLESS",
|
| 4 |
+
"separation": "Разделение",
|
| 5 |
+
"plugins": "Плагины",
|
| 6 |
+
"select_file": "Выберите файл",
|
| 7 |
+
"audio_path": "Путь к файлу",
|
| 8 |
+
"audio_path_info": "Здесь можно ввести путь к файлу, либо загрузить его выше и получить путь к загруженному файлу",
|
| 9 |
+
"model_type": "Тип модели",
|
| 10 |
+
"model_name": "Имя модели",
|
| 11 |
+
"vr_aggressiveness": "Агрессивность для VR моделей",
|
| 12 |
+
"extract_instrumental": "Извлечь инструментал",
|
| 13 |
+
"stems_list": "Список стемов",
|
| 14 |
+
"output_format": "Формат вывода",
|
| 15 |
+
"separate_btn": "Разделить",
|
| 16 |
+
"upload": "Загрузка плагинов (.py)",
|
| 17 |
+
"upload_btn": "Загрузить",
|
| 18 |
+
"loading_plugin": "Загружается плагин: {name}",
|
| 19 |
+
"error_loading_plugin": "Произошла ошибка при загрузке плагина: {e}",
|
| 20 |
+
"target_instrument": "Целевой инструмент",
|
| 21 |
+
"stems_info": "Выбор стемов недоступен\nДля извлечения второго стема включите \"Извлечь инструментал\"",
|
| 22 |
+
"stems_info2": "Для получения остатка (при выбранных стемах), включите \"Извлечь инструментал\"",
|
| 23 |
+
"bitrate": "Битрейт (Кбит/сек)"
|
| 24 |
+
},
|
| 25 |
+
"en": {
|
| 26 |
+
"app_title": "MVSEPLESS",
|
| 27 |
+
"separation": "Separation",
|
| 28 |
+
"plugins": "Plugins",
|
| 29 |
+
"select_file": "Select File",
|
| 30 |
+
"audio_path": "Audio path",
|
| 31 |
+
"audio_path_info": "You can enter the file path here, or upload it above and get the path to the uploaded file.",
|
| 32 |
+
"model_type": "Model Type",
|
| 33 |
+
"model_name": "Model Name",
|
| 34 |
+
"vr_aggressiveness": "Aggressiveness for VR Models",
|
| 35 |
+
"extract_instrumental": "Extract Instrumental",
|
| 36 |
+
"stems_list": "Stems List",
|
| 37 |
+
"output_format": "Output Format",
|
| 38 |
+
"separate_btn": "Separate",
|
| 39 |
+
"upload": "Upload plugins (.py)",
|
| 40 |
+
"upload_btn": "Upload",
|
| 41 |
+
"loading_plugin": "Loading plugin: {name}",
|
| 42 |
+
"error_loading_plugin": "As error occured loading plugin: {e}",
|
| 43 |
+
"target_instrument": "Target instrument",
|
| 44 |
+
"stems_info": "Stem selection unavailable\nEnable \"Extract Instrumental\" to extract the second stem",
|
| 45 |
+
"stems_info2": "To extract the residual (with selected_stems), enable \"Extract Instrumental\"",
|
| 46 |
+
"bitrate": "Bitrate (Kbit/sec)"
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
TRANSLATIONS_STEMS = {
|
| 51 |
+
"ru": {
|
| 52 |
+
"vocals": "Вокал",
|
| 53 |
+
"Vocals": "Вокал",
|
| 54 |
+
"other": "Другое",
|
| 55 |
+
"Other": "Другое",
|
| 56 |
+
"Instrumental": "Инструментал",
|
| 57 |
+
"instrumnetal": "Инструментал",
|
| 58 |
+
"instrumental +": "Инструментал +",
|
| 59 |
+
"instrumental -": "Инструментал -",
|
| 60 |
+
"Bleed": "Фон",
|
| 61 |
+
"Guitar": "Гитара",
|
| 62 |
+
"drums": "Барабаны",
|
| 63 |
+
"bass": "Бас",
|
| 64 |
+
"karaoke": "Караоке",
|
| 65 |
+
"reverb": "Реверберация",
|
| 66 |
+
"noreverb": "Без реверберации",
|
| 67 |
+
"aspiration": "Придыхание",
|
| 68 |
+
"dry": "Сухой звук",
|
| 69 |
+
"crowd": "Толпа",
|
| 70 |
+
"percussions": "Перкуссия",
|
| 71 |
+
"piano": "Пианино",
|
| 72 |
+
"guitar": "Гитара",
|
| 73 |
+
"male": "Мужской",
|
| 74 |
+
"female": "Женский",
|
| 75 |
+
"kick": "Кик",
|
| 76 |
+
"snare": "Малый барабан",
|
| 77 |
+
"toms": "Том-томы",
|
| 78 |
+
"hh": "Хай-хэт",
|
| 79 |
+
"ride": "Райд",
|
| 80 |
+
"crash": "Крэш",
|
| 81 |
+
"similarity": "Сходство",
|
| 82 |
+
"difference": "Различие",
|
| 83 |
+
"inst": "Инструмент",
|
| 84 |
+
"orch": "Оркестр",
|
| 85 |
+
"No Woodwinds": "Без деревянных духовых",
|
| 86 |
+
"Woodwinds": "Деревянные духовые",
|
| 87 |
+
"No Echo": "Без эха",
|
| 88 |
+
"Echo": "Эхо",
|
| 89 |
+
"No Reverb": "Без реверберации",
|
| 90 |
+
"Reverb": "Реверберация",
|
| 91 |
+
"Noise": "Шум",
|
| 92 |
+
"No Noise": "Без шума",
|
| 93 |
+
"Dry": "Сухой звук",
|
| 94 |
+
"No Dry": "Не сухой звук",
|
| 95 |
+
"Breath": "Дыхание",
|
| 96 |
+
"No Breath": "Без дыхания",
|
| 97 |
+
"No Crowd": "Без толпы",
|
| 98 |
+
"Crowd": "Толпа",
|
| 99 |
+
"No Other": "Без другого",
|
| 100 |
+
"Bass": "Бас",
|
| 101 |
+
"No Bass": "Без баса",
|
| 102 |
+
"Drums": "Барабаны",
|
| 103 |
+
"No Drums": "Без барабанов",
|
| 104 |
+
"speech": "Речь",
|
| 105 |
+
"music": "Музыка",
|
| 106 |
+
"effects": "Эффекты",
|
| 107 |
+
"sfx": "Звуковые эффекты",
|
| 108 |
+
"inverted +": "Инверсия +",
|
| 109 |
+
"inverted -": "Инверсия -"
|
| 110 |
+
},
|
| 111 |
+
"en": {
|
| 112 |
+
"vocals": "Vocals",
|
| 113 |
+
"Vocals": "Vocals",
|
| 114 |
+
"other": "Other",
|
| 115 |
+
"Other": "Other",
|
| 116 |
+
"Instrumental": "Instrumental",
|
| 117 |
+
"instrumnetal": "Instrumental",
|
| 118 |
+
"instrumental +": "Instrumental +",
|
| 119 |
+
"instrumental -": "Instrumental -",
|
| 120 |
+
"Bleed": "Bleed",
|
| 121 |
+
"Guitar": "Guitar",
|
| 122 |
+
"drums": "Drums",
|
| 123 |
+
"bass": "Bass",
|
| 124 |
+
"karaoke": "Karaoke",
|
| 125 |
+
"reverb": "Reverb",
|
| 126 |
+
"noreverb": "No reverb",
|
| 127 |
+
"aspiration": "Aspiration",
|
| 128 |
+
"dry": "Dry",
|
| 129 |
+
"crowd": "Crowd",
|
| 130 |
+
"percussions": "Percussions",
|
| 131 |
+
"piano": "Piano",
|
| 132 |
+
"guitar": "Guitar",
|
| 133 |
+
"male": "Male",
|
| 134 |
+
"female": "Female",
|
| 135 |
+
"kick": "Kick",
|
| 136 |
+
"snare": "Snare",
|
| 137 |
+
"toms": "Toms",
|
| 138 |
+
"hh": "Hi-hat",
|
| 139 |
+
"ride": "Ride",
|
| 140 |
+
"crash": "Crash",
|
| 141 |
+
"similarity": "Similarity",
|
| 142 |
+
"difference": "Difference",
|
| 143 |
+
"inst": "Instrument",
|
| 144 |
+
"orch": "Orchestra",
|
| 145 |
+
"No Woodwinds": "No Woodwinds",
|
| 146 |
+
"Woodwinds": "Woodwinds",
|
| 147 |
+
"No Echo": "No Echo",
|
| 148 |
+
"Echo": "Echo",
|
| 149 |
+
"No Reverb": "No Reverb",
|
| 150 |
+
"Reverb": "Reverb",
|
| 151 |
+
"Noise": "Noise",
|
| 152 |
+
"No Noise": "No Noise",
|
| 153 |
+
"Dry": "Dry",
|
| 154 |
+
"No Dry": "No Dry",
|
| 155 |
+
"Breath": "Breath",
|
| 156 |
+
"No Breath": "No Breath",
|
| 157 |
+
"No Crowd": "No Crowd",
|
| 158 |
+
"Crowd": "Crowd",
|
| 159 |
+
"No Other": "No Other",
|
| 160 |
+
"Bass": "Bass",
|
| 161 |
+
"No Bass": "No Bass",
|
| 162 |
+
"Drums": "Drums",
|
| 163 |
+
"No Drums": "No Drums",
|
| 164 |
+
"speech": "Speech",
|
| 165 |
+
"music": "Music",
|
| 166 |
+
"effects": "Effects",
|
| 167 |
+
"sfx": "SFX",
|
| 168 |
+
"inverted +": "Inverted +",
|
| 169 |
+
"inverted -": "Inverted -"
|
| 170 |
+
}
|
| 171 |
+
}
|
model_list.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
multi_inference.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import shutil
|
| 4 |
+
import sys
|
| 5 |
+
import gc
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import subprocess
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from tabulate import tabulate
|
| 11 |
+
|
| 12 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 13 |
+
sys.path.append(SCRIPT_DIR)
|
| 14 |
+
os.chdir(SCRIPT_DIR)
|
| 15 |
+
|
| 16 |
+
from model_list import models_data
|
| 17 |
+
from utils.preedit_config import conf_editor
|
| 18 |
+
from utils.download_models import download_model
|
| 19 |
+
|
| 20 |
+
MODELS_CACHE_DIR = os.path.join(SCRIPT_DIR, "separator", "models_cache")
|
| 21 |
+
MODEL_TYPES = ["mel_band_roformer", "bs_roformer", "mdx23c", "scnet", "htdemucs", "bandit", "bandit_v2", "vr", "mdx"]
|
| 22 |
+
OUTPUT_FORMATS = ["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff"]
|
| 23 |
+
|
| 24 |
+
class MVSEPLESS:
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.models_cache_dir = os.path.join(SCRIPT_DIR, "separator", "models_cache")
|
| 27 |
+
self.model_types = MODEL_TYPES
|
| 28 |
+
self.output_formats = OUTPUT_FORMATS
|
| 29 |
+
|
| 30 |
+
def get_mt(self):
|
| 31 |
+
return list(models_data.keys())
|
| 32 |
+
|
| 33 |
+
def get_mn(self, model_type):
|
| 34 |
+
return list(models_data[model_type].keys())
|
| 35 |
+
|
| 36 |
+
def get_stems(self, model_type, model_name):
|
| 37 |
+
stems = models_data[model_type][model_name]["stems"]
|
| 38 |
+
return stems
|
| 39 |
+
|
| 40 |
+
def get_tgt_inst(self, model_type, model_name):
|
| 41 |
+
target_instrument = models_data[model_type][model_name]["target_instrument"]
|
| 42 |
+
return target_instrument
|
| 43 |
+
|
| 44 |
+
def display_models_info(self, filter: str = None):
|
| 45 |
+
print("\nAvailable Models Information:")
|
| 46 |
+
print("=" * 50)
|
| 47 |
+
|
| 48 |
+
for model_type in models_data:
|
| 49 |
+
print(f"\nModel Type: {model_type.upper()}")
|
| 50 |
+
print("-" * 50)
|
| 51 |
+
|
| 52 |
+
table_data = []
|
| 53 |
+
headers = ["Model Name", "Stems", "Target Instrument", "Primary Stem"]
|
| 54 |
+
|
| 55 |
+
for model_name in models_data[model_type]:
|
| 56 |
+
model_info = models_data[model_type][model_name]
|
| 57 |
+
|
| 58 |
+
if filter and filter not in model_info.get('stems', []):
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
stems = "\n".join(model_info.get('stems', [])) if 'stems' in model_info else "N/A"
|
| 62 |
+
target = model_info.get('target_instrument', "N/A")
|
| 63 |
+
primary = model_info.get('primary_stem', "N/A")
|
| 64 |
+
|
| 65 |
+
table_data.append([model_name, stems, target, primary])
|
| 66 |
+
|
| 67 |
+
print(tabulate(table_data, headers=headers, tablefmt="grid"))
|
| 68 |
+
print()
|
| 69 |
+
|
| 70 |
+
def separator(
|
| 71 |
+
self,
|
| 72 |
+
input_file: str = None,
|
| 73 |
+
output_dir: str = None,
|
| 74 |
+
model_type: str = "mel_band_roformer",
|
| 75 |
+
model_name: str = "Mel-Band-Roformer_Vocals_kimberley_jensen",
|
| 76 |
+
ext_inst: bool = False,
|
| 77 |
+
vr_aggr: int = 5,
|
| 78 |
+
output_format: str = "wav",
|
| 79 |
+
output_bitrate: str = "320k",
|
| 80 |
+
template: str = "NAME_(STEM)_MODEL",
|
| 81 |
+
call_method: str = "cli",
|
| 82 |
+
selected_stems: list = None
|
| 83 |
+
):
|
| 84 |
+
if selected_stems is None:
|
| 85 |
+
selected_stems = []
|
| 86 |
+
|
| 87 |
+
if not input_file:
|
| 88 |
+
print("Please, input path to input file")
|
| 89 |
+
return [("None", "/none/none.mp3")]
|
| 90 |
+
|
| 91 |
+
if not os.path.exists(input_file):
|
| 92 |
+
print("Input file not exist")
|
| 93 |
+
return [("None", "/none/none.mp3")]
|
| 94 |
+
|
| 95 |
+
if "STEM" not in template:
|
| 96 |
+
template = template + "_STEM"
|
| 97 |
+
|
| 98 |
+
print(f"Starting inference: {model_type}/{model_name}, bitrate={output_bitrate}, method={call_method}, stems={selected_stems}")
|
| 99 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 100 |
+
|
| 101 |
+
if model_type in ["mel_band_roformer", "bs_roformer", "mdx23c", "scnet", "htdemucs", "bandit", "bandit_v2"]:
|
| 102 |
+
try:
|
| 103 |
+
info = models_data[model_type][model_name]
|
| 104 |
+
except KeyError:
|
| 105 |
+
print("Model not exist")
|
| 106 |
+
return [("None", "/none/none.mp3")]
|
| 107 |
+
|
| 108 |
+
conf, ckpt = download_model(self.models_cache_dir, model_name, model_type,
|
| 109 |
+
info["checkpoint_url"], info["config_url"])
|
| 110 |
+
if model_type != "htdemucs":
|
| 111 |
+
conf_editor(conf)
|
| 112 |
+
|
| 113 |
+
if call_method == "cli":
|
| 114 |
+
cmd = ["python", "-m", "separator.msst_separator", f'--input "{input_file}"',
|
| 115 |
+
f'--store_dir "{output_dir}"', f'--model_type "{model_type}"',
|
| 116 |
+
f'--model_name "{model_name}"', f'--config_path "{conf}"',
|
| 117 |
+
f'--start_check_point "{ckpt}"', f'--output_format "{output_format}"',
|
| 118 |
+
f'--output_bitrate "{output_bitrate}"', f'--template "{template}"',
|
| 119 |
+
"--save_results_info"]
|
| 120 |
+
if ext_inst:
|
| 121 |
+
cmd.append("--extract_instrumental")
|
| 122 |
+
if selected_stems:
|
| 123 |
+
instruments = " ".join(f'"{s}"' for s in selected_stems)
|
| 124 |
+
cmd.append(f'--selected_instruments {instruments}')
|
| 125 |
+
subprocess.run(" ".join(cmd), shell=True, check=True)
|
| 126 |
+
|
| 127 |
+
results_path = os.path.join(output_dir, "results.json")
|
| 128 |
+
if os.path.exists(results_path):
|
| 129 |
+
with open(results_path, encoding="utf-8") as f:
|
| 130 |
+
return json.load(f)
|
| 131 |
+
return [("None", "/none/none.mp3")]
|
| 132 |
+
|
| 133 |
+
elif call_method == "direct":
|
| 134 |
+
from separator.msst_separator import mvsep_offline
|
| 135 |
+
try:
|
| 136 |
+
return mvsep_offline(
|
| 137 |
+
input_path=input_file, store_dir=output_dir, model_type=model_type,
|
| 138 |
+
config_path=conf, start_check_point=ckpt, extract_instrumental=ext_inst,
|
| 139 |
+
output_format=output_format, output_bitrate=output_bitrate,
|
| 140 |
+
model_name=model_name, template=template, selected_instruments=selected_stems
|
| 141 |
+
)
|
| 142 |
+
except Exception as e:
|
| 143 |
+
print(e)
|
| 144 |
+
return [("None", "/none/none.mp3")]
|
| 145 |
+
|
| 146 |
+
elif model_type in ["vr", "mdx"]:
|
| 147 |
+
try:
|
| 148 |
+
info = models_data[model_type][model_name]
|
| 149 |
+
except KeyError:
|
| 150 |
+
print("Model not exist")
|
| 151 |
+
return [("None", "/none/none.mp3")]
|
| 152 |
+
|
| 153 |
+
if model_type == "vr" and info.get("custom_vr", False):
|
| 154 |
+
conf, ckpt = download_model(self.models_cache_dir, model_name, model_type,
|
| 155 |
+
info["checkpoint_url"], info["config_url"])
|
| 156 |
+
primary_stem = info["primary_stem"]
|
| 157 |
+
|
| 158 |
+
if call_method == "cli":
|
| 159 |
+
cmd = ["python", "-m", "separator.uvr_sep", "custom_vr",
|
| 160 |
+
f'--input_file "{input_file}"', f'--ckpt_path "{ckpt}"',
|
| 161 |
+
f'--config_path "{conf}"', f'--bitrate "{output_bitrate}"',
|
| 162 |
+
f'--model_name "{model_name}"', f'--template "{template}"',
|
| 163 |
+
f'--output_format "{output_format}"', f'--primary_stem "{primary_stem}"',
|
| 164 |
+
f'--aggression {vr_aggr}', f'--output_dir "{output_dir}"']
|
| 165 |
+
if selected_stems:
|
| 166 |
+
instruments = " ".join(f'"{s}"' for s in selected_stems)
|
| 167 |
+
cmd.append(f'--selected_instruments {instruments}')
|
| 168 |
+
subprocess.run(" ".join(cmd), shell=True, check=True)
|
| 169 |
+
|
| 170 |
+
results_path = os.path.join(output_dir, "results.json")
|
| 171 |
+
if os.path.exists(results_path):
|
| 172 |
+
with open(results_path, encoding="utf-8") as f:
|
| 173 |
+
return json.load(f)
|
| 174 |
+
return [("None", "/none/none.mp3")]
|
| 175 |
+
|
| 176 |
+
elif call_method == "direct":
|
| 177 |
+
from separator.uvr_sep import custom_vr_separate
|
| 178 |
+
try:
|
| 179 |
+
return custom_vr_separate(
|
| 180 |
+
input_file=input_file, ckpt_path=ckpt, config_path=conf,
|
| 181 |
+
bitrate=output_bitrate, model_name=model_name, template=template,
|
| 182 |
+
output_format=output_format, primary_stem=primary_stem,
|
| 183 |
+
aggression=vr_aggr, output_dir=output_dir,
|
| 184 |
+
selected_instruments=selected_stems
|
| 185 |
+
)
|
| 186 |
+
except Exception as e:
|
| 187 |
+
print(e)
|
| 188 |
+
return [("None", "/none/none.mp3")]
|
| 189 |
+
else:
|
| 190 |
+
if call_method == "cli":
|
| 191 |
+
cmd = ["python", "-m", "separator.uvr_sep", "uvr",
|
| 192 |
+
f'--input_file "{input_file}"', f'--output_dir "{output_dir}"',
|
| 193 |
+
f'--template "{template}"', f'--bitrate "{output_bitrate}"',
|
| 194 |
+
f'--model_dir "{self.models_cache_dir}"', f'--model_type "{model_type}"',
|
| 195 |
+
f'--model_name "{model_name}"', f'--output_format "{output_format}"',
|
| 196 |
+
f'--aggression {vr_aggr}']
|
| 197 |
+
if selected_stems:
|
| 198 |
+
instruments = " ".join(f'"{s}"' for s in selected_stems)
|
| 199 |
+
cmd.append(f'--selected_instruments {instruments}')
|
| 200 |
+
subprocess.run(" ".join(cmd), shell=True, check=True)
|
| 201 |
+
|
| 202 |
+
results_path = os.path.join(output_dir, "results.json")
|
| 203 |
+
if os.path.exists(results_path):
|
| 204 |
+
with open(results_path, encoding="utf-8") as f:
|
| 205 |
+
return json.load(f)
|
| 206 |
+
return [("None", "/none/none.mp3")]
|
| 207 |
+
|
| 208 |
+
elif call_method == "direct":
|
| 209 |
+
from separator.uvr_sep import non_custom_uvr_inference
|
| 210 |
+
try:
|
| 211 |
+
return non_custom_uvr_inference(
|
| 212 |
+
input_file=input_file, output_dir=output_dir, template=template,
|
| 213 |
+
bitrate=output_bitrate, model_dir=self.models_cache_dir,
|
| 214 |
+
model_type=model_type, model_name=model_name,
|
| 215 |
+
output_format=output_format, aggression=vr_aggr,
|
| 216 |
+
selected_instruments=selected_stems
|
| 217 |
+
)
|
| 218 |
+
except Exception as e:
|
| 219 |
+
print(e)
|
| 220 |
+
return [("None", "/none/none.mp3")]
|
| 221 |
+
|
| 222 |
+
print("Unsupported model type")
|
| 223 |
+
return [("None", "/none/none.mp3")]
|
| 224 |
+
|
| 225 |
+
def parse_args():
|
| 226 |
+
parser = argparse.ArgumentParser(description="Multi-inference for separation audio in Google Colab")
|
| 227 |
+
subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-command help')
|
| 228 |
+
|
| 229 |
+
list_models = subparsers.add_parser('list', help='List of exist models')
|
| 230 |
+
list_models.add_argument("-l_filter", "--list_filter", type=str, default=None, help="Show models in list only with specified stem")
|
| 231 |
+
|
| 232 |
+
separate = subparsers.add_parser('separate', help='Separate I/O params')
|
| 233 |
+
separate.add_argument("-i", "--input", type=str, required=True, help="Input file or directory")
|
| 234 |
+
separate.add_argument("-o", "--output", type=str, required=True, help="Output directory")
|
| 235 |
+
separate.add_argument("-mt", "--model_type", type=str, required=True, choices=MODEL_TYPES, help="Model type")
|
| 236 |
+
separate.add_argument("-mn", "--model_name", type=str, required=True, help="Model name")
|
| 237 |
+
separate.add_argument("-inst", "--instrumental", action='store_true', help="Extract instrumental")
|
| 238 |
+
separate.add_argument("-stems", "--stems", nargs="+", help="Select output stems")
|
| 239 |
+
separate.add_argument("-bitrate", "--bitrate", type=str, default="320k", help="Output bitrate")
|
| 240 |
+
separate.add_argument("-of", "--format", type=str, default="mp3", help="Output format")
|
| 241 |
+
separate.add_argument("-vr_aggr", "--vr_arch_aggressive", type=int, default=5, help="Aggression for VR ARCH models")
|
| 242 |
+
separate.add_argument('--template', type=str, default='NAME_STEM', help='Template naming of output files')
|
| 243 |
+
separate.add_argument("-l_out", "--list_output", action='store_true', help="Show list output files")
|
| 244 |
+
|
| 245 |
+
return parser.parse_args()
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
args = parse_args()
|
| 249 |
+
mvsepless = MVSEPLESS()
|
| 250 |
+
|
| 251 |
+
if args.command == 'list':
|
| 252 |
+
mvsepless.display_models_info(args.list_filter)
|
| 253 |
+
|
| 254 |
+
elif args.command == 'separate':
|
| 255 |
+
if os.path.isfile(args.input):
|
| 256 |
+
results = mvsepless.separator(
|
| 257 |
+
input_file=args.input,
|
| 258 |
+
output_dir=args.output,
|
| 259 |
+
model_type=args.model_type,
|
| 260 |
+
model_name=args.model_name,
|
| 261 |
+
ext_inst=args.instrumental,
|
| 262 |
+
vr_aggr=args.vr_arch_aggressive,
|
| 263 |
+
output_format=args.format,
|
| 264 |
+
output_bitrate=args.bitrate,
|
| 265 |
+
template=args.template,
|
| 266 |
+
call_method="cli",
|
| 267 |
+
selected_stems=args.stems
|
| 268 |
+
)
|
| 269 |
+
if args.list_output:
|
| 270 |
+
print("Results\n")
|
| 271 |
+
for stem, path in results:
|
| 272 |
+
print(f"Stem - {stem}\nPath - {path}\n")
|
| 273 |
+
|
| 274 |
+
elif os.path.isdir(args.input):
|
| 275 |
+
batch_results = []
|
| 276 |
+
for file in os.listdir(args.input):
|
| 277 |
+
abs_path_file = os.path.join(args.input, file)
|
| 278 |
+
if os.path.isfile(abs_path_file):
|
| 279 |
+
base_name = os.path.splitext(os.path.basename(abs_path_file))[0]
|
| 280 |
+
output_subdir = os.path.join(args.output, base_name)
|
| 281 |
+
|
| 282 |
+
results = mvsepless.separator(
|
| 283 |
+
input_file=abs_path_file,
|
| 284 |
+
output_dir=output_subdir,
|
| 285 |
+
model_type=args.model_type,
|
| 286 |
+
model_name=args.model_name,
|
| 287 |
+
ext_inst=args.instrumental,
|
| 288 |
+
vr_aggr=args.vr_arch_aggressive,
|
| 289 |
+
output_format=args.format,
|
| 290 |
+
output_bitrate=args.bitrate,
|
| 291 |
+
template=args.template,
|
| 292 |
+
call_method="cli",
|
| 293 |
+
selected_stems=args.stems
|
| 294 |
+
)
|
| 295 |
+
batch_results.append((base_name, results))
|
| 296 |
+
|
| 297 |
+
if args.list_output:
|
| 298 |
+
print("Results\n")
|
| 299 |
+
for name, stems in batch_results:
|
| 300 |
+
print(f"Name - {name}")
|
| 301 |
+
for stem, path in stems:
|
| 302 |
+
print(f" Stem - {stem}\n Path - {path}\n")
|
| 303 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.6.0
|
| 2 |
+
torchvision==0.21.0
|
| 3 |
+
torchaudio==2.6.0
|
| 4 |
+
torchcrepe==0.0.23
|
| 5 |
+
numpy==2.0.2
|
| 6 |
+
pandas==2.2.2
|
| 7 |
+
scipy==1.15.3
|
| 8 |
+
librosa==0.9.1
|
| 9 |
+
matplotlib==3.9.0
|
| 10 |
+
tqdm==4.67.1
|
| 11 |
+
einops==0.8.1
|
| 12 |
+
protobuf==5.29.4
|
| 13 |
+
soundfile==0.13.1
|
| 14 |
+
pydub==0.25.1
|
| 15 |
+
pyloudnorm==0.1.1
|
| 16 |
+
praat-parselmouth==0.4.5
|
| 17 |
+
webrtcvad==2.0.10
|
| 18 |
+
edge-tts==7.0.2
|
| 19 |
+
audiomentations==0.24.0
|
| 20 |
+
pedalboard==0.8.1
|
| 21 |
+
ffmpeg-python==0.2.0
|
| 22 |
+
faiss-cpu==1.11
|
| 23 |
+
ml_collections==1.1.0
|
| 24 |
+
timm==1.0.15
|
| 25 |
+
wandb==0.19.11
|
| 26 |
+
accelerate==1.7.0
|
| 27 |
+
bitsandbytes==0.46.0
|
| 28 |
+
tokenizers==0.19
|
| 29 |
+
huggingface-hub==0.28.1
|
| 30 |
+
transformers==4.41
|
| 31 |
+
https://github.com/noblebarkrr/mvsepless/blob/bd611441e48e918650e6860738894673b3a1a5f1/fixed/fairseq_fixed-0.13.0-cp311-cp311-linux_x86_64.whl
|
| 32 |
+
torchseg==0.0.1a4
|
| 33 |
+
demucs==4.0.0
|
| 34 |
+
asteroid==0.7.0
|
| 35 |
+
prodigyopt==1.1.2
|
| 36 |
+
torch_log_wmse==0.3.0
|
| 37 |
+
rotary_embedding_torch==0.6.5
|
| 38 |
+
local-attention==1.11.1
|
| 39 |
+
tenacity==9.1.2
|
| 40 |
+
gradio==5.38.2
|
| 41 |
+
omegaconf==2.3.0
|
| 42 |
+
beartype==0.18.5
|
| 43 |
+
spafe==0.3.2
|
| 44 |
+
torch_audiomentations==0.12.0
|
| 45 |
+
auraloss==0.4.0
|
| 46 |
+
onnxruntime-gpu>=1.17
|
| 47 |
+
yt_dlp
|
| 48 |
+
python-magic
|
| 49 |
+
pyngrok
|
| 50 |
+
|
separator/audio_writer.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydub import AudioSegment
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
def write_audio_file(output_file_path, numpy_array, sample_rate, output_format, bitrate):
|
| 5 |
+
"""
|
| 6 |
+
Записывает аудиофайл из numpy массива в указанном формате с помощью pydub.
|
| 7 |
+
|
| 8 |
+
Параметры:
|
| 9 |
+
output_file_path (str): Путь для сохранения файла (без расширения)
|
| 10 |
+
numpy_array (numpy.ndarray): Аудиоданные в виде numpy массива
|
| 11 |
+
sample_rate (int): Частота дискретизации (в Гц)
|
| 12 |
+
output_format (str): Формат выходного файла ('mp3', 'flac', 'wav', 'aiff', 'm4a', 'aac', 'ogg', 'opus')
|
| 13 |
+
encoder_settings (dict, optional): Cловарь с настройками кодировки аудио
|
| 14 |
+
"""
|
| 15 |
+
try:
|
| 16 |
+
# Проверка и нормализация входных данных
|
| 17 |
+
if not isinstance(numpy_array, np.ndarray):
|
| 18 |
+
raise ValueError("Input must be a numpy array")
|
| 19 |
+
|
| 20 |
+
# Преобразование в правильную форму (samples, channels)
|
| 21 |
+
if len(numpy_array.shape) == 1:
|
| 22 |
+
numpy_array = numpy_array.reshape(-1, 1) # Моно
|
| 23 |
+
elif len(numpy_array.shape) == 2:
|
| 24 |
+
if numpy_array.shape[0] == 2: # Если (channels, samples)
|
| 25 |
+
numpy_array = numpy_array.T # Транспонируем в (samples, channels)
|
| 26 |
+
else:
|
| 27 |
+
raise ValueError("Input array must be 1D or 2D")
|
| 28 |
+
|
| 29 |
+
# Нормализация до диапазона [-1.0, 1.0] если нужно
|
| 30 |
+
if np.issubdtype(numpy_array.dtype, np.floating):
|
| 31 |
+
numpy_array = np.clip(numpy_array, -1.0, 1.0)
|
| 32 |
+
numpy_array = (numpy_array * 32767).astype(np.int16)
|
| 33 |
+
elif numpy_array.dtype != np.int16:
|
| 34 |
+
numpy_array = numpy_array.astype(np.int16)
|
| 35 |
+
|
| 36 |
+
# Создание AudioSegment
|
| 37 |
+
if numpy_array.shape[1] == 1: # Моно
|
| 38 |
+
audio_segment = AudioSegment(
|
| 39 |
+
numpy_array.tobytes(),
|
| 40 |
+
frame_rate=sample_rate,
|
| 41 |
+
sample_width=2, # 16-bit = 2 bytes
|
| 42 |
+
channels=1
|
| 43 |
+
)
|
| 44 |
+
else: # Стерео
|
| 45 |
+
# Для стерео нужно чередовать байты левого и правого каналов
|
| 46 |
+
interleaved = np.empty((numpy_array.shape[0] * 2,), dtype=np.int16)
|
| 47 |
+
interleaved[0::2] = numpy_array[:, 0] # Левый канал
|
| 48 |
+
interleaved[1::2] = numpy_array[:, 1] # Правый канал
|
| 49 |
+
audio_segment = AudioSegment(
|
| 50 |
+
interleaved.tobytes(),
|
| 51 |
+
frame_rate=sample_rate,
|
| 52 |
+
sample_width=2,
|
| 53 |
+
channels=2
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Формирование параметров экспорта
|
| 57 |
+
|
| 58 |
+
parameters = {}
|
| 59 |
+
if bitrate:
|
| 60 |
+
parameters['bitrate'] = bitrate
|
| 61 |
+
|
| 62 |
+
# Поддержка различных форматов
|
| 63 |
+
format_mapping = {
|
| 64 |
+
'mp3': 'mp3',
|
| 65 |
+
'flac': 'flac',
|
| 66 |
+
'wav': 'wav',
|
| 67 |
+
'aiff': 'aiff',
|
| 68 |
+
'm4a': 'ipod', # для m4a в pydub используется кодек ipod
|
| 69 |
+
'aac': 'adts', # для aac в pydub используется adts
|
| 70 |
+
'ogg': 'ogg',
|
| 71 |
+
'opus': 'opus'
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
if output_format not in format_mapping:
|
| 75 |
+
raise ValueError(f"Unsupported format: {output_format}. Supported formats are: {list(format_mapping.keys())}")
|
| 76 |
+
|
| 77 |
+
# Добавление расширения файла, если его нет
|
| 78 |
+
if not output_file_path.lower().endswith(f'.{output_format}'):
|
| 79 |
+
output_file_path = f"{output_file_path}.{output_format}"
|
| 80 |
+
|
| 81 |
+
# Экспорт в нужный формат
|
| 82 |
+
audio_segment.export(output_file_path, format=format_mapping[output_format], **parameters)
|
| 83 |
+
|
| 84 |
+
except Exception as e:
|
| 85 |
+
raise RuntimeError(f"Error writing audio file: {str(e)}")
|
separator/ensemble.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import librosa
|
| 7 |
+
import tempfile
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
import numpy as np
|
| 10 |
+
import argparse
|
| 11 |
+
from separator.audio_writer import write_audio_file
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def stft(wave, nfft, hl):
|
| 15 |
+
wave_left = np.asfortranarray(wave[0])
|
| 16 |
+
wave_right = np.asfortranarray(wave[1])
|
| 17 |
+
spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
|
| 18 |
+
spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
|
| 19 |
+
spec = np.asfortranarray([spec_left, spec_right])
|
| 20 |
+
return spec
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def istft(spec, hl, length):
|
| 24 |
+
spec_left = np.asfortranarray(spec[0])
|
| 25 |
+
spec_right = np.asfortranarray(spec[1])
|
| 26 |
+
wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
|
| 27 |
+
wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
|
| 28 |
+
wave = np.asfortranarray([wave_left, wave_right])
|
| 29 |
+
return wave
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def absmax(a, *, axis):
|
| 33 |
+
dims = list(a.shape)
|
| 34 |
+
dims.pop(axis)
|
| 35 |
+
indices = np.ogrid[tuple(slice(0, d) for d in dims)]
|
| 36 |
+
argmax = np.abs(a).argmax(axis=axis)
|
| 37 |
+
# Convert indices to list before insertion
|
| 38 |
+
indices = list(indices)
|
| 39 |
+
indices.insert(axis % len(a.shape), argmax)
|
| 40 |
+
return a[tuple(indices)]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def absmin(a, *, axis):
|
| 44 |
+
dims = list(a.shape)
|
| 45 |
+
dims.pop(axis)
|
| 46 |
+
indices = np.ogrid[tuple(slice(0, d) for d in dims)]
|
| 47 |
+
argmax = np.abs(a).argmin(axis=axis)
|
| 48 |
+
indices.insert((len(a.shape) + axis) % len(a.shape), argmax)
|
| 49 |
+
return a[tuple(indices)]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def lambda_max(arr, axis=None, key=None, keepdims=False):
|
| 53 |
+
idxs = np.argmax(key(arr), axis)
|
| 54 |
+
if axis is not None:
|
| 55 |
+
idxs = np.expand_dims(idxs, axis)
|
| 56 |
+
result = np.take_along_axis(arr, idxs, axis)
|
| 57 |
+
if not keepdims:
|
| 58 |
+
result = np.squeeze(result, axis=axis)
|
| 59 |
+
return result
|
| 60 |
+
else:
|
| 61 |
+
return arr.flatten()[idxs]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def lambda_min(arr, axis=None, key=None, keepdims=False):
|
| 65 |
+
idxs = np.argmin(key(arr), axis)
|
| 66 |
+
if axis is not None:
|
| 67 |
+
idxs = np.expand_dims(idxs, axis)
|
| 68 |
+
result = np.take_along_axis(arr, idxs, axis)
|
| 69 |
+
if not keepdims:
|
| 70 |
+
result = np.squeeze(result, axis=axis)
|
| 71 |
+
return result
|
| 72 |
+
else:
|
| 73 |
+
return arr.flatten()[idxs]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def average_waveforms(pred_track, weights, algorithm):
|
| 77 |
+
"""
|
| 78 |
+
:param pred_track: shape = (num, channels, length)
|
| 79 |
+
:param weights: shape = (num, )
|
| 80 |
+
:param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
|
| 81 |
+
:return: averaged waveform in shape (channels, length)
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
pred_track = np.array(pred_track)
|
| 85 |
+
final_length = pred_track.shape[-1]
|
| 86 |
+
|
| 87 |
+
mod_track = []
|
| 88 |
+
for i in range(pred_track.shape[0]):
|
| 89 |
+
if algorithm == 'avg_wave':
|
| 90 |
+
mod_track.append(pred_track[i] * weights[i])
|
| 91 |
+
elif algorithm in ['median_wave', 'min_wave', 'max_wave']:
|
| 92 |
+
mod_track.append(pred_track[i])
|
| 93 |
+
elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']:
|
| 94 |
+
spec = stft(pred_track[i], nfft=2048, hl=1024)
|
| 95 |
+
if algorithm in ['avg_fft']:
|
| 96 |
+
mod_track.append(spec * weights[i])
|
| 97 |
+
else:
|
| 98 |
+
mod_track.append(spec)
|
| 99 |
+
pred_track = np.array(mod_track)
|
| 100 |
+
|
| 101 |
+
if algorithm in ['avg_wave']:
|
| 102 |
+
pred_track = pred_track.sum(axis=0)
|
| 103 |
+
pred_track /= np.array(weights).sum().T
|
| 104 |
+
elif algorithm in ['median_wave']:
|
| 105 |
+
pred_track = np.median(pred_track, axis=0)
|
| 106 |
+
elif algorithm in ['min_wave']:
|
| 107 |
+
pred_track = np.array(pred_track)
|
| 108 |
+
pred_track = lambda_min(pred_track, axis=0, key=np.abs)
|
| 109 |
+
elif algorithm in ['max_wave']:
|
| 110 |
+
pred_track = np.array(pred_track)
|
| 111 |
+
pred_track = lambda_max(pred_track, axis=0, key=np.abs)
|
| 112 |
+
elif algorithm in ['avg_fft']:
|
| 113 |
+
pred_track = pred_track.sum(axis=0)
|
| 114 |
+
pred_track /= np.array(weights).sum()
|
| 115 |
+
pred_track = istft(pred_track, 1024, final_length)
|
| 116 |
+
elif algorithm in ['min_fft']:
|
| 117 |
+
pred_track = np.array(pred_track)
|
| 118 |
+
pred_track = lambda_min(pred_track, axis=0, key=np.abs)
|
| 119 |
+
pred_track = istft(pred_track, 1024, final_length)
|
| 120 |
+
elif algorithm in ['max_fft']:
|
| 121 |
+
pred_track = np.array(pred_track)
|
| 122 |
+
pred_track = absmax(pred_track, axis=0)
|
| 123 |
+
pred_track = istft(pred_track, 1024, final_length)
|
| 124 |
+
elif algorithm in ['median_fft']:
|
| 125 |
+
pred_track = np.array(pred_track)
|
| 126 |
+
pred_track = np.median(pred_track, axis=0)
|
| 127 |
+
pred_track = istft(pred_track, 1024, final_length)
|
| 128 |
+
return pred_track
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def ensemble_audio_files(files, output="res.wav", ensemble_type='avg_wave', weights=None, out_format="wav"):
|
| 132 |
+
"""
|
| 133 |
+
Основная функция для объединения аудиофайлов
|
| 134 |
+
|
| 135 |
+
:param files: список путей к аудиофайлам
|
| 136 |
+
:param output: путь для сохранения результата
|
| 137 |
+
:param ensemble_type: метод объединения (avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft)
|
| 138 |
+
:param weights: список весов для каждого файла (None для равных весов)
|
| 139 |
+
:return: None
|
| 140 |
+
"""
|
| 141 |
+
print('Ensemble type: {}'.format(ensemble_type))
|
| 142 |
+
print('Number of input files: {}'.format(len(files)))
|
| 143 |
+
if weights is not None:
|
| 144 |
+
weights = np.array(weights)
|
| 145 |
+
else:
|
| 146 |
+
weights = np.ones(len(files))
|
| 147 |
+
print('Weights: {}'.format(weights))
|
| 148 |
+
print('Output file: {}'.format(output))
|
| 149 |
+
|
| 150 |
+
data = []
|
| 151 |
+
sr = None
|
| 152 |
+
for f in files:
|
| 153 |
+
if not os.path.isfile(f):
|
| 154 |
+
print('Error. Can\'t find file: {}. Check paths.'.format(f))
|
| 155 |
+
exit()
|
| 156 |
+
print('Reading file: {}'.format(f))
|
| 157 |
+
wav, current_sr = librosa.load(f, sr=None, mono=False)
|
| 158 |
+
if sr is None:
|
| 159 |
+
sr = current_sr
|
| 160 |
+
elif sr != current_sr:
|
| 161 |
+
print('Error: Sample rates must be equal for all files')
|
| 162 |
+
exit()
|
| 163 |
+
print("Waveform shape: {} sample rate: {}".format(wav.shape, sr))
|
| 164 |
+
data.append(wav)
|
| 165 |
+
|
| 166 |
+
data = np.array(data)
|
| 167 |
+
res = average_waveforms(data, weights, ensemble_type)
|
| 168 |
+
print('Result shape: {}'.format(res.shape))
|
| 169 |
+
|
| 170 |
+
output_wav = f"{output}_orig.wav"
|
| 171 |
+
output = f"{output}.{out_format}"
|
| 172 |
+
|
| 173 |
+
if out_format in ["wav", "flac"]:
|
| 174 |
+
|
| 175 |
+
sf.write(output, res.T, sr, subtype='PCM_16')
|
| 176 |
+
sf.write(output_wav, res.T, sr, subtype='PCM_16')
|
| 177 |
+
|
| 178 |
+
elif out_format in ["mp3", "m4a", "aac", "ogg", "opus", "aiff"]:
|
| 179 |
+
|
| 180 |
+
write_audio_file(output, res.T, sr, out_format, "320k")
|
| 181 |
+
sf.write(output_wav, res.T, sr, subtype='PCM_16')
|
| 182 |
+
|
| 183 |
+
return output, output_wav
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# input_settings = [("demucs / v4", 1.0, "vocals"), ("mel_band_roformer / mel_4_stems", 0.5, "vocals")]
|
| 188 |
+
|
| 189 |
+
# out, wav = ensembless(input_audio, input_settings, "max_fft", format)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
|
separator/models/bandit/core/__init__.py
ADDED
|
@@ -0,0 +1,744 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from itertools import chain, combinations
|
| 4 |
+
from typing import (
|
| 5 |
+
Any,
|
| 6 |
+
Dict,
|
| 7 |
+
Iterator,
|
| 8 |
+
Mapping, Optional,
|
| 9 |
+
Tuple, Type,
|
| 10 |
+
TypedDict
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
import pytorch_lightning as pl
|
| 14 |
+
import torch
|
| 15 |
+
import torchaudio as ta
|
| 16 |
+
import torchmetrics as tm
|
| 17 |
+
from asteroid import losses as asteroid_losses
|
| 18 |
+
# from deepspeed.ops.adam import DeepSpeedCPUAdam
|
| 19 |
+
# from geoopt import optim as gooptim
|
| 20 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
| 21 |
+
from torch import nn, optim
|
| 22 |
+
from torch.optim import lr_scheduler
|
| 23 |
+
from torch.optim.lr_scheduler import LRScheduler
|
| 24 |
+
|
| 25 |
+
from models.bandit.core import loss, metrics as metrics_, model
|
| 26 |
+
from models.bandit.core.data._types import BatchedDataDict
|
| 27 |
+
from models.bandit.core.data.augmentation import BaseAugmentor, StemAugmentor
|
| 28 |
+
from models.bandit.core.utils import audio as audio_
|
| 29 |
+
from models.bandit.core.utils.audio import BaseFader
|
| 30 |
+
|
| 31 |
+
# from pandas.io.json._normalize import nested_to_record
|
| 32 |
+
|
| 33 |
+
ConfigDict = TypedDict('ConfigDict', {'name': str, 'kwargs': Dict[str, Any]})
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SchedulerConfigDict(ConfigDict):
|
| 37 |
+
monitor: str
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
OptimizerSchedulerConfigDict = TypedDict(
|
| 41 |
+
'OptimizerSchedulerConfigDict',
|
| 42 |
+
{"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
|
| 43 |
+
total=False
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class LRSchedulerReturnDict(TypedDict, total=False):
|
| 48 |
+
scheduler: LRScheduler
|
| 49 |
+
monitor: str
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ConfigureOptimizerReturnDict(TypedDict, total=False):
|
| 53 |
+
optimizer: torch.optim.Optimizer
|
| 54 |
+
lr_scheduler: LRSchedulerReturnDict
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
OutputType = Dict[str, Any]
|
| 58 |
+
MetricsType = Dict[str, torch.Tensor]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_optimizer_class(name: str) -> Type[optim.Optimizer]:
|
| 62 |
+
|
| 63 |
+
if name == "DeepSpeedCPUAdam":
|
| 64 |
+
return DeepSpeedCPUAdam
|
| 65 |
+
|
| 66 |
+
for module in [optim, gooptim]:
|
| 67 |
+
if name in module.__dict__:
|
| 68 |
+
return module.__dict__[name]
|
| 69 |
+
|
| 70 |
+
raise NameError
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def parse_optimizer_config(
|
| 74 |
+
config: OptimizerSchedulerConfigDict,
|
| 75 |
+
parameters: Iterator[nn.Parameter]
|
| 76 |
+
) -> ConfigureOptimizerReturnDict:
|
| 77 |
+
optim_class = get_optimizer_class(config["optimizer"]["name"])
|
| 78 |
+
optimizer = optim_class(parameters, **config["optimizer"]["kwargs"])
|
| 79 |
+
|
| 80 |
+
optim_dict: ConfigureOptimizerReturnDict = {
|
| 81 |
+
"optimizer": optimizer,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
if "scheduler" in config:
|
| 85 |
+
|
| 86 |
+
lr_scheduler_class_ = config["scheduler"]["name"]
|
| 87 |
+
lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_]
|
| 88 |
+
lr_scheduler_dict: LRSchedulerReturnDict = {
|
| 89 |
+
"scheduler": lr_scheduler_class(
|
| 90 |
+
optimizer,
|
| 91 |
+
**config["scheduler"]["kwargs"]
|
| 92 |
+
)
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
if lr_scheduler_class_ == "ReduceLROnPlateau":
|
| 96 |
+
lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"]
|
| 97 |
+
|
| 98 |
+
optim_dict["lr_scheduler"] = lr_scheduler_dict
|
| 99 |
+
|
| 100 |
+
return optim_dict
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def parse_model_config(config: ConfigDict) -> Any:
|
| 104 |
+
name = config["name"]
|
| 105 |
+
|
| 106 |
+
for module in [model]:
|
| 107 |
+
if name in module.__dict__:
|
| 108 |
+
return module.__dict__[name](**config["kwargs"])
|
| 109 |
+
|
| 110 |
+
raise NameError
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
_LEGACY_LOSS_NAMES = ["HybridL1Loss"]
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module:
|
| 117 |
+
name = config["name"]
|
| 118 |
+
|
| 119 |
+
if name == "HybridL1Loss":
|
| 120 |
+
return loss.TimeFreqL1Loss(**config["kwargs"])
|
| 121 |
+
|
| 122 |
+
raise NameError
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def parse_loss_config(config: ConfigDict) -> nn.Module:
|
| 126 |
+
name = config["name"]
|
| 127 |
+
|
| 128 |
+
if name in _LEGACY_LOSS_NAMES:
|
| 129 |
+
return _parse_legacy_loss_config(config)
|
| 130 |
+
|
| 131 |
+
for module in [loss, nn.modules.loss, asteroid_losses]:
|
| 132 |
+
if name in module.__dict__:
|
| 133 |
+
# print(config["kwargs"])
|
| 134 |
+
return module.__dict__[name](**config["kwargs"])
|
| 135 |
+
|
| 136 |
+
raise NameError
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_metric(config: ConfigDict) -> tm.Metric:
|
| 140 |
+
name = config["name"]
|
| 141 |
+
|
| 142 |
+
for module in [tm, metrics_]:
|
| 143 |
+
if name in module.__dict__:
|
| 144 |
+
return module.__dict__[name](**config["kwargs"])
|
| 145 |
+
raise NameError
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection:
|
| 149 |
+
metrics = {}
|
| 150 |
+
|
| 151 |
+
for metric in config:
|
| 152 |
+
metrics[metric] = get_metric(config[metric])
|
| 153 |
+
|
| 154 |
+
return tm.MetricCollection(metrics)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def parse_fader_config(config: ConfigDict) -> BaseFader:
|
| 158 |
+
name = config["name"]
|
| 159 |
+
|
| 160 |
+
for module in [audio_]:
|
| 161 |
+
if name in module.__dict__:
|
| 162 |
+
return module.__dict__[name](**config["kwargs"])
|
| 163 |
+
|
| 164 |
+
raise NameError
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class LightningSystem(pl.LightningModule):
|
| 168 |
+
_VOX_STEMS = ["speech", "vocals"]
|
| 169 |
+
_BG_STEMS = ["background", "effects", "mne"]
|
| 170 |
+
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
config: Dict,
|
| 174 |
+
loss_adjustment: float = 1.0,
|
| 175 |
+
attach_fader: bool = False
|
| 176 |
+
) -> None:
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.optimizer_config = config["optimizer"]
|
| 179 |
+
self.model = parse_model_config(config["model"])
|
| 180 |
+
self.loss = parse_loss_config(config["loss"])
|
| 181 |
+
self.metrics = nn.ModuleDict(
|
| 182 |
+
{
|
| 183 |
+
stem: parse_metric_config(config["metrics"]["dev"])
|
| 184 |
+
for stem in self.model.stems
|
| 185 |
+
}
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.metrics.disallow_fsdp = True
|
| 189 |
+
|
| 190 |
+
self.test_metrics = nn.ModuleDict(
|
| 191 |
+
{
|
| 192 |
+
stem: parse_metric_config(config["metrics"]["test"])
|
| 193 |
+
for stem in self.model.stems
|
| 194 |
+
}
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
self.test_metrics.disallow_fsdp = True
|
| 198 |
+
|
| 199 |
+
self.fs = config["model"]["kwargs"]["fs"]
|
| 200 |
+
|
| 201 |
+
self.fader_config = config["inference"]["fader"]
|
| 202 |
+
if attach_fader:
|
| 203 |
+
self.fader = parse_fader_config(config["inference"]["fader"])
|
| 204 |
+
else:
|
| 205 |
+
self.fader = None
|
| 206 |
+
|
| 207 |
+
self.augmentation: Optional[BaseAugmentor]
|
| 208 |
+
if config.get("augmentation", None) is not None:
|
| 209 |
+
self.augmentation = StemAugmentor(**config["augmentation"])
|
| 210 |
+
else:
|
| 211 |
+
self.augmentation = None
|
| 212 |
+
|
| 213 |
+
self.predict_output_path: Optional[str] = None
|
| 214 |
+
self.loss_adjustment = loss_adjustment
|
| 215 |
+
|
| 216 |
+
self.val_prefix = None
|
| 217 |
+
self.test_prefix = None
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def configure_optimizers(self) -> Any:
|
| 221 |
+
return parse_optimizer_config(
|
| 222 |
+
self.optimizer_config,
|
| 223 |
+
self.trainer.model.parameters()
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
def compute_loss(self, batch: BatchedDataDict, output: OutputType) -> Dict[
|
| 227 |
+
str, torch.Tensor]:
|
| 228 |
+
return {"loss": self.loss(output, batch)}
|
| 229 |
+
|
| 230 |
+
def update_metrics(
|
| 231 |
+
self,
|
| 232 |
+
batch: BatchedDataDict,
|
| 233 |
+
output: OutputType,
|
| 234 |
+
mode: str
|
| 235 |
+
) -> None:
|
| 236 |
+
|
| 237 |
+
if mode == "test":
|
| 238 |
+
metrics = self.test_metrics
|
| 239 |
+
else:
|
| 240 |
+
metrics = self.metrics
|
| 241 |
+
|
| 242 |
+
for stem, metric in metrics.items():
|
| 243 |
+
|
| 244 |
+
if stem == "mne:+":
|
| 245 |
+
stem = "mne"
|
| 246 |
+
|
| 247 |
+
# print(f"matching for {stem}")
|
| 248 |
+
if mode == "train":
|
| 249 |
+
metric.update(
|
| 250 |
+
output["audio"][stem],#.cpu(),
|
| 251 |
+
batch["audio"][stem],#.cpu()
|
| 252 |
+
)
|
| 253 |
+
else:
|
| 254 |
+
if stem not in batch["audio"]:
|
| 255 |
+
matched = False
|
| 256 |
+
if stem in self._VOX_STEMS:
|
| 257 |
+
for bstem in self._VOX_STEMS:
|
| 258 |
+
if bstem in batch["audio"]:
|
| 259 |
+
batch["audio"][stem] = batch["audio"][bstem]
|
| 260 |
+
matched = True
|
| 261 |
+
break
|
| 262 |
+
elif stem in self._BG_STEMS:
|
| 263 |
+
for bstem in self._BG_STEMS:
|
| 264 |
+
if bstem in batch["audio"]:
|
| 265 |
+
batch["audio"][stem] = batch["audio"][bstem]
|
| 266 |
+
matched = True
|
| 267 |
+
break
|
| 268 |
+
else:
|
| 269 |
+
matched = True
|
| 270 |
+
|
| 271 |
+
# print(batch["audio"].keys())
|
| 272 |
+
|
| 273 |
+
if matched:
|
| 274 |
+
# print(f"matched {stem}!")
|
| 275 |
+
if stem == "mne" and "mne" not in output["audio"]:
|
| 276 |
+
output["audio"]["mne"] = output["audio"]["music"] + output["audio"]["effects"]
|
| 277 |
+
|
| 278 |
+
metric.update(
|
| 279 |
+
output["audio"][stem],#.cpu(),
|
| 280 |
+
batch["audio"][stem],#.cpu(),
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# print(metric.compute())
|
| 284 |
+
def compute_metrics(self, mode: str="dev") -> Dict[
|
| 285 |
+
str, torch.Tensor]:
|
| 286 |
+
|
| 287 |
+
if mode == "test":
|
| 288 |
+
metrics = self.test_metrics
|
| 289 |
+
else:
|
| 290 |
+
metrics = self.metrics
|
| 291 |
+
|
| 292 |
+
metric_dict = {}
|
| 293 |
+
|
| 294 |
+
for stem, metric in metrics.items():
|
| 295 |
+
md = metric.compute()
|
| 296 |
+
metric_dict.update(
|
| 297 |
+
{f"{stem}/{k}": v for k, v in md.items()}
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
self.log_dict(metric_dict, prog_bar=True, logger=False)
|
| 301 |
+
|
| 302 |
+
return metric_dict
|
| 303 |
+
|
| 304 |
+
def reset_metrics(self, test_mode: bool = False) -> None:
|
| 305 |
+
|
| 306 |
+
if test_mode:
|
| 307 |
+
metrics = self.test_metrics
|
| 308 |
+
else:
|
| 309 |
+
metrics = self.metrics
|
| 310 |
+
|
| 311 |
+
for _, metric in metrics.items():
|
| 312 |
+
metric.reset()
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def forward(self, batch: BatchedDataDict) -> Any:
|
| 316 |
+
batch, output = self.model(batch)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
return batch, output
|
| 320 |
+
|
| 321 |
+
def common_step(self, batch: BatchedDataDict, mode: str) -> Any:
|
| 322 |
+
batch, output = self.forward(batch)
|
| 323 |
+
# print(batch)
|
| 324 |
+
# print(output)
|
| 325 |
+
loss_dict = self.compute_loss(batch, output)
|
| 326 |
+
|
| 327 |
+
with torch.no_grad():
|
| 328 |
+
self.update_metrics(batch, output, mode=mode)
|
| 329 |
+
|
| 330 |
+
if mode == "train":
|
| 331 |
+
self.log("loss", loss_dict["loss"], prog_bar=True)
|
| 332 |
+
|
| 333 |
+
return output, loss_dict
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
|
| 337 |
+
|
| 338 |
+
if self.augmentation is not None:
|
| 339 |
+
with torch.no_grad():
|
| 340 |
+
batch = self.augmentation(batch)
|
| 341 |
+
|
| 342 |
+
_, loss_dict = self.common_step(batch, mode="train")
|
| 343 |
+
|
| 344 |
+
with torch.inference_mode():
|
| 345 |
+
self.log_dict_with_prefix(
|
| 346 |
+
loss_dict,
|
| 347 |
+
"train",
|
| 348 |
+
batch_size=batch["audio"]["mixture"].shape[0]
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
loss_dict["loss"] *= self.loss_adjustment
|
| 352 |
+
|
| 353 |
+
return loss_dict
|
| 354 |
+
|
| 355 |
+
def on_train_batch_end(
|
| 356 |
+
self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
|
| 357 |
+
) -> None:
|
| 358 |
+
|
| 359 |
+
metric_dict = self.compute_metrics()
|
| 360 |
+
self.log_dict_with_prefix(metric_dict, "train")
|
| 361 |
+
self.reset_metrics()
|
| 362 |
+
|
| 363 |
+
def validation_step(
|
| 364 |
+
self,
|
| 365 |
+
batch: BatchedDataDict,
|
| 366 |
+
batch_idx: int,
|
| 367 |
+
dataloader_idx: int = 0
|
| 368 |
+
) -> Dict[str, Any]:
|
| 369 |
+
|
| 370 |
+
with torch.inference_mode():
|
| 371 |
+
curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val"
|
| 372 |
+
|
| 373 |
+
if curr_val_prefix != self.val_prefix:
|
| 374 |
+
# print(f"Switching to validation dataloader {dataloader_idx}")
|
| 375 |
+
if self.val_prefix is not None:
|
| 376 |
+
self._on_validation_epoch_end()
|
| 377 |
+
self.val_prefix = curr_val_prefix
|
| 378 |
+
_, loss_dict = self.common_step(batch, mode="val")
|
| 379 |
+
|
| 380 |
+
self.log_dict_with_prefix(
|
| 381 |
+
loss_dict,
|
| 382 |
+
self.val_prefix,
|
| 383 |
+
batch_size=batch["audio"]["mixture"].shape[0],
|
| 384 |
+
prog_bar=True,
|
| 385 |
+
add_dataloader_idx=False
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
return loss_dict
|
| 389 |
+
|
| 390 |
+
def on_validation_epoch_end(self) -> None:
|
| 391 |
+
self._on_validation_epoch_end()
|
| 392 |
+
|
| 393 |
+
def _on_validation_epoch_end(self) -> None:
|
| 394 |
+
metric_dict = self.compute_metrics()
|
| 395 |
+
self.log_dict_with_prefix(metric_dict, self.val_prefix, prog_bar=True,
|
| 396 |
+
add_dataloader_idx=False)
|
| 397 |
+
# self.logger.save()
|
| 398 |
+
# print(self.val_prefix, "Validation metrics:", metric_dict)
|
| 399 |
+
self.reset_metrics()
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def old_predtest_step(
|
| 403 |
+
self,
|
| 404 |
+
batch: BatchedDataDict,
|
| 405 |
+
batch_idx: int,
|
| 406 |
+
dataloader_idx: int = 0
|
| 407 |
+
) -> Tuple[BatchedDataDict, OutputType]:
|
| 408 |
+
|
| 409 |
+
audio_batch = batch["audio"]["mixture"]
|
| 410 |
+
track_batch = batch.get("track", ["" for _ in range(len(audio_batch))])
|
| 411 |
+
|
| 412 |
+
output_list_of_dicts = [
|
| 413 |
+
self.fader(
|
| 414 |
+
audio[None, ...],
|
| 415 |
+
lambda a: self.test_forward(a, track)
|
| 416 |
+
)
|
| 417 |
+
for audio, track in zip(audio_batch, track_batch)
|
| 418 |
+
]
|
| 419 |
+
|
| 420 |
+
output_dict_of_lists = defaultdict(list)
|
| 421 |
+
|
| 422 |
+
for output_dict in output_list_of_dicts:
|
| 423 |
+
for stem, audio in output_dict.items():
|
| 424 |
+
output_dict_of_lists[stem].append(audio)
|
| 425 |
+
|
| 426 |
+
output = {
|
| 427 |
+
"audio": {
|
| 428 |
+
stem: torch.concat(output_list, dim=0)
|
| 429 |
+
for stem, output_list in output_dict_of_lists.items()
|
| 430 |
+
}
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
return batch, output
|
| 434 |
+
|
| 435 |
+
def predtest_step(
|
| 436 |
+
self,
|
| 437 |
+
batch: BatchedDataDict,
|
| 438 |
+
batch_idx: int = -1,
|
| 439 |
+
dataloader_idx: int = 0
|
| 440 |
+
) -> Tuple[BatchedDataDict, OutputType]:
|
| 441 |
+
|
| 442 |
+
if getattr(self.model, "bypass_fader", False):
|
| 443 |
+
batch, output = self.model(batch)
|
| 444 |
+
else:
|
| 445 |
+
audio_batch = batch["audio"]["mixture"]
|
| 446 |
+
output = self.fader(
|
| 447 |
+
audio_batch,
|
| 448 |
+
lambda a: self.test_forward(a, "", batch=batch)
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
return batch, output
|
| 452 |
+
|
| 453 |
+
def test_forward(
|
| 454 |
+
self,
|
| 455 |
+
audio: torch.Tensor,
|
| 456 |
+
track: str = "",
|
| 457 |
+
batch: BatchedDataDict = None
|
| 458 |
+
) -> torch.Tensor:
|
| 459 |
+
|
| 460 |
+
if self.fader is None:
|
| 461 |
+
self.attach_fader()
|
| 462 |
+
|
| 463 |
+
cond = batch.get("condition", None)
|
| 464 |
+
|
| 465 |
+
if cond is not None and cond.shape[0] == 1:
|
| 466 |
+
cond = cond.repeat(audio.shape[0], 1)
|
| 467 |
+
|
| 468 |
+
_, output = self.forward(
|
| 469 |
+
{"audio": {"mixture": audio},
|
| 470 |
+
"track": track,
|
| 471 |
+
"condition": cond,
|
| 472 |
+
}
|
| 473 |
+
) # TODO: support track properly
|
| 474 |
+
|
| 475 |
+
return output["audio"]
|
| 476 |
+
|
| 477 |
+
def on_test_epoch_start(self) -> None:
|
| 478 |
+
self.attach_fader(force_reattach=True)
|
| 479 |
+
|
| 480 |
+
def test_step(
|
| 481 |
+
self,
|
| 482 |
+
batch: BatchedDataDict,
|
| 483 |
+
batch_idx: int,
|
| 484 |
+
dataloader_idx: int = 0
|
| 485 |
+
) -> Any:
|
| 486 |
+
curr_test_prefix = f"test{dataloader_idx}"
|
| 487 |
+
|
| 488 |
+
# print(batch["audio"].keys())
|
| 489 |
+
|
| 490 |
+
if curr_test_prefix != self.test_prefix:
|
| 491 |
+
# print(f"Switching to test dataloader {dataloader_idx}")
|
| 492 |
+
if self.test_prefix is not None:
|
| 493 |
+
self._on_test_epoch_end()
|
| 494 |
+
self.test_prefix = curr_test_prefix
|
| 495 |
+
|
| 496 |
+
with torch.inference_mode():
|
| 497 |
+
_, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
| 498 |
+
# print(output)
|
| 499 |
+
self.update_metrics(batch, output, mode="test")
|
| 500 |
+
|
| 501 |
+
return output
|
| 502 |
+
|
| 503 |
+
def on_test_epoch_end(self) -> None:
|
| 504 |
+
self._on_test_epoch_end()
|
| 505 |
+
|
| 506 |
+
def _on_test_epoch_end(self) -> None:
|
| 507 |
+
metric_dict = self.compute_metrics(mode="test")
|
| 508 |
+
self.log_dict_with_prefix(metric_dict, self.test_prefix, prog_bar=True,
|
| 509 |
+
add_dataloader_idx=False)
|
| 510 |
+
# self.logger.save()
|
| 511 |
+
# print(self.test_prefix, "Test metrics:", metric_dict)
|
| 512 |
+
self.reset_metrics()
|
| 513 |
+
|
| 514 |
+
def predict_step(
|
| 515 |
+
self,
|
| 516 |
+
batch: BatchedDataDict,
|
| 517 |
+
batch_idx: int = 0,
|
| 518 |
+
dataloader_idx: int = 0,
|
| 519 |
+
include_track_name: Optional[bool] = None,
|
| 520 |
+
get_no_vox_combinations: bool = True,
|
| 521 |
+
get_residual: bool = False,
|
| 522 |
+
treat_batch_as_channels: bool = False,
|
| 523 |
+
fs: Optional[int] = None,
|
| 524 |
+
) -> Any:
|
| 525 |
+
assert self.predict_output_path is not None
|
| 526 |
+
|
| 527 |
+
batch_size = batch["audio"]["mixture"].shape[0]
|
| 528 |
+
|
| 529 |
+
if include_track_name is None:
|
| 530 |
+
include_track_name = batch_size > 1
|
| 531 |
+
|
| 532 |
+
with torch.inference_mode():
|
| 533 |
+
batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
| 534 |
+
print('Pred test finished...')
|
| 535 |
+
torch.cuda.empty_cache()
|
| 536 |
+
metric_dict = {}
|
| 537 |
+
|
| 538 |
+
if get_residual:
|
| 539 |
+
mixture = batch["audio"]["mixture"]
|
| 540 |
+
extracted = sum([output["audio"][stem] for stem in output["audio"]])
|
| 541 |
+
residual = mixture - extracted
|
| 542 |
+
print(extracted.shape, mixture.shape, residual.shape)
|
| 543 |
+
|
| 544 |
+
output["audio"]["residual"] = residual
|
| 545 |
+
|
| 546 |
+
if get_no_vox_combinations:
|
| 547 |
+
no_vox_stems = [
|
| 548 |
+
stem for stem in output["audio"] if
|
| 549 |
+
stem not in self._VOX_STEMS
|
| 550 |
+
]
|
| 551 |
+
no_vox_combinations = chain.from_iterable(
|
| 552 |
+
combinations(no_vox_stems, r) for r in
|
| 553 |
+
range(2, len(no_vox_stems) + 1)
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
for combination in no_vox_combinations:
|
| 557 |
+
combination_ = list(combination)
|
| 558 |
+
output["audio"]["+".join(combination_)] = sum(
|
| 559 |
+
[output["audio"][stem] for stem in combination_]
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
if treat_batch_as_channels:
|
| 563 |
+
for stem in output["audio"]:
|
| 564 |
+
output["audio"][stem] = output["audio"][stem].reshape(
|
| 565 |
+
1, -1, output["audio"][stem].shape[-1]
|
| 566 |
+
)
|
| 567 |
+
batch_size = 1
|
| 568 |
+
|
| 569 |
+
for b in range(batch_size):
|
| 570 |
+
print("!!", b)
|
| 571 |
+
for stem in output["audio"]:
|
| 572 |
+
print(f"Saving audio for {stem} to {self.predict_output_path}")
|
| 573 |
+
track_name = batch["track"][b].split("/")[-1]
|
| 574 |
+
|
| 575 |
+
if batch.get("audio", {}).get(stem, None) is not None:
|
| 576 |
+
self.test_metrics[stem].reset()
|
| 577 |
+
metrics = self.test_metrics[stem](
|
| 578 |
+
batch["audio"][stem][[b], ...],
|
| 579 |
+
output["audio"][stem][[b], ...]
|
| 580 |
+
)
|
| 581 |
+
snr = metrics["snr"]
|
| 582 |
+
sisnr = metrics["sisnr"]
|
| 583 |
+
sdr = metrics["sdr"]
|
| 584 |
+
metric_dict[stem] = metrics
|
| 585 |
+
print(
|
| 586 |
+
track_name,
|
| 587 |
+
f"snr={snr:2.2f} dB",
|
| 588 |
+
f"sisnr={sisnr:2.2f}",
|
| 589 |
+
f"sdr={sdr:2.2f} dB",
|
| 590 |
+
)
|
| 591 |
+
filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
|
| 592 |
+
else:
|
| 593 |
+
filename = f"{stem}.wav"
|
| 594 |
+
|
| 595 |
+
if include_track_name:
|
| 596 |
+
output_dir = os.path.join(
|
| 597 |
+
self.predict_output_path,
|
| 598 |
+
track_name
|
| 599 |
+
)
|
| 600 |
+
else:
|
| 601 |
+
output_dir = self.predict_output_path
|
| 602 |
+
|
| 603 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 604 |
+
|
| 605 |
+
if fs is None:
|
| 606 |
+
fs = self.fs
|
| 607 |
+
|
| 608 |
+
ta.save(
|
| 609 |
+
os.path.join(output_dir, filename),
|
| 610 |
+
output["audio"][stem][b, ...].cpu(),
|
| 611 |
+
fs,
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
return metric_dict
|
| 615 |
+
|
| 616 |
+
def get_stems(
|
| 617 |
+
self,
|
| 618 |
+
batch: BatchedDataDict,
|
| 619 |
+
batch_idx: int = 0,
|
| 620 |
+
dataloader_idx: int = 0,
|
| 621 |
+
include_track_name: Optional[bool] = None,
|
| 622 |
+
get_no_vox_combinations: bool = True,
|
| 623 |
+
get_residual: bool = False,
|
| 624 |
+
treat_batch_as_channels: bool = False,
|
| 625 |
+
fs: Optional[int] = None,
|
| 626 |
+
) -> Any:
|
| 627 |
+
assert self.predict_output_path is not None
|
| 628 |
+
|
| 629 |
+
batch_size = batch["audio"]["mixture"].shape[0]
|
| 630 |
+
|
| 631 |
+
if include_track_name is None:
|
| 632 |
+
include_track_name = batch_size > 1
|
| 633 |
+
|
| 634 |
+
with torch.inference_mode():
|
| 635 |
+
batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
|
| 636 |
+
torch.cuda.empty_cache()
|
| 637 |
+
metric_dict = {}
|
| 638 |
+
|
| 639 |
+
if get_residual:
|
| 640 |
+
mixture = batch["audio"]["mixture"]
|
| 641 |
+
extracted = sum([output["audio"][stem] for stem in output["audio"]])
|
| 642 |
+
residual = mixture - extracted
|
| 643 |
+
# print(extracted.shape, mixture.shape, residual.shape)
|
| 644 |
+
|
| 645 |
+
output["audio"]["residual"] = residual
|
| 646 |
+
|
| 647 |
+
if get_no_vox_combinations:
|
| 648 |
+
no_vox_stems = [
|
| 649 |
+
stem for stem in output["audio"] if
|
| 650 |
+
stem not in self._VOX_STEMS
|
| 651 |
+
]
|
| 652 |
+
no_vox_combinations = chain.from_iterable(
|
| 653 |
+
combinations(no_vox_stems, r) for r in
|
| 654 |
+
range(2, len(no_vox_stems) + 1)
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
for combination in no_vox_combinations:
|
| 658 |
+
combination_ = list(combination)
|
| 659 |
+
output["audio"]["+".join(combination_)] = sum(
|
| 660 |
+
[output["audio"][stem] for stem in combination_]
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
if treat_batch_as_channels:
|
| 664 |
+
for stem in output["audio"]:
|
| 665 |
+
output["audio"][stem] = output["audio"][stem].reshape(
|
| 666 |
+
1, -1, output["audio"][stem].shape[-1]
|
| 667 |
+
)
|
| 668 |
+
batch_size = 1
|
| 669 |
+
|
| 670 |
+
result = {}
|
| 671 |
+
for b in range(batch_size):
|
| 672 |
+
for stem in output["audio"]:
|
| 673 |
+
track_name = batch["track"][b].split("/")[-1]
|
| 674 |
+
|
| 675 |
+
if batch.get("audio", {}).get(stem, None) is not None:
|
| 676 |
+
self.test_metrics[stem].reset()
|
| 677 |
+
metrics = self.test_metrics[stem](
|
| 678 |
+
batch["audio"][stem][[b], ...],
|
| 679 |
+
output["audio"][stem][[b], ...]
|
| 680 |
+
)
|
| 681 |
+
snr = metrics["snr"]
|
| 682 |
+
sisnr = metrics["sisnr"]
|
| 683 |
+
sdr = metrics["sdr"]
|
| 684 |
+
metric_dict[stem] = metrics
|
| 685 |
+
print(
|
| 686 |
+
track_name,
|
| 687 |
+
f"snr={snr:2.2f} dB",
|
| 688 |
+
f"sisnr={sisnr:2.2f}",
|
| 689 |
+
f"sdr={sdr:2.2f} dB",
|
| 690 |
+
)
|
| 691 |
+
filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
|
| 692 |
+
else:
|
| 693 |
+
filename = f"{stem}.wav"
|
| 694 |
+
|
| 695 |
+
if include_track_name:
|
| 696 |
+
output_dir = os.path.join(
|
| 697 |
+
self.predict_output_path,
|
| 698 |
+
track_name
|
| 699 |
+
)
|
| 700 |
+
else:
|
| 701 |
+
output_dir = self.predict_output_path
|
| 702 |
+
|
| 703 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 704 |
+
|
| 705 |
+
if fs is None:
|
| 706 |
+
fs = self.fs
|
| 707 |
+
|
| 708 |
+
result[stem] = output["audio"][stem][b, ...].cpu().numpy()
|
| 709 |
+
|
| 710 |
+
return result
|
| 711 |
+
|
| 712 |
+
def load_state_dict(
|
| 713 |
+
self, state_dict: Mapping[str, Any], strict: bool = False
|
| 714 |
+
) -> Any:
|
| 715 |
+
|
| 716 |
+
return super().load_state_dict(state_dict, strict=False)
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
def set_predict_output_path(self, path: str) -> None:
|
| 720 |
+
self.predict_output_path = path
|
| 721 |
+
os.makedirs(self.predict_output_path, exist_ok=True)
|
| 722 |
+
|
| 723 |
+
self.attach_fader()
|
| 724 |
+
|
| 725 |
+
def attach_fader(self, force_reattach=False) -> None:
|
| 726 |
+
if self.fader is None or force_reattach:
|
| 727 |
+
self.fader = parse_fader_config(self.fader_config)
|
| 728 |
+
self.fader.to(self.device)
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
def log_dict_with_prefix(
|
| 732 |
+
self,
|
| 733 |
+
dict_: Dict[str, torch.Tensor],
|
| 734 |
+
prefix: str,
|
| 735 |
+
batch_size: Optional[int] = None,
|
| 736 |
+
**kwargs: Any
|
| 737 |
+
) -> None:
|
| 738 |
+
self.log_dict(
|
| 739 |
+
{f"{prefix}/{k}": v for k, v in dict_.items()},
|
| 740 |
+
batch_size=batch_size,
|
| 741 |
+
logger=True,
|
| 742 |
+
sync_dist=True,
|
| 743 |
+
**kwargs,
|
| 744 |
+
)
|
separator/models/bandit/core/data/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dnr.datamodule import DivideAndRemasterDataModule
|
| 2 |
+
from .musdb.datamodule import MUSDB18DataModule
|
separator/models/bandit/core/data/_types.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Sequence, TypedDict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
AudioDict = Dict[str, torch.Tensor]
|
| 6 |
+
|
| 7 |
+
DataDict = TypedDict('DataDict', {'audio': AudioDict, 'track': str})
|
| 8 |
+
|
| 9 |
+
BatchedDataDict = TypedDict(
|
| 10 |
+
'BatchedDataDict',
|
| 11 |
+
{'audio': AudioDict, 'track': Sequence[str]}
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DataDictWithLanguage(TypedDict):
|
| 16 |
+
audio: AudioDict
|
| 17 |
+
track: str
|
| 18 |
+
language: str
|
separator/models/bandit/core/data/augmentation.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC
|
| 2 |
+
from typing import Any, Dict, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch_audiomentations as tam
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from models.bandit.core.data._types import BatchedDataDict, DataDict
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BaseAugmentor(nn.Module, ABC):
|
| 12 |
+
def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[
|
| 13 |
+
DataDict, BatchedDataDict]:
|
| 14 |
+
raise NotImplementedError
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class StemAugmentor(BaseAugmentor):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
audiomentations: Dict[str, Dict[str, Any]],
|
| 21 |
+
fix_clipping: bool = True,
|
| 22 |
+
scaler_margin: float = 0.5,
|
| 23 |
+
apply_both_default_and_common: bool = False,
|
| 24 |
+
) -> None:
|
| 25 |
+
super().__init__()
|
| 26 |
+
|
| 27 |
+
augmentations = {}
|
| 28 |
+
|
| 29 |
+
self.has_default = "[default]" in audiomentations
|
| 30 |
+
self.has_common = "[common]" in audiomentations
|
| 31 |
+
self.apply_both_default_and_common = apply_both_default_and_common
|
| 32 |
+
|
| 33 |
+
for stem in audiomentations:
|
| 34 |
+
if audiomentations[stem]["name"] == "Compose":
|
| 35 |
+
augmentations[stem] = getattr(
|
| 36 |
+
tam,
|
| 37 |
+
audiomentations[stem]["name"]
|
| 38 |
+
)(
|
| 39 |
+
[
|
| 40 |
+
getattr(tam, aug["name"])(**aug["kwargs"])
|
| 41 |
+
for aug in
|
| 42 |
+
audiomentations[stem]["kwargs"]["transforms"]
|
| 43 |
+
],
|
| 44 |
+
**audiomentations[stem]["kwargs"]["kwargs"],
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
augmentations[stem] = getattr(
|
| 48 |
+
tam,
|
| 49 |
+
audiomentations[stem]["name"]
|
| 50 |
+
)(
|
| 51 |
+
**audiomentations[stem]["kwargs"]
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.augmentations = nn.ModuleDict(augmentations)
|
| 55 |
+
self.fix_clipping = fix_clipping
|
| 56 |
+
self.scaler_margin = scaler_margin
|
| 57 |
+
|
| 58 |
+
def check_and_fix_clipping(
|
| 59 |
+
self, item: Union[DataDict, BatchedDataDict]
|
| 60 |
+
) -> Union[DataDict, BatchedDataDict]:
|
| 61 |
+
max_abs = []
|
| 62 |
+
|
| 63 |
+
for stem in item["audio"]:
|
| 64 |
+
max_abs.append(item["audio"][stem].abs().max().item())
|
| 65 |
+
|
| 66 |
+
if max(max_abs) > 1.0:
|
| 67 |
+
scaler = 1.0 / (max(max_abs) + torch.rand(
|
| 68 |
+
(1,),
|
| 69 |
+
device=item["audio"]["mixture"].device
|
| 70 |
+
) * self.scaler_margin)
|
| 71 |
+
|
| 72 |
+
for stem in item["audio"]:
|
| 73 |
+
item["audio"][stem] *= scaler
|
| 74 |
+
|
| 75 |
+
return item
|
| 76 |
+
|
| 77 |
+
def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[
|
| 78 |
+
DataDict, BatchedDataDict]:
|
| 79 |
+
|
| 80 |
+
for stem in item["audio"]:
|
| 81 |
+
if stem == "mixture":
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
if self.has_common:
|
| 85 |
+
item["audio"][stem] = self.augmentations["[common]"](
|
| 86 |
+
item["audio"][stem]
|
| 87 |
+
).samples
|
| 88 |
+
|
| 89 |
+
if stem in self.augmentations:
|
| 90 |
+
item["audio"][stem] = self.augmentations[stem](
|
| 91 |
+
item["audio"][stem]
|
| 92 |
+
).samples
|
| 93 |
+
elif self.has_default:
|
| 94 |
+
if not self.has_common or self.apply_both_default_and_common:
|
| 95 |
+
item["audio"][stem] = self.augmentations["[default]"](
|
| 96 |
+
item["audio"][stem]
|
| 97 |
+
).samples
|
| 98 |
+
|
| 99 |
+
item["audio"]["mixture"] = sum(
|
| 100 |
+
[item["audio"][stem] for stem in item["audio"]
|
| 101 |
+
if stem != "mixture"]
|
| 102 |
+
) # type: ignore[call-overload, assignment]
|
| 103 |
+
|
| 104 |
+
if self.fix_clipping:
|
| 105 |
+
item = self.check_and_fix_clipping(item)
|
| 106 |
+
|
| 107 |
+
return item
|
separator/models/bandit/core/data/augmented.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Dict, Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.utils import data
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AugmentedDataset(data.Dataset):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
dataset: data.Dataset,
|
| 13 |
+
augmentation: nn.Module = nn.Identity(),
|
| 14 |
+
target_length: Optional[int] = None,
|
| 15 |
+
) -> None:
|
| 16 |
+
warnings.warn(
|
| 17 |
+
"This class is no longer used. Attach augmentation to "
|
| 18 |
+
"the LightningSystem instead.",
|
| 19 |
+
DeprecationWarning,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
self.dataset = dataset
|
| 23 |
+
self.augmentation = augmentation
|
| 24 |
+
|
| 25 |
+
self.ds_length: int = len(dataset) # type: ignore[arg-type]
|
| 26 |
+
self.length = target_length if target_length is not None else self.ds_length
|
| 27 |
+
|
| 28 |
+
def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str,
|
| 29 |
+
torch.Tensor]]]:
|
| 30 |
+
item = self.dataset[index % self.ds_length]
|
| 31 |
+
item = self.augmentation(item)
|
| 32 |
+
return item
|
| 33 |
+
|
| 34 |
+
def __len__(self) -> int:
|
| 35 |
+
return self.length
|
separator/models/bandit/core/data/base.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pedalboard as pb
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio as ta
|
| 9 |
+
from torch.utils import data
|
| 10 |
+
|
| 11 |
+
from models.bandit.core.data._types import AudioDict, DataDict
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseSourceSeparationDataset(data.Dataset, ABC):
|
| 15 |
+
def __init__(
|
| 16 |
+
self, split: str,
|
| 17 |
+
stems: List[str],
|
| 18 |
+
files: List[str],
|
| 19 |
+
data_path: str,
|
| 20 |
+
fs: int,
|
| 21 |
+
npy_memmap: bool,
|
| 22 |
+
recompute_mixture: bool
|
| 23 |
+
):
|
| 24 |
+
self.split = split
|
| 25 |
+
self.stems = stems
|
| 26 |
+
self.stems_no_mixture = [s for s in stems if s != "mixture"]
|
| 27 |
+
self.files = files
|
| 28 |
+
self.data_path = data_path
|
| 29 |
+
self.fs = fs
|
| 30 |
+
self.npy_memmap = npy_memmap
|
| 31 |
+
self.recompute_mixture = recompute_mixture
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def get_stem(
|
| 35 |
+
self,
|
| 36 |
+
*,
|
| 37 |
+
stem: str,
|
| 38 |
+
identifier: Dict[str, Any]
|
| 39 |
+
) -> torch.Tensor:
|
| 40 |
+
raise NotImplementedError
|
| 41 |
+
|
| 42 |
+
def _get_audio(self, stems, identifier: Dict[str, Any]):
|
| 43 |
+
audio = {}
|
| 44 |
+
for stem in stems:
|
| 45 |
+
audio[stem] = self.get_stem(stem=stem, identifier=identifier)
|
| 46 |
+
|
| 47 |
+
return audio
|
| 48 |
+
|
| 49 |
+
def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
|
| 50 |
+
|
| 51 |
+
if self.recompute_mixture:
|
| 52 |
+
audio = self._get_audio(
|
| 53 |
+
self.stems_no_mixture,
|
| 54 |
+
identifier=identifier
|
| 55 |
+
)
|
| 56 |
+
audio["mixture"] = self.compute_mixture(audio)
|
| 57 |
+
return audio
|
| 58 |
+
else:
|
| 59 |
+
return self._get_audio(self.stems, identifier=identifier)
|
| 60 |
+
|
| 61 |
+
@abstractmethod
|
| 62 |
+
def get_identifier(self, index: int) -> Dict[str, Any]:
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
|
| 66 |
+
|
| 67 |
+
return sum(
|
| 68 |
+
audio[stem] for stem in audio if stem != "mixture"
|
| 69 |
+
)
|
separator/models/bandit/core/data/dnr/__init__.py
ADDED
|
File without changes
|
separator/models/bandit/core/data/dnr/datamodule.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Mapping, Optional
|
| 3 |
+
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
|
| 6 |
+
from .dataset import (
|
| 7 |
+
DivideAndRemasterDataset,
|
| 8 |
+
DivideAndRemasterDeterministicChunkDataset,
|
| 9 |
+
DivideAndRemasterRandomChunkDataset,
|
| 10 |
+
DivideAndRemasterRandomChunkDatasetWithSpeechReverb
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def DivideAndRemasterDataModule(
|
| 15 |
+
data_root: str = "$DATA_ROOT/DnR/v2",
|
| 16 |
+
batch_size: int = 2,
|
| 17 |
+
num_workers: int = 8,
|
| 18 |
+
train_kwargs: Optional[Mapping] = None,
|
| 19 |
+
val_kwargs: Optional[Mapping] = None,
|
| 20 |
+
test_kwargs: Optional[Mapping] = None,
|
| 21 |
+
datamodule_kwargs: Optional[Mapping] = None,
|
| 22 |
+
use_speech_reverb: bool = False
|
| 23 |
+
# augmentor=None
|
| 24 |
+
) -> pl.LightningDataModule:
|
| 25 |
+
if train_kwargs is None:
|
| 26 |
+
train_kwargs = {}
|
| 27 |
+
|
| 28 |
+
if val_kwargs is None:
|
| 29 |
+
val_kwargs = {}
|
| 30 |
+
|
| 31 |
+
if test_kwargs is None:
|
| 32 |
+
test_kwargs = {}
|
| 33 |
+
|
| 34 |
+
if datamodule_kwargs is None:
|
| 35 |
+
datamodule_kwargs = {}
|
| 36 |
+
|
| 37 |
+
if num_workers is None:
|
| 38 |
+
num_workers = os.cpu_count()
|
| 39 |
+
|
| 40 |
+
if num_workers is None:
|
| 41 |
+
num_workers = 32
|
| 42 |
+
|
| 43 |
+
num_workers = min(num_workers, 64)
|
| 44 |
+
|
| 45 |
+
if use_speech_reverb:
|
| 46 |
+
train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb
|
| 47 |
+
else:
|
| 48 |
+
train_cls = DivideAndRemasterRandomChunkDataset
|
| 49 |
+
|
| 50 |
+
train_dataset = train_cls(
|
| 51 |
+
data_root, "train", **train_kwargs
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# if augmentor is not None:
|
| 55 |
+
# train_dataset = AugmentedDataset(train_dataset, augmentor)
|
| 56 |
+
|
| 57 |
+
datamodule = pl.LightningDataModule.from_datasets(
|
| 58 |
+
train_dataset=train_dataset,
|
| 59 |
+
val_dataset=DivideAndRemasterDeterministicChunkDataset(
|
| 60 |
+
data_root, "val", **val_kwargs
|
| 61 |
+
),
|
| 62 |
+
test_dataset=DivideAndRemasterDataset(
|
| 63 |
+
data_root,
|
| 64 |
+
"test",
|
| 65 |
+
**test_kwargs
|
| 66 |
+
),
|
| 67 |
+
batch_size=batch_size,
|
| 68 |
+
num_workers=num_workers,
|
| 69 |
+
**datamodule_kwargs
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
|
| 73 |
+
|
| 74 |
+
return datamodule
|
separator/models/bandit/core/data/dnr/dataset.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import ABC
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pedalboard as pb
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio as ta
|
| 9 |
+
from torch.utils import data
|
| 10 |
+
|
| 11 |
+
from models.bandit.core.data._types import AudioDict, DataDict
|
| 12 |
+
from models.bandit.core.data.base import BaseSourceSeparationDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
|
| 16 |
+
ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
|
| 17 |
+
STEM_NAME_MAP = {
|
| 18 |
+
"mixture": "mix",
|
| 19 |
+
"speech": "speech",
|
| 20 |
+
"music": "music",
|
| 21 |
+
"effects": "sfx",
|
| 22 |
+
}
|
| 23 |
+
SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
|
| 24 |
+
|
| 25 |
+
FULL_TRACK_LENGTH_SECOND = 60
|
| 26 |
+
FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
split: str,
|
| 31 |
+
stems: List[str],
|
| 32 |
+
files: List[str],
|
| 33 |
+
data_path: str,
|
| 34 |
+
fs: int = 44100,
|
| 35 |
+
npy_memmap: bool = True,
|
| 36 |
+
recompute_mixture: bool = False,
|
| 37 |
+
) -> None:
|
| 38 |
+
super().__init__(
|
| 39 |
+
split=split,
|
| 40 |
+
stems=stems,
|
| 41 |
+
files=files,
|
| 42 |
+
data_path=data_path,
|
| 43 |
+
fs=fs,
|
| 44 |
+
npy_memmap=npy_memmap,
|
| 45 |
+
recompute_mixture=recompute_mixture
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def get_stem(
|
| 49 |
+
self,
|
| 50 |
+
*,
|
| 51 |
+
stem: str,
|
| 52 |
+
identifier: Dict[str, Any]
|
| 53 |
+
) -> torch.Tensor:
|
| 54 |
+
|
| 55 |
+
if stem == "mne":
|
| 56 |
+
return self.get_stem(
|
| 57 |
+
stem="music",
|
| 58 |
+
identifier=identifier) + self.get_stem(
|
| 59 |
+
stem="effects",
|
| 60 |
+
identifier=identifier)
|
| 61 |
+
|
| 62 |
+
track = identifier["track"]
|
| 63 |
+
path = os.path.join(self.data_path, track)
|
| 64 |
+
|
| 65 |
+
if self.npy_memmap:
|
| 66 |
+
audio = np.load(
|
| 67 |
+
os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"),
|
| 68 |
+
mmap_mode="r"
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
# noinspection PyUnresolvedReferences
|
| 72 |
+
audio, _ = ta.load(
|
| 73 |
+
os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav")
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return audio
|
| 77 |
+
|
| 78 |
+
def get_identifier(self, index):
|
| 79 |
+
return dict(track=self.files[index])
|
| 80 |
+
|
| 81 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 82 |
+
identifier = self.get_identifier(index)
|
| 83 |
+
audio = self.get_audio(identifier)
|
| 84 |
+
|
| 85 |
+
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class DivideAndRemasterDataset(DivideAndRemasterBaseDataset):
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
data_root: str,
|
| 92 |
+
split: str,
|
| 93 |
+
stems: Optional[List[str]] = None,
|
| 94 |
+
fs: int = 44100,
|
| 95 |
+
npy_memmap: bool = True,
|
| 96 |
+
) -> None:
|
| 97 |
+
|
| 98 |
+
if stems is None:
|
| 99 |
+
stems = self.ALLOWED_STEMS
|
| 100 |
+
self.stems = stems
|
| 101 |
+
|
| 102 |
+
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
| 103 |
+
|
| 104 |
+
files = sorted(os.listdir(data_path))
|
| 105 |
+
files = [
|
| 106 |
+
f
|
| 107 |
+
for f in files
|
| 108 |
+
if (not f.startswith(".")) and os.path.isdir(
|
| 109 |
+
os.path.join(data_path, f)
|
| 110 |
+
)
|
| 111 |
+
]
|
| 112 |
+
# pprint(list(enumerate(files)))
|
| 113 |
+
if split == "train":
|
| 114 |
+
assert len(files) == 3406, len(files)
|
| 115 |
+
elif split == "val":
|
| 116 |
+
assert len(files) == 487, len(files)
|
| 117 |
+
elif split == "test":
|
| 118 |
+
assert len(files) == 973, len(files)
|
| 119 |
+
|
| 120 |
+
self.n_tracks = len(files)
|
| 121 |
+
|
| 122 |
+
super().__init__(
|
| 123 |
+
data_path=data_path,
|
| 124 |
+
split=split,
|
| 125 |
+
stems=stems,
|
| 126 |
+
files=files,
|
| 127 |
+
fs=fs,
|
| 128 |
+
npy_memmap=npy_memmap,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def __len__(self) -> int:
|
| 132 |
+
return self.n_tracks
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset):
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
data_root: str,
|
| 139 |
+
split: str,
|
| 140 |
+
target_length: int,
|
| 141 |
+
chunk_size_second: float,
|
| 142 |
+
stems: Optional[List[str]] = None,
|
| 143 |
+
fs: int = 44100,
|
| 144 |
+
npy_memmap: bool = True,
|
| 145 |
+
) -> None:
|
| 146 |
+
|
| 147 |
+
if stems is None:
|
| 148 |
+
stems = self.ALLOWED_STEMS
|
| 149 |
+
self.stems = stems
|
| 150 |
+
|
| 151 |
+
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
| 152 |
+
|
| 153 |
+
files = sorted(os.listdir(data_path))
|
| 154 |
+
files = [
|
| 155 |
+
f
|
| 156 |
+
for f in files
|
| 157 |
+
if (not f.startswith(".")) and os.path.isdir(
|
| 158 |
+
os.path.join(data_path, f)
|
| 159 |
+
)
|
| 160 |
+
]
|
| 161 |
+
|
| 162 |
+
if split == "train":
|
| 163 |
+
assert len(files) == 3406, len(files)
|
| 164 |
+
elif split == "val":
|
| 165 |
+
assert len(files) == 487, len(files)
|
| 166 |
+
elif split == "test":
|
| 167 |
+
assert len(files) == 973, len(files)
|
| 168 |
+
|
| 169 |
+
self.n_tracks = len(files)
|
| 170 |
+
|
| 171 |
+
self.target_length = target_length
|
| 172 |
+
self.chunk_size = int(chunk_size_second * fs)
|
| 173 |
+
|
| 174 |
+
super().__init__(
|
| 175 |
+
data_path=data_path,
|
| 176 |
+
split=split,
|
| 177 |
+
stems=stems,
|
| 178 |
+
files=files,
|
| 179 |
+
fs=fs,
|
| 180 |
+
npy_memmap=npy_memmap,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def __len__(self) -> int:
|
| 184 |
+
return self.target_length
|
| 185 |
+
|
| 186 |
+
def get_identifier(self, index):
|
| 187 |
+
return super().get_identifier(index % self.n_tracks)
|
| 188 |
+
|
| 189 |
+
def get_stem(
|
| 190 |
+
self,
|
| 191 |
+
*,
|
| 192 |
+
stem: str,
|
| 193 |
+
identifier: Dict[str, Any],
|
| 194 |
+
chunk_here: bool = False,
|
| 195 |
+
) -> torch.Tensor:
|
| 196 |
+
|
| 197 |
+
stem = super().get_stem(
|
| 198 |
+
stem=stem,
|
| 199 |
+
identifier=identifier
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
if chunk_here:
|
| 203 |
+
start = np.random.randint(
|
| 204 |
+
0,
|
| 205 |
+
self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
|
| 206 |
+
)
|
| 207 |
+
end = start + self.chunk_size
|
| 208 |
+
|
| 209 |
+
stem = stem[:, start:end]
|
| 210 |
+
|
| 211 |
+
return stem
|
| 212 |
+
|
| 213 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 214 |
+
identifier = self.get_identifier(index)
|
| 215 |
+
# self.index_lock = index
|
| 216 |
+
audio = self.get_audio(identifier)
|
| 217 |
+
# self.index_lock = None
|
| 218 |
+
|
| 219 |
+
start = np.random.randint(
|
| 220 |
+
0,
|
| 221 |
+
self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
|
| 222 |
+
)
|
| 223 |
+
end = start + self.chunk_size
|
| 224 |
+
|
| 225 |
+
audio = {
|
| 226 |
+
k: v[:, start:end] for k, v in audio.items()
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset):
|
| 233 |
+
def __init__(
|
| 234 |
+
self,
|
| 235 |
+
data_root: str,
|
| 236 |
+
split: str,
|
| 237 |
+
chunk_size_second: float,
|
| 238 |
+
hop_size_second: float,
|
| 239 |
+
stems: Optional[List[str]] = None,
|
| 240 |
+
fs: int = 44100,
|
| 241 |
+
npy_memmap: bool = True,
|
| 242 |
+
) -> None:
|
| 243 |
+
|
| 244 |
+
if stems is None:
|
| 245 |
+
stems = self.ALLOWED_STEMS
|
| 246 |
+
self.stems = stems
|
| 247 |
+
|
| 248 |
+
data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
|
| 249 |
+
|
| 250 |
+
files = sorted(os.listdir(data_path))
|
| 251 |
+
files = [
|
| 252 |
+
f
|
| 253 |
+
for f in files
|
| 254 |
+
if (not f.startswith(".")) and os.path.isdir(
|
| 255 |
+
os.path.join(data_path, f)
|
| 256 |
+
)
|
| 257 |
+
]
|
| 258 |
+
# pprint(list(enumerate(files)))
|
| 259 |
+
if split == "train":
|
| 260 |
+
assert len(files) == 3406, len(files)
|
| 261 |
+
elif split == "val":
|
| 262 |
+
assert len(files) == 487, len(files)
|
| 263 |
+
elif split == "test":
|
| 264 |
+
assert len(files) == 973, len(files)
|
| 265 |
+
|
| 266 |
+
self.n_tracks = len(files)
|
| 267 |
+
|
| 268 |
+
self.chunk_size = int(chunk_size_second * fs)
|
| 269 |
+
self.hop_size = int(hop_size_second * fs)
|
| 270 |
+
self.n_chunks_per_track = int(
|
| 271 |
+
(
|
| 272 |
+
self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
self.length = self.n_tracks * self.n_chunks_per_track
|
| 276 |
+
|
| 277 |
+
super().__init__(
|
| 278 |
+
data_path=data_path,
|
| 279 |
+
split=split,
|
| 280 |
+
stems=stems,
|
| 281 |
+
files=files,
|
| 282 |
+
fs=fs,
|
| 283 |
+
npy_memmap=npy_memmap,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
def get_identifier(self, index):
|
| 287 |
+
return super().get_identifier(index % self.n_tracks)
|
| 288 |
+
|
| 289 |
+
def __len__(self) -> int:
|
| 290 |
+
return self.length
|
| 291 |
+
|
| 292 |
+
def __getitem__(self, item: int) -> DataDict:
|
| 293 |
+
|
| 294 |
+
index = item % self.n_tracks
|
| 295 |
+
chunk = item // self.n_tracks
|
| 296 |
+
|
| 297 |
+
data_ = super().__getitem__(index)
|
| 298 |
+
|
| 299 |
+
audio = data_["audio"]
|
| 300 |
+
|
| 301 |
+
start = chunk * self.hop_size
|
| 302 |
+
end = start + self.chunk_size
|
| 303 |
+
|
| 304 |
+
for stem in self.stems:
|
| 305 |
+
data_["audio"][stem] = audio[stem][:, start:end]
|
| 306 |
+
|
| 307 |
+
return data_
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
|
| 311 |
+
DivideAndRemasterRandomChunkDataset
|
| 312 |
+
):
|
| 313 |
+
def __init__(
|
| 314 |
+
self,
|
| 315 |
+
data_root: str,
|
| 316 |
+
split: str,
|
| 317 |
+
target_length: int,
|
| 318 |
+
chunk_size_second: float,
|
| 319 |
+
stems: Optional[List[str]] = None,
|
| 320 |
+
fs: int = 44100,
|
| 321 |
+
npy_memmap: bool = True,
|
| 322 |
+
) -> None:
|
| 323 |
+
|
| 324 |
+
if stems is None:
|
| 325 |
+
stems = self.ALLOWED_STEMS
|
| 326 |
+
|
| 327 |
+
stems_no_mixture = [s for s in stems if s != "mixture"]
|
| 328 |
+
|
| 329 |
+
super().__init__(
|
| 330 |
+
data_root=data_root,
|
| 331 |
+
split=split,
|
| 332 |
+
target_length=target_length,
|
| 333 |
+
chunk_size_second=chunk_size_second,
|
| 334 |
+
stems=stems_no_mixture,
|
| 335 |
+
fs=fs,
|
| 336 |
+
npy_memmap=npy_memmap,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
self.stems = stems
|
| 340 |
+
self.stems_no_mixture = stems_no_mixture
|
| 341 |
+
|
| 342 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 343 |
+
|
| 344 |
+
data_ = super().__getitem__(index)
|
| 345 |
+
|
| 346 |
+
dry = data_["audio"]["speech"][:]
|
| 347 |
+
n_samples = dry.shape[-1]
|
| 348 |
+
|
| 349 |
+
wet_level = np.random.rand()
|
| 350 |
+
|
| 351 |
+
speech = pb.Reverb(
|
| 352 |
+
room_size=np.random.rand(),
|
| 353 |
+
damping=np.random.rand(),
|
| 354 |
+
wet_level=wet_level,
|
| 355 |
+
dry_level=(1 - wet_level),
|
| 356 |
+
width=np.random.rand()
|
| 357 |
+
).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples]
|
| 358 |
+
|
| 359 |
+
data_["audio"]["speech"] = speech
|
| 360 |
+
|
| 361 |
+
data_["audio"]["mixture"] = sum(
|
| 362 |
+
[data_["audio"][s] for s in self.stems_no_mixture]
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
return data_
|
| 366 |
+
|
| 367 |
+
def __len__(self) -> int:
|
| 368 |
+
return super().__len__()
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
if __name__ == "__main__":
|
| 372 |
+
|
| 373 |
+
from pprint import pprint
|
| 374 |
+
from tqdm.auto import tqdm
|
| 375 |
+
|
| 376 |
+
for split_ in ["train", "val", "test"]:
|
| 377 |
+
ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
|
| 378 |
+
data_root="$DATA_ROOT/DnR/v2np",
|
| 379 |
+
split=split_,
|
| 380 |
+
target_length=100,
|
| 381 |
+
chunk_size_second=6.0
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
print(split_, len(ds))
|
| 385 |
+
|
| 386 |
+
for track_ in tqdm(ds): # type: ignore
|
| 387 |
+
pprint(track_)
|
| 388 |
+
track_["audio"] = {k: v.shape for k, v in track_["audio"].items()}
|
| 389 |
+
pprint(track_)
|
| 390 |
+
# break
|
| 391 |
+
|
| 392 |
+
break
|
separator/models/bandit/core/data/dnr/preprocess.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torchaudio as ta
|
| 7 |
+
from tqdm.contrib.concurrent import process_map
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def process_one(inputs: Tuple[str, str, int]) -> None:
|
| 11 |
+
infile, outfile, target_fs = inputs
|
| 12 |
+
|
| 13 |
+
dir = os.path.dirname(outfile)
|
| 14 |
+
os.makedirs(dir, exist_ok=True)
|
| 15 |
+
|
| 16 |
+
data, fs = ta.load(infile)
|
| 17 |
+
|
| 18 |
+
if fs != target_fs:
|
| 19 |
+
data = ta.functional.resample(data, fs, target_fs, resampling_method="sinc_interp_kaiser")
|
| 20 |
+
fs = target_fs
|
| 21 |
+
|
| 22 |
+
data = data.numpy()
|
| 23 |
+
data = data.astype(np.float32)
|
| 24 |
+
|
| 25 |
+
if os.path.exists(outfile):
|
| 26 |
+
data_ = np.load(outfile)
|
| 27 |
+
if np.allclose(data, data_):
|
| 28 |
+
return
|
| 29 |
+
|
| 30 |
+
np.save(outfile, data)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def preprocess(
|
| 34 |
+
data_path: str,
|
| 35 |
+
output_path: str,
|
| 36 |
+
fs: int
|
| 37 |
+
) -> None:
|
| 38 |
+
files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
|
| 39 |
+
print(files)
|
| 40 |
+
outfiles = [
|
| 41 |
+
f.replace(data_path, output_path).replace(".wav", ".npy") for f in
|
| 42 |
+
files
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
os.makedirs(output_path, exist_ok=True)
|
| 46 |
+
inputs = list(zip(files, outfiles, [fs] * len(files)))
|
| 47 |
+
|
| 48 |
+
process_map(process_one, inputs, chunksize=32)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
import fire
|
| 53 |
+
|
| 54 |
+
fire.Fire()
|
separator/models/bandit/core/data/musdb/__init__.py
ADDED
|
File without changes
|
separator/models/bandit/core/data/musdb/datamodule.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from typing import Mapping, Optional
|
| 3 |
+
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
|
| 6 |
+
from models.bandit.core.data.musdb.dataset import (
|
| 7 |
+
MUSDB18BaseDataset,
|
| 8 |
+
MUSDB18FullTrackDataset,
|
| 9 |
+
MUSDB18SadDataset,
|
| 10 |
+
MUSDB18SadOnTheFlyAugmentedDataset
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def MUSDB18DataModule(
|
| 15 |
+
data_root: str = "$DATA_ROOT/MUSDB18/HQ",
|
| 16 |
+
target_stem: str = "vocals",
|
| 17 |
+
batch_size: int = 2,
|
| 18 |
+
num_workers: int = 8,
|
| 19 |
+
train_kwargs: Optional[Mapping] = None,
|
| 20 |
+
val_kwargs: Optional[Mapping] = None,
|
| 21 |
+
test_kwargs: Optional[Mapping] = None,
|
| 22 |
+
datamodule_kwargs: Optional[Mapping] = None,
|
| 23 |
+
use_on_the_fly: bool = True,
|
| 24 |
+
npy_memmap: bool = True
|
| 25 |
+
) -> pl.LightningDataModule:
|
| 26 |
+
if train_kwargs is None:
|
| 27 |
+
train_kwargs = {}
|
| 28 |
+
|
| 29 |
+
if val_kwargs is None:
|
| 30 |
+
val_kwargs = {}
|
| 31 |
+
|
| 32 |
+
if test_kwargs is None:
|
| 33 |
+
test_kwargs = {}
|
| 34 |
+
|
| 35 |
+
if datamodule_kwargs is None:
|
| 36 |
+
datamodule_kwargs = {}
|
| 37 |
+
|
| 38 |
+
train_dataset: MUSDB18BaseDataset
|
| 39 |
+
|
| 40 |
+
if use_on_the_fly:
|
| 41 |
+
train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
|
| 42 |
+
data_root=os.path.join(data_root, "saded-np"),
|
| 43 |
+
split="train",
|
| 44 |
+
target_stem=target_stem,
|
| 45 |
+
**train_kwargs
|
| 46 |
+
)
|
| 47 |
+
else:
|
| 48 |
+
train_dataset = MUSDB18SadDataset(
|
| 49 |
+
data_root=os.path.join(data_root, "saded-np"),
|
| 50 |
+
split="train",
|
| 51 |
+
target_stem=target_stem,
|
| 52 |
+
**train_kwargs
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
datamodule = pl.LightningDataModule.from_datasets(
|
| 56 |
+
train_dataset=train_dataset,
|
| 57 |
+
val_dataset=MUSDB18SadDataset(
|
| 58 |
+
data_root=os.path.join(data_root, "saded-np"),
|
| 59 |
+
split="val",
|
| 60 |
+
target_stem=target_stem,
|
| 61 |
+
**val_kwargs
|
| 62 |
+
),
|
| 63 |
+
test_dataset=MUSDB18FullTrackDataset(
|
| 64 |
+
data_root=os.path.join(data_root, "canonical"),
|
| 65 |
+
split="test",
|
| 66 |
+
**test_kwargs
|
| 67 |
+
),
|
| 68 |
+
batch_size=batch_size,
|
| 69 |
+
num_workers=num_workers,
|
| 70 |
+
**datamodule_kwargs
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
datamodule.predict_dataloader = ( # type: ignore[method-assign]
|
| 74 |
+
datamodule.test_dataloader
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return datamodule
|
separator/models/bandit/core/data/musdb/dataset.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import ABC
|
| 3 |
+
from typing import List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torchaudio as ta
|
| 8 |
+
from torch.utils import data
|
| 9 |
+
|
| 10 |
+
from models.bandit.core.data._types import AudioDict, DataDict
|
| 11 |
+
from models.bandit.core.data.base import BaseSourceSeparationDataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
|
| 15 |
+
|
| 16 |
+
ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
split: str,
|
| 21 |
+
stems: List[str],
|
| 22 |
+
files: List[str],
|
| 23 |
+
data_path: str,
|
| 24 |
+
fs: int = 44100,
|
| 25 |
+
npy_memmap=False,
|
| 26 |
+
) -> None:
|
| 27 |
+
super().__init__(
|
| 28 |
+
split=split,
|
| 29 |
+
stems=stems,
|
| 30 |
+
files=files,
|
| 31 |
+
data_path=data_path,
|
| 32 |
+
fs=fs,
|
| 33 |
+
npy_memmap=npy_memmap,
|
| 34 |
+
recompute_mixture=False
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
|
| 38 |
+
track = identifier["track"]
|
| 39 |
+
path = os.path.join(self.data_path, track)
|
| 40 |
+
# noinspection PyUnresolvedReferences
|
| 41 |
+
|
| 42 |
+
if self.npy_memmap:
|
| 43 |
+
audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r")
|
| 44 |
+
else:
|
| 45 |
+
audio, _ = ta.load(os.path.join(path, f"{stem}.wav"))
|
| 46 |
+
|
| 47 |
+
return audio
|
| 48 |
+
|
| 49 |
+
def get_identifier(self, index):
|
| 50 |
+
return dict(track=self.files[index])
|
| 51 |
+
|
| 52 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 53 |
+
identifier = self.get_identifier(index)
|
| 54 |
+
audio = self.get_audio(identifier)
|
| 55 |
+
|
| 56 |
+
return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class MUSDB18FullTrackDataset(MUSDB18BaseDataset):
|
| 60 |
+
|
| 61 |
+
N_TRAIN_TRACKS = 100
|
| 62 |
+
N_TEST_TRACKS = 50
|
| 63 |
+
VALIDATION_FILES = [
|
| 64 |
+
"Actions - One Minute Smile",
|
| 65 |
+
"Clara Berry And Wooldog - Waltz For My Victims",
|
| 66 |
+
"Johnny Lokke - Promises & Lies",
|
| 67 |
+
"Patrick Talbot - A Reason To Leave",
|
| 68 |
+
"Triviul - Angelsaint",
|
| 69 |
+
"Alexander Ross - Goodbye Bolero",
|
| 70 |
+
"Fergessen - Nos Palpitants",
|
| 71 |
+
"Leaf - Summerghost",
|
| 72 |
+
"Skelpolu - Human Mistakes",
|
| 73 |
+
"Young Griffo - Pennies",
|
| 74 |
+
"ANiMAL - Rockshow",
|
| 75 |
+
"James May - On The Line",
|
| 76 |
+
"Meaxic - Take A Step",
|
| 77 |
+
"Traffic Experiment - Sirens",
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self, data_root: str, split: str, stems: Optional[List[
|
| 82 |
+
str]] = None
|
| 83 |
+
) -> None:
|
| 84 |
+
|
| 85 |
+
if stems is None:
|
| 86 |
+
stems = self.ALLOWED_STEMS
|
| 87 |
+
self.stems = stems
|
| 88 |
+
|
| 89 |
+
if split == "test":
|
| 90 |
+
subset = "test"
|
| 91 |
+
elif split in ["train", "val"]:
|
| 92 |
+
subset = "train"
|
| 93 |
+
else:
|
| 94 |
+
raise NameError
|
| 95 |
+
|
| 96 |
+
data_path = os.path.join(data_root, subset)
|
| 97 |
+
|
| 98 |
+
files = sorted(os.listdir(data_path))
|
| 99 |
+
files = [f for f in files if not f.startswith(".")]
|
| 100 |
+
# pprint(list(enumerate(files)))
|
| 101 |
+
if subset == "train":
|
| 102 |
+
assert len(files) == 100, len(files)
|
| 103 |
+
if split == "train":
|
| 104 |
+
files = [f for f in files if f not in self.VALIDATION_FILES]
|
| 105 |
+
assert len(files) == 100 - len(self.VALIDATION_FILES)
|
| 106 |
+
else:
|
| 107 |
+
files = [f for f in files if f in self.VALIDATION_FILES]
|
| 108 |
+
assert len(files) == len(self.VALIDATION_FILES)
|
| 109 |
+
else:
|
| 110 |
+
split = "test"
|
| 111 |
+
assert len(files) == 50
|
| 112 |
+
|
| 113 |
+
self.n_tracks = len(files)
|
| 114 |
+
|
| 115 |
+
super().__init__(
|
| 116 |
+
data_path=data_path,
|
| 117 |
+
split=split,
|
| 118 |
+
stems=stems,
|
| 119 |
+
files=files
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def __len__(self) -> int:
|
| 123 |
+
return self.n_tracks
|
| 124 |
+
|
| 125 |
+
class MUSDB18SadDataset(MUSDB18BaseDataset):
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
data_root: str,
|
| 129 |
+
split: str,
|
| 130 |
+
target_stem: str,
|
| 131 |
+
stems: Optional[List[str]] = None,
|
| 132 |
+
target_length: Optional[int] = None,
|
| 133 |
+
npy_memmap=False,
|
| 134 |
+
) -> None:
|
| 135 |
+
|
| 136 |
+
if stems is None:
|
| 137 |
+
stems = self.ALLOWED_STEMS
|
| 138 |
+
|
| 139 |
+
data_path = os.path.join(data_root, target_stem, split)
|
| 140 |
+
|
| 141 |
+
files = sorted(os.listdir(data_path))
|
| 142 |
+
files = [f for f in files if not f.startswith(".")]
|
| 143 |
+
|
| 144 |
+
super().__init__(
|
| 145 |
+
data_path=data_path,
|
| 146 |
+
split=split,
|
| 147 |
+
stems=stems,
|
| 148 |
+
files=files,
|
| 149 |
+
npy_memmap=npy_memmap
|
| 150 |
+
)
|
| 151 |
+
self.n_segments = len(files)
|
| 152 |
+
self.target_stem = target_stem
|
| 153 |
+
self.target_length = (
|
| 154 |
+
target_length if target_length is not None else self.n_segments
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
def __len__(self) -> int:
|
| 158 |
+
return self.target_length
|
| 159 |
+
|
| 160 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 161 |
+
|
| 162 |
+
index = index % self.n_segments
|
| 163 |
+
|
| 164 |
+
return super().__getitem__(index)
|
| 165 |
+
|
| 166 |
+
def get_identifier(self, index):
|
| 167 |
+
return super().get_identifier(index % self.n_segments)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset):
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
data_root: str,
|
| 174 |
+
split: str,
|
| 175 |
+
target_stem: str,
|
| 176 |
+
stems: Optional[List[str]] = None,
|
| 177 |
+
target_length: int = 20000,
|
| 178 |
+
apply_probability: Optional[float] = None,
|
| 179 |
+
chunk_size_second: float = 3.0,
|
| 180 |
+
random_scale_range_db: Tuple[float, float] = (-10, 10),
|
| 181 |
+
drop_probability: float = 0.1,
|
| 182 |
+
rescale: bool = True,
|
| 183 |
+
) -> None:
|
| 184 |
+
super().__init__(data_root, split, target_stem, stems)
|
| 185 |
+
|
| 186 |
+
if apply_probability is None:
|
| 187 |
+
apply_probability = (
|
| 188 |
+
target_length - self.n_segments) / target_length
|
| 189 |
+
|
| 190 |
+
self.apply_probability = apply_probability
|
| 191 |
+
self.drop_probability = drop_probability
|
| 192 |
+
self.chunk_size_second = chunk_size_second
|
| 193 |
+
self.random_scale_range_db = random_scale_range_db
|
| 194 |
+
self.rescale = rescale
|
| 195 |
+
|
| 196 |
+
self.chunk_size_sample = int(self.chunk_size_second * self.fs)
|
| 197 |
+
self.target_length = target_length
|
| 198 |
+
|
| 199 |
+
def __len__(self) -> int:
|
| 200 |
+
return self.target_length
|
| 201 |
+
|
| 202 |
+
def __getitem__(self, index: int) -> DataDict:
|
| 203 |
+
|
| 204 |
+
index = index % self.n_segments
|
| 205 |
+
|
| 206 |
+
# if np.random.rand() > self.apply_probability:
|
| 207 |
+
# return super().__getitem__(index)
|
| 208 |
+
|
| 209 |
+
audio = {}
|
| 210 |
+
identifier = self.get_identifier(index)
|
| 211 |
+
|
| 212 |
+
# assert self.target_stem in self.stems_no_mixture
|
| 213 |
+
for stem in self.stems_no_mixture:
|
| 214 |
+
if stem == self.target_stem:
|
| 215 |
+
identifier_ = identifier
|
| 216 |
+
else:
|
| 217 |
+
if np.random.rand() < self.apply_probability:
|
| 218 |
+
index_ = np.random.randint(self.n_segments)
|
| 219 |
+
identifier_ = self.get_identifier(index_)
|
| 220 |
+
else:
|
| 221 |
+
identifier_ = identifier
|
| 222 |
+
|
| 223 |
+
audio[stem] = self.get_stem(stem=stem, identifier=identifier_)
|
| 224 |
+
|
| 225 |
+
# if stem == self.target_stem:
|
| 226 |
+
|
| 227 |
+
if self.chunk_size_sample < audio[stem].shape[-1]:
|
| 228 |
+
chunk_start = np.random.randint(
|
| 229 |
+
audio[stem].shape[-1] - self.chunk_size_sample
|
| 230 |
+
)
|
| 231 |
+
else:
|
| 232 |
+
chunk_start = 0
|
| 233 |
+
|
| 234 |
+
if np.random.rand() < self.drop_probability:
|
| 235 |
+
# db_scale = "-inf"
|
| 236 |
+
linear_scale = 0.0
|
| 237 |
+
else:
|
| 238 |
+
db_scale = np.random.uniform(*self.random_scale_range_db)
|
| 239 |
+
linear_scale = np.power(10, db_scale / 20)
|
| 240 |
+
# db_scale = f"{db_scale:+2.1f}"
|
| 241 |
+
# print(linear_scale)
|
| 242 |
+
audio[stem][...,
|
| 243 |
+
chunk_start: chunk_start + self.chunk_size_sample] = (
|
| 244 |
+
linear_scale
|
| 245 |
+
* audio[stem][...,
|
| 246 |
+
chunk_start: chunk_start + self.chunk_size_sample]
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
audio["mixture"] = self.compute_mixture(audio)
|
| 250 |
+
|
| 251 |
+
if self.rescale:
|
| 252 |
+
max_abs_val = max(
|
| 253 |
+
[torch.max(torch.abs(audio[stem])) for stem in self.stems]
|
| 254 |
+
) # type: ignore[type-var]
|
| 255 |
+
if max_abs_val > 1:
|
| 256 |
+
audio = {k: v / max_abs_val for k, v in audio.items()}
|
| 257 |
+
|
| 258 |
+
track = identifier["track"]
|
| 259 |
+
|
| 260 |
+
return {"audio": audio, "track": f"{self.split}/{track}"}
|
| 261 |
+
|
| 262 |
+
# if __name__ == "__main__":
|
| 263 |
+
#
|
| 264 |
+
# from pprint import pprint
|
| 265 |
+
# from tqdm.auto import tqdm
|
| 266 |
+
#
|
| 267 |
+
# for split_ in ["train", "val", "test"]:
|
| 268 |
+
# ds = MUSDB18SadOnTheFlyAugmentedDataset(
|
| 269 |
+
# data_root="$DATA_ROOT/MUSDB18/HQ/saded",
|
| 270 |
+
# split=split_,
|
| 271 |
+
# target_stem="vocals"
|
| 272 |
+
# )
|
| 273 |
+
#
|
| 274 |
+
# print(split_, len(ds))
|
| 275 |
+
#
|
| 276 |
+
# for track_ in tqdm(ds):
|
| 277 |
+
# track_["audio"] = {
|
| 278 |
+
# k: v.shape for k, v in track_["audio"].items()
|
| 279 |
+
# }
|
| 280 |
+
# pprint(track_)
|
separator/models/bandit/core/data/musdb/preprocess.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio as ta
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
from tqdm.contrib.concurrent import process_map
|
| 10 |
+
|
| 11 |
+
from core.data._types import DataDict
|
| 12 |
+
from core.data.musdb.dataset import MUSDB18FullTrackDataset
|
| 13 |
+
import pyloudnorm as pyln
|
| 14 |
+
|
| 15 |
+
class SourceActivityDetector(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
analysis_stem: str,
|
| 19 |
+
output_path: str,
|
| 20 |
+
fs: int = 44100,
|
| 21 |
+
segment_length_second: float = 6.0,
|
| 22 |
+
hop_length_second: float = 3.0,
|
| 23 |
+
n_chunks: int = 10,
|
| 24 |
+
chunk_epsilon: float = 1e-5,
|
| 25 |
+
energy_threshold_quantile: float = 0.15,
|
| 26 |
+
segment_epsilon: float = 1e-3,
|
| 27 |
+
salient_proportion_threshold: float = 0.5,
|
| 28 |
+
target_lufs: float = -24
|
| 29 |
+
) -> None:
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
self.fs = fs
|
| 33 |
+
self.segment_length = int(segment_length_second * self.fs)
|
| 34 |
+
self.hop_length = int(hop_length_second * self.fs)
|
| 35 |
+
self.n_chunks = n_chunks
|
| 36 |
+
assert self.segment_length % self.n_chunks == 0
|
| 37 |
+
self.chunk_size = self.segment_length // self.n_chunks
|
| 38 |
+
self.chunk_epsilon = chunk_epsilon
|
| 39 |
+
self.energy_threshold_quantile = energy_threshold_quantile
|
| 40 |
+
self.segment_epsilon = segment_epsilon
|
| 41 |
+
self.salient_proportion_threshold = salient_proportion_threshold
|
| 42 |
+
self.analysis_stem = analysis_stem
|
| 43 |
+
|
| 44 |
+
self.meter = pyln.Meter(self.fs)
|
| 45 |
+
self.target_lufs = target_lufs
|
| 46 |
+
|
| 47 |
+
self.output_path = output_path
|
| 48 |
+
|
| 49 |
+
def forward(self, data: DataDict) -> None:
|
| 50 |
+
|
| 51 |
+
stem_ = self.analysis_stem if (
|
| 52 |
+
self.analysis_stem != "none") else "mixture"
|
| 53 |
+
|
| 54 |
+
x = data["audio"][stem_]
|
| 55 |
+
|
| 56 |
+
xnp = x.numpy()
|
| 57 |
+
loudness = self.meter.integrated_loudness(xnp.T)
|
| 58 |
+
|
| 59 |
+
for stem in data["audio"]:
|
| 60 |
+
s = data["audio"][stem]
|
| 61 |
+
s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T
|
| 62 |
+
s = torch.as_tensor(s)
|
| 63 |
+
data["audio"][stem] = s
|
| 64 |
+
|
| 65 |
+
if x.ndim == 3:
|
| 66 |
+
assert x.shape[0] == 1
|
| 67 |
+
x = x[0]
|
| 68 |
+
|
| 69 |
+
n_chan, n_samples = x.shape
|
| 70 |
+
|
| 71 |
+
n_segments = (
|
| 72 |
+
int(
|
| 73 |
+
np.ceil((n_samples - self.segment_length) / self.hop_length)
|
| 74 |
+
) + 1
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
segments = torch.zeros((n_segments, n_chan, self.segment_length))
|
| 78 |
+
for i in range(n_segments):
|
| 79 |
+
start = i * self.hop_length
|
| 80 |
+
end = start + self.segment_length
|
| 81 |
+
end = min(end, n_samples)
|
| 82 |
+
|
| 83 |
+
xseg = x[:, start:end]
|
| 84 |
+
|
| 85 |
+
if end - start < self.segment_length:
|
| 86 |
+
xseg = F.pad(
|
| 87 |
+
xseg,
|
| 88 |
+
pad=(0, self.segment_length - (end - start)),
|
| 89 |
+
value=torch.nan
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
segments[i, :, :] = xseg
|
| 93 |
+
|
| 94 |
+
chunks = segments.reshape(
|
| 95 |
+
(n_segments, n_chan, self.n_chunks, self.chunk_size)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
if self.analysis_stem != "none":
|
| 99 |
+
chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
|
| 100 |
+
chunk_energies = torch.nan_to_num(chunk_energies, nan=0)
|
| 101 |
+
chunk_energies[chunk_energies == 0] = self.chunk_epsilon
|
| 102 |
+
|
| 103 |
+
energy_threshold = torch.nanquantile(
|
| 104 |
+
chunk_energies, q=self.energy_threshold_quantile
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
if energy_threshold < self.segment_epsilon:
|
| 108 |
+
energy_threshold = self.segment_epsilon # type: ignore[assignment]
|
| 109 |
+
|
| 110 |
+
chunks_above_threshold = chunk_energies > energy_threshold
|
| 111 |
+
n_chunks_above_threshold = torch.mean(
|
| 112 |
+
chunks_above_threshold.to(torch.float), dim=-1
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
segment_above_threshold = (
|
| 116 |
+
n_chunks_above_threshold > self.salient_proportion_threshold
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
if torch.sum(segment_above_threshold) == 0:
|
| 120 |
+
return
|
| 121 |
+
|
| 122 |
+
else:
|
| 123 |
+
segment_above_threshold = torch.ones((n_segments,))
|
| 124 |
+
|
| 125 |
+
for i in range(n_segments):
|
| 126 |
+
if not segment_above_threshold[i]:
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
outpath = os.path.join(
|
| 130 |
+
self.output_path,
|
| 131 |
+
self.analysis_stem,
|
| 132 |
+
f"{data['track']} - {self.analysis_stem}{i:03d}",
|
| 133 |
+
)
|
| 134 |
+
os.makedirs(outpath, exist_ok=True)
|
| 135 |
+
|
| 136 |
+
for stem in data["audio"]:
|
| 137 |
+
if stem == self.analysis_stem:
|
| 138 |
+
segment = torch.nan_to_num(segments[i, :, :], nan=0)
|
| 139 |
+
else:
|
| 140 |
+
start = i * self.hop_length
|
| 141 |
+
end = start + self.segment_length
|
| 142 |
+
end = min(n_samples, end)
|
| 143 |
+
|
| 144 |
+
segment = data["audio"][stem][:, start:end]
|
| 145 |
+
|
| 146 |
+
if end - start < self.segment_length:
|
| 147 |
+
segment = F.pad(
|
| 148 |
+
segment,
|
| 149 |
+
(0, self.segment_length - (end - start))
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
assert segment.shape[-1] == self.segment_length, segment.shape
|
| 153 |
+
|
| 154 |
+
# ta.save(os.path.join(outpath, f"{stem}.wav"), segment, self.fs)
|
| 155 |
+
|
| 156 |
+
np.save(os.path.join(outpath, f"{stem}.wav"), segment)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def preprocess(
|
| 160 |
+
analysis_stem: str,
|
| 161 |
+
output_path: str = "/data/MUSDB18/HQ/saded-np",
|
| 162 |
+
fs: int = 44100,
|
| 163 |
+
segment_length_second: float = 6.0,
|
| 164 |
+
hop_length_second: float = 3.0,
|
| 165 |
+
n_chunks: int = 10,
|
| 166 |
+
chunk_epsilon: float = 1e-5,
|
| 167 |
+
energy_threshold_quantile: float = 0.15,
|
| 168 |
+
segment_epsilon: float = 1e-3,
|
| 169 |
+
salient_proportion_threshold: float = 0.5,
|
| 170 |
+
) -> None:
|
| 171 |
+
|
| 172 |
+
sad = SourceActivityDetector(
|
| 173 |
+
analysis_stem=analysis_stem,
|
| 174 |
+
output_path=output_path,
|
| 175 |
+
fs=fs,
|
| 176 |
+
segment_length_second=segment_length_second,
|
| 177 |
+
hop_length_second=hop_length_second,
|
| 178 |
+
n_chunks=n_chunks,
|
| 179 |
+
chunk_epsilon=chunk_epsilon,
|
| 180 |
+
energy_threshold_quantile=energy_threshold_quantile,
|
| 181 |
+
segment_epsilon=segment_epsilon,
|
| 182 |
+
salient_proportion_threshold=salient_proportion_threshold,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
for split in ["train", "val", "test"]:
|
| 186 |
+
ds = MUSDB18FullTrackDataset(
|
| 187 |
+
data_root="/data/MUSDB18/HQ/canonical",
|
| 188 |
+
split=split,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
tracks = []
|
| 192 |
+
for i, track in enumerate(tqdm(ds, total=len(ds))):
|
| 193 |
+
if i % 32 == 0 and tracks:
|
| 194 |
+
process_map(sad, tracks, max_workers=8)
|
| 195 |
+
tracks = []
|
| 196 |
+
tracks.append(track)
|
| 197 |
+
process_map(sad, tracks, max_workers=8)
|
| 198 |
+
|
| 199 |
+
def loudness_norm_one(
|
| 200 |
+
inputs
|
| 201 |
+
):
|
| 202 |
+
infile, outfile, target_lufs = inputs
|
| 203 |
+
|
| 204 |
+
audio, fs = ta.load(infile)
|
| 205 |
+
audio = audio.mean(dim=0, keepdim=True).numpy().T
|
| 206 |
+
|
| 207 |
+
meter = pyln.Meter(fs)
|
| 208 |
+
loudness = meter.integrated_loudness(audio)
|
| 209 |
+
audio = pyln.normalize.loudness(audio, loudness, target_lufs)
|
| 210 |
+
|
| 211 |
+
os.makedirs(os.path.dirname(outfile), exist_ok=True)
|
| 212 |
+
np.save(outfile, audio.T)
|
| 213 |
+
|
| 214 |
+
def loudness_norm(
|
| 215 |
+
data_path: str,
|
| 216 |
+
# output_path: str,
|
| 217 |
+
target_lufs = -17.0,
|
| 218 |
+
):
|
| 219 |
+
files = glob.glob(
|
| 220 |
+
os.path.join(data_path, "**", "*.wav"), recursive=True
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
outfiles = [
|
| 224 |
+
f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files
|
| 225 |
+
]
|
| 226 |
+
|
| 227 |
+
files = [(f, o, target_lufs) for f, o in zip(files, outfiles)]
|
| 228 |
+
|
| 229 |
+
process_map(loudness_norm_one, files, chunksize=2)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
if __name__ == "__main__":
|
| 234 |
+
|
| 235 |
+
from tqdm.auto import tqdm
|
| 236 |
+
import fire
|
| 237 |
+
|
| 238 |
+
fire.Fire()
|
separator/models/bandit/core/data/musdb/validation.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
validation:
|
| 2 |
+
- 'Actions - One Minute Smile'
|
| 3 |
+
- 'Clara Berry And Wooldog - Waltz For My Victims'
|
| 4 |
+
- 'Johnny Lokke - Promises & Lies'
|
| 5 |
+
- 'Patrick Talbot - A Reason To Leave'
|
| 6 |
+
- 'Triviul - Angelsaint'
|
| 7 |
+
- 'Alexander Ross - Goodbye Bolero'
|
| 8 |
+
- 'Fergessen - Nos Palpitants'
|
| 9 |
+
- 'Leaf - Summerghost'
|
| 10 |
+
- 'Skelpolu - Human Mistakes'
|
| 11 |
+
- 'Young Griffo - Pennies'
|
| 12 |
+
- 'ANiMAL - Rockshow'
|
| 13 |
+
- 'James May - On The Line'
|
| 14 |
+
- 'Meaxic - Take A Step'
|
| 15 |
+
- 'Traffic Experiment - Sirens'
|
separator/models/bandit/core/loss/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._multistem import MultiStemWrapperFromConfig
|
| 2 |
+
from ._timefreq import ReImL1Loss, ReImL2Loss, TimeFreqL1Loss, TimeFreqL2Loss, TimeFreqSignalNoisePNormRatioLoss
|
separator/models/bandit/core/loss/_complex.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.modules import loss as _loss
|
| 6 |
+
from torch.nn.modules.loss import _Loss
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ReImLossWrapper(_Loss):
|
| 10 |
+
def __init__(self, module: _Loss) -> None:
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.module = module
|
| 13 |
+
|
| 14 |
+
def forward(
|
| 15 |
+
self,
|
| 16 |
+
preds: torch.Tensor,
|
| 17 |
+
target: torch.Tensor
|
| 18 |
+
) -> torch.Tensor:
|
| 19 |
+
return self.module(
|
| 20 |
+
torch.view_as_real(preds),
|
| 21 |
+
torch.view_as_real(target)
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ReImL1Loss(ReImLossWrapper):
|
| 26 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 27 |
+
l1_loss = _loss.L1Loss(**kwargs)
|
| 28 |
+
super().__init__(module=(l1_loss))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ReImL2Loss(ReImLossWrapper):
|
| 32 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 33 |
+
l2_loss = _loss.MSELoss(**kwargs)
|
| 34 |
+
super().__init__(module=(l2_loss))
|
separator/models/bandit/core/loss/_multistem.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from asteroid import losses as asteroid_losses
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn.modules.loss import _Loss
|
| 7 |
+
|
| 8 |
+
from . import snr
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss:
|
| 12 |
+
|
| 13 |
+
for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]:
|
| 14 |
+
if name in module.__dict__:
|
| 15 |
+
return module.__dict__[name](**kwargs)
|
| 16 |
+
|
| 17 |
+
raise NameError
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MultiStemWrapper(_Loss):
|
| 21 |
+
def __init__(self, module: _Loss, modality: str = "audio") -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.loss = module
|
| 24 |
+
self.modality = modality
|
| 25 |
+
|
| 26 |
+
def forward(
|
| 27 |
+
self,
|
| 28 |
+
preds: Dict[str, Dict[str, torch.Tensor]],
|
| 29 |
+
target: Dict[str, Dict[str, torch.Tensor]],
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
loss = {
|
| 32 |
+
stem: self.loss(
|
| 33 |
+
preds[self.modality][stem],
|
| 34 |
+
target[self.modality][stem]
|
| 35 |
+
)
|
| 36 |
+
for stem in preds[self.modality] if stem in target[self.modality]
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
return sum(list(loss.values()))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class MultiStemWrapperFromConfig(MultiStemWrapper):
|
| 43 |
+
def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None:
|
| 44 |
+
loss = parse_loss(name, kwargs)
|
| 45 |
+
super().__init__(module=loss, modality=modality)
|
separator/models/bandit/core/loss/_timefreq.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.modules.loss import _Loss
|
| 6 |
+
|
| 7 |
+
from models.bandit.core.loss._multistem import MultiStemWrapper
|
| 8 |
+
from models.bandit.core.loss._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
|
| 9 |
+
from models.bandit.core.loss.snr import SignalNoisePNormRatio
|
| 10 |
+
|
| 11 |
+
class TimeFreqWrapper(_Loss):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
time_module: _Loss,
|
| 15 |
+
freq_module: Optional[_Loss] = None,
|
| 16 |
+
time_weight: float = 1.0,
|
| 17 |
+
freq_weight: float = 1.0,
|
| 18 |
+
multistem: bool = True,
|
| 19 |
+
) -> None:
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
if freq_module is None:
|
| 23 |
+
freq_module = time_module
|
| 24 |
+
|
| 25 |
+
if multistem:
|
| 26 |
+
time_module = MultiStemWrapper(time_module, modality="audio")
|
| 27 |
+
freq_module = MultiStemWrapper(freq_module, modality="spectrogram")
|
| 28 |
+
|
| 29 |
+
self.time_module = time_module
|
| 30 |
+
self.freq_module = freq_module
|
| 31 |
+
|
| 32 |
+
self.time_weight = time_weight
|
| 33 |
+
self.freq_weight = freq_weight
|
| 34 |
+
|
| 35 |
+
# TODO: add better type hints
|
| 36 |
+
def forward(self, preds: Any, target: Any) -> torch.Tensor:
|
| 37 |
+
|
| 38 |
+
return self.time_weight * self.time_module(
|
| 39 |
+
preds, target
|
| 40 |
+
) + self.freq_weight * self.freq_module(preds, target)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class TimeFreqL1Loss(TimeFreqWrapper):
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
time_weight: float = 1.0,
|
| 47 |
+
freq_weight: float = 1.0,
|
| 48 |
+
tkwargs: Optional[Dict[str, Any]] = None,
|
| 49 |
+
fkwargs: Optional[Dict[str, Any]] = None,
|
| 50 |
+
multistem: bool = True,
|
| 51 |
+
) -> None:
|
| 52 |
+
if tkwargs is None:
|
| 53 |
+
tkwargs = {}
|
| 54 |
+
if fkwargs is None:
|
| 55 |
+
fkwargs = {}
|
| 56 |
+
time_module = (nn.L1Loss(**tkwargs))
|
| 57 |
+
freq_module = ReImL1Loss(**fkwargs)
|
| 58 |
+
super().__init__(
|
| 59 |
+
time_module,
|
| 60 |
+
freq_module,
|
| 61 |
+
time_weight,
|
| 62 |
+
freq_weight,
|
| 63 |
+
multistem
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class TimeFreqL2Loss(TimeFreqWrapper):
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
time_weight: float = 1.0,
|
| 71 |
+
freq_weight: float = 1.0,
|
| 72 |
+
tkwargs: Optional[Dict[str, Any]] = None,
|
| 73 |
+
fkwargs: Optional[Dict[str, Any]] = None,
|
| 74 |
+
multistem: bool = True,
|
| 75 |
+
) -> None:
|
| 76 |
+
if tkwargs is None:
|
| 77 |
+
tkwargs = {}
|
| 78 |
+
if fkwargs is None:
|
| 79 |
+
fkwargs = {}
|
| 80 |
+
time_module = nn.MSELoss(**tkwargs)
|
| 81 |
+
freq_module = ReImL2Loss(**fkwargs)
|
| 82 |
+
super().__init__(
|
| 83 |
+
time_module,
|
| 84 |
+
freq_module,
|
| 85 |
+
time_weight,
|
| 86 |
+
freq_weight,
|
| 87 |
+
multistem
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper):
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
time_weight: float = 1.0,
|
| 96 |
+
freq_weight: float = 1.0,
|
| 97 |
+
tkwargs: Optional[Dict[str, Any]] = None,
|
| 98 |
+
fkwargs: Optional[Dict[str, Any]] = None,
|
| 99 |
+
multistem: bool = True,
|
| 100 |
+
) -> None:
|
| 101 |
+
if tkwargs is None:
|
| 102 |
+
tkwargs = {}
|
| 103 |
+
if fkwargs is None:
|
| 104 |
+
fkwargs = {}
|
| 105 |
+
time_module = SignalNoisePNormRatio(**tkwargs)
|
| 106 |
+
freq_module = SignalNoisePNormRatio(**fkwargs)
|
| 107 |
+
super().__init__(
|
| 108 |
+
time_module,
|
| 109 |
+
freq_module,
|
| 110 |
+
time_weight,
|
| 111 |
+
freq_weight,
|
| 112 |
+
multistem
|
| 113 |
+
)
|
separator/models/bandit/core/loss/snr.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.nn.modules.loss import _Loss
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
class SignalNoisePNormRatio(_Loss):
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
p: float = 1.0,
|
| 9 |
+
scale_invariant: bool = False,
|
| 10 |
+
zero_mean: bool = False,
|
| 11 |
+
take_log: bool = True,
|
| 12 |
+
reduction: str = "mean",
|
| 13 |
+
EPS: float = 1e-3,
|
| 14 |
+
) -> None:
|
| 15 |
+
assert reduction != "sum", NotImplementedError
|
| 16 |
+
super().__init__(reduction=reduction)
|
| 17 |
+
assert not zero_mean
|
| 18 |
+
|
| 19 |
+
self.p = p
|
| 20 |
+
|
| 21 |
+
self.EPS = EPS
|
| 22 |
+
self.take_log = take_log
|
| 23 |
+
|
| 24 |
+
self.scale_invariant = scale_invariant
|
| 25 |
+
|
| 26 |
+
def forward(
|
| 27 |
+
self,
|
| 28 |
+
est_target: torch.Tensor,
|
| 29 |
+
target: torch.Tensor
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
|
| 32 |
+
target_ = target
|
| 33 |
+
if self.scale_invariant:
|
| 34 |
+
ndim = target.ndim
|
| 35 |
+
dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
|
| 36 |
+
s_target_energy = (
|
| 37 |
+
torch.sum(target * torch.conj(target), dim=-1, keepdim=True)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
if ndim > 2:
|
| 41 |
+
dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
|
| 42 |
+
s_target_energy = torch.sum(s_target_energy, dim=list(range(1, ndim)), keepdim=True)
|
| 43 |
+
|
| 44 |
+
target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
|
| 45 |
+
target = target_ * target_scaler
|
| 46 |
+
|
| 47 |
+
if torch.is_complex(est_target):
|
| 48 |
+
est_target = torch.view_as_real(est_target)
|
| 49 |
+
target = torch.view_as_real(target)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
batch_size = est_target.shape[0]
|
| 53 |
+
est_target = est_target.reshape(batch_size, -1)
|
| 54 |
+
target = target.reshape(batch_size, -1)
|
| 55 |
+
# target_ = target_.reshape(batch_size, -1)
|
| 56 |
+
|
| 57 |
+
if self.p == 1:
|
| 58 |
+
e_error = torch.abs(est_target-target).mean(dim=-1)
|
| 59 |
+
e_target = torch.abs(target).mean(dim=-1)
|
| 60 |
+
elif self.p == 2:
|
| 61 |
+
e_error = torch.square(est_target-target).mean(dim=-1)
|
| 62 |
+
e_target = torch.square(target).mean(dim=-1)
|
| 63 |
+
else:
|
| 64 |
+
raise NotImplementedError
|
| 65 |
+
|
| 66 |
+
if self.take_log:
|
| 67 |
+
loss = 10*(torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS))
|
| 68 |
+
else:
|
| 69 |
+
loss = (e_error + self.EPS)/(e_target + self.EPS)
|
| 70 |
+
|
| 71 |
+
if self.reduction == "mean":
|
| 72 |
+
loss = loss.mean()
|
| 73 |
+
elif self.reduction == "sum":
|
| 74 |
+
loss = loss.sum()
|
| 75 |
+
|
| 76 |
+
return loss
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class MultichannelSingleSrcNegSDR(_Loss):
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
sdr_type: str,
|
| 84 |
+
p: float = 2.0,
|
| 85 |
+
zero_mean: bool = True,
|
| 86 |
+
take_log: bool = True,
|
| 87 |
+
reduction: str = "mean",
|
| 88 |
+
EPS: float = 1e-8,
|
| 89 |
+
) -> None:
|
| 90 |
+
assert reduction != "sum", NotImplementedError
|
| 91 |
+
super().__init__(reduction=reduction)
|
| 92 |
+
|
| 93 |
+
assert sdr_type in ["snr", "sisdr", "sdsdr"]
|
| 94 |
+
self.sdr_type = sdr_type
|
| 95 |
+
self.zero_mean = zero_mean
|
| 96 |
+
self.take_log = take_log
|
| 97 |
+
self.EPS = 1e-8
|
| 98 |
+
|
| 99 |
+
self.p = p
|
| 100 |
+
|
| 101 |
+
def forward(
|
| 102 |
+
self,
|
| 103 |
+
est_target: torch.Tensor,
|
| 104 |
+
target: torch.Tensor
|
| 105 |
+
) -> torch.Tensor:
|
| 106 |
+
if target.size() != est_target.size() or target.ndim != 3:
|
| 107 |
+
raise TypeError(
|
| 108 |
+
f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
|
| 109 |
+
)
|
| 110 |
+
# Step 1. Zero-mean norm
|
| 111 |
+
if self.zero_mean:
|
| 112 |
+
mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
|
| 113 |
+
mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
|
| 114 |
+
target = target - mean_source
|
| 115 |
+
est_target = est_target - mean_estimate
|
| 116 |
+
# Step 2. Pair-wise SI-SDR.
|
| 117 |
+
if self.sdr_type in ["sisdr", "sdsdr"]:
|
| 118 |
+
# [batch, 1]
|
| 119 |
+
dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
|
| 120 |
+
# [batch, 1]
|
| 121 |
+
s_target_energy = (
|
| 122 |
+
torch.sum(target ** 2, dim=[1, 2], keepdim=True) + self.EPS
|
| 123 |
+
)
|
| 124 |
+
# [batch, time]
|
| 125 |
+
scaled_target = dot * target / s_target_energy
|
| 126 |
+
else:
|
| 127 |
+
# [batch, time]
|
| 128 |
+
scaled_target = target
|
| 129 |
+
if self.sdr_type in ["sdsdr", "snr"]:
|
| 130 |
+
e_noise = est_target - target
|
| 131 |
+
else:
|
| 132 |
+
e_noise = est_target - scaled_target
|
| 133 |
+
# [batch]
|
| 134 |
+
|
| 135 |
+
if self.p == 2.0:
|
| 136 |
+
losses = torch.sum(scaled_target ** 2, dim=[1, 2]) / (
|
| 137 |
+
torch.sum(e_noise ** 2, dim=[1, 2]) + self.EPS
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
|
| 141 |
+
torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
|
| 142 |
+
)
|
| 143 |
+
if self.take_log:
|
| 144 |
+
losses = 10 * torch.log10(losses + self.EPS)
|
| 145 |
+
losses = losses.mean() if self.reduction == "mean" else losses
|
| 146 |
+
return -losses
|
separator/models/bandit/core/metrics/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .snr import (
|
| 2 |
+
ChunkMedianScaleInvariantSignalDistortionRatio,
|
| 3 |
+
ChunkMedianScaleInvariantSignalNoiseRatio,
|
| 4 |
+
ChunkMedianSignalDistortionRatio,
|
| 5 |
+
ChunkMedianSignalNoiseRatio,
|
| 6 |
+
SafeSignalDistortionRatio,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
# from .mushra import EstimatedMushraScore
|
separator/models/bandit/core/metrics/_squim.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
from torchaudio._internal import load_state_dict_from_url
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def transform_wb_pesq_range(x: float) -> float:
|
| 14 |
+
"""The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined
|
| 15 |
+
for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric
|
| 16 |
+
defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score".
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
x (float): Narrow-band PESQ score.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
(float): Wide-band PESQ score.
|
| 23 |
+
"""
|
| 24 |
+
return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
PESQRange: Tuple[float, float] = (
|
| 28 |
+
1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of
|
| 29 |
+
# the raw score is not -0.5 anymore. It's hard to figure out the true lower bound.
|
| 30 |
+
# We are using 1.0 as a reasonable approximation.
|
| 31 |
+
transform_wb_pesq_range(4.5),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class RangeSigmoid(nn.Module):
|
| 36 |
+
def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
|
| 37 |
+
super(RangeSigmoid, self).__init__()
|
| 38 |
+
assert isinstance(val_range, tuple) and len(val_range) == 2
|
| 39 |
+
self.val_range: Tuple[float, float] = val_range
|
| 40 |
+
self.sigmoid: nn.modules.Module = nn.Sigmoid()
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 43 |
+
out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0]
|
| 44 |
+
return out
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class Encoder(nn.Module):
|
| 48 |
+
"""Encoder module that transform 1D waveform to 2D representations.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512)
|
| 52 |
+
win_len (int, optional): kernel size in the Conv1D layer. (Default: 32)
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
|
| 56 |
+
super(Encoder, self).__init__()
|
| 57 |
+
|
| 58 |
+
self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
|
| 59 |
+
|
| 60 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
"""Apply waveforms to convolutional layer and ReLU layer.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
(torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`.
|
| 68 |
+
"""
|
| 69 |
+
out = x.unsqueeze(dim=1)
|
| 70 |
+
out = F.relu(self.conv1d(out))
|
| 71 |
+
return out
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class SingleRNN(nn.Module):
|
| 75 |
+
def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None:
|
| 76 |
+
super(SingleRNN, self).__init__()
|
| 77 |
+
|
| 78 |
+
self.rnn_type = rnn_type
|
| 79 |
+
self.input_size = input_size
|
| 80 |
+
self.hidden_size = hidden_size
|
| 81 |
+
|
| 82 |
+
self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
|
| 83 |
+
input_size,
|
| 84 |
+
hidden_size,
|
| 85 |
+
1,
|
| 86 |
+
dropout=dropout,
|
| 87 |
+
batch_first=True,
|
| 88 |
+
bidirectional=True,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
self.proj = nn.Linear(hidden_size * 2, input_size)
|
| 92 |
+
|
| 93 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 94 |
+
# input shape: batch, seq, dim
|
| 95 |
+
out, _ = self.rnn(x)
|
| 96 |
+
out = self.proj(out)
|
| 97 |
+
return out
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class DPRNN(nn.Module):
|
| 101 |
+
"""*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64)
|
| 105 |
+
hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128)
|
| 106 |
+
num_blocks (int, optional): Number of DPRNN layers. (Default: 6)
|
| 107 |
+
rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM")
|
| 108 |
+
d_model (int, optional): The number of expected features in the input. (Default: 256)
|
| 109 |
+
chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100)
|
| 110 |
+
chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50)
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
feat_dim: int = 64,
|
| 116 |
+
hidden_dim: int = 128,
|
| 117 |
+
num_blocks: int = 6,
|
| 118 |
+
rnn_type: str = "LSTM",
|
| 119 |
+
d_model: int = 256,
|
| 120 |
+
chunk_size: int = 100,
|
| 121 |
+
chunk_stride: int = 50,
|
| 122 |
+
) -> None:
|
| 123 |
+
super(DPRNN, self).__init__()
|
| 124 |
+
|
| 125 |
+
self.num_blocks = num_blocks
|
| 126 |
+
|
| 127 |
+
self.row_rnn = nn.ModuleList([])
|
| 128 |
+
self.col_rnn = nn.ModuleList([])
|
| 129 |
+
self.row_norm = nn.ModuleList([])
|
| 130 |
+
self.col_norm = nn.ModuleList([])
|
| 131 |
+
for _ in range(num_blocks):
|
| 132 |
+
self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
|
| 133 |
+
self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
|
| 134 |
+
self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
|
| 135 |
+
self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
|
| 136 |
+
self.conv = nn.Sequential(
|
| 137 |
+
nn.Conv2d(feat_dim, d_model, 1),
|
| 138 |
+
nn.PReLU(),
|
| 139 |
+
)
|
| 140 |
+
self.chunk_size = chunk_size
|
| 141 |
+
self.chunk_stride = chunk_stride
|
| 142 |
+
|
| 143 |
+
def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 144 |
+
# input shape: (B, N, T)
|
| 145 |
+
seq_len = x.shape[-1]
|
| 146 |
+
|
| 147 |
+
rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
|
| 148 |
+
out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
|
| 149 |
+
|
| 150 |
+
return out, rest
|
| 151 |
+
|
| 152 |
+
def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
|
| 153 |
+
out, rest = self.pad_chunk(x)
|
| 154 |
+
batch_size, feat_dim, seq_len = out.shape
|
| 155 |
+
|
| 156 |
+
segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
|
| 157 |
+
segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
|
| 158 |
+
out = torch.cat([segments1, segments2], dim=3)
|
| 159 |
+
out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous()
|
| 160 |
+
|
| 161 |
+
return out, rest
|
| 162 |
+
|
| 163 |
+
def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
|
| 164 |
+
batch_size, dim, _, _ = x.shape
|
| 165 |
+
out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2)
|
| 166 |
+
out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :]
|
| 167 |
+
out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
|
| 168 |
+
out = out1 + out2
|
| 169 |
+
if rest > 0:
|
| 170 |
+
out = out[:, :, :-rest]
|
| 171 |
+
out = out.contiguous()
|
| 172 |
+
return out
|
| 173 |
+
|
| 174 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 175 |
+
x, rest = self.chunking(x)
|
| 176 |
+
batch_size, _, dim1, dim2 = x.shape
|
| 177 |
+
out = x
|
| 178 |
+
for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm):
|
| 179 |
+
row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous()
|
| 180 |
+
row_out = row_rnn(row_in)
|
| 181 |
+
row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
|
| 182 |
+
row_out = row_norm(row_out)
|
| 183 |
+
out = out + row_out
|
| 184 |
+
|
| 185 |
+
col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous()
|
| 186 |
+
col_out = col_rnn(col_in)
|
| 187 |
+
col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
|
| 188 |
+
col_out = col_norm(col_out)
|
| 189 |
+
out = out + col_out
|
| 190 |
+
out = self.conv(out)
|
| 191 |
+
out = self.merging(out, rest)
|
| 192 |
+
out = out.transpose(1, 2).contiguous()
|
| 193 |
+
return out
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class AutoPool(nn.Module):
|
| 197 |
+
def __init__(self, pool_dim: int = 1) -> None:
|
| 198 |
+
super(AutoPool, self).__init__()
|
| 199 |
+
self.pool_dim: int = pool_dim
|
| 200 |
+
self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
|
| 201 |
+
self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
|
| 202 |
+
|
| 203 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 204 |
+
weight = self.softmax(torch.mul(x, self.alpha))
|
| 205 |
+
out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
|
| 206 |
+
return out
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class SquimObjective(nn.Module):
|
| 210 |
+
"""Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
|
| 211 |
+
for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
|
| 215 |
+
dprnn (torch.nn.Module): DPRNN module to model sequential feature.
|
| 216 |
+
branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score.
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
def __init__(
|
| 220 |
+
self,
|
| 221 |
+
encoder: nn.Module,
|
| 222 |
+
dprnn: nn.Module,
|
| 223 |
+
branches: nn.ModuleList,
|
| 224 |
+
):
|
| 225 |
+
super(SquimObjective, self).__init__()
|
| 226 |
+
self.encoder = encoder
|
| 227 |
+
self.dprnn = dprnn
|
| 228 |
+
self.branches = branches
|
| 229 |
+
|
| 230 |
+
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
| 231 |
+
"""
|
| 232 |
+
Args:
|
| 233 |
+
x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
|
| 237 |
+
"""
|
| 238 |
+
if x.ndim != 2:
|
| 239 |
+
raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.")
|
| 240 |
+
x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
|
| 241 |
+
out = self.encoder(x)
|
| 242 |
+
out = self.dprnn(out)
|
| 243 |
+
scores = []
|
| 244 |
+
for branch in self.branches:
|
| 245 |
+
scores.append(branch(out).squeeze(dim=1))
|
| 246 |
+
return scores
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
|
| 250 |
+
"""Create branch module after DPRNN model for predicting metric score.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
d_model (int): The number of expected features in the input.
|
| 254 |
+
nhead (int): Number of heads in the multi-head attention model.
|
| 255 |
+
metric (str): The metric name to predict.
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
(nn.Module): Returned module to predict corresponding metric score.
|
| 259 |
+
"""
|
| 260 |
+
layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True)
|
| 261 |
+
layer2 = AutoPool()
|
| 262 |
+
if metric == "stoi":
|
| 263 |
+
layer3 = nn.Sequential(
|
| 264 |
+
nn.Linear(d_model, d_model),
|
| 265 |
+
nn.PReLU(),
|
| 266 |
+
nn.Linear(d_model, 1),
|
| 267 |
+
RangeSigmoid(),
|
| 268 |
+
)
|
| 269 |
+
elif metric == "pesq":
|
| 270 |
+
layer3 = nn.Sequential(
|
| 271 |
+
nn.Linear(d_model, d_model),
|
| 272 |
+
nn.PReLU(),
|
| 273 |
+
nn.Linear(d_model, 1),
|
| 274 |
+
RangeSigmoid(val_range=PESQRange),
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1))
|
| 278 |
+
return nn.Sequential(layer1, layer2, layer3)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def squim_objective_model(
|
| 282 |
+
feat_dim: int,
|
| 283 |
+
win_len: int,
|
| 284 |
+
d_model: int,
|
| 285 |
+
nhead: int,
|
| 286 |
+
hidden_dim: int,
|
| 287 |
+
num_blocks: int,
|
| 288 |
+
rnn_type: str,
|
| 289 |
+
chunk_size: int,
|
| 290 |
+
chunk_stride: Optional[int] = None,
|
| 291 |
+
) -> SquimObjective:
|
| 292 |
+
"""Build a custome :class:`torchaudio.prototype.models.SquimObjective` model.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
feat_dim (int, optional): The feature dimension after Encoder module.
|
| 296 |
+
win_len (int): Kernel size in the Encoder module.
|
| 297 |
+
d_model (int): The number of expected features in the input.
|
| 298 |
+
nhead (int): Number of heads in the multi-head attention model.
|
| 299 |
+
hidden_dim (int): Hidden dimension in the RNN layer of DPRNN.
|
| 300 |
+
num_blocks (int): Number of DPRNN layers.
|
| 301 |
+
rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"].
|
| 302 |
+
chunk_size (int): Chunk size of input for DPRNN.
|
| 303 |
+
chunk_stride (int or None, optional): Stride of chunk input for DPRNN.
|
| 304 |
+
"""
|
| 305 |
+
if chunk_stride is None:
|
| 306 |
+
chunk_stride = chunk_size // 2
|
| 307 |
+
encoder = Encoder(feat_dim, win_len)
|
| 308 |
+
dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride)
|
| 309 |
+
branches = nn.ModuleList(
|
| 310 |
+
[
|
| 311 |
+
_create_branch(d_model, nhead, "stoi"),
|
| 312 |
+
_create_branch(d_model, nhead, "pesq"),
|
| 313 |
+
_create_branch(d_model, nhead, "sisdr"),
|
| 314 |
+
]
|
| 315 |
+
)
|
| 316 |
+
return SquimObjective(encoder, dprnn, branches)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def squim_objective_base() -> SquimObjective:
|
| 320 |
+
"""Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments."""
|
| 321 |
+
return squim_objective_model(
|
| 322 |
+
feat_dim=256,
|
| 323 |
+
win_len=64,
|
| 324 |
+
d_model=256,
|
| 325 |
+
nhead=4,
|
| 326 |
+
hidden_dim=256,
|
| 327 |
+
num_blocks=2,
|
| 328 |
+
rnn_type="LSTM",
|
| 329 |
+
chunk_size=71,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
@dataclass
|
| 333 |
+
class SquimObjectiveBundle:
|
| 334 |
+
|
| 335 |
+
_path: str
|
| 336 |
+
_sample_rate: float
|
| 337 |
+
|
| 338 |
+
def _get_state_dict(self, dl_kwargs):
|
| 339 |
+
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
|
| 340 |
+
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
|
| 341 |
+
state_dict = load_state_dict_from_url(url, **dl_kwargs)
|
| 342 |
+
return state_dict
|
| 343 |
+
|
| 344 |
+
def get_model(self, *, dl_kwargs=None) -> SquimObjective:
|
| 345 |
+
"""Construct the SquimObjective model, and load the pretrained weight.
|
| 346 |
+
|
| 347 |
+
The weight file is downloaded from the internet and cached with
|
| 348 |
+
:func:`torch.hub.load_state_dict_from_url`
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
|
| 352 |
+
|
| 353 |
+
Returns:
|
| 354 |
+
Variation of :py:class:`~torchaudio.models.SquimObjective`.
|
| 355 |
+
"""
|
| 356 |
+
model = squim_objective_base()
|
| 357 |
+
model.load_state_dict(self._get_state_dict(dl_kwargs))
|
| 358 |
+
model.eval()
|
| 359 |
+
return model
|
| 360 |
+
|
| 361 |
+
@property
|
| 362 |
+
def sample_rate(self):
|
| 363 |
+
"""Sample rate of the audio that the model is trained on.
|
| 364 |
+
|
| 365 |
+
:type: float
|
| 366 |
+
"""
|
| 367 |
+
return self._sample_rate
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
SQUIM_OBJECTIVE = SquimObjectiveBundle(
|
| 371 |
+
"squim_objective_dns2020.pth",
|
| 372 |
+
_sample_rate=16000,
|
| 373 |
+
)
|
| 374 |
+
SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
|
| 375 |
+
:cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
|
| 376 |
+
|
| 377 |
+
The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
|
| 378 |
+
The weights are under `Creative Commons Attribution 4.0 International License
|
| 379 |
+
<https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
|
| 380 |
+
|
| 381 |
+
Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
|
| 382 |
+
"""
|
| 383 |
+
|
separator/models/bandit/core/metrics/snr.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torchmetrics as tm
|
| 6 |
+
from torch._C import _LinAlgError
|
| 7 |
+
from torchmetrics import functional as tmF
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SafeSignalDistortionRatio(tm.SignalDistortionRatio):
|
| 11 |
+
def __init__(self, **kwargs) -> None:
|
| 12 |
+
super().__init__(**kwargs)
|
| 13 |
+
|
| 14 |
+
def update(self, *args, **kwargs) -> Any:
|
| 15 |
+
try:
|
| 16 |
+
super().update(*args, **kwargs)
|
| 17 |
+
except:
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
def compute(self) -> Any:
|
| 21 |
+
if self.total == 0:
|
| 22 |
+
return torch.tensor(torch.nan)
|
| 23 |
+
return super().compute()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BaseChunkMedianSignalRatio(tm.Metric):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
func: Callable,
|
| 30 |
+
window_size: int,
|
| 31 |
+
hop_size: int = None,
|
| 32 |
+
zero_mean: bool = False,
|
| 33 |
+
) -> None:
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
# self.zero_mean = zero_mean
|
| 37 |
+
self.func = func
|
| 38 |
+
self.window_size = window_size
|
| 39 |
+
if hop_size is None:
|
| 40 |
+
hop_size = window_size
|
| 41 |
+
self.hop_size = hop_size
|
| 42 |
+
|
| 43 |
+
self.add_state(
|
| 44 |
+
"sum_snr",
|
| 45 |
+
default=torch.tensor(0.0),
|
| 46 |
+
dist_reduce_fx="sum"
|
| 47 |
+
)
|
| 48 |
+
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
| 49 |
+
|
| 50 |
+
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
|
| 51 |
+
|
| 52 |
+
n_samples = target.shape[-1]
|
| 53 |
+
|
| 54 |
+
n_chunks = int(
|
| 55 |
+
np.ceil((n_samples - self.window_size) / self.hop_size) + 1
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
snr_chunk = []
|
| 59 |
+
|
| 60 |
+
for i in range(n_chunks):
|
| 61 |
+
start = i * self.hop_size
|
| 62 |
+
|
| 63 |
+
if n_samples - start < self.window_size:
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
end = start + self.window_size
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
chunk_snr = self.func(
|
| 70 |
+
preds[..., start:end],
|
| 71 |
+
target[..., start:end]
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# print(preds.shape, chunk_snr.shape)
|
| 75 |
+
|
| 76 |
+
if torch.all(torch.isfinite(chunk_snr)):
|
| 77 |
+
snr_chunk.append(chunk_snr)
|
| 78 |
+
except _LinAlgError:
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
snr_chunk = torch.stack(snr_chunk, dim=-1)
|
| 82 |
+
snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1)
|
| 83 |
+
|
| 84 |
+
self.sum_snr += snr_batch.sum()
|
| 85 |
+
self.total += snr_batch.numel()
|
| 86 |
+
|
| 87 |
+
def compute(self) -> Any:
|
| 88 |
+
return self.sum_snr / self.total
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio):
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
window_size: int,
|
| 95 |
+
hop_size: int = None,
|
| 96 |
+
zero_mean: bool = False
|
| 97 |
+
) -> None:
|
| 98 |
+
super().__init__(
|
| 99 |
+
func=tmF.signal_noise_ratio,
|
| 100 |
+
window_size=window_size,
|
| 101 |
+
hop_size=hop_size,
|
| 102 |
+
zero_mean=zero_mean,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio):
|
| 107 |
+
def __init__(
|
| 108 |
+
self,
|
| 109 |
+
window_size: int,
|
| 110 |
+
hop_size: int = None,
|
| 111 |
+
zero_mean: bool = False
|
| 112 |
+
) -> None:
|
| 113 |
+
super().__init__(
|
| 114 |
+
func=tmF.scale_invariant_signal_noise_ratio,
|
| 115 |
+
window_size=window_size,
|
| 116 |
+
hop_size=hop_size,
|
| 117 |
+
zero_mean=zero_mean,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio):
|
| 122 |
+
def __init__(
|
| 123 |
+
self,
|
| 124 |
+
window_size: int,
|
| 125 |
+
hop_size: int = None,
|
| 126 |
+
zero_mean: bool = False
|
| 127 |
+
) -> None:
|
| 128 |
+
super().__init__(
|
| 129 |
+
func=tmF.signal_distortion_ratio,
|
| 130 |
+
window_size=window_size,
|
| 131 |
+
hop_size=hop_size,
|
| 132 |
+
zero_mean=zero_mean,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class ChunkMedianScaleInvariantSignalDistortionRatio(
|
| 137 |
+
BaseChunkMedianSignalRatio
|
| 138 |
+
):
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
window_size: int,
|
| 142 |
+
hop_size: int = None,
|
| 143 |
+
zero_mean: bool = False
|
| 144 |
+
) -> None:
|
| 145 |
+
super().__init__(
|
| 146 |
+
func=tmF.scale_invariant_signal_distortion_ratio,
|
| 147 |
+
window_size=window_size,
|
| 148 |
+
hop_size=hop_size,
|
| 149 |
+
zero_mean=zero_mean,
|
| 150 |
+
)
|
separator/models/bandit/core/model/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .bsrnn.wrapper import (
|
| 2 |
+
MultiMaskMultiSourceBandSplitRNNSimple,
|
| 3 |
+
)
|
separator/models/bandit/core/model/_spectral.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio as ta
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class _SpectralComponent(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
n_fft: int = 2048,
|
| 12 |
+
win_length: Optional[int] = 2048,
|
| 13 |
+
hop_length: int = 512,
|
| 14 |
+
window_fn: str = "hann_window",
|
| 15 |
+
wkwargs: Optional[Dict] = None,
|
| 16 |
+
power: Optional[int] = None,
|
| 17 |
+
center: bool = True,
|
| 18 |
+
normalized: bool = True,
|
| 19 |
+
pad_mode: str = "constant",
|
| 20 |
+
onesided: bool = True,
|
| 21 |
+
**kwargs,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
assert power is None
|
| 26 |
+
|
| 27 |
+
window_fn = torch.__dict__[window_fn]
|
| 28 |
+
|
| 29 |
+
self.stft = (
|
| 30 |
+
ta.transforms.Spectrogram(
|
| 31 |
+
n_fft=n_fft,
|
| 32 |
+
win_length=win_length,
|
| 33 |
+
hop_length=hop_length,
|
| 34 |
+
pad_mode=pad_mode,
|
| 35 |
+
pad=0,
|
| 36 |
+
window_fn=window_fn,
|
| 37 |
+
wkwargs=wkwargs,
|
| 38 |
+
power=power,
|
| 39 |
+
normalized=normalized,
|
| 40 |
+
center=center,
|
| 41 |
+
onesided=onesided,
|
| 42 |
+
)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self.istft = (
|
| 46 |
+
ta.transforms.InverseSpectrogram(
|
| 47 |
+
n_fft=n_fft,
|
| 48 |
+
win_length=win_length,
|
| 49 |
+
hop_length=hop_length,
|
| 50 |
+
pad_mode=pad_mode,
|
| 51 |
+
pad=0,
|
| 52 |
+
window_fn=window_fn,
|
| 53 |
+
wkwargs=wkwargs,
|
| 54 |
+
normalized=normalized,
|
| 55 |
+
center=center,
|
| 56 |
+
onesided=onesided,
|
| 57 |
+
)
|
| 58 |
+
)
|
separator/models/bandit/core/model/bsrnn/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC
|
| 2 |
+
from typing import Iterable, Mapping, Union
|
| 3 |
+
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule
|
| 7 |
+
from models.bandit.core.model.bsrnn.tfmodel import (
|
| 8 |
+
SeqBandModellingModule,
|
| 9 |
+
TransformerTimeFreqModule,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BandsplitCoreBase(nn.Module, ABC):
|
| 14 |
+
band_split: nn.Module
|
| 15 |
+
tf_model: nn.Module
|
| 16 |
+
mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]]
|
| 17 |
+
|
| 18 |
+
def __init__(self) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def mask(x, m):
|
| 23 |
+
return x * m
|
separator/models/bandit/core/model/bsrnn/bandsplit.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
from models.bandit.core.model.bsrnn.utils import (
|
| 7 |
+
band_widths_from_specs,
|
| 8 |
+
check_no_gap,
|
| 9 |
+
check_no_overlap,
|
| 10 |
+
check_nonzero_bandwidth,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class NormFC(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
emb_dim: int,
|
| 18 |
+
bandwidth: int,
|
| 19 |
+
in_channel: int,
|
| 20 |
+
normalize_channel_independently: bool = False,
|
| 21 |
+
treat_channel_as_feature: bool = True,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
self.treat_channel_as_feature = treat_channel_as_feature
|
| 26 |
+
|
| 27 |
+
if normalize_channel_independently:
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
|
| 30 |
+
reim = 2
|
| 31 |
+
|
| 32 |
+
self.norm = nn.LayerNorm(in_channel * bandwidth * reim)
|
| 33 |
+
|
| 34 |
+
fc_in = bandwidth * reim
|
| 35 |
+
|
| 36 |
+
if treat_channel_as_feature:
|
| 37 |
+
fc_in *= in_channel
|
| 38 |
+
else:
|
| 39 |
+
assert emb_dim % in_channel == 0
|
| 40 |
+
emb_dim = emb_dim // in_channel
|
| 41 |
+
|
| 42 |
+
self.fc = nn.Linear(fc_in, emb_dim)
|
| 43 |
+
|
| 44 |
+
def forward(self, xb):
|
| 45 |
+
# xb = (batch, n_time, in_chan, reim * band_width)
|
| 46 |
+
|
| 47 |
+
batch, n_time, in_chan, ribw = xb.shape
|
| 48 |
+
xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw))
|
| 49 |
+
# (batch, n_time, in_chan * reim * band_width)
|
| 50 |
+
|
| 51 |
+
if not self.treat_channel_as_feature:
|
| 52 |
+
xb = xb.reshape(batch, n_time, in_chan, ribw)
|
| 53 |
+
# (batch, n_time, in_chan, reim * band_width)
|
| 54 |
+
|
| 55 |
+
zb = self.fc(xb)
|
| 56 |
+
# (batch, n_time, emb_dim)
|
| 57 |
+
# OR
|
| 58 |
+
# (batch, n_time, in_chan, emb_dim_per_chan)
|
| 59 |
+
|
| 60 |
+
if not self.treat_channel_as_feature:
|
| 61 |
+
batch, n_time, in_chan, emb_dim_per_chan = zb.shape
|
| 62 |
+
# (batch, n_time, in_chan, emb_dim_per_chan)
|
| 63 |
+
zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan))
|
| 64 |
+
|
| 65 |
+
return zb # (batch, n_time, emb_dim)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class BandSplitModule(nn.Module):
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
band_specs: List[Tuple[float, float]],
|
| 72 |
+
emb_dim: int,
|
| 73 |
+
in_channel: int,
|
| 74 |
+
require_no_overlap: bool = False,
|
| 75 |
+
require_no_gap: bool = True,
|
| 76 |
+
normalize_channel_independently: bool = False,
|
| 77 |
+
treat_channel_as_feature: bool = True,
|
| 78 |
+
) -> None:
|
| 79 |
+
super().__init__()
|
| 80 |
+
|
| 81 |
+
check_nonzero_bandwidth(band_specs)
|
| 82 |
+
|
| 83 |
+
if require_no_gap:
|
| 84 |
+
check_no_gap(band_specs)
|
| 85 |
+
|
| 86 |
+
if require_no_overlap:
|
| 87 |
+
check_no_overlap(band_specs)
|
| 88 |
+
|
| 89 |
+
self.band_specs = band_specs
|
| 90 |
+
# list of [fstart, fend) in index.
|
| 91 |
+
# Note that fend is exclusive.
|
| 92 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
| 93 |
+
self.n_bands = len(band_specs)
|
| 94 |
+
self.emb_dim = emb_dim
|
| 95 |
+
|
| 96 |
+
self.norm_fc_modules = nn.ModuleList(
|
| 97 |
+
[ # type: ignore
|
| 98 |
+
(
|
| 99 |
+
NormFC(
|
| 100 |
+
emb_dim=emb_dim,
|
| 101 |
+
bandwidth=bw,
|
| 102 |
+
in_channel=in_channel,
|
| 103 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 104 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 105 |
+
)
|
| 106 |
+
)
|
| 107 |
+
for bw in self.band_widths
|
| 108 |
+
]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def forward(self, x: torch.Tensor):
|
| 112 |
+
# x = complex spectrogram (batch, in_chan, n_freq, n_time)
|
| 113 |
+
|
| 114 |
+
batch, in_chan, _, n_time = x.shape
|
| 115 |
+
|
| 116 |
+
z = torch.zeros(
|
| 117 |
+
size=(batch, self.n_bands, n_time, self.emb_dim),
|
| 118 |
+
device=x.device
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
xr = torch.view_as_real(x) # batch, in_chan, n_freq, n_time, 2
|
| 122 |
+
xr = torch.permute(
|
| 123 |
+
xr,
|
| 124 |
+
(0, 3, 1, 4, 2)
|
| 125 |
+
) # batch, n_time, in_chan, 2, n_freq
|
| 126 |
+
batch, n_time, in_chan, reim, band_width = xr.shape
|
| 127 |
+
for i, nfm in enumerate(self.norm_fc_modules):
|
| 128 |
+
# print(f"bandsplit/band{i:02d}")
|
| 129 |
+
fstart, fend = self.band_specs[i]
|
| 130 |
+
xb = xr[..., fstart:fend]
|
| 131 |
+
# (batch, n_time, in_chan, reim, band_width)
|
| 132 |
+
xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
|
| 133 |
+
# (batch, n_time, in_chan, reim * band_width)
|
| 134 |
+
# z.append(nfm(xb)) # (batch, n_time, emb_dim)
|
| 135 |
+
z[:, i, :, :] = nfm(xb.contiguous())
|
| 136 |
+
|
| 137 |
+
# z = torch.stack(z, dim=1)
|
| 138 |
+
|
| 139 |
+
return z
|
separator/models/bandit/core/model/bsrnn/core.py
ADDED
|
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from models.bandit.core.model.bsrnn import BandsplitCoreBase
|
| 8 |
+
from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule
|
| 9 |
+
from models.bandit.core.model.bsrnn.maskestim import (
|
| 10 |
+
MaskEstimationModule,
|
| 11 |
+
OverlappingMaskEstimationModule
|
| 12 |
+
)
|
| 13 |
+
from models.bandit.core.model.bsrnn.tfmodel import (
|
| 14 |
+
ConvolutionalTimeFreqModule,
|
| 15 |
+
SeqBandModellingModule,
|
| 16 |
+
TransformerTimeFreqModule
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MultiMaskBandSplitCoreBase(BandsplitCoreBase):
|
| 21 |
+
def __init__(self) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
|
| 24 |
+
def forward(self, x, cond=None, compute_residual: bool = True):
|
| 25 |
+
# x = complex spectrogram (batch, in_chan, n_freq, n_time)
|
| 26 |
+
# print(x.shape)
|
| 27 |
+
batch, in_chan, n_freq, n_time = x.shape
|
| 28 |
+
x = torch.reshape(x, (-1, 1, n_freq, n_time))
|
| 29 |
+
|
| 30 |
+
z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
|
| 31 |
+
|
| 32 |
+
# if torch.any(torch.isnan(z)):
|
| 33 |
+
# raise ValueError("z nan")
|
| 34 |
+
|
| 35 |
+
# print(z)
|
| 36 |
+
q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
|
| 37 |
+
# print(q)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# if torch.any(torch.isnan(q)):
|
| 41 |
+
# raise ValueError("q nan")
|
| 42 |
+
|
| 43 |
+
out = {}
|
| 44 |
+
|
| 45 |
+
for stem, mem in self.mask_estim.items():
|
| 46 |
+
m = mem(q, cond=cond)
|
| 47 |
+
|
| 48 |
+
# if torch.any(torch.isnan(m)):
|
| 49 |
+
# raise ValueError("m nan", stem)
|
| 50 |
+
|
| 51 |
+
s = self.mask(x, m)
|
| 52 |
+
s = torch.reshape(s, (batch, in_chan, n_freq, n_time))
|
| 53 |
+
out[stem] = s
|
| 54 |
+
|
| 55 |
+
return {"spectrogram": out}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def instantiate_mask_estim(self,
|
| 60 |
+
in_channel: int,
|
| 61 |
+
stems: List[str],
|
| 62 |
+
band_specs: List[Tuple[float, float]],
|
| 63 |
+
emb_dim: int,
|
| 64 |
+
mlp_dim: int,
|
| 65 |
+
cond_dim: int,
|
| 66 |
+
hidden_activation: str,
|
| 67 |
+
|
| 68 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 69 |
+
complex_mask: bool = True,
|
| 70 |
+
overlapping_band: bool = False,
|
| 71 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 72 |
+
n_freq: Optional[int] = None,
|
| 73 |
+
use_freq_weights: bool = True,
|
| 74 |
+
mult_add_mask: bool = False
|
| 75 |
+
):
|
| 76 |
+
if hidden_activation_kwargs is None:
|
| 77 |
+
hidden_activation_kwargs = {}
|
| 78 |
+
|
| 79 |
+
if "mne:+" in stems:
|
| 80 |
+
stems = [s for s in stems if s != "mne:+"]
|
| 81 |
+
|
| 82 |
+
if overlapping_band:
|
| 83 |
+
assert freq_weights is not None
|
| 84 |
+
assert n_freq is not None
|
| 85 |
+
|
| 86 |
+
if mult_add_mask:
|
| 87 |
+
|
| 88 |
+
self.mask_estim = nn.ModuleDict(
|
| 89 |
+
{
|
| 90 |
+
stem: MultAddMaskEstimationModule(
|
| 91 |
+
band_specs=band_specs,
|
| 92 |
+
freq_weights=freq_weights,
|
| 93 |
+
n_freq=n_freq,
|
| 94 |
+
emb_dim=emb_dim,
|
| 95 |
+
mlp_dim=mlp_dim,
|
| 96 |
+
in_channel=in_channel,
|
| 97 |
+
hidden_activation=hidden_activation,
|
| 98 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 99 |
+
complex_mask=complex_mask,
|
| 100 |
+
use_freq_weights=use_freq_weights,
|
| 101 |
+
)
|
| 102 |
+
for stem in stems
|
| 103 |
+
}
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
self.mask_estim = nn.ModuleDict(
|
| 107 |
+
{
|
| 108 |
+
stem: OverlappingMaskEstimationModule(
|
| 109 |
+
band_specs=band_specs,
|
| 110 |
+
freq_weights=freq_weights,
|
| 111 |
+
n_freq=n_freq,
|
| 112 |
+
emb_dim=emb_dim,
|
| 113 |
+
mlp_dim=mlp_dim,
|
| 114 |
+
in_channel=in_channel,
|
| 115 |
+
hidden_activation=hidden_activation,
|
| 116 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 117 |
+
complex_mask=complex_mask,
|
| 118 |
+
use_freq_weights=use_freq_weights,
|
| 119 |
+
)
|
| 120 |
+
for stem in stems
|
| 121 |
+
}
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
self.mask_estim = nn.ModuleDict(
|
| 125 |
+
{
|
| 126 |
+
stem: MaskEstimationModule(
|
| 127 |
+
band_specs=band_specs,
|
| 128 |
+
emb_dim=emb_dim,
|
| 129 |
+
mlp_dim=mlp_dim,
|
| 130 |
+
cond_dim=cond_dim,
|
| 131 |
+
in_channel=in_channel,
|
| 132 |
+
hidden_activation=hidden_activation,
|
| 133 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 134 |
+
complex_mask=complex_mask,
|
| 135 |
+
)
|
| 136 |
+
for stem in stems
|
| 137 |
+
}
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def instantiate_bandsplit(self,
|
| 141 |
+
in_channel: int,
|
| 142 |
+
band_specs: List[Tuple[float, float]],
|
| 143 |
+
require_no_overlap: bool = False,
|
| 144 |
+
require_no_gap: bool = True,
|
| 145 |
+
normalize_channel_independently: bool = False,
|
| 146 |
+
treat_channel_as_feature: bool = True,
|
| 147 |
+
emb_dim: int = 128
|
| 148 |
+
):
|
| 149 |
+
self.band_split = BandSplitModule(
|
| 150 |
+
in_channel=in_channel,
|
| 151 |
+
band_specs=band_specs,
|
| 152 |
+
require_no_overlap=require_no_overlap,
|
| 153 |
+
require_no_gap=require_no_gap,
|
| 154 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 155 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 156 |
+
emb_dim=emb_dim,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
class SingleMaskBandsplitCoreBase(BandsplitCoreBase):
|
| 160 |
+
def __init__(self, **kwargs) -> None:
|
| 161 |
+
super().__init__()
|
| 162 |
+
|
| 163 |
+
def forward(self, x):
|
| 164 |
+
# x = complex spectrogram (batch, in_chan, n_freq, n_time)
|
| 165 |
+
z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
|
| 166 |
+
q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
|
| 167 |
+
m = self.mask_estim(q) # (batch, in_chan, n_freq, n_time)
|
| 168 |
+
|
| 169 |
+
s = self.mask(x, m)
|
| 170 |
+
|
| 171 |
+
return s
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class SingleMaskBandsplitCoreRNN(
|
| 175 |
+
SingleMaskBandsplitCoreBase,
|
| 176 |
+
):
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
in_channel: int,
|
| 180 |
+
band_specs: List[Tuple[float, float]],
|
| 181 |
+
require_no_overlap: bool = False,
|
| 182 |
+
require_no_gap: bool = True,
|
| 183 |
+
normalize_channel_independently: bool = False,
|
| 184 |
+
treat_channel_as_feature: bool = True,
|
| 185 |
+
n_sqm_modules: int = 12,
|
| 186 |
+
emb_dim: int = 128,
|
| 187 |
+
rnn_dim: int = 256,
|
| 188 |
+
bidirectional: bool = True,
|
| 189 |
+
rnn_type: str = "LSTM",
|
| 190 |
+
mlp_dim: int = 512,
|
| 191 |
+
hidden_activation: str = "Tanh",
|
| 192 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 193 |
+
complex_mask: bool = True,
|
| 194 |
+
) -> None:
|
| 195 |
+
super().__init__()
|
| 196 |
+
self.band_split = (BandSplitModule(
|
| 197 |
+
in_channel=in_channel,
|
| 198 |
+
band_specs=band_specs,
|
| 199 |
+
require_no_overlap=require_no_overlap,
|
| 200 |
+
require_no_gap=require_no_gap,
|
| 201 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 202 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 203 |
+
emb_dim=emb_dim,
|
| 204 |
+
))
|
| 205 |
+
self.tf_model = (SeqBandModellingModule(
|
| 206 |
+
n_modules=n_sqm_modules,
|
| 207 |
+
emb_dim=emb_dim,
|
| 208 |
+
rnn_dim=rnn_dim,
|
| 209 |
+
bidirectional=bidirectional,
|
| 210 |
+
rnn_type=rnn_type,
|
| 211 |
+
))
|
| 212 |
+
self.mask_estim = (MaskEstimationModule(
|
| 213 |
+
in_channel=in_channel,
|
| 214 |
+
band_specs=band_specs,
|
| 215 |
+
emb_dim=emb_dim,
|
| 216 |
+
mlp_dim=mlp_dim,
|
| 217 |
+
hidden_activation=hidden_activation,
|
| 218 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 219 |
+
complex_mask=complex_mask,
|
| 220 |
+
))
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class SingleMaskBandsplitCoreTransformer(
|
| 224 |
+
SingleMaskBandsplitCoreBase,
|
| 225 |
+
):
|
| 226 |
+
def __init__(
|
| 227 |
+
self,
|
| 228 |
+
in_channel: int,
|
| 229 |
+
band_specs: List[Tuple[float, float]],
|
| 230 |
+
require_no_overlap: bool = False,
|
| 231 |
+
require_no_gap: bool = True,
|
| 232 |
+
normalize_channel_independently: bool = False,
|
| 233 |
+
treat_channel_as_feature: bool = True,
|
| 234 |
+
n_sqm_modules: int = 12,
|
| 235 |
+
emb_dim: int = 128,
|
| 236 |
+
rnn_dim: int = 256,
|
| 237 |
+
bidirectional: bool = True,
|
| 238 |
+
tf_dropout: float = 0.0,
|
| 239 |
+
mlp_dim: int = 512,
|
| 240 |
+
hidden_activation: str = "Tanh",
|
| 241 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 242 |
+
complex_mask: bool = True,
|
| 243 |
+
) -> None:
|
| 244 |
+
super().__init__()
|
| 245 |
+
self.band_split = BandSplitModule(
|
| 246 |
+
in_channel=in_channel,
|
| 247 |
+
band_specs=band_specs,
|
| 248 |
+
require_no_overlap=require_no_overlap,
|
| 249 |
+
require_no_gap=require_no_gap,
|
| 250 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 251 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 252 |
+
emb_dim=emb_dim,
|
| 253 |
+
)
|
| 254 |
+
self.tf_model = TransformerTimeFreqModule(
|
| 255 |
+
n_modules=n_sqm_modules,
|
| 256 |
+
emb_dim=emb_dim,
|
| 257 |
+
rnn_dim=rnn_dim,
|
| 258 |
+
bidirectional=bidirectional,
|
| 259 |
+
dropout=tf_dropout,
|
| 260 |
+
)
|
| 261 |
+
self.mask_estim = MaskEstimationModule(
|
| 262 |
+
in_channel=in_channel,
|
| 263 |
+
band_specs=band_specs,
|
| 264 |
+
emb_dim=emb_dim,
|
| 265 |
+
mlp_dim=mlp_dim,
|
| 266 |
+
hidden_activation=hidden_activation,
|
| 267 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 268 |
+
complex_mask=complex_mask,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase):
|
| 273 |
+
def __init__(
|
| 274 |
+
self,
|
| 275 |
+
in_channel: int,
|
| 276 |
+
stems: List[str],
|
| 277 |
+
band_specs: List[Tuple[float, float]],
|
| 278 |
+
require_no_overlap: bool = False,
|
| 279 |
+
require_no_gap: bool = True,
|
| 280 |
+
normalize_channel_independently: bool = False,
|
| 281 |
+
treat_channel_as_feature: bool = True,
|
| 282 |
+
n_sqm_modules: int = 12,
|
| 283 |
+
emb_dim: int = 128,
|
| 284 |
+
rnn_dim: int = 256,
|
| 285 |
+
bidirectional: bool = True,
|
| 286 |
+
rnn_type: str = "LSTM",
|
| 287 |
+
mlp_dim: int = 512,
|
| 288 |
+
cond_dim: int = 0,
|
| 289 |
+
hidden_activation: str = "Tanh",
|
| 290 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 291 |
+
complex_mask: bool = True,
|
| 292 |
+
overlapping_band: bool = False,
|
| 293 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 294 |
+
n_freq: Optional[int] = None,
|
| 295 |
+
use_freq_weights: bool = True,
|
| 296 |
+
mult_add_mask: bool = False
|
| 297 |
+
) -> None:
|
| 298 |
+
|
| 299 |
+
super().__init__()
|
| 300 |
+
self.instantiate_bandsplit(
|
| 301 |
+
in_channel=in_channel,
|
| 302 |
+
band_specs=band_specs,
|
| 303 |
+
require_no_overlap=require_no_overlap,
|
| 304 |
+
require_no_gap=require_no_gap,
|
| 305 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 306 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 307 |
+
emb_dim=emb_dim
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
self.tf_model = (
|
| 312 |
+
SeqBandModellingModule(
|
| 313 |
+
n_modules=n_sqm_modules,
|
| 314 |
+
emb_dim=emb_dim,
|
| 315 |
+
rnn_dim=rnn_dim,
|
| 316 |
+
bidirectional=bidirectional,
|
| 317 |
+
rnn_type=rnn_type,
|
| 318 |
+
)
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
self.mult_add_mask = mult_add_mask
|
| 322 |
+
|
| 323 |
+
self.instantiate_mask_estim(
|
| 324 |
+
in_channel=in_channel,
|
| 325 |
+
stems=stems,
|
| 326 |
+
band_specs=band_specs,
|
| 327 |
+
emb_dim=emb_dim,
|
| 328 |
+
mlp_dim=mlp_dim,
|
| 329 |
+
cond_dim=cond_dim,
|
| 330 |
+
hidden_activation=hidden_activation,
|
| 331 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 332 |
+
complex_mask=complex_mask,
|
| 333 |
+
overlapping_band=overlapping_band,
|
| 334 |
+
freq_weights=freq_weights,
|
| 335 |
+
n_freq=n_freq,
|
| 336 |
+
use_freq_weights=use_freq_weights,
|
| 337 |
+
mult_add_mask=mult_add_mask
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
@staticmethod
|
| 341 |
+
def _mult_add_mask(x, m):
|
| 342 |
+
|
| 343 |
+
assert m.ndim == 5
|
| 344 |
+
|
| 345 |
+
mm = m[..., 0]
|
| 346 |
+
am = m[..., 1]
|
| 347 |
+
|
| 348 |
+
# print(mm.shape, am.shape, x.shape, m.shape)
|
| 349 |
+
|
| 350 |
+
return x * mm + am
|
| 351 |
+
|
| 352 |
+
def mask(self, x, m):
|
| 353 |
+
if self.mult_add_mask:
|
| 354 |
+
|
| 355 |
+
return self._mult_add_mask(x, m)
|
| 356 |
+
else:
|
| 357 |
+
return super().mask(x, m)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class MultiSourceMultiMaskBandSplitCoreTransformer(
|
| 361 |
+
MultiMaskBandSplitCoreBase,
|
| 362 |
+
):
|
| 363 |
+
def __init__(
|
| 364 |
+
self,
|
| 365 |
+
in_channel: int,
|
| 366 |
+
stems: List[str],
|
| 367 |
+
band_specs: List[Tuple[float, float]],
|
| 368 |
+
require_no_overlap: bool = False,
|
| 369 |
+
require_no_gap: bool = True,
|
| 370 |
+
normalize_channel_independently: bool = False,
|
| 371 |
+
treat_channel_as_feature: bool = True,
|
| 372 |
+
n_sqm_modules: int = 12,
|
| 373 |
+
emb_dim: int = 128,
|
| 374 |
+
rnn_dim: int = 256,
|
| 375 |
+
bidirectional: bool = True,
|
| 376 |
+
tf_dropout: float = 0.0,
|
| 377 |
+
mlp_dim: int = 512,
|
| 378 |
+
hidden_activation: str = "Tanh",
|
| 379 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 380 |
+
complex_mask: bool = True,
|
| 381 |
+
overlapping_band: bool = False,
|
| 382 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 383 |
+
n_freq: Optional[int] = None,
|
| 384 |
+
use_freq_weights:bool=True,
|
| 385 |
+
rnn_type: str = "LSTM",
|
| 386 |
+
cond_dim: int = 0,
|
| 387 |
+
mult_add_mask: bool = False
|
| 388 |
+
) -> None:
|
| 389 |
+
super().__init__()
|
| 390 |
+
self.instantiate_bandsplit(
|
| 391 |
+
in_channel=in_channel,
|
| 392 |
+
band_specs=band_specs,
|
| 393 |
+
require_no_overlap=require_no_overlap,
|
| 394 |
+
require_no_gap=require_no_gap,
|
| 395 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 396 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 397 |
+
emb_dim=emb_dim
|
| 398 |
+
)
|
| 399 |
+
self.tf_model = TransformerTimeFreqModule(
|
| 400 |
+
n_modules=n_sqm_modules,
|
| 401 |
+
emb_dim=emb_dim,
|
| 402 |
+
rnn_dim=rnn_dim,
|
| 403 |
+
bidirectional=bidirectional,
|
| 404 |
+
dropout=tf_dropout,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
self.instantiate_mask_estim(
|
| 408 |
+
in_channel=in_channel,
|
| 409 |
+
stems=stems,
|
| 410 |
+
band_specs=band_specs,
|
| 411 |
+
emb_dim=emb_dim,
|
| 412 |
+
mlp_dim=mlp_dim,
|
| 413 |
+
cond_dim=cond_dim,
|
| 414 |
+
hidden_activation=hidden_activation,
|
| 415 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 416 |
+
complex_mask=complex_mask,
|
| 417 |
+
overlapping_band=overlapping_band,
|
| 418 |
+
freq_weights=freq_weights,
|
| 419 |
+
n_freq=n_freq,
|
| 420 |
+
use_freq_weights=use_freq_weights,
|
| 421 |
+
mult_add_mask=mult_add_mask
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class MultiSourceMultiMaskBandSplitCoreConv(
|
| 427 |
+
MultiMaskBandSplitCoreBase,
|
| 428 |
+
):
|
| 429 |
+
def __init__(
|
| 430 |
+
self,
|
| 431 |
+
in_channel: int,
|
| 432 |
+
stems: List[str],
|
| 433 |
+
band_specs: List[Tuple[float, float]],
|
| 434 |
+
require_no_overlap: bool = False,
|
| 435 |
+
require_no_gap: bool = True,
|
| 436 |
+
normalize_channel_independently: bool = False,
|
| 437 |
+
treat_channel_as_feature: bool = True,
|
| 438 |
+
n_sqm_modules: int = 12,
|
| 439 |
+
emb_dim: int = 128,
|
| 440 |
+
rnn_dim: int = 256,
|
| 441 |
+
bidirectional: bool = True,
|
| 442 |
+
tf_dropout: float = 0.0,
|
| 443 |
+
mlp_dim: int = 512,
|
| 444 |
+
hidden_activation: str = "Tanh",
|
| 445 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 446 |
+
complex_mask: bool = True,
|
| 447 |
+
overlapping_band: bool = False,
|
| 448 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 449 |
+
n_freq: Optional[int] = None,
|
| 450 |
+
use_freq_weights:bool=True,
|
| 451 |
+
rnn_type: str = "LSTM",
|
| 452 |
+
cond_dim: int = 0,
|
| 453 |
+
mult_add_mask: bool = False
|
| 454 |
+
) -> None:
|
| 455 |
+
super().__init__()
|
| 456 |
+
self.instantiate_bandsplit(
|
| 457 |
+
in_channel=in_channel,
|
| 458 |
+
band_specs=band_specs,
|
| 459 |
+
require_no_overlap=require_no_overlap,
|
| 460 |
+
require_no_gap=require_no_gap,
|
| 461 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 462 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 463 |
+
emb_dim=emb_dim
|
| 464 |
+
)
|
| 465 |
+
self.tf_model = ConvolutionalTimeFreqModule(
|
| 466 |
+
n_modules=n_sqm_modules,
|
| 467 |
+
emb_dim=emb_dim,
|
| 468 |
+
rnn_dim=rnn_dim,
|
| 469 |
+
bidirectional=bidirectional,
|
| 470 |
+
dropout=tf_dropout,
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
self.instantiate_mask_estim(
|
| 474 |
+
in_channel=in_channel,
|
| 475 |
+
stems=stems,
|
| 476 |
+
band_specs=band_specs,
|
| 477 |
+
emb_dim=emb_dim,
|
| 478 |
+
mlp_dim=mlp_dim,
|
| 479 |
+
cond_dim=cond_dim,
|
| 480 |
+
hidden_activation=hidden_activation,
|
| 481 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 482 |
+
complex_mask=complex_mask,
|
| 483 |
+
overlapping_band=overlapping_band,
|
| 484 |
+
freq_weights=freq_weights,
|
| 485 |
+
n_freq=n_freq,
|
| 486 |
+
use_freq_weights=use_freq_weights,
|
| 487 |
+
mult_add_mask=mult_add_mask
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase):
|
| 492 |
+
def __init__(self) -> None:
|
| 493 |
+
super().__init__()
|
| 494 |
+
|
| 495 |
+
def mask(self, x, m):
|
| 496 |
+
# x.shape = (batch, n_channel, n_freq, n_time)
|
| 497 |
+
# m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
|
| 498 |
+
|
| 499 |
+
_, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape
|
| 500 |
+
padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2)
|
| 501 |
+
|
| 502 |
+
xf = F.unfold(
|
| 503 |
+
x,
|
| 504 |
+
kernel_size=(kernel_freq, kernel_time),
|
| 505 |
+
padding=padding,
|
| 506 |
+
stride=(1, 1),
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
xf = xf.view(
|
| 510 |
+
-1,
|
| 511 |
+
n_channel,
|
| 512 |
+
kernel_freq,
|
| 513 |
+
kernel_time,
|
| 514 |
+
n_freq,
|
| 515 |
+
n_time,
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
sf = xf * m
|
| 519 |
+
|
| 520 |
+
sf = sf.view(
|
| 521 |
+
-1,
|
| 522 |
+
n_channel * kernel_freq * kernel_time,
|
| 523 |
+
n_freq * n_time,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
s = F.fold(
|
| 527 |
+
sf,
|
| 528 |
+
output_size=(n_freq, n_time),
|
| 529 |
+
kernel_size=(kernel_freq, kernel_time),
|
| 530 |
+
padding=padding,
|
| 531 |
+
stride=(1, 1),
|
| 532 |
+
).view(
|
| 533 |
+
-1,
|
| 534 |
+
n_channel,
|
| 535 |
+
n_freq,
|
| 536 |
+
n_time,
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
return s
|
| 540 |
+
|
| 541 |
+
def old_mask(self, x, m):
|
| 542 |
+
# x.shape = (batch, n_channel, n_freq, n_time)
|
| 543 |
+
# m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
|
| 544 |
+
|
| 545 |
+
s = torch.zeros_like(x)
|
| 546 |
+
|
| 547 |
+
_, n_channel, n_freq, n_time = x.shape
|
| 548 |
+
kernel_freq, kernel_time, _, _, _, _ = m.shape
|
| 549 |
+
|
| 550 |
+
# print(x.shape, m.shape)
|
| 551 |
+
|
| 552 |
+
kernel_freq_half = (kernel_freq - 1) // 2
|
| 553 |
+
kernel_time_half = (kernel_time - 1) // 2
|
| 554 |
+
|
| 555 |
+
for ifreq in range(kernel_freq):
|
| 556 |
+
for itime in range(kernel_time):
|
| 557 |
+
df, dt = kernel_freq_half - ifreq, kernel_time_half - itime
|
| 558 |
+
x = x.roll(shifts=(df, dt), dims=(2, 3))
|
| 559 |
+
|
| 560 |
+
# if `df` > 0:
|
| 561 |
+
# x[:, :, :df, :] = 0
|
| 562 |
+
# elif `df` < 0:
|
| 563 |
+
# x[:, :, df:, :] = 0
|
| 564 |
+
|
| 565 |
+
# if `dt` > 0:
|
| 566 |
+
# x[:, :, :, :dt] = 0
|
| 567 |
+
# elif `dt` < 0:
|
| 568 |
+
# x[:, :, :, dt:] = 0
|
| 569 |
+
|
| 570 |
+
fslice = slice(max(0, df), min(n_freq, n_freq + df))
|
| 571 |
+
tslice = slice(max(0, dt), min(n_time, n_time + dt))
|
| 572 |
+
|
| 573 |
+
s[:, :, fslice, tslice] += x[:, :, fslice, tslice] * m[ifreq,
|
| 574 |
+
itime, :,
|
| 575 |
+
:, fslice,
|
| 576 |
+
tslice]
|
| 577 |
+
|
| 578 |
+
return s
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
class MultiSourceMultiPatchingMaskBandSplitCoreRNN(
|
| 582 |
+
PatchingMaskBandsplitCoreBase
|
| 583 |
+
):
|
| 584 |
+
def __init__(
|
| 585 |
+
self,
|
| 586 |
+
in_channel: int,
|
| 587 |
+
stems: List[str],
|
| 588 |
+
band_specs: List[Tuple[float, float]],
|
| 589 |
+
mask_kernel_freq: int,
|
| 590 |
+
mask_kernel_time: int,
|
| 591 |
+
conv_kernel_freq: int,
|
| 592 |
+
conv_kernel_time: int,
|
| 593 |
+
kernel_norm_mlp_version: int,
|
| 594 |
+
require_no_overlap: bool = False,
|
| 595 |
+
require_no_gap: bool = True,
|
| 596 |
+
normalize_channel_independently: bool = False,
|
| 597 |
+
treat_channel_as_feature: bool = True,
|
| 598 |
+
n_sqm_modules: int = 12,
|
| 599 |
+
emb_dim: int = 128,
|
| 600 |
+
rnn_dim: int = 256,
|
| 601 |
+
bidirectional: bool = True,
|
| 602 |
+
rnn_type: str = "LSTM",
|
| 603 |
+
mlp_dim: int = 512,
|
| 604 |
+
hidden_activation: str = "Tanh",
|
| 605 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 606 |
+
complex_mask: bool = True,
|
| 607 |
+
overlapping_band: bool = False,
|
| 608 |
+
freq_weights: Optional[List[torch.Tensor]] = None,
|
| 609 |
+
n_freq: Optional[int] = None,
|
| 610 |
+
) -> None:
|
| 611 |
+
|
| 612 |
+
super().__init__()
|
| 613 |
+
self.band_split = BandSplitModule(
|
| 614 |
+
in_channel=in_channel,
|
| 615 |
+
band_specs=band_specs,
|
| 616 |
+
require_no_overlap=require_no_overlap,
|
| 617 |
+
require_no_gap=require_no_gap,
|
| 618 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 619 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 620 |
+
emb_dim=emb_dim,
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
self.tf_model = (
|
| 624 |
+
SeqBandModellingModule(
|
| 625 |
+
n_modules=n_sqm_modules,
|
| 626 |
+
emb_dim=emb_dim,
|
| 627 |
+
rnn_dim=rnn_dim,
|
| 628 |
+
bidirectional=bidirectional,
|
| 629 |
+
rnn_type=rnn_type,
|
| 630 |
+
)
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
if hidden_activation_kwargs is None:
|
| 634 |
+
hidden_activation_kwargs = {}
|
| 635 |
+
|
| 636 |
+
if overlapping_band:
|
| 637 |
+
assert freq_weights is not None
|
| 638 |
+
assert n_freq is not None
|
| 639 |
+
self.mask_estim = nn.ModuleDict(
|
| 640 |
+
{
|
| 641 |
+
stem: PatchingMaskEstimationModule(
|
| 642 |
+
band_specs=band_specs,
|
| 643 |
+
freq_weights=freq_weights,
|
| 644 |
+
n_freq=n_freq,
|
| 645 |
+
emb_dim=emb_dim,
|
| 646 |
+
mlp_dim=mlp_dim,
|
| 647 |
+
in_channel=in_channel,
|
| 648 |
+
hidden_activation=hidden_activation,
|
| 649 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 650 |
+
complex_mask=complex_mask,
|
| 651 |
+
mask_kernel_freq=mask_kernel_freq,
|
| 652 |
+
mask_kernel_time=mask_kernel_time,
|
| 653 |
+
conv_kernel_freq=conv_kernel_freq,
|
| 654 |
+
conv_kernel_time=conv_kernel_time,
|
| 655 |
+
kernel_norm_mlp_version=kernel_norm_mlp_version
|
| 656 |
+
)
|
| 657 |
+
for stem in stems
|
| 658 |
+
}
|
| 659 |
+
)
|
| 660 |
+
else:
|
| 661 |
+
raise NotImplementedError
|
separator/models/bandit/core/model/bsrnn/maskestim.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Dict, List, Optional, Tuple, Type
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn.modules import activation
|
| 7 |
+
|
| 8 |
+
from models.bandit.core.model.bsrnn.utils import (
|
| 9 |
+
band_widths_from_specs,
|
| 10 |
+
check_no_gap,
|
| 11 |
+
check_no_overlap,
|
| 12 |
+
check_nonzero_bandwidth,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseNormMLP(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
emb_dim: int,
|
| 20 |
+
mlp_dim: int,
|
| 21 |
+
bandwidth: int,
|
| 22 |
+
in_channel: Optional[int],
|
| 23 |
+
hidden_activation: str = "Tanh",
|
| 24 |
+
hidden_activation_kwargs=None,
|
| 25 |
+
complex_mask: bool = True, ):
|
| 26 |
+
|
| 27 |
+
super().__init__()
|
| 28 |
+
if hidden_activation_kwargs is None:
|
| 29 |
+
hidden_activation_kwargs = {}
|
| 30 |
+
self.hidden_activation_kwargs = hidden_activation_kwargs
|
| 31 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 32 |
+
self.hidden = torch.jit.script(nn.Sequential(
|
| 33 |
+
nn.Linear(in_features=emb_dim, out_features=mlp_dim),
|
| 34 |
+
activation.__dict__[hidden_activation](
|
| 35 |
+
**self.hidden_activation_kwargs
|
| 36 |
+
),
|
| 37 |
+
))
|
| 38 |
+
|
| 39 |
+
self.bandwidth = bandwidth
|
| 40 |
+
self.in_channel = in_channel
|
| 41 |
+
|
| 42 |
+
self.complex_mask = complex_mask
|
| 43 |
+
self.reim = 2 if complex_mask else 1
|
| 44 |
+
self.glu_mult = 2
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class NormMLP(BaseNormMLP):
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
emb_dim: int,
|
| 51 |
+
mlp_dim: int,
|
| 52 |
+
bandwidth: int,
|
| 53 |
+
in_channel: Optional[int],
|
| 54 |
+
hidden_activation: str = "Tanh",
|
| 55 |
+
hidden_activation_kwargs=None,
|
| 56 |
+
complex_mask: bool = True,
|
| 57 |
+
) -> None:
|
| 58 |
+
super().__init__(
|
| 59 |
+
emb_dim=emb_dim,
|
| 60 |
+
mlp_dim=mlp_dim,
|
| 61 |
+
bandwidth=bandwidth,
|
| 62 |
+
in_channel=in_channel,
|
| 63 |
+
hidden_activation=hidden_activation,
|
| 64 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 65 |
+
complex_mask=complex_mask,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
self.output = torch.jit.script(
|
| 69 |
+
nn.Sequential(
|
| 70 |
+
nn.Linear(
|
| 71 |
+
in_features=mlp_dim,
|
| 72 |
+
out_features=bandwidth * in_channel * self.reim * 2,
|
| 73 |
+
),
|
| 74 |
+
nn.GLU(dim=-1),
|
| 75 |
+
)
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def reshape_output(self, mb):
|
| 79 |
+
# print(mb.shape)
|
| 80 |
+
batch, n_time, _ = mb.shape
|
| 81 |
+
if self.complex_mask:
|
| 82 |
+
mb = mb.reshape(
|
| 83 |
+
batch,
|
| 84 |
+
n_time,
|
| 85 |
+
self.in_channel,
|
| 86 |
+
self.bandwidth,
|
| 87 |
+
self.reim
|
| 88 |
+
).contiguous()
|
| 89 |
+
# print(mb.shape)
|
| 90 |
+
mb = torch.view_as_complex(
|
| 91 |
+
mb
|
| 92 |
+
) # (batch, n_time, in_channel, bandwidth)
|
| 93 |
+
else:
|
| 94 |
+
mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)
|
| 95 |
+
|
| 96 |
+
mb = torch.permute(
|
| 97 |
+
mb,
|
| 98 |
+
(0, 2, 3, 1)
|
| 99 |
+
) # (batch, in_channel, bandwidth, n_time)
|
| 100 |
+
|
| 101 |
+
return mb
|
| 102 |
+
|
| 103 |
+
def forward(self, qb):
|
| 104 |
+
# qb = (batch, n_time, emb_dim)
|
| 105 |
+
|
| 106 |
+
# if torch.any(torch.isnan(qb)):
|
| 107 |
+
# raise ValueError("qb0")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
qb = self.norm(qb) # (batch, n_time, emb_dim)
|
| 111 |
+
|
| 112 |
+
# if torch.any(torch.isnan(qb)):
|
| 113 |
+
# raise ValueError("qb1")
|
| 114 |
+
|
| 115 |
+
qb = self.hidden(qb) # (batch, n_time, mlp_dim)
|
| 116 |
+
# if torch.any(torch.isnan(qb)):
|
| 117 |
+
# raise ValueError("qb2")
|
| 118 |
+
mb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
|
| 119 |
+
# if torch.any(torch.isnan(qb)):
|
| 120 |
+
# raise ValueError("mb")
|
| 121 |
+
mb = self.reshape_output(mb) # (batch, in_channel, bandwidth, n_time)
|
| 122 |
+
|
| 123 |
+
return mb
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class MultAddNormMLP(NormMLP):
|
| 127 |
+
def __init__(self, emb_dim: int, mlp_dim: int, bandwidth: int, in_channel: "int | None", hidden_activation: str = "Tanh", hidden_activation_kwargs=None, complex_mask: bool = True) -> None:
|
| 128 |
+
super().__init__(emb_dim, mlp_dim, bandwidth, in_channel, hidden_activation, hidden_activation_kwargs, complex_mask)
|
| 129 |
+
|
| 130 |
+
self.output2 = torch.jit.script(
|
| 131 |
+
nn.Sequential(
|
| 132 |
+
nn.Linear(
|
| 133 |
+
in_features=mlp_dim,
|
| 134 |
+
out_features=bandwidth * in_channel * self.reim * 2,
|
| 135 |
+
),
|
| 136 |
+
nn.GLU(dim=-1),
|
| 137 |
+
)
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def forward(self, qb):
|
| 141 |
+
|
| 142 |
+
qb = self.norm(qb) # (batch, n_time, emb_dim)
|
| 143 |
+
qb = self.hidden(qb) # (batch, n_time, mlp_dim)
|
| 144 |
+
mmb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
|
| 145 |
+
mmb = self.reshape_output(mmb) # (batch, in_channel, bandwidth, n_time)
|
| 146 |
+
amb = self.output2(qb) # (batch, n_time, bandwidth * in_channel * reim)
|
| 147 |
+
amb = self.reshape_output(amb) # (batch, in_channel, bandwidth, n_time)
|
| 148 |
+
|
| 149 |
+
return mmb, amb
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class MaskEstimationModuleSuperBase(nn.Module):
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
|
| 157 |
+
def __init__(
|
| 158 |
+
self,
|
| 159 |
+
band_specs: List[Tuple[float, float]],
|
| 160 |
+
emb_dim: int,
|
| 161 |
+
mlp_dim: int,
|
| 162 |
+
in_channel: Optional[int],
|
| 163 |
+
hidden_activation: str = "Tanh",
|
| 164 |
+
hidden_activation_kwargs: Dict = None,
|
| 165 |
+
complex_mask: bool = True,
|
| 166 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 167 |
+
norm_mlp_kwargs: Dict = None,
|
| 168 |
+
) -> None:
|
| 169 |
+
super().__init__()
|
| 170 |
+
|
| 171 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
| 172 |
+
self.n_bands = len(band_specs)
|
| 173 |
+
|
| 174 |
+
if hidden_activation_kwargs is None:
|
| 175 |
+
hidden_activation_kwargs = {}
|
| 176 |
+
|
| 177 |
+
if norm_mlp_kwargs is None:
|
| 178 |
+
norm_mlp_kwargs = {}
|
| 179 |
+
|
| 180 |
+
self.norm_mlp = nn.ModuleList(
|
| 181 |
+
[
|
| 182 |
+
(
|
| 183 |
+
norm_mlp_cls(
|
| 184 |
+
bandwidth=self.band_widths[b],
|
| 185 |
+
emb_dim=emb_dim,
|
| 186 |
+
mlp_dim=mlp_dim,
|
| 187 |
+
in_channel=in_channel,
|
| 188 |
+
hidden_activation=hidden_activation,
|
| 189 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 190 |
+
complex_mask=complex_mask,
|
| 191 |
+
**norm_mlp_kwargs,
|
| 192 |
+
)
|
| 193 |
+
)
|
| 194 |
+
for b in range(self.n_bands)
|
| 195 |
+
]
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
def compute_masks(self, q):
|
| 199 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 200 |
+
|
| 201 |
+
masks = []
|
| 202 |
+
|
| 203 |
+
for b, nmlp in enumerate(self.norm_mlp):
|
| 204 |
+
# print(f"maskestim/{b:02d}")
|
| 205 |
+
qb = q[:, b, :, :]
|
| 206 |
+
mb = nmlp(qb)
|
| 207 |
+
masks.append(mb)
|
| 208 |
+
|
| 209 |
+
return masks
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
in_channel: int,
|
| 217 |
+
band_specs: List[Tuple[float, float]],
|
| 218 |
+
freq_weights: List[torch.Tensor],
|
| 219 |
+
n_freq: int,
|
| 220 |
+
emb_dim: int,
|
| 221 |
+
mlp_dim: int,
|
| 222 |
+
cond_dim: int = 0,
|
| 223 |
+
hidden_activation: str = "Tanh",
|
| 224 |
+
hidden_activation_kwargs: Dict = None,
|
| 225 |
+
complex_mask: bool = True,
|
| 226 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 227 |
+
norm_mlp_kwargs: Dict = None,
|
| 228 |
+
use_freq_weights: bool = True,
|
| 229 |
+
) -> None:
|
| 230 |
+
check_nonzero_bandwidth(band_specs)
|
| 231 |
+
check_no_gap(band_specs)
|
| 232 |
+
|
| 233 |
+
# if cond_dim > 0:
|
| 234 |
+
# raise NotImplementedError
|
| 235 |
+
|
| 236 |
+
super().__init__(
|
| 237 |
+
band_specs=band_specs,
|
| 238 |
+
emb_dim=emb_dim + cond_dim,
|
| 239 |
+
mlp_dim=mlp_dim,
|
| 240 |
+
in_channel=in_channel,
|
| 241 |
+
hidden_activation=hidden_activation,
|
| 242 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 243 |
+
complex_mask=complex_mask,
|
| 244 |
+
norm_mlp_cls=norm_mlp_cls,
|
| 245 |
+
norm_mlp_kwargs=norm_mlp_kwargs,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
self.n_freq = n_freq
|
| 249 |
+
self.band_specs = band_specs
|
| 250 |
+
self.in_channel = in_channel
|
| 251 |
+
|
| 252 |
+
if freq_weights is not None:
|
| 253 |
+
for i, fw in enumerate(freq_weights):
|
| 254 |
+
self.register_buffer(f"freq_weights/{i}", fw)
|
| 255 |
+
|
| 256 |
+
self.use_freq_weights = use_freq_weights
|
| 257 |
+
else:
|
| 258 |
+
self.use_freq_weights = False
|
| 259 |
+
|
| 260 |
+
self.cond_dim = cond_dim
|
| 261 |
+
|
| 262 |
+
def forward(self, q, cond=None):
|
| 263 |
+
# q = (batch, n_bands, n_time, emb_dim)
|
| 264 |
+
|
| 265 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 266 |
+
|
| 267 |
+
if cond is not None:
|
| 268 |
+
print(cond)
|
| 269 |
+
if cond.ndim == 2:
|
| 270 |
+
cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1)
|
| 271 |
+
elif cond.ndim == 3:
|
| 272 |
+
assert cond.shape[1] == n_time
|
| 273 |
+
else:
|
| 274 |
+
raise ValueError(f"Invalid cond shape: {cond.shape}")
|
| 275 |
+
|
| 276 |
+
q = torch.cat([q, cond], dim=-1)
|
| 277 |
+
elif self.cond_dim > 0:
|
| 278 |
+
cond = torch.ones(
|
| 279 |
+
(batch, n_bands, n_time, self.cond_dim),
|
| 280 |
+
device=q.device,
|
| 281 |
+
dtype=q.dtype,
|
| 282 |
+
)
|
| 283 |
+
q = torch.cat([q, cond], dim=-1)
|
| 284 |
+
else:
|
| 285 |
+
pass
|
| 286 |
+
|
| 287 |
+
mask_list = self.compute_masks(
|
| 288 |
+
q
|
| 289 |
+
) # [n_bands * (batch, in_channel, bandwidth, n_time)]
|
| 290 |
+
|
| 291 |
+
masks = torch.zeros(
|
| 292 |
+
(batch, self.in_channel, self.n_freq, n_time),
|
| 293 |
+
device=q.device,
|
| 294 |
+
dtype=mask_list[0].dtype,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
for im, mask in enumerate(mask_list):
|
| 298 |
+
fstart, fend = self.band_specs[im]
|
| 299 |
+
if self.use_freq_weights:
|
| 300 |
+
fw = self.get_buffer(f"freq_weights/{im}")[:, None]
|
| 301 |
+
mask = mask * fw
|
| 302 |
+
masks[:, :, fstart:fend, :] += mask
|
| 303 |
+
|
| 304 |
+
return masks
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class MaskEstimationModule(OverlappingMaskEstimationModule):
|
| 308 |
+
def __init__(
|
| 309 |
+
self,
|
| 310 |
+
band_specs: List[Tuple[float, float]],
|
| 311 |
+
emb_dim: int,
|
| 312 |
+
mlp_dim: int,
|
| 313 |
+
in_channel: Optional[int],
|
| 314 |
+
hidden_activation: str = "Tanh",
|
| 315 |
+
hidden_activation_kwargs: Dict = None,
|
| 316 |
+
complex_mask: bool = True,
|
| 317 |
+
**kwargs,
|
| 318 |
+
) -> None:
|
| 319 |
+
check_nonzero_bandwidth(band_specs)
|
| 320 |
+
check_no_gap(band_specs)
|
| 321 |
+
check_no_overlap(band_specs)
|
| 322 |
+
super().__init__(
|
| 323 |
+
in_channel=in_channel,
|
| 324 |
+
band_specs=band_specs,
|
| 325 |
+
freq_weights=None,
|
| 326 |
+
n_freq=None,
|
| 327 |
+
emb_dim=emb_dim,
|
| 328 |
+
mlp_dim=mlp_dim,
|
| 329 |
+
hidden_activation=hidden_activation,
|
| 330 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 331 |
+
complex_mask=complex_mask,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
def forward(self, q, cond=None):
|
| 335 |
+
# q = (batch, n_bands, n_time, emb_dim)
|
| 336 |
+
|
| 337 |
+
masks = self.compute_masks(
|
| 338 |
+
q
|
| 339 |
+
) # [n_bands * (batch, in_channel, bandwidth, n_time)]
|
| 340 |
+
|
| 341 |
+
# TODO: currently this requires band specs to have no gap and no overlap
|
| 342 |
+
masks = torch.concat(
|
| 343 |
+
masks,
|
| 344 |
+
dim=2
|
| 345 |
+
) # (batch, in_channel, n_freq, n_time)
|
| 346 |
+
|
| 347 |
+
return masks
|
separator/models/bandit/core/model/bsrnn/tfmodel.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch.nn.modules import rnn
|
| 7 |
+
|
| 8 |
+
import torch.backends.cuda
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TimeFrequencyModellingModule(nn.Module):
|
| 12 |
+
def __init__(self) -> None:
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ResidualRNN(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
emb_dim: int,
|
| 20 |
+
rnn_dim: int,
|
| 21 |
+
bidirectional: bool = True,
|
| 22 |
+
rnn_type: str = "LSTM",
|
| 23 |
+
use_batch_trick: bool = True,
|
| 24 |
+
use_layer_norm: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
# n_group is the size of the 2nd dim
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
self.use_layer_norm = use_layer_norm
|
| 30 |
+
if use_layer_norm:
|
| 31 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 32 |
+
else:
|
| 33 |
+
self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)
|
| 34 |
+
|
| 35 |
+
self.rnn = rnn.__dict__[rnn_type](
|
| 36 |
+
input_size=emb_dim,
|
| 37 |
+
hidden_size=rnn_dim,
|
| 38 |
+
num_layers=1,
|
| 39 |
+
batch_first=True,
|
| 40 |
+
bidirectional=bidirectional,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self.fc = nn.Linear(
|
| 44 |
+
in_features=rnn_dim * (2 if bidirectional else 1),
|
| 45 |
+
out_features=emb_dim
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.use_batch_trick = use_batch_trick
|
| 49 |
+
if not self.use_batch_trick:
|
| 50 |
+
warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
|
| 51 |
+
|
| 52 |
+
def forward(self, z):
|
| 53 |
+
# z = (batch, n_uncrossed, n_across, emb_dim)
|
| 54 |
+
|
| 55 |
+
z0 = torch.clone(z)
|
| 56 |
+
|
| 57 |
+
# print(z.device)
|
| 58 |
+
|
| 59 |
+
if self.use_layer_norm:
|
| 60 |
+
z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
|
| 61 |
+
else:
|
| 62 |
+
z = torch.permute(
|
| 63 |
+
z, (0, 3, 1, 2)
|
| 64 |
+
) # (batch, emb_dim, n_uncrossed, n_across)
|
| 65 |
+
|
| 66 |
+
z = self.norm(z) # (batch, emb_dim, n_uncrossed, n_across)
|
| 67 |
+
|
| 68 |
+
z = torch.permute(
|
| 69 |
+
z, (0, 2, 3, 1)
|
| 70 |
+
) # (batch, n_uncrossed, n_across, emb_dim)
|
| 71 |
+
|
| 72 |
+
batch, n_uncrossed, n_across, emb_dim = z.shape
|
| 73 |
+
|
| 74 |
+
if self.use_batch_trick:
|
| 75 |
+
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
| 76 |
+
|
| 77 |
+
z = self.rnn(z.contiguous())[0] # (batch * n_uncrossed, n_across, dir_rnn_dim)
|
| 78 |
+
|
| 79 |
+
z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
|
| 80 |
+
# (batch, n_uncrossed, n_across, dir_rnn_dim)
|
| 81 |
+
else:
|
| 82 |
+
# Note: this is EXTREMELY SLOW
|
| 83 |
+
zlist = []
|
| 84 |
+
for i in range(n_uncrossed):
|
| 85 |
+
zi = self.rnn(z[:, i, :, :])[0] # (batch, n_across, emb_dim)
|
| 86 |
+
zlist.append(zi)
|
| 87 |
+
|
| 88 |
+
z = torch.stack(
|
| 89 |
+
zlist,
|
| 90 |
+
dim=1
|
| 91 |
+
) # (batch, n_uncrossed, n_across, dir_rnn_dim)
|
| 92 |
+
|
| 93 |
+
z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
|
| 94 |
+
|
| 95 |
+
z = z + z0
|
| 96 |
+
|
| 97 |
+
return z
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class SeqBandModellingModule(TimeFrequencyModellingModule):
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
n_modules: int = 12,
|
| 104 |
+
emb_dim: int = 128,
|
| 105 |
+
rnn_dim: int = 256,
|
| 106 |
+
bidirectional: bool = True,
|
| 107 |
+
rnn_type: str = "LSTM",
|
| 108 |
+
parallel_mode=False,
|
| 109 |
+
) -> None:
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.seqband = nn.ModuleList([])
|
| 112 |
+
|
| 113 |
+
if parallel_mode:
|
| 114 |
+
for _ in range(n_modules):
|
| 115 |
+
self.seqband.append(
|
| 116 |
+
nn.ModuleList(
|
| 117 |
+
[ResidualRNN(
|
| 118 |
+
emb_dim=emb_dim,
|
| 119 |
+
rnn_dim=rnn_dim,
|
| 120 |
+
bidirectional=bidirectional,
|
| 121 |
+
rnn_type=rnn_type,
|
| 122 |
+
),
|
| 123 |
+
ResidualRNN(
|
| 124 |
+
emb_dim=emb_dim,
|
| 125 |
+
rnn_dim=rnn_dim,
|
| 126 |
+
bidirectional=bidirectional,
|
| 127 |
+
rnn_type=rnn_type,
|
| 128 |
+
)]
|
| 129 |
+
)
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
|
| 133 |
+
for _ in range(2 * n_modules):
|
| 134 |
+
self.seqband.append(
|
| 135 |
+
ResidualRNN(
|
| 136 |
+
emb_dim=emb_dim,
|
| 137 |
+
rnn_dim=rnn_dim,
|
| 138 |
+
bidirectional=bidirectional,
|
| 139 |
+
rnn_type=rnn_type,
|
| 140 |
+
)
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
self.parallel_mode = parallel_mode
|
| 144 |
+
|
| 145 |
+
def forward(self, z):
|
| 146 |
+
# z = (batch, n_bands, n_time, emb_dim)
|
| 147 |
+
|
| 148 |
+
if self.parallel_mode:
|
| 149 |
+
for sbm_pair in self.seqband:
|
| 150 |
+
# z: (batch, n_bands, n_time, emb_dim)
|
| 151 |
+
sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
|
| 152 |
+
zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
|
| 153 |
+
zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
|
| 154 |
+
z = zt + zf.transpose(1, 2)
|
| 155 |
+
else:
|
| 156 |
+
for sbm in self.seqband:
|
| 157 |
+
z = sbm(z)
|
| 158 |
+
z = z.transpose(1, 2)
|
| 159 |
+
|
| 160 |
+
# (batch, n_bands, n_time, emb_dim)
|
| 161 |
+
# --> (batch, n_time, n_bands, emb_dim)
|
| 162 |
+
# OR
|
| 163 |
+
# (batch, n_time, n_bands, emb_dim)
|
| 164 |
+
# --> (batch, n_bands, n_time, emb_dim)
|
| 165 |
+
|
| 166 |
+
q = z
|
| 167 |
+
return q # (batch, n_bands, n_time, emb_dim)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class ResidualTransformer(nn.Module):
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
emb_dim: int = 128,
|
| 174 |
+
rnn_dim: int = 256,
|
| 175 |
+
bidirectional: bool = True,
|
| 176 |
+
dropout: float = 0.0,
|
| 177 |
+
) -> None:
|
| 178 |
+
# n_group is the size of the 2nd dim
|
| 179 |
+
super().__init__()
|
| 180 |
+
|
| 181 |
+
self.tf = nn.TransformerEncoderLayer(
|
| 182 |
+
d_model=emb_dim,
|
| 183 |
+
nhead=4,
|
| 184 |
+
dim_feedforward=rnn_dim,
|
| 185 |
+
batch_first=True
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.is_causal = not bidirectional
|
| 189 |
+
self.dropout = dropout
|
| 190 |
+
|
| 191 |
+
def forward(self, z):
|
| 192 |
+
batch, n_uncrossed, n_across, emb_dim = z.shape
|
| 193 |
+
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
| 194 |
+
z = self.tf(z, is_causal=self.is_causal) # (batch, n_uncrossed, n_across, emb_dim)
|
| 195 |
+
z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim))
|
| 196 |
+
|
| 197 |
+
return z
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class TransformerTimeFreqModule(TimeFrequencyModellingModule):
|
| 201 |
+
def __init__(
|
| 202 |
+
self,
|
| 203 |
+
n_modules: int = 12,
|
| 204 |
+
emb_dim: int = 128,
|
| 205 |
+
rnn_dim: int = 256,
|
| 206 |
+
bidirectional: bool = True,
|
| 207 |
+
dropout: float = 0.0,
|
| 208 |
+
) -> None:
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 211 |
+
self.seqband = nn.ModuleList([])
|
| 212 |
+
|
| 213 |
+
for _ in range(2 * n_modules):
|
| 214 |
+
self.seqband.append(
|
| 215 |
+
ResidualTransformer(
|
| 216 |
+
emb_dim=emb_dim,
|
| 217 |
+
rnn_dim=rnn_dim,
|
| 218 |
+
bidirectional=bidirectional,
|
| 219 |
+
dropout=dropout,
|
| 220 |
+
)
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
def forward(self, z):
|
| 224 |
+
# z = (batch, n_bands, n_time, emb_dim)
|
| 225 |
+
z = self.norm(z) # (batch, n_bands, n_time, emb_dim)
|
| 226 |
+
|
| 227 |
+
for sbm in self.seqband:
|
| 228 |
+
z = sbm(z)
|
| 229 |
+
z = z.transpose(1, 2)
|
| 230 |
+
|
| 231 |
+
# (batch, n_bands, n_time, emb_dim)
|
| 232 |
+
# --> (batch, n_time, n_bands, emb_dim)
|
| 233 |
+
# OR
|
| 234 |
+
# (batch, n_time, n_bands, emb_dim)
|
| 235 |
+
# --> (batch, n_bands, n_time, emb_dim)
|
| 236 |
+
|
| 237 |
+
q = z
|
| 238 |
+
return q # (batch, n_bands, n_time, emb_dim)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class ResidualConvolution(nn.Module):
|
| 243 |
+
def __init__(
|
| 244 |
+
self,
|
| 245 |
+
emb_dim: int = 128,
|
| 246 |
+
rnn_dim: int = 256,
|
| 247 |
+
bidirectional: bool = True,
|
| 248 |
+
dropout: float = 0.0,
|
| 249 |
+
) -> None:
|
| 250 |
+
# n_group is the size of the 2nd dim
|
| 251 |
+
super().__init__()
|
| 252 |
+
self.norm = nn.InstanceNorm2d(emb_dim, affine=True)
|
| 253 |
+
|
| 254 |
+
self.conv = nn.Sequential(
|
| 255 |
+
nn.Conv2d(
|
| 256 |
+
in_channels=emb_dim,
|
| 257 |
+
out_channels=rnn_dim,
|
| 258 |
+
kernel_size=(3, 3),
|
| 259 |
+
padding="same",
|
| 260 |
+
stride=(1, 1),
|
| 261 |
+
),
|
| 262 |
+
nn.Tanhshrink()
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
self.is_causal = not bidirectional
|
| 266 |
+
self.dropout = dropout
|
| 267 |
+
|
| 268 |
+
self.fc = nn.Conv2d(
|
| 269 |
+
in_channels=rnn_dim,
|
| 270 |
+
out_channels=emb_dim,
|
| 271 |
+
kernel_size=(1, 1),
|
| 272 |
+
padding="same",
|
| 273 |
+
stride=(1, 1),
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def forward(self, z):
|
| 278 |
+
# z = (batch, n_uncrossed, n_across, emb_dim)
|
| 279 |
+
|
| 280 |
+
z0 = torch.clone(z)
|
| 281 |
+
|
| 282 |
+
z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
|
| 283 |
+
z = self.conv(z) # (batch, n_uncrossed, n_across, emb_dim)
|
| 284 |
+
z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
|
| 285 |
+
z = z + z0
|
| 286 |
+
|
| 287 |
+
return z
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule):
|
| 291 |
+
def __init__(
|
| 292 |
+
self,
|
| 293 |
+
n_modules: int = 12,
|
| 294 |
+
emb_dim: int = 128,
|
| 295 |
+
rnn_dim: int = 256,
|
| 296 |
+
bidirectional: bool = True,
|
| 297 |
+
dropout: float = 0.0,
|
| 298 |
+
) -> None:
|
| 299 |
+
super().__init__()
|
| 300 |
+
self.seqband = torch.jit.script(nn.Sequential(
|
| 301 |
+
*[ResidualConvolution(
|
| 302 |
+
emb_dim=emb_dim,
|
| 303 |
+
rnn_dim=rnn_dim,
|
| 304 |
+
bidirectional=bidirectional,
|
| 305 |
+
dropout=dropout,
|
| 306 |
+
) for _ in range(2 * n_modules) ]))
|
| 307 |
+
|
| 308 |
+
def forward(self, z):
|
| 309 |
+
# z = (batch, n_bands, n_time, emb_dim)
|
| 310 |
+
|
| 311 |
+
z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time)
|
| 312 |
+
|
| 313 |
+
z = self.seqband(z) # (batch, emb_dim, n_bands, n_time)
|
| 314 |
+
|
| 315 |
+
z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim)
|
| 316 |
+
|
| 317 |
+
return z
|
separator/models/bandit/core/model/bsrnn/utils.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
+
from typing import Any, Callable
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from librosa import hz_to_midi, midi_to_hz
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torchaudio import functional as taF
|
| 10 |
+
from spafe.fbanks import bark_fbanks
|
| 11 |
+
from spafe.utils.converters import erb2hz, hz2bark, hz2erb
|
| 12 |
+
from torchaudio.functional.functional import _create_triangular_filterbank
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def band_widths_from_specs(band_specs):
|
| 16 |
+
return [e - i for i, e in band_specs]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def check_nonzero_bandwidth(band_specs):
|
| 20 |
+
# pprint(band_specs)
|
| 21 |
+
for fstart, fend in band_specs:
|
| 22 |
+
if fend - fstart <= 0:
|
| 23 |
+
raise ValueError("Bands cannot be zero-width")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def check_no_overlap(band_specs):
|
| 27 |
+
fend_prev = -1
|
| 28 |
+
for fstart_curr, fend_curr in band_specs:
|
| 29 |
+
if fstart_curr <= fend_prev:
|
| 30 |
+
raise ValueError("Bands cannot overlap")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def check_no_gap(band_specs):
|
| 34 |
+
fstart, _ = band_specs[0]
|
| 35 |
+
assert fstart == 0
|
| 36 |
+
|
| 37 |
+
fend_prev = -1
|
| 38 |
+
for fstart_curr, fend_curr in band_specs:
|
| 39 |
+
if fstart_curr - fend_prev > 1:
|
| 40 |
+
raise ValueError("Bands cannot leave gap")
|
| 41 |
+
fend_prev = fend_curr
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class BandsplitSpecification:
|
| 45 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 46 |
+
self.fs = fs
|
| 47 |
+
self.nfft = nfft
|
| 48 |
+
self.nyquist = fs / 2
|
| 49 |
+
self.max_index = nfft // 2 + 1
|
| 50 |
+
|
| 51 |
+
self.split500 = self.hertz_to_index(500)
|
| 52 |
+
self.split1k = self.hertz_to_index(1000)
|
| 53 |
+
self.split2k = self.hertz_to_index(2000)
|
| 54 |
+
self.split4k = self.hertz_to_index(4000)
|
| 55 |
+
self.split8k = self.hertz_to_index(8000)
|
| 56 |
+
self.split16k = self.hertz_to_index(16000)
|
| 57 |
+
self.split20k = self.hertz_to_index(20000)
|
| 58 |
+
|
| 59 |
+
self.above20k = [(self.split20k, self.max_index)]
|
| 60 |
+
self.above16k = [(self.split16k, self.split20k)] + self.above20k
|
| 61 |
+
|
| 62 |
+
def index_to_hertz(self, index: int):
|
| 63 |
+
return index * self.fs / self.nfft
|
| 64 |
+
|
| 65 |
+
def hertz_to_index(self, hz: float, round: bool = True):
|
| 66 |
+
index = hz * self.nfft / self.fs
|
| 67 |
+
|
| 68 |
+
if round:
|
| 69 |
+
index = int(np.round(index))
|
| 70 |
+
|
| 71 |
+
return index
|
| 72 |
+
|
| 73 |
+
def get_band_specs_with_bandwidth(
|
| 74 |
+
self,
|
| 75 |
+
start_index,
|
| 76 |
+
end_index,
|
| 77 |
+
bandwidth_hz
|
| 78 |
+
):
|
| 79 |
+
band_specs = []
|
| 80 |
+
lower = start_index
|
| 81 |
+
|
| 82 |
+
while lower < end_index:
|
| 83 |
+
upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
|
| 84 |
+
upper = min(upper, end_index)
|
| 85 |
+
|
| 86 |
+
band_specs.append((lower, upper))
|
| 87 |
+
lower = upper
|
| 88 |
+
|
| 89 |
+
return band_specs
|
| 90 |
+
|
| 91 |
+
@abstractmethod
|
| 92 |
+
def get_band_specs(self):
|
| 93 |
+
raise NotImplementedError
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class VocalBandsplitSpecification(BandsplitSpecification):
|
| 97 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 98 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 99 |
+
|
| 100 |
+
self.version = version
|
| 101 |
+
|
| 102 |
+
def get_band_specs(self):
|
| 103 |
+
return getattr(self, f"version{self.version}")()
|
| 104 |
+
|
| 105 |
+
@property
|
| 106 |
+
def version1(self):
|
| 107 |
+
return self.get_band_specs_with_bandwidth(
|
| 108 |
+
start_index=0, end_index=self.max_index, bandwidth_hz=1000
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def version2(self):
|
| 112 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 113 |
+
start_index=0, end_index=self.split16k, bandwidth_hz=1000
|
| 114 |
+
)
|
| 115 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 116 |
+
start_index=self.split16k,
|
| 117 |
+
end_index=self.split20k,
|
| 118 |
+
bandwidth_hz=2000
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
return below16k + below20k + self.above20k
|
| 122 |
+
|
| 123 |
+
def version3(self):
|
| 124 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 125 |
+
start_index=0, end_index=self.split8k, bandwidth_hz=1000
|
| 126 |
+
)
|
| 127 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 128 |
+
start_index=self.split8k,
|
| 129 |
+
end_index=self.split16k,
|
| 130 |
+
bandwidth_hz=2000
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
return below8k + below16k + self.above16k
|
| 134 |
+
|
| 135 |
+
def version4(self):
|
| 136 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 137 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 138 |
+
)
|
| 139 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 140 |
+
start_index=self.split1k,
|
| 141 |
+
end_index=self.split8k,
|
| 142 |
+
bandwidth_hz=1000
|
| 143 |
+
)
|
| 144 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 145 |
+
start_index=self.split8k,
|
| 146 |
+
end_index=self.split16k,
|
| 147 |
+
bandwidth_hz=2000
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
return below1k + below8k + below16k + self.above16k
|
| 151 |
+
|
| 152 |
+
def version5(self):
|
| 153 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 154 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 155 |
+
)
|
| 156 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 157 |
+
start_index=self.split1k,
|
| 158 |
+
end_index=self.split16k,
|
| 159 |
+
bandwidth_hz=1000
|
| 160 |
+
)
|
| 161 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 162 |
+
start_index=self.split16k,
|
| 163 |
+
end_index=self.split20k,
|
| 164 |
+
bandwidth_hz=2000
|
| 165 |
+
)
|
| 166 |
+
return below1k + below16k + below20k + self.above20k
|
| 167 |
+
|
| 168 |
+
def version6(self):
|
| 169 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 170 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 171 |
+
)
|
| 172 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 173 |
+
start_index=self.split1k,
|
| 174 |
+
end_index=self.split4k,
|
| 175 |
+
bandwidth_hz=500
|
| 176 |
+
)
|
| 177 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 178 |
+
start_index=self.split4k,
|
| 179 |
+
end_index=self.split8k,
|
| 180 |
+
bandwidth_hz=1000
|
| 181 |
+
)
|
| 182 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 183 |
+
start_index=self.split8k,
|
| 184 |
+
end_index=self.split16k,
|
| 185 |
+
bandwidth_hz=2000
|
| 186 |
+
)
|
| 187 |
+
return below1k + below4k + below8k + below16k + self.above16k
|
| 188 |
+
|
| 189 |
+
def version7(self):
|
| 190 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 191 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 192 |
+
)
|
| 193 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 194 |
+
start_index=self.split1k,
|
| 195 |
+
end_index=self.split4k,
|
| 196 |
+
bandwidth_hz=250
|
| 197 |
+
)
|
| 198 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 199 |
+
start_index=self.split4k,
|
| 200 |
+
end_index=self.split8k,
|
| 201 |
+
bandwidth_hz=500
|
| 202 |
+
)
|
| 203 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 204 |
+
start_index=self.split8k,
|
| 205 |
+
end_index=self.split16k,
|
| 206 |
+
bandwidth_hz=1000
|
| 207 |
+
)
|
| 208 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 209 |
+
start_index=self.split16k,
|
| 210 |
+
end_index=self.split20k,
|
| 211 |
+
bandwidth_hz=2000
|
| 212 |
+
)
|
| 213 |
+
return below1k + below4k + below8k + below16k + below20k + self.above20k
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class OtherBandsplitSpecification(VocalBandsplitSpecification):
|
| 217 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 218 |
+
super().__init__(nfft=nfft, fs=fs, version="7")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class BassBandsplitSpecification(BandsplitSpecification):
|
| 222 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 223 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 224 |
+
|
| 225 |
+
def get_band_specs(self):
|
| 226 |
+
below500 = self.get_band_specs_with_bandwidth(
|
| 227 |
+
start_index=0, end_index=self.split500, bandwidth_hz=50
|
| 228 |
+
)
|
| 229 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 230 |
+
start_index=self.split500,
|
| 231 |
+
end_index=self.split1k,
|
| 232 |
+
bandwidth_hz=100
|
| 233 |
+
)
|
| 234 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 235 |
+
start_index=self.split1k,
|
| 236 |
+
end_index=self.split4k,
|
| 237 |
+
bandwidth_hz=500
|
| 238 |
+
)
|
| 239 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 240 |
+
start_index=self.split4k,
|
| 241 |
+
end_index=self.split8k,
|
| 242 |
+
bandwidth_hz=1000
|
| 243 |
+
)
|
| 244 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 245 |
+
start_index=self.split8k,
|
| 246 |
+
end_index=self.split16k,
|
| 247 |
+
bandwidth_hz=2000
|
| 248 |
+
)
|
| 249 |
+
above16k = [(self.split16k, self.max_index)]
|
| 250 |
+
|
| 251 |
+
return below500 + below1k + below4k + below8k + below16k + above16k
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class DrumBandsplitSpecification(BandsplitSpecification):
|
| 255 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 256 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 257 |
+
|
| 258 |
+
def get_band_specs(self):
|
| 259 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 260 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=50
|
| 261 |
+
)
|
| 262 |
+
below2k = self.get_band_specs_with_bandwidth(
|
| 263 |
+
start_index=self.split1k,
|
| 264 |
+
end_index=self.split2k,
|
| 265 |
+
bandwidth_hz=100
|
| 266 |
+
)
|
| 267 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 268 |
+
start_index=self.split2k,
|
| 269 |
+
end_index=self.split4k,
|
| 270 |
+
bandwidth_hz=250
|
| 271 |
+
)
|
| 272 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 273 |
+
start_index=self.split4k,
|
| 274 |
+
end_index=self.split8k,
|
| 275 |
+
bandwidth_hz=500
|
| 276 |
+
)
|
| 277 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 278 |
+
start_index=self.split8k,
|
| 279 |
+
end_index=self.split16k,
|
| 280 |
+
bandwidth_hz=1000
|
| 281 |
+
)
|
| 282 |
+
above16k = [(self.split16k, self.max_index)]
|
| 283 |
+
|
| 284 |
+
return below1k + below2k + below4k + below8k + below16k + above16k
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class PerceptualBandsplitSpecification(BandsplitSpecification):
|
| 290 |
+
def __init__(
|
| 291 |
+
self,
|
| 292 |
+
nfft: int,
|
| 293 |
+
fs: int,
|
| 294 |
+
fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
|
| 295 |
+
n_bands: int,
|
| 296 |
+
f_min: float = 0.0,
|
| 297 |
+
f_max: float = None
|
| 298 |
+
) -> None:
|
| 299 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 300 |
+
self.n_bands = n_bands
|
| 301 |
+
if f_max is None:
|
| 302 |
+
f_max = fs / 2
|
| 303 |
+
|
| 304 |
+
self.filterbank = fbank_fn(
|
| 305 |
+
n_bands, fs, f_min, f_max, self.max_index
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
weight_per_bin = torch.sum(
|
| 309 |
+
self.filterbank,
|
| 310 |
+
dim=0,
|
| 311 |
+
keepdim=True
|
| 312 |
+
) # (1, n_freqs)
|
| 313 |
+
normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
|
| 314 |
+
|
| 315 |
+
freq_weights = []
|
| 316 |
+
band_specs = []
|
| 317 |
+
for i in range(self.n_bands):
|
| 318 |
+
active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
|
| 319 |
+
if isinstance(active_bins, int):
|
| 320 |
+
active_bins = (active_bins, active_bins)
|
| 321 |
+
if len(active_bins) == 0:
|
| 322 |
+
continue
|
| 323 |
+
start_index = active_bins[0]
|
| 324 |
+
end_index = active_bins[-1] + 1
|
| 325 |
+
band_specs.append((start_index, end_index))
|
| 326 |
+
freq_weights.append(normalized_mel_fb[i, start_index:end_index])
|
| 327 |
+
|
| 328 |
+
self.freq_weights = freq_weights
|
| 329 |
+
self.band_specs = band_specs
|
| 330 |
+
|
| 331 |
+
def get_band_specs(self):
|
| 332 |
+
return self.band_specs
|
| 333 |
+
|
| 334 |
+
def get_freq_weights(self):
|
| 335 |
+
return self.freq_weights
|
| 336 |
+
|
| 337 |
+
def save_to_file(self, dir_path: str) -> None:
|
| 338 |
+
|
| 339 |
+
os.makedirs(dir_path, exist_ok=True)
|
| 340 |
+
|
| 341 |
+
import pickle
|
| 342 |
+
|
| 343 |
+
with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
|
| 344 |
+
pickle.dump(
|
| 345 |
+
{
|
| 346 |
+
"band_specs": self.band_specs,
|
| 347 |
+
"freq_weights": self.freq_weights,
|
| 348 |
+
"filterbank": self.filterbank,
|
| 349 |
+
},
|
| 350 |
+
f,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 354 |
+
fb = taF.melscale_fbanks(
|
| 355 |
+
n_mels=n_bands,
|
| 356 |
+
sample_rate=fs,
|
| 357 |
+
f_min=f_min,
|
| 358 |
+
f_max=f_max,
|
| 359 |
+
n_freqs=n_freqs,
|
| 360 |
+
).T
|
| 361 |
+
|
| 362 |
+
fb[0, 0] = 1.0
|
| 363 |
+
|
| 364 |
+
return fb
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class MelBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 368 |
+
def __init__(
|
| 369 |
+
self,
|
| 370 |
+
nfft: int,
|
| 371 |
+
fs: int,
|
| 372 |
+
n_bands: int,
|
| 373 |
+
f_min: float = 0.0,
|
| 374 |
+
f_max: float = None
|
| 375 |
+
) -> None:
|
| 376 |
+
super().__init__(fbank_fn=mel_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 377 |
+
|
| 378 |
+
def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs,
|
| 379 |
+
scale="constant"):
|
| 380 |
+
|
| 381 |
+
nfft = 2 * (n_freqs - 1)
|
| 382 |
+
df = fs / nfft
|
| 383 |
+
# init freqs
|
| 384 |
+
f_max = f_max or fs / 2
|
| 385 |
+
f_min = f_min or 0
|
| 386 |
+
f_min = fs / nfft
|
| 387 |
+
|
| 388 |
+
n_octaves = np.log2(f_max / f_min)
|
| 389 |
+
n_octaves_per_band = n_octaves / n_bands
|
| 390 |
+
bandwidth_mult = np.power(2.0, n_octaves_per_band)
|
| 391 |
+
|
| 392 |
+
low_midi = max(0, hz_to_midi(f_min))
|
| 393 |
+
high_midi = hz_to_midi(f_max)
|
| 394 |
+
midi_points = np.linspace(low_midi, high_midi, n_bands)
|
| 395 |
+
hz_pts = midi_to_hz(midi_points)
|
| 396 |
+
|
| 397 |
+
low_pts = hz_pts / bandwidth_mult
|
| 398 |
+
high_pts = hz_pts * bandwidth_mult
|
| 399 |
+
|
| 400 |
+
low_bins = np.floor(low_pts / df).astype(int)
|
| 401 |
+
high_bins = np.ceil(high_pts / df).astype(int)
|
| 402 |
+
|
| 403 |
+
fb = np.zeros((n_bands, n_freqs))
|
| 404 |
+
|
| 405 |
+
for i in range(n_bands):
|
| 406 |
+
fb[i, low_bins[i]:high_bins[i]+1] = 1.0
|
| 407 |
+
|
| 408 |
+
fb[0, :low_bins[0]] = 1.0
|
| 409 |
+
fb[-1, high_bins[-1]+1:] = 1.0
|
| 410 |
+
|
| 411 |
+
return torch.as_tensor(fb)
|
| 412 |
+
|
| 413 |
+
class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 414 |
+
def __init__(
|
| 415 |
+
self,
|
| 416 |
+
nfft: int,
|
| 417 |
+
fs: int,
|
| 418 |
+
n_bands: int,
|
| 419 |
+
f_min: float = 0.0,
|
| 420 |
+
f_max: float = None
|
| 421 |
+
) -> None:
|
| 422 |
+
super().__init__(fbank_fn=musical_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def bark_filterbank(
|
| 426 |
+
n_bands, fs, f_min, f_max, n_freqs
|
| 427 |
+
):
|
| 428 |
+
nfft = 2 * (n_freqs -1)
|
| 429 |
+
fb, _ = bark_fbanks.bark_filter_banks(
|
| 430 |
+
nfilts=n_bands,
|
| 431 |
+
nfft=nfft,
|
| 432 |
+
fs=fs,
|
| 433 |
+
low_freq=f_min,
|
| 434 |
+
high_freq=f_max,
|
| 435 |
+
scale="constant"
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
return torch.as_tensor(fb)
|
| 439 |
+
|
| 440 |
+
class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 441 |
+
def __init__(
|
| 442 |
+
self,
|
| 443 |
+
nfft: int,
|
| 444 |
+
fs: int,
|
| 445 |
+
n_bands: int,
|
| 446 |
+
f_min: float = 0.0,
|
| 447 |
+
f_max: float = None
|
| 448 |
+
) -> None:
|
| 449 |
+
super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def triangular_bark_filterbank(
|
| 453 |
+
n_bands, fs, f_min, f_max, n_freqs
|
| 454 |
+
):
|
| 455 |
+
|
| 456 |
+
all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
| 457 |
+
|
| 458 |
+
# calculate mel freq bins
|
| 459 |
+
m_min = hz2bark(f_min)
|
| 460 |
+
m_max = hz2bark(f_max)
|
| 461 |
+
|
| 462 |
+
m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
| 463 |
+
f_pts = 600 * torch.sinh(m_pts / 6)
|
| 464 |
+
|
| 465 |
+
# create filterbank
|
| 466 |
+
fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 467 |
+
|
| 468 |
+
fb = fb.T
|
| 469 |
+
|
| 470 |
+
first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
| 471 |
+
first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
| 472 |
+
|
| 473 |
+
fb[first_active_band, :first_active_bin] = 1.0
|
| 474 |
+
|
| 475 |
+
return fb
|
| 476 |
+
|
| 477 |
+
class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 478 |
+
def __init__(
|
| 479 |
+
self,
|
| 480 |
+
nfft: int,
|
| 481 |
+
fs: int,
|
| 482 |
+
n_bands: int,
|
| 483 |
+
f_min: float = 0.0,
|
| 484 |
+
f_max: float = None
|
| 485 |
+
) -> None:
|
| 486 |
+
super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def minibark_filterbank(
|
| 491 |
+
n_bands, fs, f_min, f_max, n_freqs
|
| 492 |
+
):
|
| 493 |
+
fb = bark_filterbank(
|
| 494 |
+
n_bands,
|
| 495 |
+
fs,
|
| 496 |
+
f_min,
|
| 497 |
+
f_max,
|
| 498 |
+
n_freqs
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
fb[fb < np.sqrt(0.5)] = 0.0
|
| 502 |
+
|
| 503 |
+
return fb
|
| 504 |
+
|
| 505 |
+
class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 506 |
+
def __init__(
|
| 507 |
+
self,
|
| 508 |
+
nfft: int,
|
| 509 |
+
fs: int,
|
| 510 |
+
n_bands: int,
|
| 511 |
+
f_min: float = 0.0,
|
| 512 |
+
f_max: float = None
|
| 513 |
+
) -> None:
|
| 514 |
+
super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def erb_filterbank(
|
| 521 |
+
n_bands: int,
|
| 522 |
+
fs: int,
|
| 523 |
+
f_min: float,
|
| 524 |
+
f_max: float,
|
| 525 |
+
n_freqs: int,
|
| 526 |
+
) -> Tensor:
|
| 527 |
+
# freq bins
|
| 528 |
+
A = (1000 * np.log(10)) / (24.7 * 4.37)
|
| 529 |
+
all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
| 530 |
+
|
| 531 |
+
# calculate mel freq bins
|
| 532 |
+
m_min = hz2erb(f_min)
|
| 533 |
+
m_max = hz2erb(f_max)
|
| 534 |
+
|
| 535 |
+
m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
| 536 |
+
f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437
|
| 537 |
+
|
| 538 |
+
# create filterbank
|
| 539 |
+
fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 540 |
+
|
| 541 |
+
fb = fb.T
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
| 545 |
+
first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
| 546 |
+
|
| 547 |
+
fb[first_active_band, :first_active_bin] = 1.0
|
| 548 |
+
|
| 549 |
+
return fb
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 554 |
+
def __init__(
|
| 555 |
+
self,
|
| 556 |
+
nfft: int,
|
| 557 |
+
fs: int,
|
| 558 |
+
n_bands: int,
|
| 559 |
+
f_min: float = 0.0,
|
| 560 |
+
f_max: float = None
|
| 561 |
+
) -> None:
|
| 562 |
+
super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 563 |
+
|
| 564 |
+
if __name__ == "__main__":
|
| 565 |
+
import pandas as pd
|
| 566 |
+
|
| 567 |
+
band_defs = []
|
| 568 |
+
|
| 569 |
+
for bands in [VocalBandsplitSpecification]:
|
| 570 |
+
band_name = bands.__name__.replace("BandsplitSpecification", "")
|
| 571 |
+
|
| 572 |
+
mbs = bands(nfft=2048, fs=44100).get_band_specs()
|
| 573 |
+
|
| 574 |
+
for i, (f_min, f_max) in enumerate(mbs):
|
| 575 |
+
band_defs.append({
|
| 576 |
+
"band": band_name,
|
| 577 |
+
"band_index": i,
|
| 578 |
+
"f_min": f_min,
|
| 579 |
+
"f_max": f_max
|
| 580 |
+
})
|
| 581 |
+
|
| 582 |
+
df = pd.DataFrame(band_defs)
|
| 583 |
+
df.to_csv("vox7bands.csv", index=False)
|
separator/models/bandit/core/model/bsrnn/wrapper.py
ADDED
|
@@ -0,0 +1,882 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pprint import pprint
|
| 2 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from models.bandit.core.model._spectral import _SpectralComponent
|
| 8 |
+
from models.bandit.core.model.bsrnn.utils import (
|
| 9 |
+
BarkBandsplitSpecification, BassBandsplitSpecification,
|
| 10 |
+
DrumBandsplitSpecification,
|
| 11 |
+
EquivalentRectangularBandsplitSpecification, MelBandsplitSpecification,
|
| 12 |
+
MusicalBandsplitSpecification, OtherBandsplitSpecification,
|
| 13 |
+
TriangularBarkBandsplitSpecification, VocalBandsplitSpecification,
|
| 14 |
+
)
|
| 15 |
+
from .core import (
|
| 16 |
+
MultiSourceMultiMaskBandSplitCoreConv,
|
| 17 |
+
MultiSourceMultiMaskBandSplitCoreRNN,
|
| 18 |
+
MultiSourceMultiMaskBandSplitCoreTransformer,
|
| 19 |
+
MultiSourceMultiPatchingMaskBandSplitCoreRNN, SingleMaskBandsplitCoreRNN,
|
| 20 |
+
SingleMaskBandsplitCoreTransformer,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
import pytorch_lightning as pl
|
| 24 |
+
|
| 25 |
+
def get_band_specs(band_specs, n_fft, fs, n_bands=None):
|
| 26 |
+
if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
|
| 27 |
+
bsm = VocalBandsplitSpecification(
|
| 28 |
+
nfft=n_fft, fs=fs
|
| 29 |
+
).get_band_specs()
|
| 30 |
+
freq_weights = None
|
| 31 |
+
overlapping_band = False
|
| 32 |
+
elif "tribark" in band_specs:
|
| 33 |
+
assert n_bands is not None
|
| 34 |
+
specs = TriangularBarkBandsplitSpecification(
|
| 35 |
+
nfft=n_fft,
|
| 36 |
+
fs=fs,
|
| 37 |
+
n_bands=n_bands
|
| 38 |
+
)
|
| 39 |
+
bsm = specs.get_band_specs()
|
| 40 |
+
freq_weights = specs.get_freq_weights()
|
| 41 |
+
overlapping_band = True
|
| 42 |
+
elif "bark" in band_specs:
|
| 43 |
+
assert n_bands is not None
|
| 44 |
+
specs = BarkBandsplitSpecification(
|
| 45 |
+
nfft=n_fft,
|
| 46 |
+
fs=fs,
|
| 47 |
+
n_bands=n_bands
|
| 48 |
+
)
|
| 49 |
+
bsm = specs.get_band_specs()
|
| 50 |
+
freq_weights = specs.get_freq_weights()
|
| 51 |
+
overlapping_band = True
|
| 52 |
+
elif "erb" in band_specs:
|
| 53 |
+
assert n_bands is not None
|
| 54 |
+
specs = EquivalentRectangularBandsplitSpecification(
|
| 55 |
+
nfft=n_fft,
|
| 56 |
+
fs=fs,
|
| 57 |
+
n_bands=n_bands
|
| 58 |
+
)
|
| 59 |
+
bsm = specs.get_band_specs()
|
| 60 |
+
freq_weights = specs.get_freq_weights()
|
| 61 |
+
overlapping_band = True
|
| 62 |
+
elif "musical" in band_specs:
|
| 63 |
+
assert n_bands is not None
|
| 64 |
+
specs = MusicalBandsplitSpecification(
|
| 65 |
+
nfft=n_fft,
|
| 66 |
+
fs=fs,
|
| 67 |
+
n_bands=n_bands
|
| 68 |
+
)
|
| 69 |
+
bsm = specs.get_band_specs()
|
| 70 |
+
freq_weights = specs.get_freq_weights()
|
| 71 |
+
overlapping_band = True
|
| 72 |
+
elif band_specs == "dnr:mel" or "mel" in band_specs:
|
| 73 |
+
assert n_bands is not None
|
| 74 |
+
specs = MelBandsplitSpecification(
|
| 75 |
+
nfft=n_fft,
|
| 76 |
+
fs=fs,
|
| 77 |
+
n_bands=n_bands
|
| 78 |
+
)
|
| 79 |
+
bsm = specs.get_band_specs()
|
| 80 |
+
freq_weights = specs.get_freq_weights()
|
| 81 |
+
overlapping_band = True
|
| 82 |
+
else:
|
| 83 |
+
raise NameError
|
| 84 |
+
|
| 85 |
+
return bsm, freq_weights, overlapping_band
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
|
| 89 |
+
if band_specs_map == "musdb:all":
|
| 90 |
+
bsm = {
|
| 91 |
+
"vocals": VocalBandsplitSpecification(
|
| 92 |
+
nfft=n_fft, fs=fs
|
| 93 |
+
).get_band_specs(),
|
| 94 |
+
"drums": DrumBandsplitSpecification(
|
| 95 |
+
nfft=n_fft, fs=fs
|
| 96 |
+
).get_band_specs(),
|
| 97 |
+
"bass": BassBandsplitSpecification(
|
| 98 |
+
nfft=n_fft, fs=fs
|
| 99 |
+
).get_band_specs(),
|
| 100 |
+
"other": OtherBandsplitSpecification(
|
| 101 |
+
nfft=n_fft, fs=fs
|
| 102 |
+
).get_band_specs(),
|
| 103 |
+
}
|
| 104 |
+
freq_weights = None
|
| 105 |
+
overlapping_band = False
|
| 106 |
+
elif band_specs_map == "dnr:vox7":
|
| 107 |
+
bsm_, freq_weights, overlapping_band = get_band_specs(
|
| 108 |
+
"dnr:speech", n_fft, fs, n_bands
|
| 109 |
+
)
|
| 110 |
+
bsm = {
|
| 111 |
+
"speech": bsm_,
|
| 112 |
+
"music": bsm_,
|
| 113 |
+
"effects": bsm_
|
| 114 |
+
}
|
| 115 |
+
elif "dnr:vox7:" in band_specs_map:
|
| 116 |
+
stem = band_specs_map.split(":")[-1]
|
| 117 |
+
bsm_, freq_weights, overlapping_band = get_band_specs(
|
| 118 |
+
"dnr:speech", n_fft, fs, n_bands
|
| 119 |
+
)
|
| 120 |
+
bsm = {
|
| 121 |
+
stem: bsm_
|
| 122 |
+
}
|
| 123 |
+
else:
|
| 124 |
+
raise NameError
|
| 125 |
+
|
| 126 |
+
return bsm, freq_weights, overlapping_band
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class BandSplitWrapperBase(pl.LightningModule):
|
| 130 |
+
bsrnn: nn.Module
|
| 131 |
+
|
| 132 |
+
def __init__(self, **kwargs):
|
| 133 |
+
super().__init__()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class SingleMaskMultiSourceBandSplitBase(
|
| 137 |
+
BandSplitWrapperBase,
|
| 138 |
+
_SpectralComponent
|
| 139 |
+
):
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
| 143 |
+
fs: int = 44100,
|
| 144 |
+
n_fft: int = 2048,
|
| 145 |
+
win_length: Optional[int] = 2048,
|
| 146 |
+
hop_length: int = 512,
|
| 147 |
+
window_fn: str = "hann_window",
|
| 148 |
+
wkwargs: Optional[Dict] = None,
|
| 149 |
+
power: Optional[int] = None,
|
| 150 |
+
center: bool = True,
|
| 151 |
+
normalized: bool = True,
|
| 152 |
+
pad_mode: str = "constant",
|
| 153 |
+
onesided: bool = True,
|
| 154 |
+
n_bands: int = None,
|
| 155 |
+
) -> None:
|
| 156 |
+
super().__init__(
|
| 157 |
+
n_fft=n_fft,
|
| 158 |
+
win_length=win_length,
|
| 159 |
+
hop_length=hop_length,
|
| 160 |
+
window_fn=window_fn,
|
| 161 |
+
wkwargs=wkwargs,
|
| 162 |
+
power=power,
|
| 163 |
+
center=center,
|
| 164 |
+
normalized=normalized,
|
| 165 |
+
pad_mode=pad_mode,
|
| 166 |
+
onesided=onesided,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
if isinstance(band_specs_map, str):
|
| 170 |
+
self.band_specs_map, self.freq_weights, self.overlapping_band = get_band_specs_map(
|
| 171 |
+
band_specs_map,
|
| 172 |
+
n_fft,
|
| 173 |
+
fs,
|
| 174 |
+
n_bands=n_bands
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
self.stems = list(self.band_specs_map.keys())
|
| 178 |
+
|
| 179 |
+
def forward(self, batch):
|
| 180 |
+
audio = batch["audio"]
|
| 181 |
+
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in
|
| 184 |
+
audio}
|
| 185 |
+
|
| 186 |
+
X = batch["spectrogram"]["mixture"]
|
| 187 |
+
length = batch["audio"]["mixture"].shape[-1]
|
| 188 |
+
|
| 189 |
+
output = {"spectrogram": {}, "audio": {}}
|
| 190 |
+
|
| 191 |
+
for stem, bsrnn in self.bsrnn.items():
|
| 192 |
+
S = bsrnn(X)
|
| 193 |
+
s = self.istft(S, length)
|
| 194 |
+
output["spectrogram"][stem] = S
|
| 195 |
+
output["audio"][stem] = s
|
| 196 |
+
|
| 197 |
+
return batch, output
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class MultiMaskMultiSourceBandSplitBase(
|
| 201 |
+
BandSplitWrapperBase,
|
| 202 |
+
_SpectralComponent
|
| 203 |
+
):
|
| 204 |
+
def __init__(
|
| 205 |
+
self,
|
| 206 |
+
stems: List[str],
|
| 207 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 208 |
+
fs: int = 44100,
|
| 209 |
+
n_fft: int = 2048,
|
| 210 |
+
win_length: Optional[int] = 2048,
|
| 211 |
+
hop_length: int = 512,
|
| 212 |
+
window_fn: str = "hann_window",
|
| 213 |
+
wkwargs: Optional[Dict] = None,
|
| 214 |
+
power: Optional[int] = None,
|
| 215 |
+
center: bool = True,
|
| 216 |
+
normalized: bool = True,
|
| 217 |
+
pad_mode: str = "constant",
|
| 218 |
+
onesided: bool = True,
|
| 219 |
+
n_bands: int = None,
|
| 220 |
+
) -> None:
|
| 221 |
+
super().__init__(
|
| 222 |
+
n_fft=n_fft,
|
| 223 |
+
win_length=win_length,
|
| 224 |
+
hop_length=hop_length,
|
| 225 |
+
window_fn=window_fn,
|
| 226 |
+
wkwargs=wkwargs,
|
| 227 |
+
power=power,
|
| 228 |
+
center=center,
|
| 229 |
+
normalized=normalized,
|
| 230 |
+
pad_mode=pad_mode,
|
| 231 |
+
onesided=onesided,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
if isinstance(band_specs, str):
|
| 235 |
+
self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
|
| 236 |
+
band_specs,
|
| 237 |
+
n_fft,
|
| 238 |
+
fs,
|
| 239 |
+
n_bands
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
self.stems = stems
|
| 243 |
+
|
| 244 |
+
def forward(self, batch):
|
| 245 |
+
# with torch.no_grad():
|
| 246 |
+
audio = batch["audio"]
|
| 247 |
+
cond = batch.get("condition", None)
|
| 248 |
+
with torch.no_grad():
|
| 249 |
+
batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in
|
| 250 |
+
audio}
|
| 251 |
+
|
| 252 |
+
X = batch["spectrogram"]["mixture"]
|
| 253 |
+
length = batch["audio"]["mixture"].shape[-1]
|
| 254 |
+
|
| 255 |
+
output = self.bsrnn(X, cond=cond)
|
| 256 |
+
output["audio"] = {}
|
| 257 |
+
|
| 258 |
+
for stem, S in output["spectrogram"].items():
|
| 259 |
+
s = self.istft(S, length)
|
| 260 |
+
output["audio"][stem] = s
|
| 261 |
+
|
| 262 |
+
return batch, output
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class MultiMaskMultiSourceBandSplitBaseSimple(
|
| 266 |
+
BandSplitWrapperBase,
|
| 267 |
+
_SpectralComponent
|
| 268 |
+
):
|
| 269 |
+
def __init__(
|
| 270 |
+
self,
|
| 271 |
+
stems: List[str],
|
| 272 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 273 |
+
fs: int = 44100,
|
| 274 |
+
n_fft: int = 2048,
|
| 275 |
+
win_length: Optional[int] = 2048,
|
| 276 |
+
hop_length: int = 512,
|
| 277 |
+
window_fn: str = "hann_window",
|
| 278 |
+
wkwargs: Optional[Dict] = None,
|
| 279 |
+
power: Optional[int] = None,
|
| 280 |
+
center: bool = True,
|
| 281 |
+
normalized: bool = True,
|
| 282 |
+
pad_mode: str = "constant",
|
| 283 |
+
onesided: bool = True,
|
| 284 |
+
n_bands: int = None,
|
| 285 |
+
) -> None:
|
| 286 |
+
super().__init__(
|
| 287 |
+
n_fft=n_fft,
|
| 288 |
+
win_length=win_length,
|
| 289 |
+
hop_length=hop_length,
|
| 290 |
+
window_fn=window_fn,
|
| 291 |
+
wkwargs=wkwargs,
|
| 292 |
+
power=power,
|
| 293 |
+
center=center,
|
| 294 |
+
normalized=normalized,
|
| 295 |
+
pad_mode=pad_mode,
|
| 296 |
+
onesided=onesided,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
if isinstance(band_specs, str):
|
| 300 |
+
self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
|
| 301 |
+
band_specs,
|
| 302 |
+
n_fft,
|
| 303 |
+
fs,
|
| 304 |
+
n_bands
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
self.stems = stems
|
| 308 |
+
|
| 309 |
+
def forward(self, batch):
|
| 310 |
+
with torch.no_grad():
|
| 311 |
+
X = self.stft(batch)
|
| 312 |
+
length = batch.shape[-1]
|
| 313 |
+
output = self.bsrnn(X, cond=None)
|
| 314 |
+
res = []
|
| 315 |
+
for stem, S in output["spectrogram"].items():
|
| 316 |
+
s = self.istft(S, length)
|
| 317 |
+
res.append(s)
|
| 318 |
+
res = torch.stack(res, dim=1)
|
| 319 |
+
return res
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase):
|
| 323 |
+
def __init__(
|
| 324 |
+
self,
|
| 325 |
+
in_channel: int,
|
| 326 |
+
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
| 327 |
+
fs: int = 44100,
|
| 328 |
+
require_no_overlap: bool = False,
|
| 329 |
+
require_no_gap: bool = True,
|
| 330 |
+
normalize_channel_independently: bool = False,
|
| 331 |
+
treat_channel_as_feature: bool = True,
|
| 332 |
+
n_sqm_modules: int = 12,
|
| 333 |
+
emb_dim: int = 128,
|
| 334 |
+
rnn_dim: int = 256,
|
| 335 |
+
bidirectional: bool = True,
|
| 336 |
+
rnn_type: str = "LSTM",
|
| 337 |
+
mlp_dim: int = 512,
|
| 338 |
+
hidden_activation: str = "Tanh",
|
| 339 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 340 |
+
complex_mask: bool = True,
|
| 341 |
+
n_fft: int = 2048,
|
| 342 |
+
win_length: Optional[int] = 2048,
|
| 343 |
+
hop_length: int = 512,
|
| 344 |
+
window_fn: str = "hann_window",
|
| 345 |
+
wkwargs: Optional[Dict] = None,
|
| 346 |
+
power: Optional[int] = None,
|
| 347 |
+
center: bool = True,
|
| 348 |
+
normalized: bool = True,
|
| 349 |
+
pad_mode: str = "constant",
|
| 350 |
+
onesided: bool = True,
|
| 351 |
+
) -> None:
|
| 352 |
+
super().__init__(
|
| 353 |
+
band_specs_map=band_specs_map,
|
| 354 |
+
fs=fs,
|
| 355 |
+
n_fft=n_fft,
|
| 356 |
+
win_length=win_length,
|
| 357 |
+
hop_length=hop_length,
|
| 358 |
+
window_fn=window_fn,
|
| 359 |
+
wkwargs=wkwargs,
|
| 360 |
+
power=power,
|
| 361 |
+
center=center,
|
| 362 |
+
normalized=normalized,
|
| 363 |
+
pad_mode=pad_mode,
|
| 364 |
+
onesided=onesided,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
self.bsrnn = nn.ModuleDict(
|
| 368 |
+
{
|
| 369 |
+
src: SingleMaskBandsplitCoreRNN(
|
| 370 |
+
band_specs=specs,
|
| 371 |
+
in_channel=in_channel,
|
| 372 |
+
require_no_overlap=require_no_overlap,
|
| 373 |
+
require_no_gap=require_no_gap,
|
| 374 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 375 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 376 |
+
n_sqm_modules=n_sqm_modules,
|
| 377 |
+
emb_dim=emb_dim,
|
| 378 |
+
rnn_dim=rnn_dim,
|
| 379 |
+
bidirectional=bidirectional,
|
| 380 |
+
rnn_type=rnn_type,
|
| 381 |
+
mlp_dim=mlp_dim,
|
| 382 |
+
hidden_activation=hidden_activation,
|
| 383 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 384 |
+
complex_mask=complex_mask,
|
| 385 |
+
)
|
| 386 |
+
for src, specs in self.band_specs_map.items()
|
| 387 |
+
}
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class SingleMaskMultiSourceBandSplitTransformer(
|
| 392 |
+
SingleMaskMultiSourceBandSplitBase
|
| 393 |
+
):
|
| 394 |
+
def __init__(
|
| 395 |
+
self,
|
| 396 |
+
in_channel: int,
|
| 397 |
+
band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
|
| 398 |
+
fs: int = 44100,
|
| 399 |
+
require_no_overlap: bool = False,
|
| 400 |
+
require_no_gap: bool = True,
|
| 401 |
+
normalize_channel_independently: bool = False,
|
| 402 |
+
treat_channel_as_feature: bool = True,
|
| 403 |
+
n_sqm_modules: int = 12,
|
| 404 |
+
emb_dim: int = 128,
|
| 405 |
+
rnn_dim: int = 256,
|
| 406 |
+
bidirectional: bool = True,
|
| 407 |
+
tf_dropout: float = 0.0,
|
| 408 |
+
mlp_dim: int = 512,
|
| 409 |
+
hidden_activation: str = "Tanh",
|
| 410 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 411 |
+
complex_mask: bool = True,
|
| 412 |
+
n_fft: int = 2048,
|
| 413 |
+
win_length: Optional[int] = 2048,
|
| 414 |
+
hop_length: int = 512,
|
| 415 |
+
window_fn: str = "hann_window",
|
| 416 |
+
wkwargs: Optional[Dict] = None,
|
| 417 |
+
power: Optional[int] = None,
|
| 418 |
+
center: bool = True,
|
| 419 |
+
normalized: bool = True,
|
| 420 |
+
pad_mode: str = "constant",
|
| 421 |
+
onesided: bool = True,
|
| 422 |
+
) -> None:
|
| 423 |
+
super().__init__(
|
| 424 |
+
band_specs_map=band_specs_map,
|
| 425 |
+
fs=fs,
|
| 426 |
+
n_fft=n_fft,
|
| 427 |
+
win_length=win_length,
|
| 428 |
+
hop_length=hop_length,
|
| 429 |
+
window_fn=window_fn,
|
| 430 |
+
wkwargs=wkwargs,
|
| 431 |
+
power=power,
|
| 432 |
+
center=center,
|
| 433 |
+
normalized=normalized,
|
| 434 |
+
pad_mode=pad_mode,
|
| 435 |
+
onesided=onesided,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
self.bsrnn = nn.ModuleDict(
|
| 439 |
+
{
|
| 440 |
+
src: SingleMaskBandsplitCoreTransformer(
|
| 441 |
+
band_specs=specs,
|
| 442 |
+
in_channel=in_channel,
|
| 443 |
+
require_no_overlap=require_no_overlap,
|
| 444 |
+
require_no_gap=require_no_gap,
|
| 445 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 446 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 447 |
+
n_sqm_modules=n_sqm_modules,
|
| 448 |
+
emb_dim=emb_dim,
|
| 449 |
+
rnn_dim=rnn_dim,
|
| 450 |
+
bidirectional=bidirectional,
|
| 451 |
+
tf_dropout=tf_dropout,
|
| 452 |
+
mlp_dim=mlp_dim,
|
| 453 |
+
hidden_activation=hidden_activation,
|
| 454 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 455 |
+
complex_mask=complex_mask,
|
| 456 |
+
)
|
| 457 |
+
for src, specs in self.band_specs_map.items()
|
| 458 |
+
}
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
|
| 463 |
+
def __init__(
|
| 464 |
+
self,
|
| 465 |
+
in_channel: int,
|
| 466 |
+
stems: List[str],
|
| 467 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 468 |
+
fs: int = 44100,
|
| 469 |
+
require_no_overlap: bool = False,
|
| 470 |
+
require_no_gap: bool = True,
|
| 471 |
+
normalize_channel_independently: bool = False,
|
| 472 |
+
treat_channel_as_feature: bool = True,
|
| 473 |
+
n_sqm_modules: int = 12,
|
| 474 |
+
emb_dim: int = 128,
|
| 475 |
+
rnn_dim: int = 256,
|
| 476 |
+
cond_dim: int = 0,
|
| 477 |
+
bidirectional: bool = True,
|
| 478 |
+
rnn_type: str = "LSTM",
|
| 479 |
+
mlp_dim: int = 512,
|
| 480 |
+
hidden_activation: str = "Tanh",
|
| 481 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 482 |
+
complex_mask: bool = True,
|
| 483 |
+
n_fft: int = 2048,
|
| 484 |
+
win_length: Optional[int] = 2048,
|
| 485 |
+
hop_length: int = 512,
|
| 486 |
+
window_fn: str = "hann_window",
|
| 487 |
+
wkwargs: Optional[Dict] = None,
|
| 488 |
+
power: Optional[int] = None,
|
| 489 |
+
center: bool = True,
|
| 490 |
+
normalized: bool = True,
|
| 491 |
+
pad_mode: str = "constant",
|
| 492 |
+
onesided: bool = True,
|
| 493 |
+
n_bands: int = None,
|
| 494 |
+
use_freq_weights: bool = True,
|
| 495 |
+
normalize_input: bool = False,
|
| 496 |
+
mult_add_mask: bool = False,
|
| 497 |
+
freeze_encoder: bool = False,
|
| 498 |
+
) -> None:
|
| 499 |
+
super().__init__(
|
| 500 |
+
stems=stems,
|
| 501 |
+
band_specs=band_specs,
|
| 502 |
+
fs=fs,
|
| 503 |
+
n_fft=n_fft,
|
| 504 |
+
win_length=win_length,
|
| 505 |
+
hop_length=hop_length,
|
| 506 |
+
window_fn=window_fn,
|
| 507 |
+
wkwargs=wkwargs,
|
| 508 |
+
power=power,
|
| 509 |
+
center=center,
|
| 510 |
+
normalized=normalized,
|
| 511 |
+
pad_mode=pad_mode,
|
| 512 |
+
onesided=onesided,
|
| 513 |
+
n_bands=n_bands,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
|
| 517 |
+
stems=stems,
|
| 518 |
+
band_specs=self.band_specs,
|
| 519 |
+
in_channel=in_channel,
|
| 520 |
+
require_no_overlap=require_no_overlap,
|
| 521 |
+
require_no_gap=require_no_gap,
|
| 522 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 523 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 524 |
+
n_sqm_modules=n_sqm_modules,
|
| 525 |
+
emb_dim=emb_dim,
|
| 526 |
+
rnn_dim=rnn_dim,
|
| 527 |
+
bidirectional=bidirectional,
|
| 528 |
+
rnn_type=rnn_type,
|
| 529 |
+
mlp_dim=mlp_dim,
|
| 530 |
+
cond_dim=cond_dim,
|
| 531 |
+
hidden_activation=hidden_activation,
|
| 532 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 533 |
+
complex_mask=complex_mask,
|
| 534 |
+
overlapping_band=self.overlapping_band,
|
| 535 |
+
freq_weights=self.freq_weights,
|
| 536 |
+
n_freq=n_fft // 2 + 1,
|
| 537 |
+
use_freq_weights=use_freq_weights,
|
| 538 |
+
mult_add_mask=mult_add_mask
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
self.normalize_input = normalize_input
|
| 542 |
+
self.cond_dim = cond_dim
|
| 543 |
+
|
| 544 |
+
if freeze_encoder:
|
| 545 |
+
for param in self.bsrnn.band_split.parameters():
|
| 546 |
+
param.requires_grad = False
|
| 547 |
+
|
| 548 |
+
for param in self.bsrnn.tf_model.parameters():
|
| 549 |
+
param.requires_grad = False
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple):
|
| 553 |
+
def __init__(
|
| 554 |
+
self,
|
| 555 |
+
in_channel: int,
|
| 556 |
+
stems: List[str],
|
| 557 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 558 |
+
fs: int = 44100,
|
| 559 |
+
require_no_overlap: bool = False,
|
| 560 |
+
require_no_gap: bool = True,
|
| 561 |
+
normalize_channel_independently: bool = False,
|
| 562 |
+
treat_channel_as_feature: bool = True,
|
| 563 |
+
n_sqm_modules: int = 12,
|
| 564 |
+
emb_dim: int = 128,
|
| 565 |
+
rnn_dim: int = 256,
|
| 566 |
+
cond_dim: int = 0,
|
| 567 |
+
bidirectional: bool = True,
|
| 568 |
+
rnn_type: str = "LSTM",
|
| 569 |
+
mlp_dim: int = 512,
|
| 570 |
+
hidden_activation: str = "Tanh",
|
| 571 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 572 |
+
complex_mask: bool = True,
|
| 573 |
+
n_fft: int = 2048,
|
| 574 |
+
win_length: Optional[int] = 2048,
|
| 575 |
+
hop_length: int = 512,
|
| 576 |
+
window_fn: str = "hann_window",
|
| 577 |
+
wkwargs: Optional[Dict] = None,
|
| 578 |
+
power: Optional[int] = None,
|
| 579 |
+
center: bool = True,
|
| 580 |
+
normalized: bool = True,
|
| 581 |
+
pad_mode: str = "constant",
|
| 582 |
+
onesided: bool = True,
|
| 583 |
+
n_bands: int = None,
|
| 584 |
+
use_freq_weights: bool = True,
|
| 585 |
+
normalize_input: bool = False,
|
| 586 |
+
mult_add_mask: bool = False,
|
| 587 |
+
freeze_encoder: bool = False,
|
| 588 |
+
) -> None:
|
| 589 |
+
super().__init__(
|
| 590 |
+
stems=stems,
|
| 591 |
+
band_specs=band_specs,
|
| 592 |
+
fs=fs,
|
| 593 |
+
n_fft=n_fft,
|
| 594 |
+
win_length=win_length,
|
| 595 |
+
hop_length=hop_length,
|
| 596 |
+
window_fn=window_fn,
|
| 597 |
+
wkwargs=wkwargs,
|
| 598 |
+
power=power,
|
| 599 |
+
center=center,
|
| 600 |
+
normalized=normalized,
|
| 601 |
+
pad_mode=pad_mode,
|
| 602 |
+
onesided=onesided,
|
| 603 |
+
n_bands=n_bands,
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
|
| 607 |
+
stems=stems,
|
| 608 |
+
band_specs=self.band_specs,
|
| 609 |
+
in_channel=in_channel,
|
| 610 |
+
require_no_overlap=require_no_overlap,
|
| 611 |
+
require_no_gap=require_no_gap,
|
| 612 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 613 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 614 |
+
n_sqm_modules=n_sqm_modules,
|
| 615 |
+
emb_dim=emb_dim,
|
| 616 |
+
rnn_dim=rnn_dim,
|
| 617 |
+
bidirectional=bidirectional,
|
| 618 |
+
rnn_type=rnn_type,
|
| 619 |
+
mlp_dim=mlp_dim,
|
| 620 |
+
cond_dim=cond_dim,
|
| 621 |
+
hidden_activation=hidden_activation,
|
| 622 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 623 |
+
complex_mask=complex_mask,
|
| 624 |
+
overlapping_band=self.overlapping_band,
|
| 625 |
+
freq_weights=self.freq_weights,
|
| 626 |
+
n_freq=n_fft // 2 + 1,
|
| 627 |
+
use_freq_weights=use_freq_weights,
|
| 628 |
+
mult_add_mask=mult_add_mask
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
self.normalize_input = normalize_input
|
| 632 |
+
self.cond_dim = cond_dim
|
| 633 |
+
|
| 634 |
+
if freeze_encoder:
|
| 635 |
+
for param in self.bsrnn.band_split.parameters():
|
| 636 |
+
param.requires_grad = False
|
| 637 |
+
|
| 638 |
+
for param in self.bsrnn.tf_model.parameters():
|
| 639 |
+
param.requires_grad = False
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
class MultiMaskMultiSourceBandSplitTransformer(
|
| 643 |
+
MultiMaskMultiSourceBandSplitBase
|
| 644 |
+
):
|
| 645 |
+
def __init__(
|
| 646 |
+
self,
|
| 647 |
+
in_channel: int,
|
| 648 |
+
stems: List[str],
|
| 649 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 650 |
+
fs: int = 44100,
|
| 651 |
+
require_no_overlap: bool = False,
|
| 652 |
+
require_no_gap: bool = True,
|
| 653 |
+
normalize_channel_independently: bool = False,
|
| 654 |
+
treat_channel_as_feature: bool = True,
|
| 655 |
+
n_sqm_modules: int = 12,
|
| 656 |
+
emb_dim: int = 128,
|
| 657 |
+
rnn_dim: int = 256,
|
| 658 |
+
cond_dim: int = 0,
|
| 659 |
+
bidirectional: bool = True,
|
| 660 |
+
rnn_type: str = "LSTM",
|
| 661 |
+
mlp_dim: int = 512,
|
| 662 |
+
hidden_activation: str = "Tanh",
|
| 663 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 664 |
+
complex_mask: bool = True,
|
| 665 |
+
n_fft: int = 2048,
|
| 666 |
+
win_length: Optional[int] = 2048,
|
| 667 |
+
hop_length: int = 512,
|
| 668 |
+
window_fn: str = "hann_window",
|
| 669 |
+
wkwargs: Optional[Dict] = None,
|
| 670 |
+
power: Optional[int] = None,
|
| 671 |
+
center: bool = True,
|
| 672 |
+
normalized: bool = True,
|
| 673 |
+
pad_mode: str = "constant",
|
| 674 |
+
onesided: bool = True,
|
| 675 |
+
n_bands: int = None,
|
| 676 |
+
use_freq_weights: bool = True,
|
| 677 |
+
normalize_input: bool = False,
|
| 678 |
+
mult_add_mask: bool = False
|
| 679 |
+
) -> None:
|
| 680 |
+
super().__init__(
|
| 681 |
+
stems=stems,
|
| 682 |
+
band_specs=band_specs,
|
| 683 |
+
fs=fs,
|
| 684 |
+
n_fft=n_fft,
|
| 685 |
+
win_length=win_length,
|
| 686 |
+
hop_length=hop_length,
|
| 687 |
+
window_fn=window_fn,
|
| 688 |
+
wkwargs=wkwargs,
|
| 689 |
+
power=power,
|
| 690 |
+
center=center,
|
| 691 |
+
normalized=normalized,
|
| 692 |
+
pad_mode=pad_mode,
|
| 693 |
+
onesided=onesided,
|
| 694 |
+
n_bands=n_bands,
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer(
|
| 698 |
+
stems=stems,
|
| 699 |
+
band_specs=self.band_specs,
|
| 700 |
+
in_channel=in_channel,
|
| 701 |
+
require_no_overlap=require_no_overlap,
|
| 702 |
+
require_no_gap=require_no_gap,
|
| 703 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 704 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 705 |
+
n_sqm_modules=n_sqm_modules,
|
| 706 |
+
emb_dim=emb_dim,
|
| 707 |
+
rnn_dim=rnn_dim,
|
| 708 |
+
bidirectional=bidirectional,
|
| 709 |
+
rnn_type=rnn_type,
|
| 710 |
+
mlp_dim=mlp_dim,
|
| 711 |
+
cond_dim=cond_dim,
|
| 712 |
+
hidden_activation=hidden_activation,
|
| 713 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 714 |
+
complex_mask=complex_mask,
|
| 715 |
+
overlapping_band=self.overlapping_band,
|
| 716 |
+
freq_weights=self.freq_weights,
|
| 717 |
+
n_freq=n_fft // 2 + 1,
|
| 718 |
+
use_freq_weights=use_freq_weights,
|
| 719 |
+
mult_add_mask=mult_add_mask
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
class MultiMaskMultiSourceBandSplitConv(
|
| 725 |
+
MultiMaskMultiSourceBandSplitBase
|
| 726 |
+
):
|
| 727 |
+
def __init__(
|
| 728 |
+
self,
|
| 729 |
+
in_channel: int,
|
| 730 |
+
stems: List[str],
|
| 731 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 732 |
+
fs: int = 44100,
|
| 733 |
+
require_no_overlap: bool = False,
|
| 734 |
+
require_no_gap: bool = True,
|
| 735 |
+
normalize_channel_independently: bool = False,
|
| 736 |
+
treat_channel_as_feature: bool = True,
|
| 737 |
+
n_sqm_modules: int = 12,
|
| 738 |
+
emb_dim: int = 128,
|
| 739 |
+
rnn_dim: int = 256,
|
| 740 |
+
cond_dim: int = 0,
|
| 741 |
+
bidirectional: bool = True,
|
| 742 |
+
rnn_type: str = "LSTM",
|
| 743 |
+
mlp_dim: int = 512,
|
| 744 |
+
hidden_activation: str = "Tanh",
|
| 745 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 746 |
+
complex_mask: bool = True,
|
| 747 |
+
n_fft: int = 2048,
|
| 748 |
+
win_length: Optional[int] = 2048,
|
| 749 |
+
hop_length: int = 512,
|
| 750 |
+
window_fn: str = "hann_window",
|
| 751 |
+
wkwargs: Optional[Dict] = None,
|
| 752 |
+
power: Optional[int] = None,
|
| 753 |
+
center: bool = True,
|
| 754 |
+
normalized: bool = True,
|
| 755 |
+
pad_mode: str = "constant",
|
| 756 |
+
onesided: bool = True,
|
| 757 |
+
n_bands: int = None,
|
| 758 |
+
use_freq_weights: bool = True,
|
| 759 |
+
normalize_input: bool = False,
|
| 760 |
+
mult_add_mask: bool = False
|
| 761 |
+
) -> None:
|
| 762 |
+
super().__init__(
|
| 763 |
+
stems=stems,
|
| 764 |
+
band_specs=band_specs,
|
| 765 |
+
fs=fs,
|
| 766 |
+
n_fft=n_fft,
|
| 767 |
+
win_length=win_length,
|
| 768 |
+
hop_length=hop_length,
|
| 769 |
+
window_fn=window_fn,
|
| 770 |
+
wkwargs=wkwargs,
|
| 771 |
+
power=power,
|
| 772 |
+
center=center,
|
| 773 |
+
normalized=normalized,
|
| 774 |
+
pad_mode=pad_mode,
|
| 775 |
+
onesided=onesided,
|
| 776 |
+
n_bands=n_bands,
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv(
|
| 780 |
+
stems=stems,
|
| 781 |
+
band_specs=self.band_specs,
|
| 782 |
+
in_channel=in_channel,
|
| 783 |
+
require_no_overlap=require_no_overlap,
|
| 784 |
+
require_no_gap=require_no_gap,
|
| 785 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 786 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 787 |
+
n_sqm_modules=n_sqm_modules,
|
| 788 |
+
emb_dim=emb_dim,
|
| 789 |
+
rnn_dim=rnn_dim,
|
| 790 |
+
bidirectional=bidirectional,
|
| 791 |
+
rnn_type=rnn_type,
|
| 792 |
+
mlp_dim=mlp_dim,
|
| 793 |
+
cond_dim=cond_dim,
|
| 794 |
+
hidden_activation=hidden_activation,
|
| 795 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 796 |
+
complex_mask=complex_mask,
|
| 797 |
+
overlapping_band=self.overlapping_band,
|
| 798 |
+
freq_weights=self.freq_weights,
|
| 799 |
+
n_freq=n_fft // 2 + 1,
|
| 800 |
+
use_freq_weights=use_freq_weights,
|
| 801 |
+
mult_add_mask=mult_add_mask
|
| 802 |
+
)
|
| 803 |
+
class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
|
| 804 |
+
def __init__(
|
| 805 |
+
self,
|
| 806 |
+
in_channel: int,
|
| 807 |
+
stems: List[str],
|
| 808 |
+
band_specs: Union[str, List[Tuple[float, float]]],
|
| 809 |
+
kernel_norm_mlp_version: int = 1,
|
| 810 |
+
mask_kernel_freq: int = 3,
|
| 811 |
+
mask_kernel_time: int = 3,
|
| 812 |
+
conv_kernel_freq: int = 1,
|
| 813 |
+
conv_kernel_time: int = 1,
|
| 814 |
+
fs: int = 44100,
|
| 815 |
+
require_no_overlap: bool = False,
|
| 816 |
+
require_no_gap: bool = True,
|
| 817 |
+
normalize_channel_independently: bool = False,
|
| 818 |
+
treat_channel_as_feature: bool = True,
|
| 819 |
+
n_sqm_modules: int = 12,
|
| 820 |
+
emb_dim: int = 128,
|
| 821 |
+
rnn_dim: int = 256,
|
| 822 |
+
bidirectional: bool = True,
|
| 823 |
+
rnn_type: str = "LSTM",
|
| 824 |
+
mlp_dim: int = 512,
|
| 825 |
+
hidden_activation: str = "Tanh",
|
| 826 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 827 |
+
complex_mask: bool = True,
|
| 828 |
+
n_fft: int = 2048,
|
| 829 |
+
win_length: Optional[int] = 2048,
|
| 830 |
+
hop_length: int = 512,
|
| 831 |
+
window_fn: str = "hann_window",
|
| 832 |
+
wkwargs: Optional[Dict] = None,
|
| 833 |
+
power: Optional[int] = None,
|
| 834 |
+
center: bool = True,
|
| 835 |
+
normalized: bool = True,
|
| 836 |
+
pad_mode: str = "constant",
|
| 837 |
+
onesided: bool = True,
|
| 838 |
+
n_bands: int = None,
|
| 839 |
+
) -> None:
|
| 840 |
+
super().__init__(
|
| 841 |
+
stems=stems,
|
| 842 |
+
band_specs=band_specs,
|
| 843 |
+
fs=fs,
|
| 844 |
+
n_fft=n_fft,
|
| 845 |
+
win_length=win_length,
|
| 846 |
+
hop_length=hop_length,
|
| 847 |
+
window_fn=window_fn,
|
| 848 |
+
wkwargs=wkwargs,
|
| 849 |
+
power=power,
|
| 850 |
+
center=center,
|
| 851 |
+
normalized=normalized,
|
| 852 |
+
pad_mode=pad_mode,
|
| 853 |
+
onesided=onesided,
|
| 854 |
+
n_bands=n_bands,
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN(
|
| 858 |
+
stems=stems,
|
| 859 |
+
band_specs=self.band_specs,
|
| 860 |
+
in_channel=in_channel,
|
| 861 |
+
require_no_overlap=require_no_overlap,
|
| 862 |
+
require_no_gap=require_no_gap,
|
| 863 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 864 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 865 |
+
n_sqm_modules=n_sqm_modules,
|
| 866 |
+
emb_dim=emb_dim,
|
| 867 |
+
rnn_dim=rnn_dim,
|
| 868 |
+
bidirectional=bidirectional,
|
| 869 |
+
rnn_type=rnn_type,
|
| 870 |
+
mlp_dim=mlp_dim,
|
| 871 |
+
hidden_activation=hidden_activation,
|
| 872 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 873 |
+
complex_mask=complex_mask,
|
| 874 |
+
overlapping_band=self.overlapping_band,
|
| 875 |
+
freq_weights=self.freq_weights,
|
| 876 |
+
n_freq=n_fft // 2 + 1,
|
| 877 |
+
mask_kernel_freq=mask_kernel_freq,
|
| 878 |
+
mask_kernel_time=mask_kernel_time,
|
| 879 |
+
conv_kernel_freq=conv_kernel_freq,
|
| 880 |
+
conv_kernel_time=conv_kernel_time,
|
| 881 |
+
kernel_norm_mlp_version=kernel_norm_mlp_version,
|
| 882 |
+
)
|
separator/models/bandit/core/utils/__init__.py
ADDED
|
File without changes
|
separator/models/bandit/core/utils/audio.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
|
| 3 |
+
from tqdm.auto import tqdm
|
| 4 |
+
from typing import Callable, Dict, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@torch.jit.script
|
| 13 |
+
def merge(
|
| 14 |
+
combined: torch.Tensor,
|
| 15 |
+
original_batch_size: int,
|
| 16 |
+
n_channel: int,
|
| 17 |
+
n_chunks: int,
|
| 18 |
+
chunk_size: int, ):
|
| 19 |
+
combined = torch.reshape(
|
| 20 |
+
combined,
|
| 21 |
+
(original_batch_size, n_chunks, n_channel, chunk_size)
|
| 22 |
+
)
|
| 23 |
+
combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
|
| 24 |
+
original_batch_size * n_channel,
|
| 25 |
+
chunk_size,
|
| 26 |
+
n_chunks
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
return combined
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@torch.jit.script
|
| 33 |
+
def unfold(
|
| 34 |
+
padded_audio: torch.Tensor,
|
| 35 |
+
original_batch_size: int,
|
| 36 |
+
n_channel: int,
|
| 37 |
+
chunk_size: int,
|
| 38 |
+
hop_size: int
|
| 39 |
+
) -> torch.Tensor:
|
| 40 |
+
|
| 41 |
+
unfolded_input = F.unfold(
|
| 42 |
+
padded_audio[:, :, None, :],
|
| 43 |
+
kernel_size=(1, chunk_size),
|
| 44 |
+
stride=(1, hop_size)
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
_, _, n_chunks = unfolded_input.shape
|
| 48 |
+
unfolded_input = unfolded_input.view(
|
| 49 |
+
original_batch_size,
|
| 50 |
+
n_channel,
|
| 51 |
+
chunk_size,
|
| 52 |
+
n_chunks
|
| 53 |
+
)
|
| 54 |
+
unfolded_input = torch.permute(
|
| 55 |
+
unfolded_input,
|
| 56 |
+
(0, 3, 1, 2)
|
| 57 |
+
).reshape(
|
| 58 |
+
original_batch_size * n_chunks,
|
| 59 |
+
n_channel,
|
| 60 |
+
chunk_size
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
return unfolded_input
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@torch.jit.script
|
| 67 |
+
# @torch.compile
|
| 68 |
+
def merge_chunks_all(
|
| 69 |
+
combined: torch.Tensor,
|
| 70 |
+
original_batch_size: int,
|
| 71 |
+
n_channel: int,
|
| 72 |
+
n_samples: int,
|
| 73 |
+
n_padded_samples: int,
|
| 74 |
+
n_chunks: int,
|
| 75 |
+
chunk_size: int,
|
| 76 |
+
hop_size: int,
|
| 77 |
+
edge_frame_pad_sizes: Tuple[int, int],
|
| 78 |
+
standard_window: torch.Tensor,
|
| 79 |
+
first_window: torch.Tensor,
|
| 80 |
+
last_window: torch.Tensor
|
| 81 |
+
):
|
| 82 |
+
combined = merge(
|
| 83 |
+
combined,
|
| 84 |
+
original_batch_size,
|
| 85 |
+
n_channel,
|
| 86 |
+
n_chunks,
|
| 87 |
+
chunk_size
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
combined = combined * standard_window[:, None].to(combined.device)
|
| 91 |
+
|
| 92 |
+
combined = F.fold(
|
| 93 |
+
combined.to(torch.float32), output_size=(1, n_padded_samples),
|
| 94 |
+
kernel_size=(1, chunk_size),
|
| 95 |
+
stride=(1, hop_size)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
combined = combined.view(
|
| 99 |
+
original_batch_size,
|
| 100 |
+
n_channel,
|
| 101 |
+
n_padded_samples
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
pad_front, pad_back = edge_frame_pad_sizes
|
| 105 |
+
combined = combined[..., pad_front:-pad_back]
|
| 106 |
+
|
| 107 |
+
combined = combined[..., :n_samples]
|
| 108 |
+
|
| 109 |
+
return combined
|
| 110 |
+
|
| 111 |
+
# @torch.jit.script
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def merge_chunks_edge(
|
| 115 |
+
combined: torch.Tensor,
|
| 116 |
+
original_batch_size: int,
|
| 117 |
+
n_channel: int,
|
| 118 |
+
n_samples: int,
|
| 119 |
+
n_padded_samples: int,
|
| 120 |
+
n_chunks: int,
|
| 121 |
+
chunk_size: int,
|
| 122 |
+
hop_size: int,
|
| 123 |
+
edge_frame_pad_sizes: Tuple[int, int],
|
| 124 |
+
standard_window: torch.Tensor,
|
| 125 |
+
first_window: torch.Tensor,
|
| 126 |
+
last_window: torch.Tensor
|
| 127 |
+
):
|
| 128 |
+
combined = merge(
|
| 129 |
+
combined,
|
| 130 |
+
original_batch_size,
|
| 131 |
+
n_channel,
|
| 132 |
+
n_chunks,
|
| 133 |
+
chunk_size
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
combined[..., 0] = combined[..., 0] * first_window
|
| 137 |
+
combined[..., -1] = combined[..., -1] * last_window
|
| 138 |
+
combined[..., 1:-1] = combined[...,
|
| 139 |
+
1:-1] * standard_window[:, None]
|
| 140 |
+
|
| 141 |
+
combined = F.fold(
|
| 142 |
+
combined, output_size=(1, n_padded_samples),
|
| 143 |
+
kernel_size=(1, chunk_size),
|
| 144 |
+
stride=(1, hop_size)
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
combined = combined.view(
|
| 148 |
+
original_batch_size,
|
| 149 |
+
n_channel,
|
| 150 |
+
n_padded_samples
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
combined = combined[..., :n_samples]
|
| 154 |
+
|
| 155 |
+
return combined
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class BaseFader(nn.Module):
|
| 159 |
+
def __init__(
|
| 160 |
+
self,
|
| 161 |
+
chunk_size_second: float,
|
| 162 |
+
hop_size_second: float,
|
| 163 |
+
fs: int,
|
| 164 |
+
fade_edge_frames: bool,
|
| 165 |
+
batch_size: int,
|
| 166 |
+
) -> None:
|
| 167 |
+
super().__init__()
|
| 168 |
+
|
| 169 |
+
self.chunk_size = int(chunk_size_second * fs)
|
| 170 |
+
self.hop_size = int(hop_size_second * fs)
|
| 171 |
+
self.overlap_size = self.chunk_size - self.hop_size
|
| 172 |
+
self.fade_edge_frames = fade_edge_frames
|
| 173 |
+
self.batch_size = batch_size
|
| 174 |
+
|
| 175 |
+
# @torch.jit.script
|
| 176 |
+
def prepare(self, audio):
|
| 177 |
+
|
| 178 |
+
if self.fade_edge_frames:
|
| 179 |
+
audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
|
| 180 |
+
|
| 181 |
+
n_samples = audio.shape[-1]
|
| 182 |
+
n_chunks = int(
|
| 183 |
+
np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
|
| 187 |
+
pad_size = padded_size - n_samples
|
| 188 |
+
|
| 189 |
+
padded_audio = F.pad(audio, (0, pad_size))
|
| 190 |
+
|
| 191 |
+
return padded_audio, n_chunks
|
| 192 |
+
|
| 193 |
+
def forward(
|
| 194 |
+
self,
|
| 195 |
+
audio: torch.Tensor,
|
| 196 |
+
model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
|
| 197 |
+
):
|
| 198 |
+
|
| 199 |
+
original_dtype = audio.dtype
|
| 200 |
+
original_device = audio.device
|
| 201 |
+
|
| 202 |
+
audio = audio.to("cpu")
|
| 203 |
+
|
| 204 |
+
original_batch_size, n_channel, n_samples = audio.shape
|
| 205 |
+
padded_audio, n_chunks = self.prepare(audio)
|
| 206 |
+
del audio
|
| 207 |
+
n_padded_samples = padded_audio.shape[-1]
|
| 208 |
+
|
| 209 |
+
if n_channel > 1:
|
| 210 |
+
padded_audio = padded_audio.view(
|
| 211 |
+
original_batch_size * n_channel, 1, n_padded_samples
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
unfolded_input = unfold(
|
| 215 |
+
padded_audio,
|
| 216 |
+
original_batch_size,
|
| 217 |
+
n_channel,
|
| 218 |
+
self.chunk_size, self.hop_size
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
n_total_chunks, n_channel, chunk_size = unfolded_input.shape
|
| 222 |
+
|
| 223 |
+
n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
|
| 224 |
+
|
| 225 |
+
chunks_in = [
|
| 226 |
+
unfolded_input[
|
| 227 |
+
b * self.batch_size:(b + 1) * self.batch_size, ...].clone()
|
| 228 |
+
for b in range(n_batch)
|
| 229 |
+
]
|
| 230 |
+
|
| 231 |
+
all_chunks_out = defaultdict(
|
| 232 |
+
lambda: torch.zeros_like(
|
| 233 |
+
unfolded_input, device="cpu"
|
| 234 |
+
)
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# for b, cin in enumerate(tqdm(chunks_in)):
|
| 238 |
+
for b, cin in enumerate(chunks_in):
|
| 239 |
+
if torch.allclose(cin, torch.tensor(0.0)):
|
| 240 |
+
del cin
|
| 241 |
+
continue
|
| 242 |
+
|
| 243 |
+
chunks_out = model_fn(cin.to(original_device))
|
| 244 |
+
del cin
|
| 245 |
+
for s, c in chunks_out.items():
|
| 246 |
+
all_chunks_out[s][b * self.batch_size:(b + 1) * self.batch_size,
|
| 247 |
+
...] = c.cpu()
|
| 248 |
+
del chunks_out
|
| 249 |
+
|
| 250 |
+
del unfolded_input
|
| 251 |
+
del padded_audio
|
| 252 |
+
|
| 253 |
+
if self.fade_edge_frames:
|
| 254 |
+
fn = merge_chunks_all
|
| 255 |
+
else:
|
| 256 |
+
fn = merge_chunks_edge
|
| 257 |
+
outputs = {}
|
| 258 |
+
|
| 259 |
+
torch.cuda.empty_cache()
|
| 260 |
+
|
| 261 |
+
for s, c in all_chunks_out.items():
|
| 262 |
+
combined: torch.Tensor = fn(
|
| 263 |
+
c,
|
| 264 |
+
original_batch_size,
|
| 265 |
+
n_channel,
|
| 266 |
+
n_samples,
|
| 267 |
+
n_padded_samples,
|
| 268 |
+
n_chunks,
|
| 269 |
+
self.chunk_size,
|
| 270 |
+
self.hop_size,
|
| 271 |
+
self.edge_frame_pad_sizes,
|
| 272 |
+
self.standard_window,
|
| 273 |
+
self.__dict__.get("first_window", self.standard_window),
|
| 274 |
+
self.__dict__.get("last_window", self.standard_window)
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
outputs[s] = combined.to(
|
| 278 |
+
dtype=original_dtype,
|
| 279 |
+
device=original_device
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
return {
|
| 283 |
+
"audio": outputs
|
| 284 |
+
}
|
| 285 |
+
#
|
| 286 |
+
# def old_forward(
|
| 287 |
+
# self,
|
| 288 |
+
# audio: torch.Tensor,
|
| 289 |
+
# model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
|
| 290 |
+
# ):
|
| 291 |
+
#
|
| 292 |
+
# n_samples = audio.shape[-1]
|
| 293 |
+
# original_batch_size = audio.shape[0]
|
| 294 |
+
#
|
| 295 |
+
# padded_audio, n_chunks = self.prepare(audio)
|
| 296 |
+
#
|
| 297 |
+
# ndim = padded_audio.ndim
|
| 298 |
+
# broadcaster = [1 for _ in range(ndim - 1)] + [self.chunk_size]
|
| 299 |
+
#
|
| 300 |
+
# outputs = defaultdict(
|
| 301 |
+
# lambda: torch.zeros_like(
|
| 302 |
+
# padded_audio, device=audio.device, dtype=torch.float64
|
| 303 |
+
# )
|
| 304 |
+
# )
|
| 305 |
+
#
|
| 306 |
+
# all_chunks_out = []
|
| 307 |
+
# len_chunks_in = []
|
| 308 |
+
#
|
| 309 |
+
# batch_size_ = int(self.batch_size // original_batch_size)
|
| 310 |
+
# for b in range(int(np.ceil(n_chunks / batch_size_))):
|
| 311 |
+
# chunks_in = []
|
| 312 |
+
# for j in range(batch_size_):
|
| 313 |
+
# i = b * batch_size_ + j
|
| 314 |
+
# if i == n_chunks:
|
| 315 |
+
# break
|
| 316 |
+
#
|
| 317 |
+
# start = i * hop_size
|
| 318 |
+
# end = start + self.chunk_size
|
| 319 |
+
# chunk_in = padded_audio[..., start:end]
|
| 320 |
+
# chunks_in.append(chunk_in)
|
| 321 |
+
#
|
| 322 |
+
# chunks_in = torch.concat(chunks_in, dim=0)
|
| 323 |
+
# chunks_out = model_fn(chunks_in)
|
| 324 |
+
# all_chunks_out.append(chunks_out)
|
| 325 |
+
# len_chunks_in.append(len(chunks_in))
|
| 326 |
+
#
|
| 327 |
+
# for b, (chunks_out, lci) in enumerate(
|
| 328 |
+
# zip(all_chunks_out, len_chunks_in)
|
| 329 |
+
# ):
|
| 330 |
+
# for stem in chunks_out:
|
| 331 |
+
# for j in range(lci // original_batch_size):
|
| 332 |
+
# i = b * batch_size_ + j
|
| 333 |
+
#
|
| 334 |
+
# if self.fade_edge_frames:
|
| 335 |
+
# window = self.standard_window
|
| 336 |
+
# else:
|
| 337 |
+
# if i == 0:
|
| 338 |
+
# window = self.first_window
|
| 339 |
+
# elif i == n_chunks - 1:
|
| 340 |
+
# window = self.last_window
|
| 341 |
+
# else:
|
| 342 |
+
# window = self.standard_window
|
| 343 |
+
#
|
| 344 |
+
# start = i * hop_size
|
| 345 |
+
# end = start + self.chunk_size
|
| 346 |
+
#
|
| 347 |
+
# chunk_out = chunks_out[stem][j * original_batch_size: (j + 1) * original_batch_size,
|
| 348 |
+
# ...]
|
| 349 |
+
# contrib = window.view(*broadcaster) * chunk_out
|
| 350 |
+
# outputs[stem][..., start:end] = (
|
| 351 |
+
# outputs[stem][..., start:end] + contrib
|
| 352 |
+
# )
|
| 353 |
+
#
|
| 354 |
+
# if self.fade_edge_frames:
|
| 355 |
+
# pad_front, pad_back = self.edge_frame_pad_sizes
|
| 356 |
+
# outputs = {k: v[..., pad_front:-pad_back] for k, v in
|
| 357 |
+
# outputs.items()}
|
| 358 |
+
#
|
| 359 |
+
# outputs = {k: v[..., :n_samples].to(audio.dtype) for k, v in
|
| 360 |
+
# outputs.items()}
|
| 361 |
+
#
|
| 362 |
+
# return {
|
| 363 |
+
# "audio": outputs
|
| 364 |
+
# }
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class LinearFader(BaseFader):
|
| 368 |
+
def __init__(
|
| 369 |
+
self,
|
| 370 |
+
chunk_size_second: float,
|
| 371 |
+
hop_size_second: float,
|
| 372 |
+
fs: int,
|
| 373 |
+
fade_edge_frames: bool = False,
|
| 374 |
+
batch_size: int = 1,
|
| 375 |
+
) -> None:
|
| 376 |
+
|
| 377 |
+
assert hop_size_second >= chunk_size_second / 2
|
| 378 |
+
|
| 379 |
+
super().__init__(
|
| 380 |
+
chunk_size_second=chunk_size_second,
|
| 381 |
+
hop_size_second=hop_size_second,
|
| 382 |
+
fs=fs,
|
| 383 |
+
fade_edge_frames=fade_edge_frames,
|
| 384 |
+
batch_size=batch_size,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
|
| 388 |
+
out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
|
| 389 |
+
center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
|
| 390 |
+
inout_ones = torch.ones(self.overlap_size)
|
| 391 |
+
|
| 392 |
+
# using nn.Parameters allows lightning to take care of devices for us
|
| 393 |
+
self.register_buffer(
|
| 394 |
+
"standard_window",
|
| 395 |
+
torch.concat([in_fade, center_ones, out_fade])
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
self.fade_edge_frames = fade_edge_frames
|
| 399 |
+
self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
|
| 400 |
+
|
| 401 |
+
if not self.fade_edge_frames:
|
| 402 |
+
self.first_window = nn.Parameter(
|
| 403 |
+
torch.concat([inout_ones, center_ones, out_fade]),
|
| 404 |
+
requires_grad=False
|
| 405 |
+
)
|
| 406 |
+
self.last_window = nn.Parameter(
|
| 407 |
+
torch.concat([in_fade, center_ones, inout_ones]),
|
| 408 |
+
requires_grad=False
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class OverlapAddFader(BaseFader):
|
| 413 |
+
def __init__(
|
| 414 |
+
self,
|
| 415 |
+
window_type: str,
|
| 416 |
+
chunk_size_second: float,
|
| 417 |
+
hop_size_second: float,
|
| 418 |
+
fs: int,
|
| 419 |
+
batch_size: int = 1,
|
| 420 |
+
) -> None:
|
| 421 |
+
assert (chunk_size_second / hop_size_second) % 2 == 0
|
| 422 |
+
assert int(chunk_size_second * fs) % 2 == 0
|
| 423 |
+
|
| 424 |
+
super().__init__(
|
| 425 |
+
chunk_size_second=chunk_size_second,
|
| 426 |
+
hop_size_second=hop_size_second,
|
| 427 |
+
fs=fs,
|
| 428 |
+
fade_edge_frames=True,
|
| 429 |
+
batch_size=batch_size,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
|
| 433 |
+
# print(f"hop multiplier: {self.hop_multiplier}")
|
| 434 |
+
|
| 435 |
+
self.edge_frame_pad_sizes = (
|
| 436 |
+
2 * self.overlap_size,
|
| 437 |
+
2 * self.overlap_size
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
self.register_buffer(
|
| 441 |
+
"standard_window", torch.windows.__dict__[window_type](
|
| 442 |
+
self.chunk_size, sym=False, # dtype=torch.float64
|
| 443 |
+
) / self.hop_multiplier
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
if __name__ == "__main__":
|
| 448 |
+
import torchaudio as ta
|
| 449 |
+
fs = 44100
|
| 450 |
+
ola = OverlapAddFader(
|
| 451 |
+
"hann",
|
| 452 |
+
6.0,
|
| 453 |
+
1.0,
|
| 454 |
+
fs,
|
| 455 |
+
batch_size=16
|
| 456 |
+
)
|
| 457 |
+
audio_, _ = ta.load(
|
| 458 |
+
"$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too "
|
| 459 |
+
"Much/vocals.wav"
|
| 460 |
+
)
|
| 461 |
+
audio_ = audio_[None, ...]
|
| 462 |
+
out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
|
| 463 |
+
print(torch.allclose(out, audio_))
|
separator/models/bandit/model_from_config.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os.path
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
code_path = os.path.dirname(os.path.abspath(__file__)) + '/'
|
| 6 |
+
sys.path.append(code_path)
|
| 7 |
+
|
| 8 |
+
import yaml
|
| 9 |
+
from ml_collections import ConfigDict
|
| 10 |
+
|
| 11 |
+
torch.set_float32_matmul_precision("medium")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_model(
|
| 15 |
+
config_path,
|
| 16 |
+
weights_path,
|
| 17 |
+
device,
|
| 18 |
+
):
|
| 19 |
+
from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple
|
| 20 |
+
|
| 21 |
+
f = open(config_path)
|
| 22 |
+
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
|
| 23 |
+
f.close()
|
| 24 |
+
|
| 25 |
+
model = MultiMaskMultiSourceBandSplitRNNSimple(
|
| 26 |
+
**config.model
|
| 27 |
+
)
|
| 28 |
+
d = torch.load(code_path + 'model_bandit_plus_dnr_sdr_11.47.chpt')
|
| 29 |
+
model.load_state_dict(d)
|
| 30 |
+
model.to(device)
|
| 31 |
+
return model, config
|
separator/models/bandit_v2/bandit.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio as ta
|
| 5 |
+
from torch import nn
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
|
| 8 |
+
from .bandsplit import BandSplitModule
|
| 9 |
+
from .maskestim import OverlappingMaskEstimationModule
|
| 10 |
+
from .tfmodel import SeqBandModellingModule
|
| 11 |
+
from .utils import MusicalBandsplitSpecification
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BaseEndToEndModule(pl.LightningModule):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class BaseBandit(BaseEndToEndModule):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
in_channels: int,
|
| 26 |
+
fs: int,
|
| 27 |
+
band_type: str = "musical",
|
| 28 |
+
n_bands: int = 64,
|
| 29 |
+
require_no_overlap: bool = False,
|
| 30 |
+
require_no_gap: bool = True,
|
| 31 |
+
normalize_channel_independently: bool = False,
|
| 32 |
+
treat_channel_as_feature: bool = True,
|
| 33 |
+
n_sqm_modules: int = 12,
|
| 34 |
+
emb_dim: int = 128,
|
| 35 |
+
rnn_dim: int = 256,
|
| 36 |
+
bidirectional: bool = True,
|
| 37 |
+
rnn_type: str = "LSTM",
|
| 38 |
+
n_fft: int = 2048,
|
| 39 |
+
win_length: Optional[int] = 2048,
|
| 40 |
+
hop_length: int = 512,
|
| 41 |
+
window_fn: str = "hann_window",
|
| 42 |
+
wkwargs: Optional[Dict] = None,
|
| 43 |
+
power: Optional[int] = None,
|
| 44 |
+
center: bool = True,
|
| 45 |
+
normalized: bool = True,
|
| 46 |
+
pad_mode: str = "constant",
|
| 47 |
+
onesided: bool = True,
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
|
| 51 |
+
self.in_channels = in_channels
|
| 52 |
+
|
| 53 |
+
self.instantitate_spectral(
|
| 54 |
+
n_fft=n_fft,
|
| 55 |
+
win_length=win_length,
|
| 56 |
+
hop_length=hop_length,
|
| 57 |
+
window_fn=window_fn,
|
| 58 |
+
wkwargs=wkwargs,
|
| 59 |
+
power=power,
|
| 60 |
+
normalized=normalized,
|
| 61 |
+
center=center,
|
| 62 |
+
pad_mode=pad_mode,
|
| 63 |
+
onesided=onesided,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.instantiate_bandsplit(
|
| 67 |
+
in_channels=in_channels,
|
| 68 |
+
band_type=band_type,
|
| 69 |
+
n_bands=n_bands,
|
| 70 |
+
require_no_overlap=require_no_overlap,
|
| 71 |
+
require_no_gap=require_no_gap,
|
| 72 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 73 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 74 |
+
emb_dim=emb_dim,
|
| 75 |
+
n_fft=n_fft,
|
| 76 |
+
fs=fs,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
self.instantiate_tf_modelling(
|
| 80 |
+
n_sqm_modules=n_sqm_modules,
|
| 81 |
+
emb_dim=emb_dim,
|
| 82 |
+
rnn_dim=rnn_dim,
|
| 83 |
+
bidirectional=bidirectional,
|
| 84 |
+
rnn_type=rnn_type,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def instantitate_spectral(
|
| 88 |
+
self,
|
| 89 |
+
n_fft: int = 2048,
|
| 90 |
+
win_length: Optional[int] = 2048,
|
| 91 |
+
hop_length: int = 512,
|
| 92 |
+
window_fn: str = "hann_window",
|
| 93 |
+
wkwargs: Optional[Dict] = None,
|
| 94 |
+
power: Optional[int] = None,
|
| 95 |
+
normalized: bool = True,
|
| 96 |
+
center: bool = True,
|
| 97 |
+
pad_mode: str = "constant",
|
| 98 |
+
onesided: bool = True,
|
| 99 |
+
):
|
| 100 |
+
assert power is None
|
| 101 |
+
|
| 102 |
+
window_fn = torch.__dict__[window_fn]
|
| 103 |
+
|
| 104 |
+
self.stft = ta.transforms.Spectrogram(
|
| 105 |
+
n_fft=n_fft,
|
| 106 |
+
win_length=win_length,
|
| 107 |
+
hop_length=hop_length,
|
| 108 |
+
pad_mode=pad_mode,
|
| 109 |
+
pad=0,
|
| 110 |
+
window_fn=window_fn,
|
| 111 |
+
wkwargs=wkwargs,
|
| 112 |
+
power=power,
|
| 113 |
+
normalized=normalized,
|
| 114 |
+
center=center,
|
| 115 |
+
onesided=onesided,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
self.istft = ta.transforms.InverseSpectrogram(
|
| 119 |
+
n_fft=n_fft,
|
| 120 |
+
win_length=win_length,
|
| 121 |
+
hop_length=hop_length,
|
| 122 |
+
pad_mode=pad_mode,
|
| 123 |
+
pad=0,
|
| 124 |
+
window_fn=window_fn,
|
| 125 |
+
wkwargs=wkwargs,
|
| 126 |
+
normalized=normalized,
|
| 127 |
+
center=center,
|
| 128 |
+
onesided=onesided,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def instantiate_bandsplit(
|
| 132 |
+
self,
|
| 133 |
+
in_channels: int,
|
| 134 |
+
band_type: str = "musical",
|
| 135 |
+
n_bands: int = 64,
|
| 136 |
+
require_no_overlap: bool = False,
|
| 137 |
+
require_no_gap: bool = True,
|
| 138 |
+
normalize_channel_independently: bool = False,
|
| 139 |
+
treat_channel_as_feature: bool = True,
|
| 140 |
+
emb_dim: int = 128,
|
| 141 |
+
n_fft: int = 2048,
|
| 142 |
+
fs: int = 44100,
|
| 143 |
+
):
|
| 144 |
+
assert band_type == "musical"
|
| 145 |
+
|
| 146 |
+
self.band_specs = MusicalBandsplitSpecification(
|
| 147 |
+
nfft=n_fft, fs=fs, n_bands=n_bands
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.band_split = BandSplitModule(
|
| 151 |
+
in_channels=in_channels,
|
| 152 |
+
band_specs=self.band_specs.get_band_specs(),
|
| 153 |
+
require_no_overlap=require_no_overlap,
|
| 154 |
+
require_no_gap=require_no_gap,
|
| 155 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 156 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 157 |
+
emb_dim=emb_dim,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
def instantiate_tf_modelling(
|
| 161 |
+
self,
|
| 162 |
+
n_sqm_modules: int = 12,
|
| 163 |
+
emb_dim: int = 128,
|
| 164 |
+
rnn_dim: int = 256,
|
| 165 |
+
bidirectional: bool = True,
|
| 166 |
+
rnn_type: str = "LSTM",
|
| 167 |
+
):
|
| 168 |
+
try:
|
| 169 |
+
self.tf_model = torch.compile(
|
| 170 |
+
SeqBandModellingModule(
|
| 171 |
+
n_modules=n_sqm_modules,
|
| 172 |
+
emb_dim=emb_dim,
|
| 173 |
+
rnn_dim=rnn_dim,
|
| 174 |
+
bidirectional=bidirectional,
|
| 175 |
+
rnn_type=rnn_type,
|
| 176 |
+
),
|
| 177 |
+
disable=True,
|
| 178 |
+
)
|
| 179 |
+
except Exception as e:
|
| 180 |
+
self.tf_model = SeqBandModellingModule(
|
| 181 |
+
n_modules=n_sqm_modules,
|
| 182 |
+
emb_dim=emb_dim,
|
| 183 |
+
rnn_dim=rnn_dim,
|
| 184 |
+
bidirectional=bidirectional,
|
| 185 |
+
rnn_type=rnn_type,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def mask(self, x, m):
|
| 189 |
+
return x * m
|
| 190 |
+
|
| 191 |
+
def forward(self, batch, mode="train"):
|
| 192 |
+
# Model takes mono as input we give stereo, so we do process of each channel independently
|
| 193 |
+
init_shape = batch.shape
|
| 194 |
+
if not isinstance(batch, dict):
|
| 195 |
+
mono = batch.view(-1, 1, batch.shape[-1])
|
| 196 |
+
batch = {
|
| 197 |
+
"mixture": {
|
| 198 |
+
"audio": mono
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
with torch.no_grad():
|
| 203 |
+
mixture = batch["mixture"]["audio"]
|
| 204 |
+
|
| 205 |
+
x = self.stft(mixture)
|
| 206 |
+
batch["mixture"]["spectrogram"] = x
|
| 207 |
+
|
| 208 |
+
if "sources" in batch.keys():
|
| 209 |
+
for stem in batch["sources"].keys():
|
| 210 |
+
s = batch["sources"][stem]["audio"]
|
| 211 |
+
s = self.stft(s)
|
| 212 |
+
batch["sources"][stem]["spectrogram"] = s
|
| 213 |
+
|
| 214 |
+
batch = self.separate(batch)
|
| 215 |
+
|
| 216 |
+
if 1:
|
| 217 |
+
b = []
|
| 218 |
+
for s in self.stems:
|
| 219 |
+
# We need to obtain stereo again
|
| 220 |
+
r = batch['estimates'][s]['audio'].view(-1, init_shape[1], init_shape[2])
|
| 221 |
+
b.append(r)
|
| 222 |
+
# And we need to return back tensor and not independent stems
|
| 223 |
+
batch = torch.stack(b, dim=1)
|
| 224 |
+
return batch
|
| 225 |
+
|
| 226 |
+
def encode(self, batch):
|
| 227 |
+
x = batch["mixture"]["spectrogram"]
|
| 228 |
+
length = batch["mixture"]["audio"].shape[-1]
|
| 229 |
+
|
| 230 |
+
z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
|
| 231 |
+
q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
|
| 232 |
+
|
| 233 |
+
return x, q, length
|
| 234 |
+
|
| 235 |
+
def separate(self, batch):
|
| 236 |
+
raise NotImplementedError
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class Bandit(BaseBandit):
|
| 240 |
+
def __init__(
|
| 241 |
+
self,
|
| 242 |
+
in_channels: int,
|
| 243 |
+
stems: List[str],
|
| 244 |
+
band_type: str = "musical",
|
| 245 |
+
n_bands: int = 64,
|
| 246 |
+
require_no_overlap: bool = False,
|
| 247 |
+
require_no_gap: bool = True,
|
| 248 |
+
normalize_channel_independently: bool = False,
|
| 249 |
+
treat_channel_as_feature: bool = True,
|
| 250 |
+
n_sqm_modules: int = 12,
|
| 251 |
+
emb_dim: int = 128,
|
| 252 |
+
rnn_dim: int = 256,
|
| 253 |
+
bidirectional: bool = True,
|
| 254 |
+
rnn_type: str = "LSTM",
|
| 255 |
+
mlp_dim: int = 512,
|
| 256 |
+
hidden_activation: str = "Tanh",
|
| 257 |
+
hidden_activation_kwargs: Dict | None = None,
|
| 258 |
+
complex_mask: bool = True,
|
| 259 |
+
use_freq_weights: bool = True,
|
| 260 |
+
n_fft: int = 2048,
|
| 261 |
+
win_length: int | None = 2048,
|
| 262 |
+
hop_length: int = 512,
|
| 263 |
+
window_fn: str = "hann_window",
|
| 264 |
+
wkwargs: Dict | None = None,
|
| 265 |
+
power: int | None = None,
|
| 266 |
+
center: bool = True,
|
| 267 |
+
normalized: bool = True,
|
| 268 |
+
pad_mode: str = "constant",
|
| 269 |
+
onesided: bool = True,
|
| 270 |
+
fs: int = 44100,
|
| 271 |
+
stft_precisions="32",
|
| 272 |
+
bandsplit_precisions="bf16",
|
| 273 |
+
tf_model_precisions="bf16",
|
| 274 |
+
mask_estim_precisions="bf16",
|
| 275 |
+
):
|
| 276 |
+
super().__init__(
|
| 277 |
+
in_channels=in_channels,
|
| 278 |
+
band_type=band_type,
|
| 279 |
+
n_bands=n_bands,
|
| 280 |
+
require_no_overlap=require_no_overlap,
|
| 281 |
+
require_no_gap=require_no_gap,
|
| 282 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 283 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 284 |
+
n_sqm_modules=n_sqm_modules,
|
| 285 |
+
emb_dim=emb_dim,
|
| 286 |
+
rnn_dim=rnn_dim,
|
| 287 |
+
bidirectional=bidirectional,
|
| 288 |
+
rnn_type=rnn_type,
|
| 289 |
+
n_fft=n_fft,
|
| 290 |
+
win_length=win_length,
|
| 291 |
+
hop_length=hop_length,
|
| 292 |
+
window_fn=window_fn,
|
| 293 |
+
wkwargs=wkwargs,
|
| 294 |
+
power=power,
|
| 295 |
+
center=center,
|
| 296 |
+
normalized=normalized,
|
| 297 |
+
pad_mode=pad_mode,
|
| 298 |
+
onesided=onesided,
|
| 299 |
+
fs=fs,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
self.stems = stems
|
| 303 |
+
|
| 304 |
+
self.instantiate_mask_estim(
|
| 305 |
+
in_channels=in_channels,
|
| 306 |
+
stems=stems,
|
| 307 |
+
emb_dim=emb_dim,
|
| 308 |
+
mlp_dim=mlp_dim,
|
| 309 |
+
hidden_activation=hidden_activation,
|
| 310 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 311 |
+
complex_mask=complex_mask,
|
| 312 |
+
n_freq=n_fft // 2 + 1,
|
| 313 |
+
use_freq_weights=use_freq_weights,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
def instantiate_mask_estim(
|
| 317 |
+
self,
|
| 318 |
+
in_channels: int,
|
| 319 |
+
stems: List[str],
|
| 320 |
+
emb_dim: int,
|
| 321 |
+
mlp_dim: int,
|
| 322 |
+
hidden_activation: str,
|
| 323 |
+
hidden_activation_kwargs: Optional[Dict] = None,
|
| 324 |
+
complex_mask: bool = True,
|
| 325 |
+
n_freq: Optional[int] = None,
|
| 326 |
+
use_freq_weights: bool = False,
|
| 327 |
+
):
|
| 328 |
+
if hidden_activation_kwargs is None:
|
| 329 |
+
hidden_activation_kwargs = {}
|
| 330 |
+
|
| 331 |
+
assert n_freq is not None
|
| 332 |
+
|
| 333 |
+
self.mask_estim = nn.ModuleDict(
|
| 334 |
+
{
|
| 335 |
+
stem: OverlappingMaskEstimationModule(
|
| 336 |
+
band_specs=self.band_specs.get_band_specs(),
|
| 337 |
+
freq_weights=self.band_specs.get_freq_weights(),
|
| 338 |
+
n_freq=n_freq,
|
| 339 |
+
emb_dim=emb_dim,
|
| 340 |
+
mlp_dim=mlp_dim,
|
| 341 |
+
in_channels=in_channels,
|
| 342 |
+
hidden_activation=hidden_activation,
|
| 343 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 344 |
+
complex_mask=complex_mask,
|
| 345 |
+
use_freq_weights=use_freq_weights,
|
| 346 |
+
)
|
| 347 |
+
for stem in stems
|
| 348 |
+
}
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
def separate(self, batch):
|
| 352 |
+
batch["estimates"] = {}
|
| 353 |
+
|
| 354 |
+
x, q, length = self.encode(batch)
|
| 355 |
+
|
| 356 |
+
for stem, mem in self.mask_estim.items():
|
| 357 |
+
m = mem(q)
|
| 358 |
+
|
| 359 |
+
s = self.mask(x, m.to(x.dtype))
|
| 360 |
+
s = torch.reshape(s, x.shape)
|
| 361 |
+
batch["estimates"][stem] = {
|
| 362 |
+
"audio": self.istft(s, length),
|
| 363 |
+
"spectrogram": s,
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
return batch
|
| 367 |
+
|
separator/models/bandit_v2/bandsplit.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
| 6 |
+
|
| 7 |
+
from .utils import (
|
| 8 |
+
band_widths_from_specs,
|
| 9 |
+
check_no_gap,
|
| 10 |
+
check_no_overlap,
|
| 11 |
+
check_nonzero_bandwidth,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class NormFC(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
emb_dim: int,
|
| 19 |
+
bandwidth: int,
|
| 20 |
+
in_channels: int,
|
| 21 |
+
normalize_channel_independently: bool = False,
|
| 22 |
+
treat_channel_as_feature: bool = True,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
if not treat_channel_as_feature:
|
| 27 |
+
raise NotImplementedError
|
| 28 |
+
|
| 29 |
+
self.treat_channel_as_feature = treat_channel_as_feature
|
| 30 |
+
|
| 31 |
+
if normalize_channel_independently:
|
| 32 |
+
raise NotImplementedError
|
| 33 |
+
|
| 34 |
+
reim = 2
|
| 35 |
+
|
| 36 |
+
norm = nn.LayerNorm(in_channels * bandwidth * reim)
|
| 37 |
+
|
| 38 |
+
fc_in = bandwidth * reim
|
| 39 |
+
|
| 40 |
+
if treat_channel_as_feature:
|
| 41 |
+
fc_in *= in_channels
|
| 42 |
+
else:
|
| 43 |
+
assert emb_dim % in_channels == 0
|
| 44 |
+
emb_dim = emb_dim // in_channels
|
| 45 |
+
|
| 46 |
+
fc = nn.Linear(fc_in, emb_dim)
|
| 47 |
+
|
| 48 |
+
self.combined = nn.Sequential(norm, fc)
|
| 49 |
+
|
| 50 |
+
def forward(self, xb):
|
| 51 |
+
return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class BandSplitModule(nn.Module):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
band_specs: List[Tuple[float, float]],
|
| 58 |
+
emb_dim: int,
|
| 59 |
+
in_channels: int,
|
| 60 |
+
require_no_overlap: bool = False,
|
| 61 |
+
require_no_gap: bool = True,
|
| 62 |
+
normalize_channel_independently: bool = False,
|
| 63 |
+
treat_channel_as_feature: bool = True,
|
| 64 |
+
) -> None:
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
check_nonzero_bandwidth(band_specs)
|
| 68 |
+
|
| 69 |
+
if require_no_gap:
|
| 70 |
+
check_no_gap(band_specs)
|
| 71 |
+
|
| 72 |
+
if require_no_overlap:
|
| 73 |
+
check_no_overlap(band_specs)
|
| 74 |
+
|
| 75 |
+
self.band_specs = band_specs
|
| 76 |
+
# list of [fstart, fend) in index.
|
| 77 |
+
# Note that fend is exclusive.
|
| 78 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
| 79 |
+
self.n_bands = len(band_specs)
|
| 80 |
+
self.emb_dim = emb_dim
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
self.norm_fc_modules = nn.ModuleList(
|
| 84 |
+
[ # type: ignore
|
| 85 |
+
torch.compile(
|
| 86 |
+
NormFC(
|
| 87 |
+
emb_dim=emb_dim,
|
| 88 |
+
bandwidth=bw,
|
| 89 |
+
in_channels=in_channels,
|
| 90 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 91 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 92 |
+
),
|
| 93 |
+
disable=True,
|
| 94 |
+
)
|
| 95 |
+
for bw in self.band_widths
|
| 96 |
+
]
|
| 97 |
+
)
|
| 98 |
+
except Exception as e:
|
| 99 |
+
self.norm_fc_modules = nn.ModuleList(
|
| 100 |
+
[ # type: ignore
|
| 101 |
+
NormFC(
|
| 102 |
+
emb_dim=emb_dim,
|
| 103 |
+
bandwidth=bw,
|
| 104 |
+
in_channels=in_channels,
|
| 105 |
+
normalize_channel_independently=normalize_channel_independently,
|
| 106 |
+
treat_channel_as_feature=treat_channel_as_feature,
|
| 107 |
+
)
|
| 108 |
+
for bw in self.band_widths
|
| 109 |
+
]
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def forward(self, x: torch.Tensor):
|
| 113 |
+
# x = complex spectrogram (batch, in_chan, n_freq, n_time)
|
| 114 |
+
|
| 115 |
+
batch, in_chan, band_width, n_time = x.shape
|
| 116 |
+
|
| 117 |
+
z = torch.zeros(
|
| 118 |
+
size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
x = torch.permute(x, (0, 3, 1, 2)).contiguous()
|
| 122 |
+
|
| 123 |
+
for i, nfm in enumerate(self.norm_fc_modules):
|
| 124 |
+
fstart, fend = self.band_specs[i]
|
| 125 |
+
xb = x[:, :, :, fstart:fend]
|
| 126 |
+
xb = torch.view_as_real(xb)
|
| 127 |
+
xb = torch.reshape(xb, (batch, n_time, -1))
|
| 128 |
+
z[:, i, :, :] = nfm(xb)
|
| 129 |
+
|
| 130 |
+
return z
|
separator/models/bandit_v2/film.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class FiLM(nn.Module):
|
| 5 |
+
def __init__(self):
|
| 6 |
+
super().__init__()
|
| 7 |
+
|
| 8 |
+
def forward(self, x, gamma, beta):
|
| 9 |
+
return gamma * x + beta
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BTFBroadcastedFiLM(nn.Module):
|
| 13 |
+
def __init__(self):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.film = FiLM()
|
| 16 |
+
|
| 17 |
+
def forward(self, x, gamma, beta):
|
| 18 |
+
|
| 19 |
+
gamma = gamma[None, None, None, :]
|
| 20 |
+
beta = beta[None, None, None, :]
|
| 21 |
+
|
| 22 |
+
return self.film(x, gamma, beta)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
separator/models/bandit_v2/maskestim.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Optional, Tuple, Type
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.modules import activation
|
| 6 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
| 7 |
+
|
| 8 |
+
from .utils import (
|
| 9 |
+
band_widths_from_specs,
|
| 10 |
+
check_no_gap,
|
| 11 |
+
check_no_overlap,
|
| 12 |
+
check_nonzero_bandwidth,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseNormMLP(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
emb_dim: int,
|
| 20 |
+
mlp_dim: int,
|
| 21 |
+
bandwidth: int,
|
| 22 |
+
in_channels: Optional[int],
|
| 23 |
+
hidden_activation: str = "Tanh",
|
| 24 |
+
hidden_activation_kwargs=None,
|
| 25 |
+
complex_mask: bool = True,
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
if hidden_activation_kwargs is None:
|
| 29 |
+
hidden_activation_kwargs = {}
|
| 30 |
+
self.hidden_activation_kwargs = hidden_activation_kwargs
|
| 31 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 32 |
+
self.hidden = nn.Sequential(
|
| 33 |
+
nn.Linear(in_features=emb_dim, out_features=mlp_dim),
|
| 34 |
+
activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
self.bandwidth = bandwidth
|
| 38 |
+
self.in_channels = in_channels
|
| 39 |
+
|
| 40 |
+
self.complex_mask = complex_mask
|
| 41 |
+
self.reim = 2 if complex_mask else 1
|
| 42 |
+
self.glu_mult = 2
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class NormMLP(BaseNormMLP):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
emb_dim: int,
|
| 49 |
+
mlp_dim: int,
|
| 50 |
+
bandwidth: int,
|
| 51 |
+
in_channels: Optional[int],
|
| 52 |
+
hidden_activation: str = "Tanh",
|
| 53 |
+
hidden_activation_kwargs=None,
|
| 54 |
+
complex_mask: bool = True,
|
| 55 |
+
) -> None:
|
| 56 |
+
super().__init__(
|
| 57 |
+
emb_dim=emb_dim,
|
| 58 |
+
mlp_dim=mlp_dim,
|
| 59 |
+
bandwidth=bandwidth,
|
| 60 |
+
in_channels=in_channels,
|
| 61 |
+
hidden_activation=hidden_activation,
|
| 62 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 63 |
+
complex_mask=complex_mask,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.output = nn.Sequential(
|
| 67 |
+
nn.Linear(
|
| 68 |
+
in_features=mlp_dim,
|
| 69 |
+
out_features=bandwidth * in_channels * self.reim * 2,
|
| 70 |
+
),
|
| 71 |
+
nn.GLU(dim=-1),
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
self.combined = torch.compile(
|
| 76 |
+
nn.Sequential(self.norm, self.hidden, self.output), disable=True
|
| 77 |
+
)
|
| 78 |
+
except Exception as e:
|
| 79 |
+
self.combined = nn.Sequential(self.norm, self.hidden, self.output)
|
| 80 |
+
|
| 81 |
+
def reshape_output(self, mb):
|
| 82 |
+
# print(mb.shape)
|
| 83 |
+
batch, n_time, _ = mb.shape
|
| 84 |
+
if self.complex_mask:
|
| 85 |
+
mb = mb.reshape(
|
| 86 |
+
batch, n_time, self.in_channels, self.bandwidth, self.reim
|
| 87 |
+
).contiguous()
|
| 88 |
+
# print(mb.shape)
|
| 89 |
+
mb = torch.view_as_complex(mb) # (batch, n_time, in_channels, bandwidth)
|
| 90 |
+
else:
|
| 91 |
+
mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth)
|
| 92 |
+
|
| 93 |
+
mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channels, bandwidth, n_time)
|
| 94 |
+
|
| 95 |
+
return mb
|
| 96 |
+
|
| 97 |
+
def forward(self, qb):
|
| 98 |
+
# qb = (batch, n_time, emb_dim)
|
| 99 |
+
# qb = self.norm(qb) # (batch, n_time, emb_dim)
|
| 100 |
+
# qb = self.hidden(qb) # (batch, n_time, mlp_dim)
|
| 101 |
+
# mb = self.output(qb) # (batch, n_time, bandwidth * in_channels * reim)
|
| 102 |
+
|
| 103 |
+
mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False)
|
| 104 |
+
mb = self.reshape_output(mb) # (batch, in_channels, bandwidth, n_time)
|
| 105 |
+
|
| 106 |
+
return mb
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class MaskEstimationModuleSuperBase(nn.Module):
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
band_specs: List[Tuple[float, float]],
|
| 117 |
+
emb_dim: int,
|
| 118 |
+
mlp_dim: int,
|
| 119 |
+
in_channels: Optional[int],
|
| 120 |
+
hidden_activation: str = "Tanh",
|
| 121 |
+
hidden_activation_kwargs: Dict = None,
|
| 122 |
+
complex_mask: bool = True,
|
| 123 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 124 |
+
norm_mlp_kwargs: Dict = None,
|
| 125 |
+
) -> None:
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
self.band_widths = band_widths_from_specs(band_specs)
|
| 129 |
+
self.n_bands = len(band_specs)
|
| 130 |
+
|
| 131 |
+
if hidden_activation_kwargs is None:
|
| 132 |
+
hidden_activation_kwargs = {}
|
| 133 |
+
|
| 134 |
+
if norm_mlp_kwargs is None:
|
| 135 |
+
norm_mlp_kwargs = {}
|
| 136 |
+
|
| 137 |
+
self.norm_mlp = nn.ModuleList(
|
| 138 |
+
[
|
| 139 |
+
norm_mlp_cls(
|
| 140 |
+
bandwidth=self.band_widths[b],
|
| 141 |
+
emb_dim=emb_dim,
|
| 142 |
+
mlp_dim=mlp_dim,
|
| 143 |
+
in_channels=in_channels,
|
| 144 |
+
hidden_activation=hidden_activation,
|
| 145 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 146 |
+
complex_mask=complex_mask,
|
| 147 |
+
**norm_mlp_kwargs,
|
| 148 |
+
)
|
| 149 |
+
for b in range(self.n_bands)
|
| 150 |
+
]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def compute_masks(self, q):
|
| 154 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 155 |
+
|
| 156 |
+
masks = []
|
| 157 |
+
|
| 158 |
+
for b, nmlp in enumerate(self.norm_mlp):
|
| 159 |
+
# print(f"maskestim/{b:02d}")
|
| 160 |
+
qb = q[:, b, :, :]
|
| 161 |
+
mb = nmlp(qb)
|
| 162 |
+
masks.append(mb)
|
| 163 |
+
|
| 164 |
+
return masks
|
| 165 |
+
|
| 166 |
+
def compute_mask(self, q, b):
|
| 167 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 168 |
+
qb = q[:, b, :, :]
|
| 169 |
+
mb = self.norm_mlp[b](qb)
|
| 170 |
+
return mb
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
in_channels: int,
|
| 177 |
+
band_specs: List[Tuple[float, float]],
|
| 178 |
+
freq_weights: List[torch.Tensor],
|
| 179 |
+
n_freq: int,
|
| 180 |
+
emb_dim: int,
|
| 181 |
+
mlp_dim: int,
|
| 182 |
+
cond_dim: int = 0,
|
| 183 |
+
hidden_activation: str = "Tanh",
|
| 184 |
+
hidden_activation_kwargs: Dict = None,
|
| 185 |
+
complex_mask: bool = True,
|
| 186 |
+
norm_mlp_cls: Type[nn.Module] = NormMLP,
|
| 187 |
+
norm_mlp_kwargs: Dict = None,
|
| 188 |
+
use_freq_weights: bool = False,
|
| 189 |
+
) -> None:
|
| 190 |
+
check_nonzero_bandwidth(band_specs)
|
| 191 |
+
check_no_gap(band_specs)
|
| 192 |
+
|
| 193 |
+
if cond_dim > 0:
|
| 194 |
+
raise NotImplementedError
|
| 195 |
+
|
| 196 |
+
super().__init__(
|
| 197 |
+
band_specs=band_specs,
|
| 198 |
+
emb_dim=emb_dim + cond_dim,
|
| 199 |
+
mlp_dim=mlp_dim,
|
| 200 |
+
in_channels=in_channels,
|
| 201 |
+
hidden_activation=hidden_activation,
|
| 202 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 203 |
+
complex_mask=complex_mask,
|
| 204 |
+
norm_mlp_cls=norm_mlp_cls,
|
| 205 |
+
norm_mlp_kwargs=norm_mlp_kwargs,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
self.n_freq = n_freq
|
| 209 |
+
self.band_specs = band_specs
|
| 210 |
+
self.in_channels = in_channels
|
| 211 |
+
|
| 212 |
+
if freq_weights is not None and use_freq_weights:
|
| 213 |
+
for i, fw in enumerate(freq_weights):
|
| 214 |
+
self.register_buffer(f"freq_weights/{i}", fw)
|
| 215 |
+
|
| 216 |
+
self.use_freq_weights = use_freq_weights
|
| 217 |
+
else:
|
| 218 |
+
self.use_freq_weights = False
|
| 219 |
+
|
| 220 |
+
def forward(self, q):
|
| 221 |
+
# q = (batch, n_bands, n_time, emb_dim)
|
| 222 |
+
|
| 223 |
+
batch, n_bands, n_time, emb_dim = q.shape
|
| 224 |
+
|
| 225 |
+
masks = torch.zeros(
|
| 226 |
+
(batch, self.in_channels, self.n_freq, n_time),
|
| 227 |
+
device=q.device,
|
| 228 |
+
dtype=torch.complex64,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
for im in range(n_bands):
|
| 232 |
+
fstart, fend = self.band_specs[im]
|
| 233 |
+
|
| 234 |
+
mask = self.compute_mask(q, im)
|
| 235 |
+
|
| 236 |
+
if self.use_freq_weights:
|
| 237 |
+
fw = self.get_buffer(f"freq_weights/{im}")[:, None]
|
| 238 |
+
mask = mask * fw
|
| 239 |
+
masks[:, :, fstart:fend, :] += mask
|
| 240 |
+
|
| 241 |
+
return masks
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class MaskEstimationModule(OverlappingMaskEstimationModule):
|
| 245 |
+
def __init__(
|
| 246 |
+
self,
|
| 247 |
+
band_specs: List[Tuple[float, float]],
|
| 248 |
+
emb_dim: int,
|
| 249 |
+
mlp_dim: int,
|
| 250 |
+
in_channels: Optional[int],
|
| 251 |
+
hidden_activation: str = "Tanh",
|
| 252 |
+
hidden_activation_kwargs: Dict = None,
|
| 253 |
+
complex_mask: bool = True,
|
| 254 |
+
**kwargs,
|
| 255 |
+
) -> None:
|
| 256 |
+
check_nonzero_bandwidth(band_specs)
|
| 257 |
+
check_no_gap(band_specs)
|
| 258 |
+
check_no_overlap(band_specs)
|
| 259 |
+
super().__init__(
|
| 260 |
+
in_channels=in_channels,
|
| 261 |
+
band_specs=band_specs,
|
| 262 |
+
freq_weights=None,
|
| 263 |
+
n_freq=None,
|
| 264 |
+
emb_dim=emb_dim,
|
| 265 |
+
mlp_dim=mlp_dim,
|
| 266 |
+
hidden_activation=hidden_activation,
|
| 267 |
+
hidden_activation_kwargs=hidden_activation_kwargs,
|
| 268 |
+
complex_mask=complex_mask,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
def forward(self, q, cond=None):
|
| 272 |
+
# q = (batch, n_bands, n_time, emb_dim)
|
| 273 |
+
|
| 274 |
+
masks = self.compute_masks(
|
| 275 |
+
q
|
| 276 |
+
) # [n_bands * (batch, in_channels, bandwidth, n_time)]
|
| 277 |
+
|
| 278 |
+
# TODO: currently this requires band specs to have no gap and no overlap
|
| 279 |
+
masks = torch.concat(masks, dim=2) # (batch, in_channels, n_freq, n_time)
|
| 280 |
+
|
| 281 |
+
return masks
|
separator/models/bandit_v2/tfmodel.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.backends.cuda
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn.modules import rnn
|
| 7 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TimeFrequencyModellingModule(nn.Module):
|
| 11 |
+
def __init__(self) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ResidualRNN(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
emb_dim: int,
|
| 19 |
+
rnn_dim: int,
|
| 20 |
+
bidirectional: bool = True,
|
| 21 |
+
rnn_type: str = "LSTM",
|
| 22 |
+
use_batch_trick: bool = True,
|
| 23 |
+
use_layer_norm: bool = True,
|
| 24 |
+
) -> None:
|
| 25 |
+
# n_group is the size of the 2nd dim
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
assert use_layer_norm
|
| 29 |
+
assert use_batch_trick
|
| 30 |
+
|
| 31 |
+
self.use_layer_norm = use_layer_norm
|
| 32 |
+
self.norm = nn.LayerNorm(emb_dim)
|
| 33 |
+
self.rnn = rnn.__dict__[rnn_type](
|
| 34 |
+
input_size=emb_dim,
|
| 35 |
+
hidden_size=rnn_dim,
|
| 36 |
+
num_layers=1,
|
| 37 |
+
batch_first=True,
|
| 38 |
+
bidirectional=bidirectional,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.fc = nn.Linear(
|
| 42 |
+
in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self.use_batch_trick = use_batch_trick
|
| 46 |
+
if not self.use_batch_trick:
|
| 47 |
+
warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
|
| 48 |
+
|
| 49 |
+
def forward(self, z):
|
| 50 |
+
# z = (batch, n_uncrossed, n_across, emb_dim)
|
| 51 |
+
|
| 52 |
+
z0 = torch.clone(z)
|
| 53 |
+
z = self.norm(z)
|
| 54 |
+
|
| 55 |
+
batch, n_uncrossed, n_across, emb_dim = z.shape
|
| 56 |
+
z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
|
| 57 |
+
z = self.rnn(z)[0]
|
| 58 |
+
z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
|
| 59 |
+
|
| 60 |
+
z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
|
| 61 |
+
|
| 62 |
+
z = z + z0
|
| 63 |
+
|
| 64 |
+
return z
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class Transpose(nn.Module):
|
| 68 |
+
def __init__(self, dim0: int, dim1: int) -> None:
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.dim0 = dim0
|
| 71 |
+
self.dim1 = dim1
|
| 72 |
+
|
| 73 |
+
def forward(self, z):
|
| 74 |
+
return z.transpose(self.dim0, self.dim1)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class SeqBandModellingModule(TimeFrequencyModellingModule):
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
n_modules: int = 12,
|
| 81 |
+
emb_dim: int = 128,
|
| 82 |
+
rnn_dim: int = 256,
|
| 83 |
+
bidirectional: bool = True,
|
| 84 |
+
rnn_type: str = "LSTM",
|
| 85 |
+
parallel_mode=False,
|
| 86 |
+
) -> None:
|
| 87 |
+
super().__init__()
|
| 88 |
+
|
| 89 |
+
self.n_modules = n_modules
|
| 90 |
+
|
| 91 |
+
if parallel_mode:
|
| 92 |
+
self.seqband = nn.ModuleList([])
|
| 93 |
+
for _ in range(n_modules):
|
| 94 |
+
self.seqband.append(
|
| 95 |
+
nn.ModuleList(
|
| 96 |
+
[
|
| 97 |
+
ResidualRNN(
|
| 98 |
+
emb_dim=emb_dim,
|
| 99 |
+
rnn_dim=rnn_dim,
|
| 100 |
+
bidirectional=bidirectional,
|
| 101 |
+
rnn_type=rnn_type,
|
| 102 |
+
),
|
| 103 |
+
ResidualRNN(
|
| 104 |
+
emb_dim=emb_dim,
|
| 105 |
+
rnn_dim=rnn_dim,
|
| 106 |
+
bidirectional=bidirectional,
|
| 107 |
+
rnn_type=rnn_type,
|
| 108 |
+
),
|
| 109 |
+
]
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
else:
|
| 113 |
+
seqband = []
|
| 114 |
+
for _ in range(2 * n_modules):
|
| 115 |
+
seqband += [
|
| 116 |
+
ResidualRNN(
|
| 117 |
+
emb_dim=emb_dim,
|
| 118 |
+
rnn_dim=rnn_dim,
|
| 119 |
+
bidirectional=bidirectional,
|
| 120 |
+
rnn_type=rnn_type,
|
| 121 |
+
),
|
| 122 |
+
Transpose(1, 2),
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
self.seqband = nn.Sequential(*seqband)
|
| 126 |
+
|
| 127 |
+
self.parallel_mode = parallel_mode
|
| 128 |
+
|
| 129 |
+
def forward(self, z):
|
| 130 |
+
# z = (batch, n_bands, n_time, emb_dim)
|
| 131 |
+
|
| 132 |
+
if self.parallel_mode:
|
| 133 |
+
for sbm_pair in self.seqband:
|
| 134 |
+
# z: (batch, n_bands, n_time, emb_dim)
|
| 135 |
+
sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
|
| 136 |
+
zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
|
| 137 |
+
zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
|
| 138 |
+
z = zt + zf.transpose(1, 2)
|
| 139 |
+
else:
|
| 140 |
+
z = checkpoint_sequential(
|
| 141 |
+
self.seqband, self.n_modules, z, use_reentrant=False
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
q = z
|
| 145 |
+
return q # (batch, n_bands, n_time, emb_dim)
|
separator/models/bandit_v2/utils.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
+
from typing import Callable
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from librosa import hz_to_midi, midi_to_hz
|
| 8 |
+
from torchaudio import functional as taF
|
| 9 |
+
|
| 10 |
+
# from spafe.fbanks import bark_fbanks
|
| 11 |
+
# from spafe.utils.converters import erb2hz, hz2bark, hz2erb
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def band_widths_from_specs(band_specs):
|
| 15 |
+
return [e - i for i, e in band_specs]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def check_nonzero_bandwidth(band_specs):
|
| 19 |
+
# pprint(band_specs)
|
| 20 |
+
for fstart, fend in band_specs:
|
| 21 |
+
if fend - fstart <= 0:
|
| 22 |
+
raise ValueError("Bands cannot be zero-width")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def check_no_overlap(band_specs):
|
| 26 |
+
fend_prev = -1
|
| 27 |
+
for fstart_curr, fend_curr in band_specs:
|
| 28 |
+
if fstart_curr <= fend_prev:
|
| 29 |
+
raise ValueError("Bands cannot overlap")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def check_no_gap(band_specs):
|
| 33 |
+
fstart, _ = band_specs[0]
|
| 34 |
+
assert fstart == 0
|
| 35 |
+
|
| 36 |
+
fend_prev = -1
|
| 37 |
+
for fstart_curr, fend_curr in band_specs:
|
| 38 |
+
if fstart_curr - fend_prev > 1:
|
| 39 |
+
raise ValueError("Bands cannot leave gap")
|
| 40 |
+
fend_prev = fend_curr
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class BandsplitSpecification:
|
| 44 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 45 |
+
self.fs = fs
|
| 46 |
+
self.nfft = nfft
|
| 47 |
+
self.nyquist = fs / 2
|
| 48 |
+
self.max_index = nfft // 2 + 1
|
| 49 |
+
|
| 50 |
+
self.split500 = self.hertz_to_index(500)
|
| 51 |
+
self.split1k = self.hertz_to_index(1000)
|
| 52 |
+
self.split2k = self.hertz_to_index(2000)
|
| 53 |
+
self.split4k = self.hertz_to_index(4000)
|
| 54 |
+
self.split8k = self.hertz_to_index(8000)
|
| 55 |
+
self.split16k = self.hertz_to_index(16000)
|
| 56 |
+
self.split20k = self.hertz_to_index(20000)
|
| 57 |
+
|
| 58 |
+
self.above20k = [(self.split20k, self.max_index)]
|
| 59 |
+
self.above16k = [(self.split16k, self.split20k)] + self.above20k
|
| 60 |
+
|
| 61 |
+
def index_to_hertz(self, index: int):
|
| 62 |
+
return index * self.fs / self.nfft
|
| 63 |
+
|
| 64 |
+
def hertz_to_index(self, hz: float, round: bool = True):
|
| 65 |
+
index = hz * self.nfft / self.fs
|
| 66 |
+
|
| 67 |
+
if round:
|
| 68 |
+
index = int(np.round(index))
|
| 69 |
+
|
| 70 |
+
return index
|
| 71 |
+
|
| 72 |
+
def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
|
| 73 |
+
band_specs = []
|
| 74 |
+
lower = start_index
|
| 75 |
+
|
| 76 |
+
while lower < end_index:
|
| 77 |
+
upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
|
| 78 |
+
upper = min(upper, end_index)
|
| 79 |
+
|
| 80 |
+
band_specs.append((lower, upper))
|
| 81 |
+
lower = upper
|
| 82 |
+
|
| 83 |
+
return band_specs
|
| 84 |
+
|
| 85 |
+
@abstractmethod
|
| 86 |
+
def get_band_specs(self):
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class VocalBandsplitSpecification(BandsplitSpecification):
|
| 91 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 92 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 93 |
+
|
| 94 |
+
self.version = version
|
| 95 |
+
|
| 96 |
+
def get_band_specs(self):
|
| 97 |
+
return getattr(self, f"version{self.version}")()
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def version1(self):
|
| 101 |
+
return self.get_band_specs_with_bandwidth(
|
| 102 |
+
start_index=0, end_index=self.max_index, bandwidth_hz=1000
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def version2(self):
|
| 106 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 107 |
+
start_index=0, end_index=self.split16k, bandwidth_hz=1000
|
| 108 |
+
)
|
| 109 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 110 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
return below16k + below20k + self.above20k
|
| 114 |
+
|
| 115 |
+
def version3(self):
|
| 116 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 117 |
+
start_index=0, end_index=self.split8k, bandwidth_hz=1000
|
| 118 |
+
)
|
| 119 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 120 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return below8k + below16k + self.above16k
|
| 124 |
+
|
| 125 |
+
def version4(self):
|
| 126 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 127 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 128 |
+
)
|
| 129 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 130 |
+
start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
|
| 131 |
+
)
|
| 132 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 133 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
return below1k + below8k + below16k + self.above16k
|
| 137 |
+
|
| 138 |
+
def version5(self):
|
| 139 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 140 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 141 |
+
)
|
| 142 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 143 |
+
start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
|
| 144 |
+
)
|
| 145 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 146 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 147 |
+
)
|
| 148 |
+
return below1k + below16k + below20k + self.above20k
|
| 149 |
+
|
| 150 |
+
def version6(self):
|
| 151 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 152 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 153 |
+
)
|
| 154 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 155 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 156 |
+
)
|
| 157 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 158 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 159 |
+
)
|
| 160 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 161 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 162 |
+
)
|
| 163 |
+
return below1k + below4k + below8k + below16k + self.above16k
|
| 164 |
+
|
| 165 |
+
def version7(self):
|
| 166 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 167 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=100
|
| 168 |
+
)
|
| 169 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 170 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
|
| 171 |
+
)
|
| 172 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 173 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 174 |
+
)
|
| 175 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 176 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 177 |
+
)
|
| 178 |
+
below20k = self.get_band_specs_with_bandwidth(
|
| 179 |
+
start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
|
| 180 |
+
)
|
| 181 |
+
return below1k + below4k + below8k + below16k + below20k + self.above20k
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class OtherBandsplitSpecification(VocalBandsplitSpecification):
|
| 185 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 186 |
+
super().__init__(nfft=nfft, fs=fs, version="7")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class BassBandsplitSpecification(BandsplitSpecification):
|
| 190 |
+
def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
|
| 191 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 192 |
+
|
| 193 |
+
def get_band_specs(self):
|
| 194 |
+
below500 = self.get_band_specs_with_bandwidth(
|
| 195 |
+
start_index=0, end_index=self.split500, bandwidth_hz=50
|
| 196 |
+
)
|
| 197 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 198 |
+
start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
|
| 199 |
+
)
|
| 200 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 201 |
+
start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
|
| 202 |
+
)
|
| 203 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 204 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
|
| 205 |
+
)
|
| 206 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 207 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
|
| 208 |
+
)
|
| 209 |
+
above16k = [(self.split16k, self.max_index)]
|
| 210 |
+
|
| 211 |
+
return below500 + below1k + below4k + below8k + below16k + above16k
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class DrumBandsplitSpecification(BandsplitSpecification):
|
| 215 |
+
def __init__(self, nfft: int, fs: int) -> None:
|
| 216 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 217 |
+
|
| 218 |
+
def get_band_specs(self):
|
| 219 |
+
below1k = self.get_band_specs_with_bandwidth(
|
| 220 |
+
start_index=0, end_index=self.split1k, bandwidth_hz=50
|
| 221 |
+
)
|
| 222 |
+
below2k = self.get_band_specs_with_bandwidth(
|
| 223 |
+
start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
|
| 224 |
+
)
|
| 225 |
+
below4k = self.get_band_specs_with_bandwidth(
|
| 226 |
+
start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
|
| 227 |
+
)
|
| 228 |
+
below8k = self.get_band_specs_with_bandwidth(
|
| 229 |
+
start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
|
| 230 |
+
)
|
| 231 |
+
below16k = self.get_band_specs_with_bandwidth(
|
| 232 |
+
start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
|
| 233 |
+
)
|
| 234 |
+
above16k = [(self.split16k, self.max_index)]
|
| 235 |
+
|
| 236 |
+
return below1k + below2k + below4k + below8k + below16k + above16k
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class PerceptualBandsplitSpecification(BandsplitSpecification):
|
| 240 |
+
def __init__(
|
| 241 |
+
self,
|
| 242 |
+
nfft: int,
|
| 243 |
+
fs: int,
|
| 244 |
+
fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
|
| 245 |
+
n_bands: int,
|
| 246 |
+
f_min: float = 0.0,
|
| 247 |
+
f_max: float = None,
|
| 248 |
+
) -> None:
|
| 249 |
+
super().__init__(nfft=nfft, fs=fs)
|
| 250 |
+
self.n_bands = n_bands
|
| 251 |
+
if f_max is None:
|
| 252 |
+
f_max = fs / 2
|
| 253 |
+
|
| 254 |
+
self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
|
| 255 |
+
|
| 256 |
+
weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs)
|
| 257 |
+
normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
|
| 258 |
+
|
| 259 |
+
freq_weights = []
|
| 260 |
+
band_specs = []
|
| 261 |
+
for i in range(self.n_bands):
|
| 262 |
+
active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
|
| 263 |
+
if isinstance(active_bins, int):
|
| 264 |
+
active_bins = (active_bins, active_bins)
|
| 265 |
+
if len(active_bins) == 0:
|
| 266 |
+
continue
|
| 267 |
+
start_index = active_bins[0]
|
| 268 |
+
end_index = active_bins[-1] + 1
|
| 269 |
+
band_specs.append((start_index, end_index))
|
| 270 |
+
freq_weights.append(normalized_mel_fb[i, start_index:end_index])
|
| 271 |
+
|
| 272 |
+
self.freq_weights = freq_weights
|
| 273 |
+
self.band_specs = band_specs
|
| 274 |
+
|
| 275 |
+
def get_band_specs(self):
|
| 276 |
+
return self.band_specs
|
| 277 |
+
|
| 278 |
+
def get_freq_weights(self):
|
| 279 |
+
return self.freq_weights
|
| 280 |
+
|
| 281 |
+
def save_to_file(self, dir_path: str) -> None:
|
| 282 |
+
os.makedirs(dir_path, exist_ok=True)
|
| 283 |
+
|
| 284 |
+
import pickle
|
| 285 |
+
|
| 286 |
+
with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
|
| 287 |
+
pickle.dump(
|
| 288 |
+
{
|
| 289 |
+
"band_specs": self.band_specs,
|
| 290 |
+
"freq_weights": self.freq_weights,
|
| 291 |
+
"filterbank": self.filterbank,
|
| 292 |
+
},
|
| 293 |
+
f,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
|
| 298 |
+
fb = taF.melscale_fbanks(
|
| 299 |
+
n_mels=n_bands,
|
| 300 |
+
sample_rate=fs,
|
| 301 |
+
f_min=f_min,
|
| 302 |
+
f_max=f_max,
|
| 303 |
+
n_freqs=n_freqs,
|
| 304 |
+
).T
|
| 305 |
+
|
| 306 |
+
fb[0, 0] = 1.0
|
| 307 |
+
|
| 308 |
+
return fb
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class MelBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 312 |
+
def __init__(
|
| 313 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 314 |
+
) -> None:
|
| 315 |
+
super().__init__(
|
| 316 |
+
fbank_fn=mel_filterbank,
|
| 317 |
+
nfft=nfft,
|
| 318 |
+
fs=fs,
|
| 319 |
+
n_bands=n_bands,
|
| 320 |
+
f_min=f_min,
|
| 321 |
+
f_max=f_max,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
|
| 326 |
+
nfft = 2 * (n_freqs - 1)
|
| 327 |
+
df = fs / nfft
|
| 328 |
+
# init freqs
|
| 329 |
+
f_max = f_max or fs / 2
|
| 330 |
+
f_min = f_min or 0
|
| 331 |
+
f_min = fs / nfft
|
| 332 |
+
|
| 333 |
+
n_octaves = np.log2(f_max / f_min)
|
| 334 |
+
n_octaves_per_band = n_octaves / n_bands
|
| 335 |
+
bandwidth_mult = np.power(2.0, n_octaves_per_band)
|
| 336 |
+
|
| 337 |
+
low_midi = max(0, hz_to_midi(f_min))
|
| 338 |
+
high_midi = hz_to_midi(f_max)
|
| 339 |
+
midi_points = np.linspace(low_midi, high_midi, n_bands)
|
| 340 |
+
hz_pts = midi_to_hz(midi_points)
|
| 341 |
+
|
| 342 |
+
low_pts = hz_pts / bandwidth_mult
|
| 343 |
+
high_pts = hz_pts * bandwidth_mult
|
| 344 |
+
|
| 345 |
+
low_bins = np.floor(low_pts / df).astype(int)
|
| 346 |
+
high_bins = np.ceil(high_pts / df).astype(int)
|
| 347 |
+
|
| 348 |
+
fb = np.zeros((n_bands, n_freqs))
|
| 349 |
+
|
| 350 |
+
for i in range(n_bands):
|
| 351 |
+
fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
|
| 352 |
+
|
| 353 |
+
fb[0, : low_bins[0]] = 1.0
|
| 354 |
+
fb[-1, high_bins[-1] + 1 :] = 1.0
|
| 355 |
+
|
| 356 |
+
return torch.as_tensor(fb)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 360 |
+
def __init__(
|
| 361 |
+
self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
|
| 362 |
+
) -> None:
|
| 363 |
+
super().__init__(
|
| 364 |
+
fbank_fn=musical_filterbank,
|
| 365 |
+
nfft=nfft,
|
| 366 |
+
fs=fs,
|
| 367 |
+
n_bands=n_bands,
|
| 368 |
+
f_min=f_min,
|
| 369 |
+
f_max=f_max,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
# def bark_filterbank(
|
| 374 |
+
# n_bands, fs, f_min, f_max, n_freqs
|
| 375 |
+
# ):
|
| 376 |
+
# nfft = 2 * (n_freqs -1)
|
| 377 |
+
# fb, _ = bark_fbanks.bark_filter_banks(
|
| 378 |
+
# nfilts=n_bands,
|
| 379 |
+
# nfft=nfft,
|
| 380 |
+
# fs=fs,
|
| 381 |
+
# low_freq=f_min,
|
| 382 |
+
# high_freq=f_max,
|
| 383 |
+
# scale="constant"
|
| 384 |
+
# )
|
| 385 |
+
|
| 386 |
+
# return torch.as_tensor(fb)
|
| 387 |
+
|
| 388 |
+
# class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 389 |
+
# def __init__(
|
| 390 |
+
# self,
|
| 391 |
+
# nfft: int,
|
| 392 |
+
# fs: int,
|
| 393 |
+
# n_bands: int,
|
| 394 |
+
# f_min: float = 0.0,
|
| 395 |
+
# f_max: float = None
|
| 396 |
+
# ) -> None:
|
| 397 |
+
# super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
# def triangular_bark_filterbank(
|
| 401 |
+
# n_bands, fs, f_min, f_max, n_freqs
|
| 402 |
+
# ):
|
| 403 |
+
|
| 404 |
+
# all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
| 405 |
+
|
| 406 |
+
# # calculate mel freq bins
|
| 407 |
+
# m_min = hz2bark(f_min)
|
| 408 |
+
# m_max = hz2bark(f_max)
|
| 409 |
+
|
| 410 |
+
# m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
| 411 |
+
# f_pts = 600 * torch.sinh(m_pts / 6)
|
| 412 |
+
|
| 413 |
+
# # create filterbank
|
| 414 |
+
# fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 415 |
+
|
| 416 |
+
# fb = fb.T
|
| 417 |
+
|
| 418 |
+
# first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
| 419 |
+
# first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
| 420 |
+
|
| 421 |
+
# fb[first_active_band, :first_active_bin] = 1.0
|
| 422 |
+
|
| 423 |
+
# return fb
|
| 424 |
+
|
| 425 |
+
# class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 426 |
+
# def __init__(
|
| 427 |
+
# self,
|
| 428 |
+
# nfft: int,
|
| 429 |
+
# fs: int,
|
| 430 |
+
# n_bands: int,
|
| 431 |
+
# f_min: float = 0.0,
|
| 432 |
+
# f_max: float = None
|
| 433 |
+
# ) -> None:
|
| 434 |
+
# super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
# def minibark_filterbank(
|
| 438 |
+
# n_bands, fs, f_min, f_max, n_freqs
|
| 439 |
+
# ):
|
| 440 |
+
# fb = bark_filterbank(
|
| 441 |
+
# n_bands,
|
| 442 |
+
# fs,
|
| 443 |
+
# f_min,
|
| 444 |
+
# f_max,
|
| 445 |
+
# n_freqs
|
| 446 |
+
# )
|
| 447 |
+
|
| 448 |
+
# fb[fb < np.sqrt(0.5)] = 0.0
|
| 449 |
+
|
| 450 |
+
# return fb
|
| 451 |
+
|
| 452 |
+
# class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 453 |
+
# def __init__(
|
| 454 |
+
# self,
|
| 455 |
+
# nfft: int,
|
| 456 |
+
# fs: int,
|
| 457 |
+
# n_bands: int,
|
| 458 |
+
# f_min: float = 0.0,
|
| 459 |
+
# f_max: float = None
|
| 460 |
+
# ) -> None:
|
| 461 |
+
# super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
# def erb_filterbank(
|
| 465 |
+
# n_bands: int,
|
| 466 |
+
# fs: int,
|
| 467 |
+
# f_min: float,
|
| 468 |
+
# f_max: float,
|
| 469 |
+
# n_freqs: int,
|
| 470 |
+
# ) -> Tensor:
|
| 471 |
+
# # freq bins
|
| 472 |
+
# A = (1000 * np.log(10)) / (24.7 * 4.37)
|
| 473 |
+
# all_freqs = torch.linspace(0, fs // 2, n_freqs)
|
| 474 |
+
|
| 475 |
+
# # calculate mel freq bins
|
| 476 |
+
# m_min = hz2erb(f_min)
|
| 477 |
+
# m_max = hz2erb(f_max)
|
| 478 |
+
|
| 479 |
+
# m_pts = torch.linspace(m_min, m_max, n_bands + 2)
|
| 480 |
+
# f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437
|
| 481 |
+
|
| 482 |
+
# # create filterbank
|
| 483 |
+
# fb = _create_triangular_filterbank(all_freqs, f_pts)
|
| 484 |
+
|
| 485 |
+
# fb = fb.T
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
# first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
|
| 489 |
+
# first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
|
| 490 |
+
|
| 491 |
+
# fb[first_active_band, :first_active_bin] = 1.0
|
| 492 |
+
|
| 493 |
+
# return fb
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
# class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
|
| 497 |
+
# def __init__(
|
| 498 |
+
# self,
|
| 499 |
+
# nfft: int,
|
| 500 |
+
# fs: int,
|
| 501 |
+
# n_bands: int,
|
| 502 |
+
# f_min: float = 0.0,
|
| 503 |
+
# f_max: float = None
|
| 504 |
+
# ) -> None:
|
| 505 |
+
# super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
|
| 506 |
+
|
| 507 |
+
if __name__ == "__main__":
|
| 508 |
+
import pandas as pd
|
| 509 |
+
|
| 510 |
+
band_defs = []
|
| 511 |
+
|
| 512 |
+
for bands in [VocalBandsplitSpecification]:
|
| 513 |
+
band_name = bands.__name__.replace("BandsplitSpecification", "")
|
| 514 |
+
|
| 515 |
+
mbs = bands(nfft=2048, fs=44100).get_band_specs()
|
| 516 |
+
|
| 517 |
+
for i, (f_min, f_max) in enumerate(mbs):
|
| 518 |
+
band_defs.append(
|
| 519 |
+
{"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
df = pd.DataFrame(band_defs)
|
| 523 |
+
df.to_csv("vox7bands.csv", index=False)
|
separator/models/bs_roformer/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.bs_roformer.bs_roformer import BSRoformer
|
| 2 |
+
from models.bs_roformer.mel_band_roformer import MelBandRoformer
|
separator/models/bs_roformer/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (281 Bytes). View file
|
|
|