Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class LSA(nn.Module): | |
| def __init__(self, attn_dim, kernel_size=31, filters=32): | |
| super().__init__() | |
| self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True) | |
| self.L = nn.Linear(filters, attn_dim, bias=False) | |
| self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term | |
| self.v = nn.Linear(attn_dim, 1, bias=False) | |
| self.cumulative = None | |
| self.attention = None | |
| def init_attention(self, encoder_seq_proj): | |
| device = encoder_seq_proj.device # use same device as parameters | |
| b, t, c = encoder_seq_proj.size() | |
| self.cumulative = torch.zeros(b, t, device=device) | |
| self.attention = torch.zeros(b, t, device=device) | |
| def forward(self, encoder_seq_proj, query, times, chars): | |
| if times == 0: self.init_attention(encoder_seq_proj) | |
| processed_query = self.W(query).unsqueeze(1) | |
| location = self.cumulative.unsqueeze(1) | |
| processed_loc = self.L(self.conv(location).transpose(1, 2)) | |
| u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc)) | |
| u = u.squeeze(-1) | |
| # Mask zero padding chars | |
| u = u * (chars != 0).float() | |
| # Smooth Attention | |
| # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True) | |
| scores = F.softmax(u, dim=1) | |
| self.attention = scores | |
| self.cumulative = self.cumulative + self.attention | |
| return scores.unsqueeze(-1).transpose(1, 2) | |