Spaces:
Sleeping
Sleeping
| # Copyright 2025 Pathway Technology, Inc. | |
| import dataclasses | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| class BDHConfig: | |
| n_layer: int = 6 | |
| n_embd: int = 256 | |
| dropout: float = 0.1 | |
| n_head: int = 4 | |
| mlp_internal_dim_multiplier: int = 128 | |
| vocab_size: int = 256 | |
| def get_freqs(n, theta, dtype): | |
| def quantize(t, q=2): | |
| return (t / q).floor() * q | |
| return ( | |
| 1.0 | |
| / (theta ** (quantize(torch.arange(0, n, 1, dtype=dtype)) / n)) | |
| / (2 * math.pi) | |
| ) | |
| class Attention(torch.nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| nh = config.n_head | |
| D = config.n_embd | |
| N = config.mlp_internal_dim_multiplier * D // nh | |
| self.freqs = torch.nn.Buffer( | |
| get_freqs(N, theta=2**16, dtype=torch.float32).view(1, 1, 1, N) | |
| ) | |
| def phases_cos_sin(phases): | |
| phases = (phases % 1) * (2 * math.pi) | |
| phases_cos = torch.cos(phases) | |
| phases_sin = torch.sin(phases) | |
| return phases_cos, phases_sin | |
| def rope(phases, v): | |
| v_rot = torch.stack((-v[..., 1::2], v[..., ::2]), dim=-1).view(*v.size()) | |
| phases_cos, phases_sin = Attention.phases_cos_sin(phases) | |
| return (v * phases_cos).to(v.dtype) + (v_rot * phases_sin).to(v.dtype) | |
| def forward(self, Q, K, V): | |
| assert self.freqs.dtype == torch.float32 | |
| assert K is Q | |
| _, _, T, _ = Q.size() | |
| r_phases = ( | |
| torch.arange( | |
| 0, | |
| T, | |
| device=self.freqs.device, | |
| dtype=self.freqs.dtype, | |
| ).view(1, 1, -1, 1) | |
| ) * self.freqs | |
| QR = self.rope(r_phases, Q) | |
| KR = QR | |
| # Current attention | |
| scores = (QR @ KR.mT).tril(diagonal=-1) | |
| return scores @ V | |
| class BDH(nn.Module): | |
| def __init__(self, config: BDHConfig): | |
| super().__init__() | |
| assert config.vocab_size is not None | |
| self.config = config | |
| nh = config.n_head | |
| D = config.n_embd | |
| N = config.mlp_internal_dim_multiplier * D // nh | |
| self.decoder = nn.Parameter(torch.zeros((nh * N, D)).normal_(std=0.02)) | |
| self.encoder = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02)) | |
| self.attn = Attention(config) | |
| self.ln = nn.LayerNorm(D, elementwise_affine=False, bias=False) | |
| self.embed = nn.Embedding(config.vocab_size, D) | |
| self.drop = nn.Dropout(config.dropout) | |
| self.encoder_v = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02)) | |
| self.lm_head = nn.Parameter( | |
| torch.zeros((D, config.vocab_size)).normal_(std=0.02) | |
| ) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def forward(self, idx, targets=None): | |
| C = self.config | |
| B, T = idx.size() | |
| D = C.n_embd | |
| nh = C.n_head | |
| N = D * C.mlp_internal_dim_multiplier // nh | |
| x = self.embed(idx).unsqueeze(1) | |
| # actually helps with training | |
| x = self.ln(x) # B, 1, T, D | |
| for level in range(C.n_layer): | |
| x_latent = x @ self.encoder | |
| x_sparse = F.relu(x_latent) # B, nh, T, N | |
| yKV = self.attn( | |
| Q=x_sparse, | |
| K=x_sparse, | |
| V=x, | |
| ) | |
| yKV = self.ln(yKV) | |
| y_latent = yKV @ self.encoder_v | |
| y_sparse = F.relu(y_latent) | |
| xy_sparse = x_sparse * y_sparse # B, nh, T, N | |
| xy_sparse = self.drop(xy_sparse) | |
| yMLP = ( | |
| xy_sparse.transpose(1, 2).reshape(B, 1, T, N * nh) @ self.decoder | |
| ) # B, 1, T, D | |
| y = self.ln(yMLP) | |
| x = self.ln(x + y) | |
| logits = x.view(B, T, D) @ self.lm_head | |
| loss = None | |
| if targets is not None: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| return logits, loss | |
| def generate( | |
| self, | |
| idx: torch.Tensor, | |
| max_new_tokens: int, | |
| temperature: float = 1.0, | |
| top_k: int | None = None, | |
| ) -> torch.Tensor: | |
| for _ in range(max_new_tokens): | |
| idx_cond = idx | |
| logits, _ = self(idx_cond) | |
| logits = logits[:, -1, :] / temperature | |
| if top_k is not None: | |
| values, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < values[:, [-1]]] = float("-inf") | |
| probs = F.softmax(logits, dim=-1) | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat((idx, idx_next), dim=1) | |
| return idx | |