#@title Architecture implementation # TODO: comment and rename variables / clean code # https://arxiv.org/abs/2410.01201v1 import torch import torch.nn as nn import torch.nn.functional as F # appendix B # https://github.com/glassroom/heinsen_sequence def heinsen_associative_scan_log(log_coeffs, log_values): a_star = log_coeffs.cumsum(dim = 1) log_h0_plus_b_star = (log_values - a_star).logcumsumexp(dim = 1) log_h = a_star + log_h0_plus_b_star return log_h.exp() # appendix B.3 def g(x): return torch.where(x >= 0, x + 0.5, x.sigmoid()) def log_g(x): return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x)) # log-space version of minGRU - B.3.1 # they enforce the hidden states to be positive class minGRU(nn.Module): def __init__(self, d_model, d_inner): super().__init__() self.d_model = d_model self.d_inner = d_inner self.hidden_proj = nn.Linear(d_model, d_inner, bias=False) self.gate_proj = nn.Linear(d_model, d_inner, bias=False) self.out_proj = nn.Linear(d_inner, d_model, bias=False) def step(self, x, h_prev=None): hidden = self.hidden_proj(x) gate = self.gate_proj(x) h_prev = h_prev.detach() if h_prev is not None else None hidden = g(hidden) gate = gate.sigmoid() out = torch.lerp(h_prev, hidden, gate) if h_prev is not None else (hidden * gate) h_next = out[:, -1:] out = self.out_proj(out) return out, h_next def forward(self, x, h_prev=None): seq_len = x.shape[1] hidden = self.hidden_proj(x) gate = self.gate_proj(x) h_prev = h_prev.detach() if h_prev is not None else None log_coeffs = -F.softplus(gate) log_z = -F.softplus(-gate) log_tilde_h = log_g(hidden) log_values = log_z + log_tilde_h if h_prev is not None: log_values = torch.cat((h_prev.log(), log_values), dim=1) log_coeffs = F.pad(log_coeffs, (0, 0, 1, 0)) out = heinsen_associative_scan_log(log_coeffs, log_values) out = out[:, -seq_len:] h_next = out[:, -1:] out = self.out_proj(out) return out, h_next class RMSNorm(nn.Module): def __init__(self, d_model: int, eps: float=1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(d_model)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight class minGRULM(nn.Module): def __init__(self, vocab_size, d_model, d_inner, n_layers): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.d_inner = d_inner self.n_layers = n_layers self.embed = nn.Embedding(vocab_size, d_model) self.layers = nn.ModuleList([]) for _ in range(n_layers): self.layers.append(nn.ModuleList([ RMSNorm(d_model), minGRU(d_model, d_inner) ])) self.norm_f = RMSNorm(d_model) self.lm_head = nn.Linear(d_model, vocab_size, bias = False) # One single step of minGRU, forwarding one token and outputting one token def step(self, x, h_states=None): x = self.embed(x) h_next = [] h_states = iter(h_states if h_states is not None else []) for norm, mingru in self.layers: h_prev = next(h_states, None) residual = x x = norm(x) x, h_t = mingru.step(x, h_prev) x = x + residual h_next.append(h_t) x = self.norm_f(x) logits = self.lm_head(x) return logits, h_next def forward(self, x, h_states=None): x, labels = x[:, :-1], x[:, 1:] x = self.embed(x) h_next = [] h_states = iter(h_states if h_states is not None else []) for norm, mingru in self.layers: h_prev = next(h_states, None) residual = x x = norm(x) x, h_t = mingru.forward(x, h_prev) x = x + residual h_next.append(h_t) x = self.norm_f(x) logits = self.lm_head(x) loss = F.cross_entropy(logits.transpose(1, 2), labels) return logits, h_next, loss