razmars commited on
Commit
b8743eb
·
verified ·
1 Parent(s): aa9daa3

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +25 -1
modeling_super_linear.py CHANGED
@@ -597,6 +597,29 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
597
  return y
598
 
599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
  def forward(self,
601
  inputs_embeds: torch.Tensor = None,
602
  attention_mask: Optional[torch.Tensor] = None,
@@ -620,11 +643,12 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
620
 
621
  self.backbone.inf_pred_len = 336
622
 
623
-
624
  # backbone returns (B, pred_len, C)
625
 
626
  preds = self.backbone(x_enc)
627
  preds = self.revin_layer(preds, 'denorm')
 
 
628
  return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
629
 
630
 
 
597
  return y
598
 
599
 
600
+ def fourier_downsample_dim1(self,x, target_len: int):
601
+
602
+
603
+ # 1. Forward real FFT along dim-1
604
+ X = torch.fft.rfft(x, dim=1) # shape (..., L//2 + 1, ...)
605
+
606
+ # 2. Keep only the low-frequency bins needed for the shorter series
607
+ keep = target_len // 2 + 1 # rfft size for the target grid
608
+ X_crop = X[..., :keep] # ideal brick-wall low-pass
609
+
610
+ # 3. Inverse FFT to the shorter grid
611
+ y = torch.fft.irfft(X_crop, n=target_len, dim=1)
612
+
613
+ # 4. Renormalise amplitudes:
614
+ # irfft divides by `target_len`, whereas the forward rfft used length `L`.
615
+ # Multiply by (target_len / L) so DC and low-freq amplitudes match input.
616
+ #y *= target_len / L
617
+
618
+ return y
619
+
620
+
621
+
622
+
623
  def forward(self,
624
  inputs_embeds: torch.Tensor = None,
625
  attention_mask: Optional[torch.Tensor] = None,
 
643
 
644
  self.backbone.inf_pred_len = 336
645
 
 
646
  # backbone returns (B, pred_len, C)
647
 
648
  preds = self.backbone(x_enc)
649
  preds = self.revin_layer(preds, 'denorm')
650
+ preds = self.fourier_downsample_dim1(preds,96)
651
+
652
  return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
653
 
654