diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..71b0f52feaea731130d5d2605908c2a4b854ed5e --- /dev/null +++ b/app.py @@ -0,0 +1,131 @@ +import os +import sys +import time +import shutil +import argparse +from datetime import datetime +import gradio as gr +os.system("pip install https://github.com/noblebarkrr/mvsepless/blob/bd611441e48e918650e6860738894673b3a1a5f1/fixed/audio_separator-0.32.0-py3-none-any.whl") +from multi_inference import MVSEPLESS, OUTPUT_FORMATS +from assets.translations import TRANSLATIONS, TRANSLATIONS_STEMS + +OUTPUT_DIR = os.path.join(os.getcwd(), "output") +plugins_dir = os.path.join(os.getcwd(), "plugins") +os.makedirs(plugins_dir, exist_ok=True) + +CURRENT_LANG = "ru" + +def t(key, **kwargs): + """Функция для получения перевода с подстановкой значений""" + lang = CURRENT_LANG + translation = TRANSLATIONS.get(lang, {}).get(key, key) + return translation.format(**kwargs) if kwargs else translation + +def t_stem(key, **kwargs): + """Функция для получения перевода с подстановкой значений""" + lang = CURRENT_LANG + translation = TRANSLATIONS_STEMS.get(lang, {}).get(key, key) + return translation.format(**kwargs) if kwargs else translation + +def gen_out_dir(): + return os.path.join(OUTPUT_DIR, datetime.now().strftime("%Y%m%d_%H%M%S")) + +mvsepless = MVSEPLESS() + +def sep_wrapper(a, b, c, d, e, f, g, h): + results = mvsepless.separator(input_file=a, output_dir=gen_out_dir(), model_type=b, model_name=c, ext_inst=d, vr_aggr=e, output_format=f, output_bitrate=f'{g}k', call_method="cli", selected_stems=h) + stems = [] + if results: + for i, (stem, output_file) in enumerate(results[:20]): + stems.append(gr.update( + visible=True, + label=t_stem(stem), + value=output_file + )) + + while len(stems) < 20: + stems.append(gr.update(visible=False, label=None, value=None)) + + return tuple(stems) + + +theme = gr.themes.Default( + primary_hue="violet", + secondary_hue="cyan", + neutral_hue="blue", + spacing_size="sm", + font=[gr.themes.GoogleFont("Tektur"), 'ui-sans-serif', 'system-ui', 'sans-serif'], + ).set( + body_text_color='*neutral_950', + body_text_color_subdued='*neutral_500', + background_fill_primary='*neutral_200', + background_fill_primary_dark='*neutral_800', + border_color_accent='*primary_950', + border_color_accent_dark='*neutral_700', + border_color_accent_subdued='*primary_500', + border_color_primary='*primary_800', + border_color_primary_dark='*neutral_400', + color_accent_soft='*primary_100', + color_accent_soft_dark='*neutral_800', + link_text_color='*secondary_700', + link_text_color_active='*secondary_700', + link_text_color_hover='*secondary_800', + link_text_color_visited='*secondary_600', + link_text_color_visited_dark='*secondary_700', + block_background_fill='*background_fill_secondary', + block_background_fill_dark='*neutral_950', + block_label_background_fill='*secondary_400', + block_label_text_color='*neutral_800', + panel_background_fill='*background_fill_primary', + checkbox_background_color='*background_fill_secondary', + checkbox_label_background_fill_dark='*neutral_900', + input_background_fill_dark='*neutral_900', + input_background_fill_focus='*neutral_100', + input_background_fill_focus_dark='*neutral_950', + button_small_radius='*radius_sm', + button_secondary_background_fill='*neutral_400', + button_secondary_background_fill_dark='*neutral_500', + button_secondary_background_fill_hover_dark='*neutral_950' + ) + + +def create_app(): + with gr.Row(): + with gr.Column(): + input_audio = gr.Audio(label=t("select_file"), interactive=True, type="filepath") + input_audio_path = gr.Textbox(label=t("audio_path"), info=t("audio_path_info"), interactive=True) + with gr.Column(): + with gr.Row(): + model_type = gr.Dropdown(label=t("model_type"), choices=mvsepless.get_mt(), value=mvsepless.get_mt()[0], interactive=True, filterable=False) + model_name = gr.Dropdown(label=t("model_name"), choices=mvsepless.get_mn(mvsepless.get_mt()[0]), value=mvsepless.get_mn(mvsepless.get_mt()[0])[0], interactive=True, filterable=False) + target_instrument = gr.Textbox(label=t("target_instrument"), value=mvsepless.get_tgt_inst(mvsepless.get_mt()[0], mvsepless.get_mn(mvsepless.get_mt()[0])[0]), interactive=False) + vr_aggr = gr.Slider(0, 100, step=1, label=t("vr_aggressiveness"), visible=False, value=5, interactive=True) + extract_instrumental = gr.Checkbox(label=t("extract_instrumental"), value=True, interactive=True) + stems_list = gr.CheckboxGroup(label=t("stems_list"), info=t("stems_info", target_instrument="vocals"), choices=mvsepless.get_stems(mvsepless.get_mt()[0], mvsepless.get_mn(mvsepless.get_mt()[0])[0]), value=None, interactive=False) + with gr.Row(): + output_format, output_bitrate = gr.Dropdown(label=t("output_format"), choices=OUTPUT_FORMATS, value="mp3", interactive=True, filterable=False), gr.Slider(32, 320, step=1, label=t("bitrate"), value=320, interactive=True) + separate_btn = gr.Button(t("separate_btn"), variant="primary", interactive=True) + download_via_zip_btn = gr.DownloadButton(label="Download via zip", visible=False, interactive=True) + output_stems = [] + for _ in range(10): + with gr.Row(): + audio1 = gr.Audio(visible=False, interactive=False, type="filepath", show_download_button=True) + audio2 = gr.Audio(visible=False, interactive=False, type="filepath", show_download_button=True) + output_stems.extend([audio1, audio2]) + + input_audio.upload(fn=(lambda x: gr.update(value=x)), inputs=input_audio, outputs=input_audio_path) + model_type.change(fn=(lambda x: gr.update(choices=mvsepless.get_mn(x), value=mvsepless.get_mn(x)[0])), inputs=model_type, outputs=model_name).then(fn=(lambda x: (gr.update(visible=False if x in ["vr", "mdx"] else True), gr.update(visible=True if x == "vr" else False))), inputs=model_type, outputs=[extract_instrumental, vr_aggr]) + model_name.change(fn=(lambda x, y: gr.update(choices=mvsepless.get_stems(x, y), value=None)), inputs=[model_type, model_name], outputs=stems_list).then(fn=(lambda x, y: (gr.update(interactive=True if mvsepless.get_tgt_inst(x, y) == None else None, info=t("stems_info", target_instrument=mvsepless.get_tgt_inst(x, y)) if mvsepless.get_tgt_inst(x, y) is not None else t("stems_info2")), gr.update(value=mvsepless.get_tgt_inst(x, y)), gr.update(value=True if mvsepless.get_tgt_inst(x, y) is not None else False))), inputs=[model_type, model_name], outputs=[stems_list, target_instrument, extract_instrumental]) + separate_btn.click(fn=sep_wrapper, inputs=[input_audio_path, model_type, model_name, extract_instrumental, vr_aggr, output_format, output_bitrate, stems_list], outputs=output_stems, show_progress_on=input_audio) + + +CURRENT_LANG = ru +css = """ +.fixed-height { height: 160px !important; min-height: 160px !important; } +.fixed-height2 { height: 250px !important; min-height: 250px !important; } +""" + +with gr.Blocks(theme=theme, css=css) as app: + create_app() + +app.launch(allowed_paths=["/"], server_port=7860, share=False) diff --git a/assets/translations.py b/assets/translations.py new file mode 100644 index 0000000000000000000000000000000000000000..59c7ea3dc7e15b325b9c2dafa9f8f50f10d369a3 --- /dev/null +++ b/assets/translations.py @@ -0,0 +1,171 @@ +TRANSLATIONS = { + "ru": { + "app_title": "MVSEPLESS", + "separation": "Разделение", + "plugins": "Плагины", + "select_file": "Выберите файл", + "audio_path": "Путь к файлу", + "audio_path_info": "Здесь можно ввести путь к файлу, либо загрузить его выше и получить путь к загруженному файлу", + "model_type": "Тип модели", + "model_name": "Имя модели", + "vr_aggressiveness": "Агрессивность для VR моделей", + "extract_instrumental": "Извлечь инструментал", + "stems_list": "Список стемов", + "output_format": "Формат вывода", + "separate_btn": "Разделить", + "upload": "Загрузка плагинов (.py)", + "upload_btn": "Загрузить", + "loading_plugin": "Загружается плагин: {name}", + "error_loading_plugin": "Произошла ошибка при загрузке плагина: {e}", + "target_instrument": "Целевой инструмент", + "stems_info": "Выбор стемов недоступен\nДля извлечения второго стема включите \"Извлечь инструментал\"", + "stems_info2": "Для получения остатка (при выбранных стемах), включите \"Извлечь инструментал\"", + "bitrate": "Битрейт (Кбит/сек)" + }, + "en": { + "app_title": "MVSEPLESS", + "separation": "Separation", + "plugins": "Plugins", + "select_file": "Select File", + "audio_path": "Audio path", + "audio_path_info": "You can enter the file path here, or upload it above and get the path to the uploaded file.", + "model_type": "Model Type", + "model_name": "Model Name", + "vr_aggressiveness": "Aggressiveness for VR Models", + "extract_instrumental": "Extract Instrumental", + "stems_list": "Stems List", + "output_format": "Output Format", + "separate_btn": "Separate", + "upload": "Upload plugins (.py)", + "upload_btn": "Upload", + "loading_plugin": "Loading plugin: {name}", + "error_loading_plugin": "As error occured loading plugin: {e}", + "target_instrument": "Target instrument", + "stems_info": "Stem selection unavailable\nEnable \"Extract Instrumental\" to extract the second stem", + "stems_info2": "To extract the residual (with selected_stems), enable \"Extract Instrumental\"", + "bitrate": "Bitrate (Kbit/sec)" + } +} + +TRANSLATIONS_STEMS = { + "ru": { + "vocals": "Вокал", + "Vocals": "Вокал", + "other": "Другое", + "Other": "Другое", + "Instrumental": "Инструментал", + "instrumnetal": "Инструментал", + "instrumental +": "Инструментал +", + "instrumental -": "Инструментал -", + "Bleed": "Фон", + "Guitar": "Гитара", + "drums": "Барабаны", + "bass": "Бас", + "karaoke": "Караоке", + "reverb": "Реверберация", + "noreverb": "Без реверберации", + "aspiration": "Придыхание", + "dry": "Сухой звук", + "crowd": "Толпа", + "percussions": "Перкуссия", + "piano": "Пианино", + "guitar": "Гитара", + "male": "Мужской", + "female": "Женский", + "kick": "Кик", + "snare": "Малый барабан", + "toms": "Том-томы", + "hh": "Хай-хэт", + "ride": "Райд", + "crash": "Крэш", + "similarity": "Сходство", + "difference": "Различие", + "inst": "Инструмент", + "orch": "Оркестр", + "No Woodwinds": "Без деревянных духовых", + "Woodwinds": "Деревянные духовые", + "No Echo": "Без эха", + "Echo": "Эхо", + "No Reverb": "Без реверберации", + "Reverb": "Реверберация", + "Noise": "Шум", + "No Noise": "Без шума", + "Dry": "Сухой звук", + "No Dry": "Не сухой звук", + "Breath": "Дыхание", + "No Breath": "Без дыхания", + "No Crowd": "Без толпы", + "Crowd": "Толпа", + "No Other": "Без другого", + "Bass": "Бас", + "No Bass": "Без баса", + "Drums": "Барабаны", + "No Drums": "Без барабанов", + "speech": "Речь", + "music": "Музыка", + "effects": "Эффекты", + "sfx": "Звуковые эффекты", + "inverted +": "Инверсия +", + "inverted -": "Инверсия -" + }, + "en": { + "vocals": "Vocals", + "Vocals": "Vocals", + "other": "Other", + "Other": "Other", + "Instrumental": "Instrumental", + "instrumnetal": "Instrumental", + "instrumental +": "Instrumental +", + "instrumental -": "Instrumental -", + "Bleed": "Bleed", + "Guitar": "Guitar", + "drums": "Drums", + "bass": "Bass", + "karaoke": "Karaoke", + "reverb": "Reverb", + "noreverb": "No reverb", + "aspiration": "Aspiration", + "dry": "Dry", + "crowd": "Crowd", + "percussions": "Percussions", + "piano": "Piano", + "guitar": "Guitar", + "male": "Male", + "female": "Female", + "kick": "Kick", + "snare": "Snare", + "toms": "Toms", + "hh": "Hi-hat", + "ride": "Ride", + "crash": "Crash", + "similarity": "Similarity", + "difference": "Difference", + "inst": "Instrument", + "orch": "Orchestra", + "No Woodwinds": "No Woodwinds", + "Woodwinds": "Woodwinds", + "No Echo": "No Echo", + "Echo": "Echo", + "No Reverb": "No Reverb", + "Reverb": "Reverb", + "Noise": "Noise", + "No Noise": "No Noise", + "Dry": "Dry", + "No Dry": "No Dry", + "Breath": "Breath", + "No Breath": "No Breath", + "No Crowd": "No Crowd", + "Crowd": "Crowd", + "No Other": "No Other", + "Bass": "Bass", + "No Bass": "No Bass", + "Drums": "Drums", + "No Drums": "No Drums", + "speech": "Speech", + "music": "Music", + "effects": "Effects", + "sfx": "SFX", + "inverted +": "Inverted +", + "inverted -": "Inverted -" + } +} diff --git a/model_list.py b/model_list.py new file mode 100644 index 0000000000000000000000000000000000000000..a6534cdf33ad7cc8da0e47e1a49b39a6773af8ce --- /dev/null +++ b/model_list.py @@ -0,0 +1,1942 @@ +models_data = { + + "mel_band_roformer": { + + "Mel-Band-Roformer_Vocals_kimberley_jensen": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Vocals by Kimberley Jensen", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/KimberleyJSN/melbandroformer/resolve/main/MelBandRoformer.ckpt?download=true", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/KimberleyJensen/config_vocals_mel_band_roformer_kj.yaml" + }, + + "Mel-Band-Roformer_InstVoc_Duality_v1_unwa": { + "category": "Инструментал и вокал", + "full_name": "Mel-Band Roformer InstVoc Duality v1 by Unwa", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/resolve/main/melband_roformer_instvoc_duality_v1.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/resolve/main/config_melbandroformer_instvoc_duality.yaml?download=true" + }, + + "Mel-Band-Roformer_InstVoc_Duality_v2_unwa": { + "category": "Инструментал и вокал", + "full_name": "Mel-Band Roformer InstVoc Duality v2 by Unwa", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/resolve/main/melband_roformer_instvox_duality_v2.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/resolve/main/config_melbandroformer_instvoc_duality.yaml?download=true" + }, + + "Mel-Band-Roformer_Kim_FT_v1_unwa": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Kim FT v1 by Unwa", + "stems": ["Vocals", "other"], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml?download=true" + }, + + "Mel-Band-Roformer_Kim_FT_v2_unwa": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Kim FT v2 by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft2.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml?download=true" + }, + + "Mel-Band-Roformer_Kim_FT_v2_bleedless_unwa": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Kim FT v2 Bleedless by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft2_bleedless.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml?download=true" + }, + + "Mel-Band-Roformer_Kim_FT_v3_prev_unwa": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Kim FT v3 preview by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft3_prev.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml?download=true" + }, + + "Mel-Band-Roformer_Big_Beta_v1_unwa": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Big Beta v1 by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/melband_roformer_big_beta1.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/config_melbandroformer_big.yaml?download=true" + }, + + "Mel-Band-Roformer_Big_Beta_v2_unwa": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Big Beta v2 by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/melband_roformer_big_beta2.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/config_melbandroformer_big.yaml?download=true" + }, + + "Mel-Band-Roformer_Big_Beta_v3_unwa": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Big Beta v3 by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/melband_roformer_big_beta3.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/config_melbandroformer_big.yaml?download=true" + }, + + "Mel-Band-Roformer_Big_Beta_v4_unwa": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Big Beta v4 by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/melband_roformer_big_beta4.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/config_melbandroformer_big_beta4.yaml?download=true" + }, + + "Mel-Band-Roformer_Big_Beta_v5e_unwa": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Vocals Big Beta v5e by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta5e.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta5e.yaml?download=true" + }, + + "Mel-Band-Roformer_Big_Beta_v6_unwa": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Big Beta v6 by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta6.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta6.yaml?download=true" + }, + + "Mel-Band-Roformer_Big_Beta_v6x_unwa": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Big Beta v6x by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta6x.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta6x.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_v1_unwa": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental v1 by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/melband_roformer_inst_v1.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_v1_plus_unwa": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental v1+ by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/inst_v1_plus_test.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_v1e_unwa": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental v1e by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/inst_v1e.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_v1e_plus_unwa": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental v1e Plus by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/inst_v1e_plus.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_v2_unwa": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental v2 by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/melband_roformer_inst_v2.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst_v2.yaml?download=true" + }, + + "Mel-Band-Roformer_Small_v1_unwa": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Small v1 by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-small/resolve/main/melband_roformer_small_v1.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/Mel-Band-Roformer-small/resolve/main/config_melbandroformer_small.yaml?download=true" + }, + + "Mel-Band-Roformer_Bleed_Suppressor_v1_unwa_97chris": { + "category": "Шум", + "full_name": "Mel-Band Roformer Bleed Suppressor v1 by Unwa / 97chris", + "stems": ["Instrumental", "Bleed"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/jarredou/bleed_suppressor_melband_rofo_by_unwa_97chris/resolve/main/bleed_suppressor_v1.ckpt?download=true", + "config_url": "https://huggingface.co/jarredou/bleed_suppressor_melband_rofo_by_unwa_97chris/resolve/main/config_bleed_suppressor_v1.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_becruliy": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental by Becruily", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/becruily/mel-band-roformer-instrumental/resolve/main/mel_band_roformer_instrumental_becruily.ckpt?download=true", + "config_url": "https://huggingface.co/becruily/mel-band-roformer-instrumental/resolve/main/config_instrumental_becruily.yaml?download=true" + }, + + "Mel-Band-Roformer_Guitar_becruily": { + "category": "Гитара", + "full_name": "Mel-Band Roformer Instrumental by Becruily", + "stems": ["Guitar", "Other"], + "target_instrument": "Guitar", + "checkpoint_url": "https://huggingface.co/becruily/mel-band-roformer-guitar/resolve/main/becruily_guitar.ckpt?download=true", + "config_url": "https://huggingface.co/becruily/mel-band-roformer-guitar/resolve/main/config_guitar_becruily.yaml?download=true" + }, + + "Mel-Band-Roformer_Karaoke_becruily": { + "category": "Караоке", + "full_name": "Mel-Band Roformer Karaoke by Becruily", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/becruily/mel-band-roformer-karaoke/resolve/main/mel_band_roformer_karaoke_becruily.ckpt?download=true", + "config_url": "https://huggingface.co/becruily/mel-band-roformer-karaoke/resolve/main/config_karaoke_becruily.yaml?download=true" + }, + + "Mel-Band-Roformer_Vocals_becruily": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Vocals by Becruily", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/becruily/mel-band-roformer-vocals/resolve/main/mel_band_roformer_vocals_becruily.ckpt?download=true", + "config_url": "https://huggingface.co/becruily/mel-band-roformer-vocals/resolve/main/config_vocals_becruily.yaml?download=true" + }, + + "Mel-Band-Roformer_SYHFT_v1_syh99999": { + "category": "Вокал", + "full_name": "Mel-Band Roformer SYHFT v1 by SYH99999", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/MelBandRoformerSYHFT.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/config_vocals_mel_band_roformer_ft.yaml?download=true" + }, + + "Mel-Band-Roformer_SYHFT_v2_syh99999": { + "category": "Вокал", + "full_name": "Mel-Band Roformer SYHFT v2 by SYH99999", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTV2/resolve/main/MelBandRoformerSYHFTV2.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/config_vocals_mel_band_roformer_ft.yaml?download=true" + }, + + "Mel-Band-Roformer_SYHFT_v2.5_syh99999": { + "category": "Вокал", + "full_name": "Mel-Band Roformer SYHFT v2.5 by SYH99999", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTV2.5/resolve/main/MelBandRoformerSYHFTV2.5.ckpt/MelBandRoformerSYHFTV2.5.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/config_vocals_mel_band_roformer_ft.yaml?download=true" + }, + + "Mel-Band-Roformer_SYHFT_v3_syh99999": { + "category": "Вокал", + "full_name": "Mel-Band Roformer SYHFT v3 by SYH99999", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTV3Epsilon/resolve/main/MelBandRoformerSYHFTV3Epsilon.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/config_vocals_mel_band_roformer_ft.yaml?download=true" + }, + + "Mel-Band-Roformer_BIG_SYHFT_v1_Fast_syh99999": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Big SYHFT v1 Fast by SYH99999", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerBigSYHFTV1Fast/resolve/main/MelBandRoformerBigSYHFTV1.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerBigSYHFTV1Fast/resolve/main/config.yaml?download=true" + }, + + "Mel-Band-Roformer_Merged_SYHFT_Beta_v1_syh99999": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Merged Beta v1 by SYH99999", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerMergedSYHFTBeta1/resolve/main/merge_syhft.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFT/resolve/main/config_vocals_mel_band_roformer_ft.yaml?download=true" + }, + + "Mel-Band-Roformer_SYHFT_B1_model1_syh99999": { + "category": "Вокал", + "full_name": "Mel-Band Roformer SYHFT B1 1 by SYH99999", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/model.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/config.yaml?download=true" + }, + + "Mel-Band-Roformer_SYHFT_B1_model2_syh99999": { + "category": "Вокал", + "full_name": "Mel-Band Roformer SYHFT B1 2 by SYH99999", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/model2.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/config.yaml?download=true" + }, + + "Mel-Band-Roformer_SYHFT_B1_model3_syh99999": { + "category": "Вокал", + "full_name": "Mel-Band Roformer SYHFT B1 3 by SYH99999", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/model3.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/config.yaml?download=true" + }, + + "Mel-Band-Roformer_4_stems_FT_Large_v1_syh99999": { + "category": "4 стема", + "full_name": "Mel-Band Roformer 4 Stems FT Large v1 by SYH99999", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/MelBandRoformer4StemFTLarge.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/config.yaml?download=true" + }, + + "Mel-Band-Roformer_4_stems_FT_Large_v2_syh99999": { + "category": "4 стема", + "full_name": "Mel-Band Roformer 4 Stems FT Large v2 by SYH99999", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/ver2.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/config.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_1652_essid": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental by Essid (sdr 16.52)", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/Essid/Essid-MelBandRoformer/resolve/3960860f7895c87a12707ca6b378df2b3c68e2c0/model_mel_band_roformer_ep_17_sdr_16.5244.ckpt?download=true", + "config_url": "https://huggingface.co/Essid/Essid-MelBandRoformer/resolve/4768859bd59bc699d33f4567e82082993dde7eb9/config_vocals_mel_band_roformer_essid.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_1681_essid": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental by Essid (sdr 16.81)", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/Essid/Essid-MelBandRoformer/resolve/main/essid_mel_inst_old.ckpt?download=true", + "config_url": "https://huggingface.co/Essid/Essid-MelBandRoformer/resolve/4768859bd59bc699d33f4567e82082993dde7eb9/config_vocals_mel_band_roformer_essid.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Fv1_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Fv1 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv1.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Fv2_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Fv2 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv2.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Fv3_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Fv3 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv3.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Fv4N_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Fv4 Noise by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_Fv4Noise.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Fv5_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Fv5 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV5.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Fv5N_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Fv5 Noise by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV5N.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Fv6_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Fv6 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV6.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Fv6N_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Fv6 Noise by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV6N.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Fv7_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Fv7 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/Inst_GaboxV7.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Fv7N_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Fv7 Noise by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV7N.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Fv7_plus_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Fv7+ by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/5fba9605d4b6bc1a31c04c50d08d757c5107d23f/melbandroformers/experimental/instv7plus.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Fv7z_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Fv7z by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/Inst_GaboxFv7z.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + + "Mel-Band-Roformer_Instrumental_Fv8_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Fv8 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/Inst_GaboxFv8.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Fv8b_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Fv8b by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/experimental/Inst_FV8b.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + + "Mel-Band-Roformer_Instrumental_Fv9_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Fv9 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/experimental/Inst_Fv9.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Fv10_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Fv10 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/experimental/INSTV10.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_FvX_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental FvX by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/Inst_GaboxFVX.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Bv1_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Bv1 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxBv1.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Bv2_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Bv2 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxBv2.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Bv3_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Bv3 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxBv3.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_small_gabox": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Instrumental Small by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/experimental/small_inst.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + + "Mel-Band-Roformer_Vocals_Fv1_gabox": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Vocals Fv1 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gaboxFv1.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Vocals_Fv2_gabox": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Vocals Fv2 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gaboxFv2.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Vocals_Fv3_gabox": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Vocals Fv3 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_Fv3.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Vocals_Fv4_gabox": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Vocals Fv4 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_fv4.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Vocals_Fv5_gabox": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Vocals Fv5 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_fv5.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Karaoke_25_02_2025_gabox": { + "category": "Караоке", + "full_name": "Mel-Band Roformer Karaoke 25-02-2025 by GaboxR67", + "stems": ["karaoke", "other"], + "target_instrument": "karaoke", + "checkpoint_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/gabox_karaoke_25_02_2025.ckpt?download=true", + "config_url": "https://huggingface.co/jarredou/aufr33-viperx-karaoke-melroformer-model/resolve/main/config_mel_band_roformer_karaoke.yaml?download=true" + }, + + "Mel-Band-Roformer_Karaoke_28_02_2025_gabox": { + "category": "Караоке", + "full_name": "Mel-Band Roformer Karaoke 28-02-2025 by GaboxR67", + "stems": ["karaoke", "other"], + "target_instrument": "karaoke", + "checkpoint_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/gabox_karaoke_28_02_2025.ckpt?download=true", + "config_url": "https://huggingface.co/jarredou/aufr33-viperx-karaoke-melroformer-model/resolve/main/config_mel_band_roformer_karaoke.yaml?download=true" + }, + + "Mel-Band-Roformer_Karaoke_v1_gabox": { + "category": "Караоке", + "full_name": "Mel-Band Roformer Karaoke v1 by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/27e73ca2beec0ab7daa46e366159753a166612e1/melbandroformers/karaoke/Karaoke_GaboxV1.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/27e73ca2beec0ab7daa46e366159753a166612e1/melbandroformers/karaoke/karaokegabox_1750911344.yaml?download=true" + }, + + "Mel-Band-Roformer_Denoise_DeBleed_gabox": { + "category": "Шум", + "full_name": "Mel-Band Roformer Denoise DeBleed by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Instrumental", + "checkpoint_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/denoisedebleed.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml?download=true" + }, + + "Mel-Band-Roformer_Karaoke_Fusion_gonzaluigi": { + "category": "Караоке", + "full_name": "Mel-Band Roformer Karaoke Fusion by Gonzaluigi", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/mel_band_karaoke_fusion_standard.ckpt?download=true", + "config_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/melband_karaokefusion_gonza.yaml?download=true" + }, + + "Mel-Band-Roformer_Karaoke_Fusion_Aggr_gonzaluigi": { + "category": "Караоке", + "full_name": "Mel-Band Roformer Karaoke Fusion Aggressive by Gonzaluigi", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Vocals", + "checkpoint_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/mel_band_karaoke_fusion_aggressive.ckpt?download=true", + "config_url": "https://huggingface.co/Gonzaluigi/Mel-Band-Karaoke-Fusion/resolve/main/melband_karaokefusion_gonza.yaml?download=true" + }, + + + + + + + + "Mel-Band-Roformer_DeReverb_anvuew": { + "category": "Реверб", + "full_name": "Mel-Band Roformer DeReverb by Anvuew", + "stems": ["reverb", "noreverb"], + "target_instrument": "noreverb", + "checkpoint_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt?download=true", + "config_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml?download=true" + }, + + "Mel-Band-Roformer_DeReverb_Less_Aggr_anvuew": { + "category": "Реверб", + "full_name": "Mel-Band Roformer DeReverb Less Aggressive by Anvuew", + "stems": ["reverb", "noreverb"], + "target_instrument": "noreverb", + "checkpoint_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_less_aggressive_anvuew_sdr_18.8050.ckpt?download=true", + "config_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml?download=true" + }, + + "Mel-Band-Roformer_DeReverb_Mono_anvuew": { + "category": "Реверб", + "full_name": "Mel-Band Roformer DeReverb Mono by Anvuew", + "stems": ["reverb", "noreverb"], + "target_instrument": "noreverb", + "checkpoint_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_mono_anvuew_sdr_20.4029.ckpt?download=true", + "config_url": "https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml?download=true" + }, + + "Mel-Band-Roformer_Aspiration_sucial": { + "category": "Дыхание", + "full_name": "Mel-Band Roformer Aspiration by Sucial", + "stems": ["aspiration", "other"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/Sucial/Aspiration_Mel_Band_Roformer/resolve/main/aspiration_mel_band_roformer_less_aggr_sdr_18.1201.ckpt?download=true", + "config_url": "https://huggingface.co/Sucial/Aspiration_Mel_Band_Roformer/resolve/main/config_aspiration_mel_band_roformer.yaml?download=true" + }, + + "Mel-Band-Roformer_DeReverb-Echo_v1_sucial": { + "category": "Реверб и эхо", + "full_name": "Mel-Band Roformer DeReverb-Echo by Sucial", + "stems": ["dry", "other"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/dereverb-echo_mel_band_roformer_sdr_10.0169.ckpt?download=true", + "config_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb-echo_mel_band_roformer.yaml?download=true" + }, + + "Mel-Band-Roformer_DeBigReverb_sucial": { + "category": "Реверб", + "full_name": "Mel-Band Roformer DeBigReverb by Sucial", + "stems": ["dry", "other"], + "target_instrument": "dry", + "checkpoint_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/de_big_reverb_mbr_ep_362.ckpt?download=true", + "config_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb_echo_mbr_v2.yaml?download=true" + }, + + "Mel-Band-Roformer_DeSuperBigReverb_sucial": { + "category": "Реверб", + "full_name": "Mel-Band Roformer Super Big DeReverb by Sucial", + "stems": ["dry", "other"], + "target_instrument": "dry", + "checkpoint_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/de_super_big_reverb_mbr_ep_346.ckpt?download=true", + "config_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb_echo_mbr_v2.yaml?download=true" + }, + + "Mel-Band-Roformer_DeReverb-Echo_Fused_sucial": { + "category": "Реверб и эхо", + "full_name": "Mel-Band Roformer DeReverb-Echo Fused by Sucial", + "stems": ["dry", "other"], + "target_instrument": "dry", + "checkpoint_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/dereverb_echo_mbr_fused_0.5_v2_0.25_big_0.25_super.ckpt?download=true", + "config_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb_echo_mbr_v2.yaml?download=true" + }, + + "Mel-Band-Roformer_DeReverb-Echo_v2_sucial": { + "category": "Реверб и эхо", + "full_name": "Mel-Band Roformer DeReverb-Echo v2 by Sucial", + "stems": ["dry", "other"], + "target_instrument": "dry", + "checkpoint_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/dereverb_echo_mbr_v2_sdr_dry_13.4843.ckpt?download=true", + "config_url": "https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb_echo_mbr_v2.yaml?download=true" + }, + + "Mel-Band-Roformer_Karaoke_aufr33_viperx": { + "category": "Караоке", + "full_name": "Mel-Band Roformer Karaoke by Aufr33 & ViperX", + "stems": ["karaoke", "other"], + "target_instrument": "karaoke", + "checkpoint_url": "https://huggingface.co/jarredou/aufr33-viperx-karaoke-melroformer-model/resolve/main/mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956.ckpt?download=true", + "config_url": "https://huggingface.co/jarredou/aufr33-viperx-karaoke-melroformer-model/resolve/main/config_mel_band_roformer_karaoke.yaml?download=true" + }, + + "Mel-Band-Roformer_DeNoise_aufr33": { + "category": "Шум", + "full_name": "Mel-Band Roformer DeNoise by Aufr33", + "stems": ["dry", "other"], + "target_instrument": "dry", + "checkpoint_url": "https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768.ckpt?download=true", + "config_url": "https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/model_mel_band_roformer_denoise.yaml?download=true" + }, + + "Mel-Band-Roformer_Denoise_Aggr_aufr33": { + "category": "Шум", + "full_name": "Mel-Band Roformer DeNoise Aggressive by Aufr33", + "stems": ["dry", "other"], + "target_instrument": "dry", + "checkpoint_url": "https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768.ckpt?download=true", + "config_url": "https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/model_mel_band_roformer_denoise.yaml?download=true" + }, + + "Mel-Band-Roformer_Crowd_aufr33_viperx": { + "category": "Звуки толпы", + "full_name": "Mel-Band Roformer Crowd by Aufr33 & ViperX", + "stems": ["crowd", "other"], + "target_instrument": "crowd", + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.4/mel_band_roformer_crowd_aufr33_viperx_sdr_8.7144.ckpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.4/model_mel_band_roformer_crowd.yaml" + }, + + "Mel-Band-Roformer_Vocals_viperx": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Vocals by ViperX", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/viperx/model_mel_band_roformer_ep_3005_sdr_11.4360.yaml" + }, + + "Mel-Band-Roformer_Vocals_Fullness_aname": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Vocals Fullness by Aname", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/FullnessVocalModel.ckpt?download=true", + "config_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/config.yaml?download=true" + }, + + "Mel-Band-Roformer_Kim_FT_v1_aname": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Kim FT v1 by Aname", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/model_kim.ckpt?download=true", + "config_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/config.yaml?download=true" + }, + + "Mel-Band-Roformer_Kim_FT_v2_aname": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Kim FT v2 by Aname", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/model_kim_2.ckpt?download=true", + "config_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/config.yaml?download=true" + }, + + "Mel-Band-Roformer_Kim_FT_v2_Fullness_aname": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Kim FT v2 Fullness by Aname", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/model_kim_2_fullness.ckpt?download=true", + "config_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/config.yaml?download=true" + }, + + "Mel-Band-Roformer_Kim_FT_v3_aname": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Kim FT v3 by Aname", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/model_kim_3.ckpt?download=true", + "config_url": "https://huggingface.co/Aname-Tommy/MelBandRoformers/resolve/main/config.yaml?download=true" + }, + + "Mel-Band-Roformer_kapm_aname": { + "category": "Вокал", + "full_name": "Mel-Band Roformer kapm by Aname", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/kapm.ckpt?download=true", + "config_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/kapm.yaml?download=true" + }, + + "Mel-Band-Roformer_Small_aname": { + "category": "Вокал", + "full_name": "Mel-Band Roformer Small by Aname", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/Aname-Tommy/Mel_Band_Roformer_small/resolve/main/mel_band_roformer_small.ckpt?download=true", + "config_url": "https://huggingface.co/Aname-Tommy/Mel_Band_Roformer_small/resolve/main/config.yaml?download=true" + }, + + "Mel-Band-Roformer_4_stems_Large_aname": { + "category": "4 стема", + "full_name": "Mel-Band Roformer 4 Stems Large by Aname", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/Aname-Tommy/melbandroformer4stems/resolve/main/mel_band_roformer_4stems_large_ver1.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/config.yaml?download=true" + }, + + "Mel-Band-Roformer_4_stems_v2_Large_aname": { + "category": "4 стема", + "full_name": "Mel-Band Roformer 4 Stems v2 Large by Aname", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/Aname-Tommy/Test/resolve/main/4stemsver2.ckpt?download=true", + "config_url": "https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/config.yaml?download=true" + }, + + "Mel-Band-Roformer_4_stems_XL_aname": { + "category": "4 стема", + "full_name": "Mel-Band Roformer 4 Stems XL by Aname", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/Aname-Tommy/melbandroformer4stems/resolve/main/mel_band_roformer_4stems_xl_ver1.ckpt?download=true", + "config_url": "https://huggingface.co/Aname-Tommy/melbandroformer4stems/resolve/main/config_xl.yaml?download=true" + }, + + "Mel-Band-Roformer_Drums_yolkispaliks": { + "category": "Ударные", + "full_name": "Mel-Band Roformer Drums Experimental by yolkispalkis", + "stems": ["percussions", "other"], + "target_instrument": "percussions", + "checkpoint_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/model_mel_band_roformer_ep_11_sdr_7.6853.ckpt?download=true", + "config_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/config_drums_musdb18_moises_mel_band_roformer.yaml?download=true" + }, + + "Mel-Band-Roformer_Instrumental_Metal_Preview_meskvlla33": { + "category": "Инструментал", + "full_name": "Mel-Band Roformer Metal Inst Preview by Mesk", + "stems": ["vocals", "other"], + "target_instrument": "other", + "checkpoint_url": "https://huggingface.co/meskvlla33/metal_roformer_preview/resolve/main/metal_roformer_inst_mesk_preview.ckpt?download=true", + "config_url": "https://huggingface.co/meskvlla33/metal_roformer_preview/resolve/main/config_inst_metal_roformer_mesk.yaml?download=true" + } + + }, + + "bs_roformer": { + + "BS-Roformer_Drums_beatloo_labs": { + "category": "Ударные", + "full_name": "BS Roformer Drums Experimental by BeatLoo Labs", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": "drums", + "checkpoint_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/model_drums_bs_roformer_ep_12_sdr_7.2279.ckpt?download=true", + "config_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/config_4_drums_bs_roformer.yaml?download=true" + }, + + "BS-Roformer_Bass_beatloo_labs": { + "category": "Басс", + "full_name": "BS Roformer Bass Experimental by BeatLoo Labs", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": "bass", + "checkpoint_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/model_bass_bs_roformer_ep_10_sdr_5.7862.ckpt?download=true", + "config_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/config_4_bass_bs_roformer.yaml?download=true" + }, + + "BS-Roformer_Vocals_1296_viperx": { + "category": "Вокал", + "full_name": "BS Roformer Vocals (sdr 12.96) by ViperX", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "Vocals", + "checkpoint_url": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_368_sdr_12.9628.ckpt", + "config_url": "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/mdx_c_configs/model_bs_roformer_ep_368_sdr_12.9628.yaml" + }, + + "BS-Roformer_Other_viperx": { + "category": "Прочее", + "full_name": "BS Roformer Other by ViperX", + "stems": ["vocals", "other"], + "target_instrument": "other", + "checkpoint_url": "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_937_sdr_10.5309.ckpt", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/viperx/model_bs_roformer_ep_937_sdr_10.5309.yaml" + }, + + "BS-Roformer_Revive_v1_unwa": { + "category": "Вокал", + "full_name": "BS Roformer Vocals Revive v1 by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/bs_roformer_revive.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/config.yaml?download=true" + }, + + "BS-Roformer_Revive_v2_unwa": { + "category": "Вокал", + "full_name": "BS Roformer Vocals Revive v2 by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/bs_roformer_revive2.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/config.yaml?download=true" + }, + + "BS-Roformer_Revive_v3e_unwa": { + "category": "Вокал", + "full_name": "BS Roformer Vocals Revive v3e by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/bs_roformer_revive3e.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/BS-Roformer-Revive/resolve/main/config.yaml?download=true" + }, + + + "BS-Roformer_Resurrection_unwa": { + "category": "Вокал", + "full_name": "BS Roformer Vocals Resurrection by Unwa", + "stems": ["vocals", "other"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/pcunwa/BS-Roformer-Resurrection/resolve/main/BS-Roformer-Resurrection.ckpt?download=true", + "config_url": "https://huggingface.co/pcunwa/BS-Roformer-Resurrection/resolve/main/BS-Roformer-Resurrection-Config.yaml?download=true" + }, + + + "BS-Roformer_VocTest_gabox": { + "category": "Вокал", + "full_name": "BS Roformer Vocals by GaboxR67", + "stems": ["Vocals", "Instrumental"], + "target_instrument": "vocals", + "checkpoint_url": "https://huggingface.co/GaboxR67/BSRoformerVocTest/resolve/main/voc_gaboxBSR.ckpt?download=true", + "config_url": "https://huggingface.co/GaboxR67/BSRoformerVocTest/resolve/main/voc_gaboxBSroformer.yaml?download=true" + }, + + "BS-Roformer_SW": { + "category": "6 стемов", + "full_name": "BS Roformer SW", + "stems": ["bass", "drums", "other", "piano", "guitar", "vocals"], + "target_instrument": None, + "checkpoint_url": "https://github.com/undef13/splifft/releases/download/v0.0.1/roformer-fp16.pt", + "config_url": "https://huggingface.co/noblebarkrr/all_models_for_mel_band_roformer/resolve/main/BS-Roformer_SW_config.yaml?download=true" + }, + + "BS-Roformer_4_stems_zfturbo": { + "category": "4 стема", + "full_name": "BS Roformer 4 Stems by ZFTurbo", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.12/model_bs_roformer_ep_17_sdr_9.6568.ckpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.12/config_bs_roformer_384_8_2_485100.yaml" + }, + + "BS-Roformer_4_stems_FT_syh99999": { + "category": "4 стема", + "full_name": "BS Roformer 4 Stems FT by SYH99999", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/SYH99999/bs_roformer_4stems_ft/resolve/main/bs_roformer_4stems_ft.pth?download=true", + "config_url": "https://huggingface.co/SYH99999/bs_roformer_4stems_ft/resolve/main/config.yaml?download=true" + }, + + "BS-Roformer_Chorus_Male-Female_146_sucial": { + "category": "Мужской/Женский вокал", + "full_name": "BS Roformer Male-Female (ep 146) by Sucial", + "stems": ["male", "female"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/model_chorus_bs_roformer_ep_146_sdr_23.8613.ckpt?download=true", + "config_url": "https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/config_chorus_male_female_bs_roformer.yaml?download=true" + }, + + "BS-Roformer_Chorus_Male-Female_267_sucial": { + "category": "Мужской/Женский вокал", + "full_name": "BS Roformer Male-Female (ep 267) by Sucial", + "stems": ["male", "female"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/model_chorus_bs_roformer_ep_267_sdr_24.1275.ckpt?download=true", + "config_url": "https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/config_chorus_male_female_bs_roformer.yaml?download=true" + }, + + "BS-Roformer_Male-Female_aufr33": { + "category": "Мужской/Женский вокал", + "full_name": "BS Roformer Male-Female by Aufr33", + "stems": ["male", "female"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/RareSirMix/AIModelRehosting/resolve/main/bs_roformer_male_female_by_aufr33_sdr_7.2889.ckpt", + "config_url": "https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/config_chorus_male_female_bs_roformer.yaml" + }, + + "BS-Roformer_Deverb_256_8_anvuew": { + "category": "Реверб", + "full_name": "BS Roformer Deverb 256-8 by Anvuew", + "stems": ["reverb", "noreverb"], + "target_instrument": "noreverb", + "checkpoint_url": "https://huggingface.co/anvuew/deverb_bs_roformer/resolve/main/deverb_bs_roformer_8_256dim_8depth.ckpt?download=true", + "config_url": "https://huggingface.co/anvuew/deverb_bs_roformer/resolve/main/deverb_bs_roformer_8_256dim_8depth.yaml?download=true" + }, + + "BS-Roformer_Deverb_384_10_anvuew": { + "category": "Реверб", + "full_name": "BS Roformer Deverb 384-10 by Anvuew", + "stems": ["reverb", "noreverb"], + "target_instrument": "noreverb", + "checkpoint_url": "https://huggingface.co/anvuew/deverb_bs_roformer/resolve/main/deverb_bs_roformer_8_384dim_10depth.ckpt?download=true", + "config_url": "https://huggingface.co/anvuew/deverb_bs_roformer/resolve/main/deverb_bs_roformer_8_384dim_10depth.yaml?download=true" + }, + + "BS-Roformer_4_stems_aname": { + "category": "4 стема", + "full_name": "BS Roformer 4 stems by Aname", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/RareSirMix/AIModelRehosting/resolve/main/Amane4stem_bs_roformer.ckpt", + "config_url": "https://huggingface.co/RareSirMix/AIModelRehosting/resolve/main/Amane4stem_bs_roformer.yaml" + } + }, + + "mdx23c": { + + "MDX23C_InstVoc_HQ_zfturbo": { + "category": "Инструментал и вокал", + "full_name": "MDX23C Inst-Voc HQ by ZFTurbo", + "stems": ["vocals", "other"], + "target_instrument": None, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/model_vocals_mdx23c_sdr_10.17.ckpt", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_vocals_mdx23c.yaml" + }, + + "MDX23C_8kFFT_InstVoc_HQ_v1": { + "category": "Инструментал и вокал", + "full_name": "MDX23C 8k FFT Inst-Voc HQ v1", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/MDX23C-8KFFT-InstVoc_HQ.ckpt?download=true", + "config_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/model_2_stem_full_band_8k.yaml?download=true" + }, + + "MDX23C_8kFFT_InstVoc_HQ_v2": { + "category": "Инструментал и вокал", + "full_name": "MDX23C 8k FFT Inst-Voc HQ v2", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/MDX23C-8KFFT-InstVoc_HQ_2.ckpt?download=true", + "config_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/model_2_stem_full_band_8k.yaml?download=true" + }, + + "MDX23C_D1581": { + "category": "Инструментал и вокал", + "full_name": "MDX23C D1581", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/MDX23C_D1581.ckpt?download=true", + "config_url": "https://huggingface.co/Politrees/UVR_resources/resolve/main/models/MDX23C/model_2_stem_061321.yaml?download=true" + }, + + + "MDX23C_DrumSep_aufr33_jarredou": { + "category": "DrumSep", + "full_name": "MDX23C DrumSep by Aufr33 & Jarredou", + "stems": ["kick", "snare", "toms", "hh", "ride", "crash"], + "target_instrument": None, + "checkpoint_url": "https://github.com/jarredou/models/releases/download/aufr33-jarredou_MDX23C_DrumSep_model_v0.1/aufr33-jarredou_DrumSep_model_mdx23c_ep_141_sdr_10.8059.ckpt", + "config_url": "https://github.com/jarredou/models/releases/download/aufr33-jarredou_MDX23C_DrumSep_model_v0.1/aufr33-jarredou_DrumSep_model_mdx23c_ep_141_sdr_10.8059.yaml" + }, + + "MDX23C_DeReverb_aufr33_jarredou": { + "category": "Реверб", + "full_name": "MDX23C DeReverb by Aufr33 & Jarredou", + "stems": ["dry", "other"], + "target_instrument": "dry", + "checkpoint_url": "https://huggingface.co/jarredou/aufr33_jarredou_MDXv3_DeReverb/resolve/main/dereverb_mdx23c_sdr_6.9096.ckpt", + "config_url": "https://huggingface.co/jarredou/aufr33_jarredou_MDXv3_DeReverb/resolve/main/config_dereverb_mdx23c.yaml" + }, + + "MDX23C_Mid_Side_wesleyr36": { + "category": "Фантомный центр", + "full_name": "MDX23C Mid-Side by WesleyR36", + "stems": ["similarity", "difference"], + "target_instrument": "similarity", + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.10/model_mdx23c_ep_271_l1_freq_72.2383.ckpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.10/config_mdx23c_similarity.yaml" + }, + + "MDX23C_4_stems_zfturbo": { + "category": "4 стема", + "full_name": "MDX23C 4 Stems by ZFTurbo", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.1/model_mdx23c_ep_168_sdr_7.0207.ckpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.1/config_musdb18_mdx23c.yaml" + }, + + "MDX23C_Orchestra_verosment": { + "category": "Оркестр", + "full_name": "MDX23C Orchestra Experimental by Verosment", + "stems": ["inst", "orch"], + "target_instrument": "orch", + "checkpoint_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/model_mdx23c_ep_120_sdr_4.4174.ckpt?download=true", + "config_url": "https://huggingface.co/am2460162/msst_failed_failed_test/resolve/main/config_orchestra_mdx23c.yaml?download=true" + } + + }, + + "scnet": { + + "SCNet_4_stems_zfturbo": { + "category": "4 стема", + "full_name": "SCNet 4 Stems by ZFTurbo", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.9/SCNet-large_starrytong_fixed.ckpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.9/config_musdb18_scnet_large_starrytong.yaml" + }, + + "SCNet_XL_IHF_4_stems_zfturbo": { + "category": "4 стема", + "full_name": "SCNet XL IHF 4 Stems by ZFTurbo", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.15/model_scnet_ep_36_sdr_10.0891.ckpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.15/config_musdb18_scnet_xl_more_wide_v5.yaml" + }, + + + "SCNet_XL_4_stems_starrytong": { + "category": "4 стема", + "full_name": "SCNet 4 Stems XL by StarryTong", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.13/model_scnet_ep_54_sdr_9.8051.ckpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.13/config_musdb18_scnet_xl.yaml" + }, + + "SCNet_XL_4_stems_zftrubo": { + "category": "4 стема", + "full_name": "SCNet 4 Stems XL by ZFTurbo", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.6/scnet_checkpoint_musdb18.ckpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.6/config_musdb18_scnet.yaml" + }, + + "SCNet_Large_Jazz_4_stems_jorisvaneyghen": { + "category": "4 стема", + "full_name": "SCNet Large Jazz model by Joris Vaneyghen", + "stems": ["piano", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/jorisvaneyghen/SCNet/resolve/main/model_jazz_scnet_large.ckpt?download=true", + "config_url": "https://huggingface.co/spaces/jorisvaneyghen/jazz_playalong/resolve/main/configs/config_jazz_scnet_large.yaml?download=true" + }, + + "SCNet_XL_Jazz_4_stems_jorisvaneyghen": { + "category": "4 стема", + "full_name": "SCNet XL Jazz model by Joris Vaneyghen", + "stems": ["piano", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/jorisvaneyghen/SCNet/resolve/main/model_jazz_scnet_xl.ckpt?download=true", + "config_url": "https://huggingface.co/spaces/jorisvaneyghen/jazz_playalong/resolve/main/configs/config_jazz_scnet_xl.yaml?download=true" + } + + }, + + "vr": { + + "1_HP-UVR": { + "category": "Инструментал", + "full_name": "VR Arch Single Model v5: 1_HP-UVR", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "2_HP-UVR": { + "category": "Инструментал", + "full_name": "VR Arch Single Model v5: 2_HP-UVR", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "3_HP-Vocal-UVR": { + "category": "Вокал", + "full_name": "VR Arch Single Model v5: 3_HP-Vocal-UVR", + "stems": ["Vocals", "Instrumental"], + "custom_vr": False, + "target_instrument": None + }, + + "4_HP-Vocal-UVR": { + "category": "Вокал", + "full_name": "VR Arch Single Model v5: 4_HP-Vocal-UVR", + "stems": ["Vocals", "Instrumental"], + "custom_vr": False, + "target_instrument": None + }, + + "5_HP-Karaoke-UVR": { + "category": "Караоке", + "full_name": "VR Arch Single Model v5: 5_HP-Karaoke-UVR", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "6_HP-Karaoke-UVR": { + "category": "Караоке", + "full_name": "VR Arch Single Model v5: 6_HP-Karaoke-UVR", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "7_HP2-UVR": { + "category": "Инструментал и вокал", + "full_name": "VR Arch Single Model v5: 7_HP2-UVR", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "8_HP2-UVR": { + "category": "Инструментал и вокал", + "full_name": "VR Arch Single Model v5: 8_HP2-UVR", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "9_HP2-UVR": { + "category": "Инструментал и вокал", + "full_name": "VR Arch Single Model v5: 9_HP2-UVR", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "10_SP-UVR-2B-32000-1": { + "category": "Инструментал и вокал", + "full_name": "VR Arch Single Model v5: 10_SP-UVR-2B-32000-1", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "11_SP-UVR-2B-32000-2": { + "category": "Инструментал и вокал", + "full_name": "VR Arch Single Model v5: 11_SP-UVR-2B-32000-2", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "12_SP-UVR-3B-44100": { + "category": "Инструментал и вокал", + "full_name": "VR Arch Single Model v5: 12_SP-UVR-3B-44100", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "13_SP-UVR-4B-44100-1": { + "category": "Инструментал и вокал", + "full_name": "VR Arch Single Model v5: 13_SP-UVR-4B-44100-1", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "14_SP-UVR-4B-44100-2": { + "category": "Инструментал и вокал", + "full_name": "VR Arch Single Model v5: 14_SP-UVR-4B-44100-2", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "15_SP-UVR-MID-44100-1": { + "category": "Инструментал и вокал", + "full_name": "VR Arch Single Model v5: 15_SP-UVR-MID-44100-1", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "16_SP-UVR-MID-44100-2": { + "category": "Инструментал и вокал", + "full_name": "VR Arch Single Model v5: 16_SP-UVR-MID-44100-2", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "17_HP-Wind_Inst-UVR": { + "category": "Деревянные духовые", + "full_name": "VR Arch Single Model v5: 17_HP-Wind_Inst-UVR", + "stems": ["No Woodwinds", "Woodwinds"], + "custom_vr": False, + "target_instrument": None + }, + + "UVR-De-Echo-Aggressive": { + "category": "Эхо", + "full_name": "VR Arch Single Model v5: UVR-De-Echo-Aggressive by FoxJoy", + "stems": ["No Echo", "Echo"], + "custom_vr": False, + "target_instrument": None + }, + + "UVR-De-Echo-Normal": { + "category": "Эхо", + "full_name": "VR Arch Single Model v5: UVR-De-Echo-Normal by FoxJoy", + "stems": ["No Echo", "Echo"], + "custom_vr": False, + "target_instrument": None + }, + + "UVR-DeEcho-DeReverb": { + "category": "Реверб", + "full_name": "VR Arch Single Model v5: UVR-DeEcho-DeReverb by FoxJoy", + "stems": ["No Reverb", "Reverb"], + "custom_vr": False, + "target_instrument": None + }, + + "UVR-DeNoise-Lite": { + "category": "Шум", + "full_name": "VR Arch Single Model v5: UVR-DeNoise-Lite by FoxJoy", + "stems": ["Noise", "No Noise"], + "custom_vr": False, + "target_instrument": None + }, + + "UVR-DeNoise": { + "category": "Шум", + "full_name": "VR Arch Single Model v5: UVR-DeNoise by FoxJoy", + "stems": ["Noise", "No Noise"], + "custom_vr": False, + "target_instrument": None + }, + + "UVR-BVE-4B_SN-44100-1": { + "category": "Караоке", + "full_name": "VR Arch Single Model v5: UVR-BVE-4B_SN-44100", + "stems": ["Vocals", "Instrumental"], + "custom_vr": False, + "target_instrument": None + }, + + "UVR-BVE-v2-4B-SN-44100": { + "category": "Караоке", + "full_name": "VR Arch Single Model v4: UVR-BVE-v2-4B-SN-44100", + "stems": ["Vocals", "Instrumental"], + "custom_vr": True, + "primary_stem": "Vocals", + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/RareSirMix/AIModelRehosting/resolve/main/UVR-5-1_4band_v4_ms_fullband_BVE_v2_by_aufr33.pth?download=true", + "config_url": "https://huggingface.co/RareSirMix/AIModelRehosting/resolve/main/4band_v4_ms_fullband.yaml?download=true" + }, + + "MGM-v5-KAROKEE-32000-BETA1": { + "category": "Караоке", + "full_name": "VR Arch Single Model v5: MGM-v5-KAROKEE-32000-BETA1", + "stems": ["Vocals", "Instrumental"], + "custom_vr": True, + "primary_stem": "Instrumental", + "target_instrument": None, + "checkpoint_url": "https://github.com/lucassantilli/UVR-Colab-GUI/releases/download/m5.1/MGM-v5-KAROKEE-32000-BETA1.pth", + "config_url": "https://github.com/lucassantilli/UVR-Colab-GUI/raw/refs/heads/main/modelparams/2band_32000.json" + }, + + "MGM-v5-KAROKEE-32000-BETA2-AGR": { + "category": "Караоке", + "full_name": "VR Arch Single Model v5: MGM-v5-KAROKEE-32000-BETA2-AGR.pth", + "stems": ["Vocals", "Instrumental"], + "custom_vr": True, + "primary_stem": "Instrumental", + "target_instrument": None, + "checkpoint_url": "https://github.com/lucassantilli/UVR-Colab-GUI/releases/download/m5.1/MGM-v5-KAROKEE-32000-BETA2-AGR.pth", + "config_url": "https://github.com/lucassantilli/UVR-Colab-GUI/raw/refs/heads/main/modelparams/2band_32000_agr.json" + }, + + + + + "MGM_HIGHEND_v4": { + "category": "Инструментал и вокал", + "full_name": "VR Arch Single Model v4: MGM_HIGHEND_v4", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "MGM_LOWEND_A_v4": { + "category": "Инструментал и вокал", + "full_name": "VR Arch Single Model v4: MGM_LOWEND_A_v4", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "MGM_LOWEND_B_v4": { + "category": "Инструментал и вокал", + "full_name": "VR Arch Single Model v4: MGM_LOWEND_B_v4", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "MGM_MAIN_v4": { + "category": "Инструментал и вокал", + "full_name": "VR Arch Single Model v4: MGM_MAIN_v4", + "stems": ["Instrumental", "Vocals"], + "custom_vr": False, + "target_instrument": None + }, + + "UVR-De-Reverb-aufr33-jarredou": { + "category": "Реверб", + "full_name": "VR Arch Single Model v4: UVR-De-Reverb by aufr33-jarredou", + "stems": ["Dry", "No Dry"], + "custom_vr": False, + "target_instrument": None + }, + + "UVR-De-Breath-sucial-v1": { + "category": "Дыхание", + "full_name": "VR Arch Single Model v4: UVR-De-Breath v1 by Sucial", + "stems": ["Breath", "No Breath"], + "custom_vr": True, + "primary_stem": "Breath", + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/Sucial/De-Breathe-Models/resolve/main/UVR_De-Breathe_1band_sr44100_hl1024_Sucial_v1.pth?download=true", + "config_url": "https://huggingface.co/Sucial/De-Breathe-Models/resolve/main/1band_sr44100_hl1024.json?download=true" + }, + + "UVR-De-Breath-sucial-v2": { + "category": "Дыхание", + "full_name": "VR Arch Single Model v4: UVR-De-Breath v2 by Sucial", + "stems": ["Breath", "No Breath"], + "custom_vr": True, + "primary_stem": "Breath", + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/Sucial/De-Breathe-Models/resolve/main/UVR_De-Breathe_1band_sr44100_hl1024_Sucial_v2.pth?download=true", + "config_url": "https://huggingface.co/Sucial/De-Breathe-Models/resolve/main/1band_sr44100_hl1024.json?download=true" + }, + + "VR_Harmonic_Noise_Sep": { + "category": "Дыхание", + "full_name": "VR Arch Single Model v5: Harmonic_Noise_Sep", + "stems": ["Noise", "No Noise"], + "custom_vr": True, + "primary_stem": "Noise", + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/Sucial/MSST-WebUI/resolve/main/All_Models/VR_Models/Harmonic_Noise_Separation_yxlllc.pth?download=true", + "config_url": "https://github.com/SUC-DriverOld/MSST-WebUI/raw/refs/heads/main/configs_backup/vr_modelparams/1band_sr44100_hl1024.json" + } + + }, + + "mdx": { + + "UVR-MDX-NET-Inst_HQ_1": { + "category": "Инструментал", + "full_name": "MDX-Net Model: UVR-MDX-NET Inst HQ 1", + "stems": ["Instrumental", "Vocals"], + "target_instrument": None + }, + + "UVR-MDX-NET-Inst_HQ_2": { + "category": "Инструментал", + "full_name": "MDX-Net Model: UVR-MDX-NET Inst HQ 2", + "stems": ["Instrumental", "Vocals"], + "target_instrument": None + }, + + "UVR-MDX-NET-Inst_HQ_3": { + "category": "Инструментал", + "full_name": "MDX-Net Model: UVR-MDX-NET Inst HQ 3", + "stems": ["Instrumental", "Vocals"], + "target_instrument": None + }, + + "UVR-MDX-NET-Inst_HQ_4": { + "category": "Инструментал", + "full_name": "MDX-Net Model: UVR-MDX-NET Inst HQ 4", + "stems": ["Instrumental", "Vocals"], + "target_instrument": None + }, + + "UVR-MDX-NET-Inst_HQ_5": { + "category": "Инструментал", + "full_name": "MDX-Net Model: UVR-MDX-NET Inst HQ 5", + "stems": ["Instrumental", "Vocals"], + "target_instrument": None + }, + + "UVR_MDXNET_Main": { + "category": "Инструментал и вокал", + "full_name": "MDX-Net Model: UVR-MDX-NET Main", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None + }, + + "UVR-MDX-NET-Inst_Main": { + "category": "Инструментал", + "full_name": "MDX-Net Model: UVR-MDX-NET Inst Main", + "stems": ["Instrumental", "Vocals"], + "target_instrument": None + }, + + "UVR_MDXNET_1_9703": { + "category": "Инструментал и вокал", + "full_name": "MDX-Net Model: UVR-MDX-NET 1", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None + }, + + "UVR_MDXNET_2_9682": { + "category": "Инструментал и вокал", + "full_name": "MDX-Net Model: UVR-MDX-NET 2", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None + }, + + "UVR_MDXNET_3_9662": { + "category": "Инструментал и вокал", + "full_name": "MDX-Net Model: UVR-MDX-NET 3", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None + }, + + "UVR-MDX-NET-Inst_1": { + "category": "Инструментал", + "full_name": "MDX-Net Model: UVR-MDX-NET Inst 1", + "stems": ["Instrumental", "Vocals"], + "target_instrument": None + }, + + "UVR-MDX-NET-Inst_2": { + "category": "Инструментал", + "full_name": "MDX-Net Model: UVR-MDX-NET Inst 2", + "stems": ["Instrumental", "Vocals"], + "target_instrument": None + }, + + "UVR-MDX-NET-Inst_3": { + "category": "Инструментал", + "full_name": "MDX-Net Model: UVR-MDX-NET Inst 3", + "stems": ["Instrumental", "Vocals"], + "target_instrument": None + }, + + "UVR_MDXNET_KARA": { + "category": "Караоке", + "full_name": "MDX-Net Model: UVR-MDX-NET Karaoke", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None + }, + + "UVR_MDXNET_KARA_2": { + "category": "Караоке", + "full_name": "MDX-Net Model: UVR-MDX-NET Karaoke 2", + "stems": ["Instrumental", "Vocals"], + "target_instrument": None + }, + + "UVR_MDXNET_9482": { + "category": "Инструментал и вокал", + "full_name": "MDX-Net Model: UVR_MDXNET_9482", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None + }, + + "UVR-MDX-NET-Voc_FT": { + "category": "Вокал", + "full_name": "MDX-Net Model: UVR-MDX-NET Voc FT", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None + }, + + "Kim_Vocal_1": { + "category": "Вокал", + "full_name": "MDX-Net Model: Kim Vocal 1", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None + }, + + "Kim_Vocal_2": { + "category": "Вокал", + "full_name": "MDX-Net Model: Kim Vocal 2", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None + }, + + "Kim_Inst": { + "category": "Инструментал", + "full_name": "MDX-Net Model: Kim Inst", + "stems": ["Instrumental", "Vocals"], + "target_instrument": None + }, + + "Reverb_HQ_By_FoxJoy": { + "category": "Реверб", + "full_name": "MDX-Net Model: Reverb HQ By FoxJoy", + "stems": ["Reverb", "No Reverb"], + "target_instrument": None + }, + + "UVR-MDX-NET_Crowd_HQ_1": { + "category": "Звуки толпы", + "full_name": "MDX-Net Model: UVR-MDX-NET Crowd HQ 1 By Aufr33", + "stems": ["No Crowd", "Crowd"], + "target_instrument": None + }, + + "kuielab_a_vocals": { + "category": "Вокал", + "full_name": "MDX-Net Model: kuielab_a_vocals", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None + }, + + "kuielab_a_other": { + "category": "Прочее", + "full_name": "MDX-Net Model: kuielab_a_other", + "stems": ["Other", "No Other"], + "target_instrument": None + }, + + "kuielab_a_bass": { + "category": "Басс", + "full_name": "MDX-Net Model: kuielab_a_bass", + "stems": ["Bass", "No Bass"], + "target_instrument": None + }, + + "kuielab_a_drums": { + "category": "Ударные", + "full_name": "MDX-Net Model: kuielab_a_drums", + "stems": ["Drums", "No Drums"], + "target_instrument": None + }, + + "kuielab_b_vocals": { + "category": "Вокал", + "full_name": "MDX-Net Model: kuielab_b_vocals", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None + }, + + "kuielab_b_other": { + "category": "Прочее", + "full_name": "MDX-Net Model: kuielab_b_other", + "stems": ["Other", "No Other"], + "target_instrument": None + }, + + "kuielab_b_bass": { + "category": "Басс", + "full_name": "MDX-Net Model: kuielab_b_bass", + "stems": ["Bass", "No Bass"], + "target_instrument": None + }, + + "kuielab_b_drums": { + "category": "Ударные", + "full_name": "MDX-Net Model: kuielab_b_drums", + "stems": ["Drums", "No Drums"], + "target_instrument": None + }, + + "UVR-MDX-NET_Main_340": { + "category": "Инструментал и вокал", + "full_name": "MDX-Net Model VIP: UVR-MDX-NET_Main_340", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None + }, + + "UVR-MDX-NET_Main_390": { + "category": "Инструментал и вокал", + "full_name": "MDX-Net Model VIP: UVR-MDX-NET_Main_390", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None + }, + + "UVR-MDX-NET_Main_406": { + "category": "Инструментал и вокал", + "full_name": "MDX-Net Model VIP: UVR-MDX-NET_Main_406", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None + }, + + "UVR-MDX-NET_Main_427": { + "category": "Инструментал и вокал", + "full_name": "MDX-Net Model VIP: UVR-MDX-NET_Main_427", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None + }, + + "UVR-MDX-NET_Main_438": { + "category": "Инструментал и вокал", + "full_name": "MDX-Net Model VIP: UVR-MDX-NET_Main_438", + "stems": ["Vocals", "Instrumental"], + "target_instrument": None + }, + + "UVR-MDX-NET_Inst_82_beta": { + "category": "Инструментал", + "full_name": "MDX-Net Model VIP: UVR-MDX-NET_Inst_82_beta", + "stems": ["Instrumental", "Vocals"], + "target_instrument": None + }, + + "UVR-MDX-NET_Inst_90_beta": { + "category": "Инструментал", + "full_name": "MDX-Net Model VIP: UVR-MDX-NET_Inst_90_beta", + "stems": ["Instrumental", "Vocals"], + "target_instrument": None + }, + + "UVR-MDX-NET_Inst_187_beta": { + "category": "Инструментал", + "full_name": "MDX-Net Model VIP: UVR-MDX-NET_Inst_187_beta", + "stems": ["Instrumental", "Vocals"], + "target_instrument": None + }, + + "UVR-MDX-NET-Inst_full_292": { + "category": "Инструментал", + "full_name": "MDX-Net Model VIP: UVR-MDX-NET-Inst_full_292", + "stems": ["Instrumental", "Vocals"], + "target_instrument": None + } + + }, + + "htdemucs": { + + "HTDemucs4_MVSep_vocals": { + "category": "Вокал", + "full_name": "HTDemucs4 (MVSep finetuned)", + "stems": ["vocals", "other"], + "target_instrument": None, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/model_vocals_htdemucs_sdr_8.78.ckpt", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_vocals_htdemucs.yaml" + }, + + "HTDemucs4": { + "category": "4 стема", + "full_name": "HTDemucs4", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/955717e8-8726e21a.th", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml" + }, + + "HTDemucs4_6s": { + "category": "6 стемов", + "full_name": "HTDemucs4 (6 stems)", + "stems": ["vocals", "drums", "bass", "other", "guitar", "piano"], + "target_instrument": None, + "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/5c90dfd2-34c22ccb.th", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_htdemucs_6stems.yaml" + }, + + "Demucs3_mmi": { + "category": "4 стема", + "full_name": "Demucs3 mmi", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/75fc33f5-1941ce65.th", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_demucs3_mmi.yaml" + }, + + "HTDemucs4_FT_Bass": { + "category": "Басс", + "full_name": "HTDemucs4 FT Bass", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/d12395a8-e57c48e6.th", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml" + }, + + "HTDemucs4_FT_Drums": { + "category": "Ударные", + "full_name": "HTDemucs4 FT Drums", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml" + }, + + "HTDemucs4_FT_Vocals": { + "category": "Вокал", + "full_name": "HTDemucs4 FT Vocals", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml" + }, + + "HTDemucs4_FT_Other": { + "category": "Прочее", + "full_name": "HTDemucs4 FT Other", + "stems": ["vocals", "drums", "bass", "other"], + "target_instrument": None, + "checkpoint_url": "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/92cfc3b6-ef3bcb9c.th", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml" + }, + + "HTDemucs4_Mid_Side_wesleyr36": { + "category": "Фантомный центр", + "full_name": "HTDemucs4 MId-Side by wesleyr36", + "stems": ["similarity", "difference"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/jarredou/HTDemucs_Similarity_Extractor_by_wesleyr36/resolve/main/model_htdemucs_ep_21_sdr_13.6970.ckpt?download=true", + "config_url": "https://huggingface.co/jarredou/HTDemucs_Similarity_Extractor_by_wesleyr36/resolve/main/config_htdemucs_similarity.yaml?download=true" + } + + }, + + "bandit": { + + "Bandit_Plus": { + "category": "Кинематограф", + "full_name": "Bandit Plus: Cinematic Bandit Plus (by kwatcharasupat)", + "stems": ["speech", "music", "effects"], + "target_instrument": None, + "checkpoint_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.3/model_bandit_plus_dnr_sdr_11.47.chpt", + "config_url": "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.3/config_dnr_bandit_bsrnn_multi_mus64.yaml" + }, + + }, + + "bandit_v2": { + + "Bandit_v2_Multi": { + "category": "Кинематограф", + "full_name": "Bandit v2: Cinematic Bandit v2 Multilang (by kwatcharasupat)", + "stems": ["speech", "music", "sfx"], + "target_instrument": None, + "checkpoint_url": "https://huggingface.co/jarredou/banditv2_state_dicts_only/resolve/main/checkpoint-multi_state_dict.ckpt", + "config_url": "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/refs/heads/main/configs/config_dnr_bandit_v2_mus64.yaml" + }, + + } + +} + +medley_vox_models = { + + "multi_singing_librispeech": { + "checkpoint_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/multi_singing_librispeech/vocals.pth?download=true", + "config_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/multi_singing_librispeech/vocals.json?download=true" + }, + + "multi_singing_librispeech_138": { + "checkpoint_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/multi_singing_librispeech_138/vocals.pth?download=true", + "config_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/multi_singing_librispeech_138/vocals.json?download=true" + }, + + "singing_librispeech_ft_isrnet": { + "checkpoint_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/singing_librispeech_ft_iSRNet/vocals.pth?download=true", + "config_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/singing_librispeech_ft_iSRNet/vocals.json?download=true" + }, + + "singing_librispeech_isrnet": { + "checkpoint_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/singing_librispeech_iSRNet/vocals.pth?download=true", + "config_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/singing_librispeech_iSRNet/vocals.json?download=true" + }, + + "vocal_231": { + "checkpoint_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/vocal%20231/vocals.pth?download=true", + "config_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/vocal%20231/vocals.json?download=true" + }, + + "vocals_135": { + "checkpoint_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/vocals%20135/vocals.pth?download=true", + "config_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/vocals%20135/vocals.json?download=true" + }, + + "vocals_163": { + "checkpoint_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/vocals%20163/vocals.pth?download=true", + "config_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/vocals%20163/vocals.json?download=true" + }, + + "vocals_188": { + "checkpoint_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/vocals%20188/vocals.pth?download=true", + "config_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/vocals%20188/vocals.json?download=true" + }, + + "vocals_200": { + "checkpoint_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/vocals%20200/vocals.pth?download=true", + "config_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/vocals%20200/vocals.json?download=true" + }, + + "vocals_238": { + "checkpoint_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/vocals%20238/vocals.pth?download=true", + "config_url": "https://huggingface.co/Cyru5/MedleyVox/resolve/main/vocals%20238/vocals.json?download=true" + } + +} + + + + + + + + + + + diff --git a/multi_inference.py b/multi_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..026bd9e7f72559162a41233a47b2bd93c6b0533d --- /dev/null +++ b/multi_inference.py @@ -0,0 +1,303 @@ +import os +import time +import shutil +import sys +import gc +import argparse +import json +import subprocess +from datetime import datetime +from tabulate import tabulate + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(SCRIPT_DIR) +os.chdir(SCRIPT_DIR) + +from model_list import models_data +from utils.preedit_config import conf_editor +from utils.download_models import download_model + +MODELS_CACHE_DIR = os.path.join(SCRIPT_DIR, "separator", "models_cache") +MODEL_TYPES = ["mel_band_roformer", "bs_roformer", "mdx23c", "scnet", "htdemucs", "bandit", "bandit_v2", "vr", "mdx"] +OUTPUT_FORMATS = ["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff"] + +class MVSEPLESS: + def __init__(self): + self.models_cache_dir = os.path.join(SCRIPT_DIR, "separator", "models_cache") + self.model_types = MODEL_TYPES + self.output_formats = OUTPUT_FORMATS + + def get_mt(self): + return list(models_data.keys()) + + def get_mn(self, model_type): + return list(models_data[model_type].keys()) + + def get_stems(self, model_type, model_name): + stems = models_data[model_type][model_name]["stems"] + return stems + + def get_tgt_inst(self, model_type, model_name): + target_instrument = models_data[model_type][model_name]["target_instrument"] + return target_instrument + + def display_models_info(self, filter: str = None): + print("\nAvailable Models Information:") + print("=" * 50) + + for model_type in models_data: + print(f"\nModel Type: {model_type.upper()}") + print("-" * 50) + + table_data = [] + headers = ["Model Name", "Stems", "Target Instrument", "Primary Stem"] + + for model_name in models_data[model_type]: + model_info = models_data[model_type][model_name] + + if filter and filter not in model_info.get('stems', []): + continue + + stems = "\n".join(model_info.get('stems', [])) if 'stems' in model_info else "N/A" + target = model_info.get('target_instrument', "N/A") + primary = model_info.get('primary_stem', "N/A") + + table_data.append([model_name, stems, target, primary]) + + print(tabulate(table_data, headers=headers, tablefmt="grid")) + print() + + def separator( + self, + input_file: str = None, + output_dir: str = None, + model_type: str = "mel_band_roformer", + model_name: str = "Mel-Band-Roformer_Vocals_kimberley_jensen", + ext_inst: bool = False, + vr_aggr: int = 5, + output_format: str = "wav", + output_bitrate: str = "320k", + template: str = "NAME_(STEM)_MODEL", + call_method: str = "cli", + selected_stems: list = None + ): + if selected_stems is None: + selected_stems = [] + + if not input_file: + print("Please, input path to input file") + return [("None", "/none/none.mp3")] + + if not os.path.exists(input_file): + print("Input file not exist") + return [("None", "/none/none.mp3")] + + if "STEM" not in template: + template = template + "_STEM" + + print(f"Starting inference: {model_type}/{model_name}, bitrate={output_bitrate}, method={call_method}, stems={selected_stems}") + os.makedirs(output_dir, exist_ok=True) + + if model_type in ["mel_band_roformer", "bs_roformer", "mdx23c", "scnet", "htdemucs", "bandit", "bandit_v2"]: + try: + info = models_data[model_type][model_name] + except KeyError: + print("Model not exist") + return [("None", "/none/none.mp3")] + + conf, ckpt = download_model(self.models_cache_dir, model_name, model_type, + info["checkpoint_url"], info["config_url"]) + if model_type != "htdemucs": + conf_editor(conf) + + if call_method == "cli": + cmd = ["python", "-m", "separator.msst_separator", f'--input "{input_file}"', + f'--store_dir "{output_dir}"', f'--model_type "{model_type}"', + f'--model_name "{model_name}"', f'--config_path "{conf}"', + f'--start_check_point "{ckpt}"', f'--output_format "{output_format}"', + f'--output_bitrate "{output_bitrate}"', f'--template "{template}"', + "--save_results_info"] + if ext_inst: + cmd.append("--extract_instrumental") + if selected_stems: + instruments = " ".join(f'"{s}"' for s in selected_stems) + cmd.append(f'--selected_instruments {instruments}') + subprocess.run(" ".join(cmd), shell=True, check=True) + + results_path = os.path.join(output_dir, "results.json") + if os.path.exists(results_path): + with open(results_path, encoding="utf-8") as f: + return json.load(f) + return [("None", "/none/none.mp3")] + + elif call_method == "direct": + from separator.msst_separator import mvsep_offline + try: + return mvsep_offline( + input_path=input_file, store_dir=output_dir, model_type=model_type, + config_path=conf, start_check_point=ckpt, extract_instrumental=ext_inst, + output_format=output_format, output_bitrate=output_bitrate, + model_name=model_name, template=template, selected_instruments=selected_stems + ) + except Exception as e: + print(e) + return [("None", "/none/none.mp3")] + + elif model_type in ["vr", "mdx"]: + try: + info = models_data[model_type][model_name] + except KeyError: + print("Model not exist") + return [("None", "/none/none.mp3")] + + if model_type == "vr" and info.get("custom_vr", False): + conf, ckpt = download_model(self.models_cache_dir, model_name, model_type, + info["checkpoint_url"], info["config_url"]) + primary_stem = info["primary_stem"] + + if call_method == "cli": + cmd = ["python", "-m", "separator.uvr_sep", "custom_vr", + f'--input_file "{input_file}"', f'--ckpt_path "{ckpt}"', + f'--config_path "{conf}"', f'--bitrate "{output_bitrate}"', + f'--model_name "{model_name}"', f'--template "{template}"', + f'--output_format "{output_format}"', f'--primary_stem "{primary_stem}"', + f'--aggression {vr_aggr}', f'--output_dir "{output_dir}"'] + if selected_stems: + instruments = " ".join(f'"{s}"' for s in selected_stems) + cmd.append(f'--selected_instruments {instruments}') + subprocess.run(" ".join(cmd), shell=True, check=True) + + results_path = os.path.join(output_dir, "results.json") + if os.path.exists(results_path): + with open(results_path, encoding="utf-8") as f: + return json.load(f) + return [("None", "/none/none.mp3")] + + elif call_method == "direct": + from separator.uvr_sep import custom_vr_separate + try: + return custom_vr_separate( + input_file=input_file, ckpt_path=ckpt, config_path=conf, + bitrate=output_bitrate, model_name=model_name, template=template, + output_format=output_format, primary_stem=primary_stem, + aggression=vr_aggr, output_dir=output_dir, + selected_instruments=selected_stems + ) + except Exception as e: + print(e) + return [("None", "/none/none.mp3")] + else: + if call_method == "cli": + cmd = ["python", "-m", "separator.uvr_sep", "uvr", + f'--input_file "{input_file}"', f'--output_dir "{output_dir}"', + f'--template "{template}"', f'--bitrate "{output_bitrate}"', + f'--model_dir "{self.models_cache_dir}"', f'--model_type "{model_type}"', + f'--model_name "{model_name}"', f'--output_format "{output_format}"', + f'--aggression {vr_aggr}'] + if selected_stems: + instruments = " ".join(f'"{s}"' for s in selected_stems) + cmd.append(f'--selected_instruments {instruments}') + subprocess.run(" ".join(cmd), shell=True, check=True) + + results_path = os.path.join(output_dir, "results.json") + if os.path.exists(results_path): + with open(results_path, encoding="utf-8") as f: + return json.load(f) + return [("None", "/none/none.mp3")] + + elif call_method == "direct": + from separator.uvr_sep import non_custom_uvr_inference + try: + return non_custom_uvr_inference( + input_file=input_file, output_dir=output_dir, template=template, + bitrate=output_bitrate, model_dir=self.models_cache_dir, + model_type=model_type, model_name=model_name, + output_format=output_format, aggression=vr_aggr, + selected_instruments=selected_stems + ) + except Exception as e: + print(e) + return [("None", "/none/none.mp3")] + + print("Unsupported model type") + return [("None", "/none/none.mp3")] + +def parse_args(): + parser = argparse.ArgumentParser(description="Multi-inference for separation audio in Google Colab") + subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-command help') + + list_models = subparsers.add_parser('list', help='List of exist models') + list_models.add_argument("-l_filter", "--list_filter", type=str, default=None, help="Show models in list only with specified stem") + + separate = subparsers.add_parser('separate', help='Separate I/O params') + separate.add_argument("-i", "--input", type=str, required=True, help="Input file or directory") + separate.add_argument("-o", "--output", type=str, required=True, help="Output directory") + separate.add_argument("-mt", "--model_type", type=str, required=True, choices=MODEL_TYPES, help="Model type") + separate.add_argument("-mn", "--model_name", type=str, required=True, help="Model name") + separate.add_argument("-inst", "--instrumental", action='store_true', help="Extract instrumental") + separate.add_argument("-stems", "--stems", nargs="+", help="Select output stems") + separate.add_argument("-bitrate", "--bitrate", type=str, default="320k", help="Output bitrate") + separate.add_argument("-of", "--format", type=str, default="mp3", help="Output format") + separate.add_argument("-vr_aggr", "--vr_arch_aggressive", type=int, default=5, help="Aggression for VR ARCH models") + separate.add_argument('--template', type=str, default='NAME_STEM', help='Template naming of output files') + separate.add_argument("-l_out", "--list_output", action='store_true', help="Show list output files") + + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + mvsepless = MVSEPLESS() + + if args.command == 'list': + mvsepless.display_models_info(args.list_filter) + + elif args.command == 'separate': + if os.path.isfile(args.input): + results = mvsepless.separator( + input_file=args.input, + output_dir=args.output, + model_type=args.model_type, + model_name=args.model_name, + ext_inst=args.instrumental, + vr_aggr=args.vr_arch_aggressive, + output_format=args.format, + output_bitrate=args.bitrate, + template=args.template, + call_method="cli", + selected_stems=args.stems + ) + if args.list_output: + print("Results\n") + for stem, path in results: + print(f"Stem - {stem}\nPath - {path}\n") + + elif os.path.isdir(args.input): + batch_results = [] + for file in os.listdir(args.input): + abs_path_file = os.path.join(args.input, file) + if os.path.isfile(abs_path_file): + base_name = os.path.splitext(os.path.basename(abs_path_file))[0] + output_subdir = os.path.join(args.output, base_name) + + results = mvsepless.separator( + input_file=abs_path_file, + output_dir=output_subdir, + model_type=args.model_type, + model_name=args.model_name, + ext_inst=args.instrumental, + vr_aggr=args.vr_arch_aggressive, + output_format=args.format, + output_bitrate=args.bitrate, + template=args.template, + call_method="cli", + selected_stems=args.stems + ) + batch_results.append((base_name, results)) + + if args.list_output: + print("Results\n") + for name, stems in batch_results: + print(f"Name - {name}") + for stem, path in stems: + print(f" Stem - {stem}\n Path - {path}\n") + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..501c5af997247d6a126fe1480af36d36eb818da4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,50 @@ +torch==2.6.0 +torchvision==0.21.0 +torchaudio==2.6.0 +torchcrepe==0.0.23 +numpy==2.0.2 +pandas==2.2.2 +scipy==1.15.3 +librosa==0.9.1 +matplotlib==3.9.0 +tqdm==4.67.1 +einops==0.8.1 +protobuf==5.29.4 +soundfile==0.13.1 +pydub==0.25.1 +pyloudnorm==0.1.1 +praat-parselmouth==0.4.5 +webrtcvad==2.0.10 +edge-tts==7.0.2 +audiomentations==0.24.0 +pedalboard==0.8.1 +ffmpeg-python==0.2.0 +faiss-cpu==1.11 +ml_collections==1.1.0 +timm==1.0.15 +wandb==0.19.11 +accelerate==1.7.0 +bitsandbytes==0.46.0 +tokenizers==0.19 +huggingface-hub==0.28.1 +transformers==4.41 +https://github.com/noblebarkrr/mvsepless/blob/bd611441e48e918650e6860738894673b3a1a5f1/fixed/fairseq_fixed-0.13.0-cp311-cp311-linux_x86_64.whl +torchseg==0.0.1a4 +demucs==4.0.0 +asteroid==0.7.0 +prodigyopt==1.1.2 +torch_log_wmse==0.3.0 +rotary_embedding_torch==0.6.5 +local-attention==1.11.1 +tenacity==9.1.2 +gradio==5.38.2 +omegaconf==2.3.0 +beartype==0.18.5 +spafe==0.3.2 +torch_audiomentations==0.12.0 +auraloss==0.4.0 +onnxruntime-gpu>=1.17 +yt_dlp +python-magic +pyngrok + diff --git a/separator/audio_writer.py b/separator/audio_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..78eb1e60b9508a9659acc7d507a179c66ea8e1bc --- /dev/null +++ b/separator/audio_writer.py @@ -0,0 +1,85 @@ +from pydub import AudioSegment +import numpy as np + +def write_audio_file(output_file_path, numpy_array, sample_rate, output_format, bitrate): + """ + Записывает аудиофайл из numpy массива в указанном формате с помощью pydub. + + Параметры: + output_file_path (str): Путь для сохранения файла (без расширения) + numpy_array (numpy.ndarray): Аудиоданные в виде numpy массива + sample_rate (int): Частота дискретизации (в Гц) + output_format (str): Формат выходного файла ('mp3', 'flac', 'wav', 'aiff', 'm4a', 'aac', 'ogg', 'opus') + encoder_settings (dict, optional): Cловарь с настройками кодировки аудио + """ + try: + # Проверка и нормализация входных данных + if not isinstance(numpy_array, np.ndarray): + raise ValueError("Input must be a numpy array") + + # Преобразование в правильную форму (samples, channels) + if len(numpy_array.shape) == 1: + numpy_array = numpy_array.reshape(-1, 1) # Моно + elif len(numpy_array.shape) == 2: + if numpy_array.shape[0] == 2: # Если (channels, samples) + numpy_array = numpy_array.T # Транспонируем в (samples, channels) + else: + raise ValueError("Input array must be 1D or 2D") + + # Нормализация до диапазона [-1.0, 1.0] если нужно + if np.issubdtype(numpy_array.dtype, np.floating): + numpy_array = np.clip(numpy_array, -1.0, 1.0) + numpy_array = (numpy_array * 32767).astype(np.int16) + elif numpy_array.dtype != np.int16: + numpy_array = numpy_array.astype(np.int16) + + # Создание AudioSegment + if numpy_array.shape[1] == 1: # Моно + audio_segment = AudioSegment( + numpy_array.tobytes(), + frame_rate=sample_rate, + sample_width=2, # 16-bit = 2 bytes + channels=1 + ) + else: # Стерео + # Для стерео нужно чередовать байты левого и правого каналов + interleaved = np.empty((numpy_array.shape[0] * 2,), dtype=np.int16) + interleaved[0::2] = numpy_array[:, 0] # Левый канал + interleaved[1::2] = numpy_array[:, 1] # Правый канал + audio_segment = AudioSegment( + interleaved.tobytes(), + frame_rate=sample_rate, + sample_width=2, + channels=2 + ) + + # Формирование параметров экспорта + + parameters = {} + if bitrate: + parameters['bitrate'] = bitrate + + # Поддержка различных форматов + format_mapping = { + 'mp3': 'mp3', + 'flac': 'flac', + 'wav': 'wav', + 'aiff': 'aiff', + 'm4a': 'ipod', # для m4a в pydub используется кодек ipod + 'aac': 'adts', # для aac в pydub используется adts + 'ogg': 'ogg', + 'opus': 'opus' + } + + if output_format not in format_mapping: + raise ValueError(f"Unsupported format: {output_format}. Supported formats are: {list(format_mapping.keys())}") + + # Добавление расширения файла, если его нет + if not output_file_path.lower().endswith(f'.{output_format}'): + output_file_path = f"{output_file_path}.{output_format}" + + # Экспорт в нужный формат + audio_segment.export(output_file_path, format=format_mapping[output_format], **parameters) + + except Exception as e: + raise RuntimeError(f"Error writing audio file: {str(e)}") \ No newline at end of file diff --git a/separator/ensemble.py b/separator/ensemble.py new file mode 100644 index 0000000000000000000000000000000000000000..1fedbec4ad75f9d09783d1dd1837c066af69e321 --- /dev/null +++ b/separator/ensemble.py @@ -0,0 +1,192 @@ +# coding: utf-8 +__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' + +import os +import sys +import librosa +import tempfile +import soundfile as sf +import numpy as np +import argparse +from separator.audio_writer import write_audio_file + + +def stft(wave, nfft, hl): + wave_left = np.asfortranarray(wave[0]) + wave_right = np.asfortranarray(wave[1]) + spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl) + spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl) + spec = np.asfortranarray([spec_left, spec_right]) + return spec + + +def istft(spec, hl, length): + spec_left = np.asfortranarray(spec[0]) + spec_right = np.asfortranarray(spec[1]) + wave_left = librosa.istft(spec_left, hop_length=hl, length=length) + wave_right = librosa.istft(spec_right, hop_length=hl, length=length) + wave = np.asfortranarray([wave_left, wave_right]) + return wave + + +def absmax(a, *, axis): + dims = list(a.shape) + dims.pop(axis) + indices = np.ogrid[tuple(slice(0, d) for d in dims)] + argmax = np.abs(a).argmax(axis=axis) + # Convert indices to list before insertion + indices = list(indices) + indices.insert(axis % len(a.shape), argmax) + return a[tuple(indices)] + + +def absmin(a, *, axis): + dims = list(a.shape) + dims.pop(axis) + indices = np.ogrid[tuple(slice(0, d) for d in dims)] + argmax = np.abs(a).argmin(axis=axis) + indices.insert((len(a.shape) + axis) % len(a.shape), argmax) + return a[tuple(indices)] + + +def lambda_max(arr, axis=None, key=None, keepdims=False): + idxs = np.argmax(key(arr), axis) + if axis is not None: + idxs = np.expand_dims(idxs, axis) + result = np.take_along_axis(arr, idxs, axis) + if not keepdims: + result = np.squeeze(result, axis=axis) + return result + else: + return arr.flatten()[idxs] + + +def lambda_min(arr, axis=None, key=None, keepdims=False): + idxs = np.argmin(key(arr), axis) + if axis is not None: + idxs = np.expand_dims(idxs, axis) + result = np.take_along_axis(arr, idxs, axis) + if not keepdims: + result = np.squeeze(result, axis=axis) + return result + else: + return arr.flatten()[idxs] + + +def average_waveforms(pred_track, weights, algorithm): + """ + :param pred_track: shape = (num, channels, length) + :param weights: shape = (num, ) + :param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft + :return: averaged waveform in shape (channels, length) + """ + + pred_track = np.array(pred_track) + final_length = pred_track.shape[-1] + + mod_track = [] + for i in range(pred_track.shape[0]): + if algorithm == 'avg_wave': + mod_track.append(pred_track[i] * weights[i]) + elif algorithm in ['median_wave', 'min_wave', 'max_wave']: + mod_track.append(pred_track[i]) + elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']: + spec = stft(pred_track[i], nfft=2048, hl=1024) + if algorithm in ['avg_fft']: + mod_track.append(spec * weights[i]) + else: + mod_track.append(spec) + pred_track = np.array(mod_track) + + if algorithm in ['avg_wave']: + pred_track = pred_track.sum(axis=0) + pred_track /= np.array(weights).sum().T + elif algorithm in ['median_wave']: + pred_track = np.median(pred_track, axis=0) + elif algorithm in ['min_wave']: + pred_track = np.array(pred_track) + pred_track = lambda_min(pred_track, axis=0, key=np.abs) + elif algorithm in ['max_wave']: + pred_track = np.array(pred_track) + pred_track = lambda_max(pred_track, axis=0, key=np.abs) + elif algorithm in ['avg_fft']: + pred_track = pred_track.sum(axis=0) + pred_track /= np.array(weights).sum() + pred_track = istft(pred_track, 1024, final_length) + elif algorithm in ['min_fft']: + pred_track = np.array(pred_track) + pred_track = lambda_min(pred_track, axis=0, key=np.abs) + pred_track = istft(pred_track, 1024, final_length) + elif algorithm in ['max_fft']: + pred_track = np.array(pred_track) + pred_track = absmax(pred_track, axis=0) + pred_track = istft(pred_track, 1024, final_length) + elif algorithm in ['median_fft']: + pred_track = np.array(pred_track) + pred_track = np.median(pred_track, axis=0) + pred_track = istft(pred_track, 1024, final_length) + return pred_track + + +def ensemble_audio_files(files, output="res.wav", ensemble_type='avg_wave', weights=None, out_format="wav"): + """ + Основная функция для объединения аудиофайлов + + :param files: список путей к аудиофайлам + :param output: путь для сохранения результата + :param ensemble_type: метод объединения (avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft) + :param weights: список весов для каждого файла (None для равных весов) + :return: None + """ + print('Ensemble type: {}'.format(ensemble_type)) + print('Number of input files: {}'.format(len(files))) + if weights is not None: + weights = np.array(weights) + else: + weights = np.ones(len(files)) + print('Weights: {}'.format(weights)) + print('Output file: {}'.format(output)) + + data = [] + sr = None + for f in files: + if not os.path.isfile(f): + print('Error. Can\'t find file: {}. Check paths.'.format(f)) + exit() + print('Reading file: {}'.format(f)) + wav, current_sr = librosa.load(f, sr=None, mono=False) + if sr is None: + sr = current_sr + elif sr != current_sr: + print('Error: Sample rates must be equal for all files') + exit() + print("Waveform shape: {} sample rate: {}".format(wav.shape, sr)) + data.append(wav) + + data = np.array(data) + res = average_waveforms(data, weights, ensemble_type) + print('Result shape: {}'.format(res.shape)) + + output_wav = f"{output}_orig.wav" + output = f"{output}.{out_format}" + + if out_format in ["wav", "flac"]: + + sf.write(output, res.T, sr, subtype='PCM_16') + sf.write(output_wav, res.T, sr, subtype='PCM_16') + + elif out_format in ["mp3", "m4a", "aac", "ogg", "opus", "aiff"]: + + write_audio_file(output, res.T, sr, out_format, "320k") + sf.write(output_wav, res.T, sr, subtype='PCM_16') + + return output, output_wav + + + +# input_settings = [("demucs / v4", 1.0, "vocals"), ("mel_band_roformer / mel_4_stems", 0.5, "vocals")] + +# out, wav = ensembless(input_audio, input_settings, "max_fft", format) + + + diff --git a/separator/models/bandit/core/__init__.py b/separator/models/bandit/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a4d6d7953709c2f86a6b484e49c7715b58bbe86a --- /dev/null +++ b/separator/models/bandit/core/__init__.py @@ -0,0 +1,744 @@ +import os.path +from collections import defaultdict +from itertools import chain, combinations +from typing import ( + Any, + Dict, + Iterator, + Mapping, Optional, + Tuple, Type, + TypedDict +) + +import pytorch_lightning as pl +import torch +import torchaudio as ta +import torchmetrics as tm +from asteroid import losses as asteroid_losses +# from deepspeed.ops.adam import DeepSpeedCPUAdam +# from geoopt import optim as gooptim +from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch import nn, optim +from torch.optim import lr_scheduler +from torch.optim.lr_scheduler import LRScheduler + +from models.bandit.core import loss, metrics as metrics_, model +from models.bandit.core.data._types import BatchedDataDict +from models.bandit.core.data.augmentation import BaseAugmentor, StemAugmentor +from models.bandit.core.utils import audio as audio_ +from models.bandit.core.utils.audio import BaseFader + +# from pandas.io.json._normalize import nested_to_record + +ConfigDict = TypedDict('ConfigDict', {'name': str, 'kwargs': Dict[str, Any]}) + + +class SchedulerConfigDict(ConfigDict): + monitor: str + + +OptimizerSchedulerConfigDict = TypedDict( + 'OptimizerSchedulerConfigDict', + {"optimizer": ConfigDict, "scheduler": SchedulerConfigDict}, + total=False +) + + +class LRSchedulerReturnDict(TypedDict, total=False): + scheduler: LRScheduler + monitor: str + + +class ConfigureOptimizerReturnDict(TypedDict, total=False): + optimizer: torch.optim.Optimizer + lr_scheduler: LRSchedulerReturnDict + + +OutputType = Dict[str, Any] +MetricsType = Dict[str, torch.Tensor] + + +def get_optimizer_class(name: str) -> Type[optim.Optimizer]: + + if name == "DeepSpeedCPUAdam": + return DeepSpeedCPUAdam + + for module in [optim, gooptim]: + if name in module.__dict__: + return module.__dict__[name] + + raise NameError + + +def parse_optimizer_config( + config: OptimizerSchedulerConfigDict, + parameters: Iterator[nn.Parameter] +) -> ConfigureOptimizerReturnDict: + optim_class = get_optimizer_class(config["optimizer"]["name"]) + optimizer = optim_class(parameters, **config["optimizer"]["kwargs"]) + + optim_dict: ConfigureOptimizerReturnDict = { + "optimizer": optimizer, + } + + if "scheduler" in config: + + lr_scheduler_class_ = config["scheduler"]["name"] + lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_] + lr_scheduler_dict: LRSchedulerReturnDict = { + "scheduler": lr_scheduler_class( + optimizer, + **config["scheduler"]["kwargs"] + ) + } + + if lr_scheduler_class_ == "ReduceLROnPlateau": + lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"] + + optim_dict["lr_scheduler"] = lr_scheduler_dict + + return optim_dict + + +def parse_model_config(config: ConfigDict) -> Any: + name = config["name"] + + for module in [model]: + if name in module.__dict__: + return module.__dict__[name](**config["kwargs"]) + + raise NameError + + +_LEGACY_LOSS_NAMES = ["HybridL1Loss"] + + +def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module: + name = config["name"] + + if name == "HybridL1Loss": + return loss.TimeFreqL1Loss(**config["kwargs"]) + + raise NameError + + +def parse_loss_config(config: ConfigDict) -> nn.Module: + name = config["name"] + + if name in _LEGACY_LOSS_NAMES: + return _parse_legacy_loss_config(config) + + for module in [loss, nn.modules.loss, asteroid_losses]: + if name in module.__dict__: + # print(config["kwargs"]) + return module.__dict__[name](**config["kwargs"]) + + raise NameError + + +def get_metric(config: ConfigDict) -> tm.Metric: + name = config["name"] + + for module in [tm, metrics_]: + if name in module.__dict__: + return module.__dict__[name](**config["kwargs"]) + raise NameError + + +def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection: + metrics = {} + + for metric in config: + metrics[metric] = get_metric(config[metric]) + + return tm.MetricCollection(metrics) + + +def parse_fader_config(config: ConfigDict) -> BaseFader: + name = config["name"] + + for module in [audio_]: + if name in module.__dict__: + return module.__dict__[name](**config["kwargs"]) + + raise NameError + + +class LightningSystem(pl.LightningModule): + _VOX_STEMS = ["speech", "vocals"] + _BG_STEMS = ["background", "effects", "mne"] + + def __init__( + self, + config: Dict, + loss_adjustment: float = 1.0, + attach_fader: bool = False + ) -> None: + super().__init__() + self.optimizer_config = config["optimizer"] + self.model = parse_model_config(config["model"]) + self.loss = parse_loss_config(config["loss"]) + self.metrics = nn.ModuleDict( + { + stem: parse_metric_config(config["metrics"]["dev"]) + for stem in self.model.stems + } + ) + + self.metrics.disallow_fsdp = True + + self.test_metrics = nn.ModuleDict( + { + stem: parse_metric_config(config["metrics"]["test"]) + for stem in self.model.stems + } + ) + + self.test_metrics.disallow_fsdp = True + + self.fs = config["model"]["kwargs"]["fs"] + + self.fader_config = config["inference"]["fader"] + if attach_fader: + self.fader = parse_fader_config(config["inference"]["fader"]) + else: + self.fader = None + + self.augmentation: Optional[BaseAugmentor] + if config.get("augmentation", None) is not None: + self.augmentation = StemAugmentor(**config["augmentation"]) + else: + self.augmentation = None + + self.predict_output_path: Optional[str] = None + self.loss_adjustment = loss_adjustment + + self.val_prefix = None + self.test_prefix = None + + + def configure_optimizers(self) -> Any: + return parse_optimizer_config( + self.optimizer_config, + self.trainer.model.parameters() + ) + + def compute_loss(self, batch: BatchedDataDict, output: OutputType) -> Dict[ + str, torch.Tensor]: + return {"loss": self.loss(output, batch)} + + def update_metrics( + self, + batch: BatchedDataDict, + output: OutputType, + mode: str + ) -> None: + + if mode == "test": + metrics = self.test_metrics + else: + metrics = self.metrics + + for stem, metric in metrics.items(): + + if stem == "mne:+": + stem = "mne" + + # print(f"matching for {stem}") + if mode == "train": + metric.update( + output["audio"][stem],#.cpu(), + batch["audio"][stem],#.cpu() + ) + else: + if stem not in batch["audio"]: + matched = False + if stem in self._VOX_STEMS: + for bstem in self._VOX_STEMS: + if bstem in batch["audio"]: + batch["audio"][stem] = batch["audio"][bstem] + matched = True + break + elif stem in self._BG_STEMS: + for bstem in self._BG_STEMS: + if bstem in batch["audio"]: + batch["audio"][stem] = batch["audio"][bstem] + matched = True + break + else: + matched = True + + # print(batch["audio"].keys()) + + if matched: + # print(f"matched {stem}!") + if stem == "mne" and "mne" not in output["audio"]: + output["audio"]["mne"] = output["audio"]["music"] + output["audio"]["effects"] + + metric.update( + output["audio"][stem],#.cpu(), + batch["audio"][stem],#.cpu(), + ) + + # print(metric.compute()) + def compute_metrics(self, mode: str="dev") -> Dict[ + str, torch.Tensor]: + + if mode == "test": + metrics = self.test_metrics + else: + metrics = self.metrics + + metric_dict = {} + + for stem, metric in metrics.items(): + md = metric.compute() + metric_dict.update( + {f"{stem}/{k}": v for k, v in md.items()} + ) + + self.log_dict(metric_dict, prog_bar=True, logger=False) + + return metric_dict + + def reset_metrics(self, test_mode: bool = False) -> None: + + if test_mode: + metrics = self.test_metrics + else: + metrics = self.metrics + + for _, metric in metrics.items(): + metric.reset() + + + def forward(self, batch: BatchedDataDict) -> Any: + batch, output = self.model(batch) + + + return batch, output + + def common_step(self, batch: BatchedDataDict, mode: str) -> Any: + batch, output = self.forward(batch) + # print(batch) + # print(output) + loss_dict = self.compute_loss(batch, output) + + with torch.no_grad(): + self.update_metrics(batch, output, mode=mode) + + if mode == "train": + self.log("loss", loss_dict["loss"], prog_bar=True) + + return output, loss_dict + + + def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]: + + if self.augmentation is not None: + with torch.no_grad(): + batch = self.augmentation(batch) + + _, loss_dict = self.common_step(batch, mode="train") + + with torch.inference_mode(): + self.log_dict_with_prefix( + loss_dict, + "train", + batch_size=batch["audio"]["mixture"].shape[0] + ) + + loss_dict["loss"] *= self.loss_adjustment + + return loss_dict + + def on_train_batch_end( + self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int + ) -> None: + + metric_dict = self.compute_metrics() + self.log_dict_with_prefix(metric_dict, "train") + self.reset_metrics() + + def validation_step( + self, + batch: BatchedDataDict, + batch_idx: int, + dataloader_idx: int = 0 + ) -> Dict[str, Any]: + + with torch.inference_mode(): + curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val" + + if curr_val_prefix != self.val_prefix: + # print(f"Switching to validation dataloader {dataloader_idx}") + if self.val_prefix is not None: + self._on_validation_epoch_end() + self.val_prefix = curr_val_prefix + _, loss_dict = self.common_step(batch, mode="val") + + self.log_dict_with_prefix( + loss_dict, + self.val_prefix, + batch_size=batch["audio"]["mixture"].shape[0], + prog_bar=True, + add_dataloader_idx=False + ) + + return loss_dict + + def on_validation_epoch_end(self) -> None: + self._on_validation_epoch_end() + + def _on_validation_epoch_end(self) -> None: + metric_dict = self.compute_metrics() + self.log_dict_with_prefix(metric_dict, self.val_prefix, prog_bar=True, + add_dataloader_idx=False) + # self.logger.save() + # print(self.val_prefix, "Validation metrics:", metric_dict) + self.reset_metrics() + + + def old_predtest_step( + self, + batch: BatchedDataDict, + batch_idx: int, + dataloader_idx: int = 0 + ) -> Tuple[BatchedDataDict, OutputType]: + + audio_batch = batch["audio"]["mixture"] + track_batch = batch.get("track", ["" for _ in range(len(audio_batch))]) + + output_list_of_dicts = [ + self.fader( + audio[None, ...], + lambda a: self.test_forward(a, track) + ) + for audio, track in zip(audio_batch, track_batch) + ] + + output_dict_of_lists = defaultdict(list) + + for output_dict in output_list_of_dicts: + for stem, audio in output_dict.items(): + output_dict_of_lists[stem].append(audio) + + output = { + "audio": { + stem: torch.concat(output_list, dim=0) + for stem, output_list in output_dict_of_lists.items() + } + } + + return batch, output + + def predtest_step( + self, + batch: BatchedDataDict, + batch_idx: int = -1, + dataloader_idx: int = 0 + ) -> Tuple[BatchedDataDict, OutputType]: + + if getattr(self.model, "bypass_fader", False): + batch, output = self.model(batch) + else: + audio_batch = batch["audio"]["mixture"] + output = self.fader( + audio_batch, + lambda a: self.test_forward(a, "", batch=batch) + ) + + return batch, output + + def test_forward( + self, + audio: torch.Tensor, + track: str = "", + batch: BatchedDataDict = None + ) -> torch.Tensor: + + if self.fader is None: + self.attach_fader() + + cond = batch.get("condition", None) + + if cond is not None and cond.shape[0] == 1: + cond = cond.repeat(audio.shape[0], 1) + + _, output = self.forward( + {"audio": {"mixture": audio}, + "track": track, + "condition": cond, + } + ) # TODO: support track properly + + return output["audio"] + + def on_test_epoch_start(self) -> None: + self.attach_fader(force_reattach=True) + + def test_step( + self, + batch: BatchedDataDict, + batch_idx: int, + dataloader_idx: int = 0 + ) -> Any: + curr_test_prefix = f"test{dataloader_idx}" + + # print(batch["audio"].keys()) + + if curr_test_prefix != self.test_prefix: + # print(f"Switching to test dataloader {dataloader_idx}") + if self.test_prefix is not None: + self._on_test_epoch_end() + self.test_prefix = curr_test_prefix + + with torch.inference_mode(): + _, output = self.predtest_step(batch, batch_idx, dataloader_idx) + # print(output) + self.update_metrics(batch, output, mode="test") + + return output + + def on_test_epoch_end(self) -> None: + self._on_test_epoch_end() + + def _on_test_epoch_end(self) -> None: + metric_dict = self.compute_metrics(mode="test") + self.log_dict_with_prefix(metric_dict, self.test_prefix, prog_bar=True, + add_dataloader_idx=False) + # self.logger.save() + # print(self.test_prefix, "Test metrics:", metric_dict) + self.reset_metrics() + + def predict_step( + self, + batch: BatchedDataDict, + batch_idx: int = 0, + dataloader_idx: int = 0, + include_track_name: Optional[bool] = None, + get_no_vox_combinations: bool = True, + get_residual: bool = False, + treat_batch_as_channels: bool = False, + fs: Optional[int] = None, + ) -> Any: + assert self.predict_output_path is not None + + batch_size = batch["audio"]["mixture"].shape[0] + + if include_track_name is None: + include_track_name = batch_size > 1 + + with torch.inference_mode(): + batch, output = self.predtest_step(batch, batch_idx, dataloader_idx) + print('Pred test finished...') + torch.cuda.empty_cache() + metric_dict = {} + + if get_residual: + mixture = batch["audio"]["mixture"] + extracted = sum([output["audio"][stem] for stem in output["audio"]]) + residual = mixture - extracted + print(extracted.shape, mixture.shape, residual.shape) + + output["audio"]["residual"] = residual + + if get_no_vox_combinations: + no_vox_stems = [ + stem for stem in output["audio"] if + stem not in self._VOX_STEMS + ] + no_vox_combinations = chain.from_iterable( + combinations(no_vox_stems, r) for r in + range(2, len(no_vox_stems) + 1) + ) + + for combination in no_vox_combinations: + combination_ = list(combination) + output["audio"]["+".join(combination_)] = sum( + [output["audio"][stem] for stem in combination_] + ) + + if treat_batch_as_channels: + for stem in output["audio"]: + output["audio"][stem] = output["audio"][stem].reshape( + 1, -1, output["audio"][stem].shape[-1] + ) + batch_size = 1 + + for b in range(batch_size): + print("!!", b) + for stem in output["audio"]: + print(f"Saving audio for {stem} to {self.predict_output_path}") + track_name = batch["track"][b].split("/")[-1] + + if batch.get("audio", {}).get(stem, None) is not None: + self.test_metrics[stem].reset() + metrics = self.test_metrics[stem]( + batch["audio"][stem][[b], ...], + output["audio"][stem][[b], ...] + ) + snr = metrics["snr"] + sisnr = metrics["sisnr"] + sdr = metrics["sdr"] + metric_dict[stem] = metrics + print( + track_name, + f"snr={snr:2.2f} dB", + f"sisnr={sisnr:2.2f}", + f"sdr={sdr:2.2f} dB", + ) + filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav" + else: + filename = f"{stem}.wav" + + if include_track_name: + output_dir = os.path.join( + self.predict_output_path, + track_name + ) + else: + output_dir = self.predict_output_path + + os.makedirs(output_dir, exist_ok=True) + + if fs is None: + fs = self.fs + + ta.save( + os.path.join(output_dir, filename), + output["audio"][stem][b, ...].cpu(), + fs, + ) + + return metric_dict + + def get_stems( + self, + batch: BatchedDataDict, + batch_idx: int = 0, + dataloader_idx: int = 0, + include_track_name: Optional[bool] = None, + get_no_vox_combinations: bool = True, + get_residual: bool = False, + treat_batch_as_channels: bool = False, + fs: Optional[int] = None, + ) -> Any: + assert self.predict_output_path is not None + + batch_size = batch["audio"]["mixture"].shape[0] + + if include_track_name is None: + include_track_name = batch_size > 1 + + with torch.inference_mode(): + batch, output = self.predtest_step(batch, batch_idx, dataloader_idx) + torch.cuda.empty_cache() + metric_dict = {} + + if get_residual: + mixture = batch["audio"]["mixture"] + extracted = sum([output["audio"][stem] for stem in output["audio"]]) + residual = mixture - extracted + # print(extracted.shape, mixture.shape, residual.shape) + + output["audio"]["residual"] = residual + + if get_no_vox_combinations: + no_vox_stems = [ + stem for stem in output["audio"] if + stem not in self._VOX_STEMS + ] + no_vox_combinations = chain.from_iterable( + combinations(no_vox_stems, r) for r in + range(2, len(no_vox_stems) + 1) + ) + + for combination in no_vox_combinations: + combination_ = list(combination) + output["audio"]["+".join(combination_)] = sum( + [output["audio"][stem] for stem in combination_] + ) + + if treat_batch_as_channels: + for stem in output["audio"]: + output["audio"][stem] = output["audio"][stem].reshape( + 1, -1, output["audio"][stem].shape[-1] + ) + batch_size = 1 + + result = {} + for b in range(batch_size): + for stem in output["audio"]: + track_name = batch["track"][b].split("/")[-1] + + if batch.get("audio", {}).get(stem, None) is not None: + self.test_metrics[stem].reset() + metrics = self.test_metrics[stem]( + batch["audio"][stem][[b], ...], + output["audio"][stem][[b], ...] + ) + snr = metrics["snr"] + sisnr = metrics["sisnr"] + sdr = metrics["sdr"] + metric_dict[stem] = metrics + print( + track_name, + f"snr={snr:2.2f} dB", + f"sisnr={sisnr:2.2f}", + f"sdr={sdr:2.2f} dB", + ) + filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav" + else: + filename = f"{stem}.wav" + + if include_track_name: + output_dir = os.path.join( + self.predict_output_path, + track_name + ) + else: + output_dir = self.predict_output_path + + os.makedirs(output_dir, exist_ok=True) + + if fs is None: + fs = self.fs + + result[stem] = output["audio"][stem][b, ...].cpu().numpy() + + return result + + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = False + ) -> Any: + + return super().load_state_dict(state_dict, strict=False) + + + def set_predict_output_path(self, path: str) -> None: + self.predict_output_path = path + os.makedirs(self.predict_output_path, exist_ok=True) + + self.attach_fader() + + def attach_fader(self, force_reattach=False) -> None: + if self.fader is None or force_reattach: + self.fader = parse_fader_config(self.fader_config) + self.fader.to(self.device) + + + def log_dict_with_prefix( + self, + dict_: Dict[str, torch.Tensor], + prefix: str, + batch_size: Optional[int] = None, + **kwargs: Any + ) -> None: + self.log_dict( + {f"{prefix}/{k}": v for k, v in dict_.items()}, + batch_size=batch_size, + logger=True, + sync_dist=True, + **kwargs, + ) \ No newline at end of file diff --git a/separator/models/bandit/core/data/__init__.py b/separator/models/bandit/core/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1087fe2c4d7d3048295cdf73c0725a015bc0d129 --- /dev/null +++ b/separator/models/bandit/core/data/__init__.py @@ -0,0 +1,2 @@ +from .dnr.datamodule import DivideAndRemasterDataModule +from .musdb.datamodule import MUSDB18DataModule \ No newline at end of file diff --git a/separator/models/bandit/core/data/_types.py b/separator/models/bandit/core/data/_types.py new file mode 100644 index 0000000000000000000000000000000000000000..9499f9a80b5dec7b5b0e7882849e4f7b2c801ccf --- /dev/null +++ b/separator/models/bandit/core/data/_types.py @@ -0,0 +1,18 @@ +from typing import Dict, Sequence, TypedDict + +import torch + +AudioDict = Dict[str, torch.Tensor] + +DataDict = TypedDict('DataDict', {'audio': AudioDict, 'track': str}) + +BatchedDataDict = TypedDict( + 'BatchedDataDict', + {'audio': AudioDict, 'track': Sequence[str]} +) + + +class DataDictWithLanguage(TypedDict): + audio: AudioDict + track: str + language: str diff --git a/separator/models/bandit/core/data/augmentation.py b/separator/models/bandit/core/data/augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..238214bf17a69e71f48e8761e1ead05b17d0fa5a --- /dev/null +++ b/separator/models/bandit/core/data/augmentation.py @@ -0,0 +1,107 @@ +from abc import ABC +from typing import Any, Dict, Union + +import torch +import torch_audiomentations as tam +from torch import nn + +from models.bandit.core.data._types import BatchedDataDict, DataDict + + +class BaseAugmentor(nn.Module, ABC): + def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[ + DataDict, BatchedDataDict]: + raise NotImplementedError + + +class StemAugmentor(BaseAugmentor): + def __init__( + self, + audiomentations: Dict[str, Dict[str, Any]], + fix_clipping: bool = True, + scaler_margin: float = 0.5, + apply_both_default_and_common: bool = False, + ) -> None: + super().__init__() + + augmentations = {} + + self.has_default = "[default]" in audiomentations + self.has_common = "[common]" in audiomentations + self.apply_both_default_and_common = apply_both_default_and_common + + for stem in audiomentations: + if audiomentations[stem]["name"] == "Compose": + augmentations[stem] = getattr( + tam, + audiomentations[stem]["name"] + )( + [ + getattr(tam, aug["name"])(**aug["kwargs"]) + for aug in + audiomentations[stem]["kwargs"]["transforms"] + ], + **audiomentations[stem]["kwargs"]["kwargs"], + ) + else: + augmentations[stem] = getattr( + tam, + audiomentations[stem]["name"] + )( + **audiomentations[stem]["kwargs"] + ) + + self.augmentations = nn.ModuleDict(augmentations) + self.fix_clipping = fix_clipping + self.scaler_margin = scaler_margin + + def check_and_fix_clipping( + self, item: Union[DataDict, BatchedDataDict] + ) -> Union[DataDict, BatchedDataDict]: + max_abs = [] + + for stem in item["audio"]: + max_abs.append(item["audio"][stem].abs().max().item()) + + if max(max_abs) > 1.0: + scaler = 1.0 / (max(max_abs) + torch.rand( + (1,), + device=item["audio"]["mixture"].device + ) * self.scaler_margin) + + for stem in item["audio"]: + item["audio"][stem] *= scaler + + return item + + def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[ + DataDict, BatchedDataDict]: + + for stem in item["audio"]: + if stem == "mixture": + continue + + if self.has_common: + item["audio"][stem] = self.augmentations["[common]"]( + item["audio"][stem] + ).samples + + if stem in self.augmentations: + item["audio"][stem] = self.augmentations[stem]( + item["audio"][stem] + ).samples + elif self.has_default: + if not self.has_common or self.apply_both_default_and_common: + item["audio"][stem] = self.augmentations["[default]"]( + item["audio"][stem] + ).samples + + item["audio"]["mixture"] = sum( + [item["audio"][stem] for stem in item["audio"] + if stem != "mixture"] + ) # type: ignore[call-overload, assignment] + + if self.fix_clipping: + item = self.check_and_fix_clipping(item) + + return item diff --git a/separator/models/bandit/core/data/augmented.py b/separator/models/bandit/core/data/augmented.py new file mode 100644 index 0000000000000000000000000000000000000000..84d19599a6579eb5afd304ef6da76a6cbca49045 --- /dev/null +++ b/separator/models/bandit/core/data/augmented.py @@ -0,0 +1,35 @@ +import warnings +from typing import Dict, Optional, Union + +import torch +from torch import nn +from torch.utils import data + + +class AugmentedDataset(data.Dataset): + def __init__( + self, + dataset: data.Dataset, + augmentation: nn.Module = nn.Identity(), + target_length: Optional[int] = None, + ) -> None: + warnings.warn( + "This class is no longer used. Attach augmentation to " + "the LightningSystem instead.", + DeprecationWarning, + ) + + self.dataset = dataset + self.augmentation = augmentation + + self.ds_length: int = len(dataset) # type: ignore[arg-type] + self.length = target_length if target_length is not None else self.ds_length + + def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, + torch.Tensor]]]: + item = self.dataset[index % self.ds_length] + item = self.augmentation(item) + return item + + def __len__(self) -> int: + return self.length diff --git a/separator/models/bandit/core/data/base.py b/separator/models/bandit/core/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a7b6c33a85b93c32209138e3d21bfc8e0f270cac --- /dev/null +++ b/separator/models/bandit/core/data/base.py @@ -0,0 +1,69 @@ +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +import numpy as np +import pedalboard as pb +import torch +import torchaudio as ta +from torch.utils import data + +from models.bandit.core.data._types import AudioDict, DataDict + + +class BaseSourceSeparationDataset(data.Dataset, ABC): + def __init__( + self, split: str, + stems: List[str], + files: List[str], + data_path: str, + fs: int, + npy_memmap: bool, + recompute_mixture: bool + ): + self.split = split + self.stems = stems + self.stems_no_mixture = [s for s in stems if s != "mixture"] + self.files = files + self.data_path = data_path + self.fs = fs + self.npy_memmap = npy_memmap + self.recompute_mixture = recompute_mixture + + @abstractmethod + def get_stem( + self, + *, + stem: str, + identifier: Dict[str, Any] + ) -> torch.Tensor: + raise NotImplementedError + + def _get_audio(self, stems, identifier: Dict[str, Any]): + audio = {} + for stem in stems: + audio[stem] = self.get_stem(stem=stem, identifier=identifier) + + return audio + + def get_audio(self, identifier: Dict[str, Any]) -> AudioDict: + + if self.recompute_mixture: + audio = self._get_audio( + self.stems_no_mixture, + identifier=identifier + ) + audio["mixture"] = self.compute_mixture(audio) + return audio + else: + return self._get_audio(self.stems, identifier=identifier) + + @abstractmethod + def get_identifier(self, index: int) -> Dict[str, Any]: + pass + + def compute_mixture(self, audio: AudioDict) -> torch.Tensor: + + return sum( + audio[stem] for stem in audio if stem != "mixture" + ) diff --git a/separator/models/bandit/core/data/dnr/__init__.py b/separator/models/bandit/core/data/dnr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/separator/models/bandit/core/data/dnr/datamodule.py b/separator/models/bandit/core/data/dnr/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..dc5550608aabf460eb1781576112ed60185dd318 --- /dev/null +++ b/separator/models/bandit/core/data/dnr/datamodule.py @@ -0,0 +1,74 @@ +import os +from typing import Mapping, Optional + +import pytorch_lightning as pl + +from .dataset import ( + DivideAndRemasterDataset, + DivideAndRemasterDeterministicChunkDataset, + DivideAndRemasterRandomChunkDataset, + DivideAndRemasterRandomChunkDatasetWithSpeechReverb +) + + +def DivideAndRemasterDataModule( + data_root: str = "$DATA_ROOT/DnR/v2", + batch_size: int = 2, + num_workers: int = 8, + train_kwargs: Optional[Mapping] = None, + val_kwargs: Optional[Mapping] = None, + test_kwargs: Optional[Mapping] = None, + datamodule_kwargs: Optional[Mapping] = None, + use_speech_reverb: bool = False + # augmentor=None +) -> pl.LightningDataModule: + if train_kwargs is None: + train_kwargs = {} + + if val_kwargs is None: + val_kwargs = {} + + if test_kwargs is None: + test_kwargs = {} + + if datamodule_kwargs is None: + datamodule_kwargs = {} + + if num_workers is None: + num_workers = os.cpu_count() + + if num_workers is None: + num_workers = 32 + + num_workers = min(num_workers, 64) + + if use_speech_reverb: + train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb + else: + train_cls = DivideAndRemasterRandomChunkDataset + + train_dataset = train_cls( + data_root, "train", **train_kwargs + ) + + # if augmentor is not None: + # train_dataset = AugmentedDataset(train_dataset, augmentor) + + datamodule = pl.LightningDataModule.from_datasets( + train_dataset=train_dataset, + val_dataset=DivideAndRemasterDeterministicChunkDataset( + data_root, "val", **val_kwargs + ), + test_dataset=DivideAndRemasterDataset( + data_root, + "test", + **test_kwargs + ), + batch_size=batch_size, + num_workers=num_workers, + **datamodule_kwargs + ) + + datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign] + + return datamodule diff --git a/separator/models/bandit/core/data/dnr/dataset.py b/separator/models/bandit/core/data/dnr/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7b241cf0dd474eafbfc9db3ec2f4987d12596de4 --- /dev/null +++ b/separator/models/bandit/core/data/dnr/dataset.py @@ -0,0 +1,392 @@ +import os +from abc import ABC +from typing import Any, Dict, List, Optional + +import numpy as np +import pedalboard as pb +import torch +import torchaudio as ta +from torch.utils import data + +from models.bandit.core.data._types import AudioDict, DataDict +from models.bandit.core.data.base import BaseSourceSeparationDataset + + +class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC): + ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"] + STEM_NAME_MAP = { + "mixture": "mix", + "speech": "speech", + "music": "music", + "effects": "sfx", + } + SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"} + + FULL_TRACK_LENGTH_SECOND = 60 + FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100 + + def __init__( + self, + split: str, + stems: List[str], + files: List[str], + data_path: str, + fs: int = 44100, + npy_memmap: bool = True, + recompute_mixture: bool = False, + ) -> None: + super().__init__( + split=split, + stems=stems, + files=files, + data_path=data_path, + fs=fs, + npy_memmap=npy_memmap, + recompute_mixture=recompute_mixture + ) + + def get_stem( + self, + *, + stem: str, + identifier: Dict[str, Any] + ) -> torch.Tensor: + + if stem == "mne": + return self.get_stem( + stem="music", + identifier=identifier) + self.get_stem( + stem="effects", + identifier=identifier) + + track = identifier["track"] + path = os.path.join(self.data_path, track) + + if self.npy_memmap: + audio = np.load( + os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"), + mmap_mode="r" + ) + else: + # noinspection PyUnresolvedReferences + audio, _ = ta.load( + os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav") + ) + + return audio + + def get_identifier(self, index): + return dict(track=self.files[index]) + + def __getitem__(self, index: int) -> DataDict: + identifier = self.get_identifier(index) + audio = self.get_audio(identifier) + + return {"audio": audio, "track": f"{self.split}/{identifier['track']}"} + + +class DivideAndRemasterDataset(DivideAndRemasterBaseDataset): + def __init__( + self, + data_root: str, + split: str, + stems: Optional[List[str]] = None, + fs: int = 44100, + npy_memmap: bool = True, + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + self.stems = stems + + data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split]) + + files = sorted(os.listdir(data_path)) + files = [ + f + for f in files + if (not f.startswith(".")) and os.path.isdir( + os.path.join(data_path, f) + ) + ] + # pprint(list(enumerate(files))) + if split == "train": + assert len(files) == 3406, len(files) + elif split == "val": + assert len(files) == 487, len(files) + elif split == "test": + assert len(files) == 973, len(files) + + self.n_tracks = len(files) + + super().__init__( + data_path=data_path, + split=split, + stems=stems, + files=files, + fs=fs, + npy_memmap=npy_memmap, + ) + + def __len__(self) -> int: + return self.n_tracks + + +class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset): + def __init__( + self, + data_root: str, + split: str, + target_length: int, + chunk_size_second: float, + stems: Optional[List[str]] = None, + fs: int = 44100, + npy_memmap: bool = True, + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + self.stems = stems + + data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split]) + + files = sorted(os.listdir(data_path)) + files = [ + f + for f in files + if (not f.startswith(".")) and os.path.isdir( + os.path.join(data_path, f) + ) + ] + + if split == "train": + assert len(files) == 3406, len(files) + elif split == "val": + assert len(files) == 487, len(files) + elif split == "test": + assert len(files) == 973, len(files) + + self.n_tracks = len(files) + + self.target_length = target_length + self.chunk_size = int(chunk_size_second * fs) + + super().__init__( + data_path=data_path, + split=split, + stems=stems, + files=files, + fs=fs, + npy_memmap=npy_memmap, + ) + + def __len__(self) -> int: + return self.target_length + + def get_identifier(self, index): + return super().get_identifier(index % self.n_tracks) + + def get_stem( + self, + *, + stem: str, + identifier: Dict[str, Any], + chunk_here: bool = False, + ) -> torch.Tensor: + + stem = super().get_stem( + stem=stem, + identifier=identifier + ) + + if chunk_here: + start = np.random.randint( + 0, + self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size + ) + end = start + self.chunk_size + + stem = stem[:, start:end] + + return stem + + def __getitem__(self, index: int) -> DataDict: + identifier = self.get_identifier(index) + # self.index_lock = index + audio = self.get_audio(identifier) + # self.index_lock = None + + start = np.random.randint( + 0, + self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size + ) + end = start + self.chunk_size + + audio = { + k: v[:, start:end] for k, v in audio.items() + } + + return {"audio": audio, "track": f"{self.split}/{identifier['track']}"} + + +class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset): + def __init__( + self, + data_root: str, + split: str, + chunk_size_second: float, + hop_size_second: float, + stems: Optional[List[str]] = None, + fs: int = 44100, + npy_memmap: bool = True, + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + self.stems = stems + + data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split]) + + files = sorted(os.listdir(data_path)) + files = [ + f + for f in files + if (not f.startswith(".")) and os.path.isdir( + os.path.join(data_path, f) + ) + ] + # pprint(list(enumerate(files))) + if split == "train": + assert len(files) == 3406, len(files) + elif split == "val": + assert len(files) == 487, len(files) + elif split == "test": + assert len(files) == 973, len(files) + + self.n_tracks = len(files) + + self.chunk_size = int(chunk_size_second * fs) + self.hop_size = int(hop_size_second * fs) + self.n_chunks_per_track = int( + ( + self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second + ) + + self.length = self.n_tracks * self.n_chunks_per_track + + super().__init__( + data_path=data_path, + split=split, + stems=stems, + files=files, + fs=fs, + npy_memmap=npy_memmap, + ) + + def get_identifier(self, index): + return super().get_identifier(index % self.n_tracks) + + def __len__(self) -> int: + return self.length + + def __getitem__(self, item: int) -> DataDict: + + index = item % self.n_tracks + chunk = item // self.n_tracks + + data_ = super().__getitem__(index) + + audio = data_["audio"] + + start = chunk * self.hop_size + end = start + self.chunk_size + + for stem in self.stems: + data_["audio"][stem] = audio[stem][:, start:end] + + return data_ + + +class DivideAndRemasterRandomChunkDatasetWithSpeechReverb( + DivideAndRemasterRandomChunkDataset +): + def __init__( + self, + data_root: str, + split: str, + target_length: int, + chunk_size_second: float, + stems: Optional[List[str]] = None, + fs: int = 44100, + npy_memmap: bool = True, + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + + stems_no_mixture = [s for s in stems if s != "mixture"] + + super().__init__( + data_root=data_root, + split=split, + target_length=target_length, + chunk_size_second=chunk_size_second, + stems=stems_no_mixture, + fs=fs, + npy_memmap=npy_memmap, + ) + + self.stems = stems + self.stems_no_mixture = stems_no_mixture + + def __getitem__(self, index: int) -> DataDict: + + data_ = super().__getitem__(index) + + dry = data_["audio"]["speech"][:] + n_samples = dry.shape[-1] + + wet_level = np.random.rand() + + speech = pb.Reverb( + room_size=np.random.rand(), + damping=np.random.rand(), + wet_level=wet_level, + dry_level=(1 - wet_level), + width=np.random.rand() + ).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples] + + data_["audio"]["speech"] = speech + + data_["audio"]["mixture"] = sum( + [data_["audio"][s] for s in self.stems_no_mixture] + ) + + return data_ + + def __len__(self) -> int: + return super().__len__() + + +if __name__ == "__main__": + + from pprint import pprint + from tqdm.auto import tqdm + + for split_ in ["train", "val", "test"]: + ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb( + data_root="$DATA_ROOT/DnR/v2np", + split=split_, + target_length=100, + chunk_size_second=6.0 + ) + + print(split_, len(ds)) + + for track_ in tqdm(ds): # type: ignore + pprint(track_) + track_["audio"] = {k: v.shape for k, v in track_["audio"].items()} + pprint(track_) + # break + + break diff --git a/separator/models/bandit/core/data/dnr/preprocess.py b/separator/models/bandit/core/data/dnr/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0b58690f3bae726b0655dbade6393c89bf8c9e --- /dev/null +++ b/separator/models/bandit/core/data/dnr/preprocess.py @@ -0,0 +1,54 @@ +import glob +import os +from typing import Tuple + +import numpy as np +import torchaudio as ta +from tqdm.contrib.concurrent import process_map + + +def process_one(inputs: Tuple[str, str, int]) -> None: + infile, outfile, target_fs = inputs + + dir = os.path.dirname(outfile) + os.makedirs(dir, exist_ok=True) + + data, fs = ta.load(infile) + + if fs != target_fs: + data = ta.functional.resample(data, fs, target_fs, resampling_method="sinc_interp_kaiser") + fs = target_fs + + data = data.numpy() + data = data.astype(np.float32) + + if os.path.exists(outfile): + data_ = np.load(outfile) + if np.allclose(data, data_): + return + + np.save(outfile, data) + + +def preprocess( + data_path: str, + output_path: str, + fs: int +) -> None: + files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True) + print(files) + outfiles = [ + f.replace(data_path, output_path).replace(".wav", ".npy") for f in + files + ] + + os.makedirs(output_path, exist_ok=True) + inputs = list(zip(files, outfiles, [fs] * len(files))) + + process_map(process_one, inputs, chunksize=32) + + +if __name__ == "__main__": + import fire + + fire.Fire() diff --git a/separator/models/bandit/core/data/musdb/__init__.py b/separator/models/bandit/core/data/musdb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/separator/models/bandit/core/data/musdb/datamodule.py b/separator/models/bandit/core/data/musdb/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..a8984daebd535b25f0551d348c91dbd1702fb9da --- /dev/null +++ b/separator/models/bandit/core/data/musdb/datamodule.py @@ -0,0 +1,77 @@ +import os.path +from typing import Mapping, Optional + +import pytorch_lightning as pl + +from models.bandit.core.data.musdb.dataset import ( + MUSDB18BaseDataset, + MUSDB18FullTrackDataset, + MUSDB18SadDataset, + MUSDB18SadOnTheFlyAugmentedDataset +) + + +def MUSDB18DataModule( + data_root: str = "$DATA_ROOT/MUSDB18/HQ", + target_stem: str = "vocals", + batch_size: int = 2, + num_workers: int = 8, + train_kwargs: Optional[Mapping] = None, + val_kwargs: Optional[Mapping] = None, + test_kwargs: Optional[Mapping] = None, + datamodule_kwargs: Optional[Mapping] = None, + use_on_the_fly: bool = True, + npy_memmap: bool = True +) -> pl.LightningDataModule: + if train_kwargs is None: + train_kwargs = {} + + if val_kwargs is None: + val_kwargs = {} + + if test_kwargs is None: + test_kwargs = {} + + if datamodule_kwargs is None: + datamodule_kwargs = {} + + train_dataset: MUSDB18BaseDataset + + if use_on_the_fly: + train_dataset = MUSDB18SadOnTheFlyAugmentedDataset( + data_root=os.path.join(data_root, "saded-np"), + split="train", + target_stem=target_stem, + **train_kwargs + ) + else: + train_dataset = MUSDB18SadDataset( + data_root=os.path.join(data_root, "saded-np"), + split="train", + target_stem=target_stem, + **train_kwargs + ) + + datamodule = pl.LightningDataModule.from_datasets( + train_dataset=train_dataset, + val_dataset=MUSDB18SadDataset( + data_root=os.path.join(data_root, "saded-np"), + split="val", + target_stem=target_stem, + **val_kwargs + ), + test_dataset=MUSDB18FullTrackDataset( + data_root=os.path.join(data_root, "canonical"), + split="test", + **test_kwargs + ), + batch_size=batch_size, + num_workers=num_workers, + **datamodule_kwargs + ) + + datamodule.predict_dataloader = ( # type: ignore[method-assign] + datamodule.test_dataloader + ) + + return datamodule diff --git a/separator/models/bandit/core/data/musdb/dataset.py b/separator/models/bandit/core/data/musdb/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..53f5d9afdfe383600b5f89767c4ef1f4b54f4a47 --- /dev/null +++ b/separator/models/bandit/core/data/musdb/dataset.py @@ -0,0 +1,280 @@ +import os +from abc import ABC +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torchaudio as ta +from torch.utils import data + +from models.bandit.core.data._types import AudioDict, DataDict +from models.bandit.core.data.base import BaseSourceSeparationDataset + + +class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC): + + ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"] + + def __init__( + self, + split: str, + stems: List[str], + files: List[str], + data_path: str, + fs: int = 44100, + npy_memmap=False, + ) -> None: + super().__init__( + split=split, + stems=stems, + files=files, + data_path=data_path, + fs=fs, + npy_memmap=npy_memmap, + recompute_mixture=False + ) + + def get_stem(self, *, stem: str, identifier) -> torch.Tensor: + track = identifier["track"] + path = os.path.join(self.data_path, track) + # noinspection PyUnresolvedReferences + + if self.npy_memmap: + audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r") + else: + audio, _ = ta.load(os.path.join(path, f"{stem}.wav")) + + return audio + + def get_identifier(self, index): + return dict(track=self.files[index]) + + def __getitem__(self, index: int) -> DataDict: + identifier = self.get_identifier(index) + audio = self.get_audio(identifier) + + return {"audio": audio, "track": f"{self.split}/{identifier['track']}"} + + +class MUSDB18FullTrackDataset(MUSDB18BaseDataset): + + N_TRAIN_TRACKS = 100 + N_TEST_TRACKS = 50 + VALIDATION_FILES = [ + "Actions - One Minute Smile", + "Clara Berry And Wooldog - Waltz For My Victims", + "Johnny Lokke - Promises & Lies", + "Patrick Talbot - A Reason To Leave", + "Triviul - Angelsaint", + "Alexander Ross - Goodbye Bolero", + "Fergessen - Nos Palpitants", + "Leaf - Summerghost", + "Skelpolu - Human Mistakes", + "Young Griffo - Pennies", + "ANiMAL - Rockshow", + "James May - On The Line", + "Meaxic - Take A Step", + "Traffic Experiment - Sirens", + ] + + def __init__( + self, data_root: str, split: str, stems: Optional[List[ + str]] = None + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + self.stems = stems + + if split == "test": + subset = "test" + elif split in ["train", "val"]: + subset = "train" + else: + raise NameError + + data_path = os.path.join(data_root, subset) + + files = sorted(os.listdir(data_path)) + files = [f for f in files if not f.startswith(".")] + # pprint(list(enumerate(files))) + if subset == "train": + assert len(files) == 100, len(files) + if split == "train": + files = [f for f in files if f not in self.VALIDATION_FILES] + assert len(files) == 100 - len(self.VALIDATION_FILES) + else: + files = [f for f in files if f in self.VALIDATION_FILES] + assert len(files) == len(self.VALIDATION_FILES) + else: + split = "test" + assert len(files) == 50 + + self.n_tracks = len(files) + + super().__init__( + data_path=data_path, + split=split, + stems=stems, + files=files + ) + + def __len__(self) -> int: + return self.n_tracks + +class MUSDB18SadDataset(MUSDB18BaseDataset): + def __init__( + self, + data_root: str, + split: str, + target_stem: str, + stems: Optional[List[str]] = None, + target_length: Optional[int] = None, + npy_memmap=False, + ) -> None: + + if stems is None: + stems = self.ALLOWED_STEMS + + data_path = os.path.join(data_root, target_stem, split) + + files = sorted(os.listdir(data_path)) + files = [f for f in files if not f.startswith(".")] + + super().__init__( + data_path=data_path, + split=split, + stems=stems, + files=files, + npy_memmap=npy_memmap + ) + self.n_segments = len(files) + self.target_stem = target_stem + self.target_length = ( + target_length if target_length is not None else self.n_segments + ) + + def __len__(self) -> int: + return self.target_length + + def __getitem__(self, index: int) -> DataDict: + + index = index % self.n_segments + + return super().__getitem__(index) + + def get_identifier(self, index): + return super().get_identifier(index % self.n_segments) + + +class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset): + def __init__( + self, + data_root: str, + split: str, + target_stem: str, + stems: Optional[List[str]] = None, + target_length: int = 20000, + apply_probability: Optional[float] = None, + chunk_size_second: float = 3.0, + random_scale_range_db: Tuple[float, float] = (-10, 10), + drop_probability: float = 0.1, + rescale: bool = True, + ) -> None: + super().__init__(data_root, split, target_stem, stems) + + if apply_probability is None: + apply_probability = ( + target_length - self.n_segments) / target_length + + self.apply_probability = apply_probability + self.drop_probability = drop_probability + self.chunk_size_second = chunk_size_second + self.random_scale_range_db = random_scale_range_db + self.rescale = rescale + + self.chunk_size_sample = int(self.chunk_size_second * self.fs) + self.target_length = target_length + + def __len__(self) -> int: + return self.target_length + + def __getitem__(self, index: int) -> DataDict: + + index = index % self.n_segments + + # if np.random.rand() > self.apply_probability: + # return super().__getitem__(index) + + audio = {} + identifier = self.get_identifier(index) + + # assert self.target_stem in self.stems_no_mixture + for stem in self.stems_no_mixture: + if stem == self.target_stem: + identifier_ = identifier + else: + if np.random.rand() < self.apply_probability: + index_ = np.random.randint(self.n_segments) + identifier_ = self.get_identifier(index_) + else: + identifier_ = identifier + + audio[stem] = self.get_stem(stem=stem, identifier=identifier_) + + # if stem == self.target_stem: + + if self.chunk_size_sample < audio[stem].shape[-1]: + chunk_start = np.random.randint( + audio[stem].shape[-1] - self.chunk_size_sample + ) + else: + chunk_start = 0 + + if np.random.rand() < self.drop_probability: + # db_scale = "-inf" + linear_scale = 0.0 + else: + db_scale = np.random.uniform(*self.random_scale_range_db) + linear_scale = np.power(10, db_scale / 20) + # db_scale = f"{db_scale:+2.1f}" + # print(linear_scale) + audio[stem][..., + chunk_start: chunk_start + self.chunk_size_sample] = ( + linear_scale + * audio[stem][..., + chunk_start: chunk_start + self.chunk_size_sample] + ) + + audio["mixture"] = self.compute_mixture(audio) + + if self.rescale: + max_abs_val = max( + [torch.max(torch.abs(audio[stem])) for stem in self.stems] + ) # type: ignore[type-var] + if max_abs_val > 1: + audio = {k: v / max_abs_val for k, v in audio.items()} + + track = identifier["track"] + + return {"audio": audio, "track": f"{self.split}/{track}"} + +# if __name__ == "__main__": +# +# from pprint import pprint +# from tqdm.auto import tqdm +# +# for split_ in ["train", "val", "test"]: +# ds = MUSDB18SadOnTheFlyAugmentedDataset( +# data_root="$DATA_ROOT/MUSDB18/HQ/saded", +# split=split_, +# target_stem="vocals" +# ) +# +# print(split_, len(ds)) +# +# for track_ in tqdm(ds): +# track_["audio"] = { +# k: v.shape for k, v in track_["audio"].items() +# } +# pprint(track_) diff --git a/separator/models/bandit/core/data/musdb/preprocess.py b/separator/models/bandit/core/data/musdb/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..5d5892e5c3f4acef5bbb183a746d76475c461810 --- /dev/null +++ b/separator/models/bandit/core/data/musdb/preprocess.py @@ -0,0 +1,238 @@ +import glob +import os + +import numpy as np +import torch +import torchaudio as ta +from torch import nn +from torch.nn import functional as F +from tqdm.contrib.concurrent import process_map + +from core.data._types import DataDict +from core.data.musdb.dataset import MUSDB18FullTrackDataset +import pyloudnorm as pyln + +class SourceActivityDetector(nn.Module): + def __init__( + self, + analysis_stem: str, + output_path: str, + fs: int = 44100, + segment_length_second: float = 6.0, + hop_length_second: float = 3.0, + n_chunks: int = 10, + chunk_epsilon: float = 1e-5, + energy_threshold_quantile: float = 0.15, + segment_epsilon: float = 1e-3, + salient_proportion_threshold: float = 0.5, + target_lufs: float = -24 + ) -> None: + super().__init__() + + self.fs = fs + self.segment_length = int(segment_length_second * self.fs) + self.hop_length = int(hop_length_second * self.fs) + self.n_chunks = n_chunks + assert self.segment_length % self.n_chunks == 0 + self.chunk_size = self.segment_length // self.n_chunks + self.chunk_epsilon = chunk_epsilon + self.energy_threshold_quantile = energy_threshold_quantile + self.segment_epsilon = segment_epsilon + self.salient_proportion_threshold = salient_proportion_threshold + self.analysis_stem = analysis_stem + + self.meter = pyln.Meter(self.fs) + self.target_lufs = target_lufs + + self.output_path = output_path + + def forward(self, data: DataDict) -> None: + + stem_ = self.analysis_stem if ( + self.analysis_stem != "none") else "mixture" + + x = data["audio"][stem_] + + xnp = x.numpy() + loudness = self.meter.integrated_loudness(xnp.T) + + for stem in data["audio"]: + s = data["audio"][stem] + s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T + s = torch.as_tensor(s) + data["audio"][stem] = s + + if x.ndim == 3: + assert x.shape[0] == 1 + x = x[0] + + n_chan, n_samples = x.shape + + n_segments = ( + int( + np.ceil((n_samples - self.segment_length) / self.hop_length) + ) + 1 + ) + + segments = torch.zeros((n_segments, n_chan, self.segment_length)) + for i in range(n_segments): + start = i * self.hop_length + end = start + self.segment_length + end = min(end, n_samples) + + xseg = x[:, start:end] + + if end - start < self.segment_length: + xseg = F.pad( + xseg, + pad=(0, self.segment_length - (end - start)), + value=torch.nan + ) + + segments[i, :, :] = xseg + + chunks = segments.reshape( + (n_segments, n_chan, self.n_chunks, self.chunk_size) + ) + + if self.analysis_stem != "none": + chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3)) + chunk_energies = torch.nan_to_num(chunk_energies, nan=0) + chunk_energies[chunk_energies == 0] = self.chunk_epsilon + + energy_threshold = torch.nanquantile( + chunk_energies, q=self.energy_threshold_quantile + ) + + if energy_threshold < self.segment_epsilon: + energy_threshold = self.segment_epsilon # type: ignore[assignment] + + chunks_above_threshold = chunk_energies > energy_threshold + n_chunks_above_threshold = torch.mean( + chunks_above_threshold.to(torch.float), dim=-1 + ) + + segment_above_threshold = ( + n_chunks_above_threshold > self.salient_proportion_threshold + ) + + if torch.sum(segment_above_threshold) == 0: + return + + else: + segment_above_threshold = torch.ones((n_segments,)) + + for i in range(n_segments): + if not segment_above_threshold[i]: + continue + + outpath = os.path.join( + self.output_path, + self.analysis_stem, + f"{data['track']} - {self.analysis_stem}{i:03d}", + ) + os.makedirs(outpath, exist_ok=True) + + for stem in data["audio"]: + if stem == self.analysis_stem: + segment = torch.nan_to_num(segments[i, :, :], nan=0) + else: + start = i * self.hop_length + end = start + self.segment_length + end = min(n_samples, end) + + segment = data["audio"][stem][:, start:end] + + if end - start < self.segment_length: + segment = F.pad( + segment, + (0, self.segment_length - (end - start)) + ) + + assert segment.shape[-1] == self.segment_length, segment.shape + + # ta.save(os.path.join(outpath, f"{stem}.wav"), segment, self.fs) + + np.save(os.path.join(outpath, f"{stem}.wav"), segment) + + +def preprocess( + analysis_stem: str, + output_path: str = "/data/MUSDB18/HQ/saded-np", + fs: int = 44100, + segment_length_second: float = 6.0, + hop_length_second: float = 3.0, + n_chunks: int = 10, + chunk_epsilon: float = 1e-5, + energy_threshold_quantile: float = 0.15, + segment_epsilon: float = 1e-3, + salient_proportion_threshold: float = 0.5, +) -> None: + + sad = SourceActivityDetector( + analysis_stem=analysis_stem, + output_path=output_path, + fs=fs, + segment_length_second=segment_length_second, + hop_length_second=hop_length_second, + n_chunks=n_chunks, + chunk_epsilon=chunk_epsilon, + energy_threshold_quantile=energy_threshold_quantile, + segment_epsilon=segment_epsilon, + salient_proportion_threshold=salient_proportion_threshold, + ) + + for split in ["train", "val", "test"]: + ds = MUSDB18FullTrackDataset( + data_root="/data/MUSDB18/HQ/canonical", + split=split, + ) + + tracks = [] + for i, track in enumerate(tqdm(ds, total=len(ds))): + if i % 32 == 0 and tracks: + process_map(sad, tracks, max_workers=8) + tracks = [] + tracks.append(track) + process_map(sad, tracks, max_workers=8) + +def loudness_norm_one( + inputs +): + infile, outfile, target_lufs = inputs + + audio, fs = ta.load(infile) + audio = audio.mean(dim=0, keepdim=True).numpy().T + + meter = pyln.Meter(fs) + loudness = meter.integrated_loudness(audio) + audio = pyln.normalize.loudness(audio, loudness, target_lufs) + + os.makedirs(os.path.dirname(outfile), exist_ok=True) + np.save(outfile, audio.T) + +def loudness_norm( + data_path: str, + # output_path: str, + target_lufs = -17.0, +): + files = glob.glob( + os.path.join(data_path, "**", "*.wav"), recursive=True + ) + + outfiles = [ + f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files + ] + + files = [(f, o, target_lufs) for f, o in zip(files, outfiles)] + + process_map(loudness_norm_one, files, chunksize=2) + + + +if __name__ == "__main__": + + from tqdm.auto import tqdm + import fire + + fire.Fire() diff --git a/separator/models/bandit/core/data/musdb/validation.yaml b/separator/models/bandit/core/data/musdb/validation.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2f8752478d285d1d13d5e842225af1de95cae57a --- /dev/null +++ b/separator/models/bandit/core/data/musdb/validation.yaml @@ -0,0 +1,15 @@ +validation: + - 'Actions - One Minute Smile' + - 'Clara Berry And Wooldog - Waltz For My Victims' + - 'Johnny Lokke - Promises & Lies' + - 'Patrick Talbot - A Reason To Leave' + - 'Triviul - Angelsaint' + - 'Alexander Ross - Goodbye Bolero' + - 'Fergessen - Nos Palpitants' + - 'Leaf - Summerghost' + - 'Skelpolu - Human Mistakes' + - 'Young Griffo - Pennies' + - 'ANiMAL - Rockshow' + - 'James May - On The Line' + - 'Meaxic - Take A Step' + - 'Traffic Experiment - Sirens' \ No newline at end of file diff --git a/separator/models/bandit/core/loss/__init__.py b/separator/models/bandit/core/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0ab803aecde4f686e34d93f3f2d585e0a9867525 --- /dev/null +++ b/separator/models/bandit/core/loss/__init__.py @@ -0,0 +1,2 @@ +from ._multistem import MultiStemWrapperFromConfig +from ._timefreq import ReImL1Loss, ReImL2Loss, TimeFreqL1Loss, TimeFreqL2Loss, TimeFreqSignalNoisePNormRatioLoss diff --git a/separator/models/bandit/core/loss/_complex.py b/separator/models/bandit/core/loss/_complex.py new file mode 100644 index 0000000000000000000000000000000000000000..1d97e5d8bab3fdb095c2ba7c77faebef26e8f1ce --- /dev/null +++ b/separator/models/bandit/core/loss/_complex.py @@ -0,0 +1,34 @@ +from typing import Any + +import torch +from torch import nn +from torch.nn.modules import loss as _loss +from torch.nn.modules.loss import _Loss + + +class ReImLossWrapper(_Loss): + def __init__(self, module: _Loss) -> None: + super().__init__() + self.module = module + + def forward( + self, + preds: torch.Tensor, + target: torch.Tensor + ) -> torch.Tensor: + return self.module( + torch.view_as_real(preds), + torch.view_as_real(target) + ) + + +class ReImL1Loss(ReImLossWrapper): + def __init__(self, **kwargs: Any) -> None: + l1_loss = _loss.L1Loss(**kwargs) + super().__init__(module=(l1_loss)) + + +class ReImL2Loss(ReImLossWrapper): + def __init__(self, **kwargs: Any) -> None: + l2_loss = _loss.MSELoss(**kwargs) + super().__init__(module=(l2_loss)) diff --git a/separator/models/bandit/core/loss/_multistem.py b/separator/models/bandit/core/loss/_multistem.py new file mode 100644 index 0000000000000000000000000000000000000000..675e0ffbecf1f9f5efb0369bcb9d5c590efcfc31 --- /dev/null +++ b/separator/models/bandit/core/loss/_multistem.py @@ -0,0 +1,45 @@ +from typing import Any, Dict + +import torch +from asteroid import losses as asteroid_losses +from torch import nn +from torch.nn.modules.loss import _Loss + +from . import snr + + +def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss: + + for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]: + if name in module.__dict__: + return module.__dict__[name](**kwargs) + + raise NameError + + +class MultiStemWrapper(_Loss): + def __init__(self, module: _Loss, modality: str = "audio") -> None: + super().__init__() + self.loss = module + self.modality = modality + + def forward( + self, + preds: Dict[str, Dict[str, torch.Tensor]], + target: Dict[str, Dict[str, torch.Tensor]], + ) -> torch.Tensor: + loss = { + stem: self.loss( + preds[self.modality][stem], + target[self.modality][stem] + ) + for stem in preds[self.modality] if stem in target[self.modality] + } + + return sum(list(loss.values())) + + +class MultiStemWrapperFromConfig(MultiStemWrapper): + def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None: + loss = parse_loss(name, kwargs) + super().__init__(module=loss, modality=modality) diff --git a/separator/models/bandit/core/loss/_timefreq.py b/separator/models/bandit/core/loss/_timefreq.py new file mode 100644 index 0000000000000000000000000000000000000000..6ea9d5994ca645546b5ccb7e6eafaa3d2fbcf959 --- /dev/null +++ b/separator/models/bandit/core/loss/_timefreq.py @@ -0,0 +1,113 @@ +from typing import Any, Dict, Optional + +import torch +from torch import nn +from torch.nn.modules.loss import _Loss + +from models.bandit.core.loss._multistem import MultiStemWrapper +from models.bandit.core.loss._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper +from models.bandit.core.loss.snr import SignalNoisePNormRatio + +class TimeFreqWrapper(_Loss): + def __init__( + self, + time_module: _Loss, + freq_module: Optional[_Loss] = None, + time_weight: float = 1.0, + freq_weight: float = 1.0, + multistem: bool = True, + ) -> None: + super().__init__() + + if freq_module is None: + freq_module = time_module + + if multistem: + time_module = MultiStemWrapper(time_module, modality="audio") + freq_module = MultiStemWrapper(freq_module, modality="spectrogram") + + self.time_module = time_module + self.freq_module = freq_module + + self.time_weight = time_weight + self.freq_weight = freq_weight + + # TODO: add better type hints + def forward(self, preds: Any, target: Any) -> torch.Tensor: + + return self.time_weight * self.time_module( + preds, target + ) + self.freq_weight * self.freq_module(preds, target) + + +class TimeFreqL1Loss(TimeFreqWrapper): + def __init__( + self, + time_weight: float = 1.0, + freq_weight: float = 1.0, + tkwargs: Optional[Dict[str, Any]] = None, + fkwargs: Optional[Dict[str, Any]] = None, + multistem: bool = True, + ) -> None: + if tkwargs is None: + tkwargs = {} + if fkwargs is None: + fkwargs = {} + time_module = (nn.L1Loss(**tkwargs)) + freq_module = ReImL1Loss(**fkwargs) + super().__init__( + time_module, + freq_module, + time_weight, + freq_weight, + multistem + ) + + +class TimeFreqL2Loss(TimeFreqWrapper): + def __init__( + self, + time_weight: float = 1.0, + freq_weight: float = 1.0, + tkwargs: Optional[Dict[str, Any]] = None, + fkwargs: Optional[Dict[str, Any]] = None, + multistem: bool = True, + ) -> None: + if tkwargs is None: + tkwargs = {} + if fkwargs is None: + fkwargs = {} + time_module = nn.MSELoss(**tkwargs) + freq_module = ReImL2Loss(**fkwargs) + super().__init__( + time_module, + freq_module, + time_weight, + freq_weight, + multistem + ) + + + +class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper): + def __init__( + self, + time_weight: float = 1.0, + freq_weight: float = 1.0, + tkwargs: Optional[Dict[str, Any]] = None, + fkwargs: Optional[Dict[str, Any]] = None, + multistem: bool = True, + ) -> None: + if tkwargs is None: + tkwargs = {} + if fkwargs is None: + fkwargs = {} + time_module = SignalNoisePNormRatio(**tkwargs) + freq_module = SignalNoisePNormRatio(**fkwargs) + super().__init__( + time_module, + freq_module, + time_weight, + freq_weight, + multistem + ) diff --git a/separator/models/bandit/core/loss/snr.py b/separator/models/bandit/core/loss/snr.py new file mode 100644 index 0000000000000000000000000000000000000000..2996dd57080db687599c1fd673d6807041a04b52 --- /dev/null +++ b/separator/models/bandit/core/loss/snr.py @@ -0,0 +1,146 @@ +import torch +from torch.nn.modules.loss import _Loss +from torch.nn import functional as F + +class SignalNoisePNormRatio(_Loss): + def __init__( + self, + p: float = 1.0, + scale_invariant: bool = False, + zero_mean: bool = False, + take_log: bool = True, + reduction: str = "mean", + EPS: float = 1e-3, + ) -> None: + assert reduction != "sum", NotImplementedError + super().__init__(reduction=reduction) + assert not zero_mean + + self.p = p + + self.EPS = EPS + self.take_log = take_log + + self.scale_invariant = scale_invariant + + def forward( + self, + est_target: torch.Tensor, + target: torch.Tensor + ) -> torch.Tensor: + + target_ = target + if self.scale_invariant: + ndim = target.ndim + dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True) + s_target_energy = ( + torch.sum(target * torch.conj(target), dim=-1, keepdim=True) + ) + + if ndim > 2: + dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True) + s_target_energy = torch.sum(s_target_energy, dim=list(range(1, ndim)), keepdim=True) + + target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8) + target = target_ * target_scaler + + if torch.is_complex(est_target): + est_target = torch.view_as_real(est_target) + target = torch.view_as_real(target) + + + batch_size = est_target.shape[0] + est_target = est_target.reshape(batch_size, -1) + target = target.reshape(batch_size, -1) + # target_ = target_.reshape(batch_size, -1) + + if self.p == 1: + e_error = torch.abs(est_target-target).mean(dim=-1) + e_target = torch.abs(target).mean(dim=-1) + elif self.p == 2: + e_error = torch.square(est_target-target).mean(dim=-1) + e_target = torch.square(target).mean(dim=-1) + else: + raise NotImplementedError + + if self.take_log: + loss = 10*(torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS)) + else: + loss = (e_error + self.EPS)/(e_target + self.EPS) + + if self.reduction == "mean": + loss = loss.mean() + elif self.reduction == "sum": + loss = loss.sum() + + return loss + + + +class MultichannelSingleSrcNegSDR(_Loss): + def __init__( + self, + sdr_type: str, + p: float = 2.0, + zero_mean: bool = True, + take_log: bool = True, + reduction: str = "mean", + EPS: float = 1e-8, + ) -> None: + assert reduction != "sum", NotImplementedError + super().__init__(reduction=reduction) + + assert sdr_type in ["snr", "sisdr", "sdsdr"] + self.sdr_type = sdr_type + self.zero_mean = zero_mean + self.take_log = take_log + self.EPS = 1e-8 + + self.p = p + + def forward( + self, + est_target: torch.Tensor, + target: torch.Tensor + ) -> torch.Tensor: + if target.size() != est_target.size() or target.ndim != 3: + raise TypeError( + f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead" + ) + # Step 1. Zero-mean norm + if self.zero_mean: + mean_source = torch.mean(target, dim=[1, 2], keepdim=True) + mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True) + target = target - mean_source + est_target = est_target - mean_estimate + # Step 2. Pair-wise SI-SDR. + if self.sdr_type in ["sisdr", "sdsdr"]: + # [batch, 1] + dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True) + # [batch, 1] + s_target_energy = ( + torch.sum(target ** 2, dim=[1, 2], keepdim=True) + self.EPS + ) + # [batch, time] + scaled_target = dot * target / s_target_energy + else: + # [batch, time] + scaled_target = target + if self.sdr_type in ["sdsdr", "snr"]: + e_noise = est_target - target + else: + e_noise = est_target - scaled_target + # [batch] + + if self.p == 2.0: + losses = torch.sum(scaled_target ** 2, dim=[1, 2]) / ( + torch.sum(e_noise ** 2, dim=[1, 2]) + self.EPS + ) + else: + losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / ( + torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS + ) + if self.take_log: + losses = 10 * torch.log10(losses + self.EPS) + losses = losses.mean() if self.reduction == "mean" else losses + return -losses diff --git a/separator/models/bandit/core/metrics/__init__.py b/separator/models/bandit/core/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c638b4df585ad6c3c6490d9e67b7fc197f0d06f4 --- /dev/null +++ b/separator/models/bandit/core/metrics/__init__.py @@ -0,0 +1,9 @@ +from .snr import ( + ChunkMedianScaleInvariantSignalDistortionRatio, + ChunkMedianScaleInvariantSignalNoiseRatio, + ChunkMedianSignalDistortionRatio, + ChunkMedianSignalNoiseRatio, + SafeSignalDistortionRatio, +) + +# from .mushra import EstimatedMushraScore diff --git a/separator/models/bandit/core/metrics/_squim.py b/separator/models/bandit/core/metrics/_squim.py new file mode 100644 index 0000000000000000000000000000000000000000..ec76b5fb5e27d0f6a6aaa5ececc5161482150bfc --- /dev/null +++ b/separator/models/bandit/core/metrics/_squim.py @@ -0,0 +1,383 @@ +from dataclasses import dataclass + +from torchaudio._internal import load_state_dict_from_url + +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def transform_wb_pesq_range(x: float) -> float: + """The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined + for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric + defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score". + + Args: + x (float): Narrow-band PESQ score. + + Returns: + (float): Wide-band PESQ score. + """ + return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224)) + + +PESQRange: Tuple[float, float] = ( + 1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of + # the raw score is not -0.5 anymore. It's hard to figure out the true lower bound. + # We are using 1.0 as a reasonable approximation. + transform_wb_pesq_range(4.5), +) + + +class RangeSigmoid(nn.Module): + def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None: + super(RangeSigmoid, self).__init__() + assert isinstance(val_range, tuple) and len(val_range) == 2 + self.val_range: Tuple[float, float] = val_range + self.sigmoid: nn.modules.Module = nn.Sigmoid() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0] + return out + + +class Encoder(nn.Module): + """Encoder module that transform 1D waveform to 2D representations. + + Args: + feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512) + win_len (int, optional): kernel size in the Conv1D layer. (Default: 32) + """ + + def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None: + super(Encoder, self).__init__() + + self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply waveforms to convolutional layer and ReLU layer. + + Args: + x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`. + + Returns: + (torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`. + """ + out = x.unsqueeze(dim=1) + out = F.relu(self.conv1d(out)) + return out + + +class SingleRNN(nn.Module): + def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None: + super(SingleRNN, self).__init__() + + self.rnn_type = rnn_type + self.input_size = input_size + self.hidden_size = hidden_size + + self.rnn: nn.modules.Module = getattr(nn, rnn_type)( + input_size, + hidden_size, + 1, + dropout=dropout, + batch_first=True, + bidirectional=True, + ) + + self.proj = nn.Linear(hidden_size * 2, input_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # input shape: batch, seq, dim + out, _ = self.rnn(x) + out = self.proj(out) + return out + + +class DPRNN(nn.Module): + """*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`. + + Args: + feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64) + hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128) + num_blocks (int, optional): Number of DPRNN layers. (Default: 6) + rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM") + d_model (int, optional): The number of expected features in the input. (Default: 256) + chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100) + chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50) + """ + + def __init__( + self, + feat_dim: int = 64, + hidden_dim: int = 128, + num_blocks: int = 6, + rnn_type: str = "LSTM", + d_model: int = 256, + chunk_size: int = 100, + chunk_stride: int = 50, + ) -> None: + super(DPRNN, self).__init__() + + self.num_blocks = num_blocks + + self.row_rnn = nn.ModuleList([]) + self.col_rnn = nn.ModuleList([]) + self.row_norm = nn.ModuleList([]) + self.col_norm = nn.ModuleList([]) + for _ in range(num_blocks): + self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim)) + self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim)) + self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8)) + self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8)) + self.conv = nn.Sequential( + nn.Conv2d(feat_dim, d_model, 1), + nn.PReLU(), + ) + self.chunk_size = chunk_size + self.chunk_stride = chunk_stride + + def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: + # input shape: (B, N, T) + seq_len = x.shape[-1] + + rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size + out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride]) + + return out, rest + + def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: + out, rest = self.pad_chunk(x) + batch_size, feat_dim, seq_len = out.shape + + segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size) + segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size) + out = torch.cat([segments1, segments2], dim=3) + out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous() + + return out, rest + + def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor: + batch_size, dim, _, _ = x.shape + out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2) + out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :] + out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride] + out = out1 + out2 + if rest > 0: + out = out[:, :, :-rest] + out = out.contiguous() + return out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, rest = self.chunking(x) + batch_size, _, dim1, dim2 = x.shape + out = x + for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm): + row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous() + row_out = row_rnn(row_in) + row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous() + row_out = row_norm(row_out) + out = out + row_out + + col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous() + col_out = col_rnn(col_in) + col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous() + col_out = col_norm(col_out) + out = out + col_out + out = self.conv(out) + out = self.merging(out, rest) + out = out.transpose(1, 2).contiguous() + return out + + +class AutoPool(nn.Module): + def __init__(self, pool_dim: int = 1) -> None: + super(AutoPool, self).__init__() + self.pool_dim: int = pool_dim + self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim) + self.register_parameter("alpha", nn.Parameter(torch.ones(1))) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + weight = self.softmax(torch.mul(x, self.alpha)) + out = torch.sum(torch.mul(x, weight), dim=self.pool_dim) + return out + + +class SquimObjective(nn.Module): + """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores + for speech enhancement (e.g., STOI, PESQ, and SI-SDR). + + Args: + encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation. + dprnn (torch.nn.Module): DPRNN module to model sequential feature. + branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score. + """ + + def __init__( + self, + encoder: nn.Module, + dprnn: nn.Module, + branches: nn.ModuleList, + ): + super(SquimObjective, self).__init__() + self.encoder = encoder + self.dprnn = dprnn + self.branches = branches + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """ + Args: + x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`. + + Returns: + List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`. + """ + if x.ndim != 2: + raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.") + x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20) + out = self.encoder(x) + out = self.dprnn(out) + scores = [] + for branch in self.branches: + scores.append(branch(out).squeeze(dim=1)) + return scores + + +def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module: + """Create branch module after DPRNN model for predicting metric score. + + Args: + d_model (int): The number of expected features in the input. + nhead (int): Number of heads in the multi-head attention model. + metric (str): The metric name to predict. + + Returns: + (nn.Module): Returned module to predict corresponding metric score. + """ + layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True) + layer2 = AutoPool() + if metric == "stoi": + layer3 = nn.Sequential( + nn.Linear(d_model, d_model), + nn.PReLU(), + nn.Linear(d_model, 1), + RangeSigmoid(), + ) + elif metric == "pesq": + layer3 = nn.Sequential( + nn.Linear(d_model, d_model), + nn.PReLU(), + nn.Linear(d_model, 1), + RangeSigmoid(val_range=PESQRange), + ) + else: + layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1)) + return nn.Sequential(layer1, layer2, layer3) + + +def squim_objective_model( + feat_dim: int, + win_len: int, + d_model: int, + nhead: int, + hidden_dim: int, + num_blocks: int, + rnn_type: str, + chunk_size: int, + chunk_stride: Optional[int] = None, +) -> SquimObjective: + """Build a custome :class:`torchaudio.prototype.models.SquimObjective` model. + + Args: + feat_dim (int, optional): The feature dimension after Encoder module. + win_len (int): Kernel size in the Encoder module. + d_model (int): The number of expected features in the input. + nhead (int): Number of heads in the multi-head attention model. + hidden_dim (int): Hidden dimension in the RNN layer of DPRNN. + num_blocks (int): Number of DPRNN layers. + rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. + chunk_size (int): Chunk size of input for DPRNN. + chunk_stride (int or None, optional): Stride of chunk input for DPRNN. + """ + if chunk_stride is None: + chunk_stride = chunk_size // 2 + encoder = Encoder(feat_dim, win_len) + dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride) + branches = nn.ModuleList( + [ + _create_branch(d_model, nhead, "stoi"), + _create_branch(d_model, nhead, "pesq"), + _create_branch(d_model, nhead, "sisdr"), + ] + ) + return SquimObjective(encoder, dprnn, branches) + + +def squim_objective_base() -> SquimObjective: + """Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments.""" + return squim_objective_model( + feat_dim=256, + win_len=64, + d_model=256, + nhead=4, + hidden_dim=256, + num_blocks=2, + rnn_type="LSTM", + chunk_size=71, + ) + +@dataclass +class SquimObjectiveBundle: + + _path: str + _sample_rate: float + + def _get_state_dict(self, dl_kwargs): + url = f"https://download.pytorch.org/torchaudio/models/{self._path}" + dl_kwargs = {} if dl_kwargs is None else dl_kwargs + state_dict = load_state_dict_from_url(url, **dl_kwargs) + return state_dict + + def get_model(self, *, dl_kwargs=None) -> SquimObjective: + """Construct the SquimObjective model, and load the pretrained weight. + + The weight file is downloaded from the internet and cached with + :func:`torch.hub.load_state_dict_from_url` + + Args: + dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. + + Returns: + Variation of :py:class:`~torchaudio.models.SquimObjective`. + """ + model = squim_objective_base() + model.load_state_dict(self._get_state_dict(dl_kwargs)) + model.eval() + return model + + @property + def sample_rate(self): + """Sample rate of the audio that the model is trained on. + + :type: float + """ + return self._sample_rate + + +SQUIM_OBJECTIVE = SquimObjectiveBundle( + "squim_objective_dns2020.pth", + _sample_rate=16000, +) +SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in + :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`. + + The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`. + The weights are under `Creative Commons Attribution 4.0 International License + `__. + + Please refer to :py:class:`SquimObjectiveBundle` for usage instructions. + """ + diff --git a/separator/models/bandit/core/metrics/snr.py b/separator/models/bandit/core/metrics/snr.py new file mode 100644 index 0000000000000000000000000000000000000000..d2830b2cbecfa681c449d09e2d4c35a20fc98128 --- /dev/null +++ b/separator/models/bandit/core/metrics/snr.py @@ -0,0 +1,150 @@ +from typing import Any, Callable + +import numpy as np +import torch +import torchmetrics as tm +from torch._C import _LinAlgError +from torchmetrics import functional as tmF + + +class SafeSignalDistortionRatio(tm.SignalDistortionRatio): + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def update(self, *args, **kwargs) -> Any: + try: + super().update(*args, **kwargs) + except: + pass + + def compute(self) -> Any: + if self.total == 0: + return torch.tensor(torch.nan) + return super().compute() + + +class BaseChunkMedianSignalRatio(tm.Metric): + def __init__( + self, + func: Callable, + window_size: int, + hop_size: int = None, + zero_mean: bool = False, + ) -> None: + super().__init__() + + # self.zero_mean = zero_mean + self.func = func + self.window_size = window_size + if hop_size is None: + hop_size = window_size + self.hop_size = hop_size + + self.add_state( + "sum_snr", + default=torch.tensor(0.0), + dist_reduce_fx="sum" + ) + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + + n_samples = target.shape[-1] + + n_chunks = int( + np.ceil((n_samples - self.window_size) / self.hop_size) + 1 + ) + + snr_chunk = [] + + for i in range(n_chunks): + start = i * self.hop_size + + if n_samples - start < self.window_size: + continue + + end = start + self.window_size + + try: + chunk_snr = self.func( + preds[..., start:end], + target[..., start:end] + ) + + # print(preds.shape, chunk_snr.shape) + + if torch.all(torch.isfinite(chunk_snr)): + snr_chunk.append(chunk_snr) + except _LinAlgError: + pass + + snr_chunk = torch.stack(snr_chunk, dim=-1) + snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1) + + self.sum_snr += snr_batch.sum() + self.total += snr_batch.numel() + + def compute(self) -> Any: + return self.sum_snr / self.total + + +class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio): + def __init__( + self, + window_size: int, + hop_size: int = None, + zero_mean: bool = False + ) -> None: + super().__init__( + func=tmF.signal_noise_ratio, + window_size=window_size, + hop_size=hop_size, + zero_mean=zero_mean, + ) + + +class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio): + def __init__( + self, + window_size: int, + hop_size: int = None, + zero_mean: bool = False + ) -> None: + super().__init__( + func=tmF.scale_invariant_signal_noise_ratio, + window_size=window_size, + hop_size=hop_size, + zero_mean=zero_mean, + ) + + +class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio): + def __init__( + self, + window_size: int, + hop_size: int = None, + zero_mean: bool = False + ) -> None: + super().__init__( + func=tmF.signal_distortion_ratio, + window_size=window_size, + hop_size=hop_size, + zero_mean=zero_mean, + ) + + +class ChunkMedianScaleInvariantSignalDistortionRatio( + BaseChunkMedianSignalRatio + ): + def __init__( + self, + window_size: int, + hop_size: int = None, + zero_mean: bool = False + ) -> None: + super().__init__( + func=tmF.scale_invariant_signal_distortion_ratio, + window_size=window_size, + hop_size=hop_size, + zero_mean=zero_mean, + ) diff --git a/separator/models/bandit/core/model/__init__.py b/separator/models/bandit/core/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..54ac48eb69d6f844ba5b73b213eae4cfab157cac --- /dev/null +++ b/separator/models/bandit/core/model/__init__.py @@ -0,0 +1,3 @@ +from .bsrnn.wrapper import ( + MultiMaskMultiSourceBandSplitRNNSimple, +) diff --git a/separator/models/bandit/core/model/_spectral.py b/separator/models/bandit/core/model/_spectral.py new file mode 100644 index 0000000000000000000000000000000000000000..564cd28600719579227a6085eed5e9d6ee521312 --- /dev/null +++ b/separator/models/bandit/core/model/_spectral.py @@ -0,0 +1,58 @@ +from typing import Dict, Optional + +import torch +import torchaudio as ta +from torch import nn + + +class _SpectralComponent(nn.Module): + def __init__( + self, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + **kwargs, + ) -> None: + super().__init__() + + assert power is None + + window_fn = torch.__dict__[window_fn] + + self.stft = ( + ta.transforms.Spectrogram( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + pad_mode=pad_mode, + pad=0, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + normalized=normalized, + center=center, + onesided=onesided, + ) + ) + + self.istft = ( + ta.transforms.InverseSpectrogram( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + pad_mode=pad_mode, + pad=0, + window_fn=window_fn, + wkwargs=wkwargs, + normalized=normalized, + center=center, + onesided=onesided, + ) + ) diff --git a/separator/models/bandit/core/model/bsrnn/__init__.py b/separator/models/bandit/core/model/bsrnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c27826197fc8f4eb7a7036d8037966a58d8b38d4 --- /dev/null +++ b/separator/models/bandit/core/model/bsrnn/__init__.py @@ -0,0 +1,23 @@ +from abc import ABC +from typing import Iterable, Mapping, Union + +from torch import nn + +from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule +from models.bandit.core.model.bsrnn.tfmodel import ( + SeqBandModellingModule, + TransformerTimeFreqModule, +) + + +class BandsplitCoreBase(nn.Module, ABC): + band_split: nn.Module + tf_model: nn.Module + mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]] + + def __init__(self) -> None: + super().__init__() + + @staticmethod + def mask(x, m): + return x * m diff --git a/separator/models/bandit/core/model/bsrnn/bandsplit.py b/separator/models/bandit/core/model/bsrnn/bandsplit.py new file mode 100644 index 0000000000000000000000000000000000000000..63e6255857fb2d538634be317332afb2f93e145d --- /dev/null +++ b/separator/models/bandit/core/model/bsrnn/bandsplit.py @@ -0,0 +1,139 @@ +from typing import List, Tuple + +import torch +from torch import nn + +from models.bandit.core.model.bsrnn.utils import ( + band_widths_from_specs, + check_no_gap, + check_no_overlap, + check_nonzero_bandwidth, +) + + +class NormFC(nn.Module): + def __init__( + self, + emb_dim: int, + bandwidth: int, + in_channel: int, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + ) -> None: + super().__init__() + + self.treat_channel_as_feature = treat_channel_as_feature + + if normalize_channel_independently: + raise NotImplementedError + + reim = 2 + + self.norm = nn.LayerNorm(in_channel * bandwidth * reim) + + fc_in = bandwidth * reim + + if treat_channel_as_feature: + fc_in *= in_channel + else: + assert emb_dim % in_channel == 0 + emb_dim = emb_dim // in_channel + + self.fc = nn.Linear(fc_in, emb_dim) + + def forward(self, xb): + # xb = (batch, n_time, in_chan, reim * band_width) + + batch, n_time, in_chan, ribw = xb.shape + xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw)) + # (batch, n_time, in_chan * reim * band_width) + + if not self.treat_channel_as_feature: + xb = xb.reshape(batch, n_time, in_chan, ribw) + # (batch, n_time, in_chan, reim * band_width) + + zb = self.fc(xb) + # (batch, n_time, emb_dim) + # OR + # (batch, n_time, in_chan, emb_dim_per_chan) + + if not self.treat_channel_as_feature: + batch, n_time, in_chan, emb_dim_per_chan = zb.shape + # (batch, n_time, in_chan, emb_dim_per_chan) + zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan)) + + return zb # (batch, n_time, emb_dim) + + +class BandSplitModule(nn.Module): + def __init__( + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + in_channel: int, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + ) -> None: + super().__init__() + + check_nonzero_bandwidth(band_specs) + + if require_no_gap: + check_no_gap(band_specs) + + if require_no_overlap: + check_no_overlap(band_specs) + + self.band_specs = band_specs + # list of [fstart, fend) in index. + # Note that fend is exclusive. + self.band_widths = band_widths_from_specs(band_specs) + self.n_bands = len(band_specs) + self.emb_dim = emb_dim + + self.norm_fc_modules = nn.ModuleList( + [ # type: ignore + ( + NormFC( + emb_dim=emb_dim, + bandwidth=bw, + in_channel=in_channel, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + ) + ) + for bw in self.band_widths + ] + ) + + def forward(self, x: torch.Tensor): + # x = complex spectrogram (batch, in_chan, n_freq, n_time) + + batch, in_chan, _, n_time = x.shape + + z = torch.zeros( + size=(batch, self.n_bands, n_time, self.emb_dim), + device=x.device + ) + + xr = torch.view_as_real(x) # batch, in_chan, n_freq, n_time, 2 + xr = torch.permute( + xr, + (0, 3, 1, 4, 2) + ) # batch, n_time, in_chan, 2, n_freq + batch, n_time, in_chan, reim, band_width = xr.shape + for i, nfm in enumerate(self.norm_fc_modules): + # print(f"bandsplit/band{i:02d}") + fstart, fend = self.band_specs[i] + xb = xr[..., fstart:fend] + # (batch, n_time, in_chan, reim, band_width) + xb = torch.reshape(xb, (batch, n_time, in_chan, -1)) + # (batch, n_time, in_chan, reim * band_width) + # z.append(nfm(xb)) # (batch, n_time, emb_dim) + z[:, i, :, :] = nfm(xb.contiguous()) + + # z = torch.stack(z, dim=1) + + return z diff --git a/separator/models/bandit/core/model/bsrnn/core.py b/separator/models/bandit/core/model/bsrnn/core.py new file mode 100644 index 0000000000000000000000000000000000000000..7fd36259002a395e7b7864f605fcab5b4422e422 --- /dev/null +++ b/separator/models/bandit/core/model/bsrnn/core.py @@ -0,0 +1,661 @@ +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn +from torch.nn import functional as F + +from models.bandit.core.model.bsrnn import BandsplitCoreBase +from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule +from models.bandit.core.model.bsrnn.maskestim import ( + MaskEstimationModule, + OverlappingMaskEstimationModule +) +from models.bandit.core.model.bsrnn.tfmodel import ( + ConvolutionalTimeFreqModule, + SeqBandModellingModule, + TransformerTimeFreqModule +) + + +class MultiMaskBandSplitCoreBase(BandsplitCoreBase): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, cond=None, compute_residual: bool = True): + # x = complex spectrogram (batch, in_chan, n_freq, n_time) + # print(x.shape) + batch, in_chan, n_freq, n_time = x.shape + x = torch.reshape(x, (-1, 1, n_freq, n_time)) + + z = self.band_split(x) # (batch, emb_dim, n_band, n_time) + + # if torch.any(torch.isnan(z)): + # raise ValueError("z nan") + + # print(z) + q = self.tf_model(z) # (batch, emb_dim, n_band, n_time) + # print(q) + + + # if torch.any(torch.isnan(q)): + # raise ValueError("q nan") + + out = {} + + for stem, mem in self.mask_estim.items(): + m = mem(q, cond=cond) + + # if torch.any(torch.isnan(m)): + # raise ValueError("m nan", stem) + + s = self.mask(x, m) + s = torch.reshape(s, (batch, in_chan, n_freq, n_time)) + out[stem] = s + + return {"spectrogram": out} + + + + def instantiate_mask_estim(self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + emb_dim: int, + mlp_dim: int, + cond_dim: int, + hidden_activation: str, + + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + use_freq_weights: bool = True, + mult_add_mask: bool = False + ): + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + + if "mne:+" in stems: + stems = [s for s in stems if s != "mne:+"] + + if overlapping_band: + assert freq_weights is not None + assert n_freq is not None + + if mult_add_mask: + + self.mask_estim = nn.ModuleDict( + { + stem: MultAddMaskEstimationModule( + band_specs=band_specs, + freq_weights=freq_weights, + n_freq=n_freq, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + use_freq_weights=use_freq_weights, + ) + for stem in stems + } + ) + else: + self.mask_estim = nn.ModuleDict( + { + stem: OverlappingMaskEstimationModule( + band_specs=band_specs, + freq_weights=freq_weights, + n_freq=n_freq, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + use_freq_weights=use_freq_weights, + ) + for stem in stems + } + ) + else: + self.mask_estim = nn.ModuleDict( + { + stem: MaskEstimationModule( + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + for stem in stems + } + ) + + def instantiate_bandsplit(self, + in_channel: int, + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + emb_dim: int = 128 + ): + self.band_split = BandSplitModule( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + +class SingleMaskBandsplitCoreBase(BandsplitCoreBase): + def __init__(self, **kwargs) -> None: + super().__init__() + + def forward(self, x): + # x = complex spectrogram (batch, in_chan, n_freq, n_time) + z = self.band_split(x) # (batch, emb_dim, n_band, n_time) + q = self.tf_model(z) # (batch, emb_dim, n_band, n_time) + m = self.mask_estim(q) # (batch, in_chan, n_freq, n_time) + + s = self.mask(x, m) + + return s + + +class SingleMaskBandsplitCoreRNN( + SingleMaskBandsplitCoreBase, +): + def __init__( + self, + in_channel: int, + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + ) -> None: + super().__init__() + self.band_split = (BandSplitModule( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + )) + self.tf_model = (SeqBandModellingModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + )) + self.mask_estim = (MaskEstimationModule( + in_channel=in_channel, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + )) + + +class SingleMaskBandsplitCoreTransformer( + SingleMaskBandsplitCoreBase, +): + def __init__( + self, + in_channel: int, + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + tf_dropout: float = 0.0, + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + ) -> None: + super().__init__() + self.band_split = BandSplitModule( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + self.tf_model = TransformerTimeFreqModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=tf_dropout, + ) + self.mask_estim = MaskEstimationModule( + in_channel=in_channel, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + + +class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + cond_dim: int = 0, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + use_freq_weights: bool = True, + mult_add_mask: bool = False + ) -> None: + + super().__init__() + self.instantiate_bandsplit( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim + ) + + + self.tf_model = ( + SeqBandModellingModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) + ) + + self.mult_add_mask = mult_add_mask + + self.instantiate_mask_estim( + in_channel=in_channel, + stems=stems, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=overlapping_band, + freq_weights=freq_weights, + n_freq=n_freq, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask + ) + + @staticmethod + def _mult_add_mask(x, m): + + assert m.ndim == 5 + + mm = m[..., 0] + am = m[..., 1] + + # print(mm.shape, am.shape, x.shape, m.shape) + + return x * mm + am + + def mask(self, x, m): + if self.mult_add_mask: + + return self._mult_add_mask(x, m) + else: + return super().mask(x, m) + + +class MultiSourceMultiMaskBandSplitCoreTransformer( + MultiMaskBandSplitCoreBase, +): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + tf_dropout: float = 0.0, + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + use_freq_weights:bool=True, + rnn_type: str = "LSTM", + cond_dim: int = 0, + mult_add_mask: bool = False + ) -> None: + super().__init__() + self.instantiate_bandsplit( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim + ) + self.tf_model = TransformerTimeFreqModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=tf_dropout, + ) + + self.instantiate_mask_estim( + in_channel=in_channel, + stems=stems, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=overlapping_band, + freq_weights=freq_weights, + n_freq=n_freq, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask + ) + + + +class MultiSourceMultiMaskBandSplitCoreConv( + MultiMaskBandSplitCoreBase, +): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + tf_dropout: float = 0.0, + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + use_freq_weights:bool=True, + rnn_type: str = "LSTM", + cond_dim: int = 0, + mult_add_mask: bool = False + ) -> None: + super().__init__() + self.instantiate_bandsplit( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim + ) + self.tf_model = ConvolutionalTimeFreqModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=tf_dropout, + ) + + self.instantiate_mask_estim( + in_channel=in_channel, + stems=stems, + band_specs=band_specs, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=overlapping_band, + freq_weights=freq_weights, + n_freq=n_freq, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask + ) + + +class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase): + def __init__(self) -> None: + super().__init__() + + def mask(self, x, m): + # x.shape = (batch, n_channel, n_freq, n_time) + # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time) + + _, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape + padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2) + + xf = F.unfold( + x, + kernel_size=(kernel_freq, kernel_time), + padding=padding, + stride=(1, 1), + ) + + xf = xf.view( + -1, + n_channel, + kernel_freq, + kernel_time, + n_freq, + n_time, + ) + + sf = xf * m + + sf = sf.view( + -1, + n_channel * kernel_freq * kernel_time, + n_freq * n_time, + ) + + s = F.fold( + sf, + output_size=(n_freq, n_time), + kernel_size=(kernel_freq, kernel_time), + padding=padding, + stride=(1, 1), + ).view( + -1, + n_channel, + n_freq, + n_time, + ) + + return s + + def old_mask(self, x, m): + # x.shape = (batch, n_channel, n_freq, n_time) + # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time) + + s = torch.zeros_like(x) + + _, n_channel, n_freq, n_time = x.shape + kernel_freq, kernel_time, _, _, _, _ = m.shape + + # print(x.shape, m.shape) + + kernel_freq_half = (kernel_freq - 1) // 2 + kernel_time_half = (kernel_time - 1) // 2 + + for ifreq in range(kernel_freq): + for itime in range(kernel_time): + df, dt = kernel_freq_half - ifreq, kernel_time_half - itime + x = x.roll(shifts=(df, dt), dims=(2, 3)) + + # if `df` > 0: + # x[:, :, :df, :] = 0 + # elif `df` < 0: + # x[:, :, df:, :] = 0 + + # if `dt` > 0: + # x[:, :, :, :dt] = 0 + # elif `dt` < 0: + # x[:, :, :, dt:] = 0 + + fslice = slice(max(0, df), min(n_freq, n_freq + df)) + tslice = slice(max(0, dt), min(n_time, n_time + dt)) + + s[:, :, fslice, tslice] += x[:, :, fslice, tslice] * m[ifreq, + itime, :, + :, fslice, + tslice] + + return s + + +class MultiSourceMultiPatchingMaskBandSplitCoreRNN( + PatchingMaskBandsplitCoreBase +): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: List[Tuple[float, float]], + mask_kernel_freq: int, + mask_kernel_time: int, + conv_kernel_freq: int, + conv_kernel_time: int, + kernel_norm_mlp_version: int, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + overlapping_band: bool = False, + freq_weights: Optional[List[torch.Tensor]] = None, + n_freq: Optional[int] = None, + ) -> None: + + super().__init__() + self.band_split = BandSplitModule( + in_channel=in_channel, + band_specs=band_specs, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + + self.tf_model = ( + SeqBandModellingModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) + ) + + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + + if overlapping_band: + assert freq_weights is not None + assert n_freq is not None + self.mask_estim = nn.ModuleDict( + { + stem: PatchingMaskEstimationModule( + band_specs=band_specs, + freq_weights=freq_weights, + n_freq=n_freq, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + mask_kernel_freq=mask_kernel_freq, + mask_kernel_time=mask_kernel_time, + conv_kernel_freq=conv_kernel_freq, + conv_kernel_time=conv_kernel_time, + kernel_norm_mlp_version=kernel_norm_mlp_version + ) + for stem in stems + } + ) + else: + raise NotImplementedError diff --git a/separator/models/bandit/core/model/bsrnn/maskestim.py b/separator/models/bandit/core/model/bsrnn/maskestim.py new file mode 100644 index 0000000000000000000000000000000000000000..0b9289dfa702e02ff4d4f0dc76196fd39bb68e34 --- /dev/null +++ b/separator/models/bandit/core/model/bsrnn/maskestim.py @@ -0,0 +1,347 @@ +import warnings +from typing import Dict, List, Optional, Tuple, Type + +import torch +from torch import nn +from torch.nn.modules import activation + +from models.bandit.core.model.bsrnn.utils import ( + band_widths_from_specs, + check_no_gap, + check_no_overlap, + check_nonzero_bandwidth, +) + + +class BaseNormMLP(nn.Module): + def __init__( + self, + emb_dim: int, + mlp_dim: int, + bandwidth: int, + in_channel: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs=None, + complex_mask: bool = True, ): + + super().__init__() + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + self.hidden_activation_kwargs = hidden_activation_kwargs + self.norm = nn.LayerNorm(emb_dim) + self.hidden = torch.jit.script(nn.Sequential( + nn.Linear(in_features=emb_dim, out_features=mlp_dim), + activation.__dict__[hidden_activation]( + **self.hidden_activation_kwargs + ), + )) + + self.bandwidth = bandwidth + self.in_channel = in_channel + + self.complex_mask = complex_mask + self.reim = 2 if complex_mask else 1 + self.glu_mult = 2 + + +class NormMLP(BaseNormMLP): + def __init__( + self, + emb_dim: int, + mlp_dim: int, + bandwidth: int, + in_channel: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs=None, + complex_mask: bool = True, + ) -> None: + super().__init__( + emb_dim=emb_dim, + mlp_dim=mlp_dim, + bandwidth=bandwidth, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + + self.output = torch.jit.script( + nn.Sequential( + nn.Linear( + in_features=mlp_dim, + out_features=bandwidth * in_channel * self.reim * 2, + ), + nn.GLU(dim=-1), + ) + ) + + def reshape_output(self, mb): + # print(mb.shape) + batch, n_time, _ = mb.shape + if self.complex_mask: + mb = mb.reshape( + batch, + n_time, + self.in_channel, + self.bandwidth, + self.reim + ).contiguous() + # print(mb.shape) + mb = torch.view_as_complex( + mb + ) # (batch, n_time, in_channel, bandwidth) + else: + mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth) + + mb = torch.permute( + mb, + (0, 2, 3, 1) + ) # (batch, in_channel, bandwidth, n_time) + + return mb + + def forward(self, qb): + # qb = (batch, n_time, emb_dim) + + # if torch.any(torch.isnan(qb)): + # raise ValueError("qb0") + + + qb = self.norm(qb) # (batch, n_time, emb_dim) + + # if torch.any(torch.isnan(qb)): + # raise ValueError("qb1") + + qb = self.hidden(qb) # (batch, n_time, mlp_dim) + # if torch.any(torch.isnan(qb)): + # raise ValueError("qb2") + mb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim) + # if torch.any(torch.isnan(qb)): + # raise ValueError("mb") + mb = self.reshape_output(mb) # (batch, in_channel, bandwidth, n_time) + + return mb + + +class MultAddNormMLP(NormMLP): + def __init__(self, emb_dim: int, mlp_dim: int, bandwidth: int, in_channel: "int | None", hidden_activation: str = "Tanh", hidden_activation_kwargs=None, complex_mask: bool = True) -> None: + super().__init__(emb_dim, mlp_dim, bandwidth, in_channel, hidden_activation, hidden_activation_kwargs, complex_mask) + + self.output2 = torch.jit.script( + nn.Sequential( + nn.Linear( + in_features=mlp_dim, + out_features=bandwidth * in_channel * self.reim * 2, + ), + nn.GLU(dim=-1), + ) + ) + + def forward(self, qb): + + qb = self.norm(qb) # (batch, n_time, emb_dim) + qb = self.hidden(qb) # (batch, n_time, mlp_dim) + mmb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim) + mmb = self.reshape_output(mmb) # (batch, in_channel, bandwidth, n_time) + amb = self.output2(qb) # (batch, n_time, bandwidth * in_channel * reim) + amb = self.reshape_output(amb) # (batch, in_channel, bandwidth, n_time) + + return mmb, amb + + +class MaskEstimationModuleSuperBase(nn.Module): + pass + + +class MaskEstimationModuleBase(MaskEstimationModuleSuperBase): + def __init__( + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + mlp_dim: int, + in_channel: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + norm_mlp_cls: Type[nn.Module] = NormMLP, + norm_mlp_kwargs: Dict = None, + ) -> None: + super().__init__() + + self.band_widths = band_widths_from_specs(band_specs) + self.n_bands = len(band_specs) + + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + + if norm_mlp_kwargs is None: + norm_mlp_kwargs = {} + + self.norm_mlp = nn.ModuleList( + [ + ( + norm_mlp_cls( + bandwidth=self.band_widths[b], + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + **norm_mlp_kwargs, + ) + ) + for b in range(self.n_bands) + ] + ) + + def compute_masks(self, q): + batch, n_bands, n_time, emb_dim = q.shape + + masks = [] + + for b, nmlp in enumerate(self.norm_mlp): + # print(f"maskestim/{b:02d}") + qb = q[:, b, :, :] + mb = nmlp(qb) + masks.append(mb) + + return masks + + + +class OverlappingMaskEstimationModule(MaskEstimationModuleBase): + def __init__( + self, + in_channel: int, + band_specs: List[Tuple[float, float]], + freq_weights: List[torch.Tensor], + n_freq: int, + emb_dim: int, + mlp_dim: int, + cond_dim: int = 0, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + norm_mlp_cls: Type[nn.Module] = NormMLP, + norm_mlp_kwargs: Dict = None, + use_freq_weights: bool = True, + ) -> None: + check_nonzero_bandwidth(band_specs) + check_no_gap(band_specs) + + # if cond_dim > 0: + # raise NotImplementedError + + super().__init__( + band_specs=band_specs, + emb_dim=emb_dim + cond_dim, + mlp_dim=mlp_dim, + in_channel=in_channel, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + norm_mlp_cls=norm_mlp_cls, + norm_mlp_kwargs=norm_mlp_kwargs, + ) + + self.n_freq = n_freq + self.band_specs = band_specs + self.in_channel = in_channel + + if freq_weights is not None: + for i, fw in enumerate(freq_weights): + self.register_buffer(f"freq_weights/{i}", fw) + + self.use_freq_weights = use_freq_weights + else: + self.use_freq_weights = False + + self.cond_dim = cond_dim + + def forward(self, q, cond=None): + # q = (batch, n_bands, n_time, emb_dim) + + batch, n_bands, n_time, emb_dim = q.shape + + if cond is not None: + print(cond) + if cond.ndim == 2: + cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1) + elif cond.ndim == 3: + assert cond.shape[1] == n_time + else: + raise ValueError(f"Invalid cond shape: {cond.shape}") + + q = torch.cat([q, cond], dim=-1) + elif self.cond_dim > 0: + cond = torch.ones( + (batch, n_bands, n_time, self.cond_dim), + device=q.device, + dtype=q.dtype, + ) + q = torch.cat([q, cond], dim=-1) + else: + pass + + mask_list = self.compute_masks( + q + ) # [n_bands * (batch, in_channel, bandwidth, n_time)] + + masks = torch.zeros( + (batch, self.in_channel, self.n_freq, n_time), + device=q.device, + dtype=mask_list[0].dtype, + ) + + for im, mask in enumerate(mask_list): + fstart, fend = self.band_specs[im] + if self.use_freq_weights: + fw = self.get_buffer(f"freq_weights/{im}")[:, None] + mask = mask * fw + masks[:, :, fstart:fend, :] += mask + + return masks + + +class MaskEstimationModule(OverlappingMaskEstimationModule): + def __init__( + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + mlp_dim: int, + in_channel: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + **kwargs, + ) -> None: + check_nonzero_bandwidth(band_specs) + check_no_gap(band_specs) + check_no_overlap(band_specs) + super().__init__( + in_channel=in_channel, + band_specs=band_specs, + freq_weights=None, + n_freq=None, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + + def forward(self, q, cond=None): + # q = (batch, n_bands, n_time, emb_dim) + + masks = self.compute_masks( + q + ) # [n_bands * (batch, in_channel, bandwidth, n_time)] + + # TODO: currently this requires band specs to have no gap and no overlap + masks = torch.concat( + masks, + dim=2 + ) # (batch, in_channel, n_freq, n_time) + + return masks diff --git a/separator/models/bandit/core/model/bsrnn/tfmodel.py b/separator/models/bandit/core/model/bsrnn/tfmodel.py new file mode 100644 index 0000000000000000000000000000000000000000..ba710798c5ab49936bd63c914f20da516cbc6af9 --- /dev/null +++ b/separator/models/bandit/core/model/bsrnn/tfmodel.py @@ -0,0 +1,317 @@ +import warnings + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.modules import rnn + +import torch.backends.cuda + + +class TimeFrequencyModellingModule(nn.Module): + def __init__(self) -> None: + super().__init__() + + +class ResidualRNN(nn.Module): + def __init__( + self, + emb_dim: int, + rnn_dim: int, + bidirectional: bool = True, + rnn_type: str = "LSTM", + use_batch_trick: bool = True, + use_layer_norm: bool = True, + ) -> None: + # n_group is the size of the 2nd dim + super().__init__() + + self.use_layer_norm = use_layer_norm + if use_layer_norm: + self.norm = nn.LayerNorm(emb_dim) + else: + self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim) + + self.rnn = rnn.__dict__[rnn_type]( + input_size=emb_dim, + hidden_size=rnn_dim, + num_layers=1, + batch_first=True, + bidirectional=bidirectional, + ) + + self.fc = nn.Linear( + in_features=rnn_dim * (2 if bidirectional else 1), + out_features=emb_dim + ) + + self.use_batch_trick = use_batch_trick + if not self.use_batch_trick: + warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!") + + def forward(self, z): + # z = (batch, n_uncrossed, n_across, emb_dim) + + z0 = torch.clone(z) + + # print(z.device) + + if self.use_layer_norm: + z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim) + else: + z = torch.permute( + z, (0, 3, 1, 2) + ) # (batch, emb_dim, n_uncrossed, n_across) + + z = self.norm(z) # (batch, emb_dim, n_uncrossed, n_across) + + z = torch.permute( + z, (0, 2, 3, 1) + ) # (batch, n_uncrossed, n_across, emb_dim) + + batch, n_uncrossed, n_across, emb_dim = z.shape + + if self.use_batch_trick: + z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim)) + + z = self.rnn(z.contiguous())[0] # (batch * n_uncrossed, n_across, dir_rnn_dim) + + z = torch.reshape(z, (batch, n_uncrossed, n_across, -1)) + # (batch, n_uncrossed, n_across, dir_rnn_dim) + else: + # Note: this is EXTREMELY SLOW + zlist = [] + for i in range(n_uncrossed): + zi = self.rnn(z[:, i, :, :])[0] # (batch, n_across, emb_dim) + zlist.append(zi) + + z = torch.stack( + zlist, + dim=1 + ) # (batch, n_uncrossed, n_across, dir_rnn_dim) + + z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim) + + z = z + z0 + + return z + + +class SeqBandModellingModule(TimeFrequencyModellingModule): + def __init__( + self, + n_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + parallel_mode=False, + ) -> None: + super().__init__() + self.seqband = nn.ModuleList([]) + + if parallel_mode: + for _ in range(n_modules): + self.seqband.append( + nn.ModuleList( + [ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ), + ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + )] + ) + ) + else: + + for _ in range(2 * n_modules): + self.seqband.append( + ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) + ) + + self.parallel_mode = parallel_mode + + def forward(self, z): + # z = (batch, n_bands, n_time, emb_dim) + + if self.parallel_mode: + for sbm_pair in self.seqband: + # z: (batch, n_bands, n_time, emb_dim) + sbm_t, sbm_f = sbm_pair[0], sbm_pair[1] + zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim) + zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim) + z = zt + zf.transpose(1, 2) + else: + for sbm in self.seqband: + z = sbm(z) + z = z.transpose(1, 2) + + # (batch, n_bands, n_time, emb_dim) + # --> (batch, n_time, n_bands, emb_dim) + # OR + # (batch, n_time, n_bands, emb_dim) + # --> (batch, n_bands, n_time, emb_dim) + + q = z + return q # (batch, n_bands, n_time, emb_dim) + + +class ResidualTransformer(nn.Module): + def __init__( + self, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + dropout: float = 0.0, + ) -> None: + # n_group is the size of the 2nd dim + super().__init__() + + self.tf = nn.TransformerEncoderLayer( + d_model=emb_dim, + nhead=4, + dim_feedforward=rnn_dim, + batch_first=True + ) + + self.is_causal = not bidirectional + self.dropout = dropout + + def forward(self, z): + batch, n_uncrossed, n_across, emb_dim = z.shape + z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim)) + z = self.tf(z, is_causal=self.is_causal) # (batch, n_uncrossed, n_across, emb_dim) + z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim)) + + return z + + +class TransformerTimeFreqModule(TimeFrequencyModellingModule): + def __init__( + self, + n_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.norm = nn.LayerNorm(emb_dim) + self.seqband = nn.ModuleList([]) + + for _ in range(2 * n_modules): + self.seqband.append( + ResidualTransformer( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=dropout, + ) + ) + + def forward(self, z): + # z = (batch, n_bands, n_time, emb_dim) + z = self.norm(z) # (batch, n_bands, n_time, emb_dim) + + for sbm in self.seqband: + z = sbm(z) + z = z.transpose(1, 2) + + # (batch, n_bands, n_time, emb_dim) + # --> (batch, n_time, n_bands, emb_dim) + # OR + # (batch, n_time, n_bands, emb_dim) + # --> (batch, n_bands, n_time, emb_dim) + + q = z + return q # (batch, n_bands, n_time, emb_dim) + + + +class ResidualConvolution(nn.Module): + def __init__( + self, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + dropout: float = 0.0, + ) -> None: + # n_group is the size of the 2nd dim + super().__init__() + self.norm = nn.InstanceNorm2d(emb_dim, affine=True) + + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=emb_dim, + out_channels=rnn_dim, + kernel_size=(3, 3), + padding="same", + stride=(1, 1), + ), + nn.Tanhshrink() + ) + + self.is_causal = not bidirectional + self.dropout = dropout + + self.fc = nn.Conv2d( + in_channels=rnn_dim, + out_channels=emb_dim, + kernel_size=(1, 1), + padding="same", + stride=(1, 1), + ) + + + def forward(self, z): + # z = (batch, n_uncrossed, n_across, emb_dim) + + z0 = torch.clone(z) + + z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim) + z = self.conv(z) # (batch, n_uncrossed, n_across, emb_dim) + z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim) + z = z + z0 + + return z + + +class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule): + def __init__( + self, + n_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.seqband = torch.jit.script(nn.Sequential( + *[ResidualConvolution( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + dropout=dropout, + ) for _ in range(2 * n_modules) ])) + + def forward(self, z): + # z = (batch, n_bands, n_time, emb_dim) + + z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time) + + z = self.seqband(z) # (batch, emb_dim, n_bands, n_time) + + z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim) + + return z diff --git a/separator/models/bandit/core/model/bsrnn/utils.py b/separator/models/bandit/core/model/bsrnn/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8636e65fe9e7fdd13fa063760018df90a01cff --- /dev/null +++ b/separator/models/bandit/core/model/bsrnn/utils.py @@ -0,0 +1,583 @@ +import os +from abc import abstractmethod +from typing import Any, Callable + +import numpy as np +import torch +from librosa import hz_to_midi, midi_to_hz +from torch import Tensor +from torchaudio import functional as taF +from spafe.fbanks import bark_fbanks +from spafe.utils.converters import erb2hz, hz2bark, hz2erb +from torchaudio.functional.functional import _create_triangular_filterbank + + +def band_widths_from_specs(band_specs): + return [e - i for i, e in band_specs] + + +def check_nonzero_bandwidth(band_specs): + # pprint(band_specs) + for fstart, fend in band_specs: + if fend - fstart <= 0: + raise ValueError("Bands cannot be zero-width") + + +def check_no_overlap(band_specs): + fend_prev = -1 + for fstart_curr, fend_curr in band_specs: + if fstart_curr <= fend_prev: + raise ValueError("Bands cannot overlap") + + +def check_no_gap(band_specs): + fstart, _ = band_specs[0] + assert fstart == 0 + + fend_prev = -1 + for fstart_curr, fend_curr in band_specs: + if fstart_curr - fend_prev > 1: + raise ValueError("Bands cannot leave gap") + fend_prev = fend_curr + + +class BandsplitSpecification: + def __init__(self, nfft: int, fs: int) -> None: + self.fs = fs + self.nfft = nfft + self.nyquist = fs / 2 + self.max_index = nfft // 2 + 1 + + self.split500 = self.hertz_to_index(500) + self.split1k = self.hertz_to_index(1000) + self.split2k = self.hertz_to_index(2000) + self.split4k = self.hertz_to_index(4000) + self.split8k = self.hertz_to_index(8000) + self.split16k = self.hertz_to_index(16000) + self.split20k = self.hertz_to_index(20000) + + self.above20k = [(self.split20k, self.max_index)] + self.above16k = [(self.split16k, self.split20k)] + self.above20k + + def index_to_hertz(self, index: int): + return index * self.fs / self.nfft + + def hertz_to_index(self, hz: float, round: bool = True): + index = hz * self.nfft / self.fs + + if round: + index = int(np.round(index)) + + return index + + def get_band_specs_with_bandwidth( + self, + start_index, + end_index, + bandwidth_hz + ): + band_specs = [] + lower = start_index + + while lower < end_index: + upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz))) + upper = min(upper, end_index) + + band_specs.append((lower, upper)) + lower = upper + + return band_specs + + @abstractmethod + def get_band_specs(self): + raise NotImplementedError + + +class VocalBandsplitSpecification(BandsplitSpecification): + def __init__(self, nfft: int, fs: int, version: str = "7") -> None: + super().__init__(nfft=nfft, fs=fs) + + self.version = version + + def get_band_specs(self): + return getattr(self, f"version{self.version}")() + + @property + def version1(self): + return self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.max_index, bandwidth_hz=1000 + ) + + def version2(self): + below16k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split16k, bandwidth_hz=1000 + ) + below20k = self.get_band_specs_with_bandwidth( + start_index=self.split16k, + end_index=self.split20k, + bandwidth_hz=2000 + ) + + return below16k + below20k + self.above20k + + def version3(self): + below8k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split8k, bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, + end_index=self.split16k, + bandwidth_hz=2000 + ) + + return below8k + below16k + self.above16k + + def version4(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, + end_index=self.split8k, + bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, + end_index=self.split16k, + bandwidth_hz=2000 + ) + + return below1k + below8k + below16k + self.above16k + + def version5(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, + end_index=self.split16k, + bandwidth_hz=1000 + ) + below20k = self.get_band_specs_with_bandwidth( + start_index=self.split16k, + end_index=self.split20k, + bandwidth_hz=2000 + ) + return below1k + below16k + below20k + self.above20k + + def version6(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, + end_index=self.split4k, + bandwidth_hz=500 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, + end_index=self.split8k, + bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, + end_index=self.split16k, + bandwidth_hz=2000 + ) + return below1k + below4k + below8k + below16k + self.above16k + + def version7(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, + end_index=self.split4k, + bandwidth_hz=250 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, + end_index=self.split8k, + bandwidth_hz=500 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, + end_index=self.split16k, + bandwidth_hz=1000 + ) + below20k = self.get_band_specs_with_bandwidth( + start_index=self.split16k, + end_index=self.split20k, + bandwidth_hz=2000 + ) + return below1k + below4k + below8k + below16k + below20k + self.above20k + + +class OtherBandsplitSpecification(VocalBandsplitSpecification): + def __init__(self, nfft: int, fs: int) -> None: + super().__init__(nfft=nfft, fs=fs, version="7") + + +class BassBandsplitSpecification(BandsplitSpecification): + def __init__(self, nfft: int, fs: int, version: str = "7") -> None: + super().__init__(nfft=nfft, fs=fs) + + def get_band_specs(self): + below500 = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split500, bandwidth_hz=50 + ) + below1k = self.get_band_specs_with_bandwidth( + start_index=self.split500, + end_index=self.split1k, + bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, + end_index=self.split4k, + bandwidth_hz=500 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, + end_index=self.split8k, + bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, + end_index=self.split16k, + bandwidth_hz=2000 + ) + above16k = [(self.split16k, self.max_index)] + + return below500 + below1k + below4k + below8k + below16k + above16k + + +class DrumBandsplitSpecification(BandsplitSpecification): + def __init__(self, nfft: int, fs: int) -> None: + super().__init__(nfft=nfft, fs=fs) + + def get_band_specs(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=50 + ) + below2k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, + end_index=self.split2k, + bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split2k, + end_index=self.split4k, + bandwidth_hz=250 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, + end_index=self.split8k, + bandwidth_hz=500 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, + end_index=self.split16k, + bandwidth_hz=1000 + ) + above16k = [(self.split16k, self.max_index)] + + return below1k + below2k + below4k + below8k + below16k + above16k + + + + +class PerceptualBandsplitSpecification(BandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + fbank_fn: Callable[[int, int, float, float, int], torch.Tensor], + n_bands: int, + f_min: float = 0.0, + f_max: float = None + ) -> None: + super().__init__(nfft=nfft, fs=fs) + self.n_bands = n_bands + if f_max is None: + f_max = fs / 2 + + self.filterbank = fbank_fn( + n_bands, fs, f_min, f_max, self.max_index + ) + + weight_per_bin = torch.sum( + self.filterbank, + dim=0, + keepdim=True + ) # (1, n_freqs) + normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs) + + freq_weights = [] + band_specs = [] + for i in range(self.n_bands): + active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist() + if isinstance(active_bins, int): + active_bins = (active_bins, active_bins) + if len(active_bins) == 0: + continue + start_index = active_bins[0] + end_index = active_bins[-1] + 1 + band_specs.append((start_index, end_index)) + freq_weights.append(normalized_mel_fb[i, start_index:end_index]) + + self.freq_weights = freq_weights + self.band_specs = band_specs + + def get_band_specs(self): + return self.band_specs + + def get_freq_weights(self): + return self.freq_weights + + def save_to_file(self, dir_path: str) -> None: + + os.makedirs(dir_path, exist_ok=True) + + import pickle + + with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f: + pickle.dump( + { + "band_specs": self.band_specs, + "freq_weights": self.freq_weights, + "filterbank": self.filterbank, + }, + f, + ) + +def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs): + fb = taF.melscale_fbanks( + n_mels=n_bands, + sample_rate=fs, + f_min=f_min, + f_max=f_max, + n_freqs=n_freqs, + ).T + + fb[0, 0] = 1.0 + + return fb + + +class MelBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + n_bands: int, + f_min: float = 0.0, + f_max: float = None + ) -> None: + super().__init__(fbank_fn=mel_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + +def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, + scale="constant"): + + nfft = 2 * (n_freqs - 1) + df = fs / nfft + # init freqs + f_max = f_max or fs / 2 + f_min = f_min or 0 + f_min = fs / nfft + + n_octaves = np.log2(f_max / f_min) + n_octaves_per_band = n_octaves / n_bands + bandwidth_mult = np.power(2.0, n_octaves_per_band) + + low_midi = max(0, hz_to_midi(f_min)) + high_midi = hz_to_midi(f_max) + midi_points = np.linspace(low_midi, high_midi, n_bands) + hz_pts = midi_to_hz(midi_points) + + low_pts = hz_pts / bandwidth_mult + high_pts = hz_pts * bandwidth_mult + + low_bins = np.floor(low_pts / df).astype(int) + high_bins = np.ceil(high_pts / df).astype(int) + + fb = np.zeros((n_bands, n_freqs)) + + for i in range(n_bands): + fb[i, low_bins[i]:high_bins[i]+1] = 1.0 + + fb[0, :low_bins[0]] = 1.0 + fb[-1, high_bins[-1]+1:] = 1.0 + + return torch.as_tensor(fb) + +class MusicalBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + n_bands: int, + f_min: float = 0.0, + f_max: float = None + ) -> None: + super().__init__(fbank_fn=musical_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + + +def bark_filterbank( + n_bands, fs, f_min, f_max, n_freqs +): + nfft = 2 * (n_freqs -1) + fb, _ = bark_fbanks.bark_filter_banks( + nfilts=n_bands, + nfft=nfft, + fs=fs, + low_freq=f_min, + high_freq=f_max, + scale="constant" + ) + + return torch.as_tensor(fb) + +class BarkBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + n_bands: int, + f_min: float = 0.0, + f_max: float = None + ) -> None: + super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + + +def triangular_bark_filterbank( + n_bands, fs, f_min, f_max, n_freqs +): + + all_freqs = torch.linspace(0, fs // 2, n_freqs) + + # calculate mel freq bins + m_min = hz2bark(f_min) + m_max = hz2bark(f_max) + + m_pts = torch.linspace(m_min, m_max, n_bands + 2) + f_pts = 600 * torch.sinh(m_pts / 6) + + # create filterbank + fb = _create_triangular_filterbank(all_freqs, f_pts) + + fb = fb.T + + first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0] + first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0] + + fb[first_active_band, :first_active_bin] = 1.0 + + return fb + +class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + n_bands: int, + f_min: float = 0.0, + f_max: float = None + ) -> None: + super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + + + +def minibark_filterbank( + n_bands, fs, f_min, f_max, n_freqs +): + fb = bark_filterbank( + n_bands, + fs, + f_min, + f_max, + n_freqs + ) + + fb[fb < np.sqrt(0.5)] = 0.0 + + return fb + +class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + n_bands: int, + f_min: float = 0.0, + f_max: float = None + ) -> None: + super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + + + + + +def erb_filterbank( + n_bands: int, + fs: int, + f_min: float, + f_max: float, + n_freqs: int, +) -> Tensor: + # freq bins + A = (1000 * np.log(10)) / (24.7 * 4.37) + all_freqs = torch.linspace(0, fs // 2, n_freqs) + + # calculate mel freq bins + m_min = hz2erb(f_min) + m_max = hz2erb(f_max) + + m_pts = torch.linspace(m_min, m_max, n_bands + 2) + f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437 + + # create filterbank + fb = _create_triangular_filterbank(all_freqs, f_pts) + + fb = fb.T + + + first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0] + first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0] + + fb[first_active_band, :first_active_bin] = 1.0 + + return fb + + + +class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + n_bands: int, + f_min: float = 0.0, + f_max: float = None + ) -> None: + super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + +if __name__ == "__main__": + import pandas as pd + + band_defs = [] + + for bands in [VocalBandsplitSpecification]: + band_name = bands.__name__.replace("BandsplitSpecification", "") + + mbs = bands(nfft=2048, fs=44100).get_band_specs() + + for i, (f_min, f_max) in enumerate(mbs): + band_defs.append({ + "band": band_name, + "band_index": i, + "f_min": f_min, + "f_max": f_max + }) + + df = pd.DataFrame(band_defs) + df.to_csv("vox7bands.csv", index=False) \ No newline at end of file diff --git a/separator/models/bandit/core/model/bsrnn/wrapper.py b/separator/models/bandit/core/model/bsrnn/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..a31c087db33eb215effa3c3fc492999c5672c55e --- /dev/null +++ b/separator/models/bandit/core/model/bsrnn/wrapper.py @@ -0,0 +1,882 @@ +from pprint import pprint +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from models.bandit.core.model._spectral import _SpectralComponent +from models.bandit.core.model.bsrnn.utils import ( + BarkBandsplitSpecification, BassBandsplitSpecification, + DrumBandsplitSpecification, + EquivalentRectangularBandsplitSpecification, MelBandsplitSpecification, + MusicalBandsplitSpecification, OtherBandsplitSpecification, + TriangularBarkBandsplitSpecification, VocalBandsplitSpecification, +) +from .core import ( + MultiSourceMultiMaskBandSplitCoreConv, + MultiSourceMultiMaskBandSplitCoreRNN, + MultiSourceMultiMaskBandSplitCoreTransformer, + MultiSourceMultiPatchingMaskBandSplitCoreRNN, SingleMaskBandsplitCoreRNN, + SingleMaskBandsplitCoreTransformer, +) + +import pytorch_lightning as pl + +def get_band_specs(band_specs, n_fft, fs, n_bands=None): + if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]: + bsm = VocalBandsplitSpecification( + nfft=n_fft, fs=fs + ).get_band_specs() + freq_weights = None + overlapping_band = False + elif "tribark" in band_specs: + assert n_bands is not None + specs = TriangularBarkBandsplitSpecification( + nfft=n_fft, + fs=fs, + n_bands=n_bands + ) + bsm = specs.get_band_specs() + freq_weights = specs.get_freq_weights() + overlapping_band = True + elif "bark" in band_specs: + assert n_bands is not None + specs = BarkBandsplitSpecification( + nfft=n_fft, + fs=fs, + n_bands=n_bands + ) + bsm = specs.get_band_specs() + freq_weights = specs.get_freq_weights() + overlapping_band = True + elif "erb" in band_specs: + assert n_bands is not None + specs = EquivalentRectangularBandsplitSpecification( + nfft=n_fft, + fs=fs, + n_bands=n_bands + ) + bsm = specs.get_band_specs() + freq_weights = specs.get_freq_weights() + overlapping_band = True + elif "musical" in band_specs: + assert n_bands is not None + specs = MusicalBandsplitSpecification( + nfft=n_fft, + fs=fs, + n_bands=n_bands + ) + bsm = specs.get_band_specs() + freq_weights = specs.get_freq_weights() + overlapping_band = True + elif band_specs == "dnr:mel" or "mel" in band_specs: + assert n_bands is not None + specs = MelBandsplitSpecification( + nfft=n_fft, + fs=fs, + n_bands=n_bands + ) + bsm = specs.get_band_specs() + freq_weights = specs.get_freq_weights() + overlapping_band = True + else: + raise NameError + + return bsm, freq_weights, overlapping_band + + +def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None): + if band_specs_map == "musdb:all": + bsm = { + "vocals": VocalBandsplitSpecification( + nfft=n_fft, fs=fs + ).get_band_specs(), + "drums": DrumBandsplitSpecification( + nfft=n_fft, fs=fs + ).get_band_specs(), + "bass": BassBandsplitSpecification( + nfft=n_fft, fs=fs + ).get_band_specs(), + "other": OtherBandsplitSpecification( + nfft=n_fft, fs=fs + ).get_band_specs(), + } + freq_weights = None + overlapping_band = False + elif band_specs_map == "dnr:vox7": + bsm_, freq_weights, overlapping_band = get_band_specs( + "dnr:speech", n_fft, fs, n_bands + ) + bsm = { + "speech": bsm_, + "music": bsm_, + "effects": bsm_ + } + elif "dnr:vox7:" in band_specs_map: + stem = band_specs_map.split(":")[-1] + bsm_, freq_weights, overlapping_band = get_band_specs( + "dnr:speech", n_fft, fs, n_bands + ) + bsm = { + stem: bsm_ + } + else: + raise NameError + + return bsm, freq_weights, overlapping_band + + +class BandSplitWrapperBase(pl.LightningModule): + bsrnn: nn.Module + + def __init__(self, **kwargs): + super().__init__() + + +class SingleMaskMultiSourceBandSplitBase( + BandSplitWrapperBase, + _SpectralComponent +): + def __init__( + self, + band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]], + fs: int = 44100, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + ) -> None: + super().__init__( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + ) + + if isinstance(band_specs_map, str): + self.band_specs_map, self.freq_weights, self.overlapping_band = get_band_specs_map( + band_specs_map, + n_fft, + fs, + n_bands=n_bands + ) + + self.stems = list(self.band_specs_map.keys()) + + def forward(self, batch): + audio = batch["audio"] + + with torch.no_grad(): + batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in + audio} + + X = batch["spectrogram"]["mixture"] + length = batch["audio"]["mixture"].shape[-1] + + output = {"spectrogram": {}, "audio": {}} + + for stem, bsrnn in self.bsrnn.items(): + S = bsrnn(X) + s = self.istft(S, length) + output["spectrogram"][stem] = S + output["audio"][stem] = s + + return batch, output + + +class MultiMaskMultiSourceBandSplitBase( + BandSplitWrapperBase, + _SpectralComponent +): + def __init__( + self, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + ) -> None: + super().__init__( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + ) + + if isinstance(band_specs, str): + self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs( + band_specs, + n_fft, + fs, + n_bands + ) + + self.stems = stems + + def forward(self, batch): + # with torch.no_grad(): + audio = batch["audio"] + cond = batch.get("condition", None) + with torch.no_grad(): + batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in + audio} + + X = batch["spectrogram"]["mixture"] + length = batch["audio"]["mixture"].shape[-1] + + output = self.bsrnn(X, cond=cond) + output["audio"] = {} + + for stem, S in output["spectrogram"].items(): + s = self.istft(S, length) + output["audio"][stem] = s + + return batch, output + + +class MultiMaskMultiSourceBandSplitBaseSimple( + BandSplitWrapperBase, + _SpectralComponent +): + def __init__( + self, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + ) -> None: + super().__init__( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + ) + + if isinstance(band_specs, str): + self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs( + band_specs, + n_fft, + fs, + n_bands + ) + + self.stems = stems + + def forward(self, batch): + with torch.no_grad(): + X = self.stft(batch) + length = batch.shape[-1] + output = self.bsrnn(X, cond=None) + res = [] + for stem, S in output["spectrogram"].items(): + s = self.istft(S, length) + res.append(s) + res = torch.stack(res, dim=1) + return res + + +class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase): + def __init__( + self, + in_channel: int, + band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + ) -> None: + super().__init__( + band_specs_map=band_specs_map, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + ) + + self.bsrnn = nn.ModuleDict( + { + src: SingleMaskBandsplitCoreRNN( + band_specs=specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + for src, specs in self.band_specs_map.items() + } + ) + + +class SingleMaskMultiSourceBandSplitTransformer( + SingleMaskMultiSourceBandSplitBase +): + def __init__( + self, + in_channel: int, + band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + tf_dropout: float = 0.0, + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + ) -> None: + super().__init__( + band_specs_map=band_specs_map, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + ) + + self.bsrnn = nn.ModuleDict( + { + src: SingleMaskBandsplitCoreTransformer( + band_specs=specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + tf_dropout=tf_dropout, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + for src, specs in self.band_specs_map.items() + } + ) + + +class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + cond_dim: int = 0, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + use_freq_weights: bool = True, + normalize_input: bool = False, + mult_add_mask: bool = False, + freeze_encoder: bool = False, + ) -> None: + super().__init__( + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, + ) + + self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN( + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask + ) + + self.normalize_input = normalize_input + self.cond_dim = cond_dim + + if freeze_encoder: + for param in self.bsrnn.band_split.parameters(): + param.requires_grad = False + + for param in self.bsrnn.tf_model.parameters(): + param.requires_grad = False + + +class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + cond_dim: int = 0, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + use_freq_weights: bool = True, + normalize_input: bool = False, + mult_add_mask: bool = False, + freeze_encoder: bool = False, + ) -> None: + super().__init__( + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, + ) + + self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN( + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask + ) + + self.normalize_input = normalize_input + self.cond_dim = cond_dim + + if freeze_encoder: + for param in self.bsrnn.band_split.parameters(): + param.requires_grad = False + + for param in self.bsrnn.tf_model.parameters(): + param.requires_grad = False + + +class MultiMaskMultiSourceBandSplitTransformer( + MultiMaskMultiSourceBandSplitBase +): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + cond_dim: int = 0, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + use_freq_weights: bool = True, + normalize_input: bool = False, + mult_add_mask: bool = False + ) -> None: + super().__init__( + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, + ) + + self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer( + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask + ) + + + +class MultiMaskMultiSourceBandSplitConv( + MultiMaskMultiSourceBandSplitBase +): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + cond_dim: int = 0, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + use_freq_weights: bool = True, + normalize_input: bool = False, + mult_add_mask: bool = False + ) -> None: + super().__init__( + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, + ) + + self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv( + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + cond_dim=cond_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + mult_add_mask=mult_add_mask + ) +class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase): + def __init__( + self, + in_channel: int, + stems: List[str], + band_specs: Union[str, List[Tuple[float, float]]], + kernel_norm_mlp_version: int = 1, + mask_kernel_freq: int = 3, + mask_kernel_time: int = 3, + conv_kernel_freq: int = 1, + conv_kernel_time: int = 1, + fs: int = 44100, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + n_bands: int = None, + ) -> None: + super().__init__( + stems=stems, + band_specs=band_specs, + fs=fs, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + n_bands=n_bands, + ) + + self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN( + stems=stems, + band_specs=self.band_specs, + in_channel=in_channel, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + overlapping_band=self.overlapping_band, + freq_weights=self.freq_weights, + n_freq=n_fft // 2 + 1, + mask_kernel_freq=mask_kernel_freq, + mask_kernel_time=mask_kernel_time, + conv_kernel_freq=conv_kernel_freq, + conv_kernel_time=conv_kernel_time, + kernel_norm_mlp_version=kernel_norm_mlp_version, + ) diff --git a/separator/models/bandit/core/utils/__init__.py b/separator/models/bandit/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/separator/models/bandit/core/utils/audio.py b/separator/models/bandit/core/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..adae756bf2b02a994a42fcc007da1e1ff7bb6cfb --- /dev/null +++ b/separator/models/bandit/core/utils/audio.py @@ -0,0 +1,463 @@ +from collections import defaultdict + +from tqdm.auto import tqdm +from typing import Callable, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + + +@torch.jit.script +def merge( + combined: torch.Tensor, + original_batch_size: int, + n_channel: int, + n_chunks: int, + chunk_size: int, ): + combined = torch.reshape( + combined, + (original_batch_size, n_chunks, n_channel, chunk_size) + ) + combined = torch.permute(combined, (0, 2, 3, 1)).reshape( + original_batch_size * n_channel, + chunk_size, + n_chunks + ) + + return combined + + +@torch.jit.script +def unfold( + padded_audio: torch.Tensor, + original_batch_size: int, + n_channel: int, + chunk_size: int, + hop_size: int + ) -> torch.Tensor: + + unfolded_input = F.unfold( + padded_audio[:, :, None, :], + kernel_size=(1, chunk_size), + stride=(1, hop_size) + ) + + _, _, n_chunks = unfolded_input.shape + unfolded_input = unfolded_input.view( + original_batch_size, + n_channel, + chunk_size, + n_chunks + ) + unfolded_input = torch.permute( + unfolded_input, + (0, 3, 1, 2) + ).reshape( + original_batch_size * n_chunks, + n_channel, + chunk_size + ) + + return unfolded_input + + +@torch.jit.script +# @torch.compile +def merge_chunks_all( + combined: torch.Tensor, + original_batch_size: int, + n_channel: int, + n_samples: int, + n_padded_samples: int, + n_chunks: int, + chunk_size: int, + hop_size: int, + edge_frame_pad_sizes: Tuple[int, int], + standard_window: torch.Tensor, + first_window: torch.Tensor, + last_window: torch.Tensor +): + combined = merge( + combined, + original_batch_size, + n_channel, + n_chunks, + chunk_size + ) + + combined = combined * standard_window[:, None].to(combined.device) + + combined = F.fold( + combined.to(torch.float32), output_size=(1, n_padded_samples), + kernel_size=(1, chunk_size), + stride=(1, hop_size) + ) + + combined = combined.view( + original_batch_size, + n_channel, + n_padded_samples + ) + + pad_front, pad_back = edge_frame_pad_sizes + combined = combined[..., pad_front:-pad_back] + + combined = combined[..., :n_samples] + + return combined + + # @torch.jit.script + + +def merge_chunks_edge( + combined: torch.Tensor, + original_batch_size: int, + n_channel: int, + n_samples: int, + n_padded_samples: int, + n_chunks: int, + chunk_size: int, + hop_size: int, + edge_frame_pad_sizes: Tuple[int, int], + standard_window: torch.Tensor, + first_window: torch.Tensor, + last_window: torch.Tensor +): + combined = merge( + combined, + original_batch_size, + n_channel, + n_chunks, + chunk_size + ) + + combined[..., 0] = combined[..., 0] * first_window + combined[..., -1] = combined[..., -1] * last_window + combined[..., 1:-1] = combined[..., + 1:-1] * standard_window[:, None] + + combined = F.fold( + combined, output_size=(1, n_padded_samples), + kernel_size=(1, chunk_size), + stride=(1, hop_size) + ) + + combined = combined.view( + original_batch_size, + n_channel, + n_padded_samples + ) + + combined = combined[..., :n_samples] + + return combined + + +class BaseFader(nn.Module): + def __init__( + self, + chunk_size_second: float, + hop_size_second: float, + fs: int, + fade_edge_frames: bool, + batch_size: int, + ) -> None: + super().__init__() + + self.chunk_size = int(chunk_size_second * fs) + self.hop_size = int(hop_size_second * fs) + self.overlap_size = self.chunk_size - self.hop_size + self.fade_edge_frames = fade_edge_frames + self.batch_size = batch_size + + # @torch.jit.script + def prepare(self, audio): + + if self.fade_edge_frames: + audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect") + + n_samples = audio.shape[-1] + n_chunks = int( + np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1 + ) + + padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size + pad_size = padded_size - n_samples + + padded_audio = F.pad(audio, (0, pad_size)) + + return padded_audio, n_chunks + + def forward( + self, + audio: torch.Tensor, + model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]], + ): + + original_dtype = audio.dtype + original_device = audio.device + + audio = audio.to("cpu") + + original_batch_size, n_channel, n_samples = audio.shape + padded_audio, n_chunks = self.prepare(audio) + del audio + n_padded_samples = padded_audio.shape[-1] + + if n_channel > 1: + padded_audio = padded_audio.view( + original_batch_size * n_channel, 1, n_padded_samples + ) + + unfolded_input = unfold( + padded_audio, + original_batch_size, + n_channel, + self.chunk_size, self.hop_size + ) + + n_total_chunks, n_channel, chunk_size = unfolded_input.shape + + n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int) + + chunks_in = [ + unfolded_input[ + b * self.batch_size:(b + 1) * self.batch_size, ...].clone() + for b in range(n_batch) + ] + + all_chunks_out = defaultdict( + lambda: torch.zeros_like( + unfolded_input, device="cpu" + ) + ) + + # for b, cin in enumerate(tqdm(chunks_in)): + for b, cin in enumerate(chunks_in): + if torch.allclose(cin, torch.tensor(0.0)): + del cin + continue + + chunks_out = model_fn(cin.to(original_device)) + del cin + for s, c in chunks_out.items(): + all_chunks_out[s][b * self.batch_size:(b + 1) * self.batch_size, + ...] = c.cpu() + del chunks_out + + del unfolded_input + del padded_audio + + if self.fade_edge_frames: + fn = merge_chunks_all + else: + fn = merge_chunks_edge + outputs = {} + + torch.cuda.empty_cache() + + for s, c in all_chunks_out.items(): + combined: torch.Tensor = fn( + c, + original_batch_size, + n_channel, + n_samples, + n_padded_samples, + n_chunks, + self.chunk_size, + self.hop_size, + self.edge_frame_pad_sizes, + self.standard_window, + self.__dict__.get("first_window", self.standard_window), + self.__dict__.get("last_window", self.standard_window) + ) + + outputs[s] = combined.to( + dtype=original_dtype, + device=original_device + ) + + return { + "audio": outputs + } + # + # def old_forward( + # self, + # audio: torch.Tensor, + # model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]], + # ): + # + # n_samples = audio.shape[-1] + # original_batch_size = audio.shape[0] + # + # padded_audio, n_chunks = self.prepare(audio) + # + # ndim = padded_audio.ndim + # broadcaster = [1 for _ in range(ndim - 1)] + [self.chunk_size] + # + # outputs = defaultdict( + # lambda: torch.zeros_like( + # padded_audio, device=audio.device, dtype=torch.float64 + # ) + # ) + # + # all_chunks_out = [] + # len_chunks_in = [] + # + # batch_size_ = int(self.batch_size // original_batch_size) + # for b in range(int(np.ceil(n_chunks / batch_size_))): + # chunks_in = [] + # for j in range(batch_size_): + # i = b * batch_size_ + j + # if i == n_chunks: + # break + # + # start = i * hop_size + # end = start + self.chunk_size + # chunk_in = padded_audio[..., start:end] + # chunks_in.append(chunk_in) + # + # chunks_in = torch.concat(chunks_in, dim=0) + # chunks_out = model_fn(chunks_in) + # all_chunks_out.append(chunks_out) + # len_chunks_in.append(len(chunks_in)) + # + # for b, (chunks_out, lci) in enumerate( + # zip(all_chunks_out, len_chunks_in) + # ): + # for stem in chunks_out: + # for j in range(lci // original_batch_size): + # i = b * batch_size_ + j + # + # if self.fade_edge_frames: + # window = self.standard_window + # else: + # if i == 0: + # window = self.first_window + # elif i == n_chunks - 1: + # window = self.last_window + # else: + # window = self.standard_window + # + # start = i * hop_size + # end = start + self.chunk_size + # + # chunk_out = chunks_out[stem][j * original_batch_size: (j + 1) * original_batch_size, + # ...] + # contrib = window.view(*broadcaster) * chunk_out + # outputs[stem][..., start:end] = ( + # outputs[stem][..., start:end] + contrib + # ) + # + # if self.fade_edge_frames: + # pad_front, pad_back = self.edge_frame_pad_sizes + # outputs = {k: v[..., pad_front:-pad_back] for k, v in + # outputs.items()} + # + # outputs = {k: v[..., :n_samples].to(audio.dtype) for k, v in + # outputs.items()} + # + # return { + # "audio": outputs + # } + + +class LinearFader(BaseFader): + def __init__( + self, + chunk_size_second: float, + hop_size_second: float, + fs: int, + fade_edge_frames: bool = False, + batch_size: int = 1, + ) -> None: + + assert hop_size_second >= chunk_size_second / 2 + + super().__init__( + chunk_size_second=chunk_size_second, + hop_size_second=hop_size_second, + fs=fs, + fade_edge_frames=fade_edge_frames, + batch_size=batch_size, + ) + + in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1] + out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:] + center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size) + inout_ones = torch.ones(self.overlap_size) + + # using nn.Parameters allows lightning to take care of devices for us + self.register_buffer( + "standard_window", + torch.concat([in_fade, center_ones, out_fade]) + ) + + self.fade_edge_frames = fade_edge_frames + self.edge_frame_pad_size = (self.overlap_size, self.overlap_size) + + if not self.fade_edge_frames: + self.first_window = nn.Parameter( + torch.concat([inout_ones, center_ones, out_fade]), + requires_grad=False + ) + self.last_window = nn.Parameter( + torch.concat([in_fade, center_ones, inout_ones]), + requires_grad=False + ) + + +class OverlapAddFader(BaseFader): + def __init__( + self, + window_type: str, + chunk_size_second: float, + hop_size_second: float, + fs: int, + batch_size: int = 1, + ) -> None: + assert (chunk_size_second / hop_size_second) % 2 == 0 + assert int(chunk_size_second * fs) % 2 == 0 + + super().__init__( + chunk_size_second=chunk_size_second, + hop_size_second=hop_size_second, + fs=fs, + fade_edge_frames=True, + batch_size=batch_size, + ) + + self.hop_multiplier = self.chunk_size / (2 * self.hop_size) + # print(f"hop multiplier: {self.hop_multiplier}") + + self.edge_frame_pad_sizes = ( + 2 * self.overlap_size, + 2 * self.overlap_size + ) + + self.register_buffer( + "standard_window", torch.windows.__dict__[window_type]( + self.chunk_size, sym=False, # dtype=torch.float64 + ) / self.hop_multiplier + ) + + +if __name__ == "__main__": + import torchaudio as ta + fs = 44100 + ola = OverlapAddFader( + "hann", + 6.0, + 1.0, + fs, + batch_size=16 + ) + audio_, _ = ta.load( + "$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too " + "Much/vocals.wav" + ) + audio_ = audio_[None, ...] + out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"] + print(torch.allclose(out, audio_)) diff --git a/separator/models/bandit/model_from_config.py b/separator/models/bandit/model_from_config.py new file mode 100644 index 0000000000000000000000000000000000000000..00ea586d7dfdbd6b89d6b7f2f400e6c8d04da5e4 --- /dev/null +++ b/separator/models/bandit/model_from_config.py @@ -0,0 +1,31 @@ +import sys +import os.path +import torch + +code_path = os.path.dirname(os.path.abspath(__file__)) + '/' +sys.path.append(code_path) + +import yaml +from ml_collections import ConfigDict + +torch.set_float32_matmul_precision("medium") + + +def get_model( + config_path, + weights_path, + device, +): + from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple + + f = open(config_path) + config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) + f.close() + + model = MultiMaskMultiSourceBandSplitRNNSimple( + **config.model + ) + d = torch.load(code_path + 'model_bandit_plus_dnr_sdr_11.47.chpt') + model.load_state_dict(d) + model.to(device) + return model, config diff --git a/separator/models/bandit_v2/bandit.py b/separator/models/bandit_v2/bandit.py new file mode 100644 index 0000000000000000000000000000000000000000..ac4e13f479891065cf1f7dae0720721128347979 --- /dev/null +++ b/separator/models/bandit_v2/bandit.py @@ -0,0 +1,367 @@ +from typing import Dict, List, Optional + +import torch +import torchaudio as ta +from torch import nn +import pytorch_lightning as pl + +from .bandsplit import BandSplitModule +from .maskestim import OverlappingMaskEstimationModule +from .tfmodel import SeqBandModellingModule +from .utils import MusicalBandsplitSpecification + + + +class BaseEndToEndModule(pl.LightningModule): + def __init__( + self, + ) -> None: + super().__init__() + + +class BaseBandit(BaseEndToEndModule): + def __init__( + self, + in_channels: int, + fs: int, + band_type: str = "musical", + n_bands: int = 64, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + + self.instantitate_spectral( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + normalized=normalized, + center=center, + pad_mode=pad_mode, + onesided=onesided, + ) + + self.instantiate_bandsplit( + in_channels=in_channels, + band_type=band_type, + n_bands=n_bands, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + n_fft=n_fft, + fs=fs, + ) + + self.instantiate_tf_modelling( + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) + + def instantitate_spectral( + self, + n_fft: int = 2048, + win_length: Optional[int] = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Optional[Dict] = None, + power: Optional[int] = None, + normalized: bool = True, + center: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + ): + assert power is None + + window_fn = torch.__dict__[window_fn] + + self.stft = ta.transforms.Spectrogram( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + pad_mode=pad_mode, + pad=0, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + normalized=normalized, + center=center, + onesided=onesided, + ) + + self.istft = ta.transforms.InverseSpectrogram( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + pad_mode=pad_mode, + pad=0, + window_fn=window_fn, + wkwargs=wkwargs, + normalized=normalized, + center=center, + onesided=onesided, + ) + + def instantiate_bandsplit( + self, + in_channels: int, + band_type: str = "musical", + n_bands: int = 64, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + emb_dim: int = 128, + n_fft: int = 2048, + fs: int = 44100, + ): + assert band_type == "musical" + + self.band_specs = MusicalBandsplitSpecification( + nfft=n_fft, fs=fs, n_bands=n_bands + ) + + self.band_split = BandSplitModule( + in_channels=in_channels, + band_specs=self.band_specs.get_band_specs(), + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + emb_dim=emb_dim, + ) + + def instantiate_tf_modelling( + self, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + ): + try: + self.tf_model = torch.compile( + SeqBandModellingModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ), + disable=True, + ) + except Exception as e: + self.tf_model = SeqBandModellingModule( + n_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ) + + def mask(self, x, m): + return x * m + + def forward(self, batch, mode="train"): + # Model takes mono as input we give stereo, so we do process of each channel independently + init_shape = batch.shape + if not isinstance(batch, dict): + mono = batch.view(-1, 1, batch.shape[-1]) + batch = { + "mixture": { + "audio": mono + } + } + + with torch.no_grad(): + mixture = batch["mixture"]["audio"] + + x = self.stft(mixture) + batch["mixture"]["spectrogram"] = x + + if "sources" in batch.keys(): + for stem in batch["sources"].keys(): + s = batch["sources"][stem]["audio"] + s = self.stft(s) + batch["sources"][stem]["spectrogram"] = s + + batch = self.separate(batch) + + if 1: + b = [] + for s in self.stems: + # We need to obtain stereo again + r = batch['estimates'][s]['audio'].view(-1, init_shape[1], init_shape[2]) + b.append(r) + # And we need to return back tensor and not independent stems + batch = torch.stack(b, dim=1) + return batch + + def encode(self, batch): + x = batch["mixture"]["spectrogram"] + length = batch["mixture"]["audio"].shape[-1] + + z = self.band_split(x) # (batch, emb_dim, n_band, n_time) + q = self.tf_model(z) # (batch, emb_dim, n_band, n_time) + + return x, q, length + + def separate(self, batch): + raise NotImplementedError + + +class Bandit(BaseBandit): + def __init__( + self, + in_channels: int, + stems: List[str], + band_type: str = "musical", + n_bands: int = 64, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + n_sqm_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + mlp_dim: int = 512, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict | None = None, + complex_mask: bool = True, + use_freq_weights: bool = True, + n_fft: int = 2048, + win_length: int | None = 2048, + hop_length: int = 512, + window_fn: str = "hann_window", + wkwargs: Dict | None = None, + power: int | None = None, + center: bool = True, + normalized: bool = True, + pad_mode: str = "constant", + onesided: bool = True, + fs: int = 44100, + stft_precisions="32", + bandsplit_precisions="bf16", + tf_model_precisions="bf16", + mask_estim_precisions="bf16", + ): + super().__init__( + in_channels=in_channels, + band_type=band_type, + n_bands=n_bands, + require_no_overlap=require_no_overlap, + require_no_gap=require_no_gap, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + n_sqm_modules=n_sqm_modules, + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window_fn=window_fn, + wkwargs=wkwargs, + power=power, + center=center, + normalized=normalized, + pad_mode=pad_mode, + onesided=onesided, + fs=fs, + ) + + self.stems = stems + + self.instantiate_mask_estim( + in_channels=in_channels, + stems=stems, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + n_freq=n_fft // 2 + 1, + use_freq_weights=use_freq_weights, + ) + + def instantiate_mask_estim( + self, + in_channels: int, + stems: List[str], + emb_dim: int, + mlp_dim: int, + hidden_activation: str, + hidden_activation_kwargs: Optional[Dict] = None, + complex_mask: bool = True, + n_freq: Optional[int] = None, + use_freq_weights: bool = False, + ): + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + + assert n_freq is not None + + self.mask_estim = nn.ModuleDict( + { + stem: OverlappingMaskEstimationModule( + band_specs=self.band_specs.get_band_specs(), + freq_weights=self.band_specs.get_freq_weights(), + n_freq=n_freq, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channels=in_channels, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + use_freq_weights=use_freq_weights, + ) + for stem in stems + } + ) + + def separate(self, batch): + batch["estimates"] = {} + + x, q, length = self.encode(batch) + + for stem, mem in self.mask_estim.items(): + m = mem(q) + + s = self.mask(x, m.to(x.dtype)) + s = torch.reshape(s, x.shape) + batch["estimates"][stem] = { + "audio": self.istft(s, length), + "spectrogram": s, + } + + return batch + diff --git a/separator/models/bandit_v2/bandsplit.py b/separator/models/bandit_v2/bandsplit.py new file mode 100644 index 0000000000000000000000000000000000000000..a14ea52bfa318264d536c9f934d0e28db63e15dc --- /dev/null +++ b/separator/models/bandit_v2/bandsplit.py @@ -0,0 +1,130 @@ +from typing import List, Tuple + +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint_sequential + +from .utils import ( + band_widths_from_specs, + check_no_gap, + check_no_overlap, + check_nonzero_bandwidth, +) + + +class NormFC(nn.Module): + def __init__( + self, + emb_dim: int, + bandwidth: int, + in_channels: int, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + ) -> None: + super().__init__() + + if not treat_channel_as_feature: + raise NotImplementedError + + self.treat_channel_as_feature = treat_channel_as_feature + + if normalize_channel_independently: + raise NotImplementedError + + reim = 2 + + norm = nn.LayerNorm(in_channels * bandwidth * reim) + + fc_in = bandwidth * reim + + if treat_channel_as_feature: + fc_in *= in_channels + else: + assert emb_dim % in_channels == 0 + emb_dim = emb_dim // in_channels + + fc = nn.Linear(fc_in, emb_dim) + + self.combined = nn.Sequential(norm, fc) + + def forward(self, xb): + return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False) + + +class BandSplitModule(nn.Module): + def __init__( + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + in_channels: int, + require_no_overlap: bool = False, + require_no_gap: bool = True, + normalize_channel_independently: bool = False, + treat_channel_as_feature: bool = True, + ) -> None: + super().__init__() + + check_nonzero_bandwidth(band_specs) + + if require_no_gap: + check_no_gap(band_specs) + + if require_no_overlap: + check_no_overlap(band_specs) + + self.band_specs = band_specs + # list of [fstart, fend) in index. + # Note that fend is exclusive. + self.band_widths = band_widths_from_specs(band_specs) + self.n_bands = len(band_specs) + self.emb_dim = emb_dim + + try: + self.norm_fc_modules = nn.ModuleList( + [ # type: ignore + torch.compile( + NormFC( + emb_dim=emb_dim, + bandwidth=bw, + in_channels=in_channels, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + ), + disable=True, + ) + for bw in self.band_widths + ] + ) + except Exception as e: + self.norm_fc_modules = nn.ModuleList( + [ # type: ignore + NormFC( + emb_dim=emb_dim, + bandwidth=bw, + in_channels=in_channels, + normalize_channel_independently=normalize_channel_independently, + treat_channel_as_feature=treat_channel_as_feature, + ) + for bw in self.band_widths + ] + ) + + def forward(self, x: torch.Tensor): + # x = complex spectrogram (batch, in_chan, n_freq, n_time) + + batch, in_chan, band_width, n_time = x.shape + + z = torch.zeros( + size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device + ) + + x = torch.permute(x, (0, 3, 1, 2)).contiguous() + + for i, nfm in enumerate(self.norm_fc_modules): + fstart, fend = self.band_specs[i] + xb = x[:, :, :, fstart:fend] + xb = torch.view_as_real(xb) + xb = torch.reshape(xb, (batch, n_time, -1)) + z[:, i, :, :] = nfm(xb) + + return z diff --git a/separator/models/bandit_v2/film.py b/separator/models/bandit_v2/film.py new file mode 100644 index 0000000000000000000000000000000000000000..e30795332ea0e06865ea3d883767db17bb02353c --- /dev/null +++ b/separator/models/bandit_v2/film.py @@ -0,0 +1,25 @@ +from torch import nn +import torch + +class FiLM(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, gamma, beta): + return gamma * x + beta + + +class BTFBroadcastedFiLM(nn.Module): + def __init__(self): + super().__init__() + self.film = FiLM() + + def forward(self, x, gamma, beta): + + gamma = gamma[None, None, None, :] + beta = beta[None, None, None, :] + + return self.film(x, gamma, beta) + + + \ No newline at end of file diff --git a/separator/models/bandit_v2/maskestim.py b/separator/models/bandit_v2/maskestim.py new file mode 100644 index 0000000000000000000000000000000000000000..65215d86a5e94dafdb71744aafadf7aaab93330d --- /dev/null +++ b/separator/models/bandit_v2/maskestim.py @@ -0,0 +1,281 @@ +from typing import Dict, List, Optional, Tuple, Type + +import torch +from torch import nn +from torch.nn.modules import activation +from torch.utils.checkpoint import checkpoint_sequential + +from .utils import ( + band_widths_from_specs, + check_no_gap, + check_no_overlap, + check_nonzero_bandwidth, +) + + +class BaseNormMLP(nn.Module): + def __init__( + self, + emb_dim: int, + mlp_dim: int, + bandwidth: int, + in_channels: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs=None, + complex_mask: bool = True, + ): + super().__init__() + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + self.hidden_activation_kwargs = hidden_activation_kwargs + self.norm = nn.LayerNorm(emb_dim) + self.hidden = nn.Sequential( + nn.Linear(in_features=emb_dim, out_features=mlp_dim), + activation.__dict__[hidden_activation](**self.hidden_activation_kwargs), + ) + + self.bandwidth = bandwidth + self.in_channels = in_channels + + self.complex_mask = complex_mask + self.reim = 2 if complex_mask else 1 + self.glu_mult = 2 + + +class NormMLP(BaseNormMLP): + def __init__( + self, + emb_dim: int, + mlp_dim: int, + bandwidth: int, + in_channels: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs=None, + complex_mask: bool = True, + ) -> None: + super().__init__( + emb_dim=emb_dim, + mlp_dim=mlp_dim, + bandwidth=bandwidth, + in_channels=in_channels, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + + self.output = nn.Sequential( + nn.Linear( + in_features=mlp_dim, + out_features=bandwidth * in_channels * self.reim * 2, + ), + nn.GLU(dim=-1), + ) + + try: + self.combined = torch.compile( + nn.Sequential(self.norm, self.hidden, self.output), disable=True + ) + except Exception as e: + self.combined = nn.Sequential(self.norm, self.hidden, self.output) + + def reshape_output(self, mb): + # print(mb.shape) + batch, n_time, _ = mb.shape + if self.complex_mask: + mb = mb.reshape( + batch, n_time, self.in_channels, self.bandwidth, self.reim + ).contiguous() + # print(mb.shape) + mb = torch.view_as_complex(mb) # (batch, n_time, in_channels, bandwidth) + else: + mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth) + + mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channels, bandwidth, n_time) + + return mb + + def forward(self, qb): + # qb = (batch, n_time, emb_dim) + # qb = self.norm(qb) # (batch, n_time, emb_dim) + # qb = self.hidden(qb) # (batch, n_time, mlp_dim) + # mb = self.output(qb) # (batch, n_time, bandwidth * in_channels * reim) + + mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False) + mb = self.reshape_output(mb) # (batch, in_channels, bandwidth, n_time) + + return mb + + +class MaskEstimationModuleSuperBase(nn.Module): + pass + + +class MaskEstimationModuleBase(MaskEstimationModuleSuperBase): + def __init__( + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + mlp_dim: int, + in_channels: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + norm_mlp_cls: Type[nn.Module] = NormMLP, + norm_mlp_kwargs: Dict = None, + ) -> None: + super().__init__() + + self.band_widths = band_widths_from_specs(band_specs) + self.n_bands = len(band_specs) + + if hidden_activation_kwargs is None: + hidden_activation_kwargs = {} + + if norm_mlp_kwargs is None: + norm_mlp_kwargs = {} + + self.norm_mlp = nn.ModuleList( + [ + norm_mlp_cls( + bandwidth=self.band_widths[b], + emb_dim=emb_dim, + mlp_dim=mlp_dim, + in_channels=in_channels, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + **norm_mlp_kwargs, + ) + for b in range(self.n_bands) + ] + ) + + def compute_masks(self, q): + batch, n_bands, n_time, emb_dim = q.shape + + masks = [] + + for b, nmlp in enumerate(self.norm_mlp): + # print(f"maskestim/{b:02d}") + qb = q[:, b, :, :] + mb = nmlp(qb) + masks.append(mb) + + return masks + + def compute_mask(self, q, b): + batch, n_bands, n_time, emb_dim = q.shape + qb = q[:, b, :, :] + mb = self.norm_mlp[b](qb) + return mb + + +class OverlappingMaskEstimationModule(MaskEstimationModuleBase): + def __init__( + self, + in_channels: int, + band_specs: List[Tuple[float, float]], + freq_weights: List[torch.Tensor], + n_freq: int, + emb_dim: int, + mlp_dim: int, + cond_dim: int = 0, + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + norm_mlp_cls: Type[nn.Module] = NormMLP, + norm_mlp_kwargs: Dict = None, + use_freq_weights: bool = False, + ) -> None: + check_nonzero_bandwidth(band_specs) + check_no_gap(band_specs) + + if cond_dim > 0: + raise NotImplementedError + + super().__init__( + band_specs=band_specs, + emb_dim=emb_dim + cond_dim, + mlp_dim=mlp_dim, + in_channels=in_channels, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + norm_mlp_cls=norm_mlp_cls, + norm_mlp_kwargs=norm_mlp_kwargs, + ) + + self.n_freq = n_freq + self.band_specs = band_specs + self.in_channels = in_channels + + if freq_weights is not None and use_freq_weights: + for i, fw in enumerate(freq_weights): + self.register_buffer(f"freq_weights/{i}", fw) + + self.use_freq_weights = use_freq_weights + else: + self.use_freq_weights = False + + def forward(self, q): + # q = (batch, n_bands, n_time, emb_dim) + + batch, n_bands, n_time, emb_dim = q.shape + + masks = torch.zeros( + (batch, self.in_channels, self.n_freq, n_time), + device=q.device, + dtype=torch.complex64, + ) + + for im in range(n_bands): + fstart, fend = self.band_specs[im] + + mask = self.compute_mask(q, im) + + if self.use_freq_weights: + fw = self.get_buffer(f"freq_weights/{im}")[:, None] + mask = mask * fw + masks[:, :, fstart:fend, :] += mask + + return masks + + +class MaskEstimationModule(OverlappingMaskEstimationModule): + def __init__( + self, + band_specs: List[Tuple[float, float]], + emb_dim: int, + mlp_dim: int, + in_channels: Optional[int], + hidden_activation: str = "Tanh", + hidden_activation_kwargs: Dict = None, + complex_mask: bool = True, + **kwargs, + ) -> None: + check_nonzero_bandwidth(band_specs) + check_no_gap(band_specs) + check_no_overlap(band_specs) + super().__init__( + in_channels=in_channels, + band_specs=band_specs, + freq_weights=None, + n_freq=None, + emb_dim=emb_dim, + mlp_dim=mlp_dim, + hidden_activation=hidden_activation, + hidden_activation_kwargs=hidden_activation_kwargs, + complex_mask=complex_mask, + ) + + def forward(self, q, cond=None): + # q = (batch, n_bands, n_time, emb_dim) + + masks = self.compute_masks( + q + ) # [n_bands * (batch, in_channels, bandwidth, n_time)] + + # TODO: currently this requires band specs to have no gap and no overlap + masks = torch.concat(masks, dim=2) # (batch, in_channels, n_freq, n_time) + + return masks diff --git a/separator/models/bandit_v2/tfmodel.py b/separator/models/bandit_v2/tfmodel.py new file mode 100644 index 0000000000000000000000000000000000000000..21aef03d1f0e814c20db05fe7d14f8019f07713b --- /dev/null +++ b/separator/models/bandit_v2/tfmodel.py @@ -0,0 +1,145 @@ +import warnings + +import torch +import torch.backends.cuda +from torch import nn +from torch.nn.modules import rnn +from torch.utils.checkpoint import checkpoint_sequential + + +class TimeFrequencyModellingModule(nn.Module): + def __init__(self) -> None: + super().__init__() + + +class ResidualRNN(nn.Module): + def __init__( + self, + emb_dim: int, + rnn_dim: int, + bidirectional: bool = True, + rnn_type: str = "LSTM", + use_batch_trick: bool = True, + use_layer_norm: bool = True, + ) -> None: + # n_group is the size of the 2nd dim + super().__init__() + + assert use_layer_norm + assert use_batch_trick + + self.use_layer_norm = use_layer_norm + self.norm = nn.LayerNorm(emb_dim) + self.rnn = rnn.__dict__[rnn_type]( + input_size=emb_dim, + hidden_size=rnn_dim, + num_layers=1, + batch_first=True, + bidirectional=bidirectional, + ) + + self.fc = nn.Linear( + in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim + ) + + self.use_batch_trick = use_batch_trick + if not self.use_batch_trick: + warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!") + + def forward(self, z): + # z = (batch, n_uncrossed, n_across, emb_dim) + + z0 = torch.clone(z) + z = self.norm(z) + + batch, n_uncrossed, n_across, emb_dim = z.shape + z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim)) + z = self.rnn(z)[0] + z = torch.reshape(z, (batch, n_uncrossed, n_across, -1)) + + z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim) + + z = z + z0 + + return z + + +class Transpose(nn.Module): + def __init__(self, dim0: int, dim1: int) -> None: + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, z): + return z.transpose(self.dim0, self.dim1) + + +class SeqBandModellingModule(TimeFrequencyModellingModule): + def __init__( + self, + n_modules: int = 12, + emb_dim: int = 128, + rnn_dim: int = 256, + bidirectional: bool = True, + rnn_type: str = "LSTM", + parallel_mode=False, + ) -> None: + super().__init__() + + self.n_modules = n_modules + + if parallel_mode: + self.seqband = nn.ModuleList([]) + for _ in range(n_modules): + self.seqband.append( + nn.ModuleList( + [ + ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ), + ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ), + ] + ) + ) + else: + seqband = [] + for _ in range(2 * n_modules): + seqband += [ + ResidualRNN( + emb_dim=emb_dim, + rnn_dim=rnn_dim, + bidirectional=bidirectional, + rnn_type=rnn_type, + ), + Transpose(1, 2), + ] + + self.seqband = nn.Sequential(*seqband) + + self.parallel_mode = parallel_mode + + def forward(self, z): + # z = (batch, n_bands, n_time, emb_dim) + + if self.parallel_mode: + for sbm_pair in self.seqband: + # z: (batch, n_bands, n_time, emb_dim) + sbm_t, sbm_f = sbm_pair[0], sbm_pair[1] + zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim) + zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim) + z = zt + zf.transpose(1, 2) + else: + z = checkpoint_sequential( + self.seqband, self.n_modules, z, use_reentrant=False + ) + + q = z + return q # (batch, n_bands, n_time, emb_dim) diff --git a/separator/models/bandit_v2/utils.py b/separator/models/bandit_v2/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ad4eab5d8c5b5396ed717f5b9c365a6900eddd2f --- /dev/null +++ b/separator/models/bandit_v2/utils.py @@ -0,0 +1,523 @@ +import os +from abc import abstractmethod +from typing import Callable + +import numpy as np +import torch +from librosa import hz_to_midi, midi_to_hz +from torchaudio import functional as taF + +# from spafe.fbanks import bark_fbanks +# from spafe.utils.converters import erb2hz, hz2bark, hz2erb + + +def band_widths_from_specs(band_specs): + return [e - i for i, e in band_specs] + + +def check_nonzero_bandwidth(band_specs): + # pprint(band_specs) + for fstart, fend in band_specs: + if fend - fstart <= 0: + raise ValueError("Bands cannot be zero-width") + + +def check_no_overlap(band_specs): + fend_prev = -1 + for fstart_curr, fend_curr in band_specs: + if fstart_curr <= fend_prev: + raise ValueError("Bands cannot overlap") + + +def check_no_gap(band_specs): + fstart, _ = band_specs[0] + assert fstart == 0 + + fend_prev = -1 + for fstart_curr, fend_curr in band_specs: + if fstart_curr - fend_prev > 1: + raise ValueError("Bands cannot leave gap") + fend_prev = fend_curr + + +class BandsplitSpecification: + def __init__(self, nfft: int, fs: int) -> None: + self.fs = fs + self.nfft = nfft + self.nyquist = fs / 2 + self.max_index = nfft // 2 + 1 + + self.split500 = self.hertz_to_index(500) + self.split1k = self.hertz_to_index(1000) + self.split2k = self.hertz_to_index(2000) + self.split4k = self.hertz_to_index(4000) + self.split8k = self.hertz_to_index(8000) + self.split16k = self.hertz_to_index(16000) + self.split20k = self.hertz_to_index(20000) + + self.above20k = [(self.split20k, self.max_index)] + self.above16k = [(self.split16k, self.split20k)] + self.above20k + + def index_to_hertz(self, index: int): + return index * self.fs / self.nfft + + def hertz_to_index(self, hz: float, round: bool = True): + index = hz * self.nfft / self.fs + + if round: + index = int(np.round(index)) + + return index + + def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz): + band_specs = [] + lower = start_index + + while lower < end_index: + upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz))) + upper = min(upper, end_index) + + band_specs.append((lower, upper)) + lower = upper + + return band_specs + + @abstractmethod + def get_band_specs(self): + raise NotImplementedError + + +class VocalBandsplitSpecification(BandsplitSpecification): + def __init__(self, nfft: int, fs: int, version: str = "7") -> None: + super().__init__(nfft=nfft, fs=fs) + + self.version = version + + def get_band_specs(self): + return getattr(self, f"version{self.version}")() + + @property + def version1(self): + return self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.max_index, bandwidth_hz=1000 + ) + + def version2(self): + below16k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split16k, bandwidth_hz=1000 + ) + below20k = self.get_band_specs_with_bandwidth( + start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000 + ) + + return below16k + below20k + self.above20k + + def version3(self): + below8k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split8k, bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 + ) + + return below8k + below16k + self.above16k + + def version4(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 + ) + + return below1k + below8k + below16k + self.above16k + + def version5(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000 + ) + below20k = self.get_band_specs_with_bandwidth( + start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000 + ) + return below1k + below16k + below20k + self.above20k + + def version6(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 + ) + return below1k + below4k + below8k + below16k + self.above16k + + def version7(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000 + ) + below20k = self.get_band_specs_with_bandwidth( + start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000 + ) + return below1k + below4k + below8k + below16k + below20k + self.above20k + + +class OtherBandsplitSpecification(VocalBandsplitSpecification): + def __init__(self, nfft: int, fs: int) -> None: + super().__init__(nfft=nfft, fs=fs, version="7") + + +class BassBandsplitSpecification(BandsplitSpecification): + def __init__(self, nfft: int, fs: int, version: str = "7") -> None: + super().__init__(nfft=nfft, fs=fs) + + def get_band_specs(self): + below500 = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split500, bandwidth_hz=50 + ) + below1k = self.get_band_specs_with_bandwidth( + start_index=self.split500, end_index=self.split1k, bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000 + ) + above16k = [(self.split16k, self.max_index)] + + return below500 + below1k + below4k + below8k + below16k + above16k + + +class DrumBandsplitSpecification(BandsplitSpecification): + def __init__(self, nfft: int, fs: int) -> None: + super().__init__(nfft=nfft, fs=fs) + + def get_band_specs(self): + below1k = self.get_band_specs_with_bandwidth( + start_index=0, end_index=self.split1k, bandwidth_hz=50 + ) + below2k = self.get_band_specs_with_bandwidth( + start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100 + ) + below4k = self.get_band_specs_with_bandwidth( + start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250 + ) + below8k = self.get_band_specs_with_bandwidth( + start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500 + ) + below16k = self.get_band_specs_with_bandwidth( + start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000 + ) + above16k = [(self.split16k, self.max_index)] + + return below1k + below2k + below4k + below8k + below16k + above16k + + +class PerceptualBandsplitSpecification(BandsplitSpecification): + def __init__( + self, + nfft: int, + fs: int, + fbank_fn: Callable[[int, int, float, float, int], torch.Tensor], + n_bands: int, + f_min: float = 0.0, + f_max: float = None, + ) -> None: + super().__init__(nfft=nfft, fs=fs) + self.n_bands = n_bands + if f_max is None: + f_max = fs / 2 + + self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index) + + weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs) + normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs) + + freq_weights = [] + band_specs = [] + for i in range(self.n_bands): + active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist() + if isinstance(active_bins, int): + active_bins = (active_bins, active_bins) + if len(active_bins) == 0: + continue + start_index = active_bins[0] + end_index = active_bins[-1] + 1 + band_specs.append((start_index, end_index)) + freq_weights.append(normalized_mel_fb[i, start_index:end_index]) + + self.freq_weights = freq_weights + self.band_specs = band_specs + + def get_band_specs(self): + return self.band_specs + + def get_freq_weights(self): + return self.freq_weights + + def save_to_file(self, dir_path: str) -> None: + os.makedirs(dir_path, exist_ok=True) + + import pickle + + with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f: + pickle.dump( + { + "band_specs": self.band_specs, + "freq_weights": self.freq_weights, + "filterbank": self.filterbank, + }, + f, + ) + + +def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs): + fb = taF.melscale_fbanks( + n_mels=n_bands, + sample_rate=fs, + f_min=f_min, + f_max=f_max, + n_freqs=n_freqs, + ).T + + fb[0, 0] = 1.0 + + return fb + + +class MelBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None + ) -> None: + super().__init__( + fbank_fn=mel_filterbank, + nfft=nfft, + fs=fs, + n_bands=n_bands, + f_min=f_min, + f_max=f_max, + ) + + +def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"): + nfft = 2 * (n_freqs - 1) + df = fs / nfft + # init freqs + f_max = f_max or fs / 2 + f_min = f_min or 0 + f_min = fs / nfft + + n_octaves = np.log2(f_max / f_min) + n_octaves_per_band = n_octaves / n_bands + bandwidth_mult = np.power(2.0, n_octaves_per_band) + + low_midi = max(0, hz_to_midi(f_min)) + high_midi = hz_to_midi(f_max) + midi_points = np.linspace(low_midi, high_midi, n_bands) + hz_pts = midi_to_hz(midi_points) + + low_pts = hz_pts / bandwidth_mult + high_pts = hz_pts * bandwidth_mult + + low_bins = np.floor(low_pts / df).astype(int) + high_bins = np.ceil(high_pts / df).astype(int) + + fb = np.zeros((n_bands, n_freqs)) + + for i in range(n_bands): + fb[i, low_bins[i] : high_bins[i] + 1] = 1.0 + + fb[0, : low_bins[0]] = 1.0 + fb[-1, high_bins[-1] + 1 :] = 1.0 + + return torch.as_tensor(fb) + + +class MusicalBandsplitSpecification(PerceptualBandsplitSpecification): + def __init__( + self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None + ) -> None: + super().__init__( + fbank_fn=musical_filterbank, + nfft=nfft, + fs=fs, + n_bands=n_bands, + f_min=f_min, + f_max=f_max, + ) + + +# def bark_filterbank( +# n_bands, fs, f_min, f_max, n_freqs +# ): +# nfft = 2 * (n_freqs -1) +# fb, _ = bark_fbanks.bark_filter_banks( +# nfilts=n_bands, +# nfft=nfft, +# fs=fs, +# low_freq=f_min, +# high_freq=f_max, +# scale="constant" +# ) + +# return torch.as_tensor(fb) + +# class BarkBandsplitSpecification(PerceptualBandsplitSpecification): +# def __init__( +# self, +# nfft: int, +# fs: int, +# n_bands: int, +# f_min: float = 0.0, +# f_max: float = None +# ) -> None: +# super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + + +# def triangular_bark_filterbank( +# n_bands, fs, f_min, f_max, n_freqs +# ): + +# all_freqs = torch.linspace(0, fs // 2, n_freqs) + +# # calculate mel freq bins +# m_min = hz2bark(f_min) +# m_max = hz2bark(f_max) + +# m_pts = torch.linspace(m_min, m_max, n_bands + 2) +# f_pts = 600 * torch.sinh(m_pts / 6) + +# # create filterbank +# fb = _create_triangular_filterbank(all_freqs, f_pts) + +# fb = fb.T + +# first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0] +# first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0] + +# fb[first_active_band, :first_active_bin] = 1.0 + +# return fb + +# class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification): +# def __init__( +# self, +# nfft: int, +# fs: int, +# n_bands: int, +# f_min: float = 0.0, +# f_max: float = None +# ) -> None: +# super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + + +# def minibark_filterbank( +# n_bands, fs, f_min, f_max, n_freqs +# ): +# fb = bark_filterbank( +# n_bands, +# fs, +# f_min, +# f_max, +# n_freqs +# ) + +# fb[fb < np.sqrt(0.5)] = 0.0 + +# return fb + +# class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification): +# def __init__( +# self, +# nfft: int, +# fs: int, +# n_bands: int, +# f_min: float = 0.0, +# f_max: float = None +# ) -> None: +# super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + + +# def erb_filterbank( +# n_bands: int, +# fs: int, +# f_min: float, +# f_max: float, +# n_freqs: int, +# ) -> Tensor: +# # freq bins +# A = (1000 * np.log(10)) / (24.7 * 4.37) +# all_freqs = torch.linspace(0, fs // 2, n_freqs) + +# # calculate mel freq bins +# m_min = hz2erb(f_min) +# m_max = hz2erb(f_max) + +# m_pts = torch.linspace(m_min, m_max, n_bands + 2) +# f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437 + +# # create filterbank +# fb = _create_triangular_filterbank(all_freqs, f_pts) + +# fb = fb.T + + +# first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0] +# first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0] + +# fb[first_active_band, :first_active_bin] = 1.0 + +# return fb + + +# class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification): +# def __init__( +# self, +# nfft: int, +# fs: int, +# n_bands: int, +# f_min: float = 0.0, +# f_max: float = None +# ) -> None: +# super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max) + +if __name__ == "__main__": + import pandas as pd + + band_defs = [] + + for bands in [VocalBandsplitSpecification]: + band_name = bands.__name__.replace("BandsplitSpecification", "") + + mbs = bands(nfft=2048, fs=44100).get_band_specs() + + for i, (f_min, f_max) in enumerate(mbs): + band_defs.append( + {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max} + ) + + df = pd.DataFrame(band_defs) + df.to_csv("vox7bands.csv", index=False) diff --git a/separator/models/bs_roformer/__init__.py b/separator/models/bs_roformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..980e0afa5b7b4fd66168bce6905a94e7c91c380e --- /dev/null +++ b/separator/models/bs_roformer/__init__.py @@ -0,0 +1,2 @@ +from models.bs_roformer.bs_roformer import BSRoformer +from models.bs_roformer.mel_band_roformer import MelBandRoformer diff --git a/separator/models/bs_roformer/__pycache__/__init__.cpython-310.pyc b/separator/models/bs_roformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d472ed651a10bd378c3635857cf22a663d230a71 Binary files /dev/null and b/separator/models/bs_roformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/separator/models/bs_roformer/__pycache__/attend.cpython-310.pyc b/separator/models/bs_roformer/__pycache__/attend.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2de273a01d1c192ce21eef254dbdf06077d870f Binary files /dev/null and b/separator/models/bs_roformer/__pycache__/attend.cpython-310.pyc differ diff --git a/separator/models/bs_roformer/__pycache__/bs_roformer.cpython-310.pyc b/separator/models/bs_roformer/__pycache__/bs_roformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69e516dc26693331915c784684408a92ee7e40e0 Binary files /dev/null and b/separator/models/bs_roformer/__pycache__/bs_roformer.cpython-310.pyc differ diff --git a/separator/models/bs_roformer/__pycache__/mel_band_roformer.cpython-310.pyc b/separator/models/bs_roformer/__pycache__/mel_band_roformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6861d4cc4aac714d5d1d99038f224aeed036b8d0 Binary files /dev/null and b/separator/models/bs_roformer/__pycache__/mel_band_roformer.cpython-310.pyc differ diff --git a/separator/models/bs_roformer/attend.py b/separator/models/bs_roformer/attend.py new file mode 100644 index 0000000000000000000000000000000000000000..d6dc4b3079cff5b3c8c90cea8df2301afd18918b --- /dev/null +++ b/separator/models/bs_roformer/attend.py @@ -0,0 +1,126 @@ +from functools import wraps +from packaging import version +from collections import namedtuple + +import os +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange, reduce + +# constants + +FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) + +# helpers + +def exists(val): + return val is not None + +def default(v, d): + return v if exists(v) else d + +def once(fn): + called = False + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + return inner + +print_once = once(print) + +# main class + +class Attend(nn.Module): + def __init__( + self, + dropout = 0., + flash = False, + scale = None + ): + super().__init__() + self.scale = scale + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.flash = flash + assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' + + # determine efficient attention configs for cuda and cpu + + self.cpu_config = FlashAttentionConfig(True, True, True) + self.cuda_config = None + + if not torch.cuda.is_available() or not flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + device_version = version.parse(f'{device_properties.major}.{device_properties.minor}') + + if device_version >= version.parse('8.0'): + if os.name == 'nt': + print_once('Windows OS detected, using math or mem efficient attention if input tensor is on cuda') + self.cuda_config = FlashAttentionConfig(False, True, True) + else: + print_once('GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda') + self.cuda_config = FlashAttentionConfig(True, False, False) + else: + print_once('GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda') + self.cuda_config = FlashAttentionConfig(False, True, True) + + def flash_attn(self, q, k, v): + _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device + + if exists(self.scale): + default_scale = q.shape[-1] ** -0.5 + q = q * (self.scale / default_scale) + + # Check if there is a compatible device for flash attention + + config = self.cuda_config if is_cuda else self.cpu_config + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale + + with torch.backends.cuda.sdp_kernel(**config._asdict()): + out = F.scaled_dot_product_attention( + q, k, v, + dropout_p = self.dropout if self.training else 0. + ) + + return out + + def forward(self, q, k, v): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + q_len, k_len, device = q.shape[-2], k.shape[-2], q.device + + scale = default(self.scale, q.shape[-1] ** -0.5) + + if self.flash: + return self.flash_attn(q, k, v) + + # similarity + + sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale + + # attention + + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + + out = einsum(f"b h i j, b h j d -> b h i d", attn, v) + + return out diff --git a/separator/models/bs_roformer/attend_sw.py b/separator/models/bs_roformer/attend_sw.py new file mode 100644 index 0000000000000000000000000000000000000000..3b708770686a57c1759c21064813812aff6a64db --- /dev/null +++ b/separator/models/bs_roformer/attend_sw.py @@ -0,0 +1,98 @@ +import logging +import os + +import torch +import torch.nn.functional as F +from packaging import version +from torch import Tensor, einsum, nn +from torch.nn.attention import SDPBackend, sdpa_kernel + +logger = logging.getLogger(__name__) + + +class Attend(nn.Module): + def __init__(self, dropout: float = 0.0, flash: bool = False, scale=None): + super().__init__() + self.scale = scale + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.flash = flash + assert not (flash and version.parse(torch.__version__) < version.parse("2.0.0")), ( + "expected pytorch >= 2.0.0 to use flash attention" + ) + + # determine efficient attention configs for cuda and cpu + self.cpu_backends = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] + self.cuda_backends: list | None = None + + if not torch.cuda.is_available() or not flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device("cuda")) + device_version = version.parse(f"{device_properties.major}.{device_properties.minor}") + + if device_version >= version.parse("8.0"): + if os.name == "nt": + cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] + logger.info(f"windows detected, {cuda_backends=}") + else: + cuda_backends = [SDPBackend.FLASH_ATTENTION] + logger.info(f"gpu compute capability >= 8.0, {cuda_backends=}") + else: + cuda_backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] + logger.info(f"gpu compute capability < 8.0, {cuda_backends=}") + + self.cuda_backends = cuda_backends + + def flash_attn(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + _, _heads, _q_len, _, _k_len, is_cuda, _device = ( + *q.shape, + k.shape[-2], + q.is_cuda, + q.device, + ) # type: ignore + + if self.scale is not None: + default_scale = q.shape[-1] ** -0.5 + q = q * (self.scale / default_scale) + + backends = self.cuda_backends if is_cuda else self.cpu_backends + # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale + with sdpa_kernel(backends=backends): # type: ignore + out = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout if self.training else 0.0 + ) + + return out + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + _q_len, _k_len, _device = q.shape[-2], k.shape[-2], q.device + + scale = self.scale or q.shape[-1] ** -0.5 + + if self.flash: + return self.flash_attn(q, k, v) + + # similarity + sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale + + # attention + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + return out diff --git a/separator/models/bs_roformer/bs_roformer.py b/separator/models/bs_roformer/bs_roformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b2d15515d868239603c280ff46c58fdc356ff636 --- /dev/null +++ b/separator/models/bs_roformer/bs_roformer.py @@ -0,0 +1,621 @@ +from functools import partial + +import torch +from torch import nn, einsum, Tensor +from torch.nn import Module, ModuleList +import torch.nn.functional as F + +from models.bs_roformer.attend import Attend +from torch.utils.checkpoint import checkpoint + +from beartype.typing import Tuple, Optional, List, Callable +from beartype import beartype + +from rotary_embedding_torch import RotaryEmbedding + +from einops import rearrange, pack, unpack +from einops.layers.torch import Rearrange + +# helper functions + +def exists(val): + return val is not None + + +def default(v, d): + return v if exists(v) else d + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +# norm + +def l2norm(t): + return F.normalize(t, dim = -1, p = 2) + + +class RMSNorm(Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim=-1) * self.scale * self.gamma + + +# attention + +class FeedForward(Module): + def __init__( + self, + dim, + mult=4, + dropout=0. + ): + super().__init__() + dim_inner = int(dim * mult) + self.net = nn.Sequential( + RMSNorm(dim), + nn.Linear(dim, dim_inner), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_inner, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + + +class Attention(Module): + def __init__( + self, + dim, + heads=8, + dim_head=64, + dropout=0., + rotary_embed=None, + flash=True + ): + super().__init__() + self.heads = heads + self.scale = dim_head ** -0.5 + dim_inner = heads * dim_head + + self.rotary_embed = rotary_embed + + self.attend = Attend(flash=flash, dropout=dropout) + + self.norm = RMSNorm(dim) + self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) + + self.to_gates = nn.Linear(dim, heads) + + self.to_out = nn.Sequential( + nn.Linear(dim_inner, dim, bias=False), + nn.Dropout(dropout) + ) + + def forward(self, x): + x = self.norm(x) + + q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads) + + if exists(self.rotary_embed): + q = self.rotary_embed.rotate_queries_or_keys(q) + k = self.rotary_embed.rotate_queries_or_keys(k) + + out = self.attend(q, k, v) + + gates = self.to_gates(x) + out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() + + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class LinearAttention(Module): + """ + this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al. + """ + + @beartype + def __init__( + self, + *, + dim, + dim_head=32, + heads=8, + scale=8, + flash=False, + dropout=0. + ): + super().__init__() + dim_inner = dim_head * heads + self.norm = RMSNorm(dim) + + self.to_qkv = nn.Sequential( + nn.Linear(dim, dim_inner * 3, bias=False), + Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads) + ) + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.attend = Attend( + scale=scale, + dropout=dropout, + flash=flash + ) + + self.to_out = nn.Sequential( + Rearrange('b h d n -> b n (h d)'), + nn.Linear(dim_inner, dim, bias=False) + ) + + def forward( + self, + x + ): + x = self.norm(x) + + q, k, v = self.to_qkv(x) + + q, k = map(l2norm, (q, k)) + q = q * self.temperature.exp() + + out = self.attend(q, k, v) + + return self.to_out(out) + + +class Transformer(Module): + def __init__( + self, + *, + dim, + depth, + dim_head=64, + heads=8, + attn_dropout=0., + ff_dropout=0., + ff_mult=4, + norm_output=True, + rotary_embed=None, + flash_attn=True, + linear_attn=False + ): + super().__init__() + self.layers = ModuleList([]) + + for _ in range(depth): + if linear_attn: + attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn) + else: + attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, + rotary_embed=rotary_embed, flash=flash_attn) + + self.layers.append(ModuleList([ + attn, + FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) + ])) + + self.norm = RMSNorm(dim) if norm_output else nn.Identity() + + def forward(self, x): + + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) + + +# bandsplit module + +class BandSplit(Module): + @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...] + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential( + RMSNorm(dim_in), + nn.Linear(dim_in, dim) + ) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +def MLP( + dim_in, + dim_out, + dim_hidden=None, + depth=1, + activation=nn.Tanh +): + dim_hidden = default(dim_hidden, dim_in) + + net = [] + dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out) + + for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): + is_last = ind == (len(dims) - 2) + + net.append(nn.Linear(layer_dim_in, layer_dim_out)) + + if is_last: + continue + + net.append(activation()) + + return nn.Sequential(*net) + + +class MaskEstimator(Module): + @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...], + depth, + mlp_expansion_factor=4 + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_freqs = ModuleList([]) + dim_hidden = dim * mlp_expansion_factor + + for dim_in in dim_inputs: + net = [] + + mlp = nn.Sequential( + MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), + nn.GLU(dim=-1) + ) + + self.to_freqs.append(mlp) + + def forward(self, x): + x = x.unbind(dim=-2) + + outs = [] + + for band_features, mlp in zip(x, self.to_freqs): + freq_out = mlp(band_features) + outs.append(freq_out) + + return torch.cat(outs, dim=-1) + + +# main class + +DEFAULT_FREQS_PER_BANDS = ( + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 12, 12, 12, 12, 12, 12, 12, 12, + 24, 24, 24, 24, 24, 24, 24, 24, + 48, 48, 48, 48, 48, 48, 48, 48, + 128, 129, +) + + +class BSRoformer(Module): + + @beartype + def __init__( + self, + dim, + *, + depth, + stereo=False, + num_stems=1, + time_transformer_depth=2, + freq_transformer_depth=2, + linear_transformer_depth=0, + freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, + # in the paper, they divide into ~60 bands, test with 1 for starters + dim_head=64, + heads=8, + attn_dropout=0., + ff_dropout=0., + flash_attn=True, + dim_freqs_in=1025, + stft_n_fft=2048, + stft_hop_length=512, + # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction + stft_win_length=2048, + stft_normalized=False, + stft_window_fn: Optional[Callable] = None, + mask_estimator_depth=2, + multi_stft_resolution_loss_weight=1., + multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256), + multi_stft_hop_size=147, + multi_stft_normalized=False, + multi_stft_window_fn: Callable = torch.hann_window, + mlp_expansion_factor=4, + use_torch_checkpoint=False, + skip_connection=False, + ): + super().__init__() + + self.stereo = stereo + self.audio_channels = 2 if stereo else 1 + self.num_stems = num_stems + self.use_torch_checkpoint = use_torch_checkpoint + self.skip_connection = skip_connection + + self.layers = ModuleList([]) + + transformer_kwargs = dict( + dim=dim, + heads=heads, + dim_head=dim_head, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + flash_attn=flash_attn, + norm_output=False + ) + + time_rotary_embed = RotaryEmbedding(dim=dim_head) + freq_rotary_embed = RotaryEmbedding(dim=dim_head) + + for _ in range(depth): + tran_modules = [] + if linear_transformer_depth > 0: + tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs)) + tran_modules.append( + Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs) + ) + tran_modules.append( + Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs) + ) + self.layers.append(nn.ModuleList(tran_modules)) + + self.final_norm = RMSNorm(dim) + + self.stft_kwargs = dict( + n_fft=stft_n_fft, + hop_length=stft_hop_length, + win_length=stft_win_length, + normalized=stft_normalized + ) + + self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length) + + freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1] + + assert len(freqs_per_bands) > 1 + assert sum( + freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}' + + freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands) + + self.band_split = BandSplit( + dim=dim, + dim_inputs=freqs_per_bands_with_complex + ) + + self.mask_estimators = nn.ModuleList([]) + + for _ in range(num_stems): + mask_estimator = MaskEstimator( + dim=dim, + dim_inputs=freqs_per_bands_with_complex, + depth=mask_estimator_depth, + mlp_expansion_factor=mlp_expansion_factor, + ) + + self.mask_estimators.append(mask_estimator) + + # for the multi-resolution stft loss + + self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight + self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes + self.multi_stft_n_fft = stft_n_fft + self.multi_stft_window_fn = multi_stft_window_fn + + self.multi_stft_kwargs = dict( + hop_length=multi_stft_hop_size, + normalized=multi_stft_normalized + ) + + def forward( + self, + raw_audio, + target=None, + return_loss_breakdown=False + ): + """ + einops + + b - batch + f - freq + t - time + s - audio channel (1 for mono, 2 for stereo) + n - number of 'stems' + c - complex (2) + d - feature dimension + """ + + device = raw_audio.device + + # defining whether model is loaded on MPS (MacOS GPU accelerator) + x_is_mps = True if device.type == "mps" else False + + if raw_audio.ndim == 2: + raw_audio = rearrange(raw_audio, 'b t -> b 1 t') + + channels = raw_audio.shape[1] + assert (not self.stereo and channels == 1) or ( + self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)' + + # to stft + + raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') + + stft_window = self.stft_window_fn(device=device) + + # RuntimeError: FFT operations are only supported on MacOS 14+ + # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used + try: + stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True) + except: + stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(device) + + stft_repr = torch.view_as_real(stft_repr) + + stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c') + stft_repr = rearrange(stft_repr, + 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting + + x = rearrange(stft_repr, 'b f t c -> b t (f c)') + + if self.use_torch_checkpoint: + x = checkpoint(self.band_split, x, use_reentrant=False) + else: + x = self.band_split(x) + + # axial / hierarchical attention + + store = [None] * len(self.layers) + for i, transformer_block in enumerate(self.layers): + + if len(transformer_block) == 3: + linear_transformer, time_transformer, freq_transformer = transformer_block + + x, ft_ps = pack([x], 'b * d') + if self.use_torch_checkpoint: + x = checkpoint(linear_transformer, x, use_reentrant=False) + else: + x = linear_transformer(x) + x, = unpack(x, ft_ps, 'b * d') + else: + time_transformer, freq_transformer = transformer_block + + if self.skip_connection: + # Sum all previous + for j in range(i): + x = x + store[j] + + x = rearrange(x, 'b t f d -> b f t d') + x, ps = pack([x], '* t d') + + if self.use_torch_checkpoint: + x = checkpoint(time_transformer, x, use_reentrant=False) + else: + x = time_transformer(x) + + x, = unpack(x, ps, '* t d') + x = rearrange(x, 'b f t d -> b t f d') + x, ps = pack([x], '* f d') + + if self.use_torch_checkpoint: + x = checkpoint(freq_transformer, x, use_reentrant=False) + else: + x = freq_transformer(x) + + x, = unpack(x, ps, '* f d') + + if self.skip_connection: + store[i] = x + + x = self.final_norm(x) + + num_stems = len(self.mask_estimators) + + if self.use_torch_checkpoint: + mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1) + else: + mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) + mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2) + + # modulate frequency representation + + stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') + + # complex number multiplication + + stft_repr = torch.view_as_complex(stft_repr) + mask = torch.view_as_complex(mask) + + stft_repr = stft_repr * mask + + # istft + + stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels) + + # same as torch.stft() fix for MacOS MPS above + try: + recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1]) + except: + recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device) + + recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems) + + if num_stems == 1: + recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t') + + # if a target is passed in, calculate loss for learning + + if not exists(target): + return recon_audio + + if self.num_stems > 1: + assert target.ndim == 4 and target.shape[1] == self.num_stems + + if target.ndim == 2: + target = rearrange(target, '... t -> ... 1 t') + + target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft + + loss = F.l1_loss(recon_audio, target) + + multi_stft_resolution_loss = 0. + + for window_size in self.multi_stft_resolutions_window_sizes: + res_stft_kwargs = dict( + n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft + win_length=window_size, + return_complex=True, + window=self.multi_stft_window_fn(window_size, device=device), + **self.multi_stft_kwargs, + ) + + recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs) + target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs) + + multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y) + + weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + + total_loss = loss + weighted_multi_resolution_loss + + if not return_loss_breakdown: + return total_loss + + return total_loss, (loss, multi_stft_resolution_loss) \ No newline at end of file diff --git a/separator/models/bs_roformer/bs_roformer_sw.py b/separator/models/bs_roformer/bs_roformer_sw.py new file mode 100644 index 0000000000000000000000000000000000000000..d85b1c1ad6503f78a0b97aba28d647c843e83c27 --- /dev/null +++ b/separator/models/bs_roformer/bs_roformer_sw.py @@ -0,0 +1,702 @@ +from __future__ import annotations + +from functools import partial + +import torch +import torch.nn.functional as F +from beartype import beartype +from beartype.typing import Callable +from einops import pack, rearrange, unpack +from einops.layers.torch import Rearrange +from torch import nn +from torch.nn import Module, ModuleList +from torch.utils.checkpoint import checkpoint + +from models.bs_roformer.attend_sw import Attend + +try: + from models.bs_roformer.attend_sage_sw import AttendSage +except ImportError: + pass + + +def l2norm(t): + return F.normalize(t, dim=-1, p=2) + + +class CustomNorm(Module): + def __init__(self, dim, eps: float = 5.960464477539063e-08): # 0x1p-24 + super().__init__() + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, x): + l2_norm = torch.linalg.norm(x, dim=-1, keepdim=True) + denom = torch.maximum(l2_norm, torch.full_like(l2_norm, self.eps)) + normalized_x = x / denom + return normalized_x * self.scale * self.gamma + + +# attention + + +class RotaryEmbedding(nn.Module): + def __init__(self, cos_emb, sin_emb): + super().__init__() + # both (seq_len_for_rotation, dim_head) + self.cos_emb = cos_emb + self.sin_emb = sin_emb + + def rotate_half(self, x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + def forward(self, x): + # x is (batch_eff, heads, seq_len_for_rotation, dim_head) + cos_b = self.cos_emb.unsqueeze(0).unsqueeze(0).to(x.device, x.dtype) + sin_b = self.sin_emb.unsqueeze(0).unsqueeze(0).to(x.device, x.dtype) + + term1 = x * cos_b + term2 = self.rotate_half(x) * sin_b + + sum = term1.to(torch.float32) + term2.to(torch.float32) + return sum.to(x.dtype) + + +class FeedForward(Module): + def __init__(self, dim, mult=4, dropout=0.0): + super().__init__() + dim_inner = int(dim * mult) + self.net = nn.Sequential( + CustomNorm(dim), + nn.Linear(dim, dim_inner), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_inner, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(Module): + def __init__( + self, + dim, + heads=8, + dim_head=64, + dropout=0.0, + shared_qkv_bias=None, + shared_out_bias=None, + rotary_embed: RotaryEmbedding | None = None, + flash=True, + sage_attention=False, + ): + super().__init__() + self.heads = heads + self.scale = dim_head**-0.5 + dim_inner = heads * dim_head + + self.rotary_embed = rotary_embed + + if sage_attention: + self.attend = AttendSage(flash=flash, dropout=dropout) # type: ignore + else: + self.attend = Attend(flash=flash, dropout=dropout) # type: ignore + + self.norm = CustomNorm(dim) + self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=(shared_qkv_bias is not None)) + if shared_qkv_bias is not None: + self.to_qkv.bias = shared_qkv_bias + + self.to_gates = nn.Linear(dim, heads) + + self.to_out = nn.Sequential( + nn.Linear(dim_inner, dim, bias=(shared_out_bias is not None)), + nn.Dropout(dropout), + ) + if shared_out_bias is not None: + self.to_out[0].bias = shared_out_bias + + def forward(self, x): + x = self.norm(x) + + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads) + + if self.rotary_embed is not None: + q = self.rotary_embed(q) + k = self.rotary_embed(k) + + out = self.attend(q, k, v) + + gates = self.to_gates(x) + gate_act = gates.sigmoid() + + out = out * rearrange(gate_act, "b n h -> b h n 1") + + out = rearrange(out, "b h n d -> b n (h d)") + out = self.to_out(out) + return out + + +class LinearAttention(Module): + """ + this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al. + """ + + @beartype + def __init__( + self, + *, + dim, + dim_head=32, + heads=8, + scale=8, + flash=False, + dropout=0.0, + sage_attention=False, + ): + super().__init__() + dim_inner = dim_head * heads + self.norm = CustomNorm(dim) + + self.to_qkv = nn.Sequential( + nn.Linear(dim, dim_inner * 3, bias=False), + Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads), + ) + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + if sage_attention: + self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash) # type: ignore + else: + self.attend = Attend(scale=scale, dropout=dropout, flash=flash) + + self.to_out = nn.Sequential( + Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False) + ) + + def forward(self, x): + x = self.norm(x) + + q, k, v = self.to_qkv(x) + + q, k = map(l2norm, (q, k)) + q = q * self.temperature.exp() + + out = self.attend(q, k, v) + + return self.to_out(out) + + +class Transformer(Module): + def __init__( + self, + *, + dim, + depth, + dim_head=64, + heads=8, + attn_dropout=0.0, + ff_dropout=0.0, + ff_mult=4, + norm_output=True, + rotary_embed: RotaryEmbedding | None = None, + flash_attn=True, + linear_attn=False, + sage_attention=False, + shared_qkv_bias=None, + shared_out_bias=None, + ): + super().__init__() + self.layers = ModuleList([]) + + for _ in range(depth): + attn: LinearAttention | Attention + if linear_attn: + attn = LinearAttention( + dim=dim, + dim_head=dim_head, + heads=heads, + dropout=attn_dropout, + flash=flash_attn, + sage_attention=sage_attention, + ) + else: + attn = Attention( + dim=dim, + dim_head=dim_head, + heads=heads, + dropout=attn_dropout, + shared_qkv_bias=shared_qkv_bias, + shared_out_bias=shared_out_bias, + rotary_embed=rotary_embed, + flash=flash_attn, + sage_attention=sage_attention, + ) + + self.layers.append( + ModuleList([attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]) + ) + + self.norm = CustomNorm(dim) if norm_output else nn.Identity() + + def forward(self, x): + for attn, ff in self.layers: # type: ignore + x = attn(x) + x + x = ff(x) + x + return self.norm(x) + + +# bandsplit module + + +class BandSplit(Module): + @beartype + def __init__(self, dim, dim_inputs: tuple[int, ...]): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential(CustomNorm(dim_in), nn.Linear(dim_in, dim)) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +def MLP( + dim_in: int, + dim_out: int, + dim_hidden: int | None = None, + depth: int = 1, + activation=nn.Tanh, +): + dim_hidden = dim_hidden or dim_in + + net = [] + dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out) + + for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): + is_last = ind == (len(dims) - 2) + + net.append(nn.Linear(layer_dim_in, layer_dim_out)) + + if is_last: + continue + + net.append(activation()) + + return nn.Sequential(*net) + + +class MaskEstimator(Module): + @beartype + def __init__(self, dim, dim_inputs: tuple[int, ...], depth, mlp_expansion_factor=4): + super().__init__() + self.dim_inputs = dim_inputs + self.to_freqs = ModuleList([]) + dim_hidden = dim * mlp_expansion_factor + + for dim_in in dim_inputs: + mlp = nn.Sequential( + MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1) + ) + + self.to_freqs.append(mlp) + + def forward(self, x): + x = x.unbind(dim=-2) + + outs = [] + + for band_features, mlp in zip(x, self.to_freqs): + freq_out = mlp(band_features) + outs.append(freq_out) + + return torch.cat(outs, dim=-1) + + +# fmt: off +DEFAULT_FREQS_PER_BANDS = ( + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 12, 12, 12, 12, 12, 12, 12, 12, + 24, 24, 24, 24, 24, 24, 24, 24, + 48, 48, 48, 48, 48, 48, 48, 48, + 128, 129 +) +# fmt: on + + +class BSRoformer_SW(Module): + @beartype + def __init__( + self, + dim, + *, + depth, + stereo=False, + num_stems=1, + time_transformer_depth=2, + freq_transformer_depth=2, + linear_transformer_depth=0, + freqs_per_bands: tuple[ + int, ... + ] = DEFAULT_FREQS_PER_BANDS, # in the paper, they divide into ~60 bands, test with 1 for starters + dim_head=64, + heads=8, + attn_dropout=0.0, + ff_dropout=0.0, + flash_attn=True, + stft_n_fft=2048, + stft_hop_length=512, # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction + stft_win_length=2048, + stft_normalized=False, + stft_window_fn: Callable | None = None, + mask_estimator_depth=2, + multi_stft_resolution_loss_weight=1.0, + multi_stft_resolutions_window_sizes: tuple[int, ...] = ( + 4096, + 2048, + 1024, + 512, + 256, + ), + multi_stft_hop_size=147, + multi_stft_normalized=False, + multi_stft_window_fn: Callable = torch.hann_window, + mlp_expansion_factor=4, + use_torch_checkpoint=False, + skip_connection=False, + sage_attention=False, + use_shared_bias=False, + chunk_size: int = 588800, + ): + super().__init__() + + self.stereo = stereo + self.audio_channels = 2 if stereo else 1 + self.num_stems = num_stems + self.use_torch_checkpoint = use_torch_checkpoint + self.skip_connection = skip_connection + + self.layers = ModuleList([]) + + if sage_attention: + print("Use Sage Attention") + + if use_shared_bias: + dim_inner = heads * dim_head + self.shared_qkv_bias = nn.Parameter(torch.ones(dim_inner * 3)) + self.shared_out_bias = nn.Parameter(torch.ones(dim)) + + transformer_kwargs = dict( + dim=dim, + heads=heads, + dim_head=dim_head, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + flash_attn=flash_attn, + norm_output=False, + sage_attention=sage_attention, + shared_qkv_bias=self.shared_qkv_bias, + shared_out_bias=self.shared_out_bias, + ) + + t_frames = chunk_size // stft_hop_length + 1 # e.g. 588800 // 512 + 1 = 1151 + self.cos_emb_time = nn.Parameter(torch.zeros(t_frames, dim_head)) + self.sin_emb_time = nn.Parameter(torch.zeros(t_frames, dim_head)) + time_rotary_embed = RotaryEmbedding(cos_emb=self.cos_emb_time, sin_emb=self.sin_emb_time) + + num_bands = len(freqs_per_bands) # e.g. 62 + self.cos_emb_freq = nn.Parameter(torch.zeros(num_bands, dim_head)) + self.sin_emb_freq = nn.Parameter(torch.zeros(num_bands, dim_head)) + freq_rotary_embed = RotaryEmbedding(cos_emb=self.cos_emb_freq, sin_emb=self.sin_emb_freq) + + for _ in range(depth): + tran_modules = [] + if linear_transformer_depth > 0: + tran_modules.append( + Transformer( + depth=linear_transformer_depth, + linear_attn=True, + **transformer_kwargs, + ) + ) + tran_modules.append( + Transformer( + depth=time_transformer_depth, + rotary_embed=time_rotary_embed, + **transformer_kwargs, + ) + ) + tran_modules.append( + Transformer( + depth=freq_transformer_depth, + rotary_embed=freq_rotary_embed, + **transformer_kwargs, + ) + ) + self.layers.append(nn.ModuleList(tran_modules)) + + self.final_norm = CustomNorm(dim) + + self.stft_kwargs = dict( + n_fft=stft_n_fft, + hop_length=stft_hop_length, + win_length=stft_win_length, + normalized=stft_normalized, + ) + + self.stft_window_fn = partial(stft_window_fn or torch.hann_window, stft_win_length) + + freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands) + + self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex) + + self.mask_estimators = nn.ModuleList([]) + + for _ in range(num_stems): + mask_estimator = MaskEstimator( + dim=dim, + dim_inputs=freqs_per_bands_with_complex, + depth=mask_estimator_depth, + mlp_expansion_factor=mlp_expansion_factor, + ) + + self.mask_estimators.append(mask_estimator) + + # for the multi-resolution stft loss + + self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight + self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes + self.multi_stft_n_fft = stft_n_fft + self.multi_stft_window_fn = multi_stft_window_fn + + self.multi_stft_kwargs = dict( + hop_length=multi_stft_hop_size, normalized=multi_stft_normalized + ) + + def forward(self, raw_audio, target=None, return_loss_breakdown=False): + """ + einops + + b - batch + f - freq + t - time + s - audio channel (1 for mono, 2 for stereo) + n - number of 'stems' + c - complex (2) + d - feature dimension + """ + + device = raw_audio.device + + # defining whether model is loaded on MPS (MacOS GPU accelerator) + x_is_mps = True if device.type == "mps" else False + + if raw_audio.ndim == 2: + raw_audio = rearrange(raw_audio, "b t -> b 1 t") + + channels = raw_audio.shape[1] + assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), ( + "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)" + ) + + # to stft + + raw_audio, batch_audio_channel_packed_shape = pack([raw_audio], "* t") + + stft_window = self.stft_window_fn(device=device) + + # RuntimeError: FFT operations are only supported on MacOS 14+ + # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used + try: + stft_repr = torch.stft( + raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True + ) + except Exception: + stft_repr = torch.stft( + raw_audio.cpu() if x_is_mps else raw_audio, + **self.stft_kwargs, + window=stft_window.cpu() if x_is_mps else stft_window, + return_complex=True, + ).to(device) + stft_repr = torch.view_as_real(stft_repr) + + stft_repr = unpack(stft_repr, batch_audio_channel_packed_shape, "* f t c")[0] + + # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting + stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c") + + x = rearrange(stft_repr, "b f t c -> b t (f c)") + + if torch.isnan(x).any() or torch.isinf(x).any(): + raise RuntimeError( + f"NaN/Inf in x after stft: {x.isnan().sum()} NaNs, {x.isinf().sum()} Infs" + ) + + if self.use_torch_checkpoint: + x = checkpoint(self.band_split, x, use_reentrant=False) + else: + x = self.band_split(x) + + if torch.isnan(x).any() or torch.isinf(x).any(): + raise RuntimeError( + f"NaN/Inf in x after band_split: {x.isnan().sum()} NaNs, {x.isinf().sum()} Infs" + ) + + # axial / hierarchical attention + + store = [None] * len(self.layers) + for i, transformer_block in enumerate(self.layers): + if len(transformer_block) == 3: + linear_transformer, time_transformer, freq_transformer = transformer_block + + x, ft_ps = pack([x], "b * d") + if self.use_torch_checkpoint: + x = checkpoint(linear_transformer, x, use_reentrant=False) + else: + x = linear_transformer(x) + (x,) = unpack(x, ft_ps, "b * d") + else: + time_transformer, freq_transformer = transformer_block + + if self.skip_connection: + # Sum all previous + for j in range(i): + x = x + store[j] + + x = rearrange(x, "b t f d -> b f t d") + x, ps = pack([x], "* t d") + + if self.use_torch_checkpoint: + x = checkpoint(time_transformer, x, use_reentrant=False) + else: + x = time_transformer(x) + + (x,) = unpack(x, ps, "* t d") + x = rearrange(x, "b f t d -> b t f d") + x, ps = pack([x], "* f d") + + if self.use_torch_checkpoint: + x = checkpoint(freq_transformer, x, use_reentrant=False) + else: + x = freq_transformer(x) + + (x,) = unpack(x, ps, "* f d") + + if self.skip_connection: + store[i] = x + + x = self.final_norm(x) + + num_stems = len(self.mask_estimators) + + if self.use_torch_checkpoint: + mask = torch.stack( + [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], + dim=1, + ) + else: + mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) + mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2) + + # modulate frequency representation + + stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c") + + # complex number multiplication + + stft_repr = torch.view_as_complex(stft_repr) + mask = torch.view_as_complex(mask) + + stft_repr = stft_repr * mask + + # istft + + stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels) + + # same as torch.stft() fix for MacOS MPS above + try: + recon_audio = torch.istft( + stft_repr, + **self.stft_kwargs, + window=stft_window, + return_complex=False, + length=raw_audio.shape[-1], + ) + except Exception: + recon_audio = torch.istft( + stft_repr.cpu() if x_is_mps else stft_repr, + **self.stft_kwargs, + window=stft_window.cpu() if x_is_mps else stft_window, + return_complex=False, + length=raw_audio.shape[-1], + ).to(device) + + recon_audio = rearrange( + recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems + ) + + if num_stems == 1: + recon_audio = rearrange(recon_audio, "b 1 s t -> b s t") + + # if a target is passed in, calculate loss for learning + + if target is None: + return recon_audio + + if self.num_stems > 1: + assert target.ndim == 4 and target.shape[1] == self.num_stems + + if target.ndim == 2: + target = rearrange(target, "... t -> ... 1 t") + + target = target[..., : recon_audio.shape[-1]] # protect against lost length on istft + + loss = F.l1_loss(recon_audio, target) + + multi_stft_resolution_loss = 0.0 + + for window_size in self.multi_stft_resolutions_window_sizes: + res_stft_kwargs = dict( + n_fft=max( + window_size, self.multi_stft_n_fft + ), # not sure what n_fft is across multi resolution stft + win_length=window_size, + return_complex=True, + window=self.multi_stft_window_fn(window_size, device=device), + **self.multi_stft_kwargs, + ) + + recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs) + target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs) + + multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y) + + weighted_multi_resolution_loss = ( + multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + ) + + total_loss = loss + weighted_multi_resolution_loss + + if not return_loss_breakdown: + return total_loss + + return total_loss, (loss, multi_stft_resolution_loss) \ No newline at end of file diff --git a/separator/models/bs_roformer/mel_band_roformer.py b/separator/models/bs_roformer/mel_band_roformer.py new file mode 100644 index 0000000000000000000000000000000000000000..9a3c0d644cdf6b219ed27501fa3817a6ddad6bdd --- /dev/null +++ b/separator/models/bs_roformer/mel_band_roformer.py @@ -0,0 +1,667 @@ +from functools import partial + +import torch +from torch import nn, einsum, Tensor +from torch.nn import Module, ModuleList +import torch.nn.functional as F + +from models.bs_roformer.attend import Attend +from torch.utils.checkpoint import checkpoint + +from beartype.typing import Tuple, Optional, List, Callable +from beartype import beartype + +from rotary_embedding_torch import RotaryEmbedding + +from einops import rearrange, pack, unpack, reduce, repeat +from einops.layers.torch import Rearrange + +from librosa import filters + + +# helper functions + +def exists(val): + return val is not None + + +def default(v, d): + return v if exists(v) else d + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +def pad_at_dim(t, pad, dim=-1, value=0.): + dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) + zeros = ((0, 0) * dims_from_right) + return F.pad(t, (*zeros, *pad), value=value) + + +def l2norm(t): + return F.normalize(t, dim=-1, p=2) + + +# norm + +class RMSNorm(Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim=-1) * self.scale * self.gamma + + +# attention + +class FeedForward(Module): + def __init__( + self, + dim, + mult=4, + dropout=0. + ): + super().__init__() + dim_inner = int(dim * mult) + self.net = nn.Sequential( + RMSNorm(dim), + nn.Linear(dim, dim_inner), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_inner, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + + +class Attention(Module): + def __init__( + self, + dim, + heads=8, + dim_head=64, + dropout=0., + rotary_embed=None, + flash=True + ): + super().__init__() + self.heads = heads + self.scale = dim_head ** -0.5 + dim_inner = heads * dim_head + + self.rotary_embed = rotary_embed + + self.attend = Attend(flash=flash, dropout=dropout) + + self.norm = RMSNorm(dim) + self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) + + self.to_gates = nn.Linear(dim, heads) + + self.to_out = nn.Sequential( + nn.Linear(dim_inner, dim, bias=False), + nn.Dropout(dropout) + ) + + def forward(self, x): + x = self.norm(x) + + q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads) + + if exists(self.rotary_embed): + q = self.rotary_embed.rotate_queries_or_keys(q) + k = self.rotary_embed.rotate_queries_or_keys(k) + + out = self.attend(q, k, v) + + gates = self.to_gates(x) + out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() + + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class LinearAttention(Module): + """ + this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al. + """ + + @beartype + def __init__( + self, + *, + dim, + dim_head=32, + heads=8, + scale=8, + flash=False, + dropout=0. + ): + super().__init__() + dim_inner = dim_head * heads + self.norm = RMSNorm(dim) + + self.to_qkv = nn.Sequential( + nn.Linear(dim, dim_inner * 3, bias=False), + Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads) + ) + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.attend = Attend( + scale=scale, + dropout=dropout, + flash=flash + ) + + self.to_out = nn.Sequential( + Rearrange('b h d n -> b n (h d)'), + nn.Linear(dim_inner, dim, bias=False) + ) + + def forward( + self, + x + ): + x = self.norm(x) + + q, k, v = self.to_qkv(x) + + q, k = map(l2norm, (q, k)) + q = q * self.temperature.exp() + + out = self.attend(q, k, v) + + return self.to_out(out) + + +class Transformer(Module): + def __init__( + self, + *, + dim, + depth, + dim_head=64, + heads=8, + attn_dropout=0., + ff_dropout=0., + ff_mult=4, + norm_output=True, + rotary_embed=None, + flash_attn=True, + linear_attn=False + ): + super().__init__() + self.layers = ModuleList([]) + + for _ in range(depth): + if linear_attn: + attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn) + else: + attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, + rotary_embed=rotary_embed, flash=flash_attn) + + self.layers.append(ModuleList([ + attn, + FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) + ])) + + self.norm = RMSNorm(dim) if norm_output else nn.Identity() + + def forward(self, x): + + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) + + +# bandsplit module + +class BandSplit(Module): + @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...] + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential( + RMSNorm(dim_in), + nn.Linear(dim_in, dim) + ) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +def MLP( + dim_in, + dim_out, + dim_hidden=None, + depth=1, + activation=nn.Tanh +): + dim_hidden = default(dim_hidden, dim_in) + + net = [] + dims = (dim_in, *((dim_hidden,) * depth), dim_out) + + for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): + is_last = ind == (len(dims) - 2) + + net.append(nn.Linear(layer_dim_in, layer_dim_out)) + + if is_last: + continue + + net.append(activation()) + + return nn.Sequential(*net) + + +class MaskEstimator(Module): + @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...], + depth, + mlp_expansion_factor=4 + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_freqs = ModuleList([]) + dim_hidden = dim * mlp_expansion_factor + + for dim_in in dim_inputs: + net = [] + + mlp = nn.Sequential( + MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), + nn.GLU(dim=-1) + ) + + self.to_freqs.append(mlp) + + def forward(self, x): + x = x.unbind(dim=-2) + + outs = [] + + for band_features, mlp in zip(x, self.to_freqs): + freq_out = mlp(band_features) + outs.append(freq_out) + + return torch.cat(outs, dim=-1) + + +# main class + +class MelBandRoformer(Module): + + @beartype + def __init__( + self, + dim, + *, + depth, + stereo=False, + num_stems=1, + time_transformer_depth=2, + freq_transformer_depth=2, + linear_transformer_depth=0, + num_bands=60, + dim_head=64, + heads=8, + attn_dropout=0.1, + ff_dropout=0.1, + flash_attn=True, + dim_freqs_in=1025, + sample_rate=44100, # needed for mel filter bank from librosa + stft_n_fft=2048, + stft_hop_length=512, + # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction + stft_win_length=2048, + stft_normalized=False, + stft_window_fn: Optional[Callable] = None, + mask_estimator_depth=1, + multi_stft_resolution_loss_weight=1., + multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256), + multi_stft_hop_size=147, + multi_stft_normalized=False, + multi_stft_window_fn: Callable = torch.hann_window, + match_input_audio_length=False, # if True, pad output tensor to match length of input tensor + mlp_expansion_factor=4, + use_torch_checkpoint=False, + skip_connection=False, + ): + super().__init__() + + self.stereo = stereo + self.audio_channels = 2 if stereo else 1 + self.num_stems = num_stems + self.use_torch_checkpoint = use_torch_checkpoint + self.skip_connection = skip_connection + + self.layers = ModuleList([]) + + transformer_kwargs = dict( + dim=dim, + heads=heads, + dim_head=dim_head, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + flash_attn=flash_attn + ) + + time_rotary_embed = RotaryEmbedding(dim=dim_head) + freq_rotary_embed = RotaryEmbedding(dim=dim_head) + + for _ in range(depth): + tran_modules = [] + if linear_transformer_depth > 0: + tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs)) + tran_modules.append( + Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs) + ) + tran_modules.append( + Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs) + ) + self.layers.append(nn.ModuleList(tran_modules)) + + self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length) + + self.stft_kwargs = dict( + n_fft=stft_n_fft, + hop_length=stft_hop_length, + win_length=stft_win_length, + normalized=stft_normalized + ) + + freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1] + + # create mel filter bank + # with librosa.filters.mel as in section 2 of paper + + mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands) + + mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy) + + # for some reason, it doesn't include the first freq? just force a value for now + + mel_filter_bank[0][0] = 1. + + # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position, + # so let's force a positive value + + mel_filter_bank[-1, -1] = 1. + + # binary as in paper (then estimated masks are averaged for overlapping regions) + + freqs_per_band = mel_filter_bank > 0 + assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now' + + repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands) + freq_indices = repeated_freq_indices[freqs_per_band] + + if stereo: + freq_indices = repeat(freq_indices, 'f -> f s', s=2) + freq_indices = freq_indices * 2 + torch.arange(2) + freq_indices = rearrange(freq_indices, 'f s -> (f s)') + + self.register_buffer('freq_indices', freq_indices, persistent=False) + self.register_buffer('freqs_per_band', freqs_per_band, persistent=False) + + num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum') + num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum') + + self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False) + self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False) + + # band split and mask estimator + + freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist()) + + self.band_split = BandSplit( + dim=dim, + dim_inputs=freqs_per_bands_with_complex + ) + + self.mask_estimators = nn.ModuleList([]) + + for _ in range(num_stems): + mask_estimator = MaskEstimator( + dim=dim, + dim_inputs=freqs_per_bands_with_complex, + depth=mask_estimator_depth, + mlp_expansion_factor=mlp_expansion_factor, + ) + + self.mask_estimators.append(mask_estimator) + + # for the multi-resolution stft loss + + self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight + self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes + self.multi_stft_n_fft = stft_n_fft + self.multi_stft_window_fn = multi_stft_window_fn + + self.multi_stft_kwargs = dict( + hop_length=multi_stft_hop_size, + normalized=multi_stft_normalized + ) + + self.match_input_audio_length = match_input_audio_length + + def forward( + self, + raw_audio, + target=None, + return_loss_breakdown=False + ): + """ + einops + + b - batch + f - freq + t - time + s - audio channel (1 for mono, 2 for stereo) + n - number of 'stems' + c - complex (2) + d - feature dimension + """ + + device = raw_audio.device + + if raw_audio.ndim == 2: + raw_audio = rearrange(raw_audio, 'b t -> b 1 t') + + batch, channels, raw_audio_length = raw_audio.shape + + istft_length = raw_audio_length if self.match_input_audio_length else None + + assert (not self.stereo and channels == 1) or ( + self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)' + + # to stft + + raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') + + stft_window = self.stft_window_fn(device=device) + + stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True) + stft_repr = torch.view_as_real(stft_repr) + + stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c') + stft_repr = rearrange(stft_repr, + 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting + + # index out all frequencies for all frequency ranges across bands ascending in one go + + batch_arange = torch.arange(batch, device=device)[..., None] + + # account for stereo + + x = stft_repr[batch_arange, self.freq_indices] + + # fold the complex (real and imag) into the frequencies dimension + + x = rearrange(x, 'b f t c -> b t (f c)') + + if self.use_torch_checkpoint: + x = checkpoint(self.band_split, x, use_reentrant=False) + else: + x = self.band_split(x) + + # axial / hierarchical attention + + store = [None] * len(self.layers) + for i, transformer_block in enumerate(self.layers): + + if len(transformer_block) == 3: + linear_transformer, time_transformer, freq_transformer = transformer_block + + x, ft_ps = pack([x], 'b * d') + if self.use_torch_checkpoint: + x = checkpoint(linear_transformer, x, use_reentrant=False) + else: + x = linear_transformer(x) + x, = unpack(x, ft_ps, 'b * d') + else: + time_transformer, freq_transformer = transformer_block + + if self.skip_connection: + # Sum all previous + for j in range(i): + x = x + store[j] + + x = rearrange(x, 'b t f d -> b f t d') + x, ps = pack([x], '* t d') + + if self.use_torch_checkpoint: + x = checkpoint(time_transformer, x, use_reentrant=False) + else: + x = time_transformer(x) + + x, = unpack(x, ps, '* t d') + x = rearrange(x, 'b f t d -> b t f d') + x, ps = pack([x], '* f d') + + if self.use_torch_checkpoint: + x = checkpoint(freq_transformer, x, use_reentrant=False) + else: + x = freq_transformer(x) + + x, = unpack(x, ps, '* f d') + + if self.skip_connection: + store[i] = x + + num_stems = len(self.mask_estimators) + if self.use_torch_checkpoint: + masks = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1) + else: + masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) + masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2) + + # modulate frequency representation + + stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') + + # complex number multiplication + + stft_repr = torch.view_as_complex(stft_repr) + masks = torch.view_as_complex(masks) + + masks = masks.type(stft_repr.dtype) + + # need to average the estimated mask for the overlapped frequencies + + scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1]) + + stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems) + masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks) + + denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels) + + masks_averaged = masks_summed / denom.clamp(min=1e-8) + + # modulate stft repr with estimated mask + + stft_repr = stft_repr * masks_averaged + + # istft + + stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels) + + recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, + length=istft_length) + + recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems) + + if num_stems == 1: + recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t') + + # if a target is passed in, calculate loss for learning + + if not exists(target): + return recon_audio + + if self.num_stems > 1: + assert target.ndim == 4 and target.shape[1] == self.num_stems + + if target.ndim == 2: + target = rearrange(target, '... t -> ... 1 t') + + target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft + + loss = F.l1_loss(recon_audio, target) + + multi_stft_resolution_loss = 0. + + for window_size in self.multi_stft_resolutions_window_sizes: + res_stft_kwargs = dict( + n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft + win_length=window_size, + return_complex=True, + window=self.multi_stft_window_fn(window_size, device=device), + **self.multi_stft_kwargs, + ) + + recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs) + target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs) + + multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y) + + weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + + total_loss = loss + weighted_multi_resolution_loss + + if not return_loss_breakdown: + return total_loss + + return total_loss, (loss, multi_stft_resolution_loss) diff --git a/separator/models/demucs4ht.py b/separator/models/demucs4ht.py new file mode 100644 index 0000000000000000000000000000000000000000..06c279c31a7ac7e12af4375a5715eb291ad5405c --- /dev/null +++ b/separator/models/demucs4ht.py @@ -0,0 +1,713 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +import numpy as np +import torch +import json +from omegaconf import OmegaConf +from demucs.demucs import Demucs +from demucs.hdemucs import HDemucs + +import math +from openunmix.filtering import wiener +from torch import nn +from torch.nn import functional as F +from fractions import Fraction +from einops import rearrange + +from demucs.transformer import CrossTransformerEncoder + +from demucs.demucs import rescale_module +from demucs.states import capture_init +from demucs.spec import spectro, ispectro +from demucs.hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer + + +class HTDemucs(nn.Module): + """ + Spectrogram and hybrid Demucs model. + The spectrogram model has the same structure as Demucs, except the first few layers are over the + frequency axis, until there is only 1 frequency, and then it moves to time convolutions. + Frequency layers can still access information across time steps thanks to the DConv residual. + + Hybrid model have a parallel time branch. At some layer, the time branch has the same stride + as the frequency branch and then the two are combined. The opposite happens in the decoder. + + Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), + or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on + Open Unmix implementation [Stoter et al. 2019]. + + The loss is always on the temporal domain, by backpropagating through the above + output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks + a bit Wiener filtering, as doing more iteration at test time will change the spectrogram + contribution, without changing the one from the waveform, which will lead to worse performance. + I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. + CaC on the other hand provides similar performance for hybrid, and works naturally with + hybrid models. + + This model also uses frequency embeddings are used to improve efficiency on convolutions + over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). + + Unlike classic Demucs, there is no resampling here, and normalization is always applied. + """ + + @capture_init + def __init__( + self, + sources, + # Channels + audio_channels=2, + channels=48, + channels_time=None, + growth=2, + # STFT + nfft=4096, + num_subbands=1, + wiener_iters=0, + end_iters=0, + wiener_residual=False, + cac=True, + # Main structure + depth=4, + rewrite=True, + # Frequency branch + multi_freqs=None, + multi_freqs_depth=3, + freq_emb=0.2, + emb_scale=10, + emb_smooth=True, + # Convolutions + kernel_size=8, + time_stride=2, + stride=4, + context=1, + context_enc=0, + # Normalization + norm_starts=4, + norm_groups=4, + # DConv residual branch + dconv_mode=1, + dconv_depth=2, + dconv_comp=8, + dconv_init=1e-3, + # Before the Transformer + bottom_channels=0, + # Transformer + t_layers=5, + t_emb="sin", + t_hidden_scale=4.0, + t_heads=8, + t_dropout=0.0, + t_max_positions=10000, + t_norm_in=True, + t_norm_in_group=False, + t_group_norm=False, + t_norm_first=True, + t_norm_out=True, + t_max_period=10000.0, + t_weight_decay=0.0, + t_lr=None, + t_layer_scale=True, + t_gelu=True, + t_weight_pos_embed=1.0, + t_sin_random_shift=0, + t_cape_mean_normalize=True, + t_cape_augment=True, + t_cape_glob_loc_scale=[5000.0, 1.0, 1.4], + t_sparse_self_attn=False, + t_sparse_cross_attn=False, + t_mask_type="diag", + t_mask_random_seed=42, + t_sparse_attn_window=500, + t_global_window=100, + t_sparsity=0.95, + t_auto_sparsity=False, + # ------ Particuliar parameters + t_cross_first=False, + # Weight init + rescale=0.1, + # Metadata + samplerate=44100, + segment=10, + use_train_segment=False, + ): + """ + Args: + sources (list[str]): list of source names. + audio_channels (int): input/output audio channels. + channels (int): initial number of hidden channels. + channels_time: if not None, use a different `channels` value for the time branch. + growth: increase the number of hidden channels by this factor at each layer. + nfft: number of fft bins. Note that changing this require careful computation of + various shape parameters and will not work out of the box for hybrid models. + wiener_iters: when using Wiener filtering, number of iterations at test time. + end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. + wiener_residual: add residual source before wiener filtering. + cac: uses complex as channels, i.e. complex numbers are 2 channels each + in input and output. no further processing is done before ISTFT. + depth (int): number of layers in the encoder and in the decoder. + rewrite (bool): add 1x1 convolution to each layer. + multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. + multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost + layers will be wrapped. + freq_emb: add frequency embedding after the first frequency layer if > 0, + the actual value controls the weight of the embedding. + emb_scale: equivalent to scaling the embedding learning rate + emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). + kernel_size: kernel_size for encoder and decoder layers. + stride: stride for encoder and decoder layers. + time_stride: stride for the final time layer, after the merge. + context: context for 1x1 conv in the decoder. + context_enc: context for 1x1 conv in the encoder. + norm_starts: layer at which group norm starts being used. + decoder layers are numbered in reverse order. + norm_groups: number of groups for group norm. + dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. + dconv_depth: depth of residual DConv branch. + dconv_comp: compression of DConv branch. + dconv_attn: adds attention layers in DConv branch starting at this layer. + dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. + dconv_init: initial scale for the DConv branch LayerScale. + bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the + transformer in order to change the number of channels + t_layers: number of layers in each branch (waveform and spec) of the transformer + t_emb: "sin", "cape" or "scaled" + t_hidden_scale: the hidden scale of the Feedforward parts of the transformer + for instance if C = 384 (the number of channels in the transformer) and + t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension + 384 * 4 = 1536 + t_heads: number of heads for the transformer + t_dropout: dropout in the transformer + t_max_positions: max_positions for the "scaled" positional embedding, only + useful if t_emb="scaled" + t_norm_in: (bool) norm before addinf positional embedding and getting into the + transformer layers + t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the + timesteps (GroupNorm with group=1) + t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the + timesteps (GroupNorm with group=1) + t_norm_first: (bool) if True the norm is before the attention and before the FFN + t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer + t_max_period: (float) denominator in the sinusoidal embedding expression + t_weight_decay: (float) weight decay for the transformer + t_lr: (float) specific learning rate for the transformer + t_layer_scale: (bool) Layer Scale for the transformer + t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else + t_weight_pos_embed: (float) weighting of the positional embedding + t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings + see: https://arxiv.org/abs/2106.03143 + t_cape_augment: (bool) if t_emb="cape", must be True during training and False + during the inference, see: https://arxiv.org/abs/2106.03143 + t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters + see: https://arxiv.org/abs/2106.03143 + t_sparse_self_attn: (bool) if True, the self attentions are sparse + t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it + unless you designed really specific masks) + t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination + with '_' between: i.e. "diag_jmask_random" (note that this is permutation + invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag") + t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed + that generated the random part of the mask + t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and + a key (j), the mask is True id |i-j|<=t_sparse_attn_window + t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :] + and mask[:, :t_global_window] will be True + t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity + level of the random part of the mask. + t_cross_first: (bool) if True cross attention is the first layer of the + transformer (False seems to be better) + rescale: weight rescaling trick + use_train_segment: (bool) if True, the actual size that is used during the + training is used during inference. + """ + super().__init__() + self.num_subbands = num_subbands + self.cac = cac + self.wiener_residual = wiener_residual + self.audio_channels = audio_channels + self.sources = sources + self.kernel_size = kernel_size + self.context = context + self.stride = stride + self.depth = depth + self.bottom_channels = bottom_channels + self.channels = channels + self.samplerate = samplerate + self.segment = segment + self.use_train_segment = use_train_segment + self.nfft = nfft + self.hop_length = nfft // 4 + self.wiener_iters = wiener_iters + self.end_iters = end_iters + self.freq_emb = None + assert wiener_iters == end_iters + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.tencoder = nn.ModuleList() + self.tdecoder = nn.ModuleList() + + chin = audio_channels + chin_z = chin # number of channels for the freq branch + if self.cac: + chin_z *= 2 + if self.num_subbands > 1: + chin_z *= self.num_subbands + chout = channels_time or channels + chout_z = channels + freqs = nfft // 2 + + for index in range(depth): + norm = index >= norm_starts + freq = freqs > 1 + stri = stride + ker = kernel_size + if not freq: + assert freqs == 1 + ker = time_stride * 2 + stri = time_stride + + pad = True + last_freq = False + if freq and freqs <= kernel_size: + ker = freqs + pad = False + last_freq = True + + kw = { + "kernel_size": ker, + "stride": stri, + "freq": freq, + "pad": pad, + "norm": norm, + "rewrite": rewrite, + "norm_groups": norm_groups, + "dconv_kw": { + "depth": dconv_depth, + "compress": dconv_comp, + "init": dconv_init, + "gelu": True, + }, + } + kwt = dict(kw) + kwt["freq"] = 0 + kwt["kernel_size"] = kernel_size + kwt["stride"] = stride + kwt["pad"] = True + kw_dec = dict(kw) + multi = False + if multi_freqs and index < multi_freqs_depth: + multi = True + kw_dec["context_freq"] = False + + if last_freq: + chout_z = max(chout, chout_z) + chout = chout_z + + enc = HEncLayer( + chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw + ) + if freq: + tenc = HEncLayer( + chin, + chout, + dconv=dconv_mode & 1, + context=context_enc, + empty=last_freq, + **kwt + ) + self.tencoder.append(tenc) + + if multi: + enc = MultiWrap(enc, multi_freqs) + self.encoder.append(enc) + if index == 0: + chin = self.audio_channels * len(self.sources) + chin_z = chin + if self.cac: + chin_z *= 2 + if self.num_subbands > 1: + chin_z *= self.num_subbands + dec = HDecLayer( + chout_z, + chin_z, + dconv=dconv_mode & 2, + last=index == 0, + context=context, + **kw_dec + ) + if multi: + dec = MultiWrap(dec, multi_freqs) + if freq: + tdec = HDecLayer( + chout, + chin, + dconv=dconv_mode & 2, + empty=last_freq, + last=index == 0, + context=context, + **kwt + ) + self.tdecoder.insert(0, tdec) + self.decoder.insert(0, dec) + + chin = chout + chin_z = chout_z + chout = int(growth * chout) + chout_z = int(growth * chout_z) + if freq: + if freqs <= kernel_size: + freqs = 1 + else: + freqs //= stride + if index == 0 and freq_emb: + self.freq_emb = ScaledEmbedding( + freqs, chin_z, smooth=emb_smooth, scale=emb_scale + ) + self.freq_emb_scale = freq_emb + + if rescale: + rescale_module(self, reference=rescale) + + transformer_channels = channels * growth ** (depth - 1) + if bottom_channels: + self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1) + self.channel_downsampler = nn.Conv1d( + bottom_channels, transformer_channels, 1 + ) + self.channel_upsampler_t = nn.Conv1d( + transformer_channels, bottom_channels, 1 + ) + self.channel_downsampler_t = nn.Conv1d( + bottom_channels, transformer_channels, 1 + ) + + transformer_channels = bottom_channels + + if t_layers > 0: + self.crosstransformer = CrossTransformerEncoder( + dim=transformer_channels, + emb=t_emb, + hidden_scale=t_hidden_scale, + num_heads=t_heads, + num_layers=t_layers, + cross_first=t_cross_first, + dropout=t_dropout, + max_positions=t_max_positions, + norm_in=t_norm_in, + norm_in_group=t_norm_in_group, + group_norm=t_group_norm, + norm_first=t_norm_first, + norm_out=t_norm_out, + max_period=t_max_period, + weight_decay=t_weight_decay, + lr=t_lr, + layer_scale=t_layer_scale, + gelu=t_gelu, + sin_random_shift=t_sin_random_shift, + weight_pos_embed=t_weight_pos_embed, + cape_mean_normalize=t_cape_mean_normalize, + cape_augment=t_cape_augment, + cape_glob_loc_scale=t_cape_glob_loc_scale, + sparse_self_attn=t_sparse_self_attn, + sparse_cross_attn=t_sparse_cross_attn, + mask_type=t_mask_type, + mask_random_seed=t_mask_random_seed, + sparse_attn_window=t_sparse_attn_window, + global_window=t_global_window, + sparsity=t_sparsity, + auto_sparsity=t_auto_sparsity, + ) + else: + self.crosstransformer = None + + def _spec(self, x): + hl = self.hop_length + nfft = self.nfft + x0 = x # noqa + + # We re-pad the signal in order to keep the property + # that the size of the output is exactly the size of the input + # divided by the stride (here hop_length), when divisible. + # This is achieved by padding by 1/4th of the kernel size (here nfft). + # which is not supported by torch.stft. + # Having all convolution operations follow this convention allow to easily + # align the time and frequency branches later on. + assert hl == nfft // 4 + le = int(math.ceil(x.shape[-1] / hl)) + pad = hl // 2 * 3 + x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect") + + z = spectro(x, nfft, hl)[..., :-1, :] + assert z.shape[-1] == le + 4, (z.shape, x.shape, le) + z = z[..., 2: 2 + le] + return z + + def _ispec(self, z, length=None, scale=0): + hl = self.hop_length // (4**scale) + z = F.pad(z, (0, 0, 0, 1)) + z = F.pad(z, (2, 2)) + pad = hl // 2 * 3 + le = hl * int(math.ceil(length / hl)) + 2 * pad + x = ispectro(z, hl, length=le) + x = x[..., pad: pad + length] + return x + + def _magnitude(self, z): + # return the magnitude of the spectrogram, except when cac is True, + # in which case we just move the complex dimension to the channel one. + if self.cac: + B, C, Fr, T = z.shape + m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) + m = m.reshape(B, C * 2, Fr, T) + else: + m = z.abs() + return m + + def _mask(self, z, m): + # Apply masking given the mixture spectrogram `z` and the estimated mask `m`. + # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored. + niters = self.wiener_iters + if self.cac: + B, S, C, Fr, T = m.shape + out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) + out = torch.view_as_complex(out.contiguous()) + return out + if self.training: + niters = self.end_iters + if niters < 0: + z = z[:, None] + return z / (1e-8 + z.abs()) * m + else: + return self._wiener(m, z, niters) + + def _wiener(self, mag_out, mix_stft, niters): + # apply wiener filtering from OpenUnmix. + init = mix_stft.dtype + wiener_win_len = 300 + residual = self.wiener_residual + + B, S, C, Fq, T = mag_out.shape + mag_out = mag_out.permute(0, 4, 3, 2, 1) + mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) + + outs = [] + for sample in range(B): + pos = 0 + out = [] + for pos in range(0, T, wiener_win_len): + frame = slice(pos, pos + wiener_win_len) + z_out = wiener( + mag_out[sample, frame], + mix_stft[sample, frame], + niters, + residual=residual, + ) + out.append(z_out.transpose(-1, -2)) + outs.append(torch.cat(out, dim=0)) + out = torch.view_as_complex(torch.stack(outs, 0)) + out = out.permute(0, 4, 3, 2, 1).contiguous() + if residual: + out = out[:, :-1] + assert list(out.shape) == [B, S, C, Fq, T] + return out.to(init) + + def valid_length(self, length: int): + """ + Return a length that is appropriate for evaluation. + In our case, always return the training length, unless + it is smaller than the given length, in which case this + raises an error. + """ + if not self.use_train_segment: + return length + training_length = int(self.segment * self.samplerate) + if training_length < length: + raise ValueError( + f"Given length {length} is longer than " + f"training length {training_length}") + return training_length + + def cac2cws(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c, k, f // k, t) + x = x.reshape(b, c * k, f // k, t) + return x + + def cws2cac(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c // k, k, f, t) + x = x.reshape(b, c // k, f * k, t) + return x + + def forward(self, mix): + length = mix.shape[-1] + length_pre_pad = None + if self.use_train_segment: + if self.training: + self.segment = Fraction(mix.shape[-1], self.samplerate) + else: + training_length = int(self.segment * self.samplerate) + # print('Training length: {} Segment: {} Sample rate: {}'.format(training_length, self.segment, self.samplerate)) + if mix.shape[-1] < training_length: + length_pre_pad = mix.shape[-1] + mix = F.pad(mix, (0, training_length - length_pre_pad)) + # print("Mix: {}".format(mix.shape)) + # print("Length: {}".format(length)) + z = self._spec(mix) + # print("Z: {} Type: {}".format(z.shape, z.dtype)) + mag = self._magnitude(z) + x = mag + # print("MAG: {} Type: {}".format(x.shape, x.dtype)) + + if self.num_subbands > 1: + x = self.cac2cws(x) + # print("After SUBBANDS: {} Type: {}".format(x.shape, x.dtype)) + + B, C, Fq, T = x.shape + + # unlike previous Demucs, we always normalize because it is easier. + mean = x.mean(dim=(1, 2, 3), keepdim=True) + std = x.std(dim=(1, 2, 3), keepdim=True) + x = (x - mean) / (1e-5 + std) + # x will be the freq. branch input. + + # Prepare the time branch input. + xt = mix + meant = xt.mean(dim=(1, 2), keepdim=True) + stdt = xt.std(dim=(1, 2), keepdim=True) + xt = (xt - meant) / (1e-5 + stdt) + + # print("XT: {}".format(xt.shape)) + + # okay, this is a giant mess I know... + saved = [] # skip connections, freq. + saved_t = [] # skip connections, time. + lengths = [] # saved lengths to properly remove padding, freq branch. + lengths_t = [] # saved lengths for time branch. + for idx, encode in enumerate(self.encoder): + lengths.append(x.shape[-1]) + inject = None + if idx < len(self.tencoder): + # we have not yet merged branches. + lengths_t.append(xt.shape[-1]) + tenc = self.tencoder[idx] + xt = tenc(xt) + # print("Encode XT {}: {}".format(idx, xt.shape)) + if not tenc.empty: + # save for skip connection + saved_t.append(xt) + else: + # tenc contains just the first conv., so that now time and freq. + # branches have the same shape and can be merged. + inject = xt + x = encode(x, inject) + # print("Encode X {}: {}".format(idx, x.shape)) + if idx == 0 and self.freq_emb is not None: + # add frequency embedding to allow for non equivariant convolutions + # over the frequency axis. + frs = torch.arange(x.shape[-2], device=x.device) + emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) + x = x + self.freq_emb_scale * emb + + saved.append(x) + if self.crosstransformer: + if self.bottom_channels: + b, c, f, t = x.shape + x = rearrange(x, "b c f t-> b c (f t)") + x = self.channel_upsampler(x) + x = rearrange(x, "b c (f t)-> b c f t", f=f) + xt = self.channel_upsampler_t(xt) + + x, xt = self.crosstransformer(x, xt) + # print("Cross Tran X {}, XT: {}".format(x.shape, xt.shape)) + + if self.bottom_channels: + x = rearrange(x, "b c f t-> b c (f t)") + x = self.channel_downsampler(x) + x = rearrange(x, "b c (f t)-> b c f t", f=f) + xt = self.channel_downsampler_t(xt) + + for idx, decode in enumerate(self.decoder): + skip = saved.pop(-1) + x, pre = decode(x, skip, lengths.pop(-1)) + # print('Decode {} X: {}'.format(idx, x.shape)) + # `pre` contains the output just before final transposed convolution, + # which is used when the freq. and time branch separate. + + offset = self.depth - len(self.tdecoder) + if idx >= offset: + tdec = self.tdecoder[idx - offset] + length_t = lengths_t.pop(-1) + if tdec.empty: + assert pre.shape[2] == 1, pre.shape + pre = pre[:, :, 0] + xt, _ = tdec(pre, None, length_t) + else: + skip = saved_t.pop(-1) + xt, _ = tdec(xt, skip, length_t) + # print('Decode {} XT: {}'.format(idx, xt.shape)) + + # Let's make sure we used all stored skip connections. + assert len(saved) == 0 + assert len(lengths_t) == 0 + assert len(saved_t) == 0 + + S = len(self.sources) + + if self.num_subbands > 1: + x = x.view(B, -1, Fq, T) + # print("X view 1: {}".format(x.shape)) + x = self.cws2cac(x) + # print("X view 2: {}".format(x.shape)) + + x = x.view(B, S, -1, Fq * self.num_subbands, T) + x = x * std[:, None] + mean[:, None] + # print("X returned: {}".format(x.shape)) + + zout = self._mask(z, x) + if self.use_train_segment: + if self.training: + x = self._ispec(zout, length) + else: + x = self._ispec(zout, training_length) + else: + x = self._ispec(zout, length) + + if self.use_train_segment: + if self.training: + xt = xt.view(B, S, -1, length) + else: + xt = xt.view(B, S, -1, training_length) + else: + xt = xt.view(B, S, -1, length) + xt = xt * stdt[:, None] + meant[:, None] + x = xt + x + if length_pre_pad: + x = x[..., :length_pre_pad] + return x + + +def get_model(args): + extra = { + 'sources': list(args.training.instruments), + 'audio_channels': args.training.channels, + 'samplerate': args.training.samplerate, + # 'segment': args.model_segment or 4 * args.dset.segment, + 'segment': args.training.segment, + } + klass = { + 'demucs': Demucs, + 'hdemucs': HDemucs, + 'htdemucs': HTDemucs, + }[args.model] + kw = OmegaConf.to_container(getattr(args, args.model), resolve=True) + model = klass(**extra, **kw) + return model + + diff --git a/separator/models/ex_bi_mamba2.py b/separator/models/ex_bi_mamba2.py new file mode 100644 index 0000000000000000000000000000000000000000..6bcf797cce8b28e88c7aab9adfa2037070504763 --- /dev/null +++ b/separator/models/ex_bi_mamba2.py @@ -0,0 +1,303 @@ +# https://github.com/Human9000/nd-Mamba2-torch + +import torch +from torch import Tensor, nn +from torch.nn import functional as F +from abc import abstractmethod + + +def silu(x): + return x * F.sigmoid(x) + + +class RMSNorm(nn.Module): + def __init__(self, d: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(d)) + + def forward(self, x, z): + x = x * silu(z) + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight + + +class Mamba2(nn.Module): + def __init__(self, d_model: int, # model dimension (D) + n_layer: int = 24, # number of Mamba-2 layers in the language model + d_state: int = 128, # state dimension (N) + d_conv: int = 4, # convolution kernel size + expand: int = 2, # expansion factor (E) + headdim: int = 64, # head dimension (P) + chunk_size: int = 64, # matrix partition size (Q) + ): + super().__init__() + self.n_layer = n_layer + self.d_state = d_state + self.headdim = headdim + # self.chunk_size = torch.tensor(chunk_size, dtype=torch.int32) + self.chunk_size = chunk_size + + self.d_inner = expand * d_model + assert self.d_inner % self.headdim == 0, "self.d_inner must be divisible by self.headdim" + self.nheads = self.d_inner // self.headdim + + d_in_proj = 2 * self.d_inner + 2 * self.d_state + self.nheads + self.in_proj = nn.Linear(d_model, d_in_proj, bias=False) + + conv_dim = self.d_inner + 2 * d_state + self.conv1d = nn.Conv1d(conv_dim, conv_dim, d_conv, groups=conv_dim, padding=d_conv - 1, ) + self.dt_bias = nn.Parameter(torch.empty(self.nheads, )) + self.A_log = nn.Parameter(torch.empty(self.nheads, )) + self.D = nn.Parameter(torch.empty(self.nheads, )) + self.norm = RMSNorm(self.d_inner, ) + self.out_proj = nn.Linear(self.d_inner, d_model, bias=False, ) + + def forward(self, u: Tensor): + A = -torch.exp(self.A_log) # (nheads,) + zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj) + z, xBC, dt = torch.split( + zxbcdt, + [ + self.d_inner, + self.d_inner + 2 * self.d_state, + self.nheads, + ], + dim=-1, + ) + dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads) + + # Pad or truncate xBC seqlen to d_conv + xBC = silu( + self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :] + ) # (batch, seqlen, d_inner + 2 * d_state)) + x, B, C = torch.split( + xBC, [self.d_inner, self.d_state, self.d_state], dim=-1 + ) + + _b, _l, _hp = x.shape + _h = _hp // self.headdim + _p = self.headdim + x = x.reshape(_b, _l, _h, _p) + + y = self.ssd(x * dt.unsqueeze(-1), + A * dt, + B.unsqueeze(2), + C.unsqueeze(2), ) + + y = y + x * self.D.unsqueeze(-1) + + _b, _l, _h, _p = y.shape + y = y.reshape(_b, _l, _h * _p) + + y = self.norm(y, z) + y = self.out_proj(y) + + return y + + def segsum(self, x: Tensor) -> Tensor: + T = x.size(-1) + device = x.device + x = x[..., None].repeat(1, 1, 1, 1, T) + mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1) + x = x.masked_fill(~mask, 0) + x_segsum = torch.cumsum(x, dim=-2) + mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + def ssd(self, x, A, B, C): + chunk_size = self.chunk_size + # if x.shape[1] % chunk_size == 0: + # + x = x.reshape(x.shape[0], x.shape[1] // chunk_size, chunk_size, x.shape[2], x.shape[3], ) + B = B.reshape(B.shape[0], B.shape[1] // chunk_size, chunk_size, B.shape[2], B.shape[3], ) + C = C.reshape(C.shape[0], C.shape[1] // chunk_size, chunk_size, C.shape[2], C.shape[3], ) + A = A.reshape(A.shape[0], A.shape[1] // chunk_size, chunk_size, A.shape[2]) + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + L = torch.exp(self.segsum(A)) + Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + + decay_chunk = torch.exp(self.segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))[0] + new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states) + states = new_states[:, :-1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + # Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") + Y = Y_diag + Y_off + Y = Y.reshape(Y.shape[0], Y.shape[1] * Y.shape[2], Y.shape[3], Y.shape[4], ) + + return Y + + +class _BiMamba2(nn.Module): + def __init__(self, + cin: int, + cout: int, + d_model: int, # model dimension (D) + n_layer: int = 24, # number of Mamba-2 layers in the language model + d_state: int = 128, # state dimension (N) + d_conv: int = 4, # convolution kernel size + expand: int = 2, # expansion factor (E) + headdim: int = 64, # head dimension (P) + chunk_size: int = 64, # matrix partition size (Q) + ): + super().__init__() + self.fc_in = nn.Linear(cin, d_model, bias=False) # 调整通道数到cmid + self.mamba2_for = Mamba2(d_model, n_layer, d_state, d_conv, expand, headdim, chunk_size, ) # 正向 + self.mamba2_back = Mamba2(d_model, n_layer, d_state, d_conv, expand, headdim, chunk_size, ) # 负向 + self.fc_out = nn.Linear(d_model, cout, bias=False) # 调整通道数到cout + self.chunk_size = chunk_size + + @abstractmethod + def forward(self, x): + pass + + +class BiMamba2_1D(_BiMamba2): + def __init__(self, cin, cout, d_model, **mamba2_args): + super().__init__(cin, cout, d_model, **mamba2_args) + + def forward(self, x): + l = x.shape[2] + x = F.pad(x, (0, (64 - x.shape[2] % 64) % 64)) # 将 l , pad到4的倍数, [b, c64,l4] + x = x.transpose(1, 2) # 转成 1d 信号 [b, c64, d4*w4*h4] + x = self.fc_in(x) # 调整通道数为目标通道数 + x1 = self.mamba2_for(x) + x2 = self.mamba2_back(x.flip(1)).flip(1) + x = x1 + x2 + x = self.fc_out(x) # 调整通道数为目标通道数 + x = x.transpose(1, 2) # 转成 1d 信号 [b, c64, d4*w4*h4] ] + x = x[:, :, :l] # 截取原图大小 + return x + + +class BiMamba2_2D(_BiMamba2): + def __init__(self, cin, cout, d_model, **mamba2_args): + super().__init__(cin, cout, d_model, **mamba2_args) + + def forward(self, x): + h, w = x.shape[2:] + x = F.pad(x, (0, (8 - x.shape[3] % 8) % 8, + 0, (8 - x.shape[2] % 8) % 8) + ) # 将 h , w pad到8的倍数, [b, c64, h8, w8] + _b, _c, _h, _w = x.shape + x = x.permute(0, 2, 3, 1).reshape(_b, _h * _w, _c) + x = self.fc_in(x) # 调整通道数为目标通道数 + x1 = self.mamba2_for(x) + x2 = self.mamba2_back(x.flip(1)).flip(1) + x = x1 + x2 + x = self.fc_out(x) # 调整通道数为目标通道数 + x = x.reshape(_b, _h, _w, -1, ) + x = x.permute(0, 3, 1, 2) + x = x.reshape(_b, -1, _h, _w, ) + x = x[:, :, :h, :w] # 截取原图大小 + return x + + +class BiMamba2_3D(_BiMamba2): + def __init__(self, cin, cout, d_model, **mamba2_args): + super().__init__(cin, cout, d_model, **mamba2_args) + + def forward(self, x): + d, h, w = x.shape[2:] + x = F.pad(x, (0, (4 - x.shape[4] % 4) % 4, + 0, (4 - x.shape[3] % 4) % 4, + 0, (4 - x.shape[2] % 4) % 4) + ) # 将 d, h, w , pad到4的倍数, [b, c64,d4, h4, w4] + _b, _c, _d, _h, _w = x.shape + x = x.permute(0, 2, 3, 4, 1).reshape(_b, _d * _h * _w, _c) + x = self.fc_in(x) # 调整通道数为目标通道数 + x1 = self.mamba2_for(x) + x2 = self.mamba2_back(x.flip(1)).flip(1) + x = x1 + x2 + x = self.fc_out(x) # 调整通道数为目标通道数 + x = x.reshape(_b, _d, _h, _w, -1) + x = x.permute(0, 4, 1, 2, 3) + x=x.reshape(_b, -1, _d, _h, _w, ) + x = x[:, :, :d, :h, :w] # 截取原图大小 + return x + + +class BiMamba2(_BiMamba2): + def __init__(self, cin, cout, d_model, **mamba2_args): + super().__init__(cin, cout, d_model, **mamba2_args) + + def forward(self, x): + size = x.shape[2:] + out_size = list(x.shape) + out_size[1] = -1 + + x = torch.flatten(x, 2) # b c size + l = x.shape[2] + _s = self.chunk_size + x = F.pad(x, [0, (_s - x.shape[2] % _s) % _s]) # 将 l, pad到chunk_size的倍数, [b, c64,l4] + x = x.transpose(1, 2) # 转成 1d 信号 + x = self.fc_in(x) # 调整通道数为目标通道数 + x1 = self.mamba2_for(x) + x2 = self.mamba2_back(x.flip(1)).flip(1) + x = x1 + x2 + x = self.fc_out(x) # 调整通道数为目标通道数 + x = x.transpose(1, 2) # 转成 1d 信号 + x = x[:, :, :l] # 截取原图大小 + x = x.reshape(out_size) + + return x + + +def test_export_jit_script(net, x): + y = net(x) + net_script = torch.jit.script(net) + torch.jit.save(net_script, 'net.jit.script') + net2 = torch.jit.load('net.jit.script') + y = net2(x) + print(y.shape) + + +def test_export_onnx(net, x): + torch.onnx.export(net, + x, + "net.onnx", # 输出的 ONNX 文件名 + export_params=True, # 存储训练参数 + opset_version=14, # 指定 ONNX 操作集版本 + do_constant_folding=False, # 是否执行常量折叠优化 + input_names=['input'], # 输入张量的名称 + output_names=['output'], # 输出张量的名称 + dynamic_axes={'input': {0: 'batch_size'}, # 可变维度的字典 + 'output': {0: 'batch_size'}}) + + +if __name__ == '__main__': + # 通用的多维度双向mamba2 + from torchnssd import ( + export_jit_script, + export_onnx, + statistics, + test_run, + ) + + net_n = BiMamba2_1D(61, 128, 32).cuda() + net_n.eval() + x = torch.randn(1, 61, 63).cuda() + export_jit_script(net_n) + export_onnx(net_n, x) + test_run(net_n, x) + statistics(net_n, (61, 63)) diff --git a/separator/models/look2hear/models/__init__.py b/separator/models/look2hear/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b72fb1f70bbc43f10bd74583d582ddb45848e45e --- /dev/null +++ b/separator/models/look2hear/models/__init__.py @@ -0,0 +1,49 @@ +### +# Author: Kai Li +# Date: 2022-02-12 15:16:35 +# Email: lk21@mails.tsinghua.edu.cn +# LastEditTime: 2022-10-04 16:24:53 +### +from .base_model import BaseModel +from .apollo import Apollo + +__all__ = [ + "BaseModel", + "GullFullband", + "Apollo" +] + + +def register_model(custom_model): + """Register a custom model, gettable with `models.get`. + + Args: + custom_model: Custom model to register. + + """ + if ( + custom_model.__name__ in globals().keys() + or custom_model.__name__.lower() in globals().keys() + ): + raise ValueError( + f"Model {custom_model.__name__} already exists. Choose another name." + ) + globals().update({custom_model.__name__: custom_model}) + + +def get(identifier): + """Returns an model class from a string (case-insensitive). + + Args: + identifier (str): the model name. + + Returns: + :class:`torch.nn.Module` + """ + if isinstance(identifier, str): + to_get = {k.lower(): v for k, v in globals().items()} + cls = to_get.get(identifier.lower()) + if cls is None: + raise ValueError(f"Could not interpret model name : {str(identifier)}") + return cls + raise ValueError(f"Could not interpret model name : {str(identifier)}") diff --git a/separator/models/look2hear/models/apollo.py b/separator/models/look2hear/models/apollo.py new file mode 100644 index 0000000000000000000000000000000000000000..5de9afd468e4635e581c0ff41dce7acc4eb249be --- /dev/null +++ b/separator/models/look2hear/models/apollo.py @@ -0,0 +1,324 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from .base_model import BaseModel + + +class RMSNorm(nn.Module): + def __init__(self, dimension, groups=1): + super().__init__() + + self.weight = nn.Parameter(torch.ones(dimension)) + self.groups = groups + self.eps = 1e-5 + + def forward(self, input): + # input size: (B, N, T) + B, N, T = input.shape + assert N % self.groups == 0 + + input_float = input.reshape(B, self.groups, -1, T).float() + input_norm = input_float * torch.rsqrt(input_float.pow(2).mean(-2, keepdim=True) + self.eps) + + return input_norm.type_as(input).reshape(B, N, T) * self.weight.reshape(1, -1, 1) + + +class RMVN(nn.Module): + """ + Rescaled MVN. + """ + + def __init__(self, dimension, groups=1): + super(RMVN, self).__init__() + + self.mean = nn.Parameter(torch.zeros(dimension)) + self.std = nn.Parameter(torch.ones(dimension)) + self.groups = groups + self.eps = 1e-5 + + def forward(self, input): + # input size: (B, N, *) + B, N = input.shape[:2] + assert N % self.groups == 0 + input_reshape = input.reshape(B, self.groups, N // self.groups, -1) + T = input_reshape.shape[-1] + + input_norm = (input_reshape - input_reshape.mean(2).unsqueeze(2)) / ( + input_reshape.var(2).unsqueeze(2) + self.eps).sqrt() + input_norm = input_norm.reshape(B, N, T) * self.std.reshape(1, -1, 1) + self.mean.reshape(1, -1, 1) + + return input_norm.reshape(input.shape) + + +class Roformer(nn.Module): + """ + Transformer with rotary positional embedding. + """ + + def __init__(self, input_size, hidden_size, num_head=8, theta=10000, window=10000, + input_drop=0., attention_drop=0., causal=True): + super().__init__() + + self.input_size = input_size + self.hidden_size = hidden_size // num_head + self.num_head = num_head + self.theta = theta # base frequency for RoPE + self.window = window + # pre-calculate rotary embeddings + cos_freq, sin_freq = self._calc_rotary_emb() + self.register_buffer("cos_freq", cos_freq) # win, N + self.register_buffer("sin_freq", sin_freq) # win, N + + self.attention_drop = attention_drop + self.causal = causal + self.eps = 1e-5 + + self.input_norm = RMSNorm(self.input_size) + self.input_drop = nn.Dropout(p=input_drop) + self.weight = nn.Conv1d(self.input_size, self.hidden_size * self.num_head * 3, 1, bias=False) + self.output = nn.Conv1d(self.hidden_size * self.num_head, self.input_size, 1, bias=False) + + self.MLP = nn.Sequential(RMSNorm(self.input_size), + nn.Conv1d(self.input_size, self.input_size * 8, 1, bias=False), + nn.SiLU() + ) + self.MLP_output = nn.Conv1d(self.input_size * 4, self.input_size, 1, bias=False) + + def _calc_rotary_emb(self): + freq = 1. / (self.theta ** ( + torch.arange(0, self.hidden_size, 2)[:(self.hidden_size // 2)] / self.hidden_size)) # theta_i + freq = freq.reshape(1, -1) # 1, N//2 + pos = torch.arange(0, self.window).reshape(-1, 1) # win, 1 + cos_freq = torch.cos(pos * freq) # win, N//2 + sin_freq = torch.sin(pos * freq) # win, N//2 + cos_freq = torch.stack([cos_freq] * 2, -1).reshape(self.window, self.hidden_size) # win, N + sin_freq = torch.stack([sin_freq] * 2, -1).reshape(self.window, self.hidden_size) # win, N + + return cos_freq, sin_freq + + def _add_rotary_emb(self, feature, pos): + # feature shape: ..., N + N = feature.shape[-1] + + feature_reshape = feature.reshape(-1, N) + pos = min(pos, self.window - 1) + cos_freq = self.cos_freq[pos] + sin_freq = self.sin_freq[pos] + reverse_sign = torch.from_numpy(np.asarray([-1, 1])).to(feature.device).type(feature.dtype) + feature_reshape_neg = ( + torch.flip(feature_reshape.reshape(-1, N // 2, 2), [-1]) * reverse_sign.reshape(1, 1, 2)).reshape( + -1, N) + feature_rope = feature_reshape * cos_freq.unsqueeze(0) + feature_reshape_neg * sin_freq.unsqueeze(0) + + return feature_rope.reshape(feature.shape) + + def _add_rotary_sequence(self, feature): + # feature shape: ..., T, N + T, N = feature.shape[-2:] + feature_reshape = feature.reshape(-1, T, N) + + cos_freq = self.cos_freq[:T] + sin_freq = self.sin_freq[:T] + reverse_sign = torch.from_numpy(np.asarray([-1, 1])).to(feature.device).type(feature.dtype) + feature_reshape_neg = ( + torch.flip(feature_reshape.reshape(-1, N // 2, 2), [-1]) * reverse_sign.reshape(1, 1, 2)).reshape( + -1, T, N) + feature_rope = feature_reshape * cos_freq.unsqueeze(0) + feature_reshape_neg * sin_freq.unsqueeze(0) + + return feature_rope.reshape(feature.shape) + + def forward(self, input): + # input shape: B, N, T + + B, _, T = input.shape + + weight = self.weight(self.input_drop(self.input_norm(input))).reshape(B, self.num_head, self.hidden_size * 3, + T).mT + Q, K, V = torch.split(weight, self.hidden_size, dim=-1) # B, num_head, T, N + + # rotary positional embedding + Q_rot = self._add_rotary_sequence(Q) + K_rot = self._add_rotary_sequence(K) + + attention_output = F.scaled_dot_product_attention(Q_rot.contiguous(), K_rot.contiguous(), V.contiguous(), + dropout_p=self.attention_drop, + is_causal=self.causal) # B, num_head, T, N + attention_output = attention_output.mT.reshape(B, -1, T) + output = self.output(attention_output) + input + + gate, z = self.MLP(output).chunk(2, dim=1) + output = output + self.MLP_output(F.silu(gate) * z) + + return output, (K_rot, V) + + +class ConvActNorm1d(nn.Module): + def __init__(self, in_channel, hidden_channel, kernel=7, causal=False): + super(ConvActNorm1d, self).__init__() + + self.in_channel = in_channel + self.kernel = kernel + self.causal = causal + if not causal: + self.conv = nn.Sequential( + nn.Conv1d(in_channel, in_channel, kernel, padding=(kernel - 1) // 2, groups=in_channel), + RMSNorm(in_channel), + nn.Conv1d(in_channel, hidden_channel, 1), + nn.SiLU(), + nn.Conv1d(hidden_channel, in_channel, 1) + ) + else: + self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=kernel - 1, groups=in_channel), + RMSNorm(in_channel), + nn.Conv1d(in_channel, hidden_channel, 1), + nn.SiLU(), + nn.Conv1d(hidden_channel, in_channel, 1) + ) + + def forward(self, input): + + output = self.conv(input) + if self.causal: + output = output[..., :-self.kernel + 1] + return input + output + + +class ICB(nn.Module): + def __init__(self, in_channel, kernel=7, causal=False): + super(ICB, self).__init__() + + self.blocks = nn.Sequential(ConvActNorm1d(in_channel, in_channel * 4, kernel, causal=causal), + ConvActNorm1d(in_channel, in_channel * 4, kernel, causal=causal), + ConvActNorm1d(in_channel, in_channel * 4, kernel, causal=causal) + ) + + def forward(self, input): + return self.blocks(input) + + +class BSNet(nn.Module): + def __init__(self, feature_dim, kernel=7): + super(BSNet, self).__init__() + + self.feature_dim = feature_dim + + self.band_net = Roformer(self.feature_dim, self.feature_dim, num_head=8, window=100, causal=False) + self.seq_net = ICB(self.feature_dim, kernel=kernel) + + def forward(self, input): + # input shape: B, nband, N, T + + B, nband, N, T = input.shape + + # band comm + band_input = input.permute(0, 3, 2, 1).reshape(B * T, -1, nband) + band_output, _ = self.band_net(band_input) + band_output = band_output.reshape(B, T, -1, nband).permute(0, 3, 2, 1) + + # sequence modeling + output = self.seq_net(band_output.reshape(B * nband, -1, T)).reshape(B, nband, -1, T) # B, nband, N, T + + return output + + +class Apollo(BaseModel): + def __init__( + self, + sr: int, + win: int, + feature_dim: int, + layer: int + ): + super().__init__(sample_rate=sr) + + self.sr = sr + self.win = int(sr * win // 1000) + self.stride = self.win // 2 + self.enc_dim = self.win // 2 + 1 + self.feature_dim = feature_dim + self.eps = torch.finfo(torch.float32).eps + + # 80 bands + bandwidth = int(self.win / 160) + self.band_width = [bandwidth] * 79 + self.band_width.append(self.enc_dim - np.sum(self.band_width)) + self.nband = len(self.band_width) + print(self.band_width, self.nband) + + self.BN = nn.ModuleList([]) + for i in range(self.nband): + self.BN.append(nn.Sequential(RMSNorm(self.band_width[i] * 2 + 1), + nn.Conv1d(self.band_width[i] * 2 + 1, self.feature_dim, 1)) + ) + + self.net = [] + for _ in range(layer): + self.net.append(BSNet(self.feature_dim)) + self.net = nn.Sequential(*self.net) + + self.output = nn.ModuleList([]) + for i in range(self.nband): + self.output.append(nn.Sequential(RMSNorm(self.feature_dim), + nn.Conv1d(self.feature_dim, self.band_width[i] * 4, 1), + nn.GLU(dim=1) + ) + ) + + def spec_band_split(self, input): + + B, nch, nsample = input.shape + + spec = torch.stft(input.view(B * nch, nsample), n_fft=self.win, hop_length=self.stride, + window=torch.hann_window(self.win).to(input.device), return_complex=True) + + subband_spec = [] + subband_spec_norm = [] + subband_power = [] + band_idx = 0 + for i in range(self.nband): + this_spec = spec[:, band_idx:band_idx + self.band_width[i]] + subband_spec.append(this_spec) # B, BW, T + subband_power.append((this_spec.abs().pow(2).sum(1) + self.eps).sqrt().unsqueeze(1)) # B, 1, T + subband_spec_norm.append( + torch.complex(this_spec.real / subband_power[-1], this_spec.imag / subband_power[-1])) # B, BW, T + band_idx += self.band_width[i] + subband_power = torch.cat(subband_power, 1) # B, nband, T + + return subband_spec_norm, subband_power + + def feature_extractor(self, input): + + subband_spec_norm, subband_power = self.spec_band_split(input) + + # normalization and bottleneck + subband_feature = [] + for i in range(self.nband): + concat_spec = torch.cat( + [subband_spec_norm[i].real, subband_spec_norm[i].imag, torch.log(subband_power[:, i].unsqueeze(1))], 1) + subband_feature.append(self.BN[i](concat_spec)) + subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T + + return subband_feature + + def forward(self, input): + + B, nch, nsample = input.shape + + subband_feature = self.feature_extractor(input) + feature = self.net(subband_feature) + + est_spec = [] + for i in range(self.nband): + this_RI = self.output[i](feature[:, i]).view(B * nch, 2, self.band_width[i], -1) + est_spec.append(torch.complex(this_RI[:, 0], this_RI[:, 1])) + est_spec = torch.cat(est_spec, 1) + est_spec = est_spec.to(dtype=torch.complex64) + output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride, + window=torch.hann_window(self.win).to(input.device), length=nsample).view(B, nch, -1) + + return output + + def get_model_args(self): + model_args = {"n_sample_rate": 2} + return model_args \ No newline at end of file diff --git a/separator/models/look2hear/models/base_model.py b/separator/models/look2hear/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e24b23192141634c0496be04899b6ecdb683b6c5 --- /dev/null +++ b/separator/models/look2hear/models/base_model.py @@ -0,0 +1,100 @@ +### +# Author: Kai Li +# Date: 2021-06-17 23:08:32 +# LastEditors: Please set LastEditors +# LastEditTime: 2022-05-26 18:06:22 +### +import torch +import torch.nn as nn + + +def _unsqueeze_to_3d(x): + """Normalize shape of `x` to [batch, n_chan, time].""" + if x.ndim == 1: + return x.reshape(1, 1, -1) + elif x.ndim == 2: + return x.unsqueeze(1) + else: + return x + + +def pad_to_appropriate_length(x, lcm): + values_to_pad = int(x.shape[-1]) % lcm + if values_to_pad: + appropriate_shape = x.shape + padded_x = torch.zeros( + list(appropriate_shape[:-1]) + + [appropriate_shape[-1] + lcm - values_to_pad], + dtype=torch.float32, + ).to(x.device) + padded_x[..., : x.shape[-1]] = x + return padded_x + return x + + +class BaseModel(nn.Module): + def __init__(self, sample_rate, in_chan=1): + super().__init__() + self._sample_rate = sample_rate + self._in_chan = in_chan + + def forward(self, *args, **kwargs): + raise NotImplementedError + + def sample_rate(self,): + return self._sample_rate + + @staticmethod + def load_state_dict_in_audio(model, pretrained_dict): + model_dict = model.state_dict() + update_dict = {} + for k, v in pretrained_dict.items(): + if "audio_model" in k: + update_dict[k[12:]] = v + model_dict.update(update_dict) + model.load_state_dict(model_dict) + return model + + @staticmethod + def from_pretrain(pretrained_model_conf_or_path, *args, **kwargs): + from . import get + + conf = torch.load( + pretrained_model_conf_or_path, map_location="cpu" + ) # Attempt to find the model and instantiate it. + + model_class = get(conf["model_name"]) + # model_class = get("Conv_TasNet") + model = model_class(*args, **kwargs) + model.load_state_dict(conf["state_dict"]) + return model + + def apollo(*args, **kwargs): + from . import get + model_class = get('Apollo') + model = model_class(*args, **kwargs) + return model + + def serialize(self): + import pytorch_lightning as pl # Not used in torch.hub + + model_conf = dict( + model_name=self.__class__.__name__, + state_dict=self.get_state_dict(), + model_args=self.get_model_args(), + ) + # Additional infos + infos = dict() + infos["software_versions"] = dict( + torch_version=torch.__version__, pytorch_lightning_version=pl.__version__, + ) + model_conf["infos"] = infos + return model_conf + + def get_state_dict(self): + """In case the state dict needs to be modified before sharing the model.""" + return self.state_dict() + + def get_model_args(self): + """Should return args to re-instantiate the class.""" + raise NotImplementedError diff --git a/separator/models/mdx23c_tfc_tdf_v3.py b/separator/models/mdx23c_tfc_tdf_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..bb7fca037728e7d7c9e6533765a314ea076c3670 --- /dev/null +++ b/separator/models/mdx23c_tfc_tdf_v3.py @@ -0,0 +1,242 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from msst_utils import prefer_target_instrument + +class STFT: + def __init__(self, config): + self.n_fft = config.n_fft + self.hop_length = config.hop_length + self.window = torch.hann_window(window_length=self.n_fft, periodic=True) + self.dim_f = config.dim_f + + def __call__(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-2] + c, t = x.shape[-2:] + x = x.reshape([-1, t]) + x = torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True, + return_complex=True + ) + x = torch.view_as_real(x) + x = x.permute([0, 3, 1, 2]) + x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) + return x[..., :self.dim_f, :] + + def inverse(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-3] + c, f, t = x.shape[-3:] + n = self.n_fft // 2 + 1 + f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) + x = torch.cat([x, f_pad], -2) + x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) + x = x.permute([0, 2, 3, 1]) + x = x[..., 0] + x[..., 1] * 1.j + x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True) + x = x.reshape([*batch_dims, 2, -1]) + return x + + +def get_norm(norm_type): + def norm(c, norm_type): + if norm_type == 'BatchNorm': + return nn.BatchNorm2d(c) + elif norm_type == 'InstanceNorm': + return nn.InstanceNorm2d(c, affine=True) + elif 'GroupNorm' in norm_type: + g = int(norm_type.replace('GroupNorm', '')) + return nn.GroupNorm(num_groups=g, num_channels=c) + else: + return nn.Identity() + + return partial(norm, norm_type=norm_type) + + +def get_act(act_type): + if act_type == 'gelu': + return nn.GELU() + elif act_type == 'relu': + return nn.ReLU() + elif act_type[:3] == 'elu': + alpha = float(act_type.replace('elu', '')) + return nn.ELU(alpha) + else: + raise Exception + + +class Upscale(nn.Module): + def __init__(self, in_c, out_c, scale, norm, act): + super().__init__() + self.conv = nn.Sequential( + norm(in_c), + act, + nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) + ) + + def forward(self, x): + return self.conv(x) + + +class Downscale(nn.Module): + def __init__(self, in_c, out_c, scale, norm, act): + super().__init__() + self.conv = nn.Sequential( + norm(in_c), + act, + nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) + ) + + def forward(self, x): + return self.conv(x) + + +class TFC_TDF(nn.Module): + def __init__(self, in_c, c, l, f, bn, norm, act): + super().__init__() + + self.blocks = nn.ModuleList() + for i in range(l): + block = nn.Module() + + block.tfc1 = nn.Sequential( + norm(in_c), + act, + nn.Conv2d(in_c, c, 3, 1, 1, bias=False), + ) + block.tdf = nn.Sequential( + norm(c), + act, + nn.Linear(f, f // bn, bias=False), + norm(c), + act, + nn.Linear(f // bn, f, bias=False), + ) + block.tfc2 = nn.Sequential( + norm(c), + act, + nn.Conv2d(c, c, 3, 1, 1, bias=False), + ) + block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False) + + self.blocks.append(block) + in_c = c + + def forward(self, x): + for block in self.blocks: + s = block.shortcut(x) + x = block.tfc1(x) + x = x + block.tdf(x) + x = block.tfc2(x) + x = x + s + return x + + +class TFC_TDF_net(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + norm = get_norm(norm_type=config.model.norm) + act = get_act(act_type=config.model.act) + + self.num_target_instruments = len(prefer_target_instrument(config)) + self.num_subbands = config.model.num_subbands + + dim_c = self.num_subbands * config.audio.num_channels * 2 + n = config.model.num_scales + scale = config.model.scale + l = config.model.num_blocks_per_scale + c = config.model.num_channels + g = config.model.growth + bn = config.model.bottleneck_factor + f = config.audio.dim_f // self.num_subbands + + self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False) + + self.encoder_blocks = nn.ModuleList() + for i in range(n): + block = nn.Module() + block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act) + block.downscale = Downscale(c, c + g, scale, norm, act) + f = f // scale[1] + c += g + self.encoder_blocks.append(block) + + self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act) + + self.decoder_blocks = nn.ModuleList() + for i in range(n): + block = nn.Module() + block.upscale = Upscale(c, c - g, scale, norm, act) + f = f * scale[1] + c -= g + block.tfc_tdf = TFC_TDF(2 * c, c, l, f, bn, norm, act) + self.decoder_blocks.append(block) + + self.final_conv = nn.Sequential( + nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), + act, + nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) + ) + + self.stft = STFT(config.audio) + + def cac2cws(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c, k, f // k, t) + x = x.reshape(b, c * k, f // k, t) + return x + + def cws2cac(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c // k, k, f, t) + x = x.reshape(b, c // k, f * k, t) + return x + + def forward(self, x): + + x = self.stft(x) + + mix = x = self.cac2cws(x) + + first_conv_out = x = self.first_conv(x) + + x = x.transpose(-1, -2) + + encoder_outputs = [] + for block in self.encoder_blocks: + x = block.tfc_tdf(x) + encoder_outputs.append(x) + x = block.downscale(x) + + x = self.bottleneck_block(x) + + for block in self.decoder_blocks: + x = block.upscale(x) + x = torch.cat([x, encoder_outputs.pop()], 1) + x = block.tfc_tdf(x) + + x = x.transpose(-1, -2) + + x = x * first_conv_out # reduce artifacts + + x = self.final_conv(torch.cat([mix, x], 1)) + + x = self.cws2cac(x) + + if self.num_target_instruments > 1: + b, c, f, t = x.shape + x = x.reshape(b, self.num_target_instruments, -1, f, t) + + x = self.stft.inverse(x) + + return x diff --git a/separator/models/scnet/__init__.py b/separator/models/scnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f6ecefede9345237623066dd21ebd8253af1c60 --- /dev/null +++ b/separator/models/scnet/__init__.py @@ -0,0 +1 @@ +from .scnet import SCNet diff --git a/separator/models/scnet/scnet.py b/separator/models/scnet/scnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b27704dc922eb593dc76f3b9905aa8c0ea02507f --- /dev/null +++ b/separator/models/scnet/scnet.py @@ -0,0 +1,373 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import deque +from .separation import SeparationNet +import typing as tp +import math + + +class Swish(nn.Module): + def forward(self, x): + return x * x.sigmoid() + + +class ConvolutionModule(nn.Module): + """ + Convolution Module in SD block. + + Args: + channels (int): input/output channels. + depth (int): number of layers in the residual branch. Each layer has its own + compress (float): amount of channel compression. + kernel (int): kernel size for the convolutions. + """ + + def __init__(self, channels, depth=2, compress=4, kernel=3): + super().__init__() + assert kernel % 2 == 1 + self.depth = abs(depth) + hidden_size = int(channels / compress) + norm = lambda d: nn.GroupNorm(1, d) + self.layers = nn.ModuleList([]) + for _ in range(self.depth): + padding = (kernel // 2) + mods = [ + norm(channels), + nn.Conv1d(channels, hidden_size * 2, kernel, padding=padding), + nn.GLU(1), + nn.Conv1d(hidden_size, hidden_size, kernel, padding=padding, groups=hidden_size), + norm(hidden_size), + Swish(), + nn.Conv1d(hidden_size, channels, 1), + ] + layer = nn.Sequential(*mods) + self.layers.append(layer) + + def forward(self, x): + for layer in self.layers: + x = x + layer(x) + return x + + +class FusionLayer(nn.Module): + """ + A FusionLayer within the decoder. + + Args: + - channels (int): Number of input channels. + - kernel_size (int, optional): Kernel size for the convolutional layer, defaults to 3. + - stride (int, optional): Stride for the convolutional layer, defaults to 1. + - padding (int, optional): Padding for the convolutional layer, defaults to 1. + """ + + def __init__(self, channels, kernel_size=3, stride=1, padding=1): + super(FusionLayer, self).__init__() + self.conv = nn.Conv2d(channels * 2, channels * 2, kernel_size, stride=stride, padding=padding) + + def forward(self, x, skip=None): + if skip is not None: + x += skip + x = x.repeat(1, 2, 1, 1) + x = self.conv(x) + x = F.glu(x, dim=1) + return x + + +class SDlayer(nn.Module): + """ + Implements a Sparse Down-sample Layer for processing different frequency bands separately. + + Args: + - channels_in (int): Input channel count. + - channels_out (int): Output channel count. + - band_configs (dict): A dictionary containing configuration for each frequency band. + Keys are 'low', 'mid', 'high' for each band, and values are + dictionaries with keys 'SR', 'stride', and 'kernel' for proportion, + stride, and kernel size, respectively. + """ + + def __init__(self, channels_in, channels_out, band_configs): + super(SDlayer, self).__init__() + + # Initializing convolutional layers for each band + self.convs = nn.ModuleList() + self.strides = [] + self.kernels = [] + for config in band_configs.values(): + self.convs.append( + nn.Conv2d(channels_in, channels_out, (config['kernel'], 1), (config['stride'], 1), (0, 0))) + self.strides.append(config['stride']) + self.kernels.append(config['kernel']) + + # Saving rate proportions for determining splits + self.SR_low = band_configs['low']['SR'] + self.SR_mid = band_configs['mid']['SR'] + + def forward(self, x): + B, C, Fr, T = x.shape + # Define splitting points based on sampling rates + splits = [ + (0, math.ceil(Fr * self.SR_low)), + (math.ceil(Fr * self.SR_low), math.ceil(Fr * (self.SR_low + self.SR_mid))), + (math.ceil(Fr * (self.SR_low + self.SR_mid)), Fr) + ] + + # Processing each band with the corresponding convolution + outputs = [] + original_lengths = [] + for conv, stride, kernel, (start, end) in zip(self.convs, self.strides, self.kernels, splits): + extracted = x[:, :, start:end, :] + original_lengths.append(end - start) + current_length = extracted.shape[2] + + # padding + if stride == 1: + total_padding = kernel - stride + else: + total_padding = (stride - current_length % stride) % stride + pad_left = total_padding // 2 + pad_right = total_padding - pad_left + + padded = F.pad(extracted, (0, 0, pad_left, pad_right)) + + output = conv(padded) + outputs.append(output) + + return outputs, original_lengths + + +class SUlayer(nn.Module): + """ + Implements a Sparse Up-sample Layer in decoder. + + Args: + - channels_in: The number of input channels. + - channels_out: The number of output channels. + - convtr_configs: Dictionary containing the configurations for transposed convolutions. + """ + + def __init__(self, channels_in, channels_out, band_configs): + super(SUlayer, self).__init__() + + # Initializing convolutional layers for each band + self.convtrs = nn.ModuleList([ + nn.ConvTranspose2d(channels_in, channels_out, [config['kernel'], 1], [config['stride'], 1]) + for _, config in band_configs.items() + ]) + + def forward(self, x, lengths, origin_lengths): + B, C, Fr, T = x.shape + # Define splitting points based on input lengths + splits = [ + (0, lengths[0]), + (lengths[0], lengths[0] + lengths[1]), + (lengths[0] + lengths[1], None) + ] + # Processing each band with the corresponding convolution + outputs = [] + for idx, (convtr, (start, end)) in enumerate(zip(self.convtrs, splits)): + out = convtr(x[:, :, start:end, :]) + # Calculate the distance to trim the output symmetrically to original length + current_Fr_length = out.shape[2] + dist = abs(origin_lengths[idx] - current_Fr_length) // 2 + + # Trim the output to the original length symmetrically + trimmed_out = out[:, :, dist:dist + origin_lengths[idx], :] + + outputs.append(trimmed_out) + + # Concatenate trimmed outputs along the frequency dimension to return the final tensor + x = torch.cat(outputs, dim=2) + + return x + + +class SDblock(nn.Module): + """ + Implements a simplified Sparse Down-sample block in encoder. + + Args: + - channels_in (int): Number of input channels. + - channels_out (int): Number of output channels. + - band_config (dict): Configuration for the SDlayer specifying band splits and convolutions. + - conv_config (dict): Configuration for convolution modules applied to each band. + - depths (list of int): List specifying the convolution depths for low, mid, and high frequency bands. + """ + + def __init__(self, channels_in, channels_out, band_configs={}, conv_config={}, depths=[3, 2, 1], kernel_size=3): + super(SDblock, self).__init__() + self.SDlayer = SDlayer(channels_in, channels_out, band_configs) + + # Dynamically create convolution modules for each band based on depths + self.conv_modules = nn.ModuleList([ + ConvolutionModule(channels_out, depth, **conv_config) for depth in depths + ]) + # Set the kernel_size to an odd number. + self.globalconv = nn.Conv2d(channels_out, channels_out, kernel_size, 1, (kernel_size - 1) // 2) + + def forward(self, x): + bands, original_lengths = self.SDlayer(x) + # B, C, f, T = band.shape + bands = [ + F.gelu( + conv(band.permute(0, 2, 1, 3).reshape(-1, band.shape[1], band.shape[3])) + .view(band.shape[0], band.shape[2], band.shape[1], band.shape[3]) + .permute(0, 2, 1, 3) + ) + for conv, band in zip(self.conv_modules, bands) + + ] + lengths = [band.size(-2) for band in bands] + full_band = torch.cat(bands, dim=2) + skip = full_band + + output = self.globalconv(full_band) + + return output, skip, lengths, original_lengths + + +class SCNet(nn.Module): + """ + The implementation of SCNet: Sparse Compression Network for Music Source Separation. Paper: https://arxiv.org/abs/2401.13276.pdf + + Args: + - sources (List[str]): List of sources to be separated. + - audio_channels (int): Number of audio channels. + - nfft (int): Number of FFTs to determine the frequency dimension of the input. + - hop_size (int): Hop size for the STFT. + - win_size (int): Window size for STFT. + - normalized (bool): Whether to normalize the STFT. + - dims (List[int]): List of channel dimensions for each block. + - band_SR (List[float]): The proportion of each frequency band. + - band_stride (List[int]): The down-sampling ratio of each frequency band. + - band_kernel (List[int]): The kernel sizes for down-sampling convolution in each frequency band + - conv_depths (List[int]): List specifying the number of convolution modules in each SD block. + - compress (int): Compression factor for convolution module. + - conv_kernel (int): Kernel size for convolution layer in convolution module. + - num_dplayer (int): Number of dual-path layers. + - expand (int): Expansion factor in the dual-path RNN, default is 1. + + """ + + def __init__(self, + sources=['drums', 'bass', 'other', 'vocals'], + audio_channels=2, + # Main structure + dims=[4, 32, 64, 128], # dims = [4, 64, 128, 256] in SCNet-large + # STFT + nfft=4096, + hop_size=1024, + win_size=4096, + normalized=True, + # SD/SU layer + band_SR=[0.175, 0.392, 0.433], + band_stride=[1, 4, 16], + band_kernel=[3, 4, 16], + # Convolution Module + conv_depths=[3, 2, 1], + compress=4, + conv_kernel=3, + # Dual-path RNN + num_dplayer=6, + expand=1, + ): + super().__init__() + self.sources = sources + self.audio_channels = audio_channels + self.dims = dims + band_keys = ['low', 'mid', 'high'] + self.band_configs = {band_keys[i]: {'SR': band_SR[i], 'stride': band_stride[i], 'kernel': band_kernel[i]} for i + in range(len(band_keys))} + self.hop_length = hop_size + self.conv_config = { + 'compress': compress, + 'kernel': conv_kernel, + } + + self.stft_config = { + 'n_fft': nfft, + 'hop_length': hop_size, + 'win_length': win_size, + 'center': True, + 'normalized': normalized + } + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + for index in range(len(dims) - 1): + enc = SDblock( + channels_in=dims[index], + channels_out=dims[index + 1], + band_configs=self.band_configs, + conv_config=self.conv_config, + depths=conv_depths + ) + self.encoder.append(enc) + + dec = nn.Sequential( + FusionLayer(channels=dims[index + 1]), + SUlayer( + channels_in=dims[index + 1], + channels_out=dims[index] if index != 0 else dims[index] * len(sources), + band_configs=self.band_configs, + ) + ) + self.decoder.insert(0, dec) + + self.separation_net = SeparationNet( + channels=dims[-1], + expand=expand, + num_layers=num_dplayer, + ) + + def forward(self, x): + # B, C, L = x.shape + B = x.shape[0] + # In the initial padding, ensure that the number of frames after the STFT (the length of the T dimension) is even, + # so that the RFFT operation can be used in the separation network. + padding = self.hop_length - x.shape[-1] % self.hop_length + if (x.shape[-1] + padding) // self.hop_length % 2 == 0: + padding += self.hop_length + x = F.pad(x, (0, padding)) + + # STFT + L = x.shape[-1] + x = x.reshape(-1, L) + x = torch.stft(x, **self.stft_config, return_complex=True) + x = torch.view_as_real(x) + x = x.permute(0, 3, 1, 2).reshape(x.shape[0] // self.audio_channels, x.shape[3] * self.audio_channels, + x.shape[1], x.shape[2]) + + B, C, Fr, T = x.shape + + save_skip = deque() + save_lengths = deque() + save_original_lengths = deque() + # encoder + for sd_layer in self.encoder: + x, skip, lengths, original_lengths = sd_layer(x) + save_skip.append(skip) + save_lengths.append(lengths) + save_original_lengths.append(original_lengths) + + # separation + x = self.separation_net(x) + + # decoder + for fusion_layer, su_layer in self.decoder: + x = fusion_layer(x, save_skip.pop()) + x = su_layer(x, save_lengths.pop(), save_original_lengths.pop()) + + # output + n = self.dims[0] + x = x.view(B, n, -1, Fr, T) + x = x.reshape(-1, 2, Fr, T).permute(0, 2, 3, 1) + x = torch.view_as_complex(x.contiguous()) + x = torch.istft(x, **self.stft_config) + x = x.reshape(B, len(self.sources), self.audio_channels, -1) + + x = x[:, :, :, :-padding] + + return x diff --git a/separator/models/scnet/separation.py b/separator/models/scnet/separation.py new file mode 100644 index 0000000000000000000000000000000000000000..d902dac4d947123d3ba1270dd065be0d8b4c5ed9 --- /dev/null +++ b/separator/models/scnet/separation.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn +from torch.nn.modules.rnn import LSTM + + +class FeatureConversion(nn.Module): + """ + Integrates into the adjacent Dual-Path layer. + + Args: + channels (int): Number of input channels. + inverse (bool): If True, uses ifft; otherwise, uses rfft. + """ + + def __init__(self, channels, inverse): + super().__init__() + self.inverse = inverse + self.channels = channels + + def forward(self, x): + # B, C, F, T = x.shape + if self.inverse: + x = x.float() + x_r = x[:, :self.channels // 2, :, :] + x_i = x[:, self.channels // 2:, :, :] + x = torch.complex(x_r, x_i) + x = torch.fft.irfft(x, dim=3, norm="ortho") + else: + x = x.float() + x = torch.fft.rfft(x, dim=3, norm="ortho") + x_real = x.real + x_imag = x.imag + x = torch.cat([x_real, x_imag], dim=1) + return x + + +class DualPathRNN(nn.Module): + """ + Dual-Path RNN in Separation Network. + + Args: + d_model (int): The number of expected features in the input (input_size). + expand (int): Expansion factor used to calculate the hidden_size of LSTM. + bidirectional (bool): If True, becomes a bidirectional LSTM. + """ + + def __init__(self, d_model, expand, bidirectional=True): + super(DualPathRNN, self).__init__() + + self.d_model = d_model + self.hidden_size = d_model * expand + self.bidirectional = bidirectional + # Initialize LSTM layers and normalization layers + self.lstm_layers = nn.ModuleList([self._init_lstm_layer(self.d_model, self.hidden_size) for _ in range(2)]) + self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_size * 2, self.d_model) for _ in range(2)]) + self.norm_layers = nn.ModuleList([nn.GroupNorm(1, d_model) for _ in range(2)]) + + def _init_lstm_layer(self, d_model, hidden_size): + return LSTM(d_model, hidden_size, num_layers=1, bidirectional=self.bidirectional, batch_first=True) + + def forward(self, x): + B, C, F, T = x.shape + + # Process dual-path rnn + original_x = x + # Frequency-path + x = self.norm_layers[0](x) + x = x.transpose(1, 3).contiguous().view(B * T, F, C) + x, _ = self.lstm_layers[0](x) + x = self.linear_layers[0](x) + x = x.view(B, T, F, C).transpose(1, 3) + x = x + original_x + + original_x = x + # Time-path + x = self.norm_layers[1](x) + x = x.transpose(1, 2).contiguous().view(B * F, C, T).transpose(1, 2) + x, _ = self.lstm_layers[1](x) + x = self.linear_layers[1](x) + x = x.transpose(1, 2).contiguous().view(B, F, C, T).transpose(1, 2) + x = x + original_x + + return x + + +class SeparationNet(nn.Module): + """ + Implements a simplified Sparse Down-sample block in an encoder architecture. + + Args: + - channels (int): Number input channels. + - expand (int): Expansion factor used to calculate the hidden_size of LSTM. + - num_layers (int): Number of dual-path layers. + """ + + def __init__(self, channels, expand=1, num_layers=6): + super(SeparationNet, self).__init__() + + self.num_layers = num_layers + + self.dp_modules = nn.ModuleList([ + DualPathRNN(channels * (2 if i % 2 == 1 else 1), expand) for i in range(num_layers) + ]) + + self.feature_conversion = nn.ModuleList([ + FeatureConversion(channels * 2, inverse=False if i % 2 == 0 else True) for i in range(num_layers) + ]) + + def forward(self, x): + for i in range(self.num_layers): + x = self.dp_modules[i](x) + x = self.feature_conversion[i](x) + return x diff --git a/separator/models/scnet_unofficial/__init__.py b/separator/models/scnet_unofficial/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d034d38a2ace2e81bd28d63dd8f25feb918f33d --- /dev/null +++ b/separator/models/scnet_unofficial/__init__.py @@ -0,0 +1 @@ +from models.scnet_unofficial.scnet import SCNet \ No newline at end of file diff --git a/separator/models/scnet_unofficial/modules/__init__.py b/separator/models/scnet_unofficial/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69617bb15044d9bbfd0211fcdfa0fa605b01c048 --- /dev/null +++ b/separator/models/scnet_unofficial/modules/__init__.py @@ -0,0 +1,3 @@ +from models.scnet_unofficial.modules.dualpath_rnn import DualPathRNN +from models.scnet_unofficial.modules.sd_encoder import SDBlock +from models.scnet_unofficial.modules.su_decoder import SUBlock diff --git a/separator/models/scnet_unofficial/modules/dualpath_rnn.py b/separator/models/scnet_unofficial/modules/dualpath_rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..2dfcdbcfc102a6fde5a2ff53a2a06f2d6caae196 --- /dev/null +++ b/separator/models/scnet_unofficial/modules/dualpath_rnn.py @@ -0,0 +1,228 @@ +import torch +import torch.nn as nn +import torch.nn.functional as Func + +class RMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return Func.normalize(x, dim=-1) * self.scale * self.gamma + + +class MambaModule(nn.Module): + def __init__(self, d_model, d_state, d_conv, d_expand): + super().__init__() + self.norm = RMSNorm(dim=d_model) + self.mamba = Mamba( + d_model=d_model, + d_state=d_state, + d_conv=d_conv, + d_expand=d_expand + ) + + def forward(self, x): + x = x + self.mamba(self.norm(x)) + return x + + +class RNNModule(nn.Module): + """ + RNNModule class implements a recurrent neural network module with LSTM cells. + + Args: + - input_dim (int): Dimensionality of the input features. + - hidden_dim (int): Dimensionality of the hidden state of the LSTM. + - bidirectional (bool, optional): If True, uses bidirectional LSTM. Defaults to True. + + Shapes: + - Input: (B, T, D) where + B is batch size, + T is sequence length, + D is input dimensionality. + - Output: (B, T, D) where + B is batch size, + T is sequence length, + D is input dimensionality. + """ + + def __init__(self, input_dim: int, hidden_dim: int, bidirectional: bool = True): + """ + Initializes RNNModule with input dimension, hidden dimension, and bidirectional flag. + """ + super().__init__() + self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=input_dim) + self.rnn = nn.LSTM( + input_dim, hidden_dim, batch_first=True, bidirectional=bidirectional + ) + self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, input_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the RNNModule. + + Args: + - x (torch.Tensor): Input tensor of shape (B, T, D). + + Returns: + - torch.Tensor: Output tensor of shape (B, T, D). + """ + x = x.transpose(1, 2) + x = self.groupnorm(x) + x = x.transpose(1, 2) + + x, (hidden, _) = self.rnn(x) + x = self.fc(x) + return x + + +class RFFTModule(nn.Module): + """ + RFFTModule class implements a module for performing real-valued Fast Fourier Transform (FFT) + or its inverse on input tensors. + + Args: + - inverse (bool, optional): If False, performs forward FFT. If True, performs inverse FFT. Defaults to False. + + Shapes: + - Input: (B, F, T, D) where + B is batch size, + F is the number of features, + T is sequence length, + D is input dimensionality. + - Output: (B, F, T // 2 + 1, D * 2) if performing forward FFT. + (B, F, T, D // 2, 2) if performing inverse FFT. + """ + + def __init__(self, inverse: bool = False): + """ + Initializes RFFTModule with inverse flag. + """ + super().__init__() + self.inverse = inverse + + def forward(self, x: torch.Tensor, time_dim: int) -> torch.Tensor: + """ + Performs forward or inverse FFT on the input tensor x. + + Args: + - x (torch.Tensor): Input tensor of shape (B, F, T, D). + - time_dim (int): Input size of time dimension. + + Returns: + - torch.Tensor: Output tensor after FFT or its inverse operation. + """ + dtype = x.dtype + B, F, T, D = x.shape + + # RuntimeError: cuFFT only supports dimensions whose sizes are powers of two when computing in half precision + x = x.float() + + if not self.inverse: + x = torch.fft.rfft(x, dim=2) + x = torch.view_as_real(x) + x = x.reshape(B, F, T // 2 + 1, D * 2) + else: + x = x.reshape(B, F, T, D // 2, 2) + x = torch.view_as_complex(x) + x = torch.fft.irfft(x, n=time_dim, dim=2) + + x = x.to(dtype) + return x + + def extra_repr(self) -> str: + """ + Returns extra representation string with module's configuration. + """ + return f"inverse={self.inverse}" + + +class DualPathRNN(nn.Module): + """ + DualPathRNN class implements a neural network with alternating layers of RNNModule and RFFTModule. + + Args: + - n_layers (int): Number of layers in the network. + - input_dim (int): Dimensionality of the input features. + - hidden_dim (int): Dimensionality of the hidden state of the RNNModule. + + Shapes: + - Input: (B, F, T, D) where + B is batch size, + F is the number of features (frequency dimension), + T is sequence length (time dimension), + D is input dimensionality (channel dimension). + - Output: (B, F, T, D) where + B is batch size, + F is the number of features (frequency dimension), + T is sequence length (time dimension), + D is input dimensionality (channel dimension). + """ + + def __init__( + self, + n_layers: int, + input_dim: int, + hidden_dim: int, + + use_mamba: bool = False, + d_state: int = 16, + d_conv: int = 4, + d_expand: int = 2 + ): + """ + Initializes DualPathRNN with the specified number of layers, input dimension, and hidden dimension. + """ + super().__init__() + + if use_mamba: + from mamba_ssm.modules.mamba_simple import Mamba + net = MambaModule + dkwargs = {"d_model": input_dim, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand} + ukwargs = {"d_model": input_dim * 2, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand * 2} + else: + net = RNNModule + dkwargs = {"input_dim": input_dim, "hidden_dim": hidden_dim} + ukwargs = {"input_dim": input_dim * 2, "hidden_dim": hidden_dim * 2} + + self.layers = nn.ModuleList() + for i in range(1, n_layers + 1): + kwargs = dkwargs if i % 2 == 1 else ukwargs + layer = nn.ModuleList([ + net(**kwargs), + net(**kwargs), + RFFTModule(inverse=(i % 2 == 0)), + ]) + self.layers.append(layer) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the DualPathRNN. + + Args: + - x (torch.Tensor): Input tensor of shape (B, F, T, D). + + Returns: + - torch.Tensor: Output tensor of shape (B, F, T, D). + """ + + time_dim = x.shape[2] + + for time_layer, freq_layer, rfft_layer in self.layers: + B, F, T, D = x.shape + + x = x.reshape((B * F), T, D) + x = time_layer(x) + x = x.reshape(B, F, T, D) + x = x.permute(0, 2, 1, 3) + + x = x.reshape((B * T), F, D) + x = freq_layer(x) + x = x.reshape(B, T, F, D) + x = x.permute(0, 2, 1, 3) + + x = rfft_layer(x, time_dim) + + return x diff --git a/separator/models/scnet_unofficial/modules/sd_encoder.py b/separator/models/scnet_unofficial/modules/sd_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..742577f480693671437dc50358a1a65d251b6e9b --- /dev/null +++ b/separator/models/scnet_unofficial/modules/sd_encoder.py @@ -0,0 +1,285 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn + +from models.scnet_unofficial.utils import create_intervals + + +class Downsample(nn.Module): + """ + Downsample class implements a module for downsampling input tensors using 2D convolution. + + Args: + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels. + - stride (int): Stride value for the convolution operation. + + Shapes: + - Input: (B, C_in, F, T) where + B is batch size, + C_in is the number of input channels, + F is the frequency dimension, + T is the time dimension. + - Output: (B, C_out, F // stride, T) where + B is batch size, + C_out is the number of output channels, + F // stride is the downsampled frequency dimension. + + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + stride: int, + ): + """ + Initializes Downsample with input dimension, output dimension, and stride. + """ + super().__init__() + self.conv = nn.Conv2d(input_dim, output_dim, 1, (stride, 1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the Downsample module. + + Args: + - x (torch.Tensor): Input tensor of shape (B, C_in, F, T). + + Returns: + - torch.Tensor: Downsampled tensor of shape (B, C_out, F // stride, T). + """ + return self.conv(x) + + +class ConvolutionModule(nn.Module): + """ + ConvolutionModule class implements a module with a sequence of convolutional layers similar to Conformer. + + Args: + - input_dim (int): Dimensionality of the input features. + - hidden_dim (int): Dimensionality of the hidden features. + - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers. + - bias (bool, optional): If True, adds a learnable bias to the output. Default is False. + + Shapes: + - Input: (B, T, D) where + B is batch size, + T is sequence length, + D is input dimensionality. + - Output: (B, T, D) where + B is batch size, + T is sequence length, + D is input dimensionality. + """ + + def __init__( + self, + input_dim: int, + hidden_dim: int, + kernel_sizes: List[int], + bias: bool = False, + ) -> None: + """ + Initializes ConvolutionModule with input dimension, hidden dimension, kernel sizes, and bias. + """ + super().__init__() + self.sequential = nn.Sequential( + nn.GroupNorm(num_groups=1, num_channels=input_dim), + nn.Conv1d( + input_dim, + 2 * hidden_dim, + kernel_sizes[0], + stride=1, + padding=(kernel_sizes[0] - 1) // 2, + bias=bias, + ), + nn.GLU(dim=1), + nn.Conv1d( + hidden_dim, + hidden_dim, + kernel_sizes[1], + stride=1, + padding=(kernel_sizes[1] - 1) // 2, + groups=hidden_dim, + bias=bias, + ), + nn.GroupNorm(num_groups=1, num_channels=hidden_dim), + nn.SiLU(), + nn.Conv1d( + hidden_dim, + input_dim, + kernel_sizes[2], + stride=1, + padding=(kernel_sizes[2] - 1) // 2, + bias=bias, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the ConvolutionModule. + + Args: + - x (torch.Tensor): Input tensor of shape (B, T, D). + + Returns: + - torch.Tensor: Output tensor of shape (B, T, D). + """ + x = x.transpose(1, 2) + x = x + self.sequential(x) + x = x.transpose(1, 2) + return x + + +class SDLayer(nn.Module): + """ + SDLayer class implements a subband decomposition layer with downsampling and convolutional modules. + + Args: + - subband_interval (Tuple[float, float]): Tuple representing the frequency interval for subband decomposition. + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels after downsampling. + - downsample_stride (int): Stride value for the downsampling operation. + - n_conv_modules (int): Number of convolutional modules. + - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers. + - bias (bool, optional): If True, adds a learnable bias to the convolutional layers. Default is True. + + Shapes: + - Input: (B, Fi, T, Ci) where + B is batch size, + Fi is the number of input subbands, + T is sequence length, and + Ci is the number of input channels. + - Output: (B, Fi+1, T, Ci+1) where + B is batch size, + Fi+1 is the number of output subbands, + T is sequence length, + Ci+1 is the number of output channels. + """ + + def __init__( + self, + subband_interval: Tuple[float, float], + input_dim: int, + output_dim: int, + downsample_stride: int, + n_conv_modules: int, + kernel_sizes: List[int], + bias: bool = True, + ): + """ + Initializes SDLayer with subband interval, input dimension, + output dimension, downsample stride, number of convolutional modules, kernel sizes, and bias. + """ + super().__init__() + self.subband_interval = subband_interval + self.downsample = Downsample(input_dim, output_dim, downsample_stride) + self.activation = nn.GELU() + conv_modules = [ + ConvolutionModule( + input_dim=output_dim, + hidden_dim=output_dim // 4, + kernel_sizes=kernel_sizes, + bias=bias, + ) + for _ in range(n_conv_modules) + ] + self.conv_modules = nn.Sequential(*conv_modules) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the SDLayer. + + Args: + - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci). + + Returns: + - torch.Tensor: Output tensor of shape (B, Fi+1, T, Ci+1). + """ + B, F, T, C = x.shape + x = x[:, int(self.subband_interval[0] * F) : int(self.subband_interval[1] * F)] + x = x.permute(0, 3, 1, 2) + x = self.downsample(x) + x = self.activation(x) + x = x.permute(0, 2, 3, 1) + + B, F, T, C = x.shape + x = x.reshape((B * F), T, C) + x = self.conv_modules(x) + x = x.reshape(B, F, T, C) + + return x + + +class SDBlock(nn.Module): + """ + SDBlock class implements a block with subband decomposition layers and global convolution. + + Args: + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels. + - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands. + - downsample_strides (List[int]): List of stride values for downsampling in each subband layer. + - n_conv_modules (List[int]): List specifying the number of convolutional modules in each subband layer. + - kernel_sizes (List[int], optional): List of kernel sizes for the convolutional layers. Default is None. + + Shapes: + - Input: (B, Fi, T, Ci) where + B is batch size, + Fi is the number of input subbands, + T is sequence length, + Ci is the number of input channels. + - Output: (B, Fi+1, T, Ci+1) where + B is batch size, + Fi+1 is the number of output subbands, + T is sequence length, + Ci+1 is the number of output channels. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + bandsplit_ratios: List[float], + downsample_strides: List[int], + n_conv_modules: List[int], + kernel_sizes: List[int] = None, + ): + """ + Initializes SDBlock with input dimension, output dimension, band split ratios, downsample strides, number of convolutional modules, and kernel sizes. + """ + super().__init__() + if kernel_sizes is None: + kernel_sizes = [3, 3, 1] + assert sum(bandsplit_ratios) == 1, "The split ratios must sum up to 1." + subband_intervals = create_intervals(bandsplit_ratios) + self.sd_layers = nn.ModuleList( + SDLayer( + input_dim=input_dim, + output_dim=output_dim, + subband_interval=sbi, + downsample_stride=dss, + n_conv_modules=ncm, + kernel_sizes=kernel_sizes, + ) + for sbi, dss, ncm in zip( + subband_intervals, downsample_strides, n_conv_modules + ) + ) + self.global_conv2d = nn.Conv2d(output_dim, output_dim, 1, 1) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Performs forward pass through the SDBlock. + + Args: + - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci). + + Returns: + - Tuple[torch.Tensor, torch.Tensor]: Output tensor and skip connection tensor. + """ + x_skip = torch.concat([layer(x) for layer in self.sd_layers], dim=1) + x = self.global_conv2d(x_skip.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + return x, x_skip diff --git a/separator/models/scnet_unofficial/modules/su_decoder.py b/separator/models/scnet_unofficial/modules/su_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..660c1fa6cbfd9b43bed73204a0bb6593524de272 --- /dev/null +++ b/separator/models/scnet_unofficial/modules/su_decoder.py @@ -0,0 +1,241 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn + +from models.scnet_unofficial.utils import get_convtranspose_output_padding + + +class FusionLayer(nn.Module): + """ + FusionLayer class implements a module for fusing two input tensors using convolutional operations. + + Args: + - input_dim (int): Dimensionality of the input channels. + - kernel_size (int, optional): Kernel size for the convolutional layer. Default is 3. + - stride (int, optional): Stride value for the convolutional layer. Default is 1. + - padding (int, optional): Padding value for the convolutional layer. Default is 1. + + Shapes: + - Input: (B, F, T, C) and (B, F, T, C) where + B is batch size, + F is the number of features, + T is sequence length, + C is input dimensionality. + - Output: (B, F, T, C) where + B is batch size, + F is the number of features, + T is sequence length, + C is input dimensionality. + """ + + def __init__( + self, input_dim: int, kernel_size: int = 3, stride: int = 1, padding: int = 1 + ): + """ + Initializes FusionLayer with input dimension, kernel size, stride, and padding. + """ + super().__init__() + self.conv = nn.Conv2d( + input_dim * 2, + input_dim * 2, + kernel_size=(kernel_size, 1), + stride=(stride, 1), + padding=(padding, 0), + ) + self.activation = nn.GLU() + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the FusionLayer. + + Args: + - x1 (torch.Tensor): First input tensor of shape (B, F, T, C). + - x2 (torch.Tensor): Second input tensor of shape (B, F, T, C). + + Returns: + - torch.Tensor: Output tensor of shape (B, F, T, C). + """ + x = x1 + x2 + x = x.repeat(1, 1, 1, 2) + x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + x = self.activation(x) + return x + + +class Upsample(nn.Module): + """ + Upsample class implements a module for upsampling input tensors using transposed 2D convolution. + + Args: + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels. + - stride (int): Stride value for the transposed convolution operation. + - output_padding (int): Output padding value for the transposed convolution operation. + + Shapes: + - Input: (B, C_in, F, T) where + B is batch size, + C_in is the number of input channels, + F is the frequency dimension, + T is the time dimension. + - Output: (B, C_out, F * stride + output_padding, T) where + B is batch size, + C_out is the number of output channels, + F * stride + output_padding is the upsampled frequency dimension. + """ + + def __init__( + self, input_dim: int, output_dim: int, stride: int, output_padding: int + ): + """ + Initializes Upsample with input dimension, output dimension, stride, and output padding. + """ + super().__init__() + self.conv = nn.ConvTranspose2d( + input_dim, output_dim, 1, (stride, 1), output_padding=(output_padding, 0) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the Upsample module. + + Args: + - x (torch.Tensor): Input tensor of shape (B, C_in, F, T). + + Returns: + - torch.Tensor: Output tensor of shape (B, C_out, F * stride + output_padding, T). + """ + return self.conv(x) + + +class SULayer(nn.Module): + """ + SULayer class implements a subband upsampling layer using transposed convolution. + + Args: + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels. + - upsample_stride (int): Stride value for the upsampling operation. + - subband_shape (int): Shape of the subband. + - sd_interval (Tuple[int, int]): Start and end indices of the subband interval. + + Shapes: + - Input: (B, F, T, C) where + B is batch size, + F is the number of features, + T is sequence length, + C is input dimensionality. + - Output: (B, F, T, C) where + B is batch size, + F is the number of features, + T is sequence length, + C is input dimensionality. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + upsample_stride: int, + subband_shape: int, + sd_interval: Tuple[int, int], + ): + """ + Initializes SULayer with input dimension, output dimension, upsample stride, subband shape, and subband interval. + """ + super().__init__() + sd_shape = sd_interval[1] - sd_interval[0] + upsample_output_padding = get_convtranspose_output_padding( + input_shape=sd_shape, output_shape=subband_shape, stride=upsample_stride + ) + self.upsample = Upsample( + input_dim=input_dim, + output_dim=output_dim, + stride=upsample_stride, + output_padding=upsample_output_padding, + ) + self.sd_interval = sd_interval + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the SULayer. + + Args: + - x (torch.Tensor): Input tensor of shape (B, F, T, C). + + Returns: + - torch.Tensor: Output tensor of shape (B, F, T, C). + """ + x = x[:, self.sd_interval[0] : self.sd_interval[1]] + x = x.permute(0, 3, 1, 2) + x = self.upsample(x) + x = x.permute(0, 2, 3, 1) + return x + + +class SUBlock(nn.Module): + """ + SUBlock class implements a block with fusion layer and subband upsampling layers. + + Args: + - input_dim (int): Dimensionality of the input channels. + - output_dim (int): Dimensionality of the output channels. + - upsample_strides (List[int]): List of stride values for the upsampling operations. + - subband_shapes (List[int]): List of shapes for the subbands. + - sd_intervals (List[Tuple[int, int]]): List of intervals for subband decomposition. + + Shapes: + - Input: (B, Fi-1, T, Ci-1) and (B, Fi-1, T, Ci-1) where + B is batch size, + Fi-1 is the number of input subbands, + T is sequence length, + Ci-1 is the number of input channels. + - Output: (B, Fi, T, Ci) where + B is batch size, + Fi is the number of output subbands, + T is sequence length, + Ci is the number of output channels. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + upsample_strides: List[int], + subband_shapes: List[int], + sd_intervals: List[Tuple[int, int]], + ): + """ + Initializes SUBlock with input dimension, output dimension, + upsample strides, subband shapes, and subband intervals. + """ + super().__init__() + self.fusion_layer = FusionLayer(input_dim=input_dim) + self.su_layers = nn.ModuleList( + SULayer( + input_dim=input_dim, + output_dim=output_dim, + upsample_stride=uss, + subband_shape=sbs, + sd_interval=sdi, + ) + for i, (uss, sbs, sdi) in enumerate( + zip(upsample_strides, subband_shapes, sd_intervals) + ) + ) + + def forward(self, x: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the SUBlock. + + Args: + - x (torch.Tensor): Input tensor of shape (B, Fi-1, T, Ci-1). + - x_skip (torch.Tensor): Input skip connection tensor of shape (B, Fi-1, T, Ci-1). + + Returns: + - torch.Tensor: Output tensor of shape (B, Fi, T, Ci). + """ + x = self.fusion_layer(x, x_skip) + x = torch.concat([layer(x) for layer in self.su_layers], dim=1) + return x diff --git a/separator/models/scnet_unofficial/scnet.py b/separator/models/scnet_unofficial/scnet.py new file mode 100644 index 0000000000000000000000000000000000000000..d076f85f1d5ce1345dc9a8c56b6a5aef09f2facc --- /dev/null +++ b/separator/models/scnet_unofficial/scnet.py @@ -0,0 +1,249 @@ +''' +SCNet - great paper, great implementation +https://arxiv.org/pdf/2401.13276.pdf +https://github.com/amanteur/SCNet-PyTorch +''' + +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +from models.scnet_unofficial.modules import DualPathRNN, SDBlock, SUBlock +from models.scnet_unofficial.utils import compute_sd_layer_shapes, compute_gcr + +from einops import rearrange, pack, unpack +from functools import partial + +from beartype.typing import Tuple, Optional, List, Callable +from beartype import beartype + +def exists(val): + return val is not None + + +def default(v, d): + return v if exists(v) else d + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +class RMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim=-1) * self.scale * self.gamma + + +class BandSplit(nn.Module): + @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...] + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential( + RMSNorm(dim_in), + nn.Linear(dim_in, dim) + ) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +class SCNet(nn.Module): + """ + SCNet class implements a source separation network, + which explicitly split the spectrogram of the mixture into several subbands + and introduce a sparsity-based encoder to model different frequency bands. + + Paper: "SCNET: SPARSE COMPRESSION NETWORK FOR MUSIC SOURCE SEPARATION" + Authors: Weinan Tong, Jiaxu Zhu et al. + Link: https://arxiv.org/abs/2401.13276.pdf + + Args: + - n_fft (int): Number of FFTs to determine the frequency dimension of the input. + - dims (List[int]): List of channel dimensions for each block. + - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands. + - downsample_strides (List[int]): List of stride values for downsampling in each block. + - n_conv_modules (List[int]): List specifying the number of convolutional modules in each block. + - n_rnn_layers (int): Number of recurrent layers in the dual path RNN. + - rnn_hidden_dim (int): Dimensionality of the hidden state in the dual path RNN. + - n_sources (int, optional): Number of sources to be separated. Default is 4. + + Shapes: + - Input: (B, C, T) where + B is batch size, + C is channel dim (mono / stereo), + T is time dim + - Output: (B, N, C, T) where + B is batch size, + N is the number of sources. + C is channel dim (mono / stereo), + T is sequence length, + """ + @beartype + def __init__( + self, + n_fft: int, + dims: List[int], + bandsplit_ratios: List[float], + downsample_strides: List[int], + n_conv_modules: List[int], + n_rnn_layers: int, + rnn_hidden_dim: int, + n_sources: int = 4, + hop_length: int = 1024, + win_length: int = 4096, + stft_window_fn: Optional[Callable] = None, + stft_normalized: bool = False, + **kwargs + ): + """ + Initializes SCNet with input parameters. + """ + super().__init__() + self.assert_input_data( + bandsplit_ratios, + downsample_strides, + n_conv_modules, + ) + + n_blocks = len(dims) - 1 + n_freq_bins = n_fft // 2 + 1 + subband_shapes, sd_intervals = compute_sd_layer_shapes( + input_shape=n_freq_bins, + bandsplit_ratios=bandsplit_ratios, + downsample_strides=downsample_strides, + n_layers=n_blocks, + ) + self.sd_blocks = nn.ModuleList( + SDBlock( + input_dim=dims[i], + output_dim=dims[i + 1], + bandsplit_ratios=bandsplit_ratios, + downsample_strides=downsample_strides, + n_conv_modules=n_conv_modules, + ) + for i in range(n_blocks) + ) + self.dualpath_blocks = DualPathRNN( + n_layers=n_rnn_layers, + input_dim=dims[-1], + hidden_dim=rnn_hidden_dim, + **kwargs + ) + self.su_blocks = nn.ModuleList( + SUBlock( + input_dim=dims[i + 1], + output_dim=dims[i] if i != 0 else dims[i] * n_sources, + subband_shapes=subband_shapes[i], + sd_intervals=sd_intervals[i], + upsample_strides=downsample_strides, + ) + for i in reversed(range(n_blocks)) + ) + self.gcr = compute_gcr(subband_shapes) + + self.stft_kwargs = dict( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + normalized=stft_normalized + ) + + self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), win_length) + self.n_sources = n_sources + self.hop_length = hop_length + + @staticmethod + def assert_input_data(*args): + """ + Asserts that the shapes of input features are equal. + """ + for arg1 in args: + for arg2 in args: + if len(arg1) != len(arg2): + raise ValueError( + f"Shapes of input features {arg1} and {arg2} are not equal." + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs forward pass through the SCNet. + + Args: + - x (torch.Tensor): Input tensor of shape (B, C, T). + + Returns: + - torch.Tensor: Output tensor of shape (B, N, C, T). + """ + + device = x.device + stft_window = self.stft_window_fn(device=device) + + if x.ndim == 2: + x = rearrange(x, 'b t -> b 1 t') + + c = x.shape[1] + + stft_pad = self.hop_length - x.shape[-1] % self.hop_length + x = F.pad(x, (0, stft_pad)) + + # stft + x, ps = pack_one(x, '* t') + x = torch.stft(x, **self.stft_kwargs, window=stft_window, return_complex=True) + x = torch.view_as_real(x) + x = unpack_one(x, ps, '* c f t') + x = rearrange(x, 'b c f t r -> b f t (c r)') + + # encoder part + x_skips = [] + for sd_block in self.sd_blocks: + x, x_skip = sd_block(x) + x_skips.append(x_skip) + + # separation part + x = self.dualpath_blocks(x) + + # decoder part + for su_block, x_skip in zip(self.su_blocks, reversed(x_skips)): + x = su_block(x, x_skip) + + # istft + x = rearrange(x, 'b f t (c r n) -> b n c f t r', c=c, n=self.n_sources, r=2) + x = x.contiguous() + + x = torch.view_as_complex(x) + x = rearrange(x, 'b n c f t -> (b n c) f t') + x = torch.istft(x, **self.stft_kwargs, window=stft_window, return_complex=False) + x = rearrange(x, '(b n c) t -> b n c t', c=c, n=self.n_sources) + + x = x[..., :-stft_pad] + + return x diff --git a/separator/models/scnet_unofficial/utils.py b/separator/models/scnet_unofficial/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aae1afcd52e8088926ea984e52c9b62ca68be65c --- /dev/null +++ b/separator/models/scnet_unofficial/utils.py @@ -0,0 +1,135 @@ +''' +SCNet - great paper, great implementation +https://arxiv.org/pdf/2401.13276.pdf +https://github.com/amanteur/SCNet-PyTorch +''' + +from typing import List, Tuple, Union + +import torch + + +def create_intervals( + splits: List[Union[float, int]] +) -> List[Union[Tuple[float, float], Tuple[int, int]]]: + """ + Create intervals based on splits provided. + + Args: + - splits (List[Union[float, int]]): List of floats or integers representing splits. + + Returns: + - List[Union[Tuple[float, float], Tuple[int, int]]]: List of tuples representing intervals. + """ + start = 0 + return [(start, start := start + split) for split in splits] + + +def get_conv_output_shape( + input_shape: int, + kernel_size: int = 1, + padding: int = 0, + dilation: int = 1, + stride: int = 1, +) -> int: + """ + Compute the output shape of a convolutional layer. + + Args: + - input_shape (int): Input shape. + - kernel_size (int, optional): Kernel size of the convolution. Default is 1. + - padding (int, optional): Padding size. Default is 0. + - dilation (int, optional): Dilation factor. Default is 1. + - stride (int, optional): Stride value. Default is 1. + + Returns: + - int: Output shape. + """ + return int( + (input_shape + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 + ) + + +def get_convtranspose_output_padding( + input_shape: int, + output_shape: int, + kernel_size: int = 1, + padding: int = 0, + dilation: int = 1, + stride: int = 1, +) -> int: + """ + Compute the output padding for a convolution transpose operation. + + Args: + - input_shape (int): Input shape. + - output_shape (int): Desired output shape. + - kernel_size (int, optional): Kernel size of the convolution. Default is 1. + - padding (int, optional): Padding size. Default is 0. + - dilation (int, optional): Dilation factor. Default is 1. + - stride (int, optional): Stride value. Default is 1. + + Returns: + - int: Output padding. + """ + return ( + output_shape + - (input_shape - 1) * stride + + 2 * padding + - dilation * (kernel_size - 1) + - 1 + ) + + +def compute_sd_layer_shapes( + input_shape: int, + bandsplit_ratios: List[float], + downsample_strides: List[int], + n_layers: int, +) -> Tuple[List[List[int]], List[List[Tuple[int, int]]]]: + """ + Compute the shapes for the subband layers. + + Args: + - input_shape (int): Input shape. + - bandsplit_ratios (List[float]): Ratios for splitting the frequency bands. + - downsample_strides (List[int]): Strides for downsampling in each layer. + - n_layers (int): Number of layers. + + Returns: + - Tuple[List[List[int]], List[List[Tuple[int, int]]]]: Tuple containing subband shapes and convolution shapes. + """ + bandsplit_shapes_list = [] + conv2d_shapes_list = [] + for _ in range(n_layers): + bandsplit_intervals = create_intervals(bandsplit_ratios) + bandsplit_shapes = [ + int(right * input_shape) - int(left * input_shape) + for left, right in bandsplit_intervals + ] + conv2d_shapes = [ + get_conv_output_shape(bs, stride=ds) + for bs, ds in zip(bandsplit_shapes, downsample_strides) + ] + input_shape = sum(conv2d_shapes) + bandsplit_shapes_list.append(bandsplit_shapes) + conv2d_shapes_list.append(create_intervals(conv2d_shapes)) + + return bandsplit_shapes_list, conv2d_shapes_list + + +def compute_gcr(subband_shapes: List[List[int]]) -> float: + """ + Compute the global compression ratio. + + Args: + - subband_shapes (List[List[int]]): List of subband shapes. + + Returns: + - float: Global compression ratio. + """ + t = torch.Tensor(subband_shapes) + gcr = torch.stack( + [(1 - t[i + 1] / t[i]).mean() for i in range(0, len(t) - 1)] + ).mean() + return float(gcr) \ No newline at end of file diff --git a/separator/models/segm_models.py b/separator/models/segm_models.py new file mode 100644 index 0000000000000000000000000000000000000000..1fdf11468935b74fdb825c1c422faeaef60f4339 --- /dev/null +++ b/separator/models/segm_models.py @@ -0,0 +1,255 @@ +import torch +import torch.nn as nn +import segmentation_models_pytorch as smp +from msst_utils import prefer_target_instrument + +class STFT: + def __init__(self, config): + self.n_fft = config.n_fft + self.hop_length = config.hop_length + self.window = torch.hann_window(window_length=self.n_fft, periodic=True) + self.dim_f = config.dim_f + + def __call__(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-2] + c, t = x.shape[-2:] + x = x.reshape([-1, t]) + x = torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True, + return_complex=True + ) + x = torch.view_as_real(x) + x = x.permute([0, 3, 1, 2]) + x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) + return x[..., :self.dim_f, :] + + def inverse(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-3] + c, f, t = x.shape[-3:] + n = self.n_fft // 2 + 1 + f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) + x = torch.cat([x, f_pad], -2) + x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) + x = x.permute([0, 2, 3, 1]) + x = x[..., 0] + x[..., 1] * 1.j + x = torch.istft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True + ) + x = x.reshape([*batch_dims, 2, -1]) + return x + + +def get_act(act_type): + if act_type == 'gelu': + return nn.GELU() + elif act_type == 'relu': + return nn.ReLU() + elif act_type[:3] == 'elu': + alpha = float(act_type.replace('elu', '')) + return nn.ELU(alpha) + else: + raise Exception + + +def get_decoder(config, c): + decoder = None + decoder_options = dict() + if config.model.decoder_type == 'unet': + try: + decoder_options = dict(config.decoder_unet) + except: + pass + decoder = smp.Unet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'fpn': + try: + decoder_options = dict(config.decoder_fpn) + except: + pass + decoder = smp.FPN( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'unet++': + try: + decoder_options = dict(config.decoder_unet_plus_plus) + except: + pass + decoder = smp.UnetPlusPlus( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'manet': + try: + decoder_options = dict(config.decoder_manet) + except: + pass + decoder = smp.MAnet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'linknet': + try: + decoder_options = dict(config.decoder_linknet) + except: + pass + decoder = smp.Linknet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'pspnet': + try: + decoder_options = dict(config.decoder_pspnet) + except: + pass + decoder = smp.PSPNet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'pspnet': + try: + decoder_options = dict(config.decoder_pspnet) + except: + pass + decoder = smp.PSPNet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'pan': + try: + decoder_options = dict(config.decoder_pan) + except: + pass + decoder = smp.PAN( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'deeplabv3': + try: + decoder_options = dict(config.decoder_deeplabv3) + except: + pass + decoder = smp.DeepLabV3( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'deeplabv3plus': + try: + decoder_options = dict(config.decoder_deeplabv3plus) + except: + pass + decoder = smp.DeepLabV3Plus( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + return decoder + + +class Segm_Models_Net(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + act = get_act(act_type=config.model.act) + + self.num_target_instruments = len(prefer_target_instrument(config)) + self.num_subbands = config.model.num_subbands + + dim_c = self.num_subbands * config.audio.num_channels * 2 + c = config.model.num_channels + f = config.audio.dim_f // self.num_subbands + + self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False) + + self.unet_model = get_decoder(config, c) + + self.final_conv = nn.Sequential( + nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), + act, + nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) + ) + + self.stft = STFT(config.audio) + + def cac2cws(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c, k, f // k, t) + x = x.reshape(b, c * k, f // k, t) + return x + + def cws2cac(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c // k, k, f, t) + x = x.reshape(b, c // k, f * k, t) + return x + + def forward(self, x): + + x = self.stft(x) + + mix = x = self.cac2cws(x) + + first_conv_out = x = self.first_conv(x) + + x = x.transpose(-1, -2) + + x = self.unet_model(x) + + x = x.transpose(-1, -2) + + x = x * first_conv_out # reduce artifacts + + x = self.final_conv(torch.cat([mix, x], 1)) + + x = self.cws2cac(x) + + if self.num_target_instruments > 1: + b, c, f, t = x.shape + x = x.reshape(b, self.num_target_instruments, -1, f, t) + + x = self.stft.inverse(x) + return x diff --git a/separator/models/torchseg_models.py b/separator/models/torchseg_models.py new file mode 100644 index 0000000000000000000000000000000000000000..c054ce475ae122f1831b2a93c1764d6998b83b98 --- /dev/null +++ b/separator/models/torchseg_models.py @@ -0,0 +1,255 @@ +import torch +import torch.nn as nn +import torchseg as smp +from msst_utils import prefer_target_instrument + +class STFT: + def __init__(self, config): + self.n_fft = config.n_fft + self.hop_length = config.hop_length + self.window = torch.hann_window(window_length=self.n_fft, periodic=True) + self.dim_f = config.dim_f + + def __call__(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-2] + c, t = x.shape[-2:] + x = x.reshape([-1, t]) + x = torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True, + return_complex=True + ) + x = torch.view_as_real(x) + x = x.permute([0, 3, 1, 2]) + x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) + return x[..., :self.dim_f, :] + + def inverse(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-3] + c, f, t = x.shape[-3:] + n = self.n_fft // 2 + 1 + f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) + x = torch.cat([x, f_pad], -2) + x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) + x = x.permute([0, 2, 3, 1]) + x = x[..., 0] + x[..., 1] * 1.j + x = torch.istft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True + ) + x = x.reshape([*batch_dims, 2, -1]) + return x + + +def get_act(act_type): + if act_type == 'gelu': + return nn.GELU() + elif act_type == 'relu': + return nn.ReLU() + elif act_type[:3] == 'elu': + alpha = float(act_type.replace('elu', '')) + return nn.ELU(alpha) + else: + raise Exception + + +def get_decoder(config, c): + decoder = None + decoder_options = dict() + if config.model.decoder_type == 'unet': + try: + decoder_options = dict(config.decoder_unet) + except: + pass + decoder = smp.Unet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'fpn': + try: + decoder_options = dict(config.decoder_fpn) + except: + pass + decoder = smp.FPN( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'unet++': + try: + decoder_options = dict(config.decoder_unet_plus_plus) + except: + pass + decoder = smp.UnetPlusPlus( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'manet': + try: + decoder_options = dict(config.decoder_manet) + except: + pass + decoder = smp.MAnet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'linknet': + try: + decoder_options = dict(config.decoder_linknet) + except: + pass + decoder = smp.Linknet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'pspnet': + try: + decoder_options = dict(config.decoder_pspnet) + except: + pass + decoder = smp.PSPNet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'pspnet': + try: + decoder_options = dict(config.decoder_pspnet) + except: + pass + decoder = smp.PSPNet( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'pan': + try: + decoder_options = dict(config.decoder_pan) + except: + pass + decoder = smp.PAN( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'deeplabv3': + try: + decoder_options = dict(config.decoder_deeplabv3) + except: + pass + decoder = smp.DeepLabV3( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + elif config.model.decoder_type == 'deeplabv3plus': + try: + decoder_options = dict(config.decoder_deeplabv3plus) + except: + pass + decoder = smp.DeepLabV3Plus( + encoder_name=config.model.encoder_name, + encoder_weights="imagenet", + in_channels=c, + classes=c, + **decoder_options, + ) + return decoder + + +class Torchseg_Net(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + act = get_act(act_type=config.model.act) + + self.num_target_instruments = len(prefer_target_instrument(config)) + self.num_subbands = config.model.num_subbands + + dim_c = self.num_subbands * config.audio.num_channels * 2 + c = config.model.num_channels + f = config.audio.dim_f // self.num_subbands + + self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False) + + self.unet_model = get_decoder(config, c) + + self.final_conv = nn.Sequential( + nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), + act, + nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) + ) + + self.stft = STFT(config.audio) + + def cac2cws(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c, k, f // k, t) + x = x.reshape(b, c * k, f // k, t) + return x + + def cws2cac(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c // k, k, f, t) + x = x.reshape(b, c // k, f * k, t) + return x + + def forward(self, x): + + x = self.stft(x) + + mix = x = self.cac2cws(x) + + first_conv_out = x = self.first_conv(x) + + x = x.transpose(-1, -2) + + x = self.unet_model(x) + + x = x.transpose(-1, -2) + + x = x * first_conv_out # reduce artifacts + + x = self.final_conv(torch.cat([mix, x], 1)) + + x = self.cws2cac(x) + + if self.num_target_instruments > 1: + b, c, f, t = x.shape + x = x.reshape(b, self.num_target_instruments, -1, f, t) + + x = self.stft.inverse(x) + return x diff --git a/separator/models/ts_bs_mamba2.py b/separator/models/ts_bs_mamba2.py new file mode 100644 index 0000000000000000000000000000000000000000..a01a2a58c5418bfd846d6edb6900810392091189 --- /dev/null +++ b/separator/models/ts_bs_mamba2.py @@ -0,0 +1,319 @@ +# https://github.com/Human9000/nd-Mamba2-torch + +from __future__ import print_function + +import torch +import torch.nn as nn +import numpy as np +from torch.utils.checkpoint import checkpoint_sequential +try: + from mamba_ssm.modules.mamba2 import Mamba2 +except Exception as e: + print('Exception during load Mamba2 modules: {}'.format(str(e))) + print('Load local torch implementation!') + from .ex_bi_mamba2 import Mamba2 + + +class MambaBlock(nn.Module): + def __init__(self, in_channels): + super(MambaBlock, self).__init__() + self.forward_mamba2 = Mamba2( + d_model=in_channels, + d_state=128, + d_conv=4, + expand=4, + headdim=64, + ) + + self.backward_mamba2 = Mamba2( + d_model=in_channels, + d_state=128, + d_conv=4, + expand=4, + headdim=64, + ) + def forward(self, input): + forward_f = input + forward_f_output = self.forward_mamba2(forward_f) + backward_f = torch.flip(input, [1]) + backward_f_output = self.backward_mamba2(backward_f) + backward_f_output2 = torch.flip(backward_f_output, [1]) + output = torch.cat([forward_f_output + input, backward_f_output2+input], -1) + return output + +class TAC(nn.Module): + """ + A transform-average-concatenate (TAC) module. + """ + def __init__(self, input_size, hidden_size): + super(TAC, self).__init__() + + self.input_size = input_size + self.eps = torch.finfo(torch.float32).eps + + self.input_norm = nn.GroupNorm(1, input_size, self.eps) + self.TAC_input = nn.Sequential(nn.Linear(input_size, hidden_size), + nn.Tanh() + ) + self.TAC_mean = nn.Sequential(nn.Linear(hidden_size, hidden_size), + nn.Tanh() + ) + self.TAC_output = nn.Sequential(nn.Linear(hidden_size*2, input_size), + nn.Tanh() + ) + + def forward(self, input): + # input shape: batch, group, N, * + + batch_size, G, N = input.shape[:3] + output = self.input_norm(input.view(batch_size*G, N, -1)).view(batch_size, G, N, -1) + T = output.shape[-1] + + # transform + group_input = output # B, G, N, T + group_input = group_input.permute(0,3,1,2).contiguous().view(-1, N) # B*T*G, N + group_output = self.TAC_input(group_input).view(batch_size, T, G, -1) # B, T, G, H + + # mean pooling + group_mean = group_output.mean(2).view(batch_size*T, -1) # B*T, H + group_mean = self.TAC_mean(group_mean).unsqueeze(1).expand(batch_size*T, G, group_mean.shape[-1]).contiguous() # B*T, G, H + + # concate + group_output = group_output.view(batch_size*T, G, -1) # B*T, G, H + group_output = torch.cat([group_output, group_mean], 2) # B*T, G, 2H + group_output = self.TAC_output(group_output.view(-1, group_output.shape[-1])) # B*T*G, N + group_output = group_output.view(batch_size, T, G, -1).permute(0,2,3,1).contiguous() # B, G, N, T + output = input + group_output.view(input.shape) + + return output + +class ResMamba(nn.Module): + def __init__(self, input_size, hidden_size, dropout=0., bidirectional=True): + super(ResMamba, self).__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.eps = torch.finfo(torch.float32).eps + + self.norm = nn.GroupNorm(1, input_size, self.eps) + self.dropout = nn.Dropout(p=dropout) + self.rnn = MambaBlock(input_size) + self.proj = nn.Linear(input_size*2 ,input_size) + # linear projection layer + + def forward(self, input): + # input shape: batch, dim, seq + rnn_output = self.rnn(self.dropout(self.norm(input)).transpose(1, 2).contiguous()) + rnn_output = self.proj(rnn_output.contiguous().view(-1, rnn_output.shape[2])).view(input.shape[0], + input.shape[2], + input.shape[1]) + + return input + rnn_output.transpose(1, 2).contiguous() + +class BSNet(nn.Module): + def __init__(self, in_channel, nband=7): + super(BSNet, self).__init__() + + self.nband = nband + self.feature_dim = in_channel // nband + + self.band_rnn = ResMamba(self.feature_dim, self.feature_dim*2) + self.band_comm = ResMamba(self.feature_dim, self.feature_dim*2) + self.channel_comm = TAC(self.feature_dim, self.feature_dim*3) + + def forward(self, input): + # input shape: B, nch, nband*N, T + B, nch, N, T = input.shape + + band_output = self.band_rnn(input.view(B*nch*self.nband, self.feature_dim, -1)).view(B*nch, self.nband, -1, T) + + # band comm + band_output = band_output.permute(0,3,2,1).contiguous().view(B*nch*T, -1, self.nband) + output = self.band_comm(band_output).view(B*nch, T, -1, self.nband).permute(0,3,2,1).contiguous() + + # channel comm + output = output.view(B, nch, self.nband, -1, T).transpose(1,2).contiguous().view(B*self.nband, nch, -1, T) + output = self.channel_comm(output).view(B, self.nband, nch, -1, T).transpose(1,2).contiguous() + + return output.view(B, nch, N, T) + +class Separator(nn.Module): + def __init__(self, sr=44100, win=2048, stride=512, feature_dim=128, num_repeat_mask=8, num_repeat_map=4, num_output=4): + super(Separator, self).__init__() + + self.sr = sr + self.win = win + self.stride = stride + self.group = self.win // 2 + self.enc_dim = self.win // 2 + 1 + self.feature_dim = feature_dim + self.num_output = num_output + self.eps = torch.finfo(torch.float32).eps + + # 0-1k (50 hop), 1k-2k (100 hop), 2k-4k (250 hop), 4k-8k (500 hop), 8k-16k (1k hop), 16k-20k (2k hop), 20k-inf + bandwidth_50 = int(np.floor(50 / (sr / 2.) * self.enc_dim)) + bandwidth_100 = int(np.floor(100 / (sr / 2.) * self.enc_dim)) + bandwidth_250 = int(np.floor(250 / (sr / 2.) * self.enc_dim)) + bandwidth_500 = int(np.floor(500 / (sr / 2.) * self.enc_dim)) + bandwidth_1k = int(np.floor(1000 / (sr / 2.) * self.enc_dim)) + bandwidth_2k = int(np.floor(2000 / (sr / 2.) * self.enc_dim)) + self.band_width = [bandwidth_50]*20 + self.band_width += [bandwidth_100]*10 + self.band_width += [bandwidth_250]*8 + self.band_width += [bandwidth_500]*8 + self.band_width += [bandwidth_1k]*8 + self.band_width += [bandwidth_2k]*2 + self.band_width.append(self.enc_dim - np.sum(self.band_width)) + self.nband = len(self.band_width) + print(self.band_width) + + self.BN_mask = nn.ModuleList([]) + for i in range(self.nband): + self.BN_mask.append(nn.Sequential(nn.GroupNorm(1, self.band_width[i]*2, self.eps), + nn.Conv1d(self.band_width[i]*2, self.feature_dim, 1) + ) + ) + + self.BN_map = nn.ModuleList([]) + for i in range(self.nband): + self.BN_map.append(nn.Sequential(nn.GroupNorm(1, self.band_width[i] * 2, self.eps), + nn.Conv1d(self.band_width[i] * 2, self.feature_dim, 1) + ) + ) + + self.separator_mask = [] + for i in range(num_repeat_mask): + self.separator_mask.append(BSNet(self.nband*self.feature_dim, self.nband)) + self.separator_mask = nn.Sequential(*self.separator_mask) + + self.separator_map = [] + for i in range(num_repeat_map): + self.separator_map.append(BSNet(self.nband * self.feature_dim, self.nband)) + self.separator_map = nn.Sequential(*self.separator_map) + + self.in_conv = nn.Conv1d(self.feature_dim*2, self.feature_dim, 1) + self.Tanh = nn.Tanh() + self.mask = nn.ModuleList([]) + self.map = nn.ModuleList([]) + for i in range(self.nband): + self.mask.append(nn.Sequential(nn.GroupNorm(1, self.feature_dim, torch.finfo(torch.float32).eps), + nn.Conv1d(self.feature_dim, self.feature_dim*1*self.num_output, 1), + nn.Tanh(), + nn.Conv1d(self.feature_dim*1*self.num_output, self.feature_dim*1*self.num_output, 1, groups=self.num_output), + nn.Tanh(), + nn.Conv1d(self.feature_dim*1*self.num_output, self.band_width[i]*4*self.num_output, 1, groups=self.num_output) + ) + ) + self.map.append(nn.Sequential(nn.GroupNorm(1, self.feature_dim, torch.finfo(torch.float32).eps), + nn.Conv1d(self.feature_dim, self.feature_dim*1*self.num_output, 1), + nn.Tanh(), + nn.Conv1d(self.feature_dim*1*self.num_output, self.feature_dim*1*self.num_output, 1, groups=self.num_output), + nn.Tanh(), + nn.Conv1d(self.feature_dim*1*self.num_output, self.band_width[i]*4*self.num_output, 1, groups=self.num_output) + ) + ) + + def pad_input(self, input, window, stride): + """ + Zero-padding input according to window/stride size. + """ + batch_size, nsample = input.shape + + # pad the signals at the end for matching the window/stride size + rest = window - (stride + nsample % window) % window + if rest > 0: + pad = torch.zeros(batch_size, rest).type(input.type()) + input = torch.cat([input, pad], 1) + pad_aux = torch.zeros(batch_size, stride).type(input.type()) + input = torch.cat([pad_aux, input, pad_aux], 1) + + return input, rest + + def forward(self, input): + # input shape: (B, C, T) + + batch_size, nch, nsample = input.shape + input = input.view(batch_size*nch, -1) + + # frequency-domain separation + spec = torch.stft(input, n_fft=self.win, hop_length=self.stride, + window=torch.hann_window(self.win).to(input.device).type(input.type()), + return_complex=True) + + # concat real and imag, split to subbands + spec_RI = torch.stack([spec.real, spec.imag], 1) # B*nch, 2, F, T + subband_spec_RI = [] + subband_spec = [] + band_idx = 0 + for i in range(len(self.band_width)): + subband_spec_RI.append(spec_RI[:,:,band_idx:band_idx+self.band_width[i]].contiguous()) + subband_spec.append(spec[:,band_idx:band_idx+self.band_width[i]]) # B*nch, BW, T + band_idx += self.band_width[i] + + # normalization and bottleneck + subband_feature_mask = [] + for i in range(len(self.band_width)): + subband_feature_mask.append(self.BN_mask[i](subband_spec_RI[i].view(batch_size*nch, self.band_width[i]*2, -1))) + subband_feature_mask = torch.stack(subband_feature_mask, 1) # B, nband, N, T + + subband_feature_map = [] + for i in range(len(self.band_width)): + subband_feature_map.append(self.BN_map[i](subband_spec_RI[i].view(batch_size * nch, self.band_width[i] * 2, -1))) + subband_feature_map = torch.stack(subband_feature_map, 1) # B, nband, N, T + # separator + sep_output = checkpoint_sequential(self.separator_mask, 2, subband_feature_mask.view(batch_size, nch, self.nband*self.feature_dim, -1)) # B, nband*N, T + sep_output = sep_output.view(batch_size*nch, self.nband, self.feature_dim, -1) + combined = torch.cat((subband_feature_map,sep_output), dim=2) + combined1 = combined.reshape(batch_size * nch * self.nband,self.feature_dim*2,-1) + combined2 = self.Tanh(self.in_conv(combined1)) + combined3 = combined2.reshape(batch_size * nch, self.nband,self.feature_dim,-1) + sep_output2 = checkpoint_sequential(self.separator_map, 2, combined3.view(batch_size, nch, self.nband*self.feature_dim, -1)) # 1B, nband*N, T + sep_output2 = sep_output2.view(batch_size * nch, self.nband, self.feature_dim, -1) + + sep_subband_spec = [] + sep_subband_spec_mask = [] + for i in range(self.nband): + this_output = self.mask[i](sep_output[:,i]).view(batch_size*nch, 2, 2, self.num_output, self.band_width[i], -1) + this_mask = this_output[:,0] * torch.sigmoid(this_output[:,1]) # B*nch, 2, K, BW, T + this_mask_real = this_mask[:,0] # B*nch, K, BW, T + this_mask_imag = this_mask[:,1] # B*nch, K, BW, T + # force mask sum to 1 + this_mask_real_sum = this_mask_real.sum(1).unsqueeze(1) # B*nch, 1, BW, T + this_mask_imag_sum = this_mask_imag.sum(1).unsqueeze(1) # B*nch, 1, BW, T + this_mask_real = this_mask_real - (this_mask_real_sum - 1) / self.num_output + this_mask_imag = this_mask_imag - this_mask_imag_sum / self.num_output + est_spec_real = subband_spec[i].real.unsqueeze(1) * this_mask_real - subband_spec[i].imag.unsqueeze(1) * this_mask_imag # B*nch, K, BW, T + est_spec_imag = subband_spec[i].real.unsqueeze(1) * this_mask_imag + subband_spec[i].imag.unsqueeze(1) * this_mask_real # B*nch, K, BW, T + + ################################## + this_output2 = self.map[i](sep_output2[:,i]).view(batch_size*nch, 2, 2, self.num_output, self.band_width[i], -1) + this_map = this_output2[:,0] * torch.sigmoid(this_output2[:,1]) # B*nch, 2, K, BW, T + this_map_real = this_map[:,0] # B*nch, K, BW, T + this_map_imag = this_map[:,1] # B*nch, K, BW, T + est_spec_real2 = est_spec_real+this_map_real + est_spec_imag2 = est_spec_imag+this_map_imag + + sep_subband_spec.append(torch.complex(est_spec_real2, est_spec_imag2)) + sep_subband_spec_mask.append(torch.complex(est_spec_real, est_spec_imag)) + + sep_subband_spec = torch.cat(sep_subband_spec, 2) + est_spec_mask = torch.cat(sep_subband_spec_mask, 2) + + output = torch.istft(sep_subband_spec.view(batch_size*nch*self.num_output, self.enc_dim, -1), + n_fft=self.win, hop_length=self.stride, + window=torch.hann_window(self.win).to(input.device).type(input.type()), length=nsample) + output_mask = torch.istft(est_spec_mask.view(batch_size*nch*self.num_output, self.enc_dim, -1), + n_fft=self.win, hop_length=self.stride, + window=torch.hann_window(self.win).to(input.device).type(input.type()), length=nsample) + + output = output.view(batch_size, nch, self.num_output, -1).transpose(1,2).contiguous() + output_mask = output_mask.view(batch_size, nch, self.num_output, -1).transpose(1,2).contiguous() + # return output, output_mask + return output + + +if __name__ == '__main__': + model = Separator().cuda() + arr = np.zeros((1, 2, 3*44100), dtype=np.float32) + x = torch.from_numpy(arr).cuda() + res = model(x) diff --git a/separator/models/upernet_swin_transformers.py b/separator/models/upernet_swin_transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..520bc14ea826678320758ea2527ed48910e77680 --- /dev/null +++ b/separator/models/upernet_swin_transformers.py @@ -0,0 +1,228 @@ +from functools import partial +import torch +import torch.nn as nn +from transformers import UperNetForSemanticSegmentation +from msst_utils import prefer_target_instrument + +class STFT: + def __init__(self, config): + self.n_fft = config.n_fft + self.hop_length = config.hop_length + self.window = torch.hann_window(window_length=self.n_fft, periodic=True) + self.dim_f = config.dim_f + + def __call__(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-2] + c, t = x.shape[-2:] + x = x.reshape([-1, t]) + x = torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True, + return_complex=True + ) + x = torch.view_as_real(x) + x = x.permute([0, 3, 1, 2]) + x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) + return x[..., :self.dim_f, :] + + def inverse(self, x): + window = self.window.to(x.device) + batch_dims = x.shape[:-3] + c, f, t = x.shape[-3:] + n = self.n_fft // 2 + 1 + f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) + x = torch.cat([x, f_pad], -2) + x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) + x = x.permute([0, 2, 3, 1]) + x = x[..., 0] + x[..., 1] * 1.j + x = torch.istft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=window, + center=True + ) + x = x.reshape([*batch_dims, 2, -1]) + return x + + +def get_norm(norm_type): + def norm(c, norm_type): + if norm_type == 'BatchNorm': + return nn.BatchNorm2d(c) + elif norm_type == 'InstanceNorm': + return nn.InstanceNorm2d(c, affine=True) + elif 'GroupNorm' in norm_type: + g = int(norm_type.replace('GroupNorm', '')) + return nn.GroupNorm(num_groups=g, num_channels=c) + else: + return nn.Identity() + + return partial(norm, norm_type=norm_type) + + +def get_act(act_type): + if act_type == 'gelu': + return nn.GELU() + elif act_type == 'relu': + return nn.ReLU() + elif act_type[:3] == 'elu': + alpha = float(act_type.replace('elu', '')) + return nn.ELU(alpha) + else: + raise Exception + + +class Upscale(nn.Module): + def __init__(self, in_c, out_c, scale, norm, act): + super().__init__() + self.conv = nn.Sequential( + norm(in_c), + act, + nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) + ) + + def forward(self, x): + return self.conv(x) + + +class Downscale(nn.Module): + def __init__(self, in_c, out_c, scale, norm, act): + super().__init__() + self.conv = nn.Sequential( + norm(in_c), + act, + nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) + ) + + def forward(self, x): + return self.conv(x) + + +class TFC_TDF(nn.Module): + def __init__(self, in_c, c, l, f, bn, norm, act): + super().__init__() + + self.blocks = nn.ModuleList() + for i in range(l): + block = nn.Module() + + block.tfc1 = nn.Sequential( + norm(in_c), + act, + nn.Conv2d(in_c, c, 3, 1, 1, bias=False), + ) + block.tdf = nn.Sequential( + norm(c), + act, + nn.Linear(f, f // bn, bias=False), + norm(c), + act, + nn.Linear(f // bn, f, bias=False), + ) + block.tfc2 = nn.Sequential( + norm(c), + act, + nn.Conv2d(c, c, 3, 1, 1, bias=False), + ) + block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False) + + self.blocks.append(block) + in_c = c + + def forward(self, x): + for block in self.blocks: + s = block.shortcut(x) + x = block.tfc1(x) + x = x + block.tdf(x) + x = block.tfc2(x) + x = x + s + return x + + +class Swin_UperNet_Model(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + act = get_act(act_type=config.model.act) + + self.num_target_instruments = len(prefer_target_instrument(config)) + self.num_subbands = config.model.num_subbands + + dim_c = self.num_subbands * config.audio.num_channels * 2 + c = config.model.num_channels + f = config.audio.dim_f // self.num_subbands + + self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False) + + self.swin_upernet_model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-swin-large") + + self.swin_upernet_model.auxiliary_head.classifier = nn.Conv2d(256, c, kernel_size=(1, 1), stride=(1, 1)) + self.swin_upernet_model.decode_head.classifier = nn.Conv2d(512, c, kernel_size=(1, 1), stride=(1, 1)) + self.swin_upernet_model.backbone.embeddings.patch_embeddings.projection = nn.Conv2d(c, 192, kernel_size=(4, 4), stride=(4, 4)) + + self.final_conv = nn.Sequential( + nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), + act, + nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) + ) + + self.stft = STFT(config.audio) + + def cac2cws(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c, k, f // k, t) + x = x.reshape(b, c * k, f // k, t) + return x + + def cws2cac(self, x): + k = self.num_subbands + b, c, f, t = x.shape + x = x.reshape(b, c // k, k, f, t) + x = x.reshape(b, c // k, f * k, t) + return x + + def forward(self, x): + + x = self.stft(x) + + mix = x = self.cac2cws(x) + + first_conv_out = x = self.first_conv(x) + + x = x.transpose(-1, -2) + + x = self.swin_upernet_model(x).logits + + x = x.transpose(-1, -2) + + x = x * first_conv_out # reduce artifacts + + x = self.final_conv(torch.cat([mix, x], 1)) + + x = self.cws2cac(x) + + if self.num_target_instruments > 1: + b, c, f, t = x.shape + x = x.reshape(b, self.num_target_instruments, -1, f, t) + + x = self.stft.inverse(x) + return x + + +if __name__ == "__main__": + model = UperNetForSemanticSegmentation.from_pretrained("./results/", ignore_mismatched_sizes=True) + print(model) + print(model.auxiliary_head.classifier) + print(model.decode_head.classifier) + + x = torch.zeros((2, 16, 512, 512), dtype=torch.float32) + res = model(x) + print(res.logits.shape) + model.save_pretrained('./results/') \ No newline at end of file diff --git a/separator/msst_separator.py b/separator/msst_separator.py new file mode 100644 index 0000000000000000000000000000000000000000..918f8f76c347c0d3eac360b1d7ce589e8baadbcc --- /dev/null +++ b/separator/msst_separator.py @@ -0,0 +1,405 @@ +import os +import sys +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(SCRIPT_DIR) + +import argparse +import gradio as gr +import time +import librosa +from datetime import datetime +from tqdm.auto import tqdm +import json +import gc +import glob +import yaml +import torch +import numpy as np +import soundfile as sf +import torch.nn as nn +from audio_writer import write_audio_file +from renamer_stems import output_file_template + +from msst_utils import prefer_target_instrument, demix, get_model_from_config, demix_demucs + +def normalize_peak(audio, peak): + current_peak = np.max(np.abs(audio)) + if current_peak == 0: + return audio # избегаем деления на ноль + scale_factor = peak / current_peak + return audio * scale_factor + +gc.enable() + +def cleanup_model(model): + try: + if isinstance(model, torch.nn.DataParallel): + model = model.module + + model.to('cpu') + + for name, param in list(model.named_parameters()): + del param + for name, buf in list(model.named_buffers()): + del buf + + del model + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + gc.collect() + print("Модель выгружена из памяти") + except Exception as e: + print(f"Ошибка при выгрузке модели: {str(e)}") + +def once_inference( + path, + model, + config, + device, + model_type, + extract_instrumental, + detailed_pbar, + output_format, + output_bitrate, + use_tta, + verbose, + model_name, + sample_rate, + instruments, + store_dir, + template, + selected_instruments +): + results = [] + progress = gr.Progress(track_tqdm=True) + print("Выбранное аудио:", path) + print("Выбранные стемы:", selected_instruments) + print("Стемы, которые будут сохранены:", instruments) + + try: + mix, sr = librosa.load(path, sr=sample_rate, mono=False) + if mix.ndim == 1: + mix = np.stack([mix, mix], axis=0) + except Exception as e: + print(f"Не удалось прочитать аудио: {path}\nОшибка: {e}") + return results + + mix_orig = mix.copy() + + mean = std = None + if config.inference.get('normalize', False): + mono = mix.mean(0) + mean = mono.mean() + std = mono.std() + mix = (mix - mean) / std + + if use_tta: + track_proc_list = [mix.copy(), mix[::-1].copy(), -1. * mix.copy()] + else: + track_proc_list = [mix.copy()] + + full_result = [] + for m in track_proc_list: + try: + if model_type != "htdemucs": + waveforms = demix(config, model, m, device, pbar=detailed_pbar, model_type=model_type) + elif model_type == "htdemucs": + waveforms = demix_demucs(config, model, m, device, pbar=detailed_pbar, model_type=model_type) + + full_result.append(waveforms) + except Exception as e: + print(f"Ошибка при демиксе: {e}") + del m + gc.collect() + + if not full_result: + print("Пустой результат демикса.") + return results + + waveforms = full_result[0] + for i in range(1, len(full_result)): + d = full_result[i] + for el in d: + if i == 2: + waveforms[el] += -1.0 * d[el] + elif i == 1: + waveforms[el] += d[el][::-1].copy() + else: + waveforms[el] += d[el] + for el in waveforms: + waveforms[el] /= len(full_result) + + if extract_instrumental and config.training.target_instrument is not None: # Если включен "Extract Instrumental / Извлечь инструментал" и найден целевой инструмент + second_stem = [s for s in config.training.instruments if s != config.training.target_instrument] + if second_stem: + second_stem_key = second_stem[0] + if second_stem_key not in instruments: + instruments.append(second_stem_key) + waveforms[second_stem_key] = mix_orig - waveforms[instruments[0]] + + elif extract_instrumental and selected_instruments and config.training.target_instrument is None: # Если включен "Extract Instrumental / Извлечь инструментал" и выбраны инструменты, то создаются стемы "inverted -" и "inverted +" (если не найден целевого инструмент) + waveforms['inverted -'] = mix_orig.copy() + for instr in instruments: + if instr in waveforms: + waveforms['inverted -'] -= waveforms[instr] # стем "inverted -": вычитание выбранного стема из оригинального сигнала (не всегда хорошо) + + if 'inverted -' not in instruments: + instruments.append('inverted -') + + all_instruments = config.training.instruments + unselected_stems = [s for s in all_instruments if s not in selected_instruments] + if unselected_stems: + waveforms['inverted +'] = np.zeros_like(mix_orig) + for stem in unselected_stems: + if stem in waveforms: + waveforms['inverted +'] += waveforms[stem] # стем "inverted +": сложение не выбранных инструментов в один стем + if 'inverted +' not in instruments: + instruments.append('inverted +') + + peak = np.max(np.abs(waveforms['inverted -'])) + waveforms['inverted +'] = normalize_peak(waveforms['inverted +'], peak) + + elif (extract_instrumental and not selected_instruments and config.training.target_instrument is None and + (all(instr in config.training.instruments for instr in ["bass", "drums", "other", "vocals"]) or + all(instr in config.training.instruments for instr in ["bass", "drums", "other", "vocals", "piano", "guitar"]))): + + waveforms['instrumental -'] = mix_orig.copy() + waveforms['instrumental -'] -= waveforms["vocals"] # стем "inverted -": вычитание выбранного стема из оригинального сигнала (не всегда хорошо) + + if 'instrumental -' not in instruments: + instruments.append('instrumental -') + + all_instruments = config.training.instruments + non_vocal_stems = [s for s in all_instruments if s not in ["vocals"]] + if non_vocal_stems: + waveforms['instrumental +'] = np.zeros_like(mix_orig) + for stem in non_vocal_stems: + if stem in waveforms: + waveforms['instrumental +'] += waveforms[stem] # стем "inverted +": сложение не выбранных инструментов в один стем + if 'instrumental +' not in instruments: + instruments.append('instrumental +') + + peak = np.max(np.abs(waveforms['instrumental -'])) + waveforms['instrumental +'] = normalize_peak(waveforms['instrumental +'], peak) + + for instr in instruments: + try: + estimates = waveforms[instr].T + if mean is not None and std is not None: + estimates = estimates * std + mean + + file_name = os.path.splitext(os.path.basename(path))[0] + custom_name = output_file_template(template, file_name, instr, model_name) + output_path = os.path.join(store_dir, f"{custom_name}.{output_format}") + + write_audio_file(output_path, estimates, sr, output_format, output_bitrate) # запись стема в аудио файл с помощью универсальной функции + + results.append((instr, output_path)) # запись информации о разделении: (название стема, путь к файлу) + del estimates + except Exception as e: + print(f"Ошибка при обработке {instr}: {e}") + gc.collect() + + del mix, mix_orig, waveforms, full_result + librosa.cache.clear() + gc.collect() + + return results + +def run_inference( + model, + config, + input_path, + store_dir, + device, + model_type, + extract_instrumental, + disable_detailed_pbar, + output_format, + output_bitrate, + use_tta, + verbose, + model_name, + template='NAME_STEM', + selected_instruments=None +): + start_time = time.time() + model.eval() + sample_rate = 44100 + if 'sample_rate' in config.audio: + sample_rate = config.audio['sample_rate'] + + instruments = prefer_target_instrument(config) + + if config.training.target_instrument is not None: + print("Целевой инструмент найден в конфигурации модели. Выбранные стемы будут проигнорированы.") + else: + if selected_instruments is not None and selected_instruments != []: + instruments = [instr for instr in instruments if instr in selected_instruments] + if verbose: + print(f"Выбранные стемы: {instruments}") + + os.makedirs(store_dir, exist_ok=True) + + detailed_pbar = not disable_detailed_pbar + + results = once_inference( + input_path, model, config, device, model_type, extract_instrumental, + detailed_pbar, output_format, output_bitrate, use_tta, verbose, + model_name, sample_rate, instruments, store_dir, template, selected_instruments + ) + + time.sleep(1) + print(f"Потрачено времени: {time.time() - start_time:.2f} сек.") + return results + +def load_model(model_type, config_path, start_check_point, device_ids, force_cpu=False): + device = "cpu" + if force_cpu: + device = "cpu" + elif torch.cuda.is_available(): + print('Разделение выполняется на ядрах CUDA. Для выполнения на процессоре установите force_cpu=True.') + device = "cuda" + + if device_ids is None: + device = "cuda:0" + elif isinstance(device_ids, (list, tuple)): + device = f'cuda:{device_ids[0]}' if device_ids else 'cuda:0' + elif isinstance(device_ids, bool): + device = "cuda:0" + else: + device = f'cuda:{int(device_ids)}' + elif torch.backends.mps.is_available(): + device = "mps" + + print(f"Используется устройство: {device}") + + model_load_start_time = time.time() + torch.backends.cudnn.benchmark = True + + model, config = get_model_from_config(model_type, config_path) + if start_check_point != '': + print(f'Выбранный чекпоинт: {start_check_point}') + if model_type in ['htdemucs', 'apollo']: + state_dict = torch.load(start_check_point, map_location=device, weights_only=False) + if 'state' in state_dict: + state_dict = state_dict['state'] + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + else: + state_dict = torch.load(start_check_point, map_location=device, weights_only=True) + model.load_state_dict(state_dict) + print(f"Стемы: {config.training.instruments}") + + if isinstance(device_ids, (list, tuple)) and len(device_ids) > 1 and not force_cpu and torch.cuda.is_available(): + model = nn.DataParallel(model, device_ids=[int(d) for d in device_ids]) + + model = model.to(device) + + print(f"Потрачено времени на загрузку модели: {time.time() - model_load_start_time:.2f} сек.") + + return model, config, device + +def mvsep_offline( + input_path, + store_dir, + model_type, + config_path, + start_check_point, + extract_instrumental, + output_format, + output_bitrate, + model_name, + template, + device_ids=None, + disable_detailed_pbar=False, + use_tta=False, + force_cpu=False, + verbose=False, + selected_instruments=None, + save_results_info=False +): + model, config, device = load_model(model_type, config_path, start_check_point, device_ids, force_cpu) + + results = run_inference( + model, config, input_path, store_dir, device, model_type, extract_instrumental, + disable_detailed_pbar, output_format, output_bitrate, use_tta, verbose, + model_name, template, selected_instruments + ) + + if save_results_info: + + with open(os.path.join(store_dir, "results.json"), 'w') as f: + json.dump(results, f) + + cleanup_model(model) + del config + gc.collect() + return results + + +def parse_args(): + parser = argparse.ArgumentParser(description='Модифицированный Music-Source-Separation-Training для разделения аудио на источники') + + # Обязательные аргументы + parser.add_argument('--input', type=str, help='Путь к входному файлу или папке') + parser.add_argument('--input_list', nargs='+', help='Список с путями к входным файлам') + parser.add_argument('--store_dir', type=str, required=True, help='Путь для сохранения результатов') + + # Основные параметры модели + parser.add_argument('--model_type', type=str, default='htdemucs', choices=["mel_band_roformer", "bs_roformer", "mdx23c", "scnet", "htdemucs", "bandit", "bandit_v2"], help='Тип модели (по умолчанию: htdemucs)') + parser.add_argument('--config_path', type=str, required=True, help='Путь к конфигурационному файлу модели') + parser.add_argument('--start_check_point', type=str, required=True, help='Путь к чекпоинту модели') + + # Параметры вывода + parser.add_argument('--output_format', type=str, default='wav', choices=["wav", "mp3", "flac", "m4a", "aac", "aiff", "ogg", "opus"], help='Формат выходных файлов') + parser.add_argument('--output_bitrate', type=str, required=True, help='Битрейт выходного файла') + + # Опциональные параметры + parser.add_argument('--batch', action='store_true', help='Обработать все файлы в папке') + parser.add_argument('--batch_list', action='store_true', help='Обработать все файлы в списке') + parser.add_argument('--selected_instruments', nargs='+', help='Список стемов для сохранения (например: vocals drums)') + parser.add_argument('--extract_instrumental', action='store_true', help='Извлечь инструментальную версию') + parser.add_argument('--template', type=str, default='NAME_STEM', help='Шаблон для имен выходных файлов') + parser.add_argument('--model_name', type=str, default='model', help='Имя модели для шаблона имен файлов') + parser.add_argument('--device_ids', nargs='+', help='ID GPU устройств для использования') + parser.add_argument('--force_cpu', action='store_true', help='Принудительно использовать CPU') + parser.add_argument('--use_tta', action='store_true', help='Использовать тестовую аугментацию') + parser.add_argument('--disable_detailed_pbar', action='store_true', help='Отключить детальный прогресс-бар') + parser.add_argument('--verbose', action='store_true', help='Подробный вывод') + parser.add_argument('--save_results_info', action='store_true', help='Сохранить данные разделения в {args.store_dir}/results.json для отображения в интерфейсе') + + return parser.parse_args() + +def main(): + args = parse_args() + + device_ids = None + if args.device_ids: + device_ids = [int(x) for x in args.device_ids] + + results = mvsep_offline( + input_path=args.input, + store_dir=args.store_dir, + model_type=args.model_type, + config_path=args.config_path, + start_check_point=args.start_check_point, + extract_instrumental=args.extract_instrumental, + output_format=args.output_format, + output_bitrate=args.output_bitrate, + model_name=args.model_name, + template=args.template, + device_ids=device_ids, + disable_detailed_pbar=args.disable_detailed_pbar, + use_tta=args.use_tta, + force_cpu=args.force_cpu, + verbose=args.verbose, + selected_instruments=args.selected_instruments, + save_results_info=args.save_results_info, + ) + +if __name__ == "__main__": + main() diff --git a/separator/msst_utils.py b/separator/msst_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..56f182a3911a2ea867ea0a31c79f049cb7baf0d6 --- /dev/null +++ b/separator/msst_utils.py @@ -0,0 +1,900 @@ +# coding: utf-8 +__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' + +import numpy as np +import torch +import torch.nn as nn +import yaml +import librosa +import torch.nn.functional as F +from ml_collections import ConfigDict +from omegaconf import OmegaConf +from tqdm.auto import tqdm +from typing import Dict, List, Tuple, Any, List, Optional + + +def load_config(model_type: str, config_path: str) -> Any: + """ + Load the configuration from the specified path based on the model type. + + Parameters: + ---------- + model_type : str + The type of model to load (e.g., 'htdemucs', 'mdx23c', etc.). + config_path : str + The path to the YAML or OmegaConf configuration file. + + Returns: + ------- + config : Any + The loaded configuration, which can be in different formats (e.g., OmegaConf or ConfigDict). + + Raises: + ------ + FileNotFoundError: + If the configuration file at `config_path` is not found. + ValueError: + If there is an error loading the configuration file. + """ + try: + with open(config_path, 'r') as f: + if model_type == 'htdemucs': + config = OmegaConf.load(config_path) + else: + config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) + return config + except FileNotFoundError: + raise FileNotFoundError(f"Configuration file not found at {config_path}") + except Exception as e: + raise ValueError(f"Error loading configuration: {e}") + + +def get_model_from_config(model_type: str, config_path: str) -> Tuple: + """ + Load the model specified by the model type and configuration file. + + Parameters: + ---------- + model_type : str + The type of model to load (e.g., 'mdx23c', 'htdemucs', 'scnet', etc.). + config_path : str + The path to the configuration file (YAML or OmegaConf format). + + Returns: + ------- + model : nn.Module or None + The initialized model based on the `model_type`, or None if the model type is not recognized. + config : Any + The configuration used to initialize the model. This could be in different formats + depending on the model type (e.g., OmegaConf, ConfigDict). + + Raises: + ------ + ValueError: + If the `model_type` is unknown or an error occurs during model initialization. + """ + + config = load_config(model_type, config_path) + + if model_type == 'mdx23c': + from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net + model = TFC_TDF_net(config) + elif model_type == 'htdemucs': + from models.demucs4ht import get_model + model = get_model(config) + elif model_type == 'segm_models': + from models.segm_models import Segm_Models_Net + model = Segm_Models_Net(config) + elif model_type == 'torchseg': + from models.torchseg_models import Torchseg_Net + model = Torchseg_Net(config) + + elif model_type == 'mel_band_roformer': + from models.bs_roformer import MelBandRoformer + model = MelBandRoformer(**dict(config.model)) + elif model_type == 'bs_roformer': + if hasattr(config.model, 'use_shared_bias'): + from models.bs_roformer.bs_roformer_sw import BSRoformer_SW + model = BSRoformer_SW(**dict(config.model)) + else: + from models.bs_roformer import BSRoformer + model = BSRoformer(**dict(config.model)) + elif model_type == 'swin_upernet': + from models.upernet_swin_transformers import Swin_UperNet_Model + model = Swin_UperNet_Model(config) + elif model_type == 'bandit': + from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple + model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model) + elif model_type == 'bandit_v2': + from models.bandit_v2.bandit import Bandit + model = Bandit(**config.kwargs) + elif model_type == 'scnet_unofficial': + from models.scnet_unofficial import SCNet + model = SCNet(**config.model) + elif model_type == 'scnet': + from models.scnet import SCNet + model = SCNet(**config.model) + elif model_type == 'apollo': + from models.look2hear.models import BaseModel + model = BaseModel.apollo(**config.model) + elif model_type == 'bs_mamba2': + from models.ts_bs_mamba2 import Separator + model = Separator(**config.model) + else: + raise ValueError(f"Unknown model type: {model_type}") + + return model, config + + +def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor: + """ + Generate a windowing array with a linear fade-in at the beginning and a fade-out at the end. + + This function creates a window of size `window_size` where the first `fade_size` elements + linearly increase from 0 to 1 (fade-in) and the last `fade_size` elements linearly decrease + from 1 to 0 (fade-out). The middle part of the window is filled with ones. + + Parameters: + ---------- + window_size : int + The total size of the window. + fade_size : int + The size of the fade-in and fade-out regions. + + Returns: + ------- + torch.Tensor + A tensor of shape (window_size,) containing the generated windowing array. + + Example: + ------- + If `window_size=10` and `fade_size=3`, the output will be: + tensor([0.0000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.5000, 0.0000]) + """ + + fadein = torch.linspace(0, 1, fade_size) + fadeout = torch.linspace(1, 0, fade_size) + + window = torch.ones(window_size) + window[-fade_size:] = fadeout + window[:fade_size] = fadein + return window + + +def demix( + config: ConfigDict, + model: torch.nn.Module, + mix: torch.Tensor, + device: torch.device, + model_type: str, + pbar: bool = False +) -> Tuple[List[Dict[str, np.ndarray]], np.ndarray]: + """ + Unified function for audio source separation with support for multiple processing modes. + + This function separates audio into its constituent sources using either a generic custom logic + or a Demucs-specific logic. It supports batch processing and overlapping window-based chunking + for efficient and artifact-free separation. + + Parameters: + ---------- + config : ConfigDict + Configuration object containing audio and inference settings. + model : torch.nn.Module + The trained model used for audio source separation. + mix : torch.Tensor + Input audio tensor with shape (channels, time). + device : torch.device + The computation device (CPU or CUDA). + model_type : str, optional + Processing mode: + - "demucs" for logic specific to the Demucs model. + Default is "generic". + pbar : bool, optional + If True, displays a progress bar during chunk processing. Default is False. + + Returns: + ------- + Union[Dict[str, np.ndarray], np.ndarray] + - A dictionary mapping target instruments to separated audio sources if multiple instruments are present. + - A numpy array of the separated source if only one instrument is present. + """ + + mix = torch.tensor(mix, dtype=torch.float32) + + if model_type == 'htdemucs': + mode = 'demucs' + else: + mode = 'generic' + # Define processing parameters based on the mode + if mode == 'demucs': + chunk_size = config.training.samplerate * config.training.segment + num_instruments = len(config.training.instruments) + num_overlap = config.inference.num_overlap + step = chunk_size // num_overlap + else: + chunk_size = config.audio.chunk_size + num_instruments = len(prefer_target_instrument(config)) + num_overlap = config.inference.num_overlap + + fade_size = chunk_size // 10 + step = chunk_size // num_overlap + border = chunk_size - step + length_init = mix.shape[-1] + windowing_array = _getWindowingArray(chunk_size, fade_size) + # Add padding for generic mode to handle edge artifacts + if length_init > 2 * border and border > 0: + mix = nn.functional.pad(mix, (border, border), mode="reflect") + + batch_size = config.inference.batch_size + + use_amp = getattr(config.training, 'use_amp', True) # Works for both OmegaConf and ConfigDict + + with torch.cuda.amp.autocast(enabled=use_amp): + with torch.inference_mode(): + # Initialize result and counter tensors + req_shape = (num_instruments,) + mix.shape + result = torch.zeros(req_shape, dtype=torch.float32) + counter = torch.zeros(req_shape, dtype=torch.float32) + + i = 0 + batch_data = [] + batch_locations = [] + progress_bar = tqdm( + total=mix.shape[1], desc="Обработка аудио фрагментов", leave=False + ) if pbar else None + + while i < mix.shape[1]: + # Extract chunk and apply padding if necessary + part = mix[:, i:i + chunk_size].to(device) + chunk_len = part.shape[-1] + if mode == "generic" and chunk_len > chunk_size // 2: + pad_mode = "reflect" + else: + pad_mode = "constant" + part = nn.functional.pad(part, (0, chunk_size - chunk_len), mode=pad_mode, value=0) + + batch_data.append(part) + batch_locations.append((i, chunk_len)) + i += step + + # Process batch if it's full or the end is reached + if len(batch_data) >= batch_size or i >= mix.shape[1]: + arr = torch.stack(batch_data, dim=0) + x = model(arr) + + if mode == "generic": + window = windowing_array.clone() # fix for clicks issue with batch_size=1 + if i - step == 0: # First audio chunk, no fadein + window[:fade_size] = 1 + elif i >= mix.shape[1]: # Last audio chunk, no fadeout + window[-fade_size:] = 1 + + for j, (start, seg_len) in enumerate(batch_locations): + if mode == "generic": + result[..., start:start + seg_len] += x[j, ..., :seg_len].cpu() * window[..., :seg_len] + counter[..., start:start + seg_len] += window[..., :seg_len] + else: + result[..., start:start + seg_len] += x[j, ..., :seg_len].cpu() + counter[..., start:start + seg_len] += 1.0 + + batch_data.clear() + batch_locations.clear() + + if progress_bar: + progress_bar.update(step) + + if progress_bar: + progress_bar.close() + + # Compute final estimated sources + estimated_sources = result / counter + estimated_sources = estimated_sources.cpu().numpy() + np.nan_to_num(estimated_sources, copy=False, nan=0.0) + + # Remove padding for generic mode + if mode == "generic": + if length_init > 2 * border and border > 0: + estimated_sources = estimated_sources[..., border:-border] + + # Return the result as a dictionary or a single array + if mode == "demucs": + instruments = config.training.instruments + else: + instruments = prefer_target_instrument(config) + + ret_data = {k: v for k, v in zip(instruments, estimated_sources)} + + if mode == "demucs" and num_instruments <= 1: + return estimated_sources + else: + return ret_data + + + + + + + + + + + +def demix_demucs(config, model, mix, device, model_type, pbar=False): + mix = torch.tensor(mix, dtype=torch.float32) + + if model_type == 'htdemucs': + mode = 'demucs' + else: + mode = 'generic' + + if mode == 'demucs': + chunk_size = config.training.samplerate * config.training.segment + num_instruments = len(config.training.instruments) + num_overlap = config.inference.num_overlap + step = chunk_size // num_overlap + fade_size = chunk_size // 10 # Добавляем fade_size для оконной функции + windowing_array = _getWindowingArray(chunk_size, fade_size) # Создаём окно + else: + chunk_size = config.audio.chunk_size + num_instruments = len(prefer_target_instrument(config)) + num_overlap = config.inference.num_overlap + fade_size = chunk_size // 10 + step = chunk_size // num_overlap + border = chunk_size - step + length_init = mix.shape[-1] + windowing_array = _getWindowingArray(chunk_size, fade_size) + if length_init > 2 * border and border > 0: + mix = nn.functional.pad(mix, (border, border), mode="reflect") + + batch_size = config.inference.batch_size + use_amp = getattr(config.training, 'use_amp', True) + + with torch.cuda.amp.autocast(enabled=use_amp): + with torch.inference_mode(): + req_shape = (num_instruments,) + mix.shape + result = torch.zeros(req_shape, dtype=torch.float32) + counter = torch.zeros(req_shape, dtype=torch.float32) + + i = 0 + batch_data = [] + batch_locations = [] + progress_bar = tqdm(total=mix.shape[1], desc="Обработка аудио фрагментов", leave=False) if pbar else None + + while i < mix.shape[1]: + part = mix[:, i:i + chunk_size].to(device) + chunk_len = part.shape[-1] + pad_mode = "reflect" if chunk_len > chunk_size // 2 else "constant" + part = nn.functional.pad(part, (0, chunk_size - chunk_len), mode=pad_mode, value=0) + + batch_data.append(part) + batch_locations.append((i, chunk_len)) + i += step + + if len(batch_data) >= batch_size or i >= mix.shape[1]: + arr = torch.stack(batch_data, dim=0) + x = model(arr) + + window = windowing_array.clone() + if i - step == 0: # Первый чанк, без fade-in + window[:fade_size] = 1 + elif i >= mix.shape[1]: # Последний чанк, без fade-out + window[-fade_size:] = 1 + + for j, (start, seg_len) in enumerate(batch_locations): + result[..., start:start + seg_len] += x[j, ..., :seg_len].cpu() * window[..., :seg_len] + counter[..., start:start + seg_len] += window[..., :seg_len] + + batch_data.clear() + batch_locations.clear() + + if progress_bar: + progress_bar.update(step) + + if progress_bar: + progress_bar.close() + + estimated_sources = result / counter + estimated_sources = estimated_sources.cpu().numpy() + np.nan_to_num(estimated_sources, copy=False, nan=0.0) + + if mode == "demucs" and num_instruments <= 1: + return estimated_sources + else: + instruments = config.training.instruments + return {k: v for k, v in zip(instruments, estimated_sources)} + + + + + + + + + + + + + + + + + + + + + + + +def sdr(references: np.ndarray, estimates: np.ndarray) -> np.ndarray: + """ + Compute Signal-to-Distortion Ratio (SDR) for one or more audio tracks. + + SDR is a measure of how well the predicted source (estimate) matches the reference source. + It is calculated as the ratio of the energy of the reference signal to the energy of the error (difference between reference and estimate). + Return SDR in decibels (dB) + Parameters: + ---------- + references : np.ndarray + A 3D numpy array of shape (num_sources, num_channels, num_samples), where num_sources is the number of sources, + num_channels is the number of channels (e.g., 1 for mono, 2 for stereo), and num_samples is the length of the audio signal. + + estimates : np.ndarray + A 3D numpy array of shape (num_sources, num_channels, num_samples) representing the estimated sources. + + Returns: + ------- + np.ndarray + A 1D numpy array containing the SDR values for each source. + """ + eps = 1e-8 # to avoid numerical errors + num = np.sum(np.square(references), axis=(1, 2)) + den = np.sum(np.square(references - estimates), axis=(1, 2)) + num += eps + den += eps + return 10 * np.log10(num / den) + + +def si_sdr(reference: np.ndarray, estimate: np.ndarray) -> float: + """ + Compute Scale-Invariant Signal-to-Distortion Ratio (SI-SDR) for one or more audio tracks. + + SI-SDR is a variant of the SDR metric that is invariant to the scaling of the estimate relative to the reference. + It is calculated by scaling the estimate to match the reference signal and then computing the SDR. + + Parameters: + ---------- + reference : np.ndarray + A 3D numpy array of shape (num_sources, num_channels, num_samples), where num_sources is the number of sources, + num_channels is the number of channels (e.g., 1 for mono, 2 for stereo), and num_samples is the length of the audio signal. + + estimate : np.ndarray + A 3D numpy array of shape (num_sources, num_channels, num_samples) representing the estimated sources. + + Returns: + ------- + float + The SI-SDR value for the source. It is a scalar representing the Signal-to-Distortion Ratio in decibels (dB). + """ + eps = 1e-8 # To avoid numerical errors + scale = np.sum(estimate * reference + eps, axis=(0, 1)) / np.sum(reference ** 2 + eps, axis=(0, 1)) + scale = np.expand_dims(scale, axis=(0, 1)) # Reshape to [num_sources, 1] + + reference = reference * scale + si_sdr = np.mean(10 * np.log10( + np.sum(reference ** 2, axis=(0, 1)) / (np.sum((reference - estimate) ** 2, axis=(0, 1)) + eps) + eps)) + + return si_sdr + + +def L1Freq_metric( + reference: np.ndarray, + estimate: np.ndarray, + fft_size: int = 2048, + hop_size: int = 1024, + device: str = 'cpu' +) -> float: + """ + Compute the L1 Frequency Metric between the reference and estimated audio signals. + + This metric compares the magnitude spectrograms of the reference and estimated audio signals + using the Short-Time Fourier Transform (STFT) and calculates the L1 loss between them. The result + is scaled to the range [0, 100] where a higher value indicates better performance. + + Parameters: + ---------- + reference : np.ndarray + A 2D numpy array of shape (num_channels, num_samples) representing the reference (ground truth) audio signal. + + estimate : np.ndarray + A 2D numpy array of shape (num_channels, num_samples) representing the estimated (predicted) audio signal. + + fft_size : int, optional + The size of the FFT (Short-Time Fourier Transform). Default is 2048. + + hop_size : int, optional + The hop size between STFT frames. Default is 1024. + + device : str, optional + The device to run the computation on ('cpu' or 'cuda'). Default is 'cpu'. + + Returns: + ------- + float + The L1 Frequency Metric in the range [0, 100], where higher values indicate better performance. + """ + + reference = torch.from_numpy(reference).to(device) + estimate = torch.from_numpy(estimate).to(device) + + reference_stft = torch.stft(reference, fft_size, hop_size, return_complex=True) + estimated_stft = torch.stft(estimate, fft_size, hop_size, return_complex=True) + + reference_mag = torch.abs(reference_stft) + estimate_mag = torch.abs(estimated_stft) + + loss = 10 * F.l1_loss(estimate_mag, reference_mag) + + ret = 100 / (1. + float(loss.cpu().numpy())) + + return ret + + +def LogWMSE_metric( + reference: np.ndarray, + estimate: np.ndarray, + mixture: np.ndarray, + device: str = 'cpu', +) -> float: + """ + Calculate the Log-WMSE (Logarithmic Weighted Mean Squared Error) between the reference, estimate, and mixture signals. + + This metric evaluates the quality of the estimated signal compared to the reference signal in the + context of audio source separation. The result is given in logarithmic scale, which helps in evaluating + signals with large amplitude differences. + + Parameters: + ---------- + reference : np.ndarray + The ground truth audio signal of shape (channels, time), where channels is the number of audio channels + (e.g., 1 for mono, 2 for stereo) and time is the length of the audio in samples. + + estimate : np.ndarray + The estimated audio signal of shape (channels, time). + + mixture : np.ndarray + The mixed audio signal of shape (channels, time). + + device : str, optional + The device to run the computation on, either 'cpu' or 'cuda'. Default is 'cpu'. + + Returns: + ------- + float + The Log-WMSE value, which quantifies the difference between the reference and estimated signal on a logarithmic scale. + """ + from torch_log_wmse import LogWMSE + log_wmse = LogWMSE( + audio_length=reference.shape[-1] / 44100, # audio length in seconds + sample_rate=44100, # sample rate of 44100 Hz + return_as_loss=False, # return as loss (False means return as metric) + bypass_filter=False, # bypass frequency filtering (False means apply filter) + ) + + reference = torch.from_numpy(reference).unsqueeze(0).unsqueeze(0).to(device) + estimate = torch.from_numpy(estimate).unsqueeze(0).unsqueeze(0).to(device) + mixture = torch.from_numpy(mixture).unsqueeze(0).to(device) + + res = log_wmse(mixture, reference, estimate) + + return float(res.cpu().numpy()) + + +def AuraSTFT_metric( + reference: np.ndarray, + estimate: np.ndarray, + device: str = 'cpu', +) -> float: + """ + Calculate the AuraSTFT metric, which evaluates the spectral difference between the reference and estimated + audio signals using Short-Time Fourier Transform (STFT) loss. + + The AuraSTFT metric computes the STFT loss in both logarithmic and linear magnitudes, and it is commonly used + to assess the quality of audio separation tasks. The result is returned as a value scaled to the range [0, 100]. + + Parameters: + ---------- + reference : np.ndarray + The ground truth audio signal of shape (channels, time), where channels is the number of audio channels + (e.g., 1 for mono, 2 for stereo) and time is the length of the audio in samples. + + estimate : np.ndarray + The estimated audio signal of shape (channels, time). + + device : str, optional + The device to run the computation on, either 'cpu' or 'cuda'. Default is 'cpu'. + + Returns: + ------- + float + The AuraSTFT metric value, scaled to the range [0, 100], which quantifies the difference between + the reference and estimated signal in the spectral domain. + """ + + from auraloss.freq import STFTLoss + + stft_loss = STFTLoss( + w_log_mag=1.0, # weight for log magnitude + w_lin_mag=0.0, # weight for linear magnitude + w_sc=1.0, # weight for spectral centroid + device=device, + ) + + reference = torch.from_numpy(reference).unsqueeze(0).to(device) + estimate = torch.from_numpy(estimate).unsqueeze(0).to(device) + + res = 100 / (1. + 10 * stft_loss(reference, estimate)) + return float(res.cpu().numpy()) + + +def AuraMRSTFT_metric( + reference: np.ndarray, + estimate: np.ndarray, + device: str = 'cpu', +) -> float: + """ + Calculate the AuraMRSTFT metric, which evaluates the spectral difference between the reference and estimated + audio signals using Multi-Resolution Short-Time Fourier Transform (STFT) loss. + + The AuraMRSTFT metric uses multi-resolution STFT analysis, which allows better representation of both + low- and high-frequency components in the audio signals. The result is returned as a value scaled to the range [0, 100]. + + Parameters: + ---------- + reference : np.ndarray + The ground truth audio signal of shape (channels, time), where channels is the number of audio channels + (e.g., 1 for mono, 2 for stereo) and time is the length of the audio in samples. + + estimate : np.ndarray + The estimated audio signal of shape (channels, time). + + device : str, optional + The device to run the computation on, either 'cpu' or 'cuda'. Default is 'cpu'. + + Returns: + ------- + float + The AuraMRSTFT metric value, scaled to the range [0, 100], which quantifies the difference between + the reference and estimated signal in the multi-resolution spectral domain. + """ + + from auraloss.freq import MultiResolutionSTFTLoss + + mrstft_loss = MultiResolutionSTFTLoss( + fft_sizes=[1024, 2048, 4096], + hop_sizes=[256, 512, 1024], + win_lengths=[1024, 2048, 4096], + scale="mel", # mel scale for frequency resolution + n_bins=128, # number of bins for mel scale + sample_rate=44100, + perceptual_weighting=True, # apply perceptual weighting + device=device + ) + + reference = torch.from_numpy(reference).unsqueeze(0).float().to(device) + estimate = torch.from_numpy(estimate).unsqueeze(0).float().to(device) + + res = 100 / (1. + 10 * mrstft_loss(reference, estimate)) + return float(res.cpu().numpy()) + + +def bleed_full( + reference: np.ndarray, + estimate: np.ndarray, + sr: int = 44100, + n_fft: int = 4096, + hop_length: int = 1024, + n_mels: int = 512, + device: str = 'cpu', +) -> Tuple[float, float]: + """ + Calculate the 'bleed' and 'fullness' metrics between a reference and an estimated audio signal. + + The 'bleed' metric measures how much the estimated signal bleeds into the reference signal, + while the 'fullness' metric measures how much the estimated signal retains its distinctiveness + in relation to the reference signal, both using mel spectrograms and decibel scaling. + + Parameters: + ---------- + reference : np.ndarray + The reference audio signal, shape (channels, time), where channels is the number of audio channels + (e.g., 1 for mono, 2 for stereo) and time is the length of the audio in samples. + + estimate : np.ndarray + The estimated audio signal, shape (channels, time). + + sr : int, optional + The sample rate of the audio signals. Default is 44100 Hz. + + n_fft : int, optional + The FFT size used to compute the STFT. Default is 4096. + + hop_length : int, optional + The hop length for STFT computation. Default is 1024. + + n_mels : int, optional + The number of mel frequency bins. Default is 512. + + device : str, optional + The device for computation, either 'cpu' or 'cuda'. Default is 'cpu'. + + Returns: + ------- + tuple + A tuple containing two values: + - `bleedless` (float): A score indicating how much 'bleeding' the estimated signal has (higher is better). + - `fullness` (float): A score indicating how 'full' the estimated signal is (higher is better). + """ + + from torchaudio.transforms import AmplitudeToDB + + reference = torch.from_numpy(reference).float().to(device) + estimate = torch.from_numpy(estimate).float().to(device) + + window = torch.hann_window(n_fft).to(device) + + # Compute STFTs with the Hann window + D1 = torch.abs(torch.stft(reference, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True, + pad_mode="constant")) + D2 = torch.abs(torch.stft(estimate, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True, + pad_mode="constant")) + + mel_basis = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels) + mel_filter_bank = torch.from_numpy(mel_basis).to(device) + + S1_mel = torch.matmul(mel_filter_bank, D1) + S2_mel = torch.matmul(mel_filter_bank, D2) + + S1_db = AmplitudeToDB(stype="magnitude", top_db=80)(S1_mel) + S2_db = AmplitudeToDB(stype="magnitude", top_db=80)(S2_mel) + + diff = S2_db - S1_db + + positive_diff = diff[diff > 0] + negative_diff = diff[diff < 0] + + average_positive = torch.mean(positive_diff) if positive_diff.numel() > 0 else torch.tensor(0.0).to(device) + average_negative = torch.mean(negative_diff) if negative_diff.numel() > 0 else torch.tensor(0.0).to(device) + + bleedless = 100 * 1 / (average_positive + 1) + fullness = 100 * 1 / (-average_negative + 1) + + return bleedless.cpu().numpy(), fullness.cpu().numpy() + + +def get_metrics( + metrics: List[str], + reference: np.ndarray, + estimate: np.ndarray, + mix: np.ndarray, + device: str = 'cpu', +) -> Dict[str, float]: + """ + Calculate a list of metrics to evaluate the performance of audio source separation models. + + The function computes the specified metrics based on the reference, estimate, and mixture. + + Parameters: + ---------- + metrics : List[str] + A list of metric names to compute (e.g., ['sdr', 'si_sdr', 'l1_freq']). + + reference : np.ndarray + The reference audio (true signal) with shape (channels, length). + + estimate : np.ndarray + The estimated audio (predicted signal) with shape (channels, length). + + mix : np.ndarray + The mixed audio signal with shape (channels, length). + + device : str, optional, default='cpu' + The device ('cpu' or 'cuda') to perform the calculations on. + + Returns: + ------- + Dict[str, float] + A dictionary containing the computed metric values. + """ + result = dict() + + # Adjust the length to be the same across all inputs + min_length = min(reference.shape[1], estimate.shape[1]) + reference = reference[..., :min_length] + estimate = estimate[..., :min_length] + mix = mix[..., :min_length] + + if 'sdr' in metrics: + references = np.expand_dims(reference, axis=0) + estimates = np.expand_dims(estimate, axis=0) + result['sdr'] = sdr(references, estimates)[0] + + if 'si_sdr' in metrics: + result['si_sdr'] = si_sdr(reference, estimate) + + if 'l1_freq' in metrics: + result['l1_freq'] = L1Freq_metric(reference, estimate, device=device) + + if 'log_wmse' in metrics: + result['log_wmse'] = LogWMSE_metric(reference, estimate, mix, device) + + if 'aura_stft' in metrics: + result['aura_stft'] = AuraSTFT_metric(reference, estimate, device) + + if 'aura_mrstft' in metrics: + result['aura_mrstft'] = AuraMRSTFT_metric(reference, estimate, device) + + if 'bleedless' in metrics or 'fullness' in metrics: + bleedless, fullness = bleed_full(reference, estimate, device=device) + if 'bleedless' in metrics: + result['bleedless'] = bleedless + if 'fullness' in metrics: + result['fullness'] = fullness + + return result + + +def prefer_target_instrument(config: ConfigDict) -> List[str]: + """ + Return the list of target instruments based on the configuration. + If a specific target instrument is specified in the configuration, + it returns a list with that instrument. Otherwise, it returns the list of instruments. + + Parameters: + ---------- + config : ConfigDict + Configuration object containing the list of instruments or the target instrument. + + Returns: + ------- + List[str] + A list of target instruments. + """ + if config.training.get('target_instrument'): + return [config.training.target_instrument] + else: + return config.training.instruments + +def prefer_target_instrument_test(config: ConfigDict, selected_instruments: Optional[List[str]] = None) -> List[str]: + """ + Return the list of target instruments based on the configuration and selected instruments. + If selected_instruments is specified, returns the intersection with available instruments. + Otherwise, if a target instrument is specified, returns it, else returns all instruments. + + Parameters: + ---------- + config : ConfigDict + Configuration object containing the list of instruments or the target instrument. + selected_instruments : Optional[List[str]] + List of instruments to select (optional) + + Returns: + ------- + List[str] + A list of target instruments. + """ + available_instruments = config.training.instruments + + if selected_instruments is not None: + # Return only selected instruments that are available + return [instr for instr in selected_instruments if instr in available_instruments] + elif config.training.get('target_instrument'): + # Default behavior if no selection - return target instrument + return [config.training.target_instrument] + else: + # If no target and no selection, return all instruments + return available_instruments + + + + diff --git a/separator/renamer_stems.py b/separator/renamer_stems.py new file mode 100644 index 0000000000000000000000000000000000000000..ade537e876020fd5130266030b7975deaa0e9e66 --- /dev/null +++ b/separator/renamer_stems.py @@ -0,0 +1,48 @@ +import os +def output_file_template(template, input_file_name, stem, model_name): + template_name = ( + template + .replace("NAME", f"{input_file_name}") + .replace("MODEL", f"{model_name}") + .replace("STEM", f"{stem}") + ) + output_name = f"{template_name}" + return output_name + +def audio_separator_rename_stems(audio, template, name_model): + base_name = os.path.splitext(os.path.basename(audio))[0] + stems = { + "Bass": template.replace("NAME", base_name).replace("STEM", "Bass").replace("MODEL", name_model), + "Crowd": template.replace("NAME", base_name).replace("STEM", "Crowd").replace("MODEL", name_model), + "Drums": template.replace("NAME", base_name).replace("STEM", "Drums").replace("MODEL", name_model), + "Dry": template.replace("NAME", base_name).replace("STEM", "Dry").replace("MODEL", name_model), + "Breath": template.replace("NAME", base_name).replace("STEM", "Breath").replace("MODEL", name_model), + "Echo": template.replace("NAME", base_name).replace("STEM", "Echo").replace("MODEL", name_model), + "Instrumental": template.replace("NAME", base_name).replace("STEM", "Instrumental").replace("MODEL", name_model), + "No Bass": template.replace("NAME", base_name).replace("STEM", "No Bass").replace("MODEL", name_model), + "No Crowd": template.replace("NAME", base_name).replace("STEM", "No Crowd").replace("MODEL", name_model), + "No Drums": template.replace("NAME", base_name).replace("STEM", "No Drums").replace("MODEL", name_model), + "No Dry": template.replace("NAME", base_name).replace("STEM", "No Dry").replace("MODEL", name_model), + "No Echo": template.replace("NAME", base_name).replace("STEM", "No Echo").replace("MODEL", name_model), + "No Noise": template.replace("NAME", base_name).replace("STEM", "No Noise").replace("MODEL", name_model), + "No Other": template.replace("NAME", base_name).replace("STEM", "No Other").replace("MODEL", name_model), + "No Breath": template.replace("NAME", base_name).replace("STEM", "No Breath").replace("MODEL", name_model), + "No Reverb": template.replace("NAME", base_name).replace("STEM", "No Reverb").replace("MODEL", name_model), + "No Woodwinds": template.replace("NAME", base_name).replace("STEM", "No Woodwinds").replace("MODEL", name_model), + "Noise": template.replace("NAME", base_name).replace("STEM", "Noise").replace("MODEL", name_model), + "Other": template.replace("NAME", base_name).replace("STEM", "Other").replace("MODEL", name_model), + "Reverb": template.replace("NAME", base_name).replace("STEM", "Reverb").replace("MODEL", name_model), + "Vocals": template.replace("NAME", base_name).replace("STEM", "Vocals").replace("MODEL", name_model), + "Woodwinds": template.replace("NAME", base_name).replace("STEM", "Woodwinds").replace("MODEL", name_model), + "Guitar": template.replace("NAME", base_name).replace("STEM", "Guitar").replace("MODEL", name_model), + "Piano": template.replace("NAME", base_name).replace("STEM", "Piano").replace("MODEL", name_model) + } + return stems + +def audio_separator_vr_rename_stems(audio, template, name_model, primary_stem): + base_name = os.path.splitext(os.path.basename(audio))[0] + stems = { + f"{primary_stem}": template.replace("NAME", base_name).replace("STEM", primary_stem).replace("MODEL", name_model), + f"No {primary_stem}": template.replace("NAME", base_name).replace("STEM", f"No {primary_stem}").replace("MODEL", name_model) + } + return stems diff --git a/separator/uvr_sep.py b/separator/uvr_sep.py new file mode 100644 index 0000000000000000000000000000000000000000..ee1b9316d72b0383158e99540182eb8499c1727e --- /dev/null +++ b/separator/uvr_sep.py @@ -0,0 +1,156 @@ +import os +import argparse +import sys +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(SCRIPT_DIR) +import json + +from audio_separator.separator import Separator +from renamer_stems import audio_separator_rename_stems + +def give_vr_params(file): + path, filename = os.path.split(file) + name_without_ext = os.path.splitext(filename)[0] + vr_param = os.path.join(path, name_without_ext) + return vr_param + +def custom_vr_separate( + input_file, + ckpt_path, + config_path, + bitrate, + model_name, + template, + output_format, + primary_stem="Vocals", + aggression=5, + output_dir="./", + selected_instruments=[] +): + + separator = Separator( + output_dir=output_dir, + output_bitrate=bitrate, + use_soundfile=False, + output_format=output_format, + output_single_stem=(selected_instruments[0] if len(selected_instruments) == 1 else None) + ) + output_names = audio_separator_rename_stems(input_file, template, model_name) + + separator.load_custom_vr_model( + model_path=ckpt_path, + config_path=config_path, + params={"primary_stem": primary_stem, "vr_model_param" : give_vr_params(config_path), "window_size" : 512, "aggression": aggression}, + ) + + output_files = separator.separate(input_file, output_names) + + return output_files + +def give_full_model_name(model_type, model_name): + if model_type == "mdx": + return f"{model_name}.onnx" + elif model_type == "vr": + return f"{model_name}.pth" + + +def non_custom_uvr_inference(input_file, output_dir, template, bitrate, model_dir, model_type, model_name, output_format, aggression, selected_instruments=[]): + + separator = Separator( + output_dir=output_dir, + output_bitrate=bitrate, + model_file_dir=model_dir, + use_soundfile=False, + output_format=output_format, + output_single_stem=(selected_instruments[0] if len(selected_instruments) == 1 else None), + vr_params={"batch_size": 1, "window_size": 512, "aggression": aggression, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False}, + mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": True} + ) + separator.load_model(model_filename=give_full_model_name(model_type, model_name)) + + output_names = audio_separator_rename_stems(input_file, template, model_name) + + output_files = separator.separate(input_file, output_names) + + return output_files + + + + + + + + + + + +def main(): + parser = argparse.ArgumentParser(description='Audio separation tool') + subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-command help') + + # Парсер для custom VR separation + custom_parser = subparsers.add_parser('custom_vr', help='Custom VR model separation') + custom_parser.add_argument('--input_file', required=True, help='Input audio file path') + custom_parser.add_argument('--ckpt_path', required=True, help='Path to model checkpoint (.pth file)') + custom_parser.add_argument('--config_path', required=True, help='Path to model config file') + custom_parser.add_argument('--bitrate', type=str, default="320k", help='Output bitrate') + custom_parser.add_argument('--model_name', required=True, help='Name of the model') + custom_parser.add_argument('--template', default="{track_name}_{stem}_{model_name}", help='Output filename template') + custom_parser.add_argument('--output_format', default="mp3", help='Output audio format') + custom_parser.add_argument('--primary_stem', default="Vocals", help='Primary stem to separate') + custom_parser.add_argument('--aggression', type=int, default=5, help='Separation aggression level') + custom_parser.add_argument('--output_dir', default="./", help='Output directory') + custom_parser.add_argument('--selected_instruments', nargs='*', default=[], help='List of instruments to separate') + + # Парсер для non-custom UVR separation + uvr_parser = subparsers.add_parser('uvr', help='Non-custom UVR separation') + uvr_parser.add_argument('--input_file', required=True, help='Input audio file path') + uvr_parser.add_argument('--output_dir', default="./", help='Output directory') + uvr_parser.add_argument('--template', default="{track_name}_{stem}_{model_name}", help='Output filename template') + uvr_parser.add_argument('--bitrate', type=str, default="320k", help='Output bitrate') + uvr_parser.add_argument('--model_dir', required=True, help='Directory containing model files') + uvr_parser.add_argument('--model_type', required=True, choices=['mdx', 'vr'], help='Model type (mdx or vr)') + uvr_parser.add_argument('--model_name', required=True, help='Name of the model') + uvr_parser.add_argument('--output_format', default="mp3", help='Output audio format') + uvr_parser.add_argument('--aggression', type=int, default=5, help='Separation aggression level (for VR models)') + uvr_parser.add_argument('--selected_instruments', nargs='*', default=[], help='List of instruments to separate') + + args = parser.parse_args() + + if args.command == 'custom_vr': + # Запуск custom VR separation + results = custom_vr_separate( + input_file=args.input_file, + ckpt_path=args.ckpt_path, + config_path=args.config_path, + bitrate=args.bitrate, + model_name=args.model_name, + template=args.template, + output_format=args.output_format, + primary_stem=args.primary_stem, + aggression=args.aggression, + output_dir=args.output_dir, + selected_instruments=args.selected_instruments + ) + with open((os.path.join(args.output_dir, "results.json")), 'w') as f: + json.dump(results, f) + + elif args.command == 'uvr': + # Запуск non-custom UVR separation + results = non_custom_uvr_inference( + input_file=args.input_file, + output_dir=args.output_dir, + template=args.template, + bitrate=args.bitrate, + model_dir=args.model_dir, + model_type=args.model_type, + model_name=args.model_name, + output_format=args.output_format, + aggression=args.aggression, + selected_instruments=args.selected_instruments + ) + with open((os.path.join(args.output_dir, "results.json")), 'w') as f: + json.dump(results, f) + +if __name__ == "__main__": + main() diff --git a/utils/download_audio.py b/utils/download_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..e302127313036c40f151d8ce8cfe93b6a4f4e982 --- /dev/null +++ b/utils/download_audio.py @@ -0,0 +1,124 @@ +import os +import requests +from mimetypes import guess_extension +from pathlib import Path +from filecmp import dircmp +import yt_dlp +import magic +import random +import gradio as gr +import string +import shutil +import re + + +def download_audio_from_url(url_or_text, type_url, cookies, output_dir="downloads"): + def extract_url_from_text(text): + """Ищет ссылку в произвольном тексте""" + url_regex = r"https?://[^\s]+" + match = re.search(url_regex, text) + if match: + return match.group(0) + return None + if type_url == "YT Music, Soundcloud, Tiktok": + def is_supported_url(url): + """Проверка, что это ссылка с нужного сайта""" + return any(domain in url for domain in ["soundcloud.com", "youtube.com", "youtu.be", "tiktok.com"]) + def download_track(url_or_text, cookies, output_dir="downloads"): + """Скачивает трек из TikTok, YouTube или SoundCloud""" + os.makedirs(output_dir, exist_ok=True) + # Проверка и извлечение URL + if "http" not in url_or_text: + gr.Warning("Нет ссылки в строке") + return None + url = extract_url_from_text(url_or_text) + if not url: + gr.Warning("Ссылка не найдена") + return None + if not is_supported_url(url): + gr.Warning(f"Сайт не поддерживается: {url}") + return None + + ydl_opts = { + "format": "bestaudio/best", + "outtmpl": f"{output_dir}/%(title)s.%(ext)s", + "postprocessors": [{ + "key": "FFmpegExtractAudio", + "preferredcodec": "mp3", + }], + "quiet": True, + "cookiefile": cookies, # Укажите путь к вашему файлу с куками + } + + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info = ydl.extract_info(url, download=True) + filename = ydl.prepare_filename(info) + final_path = filename.rsplit(".", 1)[0] + ".mp3" + gr.Warning("Трек успешно скачан") + return final_path + url = extract_url_from_text(url_or_text) + audio = download_track(url, cookies, output_dir) + return audio + if type_url == "Прямая ссылка": + + def create_unique_file_name(prefix="song_", length=15): + while True: + random_part = ''.join(random.choices(string.ascii_lowercase + string.digits, k=length)) + folder_name = prefix + random_part + return folder_name + + def download_file(url, out_dir): + try: + save_path = os.path.join(out_dir, f"{create_unique_file_name()}.bin") + + response = requests.get(url, stream=True) + response.raise_for_status() # Проверяем, что запрос выполнен успешно + + # Записываем содержимое в файл + with open(save_path, 'wb') as file: + for chunk in response.iter_content(chunk_size=8192): + file.write(chunk) + + def get_correct_extension(file_path): + # Определяем MIME-тип + mime = magic.Magic(mime=True) + mime_type = mime.from_file(file_path) + + # Получаем расширение из MIME-типа + extension = guess_extension(mime_type) + return extension if extension else ".bin" # если тип неизвестен → .bin + + def is_audio_file(file_path): + audio_extensions = {'.mp3', '.wav', '.flac', '.aiff', '.ogg', '.opus', '.m4a', '.aac'} + _, ext = os.path.splitext(file_path) + return ext.lower() in audio_extensions + + def rename_file_with_proper_extension(file_path): + dirname = os.path.dirname(file_path) + basename = os.path.basename(file_path) + name_without_ext = os.path.splitext(basename)[0] + + # Получаем правильное расширение + correct_ext = get_correct_extension(file_path) + + # Новое имя файла + new_name = f"{name_without_ext}{correct_ext}" + new_path = os.path.join(dirname, new_name) + + # Переименовываем + os.rename(file_path, new_path) + return new_path + + save_path = rename_file_with_proper_extension() + is_audio = is_audio_file(save_path) + if is_audio == False: + gr.Warning("Скачанный файл не является аудиофайлом") + return None + gr.Warning(f"Файл успешно скачан и сохранен как {save_path}") + return save_path + except Exception as e: + gr.Warning(f"Произошла ошибка при скачивании файла: {e}") + return None + url = extract_url_from_text(url_or_text) + audio = download_file(url, output_dir) + return audio diff --git a/utils/download_models.py b/utils/download_models.py new file mode 100644 index 0000000000000000000000000000000000000000..4682611c4404623d6706be141dccc59ef6204822 --- /dev/null +++ b/utils/download_models.py @@ -0,0 +1,67 @@ +import os + +def download_model(model_paths, model_name, model_type, ckpt_url, conf_url): + model_dir = os.path.join(model_paths, model_type) + os.makedirs(model_dir, exist_ok=True) + + # Инициализация переменных (на случай, если ни одно условие не сработает) + config_path = None + checkpoint_path = None + + if model_type == "mel_band_roformer": + config_path = os.path.join(model_dir, f"{model_name}_config.yaml") + checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt") + + elif model_type == "vr": + config_path = os.path.join(model_dir, f"{model_name}.json") + checkpoint_path = os.path.join(model_dir, f"{model_name}.pth") + + elif model_type == "bs_roformer": + config_path = os.path.join(model_dir, f"{model_name}_config.yaml") + checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt") + + elif model_type == "mdx23c": + config_path = os.path.join(model_dir, f"{model_name}_config.yaml") + checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt") + + elif model_type == "scnet": + config_path = os.path.join(model_dir, f"{model_name}_config.yaml") + checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt") + + elif model_type == "bandit": + config_path = os.path.join(model_dir, f"{model_name}_config.yaml") + checkpoint_path = os.path.join(model_dir, f"{model_name}.chpt") + + elif model_type == "bandit_v2": + config_path = os.path.join(model_dir, f"{model_name}_config.yaml") + checkpoint_path = os.path.join(model_dir, f"{model_name}.ckpt") + + elif model_type == "htdemucs": + config_path = os.path.join(model_dir, f"{model_name}_config.yaml") + checkpoint_path = os.path.join(model_dir, f"{model_name}.th") + + elif model_type == "medley_vox": + medley_vox_model_dir = os.path.join(model_dir, model_name) + os.makedirs(medley_vox_model_dir, exist_ok=True) + config_path = os.path.join(medley_vox_model_dir, f"vocals.json") + checkpoint_path = os.path.join(medley_vox_model_dir, f"vocals.pth") + + else: + raise ValueError(f"Unsupported model_type: {model_type}") + + # Проверяем, что пути заданы (на всякий случай) + if config_path is None or checkpoint_path is None: + raise RuntimeError("Failed to set model paths!") + + # Если файлы уже есть — пропускаем загрузку + if os.path.exists(checkpoint_path) and os.path.exists(config_path): + print("Model already downloaded") + else: + for local_path, url_model in [(checkpoint_path, ckpt_url), (config_path, conf_url)]: + download_cmd = f"wget -O {local_path} {url_model}" + os.system(download_cmd) + + if model_type == "medley_vox": + return model_dir + else: + return config_path, checkpoint_path \ No newline at end of file diff --git a/utils/preedit_config.py b/utils/preedit_config.py new file mode 100644 index 0000000000000000000000000000000000000000..9d73d04adc3c51fcc5c619fc4cc38c1d2907f9c2 --- /dev/null +++ b/utils/preedit_config.py @@ -0,0 +1,42 @@ +import os +import yaml + +def conf_editor(config_path): + + class IndentDumper(yaml.Dumper): + def increase_indent(self, flow=False, indentless=False): + return super(IndentDumper, self).increase_indent(flow, False) + + + def tuple_constructor(loader, node): + # Load the sequence of values from the YAML node + values = loader.construct_sequence(node) + # Return a tuple constructed from the sequence + return tuple(values) + + # Register the constructor with PyYAML + yaml.SafeLoader.add_constructor('tag:yaml.org,2002:python/tuple', +tuple_constructor) + + + + def conf_edit(config_path): + with open(config_path, 'r') as f: + data = yaml.load(f, Loader=yaml.SafeLoader) + + # handle cases where 'use_amp' is missing from config: + if 'use_amp' not in data.keys(): + data['training']['use_amp'] = True + + if data['inference']['num_overlap'] != 2: + data['inference']['num_overlap'] = 2 + + if data['inference']['batch_size'] == 1: + data['inference']['batch_size'] = 2 + + print("Using custom overlap and chunk_size values:") + print(f"batch_size = {data['inference']['batch_size']}") + + + with open(config_path, 'w') as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False, Dumper=IndentDumper, allow_unicode=True)