File size: 2,076 Bytes
af11ce4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import logging

import librosa
import soundfile as sf
import torch


def load_waveform(
    fname: str,
    sample_rate: int,
    dtype: str = "float32",
    device: torch.device = torch.device("cpu"),
    return_numpy: bool = False,
    max_seconds: float = None,
) -> torch.Tensor:
    """
    Load an audio file, preprocess it, and convert to a PyTorch tensor.

    Args:
        fname (str): Path to the audio file.
        sample_rate (int): Target sample rate for resampling.
        dtype (str, optional): Data type to load audio as (default: "float32").
        device (torch.device, optional): Device to place the resulting tensor
            on (default: CPU).
        return_numpy (bool): If True, returns a NumPy array instead of a
            PyTorch tensor.
        max_seconds (int): Maximum length (seconds) of the audio tensor.
            If the audio is longer than this, it will be truncated.

    Returns:
        torch.Tensor: Processed audio waveform as a PyTorch tensor,
            with shape (num_samples,).

    Notes:
        - If the audio is stereo, it will be converted to mono by averaging channels.
        - If the audio's sample rate differs from the target, it will be resampled.
    """
    # Load audio file with specified data type
    wav_data, sr = sf.read(fname, dtype=dtype)

    # Convert stereo to mono if necessary
    if len(wav_data.shape) == 2:
        wav_data = wav_data.mean(1)

    # Resample to target sample rate if needed
    if sr != sample_rate:
        wav_data = librosa.resample(wav_data, orig_sr=sr, target_sr=sample_rate)

    if max_seconds is not None:
        # Trim to max length
        max_length = sample_rate * max_seconds
        if len(wav_data) > max_length:
            wav_data = wav_data[:max_length]
            logging.warning(
                f"Wav file {fname} is longer than 2 minutes, "
                f"truncated to 2 minutes to avoid OOM."
            )
    if return_numpy:
        return wav_data
    else:
        wav_data = torch.from_numpy(wav_data)
        return wav_data.to(device)