Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +16 -73
modeling_fastesm.py
CHANGED
|
@@ -29,14 +29,11 @@ except ImportError:
|
|
| 29 |
flex_attention = None
|
| 30 |
|
| 31 |
try:
|
| 32 |
-
# when used from AutoModel, these are in the same directory
|
| 33 |
from .embedding_mixin import EmbeddingMixin
|
| 34 |
-
except:
|
| 35 |
try:
|
| 36 |
-
# whem importing as a submodule, embedding mixin is in the FastPLMs directory
|
| 37 |
from ..embedding_mixin import EmbeddingMixin
|
| 38 |
-
except:
|
| 39 |
-
# when running from our repo, these are in the base directory
|
| 40 |
from embedding_mixin import EmbeddingMixin
|
| 41 |
|
| 42 |
|
|
@@ -236,12 +233,13 @@ class EsmEmbeddings(nn.Module):
|
|
| 236 |
|
| 237 |
def __init__(self, config):
|
| 238 |
super().__init__()
|
| 239 |
-
self.
|
|
|
|
| 240 |
if config.emb_layer_norm_before:
|
| 241 |
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 242 |
else:
|
| 243 |
self.layer_norm = None
|
| 244 |
-
self.position_embedding_type =
|
| 245 |
self.register_buffer(
|
| 246 |
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 247 |
)
|
|
@@ -583,11 +581,6 @@ class FastEsmPreTrainedModel(PreTrainedModel):
|
|
| 583 |
module.bias.data.zero_()
|
| 584 |
module.weight.data.fill_(1.0)
|
| 585 |
|
| 586 |
-
def get_input_embeddings(self) -> nn.Module:
|
| 587 |
-
try:
|
| 588 |
-
return self.embeddings.word_embeddings
|
| 589 |
-
except AttributeError:
|
| 590 |
-
return self.esm.embeddings.word_embeddings
|
| 591 |
|
| 592 |
|
| 593 |
class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
@@ -727,14 +720,13 @@ class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
| 727 |
self.config = config
|
| 728 |
self.esm = FAST_ESM_ENCODER(config)
|
| 729 |
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
| 730 |
-
# Initialize weights and apply final processing
|
| 731 |
self.post_init()
|
| 732 |
|
| 733 |
def get_input_embeddings(self):
|
| 734 |
-
return self.embeddings.word_embeddings
|
| 735 |
|
| 736 |
def set_input_embeddings(self, value):
|
| 737 |
-
self.embeddings.word_embeddings = value
|
| 738 |
|
| 739 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 740 |
return self.esm._embed(input_ids, attention_mask)
|
|
@@ -806,6 +798,9 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
| 806 |
self.loss_fct = nn.CrossEntropyLoss()
|
| 807 |
self.init_weights()
|
| 808 |
|
|
|
|
|
|
|
|
|
|
| 809 |
def get_output_embeddings(self):
|
| 810 |
return self.lm_head.decoder
|
| 811 |
|
|
@@ -867,6 +862,9 @@ class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
| 867 |
self.bce = nn.BCEWithLogitsLoss()
|
| 868 |
self.init_weights()
|
| 869 |
|
|
|
|
|
|
|
|
|
|
| 870 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 871 |
return self.esm._embed(input_ids, attention_mask)
|
| 872 |
|
|
@@ -935,6 +933,9 @@ class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
| 935 |
self.loss_fct = nn.CrossEntropyLoss()
|
| 936 |
self.init_weights()
|
| 937 |
|
|
|
|
|
|
|
|
|
|
| 938 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 939 |
return self.esm._embed(input_ids, attention_mask)
|
| 940 |
|
|
@@ -978,61 +979,3 @@ class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
| 978 |
)
|
| 979 |
|
| 980 |
|
| 981 |
-
if __name__ == "__main__":
|
| 982 |
-
"""
|
| 983 |
-
Test the hidden state differences between the FastEsmModel and the HF EsmModel.
|
| 984 |
-
In full precision, the differences are very very small, but nonzero due to floating point issues with F.scaled_dot_product_attention.
|
| 985 |
-
In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
|
| 986 |
-
"""
|
| 987 |
-
import random
|
| 988 |
-
from transformers import EsmForMaskedLM as TransformersEsmModel, EsmTokenizer
|
| 989 |
-
|
| 990 |
-
model_paths = [
|
| 991 |
-
"facebook/esm2_t6_8M_UR50D",
|
| 992 |
-
"facebook/esm2_t12_35M_UR50D",
|
| 993 |
-
#"facebook/esm2_t30_150M_UR50D",
|
| 994 |
-
#"facebook/esm2_t33_650M_UR50D",
|
| 995 |
-
]
|
| 996 |
-
canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
|
| 997 |
-
length = 64
|
| 998 |
-
seq_count = 100
|
| 999 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1000 |
-
tolerances = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
|
| 1001 |
-
|
| 1002 |
-
def generate_random_sequence(length: int) -> str:
|
| 1003 |
-
return 'M' + "".join(random.choices(canonical_amino_acids, k=length))
|
| 1004 |
-
|
| 1005 |
-
print("Percentage of hidden states that are within the tolerance:")
|
| 1006 |
-
for model_path in model_paths:
|
| 1007 |
-
print(f"Testing {model_path}...")
|
| 1008 |
-
tokenizer = EsmTokenizer.from_pretrained(model_path)
|
| 1009 |
-
config = FastEsmConfig.from_pretrained(model_path)
|
| 1010 |
-
fast_model = FastEsmForMaskedLM(config).from_pretrained(model_path).to(device)
|
| 1011 |
-
print('fast model')
|
| 1012 |
-
print(fast_model)
|
| 1013 |
-
model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
|
| 1014 |
-
print('transformers model')
|
| 1015 |
-
print(model)
|
| 1016 |
-
|
| 1017 |
-
counts = [0] * len(tolerances)
|
| 1018 |
-
for _ in range(seq_count):
|
| 1019 |
-
example_seq = generate_random_sequence(length)
|
| 1020 |
-
fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
|
| 1021 |
-
fast_output = fast_model(fast_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
|
| 1022 |
-
|
| 1023 |
-
model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
|
| 1024 |
-
model_output = model(model_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
|
| 1025 |
-
|
| 1026 |
-
for i, atol in enumerate(tolerances):
|
| 1027 |
-
if torch.allclose(fast_output, model_output, atol=atol):
|
| 1028 |
-
counts[i] += 1
|
| 1029 |
-
|
| 1030 |
-
print(f"{model_path}:")
|
| 1031 |
-
for i, atol in enumerate(tolerances):
|
| 1032 |
-
print(f" tolerance={atol}: {counts[i] / seq_count * 100}%")
|
| 1033 |
-
|
| 1034 |
-
model.cpu()
|
| 1035 |
-
fast_model.cpu()
|
| 1036 |
-
del model
|
| 1037 |
-
del fast_model
|
| 1038 |
-
torch.cuda.empty_cache()
|
|
|
|
| 29 |
flex_attention = None
|
| 30 |
|
| 31 |
try:
|
|
|
|
| 32 |
from .embedding_mixin import EmbeddingMixin
|
| 33 |
+
except ImportError:
|
| 34 |
try:
|
|
|
|
| 35 |
from ..embedding_mixin import EmbeddingMixin
|
| 36 |
+
except ImportError:
|
|
|
|
| 37 |
from embedding_mixin import EmbeddingMixin
|
| 38 |
|
| 39 |
|
|
|
|
| 233 |
|
| 234 |
def __init__(self, config):
|
| 235 |
super().__init__()
|
| 236 |
+
self.padding_idx = config.pad_token_id
|
| 237 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
|
| 238 |
if config.emb_layer_norm_before:
|
| 239 |
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 240 |
else:
|
| 241 |
self.layer_norm = None
|
| 242 |
+
self.position_embedding_type = config.position_embedding_type
|
| 243 |
self.register_buffer(
|
| 244 |
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 245 |
)
|
|
|
|
| 581 |
module.bias.data.zero_()
|
| 582 |
module.weight.data.fill_(1.0)
|
| 583 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
|
| 585 |
|
| 586 |
class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
|
|
|
|
| 720 |
self.config = config
|
| 721 |
self.esm = FAST_ESM_ENCODER(config)
|
| 722 |
self.pooler = EsmPooler(config) if add_pooling_layer else None
|
|
|
|
| 723 |
self.post_init()
|
| 724 |
|
| 725 |
def get_input_embeddings(self):
|
| 726 |
+
return self.esm.embeddings.word_embeddings
|
| 727 |
|
| 728 |
def set_input_embeddings(self, value):
|
| 729 |
+
self.esm.embeddings.word_embeddings = value
|
| 730 |
|
| 731 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 732 |
return self.esm._embed(input_ids, attention_mask)
|
|
|
|
| 798 |
self.loss_fct = nn.CrossEntropyLoss()
|
| 799 |
self.init_weights()
|
| 800 |
|
| 801 |
+
def get_input_embeddings(self):
|
| 802 |
+
return self.esm.embeddings.word_embeddings
|
| 803 |
+
|
| 804 |
def get_output_embeddings(self):
|
| 805 |
return self.lm_head.decoder
|
| 806 |
|
|
|
|
| 862 |
self.bce = nn.BCEWithLogitsLoss()
|
| 863 |
self.init_weights()
|
| 864 |
|
| 865 |
+
def get_input_embeddings(self):
|
| 866 |
+
return self.esm.embeddings.word_embeddings
|
| 867 |
+
|
| 868 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 869 |
return self.esm._embed(input_ids, attention_mask)
|
| 870 |
|
|
|
|
| 933 |
self.loss_fct = nn.CrossEntropyLoss()
|
| 934 |
self.init_weights()
|
| 935 |
|
| 936 |
+
def get_input_embeddings(self):
|
| 937 |
+
return self.esm.embeddings.word_embeddings
|
| 938 |
+
|
| 939 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 940 |
return self.esm._embed(input_ids, attention_mask)
|
| 941 |
|
|
|
|
| 979 |
)
|
| 980 |
|
| 981 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|