Update modeling_super_linear.py
Browse files- modeling_super_linear.py +64 -105
modeling_super_linear.py
CHANGED
|
@@ -11,84 +11,7 @@ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
|
| 11 |
from .configuration_super_linear import SuperLinearConfig
|
| 12 |
|
| 13 |
|
| 14 |
-
"-------------------------------------------------------------------------------------------------------------------"
|
| 15 |
-
class RevIN(nn.Module):
|
| 16 |
-
def __init__(self, num_features: int, eps=1e-5, affine=True, norm_type=None, subtract_last=False):
|
| 17 |
-
"""
|
| 18 |
-
:param num_features: the number of features or channels
|
| 19 |
-
:param eps: a value added for numerical stability
|
| 20 |
-
:param affine: if True, RevIN has learnable affine parameters
|
| 21 |
-
"""
|
| 22 |
-
super(RevIN, self).__init__()
|
| 23 |
-
self.num_features = num_features
|
| 24 |
-
self.eps = eps
|
| 25 |
-
self.affine = affine
|
| 26 |
-
self.subtract_last = subtract_last
|
| 27 |
-
self.norm_type = norm_type
|
| 28 |
-
if self.affine:
|
| 29 |
-
self._init_params()
|
| 30 |
-
|
| 31 |
-
def forward(self, x, mode: str):
|
| 32 |
-
if mode == 'norm':
|
| 33 |
-
self._get_statistics(x)
|
| 34 |
-
x = self._normalize(x)
|
| 35 |
-
elif mode == 'denorm':
|
| 36 |
-
x = self._denormalize(x)
|
| 37 |
-
else:
|
| 38 |
-
raise NotImplementedError
|
| 39 |
-
return x
|
| 40 |
-
|
| 41 |
-
def _init_params(self):
|
| 42 |
-
# initialize RevIN params: (C,)
|
| 43 |
-
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
|
| 44 |
-
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
|
| 45 |
-
|
| 46 |
-
def _get_statistics(self, x):
|
| 47 |
-
dim2reduce = tuple(range(1, x.ndim-1))
|
| 48 |
-
|
| 49 |
-
if self.subtract_last:
|
| 50 |
-
self.last = x[:, -1:, :].detach()
|
| 51 |
-
self.mean = torch.mean(x[:, :-1, :], dim=dim2reduce, keepdim=True).detach()
|
| 52 |
-
else:
|
| 53 |
-
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
|
| 54 |
-
|
| 55 |
-
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
|
| 56 |
-
|
| 57 |
-
if self.norm_type == "l1":
|
| 58 |
-
self.stdev = torch.mean(torch.abs(x - self.mean), dim=dim2reduce, keepdim=True).detach()
|
| 59 |
-
elif self.norm_type == "l2":
|
| 60 |
-
self.stdev = torch.sqrt(torch.mean((x - self.mean) ** 2, dim=dim2reduce, keepdim=True) + self.eps).detach()
|
| 61 |
-
|
| 62 |
-
def _normalize(self, x):
|
| 63 |
-
if self.subtract_last:
|
| 64 |
-
x = x - self.last
|
| 65 |
-
else:
|
| 66 |
-
x = x - self.mean
|
| 67 |
-
x = x / self.stdev
|
| 68 |
|
| 69 |
-
if self.norm_type in ["l1", "l2"]:
|
| 70 |
-
x = x / self.stdev
|
| 71 |
-
|
| 72 |
-
if self.affine:
|
| 73 |
-
x = x * self.affine_weight
|
| 74 |
-
x = x + self.affine_bias
|
| 75 |
-
return x
|
| 76 |
-
|
| 77 |
-
def _denormalize(self, x):
|
| 78 |
-
if self.affine:
|
| 79 |
-
x = x - self.affine_bias
|
| 80 |
-
x = x / (self.affine_weight + self.eps*self.eps)
|
| 81 |
-
|
| 82 |
-
if self.norm_type in ["l1", "l2"]:
|
| 83 |
-
x = x * self.stdev
|
| 84 |
-
|
| 85 |
-
x = x * self.stdev
|
| 86 |
-
if self.subtract_last:
|
| 87 |
-
x = x + self.last
|
| 88 |
-
else:
|
| 89 |
-
x = x + self.mean
|
| 90 |
-
|
| 91 |
-
return x
|
| 92 |
"-------------------------------------------------------------------------------------------------------------------"
|
| 93 |
class Linear(nn.Module):
|
| 94 |
"""Simple linear layer expert."""
|
|
@@ -124,27 +47,6 @@ class Mean(nn.Module):
|
|
| 124 |
x = x.mean(dim=1).unsqueeze(1).repeat(1, self.output_len)
|
| 125 |
return x # to [Batch, Output length, Channel]
|
| 126 |
|
| 127 |
-
class RLinear(nn.Module):
|
| 128 |
-
"""Reversible Instance Normalization Linear layer expert."""
|
| 129 |
-
def __init__(self, input_len, output_len):
|
| 130 |
-
super(RLinear, self).__init__()
|
| 131 |
-
self.Linear = nn.Linear(input_len, output_len)
|
| 132 |
-
self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
|
| 133 |
-
|
| 134 |
-
def forward(self, x):
|
| 135 |
-
# x: [Batch, Input length,Channel]
|
| 136 |
-
x_shape = x.shape
|
| 137 |
-
if len(x_shape) == 2:
|
| 138 |
-
x = x.unsqueeze(-1)
|
| 139 |
-
x = x.clone()
|
| 140 |
-
x = self.revin_layer(x, 'norm')
|
| 141 |
-
|
| 142 |
-
x = self.Linear(x.permute(0,2,1)).permute(0,2,1).clone()
|
| 143 |
-
x = self.revin_layer(x, 'denorm')
|
| 144 |
-
if len(x_shape) == 2:
|
| 145 |
-
x = x.squeeze(-1)
|
| 146 |
-
return x # to [Batch, Output length, Channel]
|
| 147 |
-
|
| 148 |
"-------------------------------------------------------------------------------------------------------------------"
|
| 149 |
class SparseMoE(nn.Module):
|
| 150 |
"""
|
|
@@ -171,6 +73,24 @@ class SparseMoE(nn.Module):
|
|
| 171 |
self.use_fft = configs.use_fft
|
| 172 |
self.fft_len = configs.fft_len
|
| 173 |
self.moe_norm = configs.moe_norm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
# Initialize gating network based on configuration
|
| 176 |
if self.use_fft:
|
|
@@ -181,6 +101,18 @@ class SparseMoE(nn.Module):
|
|
| 181 |
if self.moe_norm:
|
| 182 |
self.batch_norm = nn.BatchNorm1d(self.num_experts)
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
def get_periodogram(self, inputs, n=10000):
|
| 185 |
"""
|
| 186 |
Calculate the periodogram (power spectral density) of input time series.
|
|
@@ -252,9 +184,32 @@ class SparseMoE(nn.Module):
|
|
| 252 |
|
| 253 |
# Normalize the gate values with softmax
|
| 254 |
topk_gates = F.softmax(topk_values, dim=1)
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
# Select only the outputs from the top-k experts
|
| 260 |
topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(2))
|
|
@@ -262,7 +217,10 @@ class SparseMoE(nn.Module):
|
|
| 262 |
|
| 263 |
# Combine expert outputs using the gate values
|
| 264 |
output = torch.sum(topk_gates.unsqueeze(2) * sparse_expert_outputs, dim=1)
|
| 265 |
-
|
|
|
|
|
|
|
|
|
|
| 266 |
if get_prob:
|
| 267 |
expert_probs = F.softmax(gate_outputs, dim=1)
|
| 268 |
return output, expert_probs
|
|
@@ -324,7 +282,7 @@ class Model(nn.Module):
|
|
| 324 |
elif expert_freq.lower() == "mean":
|
| 325 |
self.experts[expert_freq] = Mean(self.train_seq_len, self.train_pred_len)
|
| 326 |
else:
|
| 327 |
-
self.experts[expert_freq] =
|
| 328 |
self.n_experts = len(self.experts)
|
| 329 |
else:
|
| 330 |
raise ValueError("Please specify experts in the configuration.")
|
|
@@ -334,11 +292,11 @@ class Model(nn.Module):
|
|
| 334 |
if comp_moe > 0:
|
| 335 |
if comp_moe == 1:
|
| 336 |
print("Creating complementary expert")
|
| 337 |
-
self.experts["comp"] =
|
| 338 |
else:
|
| 339 |
for i in range(comp_moe):
|
| 340 |
print(f"Creating complementary expert {i}")
|
| 341 |
-
self.experts["comp_"+str(i)] =
|
| 342 |
|
| 343 |
# Initialize the MoE layer and dropout
|
| 344 |
self.moe = SparseMoE(configs, experts=self.experts.values())
|
|
@@ -619,6 +577,7 @@ class Model(nn.Module):
|
|
| 619 |
|
| 620 |
if x_in.dim() == 2:
|
| 621 |
out = out.squeeze(1)
|
|
|
|
| 622 |
|
| 623 |
if get_prob:
|
| 624 |
expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
|
|
|
|
| 11 |
from .configuration_super_linear import SuperLinearConfig
|
| 12 |
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
"-------------------------------------------------------------------------------------------------------------------"
|
| 16 |
class Linear(nn.Module):
|
| 17 |
"""Simple linear layer expert."""
|
|
|
|
| 47 |
x = x.mean(dim=1).unsqueeze(1).repeat(1, self.output_len)
|
| 48 |
return x # to [Batch, Output length, Channel]
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
"-------------------------------------------------------------------------------------------------------------------"
|
| 51 |
class SparseMoE(nn.Module):
|
| 52 |
"""
|
|
|
|
| 73 |
self.use_fft = configs.use_fft
|
| 74 |
self.fft_len = configs.fft_len
|
| 75 |
self.moe_norm = configs.moe_norm
|
| 76 |
+
|
| 77 |
+
# Cache for batched expert parameters
|
| 78 |
+
self.stacked_weights = None
|
| 79 |
+
self.stacked_biases = None
|
| 80 |
+
|
| 81 |
+
# Separate linear and non-linear experts
|
| 82 |
+
self.linear_expert_types = ['Linear']
|
| 83 |
+
self.linear_experts = []
|
| 84 |
+
self.nonlinear_experts = []
|
| 85 |
+
|
| 86 |
+
for idx, expert in enumerate(self.experts):
|
| 87 |
+
expert_type = type(expert).__name__
|
| 88 |
+
if expert_type in self.linear_expert_types:
|
| 89 |
+
self.linear_experts.append(idx)
|
| 90 |
+
else:
|
| 91 |
+
self.nonlinear_experts.append(idx)
|
| 92 |
+
self.num_linear_experts = len(self.linear_experts)
|
| 93 |
+
self.num_nonlinear_experts = len(self.nonlinear_experts)
|
| 94 |
|
| 95 |
# Initialize gating network based on configuration
|
| 96 |
if self.use_fft:
|
|
|
|
| 101 |
if self.moe_norm:
|
| 102 |
self.batch_norm = nn.BatchNorm1d(self.num_experts)
|
| 103 |
|
| 104 |
+
def _get_stacked_expert_params(self):
|
| 105 |
+
"""Get batched parameters for linear experts."""
|
| 106 |
+
if self.stacked_weights is None:
|
| 107 |
+
# Stack all linear expert weights: [n_linear_experts, pred_len, seq_len]
|
| 108 |
+
weights = torch.stack([self.experts[i].Linear.weight for i in self.linear_experts], dim=0)
|
| 109 |
+
# Stack all linear expert biases: [n_linear_experts, pred_len]
|
| 110 |
+
biases = torch.stack([self.experts[i].Linear.bias for i in self.linear_experts], dim=0)
|
| 111 |
+
|
| 112 |
+
self.stacked_weights = weights
|
| 113 |
+
self.stacked_biases = biases
|
| 114 |
+
return self.stacked_weights, self.stacked_biases
|
| 115 |
+
|
| 116 |
def get_periodogram(self, inputs, n=10000):
|
| 117 |
"""
|
| 118 |
Calculate the periodogram (power spectral density) of input time series.
|
|
|
|
| 184 |
|
| 185 |
# Normalize the gate values with softmax
|
| 186 |
topk_gates = F.softmax(topk_values, dim=1)
|
| 187 |
+
|
| 188 |
+
batch_size = x.size(0)
|
| 189 |
+
|
| 190 |
+
# RLinear (REVIN) normalization
|
| 191 |
+
x_mean, x_std = torch.mean(x, dim=1, keepdim=True), torch.std(x, dim=1, keepdim=True)
|
| 192 |
+
x_norm = (x - x_mean) / (x_std + 1e-5)
|
| 193 |
+
|
| 194 |
+
# Initialize expert_outputs tensor with correct shape (infer pred_len from first expert)
|
| 195 |
+
pred_len = self.experts[0](x_norm[:1]).shape[-1]
|
| 196 |
+
expert_outputs = torch.zeros(batch_size, self.num_experts, pred_len, device=x.device)
|
| 197 |
+
|
| 198 |
+
# Process linear experts using batched operations
|
| 199 |
+
if self.num_linear_experts > 0:
|
| 200 |
+
all_weights, all_biases = self._get_stacked_expert_params()
|
| 201 |
+
# Batched matrix multiplication: [n_linear_experts, pred_len, seq_len] @ [B, seq_len]
|
| 202 |
+
# Using einsum: expert, pred, seq @ batch, seq -> batch, expert, pred
|
| 203 |
+
linear_expert_outputs = torch.einsum('epd,bd->bep', all_weights, x_norm)
|
| 204 |
+
# Add biases: [n_linear_experts, pred_len] -> [1, n_linear_experts, pred_len]
|
| 205 |
+
linear_expert_outputs = linear_expert_outputs + all_biases.unsqueeze(0)
|
| 206 |
+
# Place linear expert outputs in their correct positions
|
| 207 |
+
for i, expert_idx in enumerate(self.linear_experts):
|
| 208 |
+
expert_outputs[:, expert_idx, :] = linear_expert_outputs[:, i, :]
|
| 209 |
+
|
| 210 |
+
# Process non-linear experts separately and place in correct positions
|
| 211 |
+
for expert_idx in self.nonlinear_experts:
|
| 212 |
+
expert_outputs[:, expert_idx, :] = self.experts[expert_idx](x_norm)
|
| 213 |
|
| 214 |
# Select only the outputs from the top-k experts
|
| 215 |
topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(2))
|
|
|
|
| 217 |
|
| 218 |
# Combine expert outputs using the gate values
|
| 219 |
output = torch.sum(topk_gates.unsqueeze(2) * sparse_expert_outputs, dim=1)
|
| 220 |
+
|
| 221 |
+
# RLinear (REVIN) denormalization
|
| 222 |
+
output = output * (x_std + 1e-5) + x_mean
|
| 223 |
+
|
| 224 |
if get_prob:
|
| 225 |
expert_probs = F.softmax(gate_outputs, dim=1)
|
| 226 |
return output, expert_probs
|
|
|
|
| 282 |
elif expert_freq.lower() == "mean":
|
| 283 |
self.experts[expert_freq] = Mean(self.train_seq_len, self.train_pred_len)
|
| 284 |
else:
|
| 285 |
+
self.experts[expert_freq] = Linear(self.train_seq_len, self.train_pred_len)
|
| 286 |
self.n_experts = len(self.experts)
|
| 287 |
else:
|
| 288 |
raise ValueError("Please specify experts in the configuration.")
|
|
|
|
| 292 |
if comp_moe > 0:
|
| 293 |
if comp_moe == 1:
|
| 294 |
print("Creating complementary expert")
|
| 295 |
+
self.experts["comp"] = Linear(self.train_seq_len, self.train_pred_len)
|
| 296 |
else:
|
| 297 |
for i in range(comp_moe):
|
| 298 |
print(f"Creating complementary expert {i}")
|
| 299 |
+
self.experts["comp_"+str(i)] = Linear(self.train_seq_len, self.train_pred_len)
|
| 300 |
|
| 301 |
# Initialize the MoE layer and dropout
|
| 302 |
self.moe = SparseMoE(configs, experts=self.experts.values())
|
|
|
|
| 577 |
|
| 578 |
if x_in.dim() == 2:
|
| 579 |
out = out.squeeze(1)
|
| 580 |
+
|
| 581 |
|
| 582 |
if get_prob:
|
| 583 |
expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1])
|