Spaces:
Build error
Build error
| from functools import partial | |
| from itertools import islice, cycle | |
| from torch import nn | |
| from text2punks.attention import Attention, SparseAxialCausalAttention | |
| # helpers | |
| def exists(val): | |
| return val is not None | |
| def default(val, d): | |
| return val if exists(val) else d | |
| def cast_tuple(val, depth = 1): | |
| if isinstance(val, list): | |
| val = tuple(val) | |
| return val if isinstance(val, tuple) else (val,) * depth | |
| # classes | |
| class SequentialSequence(nn.Module): | |
| def __init__(self, layers): | |
| super().__init__() | |
| self.layers = layers | |
| def forward(self, x): | |
| for (f, g) in list(self.layers): | |
| x = x + f(x) | |
| x = x + g(x) | |
| return x | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim) | |
| self.fn = fn | |
| def forward(self, x, **kwargs): | |
| return self.fn(self.norm(x), **kwargs) | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, dropout = 0.): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, dim * 4), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(dim * 4, dim) | |
| ) | |
| # the order of dropout nn.Linear(4 * n_embd, n_embd) vs nn.Dropout(resid_pdrop) | |
| def forward(self, x): | |
| return self.net(x) | |
| class Transformer(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| dim, | |
| depth, | |
| seq_len, | |
| causal = True, | |
| heads = 8, | |
| dim_head = 64, | |
| attn_dropout = 0., | |
| resid_dropout = 0., | |
| embd_dropout = 0., | |
| ff_dropout = 0., | |
| image_size = 24, | |
| attn_types = None, | |
| ): | |
| super().__init__() | |
| layers = nn.ModuleList([]) | |
| attn_types = default(attn_types, ('full',)) | |
| attn_types = cast_tuple(attn_types) | |
| attn_type_layer = islice(cycle(attn_types), depth) | |
| for attn_type in attn_type_layer: | |
| if attn_type == 'full': | |
| attn_class = partial(Attention, causal = causal) | |
| elif attn_type == 'axial_row': | |
| attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 0, image_size = image_size) | |
| elif attn_type == 'axial_col': | |
| attn_class = partial(SparseAxialCausalAttention, seq_len = seq_len, axis = 1, image_size = image_size) | |
| else: | |
| raise ValueError(f'attention type "{attn_type}" is not valid') | |
| attn = attn_class(dim, seq_len = seq_len, heads = heads, dim_head = dim_head, attn_dropout = attn_dropout, resid_dropout = resid_dropout) | |
| layers.append(nn.ModuleList([ | |
| PreNorm(dim, attn), | |
| PreNorm(dim, FeedForward(dim, dropout = ff_dropout)) | |
| ])) | |
| # full attention in the last layer | |
| attn_class = partial(Attention, causal = causal) | |
| attn = attn_class(dim, seq_len = seq_len, heads = heads, dim_head = dim_head, attn_dropout = attn_dropout, resid_dropout = resid_dropout) | |
| layers.append(nn.ModuleList([ | |
| PreNorm(dim, attn), | |
| PreNorm(dim, FeedForward(dim, dropout = ff_dropout)) | |
| ])) | |
| self.layers = SequentialSequence(layers) | |
| self.embd_drop = nn.Dropout(embd_dropout) | |
| def forward(self, x): | |
| x = self.embd_drop(x) | |
| return self.layers(x) | |