OzTianlu commited on
Commit
39fa0da
Β·
verified Β·
1 Parent(s): b0097d1

Upload MonoidForCausalLM.py

Browse files
Files changed (1) hide show
  1. MonoidForCausalLM.py +33 -1
MonoidForCausalLM.py CHANGED
@@ -62,7 +62,39 @@ from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin, Aut
62
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
63
  from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm
64
 
65
- from monoid_scan_cuda import parallel_scan, parallel_scan_with_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
 
68
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
 
62
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
63
  from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm
64
 
65
+ try:
66
+ from monoid_scan_cuda import parallel_scan, parallel_scan_with_state
67
+ except ImportError:
68
+ # Pure-PyTorch fallback (sequential scan) β€” works on CPU / MPS / any device.
69
+ # Slower than the fused CUDA kernel but numerically identical.
70
+
71
+ def parallel_scan(log_alpha: Tensor, kv: Tensor) -> Tensor:
72
+ """Sequential prefix scan fallback: S_t = exp(log_Ξ±_t)Β·S_{t-1} + kv_t."""
73
+ B, H, T, d1, d2 = kv.shape
74
+ states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
75
+ S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
76
+ for t in range(T):
77
+ decay = torch.exp(log_alpha[:, :, t]) # [B, H, 1]
78
+ while decay.dim() < S.dim():
79
+ decay = decay.unsqueeze(-1)
80
+ S = S * decay + kv[:, :, t]
81
+ states[:, :, t] = S
82
+ return states
83
+
84
+ def parallel_scan_with_state(log_alpha: Tensor, kv: Tensor):
85
+ """Sequential prefix scan that also returns the final (log_decay, S) state."""
86
+ B, H, T, d1, d2 = kv.shape
87
+ states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
88
+ S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
89
+ log_acc = torch.zeros(B, H, 1, device=log_alpha.device, dtype=log_alpha.dtype)
90
+ for t in range(T):
91
+ decay = torch.exp(log_alpha[:, :, t])
92
+ while decay.dim() < S.dim():
93
+ decay = decay.unsqueeze(-1)
94
+ S = S * decay + kv[:, :, t]
95
+ states[:, :, t] = S
96
+ log_acc = log_acc + log_alpha[:, :, t]
97
+ return states, (log_acc, S)
98
 
99
 
100
  # ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━