update
Browse files- examples/cnn_vad_by_webrtcvad/run.sh +10 -8
- examples/cnn_vad_by_webrtcvad/step_1_prepare_data.py +96 -90
- examples/cnn_vad_by_webrtcvad/step_4_train_model.py +0 -2
- examples/fsmn_vad_by_webrtcvad/run.sh +11 -9
- examples/fsmn_vad_by_webrtcvad/step_1_prepare_data.py +97 -91
- examples/fsmn_vad_by_webrtcvad/step_2_make_vad_segments.py +59 -2
- examples/fsmn_vad_by_webrtcvad/step_4_train_model.py +34 -10
- examples/fsmn_vad_by_webrtcvad/yaml/config.yaml +6 -2
- examples/silero_vad_by_webrtcvad/run.sh +10 -8
- examples/silero_vad_by_webrtcvad/step_1_prepare_data.py +97 -91
- examples/silero_vad_by_webrtcvad/step_4_train_model.py +10 -4
- examples/silero_vad_by_webrtcvad/yaml/config.yaml +6 -0
- toolbox/torchaudio/models/vad/cnn_vad/inference_cnn_vad.py +138 -0
- toolbox/torchaudio/models/vad/fsmn_vad/configuration_fsmn_vad.py +10 -3
- toolbox/torchaudio/models/vad/fsmn_vad/modeling_fsmn_vad.py +149 -32
- toolbox/torchaudio/models/vad/fsmn_vad/yaml/{config-sigmoid.yaml → config.yaml} +6 -2
- toolbox/torchaudio/models/vad/silero_vad/configuration_silero_vad.py +11 -0
- toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py +114 -46
- toolbox/torchaudio/models/vad/silero_vad/yaml/config.yaml +6 -0
examples/cnn_vad_by_webrtcvad/run.sh
CHANGED
|
@@ -5,14 +5,16 @@
|
|
| 5 |
bash run.sh --stage 2 --stop_stage 2 --system_version centos \
|
| 6 |
--file_folder_name cnn-vad-by-webrtcvad-nx-dns3 \
|
| 7 |
--final_model_name cnn-vad-by-webrtcvad-nx-dns3 \
|
| 8 |
-
--
|
| 9 |
-
--
|
|
|
|
| 10 |
|
| 11 |
bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
| 12 |
--file_folder_name cnn-vad-by-webrtcvad-nx-dns3 \
|
| 13 |
--final_model_name cnn-vad-by-webrtcvad-nx-dns3 \
|
| 14 |
-
--
|
| 15 |
-
--
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
END
|
|
@@ -30,8 +32,8 @@ final_model_name=final_model_name
|
|
| 30 |
config_file="yaml/config.yaml"
|
| 31 |
limit=10
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
|
| 36 |
max_count=-1
|
| 37 |
|
|
@@ -98,8 +100,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
|
| 98 |
$verbose && echo "stage 1: prepare data"
|
| 99 |
cd "${work_dir}" || exit 1
|
| 100 |
python3 step_1_prepare_data.py \
|
| 101 |
-
--
|
| 102 |
-
--
|
| 103 |
--train_dataset "${train_dataset}" \
|
| 104 |
--valid_dataset "${valid_dataset}" \
|
| 105 |
--max_count "${max_count}" \
|
|
|
|
| 5 |
bash run.sh --stage 2 --stop_stage 2 --system_version centos \
|
| 6 |
--file_folder_name cnn-vad-by-webrtcvad-nx-dns3 \
|
| 7 |
--final_model_name cnn-vad-by-webrtcvad-nx-dns3 \
|
| 8 |
+
--noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
| 9 |
+
--speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
| 10 |
+
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
| 11 |
|
| 12 |
bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
| 13 |
--file_folder_name cnn-vad-by-webrtcvad-nx-dns3 \
|
| 14 |
--final_model_name cnn-vad-by-webrtcvad-nx-dns3 \
|
| 15 |
+
--noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
| 16 |
+
--speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
| 17 |
+
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
| 18 |
|
| 19 |
|
| 20 |
END
|
|
|
|
| 32 |
config_file="yaml/config.yaml"
|
| 33 |
limit=10
|
| 34 |
|
| 35 |
+
noise_patterns=/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav
|
| 36 |
+
speech_patterns=/data/tianxing/HuggingDatasets/nx_noise/data/speech/**/*.wav
|
| 37 |
|
| 38 |
max_count=-1
|
| 39 |
|
|
|
|
| 100 |
$verbose && echo "stage 1: prepare data"
|
| 101 |
cd "${work_dir}" || exit 1
|
| 102 |
python3 step_1_prepare_data.py \
|
| 103 |
+
--noise_patterns "${noise_patterns}" \
|
| 104 |
+
--speech_patterns "${speech_patterns}" \
|
| 105 |
--train_dataset "${train_dataset}" \
|
| 106 |
--valid_dataset "${valid_dataset}" \
|
| 107 |
--max_count "${max_count}" \
|
examples/cnn_vad_by_webrtcvad/step_1_prepare_data.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
| 1 |
#!/usr/bin/python3
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
import argparse
|
|
|
|
| 4 |
import json
|
| 5 |
import os
|
| 6 |
from pathlib import Path
|
| 7 |
import random
|
| 8 |
import sys
|
| 9 |
import time
|
|
|
|
| 10 |
|
| 11 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 12 |
sys.path.append(os.path.join(pwd, "../../"))
|
|
@@ -19,13 +21,13 @@ from tqdm import tqdm
|
|
| 19 |
def get_args():
|
| 20 |
parser = argparse.ArgumentParser()
|
| 21 |
parser.add_argument(
|
| 22 |
-
"--
|
| 23 |
-
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
| 24 |
type=str
|
| 25 |
)
|
| 26 |
parser.add_argument(
|
| 27 |
-
"--
|
| 28 |
-
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech",
|
| 29 |
type=str
|
| 30 |
)
|
| 31 |
|
|
@@ -46,108 +48,112 @@ def get_args():
|
|
| 46 |
return args
|
| 47 |
|
| 48 |
|
| 49 |
-
def target_second_noise_signal_generator(
|
| 50 |
duration: int = 4,
|
| 51 |
sample_rate: int = 8000, max_epoch: int = 20000):
|
| 52 |
noise_list = list()
|
| 53 |
wait_duration = duration
|
| 54 |
|
| 55 |
-
data_dir = Path(data_dir)
|
| 56 |
for epoch_idx in range(max_epoch):
|
| 57 |
-
for
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
|
|
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
|
|
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
break
|
| 71 |
-
if rest_duration <= wait_duration:
|
| 72 |
-
noise_list.append({
|
| 73 |
"epoch_idx": epoch_idx,
|
| 74 |
-
"filename": filename
|
| 75 |
"raw_duration": round(raw_duration, 4),
|
| 76 |
-
"offset":
|
| 77 |
-
"duration":
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
"epoch_idx": epoch_idx,
|
| 86 |
-
"filename": filename
|
| 87 |
"raw_duration": round(raw_duration, 4),
|
| 88 |
-
"offset": round(
|
| 89 |
-
"duration": round(
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
offset += wait_duration
|
| 93 |
-
rest_duration -= wait_duration
|
| 94 |
-
wait_duration = 0
|
| 95 |
-
else:
|
| 96 |
-
raise AssertionError
|
| 97 |
-
|
| 98 |
-
if wait_duration <= 0:
|
| 99 |
-
yield noise_list
|
| 100 |
-
noise_list = list()
|
| 101 |
-
wait_duration = duration
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def target_second_speech_signal_generator(data_dir: str,
|
| 105 |
-
min_duration: int = 4,
|
| 106 |
-
max_duration: int = 6,
|
| 107 |
-
sample_rate: int = 8000, max_epoch: int = 1):
|
| 108 |
-
data_dir = Path(data_dir)
|
| 109 |
-
for epoch_idx in range(max_epoch):
|
| 110 |
-
for filename in data_dir.glob("**/*.wav"):
|
| 111 |
-
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
| 112 |
-
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 113 |
-
|
| 114 |
-
if signal.ndim != 1:
|
| 115 |
-
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 116 |
-
|
| 117 |
-
if raw_duration < min_duration:
|
| 118 |
-
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 119 |
-
continue
|
| 120 |
-
|
| 121 |
-
if raw_duration < max_duration:
|
| 122 |
-
row = {
|
| 123 |
-
"epoch_idx": epoch_idx,
|
| 124 |
-
"filename": filename.as_posix(),
|
| 125 |
-
"raw_duration": round(raw_duration, 4),
|
| 126 |
-
"offset": 0.,
|
| 127 |
-
"duration": round(raw_duration, 4),
|
| 128 |
-
}
|
| 129 |
-
yield row
|
| 130 |
-
|
| 131 |
-
signal_length = len(signal)
|
| 132 |
-
win_size = int(max_duration * sample_rate)
|
| 133 |
-
for begin in range(0, signal_length - win_size, win_size):
|
| 134 |
-
if np.sum(signal[begin: begin+win_size]) == 0:
|
| 135 |
-
continue
|
| 136 |
-
row = {
|
| 137 |
-
"epoch_idx": epoch_idx,
|
| 138 |
-
"filename": filename.as_posix(),
|
| 139 |
-
"raw_duration": round(raw_duration, 4),
|
| 140 |
-
"offset": round(begin / sample_rate, 4),
|
| 141 |
-
"duration": round(max_duration, 4),
|
| 142 |
-
}
|
| 143 |
-
yield row
|
| 144 |
|
| 145 |
|
| 146 |
def main():
|
| 147 |
args = get_args()
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
train_dataset = Path(args.train_dataset)
|
| 153 |
valid_dataset = Path(args.valid_dataset)
|
|
@@ -155,13 +161,13 @@ def main():
|
|
| 155 |
valid_dataset.parent.mkdir(parents=True, exist_ok=True)
|
| 156 |
|
| 157 |
noise_generator = target_second_noise_signal_generator(
|
| 158 |
-
|
| 159 |
duration=args.duration,
|
| 160 |
sample_rate=args.target_sample_rate,
|
| 161 |
max_epoch=100000,
|
| 162 |
)
|
| 163 |
speech_generator = target_second_speech_signal_generator(
|
| 164 |
-
|
| 165 |
min_duration=args.min_speech_duration,
|
| 166 |
max_duration=args.max_speech_duration,
|
| 167 |
sample_rate=args.target_sample_rate,
|
|
|
|
| 1 |
#!/usr/bin/python3
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
import argparse
|
| 4 |
+
from glob import glob
|
| 5 |
import json
|
| 6 |
import os
|
| 7 |
from pathlib import Path
|
| 8 |
import random
|
| 9 |
import sys
|
| 10 |
import time
|
| 11 |
+
from typing import List
|
| 12 |
|
| 13 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 14 |
sys.path.append(os.path.join(pwd, "../../"))
|
|
|
|
| 21 |
def get_args():
|
| 22 |
parser = argparse.ArgumentParser()
|
| 23 |
parser.add_argument(
|
| 24 |
+
"--noise_patterns",
|
| 25 |
+
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\**\*.wav",
|
| 26 |
type=str
|
| 27 |
)
|
| 28 |
parser.add_argument(
|
| 29 |
+
"--speech_patterns",
|
| 30 |
+
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\**\*.wav",
|
| 31 |
type=str
|
| 32 |
)
|
| 33 |
|
|
|
|
| 48 |
return args
|
| 49 |
|
| 50 |
|
| 51 |
+
def target_second_noise_signal_generator(filename_patterns: List[str],
|
| 52 |
duration: int = 4,
|
| 53 |
sample_rate: int = 8000, max_epoch: int = 20000):
|
| 54 |
noise_list = list()
|
| 55 |
wait_duration = duration
|
| 56 |
|
|
|
|
| 57 |
for epoch_idx in range(max_epoch):
|
| 58 |
+
for filename_pattern in filename_patterns:
|
| 59 |
+
for filename in glob(filename_pattern):
|
| 60 |
+
signal, _ = librosa.load(filename, sr=sample_rate)
|
| 61 |
+
|
| 62 |
+
if signal.ndim != 1:
|
| 63 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 64 |
+
|
| 65 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 66 |
+
|
| 67 |
+
offset = 0.
|
| 68 |
+
rest_duration = raw_duration
|
| 69 |
+
|
| 70 |
+
for _ in range(1000):
|
| 71 |
+
if rest_duration <= 0:
|
| 72 |
+
break
|
| 73 |
+
if rest_duration <= wait_duration:
|
| 74 |
+
noise_list.append({
|
| 75 |
+
"epoch_idx": epoch_idx,
|
| 76 |
+
"filename": filename,
|
| 77 |
+
"raw_duration": round(raw_duration, 4),
|
| 78 |
+
"offset": round(offset, 4),
|
| 79 |
+
"duration": None,
|
| 80 |
+
"duration_": round(rest_duration, 4),
|
| 81 |
+
})
|
| 82 |
+
wait_duration -= rest_duration
|
| 83 |
+
offset = 0
|
| 84 |
+
rest_duration = 0
|
| 85 |
+
elif rest_duration > wait_duration:
|
| 86 |
+
noise_list.append({
|
| 87 |
+
"epoch_idx": epoch_idx,
|
| 88 |
+
"filename": filename,
|
| 89 |
+
"raw_duration": round(raw_duration, 4),
|
| 90 |
+
"offset": round(offset, 4),
|
| 91 |
+
"duration": round(wait_duration, 4),
|
| 92 |
+
"duration_": round(wait_duration, 4),
|
| 93 |
+
})
|
| 94 |
+
offset += wait_duration
|
| 95 |
+
rest_duration -= wait_duration
|
| 96 |
+
wait_duration = 0
|
| 97 |
+
else:
|
| 98 |
+
raise AssertionError
|
| 99 |
+
|
| 100 |
+
if wait_duration <= 0:
|
| 101 |
+
yield noise_list
|
| 102 |
+
noise_list = list()
|
| 103 |
+
wait_duration = duration
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def target_second_speech_signal_generator(filename_patterns: List[str],
|
| 107 |
+
min_duration: int = 4,
|
| 108 |
+
max_duration: int = 6,
|
| 109 |
+
sample_rate: int = 8000, max_epoch: int = 1):
|
| 110 |
+
for epoch_idx in range(max_epoch):
|
| 111 |
+
for filename_pattern in filename_patterns:
|
| 112 |
+
for filename in glob(filename_pattern):
|
| 113 |
+
signal, _ = librosa.load(filename, sr=sample_rate)
|
| 114 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 115 |
|
| 116 |
+
if signal.ndim != 1:
|
| 117 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 118 |
|
| 119 |
+
if raw_duration < min_duration:
|
| 120 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 121 |
+
continue
|
| 122 |
|
| 123 |
+
if raw_duration < max_duration:
|
| 124 |
+
row = {
|
|
|
|
|
|
|
|
|
|
| 125 |
"epoch_idx": epoch_idx,
|
| 126 |
+
"filename": filename,
|
| 127 |
"raw_duration": round(raw_duration, 4),
|
| 128 |
+
"offset": 0.,
|
| 129 |
+
"duration": round(raw_duration, 4),
|
| 130 |
+
}
|
| 131 |
+
yield row
|
| 132 |
+
|
| 133 |
+
signal_length = len(signal)
|
| 134 |
+
win_size = int(max_duration * sample_rate)
|
| 135 |
+
for begin in range(0, signal_length - win_size, win_size):
|
| 136 |
+
if np.sum(signal[begin: begin+win_size]) == 0:
|
| 137 |
+
continue
|
| 138 |
+
row = {
|
| 139 |
"epoch_idx": epoch_idx,
|
| 140 |
+
"filename": filename,
|
| 141 |
"raw_duration": round(raw_duration, 4),
|
| 142 |
+
"offset": round(begin / sample_rate, 4),
|
| 143 |
+
"duration": round(max_duration, 4),
|
| 144 |
+
}
|
| 145 |
+
yield row
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
def main():
|
| 149 |
args = get_args()
|
| 150 |
|
| 151 |
+
noise_patterns = args.noise_patterns
|
| 152 |
+
noise_patterns = noise_patterns.split(" ")
|
| 153 |
+
print(f"noise_patterns: {noise_patterns}")
|
| 154 |
+
speech_patterns = args.speech_patterns
|
| 155 |
+
speech_patterns = speech_patterns.split(" ")
|
| 156 |
+
print(f"speech_patterns: {speech_patterns}")
|
| 157 |
|
| 158 |
train_dataset = Path(args.train_dataset)
|
| 159 |
valid_dataset = Path(args.valid_dataset)
|
|
|
|
| 161 |
valid_dataset.parent.mkdir(parents=True, exist_ok=True)
|
| 162 |
|
| 163 |
noise_generator = target_second_noise_signal_generator(
|
| 164 |
+
noise_patterns,
|
| 165 |
duration=args.duration,
|
| 166 |
sample_rate=args.target_sample_rate,
|
| 167 |
max_epoch=100000,
|
| 168 |
)
|
| 169 |
speech_generator = target_second_speech_signal_generator(
|
| 170 |
+
speech_patterns,
|
| 171 |
min_duration=args.min_speech_duration,
|
| 172 |
max_duration=args.max_speech_duration,
|
| 173 |
sample_rate=args.target_sample_rate,
|
examples/cnn_vad_by_webrtcvad/step_4_train_model.py
CHANGED
|
@@ -17,8 +17,6 @@ sys.path.append(os.path.join(pwd, "../../"))
|
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
| 20 |
-
import torch.nn as nn
|
| 21 |
-
from torch.nn import functional as F
|
| 22 |
from torch.utils.data.dataloader import DataLoader
|
| 23 |
from tqdm import tqdm
|
| 24 |
|
|
|
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
|
|
|
|
|
|
| 20 |
from torch.utils.data.dataloader import DataLoader
|
| 21 |
from tqdm import tqdm
|
| 22 |
|
examples/fsmn_vad_by_webrtcvad/run.sh
CHANGED
|
@@ -2,17 +2,19 @@
|
|
| 2 |
|
| 3 |
: <<'END'
|
| 4 |
|
| 5 |
-
bash run.sh --stage
|
| 6 |
--file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
| 7 |
--final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
| 8 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
| 9 |
-
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/
|
|
|
|
| 10 |
|
| 11 |
bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
| 12 |
--file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
| 13 |
--final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
| 14 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
| 15 |
-
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
END
|
|
@@ -30,8 +32,8 @@ final_model_name=final_model_name
|
|
| 30 |
config_file="yaml/config.yaml"
|
| 31 |
limit=10
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
|
| 36 |
max_count=-1
|
| 37 |
|
|
@@ -98,8 +100,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
|
| 98 |
$verbose && echo "stage 1: prepare data"
|
| 99 |
cd "${work_dir}" || exit 1
|
| 100 |
python3 step_1_prepare_data.py \
|
| 101 |
-
--
|
| 102 |
-
--
|
| 103 |
--train_dataset "${train_dataset}" \
|
| 104 |
--valid_dataset "${valid_dataset}" \
|
| 105 |
--max_count "${max_count}" \
|
|
|
|
| 2 |
|
| 3 |
: <<'END'
|
| 4 |
|
| 5 |
+
bash run.sh --stage 1 --stop_stage 1 --system_version centos \
|
| 6 |
--file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
| 7 |
--final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
| 8 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
| 9 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
| 10 |
+
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
| 11 |
|
| 12 |
bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
| 13 |
--file_folder_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
| 14 |
--final_model_name fsmn-vad-by-webrtcvad-nx2-dns3 \
|
| 15 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
| 16 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
| 17 |
+
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
| 18 |
|
| 19 |
|
| 20 |
END
|
|
|
|
| 32 |
config_file="yaml/config.yaml"
|
| 33 |
limit=10
|
| 34 |
|
| 35 |
+
noise_patterns=/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav
|
| 36 |
+
speech_patterns=/data/tianxing/HuggingDatasets/nx_noise/data/speech/**/*.wav
|
| 37 |
|
| 38 |
max_count=-1
|
| 39 |
|
|
|
|
| 100 |
$verbose && echo "stage 1: prepare data"
|
| 101 |
cd "${work_dir}" || exit 1
|
| 102 |
python3 step_1_prepare_data.py \
|
| 103 |
+
--noise_patterns "${noise_patterns}" \
|
| 104 |
+
--speech_patterns "${speech_patterns}" \
|
| 105 |
--train_dataset "${train_dataset}" \
|
| 106 |
--valid_dataset "${valid_dataset}" \
|
| 107 |
--max_count "${max_count}" \
|
examples/fsmn_vad_by_webrtcvad/step_1_prepare_data.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
| 1 |
#!/usr/bin/python3
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
import argparse
|
|
|
|
| 4 |
import json
|
| 5 |
import os
|
| 6 |
from pathlib import Path
|
| 7 |
import random
|
| 8 |
import sys
|
| 9 |
import time
|
|
|
|
| 10 |
|
| 11 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 12 |
sys.path.append(os.path.join(pwd, "../../"))
|
|
@@ -19,13 +21,13 @@ from tqdm import tqdm
|
|
| 19 |
def get_args():
|
| 20 |
parser = argparse.ArgumentParser()
|
| 21 |
parser.add_argument(
|
| 22 |
-
"--
|
| 23 |
-
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
| 24 |
type=str
|
| 25 |
)
|
| 26 |
parser.add_argument(
|
| 27 |
-
"--
|
| 28 |
-
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech",
|
| 29 |
type=str
|
| 30 |
)
|
| 31 |
|
|
@@ -46,108 +48,112 @@ def get_args():
|
|
| 46 |
return args
|
| 47 |
|
| 48 |
|
| 49 |
-
def target_second_noise_signal_generator(
|
| 50 |
duration: int = 4,
|
| 51 |
sample_rate: int = 8000, max_epoch: int = 20000):
|
| 52 |
noise_list = list()
|
| 53 |
wait_duration = duration
|
| 54 |
|
| 55 |
-
data_dir = Path(data_dir)
|
| 56 |
for epoch_idx in range(max_epoch):
|
| 57 |
-
for
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
|
|
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
|
|
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
break
|
| 71 |
-
if rest_duration <= wait_duration:
|
| 72 |
-
noise_list.append({
|
| 73 |
"epoch_idx": epoch_idx,
|
| 74 |
-
"filename": filename
|
| 75 |
"raw_duration": round(raw_duration, 4),
|
| 76 |
-
"offset":
|
| 77 |
-
"duration":
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
"epoch_idx": epoch_idx,
|
| 86 |
-
"filename": filename
|
| 87 |
"raw_duration": round(raw_duration, 4),
|
| 88 |
-
"offset": round(
|
| 89 |
-
"duration": round(
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
offset += wait_duration
|
| 93 |
-
rest_duration -= wait_duration
|
| 94 |
-
wait_duration = 0
|
| 95 |
-
else:
|
| 96 |
-
raise AssertionError
|
| 97 |
-
|
| 98 |
-
if wait_duration <= 0:
|
| 99 |
-
yield noise_list
|
| 100 |
-
noise_list = list()
|
| 101 |
-
wait_duration = duration
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def target_second_speech_signal_generator(data_dir: str,
|
| 105 |
-
min_duration: int = 4,
|
| 106 |
-
max_duration: int = 6,
|
| 107 |
-
sample_rate: int = 8000, max_epoch: int = 1):
|
| 108 |
-
data_dir = Path(data_dir)
|
| 109 |
-
for epoch_idx in range(max_epoch):
|
| 110 |
-
for filename in data_dir.glob("**/*.wav"):
|
| 111 |
-
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
| 112 |
-
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 113 |
-
|
| 114 |
-
if signal.ndim != 1:
|
| 115 |
-
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 116 |
-
|
| 117 |
-
if raw_duration < min_duration:
|
| 118 |
-
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 119 |
-
continue
|
| 120 |
-
|
| 121 |
-
if raw_duration < max_duration:
|
| 122 |
-
row = {
|
| 123 |
-
"epoch_idx": epoch_idx,
|
| 124 |
-
"filename": filename.as_posix(),
|
| 125 |
-
"raw_duration": round(raw_duration, 4),
|
| 126 |
-
"offset": 0.,
|
| 127 |
-
"duration": round(raw_duration, 4),
|
| 128 |
-
}
|
| 129 |
-
yield row
|
| 130 |
-
|
| 131 |
-
signal_length = len(signal)
|
| 132 |
-
win_size = int(max_duration * sample_rate)
|
| 133 |
-
for begin in range(0, signal_length - win_size, win_size):
|
| 134 |
-
if np.sum(signal[begin: begin+win_size]) == 0:
|
| 135 |
-
continue
|
| 136 |
-
row = {
|
| 137 |
-
"epoch_idx": epoch_idx,
|
| 138 |
-
"filename": filename.as_posix(),
|
| 139 |
-
"raw_duration": round(raw_duration, 4),
|
| 140 |
-
"offset": round(begin / sample_rate, 4),
|
| 141 |
-
"duration": round(max_duration, 4),
|
| 142 |
-
}
|
| 143 |
-
yield row
|
| 144 |
|
| 145 |
|
| 146 |
def main():
|
| 147 |
args = get_args()
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
train_dataset = Path(args.train_dataset)
|
| 153 |
valid_dataset = Path(args.valid_dataset)
|
|
@@ -155,13 +161,13 @@ def main():
|
|
| 155 |
valid_dataset.parent.mkdir(parents=True, exist_ok=True)
|
| 156 |
|
| 157 |
noise_generator = target_second_noise_signal_generator(
|
| 158 |
-
|
| 159 |
duration=args.duration,
|
| 160 |
sample_rate=args.target_sample_rate,
|
| 161 |
max_epoch=100000,
|
| 162 |
)
|
| 163 |
speech_generator = target_second_speech_signal_generator(
|
| 164 |
-
|
| 165 |
min_duration=args.min_speech_duration,
|
| 166 |
max_duration=args.max_speech_duration,
|
| 167 |
sample_rate=args.target_sample_rate,
|
|
@@ -210,7 +216,7 @@ def main():
|
|
| 210 |
"random1": random1,
|
| 211 |
}
|
| 212 |
row = json.dumps(row, ensure_ascii=False)
|
| 213 |
-
if random2 < (
|
| 214 |
fvalid.write(f"{row}\n")
|
| 215 |
else:
|
| 216 |
ftrain.write(f"{row}\n")
|
|
|
|
| 1 |
#!/usr/bin/python3
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
import argparse
|
| 4 |
+
from glob import glob
|
| 5 |
import json
|
| 6 |
import os
|
| 7 |
from pathlib import Path
|
| 8 |
import random
|
| 9 |
import sys
|
| 10 |
import time
|
| 11 |
+
from typing import List
|
| 12 |
|
| 13 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 14 |
sys.path.append(os.path.join(pwd, "../../"))
|
|
|
|
| 21 |
def get_args():
|
| 22 |
parser = argparse.ArgumentParser()
|
| 23 |
parser.add_argument(
|
| 24 |
+
"--noise_patterns",
|
| 25 |
+
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\**\*.wav",
|
| 26 |
type=str
|
| 27 |
)
|
| 28 |
parser.add_argument(
|
| 29 |
+
"--speech_patterns",
|
| 30 |
+
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\**\*.wav",
|
| 31 |
type=str
|
| 32 |
)
|
| 33 |
|
|
|
|
| 48 |
return args
|
| 49 |
|
| 50 |
|
| 51 |
+
def target_second_noise_signal_generator(filename_patterns: List[str],
|
| 52 |
duration: int = 4,
|
| 53 |
sample_rate: int = 8000, max_epoch: int = 20000):
|
| 54 |
noise_list = list()
|
| 55 |
wait_duration = duration
|
| 56 |
|
|
|
|
| 57 |
for epoch_idx in range(max_epoch):
|
| 58 |
+
for filename_pattern in filename_patterns:
|
| 59 |
+
for filename in glob(filename_pattern):
|
| 60 |
+
signal, _ = librosa.load(filename, sr=sample_rate)
|
| 61 |
+
|
| 62 |
+
if signal.ndim != 1:
|
| 63 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 64 |
+
|
| 65 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 66 |
+
|
| 67 |
+
offset = 0.
|
| 68 |
+
rest_duration = raw_duration
|
| 69 |
+
|
| 70 |
+
for _ in range(1000):
|
| 71 |
+
if rest_duration <= 0:
|
| 72 |
+
break
|
| 73 |
+
if rest_duration <= wait_duration:
|
| 74 |
+
noise_list.append({
|
| 75 |
+
"epoch_idx": epoch_idx,
|
| 76 |
+
"filename": filename,
|
| 77 |
+
"raw_duration": round(raw_duration, 4),
|
| 78 |
+
"offset": round(offset, 4),
|
| 79 |
+
"duration": None,
|
| 80 |
+
"duration_": round(rest_duration, 4),
|
| 81 |
+
})
|
| 82 |
+
wait_duration -= rest_duration
|
| 83 |
+
offset = 0
|
| 84 |
+
rest_duration = 0
|
| 85 |
+
elif rest_duration > wait_duration:
|
| 86 |
+
noise_list.append({
|
| 87 |
+
"epoch_idx": epoch_idx,
|
| 88 |
+
"filename": filename,
|
| 89 |
+
"raw_duration": round(raw_duration, 4),
|
| 90 |
+
"offset": round(offset, 4),
|
| 91 |
+
"duration": round(wait_duration, 4),
|
| 92 |
+
"duration_": round(wait_duration, 4),
|
| 93 |
+
})
|
| 94 |
+
offset += wait_duration
|
| 95 |
+
rest_duration -= wait_duration
|
| 96 |
+
wait_duration = 0
|
| 97 |
+
else:
|
| 98 |
+
raise AssertionError
|
| 99 |
+
|
| 100 |
+
if wait_duration <= 0:
|
| 101 |
+
yield noise_list
|
| 102 |
+
noise_list = list()
|
| 103 |
+
wait_duration = duration
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def target_second_speech_signal_generator(filename_patterns: List[str],
|
| 107 |
+
min_duration: int = 4,
|
| 108 |
+
max_duration: int = 6,
|
| 109 |
+
sample_rate: int = 8000, max_epoch: int = 1):
|
| 110 |
+
for epoch_idx in range(max_epoch):
|
| 111 |
+
for filename_pattern in filename_patterns:
|
| 112 |
+
for filename in glob(filename_pattern):
|
| 113 |
+
signal, _ = librosa.load(filename, sr=sample_rate)
|
| 114 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 115 |
|
| 116 |
+
if signal.ndim != 1:
|
| 117 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 118 |
|
| 119 |
+
if raw_duration < min_duration:
|
| 120 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 121 |
+
continue
|
| 122 |
|
| 123 |
+
if raw_duration < max_duration:
|
| 124 |
+
row = {
|
|
|
|
|
|
|
|
|
|
| 125 |
"epoch_idx": epoch_idx,
|
| 126 |
+
"filename": filename,
|
| 127 |
"raw_duration": round(raw_duration, 4),
|
| 128 |
+
"offset": 0.,
|
| 129 |
+
"duration": round(raw_duration, 4),
|
| 130 |
+
}
|
| 131 |
+
yield row
|
| 132 |
+
|
| 133 |
+
signal_length = len(signal)
|
| 134 |
+
win_size = int(max_duration * sample_rate)
|
| 135 |
+
for begin in range(0, signal_length - win_size, win_size):
|
| 136 |
+
if np.sum(signal[begin: begin+win_size]) == 0:
|
| 137 |
+
continue
|
| 138 |
+
row = {
|
| 139 |
"epoch_idx": epoch_idx,
|
| 140 |
+
"filename": filename,
|
| 141 |
"raw_duration": round(raw_duration, 4),
|
| 142 |
+
"offset": round(begin / sample_rate, 4),
|
| 143 |
+
"duration": round(max_duration, 4),
|
| 144 |
+
}
|
| 145 |
+
yield row
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
def main():
|
| 149 |
args = get_args()
|
| 150 |
|
| 151 |
+
noise_patterns = args.noise_patterns
|
| 152 |
+
noise_patterns = noise_patterns.split(" ")
|
| 153 |
+
print(f"noise_patterns: {noise_patterns}")
|
| 154 |
+
speech_patterns = args.speech_patterns
|
| 155 |
+
speech_patterns = speech_patterns.split(" ")
|
| 156 |
+
print(f"speech_patterns: {speech_patterns}")
|
| 157 |
|
| 158 |
train_dataset = Path(args.train_dataset)
|
| 159 |
valid_dataset = Path(args.valid_dataset)
|
|
|
|
| 161 |
valid_dataset.parent.mkdir(parents=True, exist_ok=True)
|
| 162 |
|
| 163 |
noise_generator = target_second_noise_signal_generator(
|
| 164 |
+
noise_patterns,
|
| 165 |
duration=args.duration,
|
| 166 |
sample_rate=args.target_sample_rate,
|
| 167 |
max_epoch=100000,
|
| 168 |
)
|
| 169 |
speech_generator = target_second_speech_signal_generator(
|
| 170 |
+
speech_patterns,
|
| 171 |
min_duration=args.min_speech_duration,
|
| 172 |
max_duration=args.max_speech_duration,
|
| 173 |
sample_rate=args.target_sample_rate,
|
|
|
|
| 216 |
"random1": random1,
|
| 217 |
}
|
| 218 |
row = json.dumps(row, ensure_ascii=False)
|
| 219 |
+
if random2 < (2 / 300):
|
| 220 |
fvalid.write(f"{row}\n")
|
| 221 |
else:
|
| 222 |
ftrain.write(f"{row}\n")
|
examples/fsmn_vad_by_webrtcvad/step_2_make_vad_segments.py
CHANGED
|
@@ -4,6 +4,7 @@ import argparse
|
|
| 4 |
import json
|
| 5 |
import os
|
| 6 |
import sys
|
|
|
|
| 7 |
|
| 8 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 9 |
sys.path.append(os.path.join(pwd, "../../"))
|
|
@@ -42,6 +43,54 @@ def get_args():
|
|
| 42 |
return args
|
| 43 |
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
def main():
|
| 46 |
args = get_args()
|
| 47 |
|
|
@@ -68,8 +117,8 @@ def main():
|
|
| 68 |
end_ring_rate=0.1,
|
| 69 |
frame_size_ms=30,
|
| 70 |
frame_step_ms=30,
|
| 71 |
-
padding_length_ms=
|
| 72 |
-
max_silence_length_ms=
|
| 73 |
max_speech_length_s=100,
|
| 74 |
min_speech_length_s=0.1,
|
| 75 |
sample_rate=args.expected_sample_rate,
|
|
@@ -114,6 +163,9 @@ def main():
|
|
| 114 |
)
|
| 115 |
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
| 116 |
|
|
|
|
|
|
|
|
|
|
| 117 |
# vad
|
| 118 |
vad_segments = list()
|
| 119 |
segments = w_vad.vad(waveform)
|
|
@@ -122,6 +174,7 @@ def main():
|
|
| 122 |
vad_segments += segments
|
| 123 |
w_vad.reset()
|
| 124 |
|
|
|
|
| 125 |
row["vad_segments"] = vad_segments
|
| 126 |
|
| 127 |
row = json.dumps(row, ensure_ascii=False)
|
|
@@ -168,6 +221,9 @@ def main():
|
|
| 168 |
)
|
| 169 |
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
| 170 |
|
|
|
|
|
|
|
|
|
|
| 171 |
# vad
|
| 172 |
vad_segments = list()
|
| 173 |
segments = w_vad.vad(waveform)
|
|
@@ -176,6 +232,7 @@ def main():
|
|
| 176 |
vad_segments += segments
|
| 177 |
w_vad.reset()
|
| 178 |
|
|
|
|
| 179 |
row["vad_segments"] = vad_segments
|
| 180 |
|
| 181 |
row = json.dumps(row, ensure_ascii=False)
|
|
|
|
| 4 |
import json
|
| 5 |
import os
|
| 6 |
import sys
|
| 7 |
+
from typing import List, Tuple
|
| 8 |
|
| 9 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 10 |
sys.path.append(os.path.join(pwd, "../../"))
|
|
|
|
| 43 |
return args
|
| 44 |
|
| 45 |
|
| 46 |
+
def get_non_silence_segments(waveform: np.ndarray, sample_rate: int = 8000):
|
| 47 |
+
non_silent_intervals = librosa.effects.split(
|
| 48 |
+
waveform,
|
| 49 |
+
top_db=40, # 静音阈值(单位:dB)
|
| 50 |
+
frame_length=512, # 分析帧长
|
| 51 |
+
hop_length=128 # 帧移
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# 输出非静音段的时间区间(单位:秒)
|
| 55 |
+
result = [(start / sample_rate, end / sample_rate) for (start, end) in non_silent_intervals]
|
| 56 |
+
return result
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_intersection(non_silence: list[tuple[float, float]],
|
| 60 |
+
speech: list[tuple[float, float]]) -> list[tuple[float, float]]:
|
| 61 |
+
"""
|
| 62 |
+
计算语音段与非静音段的交集
|
| 63 |
+
:param non_silence: 非静音段列表,格式 [(start1, end1), ...]
|
| 64 |
+
:param speech: 语音检测段列表,格式 [(start2, end2), ...]
|
| 65 |
+
:return: 交集段列表,格式 [(start, end), ...]
|
| 66 |
+
"""
|
| 67 |
+
# 按起始时间排序(假设输入已排序可不排)
|
| 68 |
+
non_silence = sorted(non_silence, key=lambda x: x[0])
|
| 69 |
+
speech = sorted(speech, key=lambda x: x[0])
|
| 70 |
+
|
| 71 |
+
result = []
|
| 72 |
+
i = j = 0
|
| 73 |
+
|
| 74 |
+
while i < len(non_silence) and j < len(speech):
|
| 75 |
+
ns_start, ns_end = non_silence[i]
|
| 76 |
+
sp_start, sp_end = speech[j]
|
| 77 |
+
|
| 78 |
+
# 计算重叠区间
|
| 79 |
+
overlap_start = max(ns_start, sp_start)
|
| 80 |
+
overlap_end = min(ns_end, sp_end)
|
| 81 |
+
|
| 82 |
+
if overlap_start < overlap_end:
|
| 83 |
+
result.append((overlap_start, overlap_end))
|
| 84 |
+
|
| 85 |
+
# 移动指针策略:优先处理先结束的区间
|
| 86 |
+
if ns_end < sp_end:
|
| 87 |
+
i += 1 # 非静音段先结束
|
| 88 |
+
else:
|
| 89 |
+
j += 1 # 语音段先结束
|
| 90 |
+
|
| 91 |
+
return result
|
| 92 |
+
|
| 93 |
+
|
| 94 |
def main():
|
| 95 |
args = get_args()
|
| 96 |
|
|
|
|
| 117 |
end_ring_rate=0.1,
|
| 118 |
frame_size_ms=30,
|
| 119 |
frame_step_ms=30,
|
| 120 |
+
padding_length_ms=30,
|
| 121 |
+
max_silence_length_ms=0,
|
| 122 |
max_speech_length_s=100,
|
| 123 |
min_speech_length_s=0.1,
|
| 124 |
sample_rate=args.expected_sample_rate,
|
|
|
|
| 163 |
)
|
| 164 |
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
| 165 |
|
| 166 |
+
# non_silence_segments
|
| 167 |
+
non_silence_segments = get_non_silence_segments(waveform, sample_rate=args.expected_sample_rate)
|
| 168 |
+
|
| 169 |
# vad
|
| 170 |
vad_segments = list()
|
| 171 |
segments = w_vad.vad(waveform)
|
|
|
|
| 174 |
vad_segments += segments
|
| 175 |
w_vad.reset()
|
| 176 |
|
| 177 |
+
vad_segments = get_intersection(non_silence_segments, vad_segments)
|
| 178 |
row["vad_segments"] = vad_segments
|
| 179 |
|
| 180 |
row = json.dumps(row, ensure_ascii=False)
|
|
|
|
| 221 |
)
|
| 222 |
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
| 223 |
|
| 224 |
+
# non_silence_segments
|
| 225 |
+
non_silence_segments = get_non_silence_segments(waveform, sample_rate=args.expected_sample_rate)
|
| 226 |
+
|
| 227 |
# vad
|
| 228 |
vad_segments = list()
|
| 229 |
segments = w_vad.vad(waveform)
|
|
|
|
| 232 |
vad_segments += segments
|
| 233 |
w_vad.reset()
|
| 234 |
|
| 235 |
+
vad_segments = get_intersection(non_silence_segments, vad_segments)
|
| 236 |
row["vad_segments"] = vad_segments
|
| 237 |
|
| 238 |
row = json.dumps(row, ensure_ascii=False)
|
examples/fsmn_vad_by_webrtcvad/step_4_train_model.py
CHANGED
|
@@ -17,8 +17,6 @@ sys.path.append(os.path.join(pwd, "../../"))
|
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
| 20 |
-
import torch.nn as nn
|
| 21 |
-
from torch.nn import functional as F
|
| 22 |
from torch.utils.data.dataloader import DataLoader
|
| 23 |
from tqdm import tqdm
|
| 24 |
|
|
@@ -38,7 +36,7 @@ def get_args():
|
|
| 38 |
parser.add_argument("--valid_dataset", default="valid-vad.jsonl", type=str)
|
| 39 |
|
| 40 |
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
| 41 |
-
parser.add_argument("--patience", default=
|
| 42 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 43 |
|
| 44 |
parser.add_argument("--config_file", default="yaml/config.yaml", type=str)
|
|
@@ -74,22 +72,28 @@ class CollateFunction(object):
|
|
| 74 |
|
| 75 |
def __call__(self, batch: List[dict]):
|
| 76 |
noisy_audios = list()
|
|
|
|
| 77 |
batch_vad_segments = list()
|
| 78 |
|
| 79 |
for sample in batch:
|
| 80 |
noisy_wave: torch.Tensor = sample["noisy_wave"]
|
|
|
|
| 81 |
vad_segments: List[Tuple[float, float]] = sample["vad_segments"]
|
| 82 |
|
| 83 |
noisy_audios.append(noisy_wave)
|
|
|
|
| 84 |
batch_vad_segments.append(vad_segments)
|
| 85 |
|
| 86 |
noisy_audios = torch.stack(noisy_audios)
|
|
|
|
| 87 |
|
| 88 |
# assert
|
| 89 |
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
| 90 |
raise AssertionError("nan or inf in noisy_audios")
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
return noisy_audios, batch_vad_segments
|
| 93 |
|
| 94 |
|
| 95 |
collate_fn = CollateFunction()
|
|
@@ -214,6 +218,7 @@ def main():
|
|
| 214 |
average_loss = 1000000000
|
| 215 |
average_bce_loss = 1000000000
|
| 216 |
average_dice_loss = 1000000000
|
|
|
|
| 217 |
|
| 218 |
accuracy = -1
|
| 219 |
f1 = -1
|
|
@@ -242,6 +247,7 @@ def main():
|
|
| 242 |
total_loss = 0.
|
| 243 |
total_bce_loss = 0.
|
| 244 |
total_dice_loss = 0.
|
|
|
|
| 245 |
total_batches = 0.
|
| 246 |
|
| 247 |
progress_bar_train = tqdm(
|
|
@@ -249,19 +255,22 @@ def main():
|
|
| 249 |
desc="Training; epoch-{}".format(epoch_idx),
|
| 250 |
)
|
| 251 |
for train_batch in train_data_loader:
|
| 252 |
-
noisy_audios, batch_vad_segments = train_batch
|
| 253 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
|
|
|
| 254 |
# noisy_audios shape: [b, num_samples]
|
| 255 |
num_samples = noisy_audios.shape[-1]
|
| 256 |
|
| 257 |
-
logits, probs = model.forward(noisy_audios)
|
|
|
|
| 258 |
|
| 259 |
targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
|
| 260 |
|
| 261 |
bce_loss = bce_loss_fn.forward(probs, targets)
|
| 262 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
|
|
|
| 263 |
|
| 264 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss
|
| 265 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 266 |
logger.info(f"find nan or inf in loss. continue.")
|
| 267 |
continue
|
|
@@ -278,11 +287,13 @@ def main():
|
|
| 278 |
total_loss += loss.item()
|
| 279 |
total_bce_loss += bce_loss.item()
|
| 280 |
total_dice_loss += dice_loss.item()
|
|
|
|
| 281 |
total_batches += 1
|
| 282 |
|
| 283 |
average_loss = round(total_loss / total_batches, 4)
|
| 284 |
average_bce_loss = round(total_bce_loss / total_batches, 4)
|
| 285 |
average_dice_loss = round(total_dice_loss / total_batches, 4)
|
|
|
|
| 286 |
|
| 287 |
metrics = vad_accuracy_metrics_fn.get_metric()
|
| 288 |
accuracy = metrics["accuracy"]
|
|
@@ -297,6 +308,7 @@ def main():
|
|
| 297 |
"loss": average_loss,
|
| 298 |
"bce_loss": average_bce_loss,
|
| 299 |
"dice_loss": average_dice_loss,
|
|
|
|
| 300 |
"accuracy": accuracy,
|
| 301 |
"f1": f1,
|
| 302 |
"precision": precision,
|
|
@@ -316,6 +328,7 @@ def main():
|
|
| 316 |
total_loss = 0.
|
| 317 |
total_bce_loss = 0.
|
| 318 |
total_dice_loss = 0.
|
|
|
|
| 319 |
total_batches = 0.
|
| 320 |
|
| 321 |
progress_bar_train.close()
|
|
@@ -323,19 +336,22 @@ def main():
|
|
| 323 |
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
| 324 |
)
|
| 325 |
for eval_batch in valid_data_loader:
|
| 326 |
-
noisy_audios, batch_vad_segments = eval_batch
|
| 327 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
|
|
|
| 328 |
# noisy_audios shape: [b, num_samples]
|
| 329 |
num_samples = noisy_audios.shape[-1]
|
| 330 |
|
| 331 |
-
logits, probs = model.forward(noisy_audios)
|
|
|
|
| 332 |
|
| 333 |
targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
|
| 334 |
|
| 335 |
bce_loss = bce_loss_fn.forward(probs, targets)
|
| 336 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
|
|
|
| 337 |
|
| 338 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss
|
| 339 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 340 |
logger.info(f"find nan or inf in loss. continue.")
|
| 341 |
continue
|
|
@@ -346,11 +362,13 @@ def main():
|
|
| 346 |
total_loss += loss.item()
|
| 347 |
total_bce_loss += bce_loss.item()
|
| 348 |
total_dice_loss += dice_loss.item()
|
|
|
|
| 349 |
total_batches += 1
|
| 350 |
|
| 351 |
average_loss = round(total_loss / total_batches, 4)
|
| 352 |
average_bce_loss = round(total_bce_loss / total_batches, 4)
|
| 353 |
average_dice_loss = round(total_dice_loss / total_batches, 4)
|
|
|
|
| 354 |
|
| 355 |
metrics = vad_accuracy_metrics_fn.get_metric()
|
| 356 |
accuracy = metrics["accuracy"]
|
|
@@ -365,6 +383,7 @@ def main():
|
|
| 365 |
"loss": average_loss,
|
| 366 |
"bce_loss": average_bce_loss,
|
| 367 |
"dice_loss": average_dice_loss,
|
|
|
|
| 368 |
"accuracy": accuracy,
|
| 369 |
"f1": f1,
|
| 370 |
"precision": precision,
|
|
@@ -378,6 +397,7 @@ def main():
|
|
| 378 |
total_loss = 0.
|
| 379 |
total_bce_loss = 0.
|
| 380 |
total_dice_loss = 0.
|
|
|
|
| 381 |
total_batches = 0.
|
| 382 |
|
| 383 |
progress_bar_eval.close()
|
|
@@ -419,8 +439,12 @@ def main():
|
|
| 419 |
"loss": average_loss,
|
| 420 |
"bce_loss": average_bce_loss,
|
| 421 |
"dice_loss": average_dice_loss,
|
|
|
|
| 422 |
|
| 423 |
"accuracy": accuracy,
|
|
|
|
|
|
|
|
|
|
| 424 |
}
|
| 425 |
metrics_filename = save_dir / "metrics_epoch.json"
|
| 426 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|
|
|
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
|
|
|
|
|
|
| 20 |
from torch.utils.data.dataloader import DataLoader
|
| 21 |
from tqdm import tqdm
|
| 22 |
|
|
|
|
| 36 |
parser.add_argument("--valid_dataset", default="valid-vad.jsonl", type=str)
|
| 37 |
|
| 38 |
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
| 39 |
+
parser.add_argument("--patience", default=10, type=int)
|
| 40 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 41 |
|
| 42 |
parser.add_argument("--config_file", default="yaml/config.yaml", type=str)
|
|
|
|
| 72 |
|
| 73 |
def __call__(self, batch: List[dict]):
|
| 74 |
noisy_audios = list()
|
| 75 |
+
clean_audios = list()
|
| 76 |
batch_vad_segments = list()
|
| 77 |
|
| 78 |
for sample in batch:
|
| 79 |
noisy_wave: torch.Tensor = sample["noisy_wave"]
|
| 80 |
+
clean_wave: torch.Tensor = sample["clean_wave"]
|
| 81 |
vad_segments: List[Tuple[float, float]] = sample["vad_segments"]
|
| 82 |
|
| 83 |
noisy_audios.append(noisy_wave)
|
| 84 |
+
clean_audios.append(clean_wave)
|
| 85 |
batch_vad_segments.append(vad_segments)
|
| 86 |
|
| 87 |
noisy_audios = torch.stack(noisy_audios)
|
| 88 |
+
clean_audios = torch.stack(clean_audios)
|
| 89 |
|
| 90 |
# assert
|
| 91 |
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
| 92 |
raise AssertionError("nan or inf in noisy_audios")
|
| 93 |
+
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
|
| 94 |
+
raise AssertionError("nan or inf in clean_audios")
|
| 95 |
|
| 96 |
+
return noisy_audios, clean_audios, batch_vad_segments
|
| 97 |
|
| 98 |
|
| 99 |
collate_fn = CollateFunction()
|
|
|
|
| 218 |
average_loss = 1000000000
|
| 219 |
average_bce_loss = 1000000000
|
| 220 |
average_dice_loss = 1000000000
|
| 221 |
+
average_lsnr_loss = 1000000000
|
| 222 |
|
| 223 |
accuracy = -1
|
| 224 |
f1 = -1
|
|
|
|
| 247 |
total_loss = 0.
|
| 248 |
total_bce_loss = 0.
|
| 249 |
total_dice_loss = 0.
|
| 250 |
+
total_lsnr_loss = 0.
|
| 251 |
total_batches = 0.
|
| 252 |
|
| 253 |
progress_bar_train = tqdm(
|
|
|
|
| 255 |
desc="Training; epoch-{}".format(epoch_idx),
|
| 256 |
)
|
| 257 |
for train_batch in train_data_loader:
|
| 258 |
+
noisy_audios, clean_audios, batch_vad_segments = train_batch
|
| 259 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
| 260 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
| 261 |
# noisy_audios shape: [b, num_samples]
|
| 262 |
num_samples = noisy_audios.shape[-1]
|
| 263 |
|
| 264 |
+
logits, probs, lsnr = model.forward(noisy_audios)
|
| 265 |
+
lsnr = torch.squeeze(lsnr, dim=-1)
|
| 266 |
|
| 267 |
targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
|
| 268 |
|
| 269 |
bce_loss = bce_loss_fn.forward(probs, targets)
|
| 270 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
| 271 |
+
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
| 272 |
|
| 273 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.03 * lsnr_loss
|
| 274 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 275 |
logger.info(f"find nan or inf in loss. continue.")
|
| 276 |
continue
|
|
|
|
| 287 |
total_loss += loss.item()
|
| 288 |
total_bce_loss += bce_loss.item()
|
| 289 |
total_dice_loss += dice_loss.item()
|
| 290 |
+
total_lsnr_loss += lsnr_loss.item()
|
| 291 |
total_batches += 1
|
| 292 |
|
| 293 |
average_loss = round(total_loss / total_batches, 4)
|
| 294 |
average_bce_loss = round(total_bce_loss / total_batches, 4)
|
| 295 |
average_dice_loss = round(total_dice_loss / total_batches, 4)
|
| 296 |
+
average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
|
| 297 |
|
| 298 |
metrics = vad_accuracy_metrics_fn.get_metric()
|
| 299 |
accuracy = metrics["accuracy"]
|
|
|
|
| 308 |
"loss": average_loss,
|
| 309 |
"bce_loss": average_bce_loss,
|
| 310 |
"dice_loss": average_dice_loss,
|
| 311 |
+
"lsnr_loss": average_lsnr_loss,
|
| 312 |
"accuracy": accuracy,
|
| 313 |
"f1": f1,
|
| 314 |
"precision": precision,
|
|
|
|
| 328 |
total_loss = 0.
|
| 329 |
total_bce_loss = 0.
|
| 330 |
total_dice_loss = 0.
|
| 331 |
+
total_lsnr_loss = 0.
|
| 332 |
total_batches = 0.
|
| 333 |
|
| 334 |
progress_bar_train.close()
|
|
|
|
| 336 |
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
| 337 |
)
|
| 338 |
for eval_batch in valid_data_loader:
|
| 339 |
+
noisy_audios, clean_audios, batch_vad_segments = eval_batch
|
| 340 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
| 341 |
+
clean_audios: torch.Tensor = clean_audios.to(device)
|
| 342 |
# noisy_audios shape: [b, num_samples]
|
| 343 |
num_samples = noisy_audios.shape[-1]
|
| 344 |
|
| 345 |
+
logits, probs, lsnr = model.forward(noisy_audios)
|
| 346 |
+
lsnr = torch.squeeze(lsnr, dim=-1)
|
| 347 |
|
| 348 |
targets = BaseVadLoss.get_targets(probs, batch_vad_segments, duration=num_samples / config.sample_rate)
|
| 349 |
|
| 350 |
bce_loss = bce_loss_fn.forward(probs, targets)
|
| 351 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
| 352 |
+
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
| 353 |
|
| 354 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.03 * lsnr_loss
|
| 355 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 356 |
logger.info(f"find nan or inf in loss. continue.")
|
| 357 |
continue
|
|
|
|
| 362 |
total_loss += loss.item()
|
| 363 |
total_bce_loss += bce_loss.item()
|
| 364 |
total_dice_loss += dice_loss.item()
|
| 365 |
+
total_lsnr_loss += lsnr_loss.item()
|
| 366 |
total_batches += 1
|
| 367 |
|
| 368 |
average_loss = round(total_loss / total_batches, 4)
|
| 369 |
average_bce_loss = round(total_bce_loss / total_batches, 4)
|
| 370 |
average_dice_loss = round(total_dice_loss / total_batches, 4)
|
| 371 |
+
average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
|
| 372 |
|
| 373 |
metrics = vad_accuracy_metrics_fn.get_metric()
|
| 374 |
accuracy = metrics["accuracy"]
|
|
|
|
| 383 |
"loss": average_loss,
|
| 384 |
"bce_loss": average_bce_loss,
|
| 385 |
"dice_loss": average_dice_loss,
|
| 386 |
+
"lsnr_loss": average_lsnr_loss,
|
| 387 |
"accuracy": accuracy,
|
| 388 |
"f1": f1,
|
| 389 |
"precision": precision,
|
|
|
|
| 397 |
total_loss = 0.
|
| 398 |
total_bce_loss = 0.
|
| 399 |
total_dice_loss = 0.
|
| 400 |
+
total_lsnr_loss = 0.
|
| 401 |
total_batches = 0.
|
| 402 |
|
| 403 |
progress_bar_eval.close()
|
|
|
|
| 439 |
"loss": average_loss,
|
| 440 |
"bce_loss": average_bce_loss,
|
| 441 |
"dice_loss": average_dice_loss,
|
| 442 |
+
"lsnr_loss": average_lsnr_loss,
|
| 443 |
|
| 444 |
"accuracy": accuracy,
|
| 445 |
+
"f1": f1,
|
| 446 |
+
"precision": precision,
|
| 447 |
+
"recall": recall,
|
| 448 |
}
|
| 449 |
metrics_filename = save_dir / "metrics_epoch.json"
|
| 450 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|
examples/fsmn_vad_by_webrtcvad/yaml/config.yaml
CHANGED
|
@@ -18,9 +18,13 @@ fsmn_basic_block_rorder: 0
|
|
| 18 |
fsmn_basic_block_lstride: 1
|
| 19 |
fsmn_basic_block_rstride: 0
|
| 20 |
fsmn_output_affine_size: 140
|
| 21 |
-
fsmn_output_size:
|
| 22 |
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# data
|
| 26 |
min_snr_db: -10
|
|
|
|
| 18 |
fsmn_basic_block_lstride: 1
|
| 19 |
fsmn_basic_block_rstride: 0
|
| 20 |
fsmn_output_affine_size: 140
|
| 21 |
+
fsmn_output_size: 2
|
| 22 |
|
| 23 |
+
# lsnr
|
| 24 |
+
n_frame: 3
|
| 25 |
+
min_local_snr_db: -15
|
| 26 |
+
max_local_snr_db: 30
|
| 27 |
+
norm_tau: 1.
|
| 28 |
|
| 29 |
# data
|
| 30 |
min_snr_db: -10
|
examples/silero_vad_by_webrtcvad/run.sh
CHANGED
|
@@ -5,14 +5,16 @@
|
|
| 5 |
bash run.sh --stage 2 --stop_stage 2 --system_version centos \
|
| 6 |
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
| 7 |
--final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
|
| 8 |
-
--
|
| 9 |
-
--
|
|
|
|
| 10 |
|
| 11 |
bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
| 12 |
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
| 13 |
--final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
|
| 14 |
-
--
|
| 15 |
-
--
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
END
|
|
@@ -30,8 +32,8 @@ final_model_name=final_model_name
|
|
| 30 |
config_file="yaml/config.yaml"
|
| 31 |
limit=10
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
|
| 36 |
max_count=-1
|
| 37 |
|
|
@@ -98,8 +100,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
|
| 98 |
$verbose && echo "stage 1: prepare data"
|
| 99 |
cd "${work_dir}" || exit 1
|
| 100 |
python3 step_1_prepare_data.py \
|
| 101 |
-
--
|
| 102 |
-
--
|
| 103 |
--train_dataset "${train_dataset}" \
|
| 104 |
--valid_dataset "${valid_dataset}" \
|
| 105 |
--max_count "${max_count}" \
|
|
|
|
| 5 |
bash run.sh --stage 2 --stop_stage 2 --system_version centos \
|
| 6 |
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
| 7 |
--final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
|
| 8 |
+
--noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
| 9 |
+
--speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
| 10 |
+
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
| 11 |
|
| 12 |
bash run.sh --stage 3 --stop_stage 3 --system_version centos \
|
| 13 |
--file_folder_name silero-vad-by-webrtcvad-nx2-dns3 \
|
| 14 |
--final_model_name silero-vad-by-webrtcvad-nx2-dns3 \
|
| 15 |
+
--noise_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav" \
|
| 16 |
+
--speech_patterns "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech/**/*.wav \
|
| 17 |
+
/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2/**/*.wav"
|
| 18 |
|
| 19 |
|
| 20 |
END
|
|
|
|
| 32 |
config_file="yaml/config.yaml"
|
| 33 |
limit=10
|
| 34 |
|
| 35 |
+
noise_patterns=/data/tianxing/HuggingDatasets/nx_noise/data/noise/**/*.wav
|
| 36 |
+
speech_patterns=/data/tianxing/HuggingDatasets/nx_noise/data/speech/**/*.wav
|
| 37 |
|
| 38 |
max_count=-1
|
| 39 |
|
|
|
|
| 100 |
$verbose && echo "stage 1: prepare data"
|
| 101 |
cd "${work_dir}" || exit 1
|
| 102 |
python3 step_1_prepare_data.py \
|
| 103 |
+
--noise_patterns "${noise_patterns}" \
|
| 104 |
+
--speech_patterns "${speech_patterns}" \
|
| 105 |
--train_dataset "${train_dataset}" \
|
| 106 |
--valid_dataset "${valid_dataset}" \
|
| 107 |
--max_count "${max_count}" \
|
examples/silero_vad_by_webrtcvad/step_1_prepare_data.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
| 1 |
#!/usr/bin/python3
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
import argparse
|
|
|
|
| 4 |
import json
|
| 5 |
import os
|
| 6 |
from pathlib import Path
|
| 7 |
import random
|
| 8 |
import sys
|
| 9 |
import time
|
|
|
|
| 10 |
|
| 11 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 12 |
sys.path.append(os.path.join(pwd, "../../"))
|
|
@@ -19,13 +21,13 @@ from tqdm import tqdm
|
|
| 19 |
def get_args():
|
| 20 |
parser = argparse.ArgumentParser()
|
| 21 |
parser.add_argument(
|
| 22 |
-
"--
|
| 23 |
-
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
| 24 |
type=str
|
| 25 |
)
|
| 26 |
parser.add_argument(
|
| 27 |
-
"--
|
| 28 |
-
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech",
|
| 29 |
type=str
|
| 30 |
)
|
| 31 |
|
|
@@ -46,108 +48,112 @@ def get_args():
|
|
| 46 |
return args
|
| 47 |
|
| 48 |
|
| 49 |
-
def target_second_noise_signal_generator(
|
| 50 |
duration: int = 4,
|
| 51 |
sample_rate: int = 8000, max_epoch: int = 20000):
|
| 52 |
noise_list = list()
|
| 53 |
wait_duration = duration
|
| 54 |
|
| 55 |
-
data_dir = Path(data_dir)
|
| 56 |
for epoch_idx in range(max_epoch):
|
| 57 |
-
for
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
|
|
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
|
|
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
break
|
| 71 |
-
if rest_duration <= wait_duration:
|
| 72 |
-
noise_list.append({
|
| 73 |
"epoch_idx": epoch_idx,
|
| 74 |
-
"filename": filename
|
| 75 |
"raw_duration": round(raw_duration, 4),
|
| 76 |
-
"offset":
|
| 77 |
-
"duration":
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
"epoch_idx": epoch_idx,
|
| 86 |
-
"filename": filename
|
| 87 |
"raw_duration": round(raw_duration, 4),
|
| 88 |
-
"offset": round(
|
| 89 |
-
"duration": round(
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
offset += wait_duration
|
| 93 |
-
rest_duration -= wait_duration
|
| 94 |
-
wait_duration = 0
|
| 95 |
-
else:
|
| 96 |
-
raise AssertionError
|
| 97 |
-
|
| 98 |
-
if wait_duration <= 0:
|
| 99 |
-
yield noise_list
|
| 100 |
-
noise_list = list()
|
| 101 |
-
wait_duration = duration
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def target_second_speech_signal_generator(data_dir: str,
|
| 105 |
-
min_duration: int = 4,
|
| 106 |
-
max_duration: int = 6,
|
| 107 |
-
sample_rate: int = 8000, max_epoch: int = 1):
|
| 108 |
-
data_dir = Path(data_dir)
|
| 109 |
-
for epoch_idx in range(max_epoch):
|
| 110 |
-
for filename in data_dir.glob("**/*.wav"):
|
| 111 |
-
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
| 112 |
-
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 113 |
-
|
| 114 |
-
if signal.ndim != 1:
|
| 115 |
-
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 116 |
-
|
| 117 |
-
if raw_duration < min_duration:
|
| 118 |
-
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 119 |
-
continue
|
| 120 |
-
|
| 121 |
-
if raw_duration < max_duration:
|
| 122 |
-
row = {
|
| 123 |
-
"epoch_idx": epoch_idx,
|
| 124 |
-
"filename": filename.as_posix(),
|
| 125 |
-
"raw_duration": round(raw_duration, 4),
|
| 126 |
-
"offset": 0.,
|
| 127 |
-
"duration": round(raw_duration, 4),
|
| 128 |
-
}
|
| 129 |
-
yield row
|
| 130 |
-
|
| 131 |
-
signal_length = len(signal)
|
| 132 |
-
win_size = int(max_duration * sample_rate)
|
| 133 |
-
for begin in range(0, signal_length - win_size, win_size):
|
| 134 |
-
if np.sum(signal[begin: begin+win_size]) == 0:
|
| 135 |
-
continue
|
| 136 |
-
row = {
|
| 137 |
-
"epoch_idx": epoch_idx,
|
| 138 |
-
"filename": filename.as_posix(),
|
| 139 |
-
"raw_duration": round(raw_duration, 4),
|
| 140 |
-
"offset": round(begin / sample_rate, 4),
|
| 141 |
-
"duration": round(max_duration, 4),
|
| 142 |
-
}
|
| 143 |
-
yield row
|
| 144 |
|
| 145 |
|
| 146 |
def main():
|
| 147 |
args = get_args()
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
train_dataset = Path(args.train_dataset)
|
| 153 |
valid_dataset = Path(args.valid_dataset)
|
|
@@ -155,13 +161,13 @@ def main():
|
|
| 155 |
valid_dataset.parent.mkdir(parents=True, exist_ok=True)
|
| 156 |
|
| 157 |
noise_generator = target_second_noise_signal_generator(
|
| 158 |
-
|
| 159 |
duration=args.duration,
|
| 160 |
sample_rate=args.target_sample_rate,
|
| 161 |
max_epoch=100000,
|
| 162 |
)
|
| 163 |
speech_generator = target_second_speech_signal_generator(
|
| 164 |
-
|
| 165 |
min_duration=args.min_speech_duration,
|
| 166 |
max_duration=args.max_speech_duration,
|
| 167 |
sample_rate=args.target_sample_rate,
|
|
@@ -210,7 +216,7 @@ def main():
|
|
| 210 |
"random1": random1,
|
| 211 |
}
|
| 212 |
row = json.dumps(row, ensure_ascii=False)
|
| 213 |
-
if random2 < (
|
| 214 |
fvalid.write(f"{row}\n")
|
| 215 |
else:
|
| 216 |
ftrain.write(f"{row}\n")
|
|
|
|
| 1 |
#!/usr/bin/python3
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
import argparse
|
| 4 |
+
from glob import glob
|
| 5 |
import json
|
| 6 |
import os
|
| 7 |
from pathlib import Path
|
| 8 |
import random
|
| 9 |
import sys
|
| 10 |
import time
|
| 11 |
+
from typing import List
|
| 12 |
|
| 13 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 14 |
sys.path.append(os.path.join(pwd, "../../"))
|
|
|
|
| 21 |
def get_args():
|
| 22 |
parser = argparse.ArgumentParser()
|
| 23 |
parser.add_argument(
|
| 24 |
+
"--noise_patterns",
|
| 25 |
+
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\**\*.wav",
|
| 26 |
type=str
|
| 27 |
)
|
| 28 |
parser.add_argument(
|
| 29 |
+
"--speech_patterns",
|
| 30 |
+
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\**\*.wav",
|
| 31 |
type=str
|
| 32 |
)
|
| 33 |
|
|
|
|
| 48 |
return args
|
| 49 |
|
| 50 |
|
| 51 |
+
def target_second_noise_signal_generator(filename_patterns: List[str],
|
| 52 |
duration: int = 4,
|
| 53 |
sample_rate: int = 8000, max_epoch: int = 20000):
|
| 54 |
noise_list = list()
|
| 55 |
wait_duration = duration
|
| 56 |
|
|
|
|
| 57 |
for epoch_idx in range(max_epoch):
|
| 58 |
+
for filename_pattern in filename_patterns:
|
| 59 |
+
for filename in glob(filename_pattern):
|
| 60 |
+
signal, _ = librosa.load(filename, sr=sample_rate)
|
| 61 |
+
|
| 62 |
+
if signal.ndim != 1:
|
| 63 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 64 |
+
|
| 65 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 66 |
+
|
| 67 |
+
offset = 0.
|
| 68 |
+
rest_duration = raw_duration
|
| 69 |
+
|
| 70 |
+
for _ in range(1000):
|
| 71 |
+
if rest_duration <= 0:
|
| 72 |
+
break
|
| 73 |
+
if rest_duration <= wait_duration:
|
| 74 |
+
noise_list.append({
|
| 75 |
+
"epoch_idx": epoch_idx,
|
| 76 |
+
"filename": filename,
|
| 77 |
+
"raw_duration": round(raw_duration, 4),
|
| 78 |
+
"offset": round(offset, 4),
|
| 79 |
+
"duration": None,
|
| 80 |
+
"duration_": round(rest_duration, 4),
|
| 81 |
+
})
|
| 82 |
+
wait_duration -= rest_duration
|
| 83 |
+
offset = 0
|
| 84 |
+
rest_duration = 0
|
| 85 |
+
elif rest_duration > wait_duration:
|
| 86 |
+
noise_list.append({
|
| 87 |
+
"epoch_idx": epoch_idx,
|
| 88 |
+
"filename": filename,
|
| 89 |
+
"raw_duration": round(raw_duration, 4),
|
| 90 |
+
"offset": round(offset, 4),
|
| 91 |
+
"duration": round(wait_duration, 4),
|
| 92 |
+
"duration_": round(wait_duration, 4),
|
| 93 |
+
})
|
| 94 |
+
offset += wait_duration
|
| 95 |
+
rest_duration -= wait_duration
|
| 96 |
+
wait_duration = 0
|
| 97 |
+
else:
|
| 98 |
+
raise AssertionError
|
| 99 |
+
|
| 100 |
+
if wait_duration <= 0:
|
| 101 |
+
yield noise_list
|
| 102 |
+
noise_list = list()
|
| 103 |
+
wait_duration = duration
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def target_second_speech_signal_generator(filename_patterns: List[str],
|
| 107 |
+
min_duration: int = 4,
|
| 108 |
+
max_duration: int = 6,
|
| 109 |
+
sample_rate: int = 8000, max_epoch: int = 1):
|
| 110 |
+
for epoch_idx in range(max_epoch):
|
| 111 |
+
for filename_pattern in filename_patterns:
|
| 112 |
+
for filename in glob(filename_pattern):
|
| 113 |
+
signal, _ = librosa.load(filename, sr=sample_rate)
|
| 114 |
+
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 115 |
|
| 116 |
+
if signal.ndim != 1:
|
| 117 |
+
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 118 |
|
| 119 |
+
if raw_duration < min_duration:
|
| 120 |
+
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 121 |
+
continue
|
| 122 |
|
| 123 |
+
if raw_duration < max_duration:
|
| 124 |
+
row = {
|
|
|
|
|
|
|
|
|
|
| 125 |
"epoch_idx": epoch_idx,
|
| 126 |
+
"filename": filename,
|
| 127 |
"raw_duration": round(raw_duration, 4),
|
| 128 |
+
"offset": 0.,
|
| 129 |
+
"duration": round(raw_duration, 4),
|
| 130 |
+
}
|
| 131 |
+
yield row
|
| 132 |
+
|
| 133 |
+
signal_length = len(signal)
|
| 134 |
+
win_size = int(max_duration * sample_rate)
|
| 135 |
+
for begin in range(0, signal_length - win_size, win_size):
|
| 136 |
+
if np.sum(signal[begin: begin+win_size]) == 0:
|
| 137 |
+
continue
|
| 138 |
+
row = {
|
| 139 |
"epoch_idx": epoch_idx,
|
| 140 |
+
"filename": filename,
|
| 141 |
"raw_duration": round(raw_duration, 4),
|
| 142 |
+
"offset": round(begin / sample_rate, 4),
|
| 143 |
+
"duration": round(max_duration, 4),
|
| 144 |
+
}
|
| 145 |
+
yield row
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
def main():
|
| 149 |
args = get_args()
|
| 150 |
|
| 151 |
+
noise_patterns = args.noise_patterns
|
| 152 |
+
noise_patterns = noise_patterns.split(" ")
|
| 153 |
+
print(f"noise_patterns: {noise_patterns}")
|
| 154 |
+
speech_patterns = args.speech_patterns
|
| 155 |
+
speech_patterns = speech_patterns.split(" ")
|
| 156 |
+
print(f"speech_patterns: {speech_patterns}")
|
| 157 |
|
| 158 |
train_dataset = Path(args.train_dataset)
|
| 159 |
valid_dataset = Path(args.valid_dataset)
|
|
|
|
| 161 |
valid_dataset.parent.mkdir(parents=True, exist_ok=True)
|
| 162 |
|
| 163 |
noise_generator = target_second_noise_signal_generator(
|
| 164 |
+
noise_patterns,
|
| 165 |
duration=args.duration,
|
| 166 |
sample_rate=args.target_sample_rate,
|
| 167 |
max_epoch=100000,
|
| 168 |
)
|
| 169 |
speech_generator = target_second_speech_signal_generator(
|
| 170 |
+
speech_patterns,
|
| 171 |
min_duration=args.min_speech_duration,
|
| 172 |
max_duration=args.max_speech_duration,
|
| 173 |
sample_rate=args.target_sample_rate,
|
|
|
|
| 216 |
"random1": random1,
|
| 217 |
}
|
| 218 |
row = json.dumps(row, ensure_ascii=False)
|
| 219 |
+
if random2 < (2 / 300):
|
| 220 |
fvalid.write(f"{row}\n")
|
| 221 |
else:
|
| 222 |
ftrain.write(f"{row}\n")
|
examples/silero_vad_by_webrtcvad/step_4_train_model.py
CHANGED
|
@@ -17,8 +17,6 @@ sys.path.append(os.path.join(pwd, "../../"))
|
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
| 20 |
-
import torch.nn as nn
|
| 21 |
-
from torch.nn import functional as F
|
| 22 |
from torch.utils.data.dataloader import DataLoader
|
| 23 |
from tqdm import tqdm
|
| 24 |
|
|
@@ -38,7 +36,7 @@ def get_args():
|
|
| 38 |
parser.add_argument("--valid_dataset", default="valid-vad.jsonl", type=str)
|
| 39 |
|
| 40 |
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
| 41 |
-
parser.add_argument("--patience", default=
|
| 42 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 43 |
|
| 44 |
parser.add_argument("--config_file", default="yaml/config.yaml", type=str)
|
|
@@ -74,22 +72,28 @@ class CollateFunction(object):
|
|
| 74 |
|
| 75 |
def __call__(self, batch: List[dict]):
|
| 76 |
noisy_audios = list()
|
|
|
|
| 77 |
batch_vad_segments = list()
|
| 78 |
|
| 79 |
for sample in batch:
|
| 80 |
noisy_wave: torch.Tensor = sample["noisy_wave"]
|
|
|
|
| 81 |
vad_segments: List[Tuple[float, float]] = sample["vad_segments"]
|
| 82 |
|
| 83 |
noisy_audios.append(noisy_wave)
|
|
|
|
| 84 |
batch_vad_segments.append(vad_segments)
|
| 85 |
|
| 86 |
noisy_audios = torch.stack(noisy_audios)
|
|
|
|
| 87 |
|
| 88 |
# assert
|
| 89 |
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
| 90 |
raise AssertionError("nan or inf in noisy_audios")
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
return noisy_audios, batch_vad_segments
|
| 93 |
|
| 94 |
|
| 95 |
collate_fn = CollateFunction()
|
|
@@ -214,6 +218,7 @@ def main():
|
|
| 214 |
average_loss = 1000000000
|
| 215 |
average_bce_loss = 1000000000
|
| 216 |
average_dice_loss = 1000000000
|
|
|
|
| 217 |
|
| 218 |
accuracy = -1
|
| 219 |
f1 = -1
|
|
@@ -242,6 +247,7 @@ def main():
|
|
| 242 |
total_loss = 0.
|
| 243 |
total_bce_loss = 0.
|
| 244 |
total_dice_loss = 0.
|
|
|
|
| 245 |
total_batches = 0.
|
| 246 |
|
| 247 |
progress_bar_train = tqdm(
|
|
|
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
|
|
|
|
|
|
| 20 |
from torch.utils.data.dataloader import DataLoader
|
| 21 |
from tqdm import tqdm
|
| 22 |
|
|
|
|
| 36 |
parser.add_argument("--valid_dataset", default="valid-vad.jsonl", type=str)
|
| 37 |
|
| 38 |
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
| 39 |
+
parser.add_argument("--patience", default=10, type=int)
|
| 40 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 41 |
|
| 42 |
parser.add_argument("--config_file", default="yaml/config.yaml", type=str)
|
|
|
|
| 72 |
|
| 73 |
def __call__(self, batch: List[dict]):
|
| 74 |
noisy_audios = list()
|
| 75 |
+
clean_audios = list()
|
| 76 |
batch_vad_segments = list()
|
| 77 |
|
| 78 |
for sample in batch:
|
| 79 |
noisy_wave: torch.Tensor = sample["noisy_wave"]
|
| 80 |
+
clean_wave: torch.Tensor = sample["clean_wave"]
|
| 81 |
vad_segments: List[Tuple[float, float]] = sample["vad_segments"]
|
| 82 |
|
| 83 |
noisy_audios.append(noisy_wave)
|
| 84 |
+
clean_audios.append(clean_wave)
|
| 85 |
batch_vad_segments.append(vad_segments)
|
| 86 |
|
| 87 |
noisy_audios = torch.stack(noisy_audios)
|
| 88 |
+
clean_audios = torch.stack(clean_audios)
|
| 89 |
|
| 90 |
# assert
|
| 91 |
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
| 92 |
raise AssertionError("nan or inf in noisy_audios")
|
| 93 |
+
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
|
| 94 |
+
raise AssertionError("nan or inf in clean_audios")
|
| 95 |
|
| 96 |
+
return noisy_audios, clean_audios, batch_vad_segments
|
| 97 |
|
| 98 |
|
| 99 |
collate_fn = CollateFunction()
|
|
|
|
| 218 |
average_loss = 1000000000
|
| 219 |
average_bce_loss = 1000000000
|
| 220 |
average_dice_loss = 1000000000
|
| 221 |
+
average_lsnr_loss = 1000000000
|
| 222 |
|
| 223 |
accuracy = -1
|
| 224 |
f1 = -1
|
|
|
|
| 247 |
total_loss = 0.
|
| 248 |
total_bce_loss = 0.
|
| 249 |
total_dice_loss = 0.
|
| 250 |
+
total_lsnr_loss = 0.
|
| 251 |
total_batches = 0.
|
| 252 |
|
| 253 |
progress_bar_train = tqdm(
|
examples/silero_vad_by_webrtcvad/yaml/config.yaml
CHANGED
|
@@ -11,6 +11,12 @@ win_type: hann
|
|
| 11 |
in_channels: 64
|
| 12 |
hidden_size: 128
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# data
|
| 15 |
min_snr_db: -10
|
| 16 |
max_snr_db: 20
|
|
|
|
| 11 |
in_channels: 64
|
| 12 |
hidden_size: 128
|
| 13 |
|
| 14 |
+
# lsnr
|
| 15 |
+
n_frame: 3
|
| 16 |
+
min_local_snr_db: -15
|
| 17 |
+
max_local_snr_db: 30
|
| 18 |
+
norm_tau: 1.
|
| 19 |
+
|
| 20 |
# data
|
| 21 |
min_snr_db: -10
|
| 22 |
max_snr_db: 20
|
toolbox/torchaudio/models/vad/cnn_vad/inference_cnn_vad.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import argparse
|
| 4 |
+
import logging
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import shutil
|
| 7 |
+
import tempfile, time
|
| 8 |
+
from typing import List
|
| 9 |
+
import zipfile
|
| 10 |
+
|
| 11 |
+
from scipy.io import wavfile
|
| 12 |
+
import librosa
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torchaudio
|
| 16 |
+
|
| 17 |
+
torch.set_num_threads(1)
|
| 18 |
+
|
| 19 |
+
from project_settings import project_path
|
| 20 |
+
from toolbox.torchaudio.models.vad.cnn_vad.configuration_cnn_vad import CNNVadConfig
|
| 21 |
+
from toolbox.torchaudio.models.vad.cnn_vad.modeling_cnn_vad import CNNVadPretrainedModel, MODEL_FILE
|
| 22 |
+
from toolbox.torchaudio.utils.visualization import process_speech_probs, make_visualization
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger("toolbox")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class InferenceSileroVad(object):
|
| 29 |
+
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
|
| 30 |
+
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
| 31 |
+
self.device = torch.device(device)
|
| 32 |
+
|
| 33 |
+
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
|
| 34 |
+
config, model = self.load_models(self.pretrained_model_path_or_zip_file)
|
| 35 |
+
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
|
| 36 |
+
|
| 37 |
+
self.config = config
|
| 38 |
+
self.model = model
|
| 39 |
+
self.model.to(device)
|
| 40 |
+
self.model.eval()
|
| 41 |
+
|
| 42 |
+
def load_models(self, model_path: str):
|
| 43 |
+
model_path = Path(model_path)
|
| 44 |
+
if model_path.name.endswith(".zip"):
|
| 45 |
+
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
|
| 46 |
+
out_root = Path(tempfile.gettempdir()) / "cc_vad"
|
| 47 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
| 48 |
+
f_zip.extractall(path=out_root)
|
| 49 |
+
model_path = out_root / model_path.stem
|
| 50 |
+
|
| 51 |
+
config = CNNVadConfig.from_pretrained(
|
| 52 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
| 53 |
+
)
|
| 54 |
+
model = CNNVadPretrainedModel.from_pretrained(
|
| 55 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
| 56 |
+
)
|
| 57 |
+
model.to(self.device)
|
| 58 |
+
model.eval()
|
| 59 |
+
|
| 60 |
+
shutil.rmtree(model_path)
|
| 61 |
+
return config, model
|
| 62 |
+
|
| 63 |
+
def infer(self, signal: torch.Tensor) -> float:
|
| 64 |
+
# signal shape: [num_samples,], value between -1 and 1.
|
| 65 |
+
|
| 66 |
+
inputs = torch.tensor(signal, dtype=torch.float32)
|
| 67 |
+
inputs = torch.unsqueeze(inputs, dim=0)
|
| 68 |
+
# inputs shape: [1, num_samples,]
|
| 69 |
+
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
logits, probs, lsnr = self.model.forward(inputs)
|
| 72 |
+
|
| 73 |
+
# probs shape: [b, t, 1]
|
| 74 |
+
probs = torch.squeeze(probs, dim=-1)
|
| 75 |
+
# probs shape: [b, t]
|
| 76 |
+
|
| 77 |
+
probs = probs.numpy()
|
| 78 |
+
probs = probs[0]
|
| 79 |
+
probs = probs.tolist()
|
| 80 |
+
return probs
|
| 81 |
+
|
| 82 |
+
def post_process(self, probs: List[float]):
|
| 83 |
+
return
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_args():
|
| 87 |
+
parser = argparse.ArgumentParser()
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--wav_file",
|
| 90 |
+
# default=(project_path / "data/examples/ai_agent/chinese-4.wav").as_posix(),
|
| 91 |
+
# default=(project_path / "data/examples/ai_agent/chinese-5.wav").as_posix(),
|
| 92 |
+
# default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
|
| 93 |
+
# default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
|
| 94 |
+
# default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
|
| 95 |
+
# default=(project_path / "data/examples/speech/active_media_r_2a2f472b-a0b8-4fd5-b1c4-1aedc5d2ce57_0.wav").as_posix(),
|
| 96 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0b8a8e80-52af-423b-8877-03a78b1e6e43_0.wav",
|
| 97 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0ebffb68-6490-4a8b-8eb6-eb82443d7d75_0.wav",
|
| 98 |
+
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0f6ec933-90df-447b-aca4-6ddc149452ab_0.wav",
|
| 99 |
+
type=str,
|
| 100 |
+
)
|
| 101 |
+
args = parser.parse_args()
|
| 102 |
+
return args
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
SAMPLE_RATE = 8000
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def main():
|
| 109 |
+
args = get_args()
|
| 110 |
+
|
| 111 |
+
sample_rate, signal = wavfile.read(args.wav_file)
|
| 112 |
+
if SAMPLE_RATE != sample_rate:
|
| 113 |
+
raise AssertionError
|
| 114 |
+
signal = signal / (1 << 15)
|
| 115 |
+
|
| 116 |
+
infer = InferenceSileroVad(
|
| 117 |
+
pretrained_model_path_or_zip_file=(project_path / "trained_models/cnn-vad-by-webrtcvad-nx-dns3.zip").as_posix()
|
| 118 |
+
# pretrained_model_path_or_zip_file=(project_path / "trained_models/silero-vad-webrtcvad-nx2-dns3.zip").as_posix()
|
| 119 |
+
)
|
| 120 |
+
frame_step = infer.model.hop_size
|
| 121 |
+
|
| 122 |
+
speech_probs = infer.infer(signal)
|
| 123 |
+
|
| 124 |
+
# print(speech_probs)
|
| 125 |
+
|
| 126 |
+
speech_probs = process_speech_probs(
|
| 127 |
+
signal=signal,
|
| 128 |
+
speech_probs=speech_probs,
|
| 129 |
+
frame_step=frame_step,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# plot
|
| 133 |
+
make_visualization(signal, speech_probs, SAMPLE_RATE)
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
main()
|
toolbox/torchaudio/models/vad/fsmn_vad/configuration_fsmn_vad.py
CHANGED
|
@@ -23,9 +23,12 @@ class FSMNVadConfig(PretrainedConfig):
|
|
| 23 |
fsmn_basic_block_lstride: int = 1,
|
| 24 |
fsmn_basic_block_rstride: int = 0,
|
| 25 |
fsmn_output_affine_size: int = 140,
|
| 26 |
-
fsmn_output_size: int =
|
| 27 |
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
min_snr_db: float = -10,
|
| 31 |
max_snr_db: float = 20,
|
|
@@ -65,7 +68,11 @@ class FSMNVadConfig(PretrainedConfig):
|
|
| 65 |
self.fsmn_output_affine_size = fsmn_output_affine_size
|
| 66 |
self.fsmn_output_size = fsmn_output_size
|
| 67 |
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
# data snr
|
| 71 |
self.min_snr_db = min_snr_db
|
|
|
|
| 23 |
fsmn_basic_block_lstride: int = 1,
|
| 24 |
fsmn_basic_block_rstride: int = 0,
|
| 25 |
fsmn_output_affine_size: int = 140,
|
| 26 |
+
fsmn_output_size: int = 2,
|
| 27 |
|
| 28 |
+
n_frame: int = 3,
|
| 29 |
+
min_local_snr_db: float = -15,
|
| 30 |
+
max_local_snr_db: float = 30,
|
| 31 |
+
norm_tau: float = 1.,
|
| 32 |
|
| 33 |
min_snr_db: float = -10,
|
| 34 |
max_snr_db: float = 20,
|
|
|
|
| 68 |
self.fsmn_output_affine_size = fsmn_output_affine_size
|
| 69 |
self.fsmn_output_size = fsmn_output_size
|
| 70 |
|
| 71 |
+
# lsnr
|
| 72 |
+
self.n_frame = n_frame
|
| 73 |
+
self.min_local_snr_db = min_local_snr_db
|
| 74 |
+
self.max_local_snr_db = max_local_snr_db
|
| 75 |
+
self.norm_tau = norm_tau
|
| 76 |
|
| 77 |
# data snr
|
| 78 |
self.min_snr_db = min_snr_db
|
toolbox/torchaudio/models/vad/fsmn_vad/modeling_fsmn_vad.py
CHANGED
|
@@ -15,48 +15,111 @@ from typing import Optional, Union
|
|
| 15 |
|
| 16 |
import torch
|
| 17 |
import torch.nn as nn
|
|
|
|
| 18 |
|
| 19 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
| 20 |
from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
|
| 21 |
from toolbox.torchaudio.modules.conv_stft import ConvSTFT
|
| 22 |
from toolbox.torchaudio.models.vad.fsmn_vad.fsmn_encoder import FSMN
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
MODEL_FILE = "model.pt"
|
| 26 |
|
| 27 |
|
| 28 |
class FSMNVadModel(nn.Module):
|
| 29 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
super(FSMNVadModel, self).__init__()
|
| 31 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
self.eps = 1e-12
|
| 33 |
|
| 34 |
self.stft = ConvSTFT(
|
| 35 |
-
nfft=
|
| 36 |
-
win_size=
|
| 37 |
-
hop_size=
|
| 38 |
-
win_type=
|
| 39 |
power=1,
|
| 40 |
requires_grad=False
|
| 41 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
self.fsmn_encoder = FSMN(
|
| 44 |
-
input_size=
|
| 45 |
-
input_affine_size=
|
| 46 |
-
hidden_size=
|
| 47 |
-
basic_block_layers=
|
| 48 |
-
basic_block_hidden_size=
|
| 49 |
-
basic_block_lorder=
|
| 50 |
-
basic_block_rorder=
|
| 51 |
-
basic_block_lstride=
|
| 52 |
-
basic_block_rstride=
|
| 53 |
-
output_affine_size=
|
| 54 |
-
output_size=
|
|
|
|
| 55 |
)
|
| 56 |
|
| 57 |
-
|
| 58 |
-
self.
|
| 59 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
def forward(self, signal: torch.Tensor):
|
| 62 |
if signal.dim() == 2:
|
|
@@ -71,14 +134,49 @@ class FSMNVadModel(nn.Module):
|
|
| 71 |
# x shape: [b, t, f]
|
| 72 |
|
| 73 |
logits, _ = self.fsmn_encoder.forward(x)
|
|
|
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
class FSMNVadPretrainedModel(FSMNVadModel):
|
|
@@ -86,8 +184,26 @@ class FSMNVadPretrainedModel(FSMNVadModel):
|
|
| 86 |
config: FSMNVadConfig,
|
| 87 |
):
|
| 88 |
super(FSMNVadPretrainedModel, self).__init__(
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
)
|
|
|
|
| 91 |
|
| 92 |
@classmethod
|
| 93 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
|
@@ -133,10 +249,11 @@ def main():
|
|
| 133 |
|
| 134 |
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
|
| 135 |
|
| 136 |
-
logits, probs = model.forward(noisy)
|
| 137 |
-
print(f"
|
| 138 |
-
print(f"probs.shape: {
|
| 139 |
-
print(f"
|
|
|
|
| 140 |
return
|
| 141 |
|
| 142 |
|
|
|
|
| 15 |
|
| 16 |
import torch
|
| 17 |
import torch.nn as nn
|
| 18 |
+
from torch.nn import functional as F
|
| 19 |
|
| 20 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
| 21 |
from toolbox.torchaudio.models.vad.fsmn_vad.configuration_fsmn_vad import FSMNVadConfig
|
| 22 |
from toolbox.torchaudio.modules.conv_stft import ConvSTFT
|
| 23 |
from toolbox.torchaudio.models.vad.fsmn_vad.fsmn_encoder import FSMN
|
| 24 |
+
from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
|
| 25 |
|
| 26 |
|
| 27 |
MODEL_FILE = "model.pt"
|
| 28 |
|
| 29 |
|
| 30 |
class FSMNVadModel(nn.Module):
|
| 31 |
+
def __init__(self,
|
| 32 |
+
sample_rate: int,
|
| 33 |
+
nfft: int,
|
| 34 |
+
win_size: int,
|
| 35 |
+
hop_size: int,
|
| 36 |
+
win_type: int,
|
| 37 |
+
|
| 38 |
+
fsmn_input_size: int,
|
| 39 |
+
fsmn_input_affine_size: int,
|
| 40 |
+
fsmn_hidden_size: int,
|
| 41 |
+
fsmn_basic_block_layers: int,
|
| 42 |
+
fsmn_basic_block_hidden_size: int,
|
| 43 |
+
fsmn_basic_block_lorder: int,
|
| 44 |
+
fsmn_basic_block_rorder: int,
|
| 45 |
+
fsmn_basic_block_lstride: int,
|
| 46 |
+
fsmn_basic_block_rstride: int,
|
| 47 |
+
fsmn_output_affine_size: int,
|
| 48 |
+
|
| 49 |
+
n_frame: int,
|
| 50 |
+
min_local_snr_db: float,
|
| 51 |
+
max_local_snr_db: float,
|
| 52 |
+
):
|
| 53 |
super(FSMNVadModel, self).__init__()
|
| 54 |
+
self.sample_rate = sample_rate
|
| 55 |
+
self.nfft = nfft
|
| 56 |
+
self.win_size = win_size
|
| 57 |
+
self.hop_size = hop_size
|
| 58 |
+
self.win_type = win_type
|
| 59 |
+
|
| 60 |
+
self.fsmn_input_size = fsmn_input_size
|
| 61 |
+
self.fsmn_input_affine_size = fsmn_input_affine_size
|
| 62 |
+
self.fsmn_hidden_size = fsmn_hidden_size
|
| 63 |
+
self.fsmn_basic_block_layers = fsmn_basic_block_layers
|
| 64 |
+
self.fsmn_basic_block_hidden_size = fsmn_basic_block_hidden_size
|
| 65 |
+
self.fsmn_basic_block_lorder = fsmn_basic_block_lorder
|
| 66 |
+
self.fsmn_basic_block_rorder = fsmn_basic_block_rorder
|
| 67 |
+
self.fsmn_basic_block_lstride = fsmn_basic_block_lstride
|
| 68 |
+
self.fsmn_basic_block_rstride = fsmn_basic_block_rstride
|
| 69 |
+
self.fsmn_output_affine_size = fsmn_output_affine_size
|
| 70 |
+
|
| 71 |
+
self.n_frame = n_frame
|
| 72 |
+
self.min_local_snr_db = min_local_snr_db
|
| 73 |
+
self.max_local_snr_db = max_local_snr_db
|
| 74 |
+
|
| 75 |
self.eps = 1e-12
|
| 76 |
|
| 77 |
self.stft = ConvSTFT(
|
| 78 |
+
nfft=self.nfft,
|
| 79 |
+
win_size=self.win_size,
|
| 80 |
+
hop_size=self.hop_size,
|
| 81 |
+
win_type=self.win_type,
|
| 82 |
power=1,
|
| 83 |
requires_grad=False
|
| 84 |
)
|
| 85 |
+
self.complex_stft = ConvSTFT(
|
| 86 |
+
nfft=self.nfft,
|
| 87 |
+
win_size=self.win_size,
|
| 88 |
+
hop_size=self.hop_size,
|
| 89 |
+
win_type=self.win_type,
|
| 90 |
+
power=None,
|
| 91 |
+
requires_grad=False
|
| 92 |
+
)
|
| 93 |
|
| 94 |
self.fsmn_encoder = FSMN(
|
| 95 |
+
input_size=self.fsmn_input_size,
|
| 96 |
+
input_affine_size=self.fsmn_input_affine_size,
|
| 97 |
+
hidden_size=self.fsmn_hidden_size,
|
| 98 |
+
basic_block_layers=self.fsmn_basic_block_layers,
|
| 99 |
+
basic_block_hidden_size=self.fsmn_basic_block_hidden_size,
|
| 100 |
+
basic_block_lorder=self.fsmn_basic_block_lorder,
|
| 101 |
+
basic_block_rorder=self.fsmn_basic_block_rorder,
|
| 102 |
+
basic_block_lstride=self.fsmn_basic_block_lstride,
|
| 103 |
+
basic_block_rstride=self.fsmn_basic_block_rstride,
|
| 104 |
+
output_affine_size=self.fsmn_output_affine_size,
|
| 105 |
+
output_size=2,
|
| 106 |
+
# output_size=self.fsmn_output_size,
|
| 107 |
)
|
| 108 |
|
| 109 |
+
# lsnr
|
| 110 |
+
self.lsnr_scale = self.max_local_snr_db - self.min_local_snr_db
|
| 111 |
+
self.lsnr_offset = self.min_local_snr_db
|
| 112 |
+
|
| 113 |
+
self.lsnr_fn = LocalSnrTarget(
|
| 114 |
+
sample_rate=self.sample_rate,
|
| 115 |
+
nfft=self.nfft,
|
| 116 |
+
win_size=self.win_size,
|
| 117 |
+
hop_size=self.hop_size,
|
| 118 |
+
n_frame=self.n_frame,
|
| 119 |
+
min_local_snr=self.min_local_snr_db,
|
| 120 |
+
max_local_snr=self.max_local_snr_db,
|
| 121 |
+
db=True,
|
| 122 |
+
)
|
| 123 |
|
| 124 |
def forward(self, signal: torch.Tensor):
|
| 125 |
if signal.dim() == 2:
|
|
|
|
| 134 |
# x shape: [b, t, f]
|
| 135 |
|
| 136 |
logits, _ = self.fsmn_encoder.forward(x)
|
| 137 |
+
# logits shape: [b, t, 2]
|
| 138 |
|
| 139 |
+
splits = torch.split(logits, split_size_or_sections=[1, 1], dim=-1)
|
| 140 |
+
vad_logits = splits[0]
|
| 141 |
+
snr_logits = splits[1]
|
| 142 |
+
# shape: [b, t, 1]
|
| 143 |
+
vad_probs = F.sigmoid(vad_logits)
|
| 144 |
+
# vad_probs shape: [b, t, 1]
|
| 145 |
+
|
| 146 |
+
lsnr = F.sigmoid(snr_logits) * self.lsnr_scale + self.lsnr_offset
|
| 147 |
+
# lsnr shape: [b, t, 1]
|
| 148 |
+
|
| 149 |
+
return vad_logits, vad_probs, lsnr
|
| 150 |
+
|
| 151 |
+
def lsnr_loss_fn(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
|
| 152 |
+
if noisy.shape != clean.shape:
|
| 153 |
+
raise AssertionError("Input signals must have the same shape")
|
| 154 |
+
noise = noisy - clean
|
| 155 |
+
|
| 156 |
+
if clean.dim() == 2:
|
| 157 |
+
clean = torch.unsqueeze(clean, dim=1)
|
| 158 |
+
if noise.dim() == 2:
|
| 159 |
+
noise = torch.unsqueeze(noise, dim=1)
|
| 160 |
+
|
| 161 |
+
stft_clean = self.complex_stft.forward(clean)
|
| 162 |
+
stft_noise = self.complex_stft.forward(noise)
|
| 163 |
+
# shape: [b, f, t]
|
| 164 |
+
stft_clean = torch.transpose(stft_clean, dim0=1, dim1=2)
|
| 165 |
+
stft_noise = torch.transpose(stft_noise, dim0=1, dim1=2)
|
| 166 |
+
# shape: [b, t, f]
|
| 167 |
+
stft_clean = torch.unsqueeze(stft_clean, dim=1)
|
| 168 |
+
stft_noise = torch.unsqueeze(stft_noise, dim=1)
|
| 169 |
+
# shape: [b, 1, t, f]
|
| 170 |
+
|
| 171 |
+
# lsnr shape: [b, 1, t]
|
| 172 |
+
lsnr = lsnr.squeeze(1)
|
| 173 |
+
# lsnr shape: [b, t]
|
| 174 |
+
|
| 175 |
+
lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
|
| 176 |
+
# lsnr_gth shape: [b, t]
|
| 177 |
+
|
| 178 |
+
loss = F.mse_loss(lsnr, lsnr_gth)
|
| 179 |
+
return loss
|
| 180 |
|
| 181 |
|
| 182 |
class FSMNVadPretrainedModel(FSMNVadModel):
|
|
|
|
| 184 |
config: FSMNVadConfig,
|
| 185 |
):
|
| 186 |
super(FSMNVadPretrainedModel, self).__init__(
|
| 187 |
+
sample_rate=config.sample_rate,
|
| 188 |
+
nfft=config.nfft,
|
| 189 |
+
win_size=config.win_size,
|
| 190 |
+
hop_size=config.hop_size,
|
| 191 |
+
win_type=config.win_type,
|
| 192 |
+
fsmn_input_size=config.fsmn_input_size,
|
| 193 |
+
fsmn_input_affine_size=config.fsmn_input_affine_size,
|
| 194 |
+
fsmn_hidden_size=config.fsmn_hidden_size,
|
| 195 |
+
fsmn_basic_block_layers=config.fsmn_basic_block_layers,
|
| 196 |
+
fsmn_basic_block_hidden_size=config.fsmn_basic_block_hidden_size,
|
| 197 |
+
fsmn_basic_block_lorder=config.fsmn_basic_block_lorder,
|
| 198 |
+
fsmn_basic_block_rorder=config.fsmn_basic_block_rorder,
|
| 199 |
+
fsmn_basic_block_lstride=config.fsmn_basic_block_lstride,
|
| 200 |
+
fsmn_basic_block_rstride=config.fsmn_basic_block_rstride,
|
| 201 |
+
fsmn_output_affine_size=config.fsmn_output_affine_size,
|
| 202 |
+
n_frame=config.n_frame,
|
| 203 |
+
min_local_snr_db=config.min_local_snr_db,
|
| 204 |
+
max_local_snr_db=config.max_local_snr_db,
|
| 205 |
)
|
| 206 |
+
self.config = config
|
| 207 |
|
| 208 |
@classmethod
|
| 209 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
|
|
|
| 249 |
|
| 250 |
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
|
| 251 |
|
| 252 |
+
logits, probs, lsnr = model.forward(noisy)
|
| 253 |
+
print(f"logits.shape: {logits.shape}")
|
| 254 |
+
print(f"probs.shape: {probs.shape}")
|
| 255 |
+
print(f"lsnr.shape: {lsnr.shape}")
|
| 256 |
+
|
| 257 |
return
|
| 258 |
|
| 259 |
|
toolbox/torchaudio/models/vad/fsmn_vad/yaml/{config-sigmoid.yaml → config.yaml}
RENAMED
|
@@ -18,9 +18,13 @@ fsmn_basic_block_rorder: 0
|
|
| 18 |
fsmn_basic_block_lstride: 1
|
| 19 |
fsmn_basic_block_rstride: 0
|
| 20 |
fsmn_output_affine_size: 140
|
| 21 |
-
fsmn_output_size:
|
| 22 |
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# data
|
| 26 |
min_snr_db: -10
|
|
|
|
| 18 |
fsmn_basic_block_lstride: 1
|
| 19 |
fsmn_basic_block_rstride: 0
|
| 20 |
fsmn_output_affine_size: 140
|
| 21 |
+
fsmn_output_size: 2
|
| 22 |
|
| 23 |
+
# lsnr
|
| 24 |
+
n_frame: 3
|
| 25 |
+
min_local_snr_db: -15
|
| 26 |
+
max_local_snr_db: 30
|
| 27 |
+
norm_tau: 1.
|
| 28 |
|
| 29 |
# data
|
| 30 |
min_snr_db: -10
|
toolbox/torchaudio/models/vad/silero_vad/configuration_silero_vad.py
CHANGED
|
@@ -16,6 +16,11 @@ class SileroVadConfig(PretrainedConfig):
|
|
| 16 |
in_channels: int = 64,
|
| 17 |
hidden_size: int = 128,
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
min_snr_db: float = -10,
|
| 20 |
max_snr_db: float = 20,
|
| 21 |
|
|
@@ -45,6 +50,12 @@ class SileroVadConfig(PretrainedConfig):
|
|
| 45 |
self.in_channels = in_channels
|
| 46 |
self.hidden_size = hidden_size
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
# data snr
|
| 49 |
self.min_snr_db = min_snr_db
|
| 50 |
self.max_snr_db = max_snr_db
|
|
|
|
| 16 |
in_channels: int = 64,
|
| 17 |
hidden_size: int = 128,
|
| 18 |
|
| 19 |
+
n_frame: int = 3,
|
| 20 |
+
min_local_snr_db: float = -15,
|
| 21 |
+
max_local_snr_db: float = 30,
|
| 22 |
+
norm_tau: float = 1.,
|
| 23 |
+
|
| 24 |
min_snr_db: float = -10,
|
| 25 |
max_snr_db: float = 20,
|
| 26 |
|
|
|
|
| 50 |
self.in_channels = in_channels
|
| 51 |
self.hidden_size = hidden_size
|
| 52 |
|
| 53 |
+
# lsnr
|
| 54 |
+
self.n_frame = n_frame
|
| 55 |
+
self.min_local_snr_db = min_local_snr_db
|
| 56 |
+
self.max_local_snr_db = max_local_snr_db
|
| 57 |
+
self.norm_tau = norm_tau
|
| 58 |
+
|
| 59 |
# data snr
|
| 60 |
self.min_snr_db = min_snr_db
|
| 61 |
self.max_snr_db = max_snr_db
|
toolbox/torchaudio/models/vad/silero_vad/modeling_silero_vad.py
CHANGED
|
@@ -13,10 +13,12 @@ from typing import Optional, Union
|
|
| 13 |
|
| 14 |
import torch
|
| 15 |
import torch.nn as nn
|
|
|
|
| 16 |
|
| 17 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
| 18 |
from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
|
| 19 |
from toolbox.torchaudio.modules.conv_stft import ConvSTFT
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
MODEL_FILE = "model.pt"
|
|
@@ -80,50 +82,99 @@ class Encoder(nn.Module):
|
|
| 80 |
|
| 81 |
|
| 82 |
class SileroVadModel(nn.Module):
|
| 83 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
super(SileroVadModel, self).__init__()
|
| 85 |
-
self.
|
| 86 |
-
self.
|
| 87 |
-
self.
|
| 88 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
-
self.config = config
|
| 91 |
self.eps = 1e-12
|
| 92 |
|
| 93 |
self.stft = ConvSTFT(
|
| 94 |
-
nfft=
|
| 95 |
-
win_size=
|
| 96 |
-
hop_size=
|
| 97 |
-
win_type=
|
| 98 |
power=1,
|
| 99 |
requires_grad=False
|
| 100 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
self.linear = nn.Linear(
|
| 103 |
-
in_features=(
|
| 104 |
-
out_features=
|
| 105 |
)
|
| 106 |
|
| 107 |
self.encoder = Encoder(
|
| 108 |
-
in_channels=
|
| 109 |
-
out_channels=
|
| 110 |
)
|
| 111 |
|
| 112 |
self.lstm = nn.LSTM(
|
| 113 |
-
input_size=
|
| 114 |
-
hidden_size=
|
| 115 |
bidirectional=False,
|
| 116 |
batch_first=True
|
| 117 |
)
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
|
|
|
| 121 |
nn.ReLU(),
|
| 122 |
nn.Linear(32, 1),
|
| 123 |
)
|
| 124 |
-
|
| 125 |
self.sigmoid = nn.Sigmoid()
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
def forward(self, signal: torch.Tensor):
|
| 128 |
if signal.dim() == 2:
|
| 129 |
signal = torch.unsqueeze(signal, dim=1)
|
|
@@ -143,40 +194,46 @@ class SileroVadModel(nn.Module):
|
|
| 143 |
# x shape: [b, t, f]
|
| 144 |
|
| 145 |
x, _ = self.lstm.forward(x)
|
| 146 |
-
|
|
|
|
| 147 |
# logits shape: [b, t, 1]
|
| 148 |
probs = self.sigmoid.forward(logits)
|
| 149 |
# probs shape: [b, t, 1]
|
| 150 |
-
return logits, probs
|
| 151 |
|
| 152 |
-
|
| 153 |
-
#
|
| 154 |
|
| 155 |
-
|
| 156 |
-
# mags shape: [b, f, t]
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
| 163 |
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
_, _, num_samples = signal.shape
|
| 170 |
-
# signal shape [b, 1, num_samples]
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
for i in range(int(t)):
|
| 175 |
-
begin = i * self.hop_size
|
| 176 |
-
end = begin + self.win_size
|
| 177 |
-
sub_signal = signal[:, :, begin: end]
|
| 178 |
|
| 179 |
-
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
class SileroVadPretrainedModel(SileroVadModel):
|
|
@@ -184,8 +241,18 @@ class SileroVadPretrainedModel(SileroVadModel):
|
|
| 184 |
config: SileroVadConfig,
|
| 185 |
):
|
| 186 |
super(SileroVadPretrainedModel, self).__init__(
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
)
|
|
|
|
| 189 |
|
| 190 |
@classmethod
|
| 191 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
|
@@ -227,13 +294,14 @@ class SileroVadPretrainedModel(SileroVadModel):
|
|
| 227 |
|
| 228 |
def main():
|
| 229 |
config = SileroVadConfig()
|
| 230 |
-
model =
|
| 231 |
|
| 232 |
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
|
| 233 |
|
| 234 |
-
logits, probs = model.forward(noisy)
|
| 235 |
-
print(f"logits: {probs}")
|
| 236 |
print(f"logits.shape: {logits.shape}")
|
|
|
|
|
|
|
| 237 |
|
| 238 |
return
|
| 239 |
|
|
|
|
| 13 |
|
| 14 |
import torch
|
| 15 |
import torch.nn as nn
|
| 16 |
+
from torch.nn import functional as F
|
| 17 |
|
| 18 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
| 19 |
from toolbox.torchaudio.models.vad.silero_vad.configuration_silero_vad import SileroVadConfig
|
| 20 |
from toolbox.torchaudio.modules.conv_stft import ConvSTFT
|
| 21 |
+
from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
|
| 22 |
|
| 23 |
|
| 24 |
MODEL_FILE = "model.pt"
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
class SileroVadModel(nn.Module):
|
| 85 |
+
def __init__(self,
|
| 86 |
+
sample_rate: int,
|
| 87 |
+
nfft: int,
|
| 88 |
+
win_size: int,
|
| 89 |
+
hop_size: int,
|
| 90 |
+
win_type: int,
|
| 91 |
+
|
| 92 |
+
in_channels: int,
|
| 93 |
+
hidden_size: int,
|
| 94 |
+
|
| 95 |
+
n_frame: int,
|
| 96 |
+
min_local_snr_db: float,
|
| 97 |
+
max_local_snr_db: float,
|
| 98 |
+
|
| 99 |
+
):
|
| 100 |
super(SileroVadModel, self).__init__()
|
| 101 |
+
self.sample_rate = sample_rate
|
| 102 |
+
self.nfft = nfft
|
| 103 |
+
self.win_size = win_size
|
| 104 |
+
self.hop_size = hop_size
|
| 105 |
+
self.win_type = win_type
|
| 106 |
+
|
| 107 |
+
self.in_channels = in_channels
|
| 108 |
+
self.hidden_size = hidden_size
|
| 109 |
+
|
| 110 |
+
self.n_frame = n_frame
|
| 111 |
+
self.min_local_snr_db = min_local_snr_db
|
| 112 |
+
self.max_local_snr_db = max_local_snr_db
|
| 113 |
|
|
|
|
| 114 |
self.eps = 1e-12
|
| 115 |
|
| 116 |
self.stft = ConvSTFT(
|
| 117 |
+
nfft=nfft,
|
| 118 |
+
win_size=win_size,
|
| 119 |
+
hop_size=hop_size,
|
| 120 |
+
win_type=win_type,
|
| 121 |
power=1,
|
| 122 |
requires_grad=False
|
| 123 |
)
|
| 124 |
+
self.complex_stft = ConvSTFT(
|
| 125 |
+
nfft=nfft,
|
| 126 |
+
win_size=win_size,
|
| 127 |
+
hop_size=hop_size,
|
| 128 |
+
win_type=win_type,
|
| 129 |
+
power=None,
|
| 130 |
+
requires_grad=False
|
| 131 |
+
)
|
| 132 |
|
| 133 |
self.linear = nn.Linear(
|
| 134 |
+
in_features=(self.nfft // 2 + 1),
|
| 135 |
+
out_features=self.in_channels,
|
| 136 |
)
|
| 137 |
|
| 138 |
self.encoder = Encoder(
|
| 139 |
+
in_channels=self.in_channels,
|
| 140 |
+
out_channels=self.hidden_size,
|
| 141 |
)
|
| 142 |
|
| 143 |
self.lstm = nn.LSTM(
|
| 144 |
+
input_size=self.hidden_size,
|
| 145 |
+
hidden_size=self.hidden_size,
|
| 146 |
bidirectional=False,
|
| 147 |
batch_first=True
|
| 148 |
)
|
| 149 |
|
| 150 |
+
# vad
|
| 151 |
+
self.vad_fc = nn.Sequential(
|
| 152 |
+
nn.Linear(self.hidden_size, 32),
|
| 153 |
nn.ReLU(),
|
| 154 |
nn.Linear(32, 1),
|
| 155 |
)
|
|
|
|
| 156 |
self.sigmoid = nn.Sigmoid()
|
| 157 |
|
| 158 |
+
# lsnr
|
| 159 |
+
self.lsnr_fc = nn.Sequential(
|
| 160 |
+
nn.Linear(self.hidden_size, 1),
|
| 161 |
+
nn.Sigmoid()
|
| 162 |
+
)
|
| 163 |
+
self.lsnr_scale = self.max_local_snr_db - self.min_local_snr_db
|
| 164 |
+
self.lsnr_offset = self.min_local_snr_db
|
| 165 |
+
|
| 166 |
+
# lsnr
|
| 167 |
+
self.lsnr_fn = LocalSnrTarget(
|
| 168 |
+
sample_rate=self.sample_rate,
|
| 169 |
+
nfft=self.nfft,
|
| 170 |
+
win_size=self.win_size,
|
| 171 |
+
hop_size=self.hop_size,
|
| 172 |
+
n_frame=self.n_frame,
|
| 173 |
+
min_local_snr=self.min_local_snr_db,
|
| 174 |
+
max_local_snr=self.max_local_snr_db,
|
| 175 |
+
db=True,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
def forward(self, signal: torch.Tensor):
|
| 179 |
if signal.dim() == 2:
|
| 180 |
signal = torch.unsqueeze(signal, dim=1)
|
|
|
|
| 194 |
# x shape: [b, t, f]
|
| 195 |
|
| 196 |
x, _ = self.lstm.forward(x)
|
| 197 |
+
|
| 198 |
+
logits = self.vad_fc.forward(x)
|
| 199 |
# logits shape: [b, t, 1]
|
| 200 |
probs = self.sigmoid.forward(logits)
|
| 201 |
# probs shape: [b, t, 1]
|
|
|
|
| 202 |
|
| 203 |
+
lsnr = self.lsnr_fc.forward(x) * self.lsnr_scale + self.lsnr_offset
|
| 204 |
+
# lsnr shape: [b, t, 1]
|
| 205 |
|
| 206 |
+
return logits, probs, lsnr
|
|
|
|
| 207 |
|
| 208 |
+
def lsnr_loss_fn(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
|
| 209 |
+
if noisy.shape != clean.shape:
|
| 210 |
+
raise AssertionError("Input signals must have the same shape")
|
| 211 |
+
noise = noisy - clean
|
| 212 |
|
| 213 |
+
if clean.dim() == 2:
|
| 214 |
+
clean = torch.unsqueeze(clean, dim=1)
|
| 215 |
+
if noise.dim() == 2:
|
| 216 |
+
noise = torch.unsqueeze(noise, dim=1)
|
| 217 |
|
| 218 |
+
stft_clean = self.complex_stft.forward(clean)
|
| 219 |
+
stft_noise = self.complex_stft.forward(noise)
|
| 220 |
+
# shape: [b, f, t]
|
| 221 |
+
stft_clean = torch.transpose(stft_clean, dim0=1, dim1=2)
|
| 222 |
+
stft_noise = torch.transpose(stft_noise, dim0=1, dim1=2)
|
| 223 |
+
# shape: [b, t, f]
|
| 224 |
+
stft_clean = torch.unsqueeze(stft_clean, dim=1)
|
| 225 |
+
stft_noise = torch.unsqueeze(stft_noise, dim=1)
|
| 226 |
+
# shape: [b, 1, t, f]
|
| 227 |
|
| 228 |
+
# lsnr shape: [b, 1, t]
|
| 229 |
+
lsnr = lsnr.squeeze(1)
|
| 230 |
+
# lsnr shape: [b, t]
|
|
|
|
|
|
|
| 231 |
|
| 232 |
+
lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
|
| 233 |
+
# lsnr_gth shape: [b, t]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
+
loss = F.mse_loss(lsnr, lsnr_gth)
|
| 236 |
+
return loss
|
| 237 |
|
| 238 |
|
| 239 |
class SileroVadPretrainedModel(SileroVadModel):
|
|
|
|
| 241 |
config: SileroVadConfig,
|
| 242 |
):
|
| 243 |
super(SileroVadPretrainedModel, self).__init__(
|
| 244 |
+
sample_rate=config.sample_rate,
|
| 245 |
+
nfft=config.nfft,
|
| 246 |
+
win_size=config.win_size,
|
| 247 |
+
hop_size=config.hop_size,
|
| 248 |
+
win_type=config.win_type,
|
| 249 |
+
in_channels=config.in_channels,
|
| 250 |
+
hidden_size=config.hidden_size,
|
| 251 |
+
n_frame=config.n_frame,
|
| 252 |
+
min_local_snr_db=config.min_local_snr_db,
|
| 253 |
+
max_local_snr_db=config.max_local_snr_db,
|
| 254 |
)
|
| 255 |
+
self.config = config
|
| 256 |
|
| 257 |
@classmethod
|
| 258 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
|
|
|
| 294 |
|
| 295 |
def main():
|
| 296 |
config = SileroVadConfig()
|
| 297 |
+
model = SileroVadPretrainedModel(config=config)
|
| 298 |
|
| 299 |
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
|
| 300 |
|
| 301 |
+
logits, probs, lsnr = model.forward(noisy)
|
|
|
|
| 302 |
print(f"logits.shape: {logits.shape}")
|
| 303 |
+
print(f"probs.shape: {probs.shape}")
|
| 304 |
+
print(f"lsnr.shape: {lsnr.shape}")
|
| 305 |
|
| 306 |
return
|
| 307 |
|
toolbox/torchaudio/models/vad/silero_vad/yaml/config.yaml
CHANGED
|
@@ -11,6 +11,12 @@ win_type: hann
|
|
| 11 |
in_channels: 64
|
| 12 |
hidden_size: 128
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# data
|
| 15 |
min_snr_db: -10
|
| 16 |
max_snr_db: 20
|
|
|
|
| 11 |
in_channels: 64
|
| 12 |
hidden_size: 128
|
| 13 |
|
| 14 |
+
# lsnr
|
| 15 |
+
n_frame: 3
|
| 16 |
+
min_local_snr_db: -15
|
| 17 |
+
max_local_snr_db: 30
|
| 18 |
+
norm_tau: 1.
|
| 19 |
+
|
| 20 |
# data
|
| 21 |
min_snr_db: -10
|
| 22 |
max_snr_db: 20
|