@@ -412,53 +412,71 @@ def convert_optimizer_to_new_format(
412412 dict[str, Any]
413413 converted output with optimizer section
414414 """
415+ # Default optimizer values (must match argcheck.py defaults)
416+ default_optimizer = {
417+ "type" : "Adam" ,
418+ "adam_beta1" : 0.9 ,
419+ "adam_beta2" : 0.999 ,
420+ "weight_decay" : 0.0 ,
421+ }
422+
415423 training_cfg = jdata .get ("training" , {})
416- if "opt_type" not in training_cfg :
417- # No conversion needed
418- return jdata
419-
420- # Optimizer parameters that may be in the training section
421- optimizer_keys = [
422- "opt_type" ,
423- "kf_blocksize" ,
424- "kf_start_pref_e" ,
425- "kf_limit_pref_e" ,
426- "kf_start_pref_f" ,
427- "kf_limit_pref_f" ,
428- "weight_decay" ,
429- "momentum" ,
430- "muon_momentum" ,
431- "adam_beta1" ,
432- "adam_beta2" ,
433- "lr_adjust" ,
434- "lr_adjust_coeff" ,
435- "muon_2d_only" ,
436- "min_2d_dim" ,
437- ]
438-
439- # Extract optimizer parameters from training section
440- optimizer_cfg = {}
441- for key in optimizer_keys :
442- if key in training_cfg :
443- optimizer_cfg [key ] = training_cfg .pop (key )
444-
445- # Convert opt_type to type for new format
446- if "opt_type" in optimizer_cfg :
447- optimizer_cfg ["type" ] = optimizer_cfg .pop ("opt_type" )
448-
449- # Set the optimizer section if not already present
450- if "optimizer" not in jdata :
451- jdata ["optimizer" ] = optimizer_cfg
452- else :
453- # Merge with existing optimizer config (new config from conversion takes precedence)
454- jdata ["optimizer" ].update (optimizer_cfg )
424+ optimizer_cfg = jdata .get ("optimizer" , {})
425+
426+ # Case 1: Old format - optimizer params in training section
427+ if "opt_type" in training_cfg :
428+ # Optimizer parameters that may be in the training section
429+ optimizer_keys = [
430+ "opt_type" ,
431+ "kf_blocksize" ,
432+ "kf_start_pref_e" ,
433+ "kf_limit_pref_e" ,
434+ "kf_start_pref_f" ,
435+ "kf_limit_pref_f" ,
436+ "weight_decay" ,
437+ "momentum" ,
438+ "muon_momentum" ,
439+ "adam_beta1" ,
440+ "adam_beta2" ,
441+ "lr_adjust" ,
442+ "lr_adjust_coeff" ,
443+ "muon_2d_only" ,
444+ "min_2d_dim" ,
445+ ]
446+
447+ # Extract optimizer parameters from training section
448+ extracted_cfg = {}
449+ for key in optimizer_keys :
450+ if key in training_cfg :
451+ extracted_cfg [key ] = training_cfg .pop (key )
452+
453+ # Convert opt_type to type for new format
454+ if "opt_type" in extracted_cfg :
455+ extracted_cfg ["type" ] = extracted_cfg .pop ("opt_type" )
456+
457+ # Merge with existing optimizer config (conversion takes precedence)
458+ optimizer_cfg = {** optimizer_cfg , ** extracted_cfg }
455459
456- if warning :
457- warnings .warn (
458- "Placing optimizer parameters (opt_type, kf_blocksize, etc.) in the training section "
459- "is deprecated. Use a separate 'optimizer' section with 'type' field instead." ,
460- DeprecationWarning ,
461- stacklevel = 2 ,
462- )
460+ if warning :
461+ warnings .warn (
462+ "Placing optimizer parameters (opt_type, kf_blocksize, etc.) in the training section "
463+ "is deprecated. Use a separate 'optimizer' section with 'type' field instead." ,
464+ DeprecationWarning ,
465+ stacklevel = 2 ,
466+ )
467+
468+ # Case 2: Fill in missing defaults
469+ # If type is not specified, default to Adam
470+ if "type" not in optimizer_cfg :
471+ optimizer_cfg ["type" ] = default_optimizer ["type" ]
472+
473+ # Fill in defaults for Adam optimizer type
474+ if optimizer_cfg ["type" ] in ("Adam" , "AdamW" ):
475+ for key , value in default_optimizer .items ():
476+ if key not in optimizer_cfg :
477+ optimizer_cfg [key ] = value
478+
479+ # Set/update the optimizer section
480+ jdata ["optimizer" ] = optimizer_cfg
463481
464482 return jdata
0 commit comments