|
1 | 1 | # SPDX-License-Identifier: LGPL-3.0-or-later |
2 | 2 | import math |
| 3 | +from collections.abc import ( |
| 4 | + Callable, |
| 5 | +) |
3 | 6 | from typing import ( |
4 | 7 | Any, |
5 | 8 | ) |
|
30 | 33 | map_atom_exclude_types, |
31 | 34 | map_pair_exclude_types, |
32 | 35 | ) |
| 36 | +from deepmd.utils.path import ( |
| 37 | + DPPath, |
| 38 | +) |
33 | 39 |
|
34 | 40 | from .make_base_atomic_model import ( |
35 | 41 | make_base_atomic_model, |
@@ -246,6 +252,180 @@ def call( |
246 | 252 | aparam=aparam, |
247 | 253 | ) |
248 | 254 |
|
| 255 | + def get_intensive(self) -> bool: |
| 256 | + """Whether the fitting property is intensive.""" |
| 257 | + return False |
| 258 | + |
| 259 | + def get_compute_stats_distinguish_types(self) -> bool: |
| 260 | + """Get whether the fitting net computes stats which are not distinguished between different types of atoms.""" |
| 261 | + return True |
| 262 | + |
| 263 | + def compute_or_load_out_stat( |
| 264 | + self, |
| 265 | + merged: Callable[[], list[dict]] | list[dict], |
| 266 | + stat_file_path: DPPath | None = None, |
| 267 | + ) -> None: |
| 268 | + """ |
| 269 | + Compute the output statistics (e.g. energy bias) for the fitting net from packed data. |
| 270 | +
|
| 271 | + Parameters |
| 272 | + ---------- |
| 273 | + merged : Union[Callable[[], list[dict]], list[dict]] |
| 274 | + - list[dict]: A list of data samples from various data systems. |
| 275 | + Each element, `merged[i]`, is a data dictionary containing `keys`: `np.ndarray` |
| 276 | + originating from the `i`-th data system. |
| 277 | + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format |
| 278 | + only when needed. Since the sampling process can be slow and memory-intensive, |
| 279 | + the lazy function helps by only sampling once. |
| 280 | + stat_file_path : Optional[DPPath] |
| 281 | + The path to the stat file. |
| 282 | +
|
| 283 | + """ |
| 284 | + self.change_out_bias( |
| 285 | + merged, |
| 286 | + stat_file_path=stat_file_path, |
| 287 | + bias_adjust_mode="set-by-statistic", |
| 288 | + ) |
| 289 | + |
| 290 | + def change_out_bias( |
| 291 | + self, |
| 292 | + sample_merged: Callable[[], list[dict]] | list[dict], |
| 293 | + stat_file_path: DPPath | None = None, |
| 294 | + bias_adjust_mode: str = "change-by-statistic", |
| 295 | + ) -> None: |
| 296 | + """Change the output bias according to the input data and the pretrained model. |
| 297 | +
|
| 298 | + Parameters |
| 299 | + ---------- |
| 300 | + sample_merged : Union[Callable[[], list[dict]], list[dict]] |
| 301 | + - list[dict]: A list of data samples from various data systems. |
| 302 | + Each element, `merged[i]`, is a data dictionary containing `keys`: `np.ndarray` |
| 303 | + originating from the `i`-th data system. |
| 304 | + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format |
| 305 | + only when needed. Since the sampling process can be slow and memory-intensive, |
| 306 | + the lazy function helps by only sampling once. |
| 307 | + bias_adjust_mode : str |
| 308 | + The mode for changing output bias : ['change-by-statistic', 'set-by-statistic'] |
| 309 | + 'change-by-statistic' : perform predictions on labels of target dataset, |
| 310 | + and do least square on the errors to obtain the target shift as bias. |
| 311 | + 'set-by-statistic' : directly use the statistic output bias in the target dataset. |
| 312 | + stat_file_path : Optional[DPPath] |
| 313 | + The path to the stat file. |
| 314 | + """ |
| 315 | + from deepmd.dpmodel.utils.stat import ( |
| 316 | + compute_output_stats, |
| 317 | + ) |
| 318 | + |
| 319 | + if bias_adjust_mode == "change-by-statistic": |
| 320 | + delta_bias, out_std = compute_output_stats( |
| 321 | + sample_merged, |
| 322 | + self.get_ntypes(), |
| 323 | + keys=list(self.atomic_output_def().keys()), |
| 324 | + stat_file_path=stat_file_path, |
| 325 | + model_forward=self._get_forward_wrapper_func(), |
| 326 | + rcond=self.rcond, |
| 327 | + preset_bias=self.preset_out_bias, |
| 328 | + stats_distinguish_types=self.get_compute_stats_distinguish_types(), |
| 329 | + intensive=self.get_intensive(), |
| 330 | + ) |
| 331 | + self._store_out_stat(delta_bias, out_std, add=True) |
| 332 | + elif bias_adjust_mode == "set-by-statistic": |
| 333 | + bias_out, std_out = compute_output_stats( |
| 334 | + sample_merged, |
| 335 | + self.get_ntypes(), |
| 336 | + keys=list(self.atomic_output_def().keys()), |
| 337 | + stat_file_path=stat_file_path, |
| 338 | + rcond=self.rcond, |
| 339 | + preset_bias=self.preset_out_bias, |
| 340 | + stats_distinguish_types=self.get_compute_stats_distinguish_types(), |
| 341 | + intensive=self.get_intensive(), |
| 342 | + ) |
| 343 | + self._store_out_stat(bias_out, std_out) |
| 344 | + else: |
| 345 | + raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode) |
| 346 | + |
| 347 | + def _store_out_stat( |
| 348 | + self, |
| 349 | + out_bias: dict[str, np.ndarray], |
| 350 | + out_std: dict[str, np.ndarray], |
| 351 | + add: bool = False, |
| 352 | + ) -> None: |
| 353 | + """Store output bias and std into the model.""" |
| 354 | + ntypes = self.get_ntypes() |
| 355 | + out_bias_data = np.array(to_numpy_array(self.out_bias)) |
| 356 | + out_std_data = np.array(to_numpy_array(self.out_std)) |
| 357 | + for kk in out_bias.keys(): |
| 358 | + assert kk in out_std.keys() |
| 359 | + idx = self._get_bias_index(kk) |
| 360 | + size = self._varsize(self.atomic_output_def()[kk].shape) |
| 361 | + if not add: |
| 362 | + out_bias_data[idx, :, :size] = out_bias[kk].reshape(ntypes, size) |
| 363 | + else: |
| 364 | + out_bias_data[idx, :, :size] += out_bias[kk].reshape(ntypes, size) |
| 365 | + out_std_data[idx, :, :size] = out_std[kk].reshape(ntypes, size) |
| 366 | + self.out_bias = out_bias_data |
| 367 | + self.out_std = out_std_data |
| 368 | + |
| 369 | + def _get_forward_wrapper_func(self) -> Callable[..., dict[str, np.ndarray]]: |
| 370 | + """Get a forward wrapper of the atomic model for output bias calculation.""" |
| 371 | + import array_api_compat |
| 372 | + |
| 373 | + from deepmd.dpmodel.utils.nlist import ( |
| 374 | + extend_input_and_build_neighbor_list, |
| 375 | + ) |
| 376 | + |
| 377 | + def model_forward( |
| 378 | + coord: np.ndarray, |
| 379 | + atype: np.ndarray, |
| 380 | + box: np.ndarray | None, |
| 381 | + fparam: np.ndarray | None = None, |
| 382 | + aparam: np.ndarray | None = None, |
| 383 | + ) -> dict[str, np.ndarray]: |
| 384 | + # Get reference array to determine the target array type and device |
| 385 | + # Use out_bias as reference since it's always present |
| 386 | + ref_array = self.out_bias |
| 387 | + xp = array_api_compat.array_namespace(ref_array) |
| 388 | + |
| 389 | + # Convert numpy inputs to the model's array type with correct device |
| 390 | + device = array_api_compat.device(ref_array) |
| 391 | + coord = xp.asarray(coord, device=device) |
| 392 | + atype = xp.asarray(atype, device=device) |
| 393 | + if box is not None: |
| 394 | + if np.allclose(box, 0.0): |
| 395 | + box = None |
| 396 | + else: |
| 397 | + box = xp.asarray(box, device=device) |
| 398 | + if fparam is not None: |
| 399 | + fparam = xp.asarray(fparam, device=device) |
| 400 | + if aparam is not None: |
| 401 | + aparam = xp.asarray(aparam, device=device) |
| 402 | + |
| 403 | + ( |
| 404 | + extended_coord, |
| 405 | + extended_atype, |
| 406 | + mapping, |
| 407 | + nlist, |
| 408 | + ) = extend_input_and_build_neighbor_list( |
| 409 | + coord, |
| 410 | + atype, |
| 411 | + self.get_rcut(), |
| 412 | + self.get_sel(), |
| 413 | + mixed_types=self.mixed_types(), |
| 414 | + box=box, |
| 415 | + ) |
| 416 | + atomic_ret = self.forward_common_atomic( |
| 417 | + extended_coord, |
| 418 | + extended_atype, |
| 419 | + nlist, |
| 420 | + mapping=mapping, |
| 421 | + fparam=fparam, |
| 422 | + aparam=aparam, |
| 423 | + ) |
| 424 | + # Convert outputs back to numpy arrays |
| 425 | + return {kk: to_numpy_array(vv) for kk, vv in atomic_ret.items()} |
| 426 | + |
| 427 | + return model_forward |
| 428 | + |
249 | 429 | def serialize(self) -> dict: |
250 | 430 | return { |
251 | 431 | "type_map": self.type_map, |
|
0 commit comments