Skip to content

Commit

Permalink
revert: script links
Browse files Browse the repository at this point in the history
  • Loading branch information
YaoYinYing committed Aug 30, 2024
1 parent 6dc4554 commit 5e3824f
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ nproc_msa:
input: null # Input file/Directory, required field. If it's a directory, run HF3 against all files under input directory.
output: null # Corresponds to --output_dir, required field
override: false # Set true to override existing msa output directory


# Binary tool paths, leave them as null to find proper ones under PATH or conda bin path
bin:
Expand Down
54 changes: 27 additions & 27 deletions apps/protein_folding/helixfold3/helixfold/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
from helixfold.data.tools.utils import timing
from helixfold.utils.utils import get_custom_amp_list
from typing import Dict, Mapping
from helixfold.infer_scripts import feature_processing_aa, preprocess
from helixfold.infer_scripts.tools import mmcif_writer
from helixfold.utils import feature_processing_aa, preprocess
from helixfold.utils import mmcif_writer


script_path=os.path.dirname(__file__)
Expand Down Expand Up @@ -495,7 +495,7 @@ def split_prediction(pred, rank):
return prediction

@dataclass
class HelixFold:
class HelixFoldRunner:

cfg: DictConfig

Expand Down Expand Up @@ -580,11 +580,16 @@ def fold(self, entity: str):
init_seed(seed)




print(f"============ Data Loading ============")
job_base = pathlib.Path(entity).stem
output_dir_base = pathlib.Path(self.cfg.output).joinpath(job_base)

expected_res=os.path.join(self.cfg.output, job_base, f'{job_base}-rank1','all_results.json')
if os.path.isfile(expected_res):
logging.warning(f'Skip {job_base} because {expected_res} exists')
return


msa_output_dir = output_dir_base.joinpath('msas')
msa_output_dir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -659,36 +664,31 @@ def cleanup_ramdisk(self, ramdisk_path: str = "/dev/shm"):
@hydra.main(version_base=None, config_path=os.path.join(script_path,'config',),config_name='helixfold')
def main(cfg: DictConfig):

hf_runner=HelixFold(cfg=cfg)
hf_runner=HelixFoldRunner(cfg=cfg)

try:

if os.path.isfile(cfg.input):
logging.info(f'Starting inference on {cfg.input}')

if os.path.isfile(cfg.input):
logging.info(f'Starting inference on {cfg.input}')
try:
hf_runner.fold(cfg.input)
return
elif os.path.isdir(cfg.input):
logging.info(f'Starting inference on all files in {cfg.input}')
for f in [i for i in os.listdir(cfg.input) if any(i.endswith(p) for p in ['json', 'jsonl', 'json.gz', 'jsonl.gz'])]:
logging.info(f'Processing {f}')
try:
hf_runner.fold(os.path.join(cfg.input,f))
except (ValueError, AssertionError, RuntimeError, FileNotFoundError, MemoryError) as e :
logging.error(f'Error processing {f}: {e}')
continue
except (ValueError, AssertionError, RuntimeError, FileNotFoundError, MemoryError) as e :
logging.error(f'Error processing {cfg.input}: {e}')
elif os.path.isdir(cfg.input):
logging.info(f'Starting inference on all files in {cfg.input}')
for f in [i for i in os.listdir(cfg.input) if any(i.endswith(p) for p in ['json', 'jsonl', 'json.gz', 'jsonl.gz'])]:
logging.info(f'Processing {f}')
try:
hf_runner.fold(os.path.join(cfg.input,f))
except (ValueError, AssertionError, RuntimeError, FileNotFoundError, MemoryError, OSError) as e :
logging.error(f'Error processing {f}: {e}')
continue


except Exception as e:
logging.error(f'Error processing {cfg.input}: {e}')

finally:
hf_runner.cleanup_ramdisk()
return hf_runner.cleanup_ramdisk()



raise ValueError(f'Invalid input file/dir: {cfg.input}')


@hydra.main(version_base=None, config_path=os.path.join(script_path,'config',),config_name='helixfold')
def show_atom_id_ccd(cfg: DictConfig):

Expand Down

0 comments on commit 5e3824f

Please sign in to comment.