Update tensormapper.py
Browse files- tensormapper.py +8 -13
tensormapper.py
CHANGED
|
@@ -7,17 +7,17 @@ class VecDyT(nn.Module):
|
|
| 7 |
def __init__(self, input_shape):
|
| 8 |
|
| 9 |
super().__init__()
|
| 10 |
-
|
| 11 |
-
|
| 12 |
self.alpha = nn.Parameter(torch.randn(input_shape))
|
| 13 |
|
| 14 |
-
|
| 15 |
def forward(self, x):
|
| 16 |
x = torch.tanh(self.alpha * x)
|
| 17 |
return x
|
| 18 |
|
|
|
|
| 19 |
class GatingUnit(nn.Module):
|
| 20 |
def __init__(self,dim):
|
|
|
|
| 21 |
super().__init__()
|
| 22 |
|
| 23 |
self.proj_1 = nn.Linear(dim,dim,bias=False)
|
|
@@ -36,21 +36,16 @@ class GatingUnit(nn.Module):
|
|
| 36 |
|
| 37 |
return g
|
| 38 |
|
| 39 |
-
class TTT(nn.Module):
|
| 40 |
-
|
| 41 |
-
|
| 42 |
def __init__(self, dim: int):
|
|
|
|
| 43 |
super(TTT, self).__init__()
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
|
| 49 |
self.mapping = nn.Linear(dim,dim,bias=False)
|
| 50 |
self.State = nn.Linear(dim,dim,bias=False)
|
| 51 |
self.Probe = nn.Linear(dim,dim,bias=False)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
|
| 55 |
def forward(self, in_seq: Tensor) -> Tensor:
|
| 56 |
|
|
@@ -78,8 +73,8 @@ class TTT(nn.Module):
|
|
| 78 |
|
| 79 |
|
| 80 |
class TensorMapperBlock(nn.Module):
|
| 81 |
-
|
| 82 |
def __init__(self, dim, num_patch):
|
|
|
|
| 83 |
super().__init__()
|
| 84 |
|
| 85 |
self.norm_1 = VecDyT(dim)
|
|
|
|
| 7 |
def __init__(self, input_shape):
|
| 8 |
|
| 9 |
super().__init__()
|
| 10 |
+
|
|
|
|
| 11 |
self.alpha = nn.Parameter(torch.randn(input_shape))
|
| 12 |
|
|
|
|
| 13 |
def forward(self, x):
|
| 14 |
x = torch.tanh(self.alpha * x)
|
| 15 |
return x
|
| 16 |
|
| 17 |
+
|
| 18 |
class GatingUnit(nn.Module):
|
| 19 |
def __init__(self,dim):
|
| 20 |
+
|
| 21 |
super().__init__()
|
| 22 |
|
| 23 |
self.proj_1 = nn.Linear(dim,dim,bias=False)
|
|
|
|
| 36 |
|
| 37 |
return g
|
| 38 |
|
| 39 |
+
class TTT(nn.Module):
|
|
|
|
|
|
|
| 40 |
def __init__(self, dim: int):
|
| 41 |
+
|
| 42 |
super(TTT, self).__init__()
|
| 43 |
+
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
self.mapping = nn.Linear(dim,dim,bias=False)
|
| 46 |
self.State = nn.Linear(dim,dim,bias=False)
|
| 47 |
self.Probe = nn.Linear(dim,dim,bias=False)
|
| 48 |
+
|
|
|
|
| 49 |
|
| 50 |
def forward(self, in_seq: Tensor) -> Tensor:
|
| 51 |
|
|
|
|
| 73 |
|
| 74 |
|
| 75 |
class TensorMapperBlock(nn.Module):
|
|
|
|
| 76 |
def __init__(self, dim, num_patch):
|
| 77 |
+
|
| 78 |
super().__init__()
|
| 79 |
|
| 80 |
self.norm_1 = VecDyT(dim)
|