Upload 2 files
Browse files- MonoidForCausalLM.py +21 -34
MonoidForCausalLM.py
CHANGED
|
@@ -483,43 +483,30 @@ class MonoidAttention(nn.Module):
|
|
| 483 |
return self.o_proj(o), final_state
|
| 484 |
|
| 485 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 486 |
-
# Training path
|
| 487 |
-
#
|
| 488 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 489 |
-
#
|
| 490 |
-
#
|
| 491 |
-
#
|
| 492 |
#
|
| 493 |
-
#
|
| 494 |
-
#
|
| 495 |
-
#
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
# The cumulative decay ฮ _{i=1}^{t} ฮฑ_i = exp(ฮฃ_{i=1}^{t} log_ฮฑ_i).
|
| 509 |
-
# parallel_scan ไป S_0 = 0 ๅผๅง, ไฝๆจ็ไป S_0 = h0 ๅผๅงใ
|
| 510 |
-
# ไฟฎๆญฃ: S_t(ๅซh0) = h0 ยท ฮ _{i=1}^{t} ฮฑ_i + S_t(ๆซๆ็ปๆ)
|
| 511 |
-
# ็ดฏ็งฏ่กฐๅ ฮ _{i=1}^{t} ฮฑ_i = exp(ฮฃ_{i=1}^{t} log_ฮฑ_i)ใ
|
| 512 |
-
cum_log_decay = torch.cumsum(log_alpha.squeeze(-1), dim=2) # [B,H,T]
|
| 513 |
-
cum_decay = torch.exp(cum_log_decay).unsqueeze(-1).unsqueeze(-1) # [B,H,T,1,1]
|
| 514 |
-
states = states + self.h0.unsqueeze(2) * cum_decay # [B,H,T,d,d]
|
| 515 |
-
del cum_decay
|
| 516 |
-
|
| 517 |
-
# Readout: o_t = q_t ยท S_t for all t simultaneously
|
| 518 |
-
# ่ฏปๅบ: o_t = q_t ยท S_t, ๅฏนๆๆ t ๅๆถ่ฎก็ฎ
|
| 519 |
-
o = torch.einsum('bhtd, bhtde -> bhte', q, states)
|
| 520 |
-
del states # free [B,H,T,d,d]
|
| 521 |
o = o.transpose(1, 2).contiguous().view(B, T, -1)
|
| 522 |
-
return self.o_proj(o),
|
| 523 |
|
| 524 |
|
| 525 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
|
|
|
| 483 |
return self.o_proj(o), final_state
|
| 484 |
|
| 485 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 486 |
+
# Training path: memory-efficient sequential scan + inline readout
|
| 487 |
+
# ่ฎญ็ป่ทฏๅพ: ๅ
ๅญ้ซๆ็ไธฒ่กๆซๆ + ๅ
่่ฏปๅบ
|
| 488 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 489 |
+
# Loop token-by-token with running state S=[B,H,d,d].
|
| 490 |
+
# Peak memory: O(BยทHยทdยฒ) instead of O(BยทHยทTยทdยฒ).
|
| 491 |
+
# Autograd records each step for correct gradient computation.
|
| 492 |
#
|
| 493 |
+
# ้ token ๅพช็ฏ, ไฝฟ็จ่ฟ่ก็ถๆ S=[B,H,d,d]ใ
|
| 494 |
+
# ๅณฐๅผๅ
ๅญ: O(BยทHยทdยฒ) ่้ O(BยทHยทTยทdยฒ)ใ
|
| 495 |
+
# Autograd ่ฎฐๅฝๆฏๆญฅๆไฝไปฅๆญฃ็กฎ่ฎก็ฎๆขฏๅบฆใ
|
| 496 |
+
|
| 497 |
+
S = self.h0.expand(B, -1, -1, -1).clone() # [B,H,d,d]
|
| 498 |
+
o_parts = []
|
| 499 |
+
for t in range(T):
|
| 500 |
+
kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, t], v[:, :, t])
|
| 501 |
+
decay = torch.exp(log_alpha[:, :, t]) # [B,H,1]
|
| 502 |
+
while decay.dim() < S.dim():
|
| 503 |
+
decay = decay.unsqueeze(-1)
|
| 504 |
+
S = S * decay + kv_t
|
| 505 |
+
o_parts.append(torch.einsum('bhd, bhde -> bhe', q[:, :, t], S))
|
| 506 |
+
|
| 507 |
+
o = torch.stack(o_parts, dim=2) # [B,H,T,d]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
o = o.transpose(1, 2).contiguous().view(B, T, -1)
|
| 509 |
+
return self.o_proj(o), None
|
| 510 |
|
| 511 |
|
| 512 |
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|