File size: 8,060 Bytes
6e7d4ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import io
import random
import warnings
import torch
import webdataset as wds

from pathlib import Path
from torch.utils.data import Dataset

from src.data.data_utils import TensorDict, collate_entity
from src.constants import WEBDATASET_SHARD_SIZE, WEBDATASET_VAL_SIZE


class ProcessedLigandPocketDataset(Dataset):
    def __init__(self, pt_path, ligand_transform=None, pocket_transform=None,
                 catch_errors=False):

        self.ligand_transform = ligand_transform
        self.pocket_transform = pocket_transform
        self.catch_errors = catch_errors
        self.pt_path = pt_path

        self.data = torch.load(pt_path)

        # add number of nodes for convenience
        for entity in ['ligands', 'pockets']:
            self.data[entity]['size'] = torch.tensor([len(x) for x in self.data[entity]['x']])
            self.data[entity]['n_bonds'] = torch.tensor([len(x) for x in self.data[entity]['bond_one_hot']])

    def __len__(self):
        return len(self.data['ligands']['name'])

    def __getitem__(self, idx):
        data = {}
        data['ligand'] = {key: val[idx] for key, val in self.data['ligands'].items()}
        data['pocket'] = {key: val[idx] for key, val in self.data['pockets'].items()}
        try:
            if self.ligand_transform is not None:
                data['ligand'] = self.ligand_transform(data['ligand'])
            if self.pocket_transform is not None:
                data['pocket'] = self.pocket_transform(data['pocket'])
        except (RuntimeError, ValueError) as e:
            if self.catch_errors:
                warnings.warn(f"{type(e).__name__}('{e}') in data transform. "
                              f"Returning random item instead")
                # replace bad item with a random one
                rand_idx = random.randint(0, len(self) - 1)
                return self[rand_idx]
            else:
                raise e
        return data

    @staticmethod
    def collate_fn(batch_pairs, ligand_transform=None):

        out = {}
        for entity in ['ligand', 'pocket']:
            batch = [x[entity] for x in batch_pairs]

            if entity == 'ligand' and ligand_transform is not None:
                max_size = max(x['size'].item() for x in batch)
                # TODO: might have to remove elements from batch if processing fails, warn user in that case
                batch = [ligand_transform(x, max_size=max_size) for x in batch]

            out[entity] = TensorDict(**collate_entity(batch))

        return out


class ClusteredDataset(ProcessedLigandPocketDataset):
    def __init__(self, pt_path, ligand_transform=None, pocket_transform=None,
                 catch_errors=False):
        super().__init__(pt_path, ligand_transform, pocket_transform, catch_errors)
        self.clusters = list(self.data['clusters'].values())

    def __len__(self):
        return len(self.clusters)

    def __getitem__(self, cidx):
        cluster_inds = self.clusters[cidx]
        # idx = cluster_inds[random.randint(0, len(cluster_inds) - 1)]
        idx = random.choice(cluster_inds)
        return super().__getitem__(idx)

class DPODataset(ProcessedLigandPocketDataset):
    def __init__(self, pt_path, ligand_transform=None, pocket_transform=None,
                 catch_errors=False):
        self.ligand_transform = ligand_transform
        self.pocket_transform = pocket_transform
        self.catch_errors = catch_errors
        self.pt_path = pt_path

        self.data = torch.load(pt_path)

        if not 'pockets' in self.data:
            self.data['pockets'] = self.data['pockets_w']
        if not 'ligands' in self.data:
            self.data['ligands'] = self.data['ligands_w']

        if (
            len(self.data["ligands"]["name"])
            != len(self.data["ligands_l"]["name"])
            != len(self.data["pockets"]["name"])
        ):
            raise ValueError(
                "Error while importing DPO Dataset: Number of ligands winning, ligands losing and pockets must be the same"
            )

        # add number of nodes for convenience
        for entity in ['ligands', 'ligands_l', 'pockets']:
            self.data[entity]['size'] = torch.tensor([len(x) for x in self.data[entity]['x']])
            self.data[entity]['n_bonds'] = torch.tensor([len(x) for x in self.data[entity]['bond_one_hot']])

    def __len__(self):
        return len(self.data["ligands"]["name"])

    def __getitem__(self, idx):
        data = {}
        data['ligand'] = {key: val[idx] for key, val in self.data['ligands'].items()}
        data['ligand_l'] = {key: val[idx] for key, val in self.data['ligands_l'].items()}
        data['pocket'] = {key: val[idx] for key, val in self.data['pockets'].items()}
        try:
            if self.ligand_transform is not None:
                data['ligand'] = self.ligand_transform(data['ligand'])
                data['ligand_l'] = self.ligand_transform(data['ligand_l'])
            if self.pocket_transform is not None:
                data['pocket'] = self.pocket_transform(data['pocket'])
        except (RuntimeError, ValueError) as e:
            if self.catch_errors:
                warnings.warn(f"{type(e).__name__}('{e}') in data transform. "
                              f"Returning random item instead")
                # replace bad item with a random one
                rand_idx = random.randint(0, len(self) - 1)
                return self[rand_idx]
            else:
                raise e
        return data
    
    @staticmethod
    def collate_fn(batch_pairs, ligand_transform=None):

        out = {}
        for entity in ['ligand', 'ligand_l', 'pocket']:
            batch = [x[entity] for x in batch_pairs]

            if entity in ['ligand', 'ligand_l'] and ligand_transform is not None:
                max_size = max(x['size'].item() for x in batch)
                batch = [ligand_transform(x, max_size=max_size) for x in batch]

            out[entity] = TensorDict(**collate_entity(batch))

        return out

##########################################
############### WebDatasets ##############
##########################################

class ProteinLigandWebDataset(wds.WebDataset):
    @staticmethod
    def collate_fn(batch_pairs, ligand_transform=None):
        return ProcessedLigandPocketDataset.collate_fn(batch_pairs, ligand_transform)


def wds_decoder(key, value):
    return torch.load(io.BytesIO(value))


def preprocess_wds_item(data):
    out = {}
    for entity in ['ligand', 'pocket']:
        out[entity] = data['pt'][entity]
        for attr in ['size', 'n_bonds']:
            if torch.is_tensor(out[entity][attr]):
                assert len(out[entity][attr]) == 0
                out[entity][attr] = 0

    return out


def get_wds(data_path, stage, ligand_transform=None, pocket_transform=None):
    current_data_dir = Path(data_path, stage)
    shards = sorted(current_data_dir.glob('shard-?????.tar'), key=lambda s: int(s.name.split('-')[-1].split('.')[0]))
    min_shard = min(shards).name.split('-')[-1].split('.')[0]
    max_shard = max(shards).name.split('-')[-1].split('.')[0]
    total_size = (int(max_shard) - int(min_shard) + 1) * WEBDATASET_SHARD_SIZE if stage == 'train' else WEBDATASET_VAL_SIZE

    url = f'{data_path}/{stage}/shard-{{{min_shard}..{max_shard}}}.tar'
    ligand_transform_wrapper = lambda _data: _data
    pocket_transform_wrapper = lambda _data: _data

    if ligand_transform is not None:
        def ligand_transform_wrapper(_data):
            _data['pt']['ligand'] = ligand_transform(_data['pt']['ligand'])
            return _data
        
    if pocket_transform is not None:
        def pocket_transform_wrapper(_data):
            _data['pt']['pocket'] = pocket_transform(_data['pt']['pocket'])
            return _data

    return (
        ProteinLigandWebDataset(url, nodesplitter=wds.split_by_node)
        .decode(wds_decoder)
        .map(ligand_transform_wrapper)
        .map(pocket_transform_wrapper)
        .map(preprocess_wds_item)
        .with_length(total_size)
    )