Update rmsnorm.py
Browse filesModifies the forward pass of RMSNorm to avoid mixed precision issues as described in https://github.com/chandar-lab/AMPLIFY/issues/19
- rmsnorm.py +5 -7
rmsnorm.py
CHANGED
|
@@ -6,29 +6,27 @@ class RMSNorm(nn.Module):
|
|
| 6 |
def __init__(self, dim: int, eps: float = 1e-6):
|
| 7 |
"""
|
| 8 |
Initialize the RMSNorm normalization layer.
|
| 9 |
-
|
| 10 |
Args:
|
| 11 |
dim (int): The dimension of the input tensor.
|
| 12 |
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
| 13 |
-
|
| 14 |
Attributes:
|
| 15 |
eps (float): A small value added to the denominator for numerical stability.
|
| 16 |
weight (nn.Parameter): Learnable scaling parameter.
|
| 17 |
-
|
| 18 |
"""
|
| 19 |
super().__init__()
|
| 20 |
self.eps = eps
|
| 21 |
self.weight = nn.Parameter(torch.ones(dim))
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
def forward(self, x):
|
| 24 |
"""
|
| 25 |
Forward pass through the RMSNorm layer.
|
| 26 |
-
|
| 27 |
Args:
|
| 28 |
x (torch.Tensor): The input tensor.
|
| 29 |
-
|
| 30 |
Returns:
|
| 31 |
torch.Tensor: The output tensor after applying RMSNorm.
|
| 32 |
-
|
| 33 |
"""
|
| 34 |
-
|
|
|
|
|
|
| 6 |
def __init__(self, dim: int, eps: float = 1e-6):
|
| 7 |
"""
|
| 8 |
Initialize the RMSNorm normalization layer.
|
|
|
|
| 9 |
Args:
|
| 10 |
dim (int): The dimension of the input tensor.
|
| 11 |
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
|
|
|
| 12 |
Attributes:
|
| 13 |
eps (float): A small value added to the denominator for numerical stability.
|
| 14 |
weight (nn.Parameter): Learnable scaling parameter.
|
|
|
|
| 15 |
"""
|
| 16 |
super().__init__()
|
| 17 |
self.eps = eps
|
| 18 |
self.weight = nn.Parameter(torch.ones(dim))
|
| 19 |
|
| 20 |
+
def _norm(self, x):
|
| 21 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 22 |
+
|
| 23 |
def forward(self, x):
|
| 24 |
"""
|
| 25 |
Forward pass through the RMSNorm layer.
|
|
|
|
| 26 |
Args:
|
| 27 |
x (torch.Tensor): The input tensor.
|
|
|
|
| 28 |
Returns:
|
| 29 |
torch.Tensor: The output tensor after applying RMSNorm.
|
|
|
|
| 30 |
"""
|
| 31 |
+
output = self._norm(x.float()).type_as(x) # Avoids mixed precision issues as in https://github.com/chandar-lab/AMPLIFY/issues/19
|
| 32 |
+
return output * self.weight
|