Skip to content

Commit

Permalink
chore: cpu only for msa only
Browse files Browse the repository at this point in the history
  • Loading branch information
YaoYinYing committed Aug 26, 2024
1 parent d7ee1da commit 1fd11ba
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions apps/protein_folding/helixfold3/helixfold/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,12 @@ def split_prediction(pred, rank):
def main(cfg: DictConfig):
logging.set_verbosity(cfg.logging_level)

if cfg.msa_only == True:
logging.warning(f'Model inference will be skipped because MSA-only mode is required.')
logging.warning(f'Use CPU only')
paddle.device.set_device("cpu")


"""main function"""
new_einsum = os.getenv("FLAGS_new_einsum", True)
print(f'>>> PaddlePaddle commit: {paddle.version.commit}')
Expand Down Expand Up @@ -505,6 +511,8 @@ def main(cfg: DictConfig):
model.helixfold.set_state_dict(pd_params['model'])
else:
model.helixfold.set_state_dict(pd_params)



if cfg.precision == "bf16" and cfg.amp_level == "O2":
raise NotImplementedError("bf16 O2 is not supported yet.")
Expand All @@ -531,6 +539,10 @@ def main(cfg: DictConfig):
with open(features_pkl, 'wb') as f:
pickle.dump(feature_dict, f, protocol=4)

if cfg.msa_only == True:
logging.warning(f'Model inference is skipped because MSA-only mode is required.')
exit()

feature_dict['feat'] = batch_convert(feature_dict['feat'], add_batch=True)
feature_dict['label'] = batch_convert(feature_dict['label'], add_batch=True)

Expand Down

0 comments on commit 1fd11ba

Please sign in to comment.