razmars commited on
Commit
f0f208d
·
verified ·
1 Parent(s): 0885b44

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +34 -34
modeling_super_linear.py CHANGED
@@ -293,40 +293,40 @@ class SparseNoisyMoE(nn.Module):
293
  self.gating_network = nn.Linear(input_dim, self.num_experts, bias=True)
294
 
295
  def get_periodogram(self, inputs, ker_len=50, con=1, n=10000):
296
- if inputs.dim() == 2:
297
- x_0 = inputs.unsqueeze(2)
298
- else:
299
- x_0 = inputs
300
- x_0 = x_0 - torch.mean(x_0, dim=1, keepdim=True)
301
-
302
- v = torch.arange(0, n) / n
303
- if con:
304
- if ker_len is None:
305
- ker_len = n // 4
306
- ker_len = min(ker_len, 50)
307
-
308
- x_0 = x_0.permute(0, 2, 1)
309
- ker = (torch.ones(1, 1, ker_len) / ker_len).to(x_0.device)
310
- x_c = F.conv1d(x_0, ker, padding="same")
311
- x_c[:, :, :ker_len // 2] = x_c[:, :, ker_len // 2:ker_len // 2 + 1]
312
- x_c[:, :, -ker_len // 2:] = x_c[:, :, -ker_len // 2 - 1:-ker_len // 2]
313
- x_0 = x_0 - x_c
314
- x_0 = x_0.permute(0, 2, 1)
315
-
316
- dft = torch.fft.fft(x_0, dim=1, n=n) / np.sqrt(n)
317
- dft = dft[:, :n//2, :]
318
- I = torch.abs(dft) ** 2
319
-
320
- I_sum = torch.sum(I, dim=1, keepdim=True)
321
- I_sum[I_sum == 0] = 1
322
- I = I / I_sum
323
-
324
- if torch.any(I_sum == 0):
325
- print("Zeros in the sum")
326
- raise ValueError
327
-
328
- if inputs.dim() == 2:
329
- I = I.squeeze(2)
330
 
331
 
332
  def fourier_interp_dim1(self,x, target_len: int = 512):
 
293
  self.gating_network = nn.Linear(input_dim, self.num_experts, bias=True)
294
 
295
  def get_periodogram(self, inputs, ker_len=50, con=1, n=10000):
296
+ if inputs.dim() == 2:
297
+ x_0 = inputs.unsqueeze(2)
298
+ else:
299
+ x_0 = inputs
300
+ x_0 = x_0 - torch.mean(x_0, dim=1, keepdim=True)
301
+
302
+ v = torch.arange(0, n) / n
303
+ if con:
304
+ if ker_len is None:
305
+ ker_len = n // 4
306
+ ker_len = min(ker_len, 50)
307
+
308
+ x_0 = x_0.permute(0, 2, 1)
309
+ ker = (torch.ones(1, 1, ker_len) / ker_len).to(x_0.device)
310
+ x_c = F.conv1d(x_0, ker, padding="same")
311
+ x_c[:, :, :ker_len // 2] = x_c[:, :, ker_len // 2:ker_len // 2 + 1]
312
+ x_c[:, :, -ker_len // 2:] = x_c[:, :, -ker_len // 2 - 1:-ker_len // 2]
313
+ x_0 = x_0 - x_c
314
+ x_0 = x_0.permute(0, 2, 1)
315
+
316
+ dft = torch.fft.fft(x_0, dim=1, n=n) / np.sqrt(n)
317
+ dft = dft[:, :n//2, :]
318
+ I = torch.abs(dft) ** 2
319
+
320
+ I_sum = torch.sum(I, dim=1, keepdim=True)
321
+ I_sum[I_sum == 0] = 1
322
+ I = I / I_sum
323
+
324
+ if torch.any(I_sum == 0):
325
+ print("Zeros in the sum")
326
+ raise ValueError
327
+
328
+ if inputs.dim() == 2:
329
+ I = I.squeeze(2)
330
 
331
 
332
  def fourier_interp_dim1(self,x, target_len: int = 512):