Spaces:
Sleeping
Sleeping
File size: 2,771 Bytes
92e075b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import io
import numpy as np
import soundfile as sf
import torch
from numpy.typing import NDArray
# def wav_to_bytes(
# wav: torch.Tensor | NDArray, sample_rate: int = 16_000, format: str = "wav"
# ) -> NDArray[np.int8]:
# """Convert audio tensor to bytes using soundfile directly."""
# # Convert to numpy if torch tensor
# if isinstance(wav, torch.Tensor):
# if wav.is_cuda:
# wav = wav.cpu()
# # Convert to float32 first (numpy doesn't support bfloat16)
# if wav.dtype != torch.float32:
# wav = wav.float()
# wav = wav.numpy()
# # Ensure float32 dtype for numpy arrays
# if wav.dtype != np.float32:
# wav = wav.astype(np.float32)
# # Handle shape: soundfile expects (samples,) for mono or (samples, channels) for multi-channel
# if wav.ndim == 1:
# # Already correct shape for mono
# pass
# elif wav.ndim == 2:
# # If shape is (channels, samples), transpose to (samples, channels)
# if wav.shape[0] < wav.shape[1]:
# wav = wav.T
# # Create buffer and write using soundfile directly
# buffer = io.BytesIO()
# # Map format string to soundfile format
# sf_format = format.upper() if format.lower() in ['wav', 'flac', 'ogg'] else 'WAV'
# subtype = 'PCM_16' if sf_format == 'WAV' else None
# # Write to buffer
# sf.write(buffer, wav, sample_rate, format=sf_format, subtype=subtype)
# buffer.seek(0)
# return np.frombuffer(buffer.getvalue(), dtype=np.int8)
# # return buffer.read()
def wav_to_bytes(wav: torch.Tensor | np.ndarray, sample_rate: int = 16000, format: str = "wav"):
"""Convert audio tensor to bytes using soundfile directly (safe + dtype fix)."""
# ✅ Convert to numpy if torch tensor
if isinstance(wav, torch.Tensor):
wav = wav.detach().cpu()
if wav.dtype == torch.bfloat16:
wav = wav.to(torch.float32) # FIX: convert unsupported dtype
elif wav.dtype != torch.float32:
wav = wav.float()
wav = wav.numpy()
# ✅ Handle empty or multi-dim cases
if wav.ndim > 1:
wav = wav.squeeze()
if wav.size == 0:
raise ValueError("Empty audio segment passed to wav_to_bytes")
# ✅ Ensure valid range and dtype
wav = wav.astype(np.float32)
wav = np.nan_to_num(np.clip(wav, -1.0, 1.0))
buffer = io.BytesIO()
try:
sf.write(buffer, wav, sample_rate, format="WAV", subtype="PCM_16")
except Exception as e:
print(f"[ERROR] soundfile write failed: {e}")
raise
buffer.seek(0)
return np.frombuffer(buffer.getvalue(), dtype=np.int8)
|