@@ -370,6 +370,69 @@ def deprecate_numb_test(
370370 return jdata
371371
372372
373+ def migrate_training_warmup (
374+ jdata : dict [str , Any ], warning : bool = True
375+ ) -> dict [str , Any ]:
376+ """
377+ Migrate legacy warmup settings from training to learning_rate.
378+
379+ Parameters
380+ ----------
381+ jdata : dict[str, Any]
382+ Input configuration dictionary.
383+ warning : bool, optional
384+ Whether to show a deprecation warning, by default True.
385+
386+ Returns
387+ -------
388+ dict[str, Any]
389+ Updated configuration dictionary.
390+ """
391+ training = jdata .get ("training" )
392+ if not isinstance (training , dict ):
393+ return jdata
394+
395+ warmup_keys = ("warmup_steps" , "warmup_ratio" , "warmup_start_factor" )
396+ legacy_keys = [key for key in warmup_keys if key in training ]
397+ if not legacy_keys :
398+ return jdata
399+
400+ lr = jdata .get ("learning_rate" )
401+ if not isinstance (lr , dict ):
402+ for key in legacy_keys :
403+ training .pop (key )
404+ if warning :
405+ warnings .warn (
406+ "Found legacy warmup settings under training, but learning_rate "
407+ "is missing or invalid. The warmup keys were removed from training."
408+ )
409+ return jdata
410+
411+ moved_keys = []
412+ skipped_keys = []
413+ # === Step 1. Move legacy warmup keys ===
414+ for key in legacy_keys :
415+ value = training .pop (key )
416+ if key in lr :
417+ skipped_keys .append (key )
418+ continue
419+ lr [key ] = value
420+ moved_keys .append (key )
421+
422+ if warning :
423+ if skipped_keys :
424+ warnings .warn (
425+ "Legacy warmup settings under training were ignored because "
426+ f"learning_rate already defines them: { ', ' .join (skipped_keys )} ."
427+ )
428+ else :
429+ warnings .warn (
430+ "Legacy warmup settings under training were moved to learning_rate: "
431+ f"{ ', ' .join (moved_keys )} ."
432+ )
433+ return jdata
434+
435+
373436def update_deepmd_input (
374437 jdata : dict [str , Any ], warning : bool = True , dump : str | Path | None = None
375438) -> dict [str , Any ]:
@@ -389,4 +452,5 @@ def is_deepmd_v1_input(jdata: dict[str, Any]) -> bool:
389452 else :
390453 jdata = deprecate_numb_test (jdata , warning , dump )
391454
455+ jdata = migrate_training_warmup (jdata , warning = warning )
392456 return jdata
0 commit comments