Skip to content

Commit

Permalink
small changes and uploading model
Browse files Browse the repository at this point in the history
  • Loading branch information
Wengong Jin committed Feb 8, 2019
1 parent bc8d334 commit 01efeef
Show file tree
Hide file tree
Showing 9 changed files with 30,040 additions and 21 deletions.
2 changes: 1 addition & 1 deletion fast_jtnn/jtnn_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def aggregate(self, hiddens, contexts, x_tree_vecs, mode):
elif mode == 'stop':
V, V_o = self.U, self.U_o
else:
raise ValueError('attention mode is wrong')
raise ValueError('aggregate mode is wrong')

tree_contexts = x_tree_vecs.index_select(0, contexts)
input_vec = torch.cat([hiddens, tree_contexts], dim=-1)
Expand Down
2 changes: 1 addition & 1 deletion fast_molvae/vae_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
meters = np.zeros(4)

for epoch in xrange(args.epoch):
loader = MolTreeFolder(args.train, vocab, args.batch_size, num_workers=5)
loader = MolTreeFolder(args.train, vocab, args.batch_size, num_workers=4)
for batch in loader:
total_step += 1
try:
Expand Down
6 changes: 3 additions & 3 deletions jtnn/jtnn_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def forward(self, mol_batch, mol_vec):
stop_acc = torch.eq(stops, stop_targets).float()
stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

return pred_loss, stop_loss, pred_acc.data[0], stop_acc.data[0]
return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()

def decode(self, mol_vec, prob_decode):
stack,trace = [],[]
Expand All @@ -194,7 +194,7 @@ def decode(self, mol_vec, prob_decode):
root_hidden = nn.ReLU()(self.W(root_hidden))
root_score = self.W_o(root_hidden)
_,root_wid = torch.max(root_score, dim=1)
root_wid = root_wid.data[0]
root_wid = root_wid.item()

root = MolTreeNode(self.vocab.get_smiles(root_wid))
root.wid = root_wid
Expand Down Expand Up @@ -223,7 +223,7 @@ def decode(self, mol_vec, prob_decode):
if prob_decode:
backtrack = (torch.bernoulli(1.0 - stop_score.data)[0] == 1)
else:
backtrack = (stop_score.data[0] < 0.5)
backtrack = (stop_score.item() < 0.5)

if not backtrack: #Forward: Predict next clique
new_h = GRU(cur_x, cur_h_nei, self.W_z, self.W_r, self.U_r, self.W_h)
Expand Down
19 changes: 13 additions & 6 deletions jtnn/jtnn_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def set_batch_nodeID(mol_batch, vocab):

class JTNNVAE(nn.Module):

def __init__(self, vocab, hidden_size, latent_size, depth):
def __init__(self, vocab, hidden_size, latent_size, depth, stereo=True):
super(JTNNVAE, self).__init__()
self.vocab = vocab
self.hidden_size = hidden_size
Expand All @@ -43,7 +43,9 @@ def __init__(self, vocab, hidden_size, latent_size, depth):
self.G_var = nn.Linear(hidden_size, latent_size / 2)

self.assm_loss = nn.CrossEntropyLoss(size_average=False)
self.stereo_loss = nn.CrossEntropyLoss(size_average=False)
self.stereo = stereo
if stereo:
self.stereo_loss = nn.CrossEntropyLoss(size_average=False)

def encode(self, mol_batch):
set_batch_nodeID(mol_batch, self.vocab)
Expand Down Expand Up @@ -85,12 +87,15 @@ def forward(self, mol_batch, beta=0):

word_loss, topo_loss, word_acc, topo_acc = self.decoder(mol_batch, tree_vec)
assm_loss, assm_acc = self.assm(mol_batch, mol_vec, tree_mess)
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
if self.stereo:
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
else:
stereo_loss, stereo_acc = 0, 0

all_vec = torch.cat([tree_vec, mol_vec], dim=1)
loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss

return loss, kl_loss.data[0], word_acc, topo_acc, assm_acc, stereo_acc
return loss, kl_loss.item(), word_acc, topo_acc, assm_acc, stereo_acc

def assm(self, mol_batch, mol_vec, tree_mess):
cands = []
Expand Down Expand Up @@ -123,7 +128,7 @@ def assm(self, mol_batch, mol_vec, tree_mess):
cur_score = scores.narrow(0, tot, ncand)
tot += ncand

if cur_score.data[label] >= cur_score.max().data[0]:
if cur_score[label].item() >= cur_score.max().item():
acc += 1

label = create_var(torch.LongTensor([label]))
Expand Down Expand Up @@ -242,6 +247,8 @@ def decode(self, tree_vec, mol_vec, prob_decode):
set_atommap(cur_mol)
cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
if cur_mol is None: return None
if self.stereo == False:
return Chem.MolToSmiles(cur_mol)

smiles2D = Chem.MolToSmiles(cur_mol)
stereo_cands = decode_stereo(smiles2D)
Expand Down Expand Up @@ -285,7 +292,7 @@ def dfs_assemble(self, tree_mess, mol_vec, all_nodes, cur_mol, global_amap, fa_a
backup_mol = Chem.RWMol(cur_mol)
for i in xrange(cand_idx.numel()):
cur_mol = Chem.RWMol(backup_mol)
pred_amap = cand_amap[cand_idx[i].data[0]]
pred_amap = cand_amap[cand_idx[i].item()]
new_global_amap = copy.deepcopy(global_amap)

for nei_id,ctr_atom,nei_atom in pred_amap:
Expand Down
Binary file added molvae/moses-h450L56d3beta0.3/model.iter-2
Binary file not shown.
Loading

0 comments on commit 01efeef

Please sign in to comment.