Skip to content

Commit

Permalink
fix:config
Browse files Browse the repository at this point in the history
  • Loading branch information
YaoYinYing committed Aug 17, 2024
1 parent a3d62e2 commit 6e39caa
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ CONFIG_DIFFS:
model:
global_config:
subbatch_size: 96 # model.global_config.subbatch_size
num_recycle: 3 # model.num_recycle
heads:
confidence_head:
weight: 0.0 # model.heads.confidence_head.weight
num_recycle: 3 # model.num_recycle
heads:
confidence_head:
weight: 0.0 # model.heads.confidence_head.weight
3 changes: 1 addition & 2 deletions apps/protein_folding/helixfold3/helixfold/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,7 @@ def main(cfg: DictConfig):
msa_templ_data_pipeline_dict = get_msa_templates_pipeline(cfg=cfg)

### Create model
model_config = config.model_config(cfg.job_id)
#print(f'>>> model_config:\n{model_config}')
model_config = config.model_config(cfg.CONFIG_DIFFS)

model = RunModel(model_config)

Expand Down
15 changes: 7 additions & 8 deletions apps/protein_folding/helixfold3/helixfold/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,17 @@ def model_config(config_diffs: Union[str, DictConfig, dict[str, dict[str, Any]]]
print(f'Updated config from `CONFIG_DIFFS.{preset_name}`: {updated_config}')

return cfg
if 'model' in config_diffs and (updated_config := config_diffs['model']) is not None:
cfg.update(dict(updated_config))
print(f'Updated config from `CONFIG_DIFFS`: {updated_config}')


# update from detailed configuration
if any(root_kw in config_diffs for root_kw in CONFIG_ALLATOM):
for root_kw in CONFIG_ALLATOM:
if root_kw in config_diffs and (updated_config := config_diffs[root_kw]) is not None:
cfg.update(dict(updated_config))
print(f'Updated config from `CONFIG_DIFFS`: {updated_config}')
return cfg

raise ValueError(f'Invalid config_diffs ({type(config_diffs)}): {config_diffs}')






# preset for runs
CONFIG_DIFFS = {
Expand Down

0 comments on commit 6e39caa

Please sign in to comment.