Skip to content

Commit 63db1a4

Browse files
committed
merge commont functions
1 parent e625e39 commit 63db1a4

File tree

7 files changed

+235
-299
lines changed

7 files changed

+235
-299
lines changed

deepmd/dpmodel/utils/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@
3636
save_dp_model,
3737
traverse_model_dict,
3838
)
39+
from .training_utils import (
40+
compute_total_numb_batch,
41+
resolve_model_prob,
42+
resolve_model_prob_from_epochs,
43+
)
3944

4045
__all__ = [
4146
"AtomExcludeMask",
@@ -49,6 +54,7 @@
4954
"aggregate",
5055
"build_multiple_neighbor_list",
5156
"build_neighbor_list",
57+
"compute_total_numb_batch",
5258
"extend_coord_with_ghosts",
5359
"get_graph_index",
5460
"get_multiple_nlist_key",
@@ -60,6 +66,8 @@
6066
"nlist_distinguish_types",
6167
"normalize_coord",
6268
"phys2inter",
69+
"resolve_model_prob",
70+
"resolve_model_prob_from_epochs",
6371
"save_dp_model",
6472
"to_face_distance",
6573
"traverse_model_dict",
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import logging
3+
from collections.abc import (
4+
Iterable,
5+
)
6+
7+
import numpy as np
8+
9+
log = logging.getLogger(__name__)
10+
11+
12+
def compute_total_numb_batch(
13+
numb_batches: Iterable[int],
14+
sampler_weights: np.ndarray,
15+
) -> int:
16+
"""Compute total number of batches considering sampler weights.
17+
18+
Parameters
19+
----------
20+
numb_batches : Iterable[int]
21+
Number of batches for each data system.
22+
sampler_weights : np.ndarray
23+
Sampling weights for each data system.
24+
25+
Returns
26+
-------
27+
int
28+
Total number of batches.
29+
30+
Raises
31+
------
32+
ValueError
33+
If input validation fails.
34+
"""
35+
weights = np.asarray(sampler_weights, dtype=np.float64)
36+
if weights.ndim != 1:
37+
raise ValueError("Sampler weights must be 1D.")
38+
if weights.size == 0:
39+
raise ValueError("Sampler weights are empty.")
40+
if not np.all(np.isfinite(weights)):
41+
raise ValueError("Sampler weights must be finite.")
42+
if np.any(weights < 0.0):
43+
raise ValueError("Sampler weights must be non-negative.")
44+
weight_sum = float(np.sum(weights))
45+
if weight_sum <= 0.0:
46+
raise ValueError("Sampler weights must sum to a positive value.")
47+
probs = weights / weight_sum
48+
nbatches = np.asarray(numb_batches, dtype=np.float64)
49+
if nbatches.ndim != 1:
50+
raise ValueError("Number of batches must be 1D.")
51+
if nbatches.size == 0:
52+
raise ValueError("Number of batches is empty.")
53+
if not np.all(np.isfinite(nbatches)):
54+
raise ValueError("Number of batches must be finite.")
55+
if np.any(nbatches < 0.0):
56+
raise ValueError("Number of batches must be non-negative.")
57+
if nbatches.shape[0] != probs.shape[0]:
58+
raise ValueError("Number of batches and sampler weights must match.")
59+
valid = probs > 0.0
60+
if not np.any(valid):
61+
raise ValueError(
62+
"Sampler probabilities must contain at least one positive entry."
63+
)
64+
return int(np.ceil(np.max(nbatches[valid] / probs[valid])))
65+
66+
67+
def resolve_model_prob(
68+
model_keys: list[str],
69+
model_prob_config: dict[str, float] | None,
70+
model_training_data: dict[str, object],
71+
rank: int = 0,
72+
) -> np.ndarray:
73+
"""Resolve model training probability for multi-task training.
74+
75+
Parameters
76+
----------
77+
model_keys : list[str]
78+
List of model keys.
79+
model_prob_config : dict[str, float] | None
80+
User-specified model probabilities. If None, use data size.
81+
model_training_data : dict[str, object]
82+
Training data for each model.
83+
rank : int, optional
84+
Process rank for distributed training, by default 0.
85+
86+
Returns
87+
-------
88+
np.ndarray
89+
Normalized model probabilities.
90+
91+
Raises
92+
------
93+
ValueError
94+
If input validation fails.
95+
"""
96+
model_prob = np.zeros(len(model_keys), dtype=np.float64)
97+
if model_prob_config:
98+
missing = [k for k in model_keys if k not in model_prob_config]
99+
if missing:
100+
raise ValueError(
101+
f"training.model_prob must specify all tasks; missing: {missing}"
102+
)
103+
for ii, model_key in enumerate(model_keys):
104+
if model_key in model_prob_config:
105+
model_prob[ii] = float(model_prob_config[model_key])
106+
else:
107+
if rank == 0:
108+
log.info(
109+
"training.model_prob is not set or empty; defaulting to the "
110+
"number of systems per task."
111+
)
112+
for ii, model_key in enumerate(model_keys):
113+
model_prob[ii] = float(len(model_training_data[model_key]))
114+
if not np.all(np.isfinite(model_prob)):
115+
raise ValueError("Model prob must be finite.")
116+
if np.any(model_prob < 0.0):
117+
raise ValueError("Model prob must be non-negative.")
118+
sum_prob = float(np.sum(model_prob))
119+
if sum_prob <= 0.0:
120+
raise ValueError("Sum of model prob must be larger than 0!")
121+
return model_prob / sum_prob
122+
123+
124+
def resolve_model_prob_from_epochs(
125+
model_keys: list[str],
126+
num_epoch_dict_config: dict[str, float],
127+
per_task_total: np.ndarray,
128+
) -> tuple[np.ndarray, int, dict[str, float]]:
129+
"""Resolve model probability and training steps from epoch configuration.
130+
131+
Parameters
132+
----------
133+
model_keys : list[str]
134+
List of model keys.
135+
num_epoch_dict_config : dict[str, float]
136+
Target epochs for each task.
137+
per_task_total : np.ndarray
138+
Total batches per task.
139+
140+
Returns
141+
-------
142+
tuple[np.ndarray, int, dict[str, float]]
143+
Model probabilities, total training steps, and per-task steps.
144+
145+
Raises
146+
------
147+
ValueError
148+
If input validation fails.
149+
"""
150+
if not num_epoch_dict_config:
151+
raise ValueError("training.num_epoch_dict must be set for multi-task epochs.")
152+
missing = [k for k in model_keys if k not in num_epoch_dict_config]
153+
if missing:
154+
raise ValueError(
155+
f"training.num_epoch_dict must specify all tasks; missing: {missing}"
156+
)
157+
epoch_targets = np.zeros(len(model_keys), dtype=np.float64)
158+
for ii, model_key in enumerate(model_keys):
159+
epoch_value = num_epoch_dict_config[model_key]
160+
if epoch_value is None:
161+
raise ValueError(
162+
f"training.num_epoch_dict['{model_key}'] must be positive."
163+
)
164+
epoch_value = float(epoch_value)
165+
if not np.isfinite(epoch_value) or epoch_value <= 0.0:
166+
raise ValueError(
167+
f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}."
168+
)
169+
epoch_targets[ii] = epoch_value
170+
per_task_total = np.asarray(per_task_total, dtype=np.float64)
171+
if per_task_total.ndim != 1:
172+
raise ValueError("Per-task total batches must be 1D.")
173+
if per_task_total.shape[0] != epoch_targets.shape[0]:
174+
raise ValueError("Per-task totals and epoch targets must match.")
175+
if not np.all(np.isfinite(per_task_total)):
176+
raise ValueError("Per-task total batches must be finite.")
177+
if np.any(per_task_total <= 0.0):
178+
raise ValueError("Per-task total batches must be positive.")
179+
per_task_steps = per_task_total * epoch_targets
180+
total_target_steps = float(np.sum(per_task_steps))
181+
if total_target_steps <= 0.0:
182+
raise ValueError("Sum of target steps must be positive.")
183+
model_prob = per_task_steps / total_target_steps
184+
num_steps = int(np.ceil(total_target_steps))
185+
per_task_steps_map = {
186+
model_key: float(per_task_steps[ii]) for ii, model_key in enumerate(model_keys)
187+
}
188+
return model_prob, num_steps, per_task_steps_map

deepmd/pd/train/training.py

Lines changed: 7 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,11 @@
3030
from deepmd.common import (
3131
symlink_prefix_files,
3232
)
33-
from deepmd.dpmodel.utils.learning_rate import (
34-
BaseLR,
33+
from deepmd.dpmodel.utils.learning_rate import BaseLR
34+
from deepmd.dpmodel.utils import (
35+
compute_total_numb_batch,
36+
resolve_model_prob,
37+
resolve_model_prob_from_epochs,
3538
)
3639
from deepmd.loggers.training import (
3740
format_training_message,
@@ -213,114 +216,6 @@ def get_dataloader_and_buffer(
213216
valid_numb_batch,
214217
)
215218

216-
def compute_total_numb_batch(numb_batches, sampler_weights) -> int:
217-
weights = np.asarray(sampler_weights, dtype=np.float64)
218-
if weights.ndim != 1:
219-
raise ValueError("Sampler weights must be 1D.")
220-
if weights.size == 0:
221-
raise ValueError("Sampler weights are empty.")
222-
if not np.all(np.isfinite(weights)):
223-
raise ValueError("Sampler weights must be finite.")
224-
if np.any(weights < 0.0):
225-
raise ValueError("Sampler weights must be non-negative.")
226-
weight_sum = float(np.sum(weights))
227-
if weight_sum <= 0.0:
228-
raise ValueError("Sampler weights must sum to a positive value.")
229-
probs = weights / weight_sum
230-
nbatches = np.asarray(numb_batches, dtype=np.float64)
231-
if nbatches.ndim != 1:
232-
raise ValueError("Number of batches must be 1D.")
233-
if nbatches.size == 0:
234-
raise ValueError("Number of batches is empty.")
235-
if not np.all(np.isfinite(nbatches)):
236-
raise ValueError("Number of batches must be finite.")
237-
if np.any(nbatches < 0.0):
238-
raise ValueError("Number of batches must be non-negative.")
239-
if nbatches.shape[0] != probs.shape[0]:
240-
raise ValueError("Number of batches and sampler weights must match.")
241-
valid = probs > 0.0
242-
if not np.any(valid):
243-
raise ValueError(
244-
"Sampler probabilities must contain at least one positive entry."
245-
)
246-
return int(np.ceil(np.max(nbatches[valid] / probs[valid])))
247-
248-
def resolve_model_prob(
249-
model_keys,
250-
model_prob_config,
251-
model_training_data,
252-
) -> np.ndarray:
253-
model_prob = np.zeros(len(model_keys), dtype=np.float64)
254-
if model_prob_config:
255-
missing = [k for k in model_keys if k not in model_prob_config]
256-
if missing:
257-
raise ValueError(
258-
f"training.model_prob must specify all tasks; missing: {missing}"
259-
)
260-
for ii, model_key in enumerate(model_keys):
261-
if model_key in model_prob_config:
262-
model_prob[ii] = float(model_prob_config[model_key])
263-
else:
264-
for ii, model_key in enumerate(model_keys):
265-
model_prob[ii] = float(len(model_training_data[model_key]))
266-
if not np.all(np.isfinite(model_prob)):
267-
raise ValueError("Model prob must be finite.")
268-
if np.any(model_prob < 0.0):
269-
raise ValueError("Model prob must be non-negative.")
270-
sum_prob = float(np.sum(model_prob))
271-
if sum_prob <= 0.0:
272-
raise ValueError("Sum of model prob must be larger than 0!")
273-
return model_prob / sum_prob
274-
275-
def resolve_model_prob_from_epochs(
276-
model_keys,
277-
num_epoch_dict_config,
278-
per_task_total,
279-
) -> tuple[np.ndarray, int, dict[str, float]]:
280-
if not num_epoch_dict_config:
281-
raise ValueError(
282-
"training.num_epoch_dict must be set for multi-task epochs."
283-
)
284-
missing = [k for k in model_keys if k not in num_epoch_dict_config]
285-
if missing:
286-
raise ValueError(
287-
"training.num_epoch_dict must specify all tasks; "
288-
f"missing: {missing}"
289-
)
290-
epoch_targets = np.zeros(len(model_keys), dtype=np.float64)
291-
for ii, model_key in enumerate(model_keys):
292-
epoch_value = num_epoch_dict_config[model_key]
293-
if epoch_value is None:
294-
raise ValueError(
295-
f"training.num_epoch_dict['{model_key}'] must be positive."
296-
)
297-
epoch_value = float(epoch_value)
298-
if not np.isfinite(epoch_value) or epoch_value <= 0.0:
299-
raise ValueError(
300-
f"training.num_epoch_dict['{model_key}'] must be positive, got {epoch_value}."
301-
)
302-
epoch_targets[ii] = epoch_value
303-
per_task_total = np.asarray(per_task_total, dtype=np.float64)
304-
if per_task_total.ndim != 1:
305-
raise ValueError("Per-task total batches must be 1D.")
306-
if per_task_total.shape[0] != epoch_targets.shape[0]:
307-
raise ValueError("Per-task totals and epoch targets must match.")
308-
if not np.all(np.isfinite(per_task_total)):
309-
raise ValueError("Per-task total batches must be finite.")
310-
if np.any(per_task_total <= 0.0):
311-
raise ValueError("Per-task total batches must be positive.")
312-
per_task_steps = per_task_total * epoch_targets
313-
total_target_steps = float(np.sum(per_task_steps))
314-
if total_target_steps <= 0.0:
315-
raise ValueError("Sum of target steps must be positive.")
316-
model_prob = per_task_steps / total_target_steps
317-
num_steps = int(np.ceil(total_target_steps))
318-
per_task_steps_map = {
319-
model_key: float(per_task_steps[ii])
320-
for ii, model_key in enumerate(model_keys)
321-
}
322-
return model_prob, num_steps, per_task_steps_map
323-
324219
def single_model_stat(
325220
_model: Any,
326221
_data_stat_nbatch: int,
@@ -567,6 +462,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
567462
self.model_keys,
568463
training_params.get("model_prob"),
569464
training_data,
465+
rank=self.rank,
570466
)
571467

572468
# Learning rate
@@ -760,6 +656,7 @@ def single_model_finetune(
760656
self.model_keys,
761657
training_params.get("model_prob"),
762658
training_data,
659+
rank=self.rank,
763660
)
764661

765662
# Multi-task share params

0 commit comments

Comments
 (0)