Upload 2 files
Browse files- MonoidForCausalLM.py +46 -18
MonoidForCausalLM.py
CHANGED
|
@@ -97,6 +97,7 @@ except ImportError:
|
|
| 97 |
return states, (log_acc, S)
|
| 98 |
|
| 99 |
|
|
|
|
| 100 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 101 |
# Config / ้
็ฝฎ
|
| 102 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
|
@@ -452,6 +453,35 @@ class MonoidAttention(nn.Module):
|
|
| 452 |
o = o.contiguous().view(B, 1, -1)
|
| 453 |
return self.o_proj(o), new_state
|
| 454 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 456 |
# Training path (parallel scan): O(T) via prefix sum
|
| 457 |
# ่ฎญ็ป่ทฏๅพ (ๅนถ่กๆซๆ): ้่ฟๅ็ผๅ O(T)
|
|
@@ -467,16 +497,9 @@ class MonoidAttention(nn.Module):
|
|
| 467 |
# Batch outer product: kv_{t} = k_t โ v_t for all t
|
| 468 |
# ๆน้ๅค็งฏ: kv_{t} = k_t โ v_t, ๅฏนๆๆ t
|
| 469 |
kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
# ๅธฆ็ถๆๆๅ็้ขๅกซๅ
(็จไบๅๆขๅฐ RNN ๆจ็)
|
| 474 |
-
states, final_state = parallel_scan_with_state(log_alpha, kv)
|
| 475 |
-
else:
|
| 476 |
-
# Pure training, no state needed
|
| 477 |
-
# ็บฏ่ฎญ็ป, ไธ้่ฆ็ถๆ
|
| 478 |
-
states = parallel_scan(log_alpha, kv)
|
| 479 |
-
final_state = None
|
| 480 |
|
| 481 |
# โโ Incorporate h0: make training consistent with inference โโ
|
| 482 |
# โโ ่ๅ
ฅ h0: ไฝฟ่ฎญ็ปไธๆจ็ไธ่ด โโ
|
|
@@ -489,17 +512,12 @@ class MonoidAttention(nn.Module):
|
|
| 489 |
cum_log_decay = torch.cumsum(log_alpha.squeeze(-1), dim=2) # [B,H,T]
|
| 490 |
cum_decay = torch.exp(cum_log_decay).unsqueeze(-1).unsqueeze(-1) # [B,H,T,1,1]
|
| 491 |
states = states + self.h0.unsqueeze(2) * cum_decay # [B,H,T,d,d]
|
| 492 |
-
|
| 493 |
-
if use_cache:
|
| 494 |
-
# Update final_state to include h0 contribution
|
| 495 |
-
# ๆดๆฐๆ็ป็ถๆไปฅๅ
ๅซ h0 ็่ดก็ฎ
|
| 496 |
-
final_state = (final_state[0], states[:, :, -1])
|
| 497 |
-
if monoid_cache is not None:
|
| 498 |
-
monoid_cache.update(self.layer_idx, final_state)
|
| 499 |
|
| 500 |
# Readout: o_t = q_t ยท S_t for all t simultaneously
|
| 501 |
# ่ฏปๅบ: o_t = q_t ยท S_t, ๅฏนๆๆ t ๅๆถ่ฎก็ฎ
|
| 502 |
o = torch.einsum('bhtd, bhtde -> bhte', q, states)
|
|
|
|
| 503 |
o = o.transpose(1, 2).contiguous().view(B, T, -1)
|
| 504 |
return self.o_proj(o), final_state
|
| 505 |
|
|
@@ -576,6 +594,8 @@ class MonoidPreTrainedModel(PreTrainedModel):
|
|
| 576 |
if module.padding_idx is not None:
|
| 577 |
module.weight.data[module.padding_idx].zero_()
|
| 578 |
|
|
|
|
|
|
|
| 579 |
|
| 580 |
class MonoidModel(MonoidPreTrainedModel):
|
| 581 |
"""
|
|
@@ -609,7 +629,15 @@ class MonoidModel(MonoidPreTrainedModel):
|
|
| 609 |
|
| 610 |
hidden_states = inputs_embeds
|
| 611 |
for layer in self.layers:
|
| 612 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
|
| 614 |
hidden_states = self.norm(hidden_states)
|
| 615 |
|
|
|
|
| 97 |
return states, (log_acc, S)
|
| 98 |
|
| 99 |
|
| 100 |
+
|
| 101 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 102 |
# Config / ้
็ฝฎ
|
| 103 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
|
|
|
| 453 |
o = o.contiguous().view(B, 1, -1)
|
| 454 |
return self.o_proj(o), new_state
|
| 455 |
|
| 456 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 457 |
+
# Inference prefill (use_cache=True, T>1): fused scan + readout
|
| 458 |
+
# ๆจ็้ขๅกซๅ
(use_cache=True, T>1): ่ๅๆซๆ + ่ฏปๅบ
|
| 459 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 460 |
+
# Avoids materializing full [B,H,T,d,d] states tensor.
|
| 461 |
+
# Peak memory: O(Hยทdยฒ) instead of O(TยทHยทdยฒ).
|
| 462 |
+
# ้ฟๅ
ๅฎไฝๅๅฎๆด็ [B,H,T,d,d] ็ถๆๅผ ้ใ
|
| 463 |
+
# ๅณฐๅผๅ
ๅญ: O(Hยทdยฒ) ่้ O(TยทHยทdยฒ)ใ
|
| 464 |
+
if use_cache:
|
| 465 |
+
S = self.h0.expand(B, -1, -1, -1).clone() # [B,H,d,d]
|
| 466 |
+
log_acc = torch.zeros(B, H, 1, device=hidden_states.device, dtype=q.dtype)
|
| 467 |
+
o_parts = []
|
| 468 |
+
for t in range(T):
|
| 469 |
+
kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, t], v[:, :, t])
|
| 470 |
+
decay = torch.exp(log_alpha[:, :, t]) # [B,H,1]
|
| 471 |
+
while decay.dim() < S.dim():
|
| 472 |
+
decay = decay.unsqueeze(-1)
|
| 473 |
+
S = S * decay + kv_t
|
| 474 |
+
o_parts.append(torch.einsum('bhd, bhde -> bhe', q[:, :, t], S))
|
| 475 |
+
log_acc = log_acc + log_alpha[:, :, t]
|
| 476 |
+
|
| 477 |
+
final_state = (log_acc, S)
|
| 478 |
+
if monoid_cache is not None:
|
| 479 |
+
monoid_cache.update(self.layer_idx, final_state)
|
| 480 |
+
|
| 481 |
+
o = torch.stack(o_parts, dim=2) # [B,H,T,d]
|
| 482 |
+
o = o.transpose(1, 2).contiguous().view(B, T, -1)
|
| 483 |
+
return self.o_proj(o), final_state
|
| 484 |
+
|
| 485 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 486 |
# Training path (parallel scan): O(T) via prefix sum
|
| 487 |
# ่ฎญ็ป่ทฏๅพ (ๅนถ่กๆซๆ): ้่ฟๅ็ผๅ O(T)
|
|
|
|
| 497 |
# Batch outer product: kv_{t} = k_t โ v_t for all t
|
| 498 |
# ๆน้ๅค็งฏ: kv_{t} = k_t โ v_t, ๅฏนๆๆ t
|
| 499 |
kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
|
| 500 |
+
states = parallel_scan(log_alpha, kv)
|
| 501 |
+
del kv # free [B,H,T,d,d] early
|
| 502 |
+
final_state = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
|
| 504 |
# โโ Incorporate h0: make training consistent with inference โโ
|
| 505 |
# โโ ่ๅ
ฅ h0: ไฝฟ่ฎญ็ปไธๆจ็ไธ่ด โโ
|
|
|
|
| 512 |
cum_log_decay = torch.cumsum(log_alpha.squeeze(-1), dim=2) # [B,H,T]
|
| 513 |
cum_decay = torch.exp(cum_log_decay).unsqueeze(-1).unsqueeze(-1) # [B,H,T,1,1]
|
| 514 |
states = states + self.h0.unsqueeze(2) * cum_decay # [B,H,T,d,d]
|
| 515 |
+
del cum_decay
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
|
| 517 |
# Readout: o_t = q_t ยท S_t for all t simultaneously
|
| 518 |
# ่ฏปๅบ: o_t = q_t ยท S_t, ๅฏนๆๆ t ๅๆถ่ฎก็ฎ
|
| 519 |
o = torch.einsum('bhtd, bhtde -> bhte', q, states)
|
| 520 |
+
del states # free [B,H,T,d,d]
|
| 521 |
o = o.transpose(1, 2).contiguous().view(B, T, -1)
|
| 522 |
return self.o_proj(o), final_state
|
| 523 |
|
|
|
|
| 594 |
if module.padding_idx is not None:
|
| 595 |
module.weight.data[module.padding_idx].zero_()
|
| 596 |
|
| 597 |
+
if isinstance(module, MonoidAttention):
|
| 598 |
+
nn.init.constant_(module.decay_proj.bias, 4.0)
|
| 599 |
|
| 600 |
class MonoidModel(MonoidPreTrainedModel):
|
| 601 |
"""
|
|
|
|
| 629 |
|
| 630 |
hidden_states = inputs_embeds
|
| 631 |
for layer in self.layers:
|
| 632 |
+
if self.gradient_checkpointing and self.training and not use_cache:
|
| 633 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 634 |
+
layer.__call__,
|
| 635 |
+
hidden_states,
|
| 636 |
+
monoid_cache,
|
| 637 |
+
use_cache,
|
| 638 |
+
)
|
| 639 |
+
else:
|
| 640 |
+
hidden_states = layer(hidden_states, monoid_cache=monoid_cache, use_cache=use_cache)
|
| 641 |
|
| 642 |
hidden_states = self.norm(hidden_states)
|
| 643 |
|