| | import torch |
| | import torch.nn.functional as F |
| | import torch.nn as nn |
| | import pytorch_lightning as pl |
| | import time |
| | from transformers import AutoModel, AutoConfig, AutoTokenizer |
| | import xgboost as xgb |
| | import esm |
| |
|
| | class UnpooledBindingPredictor(nn.Module): |
| | def __init__(self, |
| | esm_model_name="facebook/esm2_t33_650M_UR50D", |
| | hidden_dim=512, |
| | kernel_sizes=[3, 5, 7], |
| | n_heads=8, |
| | n_layers=3, |
| | dropout=0.1, |
| | freeze_esm=True): |
| | super().__init__() |
| | |
| | |
| | self.tight_threshold = 7.5 |
| | self.weak_threshold = 6.0 |
| | |
| | |
| | self.esm_model = AutoModel.from_pretrained(esm_model_name) |
| | self.config = AutoConfig.from_pretrained(esm_model_name) |
| | |
| | |
| | if freeze_esm: |
| | for param in self.esm_model.parameters(): |
| | param.requires_grad = False |
| | |
| | |
| | esm_dim = self.config.hidden_size |
| | |
| | |
| | output_channels_per_kernel = 64 |
| | |
| | |
| | self.protein_conv_layers = nn.ModuleList([ |
| | nn.Conv1d( |
| | in_channels=esm_dim, |
| | out_channels=output_channels_per_kernel, |
| | kernel_size=k, |
| | padding='same' |
| | ) for k in kernel_sizes |
| | ]) |
| | |
| | self.binder_conv_layers = nn.ModuleList([ |
| | nn.Conv1d( |
| | in_channels=esm_dim, |
| | out_channels=output_channels_per_kernel, |
| | kernel_size=k, |
| | padding='same' |
| | ) for k in kernel_sizes |
| | ]) |
| | |
| | |
| | total_features_per_seq = output_channels_per_kernel * len(kernel_sizes) * 2 |
| | |
| | |
| | self.protein_projection = nn.Linear(total_features_per_seq, hidden_dim) |
| | self.binder_projection = nn.Linear(total_features_per_seq, hidden_dim) |
| | |
| | self.protein_norm = nn.LayerNorm(hidden_dim) |
| | self.binder_norm = nn.LayerNorm(hidden_dim) |
| | |
| | |
| | self.cross_attention_layers = nn.ModuleList([ |
| | nn.ModuleDict({ |
| | 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout), |
| | 'norm1': nn.LayerNorm(hidden_dim), |
| | 'ffn': nn.Sequential( |
| | nn.Linear(hidden_dim, hidden_dim * 4), |
| | nn.ReLU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(hidden_dim * 4, hidden_dim) |
| | ), |
| | 'norm2': nn.LayerNorm(hidden_dim) |
| | }) for _ in range(n_layers) |
| | ]) |
| | |
| | |
| | self.shared_head = nn.Sequential( |
| | nn.Linear(hidden_dim * 2, hidden_dim), |
| | nn.ReLU(), |
| | nn.Dropout(dropout), |
| | ) |
| | |
| | |
| | self.regression_head = nn.Linear(hidden_dim, 1) |
| | |
| | |
| | self.classification_head = nn.Linear(hidden_dim, 3) |
| | |
| | def get_binding_class(self, affinity): |
| | """Convert affinity values to class indices |
| | 0: tight binding (>= 7.5) |
| | 1: medium binding (6.0-7.5) |
| | 2: weak binding (< 6.0) |
| | """ |
| | if isinstance(affinity, torch.Tensor): |
| | tight_mask = affinity >= self.tight_threshold |
| | weak_mask = affinity < self.weak_threshold |
| | medium_mask = ~(tight_mask | weak_mask) |
| | |
| | classes = torch.zeros_like(affinity, dtype=torch.long) |
| | classes[medium_mask] = 1 |
| | classes[weak_mask] = 2 |
| | return classes |
| | else: |
| | if affinity >= self.tight_threshold: |
| | return 0 |
| | elif affinity < self.weak_threshold: |
| | return 2 |
| | else: |
| | return 1 |
| | |
| | def compute_embeddings(self, input_ids, attention_mask=None): |
| | """Compute ESM embeddings on the fly""" |
| | esm_outputs = self.esm_model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | return_dict=True |
| | ) |
| | |
| | |
| | return esm_outputs.last_hidden_state |
| | |
| | def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None): |
| | """Process a sequence through CNN layers and pooling""" |
| | |
| | x = unpooled_emb.transpose(1, 2) |
| | |
| | |
| | conv_outputs = [] |
| | for conv in conv_layers: |
| | conv_out = F.relu(conv(x)) |
| | conv_outputs.append(conv_out) |
| | |
| | |
| | conv_output = torch.cat(conv_outputs, dim=1) |
| | |
| | |
| | |
| | if attention_mask is not None: |
| | |
| | |
| | expanded_mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1) |
| | |
| | |
| | masked_output = conv_output.clone() |
| | masked_output = masked_output.masked_fill(expanded_mask == 0, float('-inf')) |
| | |
| | |
| | max_pooled = torch.max(masked_output, dim=2)[0] |
| | |
| | |
| | sum_pooled = torch.sum(conv_output * expanded_mask, dim=2) |
| | valid_positions = torch.sum(expanded_mask, dim=2) |
| | valid_positions = torch.clamp(valid_positions, min=1.0) |
| | avg_pooled = sum_pooled / valid_positions |
| | else: |
| | |
| | max_pooled = torch.max(conv_output, dim=2)[0] |
| | avg_pooled = torch.mean(conv_output, dim=2) |
| | |
| | |
| | pooled = torch.cat([max_pooled, avg_pooled], dim=1) |
| | |
| | return pooled |
| | |
| | def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None): |
| | |
| | protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask) |
| | binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask) |
| | |
| | |
| | protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask) |
| | binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask) |
| | |
| | |
| | protein = self.protein_norm(self.protein_projection(protein_features)) |
| | binder = self.binder_norm(self.binder_projection(binder_features)) |
| | |
| | |
| | protein = protein.unsqueeze(0) |
| | binder = binder.unsqueeze(0) |
| | |
| | |
| | for layer in self.cross_attention_layers: |
| | |
| | attended_protein = layer['attention']( |
| | protein, binder, binder |
| | )[0] |
| | protein = layer['norm1'](protein + attended_protein) |
| | protein = layer['norm2'](protein + layer['ffn'](protein)) |
| | |
| | |
| | attended_binder = layer['attention']( |
| | binder, protein, protein |
| | )[0] |
| | binder = layer['norm1'](binder + attended_binder) |
| | binder = layer['norm2'](binder + layer['ffn'](binder)) |
| | |
| | |
| | protein_pool = protein.squeeze(0) |
| | binder_pool = binder.squeeze(0) |
| | |
| | |
| | combined = torch.cat([protein_pool, binder_pool], dim=-1) |
| | |
| | |
| | shared_features = self.shared_head(combined) |
| | |
| | regression_output = self.regression_head(shared_features) |
| | |
| | |
| | |
| | return regression_output |
| |
|
| | class ImprovedBindingPredictor(nn.Module): |
| | def __init__(self, |
| | esm_dim=1280, |
| | smiles_dim=1280, |
| | hidden_dim=512, |
| | n_heads=8, |
| | n_layers=5, |
| | dropout=0.1): |
| | super().__init__() |
| | |
| | |
| | self.tight_threshold = 7.5 |
| | self.weak_threshold = 6.0 |
| | |
| | |
| | self.smiles_projection = nn.Linear(smiles_dim, hidden_dim) |
| | self.protein_projection = nn.Linear(esm_dim, hidden_dim) |
| | self.protein_norm = nn.LayerNorm(hidden_dim) |
| | self.smiles_norm = nn.LayerNorm(hidden_dim) |
| | |
| | |
| | self.cross_attention_layers = nn.ModuleList([ |
| | nn.ModuleDict({ |
| | 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout), |
| | 'norm1': nn.LayerNorm(hidden_dim), |
| | 'ffn': nn.Sequential( |
| | nn.Linear(hidden_dim, hidden_dim * 4), |
| | nn.ReLU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(hidden_dim * 4, hidden_dim) |
| | ), |
| | 'norm2': nn.LayerNorm(hidden_dim) |
| | }) for _ in range(n_layers) |
| | ]) |
| | |
| | |
| | self.shared_head = nn.Sequential( |
| | nn.Linear(hidden_dim * 2, hidden_dim), |
| | nn.ReLU(), |
| | nn.Dropout(dropout), |
| | ) |
| | |
| | |
| | self.regression_head = nn.Linear(hidden_dim, 1) |
| | |
| | |
| | self.classification_head = nn.Linear(hidden_dim, 3) |
| | |
| | def get_binding_class(self, affinity): |
| | """Convert affinity values to class indices |
| | 0: tight binding (>= 7.5) |
| | 1: medium binding (6.0-7.5) |
| | 2: weak binding (< 6.0) |
| | """ |
| | if isinstance(affinity, torch.Tensor): |
| | tight_mask = affinity >= self.tight_threshold |
| | weak_mask = affinity < self.weak_threshold |
| | medium_mask = ~(tight_mask | weak_mask) |
| | |
| | classes = torch.zeros_like(affinity, dtype=torch.long) |
| | classes[medium_mask] = 1 |
| | classes[weak_mask] = 2 |
| | return classes |
| | else: |
| | if affinity >= self.tight_threshold: |
| | return 0 |
| | elif affinity < self.weak_threshold: |
| | return 2 |
| | else: |
| | return 1 |
| | |
| | def forward(self, protein_emb, binder_emb): |
| | |
| | protein = self.protein_norm(self.protein_projection(protein_emb)) |
| | smiles = self.smiles_norm(self.smiles_projection(binder_emb)) |
| | |
| | protein = protein.transpose(0, 1) |
| | smiles = smiles.transpose(0, 1) |
| | |
| | |
| | for layer in self.cross_attention_layers: |
| | |
| | attended_protein = layer['attention']( |
| | protein, smiles, smiles |
| | )[0] |
| | protein = layer['norm1'](protein + attended_protein) |
| | protein = layer['norm2'](protein + layer['ffn'](protein)) |
| | |
| | |
| | attended_smiles = layer['attention']( |
| | smiles, protein, protein |
| | )[0] |
| | smiles = layer['norm1'](smiles + attended_smiles) |
| | smiles = layer['norm2'](smiles + layer['ffn'](smiles)) |
| | |
| | |
| | protein_pool = torch.mean(protein, dim=0) |
| | smiles_pool = torch.mean(smiles, dim=0) |
| | |
| | |
| | combined = torch.cat([protein_pool, smiles_pool], dim=-1) |
| | |
| | |
| | shared_features = self.shared_head(combined) |
| | |
| | regression_output = self.regression_head(shared_features) |
| | |
| | return regression_output |
| |
|
| | class PooledAffinityModel(nn.Module): |
| | def __init__(self, affinity_predictor, target_sequence): |
| | super(PooledAffinityModel, self).__init__() |
| | self.affinity_predictor = affinity_predictor |
| | self.target_sequence = target_sequence |
| | self.esm_model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(self.target_sequence.device) |
| | for param in self.esm_model.parameters(): |
| | param.requires_grad = False |
| | |
| | def compute_embeddings(self, input_ids, attention_mask=None): |
| | """Compute ESM embeddings on the fly""" |
| | esm_outputs = self.esm_model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | return_dict=True |
| | ) |
| | |
| | |
| | return esm_outputs.last_hidden_state |
| | |
| | def forward(self, x): |
| | target_sequence = self.target_sequence.repeat(x.shape[0], 1) |
| |
|
| | protein_emb = self.compute_embeddings(input_ids=target_sequence) |
| | binder_emb = self.compute_embeddings(input_ids=x) |
| | return self.affinity_predictor(protein_emb=protein_emb, binder_emb=binder_emb).squeeze(-1) |
| |
|
| | class AffinityModel(nn.Module): |
| | def __init__(self, affinity_predictor, target_sequence): |
| | super(AffinityModel, self).__init__() |
| | self.affinity_predictor = affinity_predictor |
| | self.target_sequence = target_sequence |
| | |
| | def forward(self, x): |
| | target_sequence = self.target_sequence.repeat(x.shape[0], 1) |
| | affinity = self.affinity_predictor(protein_input_ids=target_sequence, binder_input_ids=x).squeeze(-1) |
| | return affinity / 10 |
| |
|
| | class HemolysisModel: |
| | def __init__(self, device): |
| | self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_hemolysis.json') |
| | |
| | self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device) |
| | self.model.eval() |
| |
|
| | self.device = device |
| | |
| | def generate_embeddings(self, sequences): |
| | """Generate ESM embeddings for protein sequences""" |
| | with torch.no_grad(): |
| | embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1) |
| | embeddings = embeddings.cpu().numpy() |
| | |
| | return embeddings |
| | |
| | def get_scores(self, input_seqs): |
| | scores = np.ones(len(input_seqs)) |
| | features = self.generate_embeddings(input_seqs) |
| | |
| | if len(features) == 0: |
| | return scores |
| | |
| | features = np.nan_to_num(features, nan=0.) |
| | features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) |
| | |
| | features = xgb.DMatrix(features) |
| | |
| | probs = self.predictor.predict(features) |
| | |
| | return torch.from_numpy(scores - probs).to(self.device) |
| | |
| | def __call__(self, input_seqs: list): |
| | scores = self.get_scores(input_seqs) |
| | return scores |
| |
|
| | class NonfoulingModel: |
| | def __init__(self, device): |
| | |
| | self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_nonfouling.json') |
| | |
| | self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device) |
| | self.model.eval() |
| |
|
| | self.device = device |
| | |
| | def generate_embeddings(self, sequences): |
| | """Generate ESM embeddings for protein sequences""" |
| | with torch.no_grad(): |
| | embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1) |
| | embeddings = embeddings.cpu().numpy() |
| | |
| | return embeddings |
| | |
| | def get_scores(self, input_seqs): |
| | scores = np.zeros(len(input_seqs)) |
| | features = self.generate_embeddings(input_seqs) |
| | |
| | if len(features) == 0: |
| | return scores |
| | |
| | features = np.nan_to_num(features, nan=0.) |
| | features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) |
| | |
| | features = xgb.DMatrix(features) |
| | |
| | scores = self.predictor.predict(features) |
| | return torch.from_numpy(scores).to(self.device) |
| | |
| | def __call__(self, input_seqs: list): |
| | scores = self.get_scores(input_seqs) |
| | return scores |
| |
|
| | class SolubilityModel: |
| | def __init__(self, device): |
| | |
| | self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_solubility.json') |
| | |
| | self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device) |
| | self.model.eval() |
| |
|
| | self.device = device |
| | |
| | def generate_embeddings(self, sequences): |
| | """Generate ESM embeddings for protein sequences""" |
| | with torch.no_grad(): |
| | embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1) |
| | embeddings = embeddings.cpu().numpy() |
| | |
| | return embeddings |
| | |
| | def get_scores(self, input_seqs: list): |
| | scores = np.zeros(len(input_seqs)) |
| | features = self.generate_embeddings(input_seqs) |
| | |
| | if len(features) == 0: |
| | return scores |
| | |
| | features = np.nan_to_num(features, nan=0.) |
| | features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) |
| | |
| | features = xgb.DMatrix(features) |
| | |
| | scores = self.predictor.predict(features) |
| | return torch.from_numpy(scores).to(self.device) |
| | |
| | def __call__(self, input_seqs: list): |
| | scores = self.get_scores(input_seqs) |
| | return scores |
| |
|
| | class PeptideCNN(nn.Module): |
| | def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate): |
| | super().__init__() |
| | self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1) |
| | self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1) |
| | self.fc = nn.Linear(hidden_dims[1], output_dim) |
| | self.dropout = nn.Dropout(dropout_rate) |
| | self.predictor = nn.Linear(output_dim, 1) |
| |
|
| | self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| | self.esm_model.eval() |
| |
|
| | def forward(self, input_ids, attention_mask=None, return_features=False): |
| | with torch.no_grad(): |
| | x = self.esm_model(input_ids, attention_mask).last_hidden_state |
| | |
| | x = x.permute(0, 2, 1) |
| | x = nn.functional.relu(self.conv1(x)) |
| | x = self.dropout(x) |
| | x = nn.functional.relu(self.conv2(x)) |
| | x = self.dropout(x) |
| | x = x.permute(0, 2, 1) |
| | |
| | |
| | x = x.mean(dim=1) |
| | |
| | features = self.fc(x) |
| | if return_features: |
| | return features |
| | return self.predictor(features) |
| |
|
| | class HalfLifeModel: |
| | def __init__(self, device): |
| | input_dim = 1280 |
| | hidden_dims = [input_dim // 2, input_dim // 4] |
| | output_dim = input_dim // 8 |
| | dropout_rate = 0.3 |
| | self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device) |
| | self.model.load_state_dict(torch.load('./classifier_ckpt/best_model_half_life.pth', map_location=device, weights_only=False)) |
| | self.model.eval() |
| |
|
| | def __call__(self, x): |
| | prediction = self.model(x, return_features=False) |
| | half_life = torch.clamp(prediction.squeeze(-1), max=2.0, min=0.0) |
| |
|
| | return half_life / 2 |
| |
|
| |
|
| | def load_bindevaluator(checkpoint_path, device): |
| | bindevaluator = BindEvaluator.load_from_checkpoint(checkpoint_path, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device) |
| | bindevaluator.eval() |
| | for param in bindevaluator.parameters(): |
| | param.requires_grad = False |
| |
|
| | return bindevaluator |
| |
|
| |
|
| |
|
| | def load_pooled_affinity_predictor(checkpoint_path, device): |
| | """Load trained model from checkpoint.""" |
| | checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| |
|
| | model = ImprovedBindingPredictor().to(device) |
| | |
| | |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.eval() |
| | |
| | return model |
| |
|
| | def load_affinity_predictor(checkpoint_path, device): |
| | """Load trained model from checkpoint.""" |
| | checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| |
|
| | model = UnpooledBindingPredictor( |
| | esm_model_name="facebook/esm2_t33_650M_UR50D", |
| | hidden_dim=384, |
| | kernel_sizes=[3, 5, 7], |
| | n_heads=8, |
| | n_layers=4, |
| | dropout=0.14561457009902096, |
| | freeze_esm=True |
| | ).to(device) |
| |
|
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.eval() |
| | |
| | return model |
| |
|