OzTianlu commited on
Commit
471fab3
ยท
verified ยท
1 Parent(s): 39fa0da

Upload 2 files

Browse files
Files changed (1) hide show
  1. MonoidForCausalLM.py +46 -18
MonoidForCausalLM.py CHANGED
@@ -97,6 +97,7 @@ except ImportError:
97
  return states, (log_acc, S)
98
 
99
 
 
100
  # โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
101
  # Config / ้…็ฝฎ
102
  # โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
@@ -452,6 +453,35 @@ class MonoidAttention(nn.Module):
452
  o = o.contiguous().view(B, 1, -1)
453
  return self.o_proj(o), new_state
454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
456
  # Training path (parallel scan): O(T) via prefix sum
457
  # ่ฎญ็ปƒ่ทฏๅพ„ (ๅนถ่กŒๆ‰ซๆ): ้€š่ฟ‡ๅ‰็ผ€ๅ’Œ O(T)
@@ -467,16 +497,9 @@ class MonoidAttention(nn.Module):
467
  # Batch outer product: kv_{t} = k_t โŠ— v_t for all t
468
  # ๆ‰น้‡ๅค–็งฏ: kv_{t} = k_t โŠ— v_t, ๅฏนๆ‰€ๆœ‰ t
469
  kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
470
-
471
- if use_cache:
472
- # Prefill with state extraction (for switching to RNN inference)
473
- # ๅธฆ็Šถๆ€ๆๅ–็š„้ข„ๅกซๅ…… (็”จไบŽๅˆ‡ๆขๅˆฐ RNN ๆŽจ็†)
474
- states, final_state = parallel_scan_with_state(log_alpha, kv)
475
- else:
476
- # Pure training, no state needed
477
- # ็บฏ่ฎญ็ปƒ, ไธ้œ€่ฆ็Šถๆ€
478
- states = parallel_scan(log_alpha, kv)
479
- final_state = None
480
 
481
  # โ”€โ”€ Incorporate h0: make training consistent with inference โ”€โ”€
482
  # โ”€โ”€ ่žๅ…ฅ h0: ไฝฟ่ฎญ็ปƒไธŽๆŽจ็†ไธ€่‡ด โ”€โ”€
@@ -489,17 +512,12 @@ class MonoidAttention(nn.Module):
489
  cum_log_decay = torch.cumsum(log_alpha.squeeze(-1), dim=2) # [B,H,T]
490
  cum_decay = torch.exp(cum_log_decay).unsqueeze(-1).unsqueeze(-1) # [B,H,T,1,1]
491
  states = states + self.h0.unsqueeze(2) * cum_decay # [B,H,T,d,d]
492
-
493
- if use_cache:
494
- # Update final_state to include h0 contribution
495
- # ๆ›ดๆ–ฐๆœ€็ปˆ็Šถๆ€ไปฅๅŒ…ๅซ h0 ็š„่ดก็Œฎ
496
- final_state = (final_state[0], states[:, :, -1])
497
- if monoid_cache is not None:
498
- monoid_cache.update(self.layer_idx, final_state)
499
 
500
  # Readout: o_t = q_t ยท S_t for all t simultaneously
501
  # ่ฏปๅ‡บ: o_t = q_t ยท S_t, ๅฏนๆ‰€ๆœ‰ t ๅŒๆ—ถ่ฎก็ฎ—
502
  o = torch.einsum('bhtd, bhtde -> bhte', q, states)
 
503
  o = o.transpose(1, 2).contiguous().view(B, T, -1)
504
  return self.o_proj(o), final_state
505
 
@@ -576,6 +594,8 @@ class MonoidPreTrainedModel(PreTrainedModel):
576
  if module.padding_idx is not None:
577
  module.weight.data[module.padding_idx].zero_()
578
 
 
 
579
 
580
  class MonoidModel(MonoidPreTrainedModel):
581
  """
@@ -609,7 +629,15 @@ class MonoidModel(MonoidPreTrainedModel):
609
 
610
  hidden_states = inputs_embeds
611
  for layer in self.layers:
612
- hidden_states = layer(hidden_states, monoid_cache=monoid_cache, use_cache=use_cache)
 
 
 
 
 
 
 
 
613
 
614
  hidden_states = self.norm(hidden_states)
615
 
 
97
  return states, (log_acc, S)
98
 
99
 
100
+
101
  # โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
102
  # Config / ้…็ฝฎ
103
  # โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”
 
453
  o = o.contiguous().view(B, 1, -1)
454
  return self.o_proj(o), new_state
455
 
456
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
457
+ # Inference prefill (use_cache=True, T>1): fused scan + readout
458
+ # ๆŽจ็†้ข„ๅกซๅ…… (use_cache=True, T>1): ่žๅˆๆ‰ซๆ + ่ฏปๅ‡บ
459
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
460
+ # Avoids materializing full [B,H,T,d,d] states tensor.
461
+ # Peak memory: O(Hยทdยฒ) instead of O(TยทHยทdยฒ).
462
+ # ้ฟๅ…ๅฎžไฝ“ๅŒ–ๅฎŒๆ•ด็š„ [B,H,T,d,d] ็Šถๆ€ๅผ ้‡ใ€‚
463
+ # ๅณฐๅ€ผๅ†…ๅญ˜: O(Hยทdยฒ) ่€Œ้ž O(TยทHยทdยฒ)ใ€‚
464
+ if use_cache:
465
+ S = self.h0.expand(B, -1, -1, -1).clone() # [B,H,d,d]
466
+ log_acc = torch.zeros(B, H, 1, device=hidden_states.device, dtype=q.dtype)
467
+ o_parts = []
468
+ for t in range(T):
469
+ kv_t = torch.einsum('bhd, bhe -> bhde', k[:, :, t], v[:, :, t])
470
+ decay = torch.exp(log_alpha[:, :, t]) # [B,H,1]
471
+ while decay.dim() < S.dim():
472
+ decay = decay.unsqueeze(-1)
473
+ S = S * decay + kv_t
474
+ o_parts.append(torch.einsum('bhd, bhde -> bhe', q[:, :, t], S))
475
+ log_acc = log_acc + log_alpha[:, :, t]
476
+
477
+ final_state = (log_acc, S)
478
+ if monoid_cache is not None:
479
+ monoid_cache.update(self.layer_idx, final_state)
480
+
481
+ o = torch.stack(o_parts, dim=2) # [B,H,T,d]
482
+ o = o.transpose(1, 2).contiguous().view(B, T, -1)
483
+ return self.o_proj(o), final_state
484
+
485
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
486
  # Training path (parallel scan): O(T) via prefix sum
487
  # ่ฎญ็ปƒ่ทฏๅพ„ (ๅนถ่กŒๆ‰ซๆ): ้€š่ฟ‡ๅ‰็ผ€ๅ’Œ O(T)
 
497
  # Batch outer product: kv_{t} = k_t โŠ— v_t for all t
498
  # ๆ‰น้‡ๅค–็งฏ: kv_{t} = k_t โŠ— v_t, ๅฏนๆ‰€ๆœ‰ t
499
  kv = torch.einsum('bhtd, bhte -> bhtde', k, v) # [B,H,T,d,d]
500
+ states = parallel_scan(log_alpha, kv)
501
+ del kv # free [B,H,T,d,d] early
502
+ final_state = None
 
 
 
 
 
 
 
503
 
504
  # โ”€โ”€ Incorporate h0: make training consistent with inference โ”€โ”€
505
  # โ”€โ”€ ่žๅ…ฅ h0: ไฝฟ่ฎญ็ปƒไธŽๆŽจ็†ไธ€่‡ด โ”€โ”€
 
512
  cum_log_decay = torch.cumsum(log_alpha.squeeze(-1), dim=2) # [B,H,T]
513
  cum_decay = torch.exp(cum_log_decay).unsqueeze(-1).unsqueeze(-1) # [B,H,T,1,1]
514
  states = states + self.h0.unsqueeze(2) * cum_decay # [B,H,T,d,d]
515
+ del cum_decay
 
 
 
 
 
 
516
 
517
  # Readout: o_t = q_t ยท S_t for all t simultaneously
518
  # ่ฏปๅ‡บ: o_t = q_t ยท S_t, ๅฏนๆ‰€ๆœ‰ t ๅŒๆ—ถ่ฎก็ฎ—
519
  o = torch.einsum('bhtd, bhtde -> bhte', q, states)
520
+ del states # free [B,H,T,d,d]
521
  o = o.transpose(1, 2).contiguous().view(B, T, -1)
522
  return self.o_proj(o), final_state
523
 
 
594
  if module.padding_idx is not None:
595
  module.weight.data[module.padding_idx].zero_()
596
 
597
+ if isinstance(module, MonoidAttention):
598
+ nn.init.constant_(module.decay_proj.bias, 4.0)
599
 
600
  class MonoidModel(MonoidPreTrainedModel):
601
  """
 
629
 
630
  hidden_states = inputs_embeds
631
  for layer in self.layers:
632
+ if self.gradient_checkpointing and self.training and not use_cache:
633
+ hidden_states = self._gradient_checkpointing_func(
634
+ layer.__call__,
635
+ hidden_states,
636
+ monoid_cache,
637
+ use_cache,
638
+ )
639
+ else:
640
+ hidden_states = layer(hidden_states, monoid_cache=monoid_cache, use_cache=use_cache)
641
 
642
  hidden_states = self.norm(hidden_states)
643