Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions deepmd/backend/pt_expt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from collections.abc import (
Callable,
)
from importlib.util import (
find_spec,
)
from typing import (
TYPE_CHECKING,
ClassVar,
)

from deepmd.backend.backend import (
Backend,
)

if TYPE_CHECKING:
from argparse import (
Namespace,
)

from deepmd.infer.deep_eval import (
DeepEvalBackend,
)
from deepmd.utils.neighbor_stat import (
NeighborStat,
)


@Backend.register("pt-expt")
@Backend.register("pytorch-exportable")
class PyTorchExportableBackend(Backend):
"""PyTorch exportable backend."""

name = "PyTorch Exportable"
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = (
Backend.Feature.ENTRY_POINT
| Backend.Feature.DEEP_EVAL
| Backend.Feature.NEIGHBOR_STAT
| Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = [".pte"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
"""Check if the backend is available.

Returns
-------
bool
Whether the backend is available.
"""
return find_spec("torch") is not None

@property
def entry_point_hook(self) -> Callable[["Namespace"], None]:
"""The entry point hook of the backend.

Returns
-------
Callable[[Namespace], None]
The entry point hook of the backend.
"""
from deepmd.pt.entrypoints.main import main as deepmd_main

return deepmd_main

@property
def deep_eval(self) -> type["DeepEvalBackend"]:
"""The Deep Eval backend of the backend.

Returns
-------
type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
raise NotImplementedError

@property
def neighbor_stat(self) -> type["NeighborStat"]:
"""The neighbor statistics of the backend.

Returns
-------
type[NeighborStat]
The neighbor statistics of the backend.
"""
raise NotImplementedError

@property
def serialize_hook(self) -> Callable[[str], dict]:
"""The serialize hook to convert the model file to a dictionary.

Returns
-------
Callable[[str], dict]
The serialize hook of the backend.
"""
raise NotImplementedError

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
"""The deserialize hook to convert the dictionary to a model file.

Returns
-------
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
raise NotImplementedError
9 changes: 7 additions & 2 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,9 +909,14 @@ def compute_input_stats(
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
xp = array_api_compat.array_namespace(self.stddev)
device = array_api_compat.device(self.stddev)
if not self.set_davg_zero:
self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True)
self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True)
self.mean = xp.asarray(
mean, dtype=self.mean.dtype, copy=True, device=device
)
self.stddev = xp.asarray(
stddev, dtype=self.stddev.dtype, copy=True, device=device
)

def get_stats(self) -> dict[str, StatItem]:
"""Get the statistics of the descriptor."""
Expand Down
9 changes: 7 additions & 2 deletions deepmd/dpmodel/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,14 @@ def compute_input_stats(
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
xp = array_api_compat.array_namespace(self.stddev)
device = array_api_compat.device(self.stddev)
if not self.set_davg_zero:
self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True)
self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True)
self.mean = xp.asarray(
mean, dtype=self.mean.dtype, copy=True, device=device
)
self.stddev = xp.asarray(
stddev, dtype=self.stddev.dtype, copy=True, device=device
)

def get_stats(self) -> dict[str, StatItem]:
"""Get the statistics of the descriptor."""
Expand Down
9 changes: 7 additions & 2 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,14 @@ def compute_input_stats(
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
xp = array_api_compat.array_namespace(self.stddev)
device = array_api_compat.device(self.stddev)
if not self.set_davg_zero:
self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True)
self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True)
self.mean = xp.asarray(
mean, dtype=self.mean.dtype, copy=True, device=device
)
self.stddev = xp.asarray(
stddev, dtype=self.stddev.dtype, copy=True, device=device
)

def get_stats(self) -> dict[str, StatItem]:
"""Get the statistics of the descriptor."""
Expand Down
13 changes: 10 additions & 3 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,12 @@ def compute_input_stats(
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
xp = array_api_compat.array_namespace(self.dstd)
device = array_api_compat.device(self.dstd)
if not self.set_davg_zero:
self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True)
self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True)
self.davg = xp.asarray(
mean, dtype=self.davg.dtype, copy=True, device=device
)
self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True, device=device)

def set_stat_mean_and_stddev(
self,
Expand Down Expand Up @@ -607,7 +610,11 @@ def call(
sec = self.sel_cumsum

ng = self.neuron[-1]
gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype)
gr = xp.zeros(
[nf * nloc, ng, 4],
dtype=input_dtype,
device=array_api_compat.device(coord_ext),
)
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# merge nf and nloc axis, so for type_one_side == False,
# we don't require atype is the same in all frames
Expand Down
11 changes: 8 additions & 3 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,12 @@ def compute_input_stats(
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
xp = array_api_compat.array_namespace(self.dstd)
device = array_api_compat.device(self.dstd)
if not self.set_davg_zero:
self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True)
self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True)
self.davg = xp.asarray(
mean, dtype=self.davg.dtype, copy=True, device=device
)
self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True, device=device)

def set_stat_mean_and_stddev(
self,
Expand Down Expand Up @@ -391,7 +394,9 @@ def call(

ng = self.neuron[-1]
xyz_scatter = xp.zeros(
[nf, nloc, ng], dtype=get_xp_precision(xp, self.precision)
[nf, nloc, ng],
dtype=get_xp_precision(xp, self.precision),
device=array_api_compat.device(coord_ext),
)
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
rr = xp.astype(rr, xyz_scatter.dtype)
Expand Down
7 changes: 5 additions & 2 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,12 @@ def compute_input_stats(
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
xp = array_api_compat.array_namespace(self.dstd)
device = array_api_compat.device(self.dstd)
if not self.set_davg_zero:
self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True)
self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True)
self.davg = xp.asarray(
mean, dtype=self.davg.dtype, copy=True, device=device
)
self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True, device=device)

def set_stat_mean_and_stddev(
self,
Expand Down
9 changes: 7 additions & 2 deletions deepmd/dpmodel/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,9 +694,14 @@ def compute_input_stats(
self.stats = env_mat_stat.stats
mean, stddev = env_mat_stat()
xp = array_api_compat.array_namespace(self.stddev)
device = array_api_compat.device(self.stddev)
if not self.set_davg_zero:
self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True)
self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True)
self.mean = xp.asarray(
mean, dtype=self.mean.dtype, copy=True, device=device
)
self.stddev = xp.asarray(
stddev, dtype=self.stddev.dtype, copy=True, device=device
)

def get_stats(self) -> dict[str, StatItem]:
"""Get the statistics of the descriptor."""
Expand Down
2 changes: 1 addition & 1 deletion deepmd/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_default_nthreads() -> tuple[int, int]:
), int(
os.environ.get(
"DP_INTER_OP_PARALLELISM_THREADS",
os.environ.get("TF_INTRA_OP_PARALLELISM_THREADS", "0"),
os.environ.get("TF_INTER_OP_PARALLELISM_THREADS", "0"),
)
)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@

# throw warnings if threads not set
set_default_nthreads()
inter_nthreads, intra_nthreads = get_default_nthreads()
intra_nthreads, inter_nthreads = get_default_nthreads()
if inter_nthreads > 0: # the behavior of 0 is not documented
torch.set_num_interop_threads(inter_nthreads)
if intra_nthreads > 0:
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt_expt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
35 changes: 35 additions & 0 deletions deepmd/pt_expt/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import importlib
from typing import (
Any,
overload,
)

import numpy as np

from deepmd.pt_expt.utils import (
env,
)

torch = importlib.import_module("torch")


@overload
def to_torch_array(array: np.ndarray) -> torch.Tensor: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


@overload
def to_torch_array(array: None) -> None: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


@overload
def to_torch_array(array: torch.Tensor) -> torch.Tensor: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


def to_torch_array(array: Any) -> torch.Tensor | None:
"""Convert input to a torch tensor on the pt-expt device."""
if array is None:
return None
if torch.is_tensor(array):
return array.to(device=env.DEVICE)
return torch.as_tensor(array, device=env.DEVICE)
16 changes: 16 additions & 0 deletions deepmd/pt_expt/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .base_descriptor import (
BaseDescriptor,
)
from .se_e2_a import (
DescrptSeA,
)
from .se_r import (
DescrptSeR,
)

__all__ = [
"BaseDescriptor",
"DescrptSeA",
"DescrptSeR",
]
10 changes: 10 additions & 0 deletions deepmd/pt_expt/descriptor/base_descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import importlib

from deepmd.dpmodel.descriptor import (
make_base_descriptor,
)

torch = importlib.import_module("torch")

BaseDescriptor = make_base_descriptor(torch.Tensor, "forward")
Loading