Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions src/fairchem/core/scripts/create_uma_finetune_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ def create_yaml(
dataset_name: str,
regression_tasks: str,
base_model_name: str,
uma_finetune_yaml: Path | str | None = None,
data_task_yaml: Path | str | None = None,
):
data_task_yaml = TEMPLATE_DIR / REGRESSION_LABEL_TO_TASK_YAML[regression_tasks]
if data_task_yaml is None:
data_task_yaml = TEMPLATE_DIR / REGRESSION_LABEL_TO_TASK_YAML[regression_tasks]
with open(data_task_yaml) as file:
template = yaml.safe_load(file)
template["dataset_name"] = dataset_name
Expand All @@ -52,7 +55,8 @@ def create_yaml(
) as yaml_file:
yaml.dump(template, yaml_file, default_flow_style=False, sort_keys=False)

uma_finetune_yaml = TEMPLATE_DIR / UMA_SM_FINETUNE_YAML
if uma_finetune_yaml is None:
uma_finetune_yaml = TEMPLATE_DIR / UMA_SM_FINETUNE_YAML
with open(uma_finetune_yaml) as file:
template_ft = yaml.safe_load(file)
template_ft["base_model_name"] = base_model_name
Expand Down Expand Up @@ -115,6 +119,20 @@ def create_yaml(
default=8,
help="Number of parallel workers for processing files.",
)
parser.add_argument(
"--uma-finetune-yaml",
type=str,
default=None,
help="Path to the finetuning template file, "
"if not specified defaults to template file within the repository.",
)
parser.add_argument(
"--data-task-yaml",
type=str,
default=None,
help="Path to the template file for the dataset specification, "
"if not specified defaults to template file within the repository based on regression task.",
)
args = parser.parse_args()
assert not Path(
args.output_dir
Expand All @@ -141,6 +159,8 @@ def create_yaml(
dataset_name=args.uma_task,
regression_tasks=args.regression_tasks,
base_model_name=args.base_model,
uma_finetune_yaml=args.uma_finetune_yaml,
data_task_yaml=args.data_task_yaml
)
logging.info(f"Generated dataset and data config yaml in {args.output_dir}")
logging.info(
Expand Down