Upload 11 files
Browse files- MonoidForCausalLM.py +93 -46
- README.md +88 -116
- model.safetensors +2 -2
- monoid_scan_cuda.py +164 -164
- training_args.bin +2 -2
MonoidForCausalLM.py
CHANGED
|
@@ -14,11 +14,13 @@ Architecture / 架构概要:
|
|
| 14 |
|
| 15 |
Monoid attention compresses the entire causal history into a
|
| 16 |
fixed-size state matrix S_t ∈ ℝ^{d×d} per head:
|
| 17 |
-
S_t = α_t · S_{t-1} + k_t ⊗ v_t
|
| 18 |
-
o_t = q_t · S_t
|
|
|
|
| 19 |
幺半群注意力将完整因果历史压缩到每个头一个固定大小的状态矩阵 S_t:
|
| 20 |
-
S_t = α_t · S_{t-1} + k_t ⊗ v_t
|
| 21 |
-
o_t = q_t · S_t
|
|
|
|
| 22 |
|
| 23 |
This is a monoid because the binary operator:
|
| 24 |
(log_α, S) ⊕ (log_β, X) = (log_α + log_β, exp(log_β)·S + X)
|
|
@@ -69,12 +71,12 @@ except ImportError:
|
|
| 69 |
# Slower than the fused CUDA kernel but numerically identical.
|
| 70 |
|
| 71 |
def parallel_scan(log_alpha: Tensor, kv: Tensor) -> Tensor:
|
| 72 |
-
"""Sequential prefix scan fallback: S_t = exp(log_α_t)·S_{t-1} + kv_t."""
|
| 73 |
B, H, T, d1, d2 = kv.shape
|
| 74 |
states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
|
| 75 |
S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
|
| 76 |
for t in range(T):
|
| 77 |
-
decay = torch.exp(log_alpha[:, :, t]) # [B, H,
|
| 78 |
while decay.dim() < S.dim():
|
| 79 |
decay = decay.unsqueeze(-1)
|
| 80 |
S = S * decay + kv[:, :, t]
|
|
@@ -86,7 +88,7 @@ except ImportError:
|
|
| 86 |
B, H, T, d1, d2 = kv.shape
|
| 87 |
states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
|
| 88 |
S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
|
| 89 |
-
log_acc = torch.zeros(B, H,
|
| 90 |
for t in range(T):
|
| 91 |
decay = torch.exp(log_alpha[:, :, t])
|
| 92 |
while decay.dim() < S.dim():
|
|
@@ -217,11 +219,12 @@ def monoid_op(
|
|
| 217 |
b: tuple[Tensor, Tensor],
|
| 218 |
) -> tuple[Tensor, Tensor]:
|
| 219 |
"""
|
| 220 |
-
The monoid binary operator ⊕ on (log-space decay, state matrix) pairs.
|
| 221 |
-
幺半群二元算子 ⊕,作用于 (
|
| 222 |
|
| 223 |
Definition / 定义:
|
| 224 |
-
(log_α, S) ⊕ (log_β, X) = (log_α + log_β, exp(log_β)·S + X)
|
|
|
|
| 225 |
|
| 226 |
Why this is a monoid / 为什么这是幺半群:
|
| 227 |
• Associativity / 结合律:
|
|
@@ -326,15 +329,19 @@ class MonoidAttention(nn.Module):
|
|
| 326 |
|
| 327 |
# --- Decay gate (novel component, randomly initialized) ---
|
| 328 |
# --- 衰减门 (全新组件, 随机初始化) ---
|
| 329 |
-
# Projects hidden_size → num_heads, yielding
|
| 330 |
-
#
|
| 331 |
-
#
|
| 332 |
-
#
|
| 333 |
-
#
|
| 334 |
-
#
|
| 335 |
-
#
|
| 336 |
-
#
|
| 337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
|
| 339 |
# --- QK-Norm (novel component, randomly initialized) ---
|
| 340 |
# --- QK 归一化 (全新组件, 随机初始化) ---
|
|
@@ -358,12 +365,17 @@ class MonoidAttention(nn.Module):
|
|
| 358 |
def forward(
|
| 359 |
self,
|
| 360 |
hidden_states: Tensor,
|
|
|
|
| 361 |
monoid_cache: MonoidCache | None = None,
|
| 362 |
use_cache: bool = False,
|
| 363 |
) -> tuple[Tensor, tuple[Tensor, Tensor] | None]:
|
| 364 |
"""
|
| 365 |
Args:
|
| 366 |
hidden_states: [B, T, hidden_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
monoid_cache: O(1) state cache for inference
|
| 368 |
推理用 O(1) 状态缓存
|
| 369 |
use_cache: whether to use/update the cache
|
|
@@ -399,13 +411,40 @@ class MonoidAttention(nn.Module):
|
|
| 399 |
# PSD 保证信息单调积累。
|
| 400 |
k = torch.nn.functional.silu(k)
|
| 401 |
|
| 402 |
-
# --- Compute per-
|
| 403 |
-
# ---
|
| 404 |
-
#
|
| 405 |
-
#
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
# ══════════════════════════════════════════════════════════
|
| 411 |
# Inference path (RNN mode): O(1) per token per layer
|
|
@@ -427,7 +466,7 @@ class MonoidAttention(nn.Module):
|
|
| 427 |
# Outer product: k_t ⊗ v_t ∈ ℝ^{H×d×d}
|
| 428 |
# 外积: k_t ⊗ v_t ∈ ℝ^{H×d×d}
|
| 429 |
kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, 0], v[:, :, 0])
|
| 430 |
-
log_t = log_alpha[:, :, 0] # [B,H,
|
| 431 |
|
| 432 |
prev = monoid_cache.get_state(self.layer_idx) if monoid_cache else None
|
| 433 |
if prev is None:
|
|
@@ -467,16 +506,16 @@ class MonoidAttention(nn.Module):
|
|
| 467 |
kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
|
| 468 |
states, (log_acc, S_T) = parallel_scan_with_state(log_alpha, kv)
|
| 469 |
|
| 470 |
-
# Add h0 contribution: S_t += (∏_{i=0}^{t} α_i) · h0
|
| 471 |
-
# 叠加 h0 贡献: S_t += (∏_{i=0}^{t} α_i) · h0
|
| 472 |
-
cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,
|
| 473 |
-
h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,
|
| 474 |
states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
|
| 475 |
|
| 476 |
# Final state includes h0 contribution
|
| 477 |
# 最终状态包含 h0 贡献
|
| 478 |
-
total_h0_decay = torch.exp(log_acc).unsqueeze(-1) # [B,H,
|
| 479 |
-
S_final = S_T + total_h0_decay * self.h0.squeeze(0) # [B,H,d,d]
|
| 480 |
# h0 is [1,H,d,d], squeeze(0) removed for clarity but expand also works
|
| 481 |
final_state = (log_acc, S_final)
|
| 482 |
|
|
@@ -507,16 +546,16 @@ class MonoidAttention(nn.Module):
|
|
| 507 |
# 向量化外积: 一次性计算所有 t 的 k_t ⊗ v_t
|
| 508 |
kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
|
| 509 |
|
| 510 |
-
# Parallel prefix scan: S_t = α_t·S_{t-1} + kv_t (from S=0)
|
| 511 |
-
# 并行前缀扫描: S_t = α_t·S_{t-1} + kv_t (从 S=0 开始)
|
| 512 |
-
#
|
| 513 |
-
#
|
| 514 |
states = parallel_scan(log_alpha, kv) # [B,H,T,d,d]
|
| 515 |
|
| 516 |
-
# Add h0 contribution: S_t += (∏_{i=0}^{t} α_i) · h0
|
| 517 |
-
# 叠加 h0 贡献: S_t += (∏_{i=0}^{t} α_i) · h0
|
| 518 |
-
cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,
|
| 519 |
-
h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,
|
| 520 |
states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
|
| 521 |
|
| 522 |
# Vectorized readout: o_t = q_t · S_t for all t at once
|
|
@@ -557,6 +596,7 @@ class MonoidDecoderLayer(nn.Module):
|
|
| 557 |
def forward(
|
| 558 |
self,
|
| 559 |
hidden_states: Tensor,
|
|
|
|
| 560 |
monoid_cache: MonoidCache | None = None,
|
| 561 |
use_cache: bool = False,
|
| 562 |
) -> Tensor:
|
|
@@ -564,7 +604,7 @@ class MonoidDecoderLayer(nn.Module):
|
|
| 564 |
# --- 注意力块 + 残差连接 ---
|
| 565 |
residual = hidden_states
|
| 566 |
hidden_states = self.input_layernorm(hidden_states)
|
| 567 |
-
hidden_states, _ = self.self_attn(hidden_states, monoid_cache=monoid_cache, use_cache=use_cache)
|
| 568 |
hidden_states = residual + hidden_states
|
| 569 |
|
| 570 |
# --- FFN block with residual ---
|
|
@@ -600,7 +640,7 @@ class MonoidPreTrainedModel(PreTrainedModel):
|
|
| 600 |
module.weight.data[module.padding_idx].zero_()
|
| 601 |
|
| 602 |
if isinstance(module, MonoidAttention):
|
| 603 |
-
nn.init.constant_(module.decay_proj.bias,
|
| 604 |
|
| 605 |
class MonoidModel(MonoidPreTrainedModel):
|
| 606 |
"""
|
|
@@ -625,6 +665,7 @@ class MonoidModel(MonoidPreTrainedModel):
|
|
| 625 |
def forward(
|
| 626 |
self,
|
| 627 |
input_ids: Tensor | None = None,
|
|
|
|
| 628 |
inputs_embeds: Tensor | None = None,
|
| 629 |
monoid_cache: MonoidCache | None = None,
|
| 630 |
use_cache: bool = False,
|
|
@@ -638,11 +679,12 @@ class MonoidModel(MonoidPreTrainedModel):
|
|
| 638 |
hidden_states = self._gradient_checkpointing_func(
|
| 639 |
layer.__call__,
|
| 640 |
hidden_states,
|
|
|
|
| 641 |
monoid_cache,
|
| 642 |
use_cache,
|
| 643 |
)
|
| 644 |
else:
|
| 645 |
-
hidden_states = layer(hidden_states, monoid_cache=monoid_cache, use_cache=use_cache)
|
| 646 |
|
| 647 |
hidden_states = self.norm(hidden_states)
|
| 648 |
|
|
@@ -723,9 +765,13 @@ class MonoidForCausalLM(MonoidPreTrainedModel, GenerationMixin):
|
|
| 723 |
# Cache exists → only feed the latest token (O(1) inference)
|
| 724 |
# 缓存已存在 → 只需输入最新的 token (O(1) 推理)
|
| 725 |
input_ids = input_ids[:, -1:]
|
|
|
|
|
|
|
|
|
|
| 726 |
|
| 727 |
model_inputs = {
|
| 728 |
"input_ids": input_ids,
|
|
|
|
| 729 |
"monoid_cache": past_key_values,
|
| 730 |
"use_cache": True,
|
| 731 |
}
|
|
@@ -734,8 +780,8 @@ class MonoidForCausalLM(MonoidPreTrainedModel, GenerationMixin):
|
|
| 734 |
def forward(
|
| 735 |
self,
|
| 736 |
input_ids: Tensor | None = None,
|
| 737 |
-
attention_mask: Tensor | None = None, #
|
| 738 |
-
#
|
| 739 |
position_ids: Tensor | None = None, # kept for API compat; monoid ignores this
|
| 740 |
# 保留 API 兼容性; 幺半群不使用
|
| 741 |
past_key_values: MonoidCache | None = None,
|
|
@@ -762,6 +808,7 @@ class MonoidForCausalLM(MonoidPreTrainedModel, GenerationMixin):
|
|
| 762 |
|
| 763 |
outputs = self.model(
|
| 764 |
input_ids=input_ids,
|
|
|
|
| 765 |
inputs_embeds=inputs_embeds,
|
| 766 |
monoid_cache=cache,
|
| 767 |
use_cache=bool(use_cache),
|
|
|
|
| 14 |
|
| 15 |
Monoid attention compresses the entire causal history into a
|
| 16 |
fixed-size state matrix S_t ∈ ℝ^{d×d} per head:
|
| 17 |
+
S_t = diag(α_t) · S_{t-1} + k_t ⊗ v_t (vector decay recurrence)
|
| 18 |
+
o_t = q_t · S_t (state readout)
|
| 19 |
+
where α_t ∈ ℝ^d is a per-dimension vector decay gate.
|
| 20 |
幺半群注意力将完整因果历史压缩到每个头一个固定大小的状态矩阵 S_t:
|
| 21 |
+
S_t = diag(α_t) · S_{t-1} + k_t ⊗ v_t (向量衰减递推)
|
| 22 |
+
o_t = q_t · S_t (状态读出)
|
| 23 |
+
其中 α_t ∈ ℝ^d 是逐维度的向量衰减门。
|
| 24 |
|
| 25 |
This is a monoid because the binary operator:
|
| 26 |
(log_α, S) ⊕ (log_β, X) = (log_α + log_β, exp(log_β)·S + X)
|
|
|
|
| 71 |
# Slower than the fused CUDA kernel but numerically identical.
|
| 72 |
|
| 73 |
def parallel_scan(log_alpha: Tensor, kv: Tensor) -> Tensor:
|
| 74 |
+
"""Sequential prefix scan fallback: S_t[i,:] = exp(log_α_t[i])·S_{t-1}[i,:] + kv_t[i,:]."""
|
| 75 |
B, H, T, d1, d2 = kv.shape
|
| 76 |
states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
|
| 77 |
S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
|
| 78 |
for t in range(T):
|
| 79 |
+
decay = torch.exp(log_alpha[:, :, t]) # [B, H, d]
|
| 80 |
while decay.dim() < S.dim():
|
| 81 |
decay = decay.unsqueeze(-1)
|
| 82 |
S = S * decay + kv[:, :, t]
|
|
|
|
| 88 |
B, H, T, d1, d2 = kv.shape
|
| 89 |
states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
|
| 90 |
S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
|
| 91 |
+
log_acc = torch.zeros(B, H, d1, device=log_alpha.device, dtype=log_alpha.dtype)
|
| 92 |
for t in range(T):
|
| 93 |
decay = torch.exp(log_alpha[:, :, t])
|
| 94 |
while decay.dim() < S.dim():
|
|
|
|
| 219 |
b: tuple[Tensor, Tensor],
|
| 220 |
) -> tuple[Tensor, Tensor]:
|
| 221 |
"""
|
| 222 |
+
The monoid binary operator ⊕ on (log-space vector decay, state matrix) pairs.
|
| 223 |
+
幺半群二元算子 ⊕,作用于 (对数向量衰减, 状态矩阵) 对。
|
| 224 |
|
| 225 |
Definition / 定义:
|
| 226 |
+
(log_α, S) ⊕ (log_β, X) = (log_α + log_β, diag(exp(log_β))·S + X)
|
| 227 |
+
where log_α, log_β ∈ ℝ^d are per-dimension log decay vectors.
|
| 228 |
|
| 229 |
Why this is a monoid / 为什么这是幺半群:
|
| 230 |
• Associativity / 结合律:
|
|
|
|
| 329 |
|
| 330 |
# --- Decay gate (novel component, randomly initialized) ---
|
| 331 |
# --- 衰减门 (全新组件, 随机初始化) ---
|
| 332 |
+
# Projects hidden_size → num_heads * head_dim, yielding a VECTOR per head.
|
| 333 |
+
# Activation: log_α = -softplus(Wx + b), giving α ∈ (0, 1].
|
| 334 |
+
# Vector decay: S_t = diag(α_t) · S_{t-1} + k_t ⊗ v_t
|
| 335 |
+
# Different feature dimensions can have independent lifetimes:
|
| 336 |
+
# - fast-decaying dims for local syntax
|
| 337 |
+
# - slow-decaying dims for global entity/fact memory
|
| 338 |
+
# 将 hidden_size 投影到 num_heads * head_dim, 每个头产生一个向量。
|
| 339 |
+
# 激活: log_α = -softplus(Wx + b), 使 α ∈ (0, 1]。
|
| 340 |
+
# 向量衰减: S_t = diag(α_t) · S_{t-1} + k_t ⊗ v_t
|
| 341 |
+
# 不同特征维度拥有独立的生命周期:
|
| 342 |
+
# - 快速衰减的维度负责局部语法结构
|
| 343 |
+
# - 慢速衰减的维度负责全局实体和事实记忆
|
| 344 |
+
self.decay_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=True)
|
| 345 |
|
| 346 |
# --- QK-Norm (novel component, randomly initialized) ---
|
| 347 |
# --- QK 归一化 (全新组件, 随机初始化) ---
|
|
|
|
| 365 |
def forward(
|
| 366 |
self,
|
| 367 |
hidden_states: Tensor,
|
| 368 |
+
attention_mask: Tensor | None = None,
|
| 369 |
monoid_cache: MonoidCache | None = None,
|
| 370 |
use_cache: bool = False,
|
| 371 |
) -> tuple[Tensor, tuple[Tensor, Tensor] | None]:
|
| 372 |
"""
|
| 373 |
Args:
|
| 374 |
hidden_states: [B, T, hidden_size]
|
| 375 |
+
attention_mask: [B, T] with 1=real token, 0=pad.
|
| 376 |
+
For PAD positions: α=1 (preserve state), kv=0 (no contribution).
|
| 377 |
+
掩码: 1=真实token, 0=填充。
|
| 378 |
+
填充位置: α=1 (保持状态不变), kv=0 (无贡献)。
|
| 379 |
monoid_cache: O(1) state cache for inference
|
| 380 |
推理用 O(1) 状态缓存
|
| 381 |
use_cache: whether to use/update the cache
|
|
|
|
| 411 |
# PSD 保证信息单调积累。
|
| 412 |
k = torch.nn.functional.silu(k)
|
| 413 |
|
| 414 |
+
# --- Compute per-dimension vector decay gate α_t ---
|
| 415 |
+
# --- 计算每维度向量衰减门 α_t ---
|
| 416 |
+
# Negative Softplus: log_α = -softplus(Wx + b)
|
| 417 |
+
# Value range: log_α ∈ (-∞, 0), i.e. α ∈ (0, 1].
|
| 418 |
+
# When Wx → -∞: softplus → 0, α → 1 (perfect memory, no forgetting)
|
| 419 |
+
# When Wx → +∞: softplus → Wx, α → 0 (complete forgetting)
|
| 420 |
+
# This avoids α > 1 explosion (unlike SiLU) while still allowing
|
| 421 |
+
# α = 1 for lossless memory (unlike Sigmoid which caps at <1).
|
| 422 |
+
# Each dimension of the d-vector decays independently:
|
| 423 |
+
# S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
|
| 424 |
+
#
|
| 425 |
+
# 负 Softplus: log_α = -softplus(Wx + b)
|
| 426 |
+
# 值域: log_α ∈ (-∞, 0), 即 α ∈ (0, 1]。
|
| 427 |
+
# 当 Wx → -∞: softplus → 0, α → 1 (完美记忆, 不遗忘)
|
| 428 |
+
# 当 Wx → +∞: softplus → Wx, α → 0 (完全遗忘)
|
| 429 |
+
# 避免了 SiLU 的 α > 1 爆炸, 同时允许 α = 1 无损记忆 (Sigmoid 无法做到)。
|
| 430 |
+
# d-向量的每个维度独立衰减:
|
| 431 |
+
# S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
|
| 432 |
+
raw = self.decay_proj(hidden_states) # [B,T,H*d]
|
| 433 |
+
log_alpha = -torch.nn.functional.softplus(raw) # [B,T,H*d]
|
| 434 |
+
log_alpha = log_alpha.view(B, T, H, d).transpose(1, 2) # [B,H,T,d]
|
| 435 |
+
|
| 436 |
+
# --- Apply attention_mask: PAD tokens must be invisible to the recurrence ---
|
| 437 |
+
# --- 应用注意力掩码: PAD token 必须对递推不可见 ---
|
| 438 |
+
# For PAD positions (mask=0): set log_α=0 (α=1, preserve state) and kv=0 (no contribution).
|
| 439 |
+
# This makes S_t = 1·S_{t-1} + 0 = S_{t-1}, i.e. PAD is a no-op on the state.
|
| 440 |
+
# 对于 PAD 位置 (mask=0): 设 log_α=0 (α=1, 保持状态) 且 kv=0 (无贡献)。
|
| 441 |
+
# 这使得 S_t = 1·S_{t-1} + 0 = S_{t-1}, 即 PAD 对状态是空操作。
|
| 442 |
+
if attention_mask is not None:
|
| 443 |
+
# attention_mask: [B, T] → [B, 1, T, 1] for broadcasting with [B, H, T, d]
|
| 444 |
+
mask = attention_mask[:, None, :, None].to(log_alpha.dtype) # [B,1,T,1]
|
| 445 |
+
log_alpha = log_alpha * mask # PAD → log_α=0 → α=1
|
| 446 |
+
k = k * mask # PAD → k=0
|
| 447 |
+
v = v * mask # PAD → v=0 → kv=0
|
| 448 |
|
| 449 |
# ══════════════════════════════════════════════════════════
|
| 450 |
# Inference path (RNN mode): O(1) per token per layer
|
|
|
|
| 466 |
# Outer product: k_t ⊗ v_t ∈ ℝ^{H×d×d}
|
| 467 |
# 外积: k_t ⊗ v_t ∈ ℝ^{H×d×d}
|
| 468 |
kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, 0], v[:, :, 0])
|
| 469 |
+
log_t = log_alpha[:, :, 0] # [B,H,d]
|
| 470 |
|
| 471 |
prev = monoid_cache.get_state(self.layer_idx) if monoid_cache else None
|
| 472 |
if prev is None:
|
|
|
|
| 506 |
kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
|
| 507 |
states, (log_acc, S_T) = parallel_scan_with_state(log_alpha, kv)
|
| 508 |
|
| 509 |
+
# Add h0 contribution: S_t += diag(∏_{i=0}^{t} α_i) · h0
|
| 510 |
+
# 叠加 h0 贡献: S_t += diag(∏_{i=0}^{t} α_i) · h0
|
| 511 |
+
cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,d]
|
| 512 |
+
h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,d,1]
|
| 513 |
states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
|
| 514 |
|
| 515 |
# Final state includes h0 contribution
|
| 516 |
# 最终状态包含 h0 贡献
|
| 517 |
+
total_h0_decay = torch.exp(log_acc).unsqueeze(-1) # [B,H,d,1]
|
| 518 |
+
S_final = S_T + total_h0_decay * self.h0.squeeze(0) # [B,H,d,d]
|
| 519 |
# h0 is [1,H,d,d], squeeze(0) removed for clarity but expand also works
|
| 520 |
final_state = (log_acc, S_final)
|
| 521 |
|
|
|
|
| 546 |
# 向量化外积: 一次性计算所有 t 的 k_t ⊗ v_t
|
| 547 |
kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
|
| 548 |
|
| 549 |
+
# Parallel prefix scan: S_t = diag(α_t)·S_{t-1} + kv_t (from S=0)
|
| 550 |
+
# 并行前缀扫描: S_t = diag(α_t)·S_{t-1} + kv_t (从 S=0 开始)
|
| 551 |
+
# log_alpha is [B,H,T,d] — vector decay per dimension.
|
| 552 |
+
# log_alpha 为 [B,H,T,d] — 每维度向量衰减。
|
| 553 |
states = parallel_scan(log_alpha, kv) # [B,H,T,d,d]
|
| 554 |
|
| 555 |
+
# Add h0 contribution: S_t += diag(∏_{i=0}^{t} α_i) · h0
|
| 556 |
+
# 叠加 h0 贡献: S_t += diag(∏_{i=0}^{t} α_i) · h0
|
| 557 |
+
cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,d]
|
| 558 |
+
h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,d,1]
|
| 559 |
states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
|
| 560 |
|
| 561 |
# Vectorized readout: o_t = q_t · S_t for all t at once
|
|
|
|
| 596 |
def forward(
|
| 597 |
self,
|
| 598 |
hidden_states: Tensor,
|
| 599 |
+
attention_mask: Tensor | None = None,
|
| 600 |
monoid_cache: MonoidCache | None = None,
|
| 601 |
use_cache: bool = False,
|
| 602 |
) -> Tensor:
|
|
|
|
| 604 |
# --- 注意力块 + 残差连接 ---
|
| 605 |
residual = hidden_states
|
| 606 |
hidden_states = self.input_layernorm(hidden_states)
|
| 607 |
+
hidden_states, _ = self.self_attn(hidden_states, attention_mask=attention_mask, monoid_cache=monoid_cache, use_cache=use_cache)
|
| 608 |
hidden_states = residual + hidden_states
|
| 609 |
|
| 610 |
# --- FFN block with residual ---
|
|
|
|
| 640 |
module.weight.data[module.padding_idx].zero_()
|
| 641 |
|
| 642 |
if isinstance(module, MonoidAttention):
|
| 643 |
+
nn.init.constant_(module.decay_proj.bias, 1.0)
|
| 644 |
|
| 645 |
class MonoidModel(MonoidPreTrainedModel):
|
| 646 |
"""
|
|
|
|
| 665 |
def forward(
|
| 666 |
self,
|
| 667 |
input_ids: Tensor | None = None,
|
| 668 |
+
attention_mask: Tensor | None = None,
|
| 669 |
inputs_embeds: Tensor | None = None,
|
| 670 |
monoid_cache: MonoidCache | None = None,
|
| 671 |
use_cache: bool = False,
|
|
|
|
| 679 |
hidden_states = self._gradient_checkpointing_func(
|
| 680 |
layer.__call__,
|
| 681 |
hidden_states,
|
| 682 |
+
attention_mask,
|
| 683 |
monoid_cache,
|
| 684 |
use_cache,
|
| 685 |
)
|
| 686 |
else:
|
| 687 |
+
hidden_states = layer(hidden_states, attention_mask=attention_mask, monoid_cache=monoid_cache, use_cache=use_cache)
|
| 688 |
|
| 689 |
hidden_states = self.norm(hidden_states)
|
| 690 |
|
|
|
|
| 765 |
# Cache exists → only feed the latest token (O(1) inference)
|
| 766 |
# 缓存已存在 → 只需输入最新的 token (O(1) 推理)
|
| 767 |
input_ids = input_ids[:, -1:]
|
| 768 |
+
# Decode step: single real token, no PAD → mask not needed
|
| 769 |
+
# 解码步: 单个真实token, 无PAD → 不需要掩码
|
| 770 |
+
attention_mask = None
|
| 771 |
|
| 772 |
model_inputs = {
|
| 773 |
"input_ids": input_ids,
|
| 774 |
+
"attention_mask": attention_mask,
|
| 775 |
"monoid_cache": past_key_values,
|
| 776 |
"use_cache": True,
|
| 777 |
}
|
|
|
|
| 780 |
def forward(
|
| 781 |
self,
|
| 782 |
input_ids: Tensor | None = None,
|
| 783 |
+
attention_mask: Tensor | None = None, # [B,T] 1=real, 0=pad — used to mask PAD from recurrence
|
| 784 |
+
# [B,T] 1=真实token, 0=填充 — 用于屏蔽PAD对递推的影响
|
| 785 |
position_ids: Tensor | None = None, # kept for API compat; monoid ignores this
|
| 786 |
# 保留 API 兼容性; 幺半群不使用
|
| 787 |
past_key_values: MonoidCache | None = None,
|
|
|
|
| 808 |
|
| 809 |
outputs = self.model(
|
| 810 |
input_ids=input_ids,
|
| 811 |
+
attention_mask=attention_mask,
|
| 812 |
inputs_embeds=inputs_embeds,
|
| 813 |
monoid_cache=cache,
|
| 814 |
use_cache=bool(use_cache),
|
README.md
CHANGED
|
@@ -9,6 +9,7 @@ tags:
|
|
| 9 |
- linear-attention
|
| 10 |
- state-space
|
| 11 |
- O(1)-inference
|
|
|
|
| 12 |
- reasoning
|
| 13 |
pipeline_tag: text-generation
|
| 14 |
model-index:
|
|
@@ -20,124 +21,113 @@ model-index:
|
|
| 20 |
|
| 21 |
A 1.3B parameter language model that replaces softmax attention with **causal monoid state compression**, achieving **O(1) time per token** and **O(1) memory** at inference — regardless of sequence length.
|
| 22 |
|
| 23 |
-
Fine-tuned for enhanced reasoning with structured chain-of-thought data.
|
| 24 |
-
|
| 25 |
## Monoid Attention — Internal Structure
|
| 26 |
|
| 27 |
```
|
| 28 |
MonoidAttention (per layer, per head)
|
| 29 |
-
|
| 30 |
-
│
|
| 31 |
-
│ x_t ∈ R^{2048}
|
| 32 |
-
│ │
|
| 33 |
-
│ ├──> q_proj ──> RMSNorm ──> q_t ∈ R^
|
| 34 |
-
│ │
|
| 35 |
-
│ ├──> k_proj ──> RMSNorm ──> SiLU ──> k_t ∈ R^
|
| 36 |
-
│ │
|
| 37 |
-
│ ├──> v_proj ──> v_t ∈ R^
|
| 38 |
-
│ │
|
| 39 |
-
│ └──> decay_proj ──>
|
| 40 |
-
│
|
| 41 |
-
│ k_t
|
| 42 |
-
│ │
|
| 43 |
-
│ │
|
| 44 |
-
│ v
|
| 45 |
-
│ S_t =
|
| 46 |
-
│ │
|
| 47 |
-
│ │
|
| 48 |
-
│ v
|
| 49 |
-
│ o_t = q_t
|
| 50 |
-
│
|
| 51 |
-
|
| 52 |
-
```
|
| 53 |
-
|
| 54 |
-
## Monoid State Diagonal — O(1) Compression Contour
|
| 55 |
-
|
| 56 |
-
The state matrix `S_t` accumulates causal history along its diagonal. Each head maintains an independent `d x d` state that compresses ALL past tokens into a fixed footprint:
|
| 57 |
-
|
| 58 |
-
```
|
| 59 |
-
State Matrix S_t ∈ R^{64 x 64} (one per head, 32 heads per layer)
|
| 60 |
-
|
| 61 |
-
k-dim -->
|
| 62 |
-
0 8 16 24 32 40 48 56 63
|
| 63 |
-
┌───┬───┬───┬───┬───┬───┬───┬───┐ 0
|
| 64 |
-
│***│** │* │ │ │ │ │ │ v-dim
|
| 65 |
-
│***│** │* │. │ │ │ │ │ |
|
| 66 |
-
├───┼───┼───┼───┼───┼───┼───┼───┤ 8 |
|
| 67 |
-
│** │***│** │* │. │ │ │ │ v
|
| 68 |
-
│* │***│** │* │. │ │ │ │
|
| 69 |
-
├───┼───┼───┼───┼───┼───┼───┼───┤ 16
|
| 70 |
-
│* │** │***│** │* │. │ │ │
|
| 71 |
-
│. │* │***│** │* │. │ │ │
|
| 72 |
-
├───┼───┼───┼───┼───┼───┼───┼───┤ 24
|
| 73 |
-
│ │. │** │***│** │* │. │ │
|
| 74 |
-
│ │ │* │***│** │* │. │ │
|
| 75 |
-
├───┼───┼───┼───┼───┼───┼───┼───┤ 32
|
| 76 |
-
│ │ │. │** │***│** │* │. │
|
| 77 |
-
│ │ │ │* │***│** │* │. │
|
| 78 |
-
├───┼───┼───┼───┼───┼───┼───┼───┤ 40
|
| 79 |
-
│ │ │ │. │** │***│** │* │
|
| 80 |
-
│ │ │ │ │* │***│** │* │
|
| 81 |
-
├───┼───┼───┼───┼───┼───┼───┼───┤ 48
|
| 82 |
-
│ │ │ │ │. │** │***│** │
|
| 83 |
-
│ │ │ │ │ │* │***│** │
|
| 84 |
-
├───┼───┼───┼───┼───┼───┼───┼───┤ 56
|
| 85 |
-
│ │ │ │ │ │. │** │***│
|
| 86 |
-
│ │ │ │ │ │ │* │***│
|
| 87 |
-
└───┴───┴───┴───┴───┴─��─┴───┴───┘ 63
|
| 88 |
-
|
| 89 |
-
Legend: *** = high activation (recent tokens, alpha^0 ~ alpha^2)
|
| 90 |
-
** = medium (alpha^3 ~ alpha^5)
|
| 91 |
-
* = fading (alpha^6 ~ alpha^10)
|
| 92 |
-
. = near-zero (alpha^11+, effectively forgotten)
|
| 93 |
-
= zero (never reached or fully decayed)
|
| 94 |
-
|
| 95 |
-
The diagonal band emerges because S_t = SUM_{i<=t} alpha^{t-i} * k_i (x) v_i.
|
| 96 |
-
Recent outer products dominate near the diagonal; older ones decay
|
| 97 |
-
exponentially via alpha, creating this characteristic contour.
|
| 98 |
```
|
| 99 |
|
| 100 |
-
|
| 101 |
## Key Properties
|
| 102 |
|
| 103 |
| Property | Transformer (Llama) | Spartacus (Monoid) |
|
| 104 |
|---|---|---|
|
| 105 |
-
| Inference time per token | O(T)
|
| 106 |
-
| Inference memory per layer | O(T)
|
| 107 |
-
| Sequence length extrapolation | Degrades beyond training length | **Unlimited**
|
| 108 |
| Causality | Imposed via attention mask | **Built into the recurrence** |
|
| 109 |
-
| Training complexity | O(T
|
| 110 |
|
| 111 |
## The Monoid Recurrence
|
| 112 |
|
| 113 |
Standard attention computes:
|
| 114 |
|
| 115 |
```
|
| 116 |
-
o_t =
|
| 117 |
```
|
| 118 |
|
| 119 |
Monoid attention compresses the entire causal history into a **fixed-size state matrix** S_t per head:
|
| 120 |
|
| 121 |
```
|
| 122 |
-
S_t =
|
| 123 |
-
o_t = q_t
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
```
|
| 125 |
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
-
|
| 129 |
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
-
-
|
| 133 |
-
- The model learns **when to forget** rather than encoding **where tokens are** (no positional encoding needed)
|
| 134 |
-
- No attention mask required -- causality is structural, not enforced
|
| 135 |
|
| 136 |
## Design Choices
|
| 137 |
|
| 138 |
-
- **SiLU-activated keys**: `k = SiLU(k_proj(x))` ensures non-negative keys, making the state matrix
|
| 139 |
-
- **
|
| 140 |
-
- **
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
## Model Details
|
| 143 |
|
|
@@ -151,7 +141,8 @@ Unlike Transformers where causality is a constraint imposed by masking, Spartacu
|
|
| 151 |
| Layers | 16 |
|
| 152 |
| Attention heads | 32 |
|
| 153 |
| Head dimension | 64 |
|
| 154 |
-
|
|
|
|
|
| 155 |
| Vocabulary | 128,256 (Llama-3.2 tokenizer) |
|
| 156 |
| Precision | bfloat16 |
|
| 157 |
|
|
@@ -167,7 +158,7 @@ Unlike Transformers where causality is a constraint imposed by masking, Spartacu
|
|
| 167 |
|
| 168 |
### Comparison with ~1B Baselines (acc_norm, 0-shot)
|
| 169 |
|
| 170 |
-
| Task | Spartacus-1B
|
| 171 |
|---|---|---|---|---|---|
|
| 172 |
| ARC-C | **0.3063** | 0.3268 | ~0.359 | 0.284 | ~0.301 |
|
| 173 |
| ARC-E | **0.5518** | 0.5547 | ~0.752 | 0.512 | ~0.530 |
|
|
@@ -177,34 +168,15 @@ Unlike Transformers where causality is a constraint imposed by masking, Spartacu
|
|
| 177 |
|
| 178 |
> Spartacus achieves competitive performance with sub-quadratic models (Mamba, RWKV) while maintaining **O(1) inference time and memory per token**. Scores marked with ~ are approximate community-reported values.
|
| 179 |
|
| 180 |
-
## Training
|
| 181 |
-
|
| 182 |
-
### Stage 1: General SFT
|
| 183 |
-
|
| 184 |
-
- **Base weights**: Transferred from Llama-3.2-1B-Instruct (embeddings, MLP, norms)
|
| 185 |
-
- **Data**: Capybara + smol-smoltalk (general conversation)
|
| 186 |
-
- **Training**: Full-parameter SFT
|
| 187 |
-
|
| 188 |
-
### Stage 2: Reasoning Enhancement
|
| 189 |
-
|
| 190 |
-
- **Data mix**: 60% Qwen3-Short-Reasoning + 20% Capybara + 20% smol-smoltalk
|
| 191 |
-
- **Steps**: 2,000
|
| 192 |
-
- **Learning rate**: 2e-5 (cosine schedule, 50 warmup steps)
|
| 193 |
-
- **Batch size**: 8
|
| 194 |
-
- **Sequence length**: 2,048
|
| 195 |
-
- **Precision**: bfloat16
|
| 196 |
-
- **Optimizer**: AdamW (weight decay 0.01, max grad norm 1.0)
|
| 197 |
-
|
| 198 |
-
The reasoning data uses structured "Thought + Solution" format to strengthen chain-of-thought capabilities while the general data prevents catastrophic forgetting.
|
| 199 |
-
|
| 200 |
## Parallel Scan Implementation
|
| 201 |
|
| 202 |
-
The `monoid_scan_cuda.py` module provides a Triton JIT-compiled parallel prefix scan:
|
| 203 |
|
| 204 |
-
- **
|
| 205 |
-
- **
|
|
|
|
| 206 |
- **Fallback**: Pure PyTorch sequential scan for CPU/MPS
|
| 207 |
-
- **Auto-dispatch**: CUDA
|
| 208 |
|
| 209 |
## Usage
|
| 210 |
|
|
@@ -231,7 +203,7 @@ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
|
| 231 |
|
| 232 |
```
|
| 233 |
MonoidForCausalLM.py # Model architecture (MonoidConfig, MonoidAttention, MonoidForCausalLM)
|
| 234 |
-
monoid_scan_cuda.py # Triton JIT parallel prefix scan + PyTorch fallback
|
| 235 |
model.safetensors # Model weights (bfloat16)
|
| 236 |
config.json # Model configuration
|
| 237 |
tokenizer.json # Llama-3.2 tokenizer
|
|
@@ -245,7 +217,7 @@ tokenizer.json # Llama-3.2 tokenizer
|
|
| 245 |
author={NoesisLab},
|
| 246 |
year={2025},
|
| 247 |
url={https://huggingface.co/NoesisLab/Spartacus-1B-Instruct},
|
| 248 |
-
description={Replaces softmax attention with monoid state compression for constant-time, constant-memory autoregressive generation}
|
| 249 |
}
|
| 250 |
```
|
| 251 |
|
|
|
|
| 9 |
- linear-attention
|
| 10 |
- state-space
|
| 11 |
- O(1)-inference
|
| 12 |
+
- vector-decay
|
| 13 |
- reasoning
|
| 14 |
pipeline_tag: text-generation
|
| 15 |
model-index:
|
|
|
|
| 21 |
|
| 22 |
A 1.3B parameter language model that replaces softmax attention with **causal monoid state compression**, achieving **O(1) time per token** and **O(1) memory** at inference — regardless of sequence length.
|
| 23 |
|
|
|
|
|
|
|
| 24 |
## Monoid Attention — Internal Structure
|
| 25 |
|
| 26 |
```
|
| 27 |
MonoidAttention (per layer, per head)
|
| 28 |
+
┌─────────────────────────────────────────────────────────────────────────┐
|
| 29 |
+
│ │
|
| 30 |
+
│ x_t ∈ R^{2048} │
|
| 31 |
+
│ │ │
|
| 32 |
+
│ ├──> q_proj ──> RMSNorm ──> q_t ∈ R^d (query, scaled 1/√d) │
|
| 33 |
+
│ │ │
|
| 34 |
+
│ ├──> k_proj ──> RMSNorm ──> SiLU ──> k_t ∈ R^d (key, non-negative) │
|
| 35 |
+
│ │ │
|
| 36 |
+
│ ├──> v_proj ──> v_t ∈ R^d (value) │
|
| 37 |
+
│ │ │
|
| 38 |
+
│ └──> decay_proj ──> -Softplus ──> log α_t ∈ R^d (vector decay gate) │
|
| 39 |
+
│ │
|
| 40 |
+
│ k_t ⊗ v_t │
|
| 41 |
+
│ │ ┌─────────────────────────────────┐ │
|
| 42 |
+
│ │ │ State Matrix S_t ∈ R^{d x d} │ │
|
| 43 |
+
│ v │ "Compressed causal history" │ │
|
| 44 |
+
│ S_t = diag(α_t) · S_{t-1} + k_t ⊗ v_t │ │
|
| 45 |
+
│ │ │ α_t ∈ (0,1]^d per dimension │ │
|
| 46 |
+
│ │ └─────────────────────────────────┘ │
|
| 47 |
+
│ v │
|
| 48 |
+
│ o_t = q_t · S_t ──> o_proj ──> output │
|
| 49 |
+
│ │
|
| 50 |
+
└─────────────────────────────────────────────────────────────────────────┘
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
```
|
| 52 |
|
|
|
|
| 53 |
## Key Properties
|
| 54 |
|
| 55 |
| Property | Transformer (Llama) | Spartacus (Monoid) |
|
| 56 |
|---|---|---|
|
| 57 |
+
| Inference time per token | O(T) — scans full KV-cache | **O(1)** — single state update |
|
| 58 |
+
| Inference memory per layer | O(T) — stores all past K,V | **O(1)** — fixed d×d state matrix |
|
| 59 |
+
| Sequence length extrapolation | Degrades beyond training length | **Unlimited** — state size is constant |
|
| 60 |
| Causality | Imposed via attention mask | **Built into the recurrence** |
|
| 61 |
+
| Training complexity | O(T²) | **O(T)** via parallel prefix scan |
|
| 62 |
|
| 63 |
## The Monoid Recurrence
|
| 64 |
|
| 65 |
Standard attention computes:
|
| 66 |
|
| 67 |
```
|
| 68 |
+
o_t = Σ_{i≤t} softmax(q_t · k_i) v_i — requires O(T) KV-cache
|
| 69 |
```
|
| 70 |
|
| 71 |
Monoid attention compresses the entire causal history into a **fixed-size state matrix** S_t per head:
|
| 72 |
|
| 73 |
```
|
| 74 |
+
S_t = diag(α_t) · S_{t-1} + k_t ⊗ v_t — vector decay monoid recurrence
|
| 75 |
+
o_t = q_t · S_t — state readout
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
This is a monoid because the binary operator `(log_α, S) ⊕ (log_β, X) = (log_α + log_β, exp(log_β)·S + X)` is **associative**, enabling O(T) parallel prefix scan for training and O(1) sequential update for inference.
|
| 79 |
+
|
| 80 |
+
## Vector Decay — Per-Dimension Memory Lifetimes
|
| 81 |
+
|
| 82 |
+
Unlike scalar decay (one α per head), Spartacus uses **vector decay**: each dimension of the d-vector has its own independent decay rate α_t[i] ∈ (0, 1]:
|
| 83 |
+
|
| 84 |
+
```
|
| 85 |
+
S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
This allows different feature dimensions to specialize:
|
| 89 |
+
- **Fast-decaying dimensions** (α ≈ 0) — local syntax, punctuation, function words
|
| 90 |
+
- **Slow-decaying dimensions** (α ≈ 1) — entity memory, topic tracking, long-range facts
|
| 91 |
+
|
| 92 |
+
The decay gate uses **Negative Softplus** activation:
|
| 93 |
+
|
| 94 |
+
```
|
| 95 |
+
log α_t = -softplus(W·x_t + b)
|
| 96 |
```
|
| 97 |
|
| 98 |
+
| Property | Value |
|
| 99 |
+
|---|---|
|
| 100 |
+
| Range | α ∈ (0, 1] — bounded, no explosion |
|
| 101 |
+
| Perfect memory | W·x → -∞ ⟹ softplus → 0 ⟹ α → 1 (lossless retention) |
|
| 102 |
+
| Full forgetting | W·x → +∞ ⟹ softplus → ∞ ⟹ α → 0 (complete reset) |
|
| 103 |
+
| Stability | α ≤ 1 by construction — no divergence regardless of input magnitude |
|
| 104 |
+
|
| 105 |
+
## Attention Mask — Padding-Aware Recurrence
|
| 106 |
|
| 107 |
+
The monoid recurrence correctly handles `attention_mask` for padded batches (e.g., left-padding during `generate()`). For PAD positions (mask=0):
|
| 108 |
|
| 109 |
+
```
|
| 110 |
+
log_α = 0 → α = 1 (preserve state unchanged)
|
| 111 |
+
k = 0, v = 0 → kv = 0 (no information injected)
|
| 112 |
+
```
|
| 113 |
|
| 114 |
+
Net effect: `S_t = 1·S_{t-1} + 0 = S_{t-1}` — PAD acts as the **monoid identity element**, completely invisible to the recurrence. This ensures identical outputs whether inputs are padded or not.
|
|
|
|
|
|
|
| 115 |
|
| 116 |
## Design Choices
|
| 117 |
|
| 118 |
+
- **SiLU-activated keys**: `k = SiLU(k_proj(x))` ensures non-negative keys, making the state matrix S positive semi-definite (PSD). This prevents "feature erasure" where one token's contribution cancels another's
|
| 119 |
+
- **QK-Norm**: RMSNorm on both q and k before readout, stabilizing the scale of q·S when the state matrix accumulates many outer products
|
| 120 |
+
- **Log-space decay**: Working in log-space `log(α)` avoids numerical underflow when α^T → 0 for long sequences
|
| 121 |
+
- **Learnable h0**: The initial state S₀ = h0 is a learnable parameter (zero-initialized), acting as a compressed "system prompt"
|
| 122 |
+
- **Negative Softplus gate**: Ensures α ∈ (0, 1] by construction — allows perfect memory (α=1) while preventing state explosion (α>1)
|
| 123 |
+
|
| 124 |
+
## Three Forward Paths
|
| 125 |
+
|
| 126 |
+
| Path | Condition | Complexity | Description |
|
| 127 |
+
|---|---|---|---|
|
| 128 |
+
| Training | `use_cache=False` | O(T) parallel scan | Vectorized outer products → parallel prefix scan → vectorized readout |
|
| 129 |
+
| Inference prefill | `use_cache=True, T>1` | O(T) parallel scan | Same as training + extracts final state S_T for cache |
|
| 130 |
+
| Inference decode | `use_cache=True, T=1` | **O(1)** monoid_op | Single `monoid_op` to fold new token into state → one matmul readout |
|
| 131 |
|
| 132 |
## Model Details
|
| 133 |
|
|
|
|
| 141 |
| Layers | 16 |
|
| 142 |
| Attention heads | 32 |
|
| 143 |
| Head dimension | 64 |
|
| 144 |
+
| Decay gate | Vector decay, d=64 per head |
|
| 145 |
+
| State matrix per head | 64 × 64 = 4,096 floats |
|
| 146 |
| Vocabulary | 128,256 (Llama-3.2 tokenizer) |
|
| 147 |
| Precision | bfloat16 |
|
| 148 |
|
|
|
|
| 158 |
|
| 159 |
### Comparison with ~1B Baselines (acc_norm, 0-shot)
|
| 160 |
|
| 161 |
+
| Task | Spartacus-1B | TinyLlama-1.1B | Llama 3.2-1B | Mamba-1.4B | RWKV-6-1.6B |
|
| 162 |
|---|---|---|---|---|---|
|
| 163 |
| ARC-C | **0.3063** | 0.3268 | ~0.359 | 0.284 | ~0.301 |
|
| 164 |
| ARC-E | **0.5518** | 0.5547 | ~0.752 | 0.512 | ~0.530 |
|
|
|
|
| 168 |
|
| 169 |
> Spartacus achieves competitive performance with sub-quadratic models (Mamba, RWKV) while maintaining **O(1) inference time and memory per token**. Scores marked with ~ are approximate community-reported values.
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
## Parallel Scan Implementation
|
| 172 |
|
| 173 |
+
The `monoid_scan_cuda.py` module provides a Triton JIT-compiled parallel prefix scan for the vector-decay monoid:
|
| 174 |
|
| 175 |
+
- **Grid**: `(B*H*D_k, ceil(D_v/BLOCK_DV))` — one program per state matrix row
|
| 176 |
+
- **Forward**: Sequential scan along T per row, parallelized across all (batch, head, d_k) dimensions
|
| 177 |
+
- **Backward**: Reverse-order adjoint scan with per-row D_v reduction (minimal atomic_add)
|
| 178 |
- **Fallback**: Pure PyTorch sequential scan for CPU/MPS
|
| 179 |
+
- **Auto-dispatch**: CUDA → Triton kernel, otherwise → PyTorch fallback
|
| 180 |
|
| 181 |
## Usage
|
| 182 |
|
|
|
|
| 203 |
|
| 204 |
```
|
| 205 |
MonoidForCausalLM.py # Model architecture (MonoidConfig, MonoidAttention, MonoidForCausalLM)
|
| 206 |
+
monoid_scan_cuda.py # Triton JIT parallel prefix scan (vector decay) + PyTorch fallback
|
| 207 |
model.safetensors # Model weights (bfloat16)
|
| 208 |
config.json # Model configuration
|
| 209 |
tokenizer.json # Llama-3.2 tokenizer
|
|
|
|
| 217 |
author={NoesisLab},
|
| 218 |
year={2025},
|
| 219 |
url={https://huggingface.co/NoesisLab/Spartacus-1B-Instruct},
|
| 220 |
+
description={Replaces softmax attention with vector-decay monoid state compression for constant-time, constant-memory autoregressive generation}
|
| 221 |
}
|
| 222 |
```
|
| 223 |
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d5cd463898c4ce262d12fe56c6227d0c1117680aa13892f9cac6e100a1db9077
|
| 3 |
+
size 2811462896
|
monoid_scan_cuda.py
CHANGED
|
@@ -2,36 +2,33 @@
|
|
| 2 |
monoid_scan_cuda.py — Triton CUDA JIT Accelerated Parallel Prefix Scan
|
| 3 |
monoid_scan_cuda.py — Triton CUDA JIT 加速的并行前缀扫描
|
| 4 |
|
| 5 |
-
This module implements the parallel prefix scan for the monoid recurrence:
|
| 6 |
-
y_t = exp(log_decay_t) · y_{t-1} + x_t
|
| 7 |
-
|
| 8 |
-
y_t = exp(log_decay_t) · y_{t-1} + x_t
|
| 9 |
|
| 10 |
This is the computational backbone of Monoid Attention's state compression.
|
| 11 |
这是幺半群注意力状态压缩的计算骨干。
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
但因为 (log_α, S) ⊕ (log_β, X) = (log_α+log_β, exp(log_β)·S+X)
|
| 20 |
-
满足结合律, 我们可以通过并行归约树在 O(log T) 深度内计算所有前缀和 S_1..S_T,
|
| 21 |
-
而非 O(T) 的串行步骤。
|
| 22 |
-
|
| 23 |
-
Training uses O(T) parallel scan (this file).
|
| 24 |
-
Inference uses O(1) sequential monoid_op (in MonoidForCausalLM.py).
|
| 25 |
-
训练使用 O(T) 并行扫描 (本文件)。
|
| 26 |
-
推理使用 O(1) 串行 monoid_op (在 MonoidForCausalLM.py 中)。
|
| 27 |
|
| 28 |
Implementation:
|
| 29 |
-
Forward: sequential scan along T, parallelized across B*H*
|
|
|
|
|
|
|
| 30 |
Backward: reverse-order adjoint scan for gradient computation.
|
|
|
|
| 31 |
Auto-dispatches: CUDA → Triton kernel, CPU/MPS → PyTorch fallback.
|
| 32 |
|
| 33 |
-
前向: 沿 T 维顺序扫描, 跨 B*H*
|
|
|
|
| 34 |
反向: 逆序伴随变量扫描计算梯度。
|
|
|
|
| 35 |
自动分派: CUDA → Triton 核函数, CPU/MPS → PyTorch 回退。
|
| 36 |
"""
|
| 37 |
|
|
@@ -60,18 +57,18 @@ def _sequential_scan(log_decays: Tensor, values: Tensor) -> Tensor:
|
|
| 60 |
Pure PyTorch sequential scan fallback (when no CUDA / Triton available).
|
| 61 |
纯 PyTorch 串行扫描回退 (无 CUDA / Triton 时使用)。
|
| 62 |
|
| 63 |
-
Implements the monoid recurrence step by step:
|
| 64 |
acc_0 = 0
|
| 65 |
-
acc_t = exp(log_decay_t) · acc_{t-1} + values_t
|
| 66 |
This is O(T) sequential — correct but slow on GPU.
|
| 67 |
-
|
| 68 |
acc_0 = 0
|
| 69 |
-
acc_t = exp(log_decay_t) · acc_{t-1} + values_t
|
| 70 |
这是 O(T) 串行的 — 结果正确但在 GPU 上较慢。
|
| 71 |
|
| 72 |
Args:
|
| 73 |
-
log_decays: [B, H, T,
|
| 74 |
-
|
| 75 |
values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t to accumulate
|
| 76 |
待累积的外积 k_t⊗v_t
|
| 77 |
Returns:
|
|
@@ -84,17 +81,17 @@ def _sequential_scan(log_decays: Tensor, values: Tensor) -> Tensor:
|
|
| 84 |
# acc 代表 S_t — 时刻 t 的压缩因果状态
|
| 85 |
acc = torch.zeros(B, H, D_k, D_v, device=values.device, dtype=values.dtype)
|
| 86 |
for t in range(T):
|
| 87 |
-
# S_t = α_t · S_{t-1} + kv_t (
|
| 88 |
-
# S_t = α_t · S_{t-1} + kv_t (
|
| 89 |
-
decay_t = torch.exp(log_decays[:, :, t]).unsqueeze(-1) # [B,H,
|
| 90 |
acc = acc * decay_t + values[:, :, t]
|
| 91 |
out[:, :, t] = acc
|
| 92 |
return out
|
| 93 |
|
| 94 |
|
| 95 |
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 96 |
-
# Triton Kernels — GPU-accelerated scan
|
| 97 |
-
# Triton 核函数 — GPU 加速扫描
|
| 98 |
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 99 |
|
| 100 |
if HAS_TRITON:
|
|
@@ -102,147 +99,136 @@ if HAS_TRITON:
|
|
| 102 |
@triton.jit
|
| 103 |
def _scan_fwd_kernel(
|
| 104 |
LD_ptr, V_ptr, O_ptr,
|
| 105 |
-
T,
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
):
|
| 111 |
"""
|
| 112 |
-
Forward scan kernel — computes all prefix states S_1..S_T.
|
| 113 |
-
前向扫描核函数 — 计算所有前缀状态 S_1..S_T。
|
| 114 |
|
| 115 |
Parallelization strategy / 并行化策略:
|
| 116 |
-
- program_id(0) =
|
| 117 |
-
每个 (batch, head)
|
| 118 |
-
- program_id(1) =
|
| 119 |
-
每个
|
| 120 |
- Sequential loop over T (the causal recurrence is inherently sequential)
|
| 121 |
沿 T 维串行循环 (因果递推本质上是串行的)
|
| 122 |
|
| 123 |
-
Each program
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
| 126 |
|
| 127 |
-
|
| 128 |
-
B*H*ceil(
|
| 129 |
-
注意: 虽然 T 循环在每个 program 内是串行的,
|
| 130 |
-
但 B*H*ceil(D/BLOCK_D) 个 program 在 GPU 上并行运行。
|
| 131 |
"""
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
|
| 137 |
-
# acc = S_0 = 0 (identity element of the monoid)
|
| 138 |
-
# acc = S_0 = 0 (幺半群的单位元)
|
| 139 |
-
acc = tl.zeros([
|
| 140 |
|
| 141 |
-
ld_base = LD_ptr +
|
| 142 |
-
v_base = V_ptr +
|
| 143 |
-
o_base = O_ptr +
|
| 144 |
|
| 145 |
for t in range(T):
|
| 146 |
-
# Load log_decay
|
| 147 |
-
#
|
| 148 |
ld_val = tl.load(ld_base + t * s_ld_t).to(tl.float32)
|
| 149 |
decay = tl.exp(ld_val)
|
| 150 |
|
| 151 |
-
# Load kv_t (
|
| 152 |
-
# 加载 kv_t (
|
| 153 |
val = tl.load(
|
| 154 |
-
v_base + t * s_v_t +
|
| 155 |
-
mask=
|
| 156 |
).to(tl.float32)
|
| 157 |
|
| 158 |
-
# Core recurrence: S_t = α_t · S_{t-1} + kv_t
|
| 159 |
-
# 核心递推: S_t = α_t · S_{t-1} + kv_t
|
| 160 |
acc = acc * decay + val
|
| 161 |
|
| 162 |
-
# Store S_t
|
| 163 |
tl.store(
|
| 164 |
-
o_base + t * s_o_t +
|
| 165 |
-
acc, mask=
|
| 166 |
)
|
| 167 |
|
| 168 |
@triton.jit
|
| 169 |
def _scan_bwd_kernel(
|
| 170 |
LD_ptr, O_ptr, GO_ptr, GV_ptr, GLD_ptr,
|
| 171 |
-
T,
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
):
|
| 179 |
"""
|
| 180 |
-
Backward scan kernel — computes gradients via adjoint method.
|
| 181 |
-
反向扫描核函数 —
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
The
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
Gradients / 梯度:
|
| 192 |
-
∂L/∂x_t = λ_t (gradient w.r.t. input values)
|
| 193 |
-
(对输入值的梯度)
|
| 194 |
-
∂L/∂log_a_t = a_t · Σ_D(λ_t · y_{t-1}) (gradient w.r.t. log-decay)
|
| 195 |
-
(对对数衰减的梯度)
|
| 196 |
-
|
| 197 |
-
The gradient of log_decay is critical for training the decay gate:
|
| 198 |
-
it tells the model how to adjust each head's forgetting rate.
|
| 199 |
-
log_decay 的梯度对训练衰减门至关重要:
|
| 200 |
-
它告诉模型如何调整每个头的遗忘速率。
|
| 201 |
"""
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
|
| 207 |
# adj holds a_{t+1} · λ_{t+1}, initialized to 0 at the sequence end
|
| 208 |
# adj 保存 a_{t+1} · λ_{t+1}, 在序列末尾初始化为 0
|
| 209 |
-
adj = tl.zeros([
|
| 210 |
|
| 211 |
for t_rev in range(T):
|
| 212 |
t = T - 1 - t_rev # reverse time / 逆序时间
|
| 213 |
|
| 214 |
-
# Load ∂L/∂y_t (upstream gradient)
|
| 215 |
-
# 加载 ∂L/∂y_t (上游梯度)
|
| 216 |
go = tl.load(
|
| 217 |
-
GO_ptr +
|
| 218 |
-
mask=
|
| 219 |
).to(tl.float32)
|
| 220 |
|
| 221 |
# Adjoint: λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1}
|
| 222 |
# 伴随: λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1}
|
| 223 |
lam = go + adj
|
| 224 |
|
| 225 |
-
# ∂L/∂x_t = λ_t (gradient of values
|
|
|
|
| 226 |
tl.store(
|
| 227 |
-
GV_ptr +
|
| 228 |
-
lam, mask=
|
| 229 |
)
|
| 230 |
|
| 231 |
-
# ∂L/∂
|
| 232 |
-
#
|
| 233 |
-
#
|
| 234 |
-
|
| 235 |
-
# 教模型如何控制因果信息的保留。
|
| 236 |
-
ld_val = tl.load(LD_ptr + bh * s_ld_bh + t * s_ld_t).to(tl.float32)
|
| 237 |
a_t = tl.exp(ld_val)
|
| 238 |
|
| 239 |
if t > 0:
|
| 240 |
y_prev = tl.load(
|
| 241 |
-
O_ptr +
|
| 242 |
-
mask=
|
| 243 |
).to(tl.float32)
|
| 244 |
-
|
| 245 |
-
tl.atomic_add(GLD_ptr +
|
| 246 |
|
| 247 |
# Prepare for next step (t-1): adj = a_t · λ_t
|
| 248 |
# 为下一步 (t-1) 准备: adj = a_t · λ_t
|
|
@@ -255,78 +241,91 @@ if HAS_TRITON:
|
|
| 255 |
|
| 256 |
class _ParallelScanFn(Function):
|
| 257 |
"""
|
| 258 |
-
Custom autograd function for the parallel prefix scan.
|
| 259 |
-
并行前缀扫描的自定义 autograd
|
| 260 |
|
| 261 |
Forward: launches _scan_fwd_kernel to compute all prefix states.
|
|
|
|
| 262 |
Backward: launches _scan_bwd_kernel to compute gradients via adjoint method.
|
|
|
|
| 263 |
|
| 264 |
前向: 启动 _scan_fwd_kernel 计算所有前缀状态。
|
|
|
|
| 265 |
反向: 启动 _scan_bwd_kernel 通过伴随方法计算梯度。
|
|
|
|
| 266 |
"""
|
| 267 |
@staticmethod
|
| 268 |
def forward(ctx, log_decays: Tensor, values: Tensor) -> Tensor:
|
| 269 |
B, H, T, D_k, D_v = values.shape
|
| 270 |
-
D = D_k * D_v # flattened state dimension / 展平的状态维度
|
| 271 |
|
| 272 |
-
#
|
| 273 |
-
#
|
| 274 |
-
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
o_flat = torch.empty_like(v_flat)
|
| 277 |
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
# Grid: (
|
| 281 |
-
# 网格: (
|
| 282 |
-
grid = (
|
| 283 |
|
| 284 |
_scan_fwd_kernel[grid](
|
| 285 |
ld_flat, v_flat, o_flat,
|
| 286 |
-
T,
|
| 287 |
ld_flat.stride(0), ld_flat.stride(1),
|
| 288 |
v_flat.stride(0), v_flat.stride(1), v_flat.stride(2),
|
| 289 |
o_flat.stride(0), o_flat.stride(1), o_flat.stride(2),
|
| 290 |
-
|
| 291 |
)
|
| 292 |
|
| 293 |
# Save for backward: need log_decays and forward outputs y_t
|
| 294 |
# 为反向传播保存: 需要 log_decays 和前向输出 y_t
|
| 295 |
ctx.save_for_backward(ld_flat, o_flat)
|
| 296 |
-
ctx.shape_info = (B, H, T, D_k, D_v,
|
| 297 |
-
|
|
|
|
| 298 |
|
| 299 |
@staticmethod
|
| 300 |
def backward(ctx, grad_output: Tensor):
|
| 301 |
ld_flat, o_flat = ctx.saved_tensors
|
| 302 |
-
B, H, T, D_k, D_v,
|
| 303 |
|
| 304 |
-
|
|
|
|
| 305 |
gv_flat = torch.empty_like(go_flat)
|
| 306 |
-
# Use f32 for
|
| 307 |
-
# 使用 f32
|
| 308 |
-
gld_flat = torch.zeros(
|
| 309 |
|
| 310 |
-
grid = (
|
| 311 |
|
| 312 |
_scan_bwd_kernel[grid](
|
| 313 |
ld_flat, o_flat, go_flat, gv_flat, gld_flat,
|
| 314 |
-
T,
|
| 315 |
ld_flat.stride(0), ld_flat.stride(1),
|
| 316 |
o_flat.stride(0), o_flat.stride(1), o_flat.stride(2),
|
| 317 |
go_flat.stride(0), go_flat.stride(1), go_flat.stride(2),
|
| 318 |
gv_flat.stride(0), gv_flat.stride(1), gv_flat.stride(2),
|
| 319 |
gld_flat.stride(0), gld_flat.stride(1),
|
| 320 |
-
|
| 321 |
)
|
| 322 |
|
| 323 |
-
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
return grad_log_decays, grad_values
|
| 326 |
|
| 327 |
def _triton_parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
|
| 328 |
-
"""Triton-accelerated parallel scan entry point.
|
| 329 |
-
Triton
|
| 330 |
return _ParallelScanFn.apply(log_decays, values)
|
| 331 |
|
| 332 |
else:
|
|
@@ -339,15 +338,16 @@ else:
|
|
| 339 |
|
| 340 |
def parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
|
| 341 |
"""
|
| 342 |
-
Parallel prefix scan — computes all prefix monoid sums.
|
| 343 |
-
并行前缀扫描 —
|
| 344 |
|
| 345 |
This is the training-time workhorse of Monoid Attention.
|
| 346 |
-
It computes S_1, S_2, ..., S_T where
|
|
|
|
| 347 |
for ALL timesteps simultaneously.
|
| 348 |
这是幺半群注意力训练时的主力计算。
|
| 349 |
它同时计算所有时间步的 S_1, S_2, ..., S_T,
|
| 350 |
-
其中 S_t = α_t·S_{t-1} + kv_t。
|
| 351 |
|
| 352 |
Auto-dispatches based on device:
|
| 353 |
CUDA → Triton JIT kernel (fast, with custom backward)
|
|
@@ -357,8 +357,8 @@ def parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
|
|
| 357 |
CPU/MPS → PyTorch 串行扫描 (正确, 较慢)
|
| 358 |
|
| 359 |
Args:
|
| 360 |
-
log_decays: [B, H, T,
|
| 361 |
-
|
| 362 |
values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t
|
| 363 |
外积 k_t⊗v_t
|
| 364 |
Returns:
|
|
@@ -374,8 +374,8 @@ def parallel_scan_with_state(
|
|
| 374 |
log_decays: Tensor, values: Tensor,
|
| 375 |
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
| 376 |
"""
|
| 377 |
-
Parallel prefix scan + extract final state for inference handoff.
|
| 378 |
-
并行前缀扫描 +
|
| 379 |
|
| 380 |
Used during prefill: compute all training-time prefix states,
|
| 381 |
AND extract the final accumulated state S_T so that subsequent
|
|
@@ -389,22 +389,22 @@ def parallel_scan_with_state(
|
|
| 389 |
这是训练模式 (并行扫描) 和推理模式 (串行 monoid_op) 之间的桥梁。
|
| 390 |
|
| 391 |
Args:
|
| 392 |
-
log_decays: [B, H, T,
|
| 393 |
values: [B, H, T, D_k, D_v]
|
| 394 |
|
| 395 |
Returns:
|
| 396 |
output: [B, H, T, D_k, D_v] — all prefix states S_1..S_T
|
| 397 |
所有前缀状态
|
| 398 |
final_state: (log_acc, S_T) where
|
| 399 |
-
log_acc: [B, H,
|
| 400 |
-
|
| 401 |
-
final_state: [B, H, D_k, D_v]
|
| 402 |
-
|
| 403 |
"""
|
| 404 |
output = parallel_scan(log_decays, values)
|
| 405 |
-
# Sum all log-decays to get the total accumulated decay
|
| 406 |
-
# 对所有 log-decay
|
| 407 |
-
log_acc = log_decays.
|
| 408 |
# The last timestep's state IS the full causal summary
|
| 409 |
# 最后一个时间步的状态就是完整的因果摘要
|
| 410 |
final_state = output[:, :, -1] # [B, H, D_k, D_v]
|
|
|
|
| 2 |
monoid_scan_cuda.py — Triton CUDA JIT Accelerated Parallel Prefix Scan
|
| 3 |
monoid_scan_cuda.py — Triton CUDA JIT 加速的并行前缀扫描
|
| 4 |
|
| 5 |
+
This module implements the parallel prefix scan for the vector-decay monoid recurrence:
|
| 6 |
+
y_t[i,:] = exp(log_decay_t[i]) · y_{t-1}[i,:] + x_t[i,:]
|
| 7 |
+
本模块实现向量衰减幺半群递推的并行前缀扫描:
|
| 8 |
+
y_t[i,:] = exp(log_decay_t[i]) · y_{t-1}[i,:] + x_t[i,:]
|
| 9 |
|
| 10 |
This is the computational backbone of Monoid Attention's state compression.
|
| 11 |
这是幺半群注意力状态压缩的计算骨干。
|
| 12 |
|
| 13 |
+
Vector decay: each dimension of the D_k×D_v state matrix has its own
|
| 14 |
+
per-dimension decay rate α_t ∈ ℝ^{D_k}, enabling different feature
|
| 15 |
+
dimensions to have independent memory lifetimes (fast-decaying for
|
| 16 |
+
local syntax, slow-decaying for global entity memory).
|
| 17 |
+
向量衰减: D_k×D_v 状态矩阵的每个维度拥有独立的衰减率 α_t ∈ ℝ^{D_k},
|
| 18 |
+
使不同特征维度拥有独立的记忆生命周期 (快速衰减用于局部语法, 慢速衰减用于全局实体记忆)。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
Implementation:
|
| 21 |
+
Forward: sequential scan along T, parallelized across B*H*D_k on GPU.
|
| 22 |
+
Each program handles one row of the state matrix (D_v elements)
|
| 23 |
+
with a scalar decay per row.
|
| 24 |
Backward: reverse-order adjoint scan for gradient computation.
|
| 25 |
+
Per-row reduction for log_decay gradient (no atomic_add needed).
|
| 26 |
Auto-dispatches: CUDA → Triton kernel, CPU/MPS → PyTorch fallback.
|
| 27 |
|
| 28 |
+
前向: 沿 T 维顺序扫描, 跨 B*H*D_k 在 GPU 上并行。
|
| 29 |
+
每个 program 处理状态矩阵的一行 (D_v 个元素), 每行一个标量衰减。
|
| 30 |
反向: 逆序伴随变量扫描计算梯度。
|
| 31 |
+
逐行归约计算 log_decay 梯度 (无需 atomic_add)。
|
| 32 |
自动分派: CUDA → Triton 核函数, CPU/MPS → PyTorch 回退。
|
| 33 |
"""
|
| 34 |
|
|
|
|
| 57 |
Pure PyTorch sequential scan fallback (when no CUDA / Triton available).
|
| 58 |
纯 PyTorch 串行扫描回退 (无 CUDA / Triton 时使用)。
|
| 59 |
|
| 60 |
+
Implements the vector-decay monoid recurrence step by step:
|
| 61 |
acc_0 = 0
|
| 62 |
+
acc_t[i,:] = exp(log_decay_t[i]) · acc_{t-1}[i,:] + values_t[i,:]
|
| 63 |
This is O(T) sequential — correct but slow on GPU.
|
| 64 |
+
逐步实现向量衰减幺半群递推:
|
| 65 |
acc_0 = 0
|
| 66 |
+
acc_t[i,:] = exp(log_decay_t[i]) · acc_{t-1}[i,:] + values_t[i,:]
|
| 67 |
这是 O(T) 串行的 — 结果正确但在 GPU 上较慢。
|
| 68 |
|
| 69 |
Args:
|
| 70 |
+
log_decays: [B, H, T, D_k] — log of per-dimension per-step decay gates
|
| 71 |
+
每维度每步衰减门的对数
|
| 72 |
values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t to accumulate
|
| 73 |
待累积的外积 k_t⊗v_t
|
| 74 |
Returns:
|
|
|
|
| 81 |
# acc 代表 S_t — 时刻 t 的压缩因果状态
|
| 82 |
acc = torch.zeros(B, H, D_k, D_v, device=values.device, dtype=values.dtype)
|
| 83 |
for t in range(T):
|
| 84 |
+
# S_t = diag(α_t) · S_{t-1} + kv_t (vector decay monoid recurrence)
|
| 85 |
+
# S_t = diag(α_t) · S_{t-1} + kv_t (向量衰减幺半群递推)
|
| 86 |
+
decay_t = torch.exp(log_decays[:, :, t]).unsqueeze(-1) # [B,H,D_k,1]
|
| 87 |
acc = acc * decay_t + values[:, :, t]
|
| 88 |
out[:, :, t] = acc
|
| 89 |
return out
|
| 90 |
|
| 91 |
|
| 92 |
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 93 |
+
# Triton Kernels — GPU-accelerated scan (vector decay)
|
| 94 |
+
# Triton 核函数 — GPU 加速扫描 (向量衰减)
|
| 95 |
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
| 96 |
|
| 97 |
if HAS_TRITON:
|
|
|
|
| 99 |
@triton.jit
|
| 100 |
def _scan_fwd_kernel(
|
| 101 |
LD_ptr, V_ptr, O_ptr,
|
| 102 |
+
T, D_v,
|
| 103 |
+
s_ld_bhdk, s_ld_t,
|
| 104 |
+
s_v_bhdk, s_v_t, s_v_dv,
|
| 105 |
+
s_o_bhdk, s_o_t, s_o_dv,
|
| 106 |
+
BLOCK_DV: tl.constexpr,
|
| 107 |
):
|
| 108 |
"""
|
| 109 |
+
Forward scan kernel — computes all prefix states S_1..S_T (vector decay).
|
| 110 |
+
前向扫描核函数 — 计算所有前缀状态 S_1..S_T (向量衰减)。
|
| 111 |
|
| 112 |
Parallelization strategy / 并行化策略:
|
| 113 |
+
- program_id(0) = bhdk: one program per (batch, head, d_k row) triple
|
| 114 |
+
每个 (batch, head, d_k 行) 三元组一个 program
|
| 115 |
+
- program_id(1) = dvb: one program per D_v-dimension block (typically 1 block)
|
| 116 |
+
每个 D_v 维 block 一个 program (通常只有 1 个 block)
|
| 117 |
- Sequential loop over T (the causal recurrence is inherently sequential)
|
| 118 |
沿 T 维串行循环 (因果递推本质上是串行的)
|
| 119 |
|
| 120 |
+
Each program handles one row of the D_k×D_v state matrix, where the
|
| 121 |
+
decay is a single scalar per row. This eliminates the need for
|
| 122 |
+
row-index computation in the inner loop.
|
| 123 |
+
每个 program 处理 D_k×D_v 状态矩阵的一行, 该行的衰减是一个标量。
|
| 124 |
+
这消除了内循环中行索引计算的需要。
|
| 125 |
|
| 126 |
+
Grid: (B*H*D_k, ceil(D_v/BLOCK_DV))
|
| 127 |
+
网格: (B*H*D_k, ceil(D_v/BLOCK_DV))
|
|
|
|
|
|
|
| 128 |
"""
|
| 129 |
+
bhdk = tl.program_id(0)
|
| 130 |
+
dvb = tl.program_id(1)
|
| 131 |
+
dv_offs = dvb * BLOCK_DV + tl.arange(0, BLOCK_DV)
|
| 132 |
+
dv_mask = dv_offs < D_v
|
| 133 |
|
| 134 |
+
# acc = S_0[row,:] = 0 (identity element of the monoid)
|
| 135 |
+
# acc = S_0[行,:] = 0 (幺半群的单位元)
|
| 136 |
+
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
| 137 |
|
| 138 |
+
ld_base = LD_ptr + bhdk * s_ld_bhdk
|
| 139 |
+
v_base = V_ptr + bhdk * s_v_bhdk
|
| 140 |
+
o_base = O_ptr + bhdk * s_o_bhdk
|
| 141 |
|
| 142 |
for t in range(T):
|
| 143 |
+
# Load scalar log_decay for this row at time t
|
| 144 |
+
# 加载此行在时刻 t 的标量 log_decay
|
| 145 |
ld_val = tl.load(ld_base + t * s_ld_t).to(tl.float32)
|
| 146 |
decay = tl.exp(ld_val)
|
| 147 |
|
| 148 |
+
# Load kv_t[row, :] (one row of the outer product)
|
| 149 |
+
# 加载 kv_t[行, :] (外积的一行)
|
| 150 |
val = tl.load(
|
| 151 |
+
v_base + t * s_v_t + dv_offs * s_v_dv,
|
| 152 |
+
mask=dv_mask, other=0.0,
|
| 153 |
).to(tl.float32)
|
| 154 |
|
| 155 |
+
# Core recurrence: S_t[i,:] = α_t[i] · S_{t-1}[i,:] + kv_t[i,:]
|
| 156 |
+
# 核心递推: S_t[i,:] = α_t[i] · S_{t-1}[i,:] + kv_t[i,:]
|
| 157 |
acc = acc * decay + val
|
| 158 |
|
| 159 |
+
# Store S_t[row, :]
|
| 160 |
tl.store(
|
| 161 |
+
o_base + t * s_o_t + dv_offs * s_o_dv,
|
| 162 |
+
acc, mask=dv_mask,
|
| 163 |
)
|
| 164 |
|
| 165 |
@triton.jit
|
| 166 |
def _scan_bwd_kernel(
|
| 167 |
LD_ptr, O_ptr, GO_ptr, GV_ptr, GLD_ptr,
|
| 168 |
+
T, D_v,
|
| 169 |
+
s_ld_bhdk, s_ld_t,
|
| 170 |
+
s_o_bhdk, s_o_t, s_o_dv,
|
| 171 |
+
s_go_bhdk, s_go_t, s_go_dv,
|
| 172 |
+
s_gv_bhdk, s_gv_t, s_gv_dv,
|
| 173 |
+
s_gld_bhdk, s_gld_t,
|
| 174 |
+
BLOCK_DV: tl.constexpr,
|
| 175 |
):
|
| 176 |
"""
|
| 177 |
+
Backward scan kernel — computes gradients via adjoint method (vector decay).
|
| 178 |
+
反向扫描核函数 — 通过伴随方法计算梯度 (向量衰减)。
|
| 179 |
+
|
| 180 |
+
Each program handles one row of the state matrix (one d_k dimension).
|
| 181 |
+
The decay for this row is a scalar, so the log_decay gradient is:
|
| 182 |
+
∂L/∂log_α_t[i] = α_t[i] · Σ_j(λ_t[i,j] · y_{t-1}[i,j])
|
| 183 |
+
The sum over j (D_v) is computed within this single program — no atomic_add.
|
| 184 |
+
每个 program 处理状态矩阵的一行 (一个 d_k 维度)。
|
| 185 |
+
该行的衰减是标量, 因此 log_decay 梯度为:
|
| 186 |
+
∂L/∂log_α_t[i] = α_t[i] · Σ_j(λ_t[i,j] · y_{t-1}[i,j])
|
| 187 |
+
对 j (D_v) 的求和在单个 program 内完成 — 无需 atomic_add。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
"""
|
| 189 |
+
bhdk = tl.program_id(0)
|
| 190 |
+
dvb = tl.program_id(1)
|
| 191 |
+
dv_offs = dvb * BLOCK_DV + tl.arange(0, BLOCK_DV)
|
| 192 |
+
dv_mask = dv_offs < D_v
|
| 193 |
|
| 194 |
# adj holds a_{t+1} · λ_{t+1}, initialized to 0 at the sequence end
|
| 195 |
# adj 保存 a_{t+1} · λ_{t+1}, 在序列末尾初始化为 0
|
| 196 |
+
adj = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
| 197 |
|
| 198 |
for t_rev in range(T):
|
| 199 |
t = T - 1 - t_rev # reverse time / 逆序时间
|
| 200 |
|
| 201 |
+
# Load ∂L/∂y_t[row, :] (upstream gradient)
|
| 202 |
+
# 加载 ∂L/∂y_t[行, :] (上游梯度)
|
| 203 |
go = tl.load(
|
| 204 |
+
GO_ptr + bhdk * s_go_bhdk + t * s_go_t + dv_offs * s_go_dv,
|
| 205 |
+
mask=dv_mask, other=0.0,
|
| 206 |
).to(tl.float32)
|
| 207 |
|
| 208 |
# Adjoint: λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1}
|
| 209 |
# 伴随: λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1}
|
| 210 |
lam = go + adj
|
| 211 |
|
| 212 |
+
# ∂L/∂x_t[row,:] = λ_t (gradient of values)
|
| 213 |
+
# ∂L/∂x_t[行,:] = λ_t (值的梯度)
|
| 214 |
tl.store(
|
| 215 |
+
GV_ptr + bhdk * s_gv_bhdk + t * s_gv_t + dv_offs * s_gv_dv,
|
| 216 |
+
lam, mask=dv_mask,
|
| 217 |
)
|
| 218 |
|
| 219 |
+
# ∂L/∂log_α_t[i] = α_t[i] · Σ_j(λ_t[i,j] · y_{t-1}[i,j])
|
| 220 |
+
# Per-row scalar gradient: sum over D_v within this program.
|
| 221 |
+
# 逐行标量梯度: 在此 program 内对 D_v 求和。
|
| 222 |
+
ld_val = tl.load(LD_ptr + bhdk * s_ld_bhdk + t * s_ld_t).to(tl.float32)
|
|
|
|
|
|
|
| 223 |
a_t = tl.exp(ld_val)
|
| 224 |
|
| 225 |
if t > 0:
|
| 226 |
y_prev = tl.load(
|
| 227 |
+
O_ptr + bhdk * s_o_bhdk + (t - 1) * s_o_t + dv_offs * s_o_dv,
|
| 228 |
+
mask=dv_mask, other=0.0,
|
| 229 |
).to(tl.float32)
|
| 230 |
+
grad_ld = tl.sum(lam * y_prev) * a_t
|
| 231 |
+
tl.atomic_add(GLD_ptr + bhdk * s_gld_bhdk + t * s_gld_t, grad_ld)
|
| 232 |
|
| 233 |
# Prepare for next step (t-1): adj = a_t · λ_t
|
| 234 |
# 为下一步 (t-1) 准备: adj = a_t · λ_t
|
|
|
|
| 241 |
|
| 242 |
class _ParallelScanFn(Function):
|
| 243 |
"""
|
| 244 |
+
Custom autograd function for the parallel prefix scan (vector decay).
|
| 245 |
+
并行前缀扫描的自定义 autograd 函数 (向量衰减)。
|
| 246 |
|
| 247 |
Forward: launches _scan_fwd_kernel to compute all prefix states.
|
| 248 |
+
Grid: (B*H*D_k, ceil(D_v/BLOCK_DV)), one program per state row.
|
| 249 |
Backward: launches _scan_bwd_kernel to compute gradients via adjoint method.
|
| 250 |
+
Per-row reduction eliminates most atomic_add overhead.
|
| 251 |
|
| 252 |
前向: 启动 _scan_fwd_kernel 计算所有前缀状态。
|
| 253 |
+
网格: (B*H*D_k, ceil(D_v/BLOCK_DV)), 每行状态一个 program。
|
| 254 |
反向: 启动 _scan_bwd_kernel 通过伴随方法计算梯度。
|
| 255 |
+
逐行归约消除大部分 atomic_add 开销。
|
| 256 |
"""
|
| 257 |
@staticmethod
|
| 258 |
def forward(ctx, log_decays: Tensor, values: Tensor) -> Tensor:
|
| 259 |
B, H, T, D_k, D_v = values.shape
|
|
|
|
| 260 |
|
| 261 |
+
# Reshape for row-parallel kernel:
|
| 262 |
+
# log_decays: [B, H, T, D_k] → permute to [B, H, D_k, T] → [B*H*D_k, T]
|
| 263 |
+
# values: [B, H, T, D_k, D_v] → permute to [B, H, D_k, T, D_v] → [B*H*D_k, T, D_v]
|
| 264 |
+
# 为行并行核函数重塑:
|
| 265 |
+
# log_decays: [B, H, T, D_k] → 转置为 [B, H, D_k, T] → [B*H*D_k, T]
|
| 266 |
+
# values: [B, H, T, D_k, D_v] → 转置为 [B, H, D_k, T, D_v] → [B*H*D_k, T, D_v]
|
| 267 |
+
ld_flat = log_decays.permute(0, 1, 3, 2).contiguous().reshape(B * H * D_k, T)
|
| 268 |
+
v_flat = values.permute(0, 1, 3, 2, 4).contiguous().reshape(B * H * D_k, T, D_v)
|
| 269 |
o_flat = torch.empty_like(v_flat)
|
| 270 |
|
| 271 |
+
BHDK = B * H * D_k
|
| 272 |
+
BLOCK_DV = min(triton.next_power_of_2(D_v), 1024)
|
| 273 |
+
# Grid: (B*H*D_k, ceil(D_v/BLOCK_DV)) — one program per (batch, head, row, dv-block)
|
| 274 |
+
# 网格: (B*H*D_k, ceil(D_v/BLOCK_DV))
|
| 275 |
+
grid = (BHDK, triton.cdiv(D_v, BLOCK_DV))
|
| 276 |
|
| 277 |
_scan_fwd_kernel[grid](
|
| 278 |
ld_flat, v_flat, o_flat,
|
| 279 |
+
T, D_v,
|
| 280 |
ld_flat.stride(0), ld_flat.stride(1),
|
| 281 |
v_flat.stride(0), v_flat.stride(1), v_flat.stride(2),
|
| 282 |
o_flat.stride(0), o_flat.stride(1), o_flat.stride(2),
|
| 283 |
+
BLOCK_DV=BLOCK_DV,
|
| 284 |
)
|
| 285 |
|
| 286 |
# Save for backward: need log_decays and forward outputs y_t
|
| 287 |
# 为反向传播保存: 需要 log_decays 和前向输出 y_t
|
| 288 |
ctx.save_for_backward(ld_flat, o_flat)
|
| 289 |
+
ctx.shape_info = (B, H, T, D_k, D_v, BHDK, BLOCK_DV)
|
| 290 |
+
# Reshape back: [B*H*D_k, T, D_v] → [B, H, D_k, T, D_v] → [B, H, T, D_k, D_v]
|
| 291 |
+
return o_flat.reshape(B, H, D_k, T, D_v).permute(0, 1, 3, 2, 4).contiguous()
|
| 292 |
|
| 293 |
@staticmethod
|
| 294 |
def backward(ctx, grad_output: Tensor):
|
| 295 |
ld_flat, o_flat = ctx.saved_tensors
|
| 296 |
+
B, H, T, D_k, D_v, BHDK, BLOCK_DV = ctx.shape_info
|
| 297 |
|
| 298 |
+
# Permute grad_output to match row-parallel layout: [B,H,T,D_k,D_v] → [B*H*D_k, T, D_v]
|
| 299 |
+
go_flat = grad_output.permute(0, 1, 3, 2, 4).contiguous().reshape(BHDK, T, D_v)
|
| 300 |
gv_flat = torch.empty_like(go_flat)
|
| 301 |
+
# Use f32 for gradient accumulation precision
|
| 302 |
+
# 使用 f32 保证梯度累积的精度
|
| 303 |
+
gld_flat = torch.zeros(BHDK, T, device=ld_flat.device, dtype=torch.float32)
|
| 304 |
|
| 305 |
+
grid = (BHDK, triton.cdiv(D_v, BLOCK_DV))
|
| 306 |
|
| 307 |
_scan_bwd_kernel[grid](
|
| 308 |
ld_flat, o_flat, go_flat, gv_flat, gld_flat,
|
| 309 |
+
T, D_v,
|
| 310 |
ld_flat.stride(0), ld_flat.stride(1),
|
| 311 |
o_flat.stride(0), o_flat.stride(1), o_flat.stride(2),
|
| 312 |
go_flat.stride(0), go_flat.stride(1), go_flat.stride(2),
|
| 313 |
gv_flat.stride(0), gv_flat.stride(1), gv_flat.stride(2),
|
| 314 |
gld_flat.stride(0), gld_flat.stride(1),
|
| 315 |
+
BLOCK_DV=BLOCK_DV,
|
| 316 |
)
|
| 317 |
|
| 318 |
+
# Reshape gradients back to original layout
|
| 319 |
+
# 重塑梯度回原始布局
|
| 320 |
+
# gld: [B*H*D_k, T] → [B, H, D_k, T] → [B, H, T, D_k]
|
| 321 |
+
grad_log_decays = gld_flat.to(grad_output.dtype).reshape(B, H, D_k, T).permute(0, 1, 3, 2).contiguous()
|
| 322 |
+
# gv: [B*H*D_k, T, D_v] → [B, H, D_k, T, D_v] → [B, H, T, D_k, D_v]
|
| 323 |
+
grad_values = gv_flat.reshape(B, H, D_k, T, D_v).permute(0, 1, 3, 2, 4).contiguous()
|
| 324 |
return grad_log_decays, grad_values
|
| 325 |
|
| 326 |
def _triton_parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
|
| 327 |
+
"""Triton-accelerated parallel scan entry point (vector decay).
|
| 328 |
+
Triton 加速的并行扫描入口 (向量衰减)。"""
|
| 329 |
return _ParallelScanFn.apply(log_decays, values)
|
| 330 |
|
| 331 |
else:
|
|
|
|
| 338 |
|
| 339 |
def parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
|
| 340 |
"""
|
| 341 |
+
Parallel prefix scan — computes all prefix monoid sums (vector decay).
|
| 342 |
+
并行前缀扫描 — 计算所有前缀幺半群和 (向量衰减)。
|
| 343 |
|
| 344 |
This is the training-time workhorse of Monoid Attention.
|
| 345 |
+
It computes S_1, S_2, ..., S_T where
|
| 346 |
+
S_t[i,:] = α_t[i]·S_{t-1}[i,:] + kv_t[i,:]
|
| 347 |
for ALL timesteps simultaneously.
|
| 348 |
这是幺半群注意力训练时的主力计算。
|
| 349 |
它同时计算所有时间步的 S_1, S_2, ..., S_T,
|
| 350 |
+
其中 S_t[i,:] = α_t[i]·S_{t-1}[i,:] + kv_t[i,:]。
|
| 351 |
|
| 352 |
Auto-dispatches based on device:
|
| 353 |
CUDA → Triton JIT kernel (fast, with custom backward)
|
|
|
|
| 357 |
CPU/MPS → PyTorch 串行扫描 (正确, 较慢)
|
| 358 |
|
| 359 |
Args:
|
| 360 |
+
log_decays: [B, H, T, D_k] — log of per-dimension decay gates α_t
|
| 361 |
+
每维度衰减门 α_t 的对数
|
| 362 |
values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t
|
| 363 |
外积 k_t⊗v_t
|
| 364 |
Returns:
|
|
|
|
| 374 |
log_decays: Tensor, values: Tensor,
|
| 375 |
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
| 376 |
"""
|
| 377 |
+
Parallel prefix scan + extract final state for inference handoff (vector decay).
|
| 378 |
+
并行前缀扫描 + 提取最终状态用于推理切换 (向量衰减)。
|
| 379 |
|
| 380 |
Used during prefill: compute all training-time prefix states,
|
| 381 |
AND extract the final accumulated state S_T so that subsequent
|
|
|
|
| 389 |
这是训练模式 (并行扫描) 和推理模式 (串行 monoid_op) 之间的桥梁。
|
| 390 |
|
| 391 |
Args:
|
| 392 |
+
log_decays: [B, H, T, D_k]
|
| 393 |
values: [B, H, T, D_k, D_v]
|
| 394 |
|
| 395 |
Returns:
|
| 396 |
output: [B, H, T, D_k, D_v] — all prefix states S_1..S_T
|
| 397 |
所有前缀状态
|
| 398 |
final_state: (log_acc, S_T) where
|
| 399 |
+
log_acc: [B, H, D_k] — accumulated log-decay vector (for future monoid_op)
|
| 400 |
+
累积对数衰减向量 (供后续 monoid_op 使用)
|
| 401 |
+
final_state: [B, H, D_k, D_v] — S_T, the compressed causal summary
|
| 402 |
+
S_T, 压缩的因果摘要
|
| 403 |
"""
|
| 404 |
output = parallel_scan(log_decays, values)
|
| 405 |
+
# Sum all log-decays over T to get the total accumulated decay per dimension
|
| 406 |
+
# 对所有 log-decay 沿 T 求和得到每个维度的总累积衰减
|
| 407 |
+
log_acc = log_decays.sum(dim=2) # [B, H, D_k]
|
| 408 |
# The last timestep's state IS the full causal summary
|
| 409 |
# 最后一个时间步的状态就是完整的因果摘要
|
| 410 |
final_state = output[:, :, -1] # [B, H, D_k, D_v]
|
training_args.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:32938ebb880dc58fa7d6f8e45383c55e1d5d4352618531d62a28069918595445
|
| 3 |
+
size 6417
|