Skip to content

Commit

Permalink
fix: msa parallel for bfd
Browse files Browse the repository at this point in the history
  • Loading branch information
YaoYinYing committed Aug 18, 2024
1 parent e8582de commit cd92429
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def prediction_to_mmcif(pred_atom_pos: Union[np.ndarray, paddle.Tensor],
'-output', mmcif_path,
]

print('Launching subprocess "%s"', ' '.join(cmd))
print(f'Launching subprocess "{" " .join(cmd)}"', )

if os.path.exists('maxit.log'):
os.remove('maxit.log')
Expand All @@ -197,7 +197,7 @@ def prediction_to_mmcif(pred_atom_pos: Union[np.ndarray, paddle.Tensor],


if retcode:
# Logs have a 15k character limit, so log HHblits error line by line.
# Logs have a 15k character limit, so log Maxit error line by line.
print('Maxit failed. Maxit stderr begin:')
raise RuntimeError(f'Maxit failed\nstdout:\n{stdout.decode("utf-8")}\n\n'
f'stderr:\n{stderr[:500_000].decode("utf-8")}\n'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ diff_batch_size: -1 # Corresponds to --diff_batch_size
use_small_bfd: false # Corresponds to --use_small_bfd
msa_only: false # Only process msa

nproc_msa:
hhblits: 16
jackhmmer: 8

# File paths

input: null # Corresponds to --input_json, required field
Expand Down Expand Up @@ -52,7 +56,7 @@ template:

# Preset configuration
preset:
preset: reduced_dbs # Corresponds to --preset, choices=['reduced_dbs', 'full_dbs']
preset: full_dbs # Corresponds to --preset, choices=['reduced_dbs', 'full_dbs']

# Other configurations
other:
Expand Down
161 changes: 78 additions & 83 deletions apps/protein_folding/helixfold3/helixfold/data/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Functions for building the input features for the HelixFold model."""

import os
from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union
from typing import Any, Mapping, MutableMapping, Optional, Protocol, Sequence, Tuple, Union
from absl import logging
from helixfold.common import residue_constants
from helixfold.data import msa_identifiers
Expand All @@ -26,12 +26,16 @@
from helixfold.data.tools import hmmsearch
from helixfold.data.tools import jackhmmer
import numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
from joblib import Parallel, delayed
# Internal import (7716).

FeatureDict = MutableMapping[str, np.ndarray]
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]

class MsaRunner(Protocol):
def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
"""Runs the MSA tool on the input fasta file."""
...

def make_sequence_features(
sequence: str, description: str, num_res: int) -> FeatureDict:
Expand Down Expand Up @@ -83,42 +87,41 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
features['msa_species_identifiers'] = np.array(species_ids, dtype=np.object_)
return features


def run_msa_tool(msa_runner, input_fasta_path: str, msa_out_path: str,
msa_format: str, use_precomputed_msas: bool,
max_sto_sequences: Optional[int] = None
) -> Mapping[str, Any]:
"""Runs an MSA tool, checking if output already exists first."""
if not use_precomputed_msas or not os.path.exists(msa_out_path):
if msa_format == 'sto' and max_sto_sequences is not None:
print('pipeline:',input_fasta_path,max_sto_sequences)
result = msa_runner.query(input_fasta_path, max_sto_sequences)[0] # pytype: disable=wrong-arg-count
# Yinying edited here to change various position args as one tuple for multiprocess tests
def run_msa_tool(args: Tuple[MsaRunner, str, str, str, bool, int]) -> Mapping[str, Any]:
if args == None:
return None
if (len_args:=len(args))!=6:
raise ValueError(f'MsaRunner must have exactly 6 arguments but got {len_args}')

(msa_runner, input_fasta_path, msa_out_path,
msa_format, use_precomputed_msas,
max_sto_sequences) = args
#print(args)
"""Runs an MSA tool, checking if output already exists first."""
if not use_precomputed_msas or not os.path.exists(msa_out_path):
if msa_format == 'sto' and max_sto_sequences > 0:
result = msa_runner.query(input_fasta_path, max_sto_sequences)[0] # pytype: disable=wrong-arg-count
else:
result = msa_runner.query(input_fasta_path)[0]
with open(msa_out_path, 'w') as f:
f.write(result[msa_format])
else:
result = msa_runner.query(input_fasta_path)[0]
with open(msa_out_path, 'w') as f:
f.write(result[msa_format])
else:
result=read_msa_result(msa_out_path,msa_format,max_sto_sequences)
return result

def read_msa_result(msa_out_path,msa_format,max_sto_sequences):
logging.warning('Reading MSA from file %s', msa_out_path)
if msa_format == 'sto' and max_sto_sequences is not None:
precomputed_msa = parsers.truncate_stockholm_msa(
msa_out_path, max_sto_sequences)
result = {'sto': precomputed_msa}
if msa_format == 'sto' and max_sto_sequences > 0:
precomputed_msa = parsers.truncate_stockholm_msa(
msa_out_path, max_sto_sequences)
result = {'sto': precomputed_msa}
else:
with open(msa_out_path, 'r') as f:
result = {msa_format: f.read()}
return result

def run_msa_tool_wrapper(args):
"""
用于包装run_msa_tool函数的帮助程序,以便在使用argparse时可以更轻松地传递参数。
Args:
args (tuple, list): 一个元组或列表,其中包含要传递给run_msa_tool函数的参数。
Returns:
int: 返回run_msa_tool函数的返回值。
"""
return run_msa_tool(*args)
with open(msa_out_path, 'r') as f:
result = {msa_format: f.read()}
return result




class DataPipeline:
Expand All @@ -138,29 +141,38 @@ def __init__(self,
use_small_bfd: bool,
mgnify_max_hits: int = 501,
uniref_max_hits: int = 10000,
use_precomputed_msas: bool = False):
use_precomputed_msas: bool = False,
nprocs: Mapping[str, int] = {
'hhblits': 16,
'jackhmmer': 8,
}):
"""Initializes the data pipeline. Constructs a feature dict for a given FASTA file."""
self._use_small_bfd = use_small_bfd
self.nprocs=nprocs
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniref90_database_path)
database_path=uniref90_database_path, n_cpu=self.nprocs.get('jackhmmer', 8))
if use_small_bfd:
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
self.bfd_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=small_bfd_database_path)
database_path=small_bfd_database_path, n_cpu=self.nprocs.get('jackhmmer', 8))
else:
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
self.bfd_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=[bfd_database_path, uniclust30_database_path])
databases=[bfd_database_path, uniclust30_database_path], n_cpu=self.nprocs.get('hhblits', 8))
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path)
database_path=mgnify_database_path, n_cpu=self.nprocs.get('jackhmmer', 8))
self.template_searcher = template_searcher
self.template_featurizer = template_featurizer
self.mgnify_max_hits = mgnify_max_hits
self.uniref_max_hits = uniref_max_hits
self.use_precomputed_msas = use_precomputed_msas

def parallel_msa_joblib(self, func, input_args: list):
return Parallel(len(input_args))(delayed(func)(args) for args in input_args)


def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
"""Runs alignment tools on the input sequence and creates features."""
with open(input_fasta_path) as f:
Expand Down Expand Up @@ -202,44 +214,27 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
input_fasta_path,
os.path.join(msa_output_dir, 'mgnify_hits.sto'),
'sto',
self.use_precomputed_msas))
self.use_precomputed_msas,
self.mgnify_max_hits))

if self._use_small_bfd:
msa_tasks.append((
self.jackhmmer_small_bfd_runner,
input_fasta_path,
os.path.join(msa_output_dir, 'small_bfd_hits.sto'),
'sto',
self.use_precomputed_msas))
else:
msa_tasks.append((
self.hhblits_bfd_uniclust_runner,
input_fasta_path,
os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m'),
'a3m',
self.use_precomputed_msas))

msa_results = {}
with ProcessPoolExecutor() as executor:
futures = {executor.submit(run_msa_tool_wrapper, msa_task): msa_task for msa_task in msa_tasks}

for future in as_completed(futures):
task = futures[future]
try:
result = future.result()
if 'uniref90_hits.sto' in task[2]:
msa_results['uniref90'] = result
elif 'mgnify_hits.sto' in task[2]:
msa_results['mgnify'] = result
elif 'small_bfd_hits.sto' in task[2]:
msa_results['small_bfd'] = result
elif 'bfd_uniclust_hits.a3m' in task[2]:
msa_results['bfd_uniclust'] = result

except Exception as exc:
print(f'Task {task} generated an exception : {exc}')

msa_for_templates = msa_results['uniref90']['sto']

msa_tasks.append((
self.bfd_runner,
input_fasta_path,
os.path.join(msa_output_dir, 'small_bfd_hits.sto' if self._use_small_bfd else 'bfd_uniclust_hits.a3m'),
'sto' if self._use_small_bfd else 'a3m',
self.use_precomputed_msas,
0))


[
jackhmmer_uniref90_result,
jackhmmer_mgnify_result,
bfd_result,
] = self.parallel_msa_joblib(func=run_msa_tool,
input_args=msa_tasks)

msa_for_templates = jackhmmer_uniref90_result['sto']
msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(msa_for_templates)

Expand All @@ -257,16 +252,16 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
with open(pdb_hits_out_path, 'w') as f:
f.write(pdb_templates_result)

uniref90_msa = parsers.parse_stockholm(msa_results['uniref90']['sto'])
mgnify_msa = parsers.parse_stockholm(msa_results['mgnify']['sto'])
uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto'])
mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])

pdb_template_hits = self.template_searcher.get_template_hits(
output_string=pdb_templates_result, input_sequence=input_sequence)

if self._use_small_bfd:
bfd_msa = parsers.parse_stockholm(msa_results['small_bfd']['sto'])
bfd_msa = parsers.parse_stockholm(bfd_result['sto'])
else:
raise ValueError("Doesn't support full BFD yet.")
bfd_msa = parsers.parse_a3m(bfd_result['a3m'])

templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
Expand Down
4 changes: 3 additions & 1 deletion apps/protein_folding/helixfold3/helixfold/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ def get_msa_templates_pipeline(cfg: DictConfig) -> Dict:
template_searcher=template_searcher,
template_featurizer=template_featurizer,
use_small_bfd=cfg.use_small_bfd,
use_precomputed_msas=use_precomputed_msas)
use_precomputed_msas=use_precomputed_msas,
nprocs=cfg.nproc_msa,
)

prot_data_pipeline = pipeline_multimer.DataPipeline(
monomer_data_pipeline=monomer_data_pipeline,
Expand Down
1 change: 1 addition & 0 deletions apps/protein_folding/helixfold3/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ rdkit-pypi = "2022.9.5"
posebusters = "*"
hydra-core= "^1.3.2"
omegaconf = "^2.3.0"
joblib = "1.4.2"



Expand Down

0 comments on commit cd92429

Please sign in to comment.