Spartacus-1B-Instruct / MonoidForCausalLM.py
LisaMegaWatts's picture
Duplicate from NoesisLab/Spartacus-1B-Instruct
ae7984f
"""
MonoidForCausalLM โ€” Causal Monoid Language Model (HuggingFace Compatible)
MonoidForCausalLM โ€” ๅนบๅŠ็พคๅ› ๆžœ่ฏญ่จ€ๆจกๅž‹ (ๅ…ผๅฎน HuggingFace)
Architecture / ๆžถๆž„ๆฆ‚่ฆ:
Replace softmax attention with a monoid parallel-scan recurrence.
็”จๅนบๅŠ็พคๅนถ่กŒๆ‰ซๆ้€’ๆŽจๆ›ฟไปฃ softmax ๆณจๆ„ๅŠ›ใ€‚
Core idea / ๆ ธๅฟƒๆ€ๆƒณ:
Softmax attention computes o_t = ฮฃ_{iโ‰คt} softmax(q_tยทk_i) v_i
โ€” requires O(T) KV-cache per layer at inference.
Softmax ๆณจๆ„ๅŠ›่ฎก็ฎ— o_t = ฮฃ_{iโ‰คt} softmax(q_tยทk_i) v_i
โ€” ๆŽจ็†ๆ—ถๆฏๅฑ‚้œ€่ฆ O(T) ็š„ KV ็ผ“ๅญ˜ใ€‚
Monoid attention compresses the entire causal history into a
fixed-size state matrix S_t โˆˆ โ„^{dร—d} per head:
S_t = ฮฑ_t ยท S_{t-1} + k_t โŠ— v_t (explicit causal recurrence)
o_t = q_t ยท S_t (state readout)
ๅนบๅŠ็พคๆณจๆ„ๅŠ›ๅฐ†ๅฎŒๆ•ดๅ› ๆžœๅކๅฒๅŽ‹็ผฉๅˆฐๆฏไธชๅคดไธ€ไธชๅ›บๅฎšๅคงๅฐ็š„็Šถๆ€็Ÿฉ้˜ต S_t:
S_t = ฮฑ_t ยท S_{t-1} + k_t โŠ— v_t (ๆ˜พๅผๅ› ๆžœ้€’ๆŽจ)
o_t = q_t ยท S_t (็Šถๆ€่ฏปๅ‡บ)
This is a monoid because the binary operator:
(log_ฮฑ, S) โŠ• (log_ฮฒ, X) = (log_ฮฑ + log_ฮฒ, exp(log_ฮฒ)ยทS + X)
is associative โ†’ enables parallel prefix scan for training,
and O(1) sequential update for inference.
่ฟ™ๆ˜ฏไธ€ไธชๅนบๅŠ็พค๏ผŒๅ› ไธบไบŒๅ…ƒ็ฎ—ๅญ:
(log_ฮฑ, S) โŠ• (log_ฮฒ, X) = (log_ฮฑ + log_ฮฒ, exp(log_ฮฒ)ยทS + X)
ๆปก่ถณ็ป“ๅˆๅพ‹ โ†’ ่ฎญ็ปƒๆ—ถๅฏ็”จๅนถ่กŒๅ‰็ผ€ๆ‰ซๆ๏ผŒๆŽจ็†ๆ—ถ O(1) ้€ๆญฅ้€’ๆŽจใ€‚
Key properties / ๅ…ณ้”ฎ็‰นๆ€ง:
โœ“ Explicit causal modeling โ€” ฮฑ_t gate explicitly controls how fast
past information decays, making causality a first-class citizen.
ๆ˜พๅผๅ› ๆžœๅปบๆจก โ€” ฮฑ_t ่กฐๅ‡้—จๆ˜พๅผๆŽงๅˆถๅކๅฒไฟกๆฏ็š„้—ๅฟ˜้€Ÿ็އ๏ผŒ
ๅ› ๆžœๆ€งๆ˜ฏไธ€็ญ‰ๅ…ฌๆฐ‘่€Œ้ž้  mask ๆ–ฝๅŠ ็š„็บฆๆŸใ€‚
โœ“ Monoid state compression โ€” the full causal prefix x_{1:t} is
lossily compressed into a fixed-size (dร—d) state matrix per head.
No O(T) KV-cache needed; inference is O(1) per token per layer.
ๅนบๅŠ็พค็Šถๆ€ๅŽ‹็ผฉ โ€” ๅฎŒๆ•ดๅ› ๆžœๅ‰็ผ€ x_{1:t} ่ขซๆœ‰ๆŸๅŽ‹็ผฉๅˆฐๆฏไธชๅคด
ๅ›บๅฎšๅคงๅฐ็š„ (dร—d) ็Šถๆ€็Ÿฉ้˜ตไธญใ€‚ๆ— ้œ€ O(T) KV ็ผ“ๅญ˜๏ผ›
ๆŽจ็†ๆ—ถๆฏๅฑ‚ๆฏ token O(1)ใ€‚
โœ“ Parallel training โ€” associativity of โŠ• enables O(T) parallel
prefix scan (vs O(Tยฒ) for softmax attention).
ๅนถ่กŒ่ฎญ็ปƒ โ€” โŠ• ็š„็ป“ๅˆๅพ‹ไฝฟ O(T) ๅนถ่กŒๅ‰็ผ€ๆ‰ซๆๆˆไธบๅฏ่ƒฝ
(ๅฏนๆฏ” softmax ๆณจๆ„ๅŠ›็š„ O(Tยฒ))ใ€‚
Reuses LlamaMLP + LlamaRMSNorm from HuggingFace Transformers.
ๅค็”จ HuggingFace Transformers ็š„ LlamaMLP + LlamaRMSNormใ€‚
"""
from __future__ import annotations
from typing import Optional, Union
import torch
import torch.nn as nn
from torch import Tensor
from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin, AutoConfig, AutoModelForCausalLM
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm
try:
from monoid_scan_cuda import parallel_scan, parallel_scan_with_state
except ImportError:
# Pure-PyTorch fallback (sequential scan) โ€” works on CPU / MPS / any device.
# Slower than the fused CUDA kernel but numerically identical.
def parallel_scan(log_alpha: Tensor, kv: Tensor) -> Tensor:
"""Sequential prefix scan fallback: S_t = exp(log_ฮฑ_t)ยทS_{t-1} + kv_t."""
B, H, T, d1, d2 = kv.shape
states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
for t in range(T):
decay = torch.exp(log_alpha[:, :, t]) # [B, H, 1]
while decay.dim() < S.dim():
decay = decay.unsqueeze(-1)
S = S * decay + kv[:, :, t]
states[:, :, t] = S
return states
def parallel_scan_with_state(log_alpha: Tensor, kv: Tensor):
"""Sequential prefix scan that also returns the final (log_decay, S) state."""
B, H, T, d1, d2 = kv.shape
states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
log_acc = torch.zeros(B, H, 1, device=log_alpha.device, dtype=log_alpha.dtype)
for t in range(T):
decay = torch.exp(log_alpha[:, :, t])
while decay.dim() < S.dim():
decay = decay.unsqueeze(-1)
S = S * decay + kv[:, :, t]
states[:, :, t] = S
log_acc = log_acc + log_alpha[:, :, t]
return states, (log_acc, S)
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
# Config / ้…็ฝฎ
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
class MonoidConfig(PretrainedConfig):
"""
Configuration for the Monoid causal language model.
ๅนบๅŠ็พคๅ› ๆžœ่ฏญ่จ€ๆจกๅž‹็š„้…็ฝฎใ€‚
Mirrors LlamaConfig for the shared components (MLP, RMSNorm, embedding)
so that weights can be directly transferred from Llama checkpoints.
ไธŽ LlamaConfig ็š„ๅ…ฑไบซ็ป„ไปถ (MLP, RMSNorm, embedding) ไฟๆŒไธ€่‡ด,
ไปฅไพฟไปŽ Llama ๆฃ€ๆŸฅ็‚น็›ดๆŽฅ่ฟ็งปๆƒ้‡ใ€‚
"""
model_type = "monoid"
def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 576,
intermediate_size: int = 1536,
num_hidden_layers: int = 30,
num_attention_heads: int = 9,
head_dim: int = 64,
max_position_embeddings: int = 2048,
rms_norm_eps: float = 1e-5,
hidden_act: str = "silu",
mlp_bias: bool = False,
attention_bias: bool = False,
tie_word_embeddings: bool = True,
initializer_range: float = 0.041666666666666664,
pad_token_id: int = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.max_position_embeddings = max_position_embeddings
self.rms_norm_eps = rms_norm_eps
self.hidden_act = hidden_act
self.mlp_bias = mlp_bias
self.attention_bias = attention_bias
self.initializer_range = initializer_range
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
# Monoid Cache โ€” O(1) state replaces O(T) KV-Cache
# ๅนบๅŠ็พค็ผ“ๅญ˜ โ€” O(1) ็Šถๆ€ๆ›ฟไปฃ O(T) KV ็ผ“ๅญ˜
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
class MonoidCache:
"""
Per-layer monoid state cache for autoregressive inference.
่‡ชๅ›žๅฝ’ๆŽจ็†็š„้€ๅฑ‚ๅนบๅŠ็พค็Šถๆ€็ผ“ๅญ˜ใ€‚
Unlike Transformer KV-Cache that stores all past keys & values (O(T) memory),
each layer here stores exactly ONE state tuple:
(log_decay_acc, S) where S โˆˆ โ„^{B, H, d, d}
This is the monoid "sum" of all past (log_ฮฑ_i, k_iโŠ—v_i) via โŠ•.
Memory is O(1) per layer regardless of sequence length.
ไธๅŒไบŽ Transformer ็š„ KV-Cache (ๅญ˜ๅ‚จๆ‰€ๆœ‰่ฟ‡ๅŽป็š„ key ๅ’Œ value, O(T) ๅ†…ๅญ˜),
่ฟ™้‡Œๆฏๅฑ‚ไป…ๅญ˜ๅ‚จไธ€ไธช็Šถๆ€ๅ…ƒ็ป„:
(log_decay_acc, S) ๅ…ถไธญ S โˆˆ โ„^{B, H, d, d}
่ฟ™ๆ˜ฏๆ‰€ๆœ‰่ฟ‡ๅŽป็š„ (log_ฮฑ_i, k_iโŠ—v_i) ้€š่ฟ‡ โŠ• ็ดฏ็งฏ็š„ๅนบๅŠ็พค "ๅ’Œ"ใ€‚
ๆ— ่ฎบๅบๅˆ—ๅคš้•ฟ๏ผŒๆฏๅฑ‚ๅ†…ๅญ˜ O(1)ใ€‚
"""
def __init__(self):
self.states: list[tuple[Tensor, Tensor] | None] = []
self.seen_tokens: int = 0
def get_seq_length(self, layer_idx: int = 0) -> int:
return self.seen_tokens
def update(self, layer_idx: int, state: tuple[Tensor, Tensor]):
"""Store the accumulated monoid state for a given layer.
ๅญ˜ๅ‚จๆŒ‡ๅฎšๅฑ‚็š„็ดฏ็งฏๅนบๅŠ็พค็Šถๆ€ใ€‚"""
while len(self.states) <= layer_idx:
self.states.append(None)
self.states[layer_idx] = state
def get_state(self, layer_idx: int) -> tuple[Tensor, Tensor] | None:
"""Retrieve the accumulated monoid state for a given layer.
่Žทๅ–ๆŒ‡ๅฎšๅฑ‚็š„็ดฏ็งฏๅนบๅŠ็พค็Šถๆ€ใ€‚"""
if layer_idx < len(self.states):
return self.states[layer_idx]
return None
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorder cache for beam search. ไธบ beam search ้‡ๆŽ’็ผ“ๅญ˜ใ€‚"""
for i, state in enumerate(self.states):
if state is not None:
log_d, kv = state
self.states[i] = (log_d[beam_idx], kv[beam_idx])
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
# Monoid Operator โ€” the algebraic heart
# ๅนบๅŠ็พค็ฎ—ๅญ โ€” ไปฃๆ•ฐๆ ธๅฟƒ
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
def monoid_op(
a: tuple[Tensor, Tensor],
b: tuple[Tensor, Tensor],
) -> tuple[Tensor, Tensor]:
"""
The monoid binary operator โŠ• on (log-space decay, state matrix) pairs.
ๅนบๅŠ็พคไบŒๅ…ƒ็ฎ—ๅญ โŠ•๏ผŒไฝœ็”จไบŽ (ๅฏนๆ•ฐ่กฐๅ‡, ็Šถๆ€็Ÿฉ้˜ต) ๅฏนใ€‚
Definition / ๅฎšไน‰:
(log_ฮฑ, S) โŠ• (log_ฮฒ, X) = (log_ฮฑ + log_ฮฒ, exp(log_ฮฒ)ยทS + X)
Why this is a monoid / ไธบไป€ไนˆ่ฟ™ๆ˜ฏๅนบๅŠ็พค:
โ€ข Associativity / ็ป“ๅˆๅพ‹:
(a โŠ• b) โŠ• c = a โŠ• (b โŠ• c) โœ“
This enables parallel prefix scan for training (reduce tree)
and O(1) left-fold for inference (sequential append).
็ป“ๅˆๅพ‹ไฝฟ่ฎญ็ปƒๆ—ถๅฏไปฅ็”จๅนถ่กŒๅ‰็ผ€ๆ‰ซๆ (ๅฝ’็บฆๆ ‘),
ๆŽจ็†ๆ—ถๅฏไปฅ O(1) ๅทฆๆŠ˜ๅ  (้€ๆญฅ่ฟฝๅŠ )ใ€‚
โ€ข Identity / ๅ•ไฝๅ…ƒ:
e = (0, 0) โ†’ e โŠ• a = a โŠ• e = a โœ“
Why log-space / ไธบไป€ไนˆ็”จๅฏนๆ•ฐ็ฉบ้—ด:
Working in log-space for the decay factor avoids numerical
underflow when ฮฑ^T โ†’ 0 for long sequences.
่กฐๅ‡ๅ› ๅญๅœจๅฏนๆ•ฐ็ฉบ้—ดไธญ่ฟ็ฎ—๏ผŒ้ฟๅ…้•ฟๅบๅˆ—ไธ‹ ฮฑ^T โ†’ 0 ็š„ๆ•ฐๅ€ผไธ‹ๆบขใ€‚
Causal semantics / ๅ› ๆžœ่ฏญไน‰:
S_t = ฮฑ_t ยท S_{t-1} + k_t โŠ— v_t
The decay ฮฑ_t โˆˆ (0,1) explicitly controls how much of the past
the model retains. This is *explicit causal modeling* โ€” the model
must learn to balance retention vs novelty at every timestep.
่กฐๅ‡ ฮฑ_t โˆˆ (0,1) ๆ˜พๅผๆŽงๅˆถๆจกๅž‹ไฟ็•™ๅคšๅฐ‘่ฟ‡ๅŽปไฟกๆฏใ€‚
่ฟ™ๅฐฑๆ˜ฏ *ๆ˜พๅผๅ› ๆžœๅปบๆจก* โ€” ๆจกๅž‹ๅฟ…้กปๅœจๆฏไธชๆ—ถ้—ดๆญฅๅญฆไน ๅฆ‚ไฝ•
ๅนณ่กกไฟ็•™ๆ—งไฟกๆฏไธŽๅธๆ”ถๆ–ฐไฟกๆฏใ€‚
"""
log_a, kv_a = a
log_b, kv_b = b
new_log = log_a + log_b # log(ฮฑยทฮฒ) = log_ฮฑ + log_ฮฒ
decay_b = torch.exp(log_b) # ฮฒ = exp(log_ฮฒ)
while decay_b.dim() < kv_a.dim():
decay_b = decay_b.unsqueeze(-1) # broadcast to [B,H,...,1,1]
return new_log, kv_a * decay_b + kv_b # ฮฒยทS + X
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
# Monoid Attention โ€” the core innovation
# ๅนบๅŠ็พคๆณจๆ„ๅŠ› โ€” ๆ ธๅฟƒๅˆ›ๆ–ฐๅฑ‚
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
class MonoidAttention(nn.Module):
"""
Monoid Causal Attention โ€” replaces softmax attention entirely.
ๅนบๅŠ็พคๅ› ๆžœๆณจๆ„ๅŠ› โ€” ๅฎŒๅ…จๆ›ฟไปฃ softmax ๆณจๆ„ๅŠ›ใ€‚
Key differences from standard attention / ไธŽๆ ‡ๅ‡†ๆณจๆ„ๅŠ›็š„ๅ…ณ้”ฎๅŒบๅˆซ:
โœ— No RoPE / positional encoding โ€” position is implicitly encoded
by the causal decay gate ฮฑ_t. The model learns *when* to forget
rather than encoding *where* tokens are.
ไธไฝฟ็”จ RoPE / ไฝ็ฝฎ็ผ–็  โ€” ไฝ็ฝฎไฟกๆฏ็”ฑๅ› ๆžœ่กฐๅ‡้—จ ฮฑ_t ้šๅผ็ผ–็ ใ€‚
ๆจกๅž‹ๅญฆไน  *ไฝ•ๆ—ถ้—ๅฟ˜* ่€Œ้ž็ผ–็  token *ๅœจๅ“ช้‡Œ*ใ€‚
โœ— No KV-Cache โ€” replaced by MonoidCache with O(1) state per layer.
Each state S โˆˆ โ„^{Hร—dร—d} is a compressed summary of ALL past tokens.
ไธไฝฟ็”จ KV ็ผ“ๅญ˜ โ€” ็”ฑ O(1) ็š„ MonoidCache ็Šถๆ€ๆ›ฟไปฃใ€‚
ๆฏไธช็Šถๆ€ S โˆˆ โ„^{Hร—dร—d} ๆ˜ฏๆ‰€ๆœ‰่ฟ‡ๅŽป token ็š„ๅŽ‹็ผฉๆ‘˜่ฆใ€‚
โœ— No attention mask โ€” causality is built into the recurrence itself.
S_t only depends on S_{t-1} and the current token by construction.
ไธไฝฟ็”จๆณจๆ„ๅŠ›ๆŽฉ็  โ€” ๅ› ๆžœๆ€งๅ†…ๅปบไบŽ้€’ๆŽจ็ป“ๆž„ๆœฌ่บซใ€‚
S_t ไป…ไพ่ต– S_{t-1} ๅ’Œๅฝ“ๅ‰ token๏ผŒ็ป“ๆž„ไธŠไฟ่ฏๅ› ๆžœๆ€งใ€‚
Computation / ่ฎก็ฎ—:
Training (parallel scan, O(T)):
k_t = SiLU(k_proj(x_t)) # non-negative keys for PSD state
S_t = ฮฑ_t ยท S_{t-1} + k_t โŠ— v_t # monoid recurrence via prefix scan
o_t = q_t ยท S_t # linear readout from state
Inference (RNN mode, O(1) per token):
Same recurrence, but applied one token at a time.
่ฎญ็ปƒ (ๅนถ่กŒๆ‰ซๆ, O(T)):
k_t = SiLU(k_proj(x_t)) # ้ž่ดŸ key ไฟ่ฏ็Šถๆ€็Ÿฉ้˜ตๅŠๆญฃๅฎš
S_t = ฮฑ_t ยท S_{t-1} + k_t โŠ— v_t # ้€š่ฟ‡ๅ‰็ผ€ๆ‰ซๆๅฎž็ŽฐๅนบๅŠ็พค้€’ๆŽจ
o_t = q_t ยท S_t # ไปŽ็Šถๆ€ไธญ็บฟๆ€ง่ฏปๅ‡บ
ๆŽจ็† (RNN ๆจกๅผ, ๆฏ token O(1)):
ๅŒไธ€้€’ๆŽจๅ…ฌๅผ, ไฝ†้€ token ้กบๅบๅบ”็”จใ€‚
"""
def __init__(self, config: MonoidConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.scaling = self.head_dim ** -0.5 # 1/โˆšd, scale factor for qยทS readout
# qยทS ่ฏปๅ‡บ็š„็ผฉๆ”พๅ› ๅญ
# --- Projections (transferred from Llama) ---
# --- ๆŠ•ๅฝฑๅฑ‚ (ไปŽ Llama ่ฟ็งป) ---
# q_proj, o_proj: identical dims to Llama, direct copy
# k_proj, v_proj: Llama GQA has fewer KV heads; we tile to full heads
# q_proj, o_proj: ็ปดๅบฆไธŽ Llama ไธ€่‡ด, ็›ดๆŽฅๅคๅˆถ
# k_proj, v_proj: Llama GQA ็š„ KV ๅคดๆ›ดๅฐ‘; ๆˆ‘ไปฌ้‡ๅคๅˆฐๅ…จๅคดๆ•ฐ
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
# --- Decay gate (novel component, randomly initialized) ---
# --- ่กฐๅ‡้—จ (ๅ…จๆ–ฐ็ป„ไปถ, ้šๆœบๅˆๅง‹ๅŒ–) ---
# Projects hidden_size โ†’ num_heads, yielding one scalar ฮฑ per head.
# After sigmoid: ฮฑ_t โˆˆ (0,1) controls per-head forgetting rate.
# This is the key to *explicit causal modeling*: the model learns
# a content-dependent decay, not a fixed positional bias.
# ๅฐ† hidden_size ๆŠ•ๅฝฑๅˆฐ num_heads, ๆฏไธชๅคดไบง็”Ÿไธ€ไธชๆ ‡้‡ ฮฑใ€‚
# ็ป่ฟ‡ sigmoid ๅŽ: ฮฑ_t โˆˆ (0,1) ๆŽงๅˆถๆฏไธชๅคด็š„้—ๅฟ˜้€Ÿ็އใ€‚
# ่ฟ™ๆ˜ฏ *ๆ˜พๅผๅ› ๆžœๅปบๆจก* ็š„ๅ…ณ้”ฎ: ๆจกๅž‹ๅญฆไน ็š„ๆ˜ฏๅ†…ๅฎน็›ธๅ…ณ็š„่กฐๅ‡,
# ่€Œ้žๅ›บๅฎš็š„ไฝ็ฝฎๅ็ฝฎใ€‚
self.decay_proj = nn.Linear(config.hidden_size, self.num_heads, bias=True)
# --- QK-Norm (novel component, randomly initialized) ---
# --- QK ๅฝ’ไธ€ๅŒ– (ๅ…จๆ–ฐ็ป„ไปถ, ้šๆœบๅˆๅง‹ๅŒ–) ---
# Stabilizes the scale of qยทS readout. Without this, the state
# matrix S (sum of outer products) can grow unboundedly.
# ็จณๅฎš qยทS ่ฏปๅ‡บ็š„ๅฐบๅบฆใ€‚ๆฒกๆœ‰่ฟ™ไธช, ็Šถๆ€็Ÿฉ้˜ต S (ๅค–็งฏไน‹ๅ’Œ)
# ๅฏ่ƒฝๆ— ็•Œๅขž้•ฟใ€‚
self.q_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
# --- Learnable initial state h0 (novel component, zero-initialized) ---
# --- ๅฏๅญฆไน ๅˆๅง‹็Šถๆ€ h0 (ๅ…จๆ–ฐ็ป„ไปถ, ้›ถๅˆๅง‹ๅŒ–) ---
# S_0 = h0 โˆˆ โ„^{1, H, d, d}, shared across batch.
# Zero-init means the model starts with "no memory" โ€” a clean slate.
# The model can learn a non-zero h0 as a kind of "system prompt" state.
# S_0 = h0 โˆˆ โ„^{1, H, d, d}, ่ทจ batch ๅ…ฑไบซใ€‚
# ้›ถๅˆๅง‹ๅŒ–ๆ„ๅ‘ณ็€ๆจกๅž‹ไปŽ"ๆ— ่ฎฐๅฟ†"ๅผ€ๅง‹ โ€” ไธ€ๅผ ็™ฝ็บธใ€‚
# ๆจกๅž‹ๅฏไปฅๅญฆไน ้ž้›ถ็š„ h0 ไฝœไธบไธ€็ง"็ณป็ปŸๆ็คบ"็Šถๆ€ใ€‚
self.h0 = nn.Parameter(torch.zeros(1, self.num_heads, self.head_dim, self.head_dim))
def forward(
self,
hidden_states: Tensor,
monoid_cache: MonoidCache | None = None,
use_cache: bool = False,
) -> tuple[Tensor, tuple[Tensor, Tensor] | None]:
"""
Args:
hidden_states: [B, T, hidden_size]
monoid_cache: O(1) state cache for inference
ๆŽจ็†็”จ O(1) ็Šถๆ€็ผ“ๅญ˜
use_cache: whether to use/update the cache
ๆ˜ฏๅฆไฝฟ็”จ/ๆ›ดๆ–ฐ็ผ“ๅญ˜
Returns:
output: [B, T, hidden_size]
final_state: (log_decay_acc, S) or None
"""
B, T, _ = hidden_states.shape
H, d = self.num_heads, self.head_dim
# --- Project to multi-head Q, K, V ---
# --- ๆŠ•ๅฝฑๅˆฐๅคšๅคด Q, K, V ---
q = self.q_proj(hidden_states).view(B, T, H, d).transpose(1, 2) # [B,H,T,d]
k = self.k_proj(hidden_states).view(B, T, H, d).transpose(1, 2)
v = self.v_proj(hidden_states).view(B, T, H, d).transpose(1, 2)
# --- QK-Norm: stabilize qยทS readout scale ---
# --- QK ๅฝ’ไธ€ๅŒ–: ็จณๅฎš qยทS ่ฏปๅ‡บๅฐบๅบฆ ---
q = self.q_norm(q) * self.scaling
k = self.k_norm(k)
# --- Non-negative keys via SiLU ---
# --- ้€š่ฟ‡ SiLU ไฟ่ฏ key ้ž่ดŸ ---
# Why: the state S = ฮฃ ฮฑ^{t-i} k_iโŠ—v_i is a sum of outer products.
# Non-negative k ensures S is positive semi-definite (PSD),
# preventing "feature erasure" where one token's contribution
# cancels another's. PSD guarantees monotonic information accumulation.
# ๅŽŸๅ› : ็Šถๆ€ S = ฮฃ ฮฑ^{t-i} k_iโŠ—v_i ๆ˜ฏๅค–็งฏไน‹ๅ’Œใ€‚
# ้ž่ดŸ็š„ k ไฟ่ฏ S ๅŠๆญฃๅฎš (PSD), ้˜ฒๆญขไธ€ไธช token ็š„่ดก็Œฎ
# ๆŠตๆถˆๅฆไธ€ไธช token ็š„"็‰นๅพๆ“ฆ้™ค"็Žฐ่ฑกใ€‚
# PSD ไฟ่ฏไฟกๆฏๅ•่ฐƒ็งฏ็ดฏใ€‚
k = torch.nn.functional.silu(k)
# --- Compute per-head decay gate ฮฑ_t ---
# --- ่ฎก็ฎ—ๆฏๅคด่กฐๅ‡้—จ ฮฑ_t ---
# sigmoid ensures ฮฑ โˆˆ (0,1), then log-space for numerical stability.
# sigmoid ไฟ่ฏ ฮฑ โˆˆ (0,1), ็„ถๅŽ่ฝฌๅˆฐๅฏนๆ•ฐ็ฉบ้—ดไฟ่ฏๆ•ฐๅ€ผ็จณๅฎšๆ€งใ€‚
alpha = torch.sigmoid(self.decay_proj(hidden_states)) # [B,T,H]
alpha = alpha.transpose(1, 2).unsqueeze(-1) # [B,H,T,1]
log_alpha = torch.log(alpha.clamp(min=1e-6))
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Inference path (RNN mode): O(1) per token per layer
# ๆŽจ็†่ทฏๅพ„ (RNN ๆจกๅผ): ๆฏๅฑ‚ๆฏ token O(1)
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# When generating, T=1. We apply the monoid operator once
# to fold the new token into the accumulated state.
# This is where "O(1) inference" materializes:
# S_t = ฮฑ_t ยท S_{t-1} + k_t โŠ— v_t (one monoid_op call)
# o_t = q_t ยท S_t (one matmul)
# Total: O(Hยทdยฒ) per layer โ€” independent of sequence length.
#
# ็”Ÿๆˆๆ—ถ T=1ใ€‚ๆˆ‘ไปฌ่ฐƒ็”จไธ€ๆฌกๅนบๅŠ็พค็ฎ—ๅญๅฐ†ๆ–ฐ token ๆŠ˜ๅ ่ฟ›็ดฏ็งฏ็Šถๆ€ใ€‚
# ่ฟ™ๅฐฑๆ˜ฏ "O(1) ๆŽจ็†" ็š„ๅ…ทไฝ“ไฝ“็Žฐ:
# S_t = ฮฑ_t ยท S_{t-1} + k_t โŠ— v_t (ไธ€ๆฌก monoid_op)
# o_t = q_t ยท S_t (ไธ€ๆฌก็Ÿฉ้˜ตไน˜ๆณ•)
# ๆ€ป่ฎก: ๆฏๅฑ‚ O(Hยทdยฒ) โ€” ไธŽๅบๅˆ—้•ฟๅบฆๆ— ๅ…ณใ€‚
if use_cache and T == 1:
# Outer product: k_t โŠ— v_t โˆˆ โ„^{Hร—dร—d}
# ๅค–็งฏ: k_t โŠ— v_t โˆˆ โ„^{Hร—dร—d}
kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, 0], v[:, :, 0])
log_t = log_alpha[:, :, 0] # [B,H,1]
prev = monoid_cache.get_state(self.layer_idx) if monoid_cache else None
if prev is None:
# First token: initialize from learnable h0
# ็ฌฌไธ€ไธช token: ไปŽๅฏๅญฆไน ็š„ h0 ๅˆๅง‹ๅŒ–
decay_t = torch.exp(log_t)
while decay_t.dim() < self.h0.dim():
decay_t = decay_t.unsqueeze(-1)
new_state = (log_t, self.h0.expand(B, -1, -1, -1) * decay_t + kv_t)
else:
# Subsequent tokens: fold via monoid_op โ€” O(1)!
# ๅŽ็ปญ token: ้€š่ฟ‡ monoid_op ๆŠ˜ๅ  โ€” O(1)!
new_state = monoid_op(prev, (log_t, kv_t))
if monoid_cache is not None:
monoid_cache.update(self.layer_idx, new_state)
# Readout: o_t = q_t ยท S_t
# ่ฏปๅ‡บ: o_t = q_t ยท S_t
o = torch.einsum('bhd, bhde -> bhe', q[:, :, 0], new_state[1])
# Reshape [B,H,d] โ†’ [B,1,H*d] (heads contiguous, matching scan path)
# ้‡ๅก‘ [B,H,d] โ†’ [B,1,H*d] (ๅคด่ฟž็ปญๆŽ’ๅˆ—, ไธŽๆ‰ซๆ่ทฏๅพ„ไธ€่‡ด)
o = o.contiguous().view(B, 1, -1)
return self.o_proj(o), new_state
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Inference prefill (use_cache=True, T>1): fused scan + readout
# ๆŽจ็†้ข„ๅกซๅ…… (use_cache=True, T>1): ่žๅˆๆ‰ซๆ + ่ฏปๅ‡บ
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Avoids materializing full [B,H,T,d,d] states tensor.
# Peak memory: O(Hยทdยฒ) instead of O(TยทHยทdยฒ).
# ้ฟๅ…ๅฎžไฝ“ๅŒ–ๅฎŒๆ•ด็š„ [B,H,T,d,d] ็Šถๆ€ๅผ ้‡ใ€‚
# ๅณฐๅ€ผๅ†…ๅญ˜: O(Hยทdยฒ) ่€Œ้ž O(TยทHยทdยฒ)ใ€‚
if use_cache:
S = self.h0.expand(B, -1, -1, -1).clone() # [B,H,d,d]
log_acc = torch.zeros(B, H, 1, device=hidden_states.device, dtype=q.dtype)
o_parts = []
for t in range(T):
kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, t], v[:, :, t])
decay = torch.exp(log_alpha[:, :, t]) # [B,H,1]
while decay.dim() < S.dim():
decay = decay.unsqueeze(-1)
S = S * decay + kv_t
o_parts.append(torch.einsum('bhd, bhde -> bhe', q[:, :, t], S))
log_acc = log_acc + log_alpha[:, :, t]
final_state = (log_acc, S)
if monoid_cache is not None:
monoid_cache.update(self.layer_idx, final_state)
o = torch.stack(o_parts, dim=2) # [B,H,T,d]
o = o.transpose(1, 2).contiguous().view(B, T, -1)
return self.o_proj(o), final_state
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Training path: parallel scan + vectorized readout
# ่ฎญ็ปƒ่ทฏๅพ„: ๅนถ่กŒๆ‰ซๆ + ๅ‘้‡ๅŒ–่ฏปๅ‡บ
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Materialize full kv tensor [B,H,T,d,d] and scan in one pass.
# Memory: O(BยทHยทTยทdยฒ) โ€” trades memory for speed.
# Eliminates Tร—30 Python-loop kernel launches for outer product
# and readout; scan itself is parallel when CUDA kernel available.
#
# ็‰ฉๅŒ–ๅฎŒๆ•ด kv ๅผ ้‡ [B,H,T,d,d] ๅนถไธ€ๆฌกๆ€งๆ‰ซๆใ€‚
# ๅ†…ๅญ˜: O(BยทHยทTยทdยฒ) โ€” ไปฅๅ†…ๅญ˜ๆข้€Ÿๅบฆใ€‚
# ๆถˆ้™คๅค–็งฏๅ’Œ่ฏปๅ‡บ็š„ Tร—30 ๆฌก Python ๅพช็Žฏ kernel launch;
# ๅฝ“ CUDA kernel ๅฏ็”จๆ—ถๆ‰ซๆๆœฌ่บซไนŸๆ˜ฏๅนถ่กŒ็š„ใ€‚
# Vectorized outer product: kv_t = k_t โŠ— v_t for all t at once
# ๅ‘้‡ๅŒ–ๅค–็งฏ: ไธ€ๆฌกๆ€ง่ฎก็ฎ—ๆ‰€ๆœ‰ t ็š„ k_t โŠ— v_t
kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
# Parallel prefix scan: S_t = ฮฑ_tยทS_{t-1} + kv_t (from S=0)
# ๅนถ่กŒๅ‰็ผ€ๆ‰ซๆ: S_t = ฮฑ_tยทS_{t-1} + kv_t (ไปŽ S=0 ๅผ€ๅง‹)
# Keep log_alpha as [B,H,T,1] โ€” CUDA kernel backward expects this shape.
# ไฟๆŒ log_alpha ไธบ [B,H,T,1] โ€” CUDA kernel ๅๅ‘ไผ ๆ’ญ้œ€่ฆๆญคๅฝข็Šถใ€‚
states = parallel_scan(log_alpha, kv) # [B,H,T,d,d]
# Add h0 contribution: S_t += (โˆ_{i=0}^{t} ฮฑ_i) ยท h0
# ๅ ๅŠ  h0 ่ดก็Œฎ: S_t += (โˆ_{i=0}^{t} ฮฑ_i) ยท h0
cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,1]
h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,1,1]
states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
# Vectorized readout: o_t = q_t ยท S_t for all t at once
# ๅ‘้‡ๅŒ–่ฏปๅ‡บ: ไธ€ๆฌกๆ€ง่ฎก็ฎ—ๆ‰€ๆœ‰ t ็š„ q_t ยท S_t
o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d]
o = o.transpose(1, 2).contiguous().view(B, T, -1)
return self.o_proj(o), None
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
# Decoder Layer: MonoidAttn + LlamaMLP + LlamaRMSNorm
# ่งฃ็ ๅฑ‚: ๅนบๅŠ็พคๆณจๆ„ๅŠ› + LlamaMLP + LlamaRMSNorm
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
class MonoidDecoderLayer(nn.Module):
"""
Pre-Norm Transformer block with Monoid attention.
ไฝฟ็”จๅนบๅŠ็พคๆณจๆ„ๅŠ›็š„ Pre-Norm Transformer ๅ—ใ€‚
Data flow / ๆ•ฐๆฎๆต:
x โ†’ RMSNorm โ†’ MonoidAttn โ†’ +residual โ†’ RMSNorm โ†’ LlamaMLP โ†’ +residual โ†’ out
The MLP and RMSNorm are identical to Llama (weights transferred directly).
Only MonoidAttention is the novel component.
MLP ๅ’Œ RMSNorm ไธŽ Llama ๅฎŒๅ…จ็›ธๅŒ (ๆƒ้‡็›ดๆŽฅ่ฟ็งป)ใ€‚
ไป… MonoidAttention ๆ˜ฏๅ…จๆ–ฐ็ป„ไปถใ€‚
"""
gradient_checkpointing = False
def __init__(self, config: MonoidConfig, layer_idx: int):
super().__init__()
self.self_attn = MonoidAttention(config, layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: Tensor,
monoid_cache: MonoidCache | None = None,
use_cache: bool = False,
) -> Tensor:
# --- Attention block with residual ---
# --- ๆณจๆ„ๅŠ›ๅ— + ๆฎ‹ๅทฎ่ฟžๆŽฅ ---
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn(hidden_states, monoid_cache=monoid_cache, use_cache=use_cache)
hidden_states = residual + hidden_states
# --- FFN block with residual ---
# --- ๅ‰้ฆˆ็ฝ‘็ปœๅ— + ๆฎ‹ๅทฎ่ฟžๆŽฅ ---
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
# MonoidModel (backbone)
# MonoidModel (้ชจๅนฒ็ฝ‘็ปœ)
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
class MonoidPreTrainedModel(PreTrainedModel):
config_class = MonoidConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MonoidDecoderLayer"]
def _init_weights(self, module: nn.Module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MonoidAttention):
nn.init.constant_(module.decay_proj.bias, 4.0)
class MonoidModel(MonoidPreTrainedModel):
"""
Stack of MonoidDecoderLayers with token embedding and final norm.
ๅนบๅŠ็พค่งฃ็ ๅฑ‚ๅ †ๅ , ๅธฆ token ๅตŒๅ…ฅๅ’Œๆœ€็ปˆๅฝ’ไธ€ๅŒ–ใ€‚
Forward: embed_tokens โ†’ N ร— MonoidDecoderLayer โ†’ final_norm
ๅ‰ๅ‘: embed_tokens โ†’ N ร— MonoidDecoderLayer โ†’ final_norm
"""
def __init__(self, config: MonoidConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[MonoidDecoderLayer(config, i) for i in range(config.num_hidden_layers)]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self.post_init()
def forward(
self,
input_ids: Tensor | None = None,
inputs_embeds: Tensor | None = None,
monoid_cache: MonoidCache | None = None,
use_cache: bool = False,
) -> BaseModelOutputWithPast:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
for layer in self.layers:
if self.gradient_checkpointing and self.training and not use_cache:
hidden_states = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
monoid_cache,
use_cache,
)
else:
hidden_states = layer(hidden_states, monoid_cache=monoid_cache, use_cache=use_cache)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=monoid_cache,
)
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
# MonoidForCausalLM โ€” the full causal LM
# MonoidForCausalLM โ€” ๅฎŒๆ•ดๅ› ๆžœ่ฏญ่จ€ๆจกๅž‹
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
class MonoidForCausalLM(MonoidPreTrainedModel, GenerationMixin):
"""
Monoid-based causal language model with LM head.
ๅŸบไบŽๅนบๅŠ็พค็š„ๅ› ๆžœ่ฏญ่จ€ๆจกๅž‹, ๅธฆ่ฏญ่จ€ๆจกๅž‹ๅคดใ€‚
The architecture in one sentence:
"Llama body + Monoid mind" โ€” reuse Llama's proven MLP/embeddings,
replace attention with monoid state compression for O(1) inference.
ไธ€ๅฅ่ฏๆฆ‚ๆ‹ฌๆžถๆž„:
"Llama ็š„่บซไฝ“ + ๅนบๅŠ็พค็š„ๆ€็ปด" โ€” ๅค็”จ Llama ๆˆ็†Ÿ็š„ MLP/ๅตŒๅ…ฅๅฑ‚,
็”จๅนบๅŠ็พค็Šถๆ€ๅŽ‹็ผฉๆ›ฟๆขๆณจๆ„ๅŠ›, ๅฎž็Žฐ O(1) ๆŽจ็†ใ€‚
"""
_tied_weights_keys = ["lm_head.weight"]
# Tell HuggingFace GenerationMixin NOT to create DynamicCache.
# Monoid uses its own O(1) MonoidCache, not KV-Cache.
# ๅ‘Š่ฏ‰ HuggingFace ไธ่ฆๅˆ›ๅปบ DynamicCacheใ€‚
# Monoid ไฝฟ็”จ่‡ชๅทฑ็š„ O(1) MonoidCache, ไธๆ˜ฏ KV ็ผ“ๅญ˜ใ€‚
_is_stateful = True
def __init__(self, config: MonoidConfig):
super().__init__(config)
self.model = MonoidModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(
self,
input_ids: Tensor,
past_key_values=None,
attention_mask: Tensor | None = None,
inputs_embeds: Tensor | None = None,
**kwargs,
) -> dict:
"""
Called by GenerationMixin at each decoding step.
GenerationMixin ๅœจๆฏไธช่งฃ็ ๆญฅ่ฐƒ็”จๆญคๆ–นๆณ•ใ€‚
HuggingFace may pass a DynamicCache; we intercept and replace
it with MonoidCache since we don't use standard KV-cache.
HuggingFace ๅฏ่ƒฝไผ ๅ…ฅ DynamicCache; ๆˆ‘ไปฌๆ‹ฆๆˆชๅนถๆ›ฟๆขไธบ
MonoidCache, ๅ› ไธบๆˆ‘ไปฌไธไฝฟ็”จๆ ‡ๅ‡† KV ็ผ“ๅญ˜ใ€‚
"""
# Intercept non-MonoidCache objects (e.g. DynamicCache from GenerationMixin)
# ๆ‹ฆๆˆช้ž MonoidCache ๅฏน่ฑก (ๅฆ‚ GenerationMixin ๅˆ›ๅปบ็š„ DynamicCache)
if past_key_values is not None and not isinstance(past_key_values, MonoidCache):
past_key_values = None
if past_key_values is not None and past_key_values.seen_tokens > 0:
# Cache exists โ†’ only feed the latest token (O(1) inference)
# ็ผ“ๅญ˜ๅทฒๅญ˜ๅœจ โ†’ ๅช้œ€่พ“ๅ…ฅๆœ€ๆ–ฐ็š„ token (O(1) ๆŽจ็†)
input_ids = input_ids[:, -1:]
model_inputs = {
"input_ids": input_ids,
"monoid_cache": past_key_values,
"use_cache": True,
}
return model_inputs
def forward(
self,
input_ids: Tensor | None = None,
attention_mask: Tensor | None = None, # kept for API compat; monoid ignores this
# ไฟ็•™ API ๅ…ผๅฎนๆ€ง; ๅนบๅŠ็พคไธไฝฟ็”จ
position_ids: Tensor | None = None, # kept for API compat; monoid ignores this
# ไฟ็•™ API ๅ…ผๅฎนๆ€ง; ๅนบๅŠ็พคไธไฝฟ็”จ
past_key_values: MonoidCache | None = None,
inputs_embeds: Tensor | None = None,
labels: Tensor | None = None,
use_cache: bool | None = None,
monoid_cache: MonoidCache | None = None,
output_attentions: bool | None = None, # kept for API compat
output_hidden_states: bool | None = None, # kept for API compat
logits_to_keep: int | Tensor = 0,
**kwargs,
) -> CausalLMOutputWithPast:
# monoid_cache takes priority; fall back to past_key_values for GenerationMixin compat
# monoid_cache ไผ˜ๅ…ˆ; ๅ…ผๅฎน GenerationMixin ไผ ๅ…ฅ็š„ past_key_values
cache = monoid_cache or past_key_values
# Discard any non-MonoidCache (e.g. DynamicCache injected by GenerationMixin)
# ไธขๅผƒไปปไฝ•้ž MonoidCache ๅฏน่ฑก (ๅฆ‚ GenerationMixin ๆณจๅ…ฅ็š„ DynamicCache)
if cache is not None and not isinstance(cache, MonoidCache):
cache = None
if use_cache and cache is None:
cache = MonoidCache()
outputs = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
monoid_cache=cache,
use_cache=bool(use_cache),
)
hidden_states = outputs.last_hidden_state
# Optionally only compute logits for the last K tokens (memory saving)
# ๅฏ้€‰ไป…่ฎก็ฎ—ๆœ€ๅŽ K ไธช token ็š„ logits (่Š‚็œๅ†…ๅญ˜)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep > 0 else slice(None)
logits = self.lm_head(hidden_states[:, slice_indices, :])
# Standard causal LM loss: cross-entropy with shift
# ๆ ‡ๅ‡†ๅ› ๆžœ่ฏญ่จ€ๆจกๅž‹ๆŸๅคฑ: ๅธฆๅ็งป็š„ไบคๅ‰็†ต
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, self.vocab_size),
shift_labels.view(-1),
ignore_index=-100,
)
if cache is not None:
cache.seen_tokens += (input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1])
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=cache,
)
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
# AutoModel Registration / ่‡ชๅŠจๆณจๅ†Œ
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
AutoConfig.register("monoid", MonoidConfig)
AutoModelForCausalLM.register(MonoidConfig, MonoidForCausalLM)
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
# Smoke Tests / ้ชŒ่ฏ
# โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
if __name__ == '__main__':
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f'Device: {device}')
config = MonoidConfig(
vocab_size=49152,
hidden_size=576,
intermediate_size=1536,
num_hidden_layers=30,
num_attention_heads=9,
head_dim=64,
rms_norm_eps=1e-5,
hidden_act="silu",
tie_word_embeddings=True,
)
model = MonoidForCausalLM(config).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Parameters: {n_params:,}')
# -- Training smoke test / ่ฎญ็ปƒๅ†’็ƒŸๆต‹่ฏ• --
B, T = 2, 64
ids = torch.randint(0, config.vocab_size, (B, T), device=device)
out = model(ids, labels=ids)
print(f'Train โ€” logits: {out.logits.shape}, loss: {out.loss:.4f}')
# -- Inference smoke test (manual RNN loop) / ๆŽจ็†ๅ†’็ƒŸๆต‹่ฏ• (ๆ‰‹ๅŠจ RNN ๅพช็Žฏ) --
prompt = torch.randint(0, config.vocab_size, (1, 8), device=device)
cache = MonoidCache()
# Prefill / ้ข„ๅกซๅ……
prefill_out = model(prompt, use_cache=True, monoid_cache=cache)
print(f'Prefill โ€” logits: {prefill_out.logits.shape}, cache seen: {cache.seen_tokens}')
# Decode 1 token / ่งฃ็  1 ไธช token
next_tok = prefill_out.logits[:, -1:].argmax(dim=-1)
step_out = model(next_tok, use_cache=True, monoid_cache=cache)
print(f'Decode โ€” logits: {step_out.logits.shape}, cache seen: {cache.seen_tokens}')
# -- Monoid associativity check / ๅนบๅŠ็พค็ป“ๅˆๅพ‹้ชŒ่ฏ --
print('\nMonoid associativity check / ๅนบๅŠ็พค็ป“ๅˆๅพ‹้ชŒ่ฏ:')
a = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4))
b = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4))
c = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4))
ab_c = monoid_op(monoid_op(a, b), c)
a_bc = monoid_op(a, monoid_op(b, c))
err = (ab_c[1] - a_bc[1]).abs().max().item()
print(f' |(aโŠ•b)โŠ•c - aโŠ•(bโŠ•c)| = {err:.2e}')
print('\nDone.')