NOI_3_ZIP / utils.py
hynt's picture
Update utils.py
79fbe3f
raw
history blame
6.95 kB
from pydub import AudioSegment, silence
import tempfile
import hashlib
import matplotlib.pylab as plt
import librosa
from transformers import pipeline
import re
import torch
_ref_audio_cache = {}
asr_pipe = None
def chunk_text(text, max_chars=135):
# print(text)
# Bước 1: Tách câu theo dấu ". "
sentences = [s.strip() for s in text.split('. ') if s.strip()]
# Ghép câu ngắn hơn 4 từ với câu liền kề
i = 0
while i < len(sentences):
if len(sentences[i].split()) < 4:
if i == 0 and i + 1 < len(sentences):
# Ghép với câu sau
sentences[i + 1] = sentences[i] + ', ' + sentences[i + 1]
del sentences[i]
else:
if i - 1 >= 0:
# Ghép với câu trước
sentences[i - 1] = sentences[i - 1] + ', ' + sentences[i]
del sentences[i]
i -= 1
else:
i += 1
# print(sentences)
# Bước 2: Tách phần quá dài trong câu theo dấu ", "
final_sentences = []
for sentence in sentences:
parts = [p.strip() for p in sentence.split(', ')]
buffer = []
for part in parts:
buffer.append(part)
total_words = sum(len(p.split()) for p in buffer)
if total_words > 20:
# Tách câu ra
long_part = ', '.join(buffer)
final_sentences.append(long_part)
buffer = []
if buffer:
final_sentences.append(', '.join(buffer))
# print(final_sentences)
if len(final_sentences[-1].split()) < 4 and len(final_sentences) >= 2:
final_sentences[-2] = final_sentences[-2] + ", " + final_sentences[-1]
final_sentences = final_sentences[0:-1]
# print(final_sentences)
return final_sentences
def initialize_asr_pipeline(device="cuda", dtype=None):
if dtype is None:
dtype = (
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 6
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
global asr_pipe
asr_pipe = pipeline(
"automatic-speech-recognition",
model="vinai/PhoWhisper-medium",
torch_dtype=dtype,
device=device,
)
# transcribe
def transcribe(ref_audio, language=None):
global asr_pipe
if asr_pipe is None:
initialize_asr_pipeline(device="cuda")
return asr_pipe(
ref_audio,
chunk_length_s=30,
batch_size=128,
generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"},
return_timestamps=False,
)["text"].strip()
def caculate_spec(audio):
# Compute spectrogram (Short-Time Fourier Transform)
stft = librosa.stft(audio, n_fft=512, hop_length=256, win_length=512)
spectrogram = np.abs(stft)
# Convert to dB
spectrogram_db = librosa.amplitude_to_db(spectrogram, ref=np.max)
return spectrogram_db
def save_spectrogram(audio, path):
spectrogram = caculate_spec(audio)
plt.figure(figsize=(12, 4))
plt.imshow(spectrogram, origin="lower", aspect="auto")
plt.colorbar()
plt.savefig(path)
plt.close()
def remove_silence_edges(audio, silence_threshold=-42):
# Remove silence from the start
non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold)
audio = audio[non_silent_start_idx:]
# Remove silence from the end
non_silent_end_duration = audio.duration_seconds
for ms in reversed(audio):
if ms.dBFS > silence_threshold:
break
non_silent_end_duration -= 0.001
trimmed_audio = audio[: int(non_silent_end_duration * 1000)]
return trimmed_audio
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device="cuda"):
show_info("Converting audio...")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
aseg = AudioSegment.from_file(ref_audio_orig)
if clip_short:
# 1. try to find long silence for clipping
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
show_info("Audio is over 15s, clipping short. (1)")
break
non_silent_wave += non_silent_seg
# 2. try to find short silence for clipping if 1. failed
if len(non_silent_wave) > 15000:
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
show_info("Audio is over 15s, clipping short. (2)")
break
non_silent_wave += non_silent_seg
aseg = non_silent_wave
# 3. if no proper silence found for clipping
if len(aseg) > 15000:
aseg = aseg[:15000]
show_info("Audio is over 15s, clipping short. (3)")
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
aseg.export(f.name, format="wav")
ref_audio = f.name
# Compute a hash of the reference audio file
with open(ref_audio, "rb") as audio_file:
audio_data = audio_file.read()
audio_hash = hashlib.md5(audio_data).hexdigest()
if not ref_text.strip():
global _ref_audio_cache
if audio_hash in _ref_audio_cache:
# Use cached asr transcription
show_info("Using cached reference text...")
ref_text = _ref_audio_cache[audio_hash]
else:
show_info("No reference text provided, transcribing reference audio...")
ref_text = transcribe(ref_audio)
# Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
_ref_audio_cache[audio_hash] = ref_text
else:
show_info("Using custom reference text...")
# Ensure ref_text ends with a proper sentence-ending punctuation
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
if ref_text.endswith("."):
ref_text += " "
else:
ref_text += ". "
print("\nref_text ", ref_text)
return ref_audio, ref_text