Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| # -*- encoding: utf-8 -*- | |
| # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. | |
| # MIT License (https://opensource.org/licenses/MIT) | |
| import torch | |
| from typing import List, Optional, Tuple | |
| from funasr_detach.register import tables | |
| from funasr_detach.models.specaug.specaug import SpecAug | |
| from funasr_detach.models.transducer.beam_search_transducer import Hypothesis | |
| class RNNTDecoder(torch.nn.Module): | |
| """RNN decoder module. | |
| Args: | |
| vocab_size: Vocabulary size. | |
| embed_size: Embedding size. | |
| hidden_size: Hidden size.. | |
| rnn_type: Decoder layers type. | |
| num_layers: Number of decoder layers. | |
| dropout_rate: Dropout rate for decoder layers. | |
| embed_dropout_rate: Dropout rate for embedding layer. | |
| embed_pad: Embedding padding symbol ID. | |
| """ | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| embed_size: int = 256, | |
| hidden_size: int = 256, | |
| rnn_type: str = "lstm", | |
| num_layers: int = 1, | |
| dropout_rate: float = 0.0, | |
| embed_dropout_rate: float = 0.0, | |
| embed_pad: int = 0, | |
| use_embed_mask: bool = False, | |
| ) -> None: | |
| """Construct a RNNDecoder object.""" | |
| super().__init__() | |
| if rnn_type not in ("lstm", "gru"): | |
| raise ValueError(f"Not supported: rnn_type={rnn_type}") | |
| self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad) | |
| self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate) | |
| rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU | |
| self.rnn = torch.nn.ModuleList( | |
| [rnn_class(embed_size, hidden_size, 1, batch_first=True)] | |
| ) | |
| for _ in range(1, num_layers): | |
| self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)] | |
| self.dropout_rnn = torch.nn.ModuleList( | |
| [torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)] | |
| ) | |
| self.dlayers = num_layers | |
| self.dtype = rnn_type | |
| self.output_size = hidden_size | |
| self.vocab_size = vocab_size | |
| self.device = next(self.parameters()).device | |
| self.score_cache = {} | |
| self.use_embed_mask = use_embed_mask | |
| if self.use_embed_mask: | |
| self._embed_mask = SpecAug( | |
| time_mask_width_range=3, | |
| num_time_mask=4, | |
| apply_freq_mask=False, | |
| apply_time_warp=False, | |
| ) | |
| def forward( | |
| self, | |
| labels: torch.Tensor, | |
| label_lens: torch.Tensor, | |
| states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None, | |
| ) -> torch.Tensor: | |
| """Encode source label sequences. | |
| Args: | |
| labels: Label ID sequences. (B, L) | |
| states: Decoder hidden states. | |
| ((N, B, D_dec), (N, B, D_dec) or None) or None | |
| Returns: | |
| dec_out: Decoder output sequences. (B, U, D_dec) | |
| """ | |
| if states is None: | |
| states = self.init_state(labels.size(0)) | |
| dec_embed = self.dropout_embed(self.embed(labels)) | |
| if self.use_embed_mask and self.training: | |
| dec_embed = self._embed_mask(dec_embed, label_lens)[0] | |
| dec_out, states = self.rnn_forward(dec_embed, states) | |
| return dec_out | |
| def rnn_forward( | |
| self, | |
| x: torch.Tensor, | |
| state: Tuple[torch.Tensor, Optional[torch.Tensor]], | |
| ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: | |
| """Encode source label sequences. | |
| Args: | |
| x: RNN input sequences. (B, D_emb) | |
| state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) | |
| Returns: | |
| x: RNN output sequences. (B, D_dec) | |
| (h_next, c_next): Decoder hidden states. | |
| (N, B, D_dec), (N, B, D_dec) or None) | |
| """ | |
| h_prev, c_prev = state | |
| h_next, c_next = self.init_state(x.size(0)) | |
| for layer in range(self.dlayers): | |
| if self.dtype == "lstm": | |
| x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[ | |
| layer | |
| ](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1])) | |
| else: | |
| x, h_next[layer : layer + 1] = self.rnn[layer]( | |
| x, hx=h_prev[layer : layer + 1] | |
| ) | |
| x = self.dropout_rnn[layer](x) | |
| return x, (h_next, c_next) | |
| def score( | |
| self, | |
| label: torch.Tensor, | |
| label_sequence: List[int], | |
| dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]], | |
| ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: | |
| """One-step forward hypothesis. | |
| Args: | |
| label: Previous label. (1, 1) | |
| label_sequence: Current label sequence. | |
| dec_state: Previous decoder hidden states. | |
| ((N, 1, D_dec), (N, 1, D_dec) or None) | |
| Returns: | |
| dec_out: Decoder output sequence. (1, D_dec) | |
| dec_state: Decoder hidden states. | |
| ((N, 1, D_dec), (N, 1, D_dec) or None) | |
| """ | |
| str_labels = "_".join(map(str, label_sequence)) | |
| if str_labels in self.score_cache: | |
| dec_out, dec_state = self.score_cache[str_labels] | |
| else: | |
| dec_embed = self.embed(label) | |
| dec_out, dec_state = self.rnn_forward(dec_embed, dec_state) | |
| self.score_cache[str_labels] = (dec_out, dec_state) | |
| return dec_out[0], dec_state | |
| def batch_score( | |
| self, | |
| hyps: List[Hypothesis], | |
| ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: | |
| """One-step forward hypotheses. | |
| Args: | |
| hyps: Hypotheses. | |
| Returns: | |
| dec_out: Decoder output sequences. (B, D_dec) | |
| states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) | |
| """ | |
| labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device) | |
| dec_embed = self.embed(labels) | |
| states = self.create_batch_states([h.dec_state for h in hyps]) | |
| dec_out, states = self.rnn_forward(dec_embed, states) | |
| return dec_out.squeeze(1), states | |
| def set_device(self, device: torch.device) -> None: | |
| """Set GPU device to use. | |
| Args: | |
| device: Device ID. | |
| """ | |
| self.device = device | |
| def init_state( | |
| self, batch_size: int | |
| ) -> Tuple[torch.Tensor, Optional[torch.tensor]]: | |
| """Initialize decoder states. | |
| Args: | |
| batch_size: Batch size. | |
| Returns: | |
| : Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) | |
| """ | |
| h_n = torch.zeros( | |
| self.dlayers, | |
| batch_size, | |
| self.output_size, | |
| device=self.device, | |
| ) | |
| if self.dtype == "lstm": | |
| c_n = torch.zeros( | |
| self.dlayers, | |
| batch_size, | |
| self.output_size, | |
| device=self.device, | |
| ) | |
| return (h_n, c_n) | |
| return (h_n, None) | |
| def select_state( | |
| self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """Get specified ID state from decoder hidden states. | |
| Args: | |
| states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) | |
| idx: State ID to extract. | |
| Returns: | |
| : Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None) | |
| """ | |
| return ( | |
| states[0][:, idx : idx + 1, :], | |
| states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None, | |
| ) | |
| def create_batch_states( | |
| self, | |
| new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]], | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| """Create decoder hidden states. | |
| Args: | |
| new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)] | |
| Returns: | |
| states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) | |
| """ | |
| return ( | |
| torch.cat([s[0] for s in new_states], dim=1), | |
| ( | |
| torch.cat([s[1] for s in new_states], dim=1) | |
| if self.dtype == "lstm" | |
| else None | |
| ), | |
| ) | |