Skip to content

Commit

Permalink
chore: add batch mode against input dir
Browse files Browse the repository at this point in the history
  • Loading branch information
YaoYinYing committed Aug 28, 2024
1 parent 1fd11ba commit 4daee6c
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ nproc_msa:

# File paths

input: null # Corresponds to --input_json, required field
input: null # Input file/Directory, required field. If it's a directory, run HF3 against all files under input directory.
output: null # Corresponds to --output_dir, required field
override: false # Set true to override existing msa output directory

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,12 +369,10 @@ def process_chain_msa(args: tuple[pipeline_multimer_parallel.DataPipeline, str,
return chain_id, raw_features, desc, seq


def process_input_json(all_entitys: List[Entity], ccd_preprocessed_path,
def process_input_json(all_entitys: List[Entity], ccd_preprocessed_dict,
msa_templ_data_pipeline_dict, msa_output_dir,
no_msa_templ_feats=False):

## load ccd dict.
ccd_preprocessed_dict = pipeline_conf_bonds.load_ccd_dict(ccd_preprocessed_path)
all_chain_features = {}
sequence_features = {}
num_chains = 0
Expand Down
275 changes: 155 additions & 120 deletions apps/protein_folding/helixfold3/helixfold/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Inference scripts."""
from dataclasses import dataclass
import re
import os
import copy
Expand Down Expand Up @@ -42,7 +43,7 @@
from helixfold.data.tools import hmmsearch
from helixfold.data import templates
from helixfold.utils.utils import get_custom_amp_list
from typing import Dict
from typing import Dict, Mapping
from helixfold.infer_scripts import feature_processing_aa, preprocess
from helixfold.infer_scripts.tools import mmcif_writer

Expand Down Expand Up @@ -448,153 +449,187 @@ def split_prediction(pred, rank):

return prediction

@dataclass
class HelixFold:

@hydra.main(version_base=None, config_path=os.path.join(script_path,'config',),config_name='helixfold')
def main(cfg: DictConfig):
logging.set_verbosity(cfg.logging_level)
cfg: DictConfig

if cfg.msa_only == True:
logging.warning(f'Model inference will be skipped because MSA-only mode is required.')
logging.warning(f'Use CPU only')
paddle.device.set_device("cpu")

model: RunModel =None
model_config: DictConfig=None

"""main function"""
new_einsum = os.getenv("FLAGS_new_einsum", True)
print(f'>>> PaddlePaddle commit: {paddle.version.commit}')
print(f'>>> FLAGS_new_einsum: {new_einsum}')
print(f'>>> config:\n{cfg}')
ccd_dict: Mapping=None

## check maxit binary path
maxit_binary=resolve_bin_path(cfg.other.maxit_binary,'maxit')

RCSBROOT=os.path.join(os.path.dirname(maxit_binary), '..')
os.environ['RCSBROOT']=RCSBROOT
def __post_init__(self) -> None:

## check obabel
obabel_bin=resolve_bin_path(cfg.bin.obabel,'obabel')
os.environ['OBABEL_BIN']=obabel_bin
logging.set_verbosity(self.cfg.logging_level)
ccd_preprocessed_path = self.cfg.db.ccd_preprocessed
self.ccd_dict=load_ccd_dict(ccd_preprocessed_path)

all_entitys = preprocess_json_entity(cfg.input, cfg.output)

### Set seed for reproducibility
seed = cfg.seed
if seed is None:
seed = np.random.randint(10000000)
else:
logging.warning('Seed is only used for reproduction')
init_seed(seed)

use_small_bfd = cfg.preset.preset == 'reduced_dbs'
setattr(cfg, 'use_small_bfd', use_small_bfd)
if use_small_bfd:
assert cfg.db.small_bfd is not None
else:
assert cfg.db.bfd is not None
assert cfg.db.uniclust30 is not None
if self.cfg.msa_only == True:
logging.warning(f'Model inference will be skipped because MSA-only mode is required.')
logging.warning(f'Use CPU only')
paddle.device.set_device("cpu")


"""main function"""
new_einsum = os.getenv("FLAGS_new_einsum", True)
print(f'>>> PaddlePaddle commit: {paddle.version.commit}')
print(f'>>> FLAGS_new_einsum: {new_einsum}')
print(f'>>> config:\n{self.cfg}')

logging.info('Getting MSA/Template Pipelines...')
msa_templ_data_pipeline_dict = get_msa_templates_pipeline(cfg=cfg)
## check maxit binary path
maxit_binary=resolve_bin_path(self.cfg.other.maxit_binary,'maxit')

### Create model
model_config = config.model_config(cfg.CONFIG_DIFFS)
logging.warning(f'>>> Model config: \n{model_config}\n\n')
RCSBROOT=os.path.join(os.path.dirname(maxit_binary), '..')
os.environ['RCSBROOT']=RCSBROOT

## check obabel
obabel_bin=resolve_bin_path(self.cfg.bin.obabel,'obabel')
os.environ['OBABEL_BIN']=obabel_bin

use_small_bfd = self.cfg.preset.preset == 'reduced_dbs'

# fix to small bfd setting
self.cfg.use_small_bfd=use_small_bfd

model = RunModel(model_config)
if self.cfg.use_small_bfd:
assert self.cfg.db.small_bfd is not None
else:
assert self.cfg.db.bfd is not None
assert self.cfg.db.uniclust30 is not None

### Create model
self.model_config = config.model_config(self.cfg.CONFIG_DIFFS)
logging.warning(f'>>> Model config: \n{self.model_config}\n\n')

self.model = RunModel(self.model_config)

if (not cfg.weight_path is None) and (cfg.weight_path != ""):
print(f"Load pretrain model from {cfg.weight_path}")
pd_params = paddle.load(cfg.weight_path)
if (not self.cfg.weight_path is None) and (self.cfg.weight_path != ""):
print(f"Load pretrain model from {self.cfg.weight_path}")
pd_params = paddle.load(self.cfg.weight_path)

has_opt = 'optimizer' in pd_params
if has_opt:
self.model.helixfold.set_state_dict(pd_params['model'])
else:
self.model.helixfold.set_state_dict(pd_params)


if self.cfg.precision == "bf16" and self.cfg.amp_level == "O2":
raise NotImplementedError("bf16 O2 is not supported yet.")

def fold(self, entity: str):
all_entitys = preprocess_json_entity(entity, self.cfg.output)

has_opt = 'optimizer' in pd_params
if has_opt:
model.helixfold.set_state_dict(pd_params['model'])
### Set seed for reproducibility
seed = self.cfg.seed
if seed is None:
seed = np.random.randint(10000000)
else:
model.helixfold.set_state_dict(pd_params)
logging.warning('Seed is only used for reproduction')
init_seed(seed)



if cfg.precision == "bf16" and cfg.amp_level == "O2":
raise NotImplementedError("bf16 O2 is not supported yet.")

print(f"============ Data Loading ============")
job_base = pathlib.Path(cfg.input).stem
output_dir_base = pathlib.Path(cfg.output).joinpath(job_base)
msa_output_dir = output_dir_base.joinpath('msas')
msa_output_dir.mkdir(parents=True, exist_ok=True)

features_pkl = output_dir_base.joinpath('final_features.pkl')
if features_pkl.exists() and not cfg.override:
with open(features_pkl, 'rb') as f:
logging.info(f'Load features from precomputed {features_pkl}')
feature_dict = pickle.load(f)
else:
feature_dict = feature_processing_aa.process_input_json(
all_entitys,
ccd_preprocessed_path=cfg.db.ccd_preprocessed,
msa_templ_data_pipeline_dict=msa_templ_data_pipeline_dict,
msa_output_dir=msa_output_dir)

# save features
with open(features_pkl, 'wb') as f:
pickle.dump(feature_dict, f, protocol=4)

if cfg.msa_only == True:
logging.warning(f'Model inference is skipped because MSA-only mode is required.')
exit()

feature_dict['feat'] = batch_convert(feature_dict['feat'], add_batch=True)
feature_dict['label'] = batch_convert(feature_dict['label'], add_batch=True)

print(f"============ Start Inference ============")

infer_times = cfg.infer_times
if cfg.diff_batch_size > 0:
model_config.model.heads.diffusion_module.test_diff_batch_size = cfg.diff_batch_size
diff_batch_size = model_config.model.heads.diffusion_module.test_diff_batch_size
logging.info(f'Inference {infer_times} Times...')
logging.info(f"Diffusion batch size {diff_batch_size}...\n")
all_pred_path = []
for infer_id in range(infer_times):

logging.info('Getting MSA/Template Pipelines...')
msa_templ_data_pipeline_dict = get_msa_templates_pipeline(cfg=self.cfg)

print(f"============ Data Loading ============")
job_base = pathlib.Path(entity).stem
output_dir_base = pathlib.Path(self.cfg.output).joinpath(job_base)
msa_output_dir = output_dir_base.joinpath('msas')
msa_output_dir.mkdir(parents=True, exist_ok=True)

features_pkl = output_dir_base.joinpath('final_features.pkl')
if features_pkl.exists() and not self.cfg.override:
with open(features_pkl, 'rb') as f:
logging.info(f'Load features from precomputed {features_pkl}')
feature_dict = pickle.load(f)
else:
feature_dict = feature_processing_aa.process_input_json(
all_entitys,
ccd_preprocessed_dict=self.ccd_dict,
msa_templ_data_pipeline_dict=msa_templ_data_pipeline_dict,
msa_output_dir=msa_output_dir)

# save features
with open(features_pkl, 'wb') as f:
pickle.dump(feature_dict, f, protocol=4)

if self.cfg.msa_only == True:
logging.warning(f'Model inference is skipped because MSA-only mode is required.')
return

feature_dict['feat'] = batch_convert(feature_dict['feat'], add_batch=True)
feature_dict['label'] = batch_convert(feature_dict['label'], add_batch=True)

print(f"============ Start Inference ============")

logging.info(f'Start {infer_id}-th inference...\n')
prediction = eval(cfg, model, feature_dict)
infer_times = self.cfg.infer_times
if self.cfg.diff_batch_size > 0:
self.model_config.model.heads.diffusion_module.test_diff_batch_size = self.cfg.diff_batch_size
diff_batch_size = self.model_config.model.heads.diffusion_module.test_diff_batch_size
logging.info(f'Inference {infer_times} Times...')
logging.info(f"Diffusion batch size {diff_batch_size}...\n")
all_pred_path = []
for infer_id in range(infer_times):

logging.info(f'Start {infer_id}-th inference...\n')
prediction = eval(self.cfg, self.model, feature_dict)

# save result
prediction = split_prediction(prediction, diff_batch_size)
for rank_id in range(diff_batch_size):
json_name = job_base + f'-pred-{str(infer_id + 1)}-{str(rank_id + 1)}'
output_dir = pathlib.Path(output_dir_base).joinpath(json_name)
output_dir.mkdir(parents=True, exist_ok=True)
save_result(entry_name=job_base,
feature_dict=feature_dict,
prediction=prediction[rank_id],
output_dir=output_dir,
maxit_bin=self.cfg.other.maxit_binary)
all_pred_path.append(output_dir)

# save result
prediction = split_prediction(prediction, diff_batch_size)
for rank_id in range(diff_batch_size):
json_name = job_base + f'-pred-{str(infer_id + 1)}-{str(rank_id + 1)}'
output_dir = pathlib.Path(output_dir_base).joinpath(json_name)
output_dir.mkdir(parents=True, exist_ok=True)
save_result(entry_name=job_base,
feature_dict=feature_dict,
prediction=prediction[rank_id],
output_dir=output_dir,
maxit_bin=cfg.other.maxit_binary)
all_pred_path.append(output_dir)
# final ranking
print(f'============ Ranking ! ============')
ranking_all_predictions(all_pred_path)
print(f'============ Inference finished ! ============')

# final ranking
print(f'============ Ranking ! ============')
ranking_all_predictions(all_pred_path)
print(f'============ Inference finished ! ============')


@hydra.main(version_base=None, config_path=os.path.join(script_path,'config',),config_name='helixfold')
def show_atom_id_ccd(cfg: DictConfig):
ccd_preprocessed_path = cfg.db.ccd_preprocessed
def main(cfg: DictConfig):

hf_runner=HelixFold(cfg=cfg)

ccd_id=cfg.ccd_id
if len(ccd_id) <= 3 and ccd_id in (ccd_dict:=load_ccd_dict(ccd_preprocessed_path)):
logging.info(f'Atoms in {ccd_id}: {ccd_dict[ccd_id]["atom_ids"]}')
if os.path.isfile(cfg.input):
logging.info(f'Starting inference on {cfg.input}')
hf_runner.fold(cfg.input)
return
elif os.path.isdir(cfg.input):
logging.info(f'Starting inference on all files in {cfg.input}')
for f in [i for i in os.listdir(cfg.input) if any(i.endswith(p) for p in ['json', 'jsonl', 'json.gz', 'jsonl.gz'])]:
logging.info(f'Processing {f}')
try:
hf_runner.fold(os.path.join(cfg.input,f))
except (ValueError, AssertionError, RuntimeError, FileNotFoundError, MemoryError) as e :
logging.error(f'Error processing {f}: {e}')
continue

return

raise ValueError(f'Invalid input file/dir: {cfg.input}')




@hydra.main(version_base=None, config_path=os.path.join(script_path,'config',),config_name='helixfold')
def show_atom_id_ccd(cfg: DictConfig):

ccd_preprocessed_path = cfg.db.ccd_preprocessed

ccd_id=cfg.ccd_id
if len(ccd_id) <= 3 and ccd_id in (ccd_dict:=load_ccd_dict(ccd_preprocessed_path)):
logging.info(f'Atoms in {ccd_id}: {ccd_dict[ccd_id]["atom_ids"]}')
return


if __name__ == '__main__':
Expand Down

0 comments on commit 4daee6c

Please sign in to comment.