Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Tuple, Optional | |
| import numpy as np | |
| import torch | |
| from torch.nn import functional as F | |
| import torch.nn as nn | |
| class ProfileAug(nn.Module): | |
| """ | |
| Implement the augmentation for profiles including: | |
| - Split aug: split one profile into two profiles, i.e., main and inaccurate, labels assigned to main | |
| - Merge aug: merge two profiles into one, labels are also merged into one, the other set to zero | |
| - Disturb aug: disturb some profile with others to simulate the inaccurate clustering centroids. | |
| """ | |
| def __init__( | |
| self, | |
| apply_split_aug: bool = True, | |
| split_aug_prob: float = 0.05, | |
| apply_merge_aug: bool = True, | |
| merge_aug_prob: float = 0.2, | |
| apply_disturb_aug: bool = True, | |
| disturb_aug_prob: float = 0.4, | |
| disturb_alpha: float = 0.2, | |
| ) -> None: | |
| super().__init__() | |
| self.apply_split_aug = apply_split_aug | |
| self.split_aug_prob = split_aug_prob | |
| self.apply_merge_aug = apply_merge_aug | |
| self.merge_aug_prob = merge_aug_prob | |
| self.apply_disturb_aug = apply_disturb_aug | |
| self.disturb_aug_prob = disturb_aug_prob | |
| self.disturb_alpha = disturb_alpha | |
| def split_aug( | |
| self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor | |
| ): | |
| # B, N | |
| bsz, dim = profile.shape[0], profile.shape[-1] | |
| profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False) | |
| spk_count = binary_labels.sum(dim=1) | |
| prob = np.random.rand(bsz) | |
| batch_indices = np.nonzero(prob < self.split_aug_prob)[0] | |
| for idx in batch_indices: | |
| valid_spk_idx = torch.nonzero(spk_count[idx] * mask[idx]) | |
| pad_spk_idx = torch.nonzero((spk_count[idx] == 0) * mask[idx]) | |
| if len(valid_spk_idx) == 0 or len(pad_spk_idx) == 0: | |
| continue | |
| split_spk_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())] | |
| to_cover_idx = pad_spk_idx[torch.randint(len(pad_spk_idx), ())] | |
| disturb_vec = torch.randn((dim,)).to(profile) | |
| disturb_vec = F.normalize(disturb_vec, dim=-1) | |
| profile[idx, to_cover_idx] = F.normalize( | |
| profile[idx, split_spk_idx] + self.disturb_alpha * disturb_vec | |
| ) | |
| mask[idx, split_spk_idx] = 0 | |
| mask[idx, to_cover_idx] = 0 | |
| return profile, binary_labels, mask | |
| def merge_aug( | |
| self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor | |
| ): | |
| bsz, dim = profile.shape[0], profile.shape[-1] | |
| profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False) | |
| spk_count = binary_labels.sum(dim=1) | |
| prob = np.random.rand(bsz) | |
| batch_indices = np.nonzero(prob < self.merge_aug_prob)[0] | |
| for idx in batch_indices: | |
| valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx]) | |
| if len(valid_spk_idx) == 0: | |
| continue | |
| to_merge = torch.randint(len(valid_spk_idx), (2,)) | |
| spk_idx_1, spk_idx_2 = ( | |
| valid_spk_idx[to_merge[0]], | |
| valid_spk_idx[to_merge[1]], | |
| ) | |
| # merge profile | |
| profile[idx, spk_idx_1] = profile[idx, spk_idx_1] + profile[idx, spk_idx_2] | |
| profile[idx, spk_idx_1] = F.normalize(profile[idx, spk_idx_1], dim=-1) | |
| profile[idx, spk_idx_2] = 0 | |
| # merge binary labels | |
| binary_labels[idx, :, spk_idx_1] = ( | |
| binary_labels[idx, :, spk_idx_1] + binary_labels[idx, :, spk_idx_2] | |
| ) | |
| binary_labels[idx, :, spk_idx_1] = ( | |
| binary_labels[idx, :, spk_idx_1] > 0 | |
| ).to(binary_labels) | |
| binary_labels[idx, :, spk_idx_2] = 0 | |
| mask[idx, spk_idx_1] = 0 | |
| mask[idx, spk_idx_2] = 0 | |
| return profile, binary_labels, mask | |
| def disturb_aug( | |
| self, profile: torch.Tensor, binary_labels: torch.Tensor, mask: torch.Tensor | |
| ): | |
| bsz, dim = profile.shape[0], profile.shape[-1] | |
| profile_norm = torch.linalg.norm(profile, dim=-1, keepdim=False) | |
| spk_count = binary_labels.sum(dim=1) | |
| prob = np.random.rand(bsz) | |
| batch_indices = np.nonzero(prob < self.disturb_aug_prob)[0] | |
| for idx in batch_indices: | |
| pos_spk_idx = torch.nonzero(spk_count[idx] * mask[idx]) | |
| valid_spk_idx = torch.nonzero(profile_norm[idx] * mask[idx]) | |
| if len(pos_spk_idx) == 0 or len(valid_spk_idx) == 0: | |
| continue | |
| to_disturb_idx = pos_spk_idx[torch.randint(len(pos_spk_idx), ())] | |
| disturb_idx = valid_spk_idx[torch.randint(len(valid_spk_idx), ())] | |
| alpha = self.disturb_alpha * torch.rand(()).item() | |
| profile[idx, to_disturb_idx] = (1 - alpha) * profile[ | |
| idx, to_disturb_idx | |
| ] + alpha * profile[idx, disturb_idx] | |
| profile[idx, to_disturb_idx] = F.normalize( | |
| profile[idx, to_disturb_idx], dim=-1 | |
| ) | |
| mask[idx, to_disturb_idx] = 0 | |
| return profile, binary_labels, mask | |
| def forward( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor = None, | |
| profile: torch.Tensor = None, | |
| profile_lengths: torch.Tensor = None, | |
| binary_labels: torch.Tensor = None, | |
| labels_length: torch.Tensor = None, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: | |
| # copy inputs to avoid inplace-operation | |
| speech, profile, binary_labels = ( | |
| torch.clone(speech), | |
| torch.clone(profile), | |
| torch.clone(binary_labels), | |
| ) | |
| profile = F.normalize(profile, dim=-1) | |
| profile_mask = torch.ones(profile.shape[:2]).to(profile) | |
| if self.apply_disturb_aug: | |
| profile, binary_labels, profile_mask = self.disturb_aug( | |
| profile, binary_labels, profile_mask | |
| ) | |
| if self.apply_split_aug: | |
| profile, binary_labels, profile_mask = self.split_aug( | |
| profile, binary_labels, profile_mask | |
| ) | |
| if self.apply_merge_aug: | |
| profile, binary_labels, profile_mask = self.merge_aug( | |
| profile, binary_labels, profile_mask | |
| ) | |
| return speech, profile, binary_labels | |