Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions deepmd/backend/pt_expt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from collections.abc import (
Callable,
)
from importlib.util import (
find_spec,
)
from typing import (
TYPE_CHECKING,
ClassVar,
)

from deepmd.backend.backend import (
Backend,
)

if TYPE_CHECKING:
from argparse import (
Namespace,
)

from deepmd.infer.deep_eval import (
DeepEvalBackend,
)
from deepmd.utils.neighbor_stat import (
NeighborStat,
)


@Backend.register("pt-expt")
@Backend.register("pytorch-exportable")
class PyTorchExportableBackend(Backend):
"""PyTorch exportable backend."""

name = "PyTorch Exportable"
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = (
Backend.Feature.ENTRY_POINT
| Backend.Feature.DEEP_EVAL
| Backend.Feature.NEIGHBOR_STAT
| Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = [".pte"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
"""Check if the backend is available.

Returns
-------
bool
Whether the backend is available.
"""
return find_spec("torch") is not None

@property
def entry_point_hook(self) -> Callable[["Namespace"], None]:
"""The entry point hook of the backend.

Returns
-------
Callable[[Namespace], None]
The entry point hook of the backend.
"""
from deepmd.pt.entrypoints.main import main as deepmd_main

return deepmd_main

@property
def deep_eval(self) -> type["DeepEvalBackend"]:
"""The Deep Eval backend of the backend.

Returns
-------
type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
from deepmd.pt.infer.deep_eval import DeepEval as DeepEvalPT

return DeepEvalPT

@property
def neighbor_stat(self) -> type["NeighborStat"]:
"""The neighbor statistics of the backend.

Returns
-------
type[NeighborStat]
The neighbor statistics of the backend.
"""
from deepmd.pt.utils.neighbor_stat import (
NeighborStat,
)

return NeighborStat

@property
def serialize_hook(self) -> Callable[[str], dict]:
"""The serialize hook to convert the model file to a dictionary.

Returns
-------
Callable[[str], dict]
The serialize hook of the backend.
"""
from deepmd.pt.utils.serialization import (
serialize_from_file,
)

return serialize_from_file

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
"""The deserialize hook to convert the dictionary to a model file.

Returns
-------
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
from deepmd.pt.utils.serialization import (
deserialize_to_file,
)

return deserialize_to_file
6 changes: 5 additions & 1 deletion deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,11 @@ def call(
sec = self.sel_cumsum

ng = self.neuron[-1]
gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype)
gr = xp.zeros(
[nf * nloc, ng, 4],
dtype=input_dtype,
device=array_api_compat.device(coord_ext),
)
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# merge nf and nloc axis, so for type_one_side == False,
# we don't require atype is the same in all frames
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,9 @@ def call(

ng = self.neuron[-1]
xyz_scatter = xp.zeros(
[nf, nloc, ng], dtype=get_xp_precision(xp, self.precision)
[nf, nloc, ng],
dtype=get_xp_precision(xp, self.precision),
device=array_api_compat.device(coord_ext),
)
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
rr = xp.astype(rr, xyz_scatter.dtype)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_default_nthreads() -> tuple[int, int]:
), int(
os.environ.get(
"DP_INTER_OP_PARALLELISM_THREADS",
os.environ.get("TF_INTRA_OP_PARALLELISM_THREADS", "0"),
os.environ.get("TF_INTER_OP_PARALLELISM_THREADS", "0"),
)
)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@

# throw warnings if threads not set
set_default_nthreads()
inter_nthreads, intra_nthreads = get_default_nthreads()
intra_nthreads, inter_nthreads = get_default_nthreads()
if inter_nthreads > 0: # the behavior of 0 is not documented
torch.set_num_interop_threads(inter_nthreads)
if intra_nthreads > 0:
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt_expt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
35 changes: 35 additions & 0 deletions deepmd/pt_expt/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import importlib
from typing import (
Any,
overload,
)

import numpy as np

from deepmd.pt_expt.utils import (
env,
)

torch = importlib.import_module("torch")


@overload
def to_torch_array(array: np.ndarray) -> torch.Tensor: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


@overload
def to_torch_array(array: None) -> None: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


@overload
def to_torch_array(array: torch.Tensor) -> torch.Tensor: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


def to_torch_array(array: Any) -> torch.Tensor | None:
"""Convert input to a torch tensor on the pt-expt device."""
if array is None:
return None
if torch.is_tensor(array):
return array.to(device=env.DEVICE)
return torch.as_tensor(array, device=env.DEVICE)
16 changes: 16 additions & 0 deletions deepmd/pt_expt/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .base_descriptor import (
BaseDescriptor,
)
from .se_e2_a import (
DescrptSeA,
)
from .se_r import (
DescrptSeR,
)

__all__ = [
"BaseDescriptor",
"DescrptSeA",
"DescrptSeR",
]
10 changes: 10 additions & 0 deletions deepmd/pt_expt/descriptor/base_descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import importlib

from deepmd.dpmodel.descriptor import (
make_base_descriptor,
)

torch = importlib.import_module("torch")

BaseDescriptor = make_base_descriptor(torch.Tensor, "forward")
83 changes: 83 additions & 0 deletions deepmd/pt_expt/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import importlib
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP
from deepmd.pt_expt.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.pt_expt.utils import (
env,
)
from deepmd.pt_expt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt_expt.utils.network import (
NetworkCollection,
)

torch = importlib.import_module("torch")


@BaseDescriptor.register("se_e2_a_expt")
@BaseDescriptor.register("se_a_expt")
class DescrptSeA(DescrptSeADP, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
DescrptSeADP.__init__(self, *args, **kwargs)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Ensure torch.nn.Module.__call__ drives forward() for export/tracing.
return torch.nn.Module.__call__(self, *args, **kwargs)

def __setattr__(self, name: str, value: Any) -> None:

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error, as implicit returns always return None.
if name in {"davg", "dstd"} and "_buffers" in self.__dict__:
tensor = (
None if value is None else torch.as_tensor(value, device=env.DEVICE)
)
if name in self._buffers:
self._buffers[name] = tensor
return
# Register on first assignment so buffers are in state_dict and moved by .to().
self.register_buffer(name, tensor)
return
if name == "embeddings" and "_modules" in self.__dict__:
if value is not None and not isinstance(value, torch.nn.Module):
if hasattr(value, "serialize"):
value = NetworkCollection.deserialize(value.serialize())
elif isinstance(value, dict):
value = NetworkCollection.deserialize(value)
return super().__setattr__(name, value)
if name == "emask" and "_modules" in self.__dict__:
if value is not None and not isinstance(value, torch.nn.Module):
value = PairExcludeMask(
self.ntypes, exclude_types=list(value.get_exclude_types())
)
return super().__setattr__(name, value)
return super().__setattr__(name, value)

def forward(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
extended_atype_embd: torch.Tensor | None = None,
mapping: torch.Tensor | None = None,
type_embedding: torch.Tensor | None = None,
) -> tuple[
torch.Tensor,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
]:
del extended_atype_embd, type_embedding
descrpt, rot_mat, g2, h2, sw = self.call(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
)
return descrpt, rot_mat, g2, h2, sw
Loading
Loading