File size: 7,957 Bytes
92e075b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""

Media to WAV Converter Module



Converts various media formats (m4a, mp3, mp4, etc.) to standardized WAV files

and PyTorch tensors for audio transcription pipelines.



Standardization:

- 16kHz sample rate

- Mono channel (merged if multi-channel)

- Layer normalized

- bfloat16 dtype tensor

- Fail-fast error handling

"""

import os
import tempfile
from pathlib import Path
from typing import Tuple, Union, Optional

import librosa
import numpy as np
import soundfile as sf
import torch
import torch.nn.functional as F
from pydub import AudioSegment
from pydub.utils import which


# Constants
TARGET_SAMPLE_RATE = 16000
TARGET_DTYPE = torch.bfloat16


def verify_ffmpeg_installation():
    """Verify FFmpeg is available for pydub operations."""
    if not which("ffmpeg"):
        raise RuntimeError(
            "FFmpeg not found. Please install FFmpeg for media format support. "
            "On Ubuntu: sudo apt install ffmpeg"
        )


def layer_norm(tensor: torch.Tensor, shape: torch.Size) -> torch.Tensor:
    """Apply layer normalization to audio tensor."""
    # Simple layer normalization: (x - mean) / std
    mean = tensor.mean()
    std = tensor.std()
    if std == 0:
        return tensor - mean
    return (tensor - mean) / std


def detect_media_format(file_path: str) -> str:
    """Detect media format from file extension."""
    file_path = Path(file_path)
    extension = file_path.suffix.lower()

    supported_formats = {
        '.wav': 'wav',
        '.mp3': 'mp3',
        '.m4a': 'm4a',
        '.aac': 'aac',
        '.flac': 'flac',
        '.ogg': 'ogg',
        '.wma': 'wma',
        '.mp4': 'mp4',
        '.avi': 'avi',
        '.mov': 'mov',
        '.mkv': 'mkv'
    }

    # Return known format or just pass through the extension without the dot
    # Let downstream processing handle unknown formats with detailed error messages
    return supported_formats.get(extension, extension[1:] if extension.startswith('.') else extension)


def convert_to_wav_with_pydub(input_path: str, output_path: str, format_hint: str = None):
    """Convert media file to WAV using pydub (FFmpeg backend)."""
    verify_ffmpeg_installation()

    # Load audio file - pydub auto-detects format or use hint
    if format_hint:
        audio = AudioSegment.from_file(input_path, format=format_hint)
    else:
        # Let pydub auto-detect
        audio = AudioSegment.from_file(input_path)

    # Convert to WAV format with standard settings
    # pydub will handle the initial conversion, librosa will do the final processing
    audio.export(output_path, format="wav")


def process_wav_to_standard_format(wav_path: str) -> Tuple[np.ndarray, int]:
    """Process WAV file to standard format using librosa."""
    # Load the WAV file with librosa (handles resampling better than pydub)
    data, fs = librosa.load(wav_path, sr=None)  # Load at original sample rate first

    # Resample to target sample rate if needed
    if fs != TARGET_SAMPLE_RATE:
        data = librosa.resample(data, orig_sr=fs, target_sr=TARGET_SAMPLE_RATE)

    # Handle multi-channel audio by merging to mono
    if len(data.shape) > 1:
        # Average across channels
        data = np.mean(data, axis=0)

    # Ensure it's a 1D array
    data = np.asarray(data, dtype=np.float32)

    return data, TARGET_SAMPLE_RATE


def create_normalized_tensor(audio_data: np.ndarray) -> torch.Tensor:
    """Convert numpy audio data to normalized PyTorch tensor with device handling."""
    # Convert to bf16 tensor and normalize
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data = torch.Tensor(audio_data).to(torch.bfloat16)
    data = layer_norm(data, data.shape)
    data = data.unsqueeze(0).to(device)

    return data


def convert_media_to_wav(

    input_path: str,

    output_dir: Optional[str] = None,

    keep_temp_wav: bool = True

) -> Tuple[str, torch.Tensor]:
    """

    Convert media file to standardized WAV file and normalized tensor.



    Args:

        input_path: Path to input media file

        output_dir: Directory for output WAV file (default: temp directory)

        keep_temp_wav: Whether to keep the temporary WAV file



    Returns:

        Tuple of (wav_file_path, normalized_tensor)



    Raises:

        ValueError: If file format is unsupported

        RuntimeError: If FFmpeg is not available

        FileNotFoundError: If input file doesn't exist

    """

    # Validate input file
    if not os.path.exists(input_path):
        raise FileNotFoundError(f"Input file not found: {input_path}")

    input_path = os.path.abspath(input_path)

    # Detect format
    media_format = detect_media_format(input_path)

    # Setup output path
    if output_dir is None:
        output_dir = tempfile.gettempdir()

    # Create output filename
    input_name = Path(input_path).stem
    output_wav_path = os.path.join(output_dir, f"{input_name}_converted.wav")

    # Step 1: Convert to WAV using pydub (handles format conversion)
    if media_format == 'wav':
        # Already WAV, but still process through pydub to normalize format
        convert_to_wav_with_pydub(input_path, output_wav_path, 'wav')
    else:
        # Convert from other format to WAV
        convert_to_wav_with_pydub(input_path, output_wav_path, media_format)

    # Step 2: Process WAV to standard format using librosa
    audio_data, sample_rate = process_wav_to_standard_format(output_wav_path)

    # Step 3: Create normalized tensor
    normalized_tensor = create_normalized_tensor(audio_data)

    # Step 4: Save the processed audio back to WAV file
    # Overwrite the temp WAV with the processed version
    sf.write(output_wav_path, audio_data, sample_rate)

    return output_wav_path, normalized_tensor


def convert_media_to_wav_from_bytes(

    media_bytes: bytes,

    original_filename: str,

    output_dir: Optional[str] = None

) -> Tuple[str, torch.Tensor]:
    """

    Convert media from bytes to WAV file and tensor.



    Args:

        media_bytes: Raw media file bytes

        original_filename: Original filename for format detection

        output_dir: Directory for output files



    Returns:

        Tuple of (wav_file_path, normalized_tensor)

    """

    # Create temporary input file
    input_extension = Path(original_filename).suffix
    with tempfile.NamedTemporaryFile(delete=False, suffix=input_extension) as temp_input:
        temp_input.write(media_bytes)
        temp_input_path = temp_input.name

    # Convert using the main function
    wav_path, tensor = convert_media_to_wav(temp_input_path, output_dir)

    # Clean up temporary input file
    os.unlink(temp_input_path)

    return wav_path, tensor


# Utility function for getting audio info
def get_media_info(file_path: str) -> dict:
    """Get information about media file."""
    verify_ffmpeg_installation()

    audio = AudioSegment.from_file(file_path)

    return {
        "duration_seconds": len(audio) / 1000.0,
        "frame_rate": audio.frame_rate,
        "channels": audio.channels,
        "sample_width": audio.sample_width,
        "format": detect_media_format(file_path)
    }


if __name__ == "__main__":
    # Example usage
    import sys

    if len(sys.argv) != 2:
        print("Usage: python convert_media_to_wav.py <input_file>")
        sys.exit(1)

    input_file = sys.argv[1]

    print(f"Converting {input_file}...")
    wav_path, tensor = convert_media_to_wav(input_file)

    print(f"βœ“ WAV file: {wav_path}")
    print(f"βœ“ Tensor shape: {tensor.shape}")
    print(f"βœ“ Tensor dtype: {tensor.dtype}")
    print(f"βœ“ Tensor device: {tensor.device}")

    # Show media info
    info = get_media_info(input_file)
    print(f"βœ“ Media info: {info}")