Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import time | |
| import librosa | |
| from tqdm import tqdm | |
| import sys | |
| import os | |
| import glob | |
| import torch | |
| import numpy as np | |
| import soundfile as sf | |
| import torch.nn as nn | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.append(current_dir) | |
| from utils import demix_track, demix_track_demucs, get_model_from_config | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| def run_folder(model, args, config, device, verbose=False): | |
| start_time = time.time() | |
| model.eval() | |
| all_mixtures_path = glob.glob(args.input_folder + '/*.*') | |
| all_mixtures_path.sort() | |
| print('Total files found: {}'.format(len(all_mixtures_path))) | |
| instruments = config.training.instruments | |
| if config.training.target_instrument is not None: | |
| instruments = [config.training.target_instrument] | |
| if not os.path.isdir(args.store_dir): | |
| os.mkdir(args.store_dir) | |
| if not verbose: | |
| all_mixtures_path = tqdm(all_mixtures_path, desc="Total progress") | |
| if args.disable_detailed_pbar: | |
| detailed_pbar = False | |
| else: | |
| detailed_pbar = True | |
| for path in all_mixtures_path: | |
| print("Starting processing track: ", path) | |
| if not verbose: | |
| all_mixtures_path.set_postfix({'track': os.path.basename(path)}) | |
| try: | |
| mix, sr = librosa.load(path, sr=44100, mono=False) | |
| except Exception as e: | |
| print('Cannot read track: {}'.format(path)) | |
| print('Error message: {}'.format(str(e))) | |
| continue | |
| # Convert mono to stereo if needed | |
| if len(mix.shape) == 1: | |
| mix = np.stack([mix, mix], axis=0) | |
| mix_orig = mix.copy() | |
| if 'normalize' in config.inference: | |
| if config.inference['normalize'] is True: | |
| mono = mix.mean(0) | |
| mean = mono.mean() | |
| std = mono.std() | |
| mix = (mix - mean) / std | |
| if args.use_tta: | |
| # orig, channel inverse, polarity inverse | |
| track_proc_list = [mix.copy(), mix[::-1].copy(), -1. * mix.copy()] | |
| else: | |
| track_proc_list = [mix.copy()] | |
| full_result = [] | |
| for single_track in track_proc_list: | |
| mixture = torch.tensor(single_track, dtype=torch.float32) | |
| if args.model_type == 'htdemucs': | |
| waveforms = demix_track_demucs(config, model, mixture, device, pbar=detailed_pbar) | |
| else: | |
| waveforms = demix_track(config, model, mixture, device, pbar=detailed_pbar) | |
| full_result.append(waveforms) | |
| # Average all values in single dict | |
| waveforms = full_result[0] | |
| for i in range(1, len(full_result)): | |
| d = full_result[i] | |
| for el in d: | |
| if i == 2: | |
| waveforms[el] += -1.0 * d[el] | |
| elif i == 1: | |
| waveforms[el] += d[el][::-1].copy() | |
| else: | |
| waveforms[el] += d[el] | |
| for el in waveforms: | |
| waveforms[el] = waveforms[el] / len(full_result) | |
| file_name, _ = os.path.splitext(os.path.basename(path)) | |
| song_dir = os.path.join(args.store_dir, file_name) | |
| if not os.path.exists(song_dir): | |
| os.makedirs(song_dir) | |
| model_dir = os.path.join(song_dir, args.model_type) | |
| if not os.path.exists(model_dir): | |
| os.makedirs(model_dir) | |
| for instr in instruments: | |
| estimates = waveforms[instr].T | |
| if 'normalize' in config.inference: | |
| if config.inference['normalize'] is True: | |
| estimates = estimates * std + mean | |
| if args.flac_file: | |
| output_file = os.path.join(model_dir, f"{file_name}_{instr}.flac") | |
| subtype = 'PCM_16' if args.pcm_type == 'PCM_16' else 'PCM_24' | |
| sf.write(output_file, estimates, sr, subtype=subtype) | |
| else: | |
| output_file = os.path.join(model_dir, f"{file_name}_{instr}.wav") | |
| sf.write(output_file, estimates, sr, subtype='FLOAT') | |
| # Output "instrumental", which is an inverse of 'vocals' (or first stem in list if 'vocals' absent) | |
| if args.extract_instrumental: | |
| if 'vocals' in instruments: | |
| estimates = waveforms['vocals'].T | |
| else: | |
| estimates = waveforms[instruments[0]].T | |
| if 'normalize' in config.inference: | |
| if config.inference['normalize'] is True: | |
| estimates = estimates * std + mean | |
| if args.flac_file: | |
| instrum_file_name = os.path.join(model_dir, f"{file_name}_instrumental.flac") | |
| subtype = 'PCM_16' if args.pcm_type == 'PCM_16' else 'PCM_24' | |
| sf.write(instrum_file_name, mix_orig.T - estimates, sr, subtype=subtype) | |
| else: | |
| instrum_file_name = os.path.join(model_dir, f"{file_name}_instrumental.wav") | |
| sf.write(instrum_file_name, mix_orig.T - estimates, sr, subtype='FLOAT') | |
| time.sleep(1) | |
| print("Elapsed time: {:.2f} sec".format(time.time() - start_time)) | |
| def proc_folder_direct(model_type, config_path, start_check_point, input_folder, store_dir, device_ids=[0], extract_instrumental=False, disable_detailed_pbar=False, force_cpu=False, flac_file=False, pcm_type='PCM_24', use_tta=False): | |
| device = "cpu" | |
| if force_cpu: | |
| device = "cpu" | |
| elif torch.cuda.is_available(): | |
| print('CUDA is available, use --force_cpu to disable it.') | |
| device = "cuda" | |
| device = f'cuda:{device_ids}' if type(device_ids) == int else f'cuda:{device_ids[0]}' | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| print("Using device: ", device) | |
| model_load_start_time = time.time() | |
| torch.backends.cudnn.benchmark = True | |
| model, config = get_model_from_config(model_type, config_path) | |
| if start_check_point != '': | |
| print('Start from checkpoint: {}'.format(start_check_point)) | |
| if model_type == 'htdemucs': | |
| state_dict = torch.load(start_check_point, map_location=device, weights_only=False) | |
| if 'state' in state_dict: | |
| state_dict = state_dict['state'] | |
| else: | |
| state_dict = torch.load(start_check_point, map_location=device, weights_only=True) | |
| model.load_state_dict(state_dict) | |
| print("Instruments: {}".format(config.training.instruments)) | |
| if type(device_ids) != int: | |
| model = nn.DataParallel(model, device_ids=device_ids) | |
| model = model.to(device) | |
| print("Model load time: {:.2f} sec".format(time.time() - model_load_start_time)) | |
| args = argparse.Namespace( | |
| model_type=model_type, | |
| config_path=config_path, | |
| start_check_point=start_check_point, | |
| input_folder=input_folder, | |
| store_dir=store_dir, | |
| device_ids=device_ids, | |
| extract_instrumental=extract_instrumental, | |
| disable_detailed_pbar=disable_detailed_pbar, | |
| force_cpu=force_cpu, | |
| flac_file=flac_file, | |
| pcm_type=pcm_type, | |
| use_tta=use_tta | |
| ) | |
| run_folder(model, args, config, device, verbose=True) | |