Skip to content

Commit

Permalink
fix: multimer msa parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
YaoYinYing committed Aug 18, 2024
1 parent cd92429 commit 8c22e15
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use_small_bfd: false # Corresponds to --use_small_bfd
msa_only: false # Only process msa

nproc_msa:
hhblits: 16
jackhmmer: 8
hhblits: 16 # Number of processors used by hhblits
jackhmmer: 8 # Number of processors used by jackhmmer

# File paths

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from helixfold.data import feature_processing
from helixfold.data import msa_pairing
from helixfold.data import parsers
from helixfold.data import pipeline
#from helixfold.data import pipeline
from helixfold.data import pipeline_parallel as pipeline
from helixfold.data.tools import jackhmmer
import numpy as np
import multiprocessing
Expand Down Expand Up @@ -197,24 +198,38 @@ def _process_single_chain(
with temp_fasta_file(chain_fasta_str) as chain_fasta_path:
logging.info('Running monomer pipeline on chain %s: %s',
chain_id, description)

# We only construct the pairing features if there are 2 or more unique
# sequences.
self.jackhmmer_uniprot_args=(
self._uniprot_msa_runner,
str(chain_fasta_path),
os.path.join(chain_msa_output_dir, 'uniprot_hits.sto'),
'sto',
self.use_precomputed_msas,
0
)

chain_features = self._monomer_data_pipeline.process(
input_fasta_path=chain_fasta_path,
msa_output_dir=chain_msa_output_dir)
msa_output_dir=chain_msa_output_dir,
other_args=self.jackhmmer_uniprot_args if not is_homomer_or_monomer else None)

# We only construct the pairing features if there are 2 or more unique
# sequences.
if not is_homomer_or_monomer:
all_seq_msa_features = self._all_seq_msa_features(chain_fasta_path,
chain_msa_output_dir)
all_seq_msa_features = self._all_seq_msa_features(chain_msa_output_dir)
chain_features.update(all_seq_msa_features)
return chain_features

def _all_seq_msa_features(self, input_fasta_path, msa_output_dir):
def _all_seq_msa_features(self, msa_output_dir):
"""Get MSA features for unclustered uniprot, for pairing."""
out_path = os.path.join(msa_output_dir, 'uniprot_hits.sto')
result = pipeline.run_msa_tool(
self._uniprot_msa_runner, input_fasta_path, out_path, 'sto',
self.use_precomputed_msas)
# edited by yinying to adapt to the multiprocess version of run_msa_tool function
result = pipeline.read_msa_result(
msa_out_path=os.path.join(msa_output_dir, 'uniprot_hits.sto'),
msa_format='sto',
max_sto_sequences=0
)
msa = parsers.parse_stockholm(result['sto'])
msa = msa.truncate(max_seqs=self._max_uniprot_hits)
all_seq_features = pipeline.make_msa_features([msa])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,17 @@
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]

class MsaRunner(Protocol):
n_cpu: int

def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
"""Runs the MSA tool on the input fasta file."""
...

def check_used_ncpus(used: list[int]):
ncpus_sum=sum(used)
if ncpus_sum >os.cpu_count():
logging.warning(f"The number of used CPUs({ncpus_sum}) is larger than the number of available CPUs({os.cpu_count()}).")

def make_sequence_features(
sequence: str, description: str, num_res: int) -> FeatureDict:
"""Constructs a feature dict of sequence features."""
Expand Down Expand Up @@ -170,10 +177,10 @@ def __init__(self,
self.use_precomputed_msas = use_precomputed_msas

def parallel_msa_joblib(self, func, input_args: list):
return Parallel(len(input_args))(delayed(func)(args) for args in input_args)
return Parallel(len(input_args),verbose=100)(delayed(func)(args) for args in input_args)


def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
def process(self, input_fasta_path: str, msa_output_dir: str,other_args: Optional[tuple] = None) -> FeatureDict:
"""Runs alignment tools on the input sequence and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
Expand All @@ -186,7 +193,7 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
num_res = len(input_sequence)


msa_tasks = []
msa_tasks: list[Tuple[MsaRunner, str, str, str, bool, int]] = []
"""uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
jackhmmer_uniref90_result = run_msa_tool(
msa_runner=self.jackhmmer_uniref90_runner,
Expand Down Expand Up @@ -225,12 +232,16 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
'sto' if self._use_small_bfd else 'a3m',
self.use_precomputed_msas,
0))

msa_tasks.append(other_args)

check_used_ncpus(used=[mask[0].n_cpu for mask in msa_tasks if hasattr(mask[0], 'n_cpu')])

[
jackhmmer_uniref90_result,
jackhmmer_mgnify_result,
bfd_result,

] = self.parallel_msa_joblib(func=run_msa_tool,
input_args=msa_tasks)

Expand Down

0 comments on commit 8c22e15

Please sign in to comment.