Skip to content

Commit 9231baf

Browse files
ZhouGengmozhougengmo
andauthored
add mol repr demo (#60)
* add mol repr demo * add news Co-authored-by: zhougengmo <[email protected]>
1 parent 157d70e commit 9231baf

File tree

5 files changed

+300
-11
lines changed

5 files changed

+300
-11
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ Uni-Mol is composed of two models: a molecular pretraining model trained by 209M
1515

1616
News
1717
----
18+
**Oct 12 2022**: Provide a demo to get Uni-Mol molecular representation.
19+
1820
**Sep 20 2022**: Provide Uni-Mol based IFD scoring function baseline for [AIAC 2022 Competition Prediction of protein binding ability of drug molecules](http://www.aiinnovation.com.cn/#/aiaeDetail?id=560).
1921

2022
**Sep 9 2022**: Provide Uni-Mol binding pose prediction (docking) demo on Colab.
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "4f0f701f-c552-4ca1-8188-2cdfc1362f6b",
6+
"metadata": {},
7+
"source": [
8+
"# Uni-Mol Molecular Represitation"
9+
]
10+
},
11+
{
12+
"cell_type": "markdown",
13+
"id": "d3449ed8-2a57-4e62-9163-e32baf66e828",
14+
"metadata": {},
15+
"source": [
16+
"**Licenses**\n",
17+
"\n",
18+
"Copyright (c) DP Technology.\n",
19+
"\n",
20+
"This source code is licensed under the MIT license found in the\n",
21+
"LICENSE file in the root directory of this source tree.\n",
22+
"\n",
23+
"**Citations**\n",
24+
"\n",
25+
"Please cite the following papers if you use this notebook:\n",
26+
"\n",
27+
"- Gengmo Zhou, Zhifeng Gao, Qiankun Ding, Hang Zheng, Hongteng Xu, Zhewei Wei, Linfeng Zhang, Guolin Ke. \"[Uni-Mol: A Universal 3D Molecular Representation Learning Framework.](https://chemrxiv.org/engage/chemrxiv/article-details/6318b529bada388485bc8361)\"\n",
28+
"ChemRxiv (2022)"
29+
]
30+
},
31+
{
32+
"cell_type": "code",
33+
"execution_count": null,
34+
"id": "6d51f850-76cd-4801-bf2e-a4c53221d586",
35+
"metadata": {},
36+
"outputs": [],
37+
"source": [
38+
"import os\n",
39+
"import numpy as np\n",
40+
"import pandas as pd\n",
41+
"import lmdb\n",
42+
"from rdkit import Chem\n",
43+
"from rdkit.Chem import AllChem\n",
44+
"from tqdm import tqdm\n",
45+
"import pickle\n",
46+
"import glob"
47+
]
48+
},
49+
{
50+
"cell_type": "markdown",
51+
"id": "89c70ab0-da59-459d-bf1c-ac307e9e7ae5",
52+
"metadata": {},
53+
"source": [
54+
"### Your SMILES list"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": null,
60+
"id": "bfa0ce2a-b7aa-4cae-81ba-27b91c0591e4",
61+
"metadata": {},
62+
"outputs": [],
63+
"source": [
64+
"smi_list = [\n",
65+
"'CC1=C(C(=O)OC2CCCC2)[C@H](c2ccccc2OC(C)C)C2=C(O)CC(C)(C)CC2=[N+]1',\n",
66+
"'COc1cccc(-c2nc(C(=O)NC[C@H]3CCCO3)cc3c2[nH]c2ccccc23)c1',\n",
67+
"'O=C1c2ccccc2C(=O)c2c1ccc(C(=O)n1nc3c4c(cccc41)C(=O)c1ccccc1-3)c2[N+](=O)[O-]',\n",
68+
"'COc1cc(/C=N/c2nonc2NC(C)=O)ccc1OC(C)C',\n",
69+
"'CCC[C@@H]1CN(Cc2ccc3nsnc3c2)C[C@H]1NS(C)(=O)=O',\n",
70+
"'CCc1nnc(N/C(O)=C/CCOc2ccc(OC)cc2)s1',\n",
71+
"'CC(C)(C)SCCN/C=C1\\C(=O)NC(=O)N(c2ccc(Br)cc2)C1=O',\n",
72+
"'CC(C)(C)c1nc(COc2ccc3c(c2)CCn2c-3cc(OCC3COCCO3)nc2=O)no1',\n",
73+
"'N#CCCNS(=O)(=O)c1ccc(/C(O)=N/c2ccccc2Oc2ccccc2Cl)cc1',\n",
74+
"'O=C(Nc1ncc(Cl)s1)c1cccc(S(=O)(=O)Nc2ccc(Br)cc2)c1',\n",
75+
"]"
76+
]
77+
},
78+
{
79+
"cell_type": "markdown",
80+
"id": "b109d84a-8d59-445b-9997-d1383ee24079",
81+
"metadata": {},
82+
"source": [
83+
"### Generate conformations from SMILES and save to .lmdb"
84+
]
85+
},
86+
{
87+
"cell_type": "code",
88+
"execution_count": null,
89+
"id": "ea582d7d-8851-4d46-880e-54867737b232",
90+
"metadata": {},
91+
"outputs": [],
92+
"source": [
93+
"def smi2coords(smi, seed):\n",
94+
" mol = Chem.MolFromSmiles(smi)\n",
95+
" mol = AllChem.AddHs(mol)\n",
96+
" atoms = [atom.GetSymbol() for atom in mol.GetAtoms()]\n",
97+
" coordinate_list = []\n",
98+
" res = AllChem.EmbedMolecule(mol, randomSeed=seed)\n",
99+
" if res == 0:\n",
100+
" try:\n",
101+
" AllChem.MMFFOptimizeMolecule(mol)\n",
102+
" except:\n",
103+
" pass\n",
104+
" coordinates = mol.GetConformer().GetPositions()\n",
105+
" elif res == -1:\n",
106+
" mol_tmp = Chem.MolFromSmiles(smi)\n",
107+
" AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000, randomSeed=seed)\n",
108+
" mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True)\n",
109+
" try:\n",
110+
" AllChem.MMFFOptimizeMolecule(mol_tmp)\n",
111+
" except:\n",
112+
" pass\n",
113+
" coordinates = mol_tmp.GetConformer().GetPositions()\n",
114+
" assert len(atoms) == len(coordinates), \"coordinates shape is not align with {}\".format(smi)\n",
115+
" coordinate_list.append(coordinates.astype(np.float32))\n",
116+
" return pickle.dumps({'atoms': atoms, 'coordinates': coordinate_list, 'smi': smi}, protocol=-1)\n",
117+
"\n",
118+
"def write_lmdb(smiles_list, job_name, seed=42, outpath='./results'):\n",
119+
" os.makedirs(outpath, exist_ok=True)\n",
120+
" output_name = os.path.join(outpath,'{}.lmdb'.format(job_name))\n",
121+
" try:\n",
122+
" os.remove(output_name)\n",
123+
" except:\n",
124+
" pass\n",
125+
" env_new = lmdb.open(\n",
126+
" output_name,\n",
127+
" subdir=False,\n",
128+
" readonly=False,\n",
129+
" lock=False,\n",
130+
" readahead=False,\n",
131+
" meminit=False,\n",
132+
" max_readers=1,\n",
133+
" map_size=int(100e9),\n",
134+
" )\n",
135+
" txn_write = env_new.begin(write=True)\n",
136+
" for i, smiles in tqdm(enumerate(smiles_list)):\n",
137+
" inner_output = smi2coords(smiles, seed=seed)\n",
138+
" txn_write.put(f\"{i}\".encode(\"ascii\"), inner_output)\n",
139+
" txn_write.commit()\n",
140+
" env_new.close()"
141+
]
142+
},
143+
{
144+
"cell_type": "code",
145+
"execution_count": null,
146+
"id": "dad25a1a-f93e-4fdf-b389-2a3fe61a40ee",
147+
"metadata": {},
148+
"outputs": [],
149+
"source": [
150+
"seed = 42\n",
151+
"job_name = 'get_mol_repr' # replace to your custom name\n",
152+
"data_path = './results' # replace to your data path\n",
153+
"weight_path='../ckp/mol_pre_no_h_220816.pt' # replace to your ckpt path\n",
154+
"only_polar=0 # no h\n",
155+
"dict_name='dict.txt'\n",
156+
"batch_size=16\n",
157+
"results_path=data_path # replace to your save path\n",
158+
"write_lmdb(smi_list, job_name=job_name, seed=seed, outpath=data_path)"
159+
]
160+
},
161+
{
162+
"cell_type": "markdown",
163+
"id": "12284210-7f86-4062-b291-7c077ef6f83a",
164+
"metadata": {},
165+
"source": [
166+
"### Infer from ckpt"
167+
]
168+
},
169+
{
170+
"cell_type": "code",
171+
"execution_count": null,
172+
"id": "9fb2391b-81b0-4b11-95ea-3b7855db9bc6",
173+
"metadata": {},
174+
"outputs": [],
175+
"source": [
176+
"!cp ../example_data/molecule/$dict_name $data_path\n",
177+
"!python ../unimol/infer.py --user-dir ../unimol $data_path --valid-subset $job_name \\\n",
178+
" --results-path $results_path \\\n",
179+
" --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \\\n",
180+
" --task unimol --loss unimol_infer --arch unimol_base \\\n",
181+
" --path $weight_path \\\n",
182+
" --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \\\n",
183+
" --only-polar $only_polar --dict-name $dict_name \\\n",
184+
" --log-interval 50 --log-format simple --random-token-prob 0 --leave-unmasked-prob 1.0 --mode infer"
185+
]
186+
},
187+
{
188+
"cell_type": "markdown",
189+
"id": "d8421258-eca6-4801-aadd-fc67fd928cb1",
190+
"metadata": {},
191+
"source": [
192+
"### Read .pkl and save results to .csv"
193+
]
194+
},
195+
{
196+
"cell_type": "code",
197+
"execution_count": null,
198+
"id": "c456f31e-94fc-4593-97c9-1db7182465aa",
199+
"metadata": {},
200+
"outputs": [],
201+
"source": [
202+
"def get_csv_results(predict_path, results_path):\n",
203+
" predict = pd.read_pickle(predict_path)\n",
204+
" smi_list, mol_repr_list, pair_repr_list = [], [], []\n",
205+
" for batch in predict:\n",
206+
" sz = batch[\"bsz\"]\n",
207+
" for i in range(sz):\n",
208+
" smi_list.append(batch[\"smi_name\"][i])\n",
209+
" mol_repr_list.append(batch[\"mol_repr_cls\"][i])\n",
210+
" pair_repr_list.append(batch[\"pair_repr\"][i])\n",
211+
" predict_df = pd.DataFrame({\"SMILES\": smi_list, \"mol_repr\": mol_repr_list, \"pair_repr\": pair_repr_list})\n",
212+
" print(predict_df.head(1),predict_df.info())\n",
213+
" predict_df.to_csv(results_path+'/mol_repr.csv',index=False)\n",
214+
"\n",
215+
"pkl_path = glob.glob(f'{results_path}/*_{job_name}.out.pkl')[0]\n",
216+
"get_csv_results(pkl_path, results_path)"
217+
]
218+
}
219+
],
220+
"metadata": {
221+
"kernelspec": {
222+
"display_name": "Python 3 (ipykernel)",
223+
"language": "python",
224+
"name": "python3"
225+
},
226+
"language_info": {
227+
"codemirror_mode": {
228+
"name": "ipython",
229+
"version": 3
230+
},
231+
"file_extension": ".py",
232+
"mimetype": "text/x-python",
233+
"name": "python",
234+
"nbconvert_exporter": "python",
235+
"pygments_lexer": "ipython3",
236+
"version": "3.8.13"
237+
}
238+
},
239+
"nbformat": 4,
240+
"nbformat_minor": 5
241+
}

unimol/losses/unimol.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,37 @@ def cal_dist_loss(self, sample, dist, masked_tokens, target_key, normalize=False
174174
beta=1.0,
175175
)
176176
return masked_dist_loss
177+
178+
179+
@register_loss("unimol_infer")
180+
class UniMolInferLoss(UnicoreLoss):
181+
def __init__(self, task):
182+
super().__init__(task)
183+
self.padding_idx = task.dictionary.pad()
184+
185+
def forward(self, model, sample, reduce=True):
186+
"""Compute the loss for the given sample.
187+
188+
Returns a tuple with three elements:
189+
1) the loss
190+
2) the sample size, which is used as the denominator for the gradient
191+
3) logging outputs to display while training
192+
"""
193+
input_key = "net_input"
194+
target_key = "target"
195+
src_tokens = sample[input_key]["src_tokens"].ne(self.padding_idx)
196+
(
197+
encoder_rep,
198+
encoder_pair_rep,
199+
) = model(**sample[input_key], features_only=True)
200+
sample_size = sample[input_key]["src_tokens"].size(0)
201+
encoder_pair_rep_list = []
202+
for i in range(sample_size): # rm padding token
203+
encoder_pair_rep_list.append(encoder_pair_rep[i][src_tokens[i], :][:, src_tokens[i]].data.cpu().numpy())
204+
logging_output = {
205+
"mol_repr_cls": encoder_rep[:, 0, :].data.cpu().numpy(), # get cls token
206+
"pair_repr": encoder_pair_rep_list,
207+
"smi_name": sample[target_key]["smi_name"],
208+
"bsz": sample[input_key]["src_tokens"].size(0),
209+
}
210+
return 0, sample_size, logging_output

unimol/models/unimol.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -238,14 +238,16 @@ def get_dist_features(dist, et):
238238

239239
if classification_head_name is not None:
240240
logits = self.classification_heads[classification_head_name](encoder_rep)
241-
242-
return (
243-
logits,
244-
encoder_distance,
245-
encoder_coord,
246-
x_norm,
247-
delta_encoder_pair_rep_norm,
248-
)
241+
if self.args.mode == 'infer':
242+
return encoder_rep, encoder_pair_rep
243+
else:
244+
return (
245+
logits,
246+
encoder_distance,
247+
encoder_coord,
248+
x_norm,
249+
delta_encoder_pair_rep_norm,
250+
)
249251

250252
def register_classification_head(
251253
self, name, num_classes=None, inner_dim=None, **kwargs

unimol/tasks/unimol.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
TokenizeDataset,
1717
RightPadDataset2D,
1818
FromNumpyDataset,
19+
RawArrayDataset,
1920
)
2021
from unimol.data import (
2122
KeyDataset,
@@ -106,6 +107,12 @@ def add_args(parser):
106107
type=int,
107108
help="1: only polar hydrogen ; -1: all hydrogen ; 0: remove all hydrogen ",
108109
)
110+
parser.add_argument(
111+
"--mode",
112+
type=str,
113+
default="train",
114+
choices=["train", "infer"],
115+
)
109116

110117
def __init__(self, args, dictionary):
111118
super().__init__(args)
@@ -136,9 +143,11 @@ def load_dataset(self, split, combine=False, **kwargs):
136143
raw_dataset = LMDBDataset(split_path)
137144

138145
def one_dataset(raw_dataset, coord_seed, mask_seed):
139-
raw_dataset = Add2DConformerDataset(
140-
raw_dataset, "smi", "atoms", "coordinates"
141-
)
146+
if self.args.mode =='train':
147+
raw_dataset = Add2DConformerDataset(
148+
raw_dataset, "smi", "atoms", "coordinates"
149+
)
150+
smi_dataset = KeyDataset(raw_dataset, "smi")
142151
dataset = ConformerSampleDataset(
143152
raw_dataset, coord_seed, "atoms", "coordinates"
144153
)
@@ -217,6 +226,7 @@ def PrependAndAppend(dataset, pre_token, app_token):
217226
),
218227
"distance_target": RightPadDataset2D(distance_dataset, pad_idx=0),
219228
"coord_target": RightPadDatasetCoord(coord_dataset, pad_idx=0),
229+
"smi_name": RawArrayDataset(smi_dataset),
220230
}
221231

222232
net_input, target = one_dataset(raw_dataset, self.args.seed, self.args.seed)

0 commit comments

Comments
 (0)