Update chatNT.py
Browse files
chatNT.py
CHANGED
|
@@ -975,7 +975,7 @@ class TorchGptDecoder(nn.Module):
|
|
| 975 |
self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None
|
| 976 |
) -> torch.Tensor:
|
| 977 |
if attention_mask is None:
|
| 978 |
-
attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
|
| 979 |
for layer in self.layers:
|
| 980 |
embeddings = layer(embeddings, attention_mask)
|
| 981 |
|
|
@@ -985,7 +985,7 @@ class TorchGptDecoder(nn.Module):
|
|
| 985 |
self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None
|
| 986 |
) -> dict[str, torch.Tensor]:
|
| 987 |
if attention_mask is None:
|
| 988 |
-
attention_mask = build_causal_attention_mask(1, token_ids.shape[1])
|
| 989 |
|
| 990 |
tokens_embeddings = self.token_embed(token_ids)
|
| 991 |
|
|
@@ -1127,7 +1127,7 @@ def get_activation_fn(activation_name: str): # type: ignore
|
|
| 1127 |
return activations.get(activation_name, nn.functional.relu)
|
| 1128 |
|
| 1129 |
|
| 1130 |
-
def build_causal_attention_mask(batch_size: int, seq_len: int) -> torch.Tensor:
|
| 1131 |
"""
|
| 1132 |
Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
|
| 1133 |
to an attention layer.
|
|
@@ -1139,7 +1139,7 @@ def build_causal_attention_mask(batch_size: int, seq_len: int) -> torch.Tensor:
|
|
| 1139 |
Returns:
|
| 1140 |
Batch of causal masks.
|
| 1141 |
"""
|
| 1142 |
-
mask = torch.ones((batch_size, 1, seq_len, seq_len))
|
| 1143 |
causal_mask = torch.tril(mask)
|
| 1144 |
return causal_mask
|
| 1145 |
|
|
|
|
| 975 |
self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None
|
| 976 |
) -> torch.Tensor:
|
| 977 |
if attention_mask is None:
|
| 978 |
+
attention_mask = build_causal_attention_mask(1, embeddings.shape[1], device=embeddings.device)
|
| 979 |
for layer in self.layers:
|
| 980 |
embeddings = layer(embeddings, attention_mask)
|
| 981 |
|
|
|
|
| 985 |
self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None
|
| 986 |
) -> dict[str, torch.Tensor]:
|
| 987 |
if attention_mask is None:
|
| 988 |
+
attention_mask = build_causal_attention_mask(1, token_ids.shape[1], device=token_ids.device)
|
| 989 |
|
| 990 |
tokens_embeddings = self.token_embed(token_ids)
|
| 991 |
|
|
|
|
| 1127 |
return activations.get(activation_name, nn.functional.relu)
|
| 1128 |
|
| 1129 |
|
| 1130 |
+
def build_causal_attention_mask(batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
|
| 1131 |
"""
|
| 1132 |
Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
|
| 1133 |
to an attention layer.
|
|
|
|
| 1139 |
Returns:
|
| 1140 |
Batch of causal masks.
|
| 1141 |
"""
|
| 1142 |
+
mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device)
|
| 1143 |
causal_mask = torch.tril(mask)
|
| 1144 |
return causal_mask
|
| 1145 |
|