razmars commited on
Commit
0885b44
·
verified ·
1 Parent(s): dcf99e3

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +28 -52
modeling_super_linear.py CHANGED
@@ -293,64 +293,40 @@ class SparseNoisyMoE(nn.Module):
293
  self.gating_network = nn.Linear(input_dim, self.num_experts, bias=True)
294
 
295
  def get_periodogram(self, inputs, ker_len=50, con=1, n=10000):
296
- n_fft = 128
297
- ker_len =12
298
- if inputs.ndim == 2: # (B, L)
299
- B, L = inputs.shape
300
- C = 1
301
- x = inputs.unsqueeze(2) # → (B, L, 1)
302
- time_first = True # time is dim-1
303
- elif inputs.ndim == 3:
304
- B, d1, d2 = inputs.shape
305
- if d1 < d2: # (B, L, C)
306
- L, C = d1, d2
307
- x = inputs
308
- time_first = True
309
- else: # (B, C, L)
310
- C, L = d1, d2
311
- x = inputs.transpose(1, 2) # → (B, L, C)
312
- time_first = False
313
- else:
314
- raise ValueError("Input must be (B,L), (B,L,C) or (B,C,L)")
315
-
316
- # ---------- centre the signal ----------
317
- x = x - x.mean(dim=1, keepdim=True)
318
 
319
- # ---------- parameter defaults ----------
320
- if n_fft is None:
321
- n_fft = 1 << (L - 1).bit_length()
322
- if ker_len is None:
323
- ker_len = min(L // 4, 50)
324
- ker_half = ker_len // 2
325
 
326
- # ---------- high-pass detrend ----------
327
- if con and ker_len > 0:
328
- x_perm = x.permute(0, 2, 1) # (B, C, L)
329
- ker = torch.ones(1, 1, ker_len, device=x.device) / ker_len
330
- trend = F.conv1d(x_perm, ker, padding="same")
331
- left = min(ker_half, L - 1)
332
- right = min(ker_half, L - 1)
333
- trend[:, :, :left] = trend[:, :, left:left+1]
334
- trend[:, :, -right:] = trend[:, :, -(right+1):-right]
335
- x = (x_perm - trend).permute(0, 2, 1) # back to (B, L, C)
336
 
337
- # ---------- FFT ----------
338
- dft = torch.fft.fft(x, n=n_fft, dim=1) / np.sqrt(n_fft)
339
- I = (dft[:, : n_fft//2, :]).abs() ** 2
340
 
341
- # ---------- normalise ----------
342
- I_sum = I.sum(dim=1, keepdim=True)
343
- I_sum[I_sum == 0] = 1
344
- I /= I_sum
345
 
346
- # ---------- restore original layout ----------
347
- if inputs.ndim == 2: # wanted (B, … )
348
- return I.squeeze(2)
349
 
350
- if time_first: # original was (B, L, C)
351
- return I # already (B, F, C)
352
- else: # original was (B, C, L) → (B, C, F)
353
- return I.transpose(1, 2)
354
 
355
 
356
  def fourier_interp_dim1(self,x, target_len: int = 512):
 
293
  self.gating_network = nn.Linear(input_dim, self.num_experts, bias=True)
294
 
295
  def get_periodogram(self, inputs, ker_len=50, con=1, n=10000):
296
+ if inputs.dim() == 2:
297
+ x_0 = inputs.unsqueeze(2)
298
+ else:
299
+ x_0 = inputs
300
+ x_0 = x_0 - torch.mean(x_0, dim=1, keepdim=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
+ v = torch.arange(0, n) / n
303
+ if con:
304
+ if ker_len is None:
305
+ ker_len = n // 4
306
+ ker_len = min(ker_len, 50)
 
307
 
308
+ x_0 = x_0.permute(0, 2, 1)
309
+ ker = (torch.ones(1, 1, ker_len) / ker_len).to(x_0.device)
310
+ x_c = F.conv1d(x_0, ker, padding="same")
311
+ x_c[:, :, :ker_len // 2] = x_c[:, :, ker_len // 2:ker_len // 2 + 1]
312
+ x_c[:, :, -ker_len // 2:] = x_c[:, :, -ker_len // 2 - 1:-ker_len // 2]
313
+ x_0 = x_0 - x_c
314
+ x_0 = x_0.permute(0, 2, 1)
 
 
 
315
 
316
+ dft = torch.fft.fft(x_0, dim=1, n=n) / np.sqrt(n)
317
+ dft = dft[:, :n//2, :]
318
+ I = torch.abs(dft) ** 2
319
 
320
+ I_sum = torch.sum(I, dim=1, keepdim=True)
321
+ I_sum[I_sum == 0] = 1
322
+ I = I / I_sum
 
323
 
324
+ if torch.any(I_sum == 0):
325
+ print("Zeros in the sum")
326
+ raise ValueError
327
 
328
+ if inputs.dim() == 2:
329
+ I = I.squeeze(2)
 
 
330
 
331
 
332
  def fourier_interp_dim1(self,x, target_len: int = 512):