lhallee commited on
Commit
1dfe9a7
·
verified ·
1 Parent(s): a70d778

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +25 -33
README.md CHANGED
@@ -11,40 +11,22 @@ FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a new
11
 
12
  Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
13
 
14
- ## Attention backend defaults
15
- `sdpa` is the default attention backend for FastESM.
16
-
17
- To enable Flex Attention, set `attn_backend="flex"` on the config before model initialization/loading.
18
 
19
- For throughput and memory efficiency, `torch.compile(...)` is heavily recommended, especially when using Flex Attention.
20
 
21
- Outputting attention maps (or the contact prediction head) is not natively possible with the optimized attention backends (including Flex Attention). You can still pass ```output_attentions``` to have attention calculated manually and returned.
22
- Various other optimizations also make the base implementation slightly different than the one in transformers.
23
 
24
  ## Use with 🤗 transformers
25
 
26
- ### Supported models
27
- ```python
28
- model_dict = {
29
- # Synthyra/ESM2-8M
30
- 'ESM2-8M': 'facebook/esm2_t6_8M_UR50D',
31
- # Synthyra/ESM2-35M
32
- 'ESM2-35M': 'facebook/esm2_t12_35M_UR50D',
33
- # Synthyra/ESM2-150M
34
- 'ESM2-150M': 'facebook/esm2_t30_150M_UR50D',
35
- # Synthyra/ESM2-650M
36
- 'ESM2-650M': 'facebook/esm2_t33_650M_UR50D',
37
- # Synthyra/ESM2-3B
38
- 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
39
- }
40
- ```
41
-
42
  ### For working with embeddings
43
  ```python
44
  import torch
45
  from transformers import AutoModel, AutoTokenizer
46
 
47
- model_path = 'Synthyra/ESM2-8M'
48
  model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
49
  tokenizer = model.tokenizer
50
 
@@ -80,14 +62,6 @@ with torch.no_grad():
80
  print(attentions[-1].shape) # (2, 20, 11, 11)
81
  ```
82
 
83
- ### Contact prediction
84
- Because we can output attentions using the naive attention implementation, the contact prediction is also supported
85
- ```python
86
- with torch.no_grad():
87
- contact_map = model.predict_contacts(**tokenized).squeeze().cpu().numpy() # (seq_len, seq_len)
88
- ```
89
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/9707OSXZ3Wdgn0Ni-55T-.png)
90
-
91
  ## Embed entire datasets with no new code
92
  To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
93
 
@@ -136,6 +110,24 @@ Note:
136
  - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
137
  ```
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  ### Citation
141
  If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
@@ -148,4 +140,4 @@ If you use any of this implementation or work please cite it (as well as the [ES
148
  DOI = { 10.57967/hf/3726 },
149
  publisher = { Hugging Face }
150
  }
151
- ```
 
11
 
12
  Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
13
 
14
+ Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
15
+ Various other optimizations also make the base implementation slightly different than the one in transformers.
 
 
16
 
17
+ # FastESM2-650
18
 
19
+ ## A faster half-precision version of ESM2-650 with FlashAttention2 and longer context
20
+ To enhance the weights with longer context and better fp16 support, we trained ESM2-650 50000 additional steps with a traditional MLM objective (20% masking) in fp16 mixed precision on [OMGprot50](https://huggingface.co/datasets/tattabio/OMG_prot50) up to sequence length of **2048**.
21
 
22
  ## Use with 🤗 transformers
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  ### For working with embeddings
25
  ```python
26
  import torch
27
  from transformers import AutoModel, AutoTokenizer
28
 
29
+ model_path = 'Synthyra/FastESM2_650'
30
  model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
31
  tokenizer = model.tokenizer
32
 
 
62
  print(attentions[-1].shape) # (2, 20, 11, 11)
63
  ```
64
 
 
 
 
 
 
 
 
 
65
  ## Embed entire datasets with no new code
66
  To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
67
 
 
110
  - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
111
  ```
112
 
113
+ ## Model probes
114
+ We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. FastESM performs very well.
115
+
116
+ The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.
117
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/d1Xi6k1Q4-9By_MtzTvdV.png)
118
+
119
+ ## Comparison of half precisions
120
+ Presumabely because we trained in mixed-precision fp16, fp16 has closer outputs to the fp32 weights then bf16. Therefore, we recommend loading in fp16.
121
+
122
+ When summing the MSE of 1000 sequences vs. the fp32 weights:
123
+
124
+ Average MSE for FP16: 0.00000140
125
+
126
+ Average MSE for BF16: 0.00004125
127
+
128
+ ### Inference speed
129
+ We look at various ESM models and their throughput on an H100. FastESM is over twice as fast as ESM2-650 with longer sequences. Requires PyTorch 2.5+ for the most savings, see [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
130
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/PvaBGfuJXEW2v_WLkt63y.png)
131
 
132
  ### Citation
133
  If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
 
140
  DOI = { 10.57967/hf/3726 },
141
  publisher = { Hugging Face }
142
  }
143
+ ```