Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +26 -2
modeling_esm_plusplus.py
CHANGED
|
@@ -886,7 +886,7 @@ class TransformerStack(nn.Module):
|
|
| 886 |
attn_backend: str = "sdpa",
|
| 887 |
):
|
| 888 |
super().__init__()
|
| 889 |
-
self.
|
| 890 |
self.blocks = nn.ModuleList(
|
| 891 |
[
|
| 892 |
UnifiedTransformerBlock(
|
|
@@ -901,6 +901,18 @@ class TransformerStack(nn.Module):
|
|
| 901 |
)
|
| 902 |
self.norm = nn.LayerNorm(d_model, bias=False)
|
| 903 |
self.gradient_checkpointing = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 904 |
|
| 905 |
def forward(
|
| 906 |
self,
|
|
@@ -924,7 +936,7 @@ class TransformerStack(nn.Module):
|
|
| 924 |
|
| 925 |
# move to 4D attention mask or flex block mask
|
| 926 |
attention_mask, flex_block_mask = get_attention_mask(
|
| 927 |
-
attn_backend=self.
|
| 928 |
batch_size=x.shape[0],
|
| 929 |
seq_len=x.shape[1],
|
| 930 |
device=x.device,
|
|
@@ -997,6 +1009,18 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
| 997 |
nn.init.zeros_(module.bias)
|
| 998 |
nn.init.ones_(module.weight)
|
| 999 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1000 |
def _reset_rotary_embeddings(self):
|
| 1001 |
"""Refresh non-persistent rotary buffers after checkpoint loading."""
|
| 1002 |
for module in self.modules():
|
|
|
|
| 886 |
attn_backend: str = "sdpa",
|
| 887 |
):
|
| 888 |
super().__init__()
|
| 889 |
+
self._attn_backend = attn_backend
|
| 890 |
self.blocks = nn.ModuleList(
|
| 891 |
[
|
| 892 |
UnifiedTransformerBlock(
|
|
|
|
| 901 |
)
|
| 902 |
self.norm = nn.LayerNorm(d_model, bias=False)
|
| 903 |
self.gradient_checkpointing = False
|
| 904 |
+
self.attn_backend = attn_backend
|
| 905 |
+
|
| 906 |
+
@property
|
| 907 |
+
def attn_backend(self) -> str:
|
| 908 |
+
return self._attn_backend
|
| 909 |
+
|
| 910 |
+
@attn_backend.setter
|
| 911 |
+
def attn_backend(self, backend: str) -> None:
|
| 912 |
+
assert backend in ("sdpa", "flex"), f"Unsupported attn_backend: {backend}"
|
| 913 |
+
self._attn_backend = backend
|
| 914 |
+
for block in self.blocks:
|
| 915 |
+
block.attn.attn_backend = backend
|
| 916 |
|
| 917 |
def forward(
|
| 918 |
self,
|
|
|
|
| 936 |
|
| 937 |
# move to 4D attention mask or flex block mask
|
| 938 |
attention_mask, flex_block_mask = get_attention_mask(
|
| 939 |
+
attn_backend=self._attn_backend,
|
| 940 |
batch_size=x.shape[0],
|
| 941 |
seq_len=x.shape[1],
|
| 942 |
device=x.device,
|
|
|
|
| 1009 |
nn.init.zeros_(module.bias)
|
| 1010 |
nn.init.ones_(module.weight)
|
| 1011 |
|
| 1012 |
+
@property
|
| 1013 |
+
def attn_backend(self) -> str:
|
| 1014 |
+
return self.config.attn_backend
|
| 1015 |
+
|
| 1016 |
+
@attn_backend.setter
|
| 1017 |
+
def attn_backend(self, backend: str) -> None:
|
| 1018 |
+
assert backend in ("sdpa", "flex"), f"Unsupported attn_backend: {backend}"
|
| 1019 |
+
self.config.attn_backend = backend
|
| 1020 |
+
for module in self.modules():
|
| 1021 |
+
if isinstance(module, TransformerStack):
|
| 1022 |
+
module.attn_backend = backend
|
| 1023 |
+
|
| 1024 |
def _reset_rotary_embeddings(self):
|
| 1025 |
"""Refresh non-persistent rotary buffers after checkpoint loading."""
|
| 1026 |
for module in self.modules():
|