lhallee commited on
Commit
9f28fda
·
verified ·
1 Parent(s): d6d02ce

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +150 -150
README.md CHANGED
@@ -1,151 +1,151 @@
1
- ---
2
- library_name: transformers
3
- tags: []
4
- ---
5
-
6
- # NOTE
7
- The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git)
8
-
9
- # FastESM
10
- FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
11
-
12
- Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
13
-
14
- ## Attention backend defaults
15
- `sdpa` is the default attention backend for FastESM.
16
-
17
- To enable Flex Attention, set `attn_backend="flex"` on the config before model initialization/loading.
18
-
19
- For throughput and memory efficiency, `torch.compile(...)` is heavily recommended, especially when using Flex Attention.
20
-
21
- Outputting attention maps (or the contact prediction head) is not natively possible with the optimized attention backends (including Flex Attention). You can still pass ```output_attentions``` to have attention calculated manually and returned.
22
- Various other optimizations also make the base implementation slightly different than the one in transformers.
23
-
24
- ## Use with 🤗 transformers
25
-
26
- ### Supported models
27
- ```python
28
- model_dict = {
29
- # Synthyra/ESM2-8M
30
- 'ESM2-8M': 'facebook/esm2_t6_8M_UR50D',
31
- # Synthyra/ESM2-35M
32
- 'ESM2-35M': 'facebook/esm2_t12_35M_UR50D',
33
- # Synthyra/ESM2-150M
34
- 'ESM2-150M': 'facebook/esm2_t30_150M_UR50D',
35
- # Synthyra/ESM2-650M
36
- 'ESM2-650M': 'facebook/esm2_t33_650M_UR50D',
37
- # Synthyra/ESM2-3B
38
- 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
39
- }
40
- ```
41
-
42
- ### For working with embeddings
43
- ```python
44
- import torch
45
- from transformers import AutoModel, AutoTokenizer
46
-
47
- model_path = 'Synthyra/ESM2-8M'
48
- model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
49
- tokenizer = model.tokenizer
50
-
51
- sequences = ['MPRTEIN', 'MSEQWENCE']
52
- tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
53
- with torch.no_grad():
54
- embeddings = model(**tokenized).last_hidden_state
55
-
56
- print(embeddings.shape) # (2, 11, 1280)
57
- ```
58
-
59
- ### For working with sequence logits
60
- ```python
61
- import torch
62
- from transformers import AutoModelForMaskedLM, AutoTokenizer
63
-
64
- model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
65
- with torch.no_grad():
66
- logits = model(**tokenized).logits
67
-
68
- print(logits.shape) # (2, 11, 33)
69
- ```
70
-
71
- ### For working with attention maps
72
- ```python
73
- import torch
74
- from transformers import AutoModel, AutoTokenizer
75
-
76
- model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
77
- with torch.no_grad():
78
- attentions = model(**tokenized, output_attentions).attentions # tuples of (batch_size, num_heads, seq_len, seq_len)
79
-
80
- print(attentions[-1].shape) # (2, 20, 11, 11)
81
- ```
82
-
83
- ### Contact prediction
84
- Because we can output attentions using the naive attention implementation, the contact prediction is also supported
85
- ```python
86
- with torch.no_grad():
87
- contact_map = model.predict_contacts(**tokenized).squeeze().cpu().numpy() # (seq_len, seq_len)
88
- ```
89
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/9707OSXZ3Wdgn0Ni-55T-.png)
90
-
91
- ## Embed entire datasets with no new code
92
- To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
93
-
94
- Example:
95
- ```python
96
- embedding_dict = model.embed_dataset(
97
- sequences=[
98
- 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
99
- ],
100
- tokenizer=model.tokenizer,
101
- batch_size=2, # adjust for your GPU memory
102
- max_len=512, # adjust for your needs
103
- full_embeddings=False, # if True, no pooling is performed
104
- embed_dtype=torch.float32, # cast to what dtype you want
105
- pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
106
- num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
107
- sql=False, # if True, embeddings will be stored in SQLite database
108
- sql_db_path='embeddings.db',
109
- save=True, # if True, embeddings will be saved as a .pth file
110
- save_path='embeddings.pth',
111
- )
112
- # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
113
- ```
114
-
115
- ```
116
- model.embed_dataset()
117
- Args:
118
- sequences: List of protein sequences
119
- batch_size: Batch size for processing
120
- max_len: Maximum sequence length
121
- full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
122
- pooling_type: Type of pooling ('mean' or 'cls')
123
- num_workers: Number of workers for data loading, 0 for the main process
124
- sql: Whether to store embeddings in SQLite database - will be stored in float32
125
- sql_db_path: Path to SQLite database
126
-
127
- Returns:
128
- Dictionary mapping sequences to embeddings, or None if sql=True
129
-
130
- Note:
131
- - If sql=True, embeddings can only be stored in float32
132
- - sql is ideal if you need to stream a very large dataset for training in real-time
133
- - save=True is ideal if you can store the entire embedding dictionary in RAM
134
- - sql will be used if it is True and save is True or False
135
- - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
136
- - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
137
- ```
138
-
139
-
140
- ### Citation
141
- If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
142
- ```
143
- @misc {FastPLMs,
144
- author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.},
145
- title = { FastPLMs: Fast, efficient, protien language model inference from Huggingface AutoModel.},
146
- year = {2024},
147
- url = { https://huggingface.co/Synthyra/ESMplusplus_small },
148
- DOI = { 10.57967/hf/3726 },
149
- publisher = { Hugging Face }
150
- }
151
  ```
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # NOTE
7
+ The GitHub with the implementation and requirements.txt can be found [here](https://github.com/Synthyra/FastPLMs.git)
8
+
9
+ # FastESM
10
+ FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
11
+
12
+ Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
13
+
14
+ ## Attention backend defaults
15
+ `sdpa` is the default attention backend for FastESM.
16
+
17
+ To enable Flex Attention, set `attn_backend="flex"` on the config before model initialization/loading.
18
+
19
+ For throughput and memory efficiency, `torch.compile(...)` is heavily recommended, especially when using Flex Attention.
20
+
21
+ Outputting attention maps (or the contact prediction head) is not natively possible with the optimized attention backends (including Flex Attention). You can still pass ```output_attentions``` to have attention calculated manually and returned.
22
+ Various other optimizations also make the base implementation slightly different than the one in transformers.
23
+
24
+ ## Use with 🤗 transformers
25
+
26
+ ### Supported models
27
+ ```python
28
+ model_dict = {
29
+ # Synthyra/ESM2-8M
30
+ 'ESM2-8M': 'facebook/esm2_t6_8M_UR50D',
31
+ # Synthyra/ESM2-35M
32
+ 'ESM2-35M': 'facebook/esm2_t12_35M_UR50D',
33
+ # Synthyra/ESM2-150M
34
+ 'ESM2-150M': 'facebook/esm2_t30_150M_UR50D',
35
+ # Synthyra/ESM2-650M
36
+ 'ESM2-650M': 'facebook/esm2_t33_650M_UR50D',
37
+ # Synthyra/ESM2-3B
38
+ 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
39
+ }
40
+ ```
41
+
42
+ ### For working with embeddings
43
+ ```python
44
+ import torch
45
+ from transformers import AutoModel, AutoTokenizer
46
+
47
+ model_path = 'Synthyra/ESM2-8M'
48
+ model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
49
+ tokenizer = model.tokenizer
50
+
51
+ sequences = ['MPRTEIN', 'MSEQWENCE']
52
+ tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
53
+ with torch.no_grad():
54
+ embeddings = model(**tokenized).last_hidden_state
55
+
56
+ print(embeddings.shape) # (2, 11, 1280)
57
+ ```
58
+
59
+ ### For working with sequence logits
60
+ ```python
61
+ import torch
62
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
63
+
64
+ model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
65
+ with torch.no_grad():
66
+ logits = model(**tokenized).logits
67
+
68
+ print(logits.shape) # (2, 11, 33)
69
+ ```
70
+
71
+ ### For working with attention maps
72
+ ```python
73
+ import torch
74
+ from transformers import AutoModel, AutoTokenizer
75
+
76
+ model = AutoModel.from_pretrained(model_path, dtype=torch.float16, trust_remote_code=True).eval()
77
+ with torch.no_grad():
78
+ attentions = model(**tokenized, output_attentions).attentions # tuples of (batch_size, num_heads, seq_len, seq_len)
79
+
80
+ print(attentions[-1].shape) # (2, 20, 11, 11)
81
+ ```
82
+
83
+ ### Contact prediction
84
+ Because we can output attentions using the naive attention implementation, the contact prediction is also supported
85
+ ```python
86
+ with torch.no_grad():
87
+ contact_map = model.predict_contacts(**tokenized).squeeze().cpu().numpy() # (seq_len, seq_len)
88
+ ```
89
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/9707OSXZ3Wdgn0Ni-55T-.png)
90
+
91
+ ## Embed entire datasets with no new code
92
+ To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
93
+
94
+ Example:
95
+ ```python
96
+ embedding_dict = model.embed_dataset(
97
+ sequences=[
98
+ 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
99
+ ],
100
+ tokenizer=model.tokenizer,
101
+ batch_size=2, # adjust for your GPU memory
102
+ max_len=512, # adjust for your needs
103
+ full_embeddings=False, # if True, no pooling is performed
104
+ embed_dtype=torch.float32, # cast to what dtype you want
105
+ pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
106
+ num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
107
+ sql=False, # if True, embeddings will be stored in SQLite database
108
+ sql_db_path='embeddings.db',
109
+ save=True, # if True, embeddings will be saved as a .pth file
110
+ save_path='embeddings.pth',
111
+ )
112
+ # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
113
+ ```
114
+
115
+ ```
116
+ model.embed_dataset()
117
+ Args:
118
+ sequences: List of protein sequences
119
+ batch_size: Batch size for processing
120
+ max_len: Maximum sequence length
121
+ full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
122
+ pooling_type: Type of pooling ('mean' or 'cls')
123
+ num_workers: Number of workers for data loading, 0 for the main process
124
+ sql: Whether to store embeddings in SQLite database - will be stored in float32
125
+ sql_db_path: Path to SQLite database
126
+
127
+ Returns:
128
+ Dictionary mapping sequences to embeddings, or None if sql=True
129
+
130
+ Note:
131
+ - If sql=True, embeddings can only be stored in float32
132
+ - sql is ideal if you need to stream a very large dataset for training in real-time
133
+ - save=True is ideal if you can store the entire embedding dictionary in RAM
134
+ - sql will be used if it is True and save is True or False
135
+ - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
136
+ - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
137
+ ```
138
+
139
+
140
+ ### Citation
141
+ If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
142
+ ```
143
+ @misc {FastPLMs,
144
+ author = { Hallee, Logan and Bichara, David and Gleghorn, Jason P.},
145
+ title = { FastPLMs: Fast, efficient, protien language model inference from Huggingface AutoModel.},
146
+ year = {2024},
147
+ url = { https://huggingface.co/Synthyra/ESMplusplus_small },
148
+ DOI = { 10.57967/hf/3726 },
149
+ publisher = { Hugging Face }
150
+ }
151
  ```