diff --git a/src/fairchem/core/scripts/create_uma_finetune_dataset.py b/src/fairchem/core/scripts/create_uma_finetune_dataset.py index 95ec64c1a4..7c3fd64035 100644 --- a/src/fairchem/core/scripts/create_uma_finetune_dataset.py +++ b/src/fairchem/core/scripts/create_uma_finetune_dataset.py @@ -34,8 +34,11 @@ def create_yaml( dataset_name: str, regression_tasks: str, base_model_name: str, + uma_finetune_yaml: Path | str | None = None, + data_task_yaml: Path | str | None = None, ): - data_task_yaml = TEMPLATE_DIR / REGRESSION_LABEL_TO_TASK_YAML[regression_tasks] + if data_task_yaml is None: + data_task_yaml = TEMPLATE_DIR / REGRESSION_LABEL_TO_TASK_YAML[regression_tasks] with open(data_task_yaml) as file: template = yaml.safe_load(file) template["dataset_name"] = dataset_name @@ -52,7 +55,8 @@ def create_yaml( ) as yaml_file: yaml.dump(template, yaml_file, default_flow_style=False, sort_keys=False) - uma_finetune_yaml = TEMPLATE_DIR / UMA_SM_FINETUNE_YAML + if uma_finetune_yaml is None: + uma_finetune_yaml = TEMPLATE_DIR / UMA_SM_FINETUNE_YAML with open(uma_finetune_yaml) as file: template_ft = yaml.safe_load(file) template_ft["base_model_name"] = base_model_name @@ -115,6 +119,20 @@ def create_yaml( default=8, help="Number of parallel workers for processing files.", ) + parser.add_argument( + "--uma-finetune-yaml", + type=str, + default=None, + help="Path to the finetuning template file, " + "if not specified defaults to template file within the repository.", + ) + parser.add_argument( + "--data-task-yaml", + type=str, + default=None, + help="Path to the template file for the dataset specification, " + "if not specified defaults to template file within the repository based on regression task.", + ) args = parser.parse_args() assert not Path( args.output_dir @@ -141,6 +159,8 @@ def create_yaml( dataset_name=args.uma_task, regression_tasks=args.regression_tasks, base_model_name=args.base_model, + uma_finetune_yaml=args.uma_finetune_yaml, + data_task_yaml=args.data_task_yaml ) logging.info(f"Generated dataset and data config yaml in {args.output_dir}") logging.info(