forked from wengong-jin/icml18-jtnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmol_tree.py
142 lines (111 loc) · 4.39 KB
/
mol_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import rdkit
import rdkit.Chem as Chem
import copy
from chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, set_atommap, enum_assemble, decode_stereo
def get_slots(smiles):
mol = Chem.MolFromSmiles(smiles)
return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]
class Vocab(object):
def __init__(self, smiles_list):
self.vocab = smiles_list
self.vmap = {x:i for i,x in enumerate(self.vocab)}
self.slots = [get_slots(smiles) for smiles in self.vocab]
def get_index(self, smiles):
return self.vmap[smiles]
def get_smiles(self, idx):
return self.vocab[idx]
def get_slots(self, idx):
return copy.deepcopy(self.slots[idx])
def size(self):
return len(self.vocab)
class MolTreeNode(object):
def __init__(self, smiles, clique=[]):
self.smiles = smiles
self.mol = get_mol(self.smiles)
self.clique = [x for x in clique] #copy
self.neighbors = []
def add_neighbor(self, nei_node):
self.neighbors.append(nei_node)
def recover(self, original_mol):
clique = []
clique.extend(self.clique)
if not self.is_leaf:
for cidx in self.clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid)
for nei_node in self.neighbors:
clique.extend(nei_node.clique)
if nei_node.is_leaf: #Leaf node, no need to mark
continue
for cidx in nei_node.clique:
#allow singleton node override the atom mapping
if cidx not in self.clique or len(nei_node.clique) == 1:
atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node.nid)
clique = list(set(clique))
label_mol = get_clique_mol(original_mol, clique)
self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
self.label_mol = get_mol(self.label)
for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return self.label
def assemble(self):
neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x:x.mol.GetNumAtoms(), reverse=True)
singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble(self, neighbors)
if len(cands) > 0:
self.cands, self.cand_mols, _ = zip(*cands)
self.cands = list(self.cands)
self.cand_mols = list(self.cand_mols)
else:
self.cands = []
self.cand_mols = []
class MolTree(object):
def __init__(self, smiles):
self.smiles = smiles
self.mol = get_mol(smiles)
#Stereo Generation
mol = Chem.MolFromSmiles(smiles)
self.smiles3D = Chem.MolToSmiles(mol, isomericSmiles=True)
self.smiles2D = Chem.MolToSmiles(mol)
self.stereo_cands = decode_stereo(self.smiles2D)
cliques, edges = tree_decomp(self.mol)
self.nodes = []
root = 0
for i,c in enumerate(cliques):
cmol = get_clique_mol(self.mol, c)
node = MolTreeNode(get_smiles(cmol), c)
self.nodes.append(node)
if min(c) == 0:
root = i
for x,y in edges:
self.nodes[x].add_neighbor(self.nodes[y])
self.nodes[y].add_neighbor(self.nodes[x])
if root > 0:
self.nodes[0],self.nodes[root] = self.nodes[root],self.nodes[0]
for i,node in enumerate(self.nodes):
node.nid = i + 1
if len(node.neighbors) > 1: #Leaf node mol is not marked
set_atommap(node.mol, node.nid)
node.is_leaf = (len(node.neighbors) == 1)
def size(self):
return len(self.nodes)
def recover(self):
for node in self.nodes:
node.recover(self.mol)
def assemble(self):
for node in self.nodes:
node.assemble()
if __name__ == "__main__":
import sys
lg = rdkit.RDLogger.logger()
lg.setLevel(rdkit.RDLogger.CRITICAL)
cset = set()
for i,line in enumerate(sys.stdin):
smiles = line.split()[0]
mol = MolTree(smiles)
for c in mol.nodes:
cset.add(c.smiles)
for x in cset:
print x