""" MonoidForCausalLM — Causal Monoid Language Model (HuggingFace Compatible) MonoidForCausalLM — 幺半群因果语言模型 (兼容 HuggingFace) Architecture / 架构概要: Replace softmax attention with a monoid parallel-scan recurrence. 用幺半群并行扫描递推替代 softmax 注意力。 Core idea / 核心思想: Softmax attention computes o_t = Σ_{i≤t} softmax(q_t·k_i) v_i — requires O(T) KV-cache per layer at inference. Softmax 注意力计算 o_t = Σ_{i≤t} softmax(q_t·k_i) v_i — 推理时每层需要 O(T) 的 KV 缓存。 Monoid attention compresses the entire causal history into a fixed-size state matrix S_t ∈ ℝ^{d×d} per head: S_t = α_t · S_{t-1} + k_t ⊗ v_t (explicit causal recurrence) o_t = q_t · S_t (state readout) 幺半群注意力将完整因果历史压缩到每个头一个固定大小的状态矩阵 S_t: S_t = α_t · S_{t-1} + k_t ⊗ v_t (显式因果递推) o_t = q_t · S_t (状态读出) This is a monoid because the binary operator: (log_α, S) ⊕ (log_β, X) = (log_α + log_β, exp(log_β)·S + X) is associative → enables parallel prefix scan for training, and O(1) sequential update for inference. 这是一个幺半群,因为二元算子: (log_α, S) ⊕ (log_β, X) = (log_α + log_β, exp(log_β)·S + X) 满足结合律 → 训练时可用并行前缀扫描,推理时 O(1) 逐步递推。 Key properties / 关键特性: ✓ Explicit causal modeling — α_t gate explicitly controls how fast past information decays, making causality a first-class citizen. 显式因果建模 — α_t 衰减门显式控制历史信息的遗忘速率, 因果性是一等公民而非靠 mask 施加的约束。 ✓ Monoid state compression — the full causal prefix x_{1:t} is lossily compressed into a fixed-size (d×d) state matrix per head. No O(T) KV-cache needed; inference is O(1) per token per layer. 幺半群状态压缩 — 完整因果前缀 x_{1:t} 被有损压缩到每个头 固定大小的 (d×d) 状态矩阵中。无需 O(T) KV 缓存; 推理时每层每 token O(1)。 ✓ Parallel training — associativity of ⊕ enables O(T) parallel prefix scan (vs O(T²) for softmax attention). 并行训练 — ⊕ 的结合律使 O(T) 并行前缀扫描成为可能 (对比 softmax 注意力的 O(T²))。 Reuses LlamaMLP + LlamaRMSNorm from HuggingFace Transformers. 复用 HuggingFace Transformers 的 LlamaMLP + LlamaRMSNorm。 """ from __future__ import annotations from typing import Optional, Union import torch import torch.nn as nn from torch import Tensor from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin, AutoConfig, AutoModelForCausalLM from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm try: from monoid_scan_cuda import parallel_scan, parallel_scan_with_state except ImportError: # Pure-PyTorch fallback (sequential scan) — works on CPU / MPS / any device. # Slower than the fused CUDA kernel but numerically identical. def parallel_scan(log_alpha: Tensor, kv: Tensor) -> Tensor: """Sequential prefix scan fallback: S_t = exp(log_α_t)·S_{t-1} + kv_t.""" B, H, T, d1, d2 = kv.shape states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype) S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype) for t in range(T): decay = torch.exp(log_alpha[:, :, t]) # [B, H, 1] while decay.dim() < S.dim(): decay = decay.unsqueeze(-1) S = S * decay + kv[:, :, t] states[:, :, t] = S return states def parallel_scan_with_state(log_alpha: Tensor, kv: Tensor): """Sequential prefix scan that also returns the final (log_decay, S) state.""" B, H, T, d1, d2 = kv.shape states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype) S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype) log_acc = torch.zeros(B, H, 1, device=log_alpha.device, dtype=log_alpha.dtype) for t in range(T): decay = torch.exp(log_alpha[:, :, t]) while decay.dim() < S.dim(): decay = decay.unsqueeze(-1) S = S * decay + kv[:, :, t] states[:, :, t] = S log_acc = log_acc + log_alpha[:, :, t] return states, (log_acc, S) # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # Config / 配置 # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ class MonoidConfig(PretrainedConfig): """ Configuration for the Monoid causal language model. 幺半群因果语言模型的配置。 Mirrors LlamaConfig for the shared components (MLP, RMSNorm, embedding) so that weights can be directly transferred from Llama checkpoints. 与 LlamaConfig 的共享组件 (MLP, RMSNorm, embedding) 保持一致, 以便从 Llama 检查点直接迁移权重。 """ model_type = "monoid" def __init__( self, vocab_size: int = 32000, hidden_size: int = 576, intermediate_size: int = 1536, num_hidden_layers: int = 30, num_attention_heads: int = 9, head_dim: int = 64, max_position_embeddings: int = 2048, rms_norm_eps: float = 1e-5, hidden_act: str = "silu", mlp_bias: bool = False, attention_bias: bool = False, tie_word_embeddings: bool = True, initializer_range: float = 0.041666666666666664, pad_token_id: int = None, bos_token_id: int = 1, eos_token_id: int = 2, **kwargs, ): super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.head_dim = head_dim self.max_position_embeddings = max_position_embeddings self.rms_norm_eps = rms_norm_eps self.hidden_act = hidden_act self.mlp_bias = mlp_bias self.attention_bias = attention_bias self.initializer_range = initializer_range # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # Monoid Cache — O(1) state replaces O(T) KV-Cache # 幺半群缓存 — O(1) 状态替代 O(T) KV 缓存 # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ class MonoidCache: """ Per-layer monoid state cache for autoregressive inference. 自回归推理的逐层幺半群状态缓存。 Unlike Transformer KV-Cache that stores all past keys & values (O(T) memory), each layer here stores exactly ONE state tuple: (log_decay_acc, S) where S ∈ ℝ^{B, H, d, d} This is the monoid "sum" of all past (log_α_i, k_i⊗v_i) via ⊕. Memory is O(1) per layer regardless of sequence length. 不同于 Transformer 的 KV-Cache (存储所有过去的 key 和 value, O(T) 内存), 这里每层仅存储一个状态元组: (log_decay_acc, S) 其中 S ∈ ℝ^{B, H, d, d} 这是所有过去的 (log_α_i, k_i⊗v_i) 通过 ⊕ 累积的幺半群 "和"。 无论序列多长,每层内存 O(1)。 """ def __init__(self): self.states: list[tuple[Tensor, Tensor] | None] = [] self.seen_tokens: int = 0 def get_seq_length(self, layer_idx: int = 0) -> int: return self.seen_tokens def update(self, layer_idx: int, state: tuple[Tensor, Tensor]): """Store the accumulated monoid state for a given layer. 存储指定层的累积幺半群状态。""" while len(self.states) <= layer_idx: self.states.append(None) self.states[layer_idx] = state def get_state(self, layer_idx: int) -> tuple[Tensor, Tensor] | None: """Retrieve the accumulated monoid state for a given layer. 获取指定层的累积幺半群状态。""" if layer_idx < len(self.states): return self.states[layer_idx] return None def reorder_cache(self, beam_idx: torch.LongTensor): """Reorder cache for beam search. 为 beam search 重排缓存。""" for i, state in enumerate(self.states): if state is not None: log_d, kv = state self.states[i] = (log_d[beam_idx], kv[beam_idx]) # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # Monoid Operator — the algebraic heart # 幺半群算子 — 代数核心 # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ def monoid_op( a: tuple[Tensor, Tensor], b: tuple[Tensor, Tensor], ) -> tuple[Tensor, Tensor]: """ The monoid binary operator ⊕ on (log-space decay, state matrix) pairs. 幺半群二元算子 ⊕,作用于 (对数衰减, 状态矩阵) 对。 Definition / 定义: (log_α, S) ⊕ (log_β, X) = (log_α + log_β, exp(log_β)·S + X) Why this is a monoid / 为什么这是幺半群: • Associativity / 结合律: (a ⊕ b) ⊕ c = a ⊕ (b ⊕ c) ✓ This enables parallel prefix scan for training (reduce tree) and O(1) left-fold for inference (sequential append). 结合律使训练时可以用并行前缀扫描 (归约树), 推理时可以 O(1) 左折叠 (逐步追加)。 • Identity / 单位元: e = (0, 0) → e ⊕ a = a ⊕ e = a ✓ Why log-space / 为什么用对数空间: Working in log-space for the decay factor avoids numerical underflow when α^T → 0 for long sequences. 衰减因子在对数空间中运算,避免长序列下 α^T → 0 的数值下溢。 Causal semantics / 因果语义: S_t = α_t · S_{t-1} + k_t ⊗ v_t The decay α_t ∈ (0,1) explicitly controls how much of the past the model retains. This is *explicit causal modeling* — the model must learn to balance retention vs novelty at every timestep. 衰减 α_t ∈ (0,1) 显式控制模型保留多少过去信息。 这就是 *显式因果建模* — 模型必须在每个时间步学习如何 平衡保留旧信息与吸收新信息。 """ log_a, kv_a = a log_b, kv_b = b new_log = log_a + log_b # log(α·β) = log_α + log_β decay_b = torch.exp(log_b) # β = exp(log_β) while decay_b.dim() < kv_a.dim(): decay_b = decay_b.unsqueeze(-1) # broadcast to [B,H,...,1,1] return new_log, kv_a * decay_b + kv_b # β·S + X # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # Monoid Attention — the core innovation # 幺半群注意力 — 核心创新层 # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ class MonoidAttention(nn.Module): """ Monoid Causal Attention — replaces softmax attention entirely. 幺半群因果注意力 — 完全替代 softmax 注意力。 Key differences from standard attention / 与标准注意力的关键区别: ✗ No RoPE / positional encoding — position is implicitly encoded by the causal decay gate α_t. The model learns *when* to forget rather than encoding *where* tokens are. 不使用 RoPE / 位置编码 — 位置信息由因果衰减门 α_t 隐式编码。 模型学习 *何时遗忘* 而非编码 token *在哪里*。 ✗ No KV-Cache — replaced by MonoidCache with O(1) state per layer. Each state S ∈ ℝ^{H×d×d} is a compressed summary of ALL past tokens. 不使用 KV 缓存 — 由 O(1) 的 MonoidCache 状态替代。 每个状态 S ∈ ℝ^{H×d×d} 是所有过去 token 的压缩摘要。 ✗ No attention mask — causality is built into the recurrence itself. S_t only depends on S_{t-1} and the current token by construction. 不使用注意力掩码 — 因果性内建于递推结构本身。 S_t 仅依赖 S_{t-1} 和当前 token,结构上保证因果性。 Computation / 计算: Training (parallel scan, O(T)): k_t = SiLU(k_proj(x_t)) # non-negative keys for PSD state S_t = α_t · S_{t-1} + k_t ⊗ v_t # monoid recurrence via prefix scan o_t = q_t · S_t # linear readout from state Inference (RNN mode, O(1) per token): Same recurrence, but applied one token at a time. 训练 (并行扫描, O(T)): k_t = SiLU(k_proj(x_t)) # 非负 key 保证状态矩阵半正定 S_t = α_t · S_{t-1} + k_t ⊗ v_t # 通过前缀扫描实现幺半群递推 o_t = q_t · S_t # 从状态中线性读出 推理 (RNN 模式, 每 token O(1)): 同一递推公式, 但逐 token 顺序应用。 """ def __init__(self, config: MonoidConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = config.head_dim self.scaling = self.head_dim ** -0.5 # 1/√d, scale factor for q·S readout # q·S 读出的缩放因子 # --- Projections (transferred from Llama) --- # --- 投影层 (从 Llama 迁移) --- # q_proj, o_proj: identical dims to Llama, direct copy # k_proj, v_proj: Llama GQA has fewer KV heads; we tile to full heads # q_proj, o_proj: 维度与 Llama 一致, 直接复制 # k_proj, v_proj: Llama GQA 的 KV 头更少; 我们重复到全头数 self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) # --- Decay gate (novel component, randomly initialized) --- # --- 衰减门 (全新组件, 随机初始化) --- # Projects hidden_size → num_heads, yielding one scalar α per head. # After sigmoid: α_t ∈ (0,1) controls per-head forgetting rate. # This is the key to *explicit causal modeling*: the model learns # a content-dependent decay, not a fixed positional bias. # 将 hidden_size 投影到 num_heads, 每个头产生一个标量 α。 # 经过 sigmoid 后: α_t ∈ (0,1) 控制每个头的遗忘速率。 # 这是 *显式因果建模* 的关键: 模型学习的是内容相关的衰减, # 而非固定的位置偏置。 self.decay_proj = nn.Linear(config.hidden_size, self.num_heads, bias=True) # --- QK-Norm (novel component, randomly initialized) --- # --- QK 归一化 (全新组件, 随机初始化) --- # Stabilizes the scale of q·S readout. Without this, the state # matrix S (sum of outer products) can grow unboundedly. # 稳定 q·S 读出的尺度。没有这个, 状态矩阵 S (外积之和) # 可能无界增长。 self.q_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps) # --- Learnable initial state h0 (novel component, zero-initialized) --- # --- 可学习初始状态 h0 (全新组件, 零初始化) --- # S_0 = h0 ∈ ℝ^{1, H, d, d}, shared across batch. # Zero-init means the model starts with "no memory" — a clean slate. # The model can learn a non-zero h0 as a kind of "system prompt" state. # S_0 = h0 ∈ ℝ^{1, H, d, d}, 跨 batch 共享。 # 零初始化意味着模型从"无记忆"开始 — 一张白纸。 # 模型可以学习非零的 h0 作为一种"系统提示"状态。 self.h0 = nn.Parameter(torch.zeros(1, self.num_heads, self.head_dim, self.head_dim)) def forward( self, hidden_states: Tensor, monoid_cache: MonoidCache | None = None, use_cache: bool = False, ) -> tuple[Tensor, tuple[Tensor, Tensor] | None]: """ Args: hidden_states: [B, T, hidden_size] monoid_cache: O(1) state cache for inference 推理用 O(1) 状态缓存 use_cache: whether to use/update the cache 是否使用/更新缓存 Returns: output: [B, T, hidden_size] final_state: (log_decay_acc, S) or None """ B, T, _ = hidden_states.shape H, d = self.num_heads, self.head_dim # --- Project to multi-head Q, K, V --- # --- 投影到多头 Q, K, V --- q = self.q_proj(hidden_states).view(B, T, H, d).transpose(1, 2) # [B,H,T,d] k = self.k_proj(hidden_states).view(B, T, H, d).transpose(1, 2) v = self.v_proj(hidden_states).view(B, T, H, d).transpose(1, 2) # --- QK-Norm: stabilize q·S readout scale --- # --- QK 归一化: 稳定 q·S 读出尺度 --- q = self.q_norm(q) * self.scaling k = self.k_norm(k) # --- Non-negative keys via SiLU --- # --- 通过 SiLU 保证 key 非负 --- # Why: the state S = Σ α^{t-i} k_i⊗v_i is a sum of outer products. # Non-negative k ensures S is positive semi-definite (PSD), # preventing "feature erasure" where one token's contribution # cancels another's. PSD guarantees monotonic information accumulation. # 原因: 状态 S = Σ α^{t-i} k_i⊗v_i 是外积之和。 # 非负的 k 保证 S 半正定 (PSD), 防止一个 token 的贡献 # 抵消另一个 token 的"特征擦除"现象。 # PSD 保证信息单调积累。 k = torch.nn.functional.silu(k) # --- Compute per-head decay gate α_t --- # --- 计算每头衰减门 α_t --- # sigmoid ensures α ∈ (0,1), then log-space for numerical stability. # sigmoid 保证 α ∈ (0,1), 然后转到对数空间保证数值稳定性。 alpha = torch.sigmoid(self.decay_proj(hidden_states)) # [B,T,H] alpha = alpha.transpose(1, 2).unsqueeze(-1) # [B,H,T,1] log_alpha = torch.log(alpha.clamp(min=1e-6)) # ══════════════════════════════════════════════════════════ # Inference path (RNN mode): O(1) per token per layer # 推理路径 (RNN 模式): 每层每 token O(1) # ══════════════════════════════════════════════════════════ # When generating, T=1. We apply the monoid operator once # to fold the new token into the accumulated state. # This is where "O(1) inference" materializes: # S_t = α_t · S_{t-1} + k_t ⊗ v_t (one monoid_op call) # o_t = q_t · S_t (one matmul) # Total: O(H·d²) per layer — independent of sequence length. # # 生成时 T=1。我们调用一次幺半群算子将新 token 折叠进累积状态。 # 这就是 "O(1) 推理" 的具体体现: # S_t = α_t · S_{t-1} + k_t ⊗ v_t (一次 monoid_op) # o_t = q_t · S_t (一次矩阵乘法) # 总计: 每层 O(H·d²) — 与序列长度无关。 if use_cache and T == 1: # Outer product: k_t ⊗ v_t ∈ ℝ^{H×d×d} # 外积: k_t ⊗ v_t ∈ ℝ^{H×d×d} kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, 0], v[:, :, 0]) log_t = log_alpha[:, :, 0] # [B,H,1] prev = monoid_cache.get_state(self.layer_idx) if monoid_cache else None if prev is None: # First token: initialize from learnable h0 # 第一个 token: 从可学习的 h0 初始化 decay_t = torch.exp(log_t) while decay_t.dim() < self.h0.dim(): decay_t = decay_t.unsqueeze(-1) new_state = (log_t, self.h0.expand(B, -1, -1, -1) * decay_t + kv_t) else: # Subsequent tokens: fold via monoid_op — O(1)! # 后续 token: 通过 monoid_op 折叠 — O(1)! new_state = monoid_op(prev, (log_t, kv_t)) if monoid_cache is not None: monoid_cache.update(self.layer_idx, new_state) # Readout: o_t = q_t · S_t # 读出: o_t = q_t · S_t o = torch.einsum('bhd, bhde -> bhe', q[:, :, 0], new_state[1]) # Reshape [B,H,d] → [B,1,H*d] (heads contiguous, matching scan path) # 重塑 [B,H,d] → [B,1,H*d] (头连续排列, 与扫描路径一致) o = o.contiguous().view(B, 1, -1) return self.o_proj(o), new_state # ══════════════════════════════════════════════════════════ # Inference prefill (use_cache=True, T>1): fused scan + readout # 推理预填充 (use_cache=True, T>1): 融合扫描 + 读出 # ══════════════════════════════════════════════════════════ # Avoids materializing full [B,H,T,d,d] states tensor. # Peak memory: O(H·d²) instead of O(T·H·d²). # 避免实体化完整的 [B,H,T,d,d] 状态张量。 # 峰值内存: O(H·d²) 而非 O(T·H·d²)。 if use_cache: S = self.h0.expand(B, -1, -1, -1).clone() # [B,H,d,d] log_acc = torch.zeros(B, H, 1, device=hidden_states.device, dtype=q.dtype) o_parts = [] for t in range(T): kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, t], v[:, :, t]) decay = torch.exp(log_alpha[:, :, t]) # [B,H,1] while decay.dim() < S.dim(): decay = decay.unsqueeze(-1) S = S * decay + kv_t o_parts.append(torch.einsum('bhd, bhde -> bhe', q[:, :, t], S)) log_acc = log_acc + log_alpha[:, :, t] final_state = (log_acc, S) if monoid_cache is not None: monoid_cache.update(self.layer_idx, final_state) o = torch.stack(o_parts, dim=2) # [B,H,T,d] o = o.transpose(1, 2).contiguous().view(B, T, -1) return self.o_proj(o), final_state # ══════════════════════════════════════════════════════════ # Training path: parallel scan + vectorized readout # 训练路径: 并行扫描 + 向量化读出 # ══════════════════════════════════════════════════════════ # Materialize full kv tensor [B,H,T,d,d] and scan in one pass. # Memory: O(B·H·T·d²) — trades memory for speed. # Eliminates T×30 Python-loop kernel launches for outer product # and readout; scan itself is parallel when CUDA kernel available. # # 物化完整 kv 张量 [B,H,T,d,d] 并一次性扫描。 # 内存: O(B·H·T·d²) — 以内存换速度。 # 消除外积和读出的 T×30 次 Python 循环 kernel launch; # 当 CUDA kernel 可用时扫描本身也是并行的。 # Vectorized outer product: kv_t = k_t ⊗ v_t for all t at once # 向量化外积: 一次性计算所有 t 的 k_t ⊗ v_t kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d] # Parallel prefix scan: S_t = α_t·S_{t-1} + kv_t (from S=0) # 并行前缀扫描: S_t = α_t·S_{t-1} + kv_t (从 S=0 开始) # Keep log_alpha as [B,H,T,1] — CUDA kernel backward expects this shape. # 保持 log_alpha 为 [B,H,T,1] — CUDA kernel 反向传播需要此形状。 states = parallel_scan(log_alpha, kv) # [B,H,T,d,d] # Add h0 contribution: S_t += (∏_{i=0}^{t} α_i) · h0 # 叠加 h0 贡献: S_t += (∏_{i=0}^{t} α_i) · h0 cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,1] h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,1,1] states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d] # Vectorized readout: o_t = q_t · S_t for all t at once # 向量化读出: 一次性计算所有 t 的 q_t · S_t o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d] o = o.transpose(1, 2).contiguous().view(B, T, -1) return self.o_proj(o), None # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # Decoder Layer: MonoidAttn + LlamaMLP + LlamaRMSNorm # 解码层: 幺半群注意力 + LlamaMLP + LlamaRMSNorm # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ class MonoidDecoderLayer(nn.Module): """ Pre-Norm Transformer block with Monoid attention. 使用幺半群注意力的 Pre-Norm Transformer 块。 Data flow / 数据流: x → RMSNorm → MonoidAttn → +residual → RMSNorm → LlamaMLP → +residual → out The MLP and RMSNorm are identical to Llama (weights transferred directly). Only MonoidAttention is the novel component. MLP 和 RMSNorm 与 Llama 完全相同 (权重直接迁移)。 仅 MonoidAttention 是全新组件。 """ gradient_checkpointing = False def __init__(self, config: MonoidConfig, layer_idx: int): super().__init__() self.self_attn = MonoidAttention(config, layer_idx) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: Tensor, monoid_cache: MonoidCache | None = None, use_cache: bool = False, ) -> Tensor: # --- Attention block with residual --- # --- 注意力块 + 残差连接 --- residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, _ = self.self_attn(hidden_states, monoid_cache=monoid_cache, use_cache=use_cache) hidden_states = residual + hidden_states # --- FFN block with residual --- # --- 前馈网络块 + 残差连接 --- residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # MonoidModel (backbone) # MonoidModel (骨干网络) # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ class MonoidPreTrainedModel(PreTrainedModel): config_class = MonoidConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MonoidDecoderLayer"] def _init_weights(self, module: nn.Module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() if isinstance(module, MonoidAttention): nn.init.constant_(module.decay_proj.bias, 4.0) class MonoidModel(MonoidPreTrainedModel): """ Stack of MonoidDecoderLayers with token embedding and final norm. 幺半群解码层堆叠, 带 token 嵌入和最终归一化。 Forward: embed_tokens → N × MonoidDecoderLayer → final_norm 前向: embed_tokens → N × MonoidDecoderLayer → final_norm """ def __init__(self, config: MonoidConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [MonoidDecoderLayer(config, i) for i in range(config.num_hidden_layers)] ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False self.post_init() def forward( self, input_ids: Tensor | None = None, inputs_embeds: Tensor | None = None, monoid_cache: MonoidCache | None = None, use_cache: bool = False, ) -> BaseModelOutputWithPast: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds for layer in self.layers: if self.gradient_checkpointing and self.training and not use_cache: hidden_states = self._gradient_checkpointing_func( layer.__call__, hidden_states, monoid_cache, use_cache, ) else: hidden_states = layer(hidden_states, monoid_cache=monoid_cache, use_cache=use_cache) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=monoid_cache, ) # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # MonoidForCausalLM — the full causal LM # MonoidForCausalLM — 完整因果语言模型 # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ class MonoidForCausalLM(MonoidPreTrainedModel, GenerationMixin): """ Monoid-based causal language model with LM head. 基于幺半群的因果语言模型, 带语言模型头。 The architecture in one sentence: "Llama body + Monoid mind" — reuse Llama's proven MLP/embeddings, replace attention with monoid state compression for O(1) inference. 一句话概括架构: "Llama 的身体 + 幺半群的思维" — 复用 Llama 成熟的 MLP/嵌入层, 用幺半群状态压缩替换注意力, 实现 O(1) 推理。 """ _tied_weights_keys = ["lm_head.weight"] # Tell HuggingFace GenerationMixin NOT to create DynamicCache. # Monoid uses its own O(1) MonoidCache, not KV-Cache. # 告诉 HuggingFace 不要创建 DynamicCache。 # Monoid 使用自己的 O(1) MonoidCache, 不是 KV 缓存。 _is_stateful = True def __init__(self, config: MonoidConfig): super().__init__(config) self.model = MonoidModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def prepare_inputs_for_generation( self, input_ids: Tensor, past_key_values=None, attention_mask: Tensor | None = None, inputs_embeds: Tensor | None = None, **kwargs, ) -> dict: """ Called by GenerationMixin at each decoding step. GenerationMixin 在每个解码步调用此方法。 HuggingFace may pass a DynamicCache; we intercept and replace it with MonoidCache since we don't use standard KV-cache. HuggingFace 可能传入 DynamicCache; 我们拦截并替换为 MonoidCache, 因为我们不使用标准 KV 缓存。 """ # Intercept non-MonoidCache objects (e.g. DynamicCache from GenerationMixin) # 拦截非 MonoidCache 对象 (如 GenerationMixin 创建的 DynamicCache) if past_key_values is not None and not isinstance(past_key_values, MonoidCache): past_key_values = None if past_key_values is not None and past_key_values.seen_tokens > 0: # Cache exists → only feed the latest token (O(1) inference) # 缓存已存在 → 只需输入最新的 token (O(1) 推理) input_ids = input_ids[:, -1:] model_inputs = { "input_ids": input_ids, "monoid_cache": past_key_values, "use_cache": True, } return model_inputs def forward( self, input_ids: Tensor | None = None, attention_mask: Tensor | None = None, # kept for API compat; monoid ignores this # 保留 API 兼容性; 幺半群不使用 position_ids: Tensor | None = None, # kept for API compat; monoid ignores this # 保留 API 兼容性; 幺半群不使用 past_key_values: MonoidCache | None = None, inputs_embeds: Tensor | None = None, labels: Tensor | None = None, use_cache: bool | None = None, monoid_cache: MonoidCache | None = None, output_attentions: bool | None = None, # kept for API compat output_hidden_states: bool | None = None, # kept for API compat logits_to_keep: int | Tensor = 0, **kwargs, ) -> CausalLMOutputWithPast: # monoid_cache takes priority; fall back to past_key_values for GenerationMixin compat # monoid_cache 优先; 兼容 GenerationMixin 传入的 past_key_values cache = monoid_cache or past_key_values # Discard any non-MonoidCache (e.g. DynamicCache injected by GenerationMixin) # 丢弃任何非 MonoidCache 对象 (如 GenerationMixin 注入的 DynamicCache) if cache is not None and not isinstance(cache, MonoidCache): cache = None if use_cache and cache is None: cache = MonoidCache() outputs = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, monoid_cache=cache, use_cache=bool(use_cache), ) hidden_states = outputs.last_hidden_state # Optionally only compute logits for the last K tokens (memory saving) # 可选仅计算最后 K 个 token 的 logits (节省内存) slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep > 0 else slice(None) logits = self.lm_head(hidden_states[:, slice_indices, :]) # Standard causal LM loss: cross-entropy with shift # 标准因果语言模型损失: 带偏移的交叉熵 loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = nn.functional.cross_entropy( shift_logits.view(-1, self.vocab_size), shift_labels.view(-1), ignore_index=-100, ) if cache is not None: cache.seen_tokens += (input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=cache, ) # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # AutoModel Registration / 自动注册 # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ AutoConfig.register("monoid", MonoidConfig) AutoModelForCausalLM.register(MonoidConfig, MonoidForCausalLM) # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # Smoke Tests / 验证 # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ if __name__ == '__main__': device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu') print(f'Device: {device}') config = MonoidConfig( vocab_size=49152, hidden_size=576, intermediate_size=1536, num_hidden_layers=30, num_attention_heads=9, head_dim=64, rms_norm_eps=1e-5, hidden_act="silu", tie_word_embeddings=True, ) model = MonoidForCausalLM(config).to(device) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f'Parameters: {n_params:,}') # -- Training smoke test / 训练冒烟测试 -- B, T = 2, 64 ids = torch.randint(0, config.vocab_size, (B, T), device=device) out = model(ids, labels=ids) print(f'Train — logits: {out.logits.shape}, loss: {out.loss:.4f}') # -- Inference smoke test (manual RNN loop) / 推理冒烟测试 (手动 RNN 循环) -- prompt = torch.randint(0, config.vocab_size, (1, 8), device=device) cache = MonoidCache() # Prefill / 预填充 prefill_out = model(prompt, use_cache=True, monoid_cache=cache) print(f'Prefill — logits: {prefill_out.logits.shape}, cache seen: {cache.seen_tokens}') # Decode 1 token / 解码 1 个 token next_tok = prefill_out.logits[:, -1:].argmax(dim=-1) step_out = model(next_tok, use_cache=True, monoid_cache=cache) print(f'Decode — logits: {step_out.logits.shape}, cache seen: {cache.seen_tokens}') # -- Monoid associativity check / 幺半群结合律验证 -- print('\nMonoid associativity check / 幺半群结合律验证:') a = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4)) b = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4)) c = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4)) ab_c = monoid_op(monoid_op(a, b), c) a_bc = monoid_op(a, monoid_op(b, c)) err = (ab_c[1] - a_bc[1]).abs().max().item() print(f' |(a⊕b)⊕c - a⊕(b⊕c)| = {err:.2e}') print('\nDone.')