Skip to content

Commit 91a0f35

Browse files
committed
clear
1 parent 9c7ce39 commit 91a0f35

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

deepmd/tf/entrypoints/change_bias.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ def _change_bias_checkpoint_file(
196196
if stop_batch is None and num_epoch is not None:
197197
if num_epoch <= 0:
198198
raise ValueError("training.num_epoch must be positive.")
199+
# Apply sys_probs and auto_prob from original training config
200+
# to ensure stop_batch calculation matches the original training
201+
training_data_config = training_params.get("training_data", {})
202+
sys_probs = training_data_config.get("sys_probs", None)
203+
auto_prob = training_data_config.get("auto_prob", "prob_sys_size")
204+
data.set_sys_probs(sys_probs=sys_probs, auto_prob_style=auto_prob)
199205
total_numb_batch = compute_total_numb_batch(data.nbatches, data.sys_probs)
200206
if total_numb_batch <= 0:
201207
raise ValueError("Total number of training batches must be positive.")

deepmd/utils/argcheck.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3648,11 +3648,20 @@ def training_extra_check(data: dict | None) -> bool:
36483648
raise ValueError(
36493649
"training.num_epoch_dict is mutually exclusive with training.model_prob."
36503650
)
3651+
else:
3652+
if num_steps is None:
3653+
raise ValueError(
3654+
"Multi-task mode requires either training.numb_steps or training.num_epoch_dict."
3655+
)
36513656
else:
36523657
if num_steps is not None and num_epoch is not None:
36533658
raise ValueError(
36543659
"training.num_step and training.num_epoch are mutually exclusive."
36553660
)
3661+
if num_steps is None and num_epoch is None:
3662+
raise ValueError(
3663+
"Single-task mode requires either training.numb_steps or training.num_epoch."
3664+
)
36563665
return True
36573666

36583667
doc_training = "The training options."

source/tests/pt/test_sampler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,6 @@ def test_num_epoch_dict(self) -> None:
440440
sampler_2 = pt_dataloader.get_sampler_from_params(
441441
dataset_2, {"sys_probs": [0.4, 0.6], "auto_prob": "prob_sys_size"}
442442
)
443-
probs_1 = self._normalize_probs(np.asarray(sampler_1.weights))
444-
probs_2 = self._normalize_probs(np.asarray(sampler_2.weights))
445443

446444
# === Step 2. Compute per-task total_numb_batch ===
447445
per_task_total = np.array(

0 commit comments

Comments
 (0)