noblebarkrr commited on
Commit
d0cd3b0
·
1 Parent(s): 3ccdc25
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +131 -0
  2. assets/translations.py +171 -0
  3. model_list.py +0 -0
  4. multi_inference.py +303 -0
  5. requirements.txt +50 -0
  6. separator/audio_writer.py +85 -0
  7. separator/ensemble.py +192 -0
  8. separator/models/bandit/core/__init__.py +744 -0
  9. separator/models/bandit/core/data/__init__.py +2 -0
  10. separator/models/bandit/core/data/_types.py +18 -0
  11. separator/models/bandit/core/data/augmentation.py +107 -0
  12. separator/models/bandit/core/data/augmented.py +35 -0
  13. separator/models/bandit/core/data/base.py +69 -0
  14. separator/models/bandit/core/data/dnr/__init__.py +0 -0
  15. separator/models/bandit/core/data/dnr/datamodule.py +74 -0
  16. separator/models/bandit/core/data/dnr/dataset.py +392 -0
  17. separator/models/bandit/core/data/dnr/preprocess.py +54 -0
  18. separator/models/bandit/core/data/musdb/__init__.py +0 -0
  19. separator/models/bandit/core/data/musdb/datamodule.py +77 -0
  20. separator/models/bandit/core/data/musdb/dataset.py +280 -0
  21. separator/models/bandit/core/data/musdb/preprocess.py +238 -0
  22. separator/models/bandit/core/data/musdb/validation.yaml +15 -0
  23. separator/models/bandit/core/loss/__init__.py +2 -0
  24. separator/models/bandit/core/loss/_complex.py +34 -0
  25. separator/models/bandit/core/loss/_multistem.py +45 -0
  26. separator/models/bandit/core/loss/_timefreq.py +113 -0
  27. separator/models/bandit/core/loss/snr.py +146 -0
  28. separator/models/bandit/core/metrics/__init__.py +9 -0
  29. separator/models/bandit/core/metrics/_squim.py +383 -0
  30. separator/models/bandit/core/metrics/snr.py +150 -0
  31. separator/models/bandit/core/model/__init__.py +3 -0
  32. separator/models/bandit/core/model/_spectral.py +58 -0
  33. separator/models/bandit/core/model/bsrnn/__init__.py +23 -0
  34. separator/models/bandit/core/model/bsrnn/bandsplit.py +139 -0
  35. separator/models/bandit/core/model/bsrnn/core.py +661 -0
  36. separator/models/bandit/core/model/bsrnn/maskestim.py +347 -0
  37. separator/models/bandit/core/model/bsrnn/tfmodel.py +317 -0
  38. separator/models/bandit/core/model/bsrnn/utils.py +583 -0
  39. separator/models/bandit/core/model/bsrnn/wrapper.py +882 -0
  40. separator/models/bandit/core/utils/__init__.py +0 -0
  41. separator/models/bandit/core/utils/audio.py +463 -0
  42. separator/models/bandit/model_from_config.py +31 -0
  43. separator/models/bandit_v2/bandit.py +367 -0
  44. separator/models/bandit_v2/bandsplit.py +130 -0
  45. separator/models/bandit_v2/film.py +25 -0
  46. separator/models/bandit_v2/maskestim.py +281 -0
  47. separator/models/bandit_v2/tfmodel.py +145 -0
  48. separator/models/bandit_v2/utils.py +523 -0
  49. separator/models/bs_roformer/__init__.py +2 -0
  50. separator/models/bs_roformer/__pycache__/__init__.cpython-310.pyc +0 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import shutil
5
+ import argparse
6
+ from datetime import datetime
7
+ import gradio as gr
8
+ os.system("pip install https://github.com/noblebarkrr/mvsepless/blob/bd611441e48e918650e6860738894673b3a1a5f1/fixed/audio_separator-0.32.0-py3-none-any.whl")
9
+ from multi_inference import MVSEPLESS, OUTPUT_FORMATS
10
+ from assets.translations import TRANSLATIONS, TRANSLATIONS_STEMS
11
+
12
+ OUTPUT_DIR = os.path.join(os.getcwd(), "output")
13
+ plugins_dir = os.path.join(os.getcwd(), "plugins")
14
+ os.makedirs(plugins_dir, exist_ok=True)
15
+
16
+ CURRENT_LANG = "ru"
17
+
18
+ def t(key, **kwargs):
19
+ """Функция для получения перевода с подстановкой значений"""
20
+ lang = CURRENT_LANG
21
+ translation = TRANSLATIONS.get(lang, {}).get(key, key)
22
+ return translation.format(**kwargs) if kwargs else translation
23
+
24
+ def t_stem(key, **kwargs):
25
+ """Функция для получения перевода с подстановкой значений"""
26
+ lang = CURRENT_LANG
27
+ translation = TRANSLATIONS_STEMS.get(lang, {}).get(key, key)
28
+ return translation.format(**kwargs) if kwargs else translation
29
+
30
+ def gen_out_dir():
31
+ return os.path.join(OUTPUT_DIR, datetime.now().strftime("%Y%m%d_%H%M%S"))
32
+
33
+ mvsepless = MVSEPLESS()
34
+
35
+ def sep_wrapper(a, b, c, d, e, f, g, h):
36
+ results = mvsepless.separator(input_file=a, output_dir=gen_out_dir(), model_type=b, model_name=c, ext_inst=d, vr_aggr=e, output_format=f, output_bitrate=f'{g}k', call_method="cli", selected_stems=h)
37
+ stems = []
38
+ if results:
39
+ for i, (stem, output_file) in enumerate(results[:20]):
40
+ stems.append(gr.update(
41
+ visible=True,
42
+ label=t_stem(stem),
43
+ value=output_file
44
+ ))
45
+
46
+ while len(stems) < 20:
47
+ stems.append(gr.update(visible=False, label=None, value=None))
48
+
49
+ return tuple(stems)
50
+
51
+
52
+ theme = gr.themes.Default(
53
+ primary_hue="violet",
54
+ secondary_hue="cyan",
55
+ neutral_hue="blue",
56
+ spacing_size="sm",
57
+ font=[gr.themes.GoogleFont("Tektur"), 'ui-sans-serif', 'system-ui', 'sans-serif'],
58
+ ).set(
59
+ body_text_color='*neutral_950',
60
+ body_text_color_subdued='*neutral_500',
61
+ background_fill_primary='*neutral_200',
62
+ background_fill_primary_dark='*neutral_800',
63
+ border_color_accent='*primary_950',
64
+ border_color_accent_dark='*neutral_700',
65
+ border_color_accent_subdued='*primary_500',
66
+ border_color_primary='*primary_800',
67
+ border_color_primary_dark='*neutral_400',
68
+ color_accent_soft='*primary_100',
69
+ color_accent_soft_dark='*neutral_800',
70
+ link_text_color='*secondary_700',
71
+ link_text_color_active='*secondary_700',
72
+ link_text_color_hover='*secondary_800',
73
+ link_text_color_visited='*secondary_600',
74
+ link_text_color_visited_dark='*secondary_700',
75
+ block_background_fill='*background_fill_secondary',
76
+ block_background_fill_dark='*neutral_950',
77
+ block_label_background_fill='*secondary_400',
78
+ block_label_text_color='*neutral_800',
79
+ panel_background_fill='*background_fill_primary',
80
+ checkbox_background_color='*background_fill_secondary',
81
+ checkbox_label_background_fill_dark='*neutral_900',
82
+ input_background_fill_dark='*neutral_900',
83
+ input_background_fill_focus='*neutral_100',
84
+ input_background_fill_focus_dark='*neutral_950',
85
+ button_small_radius='*radius_sm',
86
+ button_secondary_background_fill='*neutral_400',
87
+ button_secondary_background_fill_dark='*neutral_500',
88
+ button_secondary_background_fill_hover_dark='*neutral_950'
89
+ )
90
+
91
+
92
+ def create_app():
93
+ with gr.Row():
94
+ with gr.Column():
95
+ input_audio = gr.Audio(label=t("select_file"), interactive=True, type="filepath")
96
+ input_audio_path = gr.Textbox(label=t("audio_path"), info=t("audio_path_info"), interactive=True)
97
+ with gr.Column():
98
+ with gr.Row():
99
+ model_type = gr.Dropdown(label=t("model_type"), choices=mvsepless.get_mt(), value=mvsepless.get_mt()[0], interactive=True, filterable=False)
100
+ model_name = gr.Dropdown(label=t("model_name"), choices=mvsepless.get_mn(mvsepless.get_mt()[0]), value=mvsepless.get_mn(mvsepless.get_mt()[0])[0], interactive=True, filterable=False)
101
+ target_instrument = gr.Textbox(label=t("target_instrument"), value=mvsepless.get_tgt_inst(mvsepless.get_mt()[0], mvsepless.get_mn(mvsepless.get_mt()[0])[0]), interactive=False)
102
+ vr_aggr = gr.Slider(0, 100, step=1, label=t("vr_aggressiveness"), visible=False, value=5, interactive=True)
103
+ extract_instrumental = gr.Checkbox(label=t("extract_instrumental"), value=True, interactive=True)
104
+ stems_list = gr.CheckboxGroup(label=t("stems_list"), info=t("stems_info", target_instrument="vocals"), choices=mvsepless.get_stems(mvsepless.get_mt()[0], mvsepless.get_mn(mvsepless.get_mt()[0])[0]), value=None, interactive=False)
105
+ with gr.Row():
106
+ output_format, output_bitrate = gr.Dropdown(label=t("output_format"), choices=OUTPUT_FORMATS, value="mp3", interactive=True, filterable=False), gr.Slider(32, 320, step=1, label=t("bitrate"), value=320, interactive=True)
107
+ separate_btn = gr.Button(t("separate_btn"), variant="primary", interactive=True)
108
+ download_via_zip_btn = gr.DownloadButton(label="Download via zip", visible=False, interactive=True)
109
+ output_stems = []
110
+ for _ in range(10):
111
+ with gr.Row():
112
+ audio1 = gr.Audio(visible=False, interactive=False, type="filepath", show_download_button=True)
113
+ audio2 = gr.Audio(visible=False, interactive=False, type="filepath", show_download_button=True)
114
+ output_stems.extend([audio1, audio2])
115
+
116
+ input_audio.upload(fn=(lambda x: gr.update(value=x)), inputs=input_audio, outputs=input_audio_path)
117
+ model_type.change(fn=(lambda x: gr.update(choices=mvsepless.get_mn(x), value=mvsepless.get_mn(x)[0])), inputs=model_type, outputs=model_name).then(fn=(lambda x: (gr.update(visible=False if x in ["vr", "mdx"] else True), gr.update(visible=True if x == "vr" else False))), inputs=model_type, outputs=[extract_instrumental, vr_aggr])
118
+ model_name.change(fn=(lambda x, y: gr.update(choices=mvsepless.get_stems(x, y), value=None)), inputs=[model_type, model_name], outputs=stems_list).then(fn=(lambda x, y: (gr.update(interactive=True if mvsepless.get_tgt_inst(x, y) == None else None, info=t("stems_info", target_instrument=mvsepless.get_tgt_inst(x, y)) if mvsepless.get_tgt_inst(x, y) is not None else t("stems_info2")), gr.update(value=mvsepless.get_tgt_inst(x, y)), gr.update(value=True if mvsepless.get_tgt_inst(x, y) is not None else False))), inputs=[model_type, model_name], outputs=[stems_list, target_instrument, extract_instrumental])
119
+ separate_btn.click(fn=sep_wrapper, inputs=[input_audio_path, model_type, model_name, extract_instrumental, vr_aggr, output_format, output_bitrate, stems_list], outputs=output_stems, show_progress_on=input_audio)
120
+
121
+
122
+ CURRENT_LANG = ru
123
+ css = """
124
+ .fixed-height { height: 160px !important; min-height: 160px !important; }
125
+ .fixed-height2 { height: 250px !important; min-height: 250px !important; }
126
+ """
127
+
128
+ with gr.Blocks(theme=theme, css=css) as app:
129
+ create_app()
130
+
131
+ app.launch(allowed_paths=["/"], server_port=7860, share=False)
assets/translations.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRANSLATIONS = {
2
+ "ru": {
3
+ "app_title": "MVSEPLESS",
4
+ "separation": "Разделение",
5
+ "plugins": "Плагины",
6
+ "select_file": "Выберите файл",
7
+ "audio_path": "Путь к файлу",
8
+ "audio_path_info": "Здесь можно ввести путь к файлу, либо загрузить его выше и получить путь к загруженному файлу",
9
+ "model_type": "Тип модели",
10
+ "model_name": "Имя модели",
11
+ "vr_aggressiveness": "Агрессивность для VR моделей",
12
+ "extract_instrumental": "Извлечь инструментал",
13
+ "stems_list": "Список стемов",
14
+ "output_format": "Формат вывода",
15
+ "separate_btn": "Разделить",
16
+ "upload": "Загрузка плагинов (.py)",
17
+ "upload_btn": "Загрузить",
18
+ "loading_plugin": "Загружается плагин: {name}",
19
+ "error_loading_plugin": "Произошла ошибка при загрузке плагина: {e}",
20
+ "target_instrument": "Целевой инструмент",
21
+ "stems_info": "Выбор стемов недоступен\nДля извлечения второго стема включите \"Извлечь инструментал\"",
22
+ "stems_info2": "Для получения остатка (при выбранных стемах), включите \"Извлечь инструментал\"",
23
+ "bitrate": "Битрейт (Кбит/сек)"
24
+ },
25
+ "en": {
26
+ "app_title": "MVSEPLESS",
27
+ "separation": "Separation",
28
+ "plugins": "Plugins",
29
+ "select_file": "Select File",
30
+ "audio_path": "Audio path",
31
+ "audio_path_info": "You can enter the file path here, or upload it above and get the path to the uploaded file.",
32
+ "model_type": "Model Type",
33
+ "model_name": "Model Name",
34
+ "vr_aggressiveness": "Aggressiveness for VR Models",
35
+ "extract_instrumental": "Extract Instrumental",
36
+ "stems_list": "Stems List",
37
+ "output_format": "Output Format",
38
+ "separate_btn": "Separate",
39
+ "upload": "Upload plugins (.py)",
40
+ "upload_btn": "Upload",
41
+ "loading_plugin": "Loading plugin: {name}",
42
+ "error_loading_plugin": "As error occured loading plugin: {e}",
43
+ "target_instrument": "Target instrument",
44
+ "stems_info": "Stem selection unavailable\nEnable \"Extract Instrumental\" to extract the second stem",
45
+ "stems_info2": "To extract the residual (with selected_stems), enable \"Extract Instrumental\"",
46
+ "bitrate": "Bitrate (Kbit/sec)"
47
+ }
48
+ }
49
+
50
+ TRANSLATIONS_STEMS = {
51
+ "ru": {
52
+ "vocals": "Вокал",
53
+ "Vocals": "Вокал",
54
+ "other": "Другое",
55
+ "Other": "Другое",
56
+ "Instrumental": "Инструментал",
57
+ "instrumnetal": "Инструментал",
58
+ "instrumental +": "Инструментал +",
59
+ "instrumental -": "Инструментал -",
60
+ "Bleed": "Фон",
61
+ "Guitar": "Гитара",
62
+ "drums": "Барабаны",
63
+ "bass": "Бас",
64
+ "karaoke": "Караоке",
65
+ "reverb": "Реверберация",
66
+ "noreverb": "Без реверберации",
67
+ "aspiration": "Придыхание",
68
+ "dry": "Сухой звук",
69
+ "crowd": "Толпа",
70
+ "percussions": "Перкуссия",
71
+ "piano": "Пианино",
72
+ "guitar": "Гитара",
73
+ "male": "Мужской",
74
+ "female": "Женский",
75
+ "kick": "Кик",
76
+ "snare": "Малый барабан",
77
+ "toms": "Том-томы",
78
+ "hh": "Хай-хэт",
79
+ "ride": "Райд",
80
+ "crash": "Крэш",
81
+ "similarity": "Сходство",
82
+ "difference": "Различие",
83
+ "inst": "Инструмент",
84
+ "orch": "Оркестр",
85
+ "No Woodwinds": "Без деревянных духовых",
86
+ "Woodwinds": "Деревянные духовые",
87
+ "No Echo": "Без эха",
88
+ "Echo": "Эхо",
89
+ "No Reverb": "Без реверберации",
90
+ "Reverb": "Реверберация",
91
+ "Noise": "Шум",
92
+ "No Noise": "Без шума",
93
+ "Dry": "Сухой звук",
94
+ "No Dry": "Не сухой звук",
95
+ "Breath": "Дыхание",
96
+ "No Breath": "Без дыхания",
97
+ "No Crowd": "Без толпы",
98
+ "Crowd": "Толпа",
99
+ "No Other": "Без другого",
100
+ "Bass": "Бас",
101
+ "No Bass": "Без баса",
102
+ "Drums": "Барабаны",
103
+ "No Drums": "Без барабанов",
104
+ "speech": "Речь",
105
+ "music": "Музыка",
106
+ "effects": "Эффекты",
107
+ "sfx": "Звуковые эффекты",
108
+ "inverted +": "Инверсия +",
109
+ "inverted -": "Инверсия -"
110
+ },
111
+ "en": {
112
+ "vocals": "Vocals",
113
+ "Vocals": "Vocals",
114
+ "other": "Other",
115
+ "Other": "Other",
116
+ "Instrumental": "Instrumental",
117
+ "instrumnetal": "Instrumental",
118
+ "instrumental +": "Instrumental +",
119
+ "instrumental -": "Instrumental -",
120
+ "Bleed": "Bleed",
121
+ "Guitar": "Guitar",
122
+ "drums": "Drums",
123
+ "bass": "Bass",
124
+ "karaoke": "Karaoke",
125
+ "reverb": "Reverb",
126
+ "noreverb": "No reverb",
127
+ "aspiration": "Aspiration",
128
+ "dry": "Dry",
129
+ "crowd": "Crowd",
130
+ "percussions": "Percussions",
131
+ "piano": "Piano",
132
+ "guitar": "Guitar",
133
+ "male": "Male",
134
+ "female": "Female",
135
+ "kick": "Kick",
136
+ "snare": "Snare",
137
+ "toms": "Toms",
138
+ "hh": "Hi-hat",
139
+ "ride": "Ride",
140
+ "crash": "Crash",
141
+ "similarity": "Similarity",
142
+ "difference": "Difference",
143
+ "inst": "Instrument",
144
+ "orch": "Orchestra",
145
+ "No Woodwinds": "No Woodwinds",
146
+ "Woodwinds": "Woodwinds",
147
+ "No Echo": "No Echo",
148
+ "Echo": "Echo",
149
+ "No Reverb": "No Reverb",
150
+ "Reverb": "Reverb",
151
+ "Noise": "Noise",
152
+ "No Noise": "No Noise",
153
+ "Dry": "Dry",
154
+ "No Dry": "No Dry",
155
+ "Breath": "Breath",
156
+ "No Breath": "No Breath",
157
+ "No Crowd": "No Crowd",
158
+ "Crowd": "Crowd",
159
+ "No Other": "No Other",
160
+ "Bass": "Bass",
161
+ "No Bass": "No Bass",
162
+ "Drums": "Drums",
163
+ "No Drums": "No Drums",
164
+ "speech": "Speech",
165
+ "music": "Music",
166
+ "effects": "Effects",
167
+ "sfx": "SFX",
168
+ "inverted +": "Inverted +",
169
+ "inverted -": "Inverted -"
170
+ }
171
+ }
model_list.py ADDED
The diff for this file is too large to render. See raw diff
 
multi_inference.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import shutil
4
+ import sys
5
+ import gc
6
+ import argparse
7
+ import json
8
+ import subprocess
9
+ from datetime import datetime
10
+ from tabulate import tabulate
11
+
12
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
13
+ sys.path.append(SCRIPT_DIR)
14
+ os.chdir(SCRIPT_DIR)
15
+
16
+ from model_list import models_data
17
+ from utils.preedit_config import conf_editor
18
+ from utils.download_models import download_model
19
+
20
+ MODELS_CACHE_DIR = os.path.join(SCRIPT_DIR, "separator", "models_cache")
21
+ MODEL_TYPES = ["mel_band_roformer", "bs_roformer", "mdx23c", "scnet", "htdemucs", "bandit", "bandit_v2", "vr", "mdx"]
22
+ OUTPUT_FORMATS = ["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff"]
23
+
24
+ class MVSEPLESS:
25
+ def __init__(self):
26
+ self.models_cache_dir = os.path.join(SCRIPT_DIR, "separator", "models_cache")
27
+ self.model_types = MODEL_TYPES
28
+ self.output_formats = OUTPUT_FORMATS
29
+
30
+ def get_mt(self):
31
+ return list(models_data.keys())
32
+
33
+ def get_mn(self, model_type):
34
+ return list(models_data[model_type].keys())
35
+
36
+ def get_stems(self, model_type, model_name):
37
+ stems = models_data[model_type][model_name]["stems"]
38
+ return stems
39
+
40
+ def get_tgt_inst(self, model_type, model_name):
41
+ target_instrument = models_data[model_type][model_name]["target_instrument"]
42
+ return target_instrument
43
+
44
+ def display_models_info(self, filter: str = None):
45
+ print("\nAvailable Models Information:")
46
+ print("=" * 50)
47
+
48
+ for model_type in models_data:
49
+ print(f"\nModel Type: {model_type.upper()}")
50
+ print("-" * 50)
51
+
52
+ table_data = []
53
+ headers = ["Model Name", "Stems", "Target Instrument", "Primary Stem"]
54
+
55
+ for model_name in models_data[model_type]:
56
+ model_info = models_data[model_type][model_name]
57
+
58
+ if filter and filter not in model_info.get('stems', []):
59
+ continue
60
+
61
+ stems = "\n".join(model_info.get('stems', [])) if 'stems' in model_info else "N/A"
62
+ target = model_info.get('target_instrument', "N/A")
63
+ primary = model_info.get('primary_stem', "N/A")
64
+
65
+ table_data.append([model_name, stems, target, primary])
66
+
67
+ print(tabulate(table_data, headers=headers, tablefmt="grid"))
68
+ print()
69
+
70
+ def separator(
71
+ self,
72
+ input_file: str = None,
73
+ output_dir: str = None,
74
+ model_type: str = "mel_band_roformer",
75
+ model_name: str = "Mel-Band-Roformer_Vocals_kimberley_jensen",
76
+ ext_inst: bool = False,
77
+ vr_aggr: int = 5,
78
+ output_format: str = "wav",
79
+ output_bitrate: str = "320k",
80
+ template: str = "NAME_(STEM)_MODEL",
81
+ call_method: str = "cli",
82
+ selected_stems: list = None
83
+ ):
84
+ if selected_stems is None:
85
+ selected_stems = []
86
+
87
+ if not input_file:
88
+ print("Please, input path to input file")
89
+ return [("None", "/none/none.mp3")]
90
+
91
+ if not os.path.exists(input_file):
92
+ print("Input file not exist")
93
+ return [("None", "/none/none.mp3")]
94
+
95
+ if "STEM" not in template:
96
+ template = template + "_STEM"
97
+
98
+ print(f"Starting inference: {model_type}/{model_name}, bitrate={output_bitrate}, method={call_method}, stems={selected_stems}")
99
+ os.makedirs(output_dir, exist_ok=True)
100
+
101
+ if model_type in ["mel_band_roformer", "bs_roformer", "mdx23c", "scnet", "htdemucs", "bandit", "bandit_v2"]:
102
+ try:
103
+ info = models_data[model_type][model_name]
104
+ except KeyError:
105
+ print("Model not exist")
106
+ return [("None", "/none/none.mp3")]
107
+
108
+ conf, ckpt = download_model(self.models_cache_dir, model_name, model_type,
109
+ info["checkpoint_url"], info["config_url"])
110
+ if model_type != "htdemucs":
111
+ conf_editor(conf)
112
+
113
+ if call_method == "cli":
114
+ cmd = ["python", "-m", "separator.msst_separator", f'--input "{input_file}"',
115
+ f'--store_dir "{output_dir}"', f'--model_type "{model_type}"',
116
+ f'--model_name "{model_name}"', f'--config_path "{conf}"',
117
+ f'--start_check_point "{ckpt}"', f'--output_format "{output_format}"',
118
+ f'--output_bitrate "{output_bitrate}"', f'--template "{template}"',
119
+ "--save_results_info"]
120
+ if ext_inst:
121
+ cmd.append("--extract_instrumental")
122
+ if selected_stems:
123
+ instruments = " ".join(f'"{s}"' for s in selected_stems)
124
+ cmd.append(f'--selected_instruments {instruments}')
125
+ subprocess.run(" ".join(cmd), shell=True, check=True)
126
+
127
+ results_path = os.path.join(output_dir, "results.json")
128
+ if os.path.exists(results_path):
129
+ with open(results_path, encoding="utf-8") as f:
130
+ return json.load(f)
131
+ return [("None", "/none/none.mp3")]
132
+
133
+ elif call_method == "direct":
134
+ from separator.msst_separator import mvsep_offline
135
+ try:
136
+ return mvsep_offline(
137
+ input_path=input_file, store_dir=output_dir, model_type=model_type,
138
+ config_path=conf, start_check_point=ckpt, extract_instrumental=ext_inst,
139
+ output_format=output_format, output_bitrate=output_bitrate,
140
+ model_name=model_name, template=template, selected_instruments=selected_stems
141
+ )
142
+ except Exception as e:
143
+ print(e)
144
+ return [("None", "/none/none.mp3")]
145
+
146
+ elif model_type in ["vr", "mdx"]:
147
+ try:
148
+ info = models_data[model_type][model_name]
149
+ except KeyError:
150
+ print("Model not exist")
151
+ return [("None", "/none/none.mp3")]
152
+
153
+ if model_type == "vr" and info.get("custom_vr", False):
154
+ conf, ckpt = download_model(self.models_cache_dir, model_name, model_type,
155
+ info["checkpoint_url"], info["config_url"])
156
+ primary_stem = info["primary_stem"]
157
+
158
+ if call_method == "cli":
159
+ cmd = ["python", "-m", "separator.uvr_sep", "custom_vr",
160
+ f'--input_file "{input_file}"', f'--ckpt_path "{ckpt}"',
161
+ f'--config_path "{conf}"', f'--bitrate "{output_bitrate}"',
162
+ f'--model_name "{model_name}"', f'--template "{template}"',
163
+ f'--output_format "{output_format}"', f'--primary_stem "{primary_stem}"',
164
+ f'--aggression {vr_aggr}', f'--output_dir "{output_dir}"']
165
+ if selected_stems:
166
+ instruments = " ".join(f'"{s}"' for s in selected_stems)
167
+ cmd.append(f'--selected_instruments {instruments}')
168
+ subprocess.run(" ".join(cmd), shell=True, check=True)
169
+
170
+ results_path = os.path.join(output_dir, "results.json")
171
+ if os.path.exists(results_path):
172
+ with open(results_path, encoding="utf-8") as f:
173
+ return json.load(f)
174
+ return [("None", "/none/none.mp3")]
175
+
176
+ elif call_method == "direct":
177
+ from separator.uvr_sep import custom_vr_separate
178
+ try:
179
+ return custom_vr_separate(
180
+ input_file=input_file, ckpt_path=ckpt, config_path=conf,
181
+ bitrate=output_bitrate, model_name=model_name, template=template,
182
+ output_format=output_format, primary_stem=primary_stem,
183
+ aggression=vr_aggr, output_dir=output_dir,
184
+ selected_instruments=selected_stems
185
+ )
186
+ except Exception as e:
187
+ print(e)
188
+ return [("None", "/none/none.mp3")]
189
+ else:
190
+ if call_method == "cli":
191
+ cmd = ["python", "-m", "separator.uvr_sep", "uvr",
192
+ f'--input_file "{input_file}"', f'--output_dir "{output_dir}"',
193
+ f'--template "{template}"', f'--bitrate "{output_bitrate}"',
194
+ f'--model_dir "{self.models_cache_dir}"', f'--model_type "{model_type}"',
195
+ f'--model_name "{model_name}"', f'--output_format "{output_format}"',
196
+ f'--aggression {vr_aggr}']
197
+ if selected_stems:
198
+ instruments = " ".join(f'"{s}"' for s in selected_stems)
199
+ cmd.append(f'--selected_instruments {instruments}')
200
+ subprocess.run(" ".join(cmd), shell=True, check=True)
201
+
202
+ results_path = os.path.join(output_dir, "results.json")
203
+ if os.path.exists(results_path):
204
+ with open(results_path, encoding="utf-8") as f:
205
+ return json.load(f)
206
+ return [("None", "/none/none.mp3")]
207
+
208
+ elif call_method == "direct":
209
+ from separator.uvr_sep import non_custom_uvr_inference
210
+ try:
211
+ return non_custom_uvr_inference(
212
+ input_file=input_file, output_dir=output_dir, template=template,
213
+ bitrate=output_bitrate, model_dir=self.models_cache_dir,
214
+ model_type=model_type, model_name=model_name,
215
+ output_format=output_format, aggression=vr_aggr,
216
+ selected_instruments=selected_stems
217
+ )
218
+ except Exception as e:
219
+ print(e)
220
+ return [("None", "/none/none.mp3")]
221
+
222
+ print("Unsupported model type")
223
+ return [("None", "/none/none.mp3")]
224
+
225
+ def parse_args():
226
+ parser = argparse.ArgumentParser(description="Multi-inference for separation audio in Google Colab")
227
+ subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-command help')
228
+
229
+ list_models = subparsers.add_parser('list', help='List of exist models')
230
+ list_models.add_argument("-l_filter", "--list_filter", type=str, default=None, help="Show models in list only with specified stem")
231
+
232
+ separate = subparsers.add_parser('separate', help='Separate I/O params')
233
+ separate.add_argument("-i", "--input", type=str, required=True, help="Input file or directory")
234
+ separate.add_argument("-o", "--output", type=str, required=True, help="Output directory")
235
+ separate.add_argument("-mt", "--model_type", type=str, required=True, choices=MODEL_TYPES, help="Model type")
236
+ separate.add_argument("-mn", "--model_name", type=str, required=True, help="Model name")
237
+ separate.add_argument("-inst", "--instrumental", action='store_true', help="Extract instrumental")
238
+ separate.add_argument("-stems", "--stems", nargs="+", help="Select output stems")
239
+ separate.add_argument("-bitrate", "--bitrate", type=str, default="320k", help="Output bitrate")
240
+ separate.add_argument("-of", "--format", type=str, default="mp3", help="Output format")
241
+ separate.add_argument("-vr_aggr", "--vr_arch_aggressive", type=int, default=5, help="Aggression for VR ARCH models")
242
+ separate.add_argument('--template', type=str, default='NAME_STEM', help='Template naming of output files')
243
+ separate.add_argument("-l_out", "--list_output", action='store_true', help="Show list output files")
244
+
245
+ return parser.parse_args()
246
+
247
+ if __name__ == "__main__":
248
+ args = parse_args()
249
+ mvsepless = MVSEPLESS()
250
+
251
+ if args.command == 'list':
252
+ mvsepless.display_models_info(args.list_filter)
253
+
254
+ elif args.command == 'separate':
255
+ if os.path.isfile(args.input):
256
+ results = mvsepless.separator(
257
+ input_file=args.input,
258
+ output_dir=args.output,
259
+ model_type=args.model_type,
260
+ model_name=args.model_name,
261
+ ext_inst=args.instrumental,
262
+ vr_aggr=args.vr_arch_aggressive,
263
+ output_format=args.format,
264
+ output_bitrate=args.bitrate,
265
+ template=args.template,
266
+ call_method="cli",
267
+ selected_stems=args.stems
268
+ )
269
+ if args.list_output:
270
+ print("Results\n")
271
+ for stem, path in results:
272
+ print(f"Stem - {stem}\nPath - {path}\n")
273
+
274
+ elif os.path.isdir(args.input):
275
+ batch_results = []
276
+ for file in os.listdir(args.input):
277
+ abs_path_file = os.path.join(args.input, file)
278
+ if os.path.isfile(abs_path_file):
279
+ base_name = os.path.splitext(os.path.basename(abs_path_file))[0]
280
+ output_subdir = os.path.join(args.output, base_name)
281
+
282
+ results = mvsepless.separator(
283
+ input_file=abs_path_file,
284
+ output_dir=output_subdir,
285
+ model_type=args.model_type,
286
+ model_name=args.model_name,
287
+ ext_inst=args.instrumental,
288
+ vr_aggr=args.vr_arch_aggressive,
289
+ output_format=args.format,
290
+ output_bitrate=args.bitrate,
291
+ template=args.template,
292
+ call_method="cli",
293
+ selected_stems=args.stems
294
+ )
295
+ batch_results.append((base_name, results))
296
+
297
+ if args.list_output:
298
+ print("Results\n")
299
+ for name, stems in batch_results:
300
+ print(f"Name - {name}")
301
+ for stem, path in stems:
302
+ print(f" Stem - {stem}\n Path - {path}\n")
303
+
requirements.txt ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.6.0
2
+ torchvision==0.21.0
3
+ torchaudio==2.6.0
4
+ torchcrepe==0.0.23
5
+ numpy==2.0.2
6
+ pandas==2.2.2
7
+ scipy==1.15.3
8
+ librosa==0.9.1
9
+ matplotlib==3.9.0
10
+ tqdm==4.67.1
11
+ einops==0.8.1
12
+ protobuf==5.29.4
13
+ soundfile==0.13.1
14
+ pydub==0.25.1
15
+ pyloudnorm==0.1.1
16
+ praat-parselmouth==0.4.5
17
+ webrtcvad==2.0.10
18
+ edge-tts==7.0.2
19
+ audiomentations==0.24.0
20
+ pedalboard==0.8.1
21
+ ffmpeg-python==0.2.0
22
+ faiss-cpu==1.11
23
+ ml_collections==1.1.0
24
+ timm==1.0.15
25
+ wandb==0.19.11
26
+ accelerate==1.7.0
27
+ bitsandbytes==0.46.0
28
+ tokenizers==0.19
29
+ huggingface-hub==0.28.1
30
+ transformers==4.41
31
+ https://github.com/noblebarkrr/mvsepless/blob/bd611441e48e918650e6860738894673b3a1a5f1/fixed/fairseq_fixed-0.13.0-cp311-cp311-linux_x86_64.whl
32
+ torchseg==0.0.1a4
33
+ demucs==4.0.0
34
+ asteroid==0.7.0
35
+ prodigyopt==1.1.2
36
+ torch_log_wmse==0.3.0
37
+ rotary_embedding_torch==0.6.5
38
+ local-attention==1.11.1
39
+ tenacity==9.1.2
40
+ gradio==5.38.2
41
+ omegaconf==2.3.0
42
+ beartype==0.18.5
43
+ spafe==0.3.2
44
+ torch_audiomentations==0.12.0
45
+ auraloss==0.4.0
46
+ onnxruntime-gpu>=1.17
47
+ yt_dlp
48
+ python-magic
49
+ pyngrok
50
+
separator/audio_writer.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydub import AudioSegment
2
+ import numpy as np
3
+
4
+ def write_audio_file(output_file_path, numpy_array, sample_rate, output_format, bitrate):
5
+ """
6
+ Записывает аудиофайл из numpy массива в указанном формате с помощью pydub.
7
+
8
+ Параметры:
9
+ output_file_path (str): Путь для сохранения файла (без расширения)
10
+ numpy_array (numpy.ndarray): Аудиоданные в виде numpy массива
11
+ sample_rate (int): Частота дискретизации (в Гц)
12
+ output_format (str): Формат выходного файла ('mp3', 'flac', 'wav', 'aiff', 'm4a', 'aac', 'ogg', 'opus')
13
+ encoder_settings (dict, optional): Cловарь с настройками кодировки аудио
14
+ """
15
+ try:
16
+ # Проверка и нормализация входных данных
17
+ if not isinstance(numpy_array, np.ndarray):
18
+ raise ValueError("Input must be a numpy array")
19
+
20
+ # Преобразование в правильную форму (samples, channels)
21
+ if len(numpy_array.shape) == 1:
22
+ numpy_array = numpy_array.reshape(-1, 1) # Моно
23
+ elif len(numpy_array.shape) == 2:
24
+ if numpy_array.shape[0] == 2: # Если (channels, samples)
25
+ numpy_array = numpy_array.T # Транспонируем в (samples, channels)
26
+ else:
27
+ raise ValueError("Input array must be 1D or 2D")
28
+
29
+ # Нормализация до диапазона [-1.0, 1.0] если нужно
30
+ if np.issubdtype(numpy_array.dtype, np.floating):
31
+ numpy_array = np.clip(numpy_array, -1.0, 1.0)
32
+ numpy_array = (numpy_array * 32767).astype(np.int16)
33
+ elif numpy_array.dtype != np.int16:
34
+ numpy_array = numpy_array.astype(np.int16)
35
+
36
+ # Создание AudioSegment
37
+ if numpy_array.shape[1] == 1: # Моно
38
+ audio_segment = AudioSegment(
39
+ numpy_array.tobytes(),
40
+ frame_rate=sample_rate,
41
+ sample_width=2, # 16-bit = 2 bytes
42
+ channels=1
43
+ )
44
+ else: # Стерео
45
+ # Для стерео нужно чередовать байты левого и правого каналов
46
+ interleaved = np.empty((numpy_array.shape[0] * 2,), dtype=np.int16)
47
+ interleaved[0::2] = numpy_array[:, 0] # Левый канал
48
+ interleaved[1::2] = numpy_array[:, 1] # Правый канал
49
+ audio_segment = AudioSegment(
50
+ interleaved.tobytes(),
51
+ frame_rate=sample_rate,
52
+ sample_width=2,
53
+ channels=2
54
+ )
55
+
56
+ # Формирование параметров экспорта
57
+
58
+ parameters = {}
59
+ if bitrate:
60
+ parameters['bitrate'] = bitrate
61
+
62
+ # Поддержка различных форматов
63
+ format_mapping = {
64
+ 'mp3': 'mp3',
65
+ 'flac': 'flac',
66
+ 'wav': 'wav',
67
+ 'aiff': 'aiff',
68
+ 'm4a': 'ipod', # для m4a в pydub используется кодек ipod
69
+ 'aac': 'adts', # для aac в pydub используется adts
70
+ 'ogg': 'ogg',
71
+ 'opus': 'opus'
72
+ }
73
+
74
+ if output_format not in format_mapping:
75
+ raise ValueError(f"Unsupported format: {output_format}. Supported formats are: {list(format_mapping.keys())}")
76
+
77
+ # Добавление расширения файла, если его нет
78
+ if not output_file_path.lower().endswith(f'.{output_format}'):
79
+ output_file_path = f"{output_file_path}.{output_format}"
80
+
81
+ # Экспорт в нужный формат
82
+ audio_segment.export(output_file_path, format=format_mapping[output_format], **parameters)
83
+
84
+ except Exception as e:
85
+ raise RuntimeError(f"Error writing audio file: {str(e)}")
separator/ensemble.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
3
+
4
+ import os
5
+ import sys
6
+ import librosa
7
+ import tempfile
8
+ import soundfile as sf
9
+ import numpy as np
10
+ import argparse
11
+ from separator.audio_writer import write_audio_file
12
+
13
+
14
+ def stft(wave, nfft, hl):
15
+ wave_left = np.asfortranarray(wave[0])
16
+ wave_right = np.asfortranarray(wave[1])
17
+ spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
18
+ spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
19
+ spec = np.asfortranarray([spec_left, spec_right])
20
+ return spec
21
+
22
+
23
+ def istft(spec, hl, length):
24
+ spec_left = np.asfortranarray(spec[0])
25
+ spec_right = np.asfortranarray(spec[1])
26
+ wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
27
+ wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
28
+ wave = np.asfortranarray([wave_left, wave_right])
29
+ return wave
30
+
31
+
32
+ def absmax(a, *, axis):
33
+ dims = list(a.shape)
34
+ dims.pop(axis)
35
+ indices = np.ogrid[tuple(slice(0, d) for d in dims)]
36
+ argmax = np.abs(a).argmax(axis=axis)
37
+ # Convert indices to list before insertion
38
+ indices = list(indices)
39
+ indices.insert(axis % len(a.shape), argmax)
40
+ return a[tuple(indices)]
41
+
42
+
43
+ def absmin(a, *, axis):
44
+ dims = list(a.shape)
45
+ dims.pop(axis)
46
+ indices = np.ogrid[tuple(slice(0, d) for d in dims)]
47
+ argmax = np.abs(a).argmin(axis=axis)
48
+ indices.insert((len(a.shape) + axis) % len(a.shape), argmax)
49
+ return a[tuple(indices)]
50
+
51
+
52
+ def lambda_max(arr, axis=None, key=None, keepdims=False):
53
+ idxs = np.argmax(key(arr), axis)
54
+ if axis is not None:
55
+ idxs = np.expand_dims(idxs, axis)
56
+ result = np.take_along_axis(arr, idxs, axis)
57
+ if not keepdims:
58
+ result = np.squeeze(result, axis=axis)
59
+ return result
60
+ else:
61
+ return arr.flatten()[idxs]
62
+
63
+
64
+ def lambda_min(arr, axis=None, key=None, keepdims=False):
65
+ idxs = np.argmin(key(arr), axis)
66
+ if axis is not None:
67
+ idxs = np.expand_dims(idxs, axis)
68
+ result = np.take_along_axis(arr, idxs, axis)
69
+ if not keepdims:
70
+ result = np.squeeze(result, axis=axis)
71
+ return result
72
+ else:
73
+ return arr.flatten()[idxs]
74
+
75
+
76
+ def average_waveforms(pred_track, weights, algorithm):
77
+ """
78
+ :param pred_track: shape = (num, channels, length)
79
+ :param weights: shape = (num, )
80
+ :param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
81
+ :return: averaged waveform in shape (channels, length)
82
+ """
83
+
84
+ pred_track = np.array(pred_track)
85
+ final_length = pred_track.shape[-1]
86
+
87
+ mod_track = []
88
+ for i in range(pred_track.shape[0]):
89
+ if algorithm == 'avg_wave':
90
+ mod_track.append(pred_track[i] * weights[i])
91
+ elif algorithm in ['median_wave', 'min_wave', 'max_wave']:
92
+ mod_track.append(pred_track[i])
93
+ elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']:
94
+ spec = stft(pred_track[i], nfft=2048, hl=1024)
95
+ if algorithm in ['avg_fft']:
96
+ mod_track.append(spec * weights[i])
97
+ else:
98
+ mod_track.append(spec)
99
+ pred_track = np.array(mod_track)
100
+
101
+ if algorithm in ['avg_wave']:
102
+ pred_track = pred_track.sum(axis=0)
103
+ pred_track /= np.array(weights).sum().T
104
+ elif algorithm in ['median_wave']:
105
+ pred_track = np.median(pred_track, axis=0)
106
+ elif algorithm in ['min_wave']:
107
+ pred_track = np.array(pred_track)
108
+ pred_track = lambda_min(pred_track, axis=0, key=np.abs)
109
+ elif algorithm in ['max_wave']:
110
+ pred_track = np.array(pred_track)
111
+ pred_track = lambda_max(pred_track, axis=0, key=np.abs)
112
+ elif algorithm in ['avg_fft']:
113
+ pred_track = pred_track.sum(axis=0)
114
+ pred_track /= np.array(weights).sum()
115
+ pred_track = istft(pred_track, 1024, final_length)
116
+ elif algorithm in ['min_fft']:
117
+ pred_track = np.array(pred_track)
118
+ pred_track = lambda_min(pred_track, axis=0, key=np.abs)
119
+ pred_track = istft(pred_track, 1024, final_length)
120
+ elif algorithm in ['max_fft']:
121
+ pred_track = np.array(pred_track)
122
+ pred_track = absmax(pred_track, axis=0)
123
+ pred_track = istft(pred_track, 1024, final_length)
124
+ elif algorithm in ['median_fft']:
125
+ pred_track = np.array(pred_track)
126
+ pred_track = np.median(pred_track, axis=0)
127
+ pred_track = istft(pred_track, 1024, final_length)
128
+ return pred_track
129
+
130
+
131
+ def ensemble_audio_files(files, output="res.wav", ensemble_type='avg_wave', weights=None, out_format="wav"):
132
+ """
133
+ Основная функция для объединения аудиофайлов
134
+
135
+ :param files: список путей к аудиофайлам
136
+ :param output: путь для сохранения результата
137
+ :param ensemble_type: метод объединения (avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft)
138
+ :param weights: список весов для каждого файла (None для равных весов)
139
+ :return: None
140
+ """
141
+ print('Ensemble type: {}'.format(ensemble_type))
142
+ print('Number of input files: {}'.format(len(files)))
143
+ if weights is not None:
144
+ weights = np.array(weights)
145
+ else:
146
+ weights = np.ones(len(files))
147
+ print('Weights: {}'.format(weights))
148
+ print('Output file: {}'.format(output))
149
+
150
+ data = []
151
+ sr = None
152
+ for f in files:
153
+ if not os.path.isfile(f):
154
+ print('Error. Can\'t find file: {}. Check paths.'.format(f))
155
+ exit()
156
+ print('Reading file: {}'.format(f))
157
+ wav, current_sr = librosa.load(f, sr=None, mono=False)
158
+ if sr is None:
159
+ sr = current_sr
160
+ elif sr != current_sr:
161
+ print('Error: Sample rates must be equal for all files')
162
+ exit()
163
+ print("Waveform shape: {} sample rate: {}".format(wav.shape, sr))
164
+ data.append(wav)
165
+
166
+ data = np.array(data)
167
+ res = average_waveforms(data, weights, ensemble_type)
168
+ print('Result shape: {}'.format(res.shape))
169
+
170
+ output_wav = f"{output}_orig.wav"
171
+ output = f"{output}.{out_format}"
172
+
173
+ if out_format in ["wav", "flac"]:
174
+
175
+ sf.write(output, res.T, sr, subtype='PCM_16')
176
+ sf.write(output_wav, res.T, sr, subtype='PCM_16')
177
+
178
+ elif out_format in ["mp3", "m4a", "aac", "ogg", "opus", "aiff"]:
179
+
180
+ write_audio_file(output, res.T, sr, out_format, "320k")
181
+ sf.write(output_wav, res.T, sr, subtype='PCM_16')
182
+
183
+ return output, output_wav
184
+
185
+
186
+
187
+ # input_settings = [("demucs / v4", 1.0, "vocals"), ("mel_band_roformer / mel_4_stems", 0.5, "vocals")]
188
+
189
+ # out, wav = ensembless(input_audio, input_settings, "max_fft", format)
190
+
191
+
192
+
separator/models/bandit/core/__init__.py ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from collections import defaultdict
3
+ from itertools import chain, combinations
4
+ from typing import (
5
+ Any,
6
+ Dict,
7
+ Iterator,
8
+ Mapping, Optional,
9
+ Tuple, Type,
10
+ TypedDict
11
+ )
12
+
13
+ import pytorch_lightning as pl
14
+ import torch
15
+ import torchaudio as ta
16
+ import torchmetrics as tm
17
+ from asteroid import losses as asteroid_losses
18
+ # from deepspeed.ops.adam import DeepSpeedCPUAdam
19
+ # from geoopt import optim as gooptim
20
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
21
+ from torch import nn, optim
22
+ from torch.optim import lr_scheduler
23
+ from torch.optim.lr_scheduler import LRScheduler
24
+
25
+ from models.bandit.core import loss, metrics as metrics_, model
26
+ from models.bandit.core.data._types import BatchedDataDict
27
+ from models.bandit.core.data.augmentation import BaseAugmentor, StemAugmentor
28
+ from models.bandit.core.utils import audio as audio_
29
+ from models.bandit.core.utils.audio import BaseFader
30
+
31
+ # from pandas.io.json._normalize import nested_to_record
32
+
33
+ ConfigDict = TypedDict('ConfigDict', {'name': str, 'kwargs': Dict[str, Any]})
34
+
35
+
36
+ class SchedulerConfigDict(ConfigDict):
37
+ monitor: str
38
+
39
+
40
+ OptimizerSchedulerConfigDict = TypedDict(
41
+ 'OptimizerSchedulerConfigDict',
42
+ {"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
43
+ total=False
44
+ )
45
+
46
+
47
+ class LRSchedulerReturnDict(TypedDict, total=False):
48
+ scheduler: LRScheduler
49
+ monitor: str
50
+
51
+
52
+ class ConfigureOptimizerReturnDict(TypedDict, total=False):
53
+ optimizer: torch.optim.Optimizer
54
+ lr_scheduler: LRSchedulerReturnDict
55
+
56
+
57
+ OutputType = Dict[str, Any]
58
+ MetricsType = Dict[str, torch.Tensor]
59
+
60
+
61
+ def get_optimizer_class(name: str) -> Type[optim.Optimizer]:
62
+
63
+ if name == "DeepSpeedCPUAdam":
64
+ return DeepSpeedCPUAdam
65
+
66
+ for module in [optim, gooptim]:
67
+ if name in module.__dict__:
68
+ return module.__dict__[name]
69
+
70
+ raise NameError
71
+
72
+
73
+ def parse_optimizer_config(
74
+ config: OptimizerSchedulerConfigDict,
75
+ parameters: Iterator[nn.Parameter]
76
+ ) -> ConfigureOptimizerReturnDict:
77
+ optim_class = get_optimizer_class(config["optimizer"]["name"])
78
+ optimizer = optim_class(parameters, **config["optimizer"]["kwargs"])
79
+
80
+ optim_dict: ConfigureOptimizerReturnDict = {
81
+ "optimizer": optimizer,
82
+ }
83
+
84
+ if "scheduler" in config:
85
+
86
+ lr_scheduler_class_ = config["scheduler"]["name"]
87
+ lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_]
88
+ lr_scheduler_dict: LRSchedulerReturnDict = {
89
+ "scheduler": lr_scheduler_class(
90
+ optimizer,
91
+ **config["scheduler"]["kwargs"]
92
+ )
93
+ }
94
+
95
+ if lr_scheduler_class_ == "ReduceLROnPlateau":
96
+ lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"]
97
+
98
+ optim_dict["lr_scheduler"] = lr_scheduler_dict
99
+
100
+ return optim_dict
101
+
102
+
103
+ def parse_model_config(config: ConfigDict) -> Any:
104
+ name = config["name"]
105
+
106
+ for module in [model]:
107
+ if name in module.__dict__:
108
+ return module.__dict__[name](**config["kwargs"])
109
+
110
+ raise NameError
111
+
112
+
113
+ _LEGACY_LOSS_NAMES = ["HybridL1Loss"]
114
+
115
+
116
+ def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module:
117
+ name = config["name"]
118
+
119
+ if name == "HybridL1Loss":
120
+ return loss.TimeFreqL1Loss(**config["kwargs"])
121
+
122
+ raise NameError
123
+
124
+
125
+ def parse_loss_config(config: ConfigDict) -> nn.Module:
126
+ name = config["name"]
127
+
128
+ if name in _LEGACY_LOSS_NAMES:
129
+ return _parse_legacy_loss_config(config)
130
+
131
+ for module in [loss, nn.modules.loss, asteroid_losses]:
132
+ if name in module.__dict__:
133
+ # print(config["kwargs"])
134
+ return module.__dict__[name](**config["kwargs"])
135
+
136
+ raise NameError
137
+
138
+
139
+ def get_metric(config: ConfigDict) -> tm.Metric:
140
+ name = config["name"]
141
+
142
+ for module in [tm, metrics_]:
143
+ if name in module.__dict__:
144
+ return module.__dict__[name](**config["kwargs"])
145
+ raise NameError
146
+
147
+
148
+ def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection:
149
+ metrics = {}
150
+
151
+ for metric in config:
152
+ metrics[metric] = get_metric(config[metric])
153
+
154
+ return tm.MetricCollection(metrics)
155
+
156
+
157
+ def parse_fader_config(config: ConfigDict) -> BaseFader:
158
+ name = config["name"]
159
+
160
+ for module in [audio_]:
161
+ if name in module.__dict__:
162
+ return module.__dict__[name](**config["kwargs"])
163
+
164
+ raise NameError
165
+
166
+
167
+ class LightningSystem(pl.LightningModule):
168
+ _VOX_STEMS = ["speech", "vocals"]
169
+ _BG_STEMS = ["background", "effects", "mne"]
170
+
171
+ def __init__(
172
+ self,
173
+ config: Dict,
174
+ loss_adjustment: float = 1.0,
175
+ attach_fader: bool = False
176
+ ) -> None:
177
+ super().__init__()
178
+ self.optimizer_config = config["optimizer"]
179
+ self.model = parse_model_config(config["model"])
180
+ self.loss = parse_loss_config(config["loss"])
181
+ self.metrics = nn.ModuleDict(
182
+ {
183
+ stem: parse_metric_config(config["metrics"]["dev"])
184
+ for stem in self.model.stems
185
+ }
186
+ )
187
+
188
+ self.metrics.disallow_fsdp = True
189
+
190
+ self.test_metrics = nn.ModuleDict(
191
+ {
192
+ stem: parse_metric_config(config["metrics"]["test"])
193
+ for stem in self.model.stems
194
+ }
195
+ )
196
+
197
+ self.test_metrics.disallow_fsdp = True
198
+
199
+ self.fs = config["model"]["kwargs"]["fs"]
200
+
201
+ self.fader_config = config["inference"]["fader"]
202
+ if attach_fader:
203
+ self.fader = parse_fader_config(config["inference"]["fader"])
204
+ else:
205
+ self.fader = None
206
+
207
+ self.augmentation: Optional[BaseAugmentor]
208
+ if config.get("augmentation", None) is not None:
209
+ self.augmentation = StemAugmentor(**config["augmentation"])
210
+ else:
211
+ self.augmentation = None
212
+
213
+ self.predict_output_path: Optional[str] = None
214
+ self.loss_adjustment = loss_adjustment
215
+
216
+ self.val_prefix = None
217
+ self.test_prefix = None
218
+
219
+
220
+ def configure_optimizers(self) -> Any:
221
+ return parse_optimizer_config(
222
+ self.optimizer_config,
223
+ self.trainer.model.parameters()
224
+ )
225
+
226
+ def compute_loss(self, batch: BatchedDataDict, output: OutputType) -> Dict[
227
+ str, torch.Tensor]:
228
+ return {"loss": self.loss(output, batch)}
229
+
230
+ def update_metrics(
231
+ self,
232
+ batch: BatchedDataDict,
233
+ output: OutputType,
234
+ mode: str
235
+ ) -> None:
236
+
237
+ if mode == "test":
238
+ metrics = self.test_metrics
239
+ else:
240
+ metrics = self.metrics
241
+
242
+ for stem, metric in metrics.items():
243
+
244
+ if stem == "mne:+":
245
+ stem = "mne"
246
+
247
+ # print(f"matching for {stem}")
248
+ if mode == "train":
249
+ metric.update(
250
+ output["audio"][stem],#.cpu(),
251
+ batch["audio"][stem],#.cpu()
252
+ )
253
+ else:
254
+ if stem not in batch["audio"]:
255
+ matched = False
256
+ if stem in self._VOX_STEMS:
257
+ for bstem in self._VOX_STEMS:
258
+ if bstem in batch["audio"]:
259
+ batch["audio"][stem] = batch["audio"][bstem]
260
+ matched = True
261
+ break
262
+ elif stem in self._BG_STEMS:
263
+ for bstem in self._BG_STEMS:
264
+ if bstem in batch["audio"]:
265
+ batch["audio"][stem] = batch["audio"][bstem]
266
+ matched = True
267
+ break
268
+ else:
269
+ matched = True
270
+
271
+ # print(batch["audio"].keys())
272
+
273
+ if matched:
274
+ # print(f"matched {stem}!")
275
+ if stem == "mne" and "mne" not in output["audio"]:
276
+ output["audio"]["mne"] = output["audio"]["music"] + output["audio"]["effects"]
277
+
278
+ metric.update(
279
+ output["audio"][stem],#.cpu(),
280
+ batch["audio"][stem],#.cpu(),
281
+ )
282
+
283
+ # print(metric.compute())
284
+ def compute_metrics(self, mode: str="dev") -> Dict[
285
+ str, torch.Tensor]:
286
+
287
+ if mode == "test":
288
+ metrics = self.test_metrics
289
+ else:
290
+ metrics = self.metrics
291
+
292
+ metric_dict = {}
293
+
294
+ for stem, metric in metrics.items():
295
+ md = metric.compute()
296
+ metric_dict.update(
297
+ {f"{stem}/{k}": v for k, v in md.items()}
298
+ )
299
+
300
+ self.log_dict(metric_dict, prog_bar=True, logger=False)
301
+
302
+ return metric_dict
303
+
304
+ def reset_metrics(self, test_mode: bool = False) -> None:
305
+
306
+ if test_mode:
307
+ metrics = self.test_metrics
308
+ else:
309
+ metrics = self.metrics
310
+
311
+ for _, metric in metrics.items():
312
+ metric.reset()
313
+
314
+
315
+ def forward(self, batch: BatchedDataDict) -> Any:
316
+ batch, output = self.model(batch)
317
+
318
+
319
+ return batch, output
320
+
321
+ def common_step(self, batch: BatchedDataDict, mode: str) -> Any:
322
+ batch, output = self.forward(batch)
323
+ # print(batch)
324
+ # print(output)
325
+ loss_dict = self.compute_loss(batch, output)
326
+
327
+ with torch.no_grad():
328
+ self.update_metrics(batch, output, mode=mode)
329
+
330
+ if mode == "train":
331
+ self.log("loss", loss_dict["loss"], prog_bar=True)
332
+
333
+ return output, loss_dict
334
+
335
+
336
+ def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
337
+
338
+ if self.augmentation is not None:
339
+ with torch.no_grad():
340
+ batch = self.augmentation(batch)
341
+
342
+ _, loss_dict = self.common_step(batch, mode="train")
343
+
344
+ with torch.inference_mode():
345
+ self.log_dict_with_prefix(
346
+ loss_dict,
347
+ "train",
348
+ batch_size=batch["audio"]["mixture"].shape[0]
349
+ )
350
+
351
+ loss_dict["loss"] *= self.loss_adjustment
352
+
353
+ return loss_dict
354
+
355
+ def on_train_batch_end(
356
+ self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
357
+ ) -> None:
358
+
359
+ metric_dict = self.compute_metrics()
360
+ self.log_dict_with_prefix(metric_dict, "train")
361
+ self.reset_metrics()
362
+
363
+ def validation_step(
364
+ self,
365
+ batch: BatchedDataDict,
366
+ batch_idx: int,
367
+ dataloader_idx: int = 0
368
+ ) -> Dict[str, Any]:
369
+
370
+ with torch.inference_mode():
371
+ curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val"
372
+
373
+ if curr_val_prefix != self.val_prefix:
374
+ # print(f"Switching to validation dataloader {dataloader_idx}")
375
+ if self.val_prefix is not None:
376
+ self._on_validation_epoch_end()
377
+ self.val_prefix = curr_val_prefix
378
+ _, loss_dict = self.common_step(batch, mode="val")
379
+
380
+ self.log_dict_with_prefix(
381
+ loss_dict,
382
+ self.val_prefix,
383
+ batch_size=batch["audio"]["mixture"].shape[0],
384
+ prog_bar=True,
385
+ add_dataloader_idx=False
386
+ )
387
+
388
+ return loss_dict
389
+
390
+ def on_validation_epoch_end(self) -> None:
391
+ self._on_validation_epoch_end()
392
+
393
+ def _on_validation_epoch_end(self) -> None:
394
+ metric_dict = self.compute_metrics()
395
+ self.log_dict_with_prefix(metric_dict, self.val_prefix, prog_bar=True,
396
+ add_dataloader_idx=False)
397
+ # self.logger.save()
398
+ # print(self.val_prefix, "Validation metrics:", metric_dict)
399
+ self.reset_metrics()
400
+
401
+
402
+ def old_predtest_step(
403
+ self,
404
+ batch: BatchedDataDict,
405
+ batch_idx: int,
406
+ dataloader_idx: int = 0
407
+ ) -> Tuple[BatchedDataDict, OutputType]:
408
+
409
+ audio_batch = batch["audio"]["mixture"]
410
+ track_batch = batch.get("track", ["" for _ in range(len(audio_batch))])
411
+
412
+ output_list_of_dicts = [
413
+ self.fader(
414
+ audio[None, ...],
415
+ lambda a: self.test_forward(a, track)
416
+ )
417
+ for audio, track in zip(audio_batch, track_batch)
418
+ ]
419
+
420
+ output_dict_of_lists = defaultdict(list)
421
+
422
+ for output_dict in output_list_of_dicts:
423
+ for stem, audio in output_dict.items():
424
+ output_dict_of_lists[stem].append(audio)
425
+
426
+ output = {
427
+ "audio": {
428
+ stem: torch.concat(output_list, dim=0)
429
+ for stem, output_list in output_dict_of_lists.items()
430
+ }
431
+ }
432
+
433
+ return batch, output
434
+
435
+ def predtest_step(
436
+ self,
437
+ batch: BatchedDataDict,
438
+ batch_idx: int = -1,
439
+ dataloader_idx: int = 0
440
+ ) -> Tuple[BatchedDataDict, OutputType]:
441
+
442
+ if getattr(self.model, "bypass_fader", False):
443
+ batch, output = self.model(batch)
444
+ else:
445
+ audio_batch = batch["audio"]["mixture"]
446
+ output = self.fader(
447
+ audio_batch,
448
+ lambda a: self.test_forward(a, "", batch=batch)
449
+ )
450
+
451
+ return batch, output
452
+
453
+ def test_forward(
454
+ self,
455
+ audio: torch.Tensor,
456
+ track: str = "",
457
+ batch: BatchedDataDict = None
458
+ ) -> torch.Tensor:
459
+
460
+ if self.fader is None:
461
+ self.attach_fader()
462
+
463
+ cond = batch.get("condition", None)
464
+
465
+ if cond is not None and cond.shape[0] == 1:
466
+ cond = cond.repeat(audio.shape[0], 1)
467
+
468
+ _, output = self.forward(
469
+ {"audio": {"mixture": audio},
470
+ "track": track,
471
+ "condition": cond,
472
+ }
473
+ ) # TODO: support track properly
474
+
475
+ return output["audio"]
476
+
477
+ def on_test_epoch_start(self) -> None:
478
+ self.attach_fader(force_reattach=True)
479
+
480
+ def test_step(
481
+ self,
482
+ batch: BatchedDataDict,
483
+ batch_idx: int,
484
+ dataloader_idx: int = 0
485
+ ) -> Any:
486
+ curr_test_prefix = f"test{dataloader_idx}"
487
+
488
+ # print(batch["audio"].keys())
489
+
490
+ if curr_test_prefix != self.test_prefix:
491
+ # print(f"Switching to test dataloader {dataloader_idx}")
492
+ if self.test_prefix is not None:
493
+ self._on_test_epoch_end()
494
+ self.test_prefix = curr_test_prefix
495
+
496
+ with torch.inference_mode():
497
+ _, output = self.predtest_step(batch, batch_idx, dataloader_idx)
498
+ # print(output)
499
+ self.update_metrics(batch, output, mode="test")
500
+
501
+ return output
502
+
503
+ def on_test_epoch_end(self) -> None:
504
+ self._on_test_epoch_end()
505
+
506
+ def _on_test_epoch_end(self) -> None:
507
+ metric_dict = self.compute_metrics(mode="test")
508
+ self.log_dict_with_prefix(metric_dict, self.test_prefix, prog_bar=True,
509
+ add_dataloader_idx=False)
510
+ # self.logger.save()
511
+ # print(self.test_prefix, "Test metrics:", metric_dict)
512
+ self.reset_metrics()
513
+
514
+ def predict_step(
515
+ self,
516
+ batch: BatchedDataDict,
517
+ batch_idx: int = 0,
518
+ dataloader_idx: int = 0,
519
+ include_track_name: Optional[bool] = None,
520
+ get_no_vox_combinations: bool = True,
521
+ get_residual: bool = False,
522
+ treat_batch_as_channels: bool = False,
523
+ fs: Optional[int] = None,
524
+ ) -> Any:
525
+ assert self.predict_output_path is not None
526
+
527
+ batch_size = batch["audio"]["mixture"].shape[0]
528
+
529
+ if include_track_name is None:
530
+ include_track_name = batch_size > 1
531
+
532
+ with torch.inference_mode():
533
+ batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
534
+ print('Pred test finished...')
535
+ torch.cuda.empty_cache()
536
+ metric_dict = {}
537
+
538
+ if get_residual:
539
+ mixture = batch["audio"]["mixture"]
540
+ extracted = sum([output["audio"][stem] for stem in output["audio"]])
541
+ residual = mixture - extracted
542
+ print(extracted.shape, mixture.shape, residual.shape)
543
+
544
+ output["audio"]["residual"] = residual
545
+
546
+ if get_no_vox_combinations:
547
+ no_vox_stems = [
548
+ stem for stem in output["audio"] if
549
+ stem not in self._VOX_STEMS
550
+ ]
551
+ no_vox_combinations = chain.from_iterable(
552
+ combinations(no_vox_stems, r) for r in
553
+ range(2, len(no_vox_stems) + 1)
554
+ )
555
+
556
+ for combination in no_vox_combinations:
557
+ combination_ = list(combination)
558
+ output["audio"]["+".join(combination_)] = sum(
559
+ [output["audio"][stem] for stem in combination_]
560
+ )
561
+
562
+ if treat_batch_as_channels:
563
+ for stem in output["audio"]:
564
+ output["audio"][stem] = output["audio"][stem].reshape(
565
+ 1, -1, output["audio"][stem].shape[-1]
566
+ )
567
+ batch_size = 1
568
+
569
+ for b in range(batch_size):
570
+ print("!!", b)
571
+ for stem in output["audio"]:
572
+ print(f"Saving audio for {stem} to {self.predict_output_path}")
573
+ track_name = batch["track"][b].split("/")[-1]
574
+
575
+ if batch.get("audio", {}).get(stem, None) is not None:
576
+ self.test_metrics[stem].reset()
577
+ metrics = self.test_metrics[stem](
578
+ batch["audio"][stem][[b], ...],
579
+ output["audio"][stem][[b], ...]
580
+ )
581
+ snr = metrics["snr"]
582
+ sisnr = metrics["sisnr"]
583
+ sdr = metrics["sdr"]
584
+ metric_dict[stem] = metrics
585
+ print(
586
+ track_name,
587
+ f"snr={snr:2.2f} dB",
588
+ f"sisnr={sisnr:2.2f}",
589
+ f"sdr={sdr:2.2f} dB",
590
+ )
591
+ filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
592
+ else:
593
+ filename = f"{stem}.wav"
594
+
595
+ if include_track_name:
596
+ output_dir = os.path.join(
597
+ self.predict_output_path,
598
+ track_name
599
+ )
600
+ else:
601
+ output_dir = self.predict_output_path
602
+
603
+ os.makedirs(output_dir, exist_ok=True)
604
+
605
+ if fs is None:
606
+ fs = self.fs
607
+
608
+ ta.save(
609
+ os.path.join(output_dir, filename),
610
+ output["audio"][stem][b, ...].cpu(),
611
+ fs,
612
+ )
613
+
614
+ return metric_dict
615
+
616
+ def get_stems(
617
+ self,
618
+ batch: BatchedDataDict,
619
+ batch_idx: int = 0,
620
+ dataloader_idx: int = 0,
621
+ include_track_name: Optional[bool] = None,
622
+ get_no_vox_combinations: bool = True,
623
+ get_residual: bool = False,
624
+ treat_batch_as_channels: bool = False,
625
+ fs: Optional[int] = None,
626
+ ) -> Any:
627
+ assert self.predict_output_path is not None
628
+
629
+ batch_size = batch["audio"]["mixture"].shape[0]
630
+
631
+ if include_track_name is None:
632
+ include_track_name = batch_size > 1
633
+
634
+ with torch.inference_mode():
635
+ batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
636
+ torch.cuda.empty_cache()
637
+ metric_dict = {}
638
+
639
+ if get_residual:
640
+ mixture = batch["audio"]["mixture"]
641
+ extracted = sum([output["audio"][stem] for stem in output["audio"]])
642
+ residual = mixture - extracted
643
+ # print(extracted.shape, mixture.shape, residual.shape)
644
+
645
+ output["audio"]["residual"] = residual
646
+
647
+ if get_no_vox_combinations:
648
+ no_vox_stems = [
649
+ stem for stem in output["audio"] if
650
+ stem not in self._VOX_STEMS
651
+ ]
652
+ no_vox_combinations = chain.from_iterable(
653
+ combinations(no_vox_stems, r) for r in
654
+ range(2, len(no_vox_stems) + 1)
655
+ )
656
+
657
+ for combination in no_vox_combinations:
658
+ combination_ = list(combination)
659
+ output["audio"]["+".join(combination_)] = sum(
660
+ [output["audio"][stem] for stem in combination_]
661
+ )
662
+
663
+ if treat_batch_as_channels:
664
+ for stem in output["audio"]:
665
+ output["audio"][stem] = output["audio"][stem].reshape(
666
+ 1, -1, output["audio"][stem].shape[-1]
667
+ )
668
+ batch_size = 1
669
+
670
+ result = {}
671
+ for b in range(batch_size):
672
+ for stem in output["audio"]:
673
+ track_name = batch["track"][b].split("/")[-1]
674
+
675
+ if batch.get("audio", {}).get(stem, None) is not None:
676
+ self.test_metrics[stem].reset()
677
+ metrics = self.test_metrics[stem](
678
+ batch["audio"][stem][[b], ...],
679
+ output["audio"][stem][[b], ...]
680
+ )
681
+ snr = metrics["snr"]
682
+ sisnr = metrics["sisnr"]
683
+ sdr = metrics["sdr"]
684
+ metric_dict[stem] = metrics
685
+ print(
686
+ track_name,
687
+ f"snr={snr:2.2f} dB",
688
+ f"sisnr={sisnr:2.2f}",
689
+ f"sdr={sdr:2.2f} dB",
690
+ )
691
+ filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
692
+ else:
693
+ filename = f"{stem}.wav"
694
+
695
+ if include_track_name:
696
+ output_dir = os.path.join(
697
+ self.predict_output_path,
698
+ track_name
699
+ )
700
+ else:
701
+ output_dir = self.predict_output_path
702
+
703
+ os.makedirs(output_dir, exist_ok=True)
704
+
705
+ if fs is None:
706
+ fs = self.fs
707
+
708
+ result[stem] = output["audio"][stem][b, ...].cpu().numpy()
709
+
710
+ return result
711
+
712
+ def load_state_dict(
713
+ self, state_dict: Mapping[str, Any], strict: bool = False
714
+ ) -> Any:
715
+
716
+ return super().load_state_dict(state_dict, strict=False)
717
+
718
+
719
+ def set_predict_output_path(self, path: str) -> None:
720
+ self.predict_output_path = path
721
+ os.makedirs(self.predict_output_path, exist_ok=True)
722
+
723
+ self.attach_fader()
724
+
725
+ def attach_fader(self, force_reattach=False) -> None:
726
+ if self.fader is None or force_reattach:
727
+ self.fader = parse_fader_config(self.fader_config)
728
+ self.fader.to(self.device)
729
+
730
+
731
+ def log_dict_with_prefix(
732
+ self,
733
+ dict_: Dict[str, torch.Tensor],
734
+ prefix: str,
735
+ batch_size: Optional[int] = None,
736
+ **kwargs: Any
737
+ ) -> None:
738
+ self.log_dict(
739
+ {f"{prefix}/{k}": v for k, v in dict_.items()},
740
+ batch_size=batch_size,
741
+ logger=True,
742
+ sync_dist=True,
743
+ **kwargs,
744
+ )
separator/models/bandit/core/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .dnr.datamodule import DivideAndRemasterDataModule
2
+ from .musdb.datamodule import MUSDB18DataModule
separator/models/bandit/core/data/_types.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Sequence, TypedDict
2
+
3
+ import torch
4
+
5
+ AudioDict = Dict[str, torch.Tensor]
6
+
7
+ DataDict = TypedDict('DataDict', {'audio': AudioDict, 'track': str})
8
+
9
+ BatchedDataDict = TypedDict(
10
+ 'BatchedDataDict',
11
+ {'audio': AudioDict, 'track': Sequence[str]}
12
+ )
13
+
14
+
15
+ class DataDictWithLanguage(TypedDict):
16
+ audio: AudioDict
17
+ track: str
18
+ language: str
separator/models/bandit/core/data/augmentation.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from typing import Any, Dict, Union
3
+
4
+ import torch
5
+ import torch_audiomentations as tam
6
+ from torch import nn
7
+
8
+ from models.bandit.core.data._types import BatchedDataDict, DataDict
9
+
10
+
11
+ class BaseAugmentor(nn.Module, ABC):
12
+ def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[
13
+ DataDict, BatchedDataDict]:
14
+ raise NotImplementedError
15
+
16
+
17
+ class StemAugmentor(BaseAugmentor):
18
+ def __init__(
19
+ self,
20
+ audiomentations: Dict[str, Dict[str, Any]],
21
+ fix_clipping: bool = True,
22
+ scaler_margin: float = 0.5,
23
+ apply_both_default_and_common: bool = False,
24
+ ) -> None:
25
+ super().__init__()
26
+
27
+ augmentations = {}
28
+
29
+ self.has_default = "[default]" in audiomentations
30
+ self.has_common = "[common]" in audiomentations
31
+ self.apply_both_default_and_common = apply_both_default_and_common
32
+
33
+ for stem in audiomentations:
34
+ if audiomentations[stem]["name"] == "Compose":
35
+ augmentations[stem] = getattr(
36
+ tam,
37
+ audiomentations[stem]["name"]
38
+ )(
39
+ [
40
+ getattr(tam, aug["name"])(**aug["kwargs"])
41
+ for aug in
42
+ audiomentations[stem]["kwargs"]["transforms"]
43
+ ],
44
+ **audiomentations[stem]["kwargs"]["kwargs"],
45
+ )
46
+ else:
47
+ augmentations[stem] = getattr(
48
+ tam,
49
+ audiomentations[stem]["name"]
50
+ )(
51
+ **audiomentations[stem]["kwargs"]
52
+ )
53
+
54
+ self.augmentations = nn.ModuleDict(augmentations)
55
+ self.fix_clipping = fix_clipping
56
+ self.scaler_margin = scaler_margin
57
+
58
+ def check_and_fix_clipping(
59
+ self, item: Union[DataDict, BatchedDataDict]
60
+ ) -> Union[DataDict, BatchedDataDict]:
61
+ max_abs = []
62
+
63
+ for stem in item["audio"]:
64
+ max_abs.append(item["audio"][stem].abs().max().item())
65
+
66
+ if max(max_abs) > 1.0:
67
+ scaler = 1.0 / (max(max_abs) + torch.rand(
68
+ (1,),
69
+ device=item["audio"]["mixture"].device
70
+ ) * self.scaler_margin)
71
+
72
+ for stem in item["audio"]:
73
+ item["audio"][stem] *= scaler
74
+
75
+ return item
76
+
77
+ def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[
78
+ DataDict, BatchedDataDict]:
79
+
80
+ for stem in item["audio"]:
81
+ if stem == "mixture":
82
+ continue
83
+
84
+ if self.has_common:
85
+ item["audio"][stem] = self.augmentations["[common]"](
86
+ item["audio"][stem]
87
+ ).samples
88
+
89
+ if stem in self.augmentations:
90
+ item["audio"][stem] = self.augmentations[stem](
91
+ item["audio"][stem]
92
+ ).samples
93
+ elif self.has_default:
94
+ if not self.has_common or self.apply_both_default_and_common:
95
+ item["audio"][stem] = self.augmentations["[default]"](
96
+ item["audio"][stem]
97
+ ).samples
98
+
99
+ item["audio"]["mixture"] = sum(
100
+ [item["audio"][stem] for stem in item["audio"]
101
+ if stem != "mixture"]
102
+ ) # type: ignore[call-overload, assignment]
103
+
104
+ if self.fix_clipping:
105
+ item = self.check_and_fix_clipping(item)
106
+
107
+ return item
separator/models/bandit/core/data/augmented.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Dict, Optional, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.utils import data
7
+
8
+
9
+ class AugmentedDataset(data.Dataset):
10
+ def __init__(
11
+ self,
12
+ dataset: data.Dataset,
13
+ augmentation: nn.Module = nn.Identity(),
14
+ target_length: Optional[int] = None,
15
+ ) -> None:
16
+ warnings.warn(
17
+ "This class is no longer used. Attach augmentation to "
18
+ "the LightningSystem instead.",
19
+ DeprecationWarning,
20
+ )
21
+
22
+ self.dataset = dataset
23
+ self.augmentation = augmentation
24
+
25
+ self.ds_length: int = len(dataset) # type: ignore[arg-type]
26
+ self.length = target_length if target_length is not None else self.ds_length
27
+
28
+ def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str,
29
+ torch.Tensor]]]:
30
+ item = self.dataset[index % self.ds_length]
31
+ item = self.augmentation(item)
32
+ return item
33
+
34
+ def __len__(self) -> int:
35
+ return self.length
separator/models/bandit/core/data/base.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import numpy as np
6
+ import pedalboard as pb
7
+ import torch
8
+ import torchaudio as ta
9
+ from torch.utils import data
10
+
11
+ from models.bandit.core.data._types import AudioDict, DataDict
12
+
13
+
14
+ class BaseSourceSeparationDataset(data.Dataset, ABC):
15
+ def __init__(
16
+ self, split: str,
17
+ stems: List[str],
18
+ files: List[str],
19
+ data_path: str,
20
+ fs: int,
21
+ npy_memmap: bool,
22
+ recompute_mixture: bool
23
+ ):
24
+ self.split = split
25
+ self.stems = stems
26
+ self.stems_no_mixture = [s for s in stems if s != "mixture"]
27
+ self.files = files
28
+ self.data_path = data_path
29
+ self.fs = fs
30
+ self.npy_memmap = npy_memmap
31
+ self.recompute_mixture = recompute_mixture
32
+
33
+ @abstractmethod
34
+ def get_stem(
35
+ self,
36
+ *,
37
+ stem: str,
38
+ identifier: Dict[str, Any]
39
+ ) -> torch.Tensor:
40
+ raise NotImplementedError
41
+
42
+ def _get_audio(self, stems, identifier: Dict[str, Any]):
43
+ audio = {}
44
+ for stem in stems:
45
+ audio[stem] = self.get_stem(stem=stem, identifier=identifier)
46
+
47
+ return audio
48
+
49
+ def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
50
+
51
+ if self.recompute_mixture:
52
+ audio = self._get_audio(
53
+ self.stems_no_mixture,
54
+ identifier=identifier
55
+ )
56
+ audio["mixture"] = self.compute_mixture(audio)
57
+ return audio
58
+ else:
59
+ return self._get_audio(self.stems, identifier=identifier)
60
+
61
+ @abstractmethod
62
+ def get_identifier(self, index: int) -> Dict[str, Any]:
63
+ pass
64
+
65
+ def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
66
+
67
+ return sum(
68
+ audio[stem] for stem in audio if stem != "mixture"
69
+ )
separator/models/bandit/core/data/dnr/__init__.py ADDED
File without changes
separator/models/bandit/core/data/dnr/datamodule.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Mapping, Optional
3
+
4
+ import pytorch_lightning as pl
5
+
6
+ from .dataset import (
7
+ DivideAndRemasterDataset,
8
+ DivideAndRemasterDeterministicChunkDataset,
9
+ DivideAndRemasterRandomChunkDataset,
10
+ DivideAndRemasterRandomChunkDatasetWithSpeechReverb
11
+ )
12
+
13
+
14
+ def DivideAndRemasterDataModule(
15
+ data_root: str = "$DATA_ROOT/DnR/v2",
16
+ batch_size: int = 2,
17
+ num_workers: int = 8,
18
+ train_kwargs: Optional[Mapping] = None,
19
+ val_kwargs: Optional[Mapping] = None,
20
+ test_kwargs: Optional[Mapping] = None,
21
+ datamodule_kwargs: Optional[Mapping] = None,
22
+ use_speech_reverb: bool = False
23
+ # augmentor=None
24
+ ) -> pl.LightningDataModule:
25
+ if train_kwargs is None:
26
+ train_kwargs = {}
27
+
28
+ if val_kwargs is None:
29
+ val_kwargs = {}
30
+
31
+ if test_kwargs is None:
32
+ test_kwargs = {}
33
+
34
+ if datamodule_kwargs is None:
35
+ datamodule_kwargs = {}
36
+
37
+ if num_workers is None:
38
+ num_workers = os.cpu_count()
39
+
40
+ if num_workers is None:
41
+ num_workers = 32
42
+
43
+ num_workers = min(num_workers, 64)
44
+
45
+ if use_speech_reverb:
46
+ train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb
47
+ else:
48
+ train_cls = DivideAndRemasterRandomChunkDataset
49
+
50
+ train_dataset = train_cls(
51
+ data_root, "train", **train_kwargs
52
+ )
53
+
54
+ # if augmentor is not None:
55
+ # train_dataset = AugmentedDataset(train_dataset, augmentor)
56
+
57
+ datamodule = pl.LightningDataModule.from_datasets(
58
+ train_dataset=train_dataset,
59
+ val_dataset=DivideAndRemasterDeterministicChunkDataset(
60
+ data_root, "val", **val_kwargs
61
+ ),
62
+ test_dataset=DivideAndRemasterDataset(
63
+ data_root,
64
+ "test",
65
+ **test_kwargs
66
+ ),
67
+ batch_size=batch_size,
68
+ num_workers=num_workers,
69
+ **datamodule_kwargs
70
+ )
71
+
72
+ datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
73
+
74
+ return datamodule
separator/models/bandit/core/data/dnr/dataset.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import numpy as np
6
+ import pedalboard as pb
7
+ import torch
8
+ import torchaudio as ta
9
+ from torch.utils import data
10
+
11
+ from models.bandit.core.data._types import AudioDict, DataDict
12
+ from models.bandit.core.data.base import BaseSourceSeparationDataset
13
+
14
+
15
+ class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
16
+ ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
17
+ STEM_NAME_MAP = {
18
+ "mixture": "mix",
19
+ "speech": "speech",
20
+ "music": "music",
21
+ "effects": "sfx",
22
+ }
23
+ SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
24
+
25
+ FULL_TRACK_LENGTH_SECOND = 60
26
+ FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
27
+
28
+ def __init__(
29
+ self,
30
+ split: str,
31
+ stems: List[str],
32
+ files: List[str],
33
+ data_path: str,
34
+ fs: int = 44100,
35
+ npy_memmap: bool = True,
36
+ recompute_mixture: bool = False,
37
+ ) -> None:
38
+ super().__init__(
39
+ split=split,
40
+ stems=stems,
41
+ files=files,
42
+ data_path=data_path,
43
+ fs=fs,
44
+ npy_memmap=npy_memmap,
45
+ recompute_mixture=recompute_mixture
46
+ )
47
+
48
+ def get_stem(
49
+ self,
50
+ *,
51
+ stem: str,
52
+ identifier: Dict[str, Any]
53
+ ) -> torch.Tensor:
54
+
55
+ if stem == "mne":
56
+ return self.get_stem(
57
+ stem="music",
58
+ identifier=identifier) + self.get_stem(
59
+ stem="effects",
60
+ identifier=identifier)
61
+
62
+ track = identifier["track"]
63
+ path = os.path.join(self.data_path, track)
64
+
65
+ if self.npy_memmap:
66
+ audio = np.load(
67
+ os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"),
68
+ mmap_mode="r"
69
+ )
70
+ else:
71
+ # noinspection PyUnresolvedReferences
72
+ audio, _ = ta.load(
73
+ os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav")
74
+ )
75
+
76
+ return audio
77
+
78
+ def get_identifier(self, index):
79
+ return dict(track=self.files[index])
80
+
81
+ def __getitem__(self, index: int) -> DataDict:
82
+ identifier = self.get_identifier(index)
83
+ audio = self.get_audio(identifier)
84
+
85
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
86
+
87
+
88
+ class DivideAndRemasterDataset(DivideAndRemasterBaseDataset):
89
+ def __init__(
90
+ self,
91
+ data_root: str,
92
+ split: str,
93
+ stems: Optional[List[str]] = None,
94
+ fs: int = 44100,
95
+ npy_memmap: bool = True,
96
+ ) -> None:
97
+
98
+ if stems is None:
99
+ stems = self.ALLOWED_STEMS
100
+ self.stems = stems
101
+
102
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
103
+
104
+ files = sorted(os.listdir(data_path))
105
+ files = [
106
+ f
107
+ for f in files
108
+ if (not f.startswith(".")) and os.path.isdir(
109
+ os.path.join(data_path, f)
110
+ )
111
+ ]
112
+ # pprint(list(enumerate(files)))
113
+ if split == "train":
114
+ assert len(files) == 3406, len(files)
115
+ elif split == "val":
116
+ assert len(files) == 487, len(files)
117
+ elif split == "test":
118
+ assert len(files) == 973, len(files)
119
+
120
+ self.n_tracks = len(files)
121
+
122
+ super().__init__(
123
+ data_path=data_path,
124
+ split=split,
125
+ stems=stems,
126
+ files=files,
127
+ fs=fs,
128
+ npy_memmap=npy_memmap,
129
+ )
130
+
131
+ def __len__(self) -> int:
132
+ return self.n_tracks
133
+
134
+
135
+ class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset):
136
+ def __init__(
137
+ self,
138
+ data_root: str,
139
+ split: str,
140
+ target_length: int,
141
+ chunk_size_second: float,
142
+ stems: Optional[List[str]] = None,
143
+ fs: int = 44100,
144
+ npy_memmap: bool = True,
145
+ ) -> None:
146
+
147
+ if stems is None:
148
+ stems = self.ALLOWED_STEMS
149
+ self.stems = stems
150
+
151
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
152
+
153
+ files = sorted(os.listdir(data_path))
154
+ files = [
155
+ f
156
+ for f in files
157
+ if (not f.startswith(".")) and os.path.isdir(
158
+ os.path.join(data_path, f)
159
+ )
160
+ ]
161
+
162
+ if split == "train":
163
+ assert len(files) == 3406, len(files)
164
+ elif split == "val":
165
+ assert len(files) == 487, len(files)
166
+ elif split == "test":
167
+ assert len(files) == 973, len(files)
168
+
169
+ self.n_tracks = len(files)
170
+
171
+ self.target_length = target_length
172
+ self.chunk_size = int(chunk_size_second * fs)
173
+
174
+ super().__init__(
175
+ data_path=data_path,
176
+ split=split,
177
+ stems=stems,
178
+ files=files,
179
+ fs=fs,
180
+ npy_memmap=npy_memmap,
181
+ )
182
+
183
+ def __len__(self) -> int:
184
+ return self.target_length
185
+
186
+ def get_identifier(self, index):
187
+ return super().get_identifier(index % self.n_tracks)
188
+
189
+ def get_stem(
190
+ self,
191
+ *,
192
+ stem: str,
193
+ identifier: Dict[str, Any],
194
+ chunk_here: bool = False,
195
+ ) -> torch.Tensor:
196
+
197
+ stem = super().get_stem(
198
+ stem=stem,
199
+ identifier=identifier
200
+ )
201
+
202
+ if chunk_here:
203
+ start = np.random.randint(
204
+ 0,
205
+ self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
206
+ )
207
+ end = start + self.chunk_size
208
+
209
+ stem = stem[:, start:end]
210
+
211
+ return stem
212
+
213
+ def __getitem__(self, index: int) -> DataDict:
214
+ identifier = self.get_identifier(index)
215
+ # self.index_lock = index
216
+ audio = self.get_audio(identifier)
217
+ # self.index_lock = None
218
+
219
+ start = np.random.randint(
220
+ 0,
221
+ self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
222
+ )
223
+ end = start + self.chunk_size
224
+
225
+ audio = {
226
+ k: v[:, start:end] for k, v in audio.items()
227
+ }
228
+
229
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
230
+
231
+
232
+ class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset):
233
+ def __init__(
234
+ self,
235
+ data_root: str,
236
+ split: str,
237
+ chunk_size_second: float,
238
+ hop_size_second: float,
239
+ stems: Optional[List[str]] = None,
240
+ fs: int = 44100,
241
+ npy_memmap: bool = True,
242
+ ) -> None:
243
+
244
+ if stems is None:
245
+ stems = self.ALLOWED_STEMS
246
+ self.stems = stems
247
+
248
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
249
+
250
+ files = sorted(os.listdir(data_path))
251
+ files = [
252
+ f
253
+ for f in files
254
+ if (not f.startswith(".")) and os.path.isdir(
255
+ os.path.join(data_path, f)
256
+ )
257
+ ]
258
+ # pprint(list(enumerate(files)))
259
+ if split == "train":
260
+ assert len(files) == 3406, len(files)
261
+ elif split == "val":
262
+ assert len(files) == 487, len(files)
263
+ elif split == "test":
264
+ assert len(files) == 973, len(files)
265
+
266
+ self.n_tracks = len(files)
267
+
268
+ self.chunk_size = int(chunk_size_second * fs)
269
+ self.hop_size = int(hop_size_second * fs)
270
+ self.n_chunks_per_track = int(
271
+ (
272
+ self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
273
+ )
274
+
275
+ self.length = self.n_tracks * self.n_chunks_per_track
276
+
277
+ super().__init__(
278
+ data_path=data_path,
279
+ split=split,
280
+ stems=stems,
281
+ files=files,
282
+ fs=fs,
283
+ npy_memmap=npy_memmap,
284
+ )
285
+
286
+ def get_identifier(self, index):
287
+ return super().get_identifier(index % self.n_tracks)
288
+
289
+ def __len__(self) -> int:
290
+ return self.length
291
+
292
+ def __getitem__(self, item: int) -> DataDict:
293
+
294
+ index = item % self.n_tracks
295
+ chunk = item // self.n_tracks
296
+
297
+ data_ = super().__getitem__(index)
298
+
299
+ audio = data_["audio"]
300
+
301
+ start = chunk * self.hop_size
302
+ end = start + self.chunk_size
303
+
304
+ for stem in self.stems:
305
+ data_["audio"][stem] = audio[stem][:, start:end]
306
+
307
+ return data_
308
+
309
+
310
+ class DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
311
+ DivideAndRemasterRandomChunkDataset
312
+ ):
313
+ def __init__(
314
+ self,
315
+ data_root: str,
316
+ split: str,
317
+ target_length: int,
318
+ chunk_size_second: float,
319
+ stems: Optional[List[str]] = None,
320
+ fs: int = 44100,
321
+ npy_memmap: bool = True,
322
+ ) -> None:
323
+
324
+ if stems is None:
325
+ stems = self.ALLOWED_STEMS
326
+
327
+ stems_no_mixture = [s for s in stems if s != "mixture"]
328
+
329
+ super().__init__(
330
+ data_root=data_root,
331
+ split=split,
332
+ target_length=target_length,
333
+ chunk_size_second=chunk_size_second,
334
+ stems=stems_no_mixture,
335
+ fs=fs,
336
+ npy_memmap=npy_memmap,
337
+ )
338
+
339
+ self.stems = stems
340
+ self.stems_no_mixture = stems_no_mixture
341
+
342
+ def __getitem__(self, index: int) -> DataDict:
343
+
344
+ data_ = super().__getitem__(index)
345
+
346
+ dry = data_["audio"]["speech"][:]
347
+ n_samples = dry.shape[-1]
348
+
349
+ wet_level = np.random.rand()
350
+
351
+ speech = pb.Reverb(
352
+ room_size=np.random.rand(),
353
+ damping=np.random.rand(),
354
+ wet_level=wet_level,
355
+ dry_level=(1 - wet_level),
356
+ width=np.random.rand()
357
+ ).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples]
358
+
359
+ data_["audio"]["speech"] = speech
360
+
361
+ data_["audio"]["mixture"] = sum(
362
+ [data_["audio"][s] for s in self.stems_no_mixture]
363
+ )
364
+
365
+ return data_
366
+
367
+ def __len__(self) -> int:
368
+ return super().__len__()
369
+
370
+
371
+ if __name__ == "__main__":
372
+
373
+ from pprint import pprint
374
+ from tqdm.auto import tqdm
375
+
376
+ for split_ in ["train", "val", "test"]:
377
+ ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
378
+ data_root="$DATA_ROOT/DnR/v2np",
379
+ split=split_,
380
+ target_length=100,
381
+ chunk_size_second=6.0
382
+ )
383
+
384
+ print(split_, len(ds))
385
+
386
+ for track_ in tqdm(ds): # type: ignore
387
+ pprint(track_)
388
+ track_["audio"] = {k: v.shape for k, v in track_["audio"].items()}
389
+ pprint(track_)
390
+ # break
391
+
392
+ break
separator/models/bandit/core/data/dnr/preprocess.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from typing import Tuple
4
+
5
+ import numpy as np
6
+ import torchaudio as ta
7
+ from tqdm.contrib.concurrent import process_map
8
+
9
+
10
+ def process_one(inputs: Tuple[str, str, int]) -> None:
11
+ infile, outfile, target_fs = inputs
12
+
13
+ dir = os.path.dirname(outfile)
14
+ os.makedirs(dir, exist_ok=True)
15
+
16
+ data, fs = ta.load(infile)
17
+
18
+ if fs != target_fs:
19
+ data = ta.functional.resample(data, fs, target_fs, resampling_method="sinc_interp_kaiser")
20
+ fs = target_fs
21
+
22
+ data = data.numpy()
23
+ data = data.astype(np.float32)
24
+
25
+ if os.path.exists(outfile):
26
+ data_ = np.load(outfile)
27
+ if np.allclose(data, data_):
28
+ return
29
+
30
+ np.save(outfile, data)
31
+
32
+
33
+ def preprocess(
34
+ data_path: str,
35
+ output_path: str,
36
+ fs: int
37
+ ) -> None:
38
+ files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
39
+ print(files)
40
+ outfiles = [
41
+ f.replace(data_path, output_path).replace(".wav", ".npy") for f in
42
+ files
43
+ ]
44
+
45
+ os.makedirs(output_path, exist_ok=True)
46
+ inputs = list(zip(files, outfiles, [fs] * len(files)))
47
+
48
+ process_map(process_one, inputs, chunksize=32)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ import fire
53
+
54
+ fire.Fire()
separator/models/bandit/core/data/musdb/__init__.py ADDED
File without changes
separator/models/bandit/core/data/musdb/datamodule.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from typing import Mapping, Optional
3
+
4
+ import pytorch_lightning as pl
5
+
6
+ from models.bandit.core.data.musdb.dataset import (
7
+ MUSDB18BaseDataset,
8
+ MUSDB18FullTrackDataset,
9
+ MUSDB18SadDataset,
10
+ MUSDB18SadOnTheFlyAugmentedDataset
11
+ )
12
+
13
+
14
+ def MUSDB18DataModule(
15
+ data_root: str = "$DATA_ROOT/MUSDB18/HQ",
16
+ target_stem: str = "vocals",
17
+ batch_size: int = 2,
18
+ num_workers: int = 8,
19
+ train_kwargs: Optional[Mapping] = None,
20
+ val_kwargs: Optional[Mapping] = None,
21
+ test_kwargs: Optional[Mapping] = None,
22
+ datamodule_kwargs: Optional[Mapping] = None,
23
+ use_on_the_fly: bool = True,
24
+ npy_memmap: bool = True
25
+ ) -> pl.LightningDataModule:
26
+ if train_kwargs is None:
27
+ train_kwargs = {}
28
+
29
+ if val_kwargs is None:
30
+ val_kwargs = {}
31
+
32
+ if test_kwargs is None:
33
+ test_kwargs = {}
34
+
35
+ if datamodule_kwargs is None:
36
+ datamodule_kwargs = {}
37
+
38
+ train_dataset: MUSDB18BaseDataset
39
+
40
+ if use_on_the_fly:
41
+ train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
42
+ data_root=os.path.join(data_root, "saded-np"),
43
+ split="train",
44
+ target_stem=target_stem,
45
+ **train_kwargs
46
+ )
47
+ else:
48
+ train_dataset = MUSDB18SadDataset(
49
+ data_root=os.path.join(data_root, "saded-np"),
50
+ split="train",
51
+ target_stem=target_stem,
52
+ **train_kwargs
53
+ )
54
+
55
+ datamodule = pl.LightningDataModule.from_datasets(
56
+ train_dataset=train_dataset,
57
+ val_dataset=MUSDB18SadDataset(
58
+ data_root=os.path.join(data_root, "saded-np"),
59
+ split="val",
60
+ target_stem=target_stem,
61
+ **val_kwargs
62
+ ),
63
+ test_dataset=MUSDB18FullTrackDataset(
64
+ data_root=os.path.join(data_root, "canonical"),
65
+ split="test",
66
+ **test_kwargs
67
+ ),
68
+ batch_size=batch_size,
69
+ num_workers=num_workers,
70
+ **datamodule_kwargs
71
+ )
72
+
73
+ datamodule.predict_dataloader = ( # type: ignore[method-assign]
74
+ datamodule.test_dataloader
75
+ )
76
+
77
+ return datamodule
separator/models/bandit/core/data/musdb/dataset.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC
3
+ from typing import List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchaudio as ta
8
+ from torch.utils import data
9
+
10
+ from models.bandit.core.data._types import AudioDict, DataDict
11
+ from models.bandit.core.data.base import BaseSourceSeparationDataset
12
+
13
+
14
+ class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
15
+
16
+ ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
17
+
18
+ def __init__(
19
+ self,
20
+ split: str,
21
+ stems: List[str],
22
+ files: List[str],
23
+ data_path: str,
24
+ fs: int = 44100,
25
+ npy_memmap=False,
26
+ ) -> None:
27
+ super().__init__(
28
+ split=split,
29
+ stems=stems,
30
+ files=files,
31
+ data_path=data_path,
32
+ fs=fs,
33
+ npy_memmap=npy_memmap,
34
+ recompute_mixture=False
35
+ )
36
+
37
+ def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
38
+ track = identifier["track"]
39
+ path = os.path.join(self.data_path, track)
40
+ # noinspection PyUnresolvedReferences
41
+
42
+ if self.npy_memmap:
43
+ audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r")
44
+ else:
45
+ audio, _ = ta.load(os.path.join(path, f"{stem}.wav"))
46
+
47
+ return audio
48
+
49
+ def get_identifier(self, index):
50
+ return dict(track=self.files[index])
51
+
52
+ def __getitem__(self, index: int) -> DataDict:
53
+ identifier = self.get_identifier(index)
54
+ audio = self.get_audio(identifier)
55
+
56
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
57
+
58
+
59
+ class MUSDB18FullTrackDataset(MUSDB18BaseDataset):
60
+
61
+ N_TRAIN_TRACKS = 100
62
+ N_TEST_TRACKS = 50
63
+ VALIDATION_FILES = [
64
+ "Actions - One Minute Smile",
65
+ "Clara Berry And Wooldog - Waltz For My Victims",
66
+ "Johnny Lokke - Promises & Lies",
67
+ "Patrick Talbot - A Reason To Leave",
68
+ "Triviul - Angelsaint",
69
+ "Alexander Ross - Goodbye Bolero",
70
+ "Fergessen - Nos Palpitants",
71
+ "Leaf - Summerghost",
72
+ "Skelpolu - Human Mistakes",
73
+ "Young Griffo - Pennies",
74
+ "ANiMAL - Rockshow",
75
+ "James May - On The Line",
76
+ "Meaxic - Take A Step",
77
+ "Traffic Experiment - Sirens",
78
+ ]
79
+
80
+ def __init__(
81
+ self, data_root: str, split: str, stems: Optional[List[
82
+ str]] = None
83
+ ) -> None:
84
+
85
+ if stems is None:
86
+ stems = self.ALLOWED_STEMS
87
+ self.stems = stems
88
+
89
+ if split == "test":
90
+ subset = "test"
91
+ elif split in ["train", "val"]:
92
+ subset = "train"
93
+ else:
94
+ raise NameError
95
+
96
+ data_path = os.path.join(data_root, subset)
97
+
98
+ files = sorted(os.listdir(data_path))
99
+ files = [f for f in files if not f.startswith(".")]
100
+ # pprint(list(enumerate(files)))
101
+ if subset == "train":
102
+ assert len(files) == 100, len(files)
103
+ if split == "train":
104
+ files = [f for f in files if f not in self.VALIDATION_FILES]
105
+ assert len(files) == 100 - len(self.VALIDATION_FILES)
106
+ else:
107
+ files = [f for f in files if f in self.VALIDATION_FILES]
108
+ assert len(files) == len(self.VALIDATION_FILES)
109
+ else:
110
+ split = "test"
111
+ assert len(files) == 50
112
+
113
+ self.n_tracks = len(files)
114
+
115
+ super().__init__(
116
+ data_path=data_path,
117
+ split=split,
118
+ stems=stems,
119
+ files=files
120
+ )
121
+
122
+ def __len__(self) -> int:
123
+ return self.n_tracks
124
+
125
+ class MUSDB18SadDataset(MUSDB18BaseDataset):
126
+ def __init__(
127
+ self,
128
+ data_root: str,
129
+ split: str,
130
+ target_stem: str,
131
+ stems: Optional[List[str]] = None,
132
+ target_length: Optional[int] = None,
133
+ npy_memmap=False,
134
+ ) -> None:
135
+
136
+ if stems is None:
137
+ stems = self.ALLOWED_STEMS
138
+
139
+ data_path = os.path.join(data_root, target_stem, split)
140
+
141
+ files = sorted(os.listdir(data_path))
142
+ files = [f for f in files if not f.startswith(".")]
143
+
144
+ super().__init__(
145
+ data_path=data_path,
146
+ split=split,
147
+ stems=stems,
148
+ files=files,
149
+ npy_memmap=npy_memmap
150
+ )
151
+ self.n_segments = len(files)
152
+ self.target_stem = target_stem
153
+ self.target_length = (
154
+ target_length if target_length is not None else self.n_segments
155
+ )
156
+
157
+ def __len__(self) -> int:
158
+ return self.target_length
159
+
160
+ def __getitem__(self, index: int) -> DataDict:
161
+
162
+ index = index % self.n_segments
163
+
164
+ return super().__getitem__(index)
165
+
166
+ def get_identifier(self, index):
167
+ return super().get_identifier(index % self.n_segments)
168
+
169
+
170
+ class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset):
171
+ def __init__(
172
+ self,
173
+ data_root: str,
174
+ split: str,
175
+ target_stem: str,
176
+ stems: Optional[List[str]] = None,
177
+ target_length: int = 20000,
178
+ apply_probability: Optional[float] = None,
179
+ chunk_size_second: float = 3.0,
180
+ random_scale_range_db: Tuple[float, float] = (-10, 10),
181
+ drop_probability: float = 0.1,
182
+ rescale: bool = True,
183
+ ) -> None:
184
+ super().__init__(data_root, split, target_stem, stems)
185
+
186
+ if apply_probability is None:
187
+ apply_probability = (
188
+ target_length - self.n_segments) / target_length
189
+
190
+ self.apply_probability = apply_probability
191
+ self.drop_probability = drop_probability
192
+ self.chunk_size_second = chunk_size_second
193
+ self.random_scale_range_db = random_scale_range_db
194
+ self.rescale = rescale
195
+
196
+ self.chunk_size_sample = int(self.chunk_size_second * self.fs)
197
+ self.target_length = target_length
198
+
199
+ def __len__(self) -> int:
200
+ return self.target_length
201
+
202
+ def __getitem__(self, index: int) -> DataDict:
203
+
204
+ index = index % self.n_segments
205
+
206
+ # if np.random.rand() > self.apply_probability:
207
+ # return super().__getitem__(index)
208
+
209
+ audio = {}
210
+ identifier = self.get_identifier(index)
211
+
212
+ # assert self.target_stem in self.stems_no_mixture
213
+ for stem in self.stems_no_mixture:
214
+ if stem == self.target_stem:
215
+ identifier_ = identifier
216
+ else:
217
+ if np.random.rand() < self.apply_probability:
218
+ index_ = np.random.randint(self.n_segments)
219
+ identifier_ = self.get_identifier(index_)
220
+ else:
221
+ identifier_ = identifier
222
+
223
+ audio[stem] = self.get_stem(stem=stem, identifier=identifier_)
224
+
225
+ # if stem == self.target_stem:
226
+
227
+ if self.chunk_size_sample < audio[stem].shape[-1]:
228
+ chunk_start = np.random.randint(
229
+ audio[stem].shape[-1] - self.chunk_size_sample
230
+ )
231
+ else:
232
+ chunk_start = 0
233
+
234
+ if np.random.rand() < self.drop_probability:
235
+ # db_scale = "-inf"
236
+ linear_scale = 0.0
237
+ else:
238
+ db_scale = np.random.uniform(*self.random_scale_range_db)
239
+ linear_scale = np.power(10, db_scale / 20)
240
+ # db_scale = f"{db_scale:+2.1f}"
241
+ # print(linear_scale)
242
+ audio[stem][...,
243
+ chunk_start: chunk_start + self.chunk_size_sample] = (
244
+ linear_scale
245
+ * audio[stem][...,
246
+ chunk_start: chunk_start + self.chunk_size_sample]
247
+ )
248
+
249
+ audio["mixture"] = self.compute_mixture(audio)
250
+
251
+ if self.rescale:
252
+ max_abs_val = max(
253
+ [torch.max(torch.abs(audio[stem])) for stem in self.stems]
254
+ ) # type: ignore[type-var]
255
+ if max_abs_val > 1:
256
+ audio = {k: v / max_abs_val for k, v in audio.items()}
257
+
258
+ track = identifier["track"]
259
+
260
+ return {"audio": audio, "track": f"{self.split}/{track}"}
261
+
262
+ # if __name__ == "__main__":
263
+ #
264
+ # from pprint import pprint
265
+ # from tqdm.auto import tqdm
266
+ #
267
+ # for split_ in ["train", "val", "test"]:
268
+ # ds = MUSDB18SadOnTheFlyAugmentedDataset(
269
+ # data_root="$DATA_ROOT/MUSDB18/HQ/saded",
270
+ # split=split_,
271
+ # target_stem="vocals"
272
+ # )
273
+ #
274
+ # print(split_, len(ds))
275
+ #
276
+ # for track_ in tqdm(ds):
277
+ # track_["audio"] = {
278
+ # k: v.shape for k, v in track_["audio"].items()
279
+ # }
280
+ # pprint(track_)
separator/models/bandit/core/data/musdb/preprocess.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchaudio as ta
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from tqdm.contrib.concurrent import process_map
10
+
11
+ from core.data._types import DataDict
12
+ from core.data.musdb.dataset import MUSDB18FullTrackDataset
13
+ import pyloudnorm as pyln
14
+
15
+ class SourceActivityDetector(nn.Module):
16
+ def __init__(
17
+ self,
18
+ analysis_stem: str,
19
+ output_path: str,
20
+ fs: int = 44100,
21
+ segment_length_second: float = 6.0,
22
+ hop_length_second: float = 3.0,
23
+ n_chunks: int = 10,
24
+ chunk_epsilon: float = 1e-5,
25
+ energy_threshold_quantile: float = 0.15,
26
+ segment_epsilon: float = 1e-3,
27
+ salient_proportion_threshold: float = 0.5,
28
+ target_lufs: float = -24
29
+ ) -> None:
30
+ super().__init__()
31
+
32
+ self.fs = fs
33
+ self.segment_length = int(segment_length_second * self.fs)
34
+ self.hop_length = int(hop_length_second * self.fs)
35
+ self.n_chunks = n_chunks
36
+ assert self.segment_length % self.n_chunks == 0
37
+ self.chunk_size = self.segment_length // self.n_chunks
38
+ self.chunk_epsilon = chunk_epsilon
39
+ self.energy_threshold_quantile = energy_threshold_quantile
40
+ self.segment_epsilon = segment_epsilon
41
+ self.salient_proportion_threshold = salient_proportion_threshold
42
+ self.analysis_stem = analysis_stem
43
+
44
+ self.meter = pyln.Meter(self.fs)
45
+ self.target_lufs = target_lufs
46
+
47
+ self.output_path = output_path
48
+
49
+ def forward(self, data: DataDict) -> None:
50
+
51
+ stem_ = self.analysis_stem if (
52
+ self.analysis_stem != "none") else "mixture"
53
+
54
+ x = data["audio"][stem_]
55
+
56
+ xnp = x.numpy()
57
+ loudness = self.meter.integrated_loudness(xnp.T)
58
+
59
+ for stem in data["audio"]:
60
+ s = data["audio"][stem]
61
+ s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T
62
+ s = torch.as_tensor(s)
63
+ data["audio"][stem] = s
64
+
65
+ if x.ndim == 3:
66
+ assert x.shape[0] == 1
67
+ x = x[0]
68
+
69
+ n_chan, n_samples = x.shape
70
+
71
+ n_segments = (
72
+ int(
73
+ np.ceil((n_samples - self.segment_length) / self.hop_length)
74
+ ) + 1
75
+ )
76
+
77
+ segments = torch.zeros((n_segments, n_chan, self.segment_length))
78
+ for i in range(n_segments):
79
+ start = i * self.hop_length
80
+ end = start + self.segment_length
81
+ end = min(end, n_samples)
82
+
83
+ xseg = x[:, start:end]
84
+
85
+ if end - start < self.segment_length:
86
+ xseg = F.pad(
87
+ xseg,
88
+ pad=(0, self.segment_length - (end - start)),
89
+ value=torch.nan
90
+ )
91
+
92
+ segments[i, :, :] = xseg
93
+
94
+ chunks = segments.reshape(
95
+ (n_segments, n_chan, self.n_chunks, self.chunk_size)
96
+ )
97
+
98
+ if self.analysis_stem != "none":
99
+ chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
100
+ chunk_energies = torch.nan_to_num(chunk_energies, nan=0)
101
+ chunk_energies[chunk_energies == 0] = self.chunk_epsilon
102
+
103
+ energy_threshold = torch.nanquantile(
104
+ chunk_energies, q=self.energy_threshold_quantile
105
+ )
106
+
107
+ if energy_threshold < self.segment_epsilon:
108
+ energy_threshold = self.segment_epsilon # type: ignore[assignment]
109
+
110
+ chunks_above_threshold = chunk_energies > energy_threshold
111
+ n_chunks_above_threshold = torch.mean(
112
+ chunks_above_threshold.to(torch.float), dim=-1
113
+ )
114
+
115
+ segment_above_threshold = (
116
+ n_chunks_above_threshold > self.salient_proportion_threshold
117
+ )
118
+
119
+ if torch.sum(segment_above_threshold) == 0:
120
+ return
121
+
122
+ else:
123
+ segment_above_threshold = torch.ones((n_segments,))
124
+
125
+ for i in range(n_segments):
126
+ if not segment_above_threshold[i]:
127
+ continue
128
+
129
+ outpath = os.path.join(
130
+ self.output_path,
131
+ self.analysis_stem,
132
+ f"{data['track']} - {self.analysis_stem}{i:03d}",
133
+ )
134
+ os.makedirs(outpath, exist_ok=True)
135
+
136
+ for stem in data["audio"]:
137
+ if stem == self.analysis_stem:
138
+ segment = torch.nan_to_num(segments[i, :, :], nan=0)
139
+ else:
140
+ start = i * self.hop_length
141
+ end = start + self.segment_length
142
+ end = min(n_samples, end)
143
+
144
+ segment = data["audio"][stem][:, start:end]
145
+
146
+ if end - start < self.segment_length:
147
+ segment = F.pad(
148
+ segment,
149
+ (0, self.segment_length - (end - start))
150
+ )
151
+
152
+ assert segment.shape[-1] == self.segment_length, segment.shape
153
+
154
+ # ta.save(os.path.join(outpath, f"{stem}.wav"), segment, self.fs)
155
+
156
+ np.save(os.path.join(outpath, f"{stem}.wav"), segment)
157
+
158
+
159
+ def preprocess(
160
+ analysis_stem: str,
161
+ output_path: str = "/data/MUSDB18/HQ/saded-np",
162
+ fs: int = 44100,
163
+ segment_length_second: float = 6.0,
164
+ hop_length_second: float = 3.0,
165
+ n_chunks: int = 10,
166
+ chunk_epsilon: float = 1e-5,
167
+ energy_threshold_quantile: float = 0.15,
168
+ segment_epsilon: float = 1e-3,
169
+ salient_proportion_threshold: float = 0.5,
170
+ ) -> None:
171
+
172
+ sad = SourceActivityDetector(
173
+ analysis_stem=analysis_stem,
174
+ output_path=output_path,
175
+ fs=fs,
176
+ segment_length_second=segment_length_second,
177
+ hop_length_second=hop_length_second,
178
+ n_chunks=n_chunks,
179
+ chunk_epsilon=chunk_epsilon,
180
+ energy_threshold_quantile=energy_threshold_quantile,
181
+ segment_epsilon=segment_epsilon,
182
+ salient_proportion_threshold=salient_proportion_threshold,
183
+ )
184
+
185
+ for split in ["train", "val", "test"]:
186
+ ds = MUSDB18FullTrackDataset(
187
+ data_root="/data/MUSDB18/HQ/canonical",
188
+ split=split,
189
+ )
190
+
191
+ tracks = []
192
+ for i, track in enumerate(tqdm(ds, total=len(ds))):
193
+ if i % 32 == 0 and tracks:
194
+ process_map(sad, tracks, max_workers=8)
195
+ tracks = []
196
+ tracks.append(track)
197
+ process_map(sad, tracks, max_workers=8)
198
+
199
+ def loudness_norm_one(
200
+ inputs
201
+ ):
202
+ infile, outfile, target_lufs = inputs
203
+
204
+ audio, fs = ta.load(infile)
205
+ audio = audio.mean(dim=0, keepdim=True).numpy().T
206
+
207
+ meter = pyln.Meter(fs)
208
+ loudness = meter.integrated_loudness(audio)
209
+ audio = pyln.normalize.loudness(audio, loudness, target_lufs)
210
+
211
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
212
+ np.save(outfile, audio.T)
213
+
214
+ def loudness_norm(
215
+ data_path: str,
216
+ # output_path: str,
217
+ target_lufs = -17.0,
218
+ ):
219
+ files = glob.glob(
220
+ os.path.join(data_path, "**", "*.wav"), recursive=True
221
+ )
222
+
223
+ outfiles = [
224
+ f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files
225
+ ]
226
+
227
+ files = [(f, o, target_lufs) for f, o in zip(files, outfiles)]
228
+
229
+ process_map(loudness_norm_one, files, chunksize=2)
230
+
231
+
232
+
233
+ if __name__ == "__main__":
234
+
235
+ from tqdm.auto import tqdm
236
+ import fire
237
+
238
+ fire.Fire()
separator/models/bandit/core/data/musdb/validation.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ validation:
2
+ - 'Actions - One Minute Smile'
3
+ - 'Clara Berry And Wooldog - Waltz For My Victims'
4
+ - 'Johnny Lokke - Promises & Lies'
5
+ - 'Patrick Talbot - A Reason To Leave'
6
+ - 'Triviul - Angelsaint'
7
+ - 'Alexander Ross - Goodbye Bolero'
8
+ - 'Fergessen - Nos Palpitants'
9
+ - 'Leaf - Summerghost'
10
+ - 'Skelpolu - Human Mistakes'
11
+ - 'Young Griffo - Pennies'
12
+ - 'ANiMAL - Rockshow'
13
+ - 'James May - On The Line'
14
+ - 'Meaxic - Take A Step'
15
+ - 'Traffic Experiment - Sirens'
separator/models/bandit/core/loss/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from ._multistem import MultiStemWrapperFromConfig
2
+ from ._timefreq import ReImL1Loss, ReImL2Loss, TimeFreqL1Loss, TimeFreqL2Loss, TimeFreqSignalNoisePNormRatioLoss
separator/models/bandit/core/loss/_complex.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.modules import loss as _loss
6
+ from torch.nn.modules.loss import _Loss
7
+
8
+
9
+ class ReImLossWrapper(_Loss):
10
+ def __init__(self, module: _Loss) -> None:
11
+ super().__init__()
12
+ self.module = module
13
+
14
+ def forward(
15
+ self,
16
+ preds: torch.Tensor,
17
+ target: torch.Tensor
18
+ ) -> torch.Tensor:
19
+ return self.module(
20
+ torch.view_as_real(preds),
21
+ torch.view_as_real(target)
22
+ )
23
+
24
+
25
+ class ReImL1Loss(ReImLossWrapper):
26
+ def __init__(self, **kwargs: Any) -> None:
27
+ l1_loss = _loss.L1Loss(**kwargs)
28
+ super().__init__(module=(l1_loss))
29
+
30
+
31
+ class ReImL2Loss(ReImLossWrapper):
32
+ def __init__(self, **kwargs: Any) -> None:
33
+ l2_loss = _loss.MSELoss(**kwargs)
34
+ super().__init__(module=(l2_loss))
separator/models/bandit/core/loss/_multistem.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ import torch
4
+ from asteroid import losses as asteroid_losses
5
+ from torch import nn
6
+ from torch.nn.modules.loss import _Loss
7
+
8
+ from . import snr
9
+
10
+
11
+ def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss:
12
+
13
+ for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]:
14
+ if name in module.__dict__:
15
+ return module.__dict__[name](**kwargs)
16
+
17
+ raise NameError
18
+
19
+
20
+ class MultiStemWrapper(_Loss):
21
+ def __init__(self, module: _Loss, modality: str = "audio") -> None:
22
+ super().__init__()
23
+ self.loss = module
24
+ self.modality = modality
25
+
26
+ def forward(
27
+ self,
28
+ preds: Dict[str, Dict[str, torch.Tensor]],
29
+ target: Dict[str, Dict[str, torch.Tensor]],
30
+ ) -> torch.Tensor:
31
+ loss = {
32
+ stem: self.loss(
33
+ preds[self.modality][stem],
34
+ target[self.modality][stem]
35
+ )
36
+ for stem in preds[self.modality] if stem in target[self.modality]
37
+ }
38
+
39
+ return sum(list(loss.values()))
40
+
41
+
42
+ class MultiStemWrapperFromConfig(MultiStemWrapper):
43
+ def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None:
44
+ loss = parse_loss(name, kwargs)
45
+ super().__init__(module=loss, modality=modality)
separator/models/bandit/core/loss/_timefreq.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.modules.loss import _Loss
6
+
7
+ from models.bandit.core.loss._multistem import MultiStemWrapper
8
+ from models.bandit.core.loss._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
9
+ from models.bandit.core.loss.snr import SignalNoisePNormRatio
10
+
11
+ class TimeFreqWrapper(_Loss):
12
+ def __init__(
13
+ self,
14
+ time_module: _Loss,
15
+ freq_module: Optional[_Loss] = None,
16
+ time_weight: float = 1.0,
17
+ freq_weight: float = 1.0,
18
+ multistem: bool = True,
19
+ ) -> None:
20
+ super().__init__()
21
+
22
+ if freq_module is None:
23
+ freq_module = time_module
24
+
25
+ if multistem:
26
+ time_module = MultiStemWrapper(time_module, modality="audio")
27
+ freq_module = MultiStemWrapper(freq_module, modality="spectrogram")
28
+
29
+ self.time_module = time_module
30
+ self.freq_module = freq_module
31
+
32
+ self.time_weight = time_weight
33
+ self.freq_weight = freq_weight
34
+
35
+ # TODO: add better type hints
36
+ def forward(self, preds: Any, target: Any) -> torch.Tensor:
37
+
38
+ return self.time_weight * self.time_module(
39
+ preds, target
40
+ ) + self.freq_weight * self.freq_module(preds, target)
41
+
42
+
43
+ class TimeFreqL1Loss(TimeFreqWrapper):
44
+ def __init__(
45
+ self,
46
+ time_weight: float = 1.0,
47
+ freq_weight: float = 1.0,
48
+ tkwargs: Optional[Dict[str, Any]] = None,
49
+ fkwargs: Optional[Dict[str, Any]] = None,
50
+ multistem: bool = True,
51
+ ) -> None:
52
+ if tkwargs is None:
53
+ tkwargs = {}
54
+ if fkwargs is None:
55
+ fkwargs = {}
56
+ time_module = (nn.L1Loss(**tkwargs))
57
+ freq_module = ReImL1Loss(**fkwargs)
58
+ super().__init__(
59
+ time_module,
60
+ freq_module,
61
+ time_weight,
62
+ freq_weight,
63
+ multistem
64
+ )
65
+
66
+
67
+ class TimeFreqL2Loss(TimeFreqWrapper):
68
+ def __init__(
69
+ self,
70
+ time_weight: float = 1.0,
71
+ freq_weight: float = 1.0,
72
+ tkwargs: Optional[Dict[str, Any]] = None,
73
+ fkwargs: Optional[Dict[str, Any]] = None,
74
+ multistem: bool = True,
75
+ ) -> None:
76
+ if tkwargs is None:
77
+ tkwargs = {}
78
+ if fkwargs is None:
79
+ fkwargs = {}
80
+ time_module = nn.MSELoss(**tkwargs)
81
+ freq_module = ReImL2Loss(**fkwargs)
82
+ super().__init__(
83
+ time_module,
84
+ freq_module,
85
+ time_weight,
86
+ freq_weight,
87
+ multistem
88
+ )
89
+
90
+
91
+
92
+ class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper):
93
+ def __init__(
94
+ self,
95
+ time_weight: float = 1.0,
96
+ freq_weight: float = 1.0,
97
+ tkwargs: Optional[Dict[str, Any]] = None,
98
+ fkwargs: Optional[Dict[str, Any]] = None,
99
+ multistem: bool = True,
100
+ ) -> None:
101
+ if tkwargs is None:
102
+ tkwargs = {}
103
+ if fkwargs is None:
104
+ fkwargs = {}
105
+ time_module = SignalNoisePNormRatio(**tkwargs)
106
+ freq_module = SignalNoisePNormRatio(**fkwargs)
107
+ super().__init__(
108
+ time_module,
109
+ freq_module,
110
+ time_weight,
111
+ freq_weight,
112
+ multistem
113
+ )
separator/models/bandit/core/loss/snr.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.modules.loss import _Loss
3
+ from torch.nn import functional as F
4
+
5
+ class SignalNoisePNormRatio(_Loss):
6
+ def __init__(
7
+ self,
8
+ p: float = 1.0,
9
+ scale_invariant: bool = False,
10
+ zero_mean: bool = False,
11
+ take_log: bool = True,
12
+ reduction: str = "mean",
13
+ EPS: float = 1e-3,
14
+ ) -> None:
15
+ assert reduction != "sum", NotImplementedError
16
+ super().__init__(reduction=reduction)
17
+ assert not zero_mean
18
+
19
+ self.p = p
20
+
21
+ self.EPS = EPS
22
+ self.take_log = take_log
23
+
24
+ self.scale_invariant = scale_invariant
25
+
26
+ def forward(
27
+ self,
28
+ est_target: torch.Tensor,
29
+ target: torch.Tensor
30
+ ) -> torch.Tensor:
31
+
32
+ target_ = target
33
+ if self.scale_invariant:
34
+ ndim = target.ndim
35
+ dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
36
+ s_target_energy = (
37
+ torch.sum(target * torch.conj(target), dim=-1, keepdim=True)
38
+ )
39
+
40
+ if ndim > 2:
41
+ dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
42
+ s_target_energy = torch.sum(s_target_energy, dim=list(range(1, ndim)), keepdim=True)
43
+
44
+ target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
45
+ target = target_ * target_scaler
46
+
47
+ if torch.is_complex(est_target):
48
+ est_target = torch.view_as_real(est_target)
49
+ target = torch.view_as_real(target)
50
+
51
+
52
+ batch_size = est_target.shape[0]
53
+ est_target = est_target.reshape(batch_size, -1)
54
+ target = target.reshape(batch_size, -1)
55
+ # target_ = target_.reshape(batch_size, -1)
56
+
57
+ if self.p == 1:
58
+ e_error = torch.abs(est_target-target).mean(dim=-1)
59
+ e_target = torch.abs(target).mean(dim=-1)
60
+ elif self.p == 2:
61
+ e_error = torch.square(est_target-target).mean(dim=-1)
62
+ e_target = torch.square(target).mean(dim=-1)
63
+ else:
64
+ raise NotImplementedError
65
+
66
+ if self.take_log:
67
+ loss = 10*(torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS))
68
+ else:
69
+ loss = (e_error + self.EPS)/(e_target + self.EPS)
70
+
71
+ if self.reduction == "mean":
72
+ loss = loss.mean()
73
+ elif self.reduction == "sum":
74
+ loss = loss.sum()
75
+
76
+ return loss
77
+
78
+
79
+
80
+ class MultichannelSingleSrcNegSDR(_Loss):
81
+ def __init__(
82
+ self,
83
+ sdr_type: str,
84
+ p: float = 2.0,
85
+ zero_mean: bool = True,
86
+ take_log: bool = True,
87
+ reduction: str = "mean",
88
+ EPS: float = 1e-8,
89
+ ) -> None:
90
+ assert reduction != "sum", NotImplementedError
91
+ super().__init__(reduction=reduction)
92
+
93
+ assert sdr_type in ["snr", "sisdr", "sdsdr"]
94
+ self.sdr_type = sdr_type
95
+ self.zero_mean = zero_mean
96
+ self.take_log = take_log
97
+ self.EPS = 1e-8
98
+
99
+ self.p = p
100
+
101
+ def forward(
102
+ self,
103
+ est_target: torch.Tensor,
104
+ target: torch.Tensor
105
+ ) -> torch.Tensor:
106
+ if target.size() != est_target.size() or target.ndim != 3:
107
+ raise TypeError(
108
+ f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
109
+ )
110
+ # Step 1. Zero-mean norm
111
+ if self.zero_mean:
112
+ mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
113
+ mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
114
+ target = target - mean_source
115
+ est_target = est_target - mean_estimate
116
+ # Step 2. Pair-wise SI-SDR.
117
+ if self.sdr_type in ["sisdr", "sdsdr"]:
118
+ # [batch, 1]
119
+ dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
120
+ # [batch, 1]
121
+ s_target_energy = (
122
+ torch.sum(target ** 2, dim=[1, 2], keepdim=True) + self.EPS
123
+ )
124
+ # [batch, time]
125
+ scaled_target = dot * target / s_target_energy
126
+ else:
127
+ # [batch, time]
128
+ scaled_target = target
129
+ if self.sdr_type in ["sdsdr", "snr"]:
130
+ e_noise = est_target - target
131
+ else:
132
+ e_noise = est_target - scaled_target
133
+ # [batch]
134
+
135
+ if self.p == 2.0:
136
+ losses = torch.sum(scaled_target ** 2, dim=[1, 2]) / (
137
+ torch.sum(e_noise ** 2, dim=[1, 2]) + self.EPS
138
+ )
139
+ else:
140
+ losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
141
+ torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
142
+ )
143
+ if self.take_log:
144
+ losses = 10 * torch.log10(losses + self.EPS)
145
+ losses = losses.mean() if self.reduction == "mean" else losses
146
+ return -losses
separator/models/bandit/core/metrics/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .snr import (
2
+ ChunkMedianScaleInvariantSignalDistortionRatio,
3
+ ChunkMedianScaleInvariantSignalNoiseRatio,
4
+ ChunkMedianSignalDistortionRatio,
5
+ ChunkMedianSignalNoiseRatio,
6
+ SafeSignalDistortionRatio,
7
+ )
8
+
9
+ # from .mushra import EstimatedMushraScore
separator/models/bandit/core/metrics/_squim.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from torchaudio._internal import load_state_dict_from_url
4
+
5
+ import math
6
+ from typing import List, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def transform_wb_pesq_range(x: float) -> float:
14
+ """The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined
15
+ for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric
16
+ defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score".
17
+
18
+ Args:
19
+ x (float): Narrow-band PESQ score.
20
+
21
+ Returns:
22
+ (float): Wide-band PESQ score.
23
+ """
24
+ return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
25
+
26
+
27
+ PESQRange: Tuple[float, float] = (
28
+ 1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of
29
+ # the raw score is not -0.5 anymore. It's hard to figure out the true lower bound.
30
+ # We are using 1.0 as a reasonable approximation.
31
+ transform_wb_pesq_range(4.5),
32
+ )
33
+
34
+
35
+ class RangeSigmoid(nn.Module):
36
+ def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
37
+ super(RangeSigmoid, self).__init__()
38
+ assert isinstance(val_range, tuple) and len(val_range) == 2
39
+ self.val_range: Tuple[float, float] = val_range
40
+ self.sigmoid: nn.modules.Module = nn.Sigmoid()
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0]
44
+ return out
45
+
46
+
47
+ class Encoder(nn.Module):
48
+ """Encoder module that transform 1D waveform to 2D representations.
49
+
50
+ Args:
51
+ feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512)
52
+ win_len (int, optional): kernel size in the Conv1D layer. (Default: 32)
53
+ """
54
+
55
+ def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
56
+ super(Encoder, self).__init__()
57
+
58
+ self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ """Apply waveforms to convolutional layer and ReLU layer.
62
+
63
+ Args:
64
+ x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
65
+
66
+ Returns:
67
+ (torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`.
68
+ """
69
+ out = x.unsqueeze(dim=1)
70
+ out = F.relu(self.conv1d(out))
71
+ return out
72
+
73
+
74
+ class SingleRNN(nn.Module):
75
+ def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None:
76
+ super(SingleRNN, self).__init__()
77
+
78
+ self.rnn_type = rnn_type
79
+ self.input_size = input_size
80
+ self.hidden_size = hidden_size
81
+
82
+ self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
83
+ input_size,
84
+ hidden_size,
85
+ 1,
86
+ dropout=dropout,
87
+ batch_first=True,
88
+ bidirectional=True,
89
+ )
90
+
91
+ self.proj = nn.Linear(hidden_size * 2, input_size)
92
+
93
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
94
+ # input shape: batch, seq, dim
95
+ out, _ = self.rnn(x)
96
+ out = self.proj(out)
97
+ return out
98
+
99
+
100
+ class DPRNN(nn.Module):
101
+ """*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`.
102
+
103
+ Args:
104
+ feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64)
105
+ hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128)
106
+ num_blocks (int, optional): Number of DPRNN layers. (Default: 6)
107
+ rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM")
108
+ d_model (int, optional): The number of expected features in the input. (Default: 256)
109
+ chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100)
110
+ chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50)
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ feat_dim: int = 64,
116
+ hidden_dim: int = 128,
117
+ num_blocks: int = 6,
118
+ rnn_type: str = "LSTM",
119
+ d_model: int = 256,
120
+ chunk_size: int = 100,
121
+ chunk_stride: int = 50,
122
+ ) -> None:
123
+ super(DPRNN, self).__init__()
124
+
125
+ self.num_blocks = num_blocks
126
+
127
+ self.row_rnn = nn.ModuleList([])
128
+ self.col_rnn = nn.ModuleList([])
129
+ self.row_norm = nn.ModuleList([])
130
+ self.col_norm = nn.ModuleList([])
131
+ for _ in range(num_blocks):
132
+ self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
133
+ self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
134
+ self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
135
+ self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
136
+ self.conv = nn.Sequential(
137
+ nn.Conv2d(feat_dim, d_model, 1),
138
+ nn.PReLU(),
139
+ )
140
+ self.chunk_size = chunk_size
141
+ self.chunk_stride = chunk_stride
142
+
143
+ def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
144
+ # input shape: (B, N, T)
145
+ seq_len = x.shape[-1]
146
+
147
+ rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
148
+ out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
149
+
150
+ return out, rest
151
+
152
+ def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
153
+ out, rest = self.pad_chunk(x)
154
+ batch_size, feat_dim, seq_len = out.shape
155
+
156
+ segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
157
+ segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
158
+ out = torch.cat([segments1, segments2], dim=3)
159
+ out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous()
160
+
161
+ return out, rest
162
+
163
+ def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
164
+ batch_size, dim, _, _ = x.shape
165
+ out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2)
166
+ out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :]
167
+ out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
168
+ out = out1 + out2
169
+ if rest > 0:
170
+ out = out[:, :, :-rest]
171
+ out = out.contiguous()
172
+ return out
173
+
174
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
175
+ x, rest = self.chunking(x)
176
+ batch_size, _, dim1, dim2 = x.shape
177
+ out = x
178
+ for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm):
179
+ row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous()
180
+ row_out = row_rnn(row_in)
181
+ row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
182
+ row_out = row_norm(row_out)
183
+ out = out + row_out
184
+
185
+ col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous()
186
+ col_out = col_rnn(col_in)
187
+ col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
188
+ col_out = col_norm(col_out)
189
+ out = out + col_out
190
+ out = self.conv(out)
191
+ out = self.merging(out, rest)
192
+ out = out.transpose(1, 2).contiguous()
193
+ return out
194
+
195
+
196
+ class AutoPool(nn.Module):
197
+ def __init__(self, pool_dim: int = 1) -> None:
198
+ super(AutoPool, self).__init__()
199
+ self.pool_dim: int = pool_dim
200
+ self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
201
+ self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
202
+
203
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
204
+ weight = self.softmax(torch.mul(x, self.alpha))
205
+ out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
206
+ return out
207
+
208
+
209
+ class SquimObjective(nn.Module):
210
+ """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
211
+ for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
212
+
213
+ Args:
214
+ encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
215
+ dprnn (torch.nn.Module): DPRNN module to model sequential feature.
216
+ branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score.
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ encoder: nn.Module,
222
+ dprnn: nn.Module,
223
+ branches: nn.ModuleList,
224
+ ):
225
+ super(SquimObjective, self).__init__()
226
+ self.encoder = encoder
227
+ self.dprnn = dprnn
228
+ self.branches = branches
229
+
230
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
231
+ """
232
+ Args:
233
+ x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
234
+
235
+ Returns:
236
+ List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
237
+ """
238
+ if x.ndim != 2:
239
+ raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.")
240
+ x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
241
+ out = self.encoder(x)
242
+ out = self.dprnn(out)
243
+ scores = []
244
+ for branch in self.branches:
245
+ scores.append(branch(out).squeeze(dim=1))
246
+ return scores
247
+
248
+
249
+ def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
250
+ """Create branch module after DPRNN model for predicting metric score.
251
+
252
+ Args:
253
+ d_model (int): The number of expected features in the input.
254
+ nhead (int): Number of heads in the multi-head attention model.
255
+ metric (str): The metric name to predict.
256
+
257
+ Returns:
258
+ (nn.Module): Returned module to predict corresponding metric score.
259
+ """
260
+ layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True)
261
+ layer2 = AutoPool()
262
+ if metric == "stoi":
263
+ layer3 = nn.Sequential(
264
+ nn.Linear(d_model, d_model),
265
+ nn.PReLU(),
266
+ nn.Linear(d_model, 1),
267
+ RangeSigmoid(),
268
+ )
269
+ elif metric == "pesq":
270
+ layer3 = nn.Sequential(
271
+ nn.Linear(d_model, d_model),
272
+ nn.PReLU(),
273
+ nn.Linear(d_model, 1),
274
+ RangeSigmoid(val_range=PESQRange),
275
+ )
276
+ else:
277
+ layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1))
278
+ return nn.Sequential(layer1, layer2, layer3)
279
+
280
+
281
+ def squim_objective_model(
282
+ feat_dim: int,
283
+ win_len: int,
284
+ d_model: int,
285
+ nhead: int,
286
+ hidden_dim: int,
287
+ num_blocks: int,
288
+ rnn_type: str,
289
+ chunk_size: int,
290
+ chunk_stride: Optional[int] = None,
291
+ ) -> SquimObjective:
292
+ """Build a custome :class:`torchaudio.prototype.models.SquimObjective` model.
293
+
294
+ Args:
295
+ feat_dim (int, optional): The feature dimension after Encoder module.
296
+ win_len (int): Kernel size in the Encoder module.
297
+ d_model (int): The number of expected features in the input.
298
+ nhead (int): Number of heads in the multi-head attention model.
299
+ hidden_dim (int): Hidden dimension in the RNN layer of DPRNN.
300
+ num_blocks (int): Number of DPRNN layers.
301
+ rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"].
302
+ chunk_size (int): Chunk size of input for DPRNN.
303
+ chunk_stride (int or None, optional): Stride of chunk input for DPRNN.
304
+ """
305
+ if chunk_stride is None:
306
+ chunk_stride = chunk_size // 2
307
+ encoder = Encoder(feat_dim, win_len)
308
+ dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride)
309
+ branches = nn.ModuleList(
310
+ [
311
+ _create_branch(d_model, nhead, "stoi"),
312
+ _create_branch(d_model, nhead, "pesq"),
313
+ _create_branch(d_model, nhead, "sisdr"),
314
+ ]
315
+ )
316
+ return SquimObjective(encoder, dprnn, branches)
317
+
318
+
319
+ def squim_objective_base() -> SquimObjective:
320
+ """Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments."""
321
+ return squim_objective_model(
322
+ feat_dim=256,
323
+ win_len=64,
324
+ d_model=256,
325
+ nhead=4,
326
+ hidden_dim=256,
327
+ num_blocks=2,
328
+ rnn_type="LSTM",
329
+ chunk_size=71,
330
+ )
331
+
332
+ @dataclass
333
+ class SquimObjectiveBundle:
334
+
335
+ _path: str
336
+ _sample_rate: float
337
+
338
+ def _get_state_dict(self, dl_kwargs):
339
+ url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
340
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
341
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
342
+ return state_dict
343
+
344
+ def get_model(self, *, dl_kwargs=None) -> SquimObjective:
345
+ """Construct the SquimObjective model, and load the pretrained weight.
346
+
347
+ The weight file is downloaded from the internet and cached with
348
+ :func:`torch.hub.load_state_dict_from_url`
349
+
350
+ Args:
351
+ dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
352
+
353
+ Returns:
354
+ Variation of :py:class:`~torchaudio.models.SquimObjective`.
355
+ """
356
+ model = squim_objective_base()
357
+ model.load_state_dict(self._get_state_dict(dl_kwargs))
358
+ model.eval()
359
+ return model
360
+
361
+ @property
362
+ def sample_rate(self):
363
+ """Sample rate of the audio that the model is trained on.
364
+
365
+ :type: float
366
+ """
367
+ return self._sample_rate
368
+
369
+
370
+ SQUIM_OBJECTIVE = SquimObjectiveBundle(
371
+ "squim_objective_dns2020.pth",
372
+ _sample_rate=16000,
373
+ )
374
+ SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
375
+ :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
376
+
377
+ The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
378
+ The weights are under `Creative Commons Attribution 4.0 International License
379
+ <https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.
380
+
381
+ Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
382
+ """
383
+
separator/models/bandit/core/metrics/snr.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchmetrics as tm
6
+ from torch._C import _LinAlgError
7
+ from torchmetrics import functional as tmF
8
+
9
+
10
+ class SafeSignalDistortionRatio(tm.SignalDistortionRatio):
11
+ def __init__(self, **kwargs) -> None:
12
+ super().__init__(**kwargs)
13
+
14
+ def update(self, *args, **kwargs) -> Any:
15
+ try:
16
+ super().update(*args, **kwargs)
17
+ except:
18
+ pass
19
+
20
+ def compute(self) -> Any:
21
+ if self.total == 0:
22
+ return torch.tensor(torch.nan)
23
+ return super().compute()
24
+
25
+
26
+ class BaseChunkMedianSignalRatio(tm.Metric):
27
+ def __init__(
28
+ self,
29
+ func: Callable,
30
+ window_size: int,
31
+ hop_size: int = None,
32
+ zero_mean: bool = False,
33
+ ) -> None:
34
+ super().__init__()
35
+
36
+ # self.zero_mean = zero_mean
37
+ self.func = func
38
+ self.window_size = window_size
39
+ if hop_size is None:
40
+ hop_size = window_size
41
+ self.hop_size = hop_size
42
+
43
+ self.add_state(
44
+ "sum_snr",
45
+ default=torch.tensor(0.0),
46
+ dist_reduce_fx="sum"
47
+ )
48
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
49
+
50
+ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
51
+
52
+ n_samples = target.shape[-1]
53
+
54
+ n_chunks = int(
55
+ np.ceil((n_samples - self.window_size) / self.hop_size) + 1
56
+ )
57
+
58
+ snr_chunk = []
59
+
60
+ for i in range(n_chunks):
61
+ start = i * self.hop_size
62
+
63
+ if n_samples - start < self.window_size:
64
+ continue
65
+
66
+ end = start + self.window_size
67
+
68
+ try:
69
+ chunk_snr = self.func(
70
+ preds[..., start:end],
71
+ target[..., start:end]
72
+ )
73
+
74
+ # print(preds.shape, chunk_snr.shape)
75
+
76
+ if torch.all(torch.isfinite(chunk_snr)):
77
+ snr_chunk.append(chunk_snr)
78
+ except _LinAlgError:
79
+ pass
80
+
81
+ snr_chunk = torch.stack(snr_chunk, dim=-1)
82
+ snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1)
83
+
84
+ self.sum_snr += snr_batch.sum()
85
+ self.total += snr_batch.numel()
86
+
87
+ def compute(self) -> Any:
88
+ return self.sum_snr / self.total
89
+
90
+
91
+ class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio):
92
+ def __init__(
93
+ self,
94
+ window_size: int,
95
+ hop_size: int = None,
96
+ zero_mean: bool = False
97
+ ) -> None:
98
+ super().__init__(
99
+ func=tmF.signal_noise_ratio,
100
+ window_size=window_size,
101
+ hop_size=hop_size,
102
+ zero_mean=zero_mean,
103
+ )
104
+
105
+
106
+ class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio):
107
+ def __init__(
108
+ self,
109
+ window_size: int,
110
+ hop_size: int = None,
111
+ zero_mean: bool = False
112
+ ) -> None:
113
+ super().__init__(
114
+ func=tmF.scale_invariant_signal_noise_ratio,
115
+ window_size=window_size,
116
+ hop_size=hop_size,
117
+ zero_mean=zero_mean,
118
+ )
119
+
120
+
121
+ class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio):
122
+ def __init__(
123
+ self,
124
+ window_size: int,
125
+ hop_size: int = None,
126
+ zero_mean: bool = False
127
+ ) -> None:
128
+ super().__init__(
129
+ func=tmF.signal_distortion_ratio,
130
+ window_size=window_size,
131
+ hop_size=hop_size,
132
+ zero_mean=zero_mean,
133
+ )
134
+
135
+
136
+ class ChunkMedianScaleInvariantSignalDistortionRatio(
137
+ BaseChunkMedianSignalRatio
138
+ ):
139
+ def __init__(
140
+ self,
141
+ window_size: int,
142
+ hop_size: int = None,
143
+ zero_mean: bool = False
144
+ ) -> None:
145
+ super().__init__(
146
+ func=tmF.scale_invariant_signal_distortion_ratio,
147
+ window_size=window_size,
148
+ hop_size=hop_size,
149
+ zero_mean=zero_mean,
150
+ )
separator/models/bandit/core/model/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .bsrnn.wrapper import (
2
+ MultiMaskMultiSourceBandSplitRNNSimple,
3
+ )
separator/models/bandit/core/model/_spectral.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+
3
+ import torch
4
+ import torchaudio as ta
5
+ from torch import nn
6
+
7
+
8
+ class _SpectralComponent(nn.Module):
9
+ def __init__(
10
+ self,
11
+ n_fft: int = 2048,
12
+ win_length: Optional[int] = 2048,
13
+ hop_length: int = 512,
14
+ window_fn: str = "hann_window",
15
+ wkwargs: Optional[Dict] = None,
16
+ power: Optional[int] = None,
17
+ center: bool = True,
18
+ normalized: bool = True,
19
+ pad_mode: str = "constant",
20
+ onesided: bool = True,
21
+ **kwargs,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+ assert power is None
26
+
27
+ window_fn = torch.__dict__[window_fn]
28
+
29
+ self.stft = (
30
+ ta.transforms.Spectrogram(
31
+ n_fft=n_fft,
32
+ win_length=win_length,
33
+ hop_length=hop_length,
34
+ pad_mode=pad_mode,
35
+ pad=0,
36
+ window_fn=window_fn,
37
+ wkwargs=wkwargs,
38
+ power=power,
39
+ normalized=normalized,
40
+ center=center,
41
+ onesided=onesided,
42
+ )
43
+ )
44
+
45
+ self.istft = (
46
+ ta.transforms.InverseSpectrogram(
47
+ n_fft=n_fft,
48
+ win_length=win_length,
49
+ hop_length=hop_length,
50
+ pad_mode=pad_mode,
51
+ pad=0,
52
+ window_fn=window_fn,
53
+ wkwargs=wkwargs,
54
+ normalized=normalized,
55
+ center=center,
56
+ onesided=onesided,
57
+ )
58
+ )
separator/models/bandit/core/model/bsrnn/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from typing import Iterable, Mapping, Union
3
+
4
+ from torch import nn
5
+
6
+ from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule
7
+ from models.bandit.core.model.bsrnn.tfmodel import (
8
+ SeqBandModellingModule,
9
+ TransformerTimeFreqModule,
10
+ )
11
+
12
+
13
+ class BandsplitCoreBase(nn.Module, ABC):
14
+ band_split: nn.Module
15
+ tf_model: nn.Module
16
+ mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]]
17
+
18
+ def __init__(self) -> None:
19
+ super().__init__()
20
+
21
+ @staticmethod
22
+ def mask(x, m):
23
+ return x * m
separator/models/bandit/core/model/bsrnn/bandsplit.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from models.bandit.core.model.bsrnn.utils import (
7
+ band_widths_from_specs,
8
+ check_no_gap,
9
+ check_no_overlap,
10
+ check_nonzero_bandwidth,
11
+ )
12
+
13
+
14
+ class NormFC(nn.Module):
15
+ def __init__(
16
+ self,
17
+ emb_dim: int,
18
+ bandwidth: int,
19
+ in_channel: int,
20
+ normalize_channel_independently: bool = False,
21
+ treat_channel_as_feature: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+ self.treat_channel_as_feature = treat_channel_as_feature
26
+
27
+ if normalize_channel_independently:
28
+ raise NotImplementedError
29
+
30
+ reim = 2
31
+
32
+ self.norm = nn.LayerNorm(in_channel * bandwidth * reim)
33
+
34
+ fc_in = bandwidth * reim
35
+
36
+ if treat_channel_as_feature:
37
+ fc_in *= in_channel
38
+ else:
39
+ assert emb_dim % in_channel == 0
40
+ emb_dim = emb_dim // in_channel
41
+
42
+ self.fc = nn.Linear(fc_in, emb_dim)
43
+
44
+ def forward(self, xb):
45
+ # xb = (batch, n_time, in_chan, reim * band_width)
46
+
47
+ batch, n_time, in_chan, ribw = xb.shape
48
+ xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw))
49
+ # (batch, n_time, in_chan * reim * band_width)
50
+
51
+ if not self.treat_channel_as_feature:
52
+ xb = xb.reshape(batch, n_time, in_chan, ribw)
53
+ # (batch, n_time, in_chan, reim * band_width)
54
+
55
+ zb = self.fc(xb)
56
+ # (batch, n_time, emb_dim)
57
+ # OR
58
+ # (batch, n_time, in_chan, emb_dim_per_chan)
59
+
60
+ if not self.treat_channel_as_feature:
61
+ batch, n_time, in_chan, emb_dim_per_chan = zb.shape
62
+ # (batch, n_time, in_chan, emb_dim_per_chan)
63
+ zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan))
64
+
65
+ return zb # (batch, n_time, emb_dim)
66
+
67
+
68
+ class BandSplitModule(nn.Module):
69
+ def __init__(
70
+ self,
71
+ band_specs: List[Tuple[float, float]],
72
+ emb_dim: int,
73
+ in_channel: int,
74
+ require_no_overlap: bool = False,
75
+ require_no_gap: bool = True,
76
+ normalize_channel_independently: bool = False,
77
+ treat_channel_as_feature: bool = True,
78
+ ) -> None:
79
+ super().__init__()
80
+
81
+ check_nonzero_bandwidth(band_specs)
82
+
83
+ if require_no_gap:
84
+ check_no_gap(band_specs)
85
+
86
+ if require_no_overlap:
87
+ check_no_overlap(band_specs)
88
+
89
+ self.band_specs = band_specs
90
+ # list of [fstart, fend) in index.
91
+ # Note that fend is exclusive.
92
+ self.band_widths = band_widths_from_specs(band_specs)
93
+ self.n_bands = len(band_specs)
94
+ self.emb_dim = emb_dim
95
+
96
+ self.norm_fc_modules = nn.ModuleList(
97
+ [ # type: ignore
98
+ (
99
+ NormFC(
100
+ emb_dim=emb_dim,
101
+ bandwidth=bw,
102
+ in_channel=in_channel,
103
+ normalize_channel_independently=normalize_channel_independently,
104
+ treat_channel_as_feature=treat_channel_as_feature,
105
+ )
106
+ )
107
+ for bw in self.band_widths
108
+ ]
109
+ )
110
+
111
+ def forward(self, x: torch.Tensor):
112
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
113
+
114
+ batch, in_chan, _, n_time = x.shape
115
+
116
+ z = torch.zeros(
117
+ size=(batch, self.n_bands, n_time, self.emb_dim),
118
+ device=x.device
119
+ )
120
+
121
+ xr = torch.view_as_real(x) # batch, in_chan, n_freq, n_time, 2
122
+ xr = torch.permute(
123
+ xr,
124
+ (0, 3, 1, 4, 2)
125
+ ) # batch, n_time, in_chan, 2, n_freq
126
+ batch, n_time, in_chan, reim, band_width = xr.shape
127
+ for i, nfm in enumerate(self.norm_fc_modules):
128
+ # print(f"bandsplit/band{i:02d}")
129
+ fstart, fend = self.band_specs[i]
130
+ xb = xr[..., fstart:fend]
131
+ # (batch, n_time, in_chan, reim, band_width)
132
+ xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
133
+ # (batch, n_time, in_chan, reim * band_width)
134
+ # z.append(nfm(xb)) # (batch, n_time, emb_dim)
135
+ z[:, i, :, :] = nfm(xb.contiguous())
136
+
137
+ # z = torch.stack(z, dim=1)
138
+
139
+ return z
separator/models/bandit/core/model/bsrnn/core.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from models.bandit.core.model.bsrnn import BandsplitCoreBase
8
+ from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule
9
+ from models.bandit.core.model.bsrnn.maskestim import (
10
+ MaskEstimationModule,
11
+ OverlappingMaskEstimationModule
12
+ )
13
+ from models.bandit.core.model.bsrnn.tfmodel import (
14
+ ConvolutionalTimeFreqModule,
15
+ SeqBandModellingModule,
16
+ TransformerTimeFreqModule
17
+ )
18
+
19
+
20
+ class MultiMaskBandSplitCoreBase(BandsplitCoreBase):
21
+ def __init__(self) -> None:
22
+ super().__init__()
23
+
24
+ def forward(self, x, cond=None, compute_residual: bool = True):
25
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
26
+ # print(x.shape)
27
+ batch, in_chan, n_freq, n_time = x.shape
28
+ x = torch.reshape(x, (-1, 1, n_freq, n_time))
29
+
30
+ z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
31
+
32
+ # if torch.any(torch.isnan(z)):
33
+ # raise ValueError("z nan")
34
+
35
+ # print(z)
36
+ q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
37
+ # print(q)
38
+
39
+
40
+ # if torch.any(torch.isnan(q)):
41
+ # raise ValueError("q nan")
42
+
43
+ out = {}
44
+
45
+ for stem, mem in self.mask_estim.items():
46
+ m = mem(q, cond=cond)
47
+
48
+ # if torch.any(torch.isnan(m)):
49
+ # raise ValueError("m nan", stem)
50
+
51
+ s = self.mask(x, m)
52
+ s = torch.reshape(s, (batch, in_chan, n_freq, n_time))
53
+ out[stem] = s
54
+
55
+ return {"spectrogram": out}
56
+
57
+
58
+
59
+ def instantiate_mask_estim(self,
60
+ in_channel: int,
61
+ stems: List[str],
62
+ band_specs: List[Tuple[float, float]],
63
+ emb_dim: int,
64
+ mlp_dim: int,
65
+ cond_dim: int,
66
+ hidden_activation: str,
67
+
68
+ hidden_activation_kwargs: Optional[Dict] = None,
69
+ complex_mask: bool = True,
70
+ overlapping_band: bool = False,
71
+ freq_weights: Optional[List[torch.Tensor]] = None,
72
+ n_freq: Optional[int] = None,
73
+ use_freq_weights: bool = True,
74
+ mult_add_mask: bool = False
75
+ ):
76
+ if hidden_activation_kwargs is None:
77
+ hidden_activation_kwargs = {}
78
+
79
+ if "mne:+" in stems:
80
+ stems = [s for s in stems if s != "mne:+"]
81
+
82
+ if overlapping_band:
83
+ assert freq_weights is not None
84
+ assert n_freq is not None
85
+
86
+ if mult_add_mask:
87
+
88
+ self.mask_estim = nn.ModuleDict(
89
+ {
90
+ stem: MultAddMaskEstimationModule(
91
+ band_specs=band_specs,
92
+ freq_weights=freq_weights,
93
+ n_freq=n_freq,
94
+ emb_dim=emb_dim,
95
+ mlp_dim=mlp_dim,
96
+ in_channel=in_channel,
97
+ hidden_activation=hidden_activation,
98
+ hidden_activation_kwargs=hidden_activation_kwargs,
99
+ complex_mask=complex_mask,
100
+ use_freq_weights=use_freq_weights,
101
+ )
102
+ for stem in stems
103
+ }
104
+ )
105
+ else:
106
+ self.mask_estim = nn.ModuleDict(
107
+ {
108
+ stem: OverlappingMaskEstimationModule(
109
+ band_specs=band_specs,
110
+ freq_weights=freq_weights,
111
+ n_freq=n_freq,
112
+ emb_dim=emb_dim,
113
+ mlp_dim=mlp_dim,
114
+ in_channel=in_channel,
115
+ hidden_activation=hidden_activation,
116
+ hidden_activation_kwargs=hidden_activation_kwargs,
117
+ complex_mask=complex_mask,
118
+ use_freq_weights=use_freq_weights,
119
+ )
120
+ for stem in stems
121
+ }
122
+ )
123
+ else:
124
+ self.mask_estim = nn.ModuleDict(
125
+ {
126
+ stem: MaskEstimationModule(
127
+ band_specs=band_specs,
128
+ emb_dim=emb_dim,
129
+ mlp_dim=mlp_dim,
130
+ cond_dim=cond_dim,
131
+ in_channel=in_channel,
132
+ hidden_activation=hidden_activation,
133
+ hidden_activation_kwargs=hidden_activation_kwargs,
134
+ complex_mask=complex_mask,
135
+ )
136
+ for stem in stems
137
+ }
138
+ )
139
+
140
+ def instantiate_bandsplit(self,
141
+ in_channel: int,
142
+ band_specs: List[Tuple[float, float]],
143
+ require_no_overlap: bool = False,
144
+ require_no_gap: bool = True,
145
+ normalize_channel_independently: bool = False,
146
+ treat_channel_as_feature: bool = True,
147
+ emb_dim: int = 128
148
+ ):
149
+ self.band_split = BandSplitModule(
150
+ in_channel=in_channel,
151
+ band_specs=band_specs,
152
+ require_no_overlap=require_no_overlap,
153
+ require_no_gap=require_no_gap,
154
+ normalize_channel_independently=normalize_channel_independently,
155
+ treat_channel_as_feature=treat_channel_as_feature,
156
+ emb_dim=emb_dim,
157
+ )
158
+
159
+ class SingleMaskBandsplitCoreBase(BandsplitCoreBase):
160
+ def __init__(self, **kwargs) -> None:
161
+ super().__init__()
162
+
163
+ def forward(self, x):
164
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
165
+ z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
166
+ q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
167
+ m = self.mask_estim(q) # (batch, in_chan, n_freq, n_time)
168
+
169
+ s = self.mask(x, m)
170
+
171
+ return s
172
+
173
+
174
+ class SingleMaskBandsplitCoreRNN(
175
+ SingleMaskBandsplitCoreBase,
176
+ ):
177
+ def __init__(
178
+ self,
179
+ in_channel: int,
180
+ band_specs: List[Tuple[float, float]],
181
+ require_no_overlap: bool = False,
182
+ require_no_gap: bool = True,
183
+ normalize_channel_independently: bool = False,
184
+ treat_channel_as_feature: bool = True,
185
+ n_sqm_modules: int = 12,
186
+ emb_dim: int = 128,
187
+ rnn_dim: int = 256,
188
+ bidirectional: bool = True,
189
+ rnn_type: str = "LSTM",
190
+ mlp_dim: int = 512,
191
+ hidden_activation: str = "Tanh",
192
+ hidden_activation_kwargs: Optional[Dict] = None,
193
+ complex_mask: bool = True,
194
+ ) -> None:
195
+ super().__init__()
196
+ self.band_split = (BandSplitModule(
197
+ in_channel=in_channel,
198
+ band_specs=band_specs,
199
+ require_no_overlap=require_no_overlap,
200
+ require_no_gap=require_no_gap,
201
+ normalize_channel_independently=normalize_channel_independently,
202
+ treat_channel_as_feature=treat_channel_as_feature,
203
+ emb_dim=emb_dim,
204
+ ))
205
+ self.tf_model = (SeqBandModellingModule(
206
+ n_modules=n_sqm_modules,
207
+ emb_dim=emb_dim,
208
+ rnn_dim=rnn_dim,
209
+ bidirectional=bidirectional,
210
+ rnn_type=rnn_type,
211
+ ))
212
+ self.mask_estim = (MaskEstimationModule(
213
+ in_channel=in_channel,
214
+ band_specs=band_specs,
215
+ emb_dim=emb_dim,
216
+ mlp_dim=mlp_dim,
217
+ hidden_activation=hidden_activation,
218
+ hidden_activation_kwargs=hidden_activation_kwargs,
219
+ complex_mask=complex_mask,
220
+ ))
221
+
222
+
223
+ class SingleMaskBandsplitCoreTransformer(
224
+ SingleMaskBandsplitCoreBase,
225
+ ):
226
+ def __init__(
227
+ self,
228
+ in_channel: int,
229
+ band_specs: List[Tuple[float, float]],
230
+ require_no_overlap: bool = False,
231
+ require_no_gap: bool = True,
232
+ normalize_channel_independently: bool = False,
233
+ treat_channel_as_feature: bool = True,
234
+ n_sqm_modules: int = 12,
235
+ emb_dim: int = 128,
236
+ rnn_dim: int = 256,
237
+ bidirectional: bool = True,
238
+ tf_dropout: float = 0.0,
239
+ mlp_dim: int = 512,
240
+ hidden_activation: str = "Tanh",
241
+ hidden_activation_kwargs: Optional[Dict] = None,
242
+ complex_mask: bool = True,
243
+ ) -> None:
244
+ super().__init__()
245
+ self.band_split = BandSplitModule(
246
+ in_channel=in_channel,
247
+ band_specs=band_specs,
248
+ require_no_overlap=require_no_overlap,
249
+ require_no_gap=require_no_gap,
250
+ normalize_channel_independently=normalize_channel_independently,
251
+ treat_channel_as_feature=treat_channel_as_feature,
252
+ emb_dim=emb_dim,
253
+ )
254
+ self.tf_model = TransformerTimeFreqModule(
255
+ n_modules=n_sqm_modules,
256
+ emb_dim=emb_dim,
257
+ rnn_dim=rnn_dim,
258
+ bidirectional=bidirectional,
259
+ dropout=tf_dropout,
260
+ )
261
+ self.mask_estim = MaskEstimationModule(
262
+ in_channel=in_channel,
263
+ band_specs=band_specs,
264
+ emb_dim=emb_dim,
265
+ mlp_dim=mlp_dim,
266
+ hidden_activation=hidden_activation,
267
+ hidden_activation_kwargs=hidden_activation_kwargs,
268
+ complex_mask=complex_mask,
269
+ )
270
+
271
+
272
+ class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase):
273
+ def __init__(
274
+ self,
275
+ in_channel: int,
276
+ stems: List[str],
277
+ band_specs: List[Tuple[float, float]],
278
+ require_no_overlap: bool = False,
279
+ require_no_gap: bool = True,
280
+ normalize_channel_independently: bool = False,
281
+ treat_channel_as_feature: bool = True,
282
+ n_sqm_modules: int = 12,
283
+ emb_dim: int = 128,
284
+ rnn_dim: int = 256,
285
+ bidirectional: bool = True,
286
+ rnn_type: str = "LSTM",
287
+ mlp_dim: int = 512,
288
+ cond_dim: int = 0,
289
+ hidden_activation: str = "Tanh",
290
+ hidden_activation_kwargs: Optional[Dict] = None,
291
+ complex_mask: bool = True,
292
+ overlapping_band: bool = False,
293
+ freq_weights: Optional[List[torch.Tensor]] = None,
294
+ n_freq: Optional[int] = None,
295
+ use_freq_weights: bool = True,
296
+ mult_add_mask: bool = False
297
+ ) -> None:
298
+
299
+ super().__init__()
300
+ self.instantiate_bandsplit(
301
+ in_channel=in_channel,
302
+ band_specs=band_specs,
303
+ require_no_overlap=require_no_overlap,
304
+ require_no_gap=require_no_gap,
305
+ normalize_channel_independently=normalize_channel_independently,
306
+ treat_channel_as_feature=treat_channel_as_feature,
307
+ emb_dim=emb_dim
308
+ )
309
+
310
+
311
+ self.tf_model = (
312
+ SeqBandModellingModule(
313
+ n_modules=n_sqm_modules,
314
+ emb_dim=emb_dim,
315
+ rnn_dim=rnn_dim,
316
+ bidirectional=bidirectional,
317
+ rnn_type=rnn_type,
318
+ )
319
+ )
320
+
321
+ self.mult_add_mask = mult_add_mask
322
+
323
+ self.instantiate_mask_estim(
324
+ in_channel=in_channel,
325
+ stems=stems,
326
+ band_specs=band_specs,
327
+ emb_dim=emb_dim,
328
+ mlp_dim=mlp_dim,
329
+ cond_dim=cond_dim,
330
+ hidden_activation=hidden_activation,
331
+ hidden_activation_kwargs=hidden_activation_kwargs,
332
+ complex_mask=complex_mask,
333
+ overlapping_band=overlapping_band,
334
+ freq_weights=freq_weights,
335
+ n_freq=n_freq,
336
+ use_freq_weights=use_freq_weights,
337
+ mult_add_mask=mult_add_mask
338
+ )
339
+
340
+ @staticmethod
341
+ def _mult_add_mask(x, m):
342
+
343
+ assert m.ndim == 5
344
+
345
+ mm = m[..., 0]
346
+ am = m[..., 1]
347
+
348
+ # print(mm.shape, am.shape, x.shape, m.shape)
349
+
350
+ return x * mm + am
351
+
352
+ def mask(self, x, m):
353
+ if self.mult_add_mask:
354
+
355
+ return self._mult_add_mask(x, m)
356
+ else:
357
+ return super().mask(x, m)
358
+
359
+
360
+ class MultiSourceMultiMaskBandSplitCoreTransformer(
361
+ MultiMaskBandSplitCoreBase,
362
+ ):
363
+ def __init__(
364
+ self,
365
+ in_channel: int,
366
+ stems: List[str],
367
+ band_specs: List[Tuple[float, float]],
368
+ require_no_overlap: bool = False,
369
+ require_no_gap: bool = True,
370
+ normalize_channel_independently: bool = False,
371
+ treat_channel_as_feature: bool = True,
372
+ n_sqm_modules: int = 12,
373
+ emb_dim: int = 128,
374
+ rnn_dim: int = 256,
375
+ bidirectional: bool = True,
376
+ tf_dropout: float = 0.0,
377
+ mlp_dim: int = 512,
378
+ hidden_activation: str = "Tanh",
379
+ hidden_activation_kwargs: Optional[Dict] = None,
380
+ complex_mask: bool = True,
381
+ overlapping_band: bool = False,
382
+ freq_weights: Optional[List[torch.Tensor]] = None,
383
+ n_freq: Optional[int] = None,
384
+ use_freq_weights:bool=True,
385
+ rnn_type: str = "LSTM",
386
+ cond_dim: int = 0,
387
+ mult_add_mask: bool = False
388
+ ) -> None:
389
+ super().__init__()
390
+ self.instantiate_bandsplit(
391
+ in_channel=in_channel,
392
+ band_specs=band_specs,
393
+ require_no_overlap=require_no_overlap,
394
+ require_no_gap=require_no_gap,
395
+ normalize_channel_independently=normalize_channel_independently,
396
+ treat_channel_as_feature=treat_channel_as_feature,
397
+ emb_dim=emb_dim
398
+ )
399
+ self.tf_model = TransformerTimeFreqModule(
400
+ n_modules=n_sqm_modules,
401
+ emb_dim=emb_dim,
402
+ rnn_dim=rnn_dim,
403
+ bidirectional=bidirectional,
404
+ dropout=tf_dropout,
405
+ )
406
+
407
+ self.instantiate_mask_estim(
408
+ in_channel=in_channel,
409
+ stems=stems,
410
+ band_specs=band_specs,
411
+ emb_dim=emb_dim,
412
+ mlp_dim=mlp_dim,
413
+ cond_dim=cond_dim,
414
+ hidden_activation=hidden_activation,
415
+ hidden_activation_kwargs=hidden_activation_kwargs,
416
+ complex_mask=complex_mask,
417
+ overlapping_band=overlapping_band,
418
+ freq_weights=freq_weights,
419
+ n_freq=n_freq,
420
+ use_freq_weights=use_freq_weights,
421
+ mult_add_mask=mult_add_mask
422
+ )
423
+
424
+
425
+
426
+ class MultiSourceMultiMaskBandSplitCoreConv(
427
+ MultiMaskBandSplitCoreBase,
428
+ ):
429
+ def __init__(
430
+ self,
431
+ in_channel: int,
432
+ stems: List[str],
433
+ band_specs: List[Tuple[float, float]],
434
+ require_no_overlap: bool = False,
435
+ require_no_gap: bool = True,
436
+ normalize_channel_independently: bool = False,
437
+ treat_channel_as_feature: bool = True,
438
+ n_sqm_modules: int = 12,
439
+ emb_dim: int = 128,
440
+ rnn_dim: int = 256,
441
+ bidirectional: bool = True,
442
+ tf_dropout: float = 0.0,
443
+ mlp_dim: int = 512,
444
+ hidden_activation: str = "Tanh",
445
+ hidden_activation_kwargs: Optional[Dict] = None,
446
+ complex_mask: bool = True,
447
+ overlapping_band: bool = False,
448
+ freq_weights: Optional[List[torch.Tensor]] = None,
449
+ n_freq: Optional[int] = None,
450
+ use_freq_weights:bool=True,
451
+ rnn_type: str = "LSTM",
452
+ cond_dim: int = 0,
453
+ mult_add_mask: bool = False
454
+ ) -> None:
455
+ super().__init__()
456
+ self.instantiate_bandsplit(
457
+ in_channel=in_channel,
458
+ band_specs=band_specs,
459
+ require_no_overlap=require_no_overlap,
460
+ require_no_gap=require_no_gap,
461
+ normalize_channel_independently=normalize_channel_independently,
462
+ treat_channel_as_feature=treat_channel_as_feature,
463
+ emb_dim=emb_dim
464
+ )
465
+ self.tf_model = ConvolutionalTimeFreqModule(
466
+ n_modules=n_sqm_modules,
467
+ emb_dim=emb_dim,
468
+ rnn_dim=rnn_dim,
469
+ bidirectional=bidirectional,
470
+ dropout=tf_dropout,
471
+ )
472
+
473
+ self.instantiate_mask_estim(
474
+ in_channel=in_channel,
475
+ stems=stems,
476
+ band_specs=band_specs,
477
+ emb_dim=emb_dim,
478
+ mlp_dim=mlp_dim,
479
+ cond_dim=cond_dim,
480
+ hidden_activation=hidden_activation,
481
+ hidden_activation_kwargs=hidden_activation_kwargs,
482
+ complex_mask=complex_mask,
483
+ overlapping_band=overlapping_band,
484
+ freq_weights=freq_weights,
485
+ n_freq=n_freq,
486
+ use_freq_weights=use_freq_weights,
487
+ mult_add_mask=mult_add_mask
488
+ )
489
+
490
+
491
+ class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase):
492
+ def __init__(self) -> None:
493
+ super().__init__()
494
+
495
+ def mask(self, x, m):
496
+ # x.shape = (batch, n_channel, n_freq, n_time)
497
+ # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
498
+
499
+ _, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape
500
+ padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2)
501
+
502
+ xf = F.unfold(
503
+ x,
504
+ kernel_size=(kernel_freq, kernel_time),
505
+ padding=padding,
506
+ stride=(1, 1),
507
+ )
508
+
509
+ xf = xf.view(
510
+ -1,
511
+ n_channel,
512
+ kernel_freq,
513
+ kernel_time,
514
+ n_freq,
515
+ n_time,
516
+ )
517
+
518
+ sf = xf * m
519
+
520
+ sf = sf.view(
521
+ -1,
522
+ n_channel * kernel_freq * kernel_time,
523
+ n_freq * n_time,
524
+ )
525
+
526
+ s = F.fold(
527
+ sf,
528
+ output_size=(n_freq, n_time),
529
+ kernel_size=(kernel_freq, kernel_time),
530
+ padding=padding,
531
+ stride=(1, 1),
532
+ ).view(
533
+ -1,
534
+ n_channel,
535
+ n_freq,
536
+ n_time,
537
+ )
538
+
539
+ return s
540
+
541
+ def old_mask(self, x, m):
542
+ # x.shape = (batch, n_channel, n_freq, n_time)
543
+ # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
544
+
545
+ s = torch.zeros_like(x)
546
+
547
+ _, n_channel, n_freq, n_time = x.shape
548
+ kernel_freq, kernel_time, _, _, _, _ = m.shape
549
+
550
+ # print(x.shape, m.shape)
551
+
552
+ kernel_freq_half = (kernel_freq - 1) // 2
553
+ kernel_time_half = (kernel_time - 1) // 2
554
+
555
+ for ifreq in range(kernel_freq):
556
+ for itime in range(kernel_time):
557
+ df, dt = kernel_freq_half - ifreq, kernel_time_half - itime
558
+ x = x.roll(shifts=(df, dt), dims=(2, 3))
559
+
560
+ # if `df` > 0:
561
+ # x[:, :, :df, :] = 0
562
+ # elif `df` < 0:
563
+ # x[:, :, df:, :] = 0
564
+
565
+ # if `dt` > 0:
566
+ # x[:, :, :, :dt] = 0
567
+ # elif `dt` < 0:
568
+ # x[:, :, :, dt:] = 0
569
+
570
+ fslice = slice(max(0, df), min(n_freq, n_freq + df))
571
+ tslice = slice(max(0, dt), min(n_time, n_time + dt))
572
+
573
+ s[:, :, fslice, tslice] += x[:, :, fslice, tslice] * m[ifreq,
574
+ itime, :,
575
+ :, fslice,
576
+ tslice]
577
+
578
+ return s
579
+
580
+
581
+ class MultiSourceMultiPatchingMaskBandSplitCoreRNN(
582
+ PatchingMaskBandsplitCoreBase
583
+ ):
584
+ def __init__(
585
+ self,
586
+ in_channel: int,
587
+ stems: List[str],
588
+ band_specs: List[Tuple[float, float]],
589
+ mask_kernel_freq: int,
590
+ mask_kernel_time: int,
591
+ conv_kernel_freq: int,
592
+ conv_kernel_time: int,
593
+ kernel_norm_mlp_version: int,
594
+ require_no_overlap: bool = False,
595
+ require_no_gap: bool = True,
596
+ normalize_channel_independently: bool = False,
597
+ treat_channel_as_feature: bool = True,
598
+ n_sqm_modules: int = 12,
599
+ emb_dim: int = 128,
600
+ rnn_dim: int = 256,
601
+ bidirectional: bool = True,
602
+ rnn_type: str = "LSTM",
603
+ mlp_dim: int = 512,
604
+ hidden_activation: str = "Tanh",
605
+ hidden_activation_kwargs: Optional[Dict] = None,
606
+ complex_mask: bool = True,
607
+ overlapping_band: bool = False,
608
+ freq_weights: Optional[List[torch.Tensor]] = None,
609
+ n_freq: Optional[int] = None,
610
+ ) -> None:
611
+
612
+ super().__init__()
613
+ self.band_split = BandSplitModule(
614
+ in_channel=in_channel,
615
+ band_specs=band_specs,
616
+ require_no_overlap=require_no_overlap,
617
+ require_no_gap=require_no_gap,
618
+ normalize_channel_independently=normalize_channel_independently,
619
+ treat_channel_as_feature=treat_channel_as_feature,
620
+ emb_dim=emb_dim,
621
+ )
622
+
623
+ self.tf_model = (
624
+ SeqBandModellingModule(
625
+ n_modules=n_sqm_modules,
626
+ emb_dim=emb_dim,
627
+ rnn_dim=rnn_dim,
628
+ bidirectional=bidirectional,
629
+ rnn_type=rnn_type,
630
+ )
631
+ )
632
+
633
+ if hidden_activation_kwargs is None:
634
+ hidden_activation_kwargs = {}
635
+
636
+ if overlapping_band:
637
+ assert freq_weights is not None
638
+ assert n_freq is not None
639
+ self.mask_estim = nn.ModuleDict(
640
+ {
641
+ stem: PatchingMaskEstimationModule(
642
+ band_specs=band_specs,
643
+ freq_weights=freq_weights,
644
+ n_freq=n_freq,
645
+ emb_dim=emb_dim,
646
+ mlp_dim=mlp_dim,
647
+ in_channel=in_channel,
648
+ hidden_activation=hidden_activation,
649
+ hidden_activation_kwargs=hidden_activation_kwargs,
650
+ complex_mask=complex_mask,
651
+ mask_kernel_freq=mask_kernel_freq,
652
+ mask_kernel_time=mask_kernel_time,
653
+ conv_kernel_freq=conv_kernel_freq,
654
+ conv_kernel_time=conv_kernel_time,
655
+ kernel_norm_mlp_version=kernel_norm_mlp_version
656
+ )
657
+ for stem in stems
658
+ }
659
+ )
660
+ else:
661
+ raise NotImplementedError
separator/models/bandit/core/model/bsrnn/maskestim.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Dict, List, Optional, Tuple, Type
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn.modules import activation
7
+
8
+ from models.bandit.core.model.bsrnn.utils import (
9
+ band_widths_from_specs,
10
+ check_no_gap,
11
+ check_no_overlap,
12
+ check_nonzero_bandwidth,
13
+ )
14
+
15
+
16
+ class BaseNormMLP(nn.Module):
17
+ def __init__(
18
+ self,
19
+ emb_dim: int,
20
+ mlp_dim: int,
21
+ bandwidth: int,
22
+ in_channel: Optional[int],
23
+ hidden_activation: str = "Tanh",
24
+ hidden_activation_kwargs=None,
25
+ complex_mask: bool = True, ):
26
+
27
+ super().__init__()
28
+ if hidden_activation_kwargs is None:
29
+ hidden_activation_kwargs = {}
30
+ self.hidden_activation_kwargs = hidden_activation_kwargs
31
+ self.norm = nn.LayerNorm(emb_dim)
32
+ self.hidden = torch.jit.script(nn.Sequential(
33
+ nn.Linear(in_features=emb_dim, out_features=mlp_dim),
34
+ activation.__dict__[hidden_activation](
35
+ **self.hidden_activation_kwargs
36
+ ),
37
+ ))
38
+
39
+ self.bandwidth = bandwidth
40
+ self.in_channel = in_channel
41
+
42
+ self.complex_mask = complex_mask
43
+ self.reim = 2 if complex_mask else 1
44
+ self.glu_mult = 2
45
+
46
+
47
+ class NormMLP(BaseNormMLP):
48
+ def __init__(
49
+ self,
50
+ emb_dim: int,
51
+ mlp_dim: int,
52
+ bandwidth: int,
53
+ in_channel: Optional[int],
54
+ hidden_activation: str = "Tanh",
55
+ hidden_activation_kwargs=None,
56
+ complex_mask: bool = True,
57
+ ) -> None:
58
+ super().__init__(
59
+ emb_dim=emb_dim,
60
+ mlp_dim=mlp_dim,
61
+ bandwidth=bandwidth,
62
+ in_channel=in_channel,
63
+ hidden_activation=hidden_activation,
64
+ hidden_activation_kwargs=hidden_activation_kwargs,
65
+ complex_mask=complex_mask,
66
+ )
67
+
68
+ self.output = torch.jit.script(
69
+ nn.Sequential(
70
+ nn.Linear(
71
+ in_features=mlp_dim,
72
+ out_features=bandwidth * in_channel * self.reim * 2,
73
+ ),
74
+ nn.GLU(dim=-1),
75
+ )
76
+ )
77
+
78
+ def reshape_output(self, mb):
79
+ # print(mb.shape)
80
+ batch, n_time, _ = mb.shape
81
+ if self.complex_mask:
82
+ mb = mb.reshape(
83
+ batch,
84
+ n_time,
85
+ self.in_channel,
86
+ self.bandwidth,
87
+ self.reim
88
+ ).contiguous()
89
+ # print(mb.shape)
90
+ mb = torch.view_as_complex(
91
+ mb
92
+ ) # (batch, n_time, in_channel, bandwidth)
93
+ else:
94
+ mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)
95
+
96
+ mb = torch.permute(
97
+ mb,
98
+ (0, 2, 3, 1)
99
+ ) # (batch, in_channel, bandwidth, n_time)
100
+
101
+ return mb
102
+
103
+ def forward(self, qb):
104
+ # qb = (batch, n_time, emb_dim)
105
+
106
+ # if torch.any(torch.isnan(qb)):
107
+ # raise ValueError("qb0")
108
+
109
+
110
+ qb = self.norm(qb) # (batch, n_time, emb_dim)
111
+
112
+ # if torch.any(torch.isnan(qb)):
113
+ # raise ValueError("qb1")
114
+
115
+ qb = self.hidden(qb) # (batch, n_time, mlp_dim)
116
+ # if torch.any(torch.isnan(qb)):
117
+ # raise ValueError("qb2")
118
+ mb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
119
+ # if torch.any(torch.isnan(qb)):
120
+ # raise ValueError("mb")
121
+ mb = self.reshape_output(mb) # (batch, in_channel, bandwidth, n_time)
122
+
123
+ return mb
124
+
125
+
126
+ class MultAddNormMLP(NormMLP):
127
+ def __init__(self, emb_dim: int, mlp_dim: int, bandwidth: int, in_channel: "int | None", hidden_activation: str = "Tanh", hidden_activation_kwargs=None, complex_mask: bool = True) -> None:
128
+ super().__init__(emb_dim, mlp_dim, bandwidth, in_channel, hidden_activation, hidden_activation_kwargs, complex_mask)
129
+
130
+ self.output2 = torch.jit.script(
131
+ nn.Sequential(
132
+ nn.Linear(
133
+ in_features=mlp_dim,
134
+ out_features=bandwidth * in_channel * self.reim * 2,
135
+ ),
136
+ nn.GLU(dim=-1),
137
+ )
138
+ )
139
+
140
+ def forward(self, qb):
141
+
142
+ qb = self.norm(qb) # (batch, n_time, emb_dim)
143
+ qb = self.hidden(qb) # (batch, n_time, mlp_dim)
144
+ mmb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
145
+ mmb = self.reshape_output(mmb) # (batch, in_channel, bandwidth, n_time)
146
+ amb = self.output2(qb) # (batch, n_time, bandwidth * in_channel * reim)
147
+ amb = self.reshape_output(amb) # (batch, in_channel, bandwidth, n_time)
148
+
149
+ return mmb, amb
150
+
151
+
152
+ class MaskEstimationModuleSuperBase(nn.Module):
153
+ pass
154
+
155
+
156
+ class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
157
+ def __init__(
158
+ self,
159
+ band_specs: List[Tuple[float, float]],
160
+ emb_dim: int,
161
+ mlp_dim: int,
162
+ in_channel: Optional[int],
163
+ hidden_activation: str = "Tanh",
164
+ hidden_activation_kwargs: Dict = None,
165
+ complex_mask: bool = True,
166
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
167
+ norm_mlp_kwargs: Dict = None,
168
+ ) -> None:
169
+ super().__init__()
170
+
171
+ self.band_widths = band_widths_from_specs(band_specs)
172
+ self.n_bands = len(band_specs)
173
+
174
+ if hidden_activation_kwargs is None:
175
+ hidden_activation_kwargs = {}
176
+
177
+ if norm_mlp_kwargs is None:
178
+ norm_mlp_kwargs = {}
179
+
180
+ self.norm_mlp = nn.ModuleList(
181
+ [
182
+ (
183
+ norm_mlp_cls(
184
+ bandwidth=self.band_widths[b],
185
+ emb_dim=emb_dim,
186
+ mlp_dim=mlp_dim,
187
+ in_channel=in_channel,
188
+ hidden_activation=hidden_activation,
189
+ hidden_activation_kwargs=hidden_activation_kwargs,
190
+ complex_mask=complex_mask,
191
+ **norm_mlp_kwargs,
192
+ )
193
+ )
194
+ for b in range(self.n_bands)
195
+ ]
196
+ )
197
+
198
+ def compute_masks(self, q):
199
+ batch, n_bands, n_time, emb_dim = q.shape
200
+
201
+ masks = []
202
+
203
+ for b, nmlp in enumerate(self.norm_mlp):
204
+ # print(f"maskestim/{b:02d}")
205
+ qb = q[:, b, :, :]
206
+ mb = nmlp(qb)
207
+ masks.append(mb)
208
+
209
+ return masks
210
+
211
+
212
+
213
+ class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
214
+ def __init__(
215
+ self,
216
+ in_channel: int,
217
+ band_specs: List[Tuple[float, float]],
218
+ freq_weights: List[torch.Tensor],
219
+ n_freq: int,
220
+ emb_dim: int,
221
+ mlp_dim: int,
222
+ cond_dim: int = 0,
223
+ hidden_activation: str = "Tanh",
224
+ hidden_activation_kwargs: Dict = None,
225
+ complex_mask: bool = True,
226
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
227
+ norm_mlp_kwargs: Dict = None,
228
+ use_freq_weights: bool = True,
229
+ ) -> None:
230
+ check_nonzero_bandwidth(band_specs)
231
+ check_no_gap(band_specs)
232
+
233
+ # if cond_dim > 0:
234
+ # raise NotImplementedError
235
+
236
+ super().__init__(
237
+ band_specs=band_specs,
238
+ emb_dim=emb_dim + cond_dim,
239
+ mlp_dim=mlp_dim,
240
+ in_channel=in_channel,
241
+ hidden_activation=hidden_activation,
242
+ hidden_activation_kwargs=hidden_activation_kwargs,
243
+ complex_mask=complex_mask,
244
+ norm_mlp_cls=norm_mlp_cls,
245
+ norm_mlp_kwargs=norm_mlp_kwargs,
246
+ )
247
+
248
+ self.n_freq = n_freq
249
+ self.band_specs = band_specs
250
+ self.in_channel = in_channel
251
+
252
+ if freq_weights is not None:
253
+ for i, fw in enumerate(freq_weights):
254
+ self.register_buffer(f"freq_weights/{i}", fw)
255
+
256
+ self.use_freq_weights = use_freq_weights
257
+ else:
258
+ self.use_freq_weights = False
259
+
260
+ self.cond_dim = cond_dim
261
+
262
+ def forward(self, q, cond=None):
263
+ # q = (batch, n_bands, n_time, emb_dim)
264
+
265
+ batch, n_bands, n_time, emb_dim = q.shape
266
+
267
+ if cond is not None:
268
+ print(cond)
269
+ if cond.ndim == 2:
270
+ cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1)
271
+ elif cond.ndim == 3:
272
+ assert cond.shape[1] == n_time
273
+ else:
274
+ raise ValueError(f"Invalid cond shape: {cond.shape}")
275
+
276
+ q = torch.cat([q, cond], dim=-1)
277
+ elif self.cond_dim > 0:
278
+ cond = torch.ones(
279
+ (batch, n_bands, n_time, self.cond_dim),
280
+ device=q.device,
281
+ dtype=q.dtype,
282
+ )
283
+ q = torch.cat([q, cond], dim=-1)
284
+ else:
285
+ pass
286
+
287
+ mask_list = self.compute_masks(
288
+ q
289
+ ) # [n_bands * (batch, in_channel, bandwidth, n_time)]
290
+
291
+ masks = torch.zeros(
292
+ (batch, self.in_channel, self.n_freq, n_time),
293
+ device=q.device,
294
+ dtype=mask_list[0].dtype,
295
+ )
296
+
297
+ for im, mask in enumerate(mask_list):
298
+ fstart, fend = self.band_specs[im]
299
+ if self.use_freq_weights:
300
+ fw = self.get_buffer(f"freq_weights/{im}")[:, None]
301
+ mask = mask * fw
302
+ masks[:, :, fstart:fend, :] += mask
303
+
304
+ return masks
305
+
306
+
307
+ class MaskEstimationModule(OverlappingMaskEstimationModule):
308
+ def __init__(
309
+ self,
310
+ band_specs: List[Tuple[float, float]],
311
+ emb_dim: int,
312
+ mlp_dim: int,
313
+ in_channel: Optional[int],
314
+ hidden_activation: str = "Tanh",
315
+ hidden_activation_kwargs: Dict = None,
316
+ complex_mask: bool = True,
317
+ **kwargs,
318
+ ) -> None:
319
+ check_nonzero_bandwidth(band_specs)
320
+ check_no_gap(band_specs)
321
+ check_no_overlap(band_specs)
322
+ super().__init__(
323
+ in_channel=in_channel,
324
+ band_specs=band_specs,
325
+ freq_weights=None,
326
+ n_freq=None,
327
+ emb_dim=emb_dim,
328
+ mlp_dim=mlp_dim,
329
+ hidden_activation=hidden_activation,
330
+ hidden_activation_kwargs=hidden_activation_kwargs,
331
+ complex_mask=complex_mask,
332
+ )
333
+
334
+ def forward(self, q, cond=None):
335
+ # q = (batch, n_bands, n_time, emb_dim)
336
+
337
+ masks = self.compute_masks(
338
+ q
339
+ ) # [n_bands * (batch, in_channel, bandwidth, n_time)]
340
+
341
+ # TODO: currently this requires band specs to have no gap and no overlap
342
+ masks = torch.concat(
343
+ masks,
344
+ dim=2
345
+ ) # (batch, in_channel, n_freq, n_time)
346
+
347
+ return masks
separator/models/bandit/core/model/bsrnn/tfmodel.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.nn.modules import rnn
7
+
8
+ import torch.backends.cuda
9
+
10
+
11
+ class TimeFrequencyModellingModule(nn.Module):
12
+ def __init__(self) -> None:
13
+ super().__init__()
14
+
15
+
16
+ class ResidualRNN(nn.Module):
17
+ def __init__(
18
+ self,
19
+ emb_dim: int,
20
+ rnn_dim: int,
21
+ bidirectional: bool = True,
22
+ rnn_type: str = "LSTM",
23
+ use_batch_trick: bool = True,
24
+ use_layer_norm: bool = True,
25
+ ) -> None:
26
+ # n_group is the size of the 2nd dim
27
+ super().__init__()
28
+
29
+ self.use_layer_norm = use_layer_norm
30
+ if use_layer_norm:
31
+ self.norm = nn.LayerNorm(emb_dim)
32
+ else:
33
+ self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)
34
+
35
+ self.rnn = rnn.__dict__[rnn_type](
36
+ input_size=emb_dim,
37
+ hidden_size=rnn_dim,
38
+ num_layers=1,
39
+ batch_first=True,
40
+ bidirectional=bidirectional,
41
+ )
42
+
43
+ self.fc = nn.Linear(
44
+ in_features=rnn_dim * (2 if bidirectional else 1),
45
+ out_features=emb_dim
46
+ )
47
+
48
+ self.use_batch_trick = use_batch_trick
49
+ if not self.use_batch_trick:
50
+ warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
51
+
52
+ def forward(self, z):
53
+ # z = (batch, n_uncrossed, n_across, emb_dim)
54
+
55
+ z0 = torch.clone(z)
56
+
57
+ # print(z.device)
58
+
59
+ if self.use_layer_norm:
60
+ z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
61
+ else:
62
+ z = torch.permute(
63
+ z, (0, 3, 1, 2)
64
+ ) # (batch, emb_dim, n_uncrossed, n_across)
65
+
66
+ z = self.norm(z) # (batch, emb_dim, n_uncrossed, n_across)
67
+
68
+ z = torch.permute(
69
+ z, (0, 2, 3, 1)
70
+ ) # (batch, n_uncrossed, n_across, emb_dim)
71
+
72
+ batch, n_uncrossed, n_across, emb_dim = z.shape
73
+
74
+ if self.use_batch_trick:
75
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
76
+
77
+ z = self.rnn(z.contiguous())[0] # (batch * n_uncrossed, n_across, dir_rnn_dim)
78
+
79
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
80
+ # (batch, n_uncrossed, n_across, dir_rnn_dim)
81
+ else:
82
+ # Note: this is EXTREMELY SLOW
83
+ zlist = []
84
+ for i in range(n_uncrossed):
85
+ zi = self.rnn(z[:, i, :, :])[0] # (batch, n_across, emb_dim)
86
+ zlist.append(zi)
87
+
88
+ z = torch.stack(
89
+ zlist,
90
+ dim=1
91
+ ) # (batch, n_uncrossed, n_across, dir_rnn_dim)
92
+
93
+ z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
94
+
95
+ z = z + z0
96
+
97
+ return z
98
+
99
+
100
+ class SeqBandModellingModule(TimeFrequencyModellingModule):
101
+ def __init__(
102
+ self,
103
+ n_modules: int = 12,
104
+ emb_dim: int = 128,
105
+ rnn_dim: int = 256,
106
+ bidirectional: bool = True,
107
+ rnn_type: str = "LSTM",
108
+ parallel_mode=False,
109
+ ) -> None:
110
+ super().__init__()
111
+ self.seqband = nn.ModuleList([])
112
+
113
+ if parallel_mode:
114
+ for _ in range(n_modules):
115
+ self.seqband.append(
116
+ nn.ModuleList(
117
+ [ResidualRNN(
118
+ emb_dim=emb_dim,
119
+ rnn_dim=rnn_dim,
120
+ bidirectional=bidirectional,
121
+ rnn_type=rnn_type,
122
+ ),
123
+ ResidualRNN(
124
+ emb_dim=emb_dim,
125
+ rnn_dim=rnn_dim,
126
+ bidirectional=bidirectional,
127
+ rnn_type=rnn_type,
128
+ )]
129
+ )
130
+ )
131
+ else:
132
+
133
+ for _ in range(2 * n_modules):
134
+ self.seqband.append(
135
+ ResidualRNN(
136
+ emb_dim=emb_dim,
137
+ rnn_dim=rnn_dim,
138
+ bidirectional=bidirectional,
139
+ rnn_type=rnn_type,
140
+ )
141
+ )
142
+
143
+ self.parallel_mode = parallel_mode
144
+
145
+ def forward(self, z):
146
+ # z = (batch, n_bands, n_time, emb_dim)
147
+
148
+ if self.parallel_mode:
149
+ for sbm_pair in self.seqband:
150
+ # z: (batch, n_bands, n_time, emb_dim)
151
+ sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
152
+ zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
153
+ zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
154
+ z = zt + zf.transpose(1, 2)
155
+ else:
156
+ for sbm in self.seqband:
157
+ z = sbm(z)
158
+ z = z.transpose(1, 2)
159
+
160
+ # (batch, n_bands, n_time, emb_dim)
161
+ # --> (batch, n_time, n_bands, emb_dim)
162
+ # OR
163
+ # (batch, n_time, n_bands, emb_dim)
164
+ # --> (batch, n_bands, n_time, emb_dim)
165
+
166
+ q = z
167
+ return q # (batch, n_bands, n_time, emb_dim)
168
+
169
+
170
+ class ResidualTransformer(nn.Module):
171
+ def __init__(
172
+ self,
173
+ emb_dim: int = 128,
174
+ rnn_dim: int = 256,
175
+ bidirectional: bool = True,
176
+ dropout: float = 0.0,
177
+ ) -> None:
178
+ # n_group is the size of the 2nd dim
179
+ super().__init__()
180
+
181
+ self.tf = nn.TransformerEncoderLayer(
182
+ d_model=emb_dim,
183
+ nhead=4,
184
+ dim_feedforward=rnn_dim,
185
+ batch_first=True
186
+ )
187
+
188
+ self.is_causal = not bidirectional
189
+ self.dropout = dropout
190
+
191
+ def forward(self, z):
192
+ batch, n_uncrossed, n_across, emb_dim = z.shape
193
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
194
+ z = self.tf(z, is_causal=self.is_causal) # (batch, n_uncrossed, n_across, emb_dim)
195
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim))
196
+
197
+ return z
198
+
199
+
200
+ class TransformerTimeFreqModule(TimeFrequencyModellingModule):
201
+ def __init__(
202
+ self,
203
+ n_modules: int = 12,
204
+ emb_dim: int = 128,
205
+ rnn_dim: int = 256,
206
+ bidirectional: bool = True,
207
+ dropout: float = 0.0,
208
+ ) -> None:
209
+ super().__init__()
210
+ self.norm = nn.LayerNorm(emb_dim)
211
+ self.seqband = nn.ModuleList([])
212
+
213
+ for _ in range(2 * n_modules):
214
+ self.seqband.append(
215
+ ResidualTransformer(
216
+ emb_dim=emb_dim,
217
+ rnn_dim=rnn_dim,
218
+ bidirectional=bidirectional,
219
+ dropout=dropout,
220
+ )
221
+ )
222
+
223
+ def forward(self, z):
224
+ # z = (batch, n_bands, n_time, emb_dim)
225
+ z = self.norm(z) # (batch, n_bands, n_time, emb_dim)
226
+
227
+ for sbm in self.seqband:
228
+ z = sbm(z)
229
+ z = z.transpose(1, 2)
230
+
231
+ # (batch, n_bands, n_time, emb_dim)
232
+ # --> (batch, n_time, n_bands, emb_dim)
233
+ # OR
234
+ # (batch, n_time, n_bands, emb_dim)
235
+ # --> (batch, n_bands, n_time, emb_dim)
236
+
237
+ q = z
238
+ return q # (batch, n_bands, n_time, emb_dim)
239
+
240
+
241
+
242
+ class ResidualConvolution(nn.Module):
243
+ def __init__(
244
+ self,
245
+ emb_dim: int = 128,
246
+ rnn_dim: int = 256,
247
+ bidirectional: bool = True,
248
+ dropout: float = 0.0,
249
+ ) -> None:
250
+ # n_group is the size of the 2nd dim
251
+ super().__init__()
252
+ self.norm = nn.InstanceNorm2d(emb_dim, affine=True)
253
+
254
+ self.conv = nn.Sequential(
255
+ nn.Conv2d(
256
+ in_channels=emb_dim,
257
+ out_channels=rnn_dim,
258
+ kernel_size=(3, 3),
259
+ padding="same",
260
+ stride=(1, 1),
261
+ ),
262
+ nn.Tanhshrink()
263
+ )
264
+
265
+ self.is_causal = not bidirectional
266
+ self.dropout = dropout
267
+
268
+ self.fc = nn.Conv2d(
269
+ in_channels=rnn_dim,
270
+ out_channels=emb_dim,
271
+ kernel_size=(1, 1),
272
+ padding="same",
273
+ stride=(1, 1),
274
+ )
275
+
276
+
277
+ def forward(self, z):
278
+ # z = (batch, n_uncrossed, n_across, emb_dim)
279
+
280
+ z0 = torch.clone(z)
281
+
282
+ z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
283
+ z = self.conv(z) # (batch, n_uncrossed, n_across, emb_dim)
284
+ z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
285
+ z = z + z0
286
+
287
+ return z
288
+
289
+
290
+ class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule):
291
+ def __init__(
292
+ self,
293
+ n_modules: int = 12,
294
+ emb_dim: int = 128,
295
+ rnn_dim: int = 256,
296
+ bidirectional: bool = True,
297
+ dropout: float = 0.0,
298
+ ) -> None:
299
+ super().__init__()
300
+ self.seqband = torch.jit.script(nn.Sequential(
301
+ *[ResidualConvolution(
302
+ emb_dim=emb_dim,
303
+ rnn_dim=rnn_dim,
304
+ bidirectional=bidirectional,
305
+ dropout=dropout,
306
+ ) for _ in range(2 * n_modules) ]))
307
+
308
+ def forward(self, z):
309
+ # z = (batch, n_bands, n_time, emb_dim)
310
+
311
+ z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time)
312
+
313
+ z = self.seqband(z) # (batch, emb_dim, n_bands, n_time)
314
+
315
+ z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim)
316
+
317
+ return z
separator/models/bandit/core/model/bsrnn/utils.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import abstractmethod
3
+ from typing import Any, Callable
4
+
5
+ import numpy as np
6
+ import torch
7
+ from librosa import hz_to_midi, midi_to_hz
8
+ from torch import Tensor
9
+ from torchaudio import functional as taF
10
+ from spafe.fbanks import bark_fbanks
11
+ from spafe.utils.converters import erb2hz, hz2bark, hz2erb
12
+ from torchaudio.functional.functional import _create_triangular_filterbank
13
+
14
+
15
+ def band_widths_from_specs(band_specs):
16
+ return [e - i for i, e in band_specs]
17
+
18
+
19
+ def check_nonzero_bandwidth(band_specs):
20
+ # pprint(band_specs)
21
+ for fstart, fend in band_specs:
22
+ if fend - fstart <= 0:
23
+ raise ValueError("Bands cannot be zero-width")
24
+
25
+
26
+ def check_no_overlap(band_specs):
27
+ fend_prev = -1
28
+ for fstart_curr, fend_curr in band_specs:
29
+ if fstart_curr <= fend_prev:
30
+ raise ValueError("Bands cannot overlap")
31
+
32
+
33
+ def check_no_gap(band_specs):
34
+ fstart, _ = band_specs[0]
35
+ assert fstart == 0
36
+
37
+ fend_prev = -1
38
+ for fstart_curr, fend_curr in band_specs:
39
+ if fstart_curr - fend_prev > 1:
40
+ raise ValueError("Bands cannot leave gap")
41
+ fend_prev = fend_curr
42
+
43
+
44
+ class BandsplitSpecification:
45
+ def __init__(self, nfft: int, fs: int) -> None:
46
+ self.fs = fs
47
+ self.nfft = nfft
48
+ self.nyquist = fs / 2
49
+ self.max_index = nfft // 2 + 1
50
+
51
+ self.split500 = self.hertz_to_index(500)
52
+ self.split1k = self.hertz_to_index(1000)
53
+ self.split2k = self.hertz_to_index(2000)
54
+ self.split4k = self.hertz_to_index(4000)
55
+ self.split8k = self.hertz_to_index(8000)
56
+ self.split16k = self.hertz_to_index(16000)
57
+ self.split20k = self.hertz_to_index(20000)
58
+
59
+ self.above20k = [(self.split20k, self.max_index)]
60
+ self.above16k = [(self.split16k, self.split20k)] + self.above20k
61
+
62
+ def index_to_hertz(self, index: int):
63
+ return index * self.fs / self.nfft
64
+
65
+ def hertz_to_index(self, hz: float, round: bool = True):
66
+ index = hz * self.nfft / self.fs
67
+
68
+ if round:
69
+ index = int(np.round(index))
70
+
71
+ return index
72
+
73
+ def get_band_specs_with_bandwidth(
74
+ self,
75
+ start_index,
76
+ end_index,
77
+ bandwidth_hz
78
+ ):
79
+ band_specs = []
80
+ lower = start_index
81
+
82
+ while lower < end_index:
83
+ upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
84
+ upper = min(upper, end_index)
85
+
86
+ band_specs.append((lower, upper))
87
+ lower = upper
88
+
89
+ return band_specs
90
+
91
+ @abstractmethod
92
+ def get_band_specs(self):
93
+ raise NotImplementedError
94
+
95
+
96
+ class VocalBandsplitSpecification(BandsplitSpecification):
97
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
98
+ super().__init__(nfft=nfft, fs=fs)
99
+
100
+ self.version = version
101
+
102
+ def get_band_specs(self):
103
+ return getattr(self, f"version{self.version}")()
104
+
105
+ @property
106
+ def version1(self):
107
+ return self.get_band_specs_with_bandwidth(
108
+ start_index=0, end_index=self.max_index, bandwidth_hz=1000
109
+ )
110
+
111
+ def version2(self):
112
+ below16k = self.get_band_specs_with_bandwidth(
113
+ start_index=0, end_index=self.split16k, bandwidth_hz=1000
114
+ )
115
+ below20k = self.get_band_specs_with_bandwidth(
116
+ start_index=self.split16k,
117
+ end_index=self.split20k,
118
+ bandwidth_hz=2000
119
+ )
120
+
121
+ return below16k + below20k + self.above20k
122
+
123
+ def version3(self):
124
+ below8k = self.get_band_specs_with_bandwidth(
125
+ start_index=0, end_index=self.split8k, bandwidth_hz=1000
126
+ )
127
+ below16k = self.get_band_specs_with_bandwidth(
128
+ start_index=self.split8k,
129
+ end_index=self.split16k,
130
+ bandwidth_hz=2000
131
+ )
132
+
133
+ return below8k + below16k + self.above16k
134
+
135
+ def version4(self):
136
+ below1k = self.get_band_specs_with_bandwidth(
137
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
138
+ )
139
+ below8k = self.get_band_specs_with_bandwidth(
140
+ start_index=self.split1k,
141
+ end_index=self.split8k,
142
+ bandwidth_hz=1000
143
+ )
144
+ below16k = self.get_band_specs_with_bandwidth(
145
+ start_index=self.split8k,
146
+ end_index=self.split16k,
147
+ bandwidth_hz=2000
148
+ )
149
+
150
+ return below1k + below8k + below16k + self.above16k
151
+
152
+ def version5(self):
153
+ below1k = self.get_band_specs_with_bandwidth(
154
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
155
+ )
156
+ below16k = self.get_band_specs_with_bandwidth(
157
+ start_index=self.split1k,
158
+ end_index=self.split16k,
159
+ bandwidth_hz=1000
160
+ )
161
+ below20k = self.get_band_specs_with_bandwidth(
162
+ start_index=self.split16k,
163
+ end_index=self.split20k,
164
+ bandwidth_hz=2000
165
+ )
166
+ return below1k + below16k + below20k + self.above20k
167
+
168
+ def version6(self):
169
+ below1k = self.get_band_specs_with_bandwidth(
170
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
171
+ )
172
+ below4k = self.get_band_specs_with_bandwidth(
173
+ start_index=self.split1k,
174
+ end_index=self.split4k,
175
+ bandwidth_hz=500
176
+ )
177
+ below8k = self.get_band_specs_with_bandwidth(
178
+ start_index=self.split4k,
179
+ end_index=self.split8k,
180
+ bandwidth_hz=1000
181
+ )
182
+ below16k = self.get_band_specs_with_bandwidth(
183
+ start_index=self.split8k,
184
+ end_index=self.split16k,
185
+ bandwidth_hz=2000
186
+ )
187
+ return below1k + below4k + below8k + below16k + self.above16k
188
+
189
+ def version7(self):
190
+ below1k = self.get_band_specs_with_bandwidth(
191
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
192
+ )
193
+ below4k = self.get_band_specs_with_bandwidth(
194
+ start_index=self.split1k,
195
+ end_index=self.split4k,
196
+ bandwidth_hz=250
197
+ )
198
+ below8k = self.get_band_specs_with_bandwidth(
199
+ start_index=self.split4k,
200
+ end_index=self.split8k,
201
+ bandwidth_hz=500
202
+ )
203
+ below16k = self.get_band_specs_with_bandwidth(
204
+ start_index=self.split8k,
205
+ end_index=self.split16k,
206
+ bandwidth_hz=1000
207
+ )
208
+ below20k = self.get_band_specs_with_bandwidth(
209
+ start_index=self.split16k,
210
+ end_index=self.split20k,
211
+ bandwidth_hz=2000
212
+ )
213
+ return below1k + below4k + below8k + below16k + below20k + self.above20k
214
+
215
+
216
+ class OtherBandsplitSpecification(VocalBandsplitSpecification):
217
+ def __init__(self, nfft: int, fs: int) -> None:
218
+ super().__init__(nfft=nfft, fs=fs, version="7")
219
+
220
+
221
+ class BassBandsplitSpecification(BandsplitSpecification):
222
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
223
+ super().__init__(nfft=nfft, fs=fs)
224
+
225
+ def get_band_specs(self):
226
+ below500 = self.get_band_specs_with_bandwidth(
227
+ start_index=0, end_index=self.split500, bandwidth_hz=50
228
+ )
229
+ below1k = self.get_band_specs_with_bandwidth(
230
+ start_index=self.split500,
231
+ end_index=self.split1k,
232
+ bandwidth_hz=100
233
+ )
234
+ below4k = self.get_band_specs_with_bandwidth(
235
+ start_index=self.split1k,
236
+ end_index=self.split4k,
237
+ bandwidth_hz=500
238
+ )
239
+ below8k = self.get_band_specs_with_bandwidth(
240
+ start_index=self.split4k,
241
+ end_index=self.split8k,
242
+ bandwidth_hz=1000
243
+ )
244
+ below16k = self.get_band_specs_with_bandwidth(
245
+ start_index=self.split8k,
246
+ end_index=self.split16k,
247
+ bandwidth_hz=2000
248
+ )
249
+ above16k = [(self.split16k, self.max_index)]
250
+
251
+ return below500 + below1k + below4k + below8k + below16k + above16k
252
+
253
+
254
+ class DrumBandsplitSpecification(BandsplitSpecification):
255
+ def __init__(self, nfft: int, fs: int) -> None:
256
+ super().__init__(nfft=nfft, fs=fs)
257
+
258
+ def get_band_specs(self):
259
+ below1k = self.get_band_specs_with_bandwidth(
260
+ start_index=0, end_index=self.split1k, bandwidth_hz=50
261
+ )
262
+ below2k = self.get_band_specs_with_bandwidth(
263
+ start_index=self.split1k,
264
+ end_index=self.split2k,
265
+ bandwidth_hz=100
266
+ )
267
+ below4k = self.get_band_specs_with_bandwidth(
268
+ start_index=self.split2k,
269
+ end_index=self.split4k,
270
+ bandwidth_hz=250
271
+ )
272
+ below8k = self.get_band_specs_with_bandwidth(
273
+ start_index=self.split4k,
274
+ end_index=self.split8k,
275
+ bandwidth_hz=500
276
+ )
277
+ below16k = self.get_band_specs_with_bandwidth(
278
+ start_index=self.split8k,
279
+ end_index=self.split16k,
280
+ bandwidth_hz=1000
281
+ )
282
+ above16k = [(self.split16k, self.max_index)]
283
+
284
+ return below1k + below2k + below4k + below8k + below16k + above16k
285
+
286
+
287
+
288
+
289
+ class PerceptualBandsplitSpecification(BandsplitSpecification):
290
+ def __init__(
291
+ self,
292
+ nfft: int,
293
+ fs: int,
294
+ fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
295
+ n_bands: int,
296
+ f_min: float = 0.0,
297
+ f_max: float = None
298
+ ) -> None:
299
+ super().__init__(nfft=nfft, fs=fs)
300
+ self.n_bands = n_bands
301
+ if f_max is None:
302
+ f_max = fs / 2
303
+
304
+ self.filterbank = fbank_fn(
305
+ n_bands, fs, f_min, f_max, self.max_index
306
+ )
307
+
308
+ weight_per_bin = torch.sum(
309
+ self.filterbank,
310
+ dim=0,
311
+ keepdim=True
312
+ ) # (1, n_freqs)
313
+ normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
314
+
315
+ freq_weights = []
316
+ band_specs = []
317
+ for i in range(self.n_bands):
318
+ active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
319
+ if isinstance(active_bins, int):
320
+ active_bins = (active_bins, active_bins)
321
+ if len(active_bins) == 0:
322
+ continue
323
+ start_index = active_bins[0]
324
+ end_index = active_bins[-1] + 1
325
+ band_specs.append((start_index, end_index))
326
+ freq_weights.append(normalized_mel_fb[i, start_index:end_index])
327
+
328
+ self.freq_weights = freq_weights
329
+ self.band_specs = band_specs
330
+
331
+ def get_band_specs(self):
332
+ return self.band_specs
333
+
334
+ def get_freq_weights(self):
335
+ return self.freq_weights
336
+
337
+ def save_to_file(self, dir_path: str) -> None:
338
+
339
+ os.makedirs(dir_path, exist_ok=True)
340
+
341
+ import pickle
342
+
343
+ with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
344
+ pickle.dump(
345
+ {
346
+ "band_specs": self.band_specs,
347
+ "freq_weights": self.freq_weights,
348
+ "filterbank": self.filterbank,
349
+ },
350
+ f,
351
+ )
352
+
353
+ def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
354
+ fb = taF.melscale_fbanks(
355
+ n_mels=n_bands,
356
+ sample_rate=fs,
357
+ f_min=f_min,
358
+ f_max=f_max,
359
+ n_freqs=n_freqs,
360
+ ).T
361
+
362
+ fb[0, 0] = 1.0
363
+
364
+ return fb
365
+
366
+
367
+ class MelBandsplitSpecification(PerceptualBandsplitSpecification):
368
+ def __init__(
369
+ self,
370
+ nfft: int,
371
+ fs: int,
372
+ n_bands: int,
373
+ f_min: float = 0.0,
374
+ f_max: float = None
375
+ ) -> None:
376
+ super().__init__(fbank_fn=mel_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
377
+
378
+ def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs,
379
+ scale="constant"):
380
+
381
+ nfft = 2 * (n_freqs - 1)
382
+ df = fs / nfft
383
+ # init freqs
384
+ f_max = f_max or fs / 2
385
+ f_min = f_min or 0
386
+ f_min = fs / nfft
387
+
388
+ n_octaves = np.log2(f_max / f_min)
389
+ n_octaves_per_band = n_octaves / n_bands
390
+ bandwidth_mult = np.power(2.0, n_octaves_per_band)
391
+
392
+ low_midi = max(0, hz_to_midi(f_min))
393
+ high_midi = hz_to_midi(f_max)
394
+ midi_points = np.linspace(low_midi, high_midi, n_bands)
395
+ hz_pts = midi_to_hz(midi_points)
396
+
397
+ low_pts = hz_pts / bandwidth_mult
398
+ high_pts = hz_pts * bandwidth_mult
399
+
400
+ low_bins = np.floor(low_pts / df).astype(int)
401
+ high_bins = np.ceil(high_pts / df).astype(int)
402
+
403
+ fb = np.zeros((n_bands, n_freqs))
404
+
405
+ for i in range(n_bands):
406
+ fb[i, low_bins[i]:high_bins[i]+1] = 1.0
407
+
408
+ fb[0, :low_bins[0]] = 1.0
409
+ fb[-1, high_bins[-1]+1:] = 1.0
410
+
411
+ return torch.as_tensor(fb)
412
+
413
+ class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
414
+ def __init__(
415
+ self,
416
+ nfft: int,
417
+ fs: int,
418
+ n_bands: int,
419
+ f_min: float = 0.0,
420
+ f_max: float = None
421
+ ) -> None:
422
+ super().__init__(fbank_fn=musical_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
423
+
424
+
425
+ def bark_filterbank(
426
+ n_bands, fs, f_min, f_max, n_freqs
427
+ ):
428
+ nfft = 2 * (n_freqs -1)
429
+ fb, _ = bark_fbanks.bark_filter_banks(
430
+ nfilts=n_bands,
431
+ nfft=nfft,
432
+ fs=fs,
433
+ low_freq=f_min,
434
+ high_freq=f_max,
435
+ scale="constant"
436
+ )
437
+
438
+ return torch.as_tensor(fb)
439
+
440
+ class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
441
+ def __init__(
442
+ self,
443
+ nfft: int,
444
+ fs: int,
445
+ n_bands: int,
446
+ f_min: float = 0.0,
447
+ f_max: float = None
448
+ ) -> None:
449
+ super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
450
+
451
+
452
+ def triangular_bark_filterbank(
453
+ n_bands, fs, f_min, f_max, n_freqs
454
+ ):
455
+
456
+ all_freqs = torch.linspace(0, fs // 2, n_freqs)
457
+
458
+ # calculate mel freq bins
459
+ m_min = hz2bark(f_min)
460
+ m_max = hz2bark(f_max)
461
+
462
+ m_pts = torch.linspace(m_min, m_max, n_bands + 2)
463
+ f_pts = 600 * torch.sinh(m_pts / 6)
464
+
465
+ # create filterbank
466
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
467
+
468
+ fb = fb.T
469
+
470
+ first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
471
+ first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
472
+
473
+ fb[first_active_band, :first_active_bin] = 1.0
474
+
475
+ return fb
476
+
477
+ class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
478
+ def __init__(
479
+ self,
480
+ nfft: int,
481
+ fs: int,
482
+ n_bands: int,
483
+ f_min: float = 0.0,
484
+ f_max: float = None
485
+ ) -> None:
486
+ super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
487
+
488
+
489
+
490
+ def minibark_filterbank(
491
+ n_bands, fs, f_min, f_max, n_freqs
492
+ ):
493
+ fb = bark_filterbank(
494
+ n_bands,
495
+ fs,
496
+ f_min,
497
+ f_max,
498
+ n_freqs
499
+ )
500
+
501
+ fb[fb < np.sqrt(0.5)] = 0.0
502
+
503
+ return fb
504
+
505
+ class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
506
+ def __init__(
507
+ self,
508
+ nfft: int,
509
+ fs: int,
510
+ n_bands: int,
511
+ f_min: float = 0.0,
512
+ f_max: float = None
513
+ ) -> None:
514
+ super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
515
+
516
+
517
+
518
+
519
+
520
+ def erb_filterbank(
521
+ n_bands: int,
522
+ fs: int,
523
+ f_min: float,
524
+ f_max: float,
525
+ n_freqs: int,
526
+ ) -> Tensor:
527
+ # freq bins
528
+ A = (1000 * np.log(10)) / (24.7 * 4.37)
529
+ all_freqs = torch.linspace(0, fs // 2, n_freqs)
530
+
531
+ # calculate mel freq bins
532
+ m_min = hz2erb(f_min)
533
+ m_max = hz2erb(f_max)
534
+
535
+ m_pts = torch.linspace(m_min, m_max, n_bands + 2)
536
+ f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437
537
+
538
+ # create filterbank
539
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
540
+
541
+ fb = fb.T
542
+
543
+
544
+ first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
545
+ first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
546
+
547
+ fb[first_active_band, :first_active_bin] = 1.0
548
+
549
+ return fb
550
+
551
+
552
+
553
+ class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
554
+ def __init__(
555
+ self,
556
+ nfft: int,
557
+ fs: int,
558
+ n_bands: int,
559
+ f_min: float = 0.0,
560
+ f_max: float = None
561
+ ) -> None:
562
+ super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
563
+
564
+ if __name__ == "__main__":
565
+ import pandas as pd
566
+
567
+ band_defs = []
568
+
569
+ for bands in [VocalBandsplitSpecification]:
570
+ band_name = bands.__name__.replace("BandsplitSpecification", "")
571
+
572
+ mbs = bands(nfft=2048, fs=44100).get_band_specs()
573
+
574
+ for i, (f_min, f_max) in enumerate(mbs):
575
+ band_defs.append({
576
+ "band": band_name,
577
+ "band_index": i,
578
+ "f_min": f_min,
579
+ "f_max": f_max
580
+ })
581
+
582
+ df = pd.DataFrame(band_defs)
583
+ df.to_csv("vox7bands.csv", index=False)
separator/models/bandit/core/model/bsrnn/wrapper.py ADDED
@@ -0,0 +1,882 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pprint import pprint
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from models.bandit.core.model._spectral import _SpectralComponent
8
+ from models.bandit.core.model.bsrnn.utils import (
9
+ BarkBandsplitSpecification, BassBandsplitSpecification,
10
+ DrumBandsplitSpecification,
11
+ EquivalentRectangularBandsplitSpecification, MelBandsplitSpecification,
12
+ MusicalBandsplitSpecification, OtherBandsplitSpecification,
13
+ TriangularBarkBandsplitSpecification, VocalBandsplitSpecification,
14
+ )
15
+ from .core import (
16
+ MultiSourceMultiMaskBandSplitCoreConv,
17
+ MultiSourceMultiMaskBandSplitCoreRNN,
18
+ MultiSourceMultiMaskBandSplitCoreTransformer,
19
+ MultiSourceMultiPatchingMaskBandSplitCoreRNN, SingleMaskBandsplitCoreRNN,
20
+ SingleMaskBandsplitCoreTransformer,
21
+ )
22
+
23
+ import pytorch_lightning as pl
24
+
25
+ def get_band_specs(band_specs, n_fft, fs, n_bands=None):
26
+ if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
27
+ bsm = VocalBandsplitSpecification(
28
+ nfft=n_fft, fs=fs
29
+ ).get_band_specs()
30
+ freq_weights = None
31
+ overlapping_band = False
32
+ elif "tribark" in band_specs:
33
+ assert n_bands is not None
34
+ specs = TriangularBarkBandsplitSpecification(
35
+ nfft=n_fft,
36
+ fs=fs,
37
+ n_bands=n_bands
38
+ )
39
+ bsm = specs.get_band_specs()
40
+ freq_weights = specs.get_freq_weights()
41
+ overlapping_band = True
42
+ elif "bark" in band_specs:
43
+ assert n_bands is not None
44
+ specs = BarkBandsplitSpecification(
45
+ nfft=n_fft,
46
+ fs=fs,
47
+ n_bands=n_bands
48
+ )
49
+ bsm = specs.get_band_specs()
50
+ freq_weights = specs.get_freq_weights()
51
+ overlapping_band = True
52
+ elif "erb" in band_specs:
53
+ assert n_bands is not None
54
+ specs = EquivalentRectangularBandsplitSpecification(
55
+ nfft=n_fft,
56
+ fs=fs,
57
+ n_bands=n_bands
58
+ )
59
+ bsm = specs.get_band_specs()
60
+ freq_weights = specs.get_freq_weights()
61
+ overlapping_band = True
62
+ elif "musical" in band_specs:
63
+ assert n_bands is not None
64
+ specs = MusicalBandsplitSpecification(
65
+ nfft=n_fft,
66
+ fs=fs,
67
+ n_bands=n_bands
68
+ )
69
+ bsm = specs.get_band_specs()
70
+ freq_weights = specs.get_freq_weights()
71
+ overlapping_band = True
72
+ elif band_specs == "dnr:mel" or "mel" in band_specs:
73
+ assert n_bands is not None
74
+ specs = MelBandsplitSpecification(
75
+ nfft=n_fft,
76
+ fs=fs,
77
+ n_bands=n_bands
78
+ )
79
+ bsm = specs.get_band_specs()
80
+ freq_weights = specs.get_freq_weights()
81
+ overlapping_band = True
82
+ else:
83
+ raise NameError
84
+
85
+ return bsm, freq_weights, overlapping_band
86
+
87
+
88
+ def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
89
+ if band_specs_map == "musdb:all":
90
+ bsm = {
91
+ "vocals": VocalBandsplitSpecification(
92
+ nfft=n_fft, fs=fs
93
+ ).get_band_specs(),
94
+ "drums": DrumBandsplitSpecification(
95
+ nfft=n_fft, fs=fs
96
+ ).get_band_specs(),
97
+ "bass": BassBandsplitSpecification(
98
+ nfft=n_fft, fs=fs
99
+ ).get_band_specs(),
100
+ "other": OtherBandsplitSpecification(
101
+ nfft=n_fft, fs=fs
102
+ ).get_band_specs(),
103
+ }
104
+ freq_weights = None
105
+ overlapping_band = False
106
+ elif band_specs_map == "dnr:vox7":
107
+ bsm_, freq_weights, overlapping_band = get_band_specs(
108
+ "dnr:speech", n_fft, fs, n_bands
109
+ )
110
+ bsm = {
111
+ "speech": bsm_,
112
+ "music": bsm_,
113
+ "effects": bsm_
114
+ }
115
+ elif "dnr:vox7:" in band_specs_map:
116
+ stem = band_specs_map.split(":")[-1]
117
+ bsm_, freq_weights, overlapping_band = get_band_specs(
118
+ "dnr:speech", n_fft, fs, n_bands
119
+ )
120
+ bsm = {
121
+ stem: bsm_
122
+ }
123
+ else:
124
+ raise NameError
125
+
126
+ return bsm, freq_weights, overlapping_band
127
+
128
+
129
+ class BandSplitWrapperBase(pl.LightningModule):
130
+ bsrnn: nn.Module
131
+
132
+ def __init__(self, **kwargs):
133
+ super().__init__()
134
+
135
+
136
+ class SingleMaskMultiSourceBandSplitBase(
137
+ BandSplitWrapperBase,
138
+ _SpectralComponent
139
+ ):
140
+ def __init__(
141
+ self,
142
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
143
+ fs: int = 44100,
144
+ n_fft: int = 2048,
145
+ win_length: Optional[int] = 2048,
146
+ hop_length: int = 512,
147
+ window_fn: str = "hann_window",
148
+ wkwargs: Optional[Dict] = None,
149
+ power: Optional[int] = None,
150
+ center: bool = True,
151
+ normalized: bool = True,
152
+ pad_mode: str = "constant",
153
+ onesided: bool = True,
154
+ n_bands: int = None,
155
+ ) -> None:
156
+ super().__init__(
157
+ n_fft=n_fft,
158
+ win_length=win_length,
159
+ hop_length=hop_length,
160
+ window_fn=window_fn,
161
+ wkwargs=wkwargs,
162
+ power=power,
163
+ center=center,
164
+ normalized=normalized,
165
+ pad_mode=pad_mode,
166
+ onesided=onesided,
167
+ )
168
+
169
+ if isinstance(band_specs_map, str):
170
+ self.band_specs_map, self.freq_weights, self.overlapping_band = get_band_specs_map(
171
+ band_specs_map,
172
+ n_fft,
173
+ fs,
174
+ n_bands=n_bands
175
+ )
176
+
177
+ self.stems = list(self.band_specs_map.keys())
178
+
179
+ def forward(self, batch):
180
+ audio = batch["audio"]
181
+
182
+ with torch.no_grad():
183
+ batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in
184
+ audio}
185
+
186
+ X = batch["spectrogram"]["mixture"]
187
+ length = batch["audio"]["mixture"].shape[-1]
188
+
189
+ output = {"spectrogram": {}, "audio": {}}
190
+
191
+ for stem, bsrnn in self.bsrnn.items():
192
+ S = bsrnn(X)
193
+ s = self.istft(S, length)
194
+ output["spectrogram"][stem] = S
195
+ output["audio"][stem] = s
196
+
197
+ return batch, output
198
+
199
+
200
+ class MultiMaskMultiSourceBandSplitBase(
201
+ BandSplitWrapperBase,
202
+ _SpectralComponent
203
+ ):
204
+ def __init__(
205
+ self,
206
+ stems: List[str],
207
+ band_specs: Union[str, List[Tuple[float, float]]],
208
+ fs: int = 44100,
209
+ n_fft: int = 2048,
210
+ win_length: Optional[int] = 2048,
211
+ hop_length: int = 512,
212
+ window_fn: str = "hann_window",
213
+ wkwargs: Optional[Dict] = None,
214
+ power: Optional[int] = None,
215
+ center: bool = True,
216
+ normalized: bool = True,
217
+ pad_mode: str = "constant",
218
+ onesided: bool = True,
219
+ n_bands: int = None,
220
+ ) -> None:
221
+ super().__init__(
222
+ n_fft=n_fft,
223
+ win_length=win_length,
224
+ hop_length=hop_length,
225
+ window_fn=window_fn,
226
+ wkwargs=wkwargs,
227
+ power=power,
228
+ center=center,
229
+ normalized=normalized,
230
+ pad_mode=pad_mode,
231
+ onesided=onesided,
232
+ )
233
+
234
+ if isinstance(band_specs, str):
235
+ self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
236
+ band_specs,
237
+ n_fft,
238
+ fs,
239
+ n_bands
240
+ )
241
+
242
+ self.stems = stems
243
+
244
+ def forward(self, batch):
245
+ # with torch.no_grad():
246
+ audio = batch["audio"]
247
+ cond = batch.get("condition", None)
248
+ with torch.no_grad():
249
+ batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in
250
+ audio}
251
+
252
+ X = batch["spectrogram"]["mixture"]
253
+ length = batch["audio"]["mixture"].shape[-1]
254
+
255
+ output = self.bsrnn(X, cond=cond)
256
+ output["audio"] = {}
257
+
258
+ for stem, S in output["spectrogram"].items():
259
+ s = self.istft(S, length)
260
+ output["audio"][stem] = s
261
+
262
+ return batch, output
263
+
264
+
265
+ class MultiMaskMultiSourceBandSplitBaseSimple(
266
+ BandSplitWrapperBase,
267
+ _SpectralComponent
268
+ ):
269
+ def __init__(
270
+ self,
271
+ stems: List[str],
272
+ band_specs: Union[str, List[Tuple[float, float]]],
273
+ fs: int = 44100,
274
+ n_fft: int = 2048,
275
+ win_length: Optional[int] = 2048,
276
+ hop_length: int = 512,
277
+ window_fn: str = "hann_window",
278
+ wkwargs: Optional[Dict] = None,
279
+ power: Optional[int] = None,
280
+ center: bool = True,
281
+ normalized: bool = True,
282
+ pad_mode: str = "constant",
283
+ onesided: bool = True,
284
+ n_bands: int = None,
285
+ ) -> None:
286
+ super().__init__(
287
+ n_fft=n_fft,
288
+ win_length=win_length,
289
+ hop_length=hop_length,
290
+ window_fn=window_fn,
291
+ wkwargs=wkwargs,
292
+ power=power,
293
+ center=center,
294
+ normalized=normalized,
295
+ pad_mode=pad_mode,
296
+ onesided=onesided,
297
+ )
298
+
299
+ if isinstance(band_specs, str):
300
+ self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
301
+ band_specs,
302
+ n_fft,
303
+ fs,
304
+ n_bands
305
+ )
306
+
307
+ self.stems = stems
308
+
309
+ def forward(self, batch):
310
+ with torch.no_grad():
311
+ X = self.stft(batch)
312
+ length = batch.shape[-1]
313
+ output = self.bsrnn(X, cond=None)
314
+ res = []
315
+ for stem, S in output["spectrogram"].items():
316
+ s = self.istft(S, length)
317
+ res.append(s)
318
+ res = torch.stack(res, dim=1)
319
+ return res
320
+
321
+
322
+ class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase):
323
+ def __init__(
324
+ self,
325
+ in_channel: int,
326
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
327
+ fs: int = 44100,
328
+ require_no_overlap: bool = False,
329
+ require_no_gap: bool = True,
330
+ normalize_channel_independently: bool = False,
331
+ treat_channel_as_feature: bool = True,
332
+ n_sqm_modules: int = 12,
333
+ emb_dim: int = 128,
334
+ rnn_dim: int = 256,
335
+ bidirectional: bool = True,
336
+ rnn_type: str = "LSTM",
337
+ mlp_dim: int = 512,
338
+ hidden_activation: str = "Tanh",
339
+ hidden_activation_kwargs: Optional[Dict] = None,
340
+ complex_mask: bool = True,
341
+ n_fft: int = 2048,
342
+ win_length: Optional[int] = 2048,
343
+ hop_length: int = 512,
344
+ window_fn: str = "hann_window",
345
+ wkwargs: Optional[Dict] = None,
346
+ power: Optional[int] = None,
347
+ center: bool = True,
348
+ normalized: bool = True,
349
+ pad_mode: str = "constant",
350
+ onesided: bool = True,
351
+ ) -> None:
352
+ super().__init__(
353
+ band_specs_map=band_specs_map,
354
+ fs=fs,
355
+ n_fft=n_fft,
356
+ win_length=win_length,
357
+ hop_length=hop_length,
358
+ window_fn=window_fn,
359
+ wkwargs=wkwargs,
360
+ power=power,
361
+ center=center,
362
+ normalized=normalized,
363
+ pad_mode=pad_mode,
364
+ onesided=onesided,
365
+ )
366
+
367
+ self.bsrnn = nn.ModuleDict(
368
+ {
369
+ src: SingleMaskBandsplitCoreRNN(
370
+ band_specs=specs,
371
+ in_channel=in_channel,
372
+ require_no_overlap=require_no_overlap,
373
+ require_no_gap=require_no_gap,
374
+ normalize_channel_independently=normalize_channel_independently,
375
+ treat_channel_as_feature=treat_channel_as_feature,
376
+ n_sqm_modules=n_sqm_modules,
377
+ emb_dim=emb_dim,
378
+ rnn_dim=rnn_dim,
379
+ bidirectional=bidirectional,
380
+ rnn_type=rnn_type,
381
+ mlp_dim=mlp_dim,
382
+ hidden_activation=hidden_activation,
383
+ hidden_activation_kwargs=hidden_activation_kwargs,
384
+ complex_mask=complex_mask,
385
+ )
386
+ for src, specs in self.band_specs_map.items()
387
+ }
388
+ )
389
+
390
+
391
+ class SingleMaskMultiSourceBandSplitTransformer(
392
+ SingleMaskMultiSourceBandSplitBase
393
+ ):
394
+ def __init__(
395
+ self,
396
+ in_channel: int,
397
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
398
+ fs: int = 44100,
399
+ require_no_overlap: bool = False,
400
+ require_no_gap: bool = True,
401
+ normalize_channel_independently: bool = False,
402
+ treat_channel_as_feature: bool = True,
403
+ n_sqm_modules: int = 12,
404
+ emb_dim: int = 128,
405
+ rnn_dim: int = 256,
406
+ bidirectional: bool = True,
407
+ tf_dropout: float = 0.0,
408
+ mlp_dim: int = 512,
409
+ hidden_activation: str = "Tanh",
410
+ hidden_activation_kwargs: Optional[Dict] = None,
411
+ complex_mask: bool = True,
412
+ n_fft: int = 2048,
413
+ win_length: Optional[int] = 2048,
414
+ hop_length: int = 512,
415
+ window_fn: str = "hann_window",
416
+ wkwargs: Optional[Dict] = None,
417
+ power: Optional[int] = None,
418
+ center: bool = True,
419
+ normalized: bool = True,
420
+ pad_mode: str = "constant",
421
+ onesided: bool = True,
422
+ ) -> None:
423
+ super().__init__(
424
+ band_specs_map=band_specs_map,
425
+ fs=fs,
426
+ n_fft=n_fft,
427
+ win_length=win_length,
428
+ hop_length=hop_length,
429
+ window_fn=window_fn,
430
+ wkwargs=wkwargs,
431
+ power=power,
432
+ center=center,
433
+ normalized=normalized,
434
+ pad_mode=pad_mode,
435
+ onesided=onesided,
436
+ )
437
+
438
+ self.bsrnn = nn.ModuleDict(
439
+ {
440
+ src: SingleMaskBandsplitCoreTransformer(
441
+ band_specs=specs,
442
+ in_channel=in_channel,
443
+ require_no_overlap=require_no_overlap,
444
+ require_no_gap=require_no_gap,
445
+ normalize_channel_independently=normalize_channel_independently,
446
+ treat_channel_as_feature=treat_channel_as_feature,
447
+ n_sqm_modules=n_sqm_modules,
448
+ emb_dim=emb_dim,
449
+ rnn_dim=rnn_dim,
450
+ bidirectional=bidirectional,
451
+ tf_dropout=tf_dropout,
452
+ mlp_dim=mlp_dim,
453
+ hidden_activation=hidden_activation,
454
+ hidden_activation_kwargs=hidden_activation_kwargs,
455
+ complex_mask=complex_mask,
456
+ )
457
+ for src, specs in self.band_specs_map.items()
458
+ }
459
+ )
460
+
461
+
462
+ class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
463
+ def __init__(
464
+ self,
465
+ in_channel: int,
466
+ stems: List[str],
467
+ band_specs: Union[str, List[Tuple[float, float]]],
468
+ fs: int = 44100,
469
+ require_no_overlap: bool = False,
470
+ require_no_gap: bool = True,
471
+ normalize_channel_independently: bool = False,
472
+ treat_channel_as_feature: bool = True,
473
+ n_sqm_modules: int = 12,
474
+ emb_dim: int = 128,
475
+ rnn_dim: int = 256,
476
+ cond_dim: int = 0,
477
+ bidirectional: bool = True,
478
+ rnn_type: str = "LSTM",
479
+ mlp_dim: int = 512,
480
+ hidden_activation: str = "Tanh",
481
+ hidden_activation_kwargs: Optional[Dict] = None,
482
+ complex_mask: bool = True,
483
+ n_fft: int = 2048,
484
+ win_length: Optional[int] = 2048,
485
+ hop_length: int = 512,
486
+ window_fn: str = "hann_window",
487
+ wkwargs: Optional[Dict] = None,
488
+ power: Optional[int] = None,
489
+ center: bool = True,
490
+ normalized: bool = True,
491
+ pad_mode: str = "constant",
492
+ onesided: bool = True,
493
+ n_bands: int = None,
494
+ use_freq_weights: bool = True,
495
+ normalize_input: bool = False,
496
+ mult_add_mask: bool = False,
497
+ freeze_encoder: bool = False,
498
+ ) -> None:
499
+ super().__init__(
500
+ stems=stems,
501
+ band_specs=band_specs,
502
+ fs=fs,
503
+ n_fft=n_fft,
504
+ win_length=win_length,
505
+ hop_length=hop_length,
506
+ window_fn=window_fn,
507
+ wkwargs=wkwargs,
508
+ power=power,
509
+ center=center,
510
+ normalized=normalized,
511
+ pad_mode=pad_mode,
512
+ onesided=onesided,
513
+ n_bands=n_bands,
514
+ )
515
+
516
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
517
+ stems=stems,
518
+ band_specs=self.band_specs,
519
+ in_channel=in_channel,
520
+ require_no_overlap=require_no_overlap,
521
+ require_no_gap=require_no_gap,
522
+ normalize_channel_independently=normalize_channel_independently,
523
+ treat_channel_as_feature=treat_channel_as_feature,
524
+ n_sqm_modules=n_sqm_modules,
525
+ emb_dim=emb_dim,
526
+ rnn_dim=rnn_dim,
527
+ bidirectional=bidirectional,
528
+ rnn_type=rnn_type,
529
+ mlp_dim=mlp_dim,
530
+ cond_dim=cond_dim,
531
+ hidden_activation=hidden_activation,
532
+ hidden_activation_kwargs=hidden_activation_kwargs,
533
+ complex_mask=complex_mask,
534
+ overlapping_band=self.overlapping_band,
535
+ freq_weights=self.freq_weights,
536
+ n_freq=n_fft // 2 + 1,
537
+ use_freq_weights=use_freq_weights,
538
+ mult_add_mask=mult_add_mask
539
+ )
540
+
541
+ self.normalize_input = normalize_input
542
+ self.cond_dim = cond_dim
543
+
544
+ if freeze_encoder:
545
+ for param in self.bsrnn.band_split.parameters():
546
+ param.requires_grad = False
547
+
548
+ for param in self.bsrnn.tf_model.parameters():
549
+ param.requires_grad = False
550
+
551
+
552
+ class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple):
553
+ def __init__(
554
+ self,
555
+ in_channel: int,
556
+ stems: List[str],
557
+ band_specs: Union[str, List[Tuple[float, float]]],
558
+ fs: int = 44100,
559
+ require_no_overlap: bool = False,
560
+ require_no_gap: bool = True,
561
+ normalize_channel_independently: bool = False,
562
+ treat_channel_as_feature: bool = True,
563
+ n_sqm_modules: int = 12,
564
+ emb_dim: int = 128,
565
+ rnn_dim: int = 256,
566
+ cond_dim: int = 0,
567
+ bidirectional: bool = True,
568
+ rnn_type: str = "LSTM",
569
+ mlp_dim: int = 512,
570
+ hidden_activation: str = "Tanh",
571
+ hidden_activation_kwargs: Optional[Dict] = None,
572
+ complex_mask: bool = True,
573
+ n_fft: int = 2048,
574
+ win_length: Optional[int] = 2048,
575
+ hop_length: int = 512,
576
+ window_fn: str = "hann_window",
577
+ wkwargs: Optional[Dict] = None,
578
+ power: Optional[int] = None,
579
+ center: bool = True,
580
+ normalized: bool = True,
581
+ pad_mode: str = "constant",
582
+ onesided: bool = True,
583
+ n_bands: int = None,
584
+ use_freq_weights: bool = True,
585
+ normalize_input: bool = False,
586
+ mult_add_mask: bool = False,
587
+ freeze_encoder: bool = False,
588
+ ) -> None:
589
+ super().__init__(
590
+ stems=stems,
591
+ band_specs=band_specs,
592
+ fs=fs,
593
+ n_fft=n_fft,
594
+ win_length=win_length,
595
+ hop_length=hop_length,
596
+ window_fn=window_fn,
597
+ wkwargs=wkwargs,
598
+ power=power,
599
+ center=center,
600
+ normalized=normalized,
601
+ pad_mode=pad_mode,
602
+ onesided=onesided,
603
+ n_bands=n_bands,
604
+ )
605
+
606
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
607
+ stems=stems,
608
+ band_specs=self.band_specs,
609
+ in_channel=in_channel,
610
+ require_no_overlap=require_no_overlap,
611
+ require_no_gap=require_no_gap,
612
+ normalize_channel_independently=normalize_channel_independently,
613
+ treat_channel_as_feature=treat_channel_as_feature,
614
+ n_sqm_modules=n_sqm_modules,
615
+ emb_dim=emb_dim,
616
+ rnn_dim=rnn_dim,
617
+ bidirectional=bidirectional,
618
+ rnn_type=rnn_type,
619
+ mlp_dim=mlp_dim,
620
+ cond_dim=cond_dim,
621
+ hidden_activation=hidden_activation,
622
+ hidden_activation_kwargs=hidden_activation_kwargs,
623
+ complex_mask=complex_mask,
624
+ overlapping_band=self.overlapping_band,
625
+ freq_weights=self.freq_weights,
626
+ n_freq=n_fft // 2 + 1,
627
+ use_freq_weights=use_freq_weights,
628
+ mult_add_mask=mult_add_mask
629
+ )
630
+
631
+ self.normalize_input = normalize_input
632
+ self.cond_dim = cond_dim
633
+
634
+ if freeze_encoder:
635
+ for param in self.bsrnn.band_split.parameters():
636
+ param.requires_grad = False
637
+
638
+ for param in self.bsrnn.tf_model.parameters():
639
+ param.requires_grad = False
640
+
641
+
642
+ class MultiMaskMultiSourceBandSplitTransformer(
643
+ MultiMaskMultiSourceBandSplitBase
644
+ ):
645
+ def __init__(
646
+ self,
647
+ in_channel: int,
648
+ stems: List[str],
649
+ band_specs: Union[str, List[Tuple[float, float]]],
650
+ fs: int = 44100,
651
+ require_no_overlap: bool = False,
652
+ require_no_gap: bool = True,
653
+ normalize_channel_independently: bool = False,
654
+ treat_channel_as_feature: bool = True,
655
+ n_sqm_modules: int = 12,
656
+ emb_dim: int = 128,
657
+ rnn_dim: int = 256,
658
+ cond_dim: int = 0,
659
+ bidirectional: bool = True,
660
+ rnn_type: str = "LSTM",
661
+ mlp_dim: int = 512,
662
+ hidden_activation: str = "Tanh",
663
+ hidden_activation_kwargs: Optional[Dict] = None,
664
+ complex_mask: bool = True,
665
+ n_fft: int = 2048,
666
+ win_length: Optional[int] = 2048,
667
+ hop_length: int = 512,
668
+ window_fn: str = "hann_window",
669
+ wkwargs: Optional[Dict] = None,
670
+ power: Optional[int] = None,
671
+ center: bool = True,
672
+ normalized: bool = True,
673
+ pad_mode: str = "constant",
674
+ onesided: bool = True,
675
+ n_bands: int = None,
676
+ use_freq_weights: bool = True,
677
+ normalize_input: bool = False,
678
+ mult_add_mask: bool = False
679
+ ) -> None:
680
+ super().__init__(
681
+ stems=stems,
682
+ band_specs=band_specs,
683
+ fs=fs,
684
+ n_fft=n_fft,
685
+ win_length=win_length,
686
+ hop_length=hop_length,
687
+ window_fn=window_fn,
688
+ wkwargs=wkwargs,
689
+ power=power,
690
+ center=center,
691
+ normalized=normalized,
692
+ pad_mode=pad_mode,
693
+ onesided=onesided,
694
+ n_bands=n_bands,
695
+ )
696
+
697
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer(
698
+ stems=stems,
699
+ band_specs=self.band_specs,
700
+ in_channel=in_channel,
701
+ require_no_overlap=require_no_overlap,
702
+ require_no_gap=require_no_gap,
703
+ normalize_channel_independently=normalize_channel_independently,
704
+ treat_channel_as_feature=treat_channel_as_feature,
705
+ n_sqm_modules=n_sqm_modules,
706
+ emb_dim=emb_dim,
707
+ rnn_dim=rnn_dim,
708
+ bidirectional=bidirectional,
709
+ rnn_type=rnn_type,
710
+ mlp_dim=mlp_dim,
711
+ cond_dim=cond_dim,
712
+ hidden_activation=hidden_activation,
713
+ hidden_activation_kwargs=hidden_activation_kwargs,
714
+ complex_mask=complex_mask,
715
+ overlapping_band=self.overlapping_band,
716
+ freq_weights=self.freq_weights,
717
+ n_freq=n_fft // 2 + 1,
718
+ use_freq_weights=use_freq_weights,
719
+ mult_add_mask=mult_add_mask
720
+ )
721
+
722
+
723
+
724
+ class MultiMaskMultiSourceBandSplitConv(
725
+ MultiMaskMultiSourceBandSplitBase
726
+ ):
727
+ def __init__(
728
+ self,
729
+ in_channel: int,
730
+ stems: List[str],
731
+ band_specs: Union[str, List[Tuple[float, float]]],
732
+ fs: int = 44100,
733
+ require_no_overlap: bool = False,
734
+ require_no_gap: bool = True,
735
+ normalize_channel_independently: bool = False,
736
+ treat_channel_as_feature: bool = True,
737
+ n_sqm_modules: int = 12,
738
+ emb_dim: int = 128,
739
+ rnn_dim: int = 256,
740
+ cond_dim: int = 0,
741
+ bidirectional: bool = True,
742
+ rnn_type: str = "LSTM",
743
+ mlp_dim: int = 512,
744
+ hidden_activation: str = "Tanh",
745
+ hidden_activation_kwargs: Optional[Dict] = None,
746
+ complex_mask: bool = True,
747
+ n_fft: int = 2048,
748
+ win_length: Optional[int] = 2048,
749
+ hop_length: int = 512,
750
+ window_fn: str = "hann_window",
751
+ wkwargs: Optional[Dict] = None,
752
+ power: Optional[int] = None,
753
+ center: bool = True,
754
+ normalized: bool = True,
755
+ pad_mode: str = "constant",
756
+ onesided: bool = True,
757
+ n_bands: int = None,
758
+ use_freq_weights: bool = True,
759
+ normalize_input: bool = False,
760
+ mult_add_mask: bool = False
761
+ ) -> None:
762
+ super().__init__(
763
+ stems=stems,
764
+ band_specs=band_specs,
765
+ fs=fs,
766
+ n_fft=n_fft,
767
+ win_length=win_length,
768
+ hop_length=hop_length,
769
+ window_fn=window_fn,
770
+ wkwargs=wkwargs,
771
+ power=power,
772
+ center=center,
773
+ normalized=normalized,
774
+ pad_mode=pad_mode,
775
+ onesided=onesided,
776
+ n_bands=n_bands,
777
+ )
778
+
779
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv(
780
+ stems=stems,
781
+ band_specs=self.band_specs,
782
+ in_channel=in_channel,
783
+ require_no_overlap=require_no_overlap,
784
+ require_no_gap=require_no_gap,
785
+ normalize_channel_independently=normalize_channel_independently,
786
+ treat_channel_as_feature=treat_channel_as_feature,
787
+ n_sqm_modules=n_sqm_modules,
788
+ emb_dim=emb_dim,
789
+ rnn_dim=rnn_dim,
790
+ bidirectional=bidirectional,
791
+ rnn_type=rnn_type,
792
+ mlp_dim=mlp_dim,
793
+ cond_dim=cond_dim,
794
+ hidden_activation=hidden_activation,
795
+ hidden_activation_kwargs=hidden_activation_kwargs,
796
+ complex_mask=complex_mask,
797
+ overlapping_band=self.overlapping_band,
798
+ freq_weights=self.freq_weights,
799
+ n_freq=n_fft // 2 + 1,
800
+ use_freq_weights=use_freq_weights,
801
+ mult_add_mask=mult_add_mask
802
+ )
803
+ class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
804
+ def __init__(
805
+ self,
806
+ in_channel: int,
807
+ stems: List[str],
808
+ band_specs: Union[str, List[Tuple[float, float]]],
809
+ kernel_norm_mlp_version: int = 1,
810
+ mask_kernel_freq: int = 3,
811
+ mask_kernel_time: int = 3,
812
+ conv_kernel_freq: int = 1,
813
+ conv_kernel_time: int = 1,
814
+ fs: int = 44100,
815
+ require_no_overlap: bool = False,
816
+ require_no_gap: bool = True,
817
+ normalize_channel_independently: bool = False,
818
+ treat_channel_as_feature: bool = True,
819
+ n_sqm_modules: int = 12,
820
+ emb_dim: int = 128,
821
+ rnn_dim: int = 256,
822
+ bidirectional: bool = True,
823
+ rnn_type: str = "LSTM",
824
+ mlp_dim: int = 512,
825
+ hidden_activation: str = "Tanh",
826
+ hidden_activation_kwargs: Optional[Dict] = None,
827
+ complex_mask: bool = True,
828
+ n_fft: int = 2048,
829
+ win_length: Optional[int] = 2048,
830
+ hop_length: int = 512,
831
+ window_fn: str = "hann_window",
832
+ wkwargs: Optional[Dict] = None,
833
+ power: Optional[int] = None,
834
+ center: bool = True,
835
+ normalized: bool = True,
836
+ pad_mode: str = "constant",
837
+ onesided: bool = True,
838
+ n_bands: int = None,
839
+ ) -> None:
840
+ super().__init__(
841
+ stems=stems,
842
+ band_specs=band_specs,
843
+ fs=fs,
844
+ n_fft=n_fft,
845
+ win_length=win_length,
846
+ hop_length=hop_length,
847
+ window_fn=window_fn,
848
+ wkwargs=wkwargs,
849
+ power=power,
850
+ center=center,
851
+ normalized=normalized,
852
+ pad_mode=pad_mode,
853
+ onesided=onesided,
854
+ n_bands=n_bands,
855
+ )
856
+
857
+ self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN(
858
+ stems=stems,
859
+ band_specs=self.band_specs,
860
+ in_channel=in_channel,
861
+ require_no_overlap=require_no_overlap,
862
+ require_no_gap=require_no_gap,
863
+ normalize_channel_independently=normalize_channel_independently,
864
+ treat_channel_as_feature=treat_channel_as_feature,
865
+ n_sqm_modules=n_sqm_modules,
866
+ emb_dim=emb_dim,
867
+ rnn_dim=rnn_dim,
868
+ bidirectional=bidirectional,
869
+ rnn_type=rnn_type,
870
+ mlp_dim=mlp_dim,
871
+ hidden_activation=hidden_activation,
872
+ hidden_activation_kwargs=hidden_activation_kwargs,
873
+ complex_mask=complex_mask,
874
+ overlapping_band=self.overlapping_band,
875
+ freq_weights=self.freq_weights,
876
+ n_freq=n_fft // 2 + 1,
877
+ mask_kernel_freq=mask_kernel_freq,
878
+ mask_kernel_time=mask_kernel_time,
879
+ conv_kernel_freq=conv_kernel_freq,
880
+ conv_kernel_time=conv_kernel_time,
881
+ kernel_norm_mlp_version=kernel_norm_mlp_version,
882
+ )
separator/models/bandit/core/utils/__init__.py ADDED
File without changes
separator/models/bandit/core/utils/audio.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ from tqdm.auto import tqdm
4
+ from typing import Callable, Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+
12
+ @torch.jit.script
13
+ def merge(
14
+ combined: torch.Tensor,
15
+ original_batch_size: int,
16
+ n_channel: int,
17
+ n_chunks: int,
18
+ chunk_size: int, ):
19
+ combined = torch.reshape(
20
+ combined,
21
+ (original_batch_size, n_chunks, n_channel, chunk_size)
22
+ )
23
+ combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
24
+ original_batch_size * n_channel,
25
+ chunk_size,
26
+ n_chunks
27
+ )
28
+
29
+ return combined
30
+
31
+
32
+ @torch.jit.script
33
+ def unfold(
34
+ padded_audio: torch.Tensor,
35
+ original_batch_size: int,
36
+ n_channel: int,
37
+ chunk_size: int,
38
+ hop_size: int
39
+ ) -> torch.Tensor:
40
+
41
+ unfolded_input = F.unfold(
42
+ padded_audio[:, :, None, :],
43
+ kernel_size=(1, chunk_size),
44
+ stride=(1, hop_size)
45
+ )
46
+
47
+ _, _, n_chunks = unfolded_input.shape
48
+ unfolded_input = unfolded_input.view(
49
+ original_batch_size,
50
+ n_channel,
51
+ chunk_size,
52
+ n_chunks
53
+ )
54
+ unfolded_input = torch.permute(
55
+ unfolded_input,
56
+ (0, 3, 1, 2)
57
+ ).reshape(
58
+ original_batch_size * n_chunks,
59
+ n_channel,
60
+ chunk_size
61
+ )
62
+
63
+ return unfolded_input
64
+
65
+
66
+ @torch.jit.script
67
+ # @torch.compile
68
+ def merge_chunks_all(
69
+ combined: torch.Tensor,
70
+ original_batch_size: int,
71
+ n_channel: int,
72
+ n_samples: int,
73
+ n_padded_samples: int,
74
+ n_chunks: int,
75
+ chunk_size: int,
76
+ hop_size: int,
77
+ edge_frame_pad_sizes: Tuple[int, int],
78
+ standard_window: torch.Tensor,
79
+ first_window: torch.Tensor,
80
+ last_window: torch.Tensor
81
+ ):
82
+ combined = merge(
83
+ combined,
84
+ original_batch_size,
85
+ n_channel,
86
+ n_chunks,
87
+ chunk_size
88
+ )
89
+
90
+ combined = combined * standard_window[:, None].to(combined.device)
91
+
92
+ combined = F.fold(
93
+ combined.to(torch.float32), output_size=(1, n_padded_samples),
94
+ kernel_size=(1, chunk_size),
95
+ stride=(1, hop_size)
96
+ )
97
+
98
+ combined = combined.view(
99
+ original_batch_size,
100
+ n_channel,
101
+ n_padded_samples
102
+ )
103
+
104
+ pad_front, pad_back = edge_frame_pad_sizes
105
+ combined = combined[..., pad_front:-pad_back]
106
+
107
+ combined = combined[..., :n_samples]
108
+
109
+ return combined
110
+
111
+ # @torch.jit.script
112
+
113
+
114
+ def merge_chunks_edge(
115
+ combined: torch.Tensor,
116
+ original_batch_size: int,
117
+ n_channel: int,
118
+ n_samples: int,
119
+ n_padded_samples: int,
120
+ n_chunks: int,
121
+ chunk_size: int,
122
+ hop_size: int,
123
+ edge_frame_pad_sizes: Tuple[int, int],
124
+ standard_window: torch.Tensor,
125
+ first_window: torch.Tensor,
126
+ last_window: torch.Tensor
127
+ ):
128
+ combined = merge(
129
+ combined,
130
+ original_batch_size,
131
+ n_channel,
132
+ n_chunks,
133
+ chunk_size
134
+ )
135
+
136
+ combined[..., 0] = combined[..., 0] * first_window
137
+ combined[..., -1] = combined[..., -1] * last_window
138
+ combined[..., 1:-1] = combined[...,
139
+ 1:-1] * standard_window[:, None]
140
+
141
+ combined = F.fold(
142
+ combined, output_size=(1, n_padded_samples),
143
+ kernel_size=(1, chunk_size),
144
+ stride=(1, hop_size)
145
+ )
146
+
147
+ combined = combined.view(
148
+ original_batch_size,
149
+ n_channel,
150
+ n_padded_samples
151
+ )
152
+
153
+ combined = combined[..., :n_samples]
154
+
155
+ return combined
156
+
157
+
158
+ class BaseFader(nn.Module):
159
+ def __init__(
160
+ self,
161
+ chunk_size_second: float,
162
+ hop_size_second: float,
163
+ fs: int,
164
+ fade_edge_frames: bool,
165
+ batch_size: int,
166
+ ) -> None:
167
+ super().__init__()
168
+
169
+ self.chunk_size = int(chunk_size_second * fs)
170
+ self.hop_size = int(hop_size_second * fs)
171
+ self.overlap_size = self.chunk_size - self.hop_size
172
+ self.fade_edge_frames = fade_edge_frames
173
+ self.batch_size = batch_size
174
+
175
+ # @torch.jit.script
176
+ def prepare(self, audio):
177
+
178
+ if self.fade_edge_frames:
179
+ audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
180
+
181
+ n_samples = audio.shape[-1]
182
+ n_chunks = int(
183
+ np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1
184
+ )
185
+
186
+ padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
187
+ pad_size = padded_size - n_samples
188
+
189
+ padded_audio = F.pad(audio, (0, pad_size))
190
+
191
+ return padded_audio, n_chunks
192
+
193
+ def forward(
194
+ self,
195
+ audio: torch.Tensor,
196
+ model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
197
+ ):
198
+
199
+ original_dtype = audio.dtype
200
+ original_device = audio.device
201
+
202
+ audio = audio.to("cpu")
203
+
204
+ original_batch_size, n_channel, n_samples = audio.shape
205
+ padded_audio, n_chunks = self.prepare(audio)
206
+ del audio
207
+ n_padded_samples = padded_audio.shape[-1]
208
+
209
+ if n_channel > 1:
210
+ padded_audio = padded_audio.view(
211
+ original_batch_size * n_channel, 1, n_padded_samples
212
+ )
213
+
214
+ unfolded_input = unfold(
215
+ padded_audio,
216
+ original_batch_size,
217
+ n_channel,
218
+ self.chunk_size, self.hop_size
219
+ )
220
+
221
+ n_total_chunks, n_channel, chunk_size = unfolded_input.shape
222
+
223
+ n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
224
+
225
+ chunks_in = [
226
+ unfolded_input[
227
+ b * self.batch_size:(b + 1) * self.batch_size, ...].clone()
228
+ for b in range(n_batch)
229
+ ]
230
+
231
+ all_chunks_out = defaultdict(
232
+ lambda: torch.zeros_like(
233
+ unfolded_input, device="cpu"
234
+ )
235
+ )
236
+
237
+ # for b, cin in enumerate(tqdm(chunks_in)):
238
+ for b, cin in enumerate(chunks_in):
239
+ if torch.allclose(cin, torch.tensor(0.0)):
240
+ del cin
241
+ continue
242
+
243
+ chunks_out = model_fn(cin.to(original_device))
244
+ del cin
245
+ for s, c in chunks_out.items():
246
+ all_chunks_out[s][b * self.batch_size:(b + 1) * self.batch_size,
247
+ ...] = c.cpu()
248
+ del chunks_out
249
+
250
+ del unfolded_input
251
+ del padded_audio
252
+
253
+ if self.fade_edge_frames:
254
+ fn = merge_chunks_all
255
+ else:
256
+ fn = merge_chunks_edge
257
+ outputs = {}
258
+
259
+ torch.cuda.empty_cache()
260
+
261
+ for s, c in all_chunks_out.items():
262
+ combined: torch.Tensor = fn(
263
+ c,
264
+ original_batch_size,
265
+ n_channel,
266
+ n_samples,
267
+ n_padded_samples,
268
+ n_chunks,
269
+ self.chunk_size,
270
+ self.hop_size,
271
+ self.edge_frame_pad_sizes,
272
+ self.standard_window,
273
+ self.__dict__.get("first_window", self.standard_window),
274
+ self.__dict__.get("last_window", self.standard_window)
275
+ )
276
+
277
+ outputs[s] = combined.to(
278
+ dtype=original_dtype,
279
+ device=original_device
280
+ )
281
+
282
+ return {
283
+ "audio": outputs
284
+ }
285
+ #
286
+ # def old_forward(
287
+ # self,
288
+ # audio: torch.Tensor,
289
+ # model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
290
+ # ):
291
+ #
292
+ # n_samples = audio.shape[-1]
293
+ # original_batch_size = audio.shape[0]
294
+ #
295
+ # padded_audio, n_chunks = self.prepare(audio)
296
+ #
297
+ # ndim = padded_audio.ndim
298
+ # broadcaster = [1 for _ in range(ndim - 1)] + [self.chunk_size]
299
+ #
300
+ # outputs = defaultdict(
301
+ # lambda: torch.zeros_like(
302
+ # padded_audio, device=audio.device, dtype=torch.float64
303
+ # )
304
+ # )
305
+ #
306
+ # all_chunks_out = []
307
+ # len_chunks_in = []
308
+ #
309
+ # batch_size_ = int(self.batch_size // original_batch_size)
310
+ # for b in range(int(np.ceil(n_chunks / batch_size_))):
311
+ # chunks_in = []
312
+ # for j in range(batch_size_):
313
+ # i = b * batch_size_ + j
314
+ # if i == n_chunks:
315
+ # break
316
+ #
317
+ # start = i * hop_size
318
+ # end = start + self.chunk_size
319
+ # chunk_in = padded_audio[..., start:end]
320
+ # chunks_in.append(chunk_in)
321
+ #
322
+ # chunks_in = torch.concat(chunks_in, dim=0)
323
+ # chunks_out = model_fn(chunks_in)
324
+ # all_chunks_out.append(chunks_out)
325
+ # len_chunks_in.append(len(chunks_in))
326
+ #
327
+ # for b, (chunks_out, lci) in enumerate(
328
+ # zip(all_chunks_out, len_chunks_in)
329
+ # ):
330
+ # for stem in chunks_out:
331
+ # for j in range(lci // original_batch_size):
332
+ # i = b * batch_size_ + j
333
+ #
334
+ # if self.fade_edge_frames:
335
+ # window = self.standard_window
336
+ # else:
337
+ # if i == 0:
338
+ # window = self.first_window
339
+ # elif i == n_chunks - 1:
340
+ # window = self.last_window
341
+ # else:
342
+ # window = self.standard_window
343
+ #
344
+ # start = i * hop_size
345
+ # end = start + self.chunk_size
346
+ #
347
+ # chunk_out = chunks_out[stem][j * original_batch_size: (j + 1) * original_batch_size,
348
+ # ...]
349
+ # contrib = window.view(*broadcaster) * chunk_out
350
+ # outputs[stem][..., start:end] = (
351
+ # outputs[stem][..., start:end] + contrib
352
+ # )
353
+ #
354
+ # if self.fade_edge_frames:
355
+ # pad_front, pad_back = self.edge_frame_pad_sizes
356
+ # outputs = {k: v[..., pad_front:-pad_back] for k, v in
357
+ # outputs.items()}
358
+ #
359
+ # outputs = {k: v[..., :n_samples].to(audio.dtype) for k, v in
360
+ # outputs.items()}
361
+ #
362
+ # return {
363
+ # "audio": outputs
364
+ # }
365
+
366
+
367
+ class LinearFader(BaseFader):
368
+ def __init__(
369
+ self,
370
+ chunk_size_second: float,
371
+ hop_size_second: float,
372
+ fs: int,
373
+ fade_edge_frames: bool = False,
374
+ batch_size: int = 1,
375
+ ) -> None:
376
+
377
+ assert hop_size_second >= chunk_size_second / 2
378
+
379
+ super().__init__(
380
+ chunk_size_second=chunk_size_second,
381
+ hop_size_second=hop_size_second,
382
+ fs=fs,
383
+ fade_edge_frames=fade_edge_frames,
384
+ batch_size=batch_size,
385
+ )
386
+
387
+ in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
388
+ out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
389
+ center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
390
+ inout_ones = torch.ones(self.overlap_size)
391
+
392
+ # using nn.Parameters allows lightning to take care of devices for us
393
+ self.register_buffer(
394
+ "standard_window",
395
+ torch.concat([in_fade, center_ones, out_fade])
396
+ )
397
+
398
+ self.fade_edge_frames = fade_edge_frames
399
+ self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
400
+
401
+ if not self.fade_edge_frames:
402
+ self.first_window = nn.Parameter(
403
+ torch.concat([inout_ones, center_ones, out_fade]),
404
+ requires_grad=False
405
+ )
406
+ self.last_window = nn.Parameter(
407
+ torch.concat([in_fade, center_ones, inout_ones]),
408
+ requires_grad=False
409
+ )
410
+
411
+
412
+ class OverlapAddFader(BaseFader):
413
+ def __init__(
414
+ self,
415
+ window_type: str,
416
+ chunk_size_second: float,
417
+ hop_size_second: float,
418
+ fs: int,
419
+ batch_size: int = 1,
420
+ ) -> None:
421
+ assert (chunk_size_second / hop_size_second) % 2 == 0
422
+ assert int(chunk_size_second * fs) % 2 == 0
423
+
424
+ super().__init__(
425
+ chunk_size_second=chunk_size_second,
426
+ hop_size_second=hop_size_second,
427
+ fs=fs,
428
+ fade_edge_frames=True,
429
+ batch_size=batch_size,
430
+ )
431
+
432
+ self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
433
+ # print(f"hop multiplier: {self.hop_multiplier}")
434
+
435
+ self.edge_frame_pad_sizes = (
436
+ 2 * self.overlap_size,
437
+ 2 * self.overlap_size
438
+ )
439
+
440
+ self.register_buffer(
441
+ "standard_window", torch.windows.__dict__[window_type](
442
+ self.chunk_size, sym=False, # dtype=torch.float64
443
+ ) / self.hop_multiplier
444
+ )
445
+
446
+
447
+ if __name__ == "__main__":
448
+ import torchaudio as ta
449
+ fs = 44100
450
+ ola = OverlapAddFader(
451
+ "hann",
452
+ 6.0,
453
+ 1.0,
454
+ fs,
455
+ batch_size=16
456
+ )
457
+ audio_, _ = ta.load(
458
+ "$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too "
459
+ "Much/vocals.wav"
460
+ )
461
+ audio_ = audio_[None, ...]
462
+ out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
463
+ print(torch.allclose(out, audio_))
separator/models/bandit/model_from_config.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os.path
3
+ import torch
4
+
5
+ code_path = os.path.dirname(os.path.abspath(__file__)) + '/'
6
+ sys.path.append(code_path)
7
+
8
+ import yaml
9
+ from ml_collections import ConfigDict
10
+
11
+ torch.set_float32_matmul_precision("medium")
12
+
13
+
14
+ def get_model(
15
+ config_path,
16
+ weights_path,
17
+ device,
18
+ ):
19
+ from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple
20
+
21
+ f = open(config_path)
22
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
23
+ f.close()
24
+
25
+ model = MultiMaskMultiSourceBandSplitRNNSimple(
26
+ **config.model
27
+ )
28
+ d = torch.load(code_path + 'model_bandit_plus_dnr_sdr_11.47.chpt')
29
+ model.load_state_dict(d)
30
+ model.to(device)
31
+ return model, config
separator/models/bandit_v2/bandit.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+
3
+ import torch
4
+ import torchaudio as ta
5
+ from torch import nn
6
+ import pytorch_lightning as pl
7
+
8
+ from .bandsplit import BandSplitModule
9
+ from .maskestim import OverlappingMaskEstimationModule
10
+ from .tfmodel import SeqBandModellingModule
11
+ from .utils import MusicalBandsplitSpecification
12
+
13
+
14
+
15
+ class BaseEndToEndModule(pl.LightningModule):
16
+ def __init__(
17
+ self,
18
+ ) -> None:
19
+ super().__init__()
20
+
21
+
22
+ class BaseBandit(BaseEndToEndModule):
23
+ def __init__(
24
+ self,
25
+ in_channels: int,
26
+ fs: int,
27
+ band_type: str = "musical",
28
+ n_bands: int = 64,
29
+ require_no_overlap: bool = False,
30
+ require_no_gap: bool = True,
31
+ normalize_channel_independently: bool = False,
32
+ treat_channel_as_feature: bool = True,
33
+ n_sqm_modules: int = 12,
34
+ emb_dim: int = 128,
35
+ rnn_dim: int = 256,
36
+ bidirectional: bool = True,
37
+ rnn_type: str = "LSTM",
38
+ n_fft: int = 2048,
39
+ win_length: Optional[int] = 2048,
40
+ hop_length: int = 512,
41
+ window_fn: str = "hann_window",
42
+ wkwargs: Optional[Dict] = None,
43
+ power: Optional[int] = None,
44
+ center: bool = True,
45
+ normalized: bool = True,
46
+ pad_mode: str = "constant",
47
+ onesided: bool = True,
48
+ ):
49
+ super().__init__()
50
+
51
+ self.in_channels = in_channels
52
+
53
+ self.instantitate_spectral(
54
+ n_fft=n_fft,
55
+ win_length=win_length,
56
+ hop_length=hop_length,
57
+ window_fn=window_fn,
58
+ wkwargs=wkwargs,
59
+ power=power,
60
+ normalized=normalized,
61
+ center=center,
62
+ pad_mode=pad_mode,
63
+ onesided=onesided,
64
+ )
65
+
66
+ self.instantiate_bandsplit(
67
+ in_channels=in_channels,
68
+ band_type=band_type,
69
+ n_bands=n_bands,
70
+ require_no_overlap=require_no_overlap,
71
+ require_no_gap=require_no_gap,
72
+ normalize_channel_independently=normalize_channel_independently,
73
+ treat_channel_as_feature=treat_channel_as_feature,
74
+ emb_dim=emb_dim,
75
+ n_fft=n_fft,
76
+ fs=fs,
77
+ )
78
+
79
+ self.instantiate_tf_modelling(
80
+ n_sqm_modules=n_sqm_modules,
81
+ emb_dim=emb_dim,
82
+ rnn_dim=rnn_dim,
83
+ bidirectional=bidirectional,
84
+ rnn_type=rnn_type,
85
+ )
86
+
87
+ def instantitate_spectral(
88
+ self,
89
+ n_fft: int = 2048,
90
+ win_length: Optional[int] = 2048,
91
+ hop_length: int = 512,
92
+ window_fn: str = "hann_window",
93
+ wkwargs: Optional[Dict] = None,
94
+ power: Optional[int] = None,
95
+ normalized: bool = True,
96
+ center: bool = True,
97
+ pad_mode: str = "constant",
98
+ onesided: bool = True,
99
+ ):
100
+ assert power is None
101
+
102
+ window_fn = torch.__dict__[window_fn]
103
+
104
+ self.stft = ta.transforms.Spectrogram(
105
+ n_fft=n_fft,
106
+ win_length=win_length,
107
+ hop_length=hop_length,
108
+ pad_mode=pad_mode,
109
+ pad=0,
110
+ window_fn=window_fn,
111
+ wkwargs=wkwargs,
112
+ power=power,
113
+ normalized=normalized,
114
+ center=center,
115
+ onesided=onesided,
116
+ )
117
+
118
+ self.istft = ta.transforms.InverseSpectrogram(
119
+ n_fft=n_fft,
120
+ win_length=win_length,
121
+ hop_length=hop_length,
122
+ pad_mode=pad_mode,
123
+ pad=0,
124
+ window_fn=window_fn,
125
+ wkwargs=wkwargs,
126
+ normalized=normalized,
127
+ center=center,
128
+ onesided=onesided,
129
+ )
130
+
131
+ def instantiate_bandsplit(
132
+ self,
133
+ in_channels: int,
134
+ band_type: str = "musical",
135
+ n_bands: int = 64,
136
+ require_no_overlap: bool = False,
137
+ require_no_gap: bool = True,
138
+ normalize_channel_independently: bool = False,
139
+ treat_channel_as_feature: bool = True,
140
+ emb_dim: int = 128,
141
+ n_fft: int = 2048,
142
+ fs: int = 44100,
143
+ ):
144
+ assert band_type == "musical"
145
+
146
+ self.band_specs = MusicalBandsplitSpecification(
147
+ nfft=n_fft, fs=fs, n_bands=n_bands
148
+ )
149
+
150
+ self.band_split = BandSplitModule(
151
+ in_channels=in_channels,
152
+ band_specs=self.band_specs.get_band_specs(),
153
+ require_no_overlap=require_no_overlap,
154
+ require_no_gap=require_no_gap,
155
+ normalize_channel_independently=normalize_channel_independently,
156
+ treat_channel_as_feature=treat_channel_as_feature,
157
+ emb_dim=emb_dim,
158
+ )
159
+
160
+ def instantiate_tf_modelling(
161
+ self,
162
+ n_sqm_modules: int = 12,
163
+ emb_dim: int = 128,
164
+ rnn_dim: int = 256,
165
+ bidirectional: bool = True,
166
+ rnn_type: str = "LSTM",
167
+ ):
168
+ try:
169
+ self.tf_model = torch.compile(
170
+ SeqBandModellingModule(
171
+ n_modules=n_sqm_modules,
172
+ emb_dim=emb_dim,
173
+ rnn_dim=rnn_dim,
174
+ bidirectional=bidirectional,
175
+ rnn_type=rnn_type,
176
+ ),
177
+ disable=True,
178
+ )
179
+ except Exception as e:
180
+ self.tf_model = SeqBandModellingModule(
181
+ n_modules=n_sqm_modules,
182
+ emb_dim=emb_dim,
183
+ rnn_dim=rnn_dim,
184
+ bidirectional=bidirectional,
185
+ rnn_type=rnn_type,
186
+ )
187
+
188
+ def mask(self, x, m):
189
+ return x * m
190
+
191
+ def forward(self, batch, mode="train"):
192
+ # Model takes mono as input we give stereo, so we do process of each channel independently
193
+ init_shape = batch.shape
194
+ if not isinstance(batch, dict):
195
+ mono = batch.view(-1, 1, batch.shape[-1])
196
+ batch = {
197
+ "mixture": {
198
+ "audio": mono
199
+ }
200
+ }
201
+
202
+ with torch.no_grad():
203
+ mixture = batch["mixture"]["audio"]
204
+
205
+ x = self.stft(mixture)
206
+ batch["mixture"]["spectrogram"] = x
207
+
208
+ if "sources" in batch.keys():
209
+ for stem in batch["sources"].keys():
210
+ s = batch["sources"][stem]["audio"]
211
+ s = self.stft(s)
212
+ batch["sources"][stem]["spectrogram"] = s
213
+
214
+ batch = self.separate(batch)
215
+
216
+ if 1:
217
+ b = []
218
+ for s in self.stems:
219
+ # We need to obtain stereo again
220
+ r = batch['estimates'][s]['audio'].view(-1, init_shape[1], init_shape[2])
221
+ b.append(r)
222
+ # And we need to return back tensor and not independent stems
223
+ batch = torch.stack(b, dim=1)
224
+ return batch
225
+
226
+ def encode(self, batch):
227
+ x = batch["mixture"]["spectrogram"]
228
+ length = batch["mixture"]["audio"].shape[-1]
229
+
230
+ z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
231
+ q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
232
+
233
+ return x, q, length
234
+
235
+ def separate(self, batch):
236
+ raise NotImplementedError
237
+
238
+
239
+ class Bandit(BaseBandit):
240
+ def __init__(
241
+ self,
242
+ in_channels: int,
243
+ stems: List[str],
244
+ band_type: str = "musical",
245
+ n_bands: int = 64,
246
+ require_no_overlap: bool = False,
247
+ require_no_gap: bool = True,
248
+ normalize_channel_independently: bool = False,
249
+ treat_channel_as_feature: bool = True,
250
+ n_sqm_modules: int = 12,
251
+ emb_dim: int = 128,
252
+ rnn_dim: int = 256,
253
+ bidirectional: bool = True,
254
+ rnn_type: str = "LSTM",
255
+ mlp_dim: int = 512,
256
+ hidden_activation: str = "Tanh",
257
+ hidden_activation_kwargs: Dict | None = None,
258
+ complex_mask: bool = True,
259
+ use_freq_weights: bool = True,
260
+ n_fft: int = 2048,
261
+ win_length: int | None = 2048,
262
+ hop_length: int = 512,
263
+ window_fn: str = "hann_window",
264
+ wkwargs: Dict | None = None,
265
+ power: int | None = None,
266
+ center: bool = True,
267
+ normalized: bool = True,
268
+ pad_mode: str = "constant",
269
+ onesided: bool = True,
270
+ fs: int = 44100,
271
+ stft_precisions="32",
272
+ bandsplit_precisions="bf16",
273
+ tf_model_precisions="bf16",
274
+ mask_estim_precisions="bf16",
275
+ ):
276
+ super().__init__(
277
+ in_channels=in_channels,
278
+ band_type=band_type,
279
+ n_bands=n_bands,
280
+ require_no_overlap=require_no_overlap,
281
+ require_no_gap=require_no_gap,
282
+ normalize_channel_independently=normalize_channel_independently,
283
+ treat_channel_as_feature=treat_channel_as_feature,
284
+ n_sqm_modules=n_sqm_modules,
285
+ emb_dim=emb_dim,
286
+ rnn_dim=rnn_dim,
287
+ bidirectional=bidirectional,
288
+ rnn_type=rnn_type,
289
+ n_fft=n_fft,
290
+ win_length=win_length,
291
+ hop_length=hop_length,
292
+ window_fn=window_fn,
293
+ wkwargs=wkwargs,
294
+ power=power,
295
+ center=center,
296
+ normalized=normalized,
297
+ pad_mode=pad_mode,
298
+ onesided=onesided,
299
+ fs=fs,
300
+ )
301
+
302
+ self.stems = stems
303
+
304
+ self.instantiate_mask_estim(
305
+ in_channels=in_channels,
306
+ stems=stems,
307
+ emb_dim=emb_dim,
308
+ mlp_dim=mlp_dim,
309
+ hidden_activation=hidden_activation,
310
+ hidden_activation_kwargs=hidden_activation_kwargs,
311
+ complex_mask=complex_mask,
312
+ n_freq=n_fft // 2 + 1,
313
+ use_freq_weights=use_freq_weights,
314
+ )
315
+
316
+ def instantiate_mask_estim(
317
+ self,
318
+ in_channels: int,
319
+ stems: List[str],
320
+ emb_dim: int,
321
+ mlp_dim: int,
322
+ hidden_activation: str,
323
+ hidden_activation_kwargs: Optional[Dict] = None,
324
+ complex_mask: bool = True,
325
+ n_freq: Optional[int] = None,
326
+ use_freq_weights: bool = False,
327
+ ):
328
+ if hidden_activation_kwargs is None:
329
+ hidden_activation_kwargs = {}
330
+
331
+ assert n_freq is not None
332
+
333
+ self.mask_estim = nn.ModuleDict(
334
+ {
335
+ stem: OverlappingMaskEstimationModule(
336
+ band_specs=self.band_specs.get_band_specs(),
337
+ freq_weights=self.band_specs.get_freq_weights(),
338
+ n_freq=n_freq,
339
+ emb_dim=emb_dim,
340
+ mlp_dim=mlp_dim,
341
+ in_channels=in_channels,
342
+ hidden_activation=hidden_activation,
343
+ hidden_activation_kwargs=hidden_activation_kwargs,
344
+ complex_mask=complex_mask,
345
+ use_freq_weights=use_freq_weights,
346
+ )
347
+ for stem in stems
348
+ }
349
+ )
350
+
351
+ def separate(self, batch):
352
+ batch["estimates"] = {}
353
+
354
+ x, q, length = self.encode(batch)
355
+
356
+ for stem, mem in self.mask_estim.items():
357
+ m = mem(q)
358
+
359
+ s = self.mask(x, m.to(x.dtype))
360
+ s = torch.reshape(s, x.shape)
361
+ batch["estimates"][stem] = {
362
+ "audio": self.istft(s, length),
363
+ "spectrogram": s,
364
+ }
365
+
366
+ return batch
367
+
separator/models/bandit_v2/bandsplit.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.utils.checkpoint import checkpoint_sequential
6
+
7
+ from .utils import (
8
+ band_widths_from_specs,
9
+ check_no_gap,
10
+ check_no_overlap,
11
+ check_nonzero_bandwidth,
12
+ )
13
+
14
+
15
+ class NormFC(nn.Module):
16
+ def __init__(
17
+ self,
18
+ emb_dim: int,
19
+ bandwidth: int,
20
+ in_channels: int,
21
+ normalize_channel_independently: bool = False,
22
+ treat_channel_as_feature: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+
26
+ if not treat_channel_as_feature:
27
+ raise NotImplementedError
28
+
29
+ self.treat_channel_as_feature = treat_channel_as_feature
30
+
31
+ if normalize_channel_independently:
32
+ raise NotImplementedError
33
+
34
+ reim = 2
35
+
36
+ norm = nn.LayerNorm(in_channels * bandwidth * reim)
37
+
38
+ fc_in = bandwidth * reim
39
+
40
+ if treat_channel_as_feature:
41
+ fc_in *= in_channels
42
+ else:
43
+ assert emb_dim % in_channels == 0
44
+ emb_dim = emb_dim // in_channels
45
+
46
+ fc = nn.Linear(fc_in, emb_dim)
47
+
48
+ self.combined = nn.Sequential(norm, fc)
49
+
50
+ def forward(self, xb):
51
+ return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False)
52
+
53
+
54
+ class BandSplitModule(nn.Module):
55
+ def __init__(
56
+ self,
57
+ band_specs: List[Tuple[float, float]],
58
+ emb_dim: int,
59
+ in_channels: int,
60
+ require_no_overlap: bool = False,
61
+ require_no_gap: bool = True,
62
+ normalize_channel_independently: bool = False,
63
+ treat_channel_as_feature: bool = True,
64
+ ) -> None:
65
+ super().__init__()
66
+
67
+ check_nonzero_bandwidth(band_specs)
68
+
69
+ if require_no_gap:
70
+ check_no_gap(band_specs)
71
+
72
+ if require_no_overlap:
73
+ check_no_overlap(band_specs)
74
+
75
+ self.band_specs = band_specs
76
+ # list of [fstart, fend) in index.
77
+ # Note that fend is exclusive.
78
+ self.band_widths = band_widths_from_specs(band_specs)
79
+ self.n_bands = len(band_specs)
80
+ self.emb_dim = emb_dim
81
+
82
+ try:
83
+ self.norm_fc_modules = nn.ModuleList(
84
+ [ # type: ignore
85
+ torch.compile(
86
+ NormFC(
87
+ emb_dim=emb_dim,
88
+ bandwidth=bw,
89
+ in_channels=in_channels,
90
+ normalize_channel_independently=normalize_channel_independently,
91
+ treat_channel_as_feature=treat_channel_as_feature,
92
+ ),
93
+ disable=True,
94
+ )
95
+ for bw in self.band_widths
96
+ ]
97
+ )
98
+ except Exception as e:
99
+ self.norm_fc_modules = nn.ModuleList(
100
+ [ # type: ignore
101
+ NormFC(
102
+ emb_dim=emb_dim,
103
+ bandwidth=bw,
104
+ in_channels=in_channels,
105
+ normalize_channel_independently=normalize_channel_independently,
106
+ treat_channel_as_feature=treat_channel_as_feature,
107
+ )
108
+ for bw in self.band_widths
109
+ ]
110
+ )
111
+
112
+ def forward(self, x: torch.Tensor):
113
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
114
+
115
+ batch, in_chan, band_width, n_time = x.shape
116
+
117
+ z = torch.zeros(
118
+ size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
119
+ )
120
+
121
+ x = torch.permute(x, (0, 3, 1, 2)).contiguous()
122
+
123
+ for i, nfm in enumerate(self.norm_fc_modules):
124
+ fstart, fend = self.band_specs[i]
125
+ xb = x[:, :, :, fstart:fend]
126
+ xb = torch.view_as_real(xb)
127
+ xb = torch.reshape(xb, (batch, n_time, -1))
128
+ z[:, i, :, :] = nfm(xb)
129
+
130
+ return z
separator/models/bandit_v2/film.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+
4
+ class FiLM(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+
8
+ def forward(self, x, gamma, beta):
9
+ return gamma * x + beta
10
+
11
+
12
+ class BTFBroadcastedFiLM(nn.Module):
13
+ def __init__(self):
14
+ super().__init__()
15
+ self.film = FiLM()
16
+
17
+ def forward(self, x, gamma, beta):
18
+
19
+ gamma = gamma[None, None, None, :]
20
+ beta = beta[None, None, None, :]
21
+
22
+ return self.film(x, gamma, beta)
23
+
24
+
25
+
separator/models/bandit_v2/maskestim.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Type
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.modules import activation
6
+ from torch.utils.checkpoint import checkpoint_sequential
7
+
8
+ from .utils import (
9
+ band_widths_from_specs,
10
+ check_no_gap,
11
+ check_no_overlap,
12
+ check_nonzero_bandwidth,
13
+ )
14
+
15
+
16
+ class BaseNormMLP(nn.Module):
17
+ def __init__(
18
+ self,
19
+ emb_dim: int,
20
+ mlp_dim: int,
21
+ bandwidth: int,
22
+ in_channels: Optional[int],
23
+ hidden_activation: str = "Tanh",
24
+ hidden_activation_kwargs=None,
25
+ complex_mask: bool = True,
26
+ ):
27
+ super().__init__()
28
+ if hidden_activation_kwargs is None:
29
+ hidden_activation_kwargs = {}
30
+ self.hidden_activation_kwargs = hidden_activation_kwargs
31
+ self.norm = nn.LayerNorm(emb_dim)
32
+ self.hidden = nn.Sequential(
33
+ nn.Linear(in_features=emb_dim, out_features=mlp_dim),
34
+ activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
35
+ )
36
+
37
+ self.bandwidth = bandwidth
38
+ self.in_channels = in_channels
39
+
40
+ self.complex_mask = complex_mask
41
+ self.reim = 2 if complex_mask else 1
42
+ self.glu_mult = 2
43
+
44
+
45
+ class NormMLP(BaseNormMLP):
46
+ def __init__(
47
+ self,
48
+ emb_dim: int,
49
+ mlp_dim: int,
50
+ bandwidth: int,
51
+ in_channels: Optional[int],
52
+ hidden_activation: str = "Tanh",
53
+ hidden_activation_kwargs=None,
54
+ complex_mask: bool = True,
55
+ ) -> None:
56
+ super().__init__(
57
+ emb_dim=emb_dim,
58
+ mlp_dim=mlp_dim,
59
+ bandwidth=bandwidth,
60
+ in_channels=in_channels,
61
+ hidden_activation=hidden_activation,
62
+ hidden_activation_kwargs=hidden_activation_kwargs,
63
+ complex_mask=complex_mask,
64
+ )
65
+
66
+ self.output = nn.Sequential(
67
+ nn.Linear(
68
+ in_features=mlp_dim,
69
+ out_features=bandwidth * in_channels * self.reim * 2,
70
+ ),
71
+ nn.GLU(dim=-1),
72
+ )
73
+
74
+ try:
75
+ self.combined = torch.compile(
76
+ nn.Sequential(self.norm, self.hidden, self.output), disable=True
77
+ )
78
+ except Exception as e:
79
+ self.combined = nn.Sequential(self.norm, self.hidden, self.output)
80
+
81
+ def reshape_output(self, mb):
82
+ # print(mb.shape)
83
+ batch, n_time, _ = mb.shape
84
+ if self.complex_mask:
85
+ mb = mb.reshape(
86
+ batch, n_time, self.in_channels, self.bandwidth, self.reim
87
+ ).contiguous()
88
+ # print(mb.shape)
89
+ mb = torch.view_as_complex(mb) # (batch, n_time, in_channels, bandwidth)
90
+ else:
91
+ mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth)
92
+
93
+ mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channels, bandwidth, n_time)
94
+
95
+ return mb
96
+
97
+ def forward(self, qb):
98
+ # qb = (batch, n_time, emb_dim)
99
+ # qb = self.norm(qb) # (batch, n_time, emb_dim)
100
+ # qb = self.hidden(qb) # (batch, n_time, mlp_dim)
101
+ # mb = self.output(qb) # (batch, n_time, bandwidth * in_channels * reim)
102
+
103
+ mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False)
104
+ mb = self.reshape_output(mb) # (batch, in_channels, bandwidth, n_time)
105
+
106
+ return mb
107
+
108
+
109
+ class MaskEstimationModuleSuperBase(nn.Module):
110
+ pass
111
+
112
+
113
+ class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
114
+ def __init__(
115
+ self,
116
+ band_specs: List[Tuple[float, float]],
117
+ emb_dim: int,
118
+ mlp_dim: int,
119
+ in_channels: Optional[int],
120
+ hidden_activation: str = "Tanh",
121
+ hidden_activation_kwargs: Dict = None,
122
+ complex_mask: bool = True,
123
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
124
+ norm_mlp_kwargs: Dict = None,
125
+ ) -> None:
126
+ super().__init__()
127
+
128
+ self.band_widths = band_widths_from_specs(band_specs)
129
+ self.n_bands = len(band_specs)
130
+
131
+ if hidden_activation_kwargs is None:
132
+ hidden_activation_kwargs = {}
133
+
134
+ if norm_mlp_kwargs is None:
135
+ norm_mlp_kwargs = {}
136
+
137
+ self.norm_mlp = nn.ModuleList(
138
+ [
139
+ norm_mlp_cls(
140
+ bandwidth=self.band_widths[b],
141
+ emb_dim=emb_dim,
142
+ mlp_dim=mlp_dim,
143
+ in_channels=in_channels,
144
+ hidden_activation=hidden_activation,
145
+ hidden_activation_kwargs=hidden_activation_kwargs,
146
+ complex_mask=complex_mask,
147
+ **norm_mlp_kwargs,
148
+ )
149
+ for b in range(self.n_bands)
150
+ ]
151
+ )
152
+
153
+ def compute_masks(self, q):
154
+ batch, n_bands, n_time, emb_dim = q.shape
155
+
156
+ masks = []
157
+
158
+ for b, nmlp in enumerate(self.norm_mlp):
159
+ # print(f"maskestim/{b:02d}")
160
+ qb = q[:, b, :, :]
161
+ mb = nmlp(qb)
162
+ masks.append(mb)
163
+
164
+ return masks
165
+
166
+ def compute_mask(self, q, b):
167
+ batch, n_bands, n_time, emb_dim = q.shape
168
+ qb = q[:, b, :, :]
169
+ mb = self.norm_mlp[b](qb)
170
+ return mb
171
+
172
+
173
+ class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
174
+ def __init__(
175
+ self,
176
+ in_channels: int,
177
+ band_specs: List[Tuple[float, float]],
178
+ freq_weights: List[torch.Tensor],
179
+ n_freq: int,
180
+ emb_dim: int,
181
+ mlp_dim: int,
182
+ cond_dim: int = 0,
183
+ hidden_activation: str = "Tanh",
184
+ hidden_activation_kwargs: Dict = None,
185
+ complex_mask: bool = True,
186
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
187
+ norm_mlp_kwargs: Dict = None,
188
+ use_freq_weights: bool = False,
189
+ ) -> None:
190
+ check_nonzero_bandwidth(band_specs)
191
+ check_no_gap(band_specs)
192
+
193
+ if cond_dim > 0:
194
+ raise NotImplementedError
195
+
196
+ super().__init__(
197
+ band_specs=band_specs,
198
+ emb_dim=emb_dim + cond_dim,
199
+ mlp_dim=mlp_dim,
200
+ in_channels=in_channels,
201
+ hidden_activation=hidden_activation,
202
+ hidden_activation_kwargs=hidden_activation_kwargs,
203
+ complex_mask=complex_mask,
204
+ norm_mlp_cls=norm_mlp_cls,
205
+ norm_mlp_kwargs=norm_mlp_kwargs,
206
+ )
207
+
208
+ self.n_freq = n_freq
209
+ self.band_specs = band_specs
210
+ self.in_channels = in_channels
211
+
212
+ if freq_weights is not None and use_freq_weights:
213
+ for i, fw in enumerate(freq_weights):
214
+ self.register_buffer(f"freq_weights/{i}", fw)
215
+
216
+ self.use_freq_weights = use_freq_weights
217
+ else:
218
+ self.use_freq_weights = False
219
+
220
+ def forward(self, q):
221
+ # q = (batch, n_bands, n_time, emb_dim)
222
+
223
+ batch, n_bands, n_time, emb_dim = q.shape
224
+
225
+ masks = torch.zeros(
226
+ (batch, self.in_channels, self.n_freq, n_time),
227
+ device=q.device,
228
+ dtype=torch.complex64,
229
+ )
230
+
231
+ for im in range(n_bands):
232
+ fstart, fend = self.band_specs[im]
233
+
234
+ mask = self.compute_mask(q, im)
235
+
236
+ if self.use_freq_weights:
237
+ fw = self.get_buffer(f"freq_weights/{im}")[:, None]
238
+ mask = mask * fw
239
+ masks[:, :, fstart:fend, :] += mask
240
+
241
+ return masks
242
+
243
+
244
+ class MaskEstimationModule(OverlappingMaskEstimationModule):
245
+ def __init__(
246
+ self,
247
+ band_specs: List[Tuple[float, float]],
248
+ emb_dim: int,
249
+ mlp_dim: int,
250
+ in_channels: Optional[int],
251
+ hidden_activation: str = "Tanh",
252
+ hidden_activation_kwargs: Dict = None,
253
+ complex_mask: bool = True,
254
+ **kwargs,
255
+ ) -> None:
256
+ check_nonzero_bandwidth(band_specs)
257
+ check_no_gap(band_specs)
258
+ check_no_overlap(band_specs)
259
+ super().__init__(
260
+ in_channels=in_channels,
261
+ band_specs=band_specs,
262
+ freq_weights=None,
263
+ n_freq=None,
264
+ emb_dim=emb_dim,
265
+ mlp_dim=mlp_dim,
266
+ hidden_activation=hidden_activation,
267
+ hidden_activation_kwargs=hidden_activation_kwargs,
268
+ complex_mask=complex_mask,
269
+ )
270
+
271
+ def forward(self, q, cond=None):
272
+ # q = (batch, n_bands, n_time, emb_dim)
273
+
274
+ masks = self.compute_masks(
275
+ q
276
+ ) # [n_bands * (batch, in_channels, bandwidth, n_time)]
277
+
278
+ # TODO: currently this requires band specs to have no gap and no overlap
279
+ masks = torch.concat(masks, dim=2) # (batch, in_channels, n_freq, n_time)
280
+
281
+ return masks
separator/models/bandit_v2/tfmodel.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+ import torch.backends.cuda
5
+ from torch import nn
6
+ from torch.nn.modules import rnn
7
+ from torch.utils.checkpoint import checkpoint_sequential
8
+
9
+
10
+ class TimeFrequencyModellingModule(nn.Module):
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+
14
+
15
+ class ResidualRNN(nn.Module):
16
+ def __init__(
17
+ self,
18
+ emb_dim: int,
19
+ rnn_dim: int,
20
+ bidirectional: bool = True,
21
+ rnn_type: str = "LSTM",
22
+ use_batch_trick: bool = True,
23
+ use_layer_norm: bool = True,
24
+ ) -> None:
25
+ # n_group is the size of the 2nd dim
26
+ super().__init__()
27
+
28
+ assert use_layer_norm
29
+ assert use_batch_trick
30
+
31
+ self.use_layer_norm = use_layer_norm
32
+ self.norm = nn.LayerNorm(emb_dim)
33
+ self.rnn = rnn.__dict__[rnn_type](
34
+ input_size=emb_dim,
35
+ hidden_size=rnn_dim,
36
+ num_layers=1,
37
+ batch_first=True,
38
+ bidirectional=bidirectional,
39
+ )
40
+
41
+ self.fc = nn.Linear(
42
+ in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
43
+ )
44
+
45
+ self.use_batch_trick = use_batch_trick
46
+ if not self.use_batch_trick:
47
+ warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
48
+
49
+ def forward(self, z):
50
+ # z = (batch, n_uncrossed, n_across, emb_dim)
51
+
52
+ z0 = torch.clone(z)
53
+ z = self.norm(z)
54
+
55
+ batch, n_uncrossed, n_across, emb_dim = z.shape
56
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
57
+ z = self.rnn(z)[0]
58
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
59
+
60
+ z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
61
+
62
+ z = z + z0
63
+
64
+ return z
65
+
66
+
67
+ class Transpose(nn.Module):
68
+ def __init__(self, dim0: int, dim1: int) -> None:
69
+ super().__init__()
70
+ self.dim0 = dim0
71
+ self.dim1 = dim1
72
+
73
+ def forward(self, z):
74
+ return z.transpose(self.dim0, self.dim1)
75
+
76
+
77
+ class SeqBandModellingModule(TimeFrequencyModellingModule):
78
+ def __init__(
79
+ self,
80
+ n_modules: int = 12,
81
+ emb_dim: int = 128,
82
+ rnn_dim: int = 256,
83
+ bidirectional: bool = True,
84
+ rnn_type: str = "LSTM",
85
+ parallel_mode=False,
86
+ ) -> None:
87
+ super().__init__()
88
+
89
+ self.n_modules = n_modules
90
+
91
+ if parallel_mode:
92
+ self.seqband = nn.ModuleList([])
93
+ for _ in range(n_modules):
94
+ self.seqband.append(
95
+ nn.ModuleList(
96
+ [
97
+ ResidualRNN(
98
+ emb_dim=emb_dim,
99
+ rnn_dim=rnn_dim,
100
+ bidirectional=bidirectional,
101
+ rnn_type=rnn_type,
102
+ ),
103
+ ResidualRNN(
104
+ emb_dim=emb_dim,
105
+ rnn_dim=rnn_dim,
106
+ bidirectional=bidirectional,
107
+ rnn_type=rnn_type,
108
+ ),
109
+ ]
110
+ )
111
+ )
112
+ else:
113
+ seqband = []
114
+ for _ in range(2 * n_modules):
115
+ seqband += [
116
+ ResidualRNN(
117
+ emb_dim=emb_dim,
118
+ rnn_dim=rnn_dim,
119
+ bidirectional=bidirectional,
120
+ rnn_type=rnn_type,
121
+ ),
122
+ Transpose(1, 2),
123
+ ]
124
+
125
+ self.seqband = nn.Sequential(*seqband)
126
+
127
+ self.parallel_mode = parallel_mode
128
+
129
+ def forward(self, z):
130
+ # z = (batch, n_bands, n_time, emb_dim)
131
+
132
+ if self.parallel_mode:
133
+ for sbm_pair in self.seqband:
134
+ # z: (batch, n_bands, n_time, emb_dim)
135
+ sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
136
+ zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
137
+ zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
138
+ z = zt + zf.transpose(1, 2)
139
+ else:
140
+ z = checkpoint_sequential(
141
+ self.seqband, self.n_modules, z, use_reentrant=False
142
+ )
143
+
144
+ q = z
145
+ return q # (batch, n_bands, n_time, emb_dim)
separator/models/bandit_v2/utils.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import abstractmethod
3
+ from typing import Callable
4
+
5
+ import numpy as np
6
+ import torch
7
+ from librosa import hz_to_midi, midi_to_hz
8
+ from torchaudio import functional as taF
9
+
10
+ # from spafe.fbanks import bark_fbanks
11
+ # from spafe.utils.converters import erb2hz, hz2bark, hz2erb
12
+
13
+
14
+ def band_widths_from_specs(band_specs):
15
+ return [e - i for i, e in band_specs]
16
+
17
+
18
+ def check_nonzero_bandwidth(band_specs):
19
+ # pprint(band_specs)
20
+ for fstart, fend in band_specs:
21
+ if fend - fstart <= 0:
22
+ raise ValueError("Bands cannot be zero-width")
23
+
24
+
25
+ def check_no_overlap(band_specs):
26
+ fend_prev = -1
27
+ for fstart_curr, fend_curr in band_specs:
28
+ if fstart_curr <= fend_prev:
29
+ raise ValueError("Bands cannot overlap")
30
+
31
+
32
+ def check_no_gap(band_specs):
33
+ fstart, _ = band_specs[0]
34
+ assert fstart == 0
35
+
36
+ fend_prev = -1
37
+ for fstart_curr, fend_curr in band_specs:
38
+ if fstart_curr - fend_prev > 1:
39
+ raise ValueError("Bands cannot leave gap")
40
+ fend_prev = fend_curr
41
+
42
+
43
+ class BandsplitSpecification:
44
+ def __init__(self, nfft: int, fs: int) -> None:
45
+ self.fs = fs
46
+ self.nfft = nfft
47
+ self.nyquist = fs / 2
48
+ self.max_index = nfft // 2 + 1
49
+
50
+ self.split500 = self.hertz_to_index(500)
51
+ self.split1k = self.hertz_to_index(1000)
52
+ self.split2k = self.hertz_to_index(2000)
53
+ self.split4k = self.hertz_to_index(4000)
54
+ self.split8k = self.hertz_to_index(8000)
55
+ self.split16k = self.hertz_to_index(16000)
56
+ self.split20k = self.hertz_to_index(20000)
57
+
58
+ self.above20k = [(self.split20k, self.max_index)]
59
+ self.above16k = [(self.split16k, self.split20k)] + self.above20k
60
+
61
+ def index_to_hertz(self, index: int):
62
+ return index * self.fs / self.nfft
63
+
64
+ def hertz_to_index(self, hz: float, round: bool = True):
65
+ index = hz * self.nfft / self.fs
66
+
67
+ if round:
68
+ index = int(np.round(index))
69
+
70
+ return index
71
+
72
+ def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
73
+ band_specs = []
74
+ lower = start_index
75
+
76
+ while lower < end_index:
77
+ upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
78
+ upper = min(upper, end_index)
79
+
80
+ band_specs.append((lower, upper))
81
+ lower = upper
82
+
83
+ return band_specs
84
+
85
+ @abstractmethod
86
+ def get_band_specs(self):
87
+ raise NotImplementedError
88
+
89
+
90
+ class VocalBandsplitSpecification(BandsplitSpecification):
91
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
92
+ super().__init__(nfft=nfft, fs=fs)
93
+
94
+ self.version = version
95
+
96
+ def get_band_specs(self):
97
+ return getattr(self, f"version{self.version}")()
98
+
99
+ @property
100
+ def version1(self):
101
+ return self.get_band_specs_with_bandwidth(
102
+ start_index=0, end_index=self.max_index, bandwidth_hz=1000
103
+ )
104
+
105
+ def version2(self):
106
+ below16k = self.get_band_specs_with_bandwidth(
107
+ start_index=0, end_index=self.split16k, bandwidth_hz=1000
108
+ )
109
+ below20k = self.get_band_specs_with_bandwidth(
110
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
111
+ )
112
+
113
+ return below16k + below20k + self.above20k
114
+
115
+ def version3(self):
116
+ below8k = self.get_band_specs_with_bandwidth(
117
+ start_index=0, end_index=self.split8k, bandwidth_hz=1000
118
+ )
119
+ below16k = self.get_band_specs_with_bandwidth(
120
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
121
+ )
122
+
123
+ return below8k + below16k + self.above16k
124
+
125
+ def version4(self):
126
+ below1k = self.get_band_specs_with_bandwidth(
127
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
128
+ )
129
+ below8k = self.get_band_specs_with_bandwidth(
130
+ start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
131
+ )
132
+ below16k = self.get_band_specs_with_bandwidth(
133
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
134
+ )
135
+
136
+ return below1k + below8k + below16k + self.above16k
137
+
138
+ def version5(self):
139
+ below1k = self.get_band_specs_with_bandwidth(
140
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
141
+ )
142
+ below16k = self.get_band_specs_with_bandwidth(
143
+ start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
144
+ )
145
+ below20k = self.get_band_specs_with_bandwidth(
146
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
147
+ )
148
+ return below1k + below16k + below20k + self.above20k
149
+
150
+ def version6(self):
151
+ below1k = self.get_band_specs_with_bandwidth(
152
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
153
+ )
154
+ below4k = self.get_band_specs_with_bandwidth(
155
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
156
+ )
157
+ below8k = self.get_band_specs_with_bandwidth(
158
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
159
+ )
160
+ below16k = self.get_band_specs_with_bandwidth(
161
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
162
+ )
163
+ return below1k + below4k + below8k + below16k + self.above16k
164
+
165
+ def version7(self):
166
+ below1k = self.get_band_specs_with_bandwidth(
167
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
168
+ )
169
+ below4k = self.get_band_specs_with_bandwidth(
170
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
171
+ )
172
+ below8k = self.get_band_specs_with_bandwidth(
173
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
174
+ )
175
+ below16k = self.get_band_specs_with_bandwidth(
176
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
177
+ )
178
+ below20k = self.get_band_specs_with_bandwidth(
179
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
180
+ )
181
+ return below1k + below4k + below8k + below16k + below20k + self.above20k
182
+
183
+
184
+ class OtherBandsplitSpecification(VocalBandsplitSpecification):
185
+ def __init__(self, nfft: int, fs: int) -> None:
186
+ super().__init__(nfft=nfft, fs=fs, version="7")
187
+
188
+
189
+ class BassBandsplitSpecification(BandsplitSpecification):
190
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
191
+ super().__init__(nfft=nfft, fs=fs)
192
+
193
+ def get_band_specs(self):
194
+ below500 = self.get_band_specs_with_bandwidth(
195
+ start_index=0, end_index=self.split500, bandwidth_hz=50
196
+ )
197
+ below1k = self.get_band_specs_with_bandwidth(
198
+ start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
199
+ )
200
+ below4k = self.get_band_specs_with_bandwidth(
201
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
202
+ )
203
+ below8k = self.get_band_specs_with_bandwidth(
204
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
205
+ )
206
+ below16k = self.get_band_specs_with_bandwidth(
207
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
208
+ )
209
+ above16k = [(self.split16k, self.max_index)]
210
+
211
+ return below500 + below1k + below4k + below8k + below16k + above16k
212
+
213
+
214
+ class DrumBandsplitSpecification(BandsplitSpecification):
215
+ def __init__(self, nfft: int, fs: int) -> None:
216
+ super().__init__(nfft=nfft, fs=fs)
217
+
218
+ def get_band_specs(self):
219
+ below1k = self.get_band_specs_with_bandwidth(
220
+ start_index=0, end_index=self.split1k, bandwidth_hz=50
221
+ )
222
+ below2k = self.get_band_specs_with_bandwidth(
223
+ start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
224
+ )
225
+ below4k = self.get_band_specs_with_bandwidth(
226
+ start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
227
+ )
228
+ below8k = self.get_band_specs_with_bandwidth(
229
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
230
+ )
231
+ below16k = self.get_band_specs_with_bandwidth(
232
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
233
+ )
234
+ above16k = [(self.split16k, self.max_index)]
235
+
236
+ return below1k + below2k + below4k + below8k + below16k + above16k
237
+
238
+
239
+ class PerceptualBandsplitSpecification(BandsplitSpecification):
240
+ def __init__(
241
+ self,
242
+ nfft: int,
243
+ fs: int,
244
+ fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
245
+ n_bands: int,
246
+ f_min: float = 0.0,
247
+ f_max: float = None,
248
+ ) -> None:
249
+ super().__init__(nfft=nfft, fs=fs)
250
+ self.n_bands = n_bands
251
+ if f_max is None:
252
+ f_max = fs / 2
253
+
254
+ self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
255
+
256
+ weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs)
257
+ normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
258
+
259
+ freq_weights = []
260
+ band_specs = []
261
+ for i in range(self.n_bands):
262
+ active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
263
+ if isinstance(active_bins, int):
264
+ active_bins = (active_bins, active_bins)
265
+ if len(active_bins) == 0:
266
+ continue
267
+ start_index = active_bins[0]
268
+ end_index = active_bins[-1] + 1
269
+ band_specs.append((start_index, end_index))
270
+ freq_weights.append(normalized_mel_fb[i, start_index:end_index])
271
+
272
+ self.freq_weights = freq_weights
273
+ self.band_specs = band_specs
274
+
275
+ def get_band_specs(self):
276
+ return self.band_specs
277
+
278
+ def get_freq_weights(self):
279
+ return self.freq_weights
280
+
281
+ def save_to_file(self, dir_path: str) -> None:
282
+ os.makedirs(dir_path, exist_ok=True)
283
+
284
+ import pickle
285
+
286
+ with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
287
+ pickle.dump(
288
+ {
289
+ "band_specs": self.band_specs,
290
+ "freq_weights": self.freq_weights,
291
+ "filterbank": self.filterbank,
292
+ },
293
+ f,
294
+ )
295
+
296
+
297
+ def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
298
+ fb = taF.melscale_fbanks(
299
+ n_mels=n_bands,
300
+ sample_rate=fs,
301
+ f_min=f_min,
302
+ f_max=f_max,
303
+ n_freqs=n_freqs,
304
+ ).T
305
+
306
+ fb[0, 0] = 1.0
307
+
308
+ return fb
309
+
310
+
311
+ class MelBandsplitSpecification(PerceptualBandsplitSpecification):
312
+ def __init__(
313
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
314
+ ) -> None:
315
+ super().__init__(
316
+ fbank_fn=mel_filterbank,
317
+ nfft=nfft,
318
+ fs=fs,
319
+ n_bands=n_bands,
320
+ f_min=f_min,
321
+ f_max=f_max,
322
+ )
323
+
324
+
325
+ def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
326
+ nfft = 2 * (n_freqs - 1)
327
+ df = fs / nfft
328
+ # init freqs
329
+ f_max = f_max or fs / 2
330
+ f_min = f_min or 0
331
+ f_min = fs / nfft
332
+
333
+ n_octaves = np.log2(f_max / f_min)
334
+ n_octaves_per_band = n_octaves / n_bands
335
+ bandwidth_mult = np.power(2.0, n_octaves_per_band)
336
+
337
+ low_midi = max(0, hz_to_midi(f_min))
338
+ high_midi = hz_to_midi(f_max)
339
+ midi_points = np.linspace(low_midi, high_midi, n_bands)
340
+ hz_pts = midi_to_hz(midi_points)
341
+
342
+ low_pts = hz_pts / bandwidth_mult
343
+ high_pts = hz_pts * bandwidth_mult
344
+
345
+ low_bins = np.floor(low_pts / df).astype(int)
346
+ high_bins = np.ceil(high_pts / df).astype(int)
347
+
348
+ fb = np.zeros((n_bands, n_freqs))
349
+
350
+ for i in range(n_bands):
351
+ fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
352
+
353
+ fb[0, : low_bins[0]] = 1.0
354
+ fb[-1, high_bins[-1] + 1 :] = 1.0
355
+
356
+ return torch.as_tensor(fb)
357
+
358
+
359
+ class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
360
+ def __init__(
361
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
362
+ ) -> None:
363
+ super().__init__(
364
+ fbank_fn=musical_filterbank,
365
+ nfft=nfft,
366
+ fs=fs,
367
+ n_bands=n_bands,
368
+ f_min=f_min,
369
+ f_max=f_max,
370
+ )
371
+
372
+
373
+ # def bark_filterbank(
374
+ # n_bands, fs, f_min, f_max, n_freqs
375
+ # ):
376
+ # nfft = 2 * (n_freqs -1)
377
+ # fb, _ = bark_fbanks.bark_filter_banks(
378
+ # nfilts=n_bands,
379
+ # nfft=nfft,
380
+ # fs=fs,
381
+ # low_freq=f_min,
382
+ # high_freq=f_max,
383
+ # scale="constant"
384
+ # )
385
+
386
+ # return torch.as_tensor(fb)
387
+
388
+ # class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
389
+ # def __init__(
390
+ # self,
391
+ # nfft: int,
392
+ # fs: int,
393
+ # n_bands: int,
394
+ # f_min: float = 0.0,
395
+ # f_max: float = None
396
+ # ) -> None:
397
+ # super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
398
+
399
+
400
+ # def triangular_bark_filterbank(
401
+ # n_bands, fs, f_min, f_max, n_freqs
402
+ # ):
403
+
404
+ # all_freqs = torch.linspace(0, fs // 2, n_freqs)
405
+
406
+ # # calculate mel freq bins
407
+ # m_min = hz2bark(f_min)
408
+ # m_max = hz2bark(f_max)
409
+
410
+ # m_pts = torch.linspace(m_min, m_max, n_bands + 2)
411
+ # f_pts = 600 * torch.sinh(m_pts / 6)
412
+
413
+ # # create filterbank
414
+ # fb = _create_triangular_filterbank(all_freqs, f_pts)
415
+
416
+ # fb = fb.T
417
+
418
+ # first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
419
+ # first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
420
+
421
+ # fb[first_active_band, :first_active_bin] = 1.0
422
+
423
+ # return fb
424
+
425
+ # class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
426
+ # def __init__(
427
+ # self,
428
+ # nfft: int,
429
+ # fs: int,
430
+ # n_bands: int,
431
+ # f_min: float = 0.0,
432
+ # f_max: float = None
433
+ # ) -> None:
434
+ # super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
435
+
436
+
437
+ # def minibark_filterbank(
438
+ # n_bands, fs, f_min, f_max, n_freqs
439
+ # ):
440
+ # fb = bark_filterbank(
441
+ # n_bands,
442
+ # fs,
443
+ # f_min,
444
+ # f_max,
445
+ # n_freqs
446
+ # )
447
+
448
+ # fb[fb < np.sqrt(0.5)] = 0.0
449
+
450
+ # return fb
451
+
452
+ # class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
453
+ # def __init__(
454
+ # self,
455
+ # nfft: int,
456
+ # fs: int,
457
+ # n_bands: int,
458
+ # f_min: float = 0.0,
459
+ # f_max: float = None
460
+ # ) -> None:
461
+ # super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
462
+
463
+
464
+ # def erb_filterbank(
465
+ # n_bands: int,
466
+ # fs: int,
467
+ # f_min: float,
468
+ # f_max: float,
469
+ # n_freqs: int,
470
+ # ) -> Tensor:
471
+ # # freq bins
472
+ # A = (1000 * np.log(10)) / (24.7 * 4.37)
473
+ # all_freqs = torch.linspace(0, fs // 2, n_freqs)
474
+
475
+ # # calculate mel freq bins
476
+ # m_min = hz2erb(f_min)
477
+ # m_max = hz2erb(f_max)
478
+
479
+ # m_pts = torch.linspace(m_min, m_max, n_bands + 2)
480
+ # f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437
481
+
482
+ # # create filterbank
483
+ # fb = _create_triangular_filterbank(all_freqs, f_pts)
484
+
485
+ # fb = fb.T
486
+
487
+
488
+ # first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
489
+ # first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
490
+
491
+ # fb[first_active_band, :first_active_bin] = 1.0
492
+
493
+ # return fb
494
+
495
+
496
+ # class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
497
+ # def __init__(
498
+ # self,
499
+ # nfft: int,
500
+ # fs: int,
501
+ # n_bands: int,
502
+ # f_min: float = 0.0,
503
+ # f_max: float = None
504
+ # ) -> None:
505
+ # super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
506
+
507
+ if __name__ == "__main__":
508
+ import pandas as pd
509
+
510
+ band_defs = []
511
+
512
+ for bands in [VocalBandsplitSpecification]:
513
+ band_name = bands.__name__.replace("BandsplitSpecification", "")
514
+
515
+ mbs = bands(nfft=2048, fs=44100).get_band_specs()
516
+
517
+ for i, (f_min, f_max) in enumerate(mbs):
518
+ band_defs.append(
519
+ {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
520
+ )
521
+
522
+ df = pd.DataFrame(band_defs)
523
+ df.to_csv("vox7bands.csv", index=False)
separator/models/bs_roformer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from models.bs_roformer.bs_roformer import BSRoformer
2
+ from models.bs_roformer.mel_band_roformer import MelBandRoformer
separator/models/bs_roformer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (281 Bytes). View file