Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions deepmd/backend/pt_expt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from collections.abc import (
Callable,
)
from importlib.util import (
find_spec,
)
from typing import (
TYPE_CHECKING,
ClassVar,
)

from deepmd.backend.backend import (
Backend,
)

if TYPE_CHECKING:
from argparse import (
Namespace,
)

from deepmd.infer.deep_eval import (
DeepEvalBackend,
)
from deepmd.utils.neighbor_stat import (
NeighborStat,
)


@Backend.register("pt-expt")
@Backend.register("pytorch-exportable")
class PyTorchExportableBackend(Backend):
"""PyTorch exportable backend."""

name = "PyTorch Exportable"
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = (
Backend.Feature.ENTRY_POINT
| Backend.Feature.DEEP_EVAL
| Backend.Feature.NEIGHBOR_STAT
| Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = [".pth", ".pt"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
"""Check if the backend is available.

Returns
-------
bool
Whether the backend is available.
"""
return find_spec("torch") is not None

@property
def entry_point_hook(self) -> Callable[["Namespace"], None]:
"""The entry point hook of the backend.

Returns
-------
Callable[[Namespace], None]
The entry point hook of the backend.
"""
from deepmd.pt.entrypoints.main import main as deepmd_main

return deepmd_main

@property
def deep_eval(self) -> type["DeepEvalBackend"]:
"""The Deep Eval backend of the backend.

Returns
-------
type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
from deepmd.pt.infer.deep_eval import DeepEval as DeepEvalPT

return DeepEvalPT

@property
def neighbor_stat(self) -> type["NeighborStat"]:
"""The neighbor statistics of the backend.

Returns
-------
type[NeighborStat]
The neighbor statistics of the backend.
"""
from deepmd.pt.utils.neighbor_stat import (
NeighborStat,
)

return NeighborStat

@property
def serialize_hook(self) -> Callable[[str], dict]:
"""The serialize hook to convert the model file to a dictionary.

Returns
-------
Callable[[str], dict]
The serialize hook of the backend.
"""
from deepmd.pt.utils.serialization import (
serialize_from_file,
)

return serialize_from_file

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
"""The deserialize hook to convert the dictionary to a model file.

Returns
-------
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
from deepmd.pt.utils.serialization import (
deserialize_to_file,
)

return deserialize_to_file
6 changes: 5 additions & 1 deletion deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,11 @@ def call(
sec = self.sel_cumsum

ng = self.neuron[-1]
gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype)
gr = xp.zeros(
[nf * nloc, ng, 4],
dtype=input_dtype,
device=array_api_compat.device(coord_ext),
)
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# merge nf and nloc axis, so for type_one_side == False,
# we don't require atype is the same in all frames
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt_expt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
12 changes: 12 additions & 0 deletions deepmd/pt_expt/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .base_descriptor import (
BaseDescriptor,
)
from .se_e2_a import (
DescrptSeA,
)

__all__ = [
"BaseDescriptor",
"DescrptSeA",
]
10 changes: 10 additions & 0 deletions deepmd/pt_expt/descriptor/base_descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import importlib

from deepmd.dpmodel.descriptor import (
make_base_descriptor,
)

torch = importlib.import_module("torch")

BaseDescriptor = make_base_descriptor(torch.Tensor, "forward")
102 changes: 102 additions & 0 deletions deepmd/pt_expt/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import importlib
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP
from deepmd.pt_expt.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.pt_expt.utils import (
env,
)
from deepmd.pt_expt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt_expt.utils.network import (
NetworkCollection,
)

torch = importlib.import_module("torch")


@BaseDescriptor.register("se_e2_a_expt")
@BaseDescriptor.register("se_a_expt")
class DescrptSeA(DescrptSeADP, torch.nn.Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
torch.nn.Module.__init__(self)
DescrptSeADP.__init__(self, *args, **kwargs)
self._convert_state()

def __setattr__(self, name: str, value: Any) -> None:

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error, as implicit returns always return None.
if name in {"davg", "dstd"} and "_buffers" in self.__dict__:
tensor = (
None if value is None else torch.as_tensor(value, device=env.DEVICE)
)
if name in self._buffers:
self._buffers[name] = tensor
return
return super().__setattr__(name, tensor)
if name == "embeddings" and "_modules" in self.__dict__:
if value is not None and not isinstance(value, torch.nn.Module):
if hasattr(value, "serialize"):
value = NetworkCollection.deserialize(value.serialize())
elif isinstance(value, dict):
value = NetworkCollection.deserialize(value)
return super().__setattr__(name, value)
if name == "emask" and "_modules" in self.__dict__:
if value is not None and not isinstance(value, torch.nn.Module):
value = PairExcludeMask(
self.ntypes, exclude_types=list(value.get_exclude_types())
)
return super().__setattr__(name, value)
return super().__setattr__(name, value)

def _convert_state(self) -> None:
if self.davg is not None:
davg = torch.as_tensor(self.davg, device=env.DEVICE)
if "davg" in self._buffers:
self._buffers["davg"] = davg
else:
if hasattr(self, "davg"):
delattr(self, "davg")
self.register_buffer("davg", davg)
if self.dstd is not None:
dstd = torch.as_tensor(self.dstd, device=env.DEVICE)
if "dstd" in self._buffers:
self._buffers["dstd"] = dstd
else:
if hasattr(self, "dstd"):
delattr(self, "dstd")
self.register_buffer("dstd", dstd)
if self.embeddings is not None:
self.embeddings = NetworkCollection.deserialize(self.embeddings.serialize())
if self.emask is not None:
self.emask = PairExcludeMask(
self.ntypes, exclude_types=list(self.emask.get_exclude_types())
)

def forward(
self,
nlist: torch.Tensor,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
extended_atype_embd: torch.Tensor | None = None,
mapping: torch.Tensor | None = None,
type_embedding: torch.Tensor | None = None,
) -> tuple[
torch.Tensor,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
]:
del extended_atype_embd, type_embedding
descrpt, rot_mat, g2, h2, sw = self.call(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
)
return descrpt, rot_mat, g2, h2, sw
11 changes: 11 additions & 0 deletions deepmd/pt_expt/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

from .exclude_mask import (
AtomExcludeMask,
PairExcludeMask,
)

__all__ = [
"AtomExcludeMask",
"PairExcludeMask",
]
117 changes: 117 additions & 0 deletions deepmd/pt_expt/utils/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import importlib
import logging
import multiprocessing
import os
import sys

import numpy as np

from deepmd.common import (
VALID_PRECISION,
)
from deepmd.env import (
GLOBAL_ENER_FLOAT_PRECISION,
GLOBAL_NP_FLOAT_PRECISION,
get_default_nthreads,
set_default_nthreads,
)

log = logging.getLogger(__name__)
torch = importlib.import_module("torch")

if sys.platform != "win32":
try:
multiprocessing.set_start_method("fork", force=True)
log.debug("Successfully set multiprocessing start method to 'fork'.")
except (RuntimeError, ValueError) as err:
log.warning(f"Could not set multiprocessing start method: {err}")
else:
log.debug("Skipping fork start method on Windows (not supported).")

SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False)
DP_DTYPE_PROMOTION_STRICT = os.environ.get("DP_DTYPE_PROMOTION_STRICT", "0") == "1"

Check notice

Code scanning / CodeQL

Unused global variable Note

The global variable 'DP_DTYPE_PROMOTION_STRICT' is not used.
try:
# only linux
ncpus = len(os.sched_getaffinity(0))
except AttributeError:
ncpus = os.cpu_count()
NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(4, ncpus)))
if multiprocessing.get_start_method() != "fork":
# spawn or forkserver does not support NUM_WORKERS > 0 for DataLoader
log.warning(
"NUM_WORKERS > 0 is not supported with spawn or forkserver start method. "
"Setting NUM_WORKERS to 0."
)
NUM_WORKERS = 0

# Make sure DDP uses correct device if applicable
LOCAL_RANK = os.environ.get("LOCAL_RANK")
LOCAL_RANK = int(0 if LOCAL_RANK is None else LOCAL_RANK)

if os.environ.get("DEVICE") == "cpu" or torch.cuda.is_available() is False:
DEVICE = torch.device("cpu")
else:
DEVICE = torch.device(f"cuda:{LOCAL_RANK}")

JIT = False
CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory
ENERGY_BIAS_TRAINABLE = True
CUSTOM_OP_USE_JIT = False

PRECISION_DICT = {
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
"half": torch.float16,
"single": torch.float32,
"double": torch.float64,
"int32": torch.int32,
"int64": torch.int64,
"bfloat16": torch.bfloat16,
"bool": torch.bool,
}
GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name]
GLOBAL_PT_ENER_FLOAT_PRECISION = PRECISION_DICT[
np.dtype(GLOBAL_ENER_FLOAT_PRECISION).name
]
PRECISION_DICT["default"] = GLOBAL_PT_FLOAT_PRECISION
assert VALID_PRECISION.issubset(PRECISION_DICT.keys())
# cannot automatically generated
RESERVED_PRECISION_DICT = {
torch.float16: "float16",
torch.float32: "float32",
torch.float64: "float64",
torch.int32: "int32",
torch.int64: "int64",
torch.bfloat16: "bfloat16",
torch.bool: "bool",
}
assert set(PRECISION_DICT.values()) == set(RESERVED_PRECISION_DICT.keys())
DEFAULT_PRECISION = "float64"

# throw warnings if threads not set
set_default_nthreads()
inter_nthreads, intra_nthreads = get_default_nthreads()
if inter_nthreads > 0: # the behavior of 0 is not documented
torch.set_num_interop_threads(inter_nthreads)
if intra_nthreads > 0:
torch.set_num_threads(intra_nthreads)

__all__ = [
"CACHE_PER_SYS",
"CUSTOM_OP_USE_JIT",
"DEFAULT_PRECISION",
"DEVICE",
"ENERGY_BIAS_TRAINABLE",
"GLOBAL_ENER_FLOAT_PRECISION",
"GLOBAL_NP_FLOAT_PRECISION",
"GLOBAL_PT_ENER_FLOAT_PRECISION",
"GLOBAL_PT_FLOAT_PRECISION",
"JIT",
"LOCAL_RANK",
"NUM_WORKERS",
"PRECISION_DICT",
"RESERVED_PRECISION_DICT",
"SAMPLER_RECORD",
]
Loading
Loading