OzTianlu commited on
Commit
b6c0790
·
verified ·
1 Parent(s): ff3696b

Upload 11 files

Browse files
Files changed (5) hide show
  1. MonoidForCausalLM.py +93 -46
  2. README.md +88 -116
  3. model.safetensors +2 -2
  4. monoid_scan_cuda.py +164 -164
  5. training_args.bin +2 -2
MonoidForCausalLM.py CHANGED
@@ -14,11 +14,13 @@ Architecture / 架构概要:
14
 
15
  Monoid attention compresses the entire causal history into a
16
  fixed-size state matrix S_t ∈ ℝ^{d×d} per head:
17
- S_t = α_t · S_{t-1} + k_t ⊗ v_t (explicit causal recurrence)
18
- o_t = q_t · S_t (state readout)
 
19
  幺半群注意力将完整因果历史压缩到每个头一个固定大小的状态矩阵 S_t:
20
- S_t = α_t · S_{t-1} + k_t ⊗ v_t (显式因果递推)
21
- o_t = q_t · S_t (状态读出)
 
22
 
23
  This is a monoid because the binary operator:
24
  (log_α, S) ⊕ (log_β, X) = (log_α + log_β, exp(log_β)·S + X)
@@ -69,12 +71,12 @@ except ImportError:
69
  # Slower than the fused CUDA kernel but numerically identical.
70
 
71
  def parallel_scan(log_alpha: Tensor, kv: Tensor) -> Tensor:
72
- """Sequential prefix scan fallback: S_t = exp(log_α_t)·S_{t-1} + kv_t."""
73
  B, H, T, d1, d2 = kv.shape
74
  states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
75
  S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
76
  for t in range(T):
77
- decay = torch.exp(log_alpha[:, :, t]) # [B, H, 1]
78
  while decay.dim() < S.dim():
79
  decay = decay.unsqueeze(-1)
80
  S = S * decay + kv[:, :, t]
@@ -86,7 +88,7 @@ except ImportError:
86
  B, H, T, d1, d2 = kv.shape
87
  states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
88
  S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
89
- log_acc = torch.zeros(B, H, 1, device=log_alpha.device, dtype=log_alpha.dtype)
90
  for t in range(T):
91
  decay = torch.exp(log_alpha[:, :, t])
92
  while decay.dim() < S.dim():
@@ -217,11 +219,12 @@ def monoid_op(
217
  b: tuple[Tensor, Tensor],
218
  ) -> tuple[Tensor, Tensor]:
219
  """
220
- The monoid binary operator ⊕ on (log-space decay, state matrix) pairs.
221
- 幺半群二元算子 ⊕,作用于 (对数衰减, 状态矩阵) 对。
222
 
223
  Definition / 定义:
224
- (log_α, S) ⊕ (log_β, X) = (log_α + log_β, exp(log_β)·S + X)
 
225
 
226
  Why this is a monoid / 为什么这是幺半群:
227
  • Associativity / 结合律:
@@ -326,15 +329,19 @@ class MonoidAttention(nn.Module):
326
 
327
  # --- Decay gate (novel component, randomly initialized) ---
328
  # --- 衰减门 (全新组件, 随机初始化) ---
329
- # Projects hidden_size → num_heads, yielding one scalar α per head.
330
- # After sigmoid: α_t (0,1) controls per-head forgetting rate.
331
- # This is the key to *explicit causal modeling*: the model learns
332
- # a content-dependent decay, not a fixed positional bias.
333
- # hidden_size 投影到 num_heads, 每个头产生一个标量 α。
334
- # 经过 sigmoid 后: α_t (0,1) 控制每个头的遗忘速率。
335
- # 这是 *显式因果建模* 的关键: 模型学习的是内容相关的衰减,
336
- # 而非固定的位置偏置。
337
- self.decay_proj = nn.Linear(config.hidden_size, self.num_heads, bias=True)
 
 
 
 
338
 
339
  # --- QK-Norm (novel component, randomly initialized) ---
340
  # --- QK 归一化 (全新组件, 随机初始化) ---
@@ -358,12 +365,17 @@ class MonoidAttention(nn.Module):
358
  def forward(
359
  self,
360
  hidden_states: Tensor,
 
361
  monoid_cache: MonoidCache | None = None,
362
  use_cache: bool = False,
363
  ) -> tuple[Tensor, tuple[Tensor, Tensor] | None]:
364
  """
365
  Args:
366
  hidden_states: [B, T, hidden_size]
 
 
 
 
367
  monoid_cache: O(1) state cache for inference
368
  推理用 O(1) 状态缓存
369
  use_cache: whether to use/update the cache
@@ -399,13 +411,40 @@ class MonoidAttention(nn.Module):
399
  # PSD 保证信息单调积累。
400
  k = torch.nn.functional.silu(k)
401
 
402
- # --- Compute per-head decay gate α_t ---
403
- # --- 计算每头衰减门 α_t ---
404
- # sigmoid ensures α (0,1), then log-space for numerical stability.
405
- # sigmoid 保证 α ∈ (0,1), 然后转到对数空间保证数值稳定性。
406
- alpha = torch.sigmoid(self.decay_proj(hidden_states)) # [B,T,H]
407
- alpha = alpha.transpose(1, 2).unsqueeze(-1) # [B,H,T,1]
408
- log_alpha = torch.log(alpha.clamp(min=1e-6))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
  # ══════════════════════════════════════════════════════════
411
  # Inference path (RNN mode): O(1) per token per layer
@@ -427,7 +466,7 @@ class MonoidAttention(nn.Module):
427
  # Outer product: k_t ⊗ v_t ∈ ℝ^{H×d×d}
428
  # 外积: k_t ⊗ v_t ∈ ℝ^{H×d×d}
429
  kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, 0], v[:, :, 0])
430
- log_t = log_alpha[:, :, 0] # [B,H,1]
431
 
432
  prev = monoid_cache.get_state(self.layer_idx) if monoid_cache else None
433
  if prev is None:
@@ -467,16 +506,16 @@ class MonoidAttention(nn.Module):
467
  kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
468
  states, (log_acc, S_T) = parallel_scan_with_state(log_alpha, kv)
469
 
470
- # Add h0 contribution: S_t += (∏_{i=0}^{t} α_i) · h0
471
- # 叠加 h0 贡献: S_t += (∏_{i=0}^{t} α_i) · h0
472
- cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,1]
473
- h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,1,1]
474
  states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
475
 
476
  # Final state includes h0 contribution
477
  # 最终状态包含 h0 贡献
478
- total_h0_decay = torch.exp(log_acc).unsqueeze(-1) # [B,H,1,1]
479
- S_final = S_T + total_h0_decay * self.h0.squeeze(0) # [B,H,d,d] (squeeze batch dim of h0)
480
  # h0 is [1,H,d,d], squeeze(0) removed for clarity but expand also works
481
  final_state = (log_acc, S_final)
482
 
@@ -507,16 +546,16 @@ class MonoidAttention(nn.Module):
507
  # 向量化外积: 一次性计算所有 t 的 k_t ⊗ v_t
508
  kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
509
 
510
- # Parallel prefix scan: S_t = α_t·S_{t-1} + kv_t (from S=0)
511
- # 并行前缀扫描: S_t = α_t·S_{t-1} + kv_t (从 S=0 开始)
512
- # Keep log_alpha as [B,H,T,1] — CUDA kernel backward expects this shape.
513
- # 保持 log_alpha 为 [B,H,T,1] — CUDA kernel 反向传播需要此形状。
514
  states = parallel_scan(log_alpha, kv) # [B,H,T,d,d]
515
 
516
- # Add h0 contribution: S_t += (∏_{i=0}^{t} α_i) · h0
517
- # 叠加 h0 贡献: S_t += (∏_{i=0}^{t} α_i) · h0
518
- cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,1]
519
- h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,1,1]
520
  states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
521
 
522
  # Vectorized readout: o_t = q_t · S_t for all t at once
@@ -557,6 +596,7 @@ class MonoidDecoderLayer(nn.Module):
557
  def forward(
558
  self,
559
  hidden_states: Tensor,
 
560
  monoid_cache: MonoidCache | None = None,
561
  use_cache: bool = False,
562
  ) -> Tensor:
@@ -564,7 +604,7 @@ class MonoidDecoderLayer(nn.Module):
564
  # --- 注意力块 + 残差连接 ---
565
  residual = hidden_states
566
  hidden_states = self.input_layernorm(hidden_states)
567
- hidden_states, _ = self.self_attn(hidden_states, monoid_cache=monoid_cache, use_cache=use_cache)
568
  hidden_states = residual + hidden_states
569
 
570
  # --- FFN block with residual ---
@@ -600,7 +640,7 @@ class MonoidPreTrainedModel(PreTrainedModel):
600
  module.weight.data[module.padding_idx].zero_()
601
 
602
  if isinstance(module, MonoidAttention):
603
- nn.init.constant_(module.decay_proj.bias, 4.0)
604
 
605
  class MonoidModel(MonoidPreTrainedModel):
606
  """
@@ -625,6 +665,7 @@ class MonoidModel(MonoidPreTrainedModel):
625
  def forward(
626
  self,
627
  input_ids: Tensor | None = None,
 
628
  inputs_embeds: Tensor | None = None,
629
  monoid_cache: MonoidCache | None = None,
630
  use_cache: bool = False,
@@ -638,11 +679,12 @@ class MonoidModel(MonoidPreTrainedModel):
638
  hidden_states = self._gradient_checkpointing_func(
639
  layer.__call__,
640
  hidden_states,
 
641
  monoid_cache,
642
  use_cache,
643
  )
644
  else:
645
- hidden_states = layer(hidden_states, monoid_cache=monoid_cache, use_cache=use_cache)
646
 
647
  hidden_states = self.norm(hidden_states)
648
 
@@ -723,9 +765,13 @@ class MonoidForCausalLM(MonoidPreTrainedModel, GenerationMixin):
723
  # Cache exists → only feed the latest token (O(1) inference)
724
  # 缓存已存在 → 只需输入最新的 token (O(1) 推理)
725
  input_ids = input_ids[:, -1:]
 
 
 
726
 
727
  model_inputs = {
728
  "input_ids": input_ids,
 
729
  "monoid_cache": past_key_values,
730
  "use_cache": True,
731
  }
@@ -734,8 +780,8 @@ class MonoidForCausalLM(MonoidPreTrainedModel, GenerationMixin):
734
  def forward(
735
  self,
736
  input_ids: Tensor | None = None,
737
- attention_mask: Tensor | None = None, # kept for API compat; monoid ignores this
738
- # 保留 API 兼容性; 幺半群不使用
739
  position_ids: Tensor | None = None, # kept for API compat; monoid ignores this
740
  # 保留 API 兼容性; 幺半群不使用
741
  past_key_values: MonoidCache | None = None,
@@ -762,6 +808,7 @@ class MonoidForCausalLM(MonoidPreTrainedModel, GenerationMixin):
762
 
763
  outputs = self.model(
764
  input_ids=input_ids,
 
765
  inputs_embeds=inputs_embeds,
766
  monoid_cache=cache,
767
  use_cache=bool(use_cache),
 
14
 
15
  Monoid attention compresses the entire causal history into a
16
  fixed-size state matrix S_t ∈ ℝ^{d×d} per head:
17
+ S_t = diag(α_t) · S_{t-1} + k_t ⊗ v_t (vector decay recurrence)
18
+ o_t = q_t · S_t (state readout)
19
+ where α_t ∈ ℝ^d is a per-dimension vector decay gate.
20
  幺半群注意力将完整因果历史压缩到每个头一个固定大小的状态矩阵 S_t:
21
+ S_t = diag(α_t) · S_{t-1} + k_t ⊗ v_t (向量衰减递推)
22
+ o_t = q_t · S_t (状态读出)
23
+ 其中 α_t ∈ ℝ^d 是逐维度的向量衰减门。
24
 
25
  This is a monoid because the binary operator:
26
  (log_α, S) ⊕ (log_β, X) = (log_α + log_β, exp(log_β)·S + X)
 
71
  # Slower than the fused CUDA kernel but numerically identical.
72
 
73
  def parallel_scan(log_alpha: Tensor, kv: Tensor) -> Tensor:
74
+ """Sequential prefix scan fallback: S_t[i,:] = exp(log_α_t[i])·S_{t-1}[i,:] + kv_t[i,:]."""
75
  B, H, T, d1, d2 = kv.shape
76
  states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
77
  S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
78
  for t in range(T):
79
+ decay = torch.exp(log_alpha[:, :, t]) # [B, H, d]
80
  while decay.dim() < S.dim():
81
  decay = decay.unsqueeze(-1)
82
  S = S * decay + kv[:, :, t]
 
88
  B, H, T, d1, d2 = kv.shape
89
  states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
90
  S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
91
+ log_acc = torch.zeros(B, H, d1, device=log_alpha.device, dtype=log_alpha.dtype)
92
  for t in range(T):
93
  decay = torch.exp(log_alpha[:, :, t])
94
  while decay.dim() < S.dim():
 
219
  b: tuple[Tensor, Tensor],
220
  ) -> tuple[Tensor, Tensor]:
221
  """
222
+ The monoid binary operator ⊕ on (log-space vector decay, state matrix) pairs.
223
+ 幺半群二元算子 ⊕,作用于 (对数向量衰减, 状态矩阵) 对。
224
 
225
  Definition / 定义:
226
+ (log_α, S) ⊕ (log_β, X) = (log_α + log_β, diag(exp(log_β))·S + X)
227
+ where log_α, log_β ∈ ℝ^d are per-dimension log decay vectors.
228
 
229
  Why this is a monoid / 为什么这是幺半群:
230
  • Associativity / 结合律:
 
329
 
330
  # --- Decay gate (novel component, randomly initialized) ---
331
  # --- 衰减门 (全新组件, 随机初始化) ---
332
+ # Projects hidden_size → num_heads * head_dim, yielding a VECTOR per head.
333
+ # Activation: log_α = -softplus(Wx + b), giving α (0, 1].
334
+ # Vector decay: S_t = diag(α_t) · S_{t-1} + k_t v_t
335
+ # Different feature dimensions can have independent lifetimes:
336
+ # - fast-decaying dims for local syntax
337
+ # - slow-decaying dims for global entity/fact memory
338
+ # hidden_size 投影到 num_heads * head_dim, 每个头产生一个向量。
339
+ # 激活: log_α = -softplus(Wx + b), 使 α ∈ (0, 1]。
340
+ # 向量衰减: S_t = diag(α_t) · S_{t-1} + k_t ⊗ v_t
341
+ # 不同特征维度拥有独立的生命周期:
342
+ # - 快速衰减的维度负责局部语法结构
343
+ # - 慢速衰减的维度负责全局实体和事实记忆
344
+ self.decay_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=True)
345
 
346
  # --- QK-Norm (novel component, randomly initialized) ---
347
  # --- QK 归一化 (全新组件, 随机初始化) ---
 
365
  def forward(
366
  self,
367
  hidden_states: Tensor,
368
+ attention_mask: Tensor | None = None,
369
  monoid_cache: MonoidCache | None = None,
370
  use_cache: bool = False,
371
  ) -> tuple[Tensor, tuple[Tensor, Tensor] | None]:
372
  """
373
  Args:
374
  hidden_states: [B, T, hidden_size]
375
+ attention_mask: [B, T] with 1=real token, 0=pad.
376
+ For PAD positions: α=1 (preserve state), kv=0 (no contribution).
377
+ 掩码: 1=真实token, 0=填充。
378
+ 填充位置: α=1 (保持状态不变), kv=0 (无贡献)。
379
  monoid_cache: O(1) state cache for inference
380
  推理用 O(1) 状态缓存
381
  use_cache: whether to use/update the cache
 
411
  # PSD 保证信息单调积累。
412
  k = torch.nn.functional.silu(k)
413
 
414
+ # --- Compute per-dimension vector decay gate α_t ---
415
+ # --- 计算每维度向量衰减门 α_t ---
416
+ # Negative Softplus: log_α = -softplus(Wx + b)
417
+ # Value range: log_α ∈ (-∞, 0), i.e. α ∈ (0, 1].
418
+ # When Wx → -∞: softplus → 0, α → 1 (perfect memory, no forgetting)
419
+ # When Wx → +∞: softplus → Wx, α → 0 (complete forgetting)
420
+ # This avoids α > 1 explosion (unlike SiLU) while still allowing
421
+ # α = 1 for lossless memory (unlike Sigmoid which caps at <1).
422
+ # Each dimension of the d-vector decays independently:
423
+ # S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
424
+ #
425
+ # 负 Softplus: log_α = -softplus(Wx + b)
426
+ # 值域: log_α ∈ (-∞, 0), 即 α ∈ (0, 1]。
427
+ # 当 Wx → -∞: softplus → 0, α → 1 (完美记忆, 不遗忘)
428
+ # 当 Wx → +∞: softplus → Wx, α → 0 (完全遗忘)
429
+ # 避免了 SiLU 的 α > 1 爆炸, 同时允许 α = 1 无损记忆 (Sigmoid 无法做到)。
430
+ # d-向量的每个维度独立衰减:
431
+ # S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
432
+ raw = self.decay_proj(hidden_states) # [B,T,H*d]
433
+ log_alpha = -torch.nn.functional.softplus(raw) # [B,T,H*d]
434
+ log_alpha = log_alpha.view(B, T, H, d).transpose(1, 2) # [B,H,T,d]
435
+
436
+ # --- Apply attention_mask: PAD tokens must be invisible to the recurrence ---
437
+ # --- 应用注意力掩码: PAD token 必须对递推不可见 ---
438
+ # For PAD positions (mask=0): set log_α=0 (α=1, preserve state) and kv=0 (no contribution).
439
+ # This makes S_t = 1·S_{t-1} + 0 = S_{t-1}, i.e. PAD is a no-op on the state.
440
+ # 对于 PAD 位置 (mask=0): 设 log_α=0 (α=1, 保持状态) 且 kv=0 (无贡献)。
441
+ # 这使得 S_t = 1·S_{t-1} + 0 = S_{t-1}, 即 PAD 对状态是空操作。
442
+ if attention_mask is not None:
443
+ # attention_mask: [B, T] → [B, 1, T, 1] for broadcasting with [B, H, T, d]
444
+ mask = attention_mask[:, None, :, None].to(log_alpha.dtype) # [B,1,T,1]
445
+ log_alpha = log_alpha * mask # PAD → log_α=0 → α=1
446
+ k = k * mask # PAD → k=0
447
+ v = v * mask # PAD → v=0 → kv=0
448
 
449
  # ══════════════════════════════════════════════════════════
450
  # Inference path (RNN mode): O(1) per token per layer
 
466
  # Outer product: k_t ⊗ v_t ∈ ℝ^{H×d×d}
467
  # 外积: k_t ⊗ v_t ∈ ℝ^{H×d×d}
468
  kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, 0], v[:, :, 0])
469
+ log_t = log_alpha[:, :, 0] # [B,H,d]
470
 
471
  prev = monoid_cache.get_state(self.layer_idx) if monoid_cache else None
472
  if prev is None:
 
506
  kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
507
  states, (log_acc, S_T) = parallel_scan_with_state(log_alpha, kv)
508
 
509
+ # Add h0 contribution: S_t += diag(∏_{i=0}^{t} α_i) · h0
510
+ # 叠加 h0 贡献: S_t += diag(∏_{i=0}^{t} α_i) · h0
511
+ cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,d]
512
+ h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,d,1]
513
  states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
514
 
515
  # Final state includes h0 contribution
516
  # 最终状态包含 h0 贡献
517
+ total_h0_decay = torch.exp(log_acc).unsqueeze(-1) # [B,H,d,1]
518
+ S_final = S_T + total_h0_decay * self.h0.squeeze(0) # [B,H,d,d]
519
  # h0 is [1,H,d,d], squeeze(0) removed for clarity but expand also works
520
  final_state = (log_acc, S_final)
521
 
 
546
  # 向量化外积: 一次性计算所有 t 的 k_t ⊗ v_t
547
  kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
548
 
549
+ # Parallel prefix scan: S_t = diag(α_t)·S_{t-1} + kv_t (from S=0)
550
+ # 并行前缀扫描: S_t = diag(α_t)·S_{t-1} + kv_t (从 S=0 开始)
551
+ # log_alpha is [B,H,T,d] — vector decay per dimension.
552
+ # log_alpha 为 [B,H,T,d] — 每维度向量衰减。
553
  states = parallel_scan(log_alpha, kv) # [B,H,T,d,d]
554
 
555
+ # Add h0 contribution: S_t += diag(∏_{i=0}^{t} α_i) · h0
556
+ # 叠加 h0 贡献: S_t += diag(∏_{i=0}^{t} α_i) · h0
557
+ cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,d]
558
+ h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,d,1]
559
  states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
560
 
561
  # Vectorized readout: o_t = q_t · S_t for all t at once
 
596
  def forward(
597
  self,
598
  hidden_states: Tensor,
599
+ attention_mask: Tensor | None = None,
600
  monoid_cache: MonoidCache | None = None,
601
  use_cache: bool = False,
602
  ) -> Tensor:
 
604
  # --- 注意力块 + 残差连接 ---
605
  residual = hidden_states
606
  hidden_states = self.input_layernorm(hidden_states)
607
+ hidden_states, _ = self.self_attn(hidden_states, attention_mask=attention_mask, monoid_cache=monoid_cache, use_cache=use_cache)
608
  hidden_states = residual + hidden_states
609
 
610
  # --- FFN block with residual ---
 
640
  module.weight.data[module.padding_idx].zero_()
641
 
642
  if isinstance(module, MonoidAttention):
643
+ nn.init.constant_(module.decay_proj.bias, 1.0)
644
 
645
  class MonoidModel(MonoidPreTrainedModel):
646
  """
 
665
  def forward(
666
  self,
667
  input_ids: Tensor | None = None,
668
+ attention_mask: Tensor | None = None,
669
  inputs_embeds: Tensor | None = None,
670
  monoid_cache: MonoidCache | None = None,
671
  use_cache: bool = False,
 
679
  hidden_states = self._gradient_checkpointing_func(
680
  layer.__call__,
681
  hidden_states,
682
+ attention_mask,
683
  monoid_cache,
684
  use_cache,
685
  )
686
  else:
687
+ hidden_states = layer(hidden_states, attention_mask=attention_mask, monoid_cache=monoid_cache, use_cache=use_cache)
688
 
689
  hidden_states = self.norm(hidden_states)
690
 
 
765
  # Cache exists → only feed the latest token (O(1) inference)
766
  # 缓存已存在 → 只需输入最新的 token (O(1) 推理)
767
  input_ids = input_ids[:, -1:]
768
+ # Decode step: single real token, no PAD → mask not needed
769
+ # 解码步: 单个真实token, 无PAD → 不需要掩码
770
+ attention_mask = None
771
 
772
  model_inputs = {
773
  "input_ids": input_ids,
774
+ "attention_mask": attention_mask,
775
  "monoid_cache": past_key_values,
776
  "use_cache": True,
777
  }
 
780
  def forward(
781
  self,
782
  input_ids: Tensor | None = None,
783
+ attention_mask: Tensor | None = None, # [B,T] 1=real, 0=pad used to mask PAD from recurrence
784
+ # [B,T] 1=真实token, 0=填充 — 用于屏蔽PAD对递推的影响
785
  position_ids: Tensor | None = None, # kept for API compat; monoid ignores this
786
  # 保留 API 兼容性; 幺半群不使用
787
  past_key_values: MonoidCache | None = None,
 
808
 
809
  outputs = self.model(
810
  input_ids=input_ids,
811
+ attention_mask=attention_mask,
812
  inputs_embeds=inputs_embeds,
813
  monoid_cache=cache,
814
  use_cache=bool(use_cache),
README.md CHANGED
@@ -9,6 +9,7 @@ tags:
9
  - linear-attention
10
  - state-space
11
  - O(1)-inference
 
12
  - reasoning
13
  pipeline_tag: text-generation
14
  model-index:
@@ -20,124 +21,113 @@ model-index:
20
 
21
  A 1.3B parameter language model that replaces softmax attention with **causal monoid state compression**, achieving **O(1) time per token** and **O(1) memory** at inference — regardless of sequence length.
22
 
23
- Fine-tuned for enhanced reasoning with structured chain-of-thought data.
24
-
25
  ## Monoid Attention — Internal Structure
26
 
27
  ```
28
  MonoidAttention (per layer, per head)
29
- ┌─────────────────────────────────────────────────────────────────────┐
30
-
31
- │ x_t ∈ R^{2048}
32
- │ │
33
- │ ├──> q_proj ──> RMSNorm ──> q_t ∈ R^{d} (query)
34
- │ │
35
- │ ├──> k_proj ──> RMSNorm ──> SiLU ──> k_t ∈ R^{d} (key, >= 0) │
36
- │ │
37
- │ ├──> v_proj ──> v_t ∈ R^{d} (value)
38
- │ │
39
- │ └──> decay_proj ──> sigmoid ──> alpha_t ∈ (0,1) (decay gate)
40
-
41
- │ k_t (x) v_t
42
- │ │ ┌──────────────────────────────┐
43
- │ │ │ State Matrix S_t ∈ R^{d x d} │ │
44
- │ v │ │
45
- │ S_t = alpha_t * S_{t-1} + k_t (x) v_t │ │
46
- │ │ "Compressed causal history" │ │
47
- │ │ └──────────────────────────────┘
48
- │ v
49
- │ o_t = q_t . S_t ──> o_proj ──> output
50
-
51
- └─────────────────────────────────────────────────────────────────────┘
52
- ```
53
-
54
- ## Monoid State Diagonal — O(1) Compression Contour
55
-
56
- The state matrix `S_t` accumulates causal history along its diagonal. Each head maintains an independent `d x d` state that compresses ALL past tokens into a fixed footprint:
57
-
58
- ```
59
- State Matrix S_t ∈ R^{64 x 64} (one per head, 32 heads per layer)
60
-
61
- k-dim -->
62
- 0 8 16 24 32 40 48 56 63
63
- ┌───┬───┬───┬───┬───┬───┬───┬───┐ 0
64
- │***│** │* │ │ │ │ │ │ v-dim
65
- │***│** │* │. │ │ │ │ │ |
66
- ├───┼───┼───┼───┼───┼───┼───┼───┤ 8 |
67
- │** │***│** │* │. │ │ │ │ v
68
- │* │***│** │* │. │ │ │ │
69
- ├───┼───┼───┼───┼───┼───┼───┼───┤ 16
70
- │* │** │***│** │* │. │ │ │
71
- │. │* │***│** │* │. │ │ │
72
- ├───┼───┼───┼───┼───┼───┼───┼───┤ 24
73
- │ │. │** │***│** │* │. │ │
74
- │ │ │* │***│** │* │. │ │
75
- ├───┼───┼───┼───┼───┼───┼───┼───┤ 32
76
- │ │ │. │** │***│** │* │. │
77
- │ │ │ │* │***│** │* │. │
78
- ├───┼───┼───┼───┼───┼───┼───┼───┤ 40
79
- │ │ │ │. │** │***│** │* │
80
- │ │ │ │ │* │***│** │* │
81
- ├───┼───┼───┼───┼───┼───┼───┼───┤ 48
82
- │ │ │ │ │. │** │***│** │
83
- │ │ │ │ │ │* │***│** │
84
- ├───┼───┼───┼───┼───┼───┼───┼───┤ 56
85
- │ │ │ │ │ │. │** │***│
86
- │ │ │ │ │ │ │* │***│
87
- └───┴───┴───┴───┴───┴─��─┴───┴───┘ 63
88
-
89
- Legend: *** = high activation (recent tokens, alpha^0 ~ alpha^2)
90
- ** = medium (alpha^3 ~ alpha^5)
91
- * = fading (alpha^6 ~ alpha^10)
92
- . = near-zero (alpha^11+, effectively forgotten)
93
- = zero (never reached or fully decayed)
94
-
95
- The diagonal band emerges because S_t = SUM_{i<=t} alpha^{t-i} * k_i (x) v_i.
96
- Recent outer products dominate near the diagonal; older ones decay
97
- exponentially via alpha, creating this characteristic contour.
98
  ```
99
 
100
-
101
  ## Key Properties
102
 
103
  | Property | Transformer (Llama) | Spartacus (Monoid) |
104
  |---|---|---|
105
- | Inference time per token | O(T) -- scans full KV-cache | **O(1)** -- single state update |
106
- | Inference memory per layer | O(T) -- stores all past K,V | **O(1)** -- fixed d x d state matrix |
107
- | Sequence length extrapolation | Degrades beyond training length | **Unlimited** -- state size is constant |
108
  | Causality | Imposed via attention mask | **Built into the recurrence** |
109
- | Training complexity | O(T^2) | **O(T)** via parallel prefix scan |
110
 
111
  ## The Monoid Recurrence
112
 
113
  Standard attention computes:
114
 
115
  ```
116
- o_t = sum_{i<=t} softmax(q_t . k_i) v_i -- requires O(T) KV-cache
117
  ```
118
 
119
  Monoid attention compresses the entire causal history into a **fixed-size state matrix** S_t per head:
120
 
121
  ```
122
- S_t = alpha_t * S_{t-1} + k_t (x) v_t -- explicit causal recurrence
123
- o_t = q_t . S_t -- state readout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  ```
125
 
126
- where `alpha_t = sigmoid(decay_proj(x_t))` is a learned, content-dependent decay gate that controls how fast past information fades.
 
 
 
 
 
 
 
127
 
128
- ## Explicit Causal Modeling
129
 
130
- Unlike Transformers where causality is a constraint imposed by masking, Spartacus makes causality a **first-class citizen**:
 
 
 
131
 
132
- - The decay gate `alpha_t` explicitly controls per-head information retention at every timestep
133
- - The model learns **when to forget** rather than encoding **where tokens are** (no positional encoding needed)
134
- - No attention mask required -- causality is structural, not enforced
135
 
136
  ## Design Choices
137
 
138
- - **SiLU-activated keys**: `k = SiLU(k_proj(x))` ensures non-negative keys, making the state matrix `S` positive semi-definite (PSD). This prevents "feature erasure" where one token's contribution cancels another's
139
- - **Log-space decay**: Working in log-space `log(alpha)` avoids numerical underflow when `alpha^T -> 0` for long sequences
140
- - **Learnable h0**: The initial state `S_0 = h0` is a learnable parameter (zero-initialized), acting as a compressed "system prompt"
 
 
 
 
 
 
 
 
 
 
141
 
142
  ## Model Details
143
 
@@ -151,7 +141,8 @@ Unlike Transformers where causality is a constraint imposed by masking, Spartacu
151
  | Layers | 16 |
152
  | Attention heads | 32 |
153
  | Head dimension | 64 |
154
- | State matrix per head | 64 x 64 = 4096 floats |
 
155
  | Vocabulary | 128,256 (Llama-3.2 tokenizer) |
156
  | Precision | bfloat16 |
157
 
@@ -167,7 +158,7 @@ Unlike Transformers where causality is a constraint imposed by masking, Spartacu
167
 
168
  ### Comparison with ~1B Baselines (acc_norm, 0-shot)
169
 
170
- | Task | Spartacus-1B-Instruct | TinyLlama-1.1B | Llama 3.2-1B | Mamba-1.4B | RWKV-6-1.6B |
171
  |---|---|---|---|---|---|
172
  | ARC-C | **0.3063** | 0.3268 | ~0.359 | 0.284 | ~0.301 |
173
  | ARC-E | **0.5518** | 0.5547 | ~0.752 | 0.512 | ~0.530 |
@@ -177,34 +168,15 @@ Unlike Transformers where causality is a constraint imposed by masking, Spartacu
177
 
178
  > Spartacus achieves competitive performance with sub-quadratic models (Mamba, RWKV) while maintaining **O(1) inference time and memory per token**. Scores marked with ~ are approximate community-reported values.
179
 
180
- ## Training
181
-
182
- ### Stage 1: General SFT
183
-
184
- - **Base weights**: Transferred from Llama-3.2-1B-Instruct (embeddings, MLP, norms)
185
- - **Data**: Capybara + smol-smoltalk (general conversation)
186
- - **Training**: Full-parameter SFT
187
-
188
- ### Stage 2: Reasoning Enhancement
189
-
190
- - **Data mix**: 60% Qwen3-Short-Reasoning + 20% Capybara + 20% smol-smoltalk
191
- - **Steps**: 2,000
192
- - **Learning rate**: 2e-5 (cosine schedule, 50 warmup steps)
193
- - **Batch size**: 8
194
- - **Sequence length**: 2,048
195
- - **Precision**: bfloat16
196
- - **Optimizer**: AdamW (weight decay 0.01, max grad norm 1.0)
197
-
198
- The reasoning data uses structured "Thought + Solution" format to strengthen chain-of-thought capabilities while the general data prevents catastrophic forgetting.
199
-
200
  ## Parallel Scan Implementation
201
 
202
- The `monoid_scan_cuda.py` module provides a Triton JIT-compiled parallel prefix scan:
203
 
204
- - **Forward**: Sequential scan along T, parallelized across B x H x D on GPU via Triton kernels
205
- - **Backward**: Reverse-order adjoint scan computes gradients for both values and log-decay gates
 
206
  - **Fallback**: Pure PyTorch sequential scan for CPU/MPS
207
- - **Auto-dispatch**: CUDA -> Triton kernel, otherwise -> PyTorch fallback
208
 
209
  ## Usage
210
 
@@ -231,7 +203,7 @@ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
231
 
232
  ```
233
  MonoidForCausalLM.py # Model architecture (MonoidConfig, MonoidAttention, MonoidForCausalLM)
234
- monoid_scan_cuda.py # Triton JIT parallel prefix scan + PyTorch fallback
235
  model.safetensors # Model weights (bfloat16)
236
  config.json # Model configuration
237
  tokenizer.json # Llama-3.2 tokenizer
@@ -245,7 +217,7 @@ tokenizer.json # Llama-3.2 tokenizer
245
  author={NoesisLab},
246
  year={2025},
247
  url={https://huggingface.co/NoesisLab/Spartacus-1B-Instruct},
248
- description={Replaces softmax attention with monoid state compression for constant-time, constant-memory autoregressive generation}
249
  }
250
  ```
251
 
 
9
  - linear-attention
10
  - state-space
11
  - O(1)-inference
12
+ - vector-decay
13
  - reasoning
14
  pipeline_tag: text-generation
15
  model-index:
 
21
 
22
  A 1.3B parameter language model that replaces softmax attention with **causal monoid state compression**, achieving **O(1) time per token** and **O(1) memory** at inference — regardless of sequence length.
23
 
 
 
24
  ## Monoid Attention — Internal Structure
25
 
26
  ```
27
  MonoidAttention (per layer, per head)
28
+ ┌─────────────────────────────────────────────────────────────────────────┐
29
+
30
+ │ x_t ∈ R^{2048}
31
+ │ │
32
+ │ ├──> q_proj ──> RMSNorm ──> q_t ∈ R^d (query, scaled 1/√d)
33
+ │ │
34
+ │ ├──> k_proj ──> RMSNorm ──> SiLU ──> k_t ∈ R^d (key, non-negative) │
35
+ │ │
36
+ │ ├──> v_proj ──> v_t ∈ R^d (value)
37
+ │ │
38
+ │ └──> decay_proj ──> -Softplus ──> log α_t R^d (vector decay gate)
39
+
40
+ │ k_t v_t
41
+ │ │ ┌─────────────────────────────────┐
42
+ │ │ │ State Matrix S_t ∈ R^{d x d} │ │
43
+ │ v "Compressed causal history" │ │
44
+ │ S_t = diag(α_t) · S_{t-1} + k_t v_t │ │
45
+ │ │ α_t (0,1]^d per dimension │ │
46
+ │ │ └─────────────────────────────────┘
47
+ │ v
48
+ │ o_t = q_t · S_t ──> o_proj ──> output
49
+
50
+ └─────────────────────────────────────────────────────────────────────────┘
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  ```
52
 
 
53
  ## Key Properties
54
 
55
  | Property | Transformer (Llama) | Spartacus (Monoid) |
56
  |---|---|---|
57
+ | Inference time per token | O(T) scans full KV-cache | **O(1)** single state update |
58
+ | Inference memory per layer | O(T) stores all past K,V | **O(1)** fixed d×d state matrix |
59
+ | Sequence length extrapolation | Degrades beyond training length | **Unlimited** state size is constant |
60
  | Causality | Imposed via attention mask | **Built into the recurrence** |
61
+ | Training complexity | O(T²) | **O(T)** via parallel prefix scan |
62
 
63
  ## The Monoid Recurrence
64
 
65
  Standard attention computes:
66
 
67
  ```
68
+ o_t = Σ_{it} softmax(q_t · k_i) v_i requires O(T) KV-cache
69
  ```
70
 
71
  Monoid attention compresses the entire causal history into a **fixed-size state matrix** S_t per head:
72
 
73
  ```
74
+ S_t = diag(α_t) · S_{t-1} + k_t v_t vector decay monoid recurrence
75
+ o_t = q_t · S_t state readout
76
+ ```
77
+
78
+ This is a monoid because the binary operator `(log_α, S) ⊕ (log_β, X) = (log_α + log_β, exp(log_β)·S + X)` is **associative**, enabling O(T) parallel prefix scan for training and O(1) sequential update for inference.
79
+
80
+ ## Vector Decay — Per-Dimension Memory Lifetimes
81
+
82
+ Unlike scalar decay (one α per head), Spartacus uses **vector decay**: each dimension of the d-vector has its own independent decay rate α_t[i] ∈ (0, 1]:
83
+
84
+ ```
85
+ S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j]
86
+ ```
87
+
88
+ This allows different feature dimensions to specialize:
89
+ - **Fast-decaying dimensions** (α ≈ 0) — local syntax, punctuation, function words
90
+ - **Slow-decaying dimensions** (α ≈ 1) — entity memory, topic tracking, long-range facts
91
+
92
+ The decay gate uses **Negative Softplus** activation:
93
+
94
+ ```
95
+ log α_t = -softplus(W·x_t + b)
96
  ```
97
 
98
+ | Property | Value |
99
+ |---|---|
100
+ | Range | α ∈ (0, 1] — bounded, no explosion |
101
+ | Perfect memory | W·x → -∞ ⟹ softplus → 0 ⟹ α → 1 (lossless retention) |
102
+ | Full forgetting | W·x → +∞ ⟹ softplus → ∞ ⟹ α → 0 (complete reset) |
103
+ | Stability | α ≤ 1 by construction — no divergence regardless of input magnitude |
104
+
105
+ ## Attention Mask — Padding-Aware Recurrence
106
 
107
+ The monoid recurrence correctly handles `attention_mask` for padded batches (e.g., left-padding during `generate()`). For PAD positions (mask=0):
108
 
109
+ ```
110
+ log_α = 0 → α = 1 (preserve state unchanged)
111
+ k = 0, v = 0 → kv = 0 (no information injected)
112
+ ```
113
 
114
+ Net effect: `S_t = 1·S_{t-1} + 0 = S_{t-1}` PAD acts as the **monoid identity element**, completely invisible to the recurrence. This ensures identical outputs whether inputs are padded or not.
 
 
115
 
116
  ## Design Choices
117
 
118
+ - **SiLU-activated keys**: `k = SiLU(k_proj(x))` ensures non-negative keys, making the state matrix S positive semi-definite (PSD). This prevents "feature erasure" where one token's contribution cancels another's
119
+ - **QK-Norm**: RMSNorm on both q and k before readout, stabilizing the scale of q·S when the state matrix accumulates many outer products
120
+ - **Log-space decay**: Working in log-space `log(α)` avoids numerical underflow when α^T 0 for long sequences
121
+ - **Learnable h0**: The initial state S₀ = h0 is a learnable parameter (zero-initialized), acting as a compressed "system prompt"
122
+ - **Negative Softplus gate**: Ensures α ∈ (0, 1] by construction — allows perfect memory (α=1) while preventing state explosion (α>1)
123
+
124
+ ## Three Forward Paths
125
+
126
+ | Path | Condition | Complexity | Description |
127
+ |---|---|---|---|
128
+ | Training | `use_cache=False` | O(T) parallel scan | Vectorized outer products → parallel prefix scan → vectorized readout |
129
+ | Inference prefill | `use_cache=True, T>1` | O(T) parallel scan | Same as training + extracts final state S_T for cache |
130
+ | Inference decode | `use_cache=True, T=1` | **O(1)** monoid_op | Single `monoid_op` to fold new token into state → one matmul readout |
131
 
132
  ## Model Details
133
 
 
141
  | Layers | 16 |
142
  | Attention heads | 32 |
143
  | Head dimension | 64 |
144
+ | Decay gate | Vector decay, d=64 per head |
145
+ | State matrix per head | 64 × 64 = 4,096 floats |
146
  | Vocabulary | 128,256 (Llama-3.2 tokenizer) |
147
  | Precision | bfloat16 |
148
 
 
158
 
159
  ### Comparison with ~1B Baselines (acc_norm, 0-shot)
160
 
161
+ | Task | Spartacus-1B | TinyLlama-1.1B | Llama 3.2-1B | Mamba-1.4B | RWKV-6-1.6B |
162
  |---|---|---|---|---|---|
163
  | ARC-C | **0.3063** | 0.3268 | ~0.359 | 0.284 | ~0.301 |
164
  | ARC-E | **0.5518** | 0.5547 | ~0.752 | 0.512 | ~0.530 |
 
168
 
169
  > Spartacus achieves competitive performance with sub-quadratic models (Mamba, RWKV) while maintaining **O(1) inference time and memory per token**. Scores marked with ~ are approximate community-reported values.
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  ## Parallel Scan Implementation
172
 
173
+ The `monoid_scan_cuda.py` module provides a Triton JIT-compiled parallel prefix scan for the vector-decay monoid:
174
 
175
+ - **Grid**: `(B*H*D_k, ceil(D_v/BLOCK_DV))` one program per state matrix row
176
+ - **Forward**: Sequential scan along T per row, parallelized across all (batch, head, d_k) dimensions
177
+ - **Backward**: Reverse-order adjoint scan with per-row D_v reduction (minimal atomic_add)
178
  - **Fallback**: Pure PyTorch sequential scan for CPU/MPS
179
+ - **Auto-dispatch**: CUDA Triton kernel, otherwise PyTorch fallback
180
 
181
  ## Usage
182
 
 
203
 
204
  ```
205
  MonoidForCausalLM.py # Model architecture (MonoidConfig, MonoidAttention, MonoidForCausalLM)
206
+ monoid_scan_cuda.py # Triton JIT parallel prefix scan (vector decay) + PyTorch fallback
207
  model.safetensors # Model weights (bfloat16)
208
  config.json # Model configuration
209
  tokenizer.json # Llama-3.2 tokenizer
 
217
  author={NoesisLab},
218
  year={2025},
219
  url={https://huggingface.co/NoesisLab/Spartacus-1B-Instruct},
220
+ description={Replaces softmax attention with vector-decay monoid state compression for constant-time, constant-memory autoregressive generation}
221
  }
222
  ```
223
 
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a32bdaef8ff9aa5cba602faa280ab5d3526515a6abd97d411c3448f9b9ebcdc7
3
- size 2679277744
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5cd463898c4ce262d12fe56c6227d0c1117680aa13892f9cac6e100a1db9077
3
+ size 2811462896
monoid_scan_cuda.py CHANGED
@@ -2,36 +2,33 @@
2
  monoid_scan_cuda.py — Triton CUDA JIT Accelerated Parallel Prefix Scan
3
  monoid_scan_cuda.py — Triton CUDA JIT 加速的并行前缀扫描
4
 
5
- This module implements the parallel prefix scan for the monoid recurrence:
6
- y_t = exp(log_decay_t) · y_{t-1} + x_t
7
- 本模块实现幺半群递推的并行前缀扫描:
8
- y_t = exp(log_decay_t) · y_{t-1} + x_t
9
 
10
  This is the computational backbone of Monoid Attention's state compression.
11
  这是幺半群注意力状态压缩的计算骨干。
12
 
13
- Why parallel prefix scan matters / 并行前缀扫描为什么重要:
14
- The monoid recurrence S_t = α_t·S_{t-1} + kv_t is inherently sequential.
15
- However, because (log_α, S) (log_β, X) = (log_α+log_β, exp(log_β)·S+X)
16
- is ASSOCIATIVE, we can compute all prefix sums S_1..S_T via a parallel
17
- reduction tree in O(log T) depth instead of O(T) sequential steps.
18
- 幺半群递推 S_t = α_t·S_{t-1} + kv_t 本质上是串行的。
19
- 但因为 (log_α, S) ⊕ (log_β, X) = (log_α+log_β, exp(log_β)·S+X)
20
- 满足结合律, 我们可以通过并行归约树在 O(log T) 深度内计算所有前缀和 S_1..S_T,
21
- 而非 O(T) 的串行步骤。
22
-
23
- Training uses O(T) parallel scan (this file).
24
- Inference uses O(1) sequential monoid_op (in MonoidForCausalLM.py).
25
- 训练使用 O(T) 并行扫描 (本文件)。
26
- 推理使用 O(1) 串行 monoid_op (在 MonoidForCausalLM.py 中)。
27
 
28
  Implementation:
29
- Forward: sequential scan along T, parallelized across B*H*D on GPU.
 
 
30
  Backward: reverse-order adjoint scan for gradient computation.
 
31
  Auto-dispatches: CUDA → Triton kernel, CPU/MPS → PyTorch fallback.
32
 
33
- 前向: 沿 T 维顺序扫描, 跨 B*H*D 在 GPU 上并行。
 
34
  反向: 逆序伴随变量扫描计算梯度。
 
35
  自动分派: CUDA → Triton 核函数, CPU/MPS → PyTorch 回退。
36
  """
37
 
@@ -60,18 +57,18 @@ def _sequential_scan(log_decays: Tensor, values: Tensor) -> Tensor:
60
  Pure PyTorch sequential scan fallback (when no CUDA / Triton available).
61
  纯 PyTorch 串行扫描回退 (无 CUDA / Triton 时使用)。
62
 
63
- Implements the monoid recurrence step by step:
64
  acc_0 = 0
65
- acc_t = exp(log_decay_t) · acc_{t-1} + values_t
66
  This is O(T) sequential — correct but slow on GPU.
67
- 逐步实现幺半群递推:
68
  acc_0 = 0
69
- acc_t = exp(log_decay_t) · acc_{t-1} + values_t
70
  这是 O(T) 串行的 — 结果正确但在 GPU 上较慢。
71
 
72
  Args:
73
- log_decays: [B, H, T, 1] — log of per-head per-step decay gates
74
- 每头每步衰减门的对数
75
  values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t to accumulate
76
  待累积的外积 k_t⊗v_t
77
  Returns:
@@ -84,17 +81,17 @@ def _sequential_scan(log_decays: Tensor, values: Tensor) -> Tensor:
84
  # acc 代表 S_t — 时刻 t 的压缩因果状态
85
  acc = torch.zeros(B, H, D_k, D_v, device=values.device, dtype=values.dtype)
86
  for t in range(T):
87
- # S_t = α_t · S_{t-1} + kv_t (the core monoid recurrence)
88
- # S_t = α_t · S_{t-1} + kv_t (核心幺半群递推)
89
- decay_t = torch.exp(log_decays[:, :, t]).unsqueeze(-1) # [B,H,1,1]
90
  acc = acc * decay_t + values[:, :, t]
91
  out[:, :, t] = acc
92
  return out
93
 
94
 
95
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
96
- # Triton Kernels — GPU-accelerated scan
97
- # Triton 核函数 — GPU 加速扫描
98
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
99
 
100
  if HAS_TRITON:
@@ -102,147 +99,136 @@ if HAS_TRITON:
102
  @triton.jit
103
  def _scan_fwd_kernel(
104
  LD_ptr, V_ptr, O_ptr,
105
- T, D,
106
- s_ld_bh, s_ld_t,
107
- s_v_bh, s_v_t, s_v_d,
108
- s_o_bh, s_o_t, s_o_d,
109
- BLOCK_D: tl.constexpr,
110
  ):
111
  """
112
- Forward scan kernel — computes all prefix states S_1..S_T.
113
- 前向扫描核函数 — 计算所有前缀状态 S_1..S_T。
114
 
115
  Parallelization strategy / 并行化策略:
116
- - program_id(0) = bh: one program per (batch, head) pair
117
- 每个 (batch, head) 对一个 program
118
- - program_id(1) = db: one program per D-dimension block
119
- 每个 D 维 block 一个 program
120
  - Sequential loop over T (the causal recurrence is inherently sequential)
121
  沿 T 维串行循环 (因果递推本质上是串行的)
122
 
123
- Each program computes: acc_t = exp(ld_t) * acc_{t-1} + val_t
124
- for a BLOCK_D-wide slice of the flattened d_k*d_v state matrix.
125
- 每个 program 计算展平的 d_k*d_v 状态矩阵的一个 BLOCK_D 宽的切片。
 
 
126
 
127
- Note: while the T-loop is sequential within each program,
128
- B*H*ceil(D/BLOCK_D) programs run in parallel on the GPU.
129
- 注意: 虽然 T 循环在每个 program 内是串行的,
130
- 但 B*H*ceil(D/BLOCK_D) 个 program 在 GPU 上并行运行。
131
  """
132
- bh = tl.program_id(0)
133
- db = tl.program_id(1)
134
- d_offs = db * BLOCK_D + tl.arange(0, BLOCK_D)
135
- d_mask = d_offs < D
136
 
137
- # acc = S_0 = 0 (identity element of the monoid)
138
- # acc = S_0 = 0 (幺半群的单位元)
139
- acc = tl.zeros([BLOCK_D], dtype=tl.float32)
140
 
141
- ld_base = LD_ptr + bh * s_ld_bh
142
- v_base = V_ptr + bh * s_v_bh
143
- o_base = O_ptr + bh * s_o_bh
144
 
145
  for t in range(T):
146
- # Load log_decay and compute decay = exp(log_α_t)
147
- # 加载 log_decay 并计算 decay = exp(log_α_t)
148
  ld_val = tl.load(ld_base + t * s_ld_t).to(tl.float32)
149
  decay = tl.exp(ld_val)
150
 
151
- # Load kv_t (a slice of the outer product k_t⊗v_t)
152
- # 加载 kv_t (外积 k_t⊗v_t 的一个切片)
153
  val = tl.load(
154
- v_base + t * s_v_t + d_offs * s_v_d,
155
- mask=d_mask, other=0.0,
156
  ).to(tl.float32)
157
 
158
- # Core recurrence: S_t = α_t · S_{t-1} + kv_t
159
- # 核心递推: S_t = α_t · S_{t-1} + kv_t
160
  acc = acc * decay + val
161
 
162
- # Store S_t
163
  tl.store(
164
- o_base + t * s_o_t + d_offs * s_o_d,
165
- acc, mask=d_mask,
166
  )
167
 
168
  @triton.jit
169
  def _scan_bwd_kernel(
170
  LD_ptr, O_ptr, GO_ptr, GV_ptr, GLD_ptr,
171
- T, D,
172
- s_ld_bh, s_ld_t,
173
- s_o_bh, s_o_t, s_o_d,
174
- s_go_bh, s_go_t, s_go_d,
175
- s_gv_bh, s_gv_t, s_gv_d,
176
- s_gld_bh, s_gld_t,
177
- BLOCK_D: tl.constexpr,
178
  ):
179
  """
180
- Backward scan kernel — computes gradients via adjoint method.
181
- 反向扫描核函数 — 通过伴随方法计算梯度。
182
-
183
- The forward recurrence is: y_t = a_t * y_{t-1} + x_t
184
- 前向递推: y_t = a_t * y_{t-1} + x_t
185
-
186
- The adjoint (reverse-time) recurrence for the Lagrange multiplier λ:
187
- λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1} (backward in time)
188
- 伴随 (逆时间) 递推的拉格朗日乘子 λ:
189
- λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1} (时间反向)
190
-
191
- Gradients / 梯度:
192
- ∂L/∂x_t = λ_t (gradient w.r.t. input values)
193
- (对输入值的梯度)
194
- ∂L/∂log_a_t = a_t · Σ_D(λ_t · y_{t-1}) (gradient w.r.t. log-decay)
195
- (对对数衰减的梯度)
196
-
197
- The gradient of log_decay is critical for training the decay gate:
198
- it tells the model how to adjust each head's forgetting rate.
199
- log_decay 的梯度对训练衰减门至关重要:
200
- 它告诉模型如何调整每个头的遗忘速率。
201
  """
202
- bh = tl.program_id(0)
203
- db = tl.program_id(1)
204
- d_offs = db * BLOCK_D + tl.arange(0, BLOCK_D)
205
- d_mask = d_offs < D
206
 
207
  # adj holds a_{t+1} · λ_{t+1}, initialized to 0 at the sequence end
208
  # adj 保存 a_{t+1} · λ_{t+1}, 在序列末尾初始化为 0
209
- adj = tl.zeros([BLOCK_D], dtype=tl.float32)
210
 
211
  for t_rev in range(T):
212
  t = T - 1 - t_rev # reverse time / 逆序时间
213
 
214
- # Load ∂L/∂y_t (upstream gradient)
215
- # 加载 ∂L/∂y_t (上游梯度)
216
  go = tl.load(
217
- GO_ptr + bh * s_go_bh + t * s_go_t + d_offs * s_go_d,
218
- mask=d_mask, other=0.0,
219
  ).to(tl.float32)
220
 
221
  # Adjoint: λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1}
222
  # 伴随: λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1}
223
  lam = go + adj
224
 
225
- # ∂L/∂x_t = λ_t (gradient of values / 值的梯度)
 
226
  tl.store(
227
- GV_ptr + bh * s_gv_bh + t * s_gv_t + d_offs * s_gv_d,
228
- lam, mask=d_mask,
229
  )
230
 
231
- # ∂L/∂log_a_t = a_t · Σ_D(λ_t · y_{t-1})
232
- # This gradient flows back to the decay gate (decay_proj),
233
- # teaching the model how to control causal information retention.
234
- # 此梯度回传到衰减门 (decay_proj),
235
- # 教模型如何控制因果信息的保留。
236
- ld_val = tl.load(LD_ptr + bh * s_ld_bh + t * s_ld_t).to(tl.float32)
237
  a_t = tl.exp(ld_val)
238
 
239
  if t > 0:
240
  y_prev = tl.load(
241
- O_ptr + bh * s_o_bh + (t - 1) * s_o_t + d_offs * s_o_d,
242
- mask=d_mask, other=0.0,
243
  ).to(tl.float32)
244
- grad_ld_partial = tl.sum(lam * y_prev) * a_t
245
- tl.atomic_add(GLD_ptr + bh * s_gld_bh + t * s_gld_t, grad_ld_partial)
246
 
247
  # Prepare for next step (t-1): adj = a_t · λ_t
248
  # 为下一步 (t-1) 准备: adj = a_t · λ_t
@@ -255,78 +241,91 @@ if HAS_TRITON:
255
 
256
  class _ParallelScanFn(Function):
257
  """
258
- Custom autograd function for the parallel prefix scan.
259
- 并行前缀扫描的自定义 autograd 函数。
260
 
261
  Forward: launches _scan_fwd_kernel to compute all prefix states.
 
262
  Backward: launches _scan_bwd_kernel to compute gradients via adjoint method.
 
263
 
264
  前向: 启动 _scan_fwd_kernel 计算所有前缀状态。
 
265
  反向: 启动 _scan_bwd_kernel 通过伴随方法计算梯度。
 
266
  """
267
  @staticmethod
268
  def forward(ctx, log_decays: Tensor, values: Tensor) -> Tensor:
269
  B, H, T, D_k, D_v = values.shape
270
- D = D_k * D_v # flattened state dimension / 展平的状态维度
271
 
272
- # Flatten: [B,H,T,1] [BH, T], [B,H,T,Dk,Dv] → [BH, T, D]
273
- # 展平: [B,H,T,1] → [BH, T], [B,H,T,Dk,Dv] → [BH, T, D]
274
- ld_flat = log_decays.squeeze(-1).contiguous().reshape(B * H, T)
275
- v_flat = values.reshape(B * H, T, D).contiguous()
 
 
 
 
276
  o_flat = torch.empty_like(v_flat)
277
 
278
- BH = B * H
279
- BLOCK_D = min(triton.next_power_of_2(D), 1024)
280
- # Grid: (BH, ceil(D/BLOCK_D)) — one program per (batch*head, D-block)
281
- # 网格: (BH, ceil(D/BLOCK_D)) — 每个 (batch*head, D-block) 一个 program
282
- grid = (BH, triton.cdiv(D, BLOCK_D))
283
 
284
  _scan_fwd_kernel[grid](
285
  ld_flat, v_flat, o_flat,
286
- T, D,
287
  ld_flat.stride(0), ld_flat.stride(1),
288
  v_flat.stride(0), v_flat.stride(1), v_flat.stride(2),
289
  o_flat.stride(0), o_flat.stride(1), o_flat.stride(2),
290
- BLOCK_D=BLOCK_D,
291
  )
292
 
293
  # Save for backward: need log_decays and forward outputs y_t
294
  # 为反向传播保存: 需要 log_decays 和前向输出 y_t
295
  ctx.save_for_backward(ld_flat, o_flat)
296
- ctx.shape_info = (B, H, T, D_k, D_v, D, BH, BLOCK_D)
297
- return o_flat.reshape(B, H, T, D_k, D_v)
 
298
 
299
  @staticmethod
300
  def backward(ctx, grad_output: Tensor):
301
  ld_flat, o_flat = ctx.saved_tensors
302
- B, H, T, D_k, D_v, D, BH, BLOCK_D = ctx.shape_info
303
 
304
- go_flat = grad_output.reshape(BH, T, D).contiguous()
 
305
  gv_flat = torch.empty_like(go_flat)
306
- # Use f32 for atomic_add precision in gradient accumulation
307
- # 使用 f32 保证 atomic_add 梯度累积的精度
308
- gld_flat = torch.zeros(BH, T, device=ld_flat.device, dtype=torch.float32)
309
 
310
- grid = (BH, triton.cdiv(D, BLOCK_D))
311
 
312
  _scan_bwd_kernel[grid](
313
  ld_flat, o_flat, go_flat, gv_flat, gld_flat,
314
- T, D,
315
  ld_flat.stride(0), ld_flat.stride(1),
316
  o_flat.stride(0), o_flat.stride(1), o_flat.stride(2),
317
  go_flat.stride(0), go_flat.stride(1), go_flat.stride(2),
318
  gv_flat.stride(0), gv_flat.stride(1), gv_flat.stride(2),
319
  gld_flat.stride(0), gld_flat.stride(1),
320
- BLOCK_D=BLOCK_D,
321
  )
322
 
323
- grad_log_decays = gld_flat.to(grad_output.dtype).reshape(B, H, T, 1)
324
- grad_values = gv_flat.reshape(B, H, T, D_k, D_v)
 
 
 
 
325
  return grad_log_decays, grad_values
326
 
327
  def _triton_parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
328
- """Triton-accelerated parallel scan entry point.
329
- Triton 加速的并行扫描入口。"""
330
  return _ParallelScanFn.apply(log_decays, values)
331
 
332
  else:
@@ -339,15 +338,16 @@ else:
339
 
340
  def parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
341
  """
342
- Parallel prefix scan — computes all prefix monoid sums.
343
- 并行前缀扫描 — 计算所有前缀幺半群和。
344
 
345
  This is the training-time workhorse of Monoid Attention.
346
- It computes S_1, S_2, ..., S_T where S_t = α_t·S_{t-1} + kv_t,
 
347
  for ALL timesteps simultaneously.
348
  这是幺半群注意力训练时的主力计算。
349
  它同时计算所有时间步的 S_1, S_2, ..., S_T,
350
- 其中 S_t = α_t·S_{t-1} + kv_t。
351
 
352
  Auto-dispatches based on device:
353
  CUDA → Triton JIT kernel (fast, with custom backward)
@@ -357,8 +357,8 @@ def parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
357
  CPU/MPS → PyTorch 串行扫描 (正确, 较慢)
358
 
359
  Args:
360
- log_decays: [B, H, T, 1] — log of decay gates α_t
361
- 衰减门 α_t 的对数
362
  values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t
363
  外积 k_t⊗v_t
364
  Returns:
@@ -374,8 +374,8 @@ def parallel_scan_with_state(
374
  log_decays: Tensor, values: Tensor,
375
  ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
376
  """
377
- Parallel prefix scan + extract final state for inference handoff.
378
- 并行前缀扫描 + 提取最终状态用于推理切换。
379
 
380
  Used during prefill: compute all training-time prefix states,
381
  AND extract the final accumulated state S_T so that subsequent
@@ -389,22 +389,22 @@ def parallel_scan_with_state(
389
  这是训练模式 (并行扫描) 和推理模式 (串行 monoid_op) 之间的桥梁。
390
 
391
  Args:
392
- log_decays: [B, H, T, 1]
393
  values: [B, H, T, D_k, D_v]
394
 
395
  Returns:
396
  output: [B, H, T, D_k, D_v] — all prefix states S_1..S_T
397
  所有前缀状态
398
  final_state: (log_acc, S_T) where
399
- log_acc: [B, H, 1] — accumulated log-decay (for future monoid_op)
400
- 累积对数衰减 (供后续 monoid_op 使用)
401
- final_state: [B, H, D_k, D_v] — S_T, the compressed causal summary
402
- S_T, 压缩的因果摘要
403
  """
404
  output = parallel_scan(log_decays, values)
405
- # Sum all log-decays to get the total accumulated decay
406
- # 对所有 log-decay 求和得到总累积衰减
407
- log_acc = log_decays.squeeze(-1).sum(dim=2, keepdim=True) # [B, H, 1]
408
  # The last timestep's state IS the full causal summary
409
  # 最后一个时间步的状态就是完整的因果摘要
410
  final_state = output[:, :, -1] # [B, H, D_k, D_v]
 
2
  monoid_scan_cuda.py — Triton CUDA JIT Accelerated Parallel Prefix Scan
3
  monoid_scan_cuda.py — Triton CUDA JIT 加速的并行前缀扫描
4
 
5
+ This module implements the parallel prefix scan for the vector-decay monoid recurrence:
6
+ y_t[i,:] = exp(log_decay_t[i]) · y_{t-1}[i,:] + x_t[i,:]
7
+ 本模块实现向量衰减幺半群递推的并行前缀扫描:
8
+ y_t[i,:] = exp(log_decay_t[i]) · y_{t-1}[i,:] + x_t[i,:]
9
 
10
  This is the computational backbone of Monoid Attention's state compression.
11
  这是幺半群注意力状态压缩的计算骨干。
12
 
13
+ Vector decay: each dimension of the D_k×D_v state matrix has its own
14
+ per-dimension decay rate α_t ℝ^{D_k}, enabling different feature
15
+ dimensions to have independent memory lifetimes (fast-decaying for
16
+ local syntax, slow-decaying for global entity memory).
17
+ 向量衰减: D_k×D_v 状态矩阵的每个维度拥有独立的衰减率 α_t ℝ^{D_k},
18
+ 使不同特征维度拥有独立的记忆生命周期 (快速衰减用于局部语法, 慢速衰减用于全局实体记忆)。
 
 
 
 
 
 
 
 
19
 
20
  Implementation:
21
+ Forward: sequential scan along T, parallelized across B*H*D_k on GPU.
22
+ Each program handles one row of the state matrix (D_v elements)
23
+ with a scalar decay per row.
24
  Backward: reverse-order adjoint scan for gradient computation.
25
+ Per-row reduction for log_decay gradient (no atomic_add needed).
26
  Auto-dispatches: CUDA → Triton kernel, CPU/MPS → PyTorch fallback.
27
 
28
+ 前向: 沿 T 维顺序扫描, 跨 B*H*D_k 在 GPU 上并行。
29
+ 每个 program 处理状态矩阵的一行 (D_v 个元素), 每行一个标量衰减。
30
  反向: 逆序伴随变量扫描计算梯度。
31
+ 逐行归约计算 log_decay 梯度 (无需 atomic_add)。
32
  自动分派: CUDA → Triton 核函数, CPU/MPS → PyTorch 回退。
33
  """
34
 
 
57
  Pure PyTorch sequential scan fallback (when no CUDA / Triton available).
58
  纯 PyTorch 串行扫描回退 (无 CUDA / Triton 时使用)。
59
 
60
+ Implements the vector-decay monoid recurrence step by step:
61
  acc_0 = 0
62
+ acc_t[i,:] = exp(log_decay_t[i]) · acc_{t-1}[i,:] + values_t[i,:]
63
  This is O(T) sequential — correct but slow on GPU.
64
+ 逐步实现向量衰减幺半群递推:
65
  acc_0 = 0
66
+ acc_t[i,:] = exp(log_decay_t[i]) · acc_{t-1}[i,:] + values_t[i,:]
67
  这是 O(T) 串行的 — 结果正确但在 GPU 上较慢。
68
 
69
  Args:
70
+ log_decays: [B, H, T, D_k] — log of per-dimension per-step decay gates
71
+ 每维度每步衰减门的对数
72
  values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t to accumulate
73
  待累积的外积 k_t⊗v_t
74
  Returns:
 
81
  # acc 代表 S_t — 时刻 t 的压缩因果状态
82
  acc = torch.zeros(B, H, D_k, D_v, device=values.device, dtype=values.dtype)
83
  for t in range(T):
84
+ # S_t = diag(α_t) · S_{t-1} + kv_t (vector decay monoid recurrence)
85
+ # S_t = diag(α_t) · S_{t-1} + kv_t (向量衰减幺半群递推)
86
+ decay_t = torch.exp(log_decays[:, :, t]).unsqueeze(-1) # [B,H,D_k,1]
87
  acc = acc * decay_t + values[:, :, t]
88
  out[:, :, t] = acc
89
  return out
90
 
91
 
92
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
93
+ # Triton Kernels — GPU-accelerated scan (vector decay)
94
+ # Triton 核函数 — GPU 加速扫描 (向量衰减)
95
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
96
 
97
  if HAS_TRITON:
 
99
  @triton.jit
100
  def _scan_fwd_kernel(
101
  LD_ptr, V_ptr, O_ptr,
102
+ T, D_v,
103
+ s_ld_bhdk, s_ld_t,
104
+ s_v_bhdk, s_v_t, s_v_dv,
105
+ s_o_bhdk, s_o_t, s_o_dv,
106
+ BLOCK_DV: tl.constexpr,
107
  ):
108
  """
109
+ Forward scan kernel — computes all prefix states S_1..S_T (vector decay).
110
+ 前向扫描核函数 — 计算所有前缀状态 S_1..S_T (向量衰减)
111
 
112
  Parallelization strategy / 并行化策略:
113
+ - program_id(0) = bhdk: one program per (batch, head, d_k row) triple
114
+ 每个 (batch, head, d_k 行) 三元组一个 program
115
+ - program_id(1) = dvb: one program per D_v-dimension block (typically 1 block)
116
+ 每个 D_v 维 block 一个 program (通常只有 1 个 block)
117
  - Sequential loop over T (the causal recurrence is inherently sequential)
118
  沿 T 维串行循环 (因果递推本质上是串行的)
119
 
120
+ Each program handles one row of the D_k×D_v state matrix, where the
121
+ decay is a single scalar per row. This eliminates the need for
122
+ row-index computation in the inner loop.
123
+ 每个 program 处理 D_k×D_v 状态矩阵的一行, 该行的衰减是一个标量。
124
+ 这消除了内循环中行索引计算的需要。
125
 
126
+ Grid: (B*H*D_k, ceil(D_v/BLOCK_DV))
127
+ 网格: (B*H*D_k, ceil(D_v/BLOCK_DV))
 
 
128
  """
129
+ bhdk = tl.program_id(0)
130
+ dvb = tl.program_id(1)
131
+ dv_offs = dvb * BLOCK_DV + tl.arange(0, BLOCK_DV)
132
+ dv_mask = dv_offs < D_v
133
 
134
+ # acc = S_0[row,:] = 0 (identity element of the monoid)
135
+ # acc = S_0[行,:] = 0 (幺半群的单位元)
136
+ acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
137
 
138
+ ld_base = LD_ptr + bhdk * s_ld_bhdk
139
+ v_base = V_ptr + bhdk * s_v_bhdk
140
+ o_base = O_ptr + bhdk * s_o_bhdk
141
 
142
  for t in range(T):
143
+ # Load scalar log_decay for this row at time t
144
+ # 加载此行在时刻 t 的标量 log_decay
145
  ld_val = tl.load(ld_base + t * s_ld_t).to(tl.float32)
146
  decay = tl.exp(ld_val)
147
 
148
+ # Load kv_t[row, :] (one row of the outer product)
149
+ # 加载 kv_t[行, :] (外积的一行)
150
  val = tl.load(
151
+ v_base + t * s_v_t + dv_offs * s_v_dv,
152
+ mask=dv_mask, other=0.0,
153
  ).to(tl.float32)
154
 
155
+ # Core recurrence: S_t[i,:] = α_t[i] · S_{t-1}[i,:] + kv_t[i,:]
156
+ # 核心递推: S_t[i,:] = α_t[i] · S_{t-1}[i,:] + kv_t[i,:]
157
  acc = acc * decay + val
158
 
159
+ # Store S_t[row, :]
160
  tl.store(
161
+ o_base + t * s_o_t + dv_offs * s_o_dv,
162
+ acc, mask=dv_mask,
163
  )
164
 
165
  @triton.jit
166
  def _scan_bwd_kernel(
167
  LD_ptr, O_ptr, GO_ptr, GV_ptr, GLD_ptr,
168
+ T, D_v,
169
+ s_ld_bhdk, s_ld_t,
170
+ s_o_bhdk, s_o_t, s_o_dv,
171
+ s_go_bhdk, s_go_t, s_go_dv,
172
+ s_gv_bhdk, s_gv_t, s_gv_dv,
173
+ s_gld_bhdk, s_gld_t,
174
+ BLOCK_DV: tl.constexpr,
175
  ):
176
  """
177
+ Backward scan kernel — computes gradients via adjoint method (vector decay).
178
+ 反向扫描核函数 — 通过伴随方法计算梯度 (向量衰减)。
179
+
180
+ Each program handles one row of the state matrix (one d_k dimension).
181
+ The decay for this row is a scalar, so the log_decay gradient is:
182
+ ∂L/∂log_α_t[i] = α_t[i] · Σ_j(λ_t[i,j] · y_{t-1}[i,j])
183
+ The sum over j (D_v) is computed within this single program — no atomic_add.
184
+ 每个 program 处理状态矩阵的一行 (一个 d_k 维度)
185
+ 该行的衰减是标量, 因此 log_decay 梯度为:
186
+ ∂L/∂log_α_t[i] = α_t[i] · Σ_j(λ_t[i,j] · y_{t-1}[i,j])
187
+ 对 j (D_v) 的求和在单个 program 内完成 — 无需 atomic_add。
 
 
 
 
 
 
 
 
 
 
188
  """
189
+ bhdk = tl.program_id(0)
190
+ dvb = tl.program_id(1)
191
+ dv_offs = dvb * BLOCK_DV + tl.arange(0, BLOCK_DV)
192
+ dv_mask = dv_offs < D_v
193
 
194
  # adj holds a_{t+1} · λ_{t+1}, initialized to 0 at the sequence end
195
  # adj 保存 a_{t+1} · λ_{t+1}, 在序列末尾初始化为 0
196
+ adj = tl.zeros([BLOCK_DV], dtype=tl.float32)
197
 
198
  for t_rev in range(T):
199
  t = T - 1 - t_rev # reverse time / 逆序时间
200
 
201
+ # Load ∂L/∂y_t[row, :] (upstream gradient)
202
+ # 加载 ∂L/∂y_t[行, :] (上游梯度)
203
  go = tl.load(
204
+ GO_ptr + bhdk * s_go_bhdk + t * s_go_t + dv_offs * s_go_dv,
205
+ mask=dv_mask, other=0.0,
206
  ).to(tl.float32)
207
 
208
  # Adjoint: λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1}
209
  # 伴随: λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1}
210
  lam = go + adj
211
 
212
+ # ∂L/∂x_t[row,:] = λ_t (gradient of values)
213
+ # ∂L/∂x_t[行,:] = λ_t (值的梯度)
214
  tl.store(
215
+ GV_ptr + bhdk * s_gv_bhdk + t * s_gv_t + dv_offs * s_gv_dv,
216
+ lam, mask=dv_mask,
217
  )
218
 
219
+ # ∂L/∂log_α_t[i] = α_t[i] · Σ_j(λ_t[i,j] · y_{t-1}[i,j])
220
+ # Per-row scalar gradient: sum over D_v within this program.
221
+ # 逐行标量梯度: 在此 program 内对 D_v 求和。
222
+ ld_val = tl.load(LD_ptr + bhdk * s_ld_bhdk + t * s_ld_t).to(tl.float32)
 
 
223
  a_t = tl.exp(ld_val)
224
 
225
  if t > 0:
226
  y_prev = tl.load(
227
+ O_ptr + bhdk * s_o_bhdk + (t - 1) * s_o_t + dv_offs * s_o_dv,
228
+ mask=dv_mask, other=0.0,
229
  ).to(tl.float32)
230
+ grad_ld = tl.sum(lam * y_prev) * a_t
231
+ tl.atomic_add(GLD_ptr + bhdk * s_gld_bhdk + t * s_gld_t, grad_ld)
232
 
233
  # Prepare for next step (t-1): adj = a_t · λ_t
234
  # 为下一步 (t-1) 准备: adj = a_t · λ_t
 
241
 
242
  class _ParallelScanFn(Function):
243
  """
244
+ Custom autograd function for the parallel prefix scan (vector decay).
245
+ 并行前缀扫描的自定义 autograd 函数 (向量衰减)。
246
 
247
  Forward: launches _scan_fwd_kernel to compute all prefix states.
248
+ Grid: (B*H*D_k, ceil(D_v/BLOCK_DV)), one program per state row.
249
  Backward: launches _scan_bwd_kernel to compute gradients via adjoint method.
250
+ Per-row reduction eliminates most atomic_add overhead.
251
 
252
  前向: 启动 _scan_fwd_kernel 计算所有前缀状态。
253
+ 网格: (B*H*D_k, ceil(D_v/BLOCK_DV)), 每行状态一个 program。
254
  反向: 启动 _scan_bwd_kernel 通过伴随方法计算梯度。
255
+ 逐行归约消除大部分 atomic_add 开销。
256
  """
257
  @staticmethod
258
  def forward(ctx, log_decays: Tensor, values: Tensor) -> Tensor:
259
  B, H, T, D_k, D_v = values.shape
 
260
 
261
+ # Reshape for row-parallel kernel:
262
+ # log_decays: [B, H, T, D_k] → permute to [B, H, D_k, T] → [B*H*D_k, T]
263
+ # values: [B, H, T, D_k, D_v] → permute to [B, H, D_k, T, D_v] → [B*H*D_k, T, D_v]
264
+ # 为行并行核函数重塑:
265
+ # log_decays: [B, H, T, D_k] → 转置为 [B, H, D_k, T] → [B*H*D_k, T]
266
+ # values: [B, H, T, D_k, D_v] → 转置为 [B, H, D_k, T, D_v] → [B*H*D_k, T, D_v]
267
+ ld_flat = log_decays.permute(0, 1, 3, 2).contiguous().reshape(B * H * D_k, T)
268
+ v_flat = values.permute(0, 1, 3, 2, 4).contiguous().reshape(B * H * D_k, T, D_v)
269
  o_flat = torch.empty_like(v_flat)
270
 
271
+ BHDK = B * H * D_k
272
+ BLOCK_DV = min(triton.next_power_of_2(D_v), 1024)
273
+ # Grid: (B*H*D_k, ceil(D_v/BLOCK_DV)) — one program per (batch, head, row, dv-block)
274
+ # 网格: (B*H*D_k, ceil(D_v/BLOCK_DV))
275
+ grid = (BHDK, triton.cdiv(D_v, BLOCK_DV))
276
 
277
  _scan_fwd_kernel[grid](
278
  ld_flat, v_flat, o_flat,
279
+ T, D_v,
280
  ld_flat.stride(0), ld_flat.stride(1),
281
  v_flat.stride(0), v_flat.stride(1), v_flat.stride(2),
282
  o_flat.stride(0), o_flat.stride(1), o_flat.stride(2),
283
+ BLOCK_DV=BLOCK_DV,
284
  )
285
 
286
  # Save for backward: need log_decays and forward outputs y_t
287
  # 为反向传播保存: 需要 log_decays 和前向输出 y_t
288
  ctx.save_for_backward(ld_flat, o_flat)
289
+ ctx.shape_info = (B, H, T, D_k, D_v, BHDK, BLOCK_DV)
290
+ # Reshape back: [B*H*D_k, T, D_v] → [B, H, D_k, T, D_v] → [B, H, T, D_k, D_v]
291
+ return o_flat.reshape(B, H, D_k, T, D_v).permute(0, 1, 3, 2, 4).contiguous()
292
 
293
  @staticmethod
294
  def backward(ctx, grad_output: Tensor):
295
  ld_flat, o_flat = ctx.saved_tensors
296
+ B, H, T, D_k, D_v, BHDK, BLOCK_DV = ctx.shape_info
297
 
298
+ # Permute grad_output to match row-parallel layout: [B,H,T,D_k,D_v] → [B*H*D_k, T, D_v]
299
+ go_flat = grad_output.permute(0, 1, 3, 2, 4).contiguous().reshape(BHDK, T, D_v)
300
  gv_flat = torch.empty_like(go_flat)
301
+ # Use f32 for gradient accumulation precision
302
+ # 使用 f32 保证梯度累积的精度
303
+ gld_flat = torch.zeros(BHDK, T, device=ld_flat.device, dtype=torch.float32)
304
 
305
+ grid = (BHDK, triton.cdiv(D_v, BLOCK_DV))
306
 
307
  _scan_bwd_kernel[grid](
308
  ld_flat, o_flat, go_flat, gv_flat, gld_flat,
309
+ T, D_v,
310
  ld_flat.stride(0), ld_flat.stride(1),
311
  o_flat.stride(0), o_flat.stride(1), o_flat.stride(2),
312
  go_flat.stride(0), go_flat.stride(1), go_flat.stride(2),
313
  gv_flat.stride(0), gv_flat.stride(1), gv_flat.stride(2),
314
  gld_flat.stride(0), gld_flat.stride(1),
315
+ BLOCK_DV=BLOCK_DV,
316
  )
317
 
318
+ # Reshape gradients back to original layout
319
+ # 重塑梯度回原始布局
320
+ # gld: [B*H*D_k, T] → [B, H, D_k, T] → [B, H, T, D_k]
321
+ grad_log_decays = gld_flat.to(grad_output.dtype).reshape(B, H, D_k, T).permute(0, 1, 3, 2).contiguous()
322
+ # gv: [B*H*D_k, T, D_v] → [B, H, D_k, T, D_v] → [B, H, T, D_k, D_v]
323
+ grad_values = gv_flat.reshape(B, H, D_k, T, D_v).permute(0, 1, 3, 2, 4).contiguous()
324
  return grad_log_decays, grad_values
325
 
326
  def _triton_parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
327
+ """Triton-accelerated parallel scan entry point (vector decay).
328
+ Triton 加速的并行扫描入口 (向量衰减)。"""
329
  return _ParallelScanFn.apply(log_decays, values)
330
 
331
  else:
 
338
 
339
  def parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
340
  """
341
+ Parallel prefix scan — computes all prefix monoid sums (vector decay).
342
+ 并行前缀扫描 — 计算所有前缀幺半群和 (向量衰减)。
343
 
344
  This is the training-time workhorse of Monoid Attention.
345
+ It computes S_1, S_2, ..., S_T where
346
+ S_t[i,:] = α_t[i]·S_{t-1}[i,:] + kv_t[i,:]
347
  for ALL timesteps simultaneously.
348
  这是幺半群注意力训练时的主力计算。
349
  它同时计算所有时间步的 S_1, S_2, ..., S_T,
350
+ 其中 S_t[i,:] = α_t[i]·S_{t-1}[i,:] + kv_t[i,:]
351
 
352
  Auto-dispatches based on device:
353
  CUDA → Triton JIT kernel (fast, with custom backward)
 
357
  CPU/MPS → PyTorch 串行扫描 (正确, 较慢)
358
 
359
  Args:
360
+ log_decays: [B, H, T, D_k] — log of per-dimension decay gates α_t
361
+ 每维度衰减门 α_t 的对数
362
  values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t
363
  外积 k_t⊗v_t
364
  Returns:
 
374
  log_decays: Tensor, values: Tensor,
375
  ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
376
  """
377
+ Parallel prefix scan + extract final state for inference handoff (vector decay).
378
+ 并行前缀扫描 + 提取最终状态用于推理切换 (向量衰减)。
379
 
380
  Used during prefill: compute all training-time prefix states,
381
  AND extract the final accumulated state S_T so that subsequent
 
389
  这是训练模式 (并行扫描) 和推理模式 (串行 monoid_op) 之间的桥梁。
390
 
391
  Args:
392
+ log_decays: [B, H, T, D_k]
393
  values: [B, H, T, D_k, D_v]
394
 
395
  Returns:
396
  output: [B, H, T, D_k, D_v] — all prefix states S_1..S_T
397
  所有前缀状态
398
  final_state: (log_acc, S_T) where
399
+ log_acc: [B, H, D_k] — accumulated log-decay vector (for future monoid_op)
400
+ 累积对数衰减向量 (供后续 monoid_op 使用)
401
+ final_state: [B, H, D_k, D_v] — S_T, the compressed causal summary
402
+ S_T, 压缩的因果摘要
403
  """
404
  output = parallel_scan(log_decays, values)
405
+ # Sum all log-decays over T to get the total accumulated decay per dimension
406
+ # 对所有 log-decay 沿 T 求和得到每个维度的总累积衰减
407
+ log_acc = log_decays.sum(dim=2) # [B, H, D_k]
408
  # The last timestep's state IS the full causal summary
409
  # 最后一个时间步的状态就是完整的因果摘要
410
  final_state = output[:, :, -1] # [B, H, D_k, D_v]
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:785ade819d78869db7f691e74573b3ce4183646d4e45ad2e1d0a940564b6b20f
3
- size 6289
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32938ebb880dc58fa7d6f8e45383c55e1d5d4352618531d62a28069918595445
3
+ size 6417