Skip to content

Commit

Permalink
fix: nested parallel runs for protein chains
Browse files Browse the repository at this point in the history
  • Loading branch information
YaoYinYing committed Aug 19, 2024
1 parent 8548570 commit d2feac5
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"entities": [
{
"type": "protein",
"sequence": "GPHMATGQDRVVALVDMDCFFVQVEQRQNPHLRNKPCAVVQYKSWKGGGIIAVSYEARAFGVTRSMWADDAKKLCPDLLLAQVRESRGKANLTKYREASVEVMEIMSRFAVIERASIDEAYVDLTSAVQERLQKLQGQPISADLLPSTYIEGLPQGPTTAEETVQKEGMRKQGLFQWLDSLQIDNLTSPDLQLTVGAVIVEEMRAAIERETGFQCSAGISHNKVLAKLACGLNKPNRQTLVSHGSVPQLFSQMPIRKIRSLGGKLGASVIEILGIEYMGELTQFTESQLQSHFGEKNGSWLYAMCRGIEHDPVKPRQLPKTIGCSKNFPGKTALATREQVQWWLLQLAQELEERLTKDRNDNDRVATQLVVSIRVQGDKRLSSLRRCCALTRYDAHKMSHDAFTVIKNCNTSGIQTEWSPPLTMLFLCATKFSAS",
"count": 1
},
{
"type": "dna",
"sequence": "CATTATGACGCT",
"count": 1
},
{
"type": "dna",
"sequence": "AGCGTCAT",
"count": 1
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"entities": [
{
"type": "protein",
"sequence": "GPHMATGQDRVVALVDMDCFFVQVEQRQNPHLRNKPCAVVQYKSWKGGGIIAVSYEARAFGVTRSMWADDAKKLCPDLLLAQVRESRGKANLTKYREASVEVMEIMSRFAVIERASIDEAYVDLTSAVQERLQKLQGQPISADLLPSTYIEGLPQGPTTAEETVQKEGMRKQGLFQWLDSLQIDNLTSPDLQLTVGAVIVEEMRAAIERETGFQCSAGISHNKVLAKLACGLNKPNRQTLVSHGSVPQLFSQMPIRKIRSLGGKLGASVIEILGIEYMGELTQFTESQLQSHFGEKNGSWLYAMCRGIEHDPVKPRQLPKTIGCSKNFPGKTALATREQVQWWLLQLAQELEERLTKDRNDNDRVATQLVVSIRVQGDKRLSSLRRCCALTRYDAHKMSHDAFTVIKNCNTSGIQTEWSPPLTMLFLCATKFSAS",
"count": 1
},
{
"type": "dna",
"sequence": "CATTATGACGCT",
"count": 1
},
{
"type": "dna",
"sequence": "AGCGTCAT",
"count": 1
},
{
"type": "ligand",
"ccd": "XG4",
"count": 1
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def make_ccd_conf_features(all_chain_info, ccd_preprocessed_dict,
features[k] = np.concatenate(v, axis=0)
features['ref_atom_count'] = np.bincount(features['ref_token2atom_idx'])

assert np.max(features['ref_element']) < 128
assert np.max(features['ref_atom_name_chars']) < 64
assert len(set([len(v) for k, v in features.items() if k != 'ref_atom_count'])) == 1 ## To check same Atom-level features.
assert np.max(features['ref_element']) < 128 # WTF?
assert np.max(features['ref_atom_name_chars']) < 64 # WTF?
assert len(set([len(v) for k, v in features.items() if k != 'ref_atom_count'])) == 1 ## To check same Atom-level features. # WTF?
return features


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
import collections
import os
import time
from typing import MutableMapping, Optional, List
from typing import Any, MutableMapping, Optional, List
from absl import logging
from helixfold.common import residue_constants
from helixfold.data import parsers
from helixfold.data.tools import utils
import numpy as np
import json
import gzip
import pickle
from rdkit import Chem

FeatureDict = MutableMapping[str, np.ndarray]
ELEMENT_MAPPING = Chem.GetPeriodicTable()

Expand Down Expand Up @@ -56,6 +57,15 @@ def flatten_is_protein_features(is_protein_feats: np.ndarray) -> FeatureDict:

return res


def dump_all_ccd_keys(ccd_data: dict[str, Any]):
ccd_keys_file='all_ccd_keys.txt'
if not os.path.isfile(ccd_keys_file):
open(ccd_keys_file, 'w').write('\n'.join(ccd_data.keys()))
logging.warning(f'All ccd keys are dumped to {ccd_keys_file}')
return ccd_keys_file


def make_sequence_features(
all_chain_info, ccd_preprocessed_dict,
extra_feats: Optional[dict]=None) -> FeatureDict:
Expand Down Expand Up @@ -106,13 +116,18 @@ def make_sequence_features(
sym_id = chainid_to_sym_id[_alphabet_chain_id]
for residue_id, ccd_id in enumerate(ccd_seq):
if ccd_id not in ccd_preprocessed_dict:
assert not extra_feats is None and ccd_id in extra_feats,\
f'<{ccd_id}> not in ccd_preprocessed_dict, But got extra_feats is None'
if extra_feats is None:
ccd_kf=dump_all_ccd_keys(ccd_preprocessed_dict)
raise ValueError(f'<{ccd_id}> not in ccd_preprocessed_dict, But got extra_feats is None. See all keys in {ccd_kf}')
if ccd_id not in extra_feats:
ccd_kf=dump_all_ccd_keys(ccd_preprocessed_dict)
raise ValueError(f'<{ccd_id}> not in ccd_preprocessed_dict or extra_feats. See all keys in {ccd_kf}')
_ccd_feats = extra_feats[ccd_id]
else:
_ccd_feats = ccd_preprocessed_dict[ccd_id]
num_atoms = len(_ccd_feats['position'])
assert num_atoms > 0, f'TODO filter - Got CCD <{ccd_id}>: 0 atom nums.'
if num_atoms == 0:
raise NotImplementedError(f'TODO filter - Got CCD <{ccd_id}>: 0 atom nums.')

if ccd_id not in residue_constants.STANDARD_LIST:
features['asym_id'].append(np.array([chain_num_id] * num_atoms, dtype=np.int32))
Expand Down Expand Up @@ -198,12 +213,14 @@ def process(self,

if ccd_preprocessed_dict is None:
ccd_preprocessed_dict = {}
st_1 = time.time()
if 'pkl.gz' in self.ccd_preprocessed_path:
with gzip.open(self.ccd_preprocessed_path, "rb") as fp:
ccd_preprocessed_dict = pickle.load(fp)
logging.info(f'load ccd dataset done. use {time.time()-st_1}s')


if not 'pkl.gz' in self.ccd_preprocessed_path:
raise ValueError(f'Invalid ccd_preprocessed_path: {self.ccd_preprocessed_path}')

with utils.timing('Loading CCD dataset'):
with gzip.open(self.ccd_preprocessed_path, "rb") as fp:
ccd_preprocessed_dict = pickle.load(fp)

if select_mmcif_chainID is not None:
select_mmcif_chainID = set(select_mmcif_chainID)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
import os
from pathlib import Path
import time, gzip, pickle
from typing import Optional, Tuple
import numpy as np
from absl import logging


from helixfold.common import residue_constants
from helixfold.data import parsers
from helixfold.data import pipeline_multimer
from helixfold.data.tools import utils
from helixfold.data import pipeline_multimer, pipeline_multimer_parallel
from helixfold.data import pipeline_rna_multimer
from helixfold.data import pipeline_conf_bonds, pipeline_token_feature, pipeline_hybrid
from helixfold.data import label_utils
from concurrent.futures import ProcessPoolExecutor, as_completed

from .preprocess import digit2alphabet


Expand Down Expand Up @@ -330,7 +334,7 @@ def add_assembly_features(all_chain_features, ccd_preprocessed_dict, no_msa_temp
"label": label,}


def process_chain_msa(args):
def process_chain_msa(args: tuple[pipeline_multimer_parallel.DataPipeline, str, Optional[str],Optional[str], os.PathLike,os.PathLike ]) -> Tuple[str,dict, str, str]:
"""
处理链,如果缓存了特征文件,则直接使用缓存的特征文件,否则生成新的特征文件。
Expand Down Expand Up @@ -360,12 +364,12 @@ def process_chain_msa(args):
with open(features_pkl, 'rb') as f:
raw_features = pickle.load(f)
else:
t0 = time.time()
raw_features = data_pipeline._process_single_chain(
chain_id, sequence=seq, description=desc,
msa_output_dir=msa_output_dir,
is_homomer_or_monomer=False)
print(f'[MSA/Template] {desc}; seq length: {len(seq)}; use: {time.time() - t0}')
with utils.timing(f'[MSA/Template]({desc}) with seq length: {len(seq)}'):
raw_features = data_pipeline._process_single_chain(
chain_id, sequence=seq, description=desc,
msa_output_dir=msa_output_dir,
is_homomer_or_monomer=False)


with open(features_pkl, 'wb') as f:
pickle.dump(raw_features, f, protocol=4)
Expand Down Expand Up @@ -450,19 +454,10 @@ def process_input_json(all_entitys, ccd_preprocessed_path,

## 2. multiprocessing for protein/rna MSA/Template search.
seqs_to_msa_features = {}
logging.info('[Multiprocess] starting MSA/Template search...')
t0 = time.time()
with ProcessPoolExecutor() as executor:
futures = [executor.submit(process_chain_msa, task) for task in tasks]

for future in as_completed(futures):
try:
_, raw_features, type_chain_id, seqs = future.result()
seqs_to_msa_features[seqs] = raw_features
except Exception as exc:
import traceback; traceback.print_exc()
logging.error(f'Task generated an exception : {exc}')
logging.info(f'[Multiprocess] All msa/template use: {time.time() - t0}')
with utils.timing('MSA/Template search'):
for task in tasks:
_, raw_features, type_chain_id, seqs=process_chain_msa(task)
seqs_to_msa_features[seqs] = raw_features

## 3. add msa_templ_feats to all_chain_features.
for type_chain_id in all_chain_features.keys():
Expand Down

0 comments on commit d2feac5

Please sign in to comment.