Skip to content

Commit 37b5eba

Browse files
committed
fix
1 parent 019a841 commit 37b5eba

File tree

10 files changed

+108
-61
lines changed

10 files changed

+108
-61
lines changed

deepmd/pd/entrypoints/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
normalize,
6565
)
6666
from deepmd.utils.compat import (
67+
convert_optimizer_to_new_format,
6768
update_deepmd_input,
6869
)
6970
from deepmd.utils.data_system import (
@@ -292,6 +293,8 @@ def train(
292293

293294
# argcheck
294295
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
296+
# Backward compatibility: convert old optimizer format
297+
config = convert_optimizer_to_new_format(config)
295298
config = normalize(config, multi_task=multi_task)
296299

297300
# do neighbor stat

deepmd/pd/train/training.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,6 @@
7272
nvprof_context,
7373
to_numpy_array,
7474
)
75-
from deepmd.utils.compat import (
76-
convert_optimizer_to_new_format,
77-
)
7875
from deepmd.utils.data import (
7976
DataRequirementItem,
8077
)
@@ -116,8 +113,6 @@ def __init__(
116113
resume_model = None
117114
resuming = resume_model is not None
118115
self.restart_training = restart_model is not None
119-
# Backward compatibility: convert old optimizer format
120-
config = convert_optimizer_to_new_format(config)
121116
model_params = config["model"]
122117
training_params = config["training"]
123118
optimizer_params = config.get("optimizer", {})

deepmd/pt/entrypoints/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
normalize,
8484
)
8585
from deepmd.utils.compat import (
86+
convert_optimizer_to_new_format,
8687
update_deepmd_input,
8788
)
8889
from deepmd.utils.data_system import (
@@ -325,6 +326,8 @@ def train(
325326

326327
# argcheck
327328
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
329+
# Backward compatibility: convert old optimizer format
330+
config = convert_optimizer_to_new_format(config)
328331
config = normalize(config, multi_task=multi_task)
329332

330333
# do neighbor stat

deepmd/pt/train/training.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,6 @@
7373
from deepmd.pt.utils.utils import (
7474
to_numpy_array,
7575
)
76-
from deepmd.utils.compat import (
77-
convert_optimizer_to_new_format,
78-
)
7976
from deepmd.utils.data import (
8077
DataRequirementItem,
8178
)
@@ -126,8 +123,6 @@ def __init__(
126123
resume_model = None
127124
resuming = resume_model is not None
128125
self.restart_training = restart_model is not None
129-
# Backward compatibility: convert old optimizer format
130-
config = convert_optimizer_to_new_format(config)
131126
model_params = config["model"]
132127
training_params = config["training"]
133128
optimizer_params = config.get("optimizer", {})

deepmd/tf/entrypoints/train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
normalize,
3838
)
3939
from deepmd.tf.utils.compat import (
40+
convert_optimizer_to_new_format,
4041
update_deepmd_input,
4142
)
4243
from deepmd.tf.utils.finetune import (
@@ -162,6 +163,8 @@ def train(
162163
jdata["model"] = json.loads(t_training_script)["model"]
163164

164165
jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")
166+
# Backward compatibility: convert old optimizer format
167+
jdata = convert_optimizer_to_new_format(jdata)
165168

166169
jdata = normalize(jdata)
167170

deepmd/tf/train/trainer.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@
5757
from deepmd.tf.utils.sess import (
5858
run_sess,
5959
)
60-
from deepmd.utils.compat import (
61-
convert_optimizer_to_new_format,
62-
)
6360
from deepmd.utils.data import (
6461
DataRequirementItem,
6562
)
@@ -124,8 +121,6 @@ def get_lr_and_coef(lr_param):
124121
lr_param = jdata["learning_rate"]
125122
self.lr, self.scale_lr_coef = get_lr_and_coef(lr_param)
126123
# optimizer
127-
# Backward compatibility: convert old optimizer format
128-
jdata = convert_optimizer_to_new_format(jdata)
129124
# Note: Default values are already filled by argcheck.normalize()
130125
optimizer_param = jdata.get("optimizer", {})
131126
self.optimizer_type = optimizer_param.get("type", "Adam")

deepmd/tf/utils/compat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
from deepmd.utils.compat import (
55
convert_input_v0_v1,
66
convert_input_v1_v2,
7+
convert_optimizer_to_new_format,
78
deprecate_numb_test,
89
update_deepmd_input,
910
)
1011

1112
__all__ = [
1213
"convert_input_v0_v1",
1314
"convert_input_v1_v2",
15+
"convert_optimizer_to_new_format",
1416
"deprecate_numb_test",
1517
"update_deepmd_input",
1618
]

deepmd/utils/compat.py

Lines changed: 64 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -412,53 +412,71 @@ def convert_optimizer_to_new_format(
412412
dict[str, Any]
413413
converted output with optimizer section
414414
"""
415+
# Default optimizer values (must match argcheck.py defaults)
416+
default_optimizer = {
417+
"type": "Adam",
418+
"adam_beta1": 0.9,
419+
"adam_beta2": 0.999,
420+
"weight_decay": 0.0,
421+
}
422+
415423
training_cfg = jdata.get("training", {})
416-
if "opt_type" not in training_cfg:
417-
# No conversion needed
418-
return jdata
419-
420-
# Optimizer parameters that may be in the training section
421-
optimizer_keys = [
422-
"opt_type",
423-
"kf_blocksize",
424-
"kf_start_pref_e",
425-
"kf_limit_pref_e",
426-
"kf_start_pref_f",
427-
"kf_limit_pref_f",
428-
"weight_decay",
429-
"momentum",
430-
"muon_momentum",
431-
"adam_beta1",
432-
"adam_beta2",
433-
"lr_adjust",
434-
"lr_adjust_coeff",
435-
"muon_2d_only",
436-
"min_2d_dim",
437-
]
438-
439-
# Extract optimizer parameters from training section
440-
optimizer_cfg = {}
441-
for key in optimizer_keys:
442-
if key in training_cfg:
443-
optimizer_cfg[key] = training_cfg.pop(key)
444-
445-
# Convert opt_type to type for new format
446-
if "opt_type" in optimizer_cfg:
447-
optimizer_cfg["type"] = optimizer_cfg.pop("opt_type")
448-
449-
# Set the optimizer section if not already present
450-
if "optimizer" not in jdata:
451-
jdata["optimizer"] = optimizer_cfg
452-
else:
453-
# Merge with existing optimizer config (new config from conversion takes precedence)
454-
jdata["optimizer"].update(optimizer_cfg)
424+
optimizer_cfg = jdata.get("optimizer", {})
425+
426+
# Case 1: Old format - optimizer params in training section
427+
if "opt_type" in training_cfg:
428+
# Optimizer parameters that may be in the training section
429+
optimizer_keys = [
430+
"opt_type",
431+
"kf_blocksize",
432+
"kf_start_pref_e",
433+
"kf_limit_pref_e",
434+
"kf_start_pref_f",
435+
"kf_limit_pref_f",
436+
"weight_decay",
437+
"momentum",
438+
"muon_momentum",
439+
"adam_beta1",
440+
"adam_beta2",
441+
"lr_adjust",
442+
"lr_adjust_coeff",
443+
"muon_2d_only",
444+
"min_2d_dim",
445+
]
446+
447+
# Extract optimizer parameters from training section
448+
extracted_cfg = {}
449+
for key in optimizer_keys:
450+
if key in training_cfg:
451+
extracted_cfg[key] = training_cfg.pop(key)
452+
453+
# Convert opt_type to type for new format
454+
if "opt_type" in extracted_cfg:
455+
extracted_cfg["type"] = extracted_cfg.pop("opt_type")
456+
457+
# Merge with existing optimizer config (conversion takes precedence)
458+
optimizer_cfg = {**optimizer_cfg, **extracted_cfg}
455459

456-
if warning:
457-
warnings.warn(
458-
"Placing optimizer parameters (opt_type, kf_blocksize, etc.) in the training section "
459-
"is deprecated. Use a separate 'optimizer' section with 'type' field instead.",
460-
DeprecationWarning,
461-
stacklevel=2,
462-
)
460+
if warning:
461+
warnings.warn(
462+
"Placing optimizer parameters (opt_type, kf_blocksize, etc.) in the training section "
463+
"is deprecated. Use a separate 'optimizer' section with 'type' field instead.",
464+
DeprecationWarning,
465+
stacklevel=2,
466+
)
467+
468+
# Case 2: Fill in missing defaults
469+
# If type is not specified, default to Adam
470+
if "type" not in optimizer_cfg:
471+
optimizer_cfg["type"] = default_optimizer["type"]
472+
473+
# Fill in defaults for Adam optimizer type
474+
if optimizer_cfg["type"] in ("Adam", "AdamW"):
475+
for key, value in default_optimizer.items():
476+
if key not in optimizer_cfg:
477+
optimizer_cfg[key] = value
478+
479+
# Set/update the optimizer section
480+
jdata["optimizer"] = optimizer_cfg
463481

464482
return jdata

source/tests/pd/test_training.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from deepmd.pd.utils.finetune import (
2222
get_finetune_rules,
2323
)
24+
from deepmd.utils.compat import (
25+
convert_optimizer_to_new_format,
26+
)
2427

2528
from .model.test_permutation import (
2629
model_dpa1,
@@ -151,6 +154,7 @@ def setUp(self) -> None:
151154
input_json = str(Path(__file__).parent / "water/se_atten.json")
152155
with open(input_json) as f:
153156
self.config = json.load(f)
157+
self.config = convert_optimizer_to_new_format(self.config, warning=False)
154158
data_file = [str(Path(__file__).parent / "water/data/data_0")]
155159
self.config["training"]["training_data"]["systems"] = data_file
156160
self.config["training"]["validation_data"]["systems"] = data_file
@@ -168,6 +172,7 @@ def setUp(self) -> None:
168172
input_json = str(Path(__file__).parent / "water/se_atten.json")
169173
with open(input_json) as f:
170174
self.config = json.load(f)
175+
self.config = convert_optimizer_to_new_format(self.config, warning=False)
171176
data_file = [str(Path(__file__).parent / "water/data/data_0")]
172177
self.config["training"]["training_data"]["systems"] = data_file
173178
self.config["training"]["validation_data"]["systems"] = data_file
@@ -188,6 +193,7 @@ def setUp(self) -> None:
188193
input_json = str(Path(__file__).parent / "water/se_atten.json")
189194
with open(input_json) as f:
190195
self.config = json.load(f)
196+
self.config = convert_optimizer_to_new_format(self.config, warning=False)
191197
data_file = [str(Path(__file__).parent / "water/data/data_0")]
192198
self.config["training"]["training_data"]["systems"] = data_file
193199
self.config["training"]["validation_data"]["systems"] = data_file
@@ -209,6 +215,7 @@ def setUp(self) -> None:
209215
input_json = str(Path(__file__).parent / "water/se_atten.json")
210216
with open(input_json) as f:
211217
self.config = json.load(f)
218+
self.config = convert_optimizer_to_new_format(self.config, warning=False)
212219
data_file = [str(Path(__file__).parent / "water/data/data_0")]
213220
self.config["training"]["training_data"]["systems"] = data_file
214221
self.config["training"]["validation_data"]["systems"] = data_file
@@ -225,6 +232,7 @@ def setUp(self) -> None:
225232
input_json = str(Path(__file__).parent / "water/se_atten.json")
226233
with open(input_json) as f:
227234
self.config = json.load(f)
235+
self.config = convert_optimizer_to_new_format(self.config, warning=False)
228236
data_file = [str(Path(__file__).parent / "water/data/data_0")]
229237
self.config["training"]["training_data"]["systems"] = data_file
230238
self.config["training"]["validation_data"]["systems"] = data_file

0 commit comments

Comments
 (0)