File size: 3,362 Bytes
6e7d4ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from rdkit import Chem

from src import constants


def remove_dummy_atoms(rdmol, sanitize=False):
    # find exit atoms to be removed
    dummy_inds = []
    for a in rdmol.GetAtoms():
        if a.GetSymbol() == '*':
            dummy_inds.append(a.GetIdx())

    dummy_inds = sorted(dummy_inds, reverse=True)
    new_mol = Chem.EditableMol(rdmol)
    for idx in dummy_inds:
        new_mol.RemoveAtom(idx)
    new_mol = new_mol.GetMol()
    if sanitize:
        Chem.SanitizeMol(new_mol)
    return new_mol


def build_molecule(coords, atom_types, bonds=None, bond_types=None,
                   atom_props=None, atom_decoder=None, bond_decoder=None):
    """
    Build RDKit molecule with given bonds
    :param coords: N x 3
    :param atom_types: N
    :param bonds: 2 x N_bonds
    :param bond_types: N_bonds
    :param atom_props: Dict, key: property name, value: list of float values (N,)
    :param atom_decoder: list
    :param bond_decoder: list
    :return: RDKit molecule
    """
    if atom_decoder is None:
        atom_decoder = constants.atom_decoder
    if bond_decoder is None:
        bond_decoder = constants.bond_decoder
    assert len(coords) == len(atom_types)
    assert bonds is None or bonds.size(1) == len(bond_types)

    mol = Chem.RWMol()
    for i, atom in enumerate(atom_types):
        element = atom_decoder[atom.item()]
        charge = None
        explicitHs = None

        if len(element) > 1 and element.endswith('H'):
            explicitHs = 1
            element = element[:-1]
        elif element.endswith('+'):
            charge = 1
            element = element[:-1]
        elif element.endswith('-'):
            charge = -1
            element = element[:-1]

        if element == 'NOATOM':
            # element = 'Xe'  # debug
            element = '*'

        a = Chem.Atom(element)

        if explicitHs is not None:
            a.SetNumExplicitHs(explicitHs)
        if charge is not None:
            a.SetFormalCharge(charge)

        if atom_props is not None:
            for k, vals in atom_props.items():
                a.SetDoubleProp(k, vals[i].item())

        mol.AddAtom(a)

    # add coordinates
    conf = Chem.Conformer(mol.GetNumAtoms())
    for i in range(mol.GetNumAtoms()):
        conf.SetAtomPosition(i, (coords[i, 0].item(),
                                 coords[i, 1].item(),
                                 coords[i, 2].item()))
    mol.AddConformer(conf)

    # add bonds
    if bonds is not None:
        for bond, bond_type in zip(bonds.T, bond_types):
            bond_type = bond_decoder[bond_type]
            src = bond[0].item()
            dst = bond[1].item()

            # try:
            if bond_type == 'NOBOND' or mol.GetAtomWithIdx(src).GetSymbol() == '*' or mol.GetAtomWithIdx(dst).GetSymbol() == '*':
                continue
            # except RuntimeError:
            #     from pdb import set_trace; set_trace()

            if mol.GetBondBetweenAtoms(src, dst) is not None:
                assert mol.GetBondBetweenAtoms(src, dst).GetBondType() == bond_type, \
                    "Trying to assign two different types to the same bond."
                continue

            if bond_type is None or src == dst:
                continue
            mol.AddBond(src, dst, bond_type)

    mol = remove_dummy_atoms(mol, sanitize=False)
    return mol