Spaces:
Running
on
Zero
Running
on
Zero
| from contextlib import contextmanager | |
| from distutils.version import LooseVersion | |
| from typing import Dict, List, Tuple, Optional | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from funasr_detach.frontends.wav_frontend import WavFrontendMel23 | |
| from funasr_detach.models.eend.encoder import EENDOLATransformerEncoder | |
| from funasr_detach.models.eend.encoder_decoder_attractor import EncoderDecoderAttractor | |
| from funasr_detach.models.eend.utils.losses import ( | |
| standard_loss, | |
| cal_power_loss, | |
| fast_batch_pit_n_speaker_loss, | |
| ) | |
| from funasr_detach.models.eend.utils.power import create_powerlabel | |
| from funasr_detach.models.eend.utils.power import generate_mapping_dict | |
| from funasr_detach.train_utils.device_funcs import force_gatherable | |
| if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): | |
| pass | |
| else: | |
| # Nothing to do if torch<1.6.0 | |
| def autocast(enabled=True): | |
| yield | |
| def pad_attractor(att, max_n_speakers): | |
| C, D = att.shape | |
| if C < max_n_speakers: | |
| att = torch.cat( | |
| [att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], | |
| dim=0, | |
| ) | |
| return att | |
| def pad_labels(ts, out_size): | |
| for i, t in enumerate(ts): | |
| if t.shape[1] < out_size: | |
| ts[i] = F.pad( | |
| t, (0, out_size - t.shape[1], 0, 0), mode="constant", value=0.0 | |
| ) | |
| return ts | |
| def pad_results(ys, out_size): | |
| ys_padded = [] | |
| for i, y in enumerate(ys): | |
| if y.shape[1] < out_size: | |
| ys_padded.append( | |
| torch.cat( | |
| [ | |
| y, | |
| torch.zeros(y.shape[0], out_size - y.shape[1]) | |
| .to(torch.float32) | |
| .to(y.device), | |
| ], | |
| dim=1, | |
| ) | |
| ) | |
| else: | |
| ys_padded.append(y) | |
| return ys_padded | |
| class DiarEENDOLAModel(nn.Module): | |
| """EEND-OLA diarization model""" | |
| def __init__( | |
| self, | |
| frontend: Optional[WavFrontendMel23], | |
| encoder: EENDOLATransformerEncoder, | |
| encoder_decoder_attractor: EncoderDecoderAttractor, | |
| n_units: int = 256, | |
| max_n_speaker: int = 8, | |
| attractor_loss_weight: float = 1.0, | |
| mapping_dict=None, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.frontend = frontend | |
| self.enc = encoder | |
| self.encoder_decoder_attractor = encoder_decoder_attractor | |
| self.attractor_loss_weight = attractor_loss_weight | |
| self.max_n_speaker = max_n_speaker | |
| if mapping_dict is None: | |
| mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker) | |
| self.mapping_dict = mapping_dict | |
| # PostNet | |
| self.postnet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True) | |
| self.output_layer = nn.Linear(n_units, mapping_dict["oov"] + 1) | |
| def forward_encoder(self, xs, ilens): | |
| xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1) | |
| pad_shape = xs.shape | |
| xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens] | |
| xs_mask = torch.nn.utils.rnn.pad_sequence( | |
| xs_mask, batch_first=True, padding_value=0 | |
| ).unsqueeze(-2) | |
| emb = self.enc(xs, xs_mask) | |
| emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0) | |
| emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)] | |
| return emb | |
| def forward_post_net(self, logits, ilens): | |
| maxlen = torch.max(ilens).to(torch.int).item() | |
| logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1) | |
| logits = nn.utils.rnn.pack_padded_sequence( | |
| logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False | |
| ) | |
| outputs, (_, _) = self.postnet(logits) | |
| outputs = nn.utils.rnn.pad_packed_sequence( | |
| outputs, batch_first=True, padding_value=-1, total_length=maxlen | |
| )[0] | |
| outputs = [ | |
| output[: ilens[i].to(torch.int).item()] for i, output in enumerate(outputs) | |
| ] | |
| outputs = [self.output_layer(output) for output in outputs] | |
| return outputs | |
| def forward( | |
| self, | |
| speech: List[torch.Tensor], | |
| speaker_labels: List[torch.Tensor], | |
| orders: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: | |
| # Check that batch_size is unified | |
| assert len(speech) == len(speaker_labels), (len(speech), len(speaker_labels)) | |
| speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64) | |
| speaker_labels_lengths = torch.tensor( | |
| [spk.shape[-1] for spk in speaker_labels] | |
| ).to(torch.int64) | |
| batch_size = len(speech) | |
| # Encoder | |
| encoder_out = self.forward_encoder(speech, speech_lengths) | |
| # Encoder-decoder attractor | |
| attractor_loss, attractors = self.encoder_decoder_attractor( | |
| [e[order] for e, order in zip(encoder_out, orders)], speaker_labels_lengths | |
| ) | |
| speaker_logits = [ | |
| torch.matmul(e, att.permute(1, 0)) | |
| for e, att in zip(encoder_out, attractors) | |
| ] | |
| # pit loss | |
| pit_speaker_labels = fast_batch_pit_n_speaker_loss( | |
| speaker_logits, speaker_labels | |
| ) | |
| pit_loss = standard_loss(speaker_logits, pit_speaker_labels) | |
| # pse loss | |
| with torch.no_grad(): | |
| power_ts = [ | |
| create_powerlabel( | |
| label.cpu().numpy(), self.mapping_dict, self.max_n_speaker | |
| ).to(encoder_out[0].device, non_blocking=True) | |
| for label in pit_speaker_labels | |
| ] | |
| pad_attractors = [pad_attractor(att, self.max_n_speaker) for att in attractors] | |
| pse_speaker_logits = [ | |
| torch.matmul(e, pad_att.permute(1, 0)) | |
| for e, pad_att in zip(encoder_out, pad_attractors) | |
| ] | |
| pse_speaker_logits = self.forward_post_net(pse_speaker_logits, speech_lengths) | |
| pse_loss = cal_power_loss(pse_speaker_logits, power_ts) | |
| loss = pse_loss + pit_loss + self.attractor_loss_weight * attractor_loss | |
| stats = dict() | |
| stats["pse_loss"] = pse_loss.detach() | |
| stats["pit_loss"] = pit_loss.detach() | |
| stats["attractor_loss"] = attractor_loss.detach() | |
| stats["batch_size"] = batch_size | |
| # Collect total loss stats | |
| stats["loss"] = torch.clone(loss.detach()) | |
| # force_gatherable: to-device and to-tensor if scalar for DataParallel | |
| loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) | |
| return loss, stats, weight | |
| def estimate_sequential( | |
| self, | |
| speech: torch.Tensor, | |
| n_speakers: int = None, | |
| shuffle: bool = True, | |
| threshold: float = 0.5, | |
| **kwargs, | |
| ): | |
| speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64) | |
| emb = self.forward_encoder(speech, speech_lengths) | |
| if shuffle: | |
| orders = [np.arange(e.shape[0]) for e in emb] | |
| for order in orders: | |
| np.random.shuffle(order) | |
| attractors, probs = self.encoder_decoder_attractor.estimate( | |
| [ | |
| e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] | |
| for e, order in zip(emb, orders) | |
| ] | |
| ) | |
| else: | |
| attractors, probs = self.encoder_decoder_attractor.estimate(emb) | |
| attractors_active = [] | |
| for p, att, e in zip(probs, attractors, emb): | |
| if n_speakers and n_speakers >= 0: | |
| att = att[:n_speakers,] | |
| attractors_active.append(att) | |
| elif threshold is not None: | |
| silence = torch.nonzero(p < threshold)[0] | |
| n_spk = silence[0] if silence.size else None | |
| att = att[:n_spk,] | |
| attractors_active.append(att) | |
| else: | |
| NotImplementedError("n_speakers or threshold has to be given.") | |
| raw_n_speakers = [att.shape[0] for att in attractors_active] | |
| attractors = [ | |
| ( | |
| pad_attractor(att, self.max_n_speaker) | |
| if att.shape[0] <= self.max_n_speaker | |
| else att[: self.max_n_speaker] | |
| ) | |
| for att in attractors_active | |
| ] | |
| ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)] | |
| logits = self.forward_post_net(ys, speech_lengths) | |
| ys = [ | |
| self.recover_y_from_powerlabel(logit, raw_n_speaker) | |
| for logit, raw_n_speaker in zip(logits, raw_n_speakers) | |
| ] | |
| return ys, emb, attractors, raw_n_speakers | |
| def recover_y_from_powerlabel(self, logit, n_speaker): | |
| pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1) | |
| oov_index = torch.where(pred == self.mapping_dict["oov"])[0] | |
| for i in oov_index: | |
| if i > 0: | |
| pred[i] = pred[i - 1] | |
| else: | |
| pred[i] = 0 | |
| pred = [self.inv_mapping_func(i) for i in pred] | |
| decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred] | |
| decisions = ( | |
| torch.from_numpy( | |
| np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0) | |
| ) | |
| .to(logit.device) | |
| .to(torch.float32) | |
| ) | |
| decisions = decisions[:, :n_speaker] | |
| return decisions | |
| def inv_mapping_func(self, label): | |
| if not isinstance(label, int): | |
| label = int(label) | |
| if label in self.mapping_dict["label2dec"].keys(): | |
| num = self.mapping_dict["label2dec"][label] | |
| else: | |
| num = -1 | |
| return num | |
| def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]: | |
| pass | |