razmars commited on
Commit
1812686
·
verified ·
1 Parent(s): c800a8a

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +4 -12
modeling_super_linear.py CHANGED
@@ -98,16 +98,8 @@ class moving_avg(nn.Module):
98
  super(moving_avg, self).__init__()
99
  self.kernel_size = kernel_size
100
  self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
101
- """
102
- def forward(self, x):
103
- # padding on the both ends of time series
104
- front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
105
- end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
106
- x = torch.cat([front, x, end], dim=1)
107
- x = self.avg(x.permute(0, 2, 1))
108
- x = x.permute(0, 2, 1)
109
- return x
110
- """
111
  def forward(self, x):
112
  # x: [Batch, Input length]
113
  # padding on the both ends of time series
@@ -236,7 +228,7 @@ class RLinear(nn.Module):
236
  def forward(self, x):
237
  # x: [Batch, Input length,Channel]
238
  x_shape = x.shape
239
- if x.shape[1] < self.seq_len:
240
  #if self.zero_shot_Linear is None:
241
  #print(F"new Lookkback : {x.shape[1]}")
242
 
@@ -244,7 +236,7 @@ class RLinear(nn.Module):
244
  x = self.revin_layer(x, 'norm')
245
  x = F.linear(x, self.zero_shot_Linear)
246
  x = self.revin_layer(x, 'denorm')
247
- return x
248
 
249
 
250
  if len(x_shape) == 2:
 
98
  super(moving_avg, self).__init__()
99
  self.kernel_size = kernel_size
100
  self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
101
+
102
+
 
 
 
 
 
 
 
 
103
  def forward(self, x):
104
  # x: [Batch, Input length]
105
  # padding on the both ends of time series
 
228
  def forward(self, x):
229
  # x: [Batch, Input length,Channel]
230
  x_shape = x.shape
231
+ ''''if x.shape[1] < self.seq_len:
232
  #if self.zero_shot_Linear is None:
233
  #print(F"new Lookkback : {x.shape[1]}")
234
 
 
236
  x = self.revin_layer(x, 'norm')
237
  x = F.linear(x, self.zero_shot_Linear)
238
  x = self.revin_layer(x, 'denorm')
239
+ return x'''
240
 
241
 
242
  if len(x_shape) == 2: