Update chatNT.py
Browse files
chatNT.py
CHANGED
|
@@ -736,9 +736,9 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
| 736 |
self.max_seq_len = config.max_seq_len
|
| 737 |
self.dim = config.dim
|
| 738 |
self.theta = config.theta
|
| 739 |
-
self.sincos_cache =
|
| 740 |
|
| 741 |
-
def _create_sinusoidal_positions(self) -> torch.Tensor:
|
| 742 |
"""
|
| 743 |
Create the sines and cosines for the RoPE.
|
| 744 |
|
|
@@ -747,19 +747,19 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
| 747 |
"""
|
| 748 |
# Create the inverse frequency based on theta and dim
|
| 749 |
inv_freq = 1.0 / (
|
| 750 |
-
self.theta ** (torch.arange(0, self.dim, 2).float() / self.dim)
|
| 751 |
)
|
| 752 |
|
| 753 |
# Compute sinusoidal input using the broadcasting
|
| 754 |
sinusoid_inp = torch.einsum(
|
| 755 |
-
"i,j->ij", torch.arange(self.max_seq_len).float(), inv_freq
|
| 756 |
)
|
| 757 |
|
| 758 |
# Apply sin and cos to the sinusoidal input
|
| 759 |
sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
|
| 760 |
|
| 761 |
# Allocate a tensor for the final sin-cos values
|
| 762 |
-
sincos = torch.zeros((self.max_seq_len, self.dim), dtype=torch.float32)
|
| 763 |
|
| 764 |
# Fill the sincos tensor with sin and cos values
|
| 765 |
sentinel = self.dim // 2 + self.dim % 2
|
|
@@ -824,6 +824,10 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
| 824 |
Returns:
|
| 825 |
RoPE embeddings for the keys and values.
|
| 826 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 827 |
batch_size, seq_len, num_heads, head_dim = k.shape
|
| 828 |
|
| 829 |
# Generate position ids
|
|
|
|
| 736 |
self.max_seq_len = config.max_seq_len
|
| 737 |
self.dim = config.dim
|
| 738 |
self.theta = config.theta
|
| 739 |
+
self.sincos_cache = None
|
| 740 |
|
| 741 |
+
def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor:
|
| 742 |
"""
|
| 743 |
Create the sines and cosines for the RoPE.
|
| 744 |
|
|
|
|
| 747 |
"""
|
| 748 |
# Create the inverse frequency based on theta and dim
|
| 749 |
inv_freq = 1.0 / (
|
| 750 |
+
self.theta ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
|
| 751 |
)
|
| 752 |
|
| 753 |
# Compute sinusoidal input using the broadcasting
|
| 754 |
sinusoid_inp = torch.einsum(
|
| 755 |
+
"i,j->ij", torch.arange(self.max_seq_len, device=device).float(), inv_freq
|
| 756 |
)
|
| 757 |
|
| 758 |
# Apply sin and cos to the sinusoidal input
|
| 759 |
sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
|
| 760 |
|
| 761 |
# Allocate a tensor for the final sin-cos values
|
| 762 |
+
sincos = torch.zeros((self.max_seq_len, self.dim), dtype=torch.float32, device=device)
|
| 763 |
|
| 764 |
# Fill the sincos tensor with sin and cos values
|
| 765 |
sentinel = self.dim // 2 + self.dim % 2
|
|
|
|
| 824 |
Returns:
|
| 825 |
RoPE embeddings for the keys and values.
|
| 826 |
"""
|
| 827 |
+
if self.sincos_cache is None:
|
| 828 |
+
device = k.device
|
| 829 |
+
self.sincos_cache = self._create_sinusoidal_positions(device=device)
|
| 830 |
+
|
| 831 |
batch_size, seq_len, num_heads, head_dim = k.shape
|
| 832 |
|
| 833 |
# Generate position ids
|