|
|
from typing import Optional |
|
|
from pathlib import Path |
|
|
from contextlib import nullcontext |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch_scatter import scatter_mean |
|
|
|
|
|
from src.constants import atom_encoder, bond_encoder |
|
|
from src.model.lightning import DrugFlow, set_default |
|
|
from src.data.dataset import ProcessedLigandPocketDataset, DPODataset |
|
|
from src.data.data_utils import AppendVirtualNodesInCoM, Residues, center_data |
|
|
|
|
|
class DPO(DrugFlow): |
|
|
def __init__(self, dpo_mode, ref_checkpoint_p, **kwargs): |
|
|
super(DPO, self).__init__(**kwargs) |
|
|
self.dpo_mode = dpo_mode |
|
|
self.dpo_beta = kwargs['loss_params'].dpo_beta if 'dpo_beta' in kwargs['loss_params'] else 100.0 |
|
|
self.dpo_beta_schedule = kwargs['loss_params'].dpo_beta_schedule if 'dpo_beta_schedule' in kwargs['loss_params'] else 't' |
|
|
self.clamp_dpo = kwargs['loss_params'].clamp_dpo if 'clamp_dpo' in kwargs['loss_params'] else True |
|
|
self.dpo_lambda_dpo = kwargs['loss_params'].dpo_lambda_dpo if 'dpo_lambda_dpo' in kwargs['loss_params'] else 1 |
|
|
self.dpo_lambda_w = kwargs['loss_params'].dpo_lambda_w if 'dpo_lambda_w' in kwargs['loss_params'] else 1 |
|
|
self.dpo_lambda_l = kwargs['loss_params'].dpo_lambda_l if 'dpo_lambda_l' in kwargs['loss_params'] else 0.2 |
|
|
self.dpo_lambda_h = kwargs['loss_params'].dpo_lambda_h if 'dpo_lambda_h' in kwargs['loss_params'] else kwargs['loss_params'].lambda_h |
|
|
self.dpo_lambda_e = kwargs['loss_params'].dpo_lambda_e if 'dpo_lambda_e' in kwargs['loss_params'] else kwargs['loss_params'].lambda_e |
|
|
self.ref_dynamics = self.init_model(kwargs['predictor_params']) |
|
|
state_dict = torch.load(ref_checkpoint_p)['state_dict'] |
|
|
self.ref_dynamics.load_state_dict({k.replace('dynamics.',''): v for k, v in state_dict.items() if k.startswith('dynamics.')}) |
|
|
print(f'Loaded reference model from {ref_checkpoint_p}') |
|
|
|
|
|
self.dynamics.load_state_dict(self.ref_dynamics.state_dict()) |
|
|
|
|
|
def get_dataset(self, stage, pocket_transform=None): |
|
|
|
|
|
|
|
|
if self.virtual_nodes and stage == 'train': |
|
|
ligand_transform = AppendVirtualNodesInCoM( |
|
|
atom_encoder, bond_encoder, add_min=self.add_virtual_min, add_max=self.add_virtual_max) |
|
|
else: |
|
|
ligand_transform = None |
|
|
|
|
|
|
|
|
catch_errors = stage == 'train' |
|
|
|
|
|
if self.sharded_dataset: |
|
|
raise NotImplementedError('Sharded dataset not implemented for DPO') |
|
|
|
|
|
if self.sample_from_clusters and stage == 'train': |
|
|
raise NotImplementedError('Sampling from clusters not implemented for DPO') |
|
|
|
|
|
if stage == 'train': |
|
|
return DPODataset( |
|
|
Path(self.datadir, 'train.pt'), |
|
|
ligand_transform=None, |
|
|
pocket_transform=pocket_transform, |
|
|
catch_errors=True, |
|
|
) |
|
|
else: |
|
|
return ProcessedLigandPocketDataset( |
|
|
pt_path=Path(self.datadir, 'val.pt' if self.debug else f'{stage}.pt'), |
|
|
ligand_transform=ligand_transform, |
|
|
pocket_transform=pocket_transform, |
|
|
catch_errors=catch_errors, |
|
|
) |
|
|
|
|
|
|
|
|
def training_step(self, data, *args): |
|
|
ligand_w, ligand_l, pocket = data['ligand'], data['ligand_l'], data['pocket'] |
|
|
loss, info = self.compute_dpo_loss(pocket, ligand_w=ligand_w, ligand_l=ligand_l, return_info=True) |
|
|
|
|
|
if torch.isnan(loss): |
|
|
print(f'For ligand pair , loss is NaN at epoch {self.current_epoch}. Info: {info}') |
|
|
|
|
|
log_dict = {k: v for k, v in info.items() if isinstance(v, float) or torch.numel(v) <= 1} |
|
|
self.log_metrics({'loss': loss, **log_dict}, 'train', batch_size=len(ligand_w['size'])) |
|
|
|
|
|
out = {'loss': loss, **info} |
|
|
self.training_step_outputs.append(out) |
|
|
return out |
|
|
|
|
|
def validation_step(self, data, *args): |
|
|
return super().validation_step(data, *args) |
|
|
|
|
|
def compute_dpo_loss(self, pocket, ligand_w, ligand_l, return_info=False): |
|
|
t = torch.rand(ligand_w['size'].size(0), device=ligand_w['x'].device).unsqueeze(-1) |
|
|
|
|
|
if self.dpo_beta_schedule == 't': |
|
|
|
|
|
beta_t = (self.dpo_beta * t).squeeze() |
|
|
elif self.dpo_beta_schedule == 'const': |
|
|
beta_t = self.dpo_beta |
|
|
else: |
|
|
raise ValueError(f'Unknown DPO beta schedule: {self.dpo_beta_schedule}') |
|
|
|
|
|
loss_dict_w = self.compute_loss_single_pair(ligand_w, pocket, t) |
|
|
loss_dict_l = self.compute_loss_single_pair(ligand_l, pocket, t) |
|
|
info = { |
|
|
'loss_x_w': loss_dict_w['theta']['x'].mean().item(), |
|
|
'loss_h_w': loss_dict_w['theta']['h'].mean().item(), |
|
|
'loss_e_w': loss_dict_w['theta']['e'].mean().item(), |
|
|
'loss_x_l': loss_dict_l['theta']['x'].mean().item(), |
|
|
'loss_h_l': loss_dict_l['theta']['h'].mean().item(), |
|
|
'loss_e_l': loss_dict_l['theta']['e'].mean().item(), |
|
|
} |
|
|
if self.dpo_mode == 'single_dpo_comp': |
|
|
loss_w_theta = ( |
|
|
loss_dict_w['theta']['x'] + |
|
|
self.dpo_lambda_h * loss_dict_w['theta']['h'] + |
|
|
self.dpo_lambda_e * loss_dict_w['theta']['e'] |
|
|
) |
|
|
loss_w_ref = ( |
|
|
loss_dict_w['ref']['x'] + |
|
|
self.dpo_lambda_h * loss_dict_w['ref']['h'] + |
|
|
self.dpo_lambda_e * loss_dict_w['ref']['e'] |
|
|
) |
|
|
loss_l_theta = ( |
|
|
loss_dict_l['theta']['x'] + |
|
|
self.dpo_lambda_h * loss_dict_l['theta']['h'] + |
|
|
self.dpo_lambda_e * loss_dict_l['theta']['e'] |
|
|
) |
|
|
loss_l_ref = ( |
|
|
loss_dict_l['ref']['x'] + |
|
|
self.dpo_lambda_h * loss_dict_l['ref']['h'] + |
|
|
self.dpo_lambda_e * loss_dict_l['ref']['e'] |
|
|
) |
|
|
diff_w = loss_w_theta - loss_w_ref |
|
|
diff_l = loss_l_theta - loss_l_ref |
|
|
info['diff_w'] = diff_w.mean().item() |
|
|
info['diff_l'] = diff_l.mean().item() |
|
|
|
|
|
diff = -1 * beta_t * (diff_w - diff_l) |
|
|
loss = -1 * F.logsigmoid(diff) |
|
|
elif self.dpo_mode == 'single_dpo_comp_v3': |
|
|
diff_w_x = loss_dict_w['theta']['x'] - loss_dict_w['ref']['x'] |
|
|
diff_w_h = loss_dict_w['theta']['h'] - loss_dict_w['ref']['h'] |
|
|
diff_w_e = loss_dict_w['theta']['e'] - loss_dict_w['ref']['e'] |
|
|
diff_l_x = loss_dict_l['theta']['x'] - loss_dict_l['ref']['x'] |
|
|
diff_l_h = loss_dict_l['theta']['h'] - loss_dict_l['ref']['h'] |
|
|
diff_l_e = loss_dict_l['theta']['e'] - loss_dict_l['ref']['e'] |
|
|
info['diff_w_x'] = diff_w_x.mean().item() |
|
|
info['diff_w_h'] = diff_w_h.mean().item() |
|
|
info['diff_w_e'] = diff_w_e.mean().item() |
|
|
info['diff_l_x'] = diff_l_x.mean().item() |
|
|
info['diff_l_h'] = diff_l_h.mean().item() |
|
|
info['diff_l_e'] = diff_l_e.mean().item() |
|
|
|
|
|
|
|
|
_diff_w = diff_w_x + self.dpo_lambda_h * diff_w_h + self.dpo_lambda_e * diff_w_e |
|
|
_diff_l = diff_l_x + self.dpo_lambda_h * diff_l_h + self.dpo_lambda_e * diff_l_e |
|
|
info['diff_w'] = _diff_w.mean().item() |
|
|
info['diff_l'] = _diff_l.mean().item() |
|
|
|
|
|
diff_x = diff_w_x - diff_l_x |
|
|
diff_h = diff_w_h - diff_l_h |
|
|
diff_e = diff_w_e - diff_l_e |
|
|
info['diff_x'] = diff_x.mean().item() |
|
|
info['diff_h'] = diff_h.mean().item() |
|
|
info['diff_e'] = diff_e.mean().item() |
|
|
|
|
|
diff = -1 * beta_t * (diff_x + self.dpo_lambda_h * diff_h + self.dpo_lambda_e * diff_e) |
|
|
if self.clamp_dpo: |
|
|
diff = diff.clamp(-10, 10) |
|
|
info['dpo_arg_min'] = diff.min().item() |
|
|
info['dpo_arg_max'] = diff.max().item() |
|
|
info['dpo_arg_mean'] = diff.mean().item() |
|
|
dpo_loss = -1 * self.dpo_lambda_dpo * F.logsigmoid(diff) |
|
|
info['dpo_loss'] = dpo_loss.mean().item() |
|
|
|
|
|
loss_w_theta_reg = ( |
|
|
loss_dict_w['theta']['x'] + |
|
|
self.lambda_h * loss_dict_w['theta']['h'] + |
|
|
self.lambda_e * loss_dict_w['theta']['e'] |
|
|
) |
|
|
info['loss_w_theta_reg'] = loss_w_theta_reg.mean().item() |
|
|
loss_l_theta_reg = ( |
|
|
loss_dict_l['theta']['x'] + |
|
|
self.lambda_h * loss_dict_l['theta']['h'] + |
|
|
self.lambda_e * loss_dict_l['theta']['e'] |
|
|
) |
|
|
info['loss_l_theta_reg'] = loss_l_theta_reg.mean().item() |
|
|
dpo_reg = self.dpo_lambda_w * loss_w_theta_reg + \ |
|
|
self.dpo_lambda_l * loss_l_theta_reg |
|
|
info['dpo_reg'] = dpo_reg.mean().item() |
|
|
loss = dpo_loss + dpo_reg |
|
|
else: |
|
|
raise ValueError(f'Unknown DPO mode: {self.dpo_mode}') |
|
|
|
|
|
if self.timestep_weights is not None: |
|
|
w_t = self.timestep_weights(t).squeeze() |
|
|
loss = w_t * loss |
|
|
|
|
|
loss = loss.mean(0) |
|
|
|
|
|
print(f'Loss is {loss}, info is {info}') |
|
|
|
|
|
return (loss, info) if return_info else loss |
|
|
|
|
|
def compute_loss_single_pair(self, ligand, pocket, t): |
|
|
pocket = Residues(**pocket) |
|
|
|
|
|
|
|
|
ligand, pocket = center_data(ligand, pocket) |
|
|
pocket_com = scatter_mean(pocket['x'], pocket['mask'], dim=0) |
|
|
|
|
|
|
|
|
z0_x = self.module_x.sample_z0(pocket_com, ligand['mask']) |
|
|
z0_h = self.module_h.sample_z0(ligand['mask']) |
|
|
z0_e = self.module_e.sample_z0(ligand['bond_mask']) |
|
|
zt_x = self.module_x.sample_zt(z0_x, ligand['x'], t, ligand['mask']) |
|
|
zt_h = self.module_h.sample_zt(z0_h, ligand['one_hot'], t, ligand['mask']) |
|
|
zt_e = self.module_e.sample_zt(z0_e, ligand['bond_one_hot'], t, ligand['bond_mask']) |
|
|
|
|
|
|
|
|
sc_transform = self.get_sc_transform_fn(None, zt_x, t, None, ligand['mask'], pocket) |
|
|
|
|
|
pred_ligand, _ = self.dynamics( |
|
|
zt_x, zt_h, ligand['mask'], pocket, t, |
|
|
bonds_ligand=(ligand['bonds'], zt_e), |
|
|
sc_transform=sc_transform |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
ref_pred_ligand, _ = self.ref_dynamics( |
|
|
zt_x, zt_h, ligand['mask'], pocket, t, |
|
|
bonds_ligand=(ligand['bonds'], zt_e), |
|
|
sc_transform=sc_transform |
|
|
) |
|
|
|
|
|
|
|
|
loss_x = self.module_x.compute_loss(pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce=self.loss_reduce) |
|
|
ref_loss_x = self.module_x.compute_loss(ref_pred_ligand['vel'], z0_x, ligand['x'], t, ligand['mask'], reduce=self.loss_reduce) |
|
|
|
|
|
t_next = torch.clamp(t + self.train_step_size, max=1.0) |
|
|
|
|
|
loss_h = self.module_h.compute_loss(pred_ligand['logits_h'], zt_h, ligand['one_hot'], ligand['mask'], t, t_next, reduce=self.loss_reduce) |
|
|
ref_loss_h = self.module_h.compute_loss(ref_pred_ligand['logits_h'], zt_h, ligand['one_hot'], ligand['mask'], t, t_next, reduce=self.loss_reduce) |
|
|
loss_e = self.module_e.compute_loss(pred_ligand['logits_e'], zt_e, ligand['bond_one_hot'], ligand['bond_mask'], t, t_next, reduce=self.loss_reduce) |
|
|
ref_loss_e = self.module_e.compute_loss(ref_pred_ligand['logits_e'], zt_e, ligand['bond_one_hot'], ligand['bond_mask'], t, t_next, reduce=self.loss_reduce) |
|
|
|
|
|
return { |
|
|
'theta': { |
|
|
'x': loss_x, |
|
|
'h': loss_h, |
|
|
'e': loss_e, |
|
|
}, |
|
|
'ref': { |
|
|
'x': ref_loss_x, |
|
|
'h': ref_loss_h, |
|
|
'e': ref_loss_e, |
|
|
} |
|
|
} |
|
|
|