lhallee commited on
Commit
cc05fd6
·
verified ·
1 Parent(s): 1dfe9a7

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +16 -73
modeling_fastesm.py CHANGED
@@ -29,14 +29,11 @@ except ImportError:
29
  flex_attention = None
30
 
31
  try:
32
- # when used from AutoModel, these are in the same directory
33
  from .embedding_mixin import EmbeddingMixin
34
- except:
35
  try:
36
- # whem importing as a submodule, embedding mixin is in the FastPLMs directory
37
  from ..embedding_mixin import EmbeddingMixin
38
- except:
39
- # when running from our repo, these are in the base directory
40
  from embedding_mixin import EmbeddingMixin
41
 
42
 
@@ -236,12 +233,13 @@ class EsmEmbeddings(nn.Module):
236
 
237
  def __init__(self, config):
238
  super().__init__()
239
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
 
240
  if config.emb_layer_norm_before:
241
  self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
242
  else:
243
  self.layer_norm = None
244
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
245
  self.register_buffer(
246
  "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
247
  )
@@ -583,11 +581,6 @@ class FastEsmPreTrainedModel(PreTrainedModel):
583
  module.bias.data.zero_()
584
  module.weight.data.fill_(1.0)
585
 
586
- def get_input_embeddings(self) -> nn.Module:
587
- try:
588
- return self.embeddings.word_embeddings
589
- except AttributeError:
590
- return self.esm.embeddings.word_embeddings
591
 
592
 
593
  class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
@@ -727,14 +720,13 @@ class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
727
  self.config = config
728
  self.esm = FAST_ESM_ENCODER(config)
729
  self.pooler = EsmPooler(config) if add_pooling_layer else None
730
- # Initialize weights and apply final processing
731
  self.post_init()
732
 
733
  def get_input_embeddings(self):
734
- return self.embeddings.word_embeddings
735
 
736
  def set_input_embeddings(self, value):
737
- self.embeddings.word_embeddings = value
738
 
739
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
740
  return self.esm._embed(input_ids, attention_mask)
@@ -806,6 +798,9 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
806
  self.loss_fct = nn.CrossEntropyLoss()
807
  self.init_weights()
808
 
 
 
 
809
  def get_output_embeddings(self):
810
  return self.lm_head.decoder
811
 
@@ -867,6 +862,9 @@ class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
867
  self.bce = nn.BCEWithLogitsLoss()
868
  self.init_weights()
869
 
 
 
 
870
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
871
  return self.esm._embed(input_ids, attention_mask)
872
 
@@ -935,6 +933,9 @@ class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
935
  self.loss_fct = nn.CrossEntropyLoss()
936
  self.init_weights()
937
 
 
 
 
938
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
939
  return self.esm._embed(input_ids, attention_mask)
940
 
@@ -978,61 +979,3 @@ class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
978
  )
979
 
980
 
981
- if __name__ == "__main__":
982
- """
983
- Test the hidden state differences between the FastEsmModel and the HF EsmModel.
984
- In full precision, the differences are very very small, but nonzero due to floating point issues with F.scaled_dot_product_attention.
985
- In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
986
- """
987
- import random
988
- from transformers import EsmForMaskedLM as TransformersEsmModel, EsmTokenizer
989
-
990
- model_paths = [
991
- "facebook/esm2_t6_8M_UR50D",
992
- "facebook/esm2_t12_35M_UR50D",
993
- #"facebook/esm2_t30_150M_UR50D",
994
- #"facebook/esm2_t33_650M_UR50D",
995
- ]
996
- canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
997
- length = 64
998
- seq_count = 100
999
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1000
- tolerances = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
1001
-
1002
- def generate_random_sequence(length: int) -> str:
1003
- return 'M' + "".join(random.choices(canonical_amino_acids, k=length))
1004
-
1005
- print("Percentage of hidden states that are within the tolerance:")
1006
- for model_path in model_paths:
1007
- print(f"Testing {model_path}...")
1008
- tokenizer = EsmTokenizer.from_pretrained(model_path)
1009
- config = FastEsmConfig.from_pretrained(model_path)
1010
- fast_model = FastEsmForMaskedLM(config).from_pretrained(model_path).to(device)
1011
- print('fast model')
1012
- print(fast_model)
1013
- model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
1014
- print('transformers model')
1015
- print(model)
1016
-
1017
- counts = [0] * len(tolerances)
1018
- for _ in range(seq_count):
1019
- example_seq = generate_random_sequence(length)
1020
- fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
1021
- fast_output = fast_model(fast_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
1022
-
1023
- model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
1024
- model_output = model(model_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
1025
-
1026
- for i, atol in enumerate(tolerances):
1027
- if torch.allclose(fast_output, model_output, atol=atol):
1028
- counts[i] += 1
1029
-
1030
- print(f"{model_path}:")
1031
- for i, atol in enumerate(tolerances):
1032
- print(f" tolerance={atol}: {counts[i] / seq_count * 100}%")
1033
-
1034
- model.cpu()
1035
- fast_model.cpu()
1036
- del model
1037
- del fast_model
1038
- torch.cuda.empty_cache()
 
29
  flex_attention = None
30
 
31
  try:
 
32
  from .embedding_mixin import EmbeddingMixin
33
+ except ImportError:
34
  try:
 
35
  from ..embedding_mixin import EmbeddingMixin
36
+ except ImportError:
 
37
  from embedding_mixin import EmbeddingMixin
38
 
39
 
 
233
 
234
  def __init__(self, config):
235
  super().__init__()
236
+ self.padding_idx = config.pad_token_id
237
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
238
  if config.emb_layer_norm_before:
239
  self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
240
  else:
241
  self.layer_norm = None
242
+ self.position_embedding_type = config.position_embedding_type
243
  self.register_buffer(
244
  "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
245
  )
 
581
  module.bias.data.zero_()
582
  module.weight.data.fill_(1.0)
583
 
 
 
 
 
 
584
 
585
 
586
  class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
 
720
  self.config = config
721
  self.esm = FAST_ESM_ENCODER(config)
722
  self.pooler = EsmPooler(config) if add_pooling_layer else None
 
723
  self.post_init()
724
 
725
  def get_input_embeddings(self):
726
+ return self.esm.embeddings.word_embeddings
727
 
728
  def set_input_embeddings(self, value):
729
+ self.esm.embeddings.word_embeddings = value
730
 
731
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
732
  return self.esm._embed(input_ids, attention_mask)
 
798
  self.loss_fct = nn.CrossEntropyLoss()
799
  self.init_weights()
800
 
801
+ def get_input_embeddings(self):
802
+ return self.esm.embeddings.word_embeddings
803
+
804
  def get_output_embeddings(self):
805
  return self.lm_head.decoder
806
 
 
862
  self.bce = nn.BCEWithLogitsLoss()
863
  self.init_weights()
864
 
865
+ def get_input_embeddings(self):
866
+ return self.esm.embeddings.word_embeddings
867
+
868
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
869
  return self.esm._embed(input_ids, attention_mask)
870
 
 
933
  self.loss_fct = nn.CrossEntropyLoss()
934
  self.init_weights()
935
 
936
+ def get_input_embeddings(self):
937
+ return self.esm.embeddings.word_embeddings
938
+
939
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
940
  return self.esm._embed(input_ids, attention_mask)
941
 
 
979
  )
980
 
981