razmars commited on
Commit
43cc2dd
·
verified ·
1 Parent(s): d1351d6

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +44 -34
modeling_super_linear.py CHANGED
@@ -293,41 +293,51 @@ class SparseNoisyMoE(nn.Module):
293
  self.gating_network = nn.Linear(input_dim, self.num_experts, bias=True)
294
 
295
  def get_periodogram(self, inputs, ker_len=50, con=1, n=10000):
296
- if inputs.dim() == 2:
297
- x_0 = inputs.unsqueeze(2)
298
- else:
299
- x_0 = inputs
300
- x_0 = x_0 - torch.mean(x_0, dim=1, keepdim=True)
301
-
302
- v = torch.arange(0, n) / n
303
- if con:
304
- if ker_len is None:
305
- ker_len = n // 4
306
- ker_len = min(ker_len, 50)
307
-
308
- x_0 = x_0.permute(0, 2, 1)
309
- ker = (torch.ones(1, 1, ker_len) / ker_len).to(x_0.device)
310
- x_c = F.conv1d(x_0, ker, padding="same")
311
- x_c[:, :, :ker_len // 2] = x_c[:, :, ker_len // 2:ker_len // 2 + 1]
312
- x_c[:, :, -ker_len // 2:] = x_c[:, :, -ker_len // 2 - 1:-ker_len // 2]
313
- x_0 = x_0 - x_c
314
- x_0 = x_0.permute(0, 2, 1)
315
-
316
- dft = torch.fft.fft(x_0, dim=1, n=n) / np.sqrt(n)
317
- dft = dft[:, :n//2, :]
318
- I = torch.abs(dft) ** 2
319
-
320
- I_sum = torch.sum(I, dim=1, keepdim=True)
321
- I_sum[I_sum == 0] = 1
322
- I = I / I_sum
323
-
324
- if torch.any(I_sum == 0):
325
- print("Zeros in the sum")
326
- raise ValueError
327
-
328
- if inputs.dim() == 2:
 
 
 
 
 
 
 
 
 
 
329
  I = I.squeeze(2)
330
-
331
  return I
332
 
333
 
 
293
  self.gating_network = nn.Linear(input_dim, self.num_experts, bias=True)
294
 
295
  def get_periodogram(self, inputs, ker_len=50, con=1, n=10000):
296
+ n_fft = 128
297
+ ker_len =12
298
+ if inputs.ndim == 2: # (B, L) → (B, L, 1)
299
+ x = inputs.unsqueeze(2)
300
+ else: # already (B, L, C)
301
+ x = inputs
302
+
303
+ B, L, C = x.shape
304
+ x = x - x.mean(dim=1, keepdim=True)
305
+
306
+ # ------------------------------------------------------------------ parameters
307
+ if n_fft is None:
308
+ n_fft = 1 << (L - 1).bit_length() # next power-of-two ≥ L
309
+ if ker_len is None:
310
+ ker_len = min(L // 4, 50) # never larger than the signal
311
+ ker_half = ker_len // 2
312
+
313
+ # ------------------------------------------------------------------ detrend
314
+ if con and ker_len > 0:
315
+ # (B, L, C) → (B, C, L) for conv1d
316
+ x_perm = x.permute(0, 2, 1)
317
+ ker = torch.ones(1, 1, ker_len, device=x.device) / ker_len
318
+ trend = F.conv1d(x_perm, ker, padding="same")
319
+ # Clamp boundary copies so we don’t index out of range for short signals
320
+ left = min(ker_half, L - 1)
321
+ right = min(ker_half, L - 1)
322
+ trend[:, :, :left] = trend[:, :, left:left+1]
323
+ trend[:, :, -right:] = trend[:, :, -(right+1):-right]
324
+ x_detrended = x_perm - trend
325
+ x = x_detrended.permute(0, 2, 1) # back to (B, L, C)
326
+
327
+ # ------------------------------------------------------------------ FFT
328
+ dft = torch.fft.fft(x, n=n_fft, dim=1) / np.sqrt(n_fft)
329
+ dft = dft[:, : n_fft // 2, :] # keep positive freqs
330
+ I = torch.abs(dft) ** 2 # periodogram
331
+
332
+ # ------------------------------------------------------------------ normalise
333
+ I_sum = I.sum(dim=1, keepdim=True)
334
+ I_sum[I_sum == 0] = 1 # avoid /0
335
+ I /= I_sum
336
+
337
+ # ------------------------------------------------------------------ squeeze back if original was 2-D
338
+ if inputs.ndim == 2:
339
  I = I.squeeze(2)
340
+
341
  return I
342
 
343