File size: 3,362 Bytes
6e7d4ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
from rdkit import Chem
from src import constants
def remove_dummy_atoms(rdmol, sanitize=False):
# find exit atoms to be removed
dummy_inds = []
for a in rdmol.GetAtoms():
if a.GetSymbol() == '*':
dummy_inds.append(a.GetIdx())
dummy_inds = sorted(dummy_inds, reverse=True)
new_mol = Chem.EditableMol(rdmol)
for idx in dummy_inds:
new_mol.RemoveAtom(idx)
new_mol = new_mol.GetMol()
if sanitize:
Chem.SanitizeMol(new_mol)
return new_mol
def build_molecule(coords, atom_types, bonds=None, bond_types=None,
atom_props=None, atom_decoder=None, bond_decoder=None):
"""
Build RDKit molecule with given bonds
:param coords: N x 3
:param atom_types: N
:param bonds: 2 x N_bonds
:param bond_types: N_bonds
:param atom_props: Dict, key: property name, value: list of float values (N,)
:param atom_decoder: list
:param bond_decoder: list
:return: RDKit molecule
"""
if atom_decoder is None:
atom_decoder = constants.atom_decoder
if bond_decoder is None:
bond_decoder = constants.bond_decoder
assert len(coords) == len(atom_types)
assert bonds is None or bonds.size(1) == len(bond_types)
mol = Chem.RWMol()
for i, atom in enumerate(atom_types):
element = atom_decoder[atom.item()]
charge = None
explicitHs = None
if len(element) > 1 and element.endswith('H'):
explicitHs = 1
element = element[:-1]
elif element.endswith('+'):
charge = 1
element = element[:-1]
elif element.endswith('-'):
charge = -1
element = element[:-1]
if element == 'NOATOM':
# element = 'Xe' # debug
element = '*'
a = Chem.Atom(element)
if explicitHs is not None:
a.SetNumExplicitHs(explicitHs)
if charge is not None:
a.SetFormalCharge(charge)
if atom_props is not None:
for k, vals in atom_props.items():
a.SetDoubleProp(k, vals[i].item())
mol.AddAtom(a)
# add coordinates
conf = Chem.Conformer(mol.GetNumAtoms())
for i in range(mol.GetNumAtoms()):
conf.SetAtomPosition(i, (coords[i, 0].item(),
coords[i, 1].item(),
coords[i, 2].item()))
mol.AddConformer(conf)
# add bonds
if bonds is not None:
for bond, bond_type in zip(bonds.T, bond_types):
bond_type = bond_decoder[bond_type]
src = bond[0].item()
dst = bond[1].item()
# try:
if bond_type == 'NOBOND' or mol.GetAtomWithIdx(src).GetSymbol() == '*' or mol.GetAtomWithIdx(dst).GetSymbol() == '*':
continue
# except RuntimeError:
# from pdb import set_trace; set_trace()
if mol.GetBondBetweenAtoms(src, dst) is not None:
assert mol.GetBondBetweenAtoms(src, dst).GetBondType() == bond_type, \
"Trying to assign two different types to the same bond."
continue
if bond_type is None or src == dst:
continue
mol.AddBond(src, dst, bond_type)
mol = remove_dummy_atoms(mol, sanitize=False)
return mol
|