Skip to content

Commit

Permalink
add configs
Browse files Browse the repository at this point in the history
  • Loading branch information
YaoYinYing committed Aug 17, 2024
1 parent d528300 commit a3d62e2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
12 changes: 12 additions & 0 deletions apps/protein_folding/helixfold3/helixfold/config/helixfold.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,15 @@ preset:
# Other configurations
other:
maxit_binary: /mnt/data/yinying/software/maxit/maxit-v11.100-prod-src/bin/maxit # Corresponds to --maxit_binary


# CONFIG_DIFFS for advanced configuration
CONFIG_DIFFS:
preset: null #choices=['null','allatom_demo', 'allatom_subbatch_64_recycle_1']
model:
global_config:
subbatch_size: 96 # model.global_config.subbatch_size
num_recycle: 3 # model.num_recycle
heads:
confidence_head:
weight: 0.0 # model.heads.confidence_head.weight
29 changes: 25 additions & 4 deletions apps/protein_folding/helixfold3/helixfold/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""Model config."""

import copy
from typing import Any, Union
import ml_collections
from omegaconf import DictConfig


NUM_RES = 'num residues placeholder'
Expand All @@ -24,16 +26,35 @@
NUM_TEMPLATES = 'num templates placeholder'


def model_config(name: str) -> ml_collections.ConfigDict:
def model_config(config_diffs: Union[str, DictConfig, dict[str, dict[str, Any]]]) -> ml_collections.ConfigDict:
"""Get the ConfigDict of a model."""

cfg = copy.deepcopy(CONFIG_ALLATOM)
if name in CONFIG_DIFFS:
cfg.update_from_flattened_dict(CONFIG_DIFFS[name])
if config_diffs is None or config_diffs=='':
# early return if nothing is changed
return cfg

return cfg
if isinstance(config_diffs, DictConfig):
if 'preset' in config_diffs and (preset_name:=config_diffs['preset']) in CONFIG_DIFFS:
updated_config=CONFIG_DIFFS[preset_name]
cfg.update_from_flattened_dict(updated_config)
print(f'Updated config from `CONFIG_DIFFS.{preset_name}`: {updated_config}')

return cfg
if 'model' in config_diffs and (updated_config := config_diffs['model']) is not None:
cfg.update(dict(updated_config))
print(f'Updated config from `CONFIG_DIFFS`: {updated_config}')

return cfg

raise ValueError(f'Invalid config_diffs ({type(config_diffs)}): {config_diffs}')






# preset for runs
CONFIG_DIFFS = {
'allatom_demo': {
'model.heads.confidence_head.weight': 0.01
Expand Down

0 comments on commit a3d62e2

Please sign in to comment.