Skip to content

Commit a37c8b4

Browse files
authored
modify reading data way and add inference test (#50)
1 parent 03a24ef commit a37c8b4

File tree

3 files changed

+148
-4
lines changed

3 files changed

+148
-4
lines changed

inference.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from fastfold.model.fastnn import set_chunk_size
3333
from fastfold.data import data_pipeline, feature_pipeline, templates
3434
from fastfold.utils import inject_fastnn
35+
from fastfold.data.parsers import parse_fasta
3536
from fastfold.utils.import_weights import import_jax_weights_
3637
from fastfold.utils.tensor_utils import tensor_tree_map
3738

@@ -141,10 +142,8 @@ def main(args):
141142

142143
# Gather input sequences
143144
with open(args.fasta_path, "r") as fp:
144-
lines = [l.strip() for l in fp.readlines()]
145-
146-
tags, seqs = lines[::2], lines[1::2]
147-
tags = [l[1:] for l in tags]
145+
fasta = fp.read()
146+
seqs, tags = parse_fasta(fasta)
148147

149148
for tag, seq in zip(tags, seqs):
150149
batch = [None]

tests/test_data_utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2021 AlQuraishi Laboratory
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
17+
18+
def random_template_feats(n_templ, n, batch_size=None):
19+
b = []
20+
if batch_size is not None:
21+
b.append(batch_size)
22+
batch = {
23+
"template_mask": np.random.randint(0, 2, (*b, n_templ)),
24+
"template_pseudo_beta_mask": np.random.randint(0, 2, (*b, n_templ, n)),
25+
"template_pseudo_beta": np.random.rand(*b, n_templ, n, 3),
26+
"template_aatype": np.random.randint(0, 22, (*b, n_templ, n)),
27+
"template_all_atom_mask": np.random.randint(
28+
0, 2, (*b, n_templ, n, 37)
29+
),
30+
"template_all_atom_positions":
31+
np.random.rand(*b, n_templ, n, 37, 3) * 10,
32+
"template_torsion_angles_sin_cos":
33+
np.random.rand(*b, n_templ, n, 7, 2),
34+
"template_alt_torsion_angles_sin_cos":
35+
np.random.rand(*b, n_templ, n, 7, 2),
36+
"template_torsion_angles_mask":
37+
np.random.rand(*b, n_templ, n, 7),
38+
}
39+
batch = {k: v.astype(np.float32) for k, v in batch.items()}
40+
batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
41+
return batch
42+
43+
44+
def random_extra_msa_feats(n_extra, n, batch_size=None):
45+
b = []
46+
if batch_size is not None:
47+
b.append(batch_size)
48+
batch = {
49+
"extra_msa": np.random.randint(0, 22, (*b, n_extra, n)).astype(
50+
np.int64
51+
),
52+
"extra_has_deletion": np.random.randint(0, 2, (*b, n_extra, n)).astype(
53+
np.float32
54+
),
55+
"extra_deletion_value": np.random.rand(*b, n_extra, n).astype(
56+
np.float32
57+
),
58+
"extra_msa_mask": np.random.randint(0, 2, (*b, n_extra, n)).astype(
59+
np.float32
60+
),
61+
}
62+
return batch

tests/test_inference.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2021 AlQuraishi Laboratory
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import time
16+
import torch
17+
import ml_collections as mlc
18+
19+
import fastfold
20+
from fastfold.model.hub import AlphaFold
21+
from fastfold.config import model_config
22+
from fastfold.model.fastnn import set_chunk_size
23+
from fastfold.utils import inject_fastnn
24+
from test_data_utils import random_extra_msa_feats, random_template_feats
25+
from fastfold.data import data_transforms
26+
from fastfold.utils.tensor_utils import tensor_tree_map
27+
28+
29+
consts = mlc.ConfigDict(
30+
{
31+
"n_res": 11,
32+
"n_seq": 13,
33+
"n_templ": 3,
34+
"n_extra": 17,
35+
}
36+
)
37+
38+
def inference():
39+
fastfold.distributed.init_dap()
40+
41+
n_seq = consts.n_seq
42+
n_templ = consts.n_templ
43+
n_res = consts.n_res
44+
n_extra_seq = consts.n_extra
45+
46+
47+
config = model_config('model_1')
48+
model = AlphaFold(config)
49+
model = inject_fastnn(model)
50+
model.eval()
51+
model.cuda()
52+
53+
set_chunk_size(model.globals.chunk_size)
54+
55+
batch = {}
56+
tf = torch.randint(config.model.input_embedder.tf_dim - 1, size=(n_res,))
57+
batch["target_feat"] = torch.nn.functional.one_hot(
58+
tf, config.model.input_embedder.tf_dim).float()
59+
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
60+
batch["residue_index"] = torch.arange(n_res)
61+
batch["msa_feat"] = torch.rand((n_seq, n_res, config.model.input_embedder.msa_dim))
62+
t_feats = random_template_feats(n_templ, n_res)
63+
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
64+
extra_feats = random_extra_msa_feats(n_extra_seq, n_res)
65+
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()})
66+
batch["msa_mask"] = torch.randint(low=0, high=2, size=(n_seq, n_res)).float()
67+
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
68+
batch.update(data_transforms.make_atom14_masks(batch))
69+
batch["no_recycling_iters"] = torch.tensor(2.)
70+
add_recycling_dims = lambda t: (
71+
t.unsqueeze(-1).expand(*t.shape, config.data.common.max_recycling_iters))
72+
batch = tensor_tree_map(add_recycling_dims, batch)
73+
74+
75+
with torch.no_grad():
76+
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
77+
t = time.perf_counter()
78+
out = model(batch)
79+
print(f"Inference time: {time.perf_counter() - t}")
80+
81+
if __name__ == "__main__":
82+
inference()
83+
print("Inference Test Passed!")

0 commit comments

Comments
 (0)