Spaces:
Runtime error
Runtime error
| import pesq | |
| from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality | |
| from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility | |
| from torchaudio.transforms import Resample | |
| import torch | |
| import torchaudio | |
| from torchmetrics import SignalNoiseRatio | |
| class Metrics(torch.nn.Module): | |
| def __init__(self, source_rate, target_rate=16000, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.source_rate = source_rate | |
| self.target_rate = target_rate | |
| self.resampler = Resample(orig_freq=source_rate, new_freq=target_rate) | |
| self.nb_pesq = PerceptualEvaluationSpeechQuality(target_rate, 'wb') | |
| self.stoi = ShortTimeObjectiveIntelligibility(target_rate, False) | |
| self.snr = SignalNoiseRatio() | |
| def forward(self, denoised, clean): | |
| pesq_scores, stoi_scores = 0, 0 | |
| for denoised_wav, clean_wav in zip(denoised, clean): | |
| if self.source_rate != self.target_rate: | |
| denoised_wav = self.resampler(denoised_wav) | |
| clean_wav = self.resampler(clean_wav) | |
| try: | |
| pesq_scores += self.nb_pesq(denoised_wav, clean_wav).item() | |
| stoi_scores += self.stoi(denoised_wav, clean_wav).item() | |
| except pesq.NoUtterancesError as e: | |
| print(e) | |
| except ValueError as e: | |
| print(e) | |
| return {'PESQ': pesq_scores, | |
| 'STOI': stoi_scores} | |