Skip to content

Commit 9bfcf24

Browse files
committed
fix
1 parent e5f2ef5 commit 9bfcf24

File tree

3 files changed

+22
-17
lines changed

3 files changed

+22
-17
lines changed

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,14 @@ def __init__(
105105
self.warmup_steps = warmup_steps
106106

107107
# === Step 5. Validate step ranges (runtime check) ===
108-
if num_steps <= 0:
109-
raise ValueError("num_steps must be positive")
108+
if num_steps < 0:
109+
raise ValueError("num_steps must be non-negative")
110110
if self.warmup_steps < 0:
111111
raise ValueError("warmup_steps must be non-negative")
112-
if self.warmup_steps >= num_steps:
112+
if num_steps > 0 and self.warmup_steps >= num_steps:
113113
raise ValueError("warmup_steps must be smaller than num_steps")
114+
if num_steps == 0 and self.warmup_steps != 0:
115+
raise ValueError("warmup_steps must be 0 when num_steps is 0")
114116

115117
# === Step 6. Compute warmup_start_lr ===
116118
self.warmup_start_lr = warmup_start_factor * start_lr
@@ -457,6 +459,9 @@ def _decay_value(self, step: int | Array) -> Array:
457459
step = np.asarray(step)
458460
xp = array_api_compat.array_namespace(step)
459461
min_lr = self.start_lr * self.lr_min_factor
462+
# Handle decay_num_steps=0 (no training steps) - return start_lr
463+
if self.decay_num_steps == 0:
464+
return xp.full_like(step, self.start_lr, dtype=xp.float64)
460465
step_lr = self.start_lr * (
461466
self.lr_min_factor
462467
+ 0.5

deepmd/tf/loss/dos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
6565
dos_hat = label_dict["dos"]
6666
atom_dos_hat = label_dict["atom_dos"]
6767

68-
find_dos = label_dict["find_dos"]
69-
find_atom_dos = label_dict["find_atom_dos"]
68+
find_dos = global_cvt_2_tf_float(label_dict["find_dos"])
69+
find_atom_dos = global_cvt_2_tf_float(label_dict["find_atom_dos"])
7070

7171
dos_reshape = tf.reshape(dos, [-1, self.numb_dos])
7272
dos_hat_reshape = tf.reshape(dos_hat, [-1, self.numb_dos])

deepmd/tf/loss/ener.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,14 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
149149
virial_hat = label_dict["virial"]
150150
atom_ener_hat = label_dict["atom_ener"]
151151
atom_pref = label_dict["atom_pref"]
152-
find_energy = label_dict["find_energy"]
153-
find_force = label_dict["find_force"]
154-
find_virial = label_dict["find_virial"]
155-
find_atom_ener = label_dict["find_atom_ener"]
156-
find_atom_pref = label_dict["find_atom_pref"]
152+
find_energy = global_cvt_2_ener_float(label_dict["find_energy"])
153+
find_force = global_cvt_2_tf_float(label_dict["find_force"])
154+
find_virial = global_cvt_2_tf_float(label_dict["find_virial"])
155+
find_atom_ener = global_cvt_2_tf_float(label_dict["find_atom_ener"])
156+
find_atom_pref = global_cvt_2_tf_float(label_dict["find_atom_pref"])
157157
if self.has_gf:
158158
drdq = label_dict["drdq"]
159-
find_drdq = label_dict["find_drdq"]
159+
find_drdq = global_cvt_2_tf_float(label_dict["find_drdq"])
160160
else:
161161
find_drdq = 0.0
162162

@@ -589,10 +589,10 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
589589
virial_label = label_dict["virial"]
590590
atom_ener_label = label_dict["atom_ener"]
591591
atom_pref = label_dict["atom_pref"]
592-
find_energy = label_dict["find_energy"]
593-
find_force = label_dict["find_force"]
594-
find_virial = label_dict["find_virial"]
595-
find_atom_ener = label_dict["find_atom_ener"]
592+
find_energy = global_cvt_2_ener_float(label_dict["find_energy"])
593+
find_force = global_cvt_2_tf_float(label_dict["find_force"])
594+
find_virial = global_cvt_2_tf_float(label_dict["find_virial"])
595+
find_atom_ener = global_cvt_2_tf_float(label_dict["find_atom_ener"])
596596

597597
if self.enable_atom_ener_coeff:
598598
# when ener_coeff (\nu) is defined, the energy is defined as
@@ -932,8 +932,8 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
932932

933933
energy_hat = label_dict["energy"]
934934
ener_dipole_hat = label_dict["energy_dipole"]
935-
find_energy = label_dict["find_energy"]
936-
find_ener_dipole = label_dict["find_energy_dipole"]
935+
find_energy = global_cvt_2_ener_float(label_dict["find_energy"])
936+
find_ener_dipole = global_cvt_2_ener_float(label_dict["find_energy_dipole"])
937937

938938
l2_ener_loss = tf.reduce_mean(
939939
tf.square(energy - energy_hat), name="l2_" + suffix

0 commit comments

Comments
 (0)