Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,520 Bytes
05d6e12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import torch
from torch import nn
class BiLSTM(nn.Module):
inference_chunk_length = 512
def __init__(self, input_features, recurrent_features, use_gru=False, dropout=0.0):
super().__init__()
self.rnn = (nn.LSTM if not use_gru else nn.GRU)(
input_features,
recurrent_features,
batch_first=True,
bidirectional=True,
dropout=dropout,
)
def forward(self, x):
if self.training:
return self.rnn(x)[0]
else:
# evaluation mode: support for longer sequences that do not fit in memory
batch_size, sequence_length, input_features = x.shape
hidden_size = self.rnn.hidden_size
num_directions = 2 if self.rnn.bidirectional else 1
h = torch.zeros(num_directions, batch_size, hidden_size, device=x.device)
c = torch.zeros(num_directions, batch_size, hidden_size, device=x.device)
output = torch.zeros(
batch_size,
sequence_length,
num_directions * hidden_size,
device=x.device,
)
# forward direction
slices = range(0, sequence_length, self.inference_chunk_length)
for start in slices:
end = start + self.inference_chunk_length
output[:, start:end, :], (h, c) = self.rnn(x[:, start:end, :], (h, c))
# reverse direction
if self.rnn.bidirectional:
h.zero_()
c.zero_()
for start in reversed(slices):
end = start + self.inference_chunk_length
result, (h, c) = self.rnn(x[:, start:end, :], (h, c))
output[:, start:end, hidden_size:] = result[:, :, hidden_size:]
return output
class UniLSTM(nn.Module):
inference_chunk_length = 512
def __init__(self, input_features, recurrent_features):
super().__init__()
self.rnn = nn.LSTM(input_features, recurrent_features, batch_first=True)
def forward(self, x):
if self.training:
return self.rnn(x)[0]
else:
# evaluation mode: support for longer sequences that do not fit in memory
batch_size, sequence_length, input_features = x.shape
hidden_size = self.rnn.hidden_size
num_directions = 2 if self.rnn.bidirectional else 1
h = torch.zeros(num_directions, batch_size, hidden_size, device=x.device)
c = torch.zeros(num_directions, batch_size, hidden_size, device=x.device)
output = torch.zeros(
batch_size,
sequence_length,
num_directions * hidden_size,
device=x.device,
)
# forward direction
slices = range(0, sequence_length, self.inference_chunk_length)
for start in slices:
end = start + self.inference_chunk_length
output[:, start:end, :], (h, c) = self.rnn(x[:, start:end, :], (h, c))
# reverse direction
if self.rnn.bidirectional:
h.zero_()
c.zero_()
for start in reversed(slices):
end = start + self.inference_chunk_length
result, (h, c) = self.rnn(x[:, start:end, :], (h, c))
output[:, start:end, hidden_size:] = result[:, :, hidden_size:]
return output
|