Spaces:
Running
on
Zero
Running
on
Zero
| import yt_dlp | |
| import re | |
| import subprocess | |
| import os | |
| import shutil | |
| from pydub import AudioSegment | |
| import gradio as gr | |
| import traceback | |
| import logging | |
| from inference import proc_folder_direct | |
| from pathlib import Path | |
| OUTPUT_FOLDER = "separation_results/" | |
| INPUT_FOLDER = "input" | |
| download_path = "" | |
| def sanitize_filename(filename): | |
| return re.sub(r'[\\/*?:"<>|]', '_', filename) | |
| def delete_input_files(input_dir): | |
| wav_dir = Path(input_dir) / "wav" | |
| for wav_file in wav_dir.glob("*.wav"): | |
| wav_file.unlink() | |
| print(f"Deleted {wav_file}") | |
| def download_youtube_audio_by_title(query, state=True): | |
| if state: | |
| delete_input_files(INPUT_FOLDER) | |
| ydl_opts = { | |
| 'quiet': True, | |
| 'default_search': 'ytsearch', | |
| 'noplaylist': True, | |
| 'format': 'bestaudio/best', | |
| 'outtmpl': './input/wav/%(title)s.%(ext)s', | |
| 'postprocessors': [{ | |
| 'key': 'FFmpegExtractAudio', | |
| 'preferredcodec': 'wav', | |
| }], | |
| } | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| search_results = ydl.extract_info(query, download=False) | |
| video_info = search_results['entries'][0] | |
| video_url = video_info['webpage_url'] | |
| video_title = video_info['title'] | |
| match = re.match(r'^(.*? - .*?)(?: \[.*\]|\(.*\))?$', video_title) | |
| formatted_title = match.group(1) if match else video_title | |
| formatted_title = sanitize_filename(formatted_title.strip()) | |
| ydl_opts['outtmpl'] = f'./input/wav/{formatted_title}.%(ext)s' | |
| if state: | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| ydl.download([video_url]) | |
| return f'./input/wav/{formatted_title}.wav' | |
| return formatted_title | |
| def run_inference(model_type, config_path, start_check_point, input_dir, output_dir, device_ids="0"): | |
| command = [ | |
| "python", "inference.py", | |
| "--model_type", model_type, | |
| "--config_path", config_path, | |
| "--start_check_point", start_check_point, | |
| "--INPUT_FOLDER", input_dir, | |
| "--store_dir", output_dir, | |
| "--device_ids", device_ids | |
| ] | |
| return subprocess.run(command, check=True, capture_output=True, text=True) | |
| def move_stems_to_parent(input_dir): | |
| for subdir, dirs, files in os.walk(input_dir): | |
| if subdir == input_dir: | |
| continue | |
| parent_dir = os.path.dirname(subdir) | |
| song_name = os.path.basename(parent_dir) | |
| if 'htdemucs' in subdir: | |
| print(f"Processing htdemucs in {subdir}") | |
| bass_path = os.path.join(subdir, f"{song_name}_bass.wav") | |
| if os.path.exists(bass_path): | |
| new_bass_path = os.path.join(parent_dir, "bass.wav") | |
| print(f"Moving {bass_path} to {new_bass_path}") | |
| shutil.move(bass_path, new_bass_path) | |
| else: | |
| print(f"Bass file not found: {bass_path}") | |
| elif 'mel_band_roformer' in subdir: | |
| print(f"Processing mel_band_roformer in {subdir}") | |
| vocals_path = os.path.join(subdir, f"{song_name}_vocals.wav") | |
| if os.path.exists(vocals_path): | |
| new_vocals_path = os.path.join(parent_dir, "vocals.wav") | |
| print(f"Moving {vocals_path} to {new_vocals_path}") | |
| shutil.move(vocals_path, new_vocals_path) | |
| else: | |
| print(f"Vocals file not found: {vocals_path}") | |
| elif 'scnet' in subdir: | |
| print(f"Processing scnet in {subdir}") | |
| other_path = os.path.join(subdir, f"{song_name}_other.wav") | |
| if os.path.exists(other_path): | |
| new_other_path = os.path.join(parent_dir, "other.wav") | |
| print(f"Moving {other_path} to {new_other_path}") | |
| shutil.move(other_path, new_other_path) | |
| else: | |
| print(f"Other file not found: {other_path}") | |
| elif 'bs_roformer' in subdir: | |
| print(f"Processing bs_roformer in {subdir}") | |
| instrumental_path = os.path.join(subdir, f"{song_name}_other.wav") | |
| if os.path.exists(instrumental_path): | |
| new_instrumental_path = os.path.join(parent_dir, "instrumental.wav") | |
| print(f"Moving {instrumental_path} to {new_instrumental_path}") | |
| shutil.move(instrumental_path, new_instrumental_path) | |
| else: | |
| print(f"Instrumental file not found: {instrumental_path}") | |
| def combine_stems_for_all(input_dir): | |
| for subdir, _, _ in os.walk(input_dir): | |
| if subdir == input_dir: | |
| continue | |
| song_name = os.path.basename(subdir) | |
| print(f"Processing {subdir}") | |
| stem_paths = { | |
| "vocals": os.path.join(subdir, "vocals.wav"), | |
| "bass": os.path.join(subdir, "bass.wav"), | |
| "others": os.path.join(subdir, "other.wav"), | |
| "instrumental": os.path.join(subdir, "instrumental.wav") | |
| } | |
| if not all(os.path.exists(path) for path in stem_paths.values()): | |
| print(f"Skipping {subdir}, not all stems are present.") | |
| continue | |
| stems = {name: AudioSegment.from_file(path) for name, path in stem_paths.items()} | |
| combined = stems["vocals"].overlay(stems["bass"]).overlay(stems["others"]).overlay(stems["instrumental"]) | |
| output_file = os.path.join(subdir, f"{song_name}.MDS.wav") | |
| combined.export(output_file, format="wav") | |
| print(f"Exported combined stems to {output_file}") | |
| def delete_folders_and_files(input_dir): | |
| folders_to_delete = ['htdemucs', 'mel_band_roformer', 'scnet', 'bs_roformer'] | |
| files_to_delete = ['bass.wav', 'vocals.wav', 'other.wav', 'instrumental.wav'] | |
| for root, dirs, files in os.walk(input_dir, topdown=False): | |
| if root == input_dir: | |
| continue | |
| for folder in folders_to_delete: | |
| folder_path = os.path.join(root, folder) | |
| if os.path.isdir(folder_path): | |
| print(f"Deleting folder: {folder_path}") | |
| shutil.rmtree(folder_path) | |
| for file in files_to_delete: | |
| file_path = os.path.join(root, file) | |
| if os.path.isfile(file_path): | |
| print(f"Deleting file: {file_path}") | |
| os.remove(file_path) | |
| for root, dirs, files in os.walk(OUTPUT_FOLDER): | |
| for dir_name in dirs: | |
| if dir_name.endswith('_vocals'): | |
| dir_path = os.path.join(root, dir_name) | |
| print(f"Deleting folder: {dir_path}") | |
| shutil.rmtree(dir_path) | |
| print("Cleanup completed.") | |
| def process_audio(song_title): | |
| try: | |
| yield "Finding audio...", None | |
| if title_input == "": | |
| raise ValueError("Please enter a song title.") | |
| formatted_title = download_youtube_audio_by_title(song_title, False) | |
| yield "Starting SCNet inference...", None | |
| proc_folder_direct("scnet", "configs/config_scnet_other.yaml", "results/model_scnet_other.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER) | |
| yield "Starting Mel Band Roformer inference...", None | |
| proc_folder_direct("mel_band_roformer", "configs/config_mel_band_roformer_vocals.yaml", "results/model_mel_band_roformer_vocals.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER, extract_instrumental=True) | |
| yield "Starting HTDemucs inference...", None | |
| proc_folder_direct("htdemucs", "configs/config_htdemucs_bass.yaml", "results/model_htdemucs_bass.th", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER) | |
| source_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}_instrumental.wav' | |
| destination_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}.wav' | |
| os.rename(source_path, destination_path) | |
| yield "Starting BS Roformer inference...", None | |
| proc_folder_direct("bs_roformer", "configs/config_bs_roformer_instrumental.yaml", "results/model_bs_roformer_instrumental.ckpt", f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer', OUTPUT_FOLDER) | |
| yield "Moving input files...", None | |
| delete_input_files(INPUT_FOLDER) | |
| yield "Moving stems to parent...", None | |
| move_stems_to_parent(OUTPUT_FOLDER) | |
| yield "Combining stems...", None | |
| combine_stems_for_all(OUTPUT_FOLDER) | |
| yield "Cleaning up...", None | |
| delete_folders_and_files(OUTPUT_FOLDER) | |
| yield f"Audio processing completed successfully.", f'{OUTPUT_FOLDER}{formatted_title}/{formatted_title}.MDS.wav' | |
| except Exception as e: | |
| error_msg = f"An error occurred: {str(e)}\n{traceback.format_exc()}" | |
| logging.error(error_msg) | |
| yield error_msg, None | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Music Player and Processor") | |
| with gr.Row(): | |
| title_input = gr.Textbox(label="Enter Song Title") | |
| play_button = gr.Button("Play") | |
| audio_output = gr.Audio(label="Audio Player") | |
| process_button = gr.Button("Process Audio") | |
| log_output = gr.Textbox(label="Processing Log", interactive=False) | |
| processed_audio_output = gr.Audio(label="Processed Audio") | |
| play_button.click( | |
| fn=download_youtube_audio_by_title, | |
| inputs=title_input, | |
| outputs=audio_output | |
| ) | |
| process_button.click( | |
| fn=process_audio, | |
| inputs=title_input, | |
| outputs=[log_output, processed_audio_output], | |
| show_progress=True | |
| ) | |
| demo.launch() |