Update rmsnorm.py
Browse filesModifies the forward pass of RMSNorm to avoid mixed precision issues as described in https://github.com/chandar-lab/AMPLIFY/issues/19
- rmsnorm.py +5 -1
rmsnorm.py
CHANGED
|
@@ -20,6 +20,9 @@ class RMSNorm(nn.Module):
|
|
| 20 |
self.eps = eps
|
| 21 |
self.weight = nn.Parameter(torch.ones(dim))
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
def forward(self, x):
|
| 24 |
"""
|
| 25 |
Forward pass through the RMSNorm layer.
|
|
@@ -31,4 +34,5 @@ class RMSNorm(nn.Module):
|
|
| 31 |
torch.Tensor: The output tensor after applying RMSNorm.
|
| 32 |
|
| 33 |
"""
|
| 34 |
-
|
|
|
|
|
|
| 20 |
self.eps = eps
|
| 21 |
self.weight = nn.Parameter(torch.ones(dim))
|
| 22 |
|
| 23 |
+
def _norm(self, x):
|
| 24 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 25 |
+
|
| 26 |
def forward(self, x):
|
| 27 |
"""
|
| 28 |
Forward pass through the RMSNorm layer.
|
|
|
|
| 34 |
torch.Tensor: The output tensor after applying RMSNorm.
|
| 35 |
|
| 36 |
"""
|
| 37 |
+
output = self._norm(x.float()).type_as(x) # Avoids mixed precision issues as in https://github.com/chandar-lab/AMPLIFY/issues/19
|
| 38 |
+
return output * self.weight
|