Skip to content

Commit

Permalink
chore: db to ramdisk
Browse files Browse the repository at this point in the history
  • Loading branch information
YaoYinYing committed Aug 28, 2024
1 parent b23e7eb commit 7dc26a7
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 19 deletions.
2 changes: 1 addition & 1 deletion apps/protein_folding/helixfold3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ The descriptions of the above script are as follows:
* `config-dir` - The directory that contains the alterative configuration file you would like to use.
* `config-name` - The name of the configuration file you would like to use.
* `input` - Input data in the form of JSON or directory that contains such JSON file(s). For file input, check content pattern in `./data/demo_*.json` for your reference.
* `output` - Model output path. The output will be in a folder named the same as your `--input_json` under this path.
* `output` - Model output path. The output will be in a folder named the same as your input json file under this path.
* `CONFIG_DIFFS.preset` - Adjusted model config preset name in `./helixfold/model/config.py:CONFIG_DIFFS`. The preset will be updated into final model configuration with `CONFIG_ALLATOM`.
* `CONFIG_DIFFS.*` - Override model any configuration in `CONFIG_ALLATOM`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,16 @@ infer_times: 1 # Corresponds to --infer_times
diff_batch_size: -1 # Corresponds to --diff_batch_size
use_small_bfd: false # Corresponds to --use_small_bfd
msa_only: false # Only process msa
ramdisk: # Force to load database to ram disk
uniprot: false # 111 GB
uniref90: false # 67 GB
mgnify: false # 64 GB



nproc_msa:
hhblits: 16 # Number of processors used by hhblits
jackhmmer: 8 # Number of processors used by jackhmmer
jackhmmer: 6 # Number of processors used by jackhmmer

# File paths

Expand Down
90 changes: 73 additions & 17 deletions apps/protein_folding/helixfold3/helixfold/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from helixfold.utils.model import RunModel
from helixfold.data.tools import hmmsearch
from helixfold.data import templates
from helixfold.data.tools.utils import timing
from helixfold.utils.utils import get_custom_amp_list
from typing import Dict, Mapping
from helixfold.infer_scripts import feature_processing_aa, preprocess
Expand Down Expand Up @@ -124,8 +125,42 @@ def resolve_bin_path(cfg_path: str, default_binary_name: str)-> str:

raise FileNotFoundError(f"Could not find a proper binary path for {default_binary_name}: {cfg_path}.")



def load_to_dev_shm(file_path: str, ramdisk_path: str = "/dev/shm") -> str:
"""
Copies a file to /dev/shm (RAM-backed filesystem) and returns the path.
:param file_path: The path to the large file on the disk.
:return: The path to the file in /dev/shm.
"""
if not os.path.isfile(file_path):
raise FileNotFoundError(f"File not found: {file_path}")

# Ensure the RAM disk path exists and is a directory
if not os.path.isdir(ramdisk_path):
raise NotADirectoryError(f"RAM disk path not found or not a directory: {ramdisk_path}")


target_path = os.path.join(ramdisk_path, pathlib.Path(file_path).name)
with timing(f'loading {file_path} -> {target_path}'):
shutil.copy(file_path, target_path)

return target_path


def get_msa_templates_pipeline(cfg: DictConfig) -> Dict:
use_precomputed_msas = True # Assuming this is a constant or should be set globally


if cfg.ramdisk.uniprot:
cfg.db.uniprot=load_to_dev_shm(cfg.db.uniprot)

if cfg.ramdisk.uniref90:
cfg.db.uniref90=load_to_dev_shm(cfg.db.uniref90)

if cfg.ramdisk.mgnify:
cfg.db.mgnify=load_to_dev_shm(cfg.db.mgnify)

template_searcher = hmmsearch.Hmmsearch(
binary_path=resolve_bin_path(cfg.bin.hmmsearch, 'hmmsearch'),
Expand Down Expand Up @@ -156,6 +191,8 @@ def get_msa_templates_pipeline(cfg: DictConfig) -> Dict:
nprocs=cfg.nproc_msa,
)



prot_data_pipeline = pipeline_multimer.DataPipeline(
monomer_data_pipeline=monomer_data_pipeline,
jackhmmer_binary_path=resolve_bin_path(cfg.bin.jackhmmer, 'jackhmmer'),
Expand Down Expand Up @@ -458,6 +495,7 @@ class HelixFold:
model_config: DictConfig=None

ccd_dict: Mapping=None
msa_templ_data_pipeline_dict: Mapping=None

def __post_init__(self) -> None:

Expand Down Expand Up @@ -515,6 +553,9 @@ def __post_init__(self) -> None:
else:
self.model.helixfold.set_state_dict(pd_params)

logging.info('Getting MSA/Template Pipelines...')
self.msa_templ_data_pipeline_dict = get_msa_templates_pipeline(cfg=self.cfg)


if self.cfg.precision == "bf16" and self.cfg.amp_level == "O2":
raise NotImplementedError("bf16 O2 is not supported yet.")
Expand All @@ -531,8 +572,7 @@ def fold(self, entity: str):
init_seed(seed)


logging.info('Getting MSA/Template Pipelines...')
msa_templ_data_pipeline_dict = get_msa_templates_pipeline(cfg=self.cfg)


print(f"============ Data Loading ============")
job_base = pathlib.Path(entity).stem
Expand All @@ -549,7 +589,7 @@ def fold(self, entity: str):
feature_dict = feature_processing_aa.process_input_json(
all_entitys,
ccd_preprocessed_dict=self.ccd_dict,
msa_templ_data_pipeline_dict=msa_templ_data_pipeline_dict,
msa_templ_data_pipeline_dict=self.msa_templ_data_pipeline_dict,
msa_output_dir=msa_output_dir)

# save features
Expand Down Expand Up @@ -596,27 +636,43 @@ def fold(self, entity: str):
print(f'============ Inference finished ! ============')


def cleanup_ramdisk(self, ramdisk_path: str = "/dev/shm"):
for db_fasta in [db for db in (self.cfg.db.uniprot, self.cfg.db.uniref90, self.cfg.db.mgnify,) if db.startswith(ramdisk_path)]:
try:
os.unlink(db_fasta)
except Exception as e:
logging.error(f"Failed to delete {db_fasta} from ram disk. Reason: {e}")


@hydra.main(version_base=None, config_path=os.path.join(script_path,'config',),config_name='helixfold')
def main(cfg: DictConfig):

hf_runner=HelixFold(cfg=cfg)

if os.path.isfile(cfg.input):
logging.info(f'Starting inference on {cfg.input}')
hf_runner.fold(cfg.input)
return
elif os.path.isdir(cfg.input):
logging.info(f'Starting inference on all files in {cfg.input}')
for f in [i for i in os.listdir(cfg.input) if any(i.endswith(p) for p in ['json', 'jsonl', 'json.gz', 'jsonl.gz'])]:
logging.info(f'Processing {f}')
try:
hf_runner.fold(os.path.join(cfg.input,f))
except (ValueError, AssertionError, RuntimeError, FileNotFoundError, MemoryError) as e :
logging.error(f'Error processing {f}: {e}')
continue
try:

if os.path.isfile(cfg.input):
logging.info(f'Starting inference on {cfg.input}')
hf_runner.fold(cfg.input)
return
elif os.path.isdir(cfg.input):
logging.info(f'Starting inference on all files in {cfg.input}')
for f in [i for i in os.listdir(cfg.input) if any(i.endswith(p) for p in ['json', 'jsonl', 'json.gz', 'jsonl.gz'])]:
logging.info(f'Processing {f}')
try:
hf_runner.fold(os.path.join(cfg.input,f))
except (ValueError, AssertionError, RuntimeError, FileNotFoundError, MemoryError) as e :
logging.error(f'Error processing {f}: {e}')
continue


except Exception as e:
logging.error(f'Error processing {cfg.input}: {e}')

finally:
hf_runner.cleanup_ramdisk()


return

raise ValueError(f'Invalid input file/dir: {cfg.input}')

Expand Down

0 comments on commit 7dc26a7

Please sign in to comment.