Skip to content

Commit

Permalink
fix: config overriden
Browse files Browse the repository at this point in the history
  • Loading branch information
YaoYinYing committed Aug 17, 2024
1 parent 84a550c commit 21d4108
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 19 deletions.
14 changes: 11 additions & 3 deletions apps/protein_folding/helixfold3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,24 @@ helixfold --config-dir=. --config-name=myfold \
output=. CONFIG_DIFFS.preset=allatom_demo
```
##### Run with additional configuration term
```shell
LD_LIBRARY_PATH=/mnt/data/envs/conda_env/envs/helixfold/lib/:$LD_LIBRARY_PATH \
helixfold \
input=/repo/PaddleHelix/apps/protein_folding/helixfold3/data/demo_6zcy.json \
output=. \
CONFIG_DIFFS.model.heads.confidence_head.weight=0.01 \
CONFIG_DIFFS.model.global_config.subbatch_size=192
```
The descriptions of the above script are as follows:
* `LD_LIBRARY_PATH` - This is required to load the `libcudnn.so` library if you encounter issue like `RuntimeError: (PreconditionNotMet) Cannot load cudnn shared library. Cannot invoke method cudnnGetVersion.`
* `config-dir` - The directory that contains the alterative configuration file you would like to use.
* `config-name` - The name of the configuration file you would like to use.
* `input` - Input data in the form of JSON. Input pattern in `./data/demo_*.json` for your reference.
* `output` - Model output path. The output will be in a folder named the same as your `--input_json` under this path.
* `--CONFIG_DIFFS.preset` - Model name in `./helixfold/model/config.py`. Different model names specify different configurations. Mirro modification to configuration can be specified in `CONFIG_DIFFS` in the `config.py` without change to the full configuration in `CONFIG_ALLATOM`.
* `CONFIG_DIFFS.preset` - Model name in `./helixfold/model/config.py`. Different model names specify different configurations. Mirro modification to configuration can be specified in `CONFIG_DIFFS` in the `config.py` without change to the full configuration in `CONFIG_ALLATOM`.
* `CONFIG_DIFFS.*` - Override model any configuration in `CONFIG_ALLATOM`.
### Understanding Model Output
Expand Down
1 change: 1 addition & 0 deletions apps/protein_folding/helixfold3/helixfold/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ def main(cfg: DictConfig):

### Create model
model_config = config.model_config(cfg.CONFIG_DIFFS)
logging.warning(f'>>> Model config: \n{model_config}\n\n')

model = RunModel(model_config)

Expand Down
30 changes: 14 additions & 16 deletions apps/protein_folding/helixfold3/helixfold/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import copy
from typing import Any, Union
import ml_collections
from omegaconf import DictConfig


Expand All @@ -26,7 +25,7 @@
NUM_TEMPLATES = 'num templates placeholder'


def model_config(config_diffs: Union[str, DictConfig, dict[str, dict[str, Any]]]) -> ml_collections.ConfigDict:
def model_config(config_diffs: Union[str, DictConfig, dict[str, dict[str, Any]]]) -> DictConfig:
"""Get the ConfigDict of a model."""

cfg = copy.deepcopy(CONFIG_ALLATOM)
Expand All @@ -37,34 +36,33 @@ def model_config(config_diffs: Union[str, DictConfig, dict[str, dict[str, Any]]]
if isinstance(config_diffs, DictConfig):
if 'preset' in config_diffs and (preset_name:=config_diffs['preset']) in CONFIG_DIFFS:
updated_config=CONFIG_DIFFS[preset_name]
cfg.update_from_flattened_dict(updated_config)
cfg.merge_with_dotlist(updated_config)
print(f'Updated config from `CONFIG_DIFFS.{preset_name}`: {updated_config}')

return cfg

# update from detailed configuration
if any(root_kw in config_diffs for root_kw in CONFIG_ALLATOM):
for root_kw in CONFIG_ALLATOM:
if root_kw in config_diffs and (updated_config := config_diffs[root_kw]) is not None:
cfg.update(dict(updated_config))
print(f'Updated config from `CONFIG_DIFFS`: {updated_config}')

cfg.merge_with(config_diffs) # merge to override
print(f'Updated config from `CONFIG_DIFFS`: {config_diffs}')
return cfg

raise ValueError(f'Invalid config_diffs ({type(config_diffs)}): {config_diffs}')


# preset for runs
CONFIG_DIFFS = {
'allatom_demo': {
'model.heads.confidence_head.weight': 0.01
},
'allatom_subbatch_64_recycle_1': {
'model.global_config.subbatch_size': 64,
'model.num_recycle': 1,
},
CONFIG_DIFFS: dict[str, list[str]] = {
'allatom_demo': [
'model.heads.confidence_head.weight=0.01'
],
'allatom_subbatch_64_recycle_1': [
'model.global_config.subbatch_size=64',
'model.num_recycle=1',
]
}

CONFIG_ALLATOM = ml_collections.ConfigDict({
CONFIG_ALLATOM = DictConfig({
'data': {
'num_blocks': 5, # for msa block deletion
'randomize_num_blocks': True,
Expand Down

0 comments on commit 21d4108

Please sign in to comment.