|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def heinsen_associative_scan_log(log_coeffs, log_values): |
|
|
a_star = log_coeffs.cumsum(dim = 1) |
|
|
log_h0_plus_b_star = (log_values - a_star).logcumsumexp(dim = 1) |
|
|
log_h = a_star + log_h0_plus_b_star |
|
|
return log_h.exp() |
|
|
|
|
|
|
|
|
|
|
|
def g(x): return torch.where(x >= 0, x + 0.5, x.sigmoid()) |
|
|
def log_g(x): return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class minGRU(nn.Module): |
|
|
def __init__(self, d_model, d_inner): |
|
|
super().__init__() |
|
|
|
|
|
self.d_model = d_model |
|
|
self.d_inner = d_inner |
|
|
|
|
|
self.hidden_proj = nn.Linear(d_model, d_inner, bias=False) |
|
|
self.gate_proj = nn.Linear(d_model, d_inner, bias=False) |
|
|
self.out_proj = nn.Linear(d_inner, d_model, bias=False) |
|
|
|
|
|
|
|
|
def step(self, x, h_prev=None): |
|
|
hidden = self.hidden_proj(x) |
|
|
gate = self.gate_proj(x) |
|
|
|
|
|
h_prev = h_prev.detach() if h_prev is not None else None |
|
|
|
|
|
hidden = g(hidden) |
|
|
gate = gate.sigmoid() |
|
|
out = torch.lerp(h_prev, hidden, gate) if h_prev is not None else (hidden * gate) |
|
|
|
|
|
h_next = out[:, -1:] |
|
|
out = self.out_proj(out) |
|
|
|
|
|
return out, h_next |
|
|
|
|
|
|
|
|
def forward(self, x, h_prev=None): |
|
|
seq_len = x.shape[1] |
|
|
hidden = self.hidden_proj(x) |
|
|
gate = self.gate_proj(x) |
|
|
|
|
|
h_prev = h_prev.detach() if h_prev is not None else None |
|
|
|
|
|
log_coeffs = -F.softplus(gate) |
|
|
log_z = -F.softplus(-gate) |
|
|
log_tilde_h = log_g(hidden) |
|
|
log_values = log_z + log_tilde_h |
|
|
|
|
|
if h_prev is not None: |
|
|
log_values = torch.cat((h_prev.log(), log_values), dim=1) |
|
|
log_coeffs = F.pad(log_coeffs, (0, 0, 1, 0)) |
|
|
|
|
|
out = heinsen_associative_scan_log(log_coeffs, log_values) |
|
|
out = out[:, -seq_len:] |
|
|
|
|
|
h_next = out[:, -1:] |
|
|
out = self.out_proj(out) |
|
|
|
|
|
return out, h_next |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, d_model: int, eps: float=1e-5): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(d_model)) |
|
|
|
|
|
def _norm(self, x): |
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
|
|
def forward(self, x): |
|
|
output = self._norm(x.float()).type_as(x) |
|
|
return output * self.weight |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class minGRULM(nn.Module): |
|
|
def __init__(self, vocab_size, d_model, d_inner, n_layers): |
|
|
super().__init__() |
|
|
|
|
|
self.vocab_size = vocab_size |
|
|
self.d_model = d_model |
|
|
self.d_inner = d_inner |
|
|
self.n_layers = n_layers |
|
|
|
|
|
self.embed = nn.Embedding(vocab_size, d_model) |
|
|
|
|
|
self.layers = nn.ModuleList([]) |
|
|
for _ in range(n_layers): |
|
|
self.layers.append(nn.ModuleList([ |
|
|
RMSNorm(d_model), |
|
|
minGRU(d_model, d_inner) |
|
|
])) |
|
|
|
|
|
self.norm_f = RMSNorm(d_model) |
|
|
self.lm_head = nn.Linear(d_model, vocab_size, bias = False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def step(self, x, h_states=None): |
|
|
x = self.embed(x) |
|
|
|
|
|
h_next = [] |
|
|
h_states = iter(h_states if h_states is not None else []) |
|
|
|
|
|
for norm, mingru in self.layers: |
|
|
h_prev = next(h_states, None) |
|
|
residual = x |
|
|
|
|
|
x = norm(x) |
|
|
x, h_t = mingru.step(x, h_prev) |
|
|
x = x + residual |
|
|
|
|
|
h_next.append(h_t) |
|
|
|
|
|
x = self.norm_f(x) |
|
|
logits = self.lm_head(x) |
|
|
|
|
|
return logits, h_next |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, h_states=None): |
|
|
x, labels = x[:, :-1], x[:, 1:] |
|
|
x = self.embed(x) |
|
|
|
|
|
h_next = [] |
|
|
h_states = iter(h_states if h_states is not None else []) |
|
|
|
|
|
for norm, mingru in self.layers: |
|
|
h_prev = next(h_states, None) |
|
|
residual = x |
|
|
|
|
|
x = norm(x) |
|
|
x, h_t = mingru.forward(x, h_prev) |
|
|
x = x + residual |
|
|
|
|
|
h_next.append(h_t) |
|
|
|
|
|
x = self.norm_f(x) |
|
|
logits = self.lm_head(x) |
|
|
loss = F.cross_entropy(logits.transpose(1, 2), labels) |
|
|
|
|
|
return logits, h_next, loss |