Spaces:
Running
on
Zero
Running
on
Zero
| """Sequential implementation of Recurrent Neural Network Language Model.""" | |
| from typing import Tuple | |
| from typing import Union | |
| import torch | |
| import torch.nn as nn | |
| from funasr_detach.train.abs_model import AbsLM | |
| class SequentialRNNLM(AbsLM): | |
| """Sequential RNNLM. | |
| See also: | |
| https://github.com/pytorch/examples/blob/4581968193699de14b56527296262dd76ab43557/word_language_model/model.py | |
| """ | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| unit: int = 650, | |
| nhid: int = None, | |
| nlayers: int = 2, | |
| dropout_rate: float = 0.0, | |
| tie_weights: bool = False, | |
| rnn_type: str = "lstm", | |
| ignore_id: int = 0, | |
| ): | |
| super().__init__() | |
| ninp = unit | |
| if nhid is None: | |
| nhid = unit | |
| rnn_type = rnn_type.upper() | |
| self.drop = nn.Dropout(dropout_rate) | |
| self.encoder = nn.Embedding(vocab_size, ninp, padding_idx=ignore_id) | |
| if rnn_type in ["LSTM", "GRU"]: | |
| rnn_class = getattr(nn, rnn_type) | |
| self.rnn = rnn_class( | |
| ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True | |
| ) | |
| else: | |
| try: | |
| nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type] | |
| except KeyError: | |
| raise ValueError( | |
| """An invalid option for `--model` was supplied, | |
| options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""" | |
| ) | |
| self.rnn = nn.RNN( | |
| ninp, | |
| nhid, | |
| nlayers, | |
| nonlinearity=nonlinearity, | |
| dropout=dropout_rate, | |
| batch_first=True, | |
| ) | |
| self.decoder = nn.Linear(nhid, vocab_size) | |
| # Optionally tie weights as in: | |
| # "Using the Output Embedding to Improve Language Models" | |
| # (Press & Wolf 2016) https://arxiv.org/abs/1608.05859 | |
| # and | |
| # "Tying Word Vectors and Word Classifiers: | |
| # A Loss Framework for Language Modeling" (Inan et al. 2016) | |
| # https://arxiv.org/abs/1611.01462 | |
| if tie_weights: | |
| if nhid != ninp: | |
| raise ValueError( | |
| "When using the tied flag, nhid must be equal to emsize" | |
| ) | |
| self.decoder.weight = self.encoder.weight | |
| self.rnn_type = rnn_type | |
| self.nhid = nhid | |
| self.nlayers = nlayers | |
| def zero_state(self): | |
| """Initialize LM state filled with zero values.""" | |
| if isinstance(self.rnn, torch.nn.LSTM): | |
| h = torch.zeros((self.nlayers, self.nhid), dtype=torch.float) | |
| c = torch.zeros((self.nlayers, self.nhid), dtype=torch.float) | |
| state = h, c | |
| else: | |
| state = torch.zeros((self.nlayers, self.nhid), dtype=torch.float) | |
| return state | |
| def forward( | |
| self, input: torch.Tensor, hidden: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| emb = self.drop(self.encoder(input)) | |
| output, hidden = self.rnn(emb, hidden) | |
| output = self.drop(output) | |
| decoded = self.decoder( | |
| output.contiguous().view(output.size(0) * output.size(1), output.size(2)) | |
| ) | |
| return ( | |
| decoded.view(output.size(0), output.size(1), decoded.size(1)), | |
| hidden, | |
| ) | |
| def score( | |
| self, | |
| y: torch.Tensor, | |
| state: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], | |
| x: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: | |
| """Score new token. | |
| Args: | |
| y: 1D torch.int64 prefix tokens. | |
| state: Scorer state for prefix tokens | |
| x: 2D encoder feature that generates ys. | |
| Returns: | |
| Tuple of | |
| torch.float32 scores for next token (n_vocab) | |
| and next state for ys | |
| """ | |
| y, new_state = self(y[-1].view(1, 1), state) | |
| logp = y.log_softmax(dim=-1).view(-1) | |
| return logp, new_state | |
| def batch_score( | |
| self, ys: torch.Tensor, states: torch.Tensor, xs: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Score new token batch. | |
| Args: | |
| ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). | |
| states (List[Any]): Scorer states for prefix tokens. | |
| xs (torch.Tensor): | |
| The encoder feature that generates ys (n_batch, xlen, n_feat). | |
| Returns: | |
| tuple[torch.Tensor, List[Any]]: Tuple of | |
| batchfied scores for next token with shape of `(n_batch, n_vocab)` | |
| and next state list for ys. | |
| """ | |
| if states[0] is None: | |
| states = None | |
| elif isinstance(self.rnn, torch.nn.LSTM): | |
| # states: Batch x 2 x (Nlayers, Dim) -> 2 x (Nlayers, Batch, Dim) | |
| h = torch.stack([h for h, c in states], dim=1) | |
| c = torch.stack([c for h, c in states], dim=1) | |
| states = h, c | |
| else: | |
| # states: Batch x (Nlayers, Dim) -> (Nlayers, Batch, Dim) | |
| states = torch.stack(states, dim=1) | |
| ys, states = self(ys[:, -1:], states) | |
| # ys: (Batch, 1, Nvocab) -> (Batch, NVocab) | |
| assert ys.size(1) == 1, ys.shape | |
| ys = ys.squeeze(1) | |
| logp = ys.log_softmax(dim=-1) | |
| # state: Change to batch first | |
| if isinstance(self.rnn, torch.nn.LSTM): | |
| # h, c: (Nlayers, Batch, Dim) | |
| h, c = states | |
| # states: Batch x 2 x (Nlayers, Dim) | |
| states = [(h[:, i], c[:, i]) for i in range(h.size(1))] | |
| else: | |
| # states: (Nlayers, Batch, Dim) -> Batch x (Nlayers, Dim) | |
| states = [states[:, i] for i in range(states.size(1))] | |
| return logp, states | |