Spartacus-1B-Instruct / monoid_scan_cuda.py
OzTianlu's picture
Upload 11 files
b6c0790 verified
"""
monoid_scan_cuda.py — Triton CUDA JIT Accelerated Parallel Prefix Scan
monoid_scan_cuda.py — Triton CUDA JIT 加速的并行前缀扫描
This module implements the parallel prefix scan for the vector-decay monoid recurrence:
y_t[i,:] = exp(log_decay_t[i]) · y_{t-1}[i,:] + x_t[i,:]
本模块实现向量衰减幺半群递推的并行前缀扫描:
y_t[i,:] = exp(log_decay_t[i]) · y_{t-1}[i,:] + x_t[i,:]
This is the computational backbone of Monoid Attention's state compression.
这是幺半群注意力状态压缩的计算骨干。
Vector decay: each dimension of the D_k×D_v state matrix has its own
per-dimension decay rate α_t ∈ ℝ^{D_k}, enabling different feature
dimensions to have independent memory lifetimes (fast-decaying for
local syntax, slow-decaying for global entity memory).
向量衰减: D_k×D_v 状态矩阵的每个维度拥有独立的衰减率 α_t ∈ ℝ^{D_k},
使不同特征维度拥有独立的记忆生命周期 (快速衰减用于局部语法, 慢速衰减用于全局实体记忆)。
Implementation:
Forward: sequential scan along T, parallelized across B*H*D_k on GPU.
Each program handles one row of the state matrix (D_v elements)
with a scalar decay per row.
Backward: reverse-order adjoint scan for gradient computation.
Per-row reduction for log_decay gradient (no atomic_add needed).
Auto-dispatches: CUDA → Triton kernel, CPU/MPS → PyTorch fallback.
前向: 沿 T 维顺序扫描, 跨 B*H*D_k 在 GPU 上并行。
每个 program 处理状态矩阵的一行 (D_v 个元素), 每行一个标量衰减。
反向: 逆序伴随变量扫描计算梯度。
逐行归约计算 log_decay 梯度 (无需 atomic_add)。
自动分派: CUDA → Triton 核函数, CPU/MPS → PyTorch 回退。
"""
from __future__ import annotations
import torch
from torch import Tensor
from torch.autograd import Function
from typing import Tuple
try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Fallback: pure PyTorch sequential scan
# 回退: 纯 PyTorch 串行扫描 (CPU / MPS / no Triton)
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def _sequential_scan(log_decays: Tensor, values: Tensor) -> Tensor:
"""
Pure PyTorch sequential scan fallback (when no CUDA / Triton available).
纯 PyTorch 串行扫描回退 (无 CUDA / Triton 时使用)。
Implements the vector-decay monoid recurrence step by step:
acc_0 = 0
acc_t[i,:] = exp(log_decay_t[i]) · acc_{t-1}[i,:] + values_t[i,:]
This is O(T) sequential — correct but slow on GPU.
逐步实现向量衰减幺半群递推:
acc_0 = 0
acc_t[i,:] = exp(log_decay_t[i]) · acc_{t-1}[i,:] + values_t[i,:]
这是 O(T) 串行的 — 结果正确但在 GPU 上较慢。
Args:
log_decays: [B, H, T, D_k] — log of per-dimension per-step decay gates
每维度每步衰减门的对数
values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t to accumulate
待累积的外积 k_t⊗v_t
Returns:
output: [B, H, T, D_k, D_v] — all prefix states S_1, ..., S_T
所有前缀状态 S_1, ..., S_T
"""
B, H, T, D_k, D_v = values.shape
out = torch.empty_like(values)
# acc represents S_t — the compressed causal state at time t
# acc 代表 S_t — 时刻 t 的压缩因果状态
acc = torch.zeros(B, H, D_k, D_v, device=values.device, dtype=values.dtype)
for t in range(T):
# S_t = diag(α_t) · S_{t-1} + kv_t (vector decay monoid recurrence)
# S_t = diag(α_t) · S_{t-1} + kv_t (向量衰减幺半群递推)
decay_t = torch.exp(log_decays[:, :, t]).unsqueeze(-1) # [B,H,D_k,1]
acc = acc * decay_t + values[:, :, t]
out[:, :, t] = acc
return out
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Triton Kernels — GPU-accelerated scan (vector decay)
# Triton 核函数 — GPU 加速扫描 (向量衰减)
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
if HAS_TRITON:
@triton.jit
def _scan_fwd_kernel(
LD_ptr, V_ptr, O_ptr,
T, D_v,
s_ld_bhdk, s_ld_t,
s_v_bhdk, s_v_t, s_v_dv,
s_o_bhdk, s_o_t, s_o_dv,
BLOCK_DV: tl.constexpr,
):
"""
Forward scan kernel — computes all prefix states S_1..S_T (vector decay).
前向扫描核函数 — 计算所有前缀状态 S_1..S_T (向量衰减)。
Parallelization strategy / 并行化策略:
- program_id(0) = bhdk: one program per (batch, head, d_k row) triple
每个 (batch, head, d_k 行) 三元组一个 program
- program_id(1) = dvb: one program per D_v-dimension block (typically 1 block)
每个 D_v 维 block 一个 program (通常只有 1 个 block)
- Sequential loop over T (the causal recurrence is inherently sequential)
沿 T 维串行循环 (因果递推本质上是串行的)
Each program handles one row of the D_k×D_v state matrix, where the
decay is a single scalar per row. This eliminates the need for
row-index computation in the inner loop.
每个 program 处理 D_k×D_v 状态矩阵的一行, 该行的衰减是一个标量。
这消除了内循环中行索引计算的需要。
Grid: (B*H*D_k, ceil(D_v/BLOCK_DV))
网格: (B*H*D_k, ceil(D_v/BLOCK_DV))
"""
bhdk = tl.program_id(0)
dvb = tl.program_id(1)
dv_offs = dvb * BLOCK_DV + tl.arange(0, BLOCK_DV)
dv_mask = dv_offs < D_v
# acc = S_0[row,:] = 0 (identity element of the monoid)
# acc = S_0[行,:] = 0 (幺半群的单位元)
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
ld_base = LD_ptr + bhdk * s_ld_bhdk
v_base = V_ptr + bhdk * s_v_bhdk
o_base = O_ptr + bhdk * s_o_bhdk
for t in range(T):
# Load scalar log_decay for this row at time t
# 加载此行在时刻 t 的标量 log_decay
ld_val = tl.load(ld_base + t * s_ld_t).to(tl.float32)
decay = tl.exp(ld_val)
# Load kv_t[row, :] (one row of the outer product)
# 加载 kv_t[行, :] (外积的一行)
val = tl.load(
v_base + t * s_v_t + dv_offs * s_v_dv,
mask=dv_mask, other=0.0,
).to(tl.float32)
# Core recurrence: S_t[i,:] = α_t[i] · S_{t-1}[i,:] + kv_t[i,:]
# 核心递推: S_t[i,:] = α_t[i] · S_{t-1}[i,:] + kv_t[i,:]
acc = acc * decay + val
# Store S_t[row, :]
tl.store(
o_base + t * s_o_t + dv_offs * s_o_dv,
acc, mask=dv_mask,
)
@triton.jit
def _scan_bwd_kernel(
LD_ptr, O_ptr, GO_ptr, GV_ptr, GLD_ptr,
T, D_v,
s_ld_bhdk, s_ld_t,
s_o_bhdk, s_o_t, s_o_dv,
s_go_bhdk, s_go_t, s_go_dv,
s_gv_bhdk, s_gv_t, s_gv_dv,
s_gld_bhdk, s_gld_t,
BLOCK_DV: tl.constexpr,
):
"""
Backward scan kernel — computes gradients via adjoint method (vector decay).
反向扫描核函数 — 通过伴随方法计算梯度 (向量衰减)。
Each program handles one row of the state matrix (one d_k dimension).
The decay for this row is a scalar, so the log_decay gradient is:
∂L/∂log_α_t[i] = α_t[i] · Σ_j(λ_t[i,j] · y_{t-1}[i,j])
The sum over j (D_v) is computed within this single program — no atomic_add.
每个 program 处理状态矩阵的一行 (一个 d_k 维度)。
该行的衰减是标量, 因此 log_decay 梯度为:
∂L/∂log_α_t[i] = α_t[i] · Σ_j(λ_t[i,j] · y_{t-1}[i,j])
对 j (D_v) 的求和在单个 program 内完成 — 无需 atomic_add。
"""
bhdk = tl.program_id(0)
dvb = tl.program_id(1)
dv_offs = dvb * BLOCK_DV + tl.arange(0, BLOCK_DV)
dv_mask = dv_offs < D_v
# adj holds a_{t+1} · λ_{t+1}, initialized to 0 at the sequence end
# adj 保存 a_{t+1} · λ_{t+1}, 在序列末尾初始化为 0
adj = tl.zeros([BLOCK_DV], dtype=tl.float32)
for t_rev in range(T):
t = T - 1 - t_rev # reverse time / 逆序时间
# Load ∂L/∂y_t[row, :] (upstream gradient)
# 加载 ∂L/∂y_t[行, :] (上游梯度)
go = tl.load(
GO_ptr + bhdk * s_go_bhdk + t * s_go_t + dv_offs * s_go_dv,
mask=dv_mask, other=0.0,
).to(tl.float32)
# Adjoint: λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1}
# 伴随: λ_t = ∂L/∂y_t + a_{t+1} · λ_{t+1}
lam = go + adj
# ∂L/∂x_t[row,:] = λ_t (gradient of values)
# ∂L/∂x_t[行,:] = λ_t (值的梯度)
tl.store(
GV_ptr + bhdk * s_gv_bhdk + t * s_gv_t + dv_offs * s_gv_dv,
lam, mask=dv_mask,
)
# ∂L/∂log_α_t[i] = α_t[i] · Σ_j(λ_t[i,j] · y_{t-1}[i,j])
# Per-row scalar gradient: sum over D_v within this program.
# 逐行标量梯度: 在此 program 内对 D_v 求和。
ld_val = tl.load(LD_ptr + bhdk * s_ld_bhdk + t * s_ld_t).to(tl.float32)
a_t = tl.exp(ld_val)
if t > 0:
y_prev = tl.load(
O_ptr + bhdk * s_o_bhdk + (t - 1) * s_o_t + dv_offs * s_o_dv,
mask=dv_mask, other=0.0,
).to(tl.float32)
grad_ld = tl.sum(lam * y_prev) * a_t
tl.atomic_add(GLD_ptr + bhdk * s_gld_bhdk + t * s_gld_t, grad_ld)
# Prepare for next step (t-1): adj = a_t · λ_t
# 为下一步 (t-1) 准备: adj = a_t · λ_t
adj = a_t * lam
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Autograd Function — bridges Triton kernels with PyTorch autograd
# 自动微分函数 — 将 Triton 核函数与 PyTorch 自动微分桥接
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
class _ParallelScanFn(Function):
"""
Custom autograd function for the parallel prefix scan (vector decay).
并行前缀扫描的自定义 autograd 函数 (向量衰减)。
Forward: launches _scan_fwd_kernel to compute all prefix states.
Grid: (B*H*D_k, ceil(D_v/BLOCK_DV)), one program per state row.
Backward: launches _scan_bwd_kernel to compute gradients via adjoint method.
Per-row reduction eliminates most atomic_add overhead.
前向: 启动 _scan_fwd_kernel 计算所有前缀状态。
网格: (B*H*D_k, ceil(D_v/BLOCK_DV)), 每行状态一个 program。
反向: 启动 _scan_bwd_kernel 通过伴随方法计算梯度。
逐行归约消除大部分 atomic_add 开销。
"""
@staticmethod
def forward(ctx, log_decays: Tensor, values: Tensor) -> Tensor:
B, H, T, D_k, D_v = values.shape
# Reshape for row-parallel kernel:
# log_decays: [B, H, T, D_k] → permute to [B, H, D_k, T] → [B*H*D_k, T]
# values: [B, H, T, D_k, D_v] → permute to [B, H, D_k, T, D_v] → [B*H*D_k, T, D_v]
# 为行并行核函数重塑:
# log_decays: [B, H, T, D_k] → 转置为 [B, H, D_k, T] → [B*H*D_k, T]
# values: [B, H, T, D_k, D_v] → 转置为 [B, H, D_k, T, D_v] → [B*H*D_k, T, D_v]
ld_flat = log_decays.permute(0, 1, 3, 2).contiguous().reshape(B * H * D_k, T)
v_flat = values.permute(0, 1, 3, 2, 4).contiguous().reshape(B * H * D_k, T, D_v)
o_flat = torch.empty_like(v_flat)
BHDK = B * H * D_k
BLOCK_DV = min(triton.next_power_of_2(D_v), 1024)
# Grid: (B*H*D_k, ceil(D_v/BLOCK_DV)) — one program per (batch, head, row, dv-block)
# 网格: (B*H*D_k, ceil(D_v/BLOCK_DV))
grid = (BHDK, triton.cdiv(D_v, BLOCK_DV))
_scan_fwd_kernel[grid](
ld_flat, v_flat, o_flat,
T, D_v,
ld_flat.stride(0), ld_flat.stride(1),
v_flat.stride(0), v_flat.stride(1), v_flat.stride(2),
o_flat.stride(0), o_flat.stride(1), o_flat.stride(2),
BLOCK_DV=BLOCK_DV,
)
# Save for backward: need log_decays and forward outputs y_t
# 为反向传播保存: 需要 log_decays 和前向输出 y_t
ctx.save_for_backward(ld_flat, o_flat)
ctx.shape_info = (B, H, T, D_k, D_v, BHDK, BLOCK_DV)
# Reshape back: [B*H*D_k, T, D_v] → [B, H, D_k, T, D_v] → [B, H, T, D_k, D_v]
return o_flat.reshape(B, H, D_k, T, D_v).permute(0, 1, 3, 2, 4).contiguous()
@staticmethod
def backward(ctx, grad_output: Tensor):
ld_flat, o_flat = ctx.saved_tensors
B, H, T, D_k, D_v, BHDK, BLOCK_DV = ctx.shape_info
# Permute grad_output to match row-parallel layout: [B,H,T,D_k,D_v] → [B*H*D_k, T, D_v]
go_flat = grad_output.permute(0, 1, 3, 2, 4).contiguous().reshape(BHDK, T, D_v)
gv_flat = torch.empty_like(go_flat)
# Use f32 for gradient accumulation precision
# 使用 f32 保证梯度累积的精度
gld_flat = torch.zeros(BHDK, T, device=ld_flat.device, dtype=torch.float32)
grid = (BHDK, triton.cdiv(D_v, BLOCK_DV))
_scan_bwd_kernel[grid](
ld_flat, o_flat, go_flat, gv_flat, gld_flat,
T, D_v,
ld_flat.stride(0), ld_flat.stride(1),
o_flat.stride(0), o_flat.stride(1), o_flat.stride(2),
go_flat.stride(0), go_flat.stride(1), go_flat.stride(2),
gv_flat.stride(0), gv_flat.stride(1), gv_flat.stride(2),
gld_flat.stride(0), gld_flat.stride(1),
BLOCK_DV=BLOCK_DV,
)
# Reshape gradients back to original layout
# 重塑梯度回原始布局
# gld: [B*H*D_k, T] → [B, H, D_k, T] → [B, H, T, D_k]
grad_log_decays = gld_flat.to(grad_output.dtype).reshape(B, H, D_k, T).permute(0, 1, 3, 2).contiguous()
# gv: [B*H*D_k, T, D_v] → [B, H, D_k, T, D_v] → [B, H, T, D_k, D_v]
grad_values = gv_flat.reshape(B, H, D_k, T, D_v).permute(0, 1, 3, 2, 4).contiguous()
return grad_log_decays, grad_values
def _triton_parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
"""Triton-accelerated parallel scan entry point (vector decay).
Triton 加速的并行扫描入口 (向量衰减)。"""
return _ParallelScanFn.apply(log_decays, values)
else:
_triton_parallel_scan = None
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Public API / 公共接口
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def parallel_scan(log_decays: Tensor, values: Tensor) -> Tensor:
"""
Parallel prefix scan — computes all prefix monoid sums (vector decay).
并行前缀扫描 — 计算所有前缀幺半群和 (向量衰减)。
This is the training-time workhorse of Monoid Attention.
It computes S_1, S_2, ..., S_T where
S_t[i,:] = α_t[i]·S_{t-1}[i,:] + kv_t[i,:]
for ALL timesteps simultaneously.
这是幺半群注意力训练时的主力计算。
它同时计算所有时间步的 S_1, S_2, ..., S_T,
其中 S_t[i,:] = α_t[i]·S_{t-1}[i,:] + kv_t[i,:]。
Auto-dispatches based on device:
CUDA → Triton JIT kernel (fast, with custom backward)
CPU/MPS → PyTorch sequential scan (correct, slower)
根据设备自动分派:
CUDA → Triton JIT 核函数 (快速, 带自定义反向传播)
CPU/MPS → PyTorch 串行扫描 (正确, 较慢)
Args:
log_decays: [B, H, T, D_k] — log of per-dimension decay gates α_t
每维度衰减门 α_t 的对数
values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t
外积 k_t⊗v_t
Returns:
states: [B, H, T, D_k, D_v] — all prefix states S_1..S_T
所有前缀状态 S_1..S_T
"""
if _triton_parallel_scan is not None and values.is_cuda:
return _triton_parallel_scan(log_decays, values)
return _sequential_scan(log_decays, values)
def parallel_scan_with_state(
log_decays: Tensor, values: Tensor,
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
"""
Parallel prefix scan + extract final state for inference handoff (vector decay).
并行前缀扫描 + 提取最终状态用于推理切换 (向量衰减)。
Used during prefill: compute all training-time prefix states,
AND extract the final accumulated state S_T so that subsequent
tokens can be generated in O(1) RNN mode via monoid_op.
在预填充时使用: 计算所有训练时的前缀状态,
同时提取最终累积状态 S_T, 以便后续 token 可以
通过 monoid_op 以 O(1) RNN 模式生成。
This is the bridge between training mode (parallel scan)
and inference mode (sequential monoid_op).
这是训练模式 (并行扫描) 和推理模式 (串行 monoid_op) 之间的桥梁。
Args:
log_decays: [B, H, T, D_k]
values: [B, H, T, D_k, D_v]
Returns:
output: [B, H, T, D_k, D_v] — all prefix states S_1..S_T
所有前缀状态
final_state: (log_acc, S_T) where
log_acc: [B, H, D_k] — accumulated log-decay vector (for future monoid_op)
累积对数衰减向量 (供后续 monoid_op 使用)
final_state: [B, H, D_k, D_v] — S_T, the compressed causal summary
S_T, 压缩的因果摘要
"""
output = parallel_scan(log_decays, values)
# Sum all log-decays over T to get the total accumulated decay per dimension
# 对所有 log-decay 沿 T 求和得到每个维度的总累积衰减
log_acc = log_decays.sum(dim=2) # [B, H, D_k]
# The last timestep's state IS the full causal summary
# 最后一个时间步的状态就是完整的因果摘要
final_state = output[:, :, -1] # [B, H, D_k, D_v]
return output, (log_acc, final_state)