diff --git a/apps/protein_folding/helixfold3/helixfold/inference.py b/apps/protein_folding/helixfold3/helixfold/inference.py index 8a862279..57fd692c 100644 --- a/apps/protein_folding/helixfold3/helixfold/inference.py +++ b/apps/protein_folding/helixfold3/helixfold/inference.py @@ -453,6 +453,12 @@ def split_prediction(pred, rank): def main(cfg: DictConfig): logging.set_verbosity(cfg.logging_level) + if cfg.msa_only == True: + logging.warning(f'Model inference will be skipped because MSA-only mode is required.') + logging.warning(f'Use CPU only') + paddle.device.set_device("cpu") + + """main function""" new_einsum = os.getenv("FLAGS_new_einsum", True) print(f'>>> PaddlePaddle commit: {paddle.version.commit}') @@ -505,6 +511,8 @@ def main(cfg: DictConfig): model.helixfold.set_state_dict(pd_params['model']) else: model.helixfold.set_state_dict(pd_params) + + if cfg.precision == "bf16" and cfg.amp_level == "O2": raise NotImplementedError("bf16 O2 is not supported yet.") @@ -531,6 +539,10 @@ def main(cfg: DictConfig): with open(features_pkl, 'wb') as f: pickle.dump(feature_dict, f, protocol=4) + if cfg.msa_only == True: + logging.warning(f'Model inference is skipped because MSA-only mode is required.') + exit() + feature_dict['feat'] = batch_convert(feature_dict['feat'], add_batch=True) feature_dict['label'] = batch_convert(feature_dict['label'], add_batch=True)