Skip to content

Commit a0bd530

Browse files
wanghan-iapcmHan Wang
andauthored
feat(pt_expt): atomic model (#5220)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added end-to-end bias/statistics workflow for atomic models (compute, load, apply, and update output biases). * Introduced PyTorch-experimental atomic model wrappers with serialization/export compatibility. * Added comprehensive statistics utilities for global and per-atom outputs. * **Bug Fixes** * Improved tensor→array conversion to handle gradient-enabled tensors robustly. * **Tests** * Added extensive tests covering stats, bias workflows, serialization, export, and consistency. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 4f182bc commit a0bd530

File tree

11 files changed

+2359
-2
lines changed

11 files changed

+2359
-2
lines changed

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import math
3+
from collections.abc import (
4+
Callable,
5+
)
36
from typing import (
47
Any,
58
)
@@ -30,6 +33,9 @@
3033
map_atom_exclude_types,
3134
map_pair_exclude_types,
3235
)
36+
from deepmd.utils.path import (
37+
DPPath,
38+
)
3339

3440
from .make_base_atomic_model import (
3541
make_base_atomic_model,
@@ -246,6 +252,180 @@ def call(
246252
aparam=aparam,
247253
)
248254

255+
def get_intensive(self) -> bool:
256+
"""Whether the fitting property is intensive."""
257+
return False
258+
259+
def get_compute_stats_distinguish_types(self) -> bool:
260+
"""Get whether the fitting net computes stats which are not distinguished between different types of atoms."""
261+
return True
262+
263+
def compute_or_load_out_stat(
264+
self,
265+
merged: Callable[[], list[dict]] | list[dict],
266+
stat_file_path: DPPath | None = None,
267+
) -> None:
268+
"""
269+
Compute the output statistics (e.g. energy bias) for the fitting net from packed data.
270+
271+
Parameters
272+
----------
273+
merged : Union[Callable[[], list[dict]], list[dict]]
274+
- list[dict]: A list of data samples from various data systems.
275+
Each element, `merged[i]`, is a data dictionary containing `keys`: `np.ndarray`
276+
originating from the `i`-th data system.
277+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
278+
only when needed. Since the sampling process can be slow and memory-intensive,
279+
the lazy function helps by only sampling once.
280+
stat_file_path : Optional[DPPath]
281+
The path to the stat file.
282+
283+
"""
284+
self.change_out_bias(
285+
merged,
286+
stat_file_path=stat_file_path,
287+
bias_adjust_mode="set-by-statistic",
288+
)
289+
290+
def change_out_bias(
291+
self,
292+
sample_merged: Callable[[], list[dict]] | list[dict],
293+
stat_file_path: DPPath | None = None,
294+
bias_adjust_mode: str = "change-by-statistic",
295+
) -> None:
296+
"""Change the output bias according to the input data and the pretrained model.
297+
298+
Parameters
299+
----------
300+
sample_merged : Union[Callable[[], list[dict]], list[dict]]
301+
- list[dict]: A list of data samples from various data systems.
302+
Each element, `merged[i]`, is a data dictionary containing `keys`: `np.ndarray`
303+
originating from the `i`-th data system.
304+
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
305+
only when needed. Since the sampling process can be slow and memory-intensive,
306+
the lazy function helps by only sampling once.
307+
bias_adjust_mode : str
308+
The mode for changing output bias : ['change-by-statistic', 'set-by-statistic']
309+
'change-by-statistic' : perform predictions on labels of target dataset,
310+
and do least square on the errors to obtain the target shift as bias.
311+
'set-by-statistic' : directly use the statistic output bias in the target dataset.
312+
stat_file_path : Optional[DPPath]
313+
The path to the stat file.
314+
"""
315+
from deepmd.dpmodel.utils.stat import (
316+
compute_output_stats,
317+
)
318+
319+
if bias_adjust_mode == "change-by-statistic":
320+
delta_bias, out_std = compute_output_stats(
321+
sample_merged,
322+
self.get_ntypes(),
323+
keys=list(self.atomic_output_def().keys()),
324+
stat_file_path=stat_file_path,
325+
model_forward=self._get_forward_wrapper_func(),
326+
rcond=self.rcond,
327+
preset_bias=self.preset_out_bias,
328+
stats_distinguish_types=self.get_compute_stats_distinguish_types(),
329+
intensive=self.get_intensive(),
330+
)
331+
self._store_out_stat(delta_bias, out_std, add=True)
332+
elif bias_adjust_mode == "set-by-statistic":
333+
bias_out, std_out = compute_output_stats(
334+
sample_merged,
335+
self.get_ntypes(),
336+
keys=list(self.atomic_output_def().keys()),
337+
stat_file_path=stat_file_path,
338+
rcond=self.rcond,
339+
preset_bias=self.preset_out_bias,
340+
stats_distinguish_types=self.get_compute_stats_distinguish_types(),
341+
intensive=self.get_intensive(),
342+
)
343+
self._store_out_stat(bias_out, std_out)
344+
else:
345+
raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode)
346+
347+
def _store_out_stat(
348+
self,
349+
out_bias: dict[str, np.ndarray],
350+
out_std: dict[str, np.ndarray],
351+
add: bool = False,
352+
) -> None:
353+
"""Store output bias and std into the model."""
354+
ntypes = self.get_ntypes()
355+
out_bias_data = np.array(to_numpy_array(self.out_bias))
356+
out_std_data = np.array(to_numpy_array(self.out_std))
357+
for kk in out_bias.keys():
358+
assert kk in out_std.keys()
359+
idx = self._get_bias_index(kk)
360+
size = self._varsize(self.atomic_output_def()[kk].shape)
361+
if not add:
362+
out_bias_data[idx, :, :size] = out_bias[kk].reshape(ntypes, size)
363+
else:
364+
out_bias_data[idx, :, :size] += out_bias[kk].reshape(ntypes, size)
365+
out_std_data[idx, :, :size] = out_std[kk].reshape(ntypes, size)
366+
self.out_bias = out_bias_data
367+
self.out_std = out_std_data
368+
369+
def _get_forward_wrapper_func(self) -> Callable[..., dict[str, np.ndarray]]:
370+
"""Get a forward wrapper of the atomic model for output bias calculation."""
371+
import array_api_compat
372+
373+
from deepmd.dpmodel.utils.nlist import (
374+
extend_input_and_build_neighbor_list,
375+
)
376+
377+
def model_forward(
378+
coord: np.ndarray,
379+
atype: np.ndarray,
380+
box: np.ndarray | None,
381+
fparam: np.ndarray | None = None,
382+
aparam: np.ndarray | None = None,
383+
) -> dict[str, np.ndarray]:
384+
# Get reference array to determine the target array type and device
385+
# Use out_bias as reference since it's always present
386+
ref_array = self.out_bias
387+
xp = array_api_compat.array_namespace(ref_array)
388+
389+
# Convert numpy inputs to the model's array type with correct device
390+
device = array_api_compat.device(ref_array)
391+
coord = xp.asarray(coord, device=device)
392+
atype = xp.asarray(atype, device=device)
393+
if box is not None:
394+
if np.allclose(box, 0.0):
395+
box = None
396+
else:
397+
box = xp.asarray(box, device=device)
398+
if fparam is not None:
399+
fparam = xp.asarray(fparam, device=device)
400+
if aparam is not None:
401+
aparam = xp.asarray(aparam, device=device)
402+
403+
(
404+
extended_coord,
405+
extended_atype,
406+
mapping,
407+
nlist,
408+
) = extend_input_and_build_neighbor_list(
409+
coord,
410+
atype,
411+
self.get_rcut(),
412+
self.get_sel(),
413+
mixed_types=self.mixed_types(),
414+
box=box,
415+
)
416+
atomic_ret = self.forward_common_atomic(
417+
extended_coord,
418+
extended_atype,
419+
nlist,
420+
mapping=mapping,
421+
fparam=fparam,
422+
aparam=aparam,
423+
)
424+
# Convert outputs back to numpy arrays
425+
return {kk: to_numpy_array(vv) for kk, vv in atomic_ret.items()}
426+
427+
return model_forward
428+
249429
def serialize(self) -> dict:
250430
return {
251431
"type_map": self.type_map,

deepmd/dpmodel/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ def to_numpy_array(x: Optional["Array"]) -> np.ndarray | None:
121121
try:
122122
# asarray is not within Array API standard, so may fail
123123
return np.asarray(x)
124-
except (ValueError, AttributeError, TypeError):
124+
except (ValueError, AttributeError, TypeError, RuntimeError):
125+
# RuntimeError: handles torch tensors with requires_grad=True
125126
xp = array_api_compat.array_namespace(x)
126127
# to fix BufferError: Cannot export readonly array since signalling readonly is unsupported by DLPack.
127128
# Move to CPU device to ensure numpy compatibility

0 commit comments

Comments
 (0)