|
|
"""Preprocessing and normalization to prepare audio for Kintsugi Depression and Anxiety model.""" |
|
|
from typing import Union, BinaryIO |
|
|
import numpy as np |
|
|
import os |
|
|
import torch |
|
|
import torchaudio |
|
|
from transformers import AutoFeatureExtractor |
|
|
|
|
|
from config import EXPECTED_SAMPLE_RATE, logmel_energies |
|
|
|
|
|
|
|
|
def load_audio(source: Union[BinaryIO, str, os.PathLike]) -> torch.Tensor: |
|
|
"""Load audio file, verify mono channel count, and resample if necessary. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
source: open file or path to file |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Time domain audio samples as a 1 x num_samples float tensor sampled at 16 kHz. |
|
|
|
|
|
""" |
|
|
audio, fs = torchaudio.load(source) |
|
|
if audio.shape[0] != 1: |
|
|
raise ValueError(f"Provided audio has {audio.shape[0]} != 1 channels.") |
|
|
if fs != EXPECTED_SAMPLE_RATE: |
|
|
audio = torchaudio.functional.resample(audio, fs, EXPECTED_SAMPLE_RATE) |
|
|
return audio |
|
|
|
|
|
|
|
|
class Preprocessor: |
|
|
def __init__(self, |
|
|
normalize_features: bool = True, |
|
|
chunk_seconds: int = 30, |
|
|
max_overlap_frac: float = 0.0, |
|
|
pad_last_chunk_to_full: bool = True, |
|
|
): |
|
|
"""Create preprocessor object. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
normalize_features: Whether the Whisper preprocessor should normalize features |
|
|
chunk_seconds: Size of model's receptive field in seconds |
|
|
max_overlap_frac: Fraction of each chunk allowed to overlap previous chunk for inputs longer than chunk_seconds |
|
|
pad_last_chunk_to_full: Whether to pad audio to an integer multiple of chunk_seconds |
|
|
|
|
|
""" |
|
|
self.preprocessor = AutoFeatureExtractor.from_pretrained("openai/whisper-small.en") |
|
|
self.normalize_features = normalize_features |
|
|
self.chunk_seconds = chunk_seconds |
|
|
self.max_overlap_frac = max_overlap_frac |
|
|
self.pad_last_chunk_to_full = pad_last_chunk_to_full |
|
|
|
|
|
def preprocess_with_audio_normalization( |
|
|
self, |
|
|
audio: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
"""Run Whisper preprocessor and normalization expected by the model. |
|
|
|
|
|
Note: some normalization steps can be avoided, but are included to match |
|
|
feature extraction used during training. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
audio: Raw audio samples as a 1 x num_samples float tensor sampled at 16 kHz |
|
|
|
|
|
Returns |
|
|
------- |
|
|
Normalized mel filter bank features as a float tensor of shape |
|
|
num_chunks x 80 mel filter bands x 3000 time frames |
|
|
|
|
|
""" |
|
|
|
|
|
audio = torch.squeeze(audio, 0) |
|
|
audio = audio - torch.mean(audio) |
|
|
audio = audio / torch.max(torch.abs(audio)) |
|
|
|
|
|
chunk_samples = EXPECTED_SAMPLE_RATE * self.chunk_seconds |
|
|
|
|
|
if self.pad_last_chunk_to_full: |
|
|
|
|
|
if self.max_overlap_frac > 0: |
|
|
raise ValueError( |
|
|
f"pad_last_chunk_to_full is only supported for non-overlapping windows" |
|
|
) |
|
|
num_chunks = np.ceil(len(audio) / chunk_samples) |
|
|
pad_size = int(num_chunks * chunk_samples - len(audio)) |
|
|
audio = torch.nn.functional.pad(audio, (0, pad_size)) |
|
|
|
|
|
overflow_len = len(audio) - chunk_samples |
|
|
|
|
|
min_hop_samples = int( |
|
|
(1 - self.max_overlap_frac) * chunk_samples |
|
|
) |
|
|
|
|
|
n_windows = 1 + overflow_len // min_hop_samples |
|
|
window_starts = np.linspace(0, overflow_len, max(n_windows, 1)).astype(int) |
|
|
|
|
|
features = self.preprocessor( |
|
|
[ |
|
|
audio[start : start + chunk_samples].numpy(force=True) |
|
|
for start in window_starts |
|
|
], |
|
|
return_tensors="pt", |
|
|
sampling_rate=EXPECTED_SAMPLE_RATE, |
|
|
do_normalize=self.normalize_features, |
|
|
) |
|
|
for key in ("input_features", "input_values"): |
|
|
if hasattr(features, key): |
|
|
features = getattr(features, key) |
|
|
break |
|
|
|
|
|
mean_features = torch.mean(features, dim=-1) |
|
|
|
|
|
rescale_factor = logmel_energies.unsqueeze(0) - mean_features |
|
|
rescale_factor = rescale_factor.unsqueeze(2) |
|
|
features += rescale_factor |
|
|
return features |
|
|
|