Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def split_heads(x, num_heads): | |
| """ Split heads | |
| :param x: A tensor with shape [batch, length, channels] | |
| :param num_heads: An integer | |
| :returns: A tensor with shape [batch, heads, length, channels / heads] | |
| """ | |
| assert x.shape[-1] % num_heads == 0, str(x.shape) | |
| return x.reshape(x.shape[:-1] + (num_heads, x.shape[-1] // num_heads)).permute(0, 2, 1, 3) | |
| def combine_heads(x): | |
| """ Combine heads | |
| :param x: A tensor with shape [batch, heads, length, channels] | |
| :returns: A tensor with shape [batch, length, heads * channels] | |
| """ | |
| x = x.permute([0, 2, 1, 3]) | |
| return x.reshape(x.shape[:-2] + (x.shape[-1] * x.shape[-2],)) | |
| class SimpleAttention(nn.Module): | |
| def __init__(self, query_size=192, key_size=192, value_size=192, num_heads=1): | |
| super(SimpleAttention, self).__init__() | |
| self.q_transform = nn.Linear(query_size, query_size, bias=False) | |
| self.k_transform = nn.Linear(key_size, query_size, bias=False) | |
| self.v_transform = nn.Linear(value_size, query_size, bias=False) | |
| self.output_transform = nn.Linear(query_size, query_size, bias=False) | |
| self.query_size = query_size | |
| self.key_size = key_size | |
| self.value_size = value_size | |
| self.num_heads = num_heads | |
| def forward(self, query, key, value, attn_mask=None, bias=None): | |
| q = self.q_transform(query) | |
| k = self.k_transform(key) | |
| v = self.v_transform(value) | |
| logits = torch.bmm(q, k.transpose(1, 2)) # [batch, length_q, length_k] | |
| if bias is not None: | |
| logits += bias | |
| if attn_mask is not None: | |
| logits = logits + attn_mask * -1e9 | |
| weights = F.softmax(logits, dim=-1) | |
| out = torch.bmm(weights, v) | |
| out = self.output_transform(out) | |
| return out, weights | |