-
Notifications
You must be signed in to change notification settings - Fork 589
feat: new backend pytorch exportable. #5194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 7 commits
ec2e031
b8a48ff
1cc001f
f2fbe88
e2afbe9
4ba511a
fb9598a
fa03351
09b33f1
67f2e54
f7d83dd
0c96bb6
9dca912
17f0a5d
8ce93ba
3091988
85f0583
d33324d
4de9a56
f4dc0af
2384835
9646d71
f270069
57433d3
eedcbaf
d8b2cf4
aeef15a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| from collections.abc import ( | ||
| Callable, | ||
| ) | ||
| from importlib.util import ( | ||
| find_spec, | ||
| ) | ||
| from typing import ( | ||
| TYPE_CHECKING, | ||
| ClassVar, | ||
| ) | ||
|
|
||
| from deepmd.backend.backend import ( | ||
| Backend, | ||
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from argparse import ( | ||
| Namespace, | ||
| ) | ||
|
|
||
| from deepmd.infer.deep_eval import ( | ||
| DeepEvalBackend, | ||
| ) | ||
| from deepmd.utils.neighbor_stat import ( | ||
| NeighborStat, | ||
| ) | ||
|
|
||
|
|
||
| @Backend.register("pt-expt") | ||
| @Backend.register("pytorch-exportable") | ||
| class PyTorchExportableBackend(Backend): | ||
| """PyTorch exportable backend.""" | ||
|
|
||
| name = "PyTorch Exportable" | ||
| """The formal name of the backend.""" | ||
| features: ClassVar[Backend.Feature] = ( | ||
| Backend.Feature.ENTRY_POINT | ||
| | Backend.Feature.DEEP_EVAL | ||
| | Backend.Feature.NEIGHBOR_STAT | ||
| | Backend.Feature.IO | ||
| ) | ||
| """The features of the backend.""" | ||
| suffixes: ClassVar[list[str]] = [".pth", ".pt"] | ||
| """The suffixes of the backend.""" | ||
wanghan-iapcm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def is_available(self) -> bool: | ||
| """Check if the backend is available. | ||
|
|
||
| Returns | ||
| ------- | ||
| bool | ||
| Whether the backend is available. | ||
| """ | ||
| return find_spec("torch") is not None | ||
|
|
||
| @property | ||
| def entry_point_hook(self) -> Callable[["Namespace"], None]: | ||
| """The entry point hook of the backend. | ||
|
|
||
| Returns | ||
| ------- | ||
| Callable[[Namespace], None] | ||
| The entry point hook of the backend. | ||
| """ | ||
| from deepmd.pt.entrypoints.main import main as deepmd_main | ||
|
|
||
| return deepmd_main | ||
|
|
||
| @property | ||
| def deep_eval(self) -> type["DeepEvalBackend"]: | ||
| """The Deep Eval backend of the backend. | ||
|
|
||
| Returns | ||
| ------- | ||
| type[DeepEvalBackend] | ||
| The Deep Eval backend of the backend. | ||
| """ | ||
| from deepmd.pt.infer.deep_eval import DeepEval as DeepEvalPT | ||
|
|
||
| return DeepEvalPT | ||
|
|
||
| @property | ||
| def neighbor_stat(self) -> type["NeighborStat"]: | ||
| """The neighbor statistics of the backend. | ||
|
|
||
| Returns | ||
| ------- | ||
| type[NeighborStat] | ||
| The neighbor statistics of the backend. | ||
| """ | ||
| from deepmd.pt.utils.neighbor_stat import ( | ||
| NeighborStat, | ||
| ) | ||
|
|
||
| return NeighborStat | ||
|
|
||
| @property | ||
| def serialize_hook(self) -> Callable[[str], dict]: | ||
| """The serialize hook to convert the model file to a dictionary. | ||
|
|
||
| Returns | ||
| ------- | ||
| Callable[[str], dict] | ||
| The serialize hook of the backend. | ||
| """ | ||
| from deepmd.pt.utils.serialization import ( | ||
| serialize_from_file, | ||
| ) | ||
|
|
||
| return serialize_from_file | ||
|
|
||
| @property | ||
| def deserialize_hook(self) -> Callable[[str, dict], None]: | ||
| """The deserialize hook to convert the dictionary to a model file. | ||
|
|
||
| Returns | ||
| ------- | ||
| Callable[[str, dict], None] | ||
| The deserialize hook of the backend. | ||
| """ | ||
| from deepmd.pt.utils.serialization import ( | ||
| deserialize_to_file, | ||
| ) | ||
|
|
||
| return deserialize_to_file | ||
wanghan-iapcm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| import importlib | ||
| from typing import ( | ||
| Any, | ||
| overload, | ||
| ) | ||
|
|
||
| import numpy as np | ||
|
|
||
| from deepmd.pt_expt.utils import ( | ||
| env, | ||
| ) | ||
|
|
||
| torch = importlib.import_module("torch") | ||
wanghan-iapcm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @overload | ||
| def to_torch_array(array: np.ndarray) -> torch.Tensor: ... | ||
Check noticeCode scanning / CodeQL Statement has no effect Note
This statement has no effect.
|
||
|
|
||
|
|
||
| @overload | ||
| def to_torch_array(array: None) -> None: ... | ||
Check noticeCode scanning / CodeQL Statement has no effect Note
This statement has no effect.
|
||
|
|
||
|
|
||
| @overload | ||
| def to_torch_array(array: torch.Tensor) -> torch.Tensor: ... | ||
Check noticeCode scanning / CodeQL Statement has no effect Note
This statement has no effect.
|
||
|
|
||
|
|
||
| def to_torch_array(array: Any) -> torch.Tensor | None: | ||
| """Convert input to a torch tensor on the pt-expt device.""" | ||
| if array is None: | ||
| return None | ||
| if torch.is_tensor(array): | ||
| return array | ||
| return torch.as_tensor(array, device=env.DEVICE) | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| from .base_descriptor import ( | ||
| BaseDescriptor, | ||
| ) | ||
| from .se_e2_a import ( | ||
| DescrptSeA, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "BaseDescriptor", | ||
| "DescrptSeA", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| import importlib | ||
|
|
||
| from deepmd.dpmodel.descriptor import ( | ||
| make_base_descriptor, | ||
| ) | ||
|
|
||
| torch = importlib.import_module("torch") | ||
|
|
||
| BaseDescriptor = make_base_descriptor(torch.Tensor, "forward") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| import importlib | ||
| from typing import ( | ||
| Any, | ||
| ) | ||
|
|
||
| from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP | ||
| from deepmd.pt_expt.descriptor.base_descriptor import ( | ||
| BaseDescriptor, | ||
| ) | ||
| from deepmd.pt_expt.utils import ( | ||
| env, | ||
| ) | ||
| from deepmd.pt_expt.utils.exclude_mask import ( | ||
| PairExcludeMask, | ||
| ) | ||
| from deepmd.pt_expt.utils.network import ( | ||
| NetworkCollection, | ||
| ) | ||
|
|
||
| torch = importlib.import_module("torch") | ||
|
|
||
|
|
||
| @BaseDescriptor.register("se_e2_a_expt") | ||
| @BaseDescriptor.register("se_a_expt") | ||
| class DescrptSeA(DescrptSeADP, torch.nn.Module): | ||
| def __init__(self, *args: Any, **kwargs: Any) -> None: | ||
| torch.nn.Module.__init__(self) | ||
| DescrptSeADP.__init__(self, *args, **kwargs) | ||
| self._convert_state() | ||
|
|
||
| def __setattr__(self, name: str, value: Any) -> None: | ||
Check noticeCode scanning / CodeQL Explicit returns mixed with implicit (fall through) returns Note
Mixing implicit and explicit returns may indicate an error, as implicit returns always return None.
|
||
| if name in {"davg", "dstd"} and "_buffers" in self.__dict__: | ||
| tensor = ( | ||
| None if value is None else torch.as_tensor(value, device=env.DEVICE) | ||
| ) | ||
| if name in self._buffers: | ||
| self._buffers[name] = tensor | ||
| return | ||
| return super().__setattr__(name, tensor) | ||
wanghan-iapcm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
coderabbitai[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if name == "embeddings" and "_modules" in self.__dict__: | ||
| if value is not None and not isinstance(value, torch.nn.Module): | ||
| if hasattr(value, "serialize"): | ||
| value = NetworkCollection.deserialize(value.serialize()) | ||
| elif isinstance(value, dict): | ||
| value = NetworkCollection.deserialize(value) | ||
| return super().__setattr__(name, value) | ||
| if name == "emask" and "_modules" in self.__dict__: | ||
| if value is not None and not isinstance(value, torch.nn.Module): | ||
| value = PairExcludeMask( | ||
| self.ntypes, exclude_types=list(value.get_exclude_types()) | ||
| ) | ||
| return super().__setattr__(name, value) | ||
| return super().__setattr__(name, value) | ||
wanghan-iapcm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def _convert_state(self) -> None: | ||
| if self.davg is not None: | ||
| davg = torch.as_tensor(self.davg, device=env.DEVICE) | ||
| if "davg" in self._buffers: | ||
| self._buffers["davg"] = davg | ||
| else: | ||
| if hasattr(self, "davg"): | ||
| delattr(self, "davg") | ||
| self.register_buffer("davg", davg) | ||
| if self.dstd is not None: | ||
| dstd = torch.as_tensor(self.dstd, device=env.DEVICE) | ||
| if "dstd" in self._buffers: | ||
| self._buffers["dstd"] = dstd | ||
| else: | ||
| if hasattr(self, "dstd"): | ||
| delattr(self, "dstd") | ||
| self.register_buffer("dstd", dstd) | ||
| if self.embeddings is not None: | ||
| self.embeddings = NetworkCollection.deserialize(self.embeddings.serialize()) | ||
| if self.emask is not None: | ||
| self.emask = PairExcludeMask( | ||
| self.ntypes, exclude_types=list(self.emask.get_exclude_types()) | ||
| ) | ||
|
|
||
| def forward( | ||
| self, | ||
| nlist: torch.Tensor, | ||
| extended_coord: torch.Tensor, | ||
| extended_atype: torch.Tensor, | ||
wanghan-iapcm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| extended_atype_embd: torch.Tensor | None = None, | ||
| mapping: torch.Tensor | None = None, | ||
| type_embedding: torch.Tensor | None = None, | ||
| ) -> tuple[ | ||
| torch.Tensor, | ||
| torch.Tensor | None, | ||
| torch.Tensor | None, | ||
| torch.Tensor | None, | ||
| torch.Tensor | None, | ||
| ]: | ||
| del extended_atype_embd, type_embedding | ||
| descrpt, rot_mat, g2, h2, sw = self.call( | ||
| extended_coord, | ||
| extended_atype, | ||
| nlist, | ||
| mapping=mapping, | ||
| ) | ||
| return descrpt, rot_mat, g2, h2, sw | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
|
||
| from .exclude_mask import ( | ||
| AtomExcludeMask, | ||
| PairExcludeMask, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "AtomExcludeMask", | ||
| "PairExcludeMask", | ||
| ] |
Uh oh!
There was an error while loading. Please reload this page.