File size: 43,968 Bytes
b0097d1 b6c0790 b0097d1 b6c0790 b0097d1 39fa0da b6c0790 39fa0da b6c0790 39fa0da b6c0790 39fa0da b0097d1 471fab3 b0097d1 b6c0790 b0097d1 b6c0790 b0097d1 b6c0790 b0097d1 b6c0790 b0097d1 b6c0790 b0097d1 b6c0790 b0097d1 b6c0790 b0097d1 471fab3 ff3696b 471fab3 ff3696b 471fab3 ff3696b b6c0790 ff3696b b6c0790 ff3696b 471fab3 ff3696b 471fab3 b0097d1 1924b81 b0097d1 1924b81 b0097d1 1924b81 b6c0790 1924b81 b6c0790 1924b81 64510f9 b0097d1 64510f9 b0097d1 b6c0790 b0097d1 b6c0790 b0097d1 471fab3 b6c0790 b0097d1 b6c0790 b0097d1 471fab3 b6c0790 471fab3 b6c0790 b0097d1 b6c0790 b0097d1 b6c0790 b0097d1 b6c0790 b0097d1 b6c0790 b0097d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 |
"""
MonoidForCausalLM โ Causal Monoid Language Model (HuggingFace Compatible)
MonoidForCausalLM โ ๅนบๅ็พคๅ ๆ่ฏญ่จๆจกๅ (ๅ
ผๅฎน HuggingFace)
Architecture / ๆถๆๆฆ่ฆ:
Replace softmax attention with a monoid parallel-scan recurrence.
็จๅนบๅ็พคๅนถ่กๆซๆ้ๆจๆฟไปฃ softmax ๆณจๆๅใ
Core idea / ๆ ธๅฟๆๆณ:
Softmax attention computes o_t = ฮฃ_{iโคt} softmax(q_tยทk_i) v_i
โ requires O(T) KV-cache per layer at inference.
Softmax ๆณจๆๅ่ฎก็ฎ o_t = ฮฃ_{iโคt} softmax(q_tยทk_i) v_i
โ ๆจ็ๆถๆฏๅฑ้่ฆ O(T) ็ KV ็ผๅญใ
Monoid attention compresses the entire causal history into a
fixed-size state matrix S_t โ โ^{dรd} per head:
S_t = diag(ฮฑ_t) ยท S_{t-1} + k_t โ v_t (vector decay recurrence)
o_t = q_t ยท S_t (state readout)
where ฮฑ_t โ โ^d is a per-dimension vector decay gate.
ๅนบๅ็พคๆณจๆๅๅฐๅฎๆดๅ ๆๅๅฒๅ็ผฉๅฐๆฏไธชๅคดไธไธชๅบๅฎๅคงๅฐ็็ถๆ็ฉ้ต S_t:
S_t = diag(ฮฑ_t) ยท S_{t-1} + k_t โ v_t (ๅ้่กฐๅ้ๆจ)
o_t = q_t ยท S_t (็ถๆ่ฏปๅบ)
ๅ
ถไธญ ฮฑ_t โ โ^d ๆฏ้็ปดๅบฆ็ๅ้่กฐๅ้จใ
This is a monoid because the binary operator:
(log_ฮฑ, S) โ (log_ฮฒ, X) = (log_ฮฑ + log_ฮฒ, exp(log_ฮฒ)ยทS + X)
is associative โ enables parallel prefix scan for training,
and O(1) sequential update for inference.
่ฟๆฏไธไธชๅนบๅ็พค๏ผๅ ไธบไบๅ
็ฎๅญ:
(log_ฮฑ, S) โ (log_ฮฒ, X) = (log_ฮฑ + log_ฮฒ, exp(log_ฮฒ)ยทS + X)
ๆปก่ถณ็ปๅๅพ โ ่ฎญ็ปๆถๅฏ็จๅนถ่กๅ็ผๆซๆ๏ผๆจ็ๆถ O(1) ้ๆญฅ้ๆจใ
Key properties / ๅ
ณ้ฎ็นๆง:
โ Explicit causal modeling โ ฮฑ_t gate explicitly controls how fast
past information decays, making causality a first-class citizen.
ๆพๅผๅ ๆๅปบๆจก โ ฮฑ_t ่กฐๅ้จๆพๅผๆงๅถๅๅฒไฟกๆฏ็้ๅฟ้็๏ผ
ๅ ๆๆงๆฏไธ็ญๅ
ฌๆฐ่้้ mask ๆฝๅ ็็บฆๆใ
โ Monoid state compression โ the full causal prefix x_{1:t} is
lossily compressed into a fixed-size (dรd) state matrix per head.
No O(T) KV-cache needed; inference is O(1) per token per layer.
ๅนบๅ็พค็ถๆๅ็ผฉ โ ๅฎๆดๅ ๆๅ็ผ x_{1:t} ่ขซๆๆๅ็ผฉๅฐๆฏไธชๅคด
ๅบๅฎๅคงๅฐ็ (dรd) ็ถๆ็ฉ้ตไธญใๆ ้ O(T) KV ็ผๅญ๏ผ
ๆจ็ๆถๆฏๅฑๆฏ token O(1)ใ
โ Parallel training โ associativity of โ enables O(T) parallel
prefix scan (vs O(Tยฒ) for softmax attention).
ๅนถ่ก่ฎญ็ป โ โ ็็ปๅๅพไฝฟ O(T) ๅนถ่กๅ็ผๆซๆๆไธบๅฏ่ฝ
(ๅฏนๆฏ softmax ๆณจๆๅ็ O(Tยฒ))ใ
Reuses LlamaMLP + LlamaRMSNorm from HuggingFace Transformers.
ๅค็จ HuggingFace Transformers ็ LlamaMLP + LlamaRMSNormใ
"""
from __future__ import annotations
from typing import Optional, Union
import torch
import torch.nn as nn
from torch import Tensor
from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin, AutoConfig, AutoModelForCausalLM
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm
try:
from monoid_scan_cuda import parallel_scan, parallel_scan_with_state
except ImportError:
# Pure-PyTorch fallback (sequential scan) โ works on CPU / MPS / any device.
# Slower than the fused CUDA kernel but numerically identical.
def parallel_scan(log_alpha: Tensor, kv: Tensor) -> Tensor:
"""Sequential prefix scan fallback: S_t[i,:] = exp(log_ฮฑ_t[i])ยทS_{t-1}[i,:] + kv_t[i,:]."""
B, H, T, d1, d2 = kv.shape
states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
for t in range(T):
decay = torch.exp(log_alpha[:, :, t]) # [B, H, d]
while decay.dim() < S.dim():
decay = decay.unsqueeze(-1)
S = S * decay + kv[:, :, t]
states[:, :, t] = S
return states
def parallel_scan_with_state(log_alpha: Tensor, kv: Tensor):
"""Sequential prefix scan that also returns the final (log_decay, S) state."""
B, H, T, d1, d2 = kv.shape
states = torch.zeros(B, H, T, d1, d2, device=kv.device, dtype=kv.dtype)
S = torch.zeros(B, H, d1, d2, device=kv.device, dtype=kv.dtype)
log_acc = torch.zeros(B, H, d1, device=log_alpha.device, dtype=log_alpha.dtype)
for t in range(T):
decay = torch.exp(log_alpha[:, :, t])
while decay.dim() < S.dim():
decay = decay.unsqueeze(-1)
S = S * decay + kv[:, :, t]
states[:, :, t] = S
log_acc = log_acc + log_alpha[:, :, t]
return states, (log_acc, S)
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Config / ้
็ฝฎ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
class MonoidConfig(PretrainedConfig):
"""
Configuration for the Monoid causal language model.
ๅนบๅ็พคๅ ๆ่ฏญ่จๆจกๅ็้
็ฝฎใ
Mirrors LlamaConfig for the shared components (MLP, RMSNorm, embedding)
so that weights can be directly transferred from Llama checkpoints.
ไธ LlamaConfig ็ๅ
ฑไบซ็ปไปถ (MLP, RMSNorm, embedding) ไฟๆไธ่ด,
ไปฅไพฟไป Llama ๆฃๆฅ็น็ดๆฅ่ฟ็งปๆ้ใ
"""
model_type = "monoid"
def __init__(
self,
vocab_size: int = 32000,
hidden_size: int = 576,
intermediate_size: int = 1536,
num_hidden_layers: int = 30,
num_attention_heads: int = 9,
head_dim: int = 64,
max_position_embeddings: int = 2048,
rms_norm_eps: float = 1e-5,
hidden_act: str = "silu",
mlp_bias: bool = False,
attention_bias: bool = False,
tie_word_embeddings: bool = True,
initializer_range: float = 0.041666666666666664,
pad_token_id: int = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.max_position_embeddings = max_position_embeddings
self.rms_norm_eps = rms_norm_eps
self.hidden_act = hidden_act
self.mlp_bias = mlp_bias
self.attention_bias = attention_bias
self.initializer_range = initializer_range
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Monoid Cache โ O(1) state replaces O(T) KV-Cache
# ๅนบๅ็พค็ผๅญ โ O(1) ็ถๆๆฟไปฃ O(T) KV ็ผๅญ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
class MonoidCache:
"""
Per-layer monoid state cache for autoregressive inference.
่ชๅๅฝๆจ็็้ๅฑๅนบๅ็พค็ถๆ็ผๅญใ
Unlike Transformer KV-Cache that stores all past keys & values (O(T) memory),
each layer here stores exactly ONE state tuple:
(log_decay_acc, S) where S โ โ^{B, H, d, d}
This is the monoid "sum" of all past (log_ฮฑ_i, k_iโv_i) via โ.
Memory is O(1) per layer regardless of sequence length.
ไธๅไบ Transformer ็ KV-Cache (ๅญๅจๆๆ่ฟๅป็ key ๅ value, O(T) ๅ
ๅญ),
่ฟ้ๆฏๅฑไป
ๅญๅจไธไธช็ถๆๅ
็ป:
(log_decay_acc, S) ๅ
ถไธญ S โ โ^{B, H, d, d}
่ฟๆฏๆๆ่ฟๅป็ (log_ฮฑ_i, k_iโv_i) ้่ฟ โ ็ดฏ็งฏ็ๅนบๅ็พค "ๅ"ใ
ๆ ่ฎบๅบๅๅค้ฟ๏ผๆฏๅฑๅ
ๅญ O(1)ใ
"""
def __init__(self):
self.states: list[tuple[Tensor, Tensor] | None] = []
self.seen_tokens: int = 0
def get_seq_length(self, layer_idx: int = 0) -> int:
return self.seen_tokens
def update(self, layer_idx: int, state: tuple[Tensor, Tensor]):
"""Store the accumulated monoid state for a given layer.
ๅญๅจๆๅฎๅฑ็็ดฏ็งฏๅนบๅ็พค็ถๆใ"""
while len(self.states) <= layer_idx:
self.states.append(None)
self.states[layer_idx] = state
def get_state(self, layer_idx: int) -> tuple[Tensor, Tensor] | None:
"""Retrieve the accumulated monoid state for a given layer.
่ทๅๆๅฎๅฑ็็ดฏ็งฏๅนบๅ็พค็ถๆใ"""
if layer_idx < len(self.states):
return self.states[layer_idx]
return None
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorder cache for beam search. ไธบ beam search ้ๆ็ผๅญใ"""
for i, state in enumerate(self.states):
if state is not None:
log_d, kv = state
self.states[i] = (log_d[beam_idx], kv[beam_idx])
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Monoid Operator โ the algebraic heart
# ๅนบๅ็พค็ฎๅญ โ ไปฃๆฐๆ ธๅฟ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
def monoid_op(
a: tuple[Tensor, Tensor],
b: tuple[Tensor, Tensor],
) -> tuple[Tensor, Tensor]:
"""
The monoid binary operator โ on (log-space vector decay, state matrix) pairs.
ๅนบๅ็พคไบๅ
็ฎๅญ โ๏ผไฝ็จไบ (ๅฏนๆฐๅ้่กฐๅ, ็ถๆ็ฉ้ต) ๅฏนใ
Definition / ๅฎไน:
(log_ฮฑ, S) โ (log_ฮฒ, X) = (log_ฮฑ + log_ฮฒ, diag(exp(log_ฮฒ))ยทS + X)
where log_ฮฑ, log_ฮฒ โ โ^d are per-dimension log decay vectors.
Why this is a monoid / ไธบไปไน่ฟๆฏๅนบๅ็พค:
โข Associativity / ็ปๅๅพ:
(a โ b) โ c = a โ (b โ c) โ
This enables parallel prefix scan for training (reduce tree)
and O(1) left-fold for inference (sequential append).
็ปๅๅพไฝฟ่ฎญ็ปๆถๅฏไปฅ็จๅนถ่กๅ็ผๆซๆ (ๅฝ็บฆๆ ),
ๆจ็ๆถๅฏไปฅ O(1) ๅทฆๆๅ (้ๆญฅ่ฟฝๅ )ใ
โข Identity / ๅไฝๅ
:
e = (0, 0) โ e โ a = a โ e = a โ
Why log-space / ไธบไปไน็จๅฏนๆฐ็ฉบ้ด:
Working in log-space for the decay factor avoids numerical
underflow when ฮฑ^T โ 0 for long sequences.
่กฐๅๅ ๅญๅจๅฏนๆฐ็ฉบ้ดไธญ่ฟ็ฎ๏ผ้ฟๅ
้ฟๅบๅไธ ฮฑ^T โ 0 ็ๆฐๅผไธๆบขใ
Causal semantics / ๅ ๆ่ฏญไน:
S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t
The decay ฮฑ_t โ (0,1) explicitly controls how much of the past
the model retains. This is *explicit causal modeling* โ the model
must learn to balance retention vs novelty at every timestep.
่กฐๅ ฮฑ_t โ (0,1) ๆพๅผๆงๅถๆจกๅไฟ็ๅคๅฐ่ฟๅปไฟกๆฏใ
่ฟๅฐฑๆฏ *ๆพๅผๅ ๆๅปบๆจก* โ ๆจกๅๅฟ
้กปๅจๆฏไธชๆถ้ดๆญฅๅญฆไน ๅฆไฝ
ๅนณ่กกไฟ็ๆงไฟกๆฏไธๅธๆถๆฐไฟกๆฏใ
"""
log_a, kv_a = a
log_b, kv_b = b
new_log = log_a + log_b # log(ฮฑยทฮฒ) = log_ฮฑ + log_ฮฒ
decay_b = torch.exp(log_b) # ฮฒ = exp(log_ฮฒ)
while decay_b.dim() < kv_a.dim():
decay_b = decay_b.unsqueeze(-1) # broadcast to [B,H,...,1,1]
return new_log, kv_a * decay_b + kv_b # ฮฒยทS + X
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Monoid Attention โ the core innovation
# ๅนบๅ็พคๆณจๆๅ โ ๆ ธๅฟๅๆฐๅฑ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
class MonoidAttention(nn.Module):
"""
Monoid Causal Attention โ replaces softmax attention entirely.
ๅนบๅ็พคๅ ๆๆณจๆๅ โ ๅฎๅ
จๆฟไปฃ softmax ๆณจๆๅใ
Key differences from standard attention / ไธๆ ๅๆณจๆๅ็ๅ
ณ้ฎๅบๅซ:
โ No RoPE / positional encoding โ position is implicitly encoded
by the causal decay gate ฮฑ_t. The model learns *when* to forget
rather than encoding *where* tokens are.
ไธไฝฟ็จ RoPE / ไฝ็ฝฎ็ผ็ โ ไฝ็ฝฎไฟกๆฏ็ฑๅ ๆ่กฐๅ้จ ฮฑ_t ้ๅผ็ผ็ ใ
ๆจกๅๅญฆไน *ไฝๆถ้ๅฟ* ่้็ผ็ token *ๅจๅช้*ใ
โ No KV-Cache โ replaced by MonoidCache with O(1) state per layer.
Each state S โ โ^{Hรdรd} is a compressed summary of ALL past tokens.
ไธไฝฟ็จ KV ็ผๅญ โ ็ฑ O(1) ็ MonoidCache ็ถๆๆฟไปฃใ
ๆฏไธช็ถๆ S โ โ^{Hรdรd} ๆฏๆๆ่ฟๅป token ็ๅ็ผฉๆ่ฆใ
โ No attention mask โ causality is built into the recurrence itself.
S_t only depends on S_{t-1} and the current token by construction.
ไธไฝฟ็จๆณจๆๅๆฉ็ โ ๅ ๆๆงๅ
ๅปบไบ้ๆจ็ปๆๆฌ่บซใ
S_t ไป
ไพ่ต S_{t-1} ๅๅฝๅ token๏ผ็ปๆไธไฟ่ฏๅ ๆๆงใ
Computation / ่ฎก็ฎ:
Training (parallel scan, O(T)):
k_t = SiLU(k_proj(x_t)) # non-negative keys for PSD state
S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t # monoid recurrence via prefix scan
o_t = q_t ยท S_t # linear readout from state
Inference (RNN mode, O(1) per token):
Same recurrence, but applied one token at a time.
่ฎญ็ป (ๅนถ่กๆซๆ, O(T)):
k_t = SiLU(k_proj(x_t)) # ้่ด key ไฟ่ฏ็ถๆ็ฉ้ตๅๆญฃๅฎ
S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t # ้่ฟๅ็ผๆซๆๅฎ็ฐๅนบๅ็พค้ๆจ
o_t = q_t ยท S_t # ไป็ถๆไธญ็บฟๆง่ฏปๅบ
ๆจ็ (RNN ๆจกๅผ, ๆฏ token O(1)):
ๅไธ้ๆจๅ
ฌๅผ, ไฝ้ token ้กบๅบๅบ็จใ
"""
def __init__(self, config: MonoidConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.scaling = self.head_dim ** -0.5 # 1/โd, scale factor for qยทS readout
# qยทS ่ฏปๅบ็็ผฉๆพๅ ๅญ
# --- Projections (transferred from Llama) ---
# --- ๆๅฝฑๅฑ (ไป Llama ่ฟ็งป) ---
# q_proj, o_proj: identical dims to Llama, direct copy
# k_proj, v_proj: Llama GQA has fewer KV heads; we tile to full heads
# q_proj, o_proj: ็ปดๅบฆไธ Llama ไธ่ด, ็ดๆฅๅคๅถ
# k_proj, v_proj: Llama GQA ็ KV ๅคดๆดๅฐ; ๆไปฌ้ๅคๅฐๅ
จๅคดๆฐ
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
# --- Decay gate (novel component, randomly initialized) ---
# --- ่กฐๅ้จ (ๅ
จๆฐ็ปไปถ, ้ๆบๅๅงๅ) ---
# Projects hidden_size โ num_heads * head_dim, yielding a VECTOR per head.
# Activation: log_ฮฑ = -softplus(Wx + b), giving ฮฑ โ (0, 1].
# Vector decay: S_t = diag(ฮฑ_t) ยท S_{t-1} + k_t โ v_t
# Different feature dimensions can have independent lifetimes:
# - fast-decaying dims for local syntax
# - slow-decaying dims for global entity/fact memory
# ๅฐ hidden_size ๆๅฝฑๅฐ num_heads * head_dim, ๆฏไธชๅคดไบง็ไธไธชๅ้ใ
# ๆฟๆดป: log_ฮฑ = -softplus(Wx + b), ไฝฟ ฮฑ โ (0, 1]ใ
# ๅ้่กฐๅ: S_t = diag(ฮฑ_t) ยท S_{t-1} + k_t โ v_t
# ไธๅ็นๅพ็ปดๅบฆๆฅๆ็ฌ็ซ็็ๅฝๅจๆ:
# - ๅฟซ้่กฐๅ็็ปดๅบฆ่ด่ดฃๅฑ้จ่ฏญๆณ็ปๆ
# - ๆ
ข้่กฐๅ็็ปดๅบฆ่ด่ดฃๅ
จๅฑๅฎไฝๅไบๅฎ่ฎฐๅฟ
self.decay_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=True)
# --- QK-Norm (novel component, randomly initialized) ---
# --- QK ๅฝไธๅ (ๅ
จๆฐ็ปไปถ, ้ๆบๅๅงๅ) ---
# Stabilizes the scale of qยทS readout. Without this, the state
# matrix S (sum of outer products) can grow unboundedly.
# ็จณๅฎ qยทS ่ฏปๅบ็ๅฐบๅบฆใๆฒกๆ่ฟไธช, ็ถๆ็ฉ้ต S (ๅค็งฏไนๅ)
# ๅฏ่ฝๆ ็ๅข้ฟใ
self.q_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
# --- Learnable initial state h0 (novel component, zero-initialized) ---
# --- ๅฏๅญฆไน ๅๅง็ถๆ h0 (ๅ
จๆฐ็ปไปถ, ้ถๅๅงๅ) ---
# S_0 = h0 โ โ^{1, H, d, d}, shared across batch.
# Zero-init means the model starts with "no memory" โ a clean slate.
# The model can learn a non-zero h0 as a kind of "system prompt" state.
# S_0 = h0 โ โ^{1, H, d, d}, ่ทจ batch ๅ
ฑไบซใ
# ้ถๅๅงๅๆๅณ็ๆจกๅไป"ๆ ่ฎฐๅฟ"ๅผๅง โ ไธๅผ ็ฝ็บธใ
# ๆจกๅๅฏไปฅๅญฆไน ้้ถ็ h0 ไฝไธบไธ็ง"็ณป็ปๆ็คบ"็ถๆใ
self.h0 = nn.Parameter(torch.zeros(1, self.num_heads, self.head_dim, self.head_dim))
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor | None = None,
monoid_cache: MonoidCache | None = None,
use_cache: bool = False,
) -> tuple[Tensor, tuple[Tensor, Tensor] | None]:
"""
Args:
hidden_states: [B, T, hidden_size]
attention_mask: [B, T] with 1=real token, 0=pad.
For PAD positions: ฮฑ=1 (preserve state), kv=0 (no contribution).
ๆฉ็ : 1=็ๅฎtoken, 0=ๅกซๅ
ใ
ๅกซๅ
ไฝ็ฝฎ: ฮฑ=1 (ไฟๆ็ถๆไธๅ), kv=0 (ๆ ่ดก็ฎ)ใ
monoid_cache: O(1) state cache for inference
ๆจ็็จ O(1) ็ถๆ็ผๅญ
use_cache: whether to use/update the cache
ๆฏๅฆไฝฟ็จ/ๆดๆฐ็ผๅญ
Returns:
output: [B, T, hidden_size]
final_state: (log_decay_acc, S) or None
"""
B, T, _ = hidden_states.shape
H, d = self.num_heads, self.head_dim
# --- Project to multi-head Q, K, V ---
# --- ๆๅฝฑๅฐๅคๅคด Q, K, V ---
q = self.q_proj(hidden_states).view(B, T, H, d).transpose(1, 2) # [B,H,T,d]
k = self.k_proj(hidden_states).view(B, T, H, d).transpose(1, 2)
v = self.v_proj(hidden_states).view(B, T, H, d).transpose(1, 2)
# --- QK-Norm: stabilize qยทS readout scale ---
# --- QK ๅฝไธๅ: ็จณๅฎ qยทS ่ฏปๅบๅฐบๅบฆ ---
q = self.q_norm(q) * self.scaling
k = self.k_norm(k)
# --- Non-negative keys via SiLU ---
# --- ้่ฟ SiLU ไฟ่ฏ key ้่ด ---
# Why: the state S = ฮฃ ฮฑ^{t-i} k_iโv_i is a sum of outer products.
# Non-negative k ensures S is positive semi-definite (PSD),
# preventing "feature erasure" where one token's contribution
# cancels another's. PSD guarantees monotonic information accumulation.
# ๅๅ : ็ถๆ S = ฮฃ ฮฑ^{t-i} k_iโv_i ๆฏๅค็งฏไนๅใ
# ้่ด็ k ไฟ่ฏ S ๅๆญฃๅฎ (PSD), ้ฒๆญขไธไธช token ็่ดก็ฎ
# ๆตๆถๅฆไธไธช token ็"็นๅพๆฆ้ค"็ฐ่ฑกใ
# PSD ไฟ่ฏไฟกๆฏๅ่ฐ็งฏ็ดฏใ
k = torch.nn.functional.silu(k)
# --- Compute per-dimension vector decay gate ฮฑ_t ---
# --- ่ฎก็ฎๆฏ็ปดๅบฆๅ้่กฐๅ้จ ฮฑ_t ---
# Negative Softplus: log_ฮฑ = -softplus(Wx + b)
# Value range: log_ฮฑ โ (-โ, 0), i.e. ฮฑ โ (0, 1].
# When Wx โ -โ: softplus โ 0, ฮฑ โ 1 (perfect memory, no forgetting)
# When Wx โ +โ: softplus โ Wx, ฮฑ โ 0 (complete forgetting)
# This avoids ฮฑ > 1 explosion (unlike SiLU) while still allowing
# ฮฑ = 1 for lossless memory (unlike Sigmoid which caps at <1).
# Each dimension of the d-vector decays independently:
# S_t[i,j] = ฮฑ_t[i] ยท S_{t-1}[i,j] + k_t[i] ยท v_t[j]
#
# ่ด Softplus: log_ฮฑ = -softplus(Wx + b)
# ๅผๅ: log_ฮฑ โ (-โ, 0), ๅณ ฮฑ โ (0, 1]ใ
# ๅฝ Wx โ -โ: softplus โ 0, ฮฑ โ 1 (ๅฎ็พ่ฎฐๅฟ, ไธ้ๅฟ)
# ๅฝ Wx โ +โ: softplus โ Wx, ฮฑ โ 0 (ๅฎๅ
จ้ๅฟ)
# ้ฟๅ
ไบ SiLU ็ ฮฑ > 1 ็็ธ, ๅๆถๅ
่ฎธ ฮฑ = 1 ๆ ๆ่ฎฐๅฟ (Sigmoid ๆ ๆณๅๅฐ)ใ
# d-ๅ้็ๆฏไธช็ปดๅบฆ็ฌ็ซ่กฐๅ:
# S_t[i,j] = ฮฑ_t[i] ยท S_{t-1}[i,j] + k_t[i] ยท v_t[j]
raw = self.decay_proj(hidden_states) # [B,T,H*d]
log_alpha = -torch.nn.functional.softplus(raw) # [B,T,H*d]
log_alpha = log_alpha.view(B, T, H, d).transpose(1, 2) # [B,H,T,d]
# --- Apply attention_mask: PAD tokens must be invisible to the recurrence ---
# --- ๅบ็จๆณจๆๅๆฉ็ : PAD token ๅฟ
้กปๅฏน้ๆจไธๅฏ่ง ---
# For PAD positions (mask=0): set log_ฮฑ=0 (ฮฑ=1, preserve state) and kv=0 (no contribution).
# This makes S_t = 1ยทS_{t-1} + 0 = S_{t-1}, i.e. PAD is a no-op on the state.
# ๅฏนไบ PAD ไฝ็ฝฎ (mask=0): ่ฎพ log_ฮฑ=0 (ฮฑ=1, ไฟๆ็ถๆ) ไธ kv=0 (ๆ ่ดก็ฎ)ใ
# ่ฟไฝฟๅพ S_t = 1ยทS_{t-1} + 0 = S_{t-1}, ๅณ PAD ๅฏน็ถๆๆฏ็ฉบๆไฝใ
if attention_mask is not None:
# attention_mask: [B, T] โ [B, 1, T, 1] for broadcasting with [B, H, T, d]
mask = attention_mask[:, None, :, None].to(log_alpha.dtype) # [B,1,T,1]
log_alpha = log_alpha * mask # PAD โ log_ฮฑ=0 โ ฮฑ=1
k = k * mask # PAD โ k=0
v = v * mask # PAD โ v=0 โ kv=0
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Inference path (RNN mode): O(1) per token per layer
# ๆจ็่ทฏๅพ (RNN ๆจกๅผ): ๆฏๅฑๆฏ token O(1)
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# When generating, T=1. We apply the monoid operator once
# to fold the new token into the accumulated state.
# This is where "O(1) inference" materializes:
# S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t (one monoid_op call)
# o_t = q_t ยท S_t (one matmul)
# Total: O(Hยทdยฒ) per layer โ independent of sequence length.
#
# ็ๆๆถ T=1ใๆไปฌ่ฐ็จไธๆฌกๅนบๅ็พค็ฎๅญๅฐๆฐ token ๆๅ ่ฟ็ดฏ็งฏ็ถๆใ
# ่ฟๅฐฑๆฏ "O(1) ๆจ็" ็ๅ
ทไฝไฝ็ฐ:
# S_t = ฮฑ_t ยท S_{t-1} + k_t โ v_t (ไธๆฌก monoid_op)
# o_t = q_t ยท S_t (ไธๆฌก็ฉ้ตไนๆณ)
# ๆป่ฎก: ๆฏๅฑ O(Hยทdยฒ) โ ไธๅบๅ้ฟๅบฆๆ ๅ
ณใ
if use_cache and T == 1:
# Outer product: k_t โ v_t โ โ^{Hรdรd}
# ๅค็งฏ: k_t โ v_t โ โ^{Hรdรd}
kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, 0], v[:, :, 0])
log_t = log_alpha[:, :, 0] # [B,H,d]
prev = monoid_cache.get_state(self.layer_idx) if monoid_cache else None
if prev is None:
# First token: initialize from learnable h0
# ็ฌฌไธไธช token: ไปๅฏๅญฆไน ็ h0 ๅๅงๅ
decay_t = torch.exp(log_t)
while decay_t.dim() < self.h0.dim():
decay_t = decay_t.unsqueeze(-1)
new_state = (log_t, self.h0.expand(B, -1, -1, -1) * decay_t + kv_t)
else:
# Subsequent tokens: fold via monoid_op โ O(1)!
# ๅ็ปญ token: ้่ฟ monoid_op ๆๅ โ O(1)!
new_state = monoid_op(prev, (log_t, kv_t))
if monoid_cache is not None:
monoid_cache.update(self.layer_idx, new_state)
# Readout: o_t = q_t ยท S_t
# ่ฏปๅบ: o_t = q_t ยท S_t
o = torch.einsum('bhd, bhde -> bhe', q[:, :, 0], new_state[1])
# Reshape [B,H,d] โ [B,1,H*d] (heads contiguous, matching scan path)
# ้ๅก [B,H,d] โ [B,1,H*d] (ๅคด่ฟ็ปญๆๅ, ไธๆซๆ่ทฏๅพไธ่ด)
o = o.contiguous().view(B, 1, -1)
return self.o_proj(o), new_state
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Inference prefill (use_cache=True, T>1): parallel scan + readout
# ๆจ็้ขๅกซๅ
(use_cache=True, T>1): ๅนถ่กๆซๆ + ่ฏปๅบ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Uses the same parallel_scan_with_state as training to leverage
# Triton CUDA kernel acceleration instead of O(T) Python loop.
# Memory: O(BยทHยทTยทdยฒ) โ same as training path.
# ไฝฟ็จไธ่ฎญ็ป็ธๅ็ parallel_scan_with_state ๆฅๅฉ็จ
# Triton CUDA ๆ ธๅฝๆฐๅ ้, ่้ O(T) ็ Python ๅพช็ฏใ
# ๅ
ๅญ: O(BยทHยทTยทdยฒ) โ ไธ่ฎญ็ป่ทฏๅพ็ธๅใ
if use_cache:
kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
states, (log_acc, S_T) = parallel_scan_with_state(log_alpha, kv)
# Add h0 contribution: S_t += diag(โ_{i=0}^{t} ฮฑ_i) ยท h0
# ๅ ๅ h0 ่ดก็ฎ: S_t += diag(โ_{i=0}^{t} ฮฑ_i) ยท h0
cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,d]
h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,d,1]
states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
# Final state includes h0 contribution
# ๆ็ป็ถๆๅ
ๅซ h0 ่ดก็ฎ
total_h0_decay = torch.exp(log_acc).unsqueeze(-1) # [B,H,d,1]
S_final = S_T + total_h0_decay * self.h0.squeeze(0) # [B,H,d,d]
# h0 is [1,H,d,d], squeeze(0) removed for clarity but expand also works
final_state = (log_acc, S_final)
if monoid_cache is not None:
monoid_cache.update(self.layer_idx, final_state)
# Vectorized readout: o_t = q_t ยท S_t for all t
# ๅ้ๅ่ฏปๅบ: ไธๆฌกๆง่ฎก็ฎๆๆ t ็ o_t = q_t ยท S_t
o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d]
o = o.transpose(1, 2).contiguous().view(B, T, -1)
return self.o_proj(o), final_state
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Training path: parallel scan + vectorized readout
# ่ฎญ็ป่ทฏๅพ: ๅนถ่กๆซๆ + ๅ้ๅ่ฏปๅบ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Materialize full kv tensor [B,H,T,d,d] and scan in one pass.
# Memory: O(BยทHยทTยทdยฒ) โ trades memory for speed.
# Eliminates Tร30 Python-loop kernel launches for outer product
# and readout; scan itself is parallel when CUDA kernel available.
#
# ็ฉๅๅฎๆด kv ๅผ ้ [B,H,T,d,d] ๅนถไธๆฌกๆงๆซๆใ
# ๅ
ๅญ: O(BยทHยทTยทdยฒ) โ ไปฅๅ
ๅญๆข้ๅบฆใ
# ๆถ้คๅค็งฏๅ่ฏปๅบ็ Tร30 ๆฌก Python ๅพช็ฏ kernel launch;
# ๅฝ CUDA kernel ๅฏ็จๆถๆซๆๆฌ่บซไนๆฏๅนถ่ก็ใ
# Vectorized outer product: kv_t = k_t โ v_t for all t at once
# ๅ้ๅๅค็งฏ: ไธๆฌกๆง่ฎก็ฎๆๆ t ็ k_t โ v_t
kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
# Parallel prefix scan: S_t = diag(ฮฑ_t)ยทS_{t-1} + kv_t (from S=0)
# ๅนถ่กๅ็ผๆซๆ: S_t = diag(ฮฑ_t)ยทS_{t-1} + kv_t (ไป S=0 ๅผๅง)
# log_alpha is [B,H,T,d] โ vector decay per dimension.
# log_alpha ไธบ [B,H,T,d] โ ๆฏ็ปดๅบฆๅ้่กฐๅใ
states = parallel_scan(log_alpha, kv) # [B,H,T,d,d]
# Add h0 contribution: S_t += diag(โ_{i=0}^{t} ฮฑ_i) ยท h0
# ๅ ๅ h0 ่ดก็ฎ: S_t += diag(โ_{i=0}^{t} ฮฑ_i) ยท h0
cum_log_alpha = torch.cumsum(log_alpha, dim=2) # [B,H,T,d]
h0_decay = torch.exp(cum_log_alpha).unsqueeze(-1) # [B,H,T,d,1]
states = states + h0_decay * self.h0.unsqueeze(2) # broadcast h0 [1,H,1,d,d]
# Vectorized readout: o_t = q_t ยท S_t for all t at once
# ๅ้ๅ่ฏปๅบ: ไธๆฌกๆง่ฎก็ฎๆๆ t ็ q_t ยท S_t
o = torch.einsum('bhtd, bhtde -> bhte', q, states) # [B,H,T,d]
o = o.transpose(1, 2).contiguous().view(B, T, -1)
return self.o_proj(o), None
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Decoder Layer: MonoidAttn + LlamaMLP + LlamaRMSNorm
# ่งฃ็ ๅฑ: ๅนบๅ็พคๆณจๆๅ + LlamaMLP + LlamaRMSNorm
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
class MonoidDecoderLayer(nn.Module):
"""
Pre-Norm Transformer block with Monoid attention.
ไฝฟ็จๅนบๅ็พคๆณจๆๅ็ Pre-Norm Transformer ๅใ
Data flow / ๆฐๆฎๆต:
x โ RMSNorm โ MonoidAttn โ +residual โ RMSNorm โ LlamaMLP โ +residual โ out
The MLP and RMSNorm are identical to Llama (weights transferred directly).
Only MonoidAttention is the novel component.
MLP ๅ RMSNorm ไธ Llama ๅฎๅ
จ็ธๅ (ๆ้็ดๆฅ่ฟ็งป)ใ
ไป
MonoidAttention ๆฏๅ
จๆฐ็ปไปถใ
"""
gradient_checkpointing = False
def __init__(self, config: MonoidConfig, layer_idx: int):
super().__init__()
self.self_attn = MonoidAttention(config, layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor | None = None,
monoid_cache: MonoidCache | None = None,
use_cache: bool = False,
) -> Tensor:
# --- Attention block with residual ---
# --- ๆณจๆๅๅ + ๆฎๅทฎ่ฟๆฅ ---
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn(hidden_states, attention_mask=attention_mask, monoid_cache=monoid_cache, use_cache=use_cache)
hidden_states = residual + hidden_states
# --- FFN block with residual ---
# --- ๅ้ฆ็ฝ็ปๅ + ๆฎๅทฎ่ฟๆฅ ---
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# MonoidModel (backbone)
# MonoidModel (้ชจๅนฒ็ฝ็ป)
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
class MonoidPreTrainedModel(PreTrainedModel):
config_class = MonoidConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MonoidDecoderLayer"]
def _init_weights(self, module: nn.Module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MonoidAttention):
nn.init.constant_(module.decay_proj.bias, 1.0)
class MonoidModel(MonoidPreTrainedModel):
"""
Stack of MonoidDecoderLayers with token embedding and final norm.
ๅนบๅ็พค่งฃ็ ๅฑๅ ๅ , ๅธฆ token ๅตๅ
ฅๅๆ็ปๅฝไธๅใ
Forward: embed_tokens โ N ร MonoidDecoderLayer โ final_norm
ๅๅ: embed_tokens โ N ร MonoidDecoderLayer โ final_norm
"""
def __init__(self, config: MonoidConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[MonoidDecoderLayer(config, i) for i in range(config.num_hidden_layers)]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self.post_init()
def forward(
self,
input_ids: Tensor | None = None,
attention_mask: Tensor | None = None,
inputs_embeds: Tensor | None = None,
monoid_cache: MonoidCache | None = None,
use_cache: bool = False,
) -> BaseModelOutputWithPast:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
for layer in self.layers:
if self.gradient_checkpointing and self.training and not use_cache:
hidden_states = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
monoid_cache,
use_cache,
)
else:
hidden_states = layer(hidden_states, attention_mask=attention_mask, monoid_cache=monoid_cache, use_cache=use_cache)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=monoid_cache,
)
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# MonoidForCausalLM โ the full causal LM
# MonoidForCausalLM โ ๅฎๆดๅ ๆ่ฏญ่จๆจกๅ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
class MonoidForCausalLM(MonoidPreTrainedModel, GenerationMixin):
"""
Monoid-based causal language model with LM head.
ๅบไบๅนบๅ็พค็ๅ ๆ่ฏญ่จๆจกๅ, ๅธฆ่ฏญ่จๆจกๅๅคดใ
The architecture in one sentence:
"Llama body + Monoid mind" โ reuse Llama's proven MLP/embeddings,
replace attention with monoid state compression for O(1) inference.
ไธๅฅ่ฏๆฆๆฌๆถๆ:
"Llama ็่บซไฝ + ๅนบๅ็พค็ๆ็ปด" โ ๅค็จ Llama ๆ็็ MLP/ๅตๅ
ฅๅฑ,
็จๅนบๅ็พค็ถๆๅ็ผฉๆฟๆขๆณจๆๅ, ๅฎ็ฐ O(1) ๆจ็ใ
"""
_tied_weights_keys = ["lm_head.weight"]
# Tell HuggingFace GenerationMixin NOT to create DynamicCache.
# Monoid uses its own O(1) MonoidCache, not KV-Cache.
# ๅ่ฏ HuggingFace ไธ่ฆๅๅปบ DynamicCacheใ
# Monoid ไฝฟ็จ่ชๅทฑ็ O(1) MonoidCache, ไธๆฏ KV ็ผๅญใ
_is_stateful = True
def __init__(self, config: MonoidConfig):
super().__init__(config)
self.model = MonoidModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(
self,
input_ids: Tensor,
past_key_values=None,
attention_mask: Tensor | None = None,
inputs_embeds: Tensor | None = None,
**kwargs,
) -> dict:
"""
Called by GenerationMixin at each decoding step.
GenerationMixin ๅจๆฏไธช่งฃ็ ๆญฅ่ฐ็จๆญคๆนๆณใ
HuggingFace may pass a DynamicCache; we intercept and replace
it with MonoidCache since we don't use standard KV-cache.
HuggingFace ๅฏ่ฝไผ ๅ
ฅ DynamicCache; ๆไปฌๆฆๆชๅนถๆฟๆขไธบ
MonoidCache, ๅ ไธบๆไปฌไธไฝฟ็จๆ ๅ KV ็ผๅญใ
"""
# Intercept non-MonoidCache objects (e.g. DynamicCache from GenerationMixin)
# ๆฆๆช้ MonoidCache ๅฏน่ฑก (ๅฆ GenerationMixin ๅๅปบ็ DynamicCache)
if past_key_values is not None and not isinstance(past_key_values, MonoidCache):
past_key_values = None
if past_key_values is not None and past_key_values.seen_tokens > 0:
# Cache exists โ only feed the latest token (O(1) inference)
# ็ผๅญๅทฒๅญๅจ โ ๅช้่พๅ
ฅๆๆฐ็ token (O(1) ๆจ็)
input_ids = input_ids[:, -1:]
# Decode step: single real token, no PAD โ mask not needed
# ่งฃ็ ๆญฅ: ๅไธช็ๅฎtoken, ๆ PAD โ ไธ้่ฆๆฉ็
attention_mask = None
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"monoid_cache": past_key_values,
"use_cache": True,
}
return model_inputs
def forward(
self,
input_ids: Tensor | None = None,
attention_mask: Tensor | None = None, # [B,T] 1=real, 0=pad โ used to mask PAD from recurrence
# [B,T] 1=็ๅฎtoken, 0=ๅกซๅ
โ ็จไบๅฑ่ฝPADๅฏน้ๆจ็ๅฝฑๅ
position_ids: Tensor | None = None, # kept for API compat; monoid ignores this
# ไฟ็ API ๅ
ผๅฎนๆง; ๅนบๅ็พคไธไฝฟ็จ
past_key_values: MonoidCache | None = None,
inputs_embeds: Tensor | None = None,
labels: Tensor | None = None,
use_cache: bool | None = None,
monoid_cache: MonoidCache | None = None,
output_attentions: bool | None = None, # kept for API compat
output_hidden_states: bool | None = None, # kept for API compat
logits_to_keep: int | Tensor = 0,
**kwargs,
) -> CausalLMOutputWithPast:
# monoid_cache takes priority; fall back to past_key_values for GenerationMixin compat
# monoid_cache ไผๅ
; ๅ
ผๅฎน GenerationMixin ไผ ๅ
ฅ็ past_key_values
cache = monoid_cache or past_key_values
# Discard any non-MonoidCache (e.g. DynamicCache injected by GenerationMixin)
# ไธขๅผไปปไฝ้ MonoidCache ๅฏน่ฑก (ๅฆ GenerationMixin ๆณจๅ
ฅ็ DynamicCache)
if cache is not None and not isinstance(cache, MonoidCache):
cache = None
if use_cache and cache is None:
cache = MonoidCache()
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
monoid_cache=cache,
use_cache=bool(use_cache),
)
hidden_states = outputs.last_hidden_state
# Optionally only compute logits for the last K tokens (memory saving)
# ๅฏ้ไป
่ฎก็ฎๆๅ K ไธช token ็ logits (่็ๅ
ๅญ)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) and logits_to_keep > 0 else slice(None)
logits = self.lm_head(hidden_states[:, slice_indices, :])
# Standard causal LM loss: cross-entropy with shift
# ๆ ๅๅ ๆ่ฏญ่จๆจกๅๆๅคฑ: ๅธฆๅ็งป็ไบคๅ็ต
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = nn.functional.cross_entropy(
shift_logits.view(-1, self.vocab_size),
shift_labels.view(-1),
ignore_index=-100,
)
if cache is not None:
cache.seen_tokens += (input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1])
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=cache,
)
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# AutoModel Registration / ่ชๅจๆณจๅ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
AutoConfig.register("monoid", MonoidConfig)
AutoModelForCausalLM.register(MonoidConfig, MonoidForCausalLM)
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Smoke Tests / ้ช่ฏ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
if __name__ == '__main__':
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f'Device: {device}')
config = MonoidConfig(
vocab_size=49152,
hidden_size=576,
intermediate_size=1536,
num_hidden_layers=30,
num_attention_heads=9,
head_dim=64,
rms_norm_eps=1e-5,
hidden_act="silu",
tie_word_embeddings=True,
)
model = MonoidForCausalLM(config).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Parameters: {n_params:,}')
# -- Training smoke test / ่ฎญ็ปๅ็ๆต่ฏ --
B, T = 2, 64
ids = torch.randint(0, config.vocab_size, (B, T), device=device)
out = model(ids, labels=ids)
print(f'Train โ logits: {out.logits.shape}, loss: {out.loss:.4f}')
# -- Inference smoke test (manual RNN loop) / ๆจ็ๅ็ๆต่ฏ (ๆๅจ RNN ๅพช็ฏ) --
prompt = torch.randint(0, config.vocab_size, (1, 8), device=device)
cache = MonoidCache()
# Prefill / ้ขๅกซๅ
prefill_out = model(prompt, use_cache=True, monoid_cache=cache)
print(f'Prefill โ logits: {prefill_out.logits.shape}, cache seen: {cache.seen_tokens}')
# Decode 1 token / ่งฃ็ 1 ไธช token
next_tok = prefill_out.logits[:, -1:].argmax(dim=-1)
step_out = model(next_tok, use_cache=True, monoid_cache=cache)
print(f'Decode โ logits: {step_out.logits.shape}, cache seen: {cache.seen_tokens}')
# -- Monoid associativity check / ๅนบๅ็พค็ปๅๅพ้ช่ฏ --
print('\nMonoid associativity check / ๅนบๅ็พค็ปๅๅพ้ช่ฏ:')
a = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4))
b = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4))
c = (torch.randn(1, 1, 1), torch.randn(1, 1, 4, 4))
ab_c = monoid_op(monoid_op(a, b), c)
a_bc = monoid_op(a, monoid_op(b, c))
err = (ab_c[1] - a_bc[1]).abs().max().item()
print(f' |(aโb)โc - aโ(bโc)| = {err:.2e}')
print('\nDone.')
|