update
Browse files
examples/cnn_vad_by_webrtcvad/step_4_train_model.py
CHANGED
|
@@ -272,7 +272,7 @@ def main():
|
|
| 272 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
| 273 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
| 274 |
|
| 275 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.
|
| 276 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 277 |
logger.info(f"find nan or inf in loss. continue.")
|
| 278 |
continue
|
|
@@ -352,7 +352,7 @@ def main():
|
|
| 352 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
| 353 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
| 354 |
|
| 355 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.
|
| 356 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 357 |
logger.info(f"find nan or inf in loss. continue.")
|
| 358 |
continue
|
|
|
|
| 272 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
| 273 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
| 274 |
|
| 275 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.03 * lsnr_loss
|
| 276 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 277 |
logger.info(f"find nan or inf in loss. continue.")
|
| 278 |
continue
|
|
|
|
| 352 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
| 353 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
| 354 |
|
| 355 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.03 * lsnr_loss
|
| 356 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 357 |
logger.info(f"find nan or inf in loss. continue.")
|
| 358 |
continue
|
toolbox/torchaudio/models/vad/cnn_vad/modeling_cnn_vad.py
CHANGED
|
@@ -201,10 +201,6 @@ class CNNVadModel(nn.Module):
|
|
| 201 |
raise AssertionError("Input signals must have the same shape")
|
| 202 |
noise = noisy - clean
|
| 203 |
|
| 204 |
-
print(f"lsnr: {lsnr.shape}")
|
| 205 |
-
print(f"clean: {clean.shape}")
|
| 206 |
-
print(f"noisy: {noisy.shape}")
|
| 207 |
-
|
| 208 |
if clean.dim() == 2:
|
| 209 |
clean = torch.unsqueeze(clean, dim=1)
|
| 210 |
if noise.dim() == 2:
|
|
@@ -227,9 +223,6 @@ class CNNVadModel(nn.Module):
|
|
| 227 |
lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
|
| 228 |
# lsnr_gth shape: [b, t]
|
| 229 |
|
| 230 |
-
print(f"lsnr: {lsnr.shape}")
|
| 231 |
-
print(f"lsnr_gth: {lsnr_gth.shape}")
|
| 232 |
-
|
| 233 |
loss = F.mse_loss(lsnr, lsnr_gth)
|
| 234 |
return loss
|
| 235 |
|
|
|
|
| 201 |
raise AssertionError("Input signals must have the same shape")
|
| 202 |
noise = noisy - clean
|
| 203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
if clean.dim() == 2:
|
| 205 |
clean = torch.unsqueeze(clean, dim=1)
|
| 206 |
if noise.dim() == 2:
|
|
|
|
| 223 |
lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
|
| 224 |
# lsnr_gth shape: [b, t]
|
| 225 |
|
|
|
|
|
|
|
|
|
|
| 226 |
loss = F.mse_loss(lsnr, lsnr_gth)
|
| 227 |
return loss
|
| 228 |
|