lirannoc commited on
Commit
6128d0f
·
verified ·
1 Parent(s): cecd232

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +64 -105
modeling_super_linear.py CHANGED
@@ -11,84 +11,7 @@ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
11
  from .configuration_super_linear import SuperLinearConfig
12
 
13
 
14
- "-------------------------------------------------------------------------------------------------------------------"
15
- class RevIN(nn.Module):
16
- def __init__(self, num_features: int, eps=1e-5, affine=True, norm_type=None, subtract_last=False):
17
- """
18
- :param num_features: the number of features or channels
19
- :param eps: a value added for numerical stability
20
- :param affine: if True, RevIN has learnable affine parameters
21
- """
22
- super(RevIN, self).__init__()
23
- self.num_features = num_features
24
- self.eps = eps
25
- self.affine = affine
26
- self.subtract_last = subtract_last
27
- self.norm_type = norm_type
28
- if self.affine:
29
- self._init_params()
30
-
31
- def forward(self, x, mode: str):
32
- if mode == 'norm':
33
- self._get_statistics(x)
34
- x = self._normalize(x)
35
- elif mode == 'denorm':
36
- x = self._denormalize(x)
37
- else:
38
- raise NotImplementedError
39
- return x
40
-
41
- def _init_params(self):
42
- # initialize RevIN params: (C,)
43
- self.affine_weight = nn.Parameter(torch.ones(self.num_features))
44
- self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
45
-
46
- def _get_statistics(self, x):
47
- dim2reduce = tuple(range(1, x.ndim-1))
48
-
49
- if self.subtract_last:
50
- self.last = x[:, -1:, :].detach()
51
- self.mean = torch.mean(x[:, :-1, :], dim=dim2reduce, keepdim=True).detach()
52
- else:
53
- self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
54
-
55
- self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
56
-
57
- if self.norm_type == "l1":
58
- self.stdev = torch.mean(torch.abs(x - self.mean), dim=dim2reduce, keepdim=True).detach()
59
- elif self.norm_type == "l2":
60
- self.stdev = torch.sqrt(torch.mean((x - self.mean) ** 2, dim=dim2reduce, keepdim=True) + self.eps).detach()
61
-
62
- def _normalize(self, x):
63
- if self.subtract_last:
64
- x = x - self.last
65
- else:
66
- x = x - self.mean
67
- x = x / self.stdev
68
 
69
- if self.norm_type in ["l1", "l2"]:
70
- x = x / self.stdev
71
-
72
- if self.affine:
73
- x = x * self.affine_weight
74
- x = x + self.affine_bias
75
- return x
76
-
77
- def _denormalize(self, x):
78
- if self.affine:
79
- x = x - self.affine_bias
80
- x = x / (self.affine_weight + self.eps*self.eps)
81
-
82
- if self.norm_type in ["l1", "l2"]:
83
- x = x * self.stdev
84
-
85
- x = x * self.stdev
86
- if self.subtract_last:
87
- x = x + self.last
88
- else:
89
- x = x + self.mean
90
-
91
- return x
92
  "-------------------------------------------------------------------------------------------------------------------"
93
  class Linear(nn.Module):
94
  """Simple linear layer expert."""
@@ -124,27 +47,6 @@ class Mean(nn.Module):
124
  x = x.mean(dim=1).unsqueeze(1).repeat(1, self.output_len)
125
  return x # to [Batch, Output length, Channel]
126
 
127
- class RLinear(nn.Module):
128
- """Reversible Instance Normalization Linear layer expert."""
129
- def __init__(self, input_len, output_len):
130
- super(RLinear, self).__init__()
131
- self.Linear = nn.Linear(input_len, output_len)
132
- self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
133
-
134
- def forward(self, x):
135
- # x: [Batch, Input length,Channel]
136
- x_shape = x.shape
137
- if len(x_shape) == 2:
138
- x = x.unsqueeze(-1)
139
- x = x.clone()
140
- x = self.revin_layer(x, 'norm')
141
-
142
- x = self.Linear(x.permute(0,2,1)).permute(0,2,1).clone()
143
- x = self.revin_layer(x, 'denorm')
144
- if len(x_shape) == 2:
145
- x = x.squeeze(-1)
146
- return x # to [Batch, Output length, Channel]
147
-
148
  "-------------------------------------------------------------------------------------------------------------------"
149
  class SparseMoE(nn.Module):
150
  """
@@ -171,6 +73,24 @@ class SparseMoE(nn.Module):
171
  self.use_fft = configs.use_fft
172
  self.fft_len = configs.fft_len
173
  self.moe_norm = configs.moe_norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  # Initialize gating network based on configuration
176
  if self.use_fft:
@@ -181,6 +101,18 @@ class SparseMoE(nn.Module):
181
  if self.moe_norm:
182
  self.batch_norm = nn.BatchNorm1d(self.num_experts)
183
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  def get_periodogram(self, inputs, n=10000):
185
  """
186
  Calculate the periodogram (power spectral density) of input time series.
@@ -252,9 +184,32 @@ class SparseMoE(nn.Module):
252
 
253
  # Normalize the gate values with softmax
254
  topk_gates = F.softmax(topk_values, dim=1)
255
-
256
- # Get outputs from all experts
257
- expert_outputs = torch.stack([self.experts[i](x) for i in range(self.num_experts)], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  # Select only the outputs from the top-k experts
260
  topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(2))
@@ -262,7 +217,10 @@ class SparseMoE(nn.Module):
262
 
263
  # Combine expert outputs using the gate values
264
  output = torch.sum(topk_gates.unsqueeze(2) * sparse_expert_outputs, dim=1)
265
-
 
 
 
266
  if get_prob:
267
  expert_probs = F.softmax(gate_outputs, dim=1)
268
  return output, expert_probs
@@ -324,7 +282,7 @@ class Model(nn.Module):
324
  elif expert_freq.lower() == "mean":
325
  self.experts[expert_freq] = Mean(self.train_seq_len, self.train_pred_len)
326
  else:
327
- self.experts[expert_freq] = RLinear(self.train_seq_len, self.train_pred_len)
328
  self.n_experts = len(self.experts)
329
  else:
330
  raise ValueError("Please specify experts in the configuration.")
@@ -334,11 +292,11 @@ class Model(nn.Module):
334
  if comp_moe > 0:
335
  if comp_moe == 1:
336
  print("Creating complementary expert")
337
- self.experts["comp"] = RLinear(self.train_seq_len, self.train_pred_len)
338
  else:
339
  for i in range(comp_moe):
340
  print(f"Creating complementary expert {i}")
341
- self.experts["comp_"+str(i)] = RLinear(self.train_seq_len, self.train_pred_len)
342
 
343
  # Initialize the MoE layer and dropout
344
  self.moe = SparseMoE(configs, experts=self.experts.values())
@@ -619,6 +577,7 @@ class Model(nn.Module):
619
 
620
  if x_in.dim() == 2:
621
  out = out.squeeze(1)
 
622
 
623
  if get_prob:
624
  expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
 
11
  from .configuration_super_linear import SuperLinearConfig
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  "-------------------------------------------------------------------------------------------------------------------"
16
  class Linear(nn.Module):
17
  """Simple linear layer expert."""
 
47
  x = x.mean(dim=1).unsqueeze(1).repeat(1, self.output_len)
48
  return x # to [Batch, Output length, Channel]
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  "-------------------------------------------------------------------------------------------------------------------"
51
  class SparseMoE(nn.Module):
52
  """
 
73
  self.use_fft = configs.use_fft
74
  self.fft_len = configs.fft_len
75
  self.moe_norm = configs.moe_norm
76
+
77
+ # Cache for batched expert parameters
78
+ self.stacked_weights = None
79
+ self.stacked_biases = None
80
+
81
+ # Separate linear and non-linear experts
82
+ self.linear_expert_types = ['Linear']
83
+ self.linear_experts = []
84
+ self.nonlinear_experts = []
85
+
86
+ for idx, expert in enumerate(self.experts):
87
+ expert_type = type(expert).__name__
88
+ if expert_type in self.linear_expert_types:
89
+ self.linear_experts.append(idx)
90
+ else:
91
+ self.nonlinear_experts.append(idx)
92
+ self.num_linear_experts = len(self.linear_experts)
93
+ self.num_nonlinear_experts = len(self.nonlinear_experts)
94
 
95
  # Initialize gating network based on configuration
96
  if self.use_fft:
 
101
  if self.moe_norm:
102
  self.batch_norm = nn.BatchNorm1d(self.num_experts)
103
 
104
+ def _get_stacked_expert_params(self):
105
+ """Get batched parameters for linear experts."""
106
+ if self.stacked_weights is None:
107
+ # Stack all linear expert weights: [n_linear_experts, pred_len, seq_len]
108
+ weights = torch.stack([self.experts[i].Linear.weight for i in self.linear_experts], dim=0)
109
+ # Stack all linear expert biases: [n_linear_experts, pred_len]
110
+ biases = torch.stack([self.experts[i].Linear.bias for i in self.linear_experts], dim=0)
111
+
112
+ self.stacked_weights = weights
113
+ self.stacked_biases = biases
114
+ return self.stacked_weights, self.stacked_biases
115
+
116
  def get_periodogram(self, inputs, n=10000):
117
  """
118
  Calculate the periodogram (power spectral density) of input time series.
 
184
 
185
  # Normalize the gate values with softmax
186
  topk_gates = F.softmax(topk_values, dim=1)
187
+
188
+ batch_size = x.size(0)
189
+
190
+ # RLinear (REVIN) normalization
191
+ x_mean, x_std = torch.mean(x, dim=1, keepdim=True), torch.std(x, dim=1, keepdim=True)
192
+ x_norm = (x - x_mean) / (x_std + 1e-5)
193
+
194
+ # Initialize expert_outputs tensor with correct shape (infer pred_len from first expert)
195
+ pred_len = self.experts[0](x_norm[:1]).shape[-1]
196
+ expert_outputs = torch.zeros(batch_size, self.num_experts, pred_len, device=x.device)
197
+
198
+ # Process linear experts using batched operations
199
+ if self.num_linear_experts > 0:
200
+ all_weights, all_biases = self._get_stacked_expert_params()
201
+ # Batched matrix multiplication: [n_linear_experts, pred_len, seq_len] @ [B, seq_len]
202
+ # Using einsum: expert, pred, seq @ batch, seq -> batch, expert, pred
203
+ linear_expert_outputs = torch.einsum('epd,bd->bep', all_weights, x_norm)
204
+ # Add biases: [n_linear_experts, pred_len] -> [1, n_linear_experts, pred_len]
205
+ linear_expert_outputs = linear_expert_outputs + all_biases.unsqueeze(0)
206
+ # Place linear expert outputs in their correct positions
207
+ for i, expert_idx in enumerate(self.linear_experts):
208
+ expert_outputs[:, expert_idx, :] = linear_expert_outputs[:, i, :]
209
+
210
+ # Process non-linear experts separately and place in correct positions
211
+ for expert_idx in self.nonlinear_experts:
212
+ expert_outputs[:, expert_idx, :] = self.experts[expert_idx](x_norm)
213
 
214
  # Select only the outputs from the top-k experts
215
  topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(2))
 
217
 
218
  # Combine expert outputs using the gate values
219
  output = torch.sum(topk_gates.unsqueeze(2) * sparse_expert_outputs, dim=1)
220
+
221
+ # RLinear (REVIN) denormalization
222
+ output = output * (x_std + 1e-5) + x_mean
223
+
224
  if get_prob:
225
  expert_probs = F.softmax(gate_outputs, dim=1)
226
  return output, expert_probs
 
282
  elif expert_freq.lower() == "mean":
283
  self.experts[expert_freq] = Mean(self.train_seq_len, self.train_pred_len)
284
  else:
285
+ self.experts[expert_freq] = Linear(self.train_seq_len, self.train_pred_len)
286
  self.n_experts = len(self.experts)
287
  else:
288
  raise ValueError("Please specify experts in the configuration.")
 
292
  if comp_moe > 0:
293
  if comp_moe == 1:
294
  print("Creating complementary expert")
295
+ self.experts["comp"] = Linear(self.train_seq_len, self.train_pred_len)
296
  else:
297
  for i in range(comp_moe):
298
  print(f"Creating complementary expert {i}")
299
+ self.experts["comp_"+str(i)] = Linear(self.train_seq_len, self.train_pred_len)
300
 
301
  # Initialize the MoE layer and dropout
302
  self.moe = SparseMoE(configs, experts=self.experts.values())
 
577
 
578
  if x_in.dim() == 2:
579
  out = out.squeeze(1)
580
+
581
 
582
  if get_prob:
583
  expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])