Abdullah-Nazhat commited on
Commit
ea63f8b
·
verified ·
1 Parent(s): ae3c6b6

Update tensormapper.py

Browse files
Files changed (1) hide show
  1. tensormapper.py +8 -13
tensormapper.py CHANGED
@@ -7,17 +7,17 @@ class VecDyT(nn.Module):
7
  def __init__(self, input_shape):
8
 
9
  super().__init__()
10
-
11
-
12
  self.alpha = nn.Parameter(torch.randn(input_shape))
13
 
14
-
15
  def forward(self, x):
16
  x = torch.tanh(self.alpha * x)
17
  return x
18
 
 
19
  class GatingUnit(nn.Module):
20
  def __init__(self,dim):
 
21
  super().__init__()
22
 
23
  self.proj_1 = nn.Linear(dim,dim,bias=False)
@@ -36,21 +36,16 @@ class GatingUnit(nn.Module):
36
 
37
  return g
38
 
39
- class TTT(nn.Module):
40
-
41
-
42
  def __init__(self, dim: int):
 
43
  super(TTT, self).__init__()
44
-
45
-
46
-
47
-
48
 
49
  self.mapping = nn.Linear(dim,dim,bias=False)
50
  self.State = nn.Linear(dim,dim,bias=False)
51
  self.Probe = nn.Linear(dim,dim,bias=False)
52
-
53
-
54
 
55
  def forward(self, in_seq: Tensor) -> Tensor:
56
 
@@ -78,8 +73,8 @@ class TTT(nn.Module):
78
 
79
 
80
  class TensorMapperBlock(nn.Module):
81
-
82
  def __init__(self, dim, num_patch):
 
83
  super().__init__()
84
 
85
  self.norm_1 = VecDyT(dim)
 
7
  def __init__(self, input_shape):
8
 
9
  super().__init__()
10
+
 
11
  self.alpha = nn.Parameter(torch.randn(input_shape))
12
 
 
13
  def forward(self, x):
14
  x = torch.tanh(self.alpha * x)
15
  return x
16
 
17
+
18
  class GatingUnit(nn.Module):
19
  def __init__(self,dim):
20
+
21
  super().__init__()
22
 
23
  self.proj_1 = nn.Linear(dim,dim,bias=False)
 
36
 
37
  return g
38
 
39
+ class TTT(nn.Module):
 
 
40
  def __init__(self, dim: int):
41
+
42
  super(TTT, self).__init__()
43
+
 
 
 
44
 
45
  self.mapping = nn.Linear(dim,dim,bias=False)
46
  self.State = nn.Linear(dim,dim,bias=False)
47
  self.Probe = nn.Linear(dim,dim,bias=False)
48
+
 
49
 
50
  def forward(self, in_seq: Tensor) -> Tensor:
51
 
 
73
 
74
 
75
  class TensorMapperBlock(nn.Module):
 
76
  def __init__(self, dim, num_patch):
77
+
78
  super().__init__()
79
 
80
  self.norm_1 = VecDyT(dim)