File size: 8,060 Bytes
6e7d4ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import io
import random
import warnings
import torch
import webdataset as wds
from pathlib import Path
from torch.utils.data import Dataset
from src.data.data_utils import TensorDict, collate_entity
from src.constants import WEBDATASET_SHARD_SIZE, WEBDATASET_VAL_SIZE
class ProcessedLigandPocketDataset(Dataset):
def __init__(self, pt_path, ligand_transform=None, pocket_transform=None,
catch_errors=False):
self.ligand_transform = ligand_transform
self.pocket_transform = pocket_transform
self.catch_errors = catch_errors
self.pt_path = pt_path
self.data = torch.load(pt_path)
# add number of nodes for convenience
for entity in ['ligands', 'pockets']:
self.data[entity]['size'] = torch.tensor([len(x) for x in self.data[entity]['x']])
self.data[entity]['n_bonds'] = torch.tensor([len(x) for x in self.data[entity]['bond_one_hot']])
def __len__(self):
return len(self.data['ligands']['name'])
def __getitem__(self, idx):
data = {}
data['ligand'] = {key: val[idx] for key, val in self.data['ligands'].items()}
data['pocket'] = {key: val[idx] for key, val in self.data['pockets'].items()}
try:
if self.ligand_transform is not None:
data['ligand'] = self.ligand_transform(data['ligand'])
if self.pocket_transform is not None:
data['pocket'] = self.pocket_transform(data['pocket'])
except (RuntimeError, ValueError) as e:
if self.catch_errors:
warnings.warn(f"{type(e).__name__}('{e}') in data transform. "
f"Returning random item instead")
# replace bad item with a random one
rand_idx = random.randint(0, len(self) - 1)
return self[rand_idx]
else:
raise e
return data
@staticmethod
def collate_fn(batch_pairs, ligand_transform=None):
out = {}
for entity in ['ligand', 'pocket']:
batch = [x[entity] for x in batch_pairs]
if entity == 'ligand' and ligand_transform is not None:
max_size = max(x['size'].item() for x in batch)
# TODO: might have to remove elements from batch if processing fails, warn user in that case
batch = [ligand_transform(x, max_size=max_size) for x in batch]
out[entity] = TensorDict(**collate_entity(batch))
return out
class ClusteredDataset(ProcessedLigandPocketDataset):
def __init__(self, pt_path, ligand_transform=None, pocket_transform=None,
catch_errors=False):
super().__init__(pt_path, ligand_transform, pocket_transform, catch_errors)
self.clusters = list(self.data['clusters'].values())
def __len__(self):
return len(self.clusters)
def __getitem__(self, cidx):
cluster_inds = self.clusters[cidx]
# idx = cluster_inds[random.randint(0, len(cluster_inds) - 1)]
idx = random.choice(cluster_inds)
return super().__getitem__(idx)
class DPODataset(ProcessedLigandPocketDataset):
def __init__(self, pt_path, ligand_transform=None, pocket_transform=None,
catch_errors=False):
self.ligand_transform = ligand_transform
self.pocket_transform = pocket_transform
self.catch_errors = catch_errors
self.pt_path = pt_path
self.data = torch.load(pt_path)
if not 'pockets' in self.data:
self.data['pockets'] = self.data['pockets_w']
if not 'ligands' in self.data:
self.data['ligands'] = self.data['ligands_w']
if (
len(self.data["ligands"]["name"])
!= len(self.data["ligands_l"]["name"])
!= len(self.data["pockets"]["name"])
):
raise ValueError(
"Error while importing DPO Dataset: Number of ligands winning, ligands losing and pockets must be the same"
)
# add number of nodes for convenience
for entity in ['ligands', 'ligands_l', 'pockets']:
self.data[entity]['size'] = torch.tensor([len(x) for x in self.data[entity]['x']])
self.data[entity]['n_bonds'] = torch.tensor([len(x) for x in self.data[entity]['bond_one_hot']])
def __len__(self):
return len(self.data["ligands"]["name"])
def __getitem__(self, idx):
data = {}
data['ligand'] = {key: val[idx] for key, val in self.data['ligands'].items()}
data['ligand_l'] = {key: val[idx] for key, val in self.data['ligands_l'].items()}
data['pocket'] = {key: val[idx] for key, val in self.data['pockets'].items()}
try:
if self.ligand_transform is not None:
data['ligand'] = self.ligand_transform(data['ligand'])
data['ligand_l'] = self.ligand_transform(data['ligand_l'])
if self.pocket_transform is not None:
data['pocket'] = self.pocket_transform(data['pocket'])
except (RuntimeError, ValueError) as e:
if self.catch_errors:
warnings.warn(f"{type(e).__name__}('{e}') in data transform. "
f"Returning random item instead")
# replace bad item with a random one
rand_idx = random.randint(0, len(self) - 1)
return self[rand_idx]
else:
raise e
return data
@staticmethod
def collate_fn(batch_pairs, ligand_transform=None):
out = {}
for entity in ['ligand', 'ligand_l', 'pocket']:
batch = [x[entity] for x in batch_pairs]
if entity in ['ligand', 'ligand_l'] and ligand_transform is not None:
max_size = max(x['size'].item() for x in batch)
batch = [ligand_transform(x, max_size=max_size) for x in batch]
out[entity] = TensorDict(**collate_entity(batch))
return out
##########################################
############### WebDatasets ##############
##########################################
class ProteinLigandWebDataset(wds.WebDataset):
@staticmethod
def collate_fn(batch_pairs, ligand_transform=None):
return ProcessedLigandPocketDataset.collate_fn(batch_pairs, ligand_transform)
def wds_decoder(key, value):
return torch.load(io.BytesIO(value))
def preprocess_wds_item(data):
out = {}
for entity in ['ligand', 'pocket']:
out[entity] = data['pt'][entity]
for attr in ['size', 'n_bonds']:
if torch.is_tensor(out[entity][attr]):
assert len(out[entity][attr]) == 0
out[entity][attr] = 0
return out
def get_wds(data_path, stage, ligand_transform=None, pocket_transform=None):
current_data_dir = Path(data_path, stage)
shards = sorted(current_data_dir.glob('shard-?????.tar'), key=lambda s: int(s.name.split('-')[-1].split('.')[0]))
min_shard = min(shards).name.split('-')[-1].split('.')[0]
max_shard = max(shards).name.split('-')[-1].split('.')[0]
total_size = (int(max_shard) - int(min_shard) + 1) * WEBDATASET_SHARD_SIZE if stage == 'train' else WEBDATASET_VAL_SIZE
url = f'{data_path}/{stage}/shard-{{{min_shard}..{max_shard}}}.tar'
ligand_transform_wrapper = lambda _data: _data
pocket_transform_wrapper = lambda _data: _data
if ligand_transform is not None:
def ligand_transform_wrapper(_data):
_data['pt']['ligand'] = ligand_transform(_data['pt']['ligand'])
return _data
if pocket_transform is not None:
def pocket_transform_wrapper(_data):
_data['pt']['pocket'] = pocket_transform(_data['pt']['pocket'])
return _data
return (
ProteinLigandWebDataset(url, nodesplitter=wds.split_by_node)
.decode(wds_decoder)
.map(ligand_transform_wrapper)
.map(pocket_transform_wrapper)
.map(preprocess_wds_item)
.with_length(total_size)
)
|