import torch from torch import nn class BiLSTM(nn.Module): inference_chunk_length = 512 def __init__(self, input_features, recurrent_features, use_gru=False, dropout=0.0): super().__init__() self.rnn = (nn.LSTM if not use_gru else nn.GRU)( input_features, recurrent_features, batch_first=True, bidirectional=True, dropout=dropout, ) def forward(self, x): if self.training: return self.rnn(x)[0] else: # evaluation mode: support for longer sequences that do not fit in memory batch_size, sequence_length, input_features = x.shape hidden_size = self.rnn.hidden_size num_directions = 2 if self.rnn.bidirectional else 1 h = torch.zeros(num_directions, batch_size, hidden_size, device=x.device) c = torch.zeros(num_directions, batch_size, hidden_size, device=x.device) output = torch.zeros( batch_size, sequence_length, num_directions * hidden_size, device=x.device, ) # forward direction slices = range(0, sequence_length, self.inference_chunk_length) for start in slices: end = start + self.inference_chunk_length output[:, start:end, :], (h, c) = self.rnn(x[:, start:end, :], (h, c)) # reverse direction if self.rnn.bidirectional: h.zero_() c.zero_() for start in reversed(slices): end = start + self.inference_chunk_length result, (h, c) = self.rnn(x[:, start:end, :], (h, c)) output[:, start:end, hidden_size:] = result[:, :, hidden_size:] return output class UniLSTM(nn.Module): inference_chunk_length = 512 def __init__(self, input_features, recurrent_features): super().__init__() self.rnn = nn.LSTM(input_features, recurrent_features, batch_first=True) def forward(self, x): if self.training: return self.rnn(x)[0] else: # evaluation mode: support for longer sequences that do not fit in memory batch_size, sequence_length, input_features = x.shape hidden_size = self.rnn.hidden_size num_directions = 2 if self.rnn.bidirectional else 1 h = torch.zeros(num_directions, batch_size, hidden_size, device=x.device) c = torch.zeros(num_directions, batch_size, hidden_size, device=x.device) output = torch.zeros( batch_size, sequence_length, num_directions * hidden_size, device=x.device, ) # forward direction slices = range(0, sequence_length, self.inference_chunk_length) for start in slices: end = start + self.inference_chunk_length output[:, start:end, :], (h, c) = self.rnn(x[:, start:end, :], (h, c)) # reverse direction if self.rnn.bidirectional: h.zero_() c.zero_() for start in reversed(slices): end = start + self.inference_chunk_length result, (h, c) = self.rnn(x[:, start:end, :], (h, c)) output[:, start:end, hidden_size:] = result[:, :, hidden_size:] return output