Skip to content

Commit 75b81fb

Browse files
committed
compat
1 parent b687e36 commit 75b81fb

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

deepmd/utils/compat.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,69 @@ def deprecate_numb_test(
370370
return jdata
371371

372372

373+
def migrate_training_warmup(
374+
jdata: dict[str, Any], warning: bool = True
375+
) -> dict[str, Any]:
376+
"""
377+
Migrate legacy warmup settings from training to learning_rate.
378+
379+
Parameters
380+
----------
381+
jdata : dict[str, Any]
382+
Input configuration dictionary.
383+
warning : bool, optional
384+
Whether to show a deprecation warning, by default True.
385+
386+
Returns
387+
-------
388+
dict[str, Any]
389+
Updated configuration dictionary.
390+
"""
391+
training = jdata.get("training")
392+
if not isinstance(training, dict):
393+
return jdata
394+
395+
warmup_keys = ("warmup_steps", "warmup_ratio", "warmup_start_factor")
396+
legacy_keys = [key for key in warmup_keys if key in training]
397+
if not legacy_keys:
398+
return jdata
399+
400+
lr = jdata.get("learning_rate")
401+
if not isinstance(lr, dict):
402+
for key in legacy_keys:
403+
training.pop(key)
404+
if warning:
405+
warnings.warn(
406+
"Found legacy warmup settings under training, but learning_rate "
407+
"is missing or invalid. The warmup keys were removed from training."
408+
)
409+
return jdata
410+
411+
moved_keys = []
412+
skipped_keys = []
413+
# === Step 1. Move legacy warmup keys ===
414+
for key in legacy_keys:
415+
value = training.pop(key)
416+
if key in lr:
417+
skipped_keys.append(key)
418+
continue
419+
lr[key] = value
420+
moved_keys.append(key)
421+
422+
if warning:
423+
if skipped_keys:
424+
warnings.warn(
425+
"Legacy warmup settings under training were ignored because "
426+
f"learning_rate already defines them: {', '.join(skipped_keys)}."
427+
)
428+
else:
429+
warnings.warn(
430+
"Legacy warmup settings under training were moved to learning_rate: "
431+
f"{', '.join(moved_keys)}."
432+
)
433+
return jdata
434+
435+
373436
def update_deepmd_input(
374437
jdata: dict[str, Any], warning: bool = True, dump: str | Path | None = None
375438
) -> dict[str, Any]:
@@ -389,4 +452,5 @@ def is_deepmd_v1_input(jdata: dict[str, Any]) -> bool:
389452
else:
390453
jdata = deprecate_numb_test(jdata, warning, dump)
391454

455+
jdata = migrate_training_warmup(jdata, warning=warning)
392456
return jdata

0 commit comments

Comments
 (0)