forked from wengong-jin/icml18-jtnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatautils.py
37 lines (28 loc) · 985 Bytes
/
datautils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from torch.utils.data import Dataset
from mol_tree import MolTree
import numpy as np
class MoleculeDataset(Dataset):
def __init__(self, data_file):
with open(data_file) as f:
self.data = [line.strip("\r\n ").split()[0] for line in f]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
smiles = self.data[idx]
mol_tree = MolTree(smiles)
mol_tree.recover()
mol_tree.assemble()
return mol_tree
class PropDataset(Dataset):
def __init__(self, data_file, prop_file):
self.prop_data = np.loadtxt(prop_file)
with open(data_file) as f:
self.data = [line.strip("\r\n ").split()[0] for line in f]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
smiles = self.data[idx]
mol_tree = MolTree(smiles)
mol_tree.recover()
mol_tree.assemble()
return mol_tree, self.prop_data[idx]