mingru / model.py
flpelerin's picture
Update model.py
74a6d3c verified
#@title Architecture implementation
# TODO: comment and rename variables / clean code
# https://arxiv.org/abs/2410.01201v1
import torch
import torch.nn as nn
import torch.nn.functional as F
# appendix B
# https://github.com/glassroom/heinsen_sequence
def heinsen_associative_scan_log(log_coeffs, log_values):
a_star = log_coeffs.cumsum(dim = 1)
log_h0_plus_b_star = (log_values - a_star).logcumsumexp(dim = 1)
log_h = a_star + log_h0_plus_b_star
return log_h.exp()
# appendix B.3
def g(x): return torch.where(x >= 0, x + 0.5, x.sigmoid())
def log_g(x): return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x))
# log-space version of minGRU - B.3.1
# they enforce the hidden states to be positive
class minGRU(nn.Module):
def __init__(self, d_model, d_inner):
super().__init__()
self.d_model = d_model
self.d_inner = d_inner
self.hidden_proj = nn.Linear(d_model, d_inner, bias=False)
self.gate_proj = nn.Linear(d_model, d_inner, bias=False)
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
def step(self, x, h_prev=None):
hidden = self.hidden_proj(x)
gate = self.gate_proj(x)
h_prev = h_prev.detach() if h_prev is not None else None
hidden = g(hidden)
gate = gate.sigmoid()
out = torch.lerp(h_prev, hidden, gate) if h_prev is not None else (hidden * gate)
h_next = out[:, -1:]
out = self.out_proj(out)
return out, h_next
def forward(self, x, h_prev=None):
seq_len = x.shape[1]
hidden = self.hidden_proj(x)
gate = self.gate_proj(x)
h_prev = h_prev.detach() if h_prev is not None else None
log_coeffs = -F.softplus(gate)
log_z = -F.softplus(-gate)
log_tilde_h = log_g(hidden)
log_values = log_z + log_tilde_h
if h_prev is not None:
log_values = torch.cat((h_prev.log(), log_values), dim=1)
log_coeffs = F.pad(log_coeffs, (0, 0, 1, 0))
out = heinsen_associative_scan_log(log_coeffs, log_values)
out = out[:, -seq_len:]
h_next = out[:, -1:]
out = self.out_proj(out)
return out, h_next
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class minGRULM(nn.Module):
def __init__(self, vocab_size, d_model, d_inner, n_layers):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.d_inner = d_inner
self.n_layers = n_layers
self.embed = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([])
for _ in range(n_layers):
self.layers.append(nn.ModuleList([
RMSNorm(d_model),
minGRU(d_model, d_inner)
]))
self.norm_f = RMSNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias = False)
# One single step of minGRU, forwarding one token and outputting one token
def step(self, x, h_states=None):
x = self.embed(x)
h_next = []
h_states = iter(h_states if h_states is not None else [])
for norm, mingru in self.layers:
h_prev = next(h_states, None)
residual = x
x = norm(x)
x, h_t = mingru.step(x, h_prev)
x = x + residual
h_next.append(h_t)
x = self.norm_f(x)
logits = self.lm_head(x)
return logits, h_next
def forward(self, x, h_states=None):
x, labels = x[:, :-1], x[:, 1:]
x = self.embed(x)
h_next = []
h_states = iter(h_states if h_states is not None else [])
for norm, mingru in self.layers:
h_prev = next(h_states, None)
residual = x
x = norm(x)
x, h_t = mingru.forward(x, h_prev)
x = x + residual
h_next.append(h_t)
x = self.norm_f(x)
logits = self.lm_head(x)
loss = F.cross_entropy(logits.transpose(1, 2), labels)
return logits, h_next, loss