File size: 12,038 Bytes
6e7d4ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
from typing import Optional
from pathlib import Path
from contextlib import nullcontext
import torch
import torch.nn.functional as F
from torch_scatter import scatter_mean
from src.constants import atom_encoder, bond_encoder
from src.model.lightning import DrugFlow, set_default
from src.data.dataset import ProcessedLigandPocketDataset, DPODataset
from src.data.data_utils import AppendVirtualNodesInCoM, Residues, center_data
class DPO(DrugFlow):
def __init__(self, dpo_mode, ref_checkpoint_p, **kwargs):
super(DPO, self).__init__(**kwargs)
self.dpo_mode = dpo_mode
self.dpo_beta = kwargs['loss_params'].dpo_beta if 'dpo_beta' in kwargs['loss_params'] else 100.0
self.dpo_beta_schedule = kwargs['loss_params'].dpo_beta_schedule if 'dpo_beta_schedule' in kwargs['loss_params'] else 't'
self.clamp_dpo = kwargs['loss_params'].clamp_dpo if 'clamp_dpo' in kwargs['loss_params'] else True
self.dpo_lambda_dpo = kwargs['loss_params'].dpo_lambda_dpo if 'dpo_lambda_dpo' in kwargs['loss_params'] else 1
self.dpo_lambda_w = kwargs['loss_params'].dpo_lambda_w if 'dpo_lambda_w' in kwargs['loss_params'] else 1
self.dpo_lambda_l = kwargs['loss_params'].dpo_lambda_l if 'dpo_lambda_l' in kwargs['loss_params'] else 0.2
self.dpo_lambda_h = kwargs['loss_params'].dpo_lambda_h if 'dpo_lambda_h' in kwargs['loss_params'] else kwargs['loss_params'].lambda_h
self.dpo_lambda_e = kwargs['loss_params'].dpo_lambda_e if 'dpo_lambda_e' in kwargs['loss_params'] else kwargs['loss_params'].lambda_e
self.ref_dynamics = self.init_model(kwargs['predictor_params'])
state_dict = torch.load(ref_checkpoint_p)['state_dict']
self.ref_dynamics.load_state_dict({k.replace('dynamics.',''): v for k, v in state_dict.items() if k.startswith('dynamics.')})
print(f'Loaded reference model from {ref_checkpoint_p}')
# initializing model params with ref model params
self.dynamics.load_state_dict(self.ref_dynamics.state_dict())
def get_dataset(self, stage, pocket_transform=None):
# when sampling we don't append virtual nodes as we might need access to the ground truth size
if self.virtual_nodes and stage == 'train':
ligand_transform = AppendVirtualNodesInCoM(
atom_encoder, bond_encoder, add_min=self.add_virtual_min, add_max=self.add_virtual_max)
else:
ligand_transform = None
# we want to know if something goes wrong on the validation or test set
catch_errors = stage == 'train'
if self.sharded_dataset:
raise NotImplementedError('Sharded dataset not implemented for DPO')
if self.sample_from_clusters and stage == 'train': # val/test should be deterministic
raise NotImplementedError('Sampling from clusters not implemented for DPO')
if stage == 'train':
return DPODataset(
Path(self.datadir, 'train.pt'),
ligand_transform=None,
pocket_transform=pocket_transform,
catch_errors=True,
)
else:
return ProcessedLigandPocketDataset(
pt_path=Path(self.datadir, 'val.pt' if self.debug else f'{stage}.pt'),
ligand_transform=ligand_transform,
pocket_transform=pocket_transform,
catch_errors=catch_errors,
)
def training_step(self, data, *args):
ligand_w, ligand_l, pocket = data['ligand'], data['ligand_l'], data['pocket']
loss, info = self.compute_dpo_loss(pocket, ligand_w=ligand_w, ligand_l=ligand_l, return_info=True)
if torch.isnan(loss):
print(f'For ligand pair , loss is NaN at epoch {self.current_epoch}. Info: {info}')
log_dict = {k: v for k, v in info.items() if isinstance(v, float) or torch.numel(v) <= 1}
self.log_metrics({'loss': loss, **log_dict}, 'train', batch_size=len(ligand_w['size']))
out = {'loss': loss, **info}
self.training_step_outputs.append(out)
return out
def validation_step(self, data, *args):
return super().validation_step(data, *args)
def compute_dpo_loss(self, pocket, ligand_w, ligand_l, return_info=False):
t = torch.rand(ligand_w['size'].size(0), device=ligand_w['x'].device).unsqueeze(-1)
if self.dpo_beta_schedule == 't':
# from https://arxiv.org/pdf/2407.13981
beta_t = (self.dpo_beta * t).squeeze()
elif self.dpo_beta_schedule == 'const':
beta_t = self.dpo_beta
else:
raise ValueError(f'Unknown DPO beta schedule: {self.dpo_beta_schedule}')
loss_dict_w = self.compute_loss_single_pair(ligand_w, pocket, t)
loss_dict_l = self.compute_loss_single_pair(ligand_l, pocket, t)
info = {
'loss_x_w': loss_dict_w['theta']['x'].mean().item(),
'loss_h_w': loss_dict_w['theta']['h'].mean().item(),
'loss_e_w': loss_dict_w['theta']['e'].mean().item(),
'loss_x_l': loss_dict_l['theta']['x'].mean().item(),
'loss_h_l': loss_dict_l['theta']['h'].mean().item(),
'loss_e_l': loss_dict_l['theta']['e'].mean().item(),
}
if self.dpo_mode == 'single_dpo_comp':
loss_w_theta = (
loss_dict_w['theta']['x'] +
self.dpo_lambda_h * loss_dict_w['theta']['h'] +
self.dpo_lambda_e * loss_dict_w['theta']['e']
)
loss_w_ref = (
loss_dict_w['ref']['x'] +
self.dpo_lambda_h * loss_dict_w['ref']['h'] +
self.dpo_lambda_e * loss_dict_w['ref']['e']
)
loss_l_theta = (
loss_dict_l['theta']['x'] +
self.dpo_lambda_h * loss_dict_l['theta']['h'] +
self.dpo_lambda_e * loss_dict_l['theta']['e']
)
loss_l_ref = (
loss_dict_l['ref']['x'] +
self.dpo_lambda_h * loss_dict_l['ref']['h'] +
self.dpo_lambda_e * loss_dict_l['ref']['e']
)
diff_w = loss_w_theta - loss_w_ref
diff_l = loss_l_theta - loss_l_ref
info['diff_w'] = diff_w.mean().item()
info['diff_l'] = diff_l.mean().item()
# print(diff)
diff = -1 * beta_t * (diff_w - diff_l)
loss = -1 * F.logsigmoid(diff)
elif self.dpo_mode == 'single_dpo_comp_v3':
diff_w_x = loss_dict_w['theta']['x'] - loss_dict_w['ref']['x']
diff_w_h = loss_dict_w['theta']['h'] - loss_dict_w['ref']['h']
diff_w_e = loss_dict_w['theta']['e'] - loss_dict_w['ref']['e']
diff_l_x = loss_dict_l['theta']['x'] - loss_dict_l['ref']['x']
diff_l_h = loss_dict_l['theta']['h'] - loss_dict_l['ref']['h']
diff_l_e = loss_dict_l['theta']['e'] - loss_dict_l['ref']['e']
info['diff_w_x'] = diff_w_x.mean().item()
info['diff_w_h'] = diff_w_h.mean().item()
info['diff_w_e'] = diff_w_e.mean().item()
info['diff_l_x'] = diff_l_x.mean().item()
info['diff_l_h'] = diff_l_h.mean().item()
info['diff_l_e'] = diff_l_e.mean().item()
# not used, just for logging
_diff_w = diff_w_x + self.dpo_lambda_h * diff_w_h + self.dpo_lambda_e * diff_w_e
_diff_l = diff_l_x + self.dpo_lambda_h * diff_l_h + self.dpo_lambda_e * diff_l_e
info['diff_w'] = _diff_w.mean().item()
info['diff_l'] = _diff_l.mean().item()
diff_x = diff_w_x - diff_l_x
diff_h = diff_w_h - diff_l_h
diff_e = diff_w_e - diff_l_e
info['diff_x'] = diff_x.mean().item()
info['diff_h'] = diff_h.mean().item()
info['diff_e'] = diff_e.mean().item()
diff = -1 * beta_t * (diff_x + self.dpo_lambda_h * diff_h + self.dpo_lambda_e * diff_e)
if self.clamp_dpo:
diff = diff.clamp(-10, 10)
info['dpo_arg_min'] = diff.min().item()
info['dpo_arg_max'] = diff.max().item()
info['dpo_arg_mean'] = diff.mean().item()
dpo_loss = -1 * self.dpo_lambda_dpo * F.logsigmoid(diff)
info['dpo_loss'] = dpo_loss.mean().item()
loss_w_theta_reg = (
loss_dict_w['theta']['x'] +
self.lambda_h * loss_dict_w['theta']['h'] +
self.lambda_e * loss_dict_w['theta']['e']
)
info['loss_w_theta_reg'] = loss_w_theta_reg.mean().item()
loss_l_theta_reg = (
loss_dict_l['theta']['x'] +
self.lambda_h * loss_dict_l['theta']['h'] +
self.lambda_e * loss_dict_l['theta']['e']
)
info['loss_l_theta_reg'] = loss_l_theta_reg.mean().item()
dpo_reg = self.dpo_lambda_w * loss_w_theta_reg + \
self.dpo_lambda_l * loss_l_theta_reg
info['dpo_reg'] = dpo_reg.mean().item()
loss = dpo_loss + dpo_reg
else:
raise ValueError(f'Unknown DPO mode: {self.dpo_mode}')
if self.timestep_weights is not None:
w_t = self.timestep_weights(t).squeeze()
loss = w_t * loss
loss = loss.mean(0)
print(f'Loss is {loss}, info is {info}')
return (loss, info) if return_info else loss
def compute_loss_single_pair(self, ligand, pocket, t):
pocket = Residues(**pocket)
# Center sample
ligand, pocket = center_data(ligand, pocket)
pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0)
# Noise
z0_x = self.module_x.sample_z0(pocket_com, ligand['mask'])
z0_h = self.module_h.sample_z0(ligand['mask'])
z0_e = self.module_e.sample_z0(ligand['bond_mask'])
zt_x = self.module_x.sample_zt(z0_x, ligand['x'], t, ligand['mask'])
zt_h = self.module_h.sample_zt(z0_h, ligand['one_hot'], t, ligand['mask'])
zt_e = self.module_e.sample_zt(z0_e, ligand['bond_one_hot'], t, ligand['bond_mask'])
# Predict denoising
sc_transform = self.get_sc_transform_fn(None, zt_x, t, None, ligand['mask'], pocket)
pred_ligand, _ = self.dynamics(
zt_x, zt_h, ligand['mask'], pocket, t,
bonds_ligand=(ligand['bonds'], zt_e),
sc_transform=sc_transform
)
# Reference model
with torch.no_grad():
ref_pred_ligand, _ = self.ref_dynamics(
zt_x, zt_h, ligand['mask'], pocket, t,
bonds_ligand=(ligand['bonds'], zt_e),
sc_transform=sc_transform
)
# Compute L2 loss
loss_x = self.module_x.compute_loss(pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce=self.loss_reduce)
ref_loss_x = self.module_x.compute_loss(ref_pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce=self.loss_reduce)
t_next = torch.clamp(t + self.train_step_size, max=1.0)
loss_h = self.module_h.compute_loss(pred_ligand['logits_h'], zt_h, ligand['one_hot'], ligand['mask'], t, t_next, reduce=self.loss_reduce)
ref_loss_h = self.module_h.compute_loss(ref_pred_ligand['logits_h'], zt_h, ligand['one_hot'], ligand['mask'], t, t_next, reduce=self.loss_reduce)
loss_e = self.module_e.compute_loss(pred_ligand['logits_e'], zt_e, ligand['bond_one_hot'], ligand['bond_mask'], t, t_next, reduce=self.loss_reduce)
ref_loss_e = self.module_e.compute_loss(ref_pred_ligand['logits_e'], zt_e, ligand['bond_one_hot'], ligand['bond_mask'], t, t_next, reduce=self.loss_reduce)
return {
'theta': {
'x': loss_x,
'h': loss_h,
'e': loss_e,
},
'ref': {
'x': ref_loss_x,
'h': ref_loss_h,
'e': ref_loss_e,
}
}
|