razmars commited on
Commit
7f83ae8
Β·
verified Β·
1 Parent(s): 0c42e00

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +14 -15
modeling_super_linear.py CHANGED
@@ -593,22 +593,21 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
593
  return y
594
 
595
  def upsample_interpolate(self, x,scale_factor, target_len: int = 512):
596
- # Add channel dimension if input is 2D
597
- size = len(x.shape)
598
- if size == 2:
599
- x = x.unsqueeze(1)
600
-
601
- print(target_len)
602
- upsample = F.interpolate(x, size=target_len, mode='linear', align_corners=False)
603
 
604
- # If input was 2D, remove the channel dimension
605
- if size == 2:
606
- upsample = upsample.squeeze(1)
607
-
608
- upsample = upsample.float()
609
-
610
- #print(f"Upsampled shape: {upsample.shape}")
611
- return upsample
612
 
613
 
614
 
 
593
  return y
594
 
595
  def upsample_interpolate(self, x,scale_factor, target_len: int = 512):
596
+ was_2d = x.dim() == 2
597
+
598
+ if was_2d: # [B, L] -> [B, 1, L]
599
+ x = x.unsqueeze(1)
600
+ else: # [B, L, C] -> [B, C, L]
601
+ x = x.permute(0, 2, 1)
602
+
603
 
604
+ x_up = F.interpolate(x, size=target_len, mode='linear', align_corners=False)
605
+
606
+ # ── restore original layout ─────────────────────────────────────────────────
607
+ if was_2d: # back to [B, target_len]
608
+ return x_up.squeeze(1).float()
609
+ else: # back to [B, target_len, C]
610
+ return x_up.permute(0, 2, 1).float()
 
611
 
612
 
613