1+ # Copyright 2021 AlQuraishi Laboratory
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ import numpy as np
16+
17+
18+ def random_template_feats (n_templ , n , batch_size = None ):
19+ b = []
20+ if batch_size is not None :
21+ b .append (batch_size )
22+ batch = {
23+ "template_mask" : np .random .randint (0 , 2 , (* b , n_templ )),
24+ "template_pseudo_beta_mask" : np .random .randint (0 , 2 , (* b , n_templ , n )),
25+ "template_pseudo_beta" : np .random .rand (* b , n_templ , n , 3 ),
26+ "template_aatype" : np .random .randint (0 , 22 , (* b , n_templ , n )),
27+ "template_all_atom_mask" : np .random .randint (
28+ 0 , 2 , (* b , n_templ , n , 37 )
29+ ),
30+ "template_all_atom_positions" :
31+ np .random .rand (* b , n_templ , n , 37 , 3 ) * 10 ,
32+ "template_torsion_angles_sin_cos" :
33+ np .random .rand (* b , n_templ , n , 7 , 2 ),
34+ "template_alt_torsion_angles_sin_cos" :
35+ np .random .rand (* b , n_templ , n , 7 , 2 ),
36+ "template_torsion_angles_mask" :
37+ np .random .rand (* b , n_templ , n , 7 ),
38+ }
39+ batch = {k : v .astype (np .float32 ) for k , v in batch .items ()}
40+ batch ["template_aatype" ] = batch ["template_aatype" ].astype (np .int64 )
41+ return batch
42+
43+
44+ def random_extra_msa_feats (n_extra , n , batch_size = None ):
45+ b = []
46+ if batch_size is not None :
47+ b .append (batch_size )
48+ batch = {
49+ "extra_msa" : np .random .randint (0 , 22 , (* b , n_extra , n )).astype (
50+ np .int64
51+ ),
52+ "extra_has_deletion" : np .random .randint (0 , 2 , (* b , n_extra , n )).astype (
53+ np .float32
54+ ),
55+ "extra_deletion_value" : np .random .rand (* b , n_extra , n ).astype (
56+ np .float32
57+ ),
58+ "extra_msa_mask" : np .random .randint (0 , 2 , (* b , n_extra , n )).astype (
59+ np .float32
60+ ),
61+ }
62+ return batch
0 commit comments