from pydub import AudioSegment, silence import tempfile import hashlib import matplotlib.pylab as plt import librosa from transformers import pipeline def initialize_asr_pipeline(device: str = device, dtype=None): if dtype is None: dtype = ( torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 and not torch.cuda.get_device_name().endswith("[ZLUDA]") else torch.float32 ) global asr_pipe asr_pipe = pipeline( "automatic-speech-recognition", model="vinai/PhoWhisper-medium", torch_dtype=dtype, device=device, ) # transcribe def transcribe(ref_audio, language=None): global asr_pipe if asr_pipe is None: initialize_asr_pipeline(device=device) return asr_pipe( ref_audio, chunk_length_s=30, batch_size=128, generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"}, return_timestamps=False, )["text"].strip() def caculate_spec(audio): # Compute spectrogram (Short-Time Fourier Transform) stft = librosa.stft(audio, n_fft=512, hop_length=256, win_length=512) spectrogram = np.abs(stft) # Convert to dB spectrogram_db = librosa.amplitude_to_db(spectrogram, ref=np.max) return spectrogram_db def save_spectrogram(audio, path): spectrogram = caculate_spec(audio) plt.figure(figsize=(12, 4)) plt.imshow(spectrogram, origin="lower", aspect="auto") plt.colorbar() plt.savefig(path) plt.close() def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device): show_info("Converting audio...") with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: aseg = AudioSegment.from_file(ref_audio_orig) if clip_short: # 1. try to find long silence for clipping non_silent_segs = silence.split_on_silence( aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10 ) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000: show_info("Audio is over 15s, clipping short. (1)") break non_silent_wave += non_silent_seg # 2. try to find short silence for clipping if 1. failed if len(non_silent_wave) > 15000: non_silent_segs = silence.split_on_silence( aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10 ) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000: show_info("Audio is over 15s, clipping short. (2)") break non_silent_wave += non_silent_seg aseg = non_silent_wave # 3. if no proper silence found for clipping if len(aseg) > 15000: aseg = aseg[:15000] show_info("Audio is over 15s, clipping short. (3)") aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50) aseg.export(f.name, format="wav") ref_audio = f.name # Compute a hash of the reference audio file with open(ref_audio, "rb") as audio_file: audio_data = audio_file.read() audio_hash = hashlib.md5(audio_data).hexdigest() if not ref_text.strip(): global _ref_audio_cache if audio_hash in _ref_audio_cache: # Use cached asr transcription show_info("Using cached reference text...") ref_text = _ref_audio_cache[audio_hash] else: show_info("No reference text provided, transcribing reference audio...") ref_text = transcribe(ref_audio) # Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak) _ref_audio_cache[audio_hash] = ref_text else: show_info("Using custom reference text...") # Ensure ref_text ends with a proper sentence-ending punctuation if not ref_text.endswith(". ") and not ref_text.endswith("。"): if ref_text.endswith("."): ref_text += " " else: ref_text += ". " print("\nref_text ", ref_text) return ref_audio, ref_text