From 4daee6ccdd349dd8dab61f1039ed8ef5c5606eb1 Mon Sep 17 00:00:00 2001 From: YaoYinYing <33014714+YaoYinYing@users.noreply.github.com> Date: Wed, 28 Aug 2024 13:03:58 +0800 Subject: [PATCH] chore: add batch mode against input dir --- .../helixfold/config/helixfold.yaml | 2 +- .../infer_scripts/feature_processing_aa.py | 4 +- .../helixfold3/helixfold/inference.py | 275 ++++++++++-------- 3 files changed, 157 insertions(+), 124 deletions(-) diff --git a/apps/protein_folding/helixfold3/helixfold/config/helixfold.yaml b/apps/protein_folding/helixfold3/helixfold/config/helixfold.yaml index 6b40ea23..fb07421c 100644 --- a/apps/protein_folding/helixfold3/helixfold/config/helixfold.yaml +++ b/apps/protein_folding/helixfold3/helixfold/config/helixfold.yaml @@ -20,7 +20,7 @@ nproc_msa: # File paths -input: null # Corresponds to --input_json, required field +input: null # Input file/Directory, required field. If it's a directory, run HF3 against all files under input directory. output: null # Corresponds to --output_dir, required field override: false # Set true to override existing msa output directory diff --git a/apps/protein_folding/helixfold3/helixfold/infer_scripts/feature_processing_aa.py b/apps/protein_folding/helixfold3/helixfold/infer_scripts/feature_processing_aa.py index 31272c32..28adea9b 100644 --- a/apps/protein_folding/helixfold3/helixfold/infer_scripts/feature_processing_aa.py +++ b/apps/protein_folding/helixfold3/helixfold/infer_scripts/feature_processing_aa.py @@ -369,12 +369,10 @@ def process_chain_msa(args: tuple[pipeline_multimer_parallel.DataPipeline, str, return chain_id, raw_features, desc, seq -def process_input_json(all_entitys: List[Entity], ccd_preprocessed_path, +def process_input_json(all_entitys: List[Entity], ccd_preprocessed_dict, msa_templ_data_pipeline_dict, msa_output_dir, no_msa_templ_feats=False): - ## load ccd dict. - ccd_preprocessed_dict = pipeline_conf_bonds.load_ccd_dict(ccd_preprocessed_path) all_chain_features = {} sequence_features = {} num_chains = 0 diff --git a/apps/protein_folding/helixfold3/helixfold/inference.py b/apps/protein_folding/helixfold3/helixfold/inference.py index 57fd692c..dc9eb85b 100644 --- a/apps/protein_folding/helixfold3/helixfold/inference.py +++ b/apps/protein_folding/helixfold3/helixfold/inference.py @@ -13,6 +13,7 @@ # limitations under the License. """Inference scripts.""" +from dataclasses import dataclass import re import os import copy @@ -42,7 +43,7 @@ from helixfold.data.tools import hmmsearch from helixfold.data import templates from helixfold.utils.utils import get_custom_amp_list -from typing import Dict +from typing import Dict, Mapping from helixfold.infer_scripts import feature_processing_aa, preprocess from helixfold.infer_scripts.tools import mmcif_writer @@ -448,153 +449,187 @@ def split_prediction(pred, rank): return prediction +@dataclass +class HelixFold: -@hydra.main(version_base=None, config_path=os.path.join(script_path,'config',),config_name='helixfold') -def main(cfg: DictConfig): - logging.set_verbosity(cfg.logging_level) + cfg: DictConfig - if cfg.msa_only == True: - logging.warning(f'Model inference will be skipped because MSA-only mode is required.') - logging.warning(f'Use CPU only') - paddle.device.set_device("cpu") - + model: RunModel =None + model_config: DictConfig=None - """main function""" - new_einsum = os.getenv("FLAGS_new_einsum", True) - print(f'>>> PaddlePaddle commit: {paddle.version.commit}') - print(f'>>> FLAGS_new_einsum: {new_einsum}') - print(f'>>> config:\n{cfg}') + ccd_dict: Mapping=None - ## check maxit binary path - maxit_binary=resolve_bin_path(cfg.other.maxit_binary,'maxit') - - RCSBROOT=os.path.join(os.path.dirname(maxit_binary), '..') - os.environ['RCSBROOT']=RCSBROOT + def __post_init__(self) -> None: - ## check obabel - obabel_bin=resolve_bin_path(cfg.bin.obabel,'obabel') - os.environ['OBABEL_BIN']=obabel_bin + logging.set_verbosity(self.cfg.logging_level) + ccd_preprocessed_path = self.cfg.db.ccd_preprocessed + self.ccd_dict=load_ccd_dict(ccd_preprocessed_path) - all_entitys = preprocess_json_entity(cfg.input, cfg.output) - - ### Set seed for reproducibility - seed = cfg.seed - if seed is None: - seed = np.random.randint(10000000) - else: - logging.warning('Seed is only used for reproduction') - init_seed(seed) - use_small_bfd = cfg.preset.preset == 'reduced_dbs' - setattr(cfg, 'use_small_bfd', use_small_bfd) - if use_small_bfd: - assert cfg.db.small_bfd is not None - else: - assert cfg.db.bfd is not None - assert cfg.db.uniclust30 is not None + if self.cfg.msa_only == True: + logging.warning(f'Model inference will be skipped because MSA-only mode is required.') + logging.warning(f'Use CPU only') + paddle.device.set_device("cpu") + + + """main function""" + new_einsum = os.getenv("FLAGS_new_einsum", True) + print(f'>>> PaddlePaddle commit: {paddle.version.commit}') + print(f'>>> FLAGS_new_einsum: {new_einsum}') + print(f'>>> config:\n{self.cfg}') - logging.info('Getting MSA/Template Pipelines...') - msa_templ_data_pipeline_dict = get_msa_templates_pipeline(cfg=cfg) + ## check maxit binary path + maxit_binary=resolve_bin_path(self.cfg.other.maxit_binary,'maxit') - ### Create model - model_config = config.model_config(cfg.CONFIG_DIFFS) - logging.warning(f'>>> Model config: \n{model_config}\n\n') + RCSBROOT=os.path.join(os.path.dirname(maxit_binary), '..') + os.environ['RCSBROOT']=RCSBROOT + + ## check obabel + obabel_bin=resolve_bin_path(self.cfg.bin.obabel,'obabel') + os.environ['OBABEL_BIN']=obabel_bin + + use_small_bfd = self.cfg.preset.preset == 'reduced_dbs' + + # fix to small bfd setting + self.cfg.use_small_bfd=use_small_bfd - model = RunModel(model_config) + if self.cfg.use_small_bfd: + assert self.cfg.db.small_bfd is not None + else: + assert self.cfg.db.bfd is not None + assert self.cfg.db.uniclust30 is not None + + ### Create model + self.model_config = config.model_config(self.cfg.CONFIG_DIFFS) + logging.warning(f'>>> Model config: \n{self.model_config}\n\n') + + self.model = RunModel(self.model_config) - if (not cfg.weight_path is None) and (cfg.weight_path != ""): - print(f"Load pretrain model from {cfg.weight_path}") - pd_params = paddle.load(cfg.weight_path) + if (not self.cfg.weight_path is None) and (self.cfg.weight_path != ""): + print(f"Load pretrain model from {self.cfg.weight_path}") + pd_params = paddle.load(self.cfg.weight_path) + + has_opt = 'optimizer' in pd_params + if has_opt: + self.model.helixfold.set_state_dict(pd_params['model']) + else: + self.model.helixfold.set_state_dict(pd_params) + + + if self.cfg.precision == "bf16" and self.cfg.amp_level == "O2": + raise NotImplementedError("bf16 O2 is not supported yet.") + + def fold(self, entity: str): + all_entitys = preprocess_json_entity(entity, self.cfg.output) - has_opt = 'optimizer' in pd_params - if has_opt: - model.helixfold.set_state_dict(pd_params['model']) + ### Set seed for reproducibility + seed = self.cfg.seed + if seed is None: + seed = np.random.randint(10000000) else: - model.helixfold.set_state_dict(pd_params) + logging.warning('Seed is only used for reproduction') + init_seed(seed) - - - if cfg.precision == "bf16" and cfg.amp_level == "O2": - raise NotImplementedError("bf16 O2 is not supported yet.") - - print(f"============ Data Loading ============") - job_base = pathlib.Path(cfg.input).stem - output_dir_base = pathlib.Path(cfg.output).joinpath(job_base) - msa_output_dir = output_dir_base.joinpath('msas') - msa_output_dir.mkdir(parents=True, exist_ok=True) - - features_pkl = output_dir_base.joinpath('final_features.pkl') - if features_pkl.exists() and not cfg.override: - with open(features_pkl, 'rb') as f: - logging.info(f'Load features from precomputed {features_pkl}') - feature_dict = pickle.load(f) - else: - feature_dict = feature_processing_aa.process_input_json( - all_entitys, - ccd_preprocessed_path=cfg.db.ccd_preprocessed, - msa_templ_data_pipeline_dict=msa_templ_data_pipeline_dict, - msa_output_dir=msa_output_dir) - - # save features - with open(features_pkl, 'wb') as f: - pickle.dump(feature_dict, f, protocol=4) - - if cfg.msa_only == True: - logging.warning(f'Model inference is skipped because MSA-only mode is required.') - exit() - - feature_dict['feat'] = batch_convert(feature_dict['feat'], add_batch=True) - feature_dict['label'] = batch_convert(feature_dict['label'], add_batch=True) - - print(f"============ Start Inference ============") - - infer_times = cfg.infer_times - if cfg.diff_batch_size > 0: - model_config.model.heads.diffusion_module.test_diff_batch_size = cfg.diff_batch_size - diff_batch_size = model_config.model.heads.diffusion_module.test_diff_batch_size - logging.info(f'Inference {infer_times} Times...') - logging.info(f"Diffusion batch size {diff_batch_size}...\n") - all_pred_path = [] - for infer_id in range(infer_times): + + logging.info('Getting MSA/Template Pipelines...') + msa_templ_data_pipeline_dict = get_msa_templates_pipeline(cfg=self.cfg) + + print(f"============ Data Loading ============") + job_base = pathlib.Path(entity).stem + output_dir_base = pathlib.Path(self.cfg.output).joinpath(job_base) + msa_output_dir = output_dir_base.joinpath('msas') + msa_output_dir.mkdir(parents=True, exist_ok=True) + + features_pkl = output_dir_base.joinpath('final_features.pkl') + if features_pkl.exists() and not self.cfg.override: + with open(features_pkl, 'rb') as f: + logging.info(f'Load features from precomputed {features_pkl}') + feature_dict = pickle.load(f) + else: + feature_dict = feature_processing_aa.process_input_json( + all_entitys, + ccd_preprocessed_dict=self.ccd_dict, + msa_templ_data_pipeline_dict=msa_templ_data_pipeline_dict, + msa_output_dir=msa_output_dir) + + # save features + with open(features_pkl, 'wb') as f: + pickle.dump(feature_dict, f, protocol=4) + + if self.cfg.msa_only == True: + logging.warning(f'Model inference is skipped because MSA-only mode is required.') + return + + feature_dict['feat'] = batch_convert(feature_dict['feat'], add_batch=True) + feature_dict['label'] = batch_convert(feature_dict['label'], add_batch=True) + + print(f"============ Start Inference ============") - logging.info(f'Start {infer_id}-th inference...\n') - prediction = eval(cfg, model, feature_dict) + infer_times = self.cfg.infer_times + if self.cfg.diff_batch_size > 0: + self.model_config.model.heads.diffusion_module.test_diff_batch_size = self.cfg.diff_batch_size + diff_batch_size = self.model_config.model.heads.diffusion_module.test_diff_batch_size + logging.info(f'Inference {infer_times} Times...') + logging.info(f"Diffusion batch size {diff_batch_size}...\n") + all_pred_path = [] + for infer_id in range(infer_times): + + logging.info(f'Start {infer_id}-th inference...\n') + prediction = eval(self.cfg, self.model, feature_dict) + + # save result + prediction = split_prediction(prediction, diff_batch_size) + for rank_id in range(diff_batch_size): + json_name = job_base + f'-pred-{str(infer_id + 1)}-{str(rank_id + 1)}' + output_dir = pathlib.Path(output_dir_base).joinpath(json_name) + output_dir.mkdir(parents=True, exist_ok=True) + save_result(entry_name=job_base, + feature_dict=feature_dict, + prediction=prediction[rank_id], + output_dir=output_dir, + maxit_bin=self.cfg.other.maxit_binary) + all_pred_path.append(output_dir) - # save result - prediction = split_prediction(prediction, diff_batch_size) - for rank_id in range(diff_batch_size): - json_name = job_base + f'-pred-{str(infer_id + 1)}-{str(rank_id + 1)}' - output_dir = pathlib.Path(output_dir_base).joinpath(json_name) - output_dir.mkdir(parents=True, exist_ok=True) - save_result(entry_name=job_base, - feature_dict=feature_dict, - prediction=prediction[rank_id], - output_dir=output_dir, - maxit_bin=cfg.other.maxit_binary) - all_pred_path.append(output_dir) + # final ranking + print(f'============ Ranking ! ============') + ranking_all_predictions(all_pred_path) + print(f'============ Inference finished ! ============') - # final ranking - print(f'============ Ranking ! ============') - ranking_all_predictions(all_pred_path) - print(f'============ Inference finished ! ============') @hydra.main(version_base=None, config_path=os.path.join(script_path,'config',),config_name='helixfold') -def show_atom_id_ccd(cfg: DictConfig): - ccd_preprocessed_path = cfg.db.ccd_preprocessed +def main(cfg: DictConfig): + hf_runner=HelixFold(cfg=cfg) - ccd_id=cfg.ccd_id - if len(ccd_id) <= 3 and ccd_id in (ccd_dict:=load_ccd_dict(ccd_preprocessed_path)): - logging.info(f'Atoms in {ccd_id}: {ccd_dict[ccd_id]["atom_ids"]}') + if os.path.isfile(cfg.input): + logging.info(f'Starting inference on {cfg.input}') + hf_runner.fold(cfg.input) return + elif os.path.isdir(cfg.input): + logging.info(f'Starting inference on all files in {cfg.input}') + for f in [i for i in os.listdir(cfg.input) if any(i.endswith(p) for p in ['json', 'jsonl', 'json.gz', 'jsonl.gz'])]: + logging.info(f'Processing {f}') + try: + hf_runner.fold(os.path.join(cfg.input,f)) + except (ValueError, AssertionError, RuntimeError, FileNotFoundError, MemoryError) as e : + logging.error(f'Error processing {f}: {e}') + continue + return + raise ValueError(f'Invalid input file/dir: {cfg.input}') - - +@hydra.main(version_base=None, config_path=os.path.join(script_path,'config',),config_name='helixfold') +def show_atom_id_ccd(cfg: DictConfig): + + ccd_preprocessed_path = cfg.db.ccd_preprocessed + ccd_id=cfg.ccd_id + if len(ccd_id) <= 3 and ccd_id in (ccd_dict:=load_ccd_dict(ccd_preprocessed_path)): + logging.info(f'Atoms in {ccd_id}: {ccd_dict[ccd_id]["atom_ids"]}') + return if __name__ == '__main__':