File size: 9,459 Bytes
6e7d4ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import warnings
from typing import Union, Iterable
import random
# import argparse
from argparse import Namespace

import numpy as np
import torch
from rdkit import Chem, RDLogger
from rdkit.Chem import KekulizeException, AtomKekulizeException
import networkx as nx
from networkx.algorithms import isomorphism
from torch_scatter import scatter_add, scatter_mean


class Queue():
    def __init__(self, max_len=50):
        self.items = []
        self.max_len = max_len

    def __len__(self):
        return len(self.items)

    def add(self, item):
        self.items.insert(0, item)
        if len(self) > self.max_len:
            self.items.pop()

    def mean(self):
        return np.mean(self.items)

    def std(self):
        return np.std(self.items)


def reverse_tensor(x):
    return x[torch.arange(x.size(0) - 1, -1, -1)]


#####


def sum_except_batch(x, indices):
    if len(x.size()) < 2:
        x = x.unsqueeze(-1)
    return scatter_add(x.sum(list(range(1, len(x.size())))), indices, dim=0)


def remove_mean_batch(x, batch_mask, dim_size=None):
    # Compute center of mass per sample
    mean = scatter_mean(x, batch_mask, dim=0, dim_size=dim_size)
    x = x - mean[batch_mask]
    return x, mean


def assert_mean_zero(x, batch_mask, thresh=1e-2, eps=1e-10):
    largest_value = x.abs().max().item()
    error = scatter_add(x, batch_mask, dim=0).abs().max().item()
    rel_error = error / (largest_value + eps)
    assert rel_error < thresh, f'Mean is not zero, relative_error {rel_error}'


def bvm(v, m):
    """
    Batched vector-matrix product of the form out = v @ m
    :param v: (b, n_in)
    :param m: (b, n_in, n_out)
    :return: (b, n_out)
    """
    # return (v.unsqueeze(1) @ m).squeeze()
    return torch.bmm(v.unsqueeze(1), m).squeeze(1)


def get_grad_norm(
        parameters: Union[torch.Tensor, Iterable[torch.Tensor]],
        norm_type: float = 2.0) -> torch.Tensor:
    """
    Adapted from: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_
    """

    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]

    norm_type = float(norm_type)

    if len(parameters) == 0:
        return torch.tensor(0.)

    device = parameters[0].grad.device

    total_norm = torch.norm(torch.stack(
        [torch.norm(p.grad.detach(), norm_type).to(device) for p in
         parameters]), norm_type)

    return total_norm


def write_xyz_file(coords, atom_types, filename):
    out = f"{len(coords)}\n\n"
    assert len(coords) == len(atom_types)
    for i in range(len(coords)):
        out += f"{atom_types[i]} {coords[i, 0]:.3f} {coords[i, 1]:.3f} {coords[i, 2]:.3f}\n"
    with open(filename, 'w') as f:
        f.write(out)


def write_sdf_file(sdf_path, molecules, catch_errors=True, connected=False):
    with Chem.SDWriter(str(sdf_path)) as w:
        for mol in molecules:
            try:
                if mol is None:
                    raise ValueError("Mol is None.")
                w.write(get_largest_connected_component(mol) if connected else mol)

            except (RuntimeError, ValueError) as e:
                if not catch_errors:
                    raise e

                if isinstance(e, (KekulizeException, AtomKekulizeException)):
                    w.SetKekulize(False)
                    w.write(get_largest_connected_component(mol) if connected else mol)
                    w.SetKekulize(True)
                    warnings.warn(f"Mol saved without kekulization.")
                else:
                    # write empty mol to preserve the original order
                    w.write(Chem.Mol())
                    warnings.warn(f"Erroneous mol replaced with empty dummy.")


def get_largest_connected_component(mol):
    try:
        frags = Chem.GetMolFrags(mol, asMols=True)
        newmol = max(frags, key=lambda m: m.GetNumAtoms())
    except:
        newmol = mol
    return newmol


def write_chain(filename, rdmol_chain):
    with open(filename, 'w') as f:
        f.write("".join([Chem.MolToXYZBlock(m) for m in rdmol_chain]))


def combine_sdfs(sdf_list, out_file):
    all_content = []
    for sdf in sdf_list:
        with open(sdf, 'r') as f:
            all_content.append(f.read())
    combined_str = '$$$$\n'.join(all_content)
    with open(out_file, 'w') as f:
        f.write(combined_str)


def batch_to_list(data, batch_mask, keep_order=True):
    if keep_order:  # preserve order of elements within each sample
        data_list = [data[batch_mask == i]
                     for i in torch.unique(batch_mask, sorted=True)]
        return data_list

    # make sure batch_mask is increasing
    idx = torch.argsort(batch_mask)
    batch_mask = batch_mask[idx]
    data = data[idx]

    chunk_sizes = torch.unique(batch_mask, return_counts=True)[1].tolist()
    return torch.split(data, chunk_sizes)


def batch_to_list_for_indices(indices, batch_mask, offsets=None):
    # (2, n) -> (n, 2)
    split = batch_to_list(indices.T, batch_mask)

    # rebase indices at zero & (n, 2) -> (2, n)
    if offsets is None:
        warnings.warn("Trying to infer index offset from smallest element in "
                      "batch. This might be wrong.")
        split = [x.T - x.min() for x in split]
    else:
        # Typically 'offsets' would be accumulate(sizes[:-1], initial=0)
        assert len(offsets) == len(split) or indices.numel() == 0
        split = [x.T - offset for x, offset in zip(split, offsets)]

    return split


def num_nodes_to_batch_mask(n_samples, num_nodes, device):
    assert isinstance(num_nodes, int) or len(num_nodes) == n_samples

    if isinstance(num_nodes, torch.Tensor):
        num_nodes = num_nodes.to(device)

    sample_inds = torch.arange(n_samples, device=device)

    return torch.repeat_interleave(sample_inds, num_nodes)


def rdmol_to_nxgraph(rdmol):
    graph = nx.Graph()
    for atom in rdmol.GetAtoms():
        # Add the atoms as nodes
        graph.add_node(atom.GetIdx(), atom_type=atom.GetAtomicNum())

    # Add the bonds as edges
    for bond in rdmol.GetBonds():
        graph.add_edge(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())

    return graph


def calc_rmsd(mol_a, mol_b):
    """ Calculate RMSD of two molecules with unknown atom correspondence. """
    graph_a = rdmol_to_nxgraph(mol_a)
    graph_b = rdmol_to_nxgraph(mol_b)

    gm = isomorphism.GraphMatcher(
        graph_a, graph_b,
        node_match=lambda na, nb: na['atom_type'] == nb['atom_type'])

    isomorphisms = list(gm.isomorphisms_iter())
    if len(isomorphisms) < 1:
        return None

    all_rmsds = []
    for mapping in isomorphisms:
        atom_types_a = [atom.GetAtomicNum() for atom in mol_a.GetAtoms()]
        atom_types_b = [mol_b.GetAtomWithIdx(mapping[i]).GetAtomicNum()
                        for i in range(mol_b.GetNumAtoms())]
        assert atom_types_a == atom_types_b

        conf_a = mol_a.GetConformer()
        coords_a = np.array([conf_a.GetAtomPosition(i)
                             for i in range(mol_a.GetNumAtoms())])
        conf_b = mol_b.GetConformer()
        coords_b = np.array([conf_b.GetAtomPosition(mapping[i])
                             for i in range(mol_b.GetNumAtoms())])

        diff = coords_a - coords_b
        rmsd = np.sqrt(np.mean(np.sum(diff * diff, axis=1)))
        all_rmsds.append(rmsd)

    if len(isomorphisms) > 1:
        print("More than one isomorphism found. Returning minimum RMSD.")

    return min(all_rmsds)


def set_deterministic(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def disable_rdkit_logging():
    # RDLogger.DisableLog('rdApp.*')
    RDLogger.DisableLog('rdApp.info')
    RDLogger.DisableLog('rdApp.error')
    RDLogger.DisableLog('rdApp.warning')


# class Namespace(argparse.Namespace):
#     """Simple definition of a Namespace class that supports Namespace- and
#     dictionary-like access."""
#     def __getitem__(self, key):
#         # return vars(self)[key]
#         return self.__dict__[key]
#
#     def __setitem__(self, key, value):
#         self.__dict__[key] = value
#
#     def __getattr__(self, item):
#         """Supports other dictionary functionalities, e.g. get(), items(), etc."""
#         # return getattr(vars(self), item)
#         return getattr(self.__dict__, item)


def dict_to_namespace(input_dict):
    """ Recursively convert a nested dictionary into a Namespace object. """
    if isinstance(input_dict, dict):
        output_namespace = Namespace()
        output = output_namespace.__dict__
        for key, value in input_dict.items():
            output[key] = dict_to_namespace(value)
        return output_namespace

    elif isinstance(input_dict, Namespace):
        # recurse as Namespace might contain dictionaries
        return dict_to_namespace(input_dict.__dict__)

    else:
        return input_dict


def namespace_to_dict(x):
    """ Recursively convert a nested Namespace object into a dictionary. """
    if not (isinstance(x, Namespace) or isinstance(x, dict)):
        return x

    if isinstance(x, Namespace):
        x = vars(x)

    # recurse
    output = {}
    for key, value in x.items():
        output[key] = namespace_to_dict(value)
    return output