Spaces:
Runtime error
Runtime error
| import argparse | |
| from tqdm import tqdm | |
| from utils import load_wav, collect_valentini_paths | |
| from metrics import Metrics | |
| from denoisers.SpectralGating import SpectralGating | |
| PARSERS = { | |
| 'valentini': collect_valentini_paths | |
| } | |
| MODELS = { | |
| 'baseline': SpectralGating | |
| } | |
| def evaluate_on_dataset(model_name, dataset_path, dataset_type): | |
| if model_name is not None: | |
| model = MODELS[model_name]() | |
| parser = PARSERS[dataset_type] | |
| clean_wavs, noisy_wavs = parser(dataset_path) | |
| metrics = Metrics() | |
| mean_scores = {'PESQ': 0, 'STOI': 0} | |
| for clean_path, noisy_path in tqdm(zip(clean_wavs, noisy_wavs), total=len(clean_wavs)): | |
| clean_wav = load_wav(clean_path) | |
| noisy_wav = load_wav(noisy_path) | |
| if model_name is None: | |
| scores = metrics.calculate(denoised=noisy_wav, clean=clean_wav) | |
| else: | |
| denoised_wav = model(noisy_wav) | |
| scores = metrics.calculate(denoised=denoised_wav, clean=clean_wav) | |
| mean_scores['PESQ'] += scores['PESQ'] | |
| mean_scores['STOI'] += scores['STOI'] | |
| mean_scores['PESQ'] = mean_scores['PESQ'].numpy() / len(clean_wavs) | |
| mean_scores['STOI'] = mean_scores['STOI'].numpy() / len(clean_wavs) | |
| return mean_scores | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(prog='Program to evaluate denoising') | |
| parser.add_argument('--dataset_path', type=str, | |
| default='/media/public/dataset/denoising/DS_10283_2791/', | |
| help='Path to dataset folder') | |
| parser.add_argument('--dataset_type', type=str, required=True, | |
| choices=['valentini']) | |
| parser.add_argument('--model_name', type=str, | |
| choices=['baseline']) | |
| args = parser.parse_args() | |
| mean_scores = evaluate_on_dataset(model_name=args.model_name, | |
| dataset_path=args.dataset_path, | |
| dataset_type=args.dataset_type) | |
| print(f"Metrics on {args.dataset_type} dataset with " | |
| f"{args.model_name if args.model_name is not None else 'ideal denoising'} = {mean_scores}") | |