diff --git a/apps/protein_folding/helixfold3/README.md b/apps/protein_folding/helixfold3/README.md index 65c8aa30..66a8756c 100644 --- a/apps/protein_folding/helixfold3/README.md +++ b/apps/protein_folding/helixfold3/README.md @@ -308,7 +308,7 @@ The descriptions of the above script are as follows: * `config-dir` - The directory that contains the alterative configuration file you would like to use. * `config-name` - The name of the configuration file you would like to use. * `input` - Input data in the form of JSON or directory that contains such JSON file(s). For file input, check content pattern in `./data/demo_*.json` for your reference. -* `output` - Model output path. The output will be in a folder named the same as your `--input_json` under this path. +* `output` - Model output path. The output will be in a folder named the same as your input json file under this path. * `CONFIG_DIFFS.preset` - Adjusted model config preset name in `./helixfold/model/config.py:CONFIG_DIFFS`. The preset will be updated into final model configuration with `CONFIG_ALLATOM`. * `CONFIG_DIFFS.*` - Override model any configuration in `CONFIG_ALLATOM`. diff --git a/apps/protein_folding/helixfold3/helixfold/config/helixfold.yaml b/apps/protein_folding/helixfold3/helixfold/config/helixfold.yaml index fb07421c..e353d5cf 100644 --- a/apps/protein_folding/helixfold3/helixfold/config/helixfold.yaml +++ b/apps/protein_folding/helixfold3/helixfold/config/helixfold.yaml @@ -13,10 +13,16 @@ infer_times: 1 # Corresponds to --infer_times diff_batch_size: -1 # Corresponds to --diff_batch_size use_small_bfd: false # Corresponds to --use_small_bfd msa_only: false # Only process msa +ramdisk: # Force to load database to ram disk + uniprot: false # 111 GB + uniref90: false # 67 GB + mgnify: false # 64 GB + + nproc_msa: hhblits: 16 # Number of processors used by hhblits - jackhmmer: 8 # Number of processors used by jackhmmer + jackhmmer: 6 # Number of processors used by jackhmmer # File paths diff --git a/apps/protein_folding/helixfold3/helixfold/inference.py b/apps/protein_folding/helixfold3/helixfold/inference.py index dc9eb85b..ff57a044 100644 --- a/apps/protein_folding/helixfold3/helixfold/inference.py +++ b/apps/protein_folding/helixfold3/helixfold/inference.py @@ -42,6 +42,7 @@ from helixfold.utils.model import RunModel from helixfold.data.tools import hmmsearch from helixfold.data import templates +from helixfold.data.tools.utils import timing from helixfold.utils.utils import get_custom_amp_list from typing import Dict, Mapping from helixfold.infer_scripts import feature_processing_aa, preprocess @@ -124,8 +125,42 @@ def resolve_bin_path(cfg_path: str, default_binary_name: str)-> str: raise FileNotFoundError(f"Could not find a proper binary path for {default_binary_name}: {cfg_path}.") + + +def load_to_dev_shm(file_path: str, ramdisk_path: str = "/dev/shm") -> str: + """ + Copies a file to /dev/shm (RAM-backed filesystem) and returns the path. + + :param file_path: The path to the large file on the disk. + :return: The path to the file in /dev/shm. + """ + if not os.path.isfile(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + # Ensure the RAM disk path exists and is a directory + if not os.path.isdir(ramdisk_path): + raise NotADirectoryError(f"RAM disk path not found or not a directory: {ramdisk_path}") + + + target_path = os.path.join(ramdisk_path, pathlib.Path(file_path).name) + with timing(f'loading {file_path} -> {target_path}'): + shutil.copy(file_path, target_path) + + return target_path + + def get_msa_templates_pipeline(cfg: DictConfig) -> Dict: use_precomputed_msas = True # Assuming this is a constant or should be set globally + + + if cfg.ramdisk.uniprot: + cfg.db.uniprot=load_to_dev_shm(cfg.db.uniprot) + + if cfg.ramdisk.uniref90: + cfg.db.uniref90=load_to_dev_shm(cfg.db.uniref90) + + if cfg.ramdisk.mgnify: + cfg.db.mgnify=load_to_dev_shm(cfg.db.mgnify) template_searcher = hmmsearch.Hmmsearch( binary_path=resolve_bin_path(cfg.bin.hmmsearch, 'hmmsearch'), @@ -156,6 +191,8 @@ def get_msa_templates_pipeline(cfg: DictConfig) -> Dict: nprocs=cfg.nproc_msa, ) + + prot_data_pipeline = pipeline_multimer.DataPipeline( monomer_data_pipeline=monomer_data_pipeline, jackhmmer_binary_path=resolve_bin_path(cfg.bin.jackhmmer, 'jackhmmer'), @@ -458,6 +495,7 @@ class HelixFold: model_config: DictConfig=None ccd_dict: Mapping=None + msa_templ_data_pipeline_dict: Mapping=None def __post_init__(self) -> None: @@ -515,6 +553,9 @@ def __post_init__(self) -> None: else: self.model.helixfold.set_state_dict(pd_params) + logging.info('Getting MSA/Template Pipelines...') + self.msa_templ_data_pipeline_dict = get_msa_templates_pipeline(cfg=self.cfg) + if self.cfg.precision == "bf16" and self.cfg.amp_level == "O2": raise NotImplementedError("bf16 O2 is not supported yet.") @@ -531,8 +572,7 @@ def fold(self, entity: str): init_seed(seed) - logging.info('Getting MSA/Template Pipelines...') - msa_templ_data_pipeline_dict = get_msa_templates_pipeline(cfg=self.cfg) + print(f"============ Data Loading ============") job_base = pathlib.Path(entity).stem @@ -549,7 +589,7 @@ def fold(self, entity: str): feature_dict = feature_processing_aa.process_input_json( all_entitys, ccd_preprocessed_dict=self.ccd_dict, - msa_templ_data_pipeline_dict=msa_templ_data_pipeline_dict, + msa_templ_data_pipeline_dict=self.msa_templ_data_pipeline_dict, msa_output_dir=msa_output_dir) # save features @@ -596,27 +636,43 @@ def fold(self, entity: str): print(f'============ Inference finished ! ============') + def cleanup_ramdisk(self, ramdisk_path: str = "/dev/shm"): + for db_fasta in [db for db in (self.cfg.db.uniprot, self.cfg.db.uniref90, self.cfg.db.mgnify,) if db.startswith(ramdisk_path)]: + try: + os.unlink(db_fasta) + except Exception as e: + logging.error(f"Failed to delete {db_fasta} from ram disk. Reason: {e}") + @hydra.main(version_base=None, config_path=os.path.join(script_path,'config',),config_name='helixfold') def main(cfg: DictConfig): hf_runner=HelixFold(cfg=cfg) - if os.path.isfile(cfg.input): - logging.info(f'Starting inference on {cfg.input}') - hf_runner.fold(cfg.input) - return - elif os.path.isdir(cfg.input): - logging.info(f'Starting inference on all files in {cfg.input}') - for f in [i for i in os.listdir(cfg.input) if any(i.endswith(p) for p in ['json', 'jsonl', 'json.gz', 'jsonl.gz'])]: - logging.info(f'Processing {f}') - try: - hf_runner.fold(os.path.join(cfg.input,f)) - except (ValueError, AssertionError, RuntimeError, FileNotFoundError, MemoryError) as e : - logging.error(f'Error processing {f}: {e}') - continue + try: + + if os.path.isfile(cfg.input): + logging.info(f'Starting inference on {cfg.input}') + hf_runner.fold(cfg.input) + return + elif os.path.isdir(cfg.input): + logging.info(f'Starting inference on all files in {cfg.input}') + for f in [i for i in os.listdir(cfg.input) if any(i.endswith(p) for p in ['json', 'jsonl', 'json.gz', 'jsonl.gz'])]: + logging.info(f'Processing {f}') + try: + hf_runner.fold(os.path.join(cfg.input,f)) + except (ValueError, AssertionError, RuntimeError, FileNotFoundError, MemoryError) as e : + logging.error(f'Error processing {f}: {e}') + continue + + + except Exception as e: + logging.error(f'Error processing {cfg.input}: {e}') + + finally: + hf_runner.cleanup_ramdisk() + - return raise ValueError(f'Invalid input file/dir: {cfg.input}')