cc_vad / tabs /vad_tab.py
HoneyTian's picture
add examples
ee326f0
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from collections import defaultdict
from functools import lru_cache
import json
import logging
from pathlib import Path
import shutil
import tempfile
import time
from typing import Dict, Tuple
import uuid
import zipfile
import gradio as gr
import librosa
from huggingface_hub import snapshot_download
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import wavfile
from project_settings import project_path
from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad_onnx import InferenceFSMNVadOnnx
from toolbox.torchaudio.models.vad.silero_vad.inference_silero_vad import InferenceSileroVad
from toolbox.torchaudio.models.vad.native_silero_vad.inference_native_silero_vad_onnx import InferenceNativeSileroVadOnnx
from toolbox.torchaudio.utils.visualization import process_speech_probs
from toolbox.vad.utils import PostProcess
from toolbox.pydub.volume import get_volume
logger = logging.getLogger("main")
def save_input_audio(sample_rate: int, signal: np.ndarray) -> str:
if signal.dtype != np.int16:
raise AssertionError(f"only support dtype np.int16, however: {signal.dtype}")
temp_audio_dir = Path(tempfile.gettempdir()) / "input_audio"
temp_audio_dir.mkdir(parents=True, exist_ok=True)
filename = temp_audio_dir / f"{uuid.uuid4()}.wav"
filename = filename.as_posix()
wavfile.write(
filename,
sample_rate, signal
)
return filename
def convert_sample_rate(signal: np.ndarray, sample_rate: int, target_sample_rate: int):
filename = save_input_audio(sample_rate, signal)
signal, _ = librosa.load(filename, sr=target_sample_rate)
signal = np.array(signal * (1 << 15), dtype=np.int16)
return signal
def get_infer_cls_by_model_name(model_name: str):
if model_name.__contains__("native_silero_vad"):
infer_cls = InferenceNativeSileroVadOnnx
elif model_name.__contains__("fsmn-vad"):
infer_cls = InferenceFSMNVadOnnx
elif model_name.__contains__("silero-vad"):
infer_cls = InferenceSileroVad
else:
raise AssertionError
return infer_cls
vad_engines: Dict[str, dict] = None
@lru_cache(maxsize=1)
def load_vad_model(infer_cls, **kwargs):
infer_engine = infer_cls(**kwargs)
return infer_engine
def generate_image(signal: np.ndarray, speech_probs: np.ndarray, sample_rate: int = 8000, title: str = ""):
duration = np.arange(0, len(signal)) / sample_rate
plt.figure(figsize=(12, 5))
plt.plot(duration, signal, color='b')
plt.plot(duration, speech_probs, color='gray')
plt.title(title)
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
plt.savefig(temp_file.name, bbox_inches="tight")
plt.close()
return temp_file.name
def when_click_vad_button(audio_file_t = None, audio_microphone_t = None,
start_ring_rate: float = 0.5, end_ring_rate: float = 0.3,
ring_max_length: int = 10,
min_silence_length: int = 2,
max_speech_length: int = 10000, min_speech_length: int = 10,
engine: str = None,
):
if audio_file_t is None and audio_microphone_t is None:
raise gr.Error(f"audio file and microphone is null.")
if audio_file_t is not None and audio_microphone_t is not None:
gr.Warning(f"both audio file and microphone file is provided, audio file taking priority.")
audio_t: Tuple = audio_file_t or audio_microphone_t
sample_rate, signal = audio_t
if sample_rate != 8000:
signal = convert_sample_rate(signal, sample_rate, 8000)
sample_rate = 8000
audio_duration = signal.shape[-1] // sample_rate
audio = np.array(signal / (1 << 15), dtype=np.float32)
infer_engine_param = vad_engines.get(engine)
if infer_engine_param is None:
raise gr.Error(f"invalid denoise engine: {engine}.")
try:
infer_cls = infer_engine_param["infer_cls"]
kwargs = infer_engine_param["kwargs"]
infer_engine = load_vad_model(infer_cls=infer_cls, **kwargs)
begin = time.time()
vad_info = infer_engine.infer(audio)
time_cost = time.time() - begin
probs: np.ndarray = vad_info["probs"]
lsnr: np.ndarray = vad_info["lsnr"]
# lsnr = lsnr / np.max(np.abs(lsnr))
lsnr = lsnr / 30
frame_step = infer_engine.config.hop_size
# post process
vad_post_process = PostProcess(
start_ring_rate=start_ring_rate,
end_ring_rate=end_ring_rate,
ring_max_length=ring_max_length,
min_silence_length=min_silence_length,
max_speech_length=max_speech_length,
min_speech_length=min_speech_length
)
vad_segments = vad_post_process.get_vad_segments(probs)
vad_flags = vad_post_process.get_vad_flags(probs, vad_segments)
# vad_image
vad_ = process_speech_probs(audio, vad_flags, frame_step)
vad_image = generate_image(audio, vad_)
# probs_image
probs_ = process_speech_probs(audio, probs, frame_step)
probs_image = generate_image(audio, probs_)
# lsnr_image
lsnr_ = process_speech_probs(audio, lsnr, frame_step)
lsnr_image = generate_image(audio, lsnr_)
# vad segment
vad_segments = [
[
v[0] * frame_step / sample_rate,
v[1] * frame_step / sample_rate
] for v in vad_segments
]
# volume
volume_map: dict = get_volume(audio, sample_rate)
# message
rtf = time_cost / audio_duration
info = {
"vad_segments": vad_segments,
"time_cost": round(time_cost, 4),
"duration": round(audio_duration, 4),
"rtf": round(rtf, 4),
**volume_map
}
message = json.dumps(info, ensure_ascii=False, indent=4)
except Exception as e:
raise gr.Error(f"vad failed, error type: {type(e)}, error text: {str(e)}.")
return vad_image, probs_image, lsnr_image, message
def get_vad_tab(trained_model_dir: str, examples_dir: str, models_repo_id: str, hf_token: str):
examples_dir = Path(examples_dir)
trained_model_dir = Path(trained_model_dir)
# download models
if not trained_model_dir.exists():
trained_model_dir.mkdir(parents=True, exist_ok=True)
_ = snapshot_download(
repo_id=models_repo_id,
local_dir=trained_model_dir.as_posix(),
token=hf_token,
)
# engines
global vad_engines
vad_engines = {
filename.stem: {
"infer_cls": get_infer_cls_by_model_name(filename.stem),
"kwargs": {
"pretrained_model_path_or_zip_file": filename.as_posix()
}
}
for filename in (project_path / "trained_models").glob("*.zip")
if filename.name not in (
# "cnn-vad-by-webrtcvad-nx-dns3.zip",
# "fsmn-vad-by-webrtcvad-nx-dns3.zip",
"examples.zip",
"sound-2-ch32.zip",
"sound-3-ch32.zip",
"sound-4-ch32.zip",
"sound-8-ch32.zip",
)
}
# choices
vad_engine_choices = list(vad_engines.keys())
# examples
if not examples_dir.exists():
example_zip_file = trained_model_dir / "examples.zip"
with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip:
out_root = examples_dir
if out_root.exists():
shutil.rmtree(out_root.as_posix())
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
# examples
examples = defaultdict(list)
for filename in examples_dir.glob("**/*.wav"):
category = filename.parts[-2]
examples[category].append([
filename.as_posix(),
None,
vad_engine_choices[0],
])
# ui
with gr.TabItem("vad"):
with gr.Row():
with gr.Column(variant="panel", scale=5):
with gr.Tabs():
with gr.TabItem("file"):
vad_audio_file = gr.Audio(label="audio")
with gr.TabItem("microphone"):
vad_audio_microphone = gr.Audio(sources="microphone", label="audio")
with gr.Row():
vad_start_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="start_ring_rate")
vad_end_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="end_ring_rate")
with gr.Row():
vad_ring_max_length = gr.Number(value=10, label="ring_max_length (*10ms)")
vad_min_silence_length = gr.Number(value=6, label="min_silence_length (*10ms)")
with gr.Row():
vad_max_speech_length = gr.Number(value=100000, label="max_speech_length (*10ms)")
vad_min_speech_length = gr.Number(value=15, label="min_speech_length (*10ms)")
vad_engine = gr.Dropdown(choices=vad_engine_choices, value=vad_engine_choices[0], label="engine")
vad_button = gr.Button(variant="primary")
with gr.Column(variant="panel", scale=5):
vad_vad_image = gr.Image(label="vad")
vad_prob_image = gr.Image(label="prob")
vad_lsnr_image = gr.Image(label="lsnr")
vad_message = gr.Textbox(lines=1, max_lines=20, label="message")
# examples ui
with gr.Tabs():
for label, sub_examples in examples.items():
with gr.TabItem(label):
gr.Examples(
examples=sub_examples,
inputs=[vad_audio_file, vad_audio_microphone, vad_engine],
outputs=[vad_vad_image, vad_prob_image, vad_lsnr_image, vad_message],
fn=when_click_vad_button,
# cache_examples=True,
# cache_mode="lazy",
)
vad_button.click(
when_click_vad_button,
inputs=[
vad_audio_file, vad_audio_microphone,
vad_start_ring_rate, vad_end_ring_rate,
vad_ring_max_length,
vad_min_silence_length,
vad_max_speech_length, vad_min_speech_length,
vad_engine,
],
outputs=[vad_vad_image, vad_prob_image, vad_lsnr_image, vad_message],
)
return locals()
if __name__ == "__main__":
pass