|
|
""" |
|
|
MonoidForCausalLM โ Causal Monoid Language Model (HuggingFace Compatible) |
|
|
MonoidForCausalLM โ ๅนบๅ็พคๅ ๆ่ฏญ่จๆจกๅ (ๅ
ผๅฎน HuggingFace) |
|
|
|
|
|
Architecture / ๆถๆๆฆ่ฆ: |
|
|
Replace softmax attention with a monoid parallel-scan recurrence. |
|
|
็จๅนบๅ็พคๅนถ่กๆซๆ้ๆจๆฟไปฃ softmax ๆณจๆๅใ |
|
|
|
|
|
Core idea / ๆ ธๅฟๆๆณ: |
|
|
Softmax attention computes o_t = ฮฃ_{iโคt} softmax(q_tยทk_i) v_i |
|
|
โ requires O(T) KV-cache per layer at inference. |
|
|
Softmax ๆณจๆๅ่ฎก็ฎ o_t = ฮฃ_{iโคt} softmax(q_tยทk_i) v_i |
|
|
โ ๆจ็ๆถๆฏๅฑ้่ฆ O(T) ็ KV ็ผๅญใ |
|
|
|
|
|
Monoid attention compresses the entire causal history into a |
|
|
fixed-size state matrix S_t โ โ^{dรd} per head: |
|
|
S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t (explicit causal recurrence) |
|
|
o_t = q_t ยท S_t (state readout) |
|
|
ๅนบๅ็พคๆณจๆๅๅฐๅฎๆดๅ ๆๅๅฒๅ็ผฉๅฐๆฏไธชๅคดไธไธชๅบๅฎๅคงๅฐ็็ถๆ็ฉ้ต S_t: |
|
|
S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t (ๆพๅผๅ ๆ้ๆจ) |
|
|
o_t = q_t ยท S_t (็ถๆ่ฏปๅบ) |
|
|
|
|
|
This is a monoid because the binary operator: |
|
|
(log_ฮฑ, S) โ (log_ฮฒ, X) = (log_ฮฑ + log_ฮฒ, exp(log_ฮฒ)ยทS + X) |
|
|
is associative โ enables parallel prefix scan for training, |
|
|
and O(1) sequential update for inference. |
|
|
่ฟๆฏไธไธชๅนบๅ็พค๏ผๅ ไธบไบๅ
็ฎๅญ: |
|
|
(log_ฮฑ, S) โ (log_ฮฒ, X) = (log_ฮฑ + log_ฮฒ, exp(log_ฮฒ)ยทS + X) |
|
|
ๆปก่ถณ็ปๅๅพ โ ่ฎญ็ปๆถๅฏ็จๅนถ่กๅ็ผๆซๆ๏ผๆจ็ๆถ O(1) ้ๆญฅ้ๆจใ |
|
|
|
|
|
Key properties / ๅ
ณ้ฎ็นๆง: |
|
|
โ Explicit causal modeling โ ฮฑ_t gate explicitly controls how fast |
|
|
past information decays, making causality a first-class citizen. |
|
|
ๆพๅผๅ ๆๅปบๆจก โ ฮฑ_t ่กฐๅ้จๆพๅผๆงๅถๅๅฒไฟกๆฏ็้ๅฟ้็๏ผ |
|
|
ๅ ๆๆงๆฏไธ็ญๅ
ฌๆฐ่้้ mask ๆฝๅ ็็บฆๆใ |
|
|
|
|
|
โ Monoid state compression โ the full causal prefix x_{1:t} is |
|
|
lossily compressed into a fixed-size (dรd) state matrix per head. |
|
|
No O(T) KV-cache needed; inference is O(1) per token per layer. |
|
|
ๅนบๅ็พค็ถๆๅ็ผฉ โ ๅฎๆดๅ ๆๅ็ผ x_{1:t} ่ขซๆๆๅ็ผฉๅฐๆฏไธชๅคด |
|
|
ๅบๅฎๅคงๅฐ็ (dรd) ็ถๆ็ฉ้ตไธญใๆ ้ O(T) KV ็ผๅญ๏ผ |
|
|
ๆจ็ๆถๆฏๅฑๆฏ token O(1)ใ |
|
|
|
|
|
โ Parallel training โ associativity of โ enables O(T) parallel |
|
|
prefix scan (vs O(Tยฒ) for softmax attention). |
|
|
ๅนถ่ก่ฎญ็ป โ โ ็็ปๅๅพไฝฟ O(T) ๅนถ่กๅ็ผๆซๆๆไธบๅฏ่ฝ |
|
|
(ๅฏนๆฏ softmax ๆณจๆๅ็ O(Tยฒ))ใ |
|
|
|
|
|
Reuses LlamaMLP + LlamaRMSNorm from HuggingFace Transformers. |
|
|
ๅค็จ HuggingFace Transformers ็ LlamaMLP + LlamaRMSNormใ |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Optional, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch import Tensor |
|
|
|
|
|
from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin, AutoConfig, AutoModelForCausalLM |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
|
|
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm |
|
|
|
|
|
try: |
|
|
from monoid_scan_cuda import parallel_scan, parallel_scan_with_state |
|
|
except ImportError: |
|
|
|
|
|
|
|
|
|
|
|
def parallel_scan(log_alpha: Tensor, kv: Tensor) -> Tensor: |
|
|
"""Sequential prefix scan fallback: S_t = exp(log_ฮฑ_t)ยทS_{t-1} + kv_t.""" |
|
|
B, H, T, d1, d2 = kv.shape |
|
|
states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype) |
|
|
S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype) |
|
|
for t in range(T): |
|
|
decay = torch.exp(log_alpha[:, :, t]) |
|
|
while decay.dim() < S.dim(): |
|
|
decay = decay.unsqueeze(-1) |
|
|
S = S * decay + kv[:, :, t] |
|
|
states[:, :, t] = S |
|
|
return states |
|
|
|
|
|
def parallel_scan_with_state(log_alpha: Tensor, kv: Tensor): |
|
|
"""Sequential prefix scan that also returns the final (log_decay, S) state.""" |
|
|
B, H, T, d1, d2 = kv.shape |
|
|
states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype) |
|
|
S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype) |
|
|
log_acc = torch.zeros(B, H, 1, device=log_alpha.device, dtype=log_alpha.dtype) |
|
|
for t in range(T): |
|
|
decay = torch.exp(log_alpha[:, :, t]) |
|
|
while decay.dim() < S.dim(): |
|
|
decay = decay.unsqueeze(-1) |
|
|
S = S * decay + kv[:, :, t] |
|
|
states[:, :, t] = S |
|
|
log_acc = log_acc + log_alpha[:, :, t] |
|
|
return states, (log_acc, S) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MonoidConfig(PretrainedConfig): |
|
|
""" |
|
|
Configuration for the Monoid causal language model. |
|
|
ๅนบๅ็พคๅ ๆ่ฏญ่จๆจกๅ็้
็ฝฎใ |
|
|
|
|
|
Mirrors LlamaConfig for the shared components (MLP, RMSNorm, embedding) |
|
|
so that weights can be directly transferred from Llama checkpoints. |
|
|
ไธ LlamaConfig ็ๅ
ฑไบซ็ปไปถ (MLP, RMSNorm, embedding) ไฟๆไธ่ด, |
|
|
ไปฅไพฟไป Llama ๆฃๆฅ็น็ดๆฅ่ฟ็งปๆ้ใ |
|
|
""" |
|
|
model_type = "monoid" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size: int = 32000, |
|
|
hidden_size: int = 576, |
|
|
intermediate_size: int = 1536, |
|
|
num_hidden_layers: int = 30, |
|
|
num_attention_heads: int = 9, |
|
|
head_dim: int = 64, |
|
|
max_position_embeddings: int = 2048, |
|
|
rms_norm_eps: float = 1e-5, |
|
|
hidden_act: str = "silu", |
|
|
mlp_bias: bool = False, |
|
|
attention_bias: bool = False, |
|
|
tie_word_embeddings: bool = True, |
|
|
initializer_range: float = 0.041666666666666664, |
|
|
pad_token_id: int = None, |
|
|
bos_token_id: int = 1, |
|
|
eos_token_id: int = 2, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__( |
|
|
pad_token_id=pad_token_id, |
|
|
bos_token_id=bos_token_id, |
|
|
eos_token_id=eos_token_id, |
|
|
tie_word_embeddings=tie_word_embeddings, |
|
|
**kwargs, |
|
|
) |
|
|
self.vocab_size = vocab_size |
|
|
self.hidden_size = hidden_size |
|
|
self.intermediate_size = intermediate_size |
|
|
self.num_hidden_layers = num_hidden_layers |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.head_dim = head_dim |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.rms_norm_eps = rms_norm_eps |
|
|
self.hidden_act = hidden_act |
|
|
self.mlp_bias = mlp_bias |
|
|
self.attention_bias = attention_bias |
|
|
self.initializer_range = initializer_range |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MonoidCache: |
|
|
""" |
|
|
Per-layer monoid state cache for autoregressive inference. |
|
|
่ชๅๅฝๆจ็็้ๅฑๅนบๅ็พค็ถๆ็ผๅญใ |
|
|
|
|
|
Unlike Transformer KV-Cache that stores all past keys & values (O(T) memory), |
|
|
each layer here stores exactly ONE state tuple: |
|
|
(log_decay_acc, S) where S โ โ^{B, H, d, d} |
|
|
This is the monoid "sum" of all past (log_ฮฑ_i, k_iโv_i) via โ. |
|
|
Memory is O(1) per layer regardless of sequence length. |
|
|
|
|
|
ไธๅไบ Transformer ็ KV-Cache (ๅญๅจๆๆ่ฟๅป็ key ๅ value, O(T) ๅ
ๅญ), |
|
|
่ฟ้ๆฏๅฑไป
ๅญๅจไธไธช็ถๆๅ
็ป: |
|
|
(log_decay_acc, S) ๅ
ถไธญ S โ โ^{B, H, d, d} |
|
|
่ฟๆฏๆๆ่ฟๅป็ (log_ฮฑ_i, k_iโv_i) ้่ฟ โ ็ดฏ็งฏ็ๅนบๅ็พค "ๅ"ใ |
|
|
ๆ ่ฎบๅบๅๅค้ฟ๏ผๆฏๅฑๅ
ๅญ O(1)ใ |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.states: list[tuple[Tensor, Tensor] | None] = [] |
|
|
self.seen_tokens: int = 0 |
|
|
|
|
|
def get_seq_length(self, layer_idx: int = 0) -> int: |
|
|
return self.seen_tokens |
|
|
|
|
|
def update(self, layer_idx: int, state: tuple[Tensor, Tensor]): |
|
|
"""Store the accumulated monoid state for a given layer. |
|
|
ๅญๅจๆๅฎๅฑ็็ดฏ็งฏๅนบๅ็พค็ถๆใ""" |
|
|
while len(self.states) <= layer_idx: |
|
|
self.states.append(None) |
|
|
self.states[layer_idx] = state |
|
|
|
|
|
def get_state(self, layer_idx: int) -> tuple[Tensor, Tensor] | None: |
|
|
"""Retrieve the accumulated monoid state for a given layer. |
|
|
่ทๅๆๅฎๅฑ็็ดฏ็งฏๅนบๅ็พค็ถๆใ""" |
|
|
if layer_idx < len(self.states): |
|
|
return self.states[layer_idx] |
|
|
return None |
|
|
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor): |
|
|
"""Reorder cache for beam search. ไธบ beam search ้ๆ็ผๅญใ""" |
|
|
for i, state in enumerate(self.states): |
|
|
if state is not None: |
|
|
log_d, kv = state |
|
|
self.states[i] = (log_d[beam_idx], kv[beam_idx]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def monoid_op( |
|
|
a: tuple[Tensor, Tensor], |
|
|
b: tuple[Tensor, Tensor], |
|
|
) -> tuple[Tensor, Tensor]: |
|
|
""" |
|
|
The monoid binary operator โ on (log-space decay, state matrix) pairs. |
|
|
ๅนบๅ็พคไบๅ
็ฎๅญ โ๏ผไฝ็จไบ (ๅฏนๆฐ่กฐๅ, ็ถๆ็ฉ้ต) ๅฏนใ |
|
|
|
|
|
Definition / ๅฎไน: |
|
|
(log_ฮฑ, S) โ (log_ฮฒ, X) = (log_ฮฑ + log_ฮฒ, exp(log_ฮฒ)ยทS + X) |
|
|
|
|
|
Why this is a monoid / ไธบไปไน่ฟๆฏๅนบๅ็พค: |
|
|
โข Associativity / ็ปๅๅพ: |
|
|
(a โ b) โ c = a โ (b โ c) โ |
|
|
This enables parallel prefix scan for training (reduce tree) |
|
|
and O(1) left-fold for inference (sequential append). |
|
|
็ปๅๅพไฝฟ่ฎญ็ปๆถๅฏไปฅ็จๅนถ่กๅ็ผๆซๆ (ๅฝ็บฆๆ ), |
|
|
ๆจ็ๆถๅฏไปฅ O(1) ๅทฆๆๅ (้ๆญฅ่ฟฝๅ )ใ |
|
|
|
|
|
โข Identity / ๅไฝๅ
: |
|
|
e = (0, 0) โ e โ a = a โ e = a โ |
|
|
|
|
|
Why log-space / ไธบไปไน็จๅฏนๆฐ็ฉบ้ด: |
|
|
Working in log-space for the decay factor avoids numerical |
|
|
underflow when ฮฑ^T โ 0 for long sequences. |
|
|
่กฐๅๅ ๅญๅจๅฏนๆฐ็ฉบ้ดไธญ่ฟ็ฎ๏ผ้ฟๅ
้ฟๅบๅไธ ฮฑ^T โ 0 ็ๆฐๅผไธๆบขใ |
|
|
|
|
|
Causal semantics / ๅ ๆ่ฏญไน: |
|
|
S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t |
|
|
The decay ฮฑ_t โ (0,1) explicitly controls how much of the past |
|
|
the model retains. This is *explicit causal modeling* โ the model |
|
|
must learn to balance retention vs novelty at every timestep. |
|
|
่กฐๅ ฮฑ_t โ (0,1) ๆพๅผๆงๅถๆจกๅไฟ็ๅคๅฐ่ฟๅปไฟกๆฏใ |
|
|
่ฟๅฐฑๆฏ *ๆพๅผๅ ๆๅปบๆจก* โ ๆจกๅๅฟ
้กปๅจๆฏไธชๆถ้ดๆญฅๅญฆไน ๅฆไฝ |
|
|
ๅนณ่กกไฟ็ๆงไฟกๆฏไธๅธๆถๆฐไฟกๆฏใ |
|
|
""" |
|
|
log_a, kv_a = a |
|
|
log_b, kv_b = b |
|
|
|
|
|
new_log = log_a + log_b |
|
|
decay_b = torch.exp(log_b) |
|
|
while decay_b.dim() < kv_a.dim(): |
|
|
decay_b = decay_b.unsqueeze(-1) |
|
|
|
|
|
return new_log, kv_a * decay_b + kv_b |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MonoidAttention(nn.Module): |
|
|
""" |
|
|
Monoid Causal Attention โ replaces softmax attention entirely. |
|
|
ๅนบๅ็พคๅ ๆๆณจๆๅ โ ๅฎๅ
จๆฟไปฃ softmax ๆณจๆๅใ |
|
|
|
|
|
Key differences from standard attention / ไธๆ ๅๆณจๆๅ็ๅ
ณ้ฎๅบๅซ: |
|
|
โ No RoPE / positional encoding โ position is implicitly encoded |
|
|
by the causal decay gate ฮฑ_t. The model learns *when* to forget |
|
|
rather than encoding *where* tokens are. |
|
|
ไธไฝฟ็จ RoPE / ไฝ็ฝฎ็ผ็ โ ไฝ็ฝฎไฟกๆฏ็ฑๅ ๆ่กฐๅ้จ ฮฑ_t ้ๅผ็ผ็ ใ |
|
|
ๆจกๅๅญฆไน *ไฝๆถ้ๅฟ* ่้็ผ็ token *ๅจๅช้*ใ |
|
|
|
|
|
โ No KV-Cache โ replaced by MonoidCache with O(1) state per layer. |
|
|
Each state S โ โ^{Hรdรd} is a compressed summary of ALL past tokens. |
|
|
ไธไฝฟ็จ KV ็ผๅญ โ ็ฑ O(1) ็ MonoidCache ็ถๆๆฟไปฃใ |
|
|
ๆฏไธช็ถๆ S โ โ^{Hรdรd} ๆฏๆๆ่ฟๅป token ็ๅ็ผฉๆ่ฆใ |
|
|
|
|
|
โ No attention mask โ causality is built into the recurrence itself. |
|
|
S_t only depends on S_{t-1} and the current token by construction. |
|
|
ไธไฝฟ็จๆณจๆๅๆฉ็ โ ๅ ๆๆงๅ
ๅปบไบ้ๆจ็ปๆๆฌ่บซใ |
|
|
S_t ไป
ไพ่ต S_{t-1} ๅๅฝๅ token๏ผ็ปๆไธไฟ่ฏๅ ๆๆงใ |
|
|
|
|
|
Computation / ่ฎก็ฎ: |
|
|
Training (parallel scan, O(T)): |
|
|
k_t = SiLU(k_proj(x_t)) # non-negative keys for PSD state |
|
|
S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t # monoid recurrence via prefix scan |
|
|
o_t = q_t ยท S_t # linear readout from state |
|
|
|
|
|
Inference (RNN mode, O(1) per token): |
|
|
Same recurrence, but applied one token at a time. |
|
|
|
|
|
่ฎญ็ป (ๅนถ่กๆซๆ, O(T)): |
|
|
k_t = SiLU(k_proj(x_t)) # ้่ด key ไฟ่ฏ็ถๆ็ฉ้ตๅๆญฃๅฎ |
|
|
S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t # ้่ฟๅ็ผๆซๆๅฎ็ฐๅนบๅ็พค้ๆจ |
|
|
o_t = q_t ยท S_t # ไป็ถๆไธญ็บฟๆง่ฏปๅบ |
|
|
|
|
|
ๆจ็ (RNN ๆจกๅผ, ๆฏ token O(1)): |
|
|
ๅไธ้ๆจๅ
ฌๅผ, ไฝ้ token ้กบๅบๅบ็จใ |
|
|
""" |
|
|
|
|
|
def __init__(self, config: MonoidConfig, layer_idx: int): |
|
|
super().__init__() |
|
|
self.layer_idx = layer_idx |
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.head_dim = config.head_dim |
|
|
self.scaling = self.head_dim ** -0.5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
|
|
self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
|
|
self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
|
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.decay_proj = nn.Linear(config.hidden_size, self.num_heads, bias=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.q_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps) |
|
|
self.k_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.h0 = nn.Parameter(torch.zeros(1, self.num_heads, self.head_dim, self.head_dim)) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: Tensor, |
|
|
monoid_cache: MonoidCache | None = None, |
|
|
use_cache: bool = False, |
|
|
) -> tuple[Tensor, tuple[Tensor, Tensor] | None]: |
|
|
""" |
|
|
Args: |
|
|
hidden_states: [B, T, hidden_size] |
|
|
monoid_cache: O(1) state cache for inference |
|
|
ๆจ็็จ O(1) ็ถๆ็ผๅญ |
|
|
use_cache: whether to use/update the cache |
|
|
ๆฏๅฆไฝฟ็จ/ๆดๆฐ็ผๅญ |
|
|
|
|
|
Returns: |
|
|
output: [B, T, hidden_size] |
|
|
final_state: (log_decay_acc, S) or None |
|
|
""" |
|
|
B, T, _ = hidden_states.shape |
|
|
H, d = self.num_heads, self.head_dim |
|
|
|
|
|
|
|
|
|
|
|
q = self.q_proj(hidden_states).view(B, T, H, d).transpose(1, 2) |
|
|
k = self.k_proj(hidden_states).view(B, T, H, d).transpose(1, 2) |
|
|
v = self.v_proj(hidden_states).view(B, T, H, d).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
q = self.q_norm(q) * self.scaling |
|
|
k = self.k_norm(k) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
k = torch.nn.functional.silu(k) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
alpha = torch.sigmoid(self.decay_proj(hidden_states)) |
|
|
alpha = alpha.transpose(1, 2).unsqueeze(-1) |
|
|
log_alpha = torch.log(alpha.clamp(min=1e-6)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if use_cache and T == 1: |
|
|
|
|
|
|
|
|
kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, 0], v[:, :, 0]) |
|
|
log_t = log_alpha[:, :, 0] |
|
|
|
|
|
prev = monoid_cache.get_state(self.layer_idx) if monoid_cache else None |
|
|
if prev is None: |
|
|
|
|
|
|
|
|
decay_t = torch.exp(log_t) |
|
|
while decay_t.dim() < self.h0.dim(): |
|
|
decay_t = decay_t.unsqueeze(-1) |
|
|
new_state = (log_t, self.h0.expand(B, -1, -1, -1) * decay_t + kv_t) |
|
|
else: |
|
|
|
|
|
|
|
|
new_state = monoid_op(prev, (log_t, kv_t)) |
|
|
|
|
|
if monoid_cache is not None: |
|
|
monoid_cache.update(self.layer_idx, new_state) |
|
|
|
|
|
|
|
|
|
|
|
o = torch.einsum('bhd, bhde -> bhe', q[:, :, 0], new_state[1]) |
|
|
|
|
|
|
|
|
o = o.contiguous().view(B, 1, -1) |
|
|
return self.o_proj(o), new_state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if use_cache: |
|
|
S = self.h0.expand(B, -1, -1, -1).clone() |
|
|
log_acc = torch.zeros(B, H, 1, device=hidden_states.device, dtype=q.dtype) |
|
|
o_parts = [] |
|
|
for t in range(T): |
|
|
kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, t], v[:, :, t]) |
|
|
decay = torch.exp(log_alpha[:, :, t]) |
|
|
while decay.dim() < S.dim(): |
|
|
decay = decay.unsqueeze(-1) |
|
|
S = S * decay + kv_t |
|
|
o_parts.append(torch.einsum('bhd, bhde -> bhe', q[:, :, t], S)) |
|
|
log_acc = log_acc + log_alpha[:, :, t] |
|
|
|
|
|
final_state = (log_acc, S) |
|
|
if monoid_cache is not None: |
|
|
monoid_cache.update(self.layer_idx, final_state) |
|
|
|
|
|
o = torch.stack(o_parts, dim=2) |
|
|
o = o.transpose(1, 2).contiguous().view(B, T, -1) |
|
|
return self.o_proj(o), final_state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kv = torch.einsum('bhtd, bhte -> bhtde', k, v) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
states = parallel_scan(log_alpha, kv) |
|
|
|
|
|
|
|
|
|
|
|
cum_log_alpha = torch.cumsum(log_alpha, dim=2) |
|
|
h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) |
|
|
states = states + h0_decay * self.h0.unsqueeze(2) |
|
|
|
|
|
|
|
|
|
|
|
o = torch.einsum('bhtd, bhtde -> bhte', q, states) |
|
|
|
|
|
o = o.transpose(1, 2).contiguous().view(B, T, -1) |
|
|
return self.o_proj(o), None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MonoidDecoderLayer(nn.Module): |
|
|
""" |
|
|
Pre-Norm Transformer block with Monoid attention. |
|
|
ไฝฟ็จๅนบๅ็พคๆณจๆๅ็ Pre-Norm Transformer ๅใ |
|
|
|
|
|
Data flow / ๆฐๆฎๆต: |
|
|
x โ RMSNorm โ MonoidAttn โ +residual โ RMSNorm โ LlamaMLP โ +residual โ out |
|
|
|
|
|
The MLP and RMSNorm are identical to Llama (weights transferred directly). |
|
|
Only MonoidAttention is the novel component. |
|
|
MLP ๅ RMSNorm ไธ Llama ๅฎๅ
จ็ธๅ (ๆ้็ดๆฅ่ฟ็งป)ใ |
|
|
ไป
MonoidAttention ๆฏๅ
จๆฐ็ปไปถใ |
|
|
""" |
|
|
gradient_checkpointing = False |
|
|
|
|
|
def __init__(self, config: MonoidConfig, layer_idx: int): |
|
|
super().__init__() |
|
|
self.self_attn = MonoidAttention(config, layer_idx) |
|
|
self.mlp = LlamaMLP(config) |
|
|
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: Tensor, |
|
|
monoid_cache: MonoidCache | None = None, |
|
|
use_cache: bool = False, |
|
|
) -> Tensor: |
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
hidden_states, _ = self.self_attn(hidden_states, monoid_cache=monoid_cache, use_cache=use_cache) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MonoidPreTrainedModel(PreTrainedModel): |
|
|
config_class = MonoidConfig |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["MonoidDecoderLayer"] |
|
|
|
|
|
def _init_weights(self, module: nn.Module): |
|
|
std = self.config.initializer_range |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
if isinstance(module, MonoidAttention): |
|
|
nn.init.constant_(module.decay_proj.bias, 4.0) |
|
|
|
|
|
class MonoidModel(MonoidPreTrainedModel): |
|
|
""" |
|
|
Stack of MonoidDecoderLayers with token embedding and final norm. |
|
|
ๅนบๅ็พค่งฃ็ ๅฑๅ ๅ , ๅธฆ token ๅตๅ
ฅๅๆ็ปๅฝไธๅใ |
|
|
|
|
|
Forward: embed_tokens โ N ร MonoidDecoderLayer โ final_norm |
|
|
ๅๅ: embed_tokens โ N ร MonoidDecoderLayer โ final_norm |
|
|
""" |
|
|
|
|
|
def __init__(self, config: MonoidConfig): |
|
|
super().__init__(config) |
|
|
self.padding_idx = config.pad_token_id |
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
|
self.layers = nn.ModuleList( |
|
|
[MonoidDecoderLayer(config, i) for i in range(config.num_hidden_layers)] |
|
|
) |
|
|
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.gradient_checkpointing = False |
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Tensor | None = None, |
|
|
inputs_embeds: Tensor | None = None, |
|
|
monoid_cache: MonoidCache | None = None, |
|
|
use_cache: bool = False, |
|
|
) -> BaseModelOutputWithPast: |
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
for layer in self.layers: |
|
|
if self.gradient_checkpointing and self.training and not use_cache: |
|
|
hidden_states = self._gradient_checkpointing_func( |
|
|
layer.__call__, |
|
|
hidden_states, |
|
|
monoid_cache, |
|
|
use_cache, |
|
|
) |
|
|
else: |
|
|
hidden_states = layer(hidden_states, monoid_cache=monoid_cache, use_cache=use_cache) |
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
return BaseModelOutputWithPast( |
|
|
last_hidden_state=hidden_states, |
|
|
past_key_values=monoid_cache, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MonoidForCausalLM(MonoidPreTrainedModel, GenerationMixin): |
|
|
""" |
|
|
Monoid-based causal language model with LM head. |
|
|
ๅบไบๅนบๅ็พค็ๅ ๆ่ฏญ่จๆจกๅ, ๅธฆ่ฏญ่จๆจกๅๅคดใ |
|
|
|
|
|
The architecture in one sentence: |
|
|
"Llama body + Monoid mind" โ reuse Llama's proven MLP/embeddings, |
|
|
replace attention with monoid state compression for O(1) inference. |
|
|
|
|
|
ไธๅฅ่ฏๆฆๆฌๆถๆ: |
|
|
"Llama ็่บซไฝ + ๅนบๅ็พค็ๆ็ปด" โ ๅค็จ Llama ๆ็็ MLP/ๅตๅ
ฅๅฑ, |
|
|
็จๅนบๅ็พค็ถๆๅ็ผฉๆฟๆขๆณจๆๅ, ๅฎ็ฐ O(1) ๆจ็ใ |
|
|
""" |
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_is_stateful = True |
|
|
|
|
|
def __init__(self, config: MonoidConfig): |
|
|
super().__init__(config) |
|
|
self.model = MonoidModel(config) |
|
|
self.vocab_size = config.vocab_size |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.model.embed_tokens |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.model.embed_tokens = value |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.lm_head |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.lm_head = new_embeddings |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
input_ids: Tensor, |
|
|
past_key_values=None, |
|
|
attention_mask: Tensor | None = None, |
|
|
inputs_embeds: Tensor | None = None, |
|
|
**kwargs, |
|
|
) -> dict: |
|
|
""" |
|
|
Called by GenerationMixin at each decoding step. |
|
|
GenerationMixin ๅจๆฏไธช่งฃ็ ๆญฅ่ฐ็จๆญคๆนๆณใ |
|
|
|
|
|
HuggingFace may pass a DynamicCache; we intercept and replace |
|
|
it with MonoidCache since we don't use standard KV-cache. |
|
|
HuggingFace ๅฏ่ฝไผ ๅ
ฅ DynamicCache; ๆไปฌๆฆๆชๅนถๆฟๆขไธบ |
|
|
MonoidCache, ๅ ไธบๆไปฌไธไฝฟ็จๆ ๅ KV ็ผๅญใ |
|
|
""" |
|
|
|
|
|
|
|
|
if past_key_values is not None and not isinstance(past_key_values, MonoidCache): |
|
|
past_key_values = None |
|
|
|
|
|
if past_key_values is not None and past_key_values.seen_tokens > 0: |
|
|
|
|
|
|
|
|
input_ids = input_ids[:, -1:] |
|
|
|
|
|
model_inputs = { |
|
|
"input_ids": input_ids, |
|
|
"monoid_cache": past_key_values, |
|
|
"use_cache": True, |
|
|
} |
|
|
return model_inputs |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Tensor | None = None, |
|
|
attention_mask: Tensor | None = None, |
|
|
|
|
|
position_ids: Tensor | None = None, |
|
|
|
|
|
past_key_values: MonoidCache | None = None, |
|
|
inputs_embeds: Tensor | None = None, |
|
|
labels: Tensor | None = None, |
|
|
use_cache: bool | None = None, |
|
|
monoid_cache: MonoidCache | None = None, |
|
|
output_attentions: bool | None = None, |
|
|
output_hidden_states: bool | None = None, |
|
|
logits_to_keep: int | Tensor = 0, |
|
|
**kwargs, |
|
|
) -> CausalLMOutputWithPast: |
|
|
|
|
|
|
|
|
cache = monoid_cache or past_key_values |
|
|
|
|
|
|
|
|
|
|
|
if cache is not None and not isinstance(cache, MonoidCache): |
|
|
cache = None |
|
|
|
|
|
if use_cache and cache is None: |
|
|
cache = MonoidCache() |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
monoid_cache=cache, |
|
|
use_cache=bool(use_cache), |
|
|
) |
|
|
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
|
|
|
|
|
|
|
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep > 0 else slice(None) |
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
|
|
|
|
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
loss = nn.functional.cross_entropy( |
|
|
shift_logits.view(-1, self.vocab_size), |
|
|
shift_labels.view(-1), |
|
|
ignore_index=-100, |
|
|
) |
|
|
|
|
|
if cache is not None: |
|
|
cache.seen_tokens += (input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]) |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=cache, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AutoConfig.register("monoid", MonoidConfig) |
|
|
AutoModelForCausalLM.register(MonoidConfig, MonoidForCausalLM) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') |
|
|
print(f'Device: {device}') |
|
|
|
|
|
config = MonoidConfig( |
|
|
vocab_size=49152, |
|
|
hidden_size=576, |
|
|
intermediate_size=1536, |
|
|
num_hidden_layers=30, |
|
|
num_attention_heads=9, |
|
|
head_dim=64, |
|
|
rms_norm_eps=1e-5, |
|
|
hidden_act="silu", |
|
|
tie_word_embeddings=True, |
|
|
) |
|
|
model = MonoidForCausalLM(config).to(device) |
|
|
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print(f'Parameters: {n_params:,}') |
|
|
|
|
|
|
|
|
B, T = 2, 64 |
|
|
ids = torch.randint(0, config.vocab_size, (B, T), device=device) |
|
|
out = model(ids, labels=ids) |
|
|
print(f'Train โ logits: {out.logits.shape}, loss: {out.loss:.4f}') |
|
|
|
|
|
|
|
|
prompt = torch.randint(0, config.vocab_size, (1, 8), device=device) |
|
|
cache = MonoidCache() |
|
|
|
|
|
prefill_out = model(prompt, use_cache=True, monoid_cache=cache) |
|
|
print(f'Prefill โ logits: {prefill_out.logits.shape}, cache seen: {cache.seen_tokens}') |
|
|
|
|
|
next_tok = prefill_out.logits[:, -1:].argmax(dim=-1) |
|
|
step_out = model(next_tok, use_cache=True, monoid_cache=cache) |
|
|
print(f'Decode โ logits: {step_out.logits.shape}, cache seen: {cache.seen_tokens}') |
|
|
|
|
|
|
|
|
print('\nMonoid associativity check / ๅนบๅ็พค็ปๅๅพ้ช่ฏ:') |
|
|
a = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4)) |
|
|
b = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4)) |
|
|
c = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4)) |
|
|
ab_c = monoid_op(monoid_op(a, b), c) |
|
|
a_bc = monoid_op(a, monoid_op(b, c)) |
|
|
err = (ab_c[1] - a_bc[1]).abs().max().item() |
|
|
print(f' |(aโb)โc - aโ(bโc)| = {err:.2e}') |
|
|
|
|
|
print('\nDone.') |
|
|
|