Spaces:
Sleeping
Sleeping
File size: 5,051 Bytes
baca18e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
# Copyright 2025 Pathway Technology, Inc.
import dataclasses
import math
import torch
import torch.nn.functional as F
from torch import nn
@dataclasses.dataclass
class BDHConfig:
n_layer: int = 6
n_embd: int = 256
dropout: float = 0.1
n_head: int = 4
mlp_internal_dim_multiplier: int = 128
vocab_size: int = 256
def get_freqs(n, theta, dtype):
def quantize(t, q=2):
return (t / q).floor() * q
return (
1.0
/ (theta ** (quantize(torch.arange(0, n, 1, dtype=dtype)) / n))
/ (2 * math.pi)
)
class Attention(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
nh = config.n_head
D = config.n_embd
N = config.mlp_internal_dim_multiplier * D // nh
self.freqs = torch.nn.Buffer(
get_freqs(N, theta=2**16, dtype=torch.float32).view(1, 1, 1, N)
)
@staticmethod
def phases_cos_sin(phases):
phases = (phases % 1) * (2 * math.pi)
phases_cos = torch.cos(phases)
phases_sin = torch.sin(phases)
return phases_cos, phases_sin
@staticmethod
def rope(phases, v):
v_rot = torch.stack((-v[..., 1::2], v[..., ::2]), dim=-1).view(*v.size())
phases_cos, phases_sin = Attention.phases_cos_sin(phases)
return (v * phases_cos).to(v.dtype) + (v_rot * phases_sin).to(v.dtype)
def forward(self, Q, K, V):
assert self.freqs.dtype == torch.float32
assert K is Q
_, _, T, _ = Q.size()
r_phases = (
torch.arange(
0,
T,
device=self.freqs.device,
dtype=self.freqs.dtype,
).view(1, 1, -1, 1)
) * self.freqs
QR = self.rope(r_phases, Q)
KR = QR
# Current attention
scores = (QR @ KR.mT).tril(diagonal=-1)
return scores @ V
class BDH(nn.Module):
def __init__(self, config: BDHConfig):
super().__init__()
assert config.vocab_size is not None
self.config = config
nh = config.n_head
D = config.n_embd
N = config.mlp_internal_dim_multiplier * D // nh
self.decoder = nn.Parameter(torch.zeros((nh * N, D)).normal_(std=0.02))
self.encoder = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02))
self.attn = Attention(config)
self.ln = nn.LayerNorm(D, elementwise_affine=False, bias=False)
self.embed = nn.Embedding(config.vocab_size, D)
self.drop = nn.Dropout(config.dropout)
self.encoder_v = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02))
self.lm_head = nn.Parameter(
torch.zeros((D, config.vocab_size)).normal_(std=0.02)
)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
C = self.config
B, T = idx.size()
D = C.n_embd
nh = C.n_head
N = D * C.mlp_internal_dim_multiplier // nh
x = self.embed(idx).unsqueeze(1)
# actually helps with training
x = self.ln(x) # B, 1, T, D
for level in range(C.n_layer):
x_latent = x @ self.encoder
x_sparse = F.relu(x_latent) # B, nh, T, N
yKV = self.attn(
Q=x_sparse,
K=x_sparse,
V=x,
)
yKV = self.ln(yKV)
y_latent = yKV @ self.encoder_v
y_sparse = F.relu(y_latent)
xy_sparse = x_sparse * y_sparse # B, nh, T, N
xy_sparse = self.drop(xy_sparse)
yMLP = (
xy_sparse.transpose(1, 2).reshape(B, 1, T, N * nh) @ self.decoder
) # B, 1, T, D
y = self.ln(yMLP)
x = self.ln(x + y)
logits = x.view(B, T, D) @ self.lm_head
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(
self,
idx: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
top_k: int | None = None,
) -> torch.Tensor:
for _ in range(max_new_tokens):
idx_cond = idx
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < values[:, [-1]]] = float("-inf")
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
|