Skip to content

Commit 1cc001f

Browse files
author
Han Wang
committed
implement env, base_descriptor and exclude_mask, remove the dependency on pt backend.
1 parent b8a48ff commit 1cc001f

File tree

8 files changed

+187
-14
lines changed

8 files changed

+187
-14
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from .base_descriptor import (
3+
BaseDescriptor,
4+
)
25
from .se_e2_a import (
36
DescrptSeA,
47
)
58

69
__all__ = [
10+
"BaseDescriptor",
711
"DescrptSeA",
812
]
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import importlib
3+
4+
from deepmd.dpmodel.descriptor import (
5+
make_base_descriptor,
6+
)
7+
8+
torch = importlib.import_module("torch")
9+
10+
BaseDescriptor = make_base_descriptor(torch.Tensor, "forward")

deepmd/pt_expt/descriptor/se_e2_a.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import importlib
23
from typing import (
34
Any,
45
)
56

6-
import torch # noqa: TID253
7-
87
from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP
9-
from deepmd.pt.model.descriptor.base_descriptor import ( # noqa: TID253
8+
from deepmd.pt_expt.descriptor.base_descriptor import (
109
BaseDescriptor,
1110
)
12-
from deepmd.pt.utils import ( # noqa: TID253
11+
from deepmd.pt_expt.utils import (
1312
env,
1413
)
15-
from deepmd.pt.utils.exclude_mask import ( # noqa: TID253
14+
from deepmd.pt_expt.utils.exclude_mask import (
1615
PairExcludeMask,
1716
)
1817
from deepmd.pt_expt.utils.network import (
1918
NetworkCollection,
2019
)
2120

21+
torch = importlib.import_module("torch")
22+
2223

2324
@BaseDescriptor.register("se_e2_a_expt")
2425
@BaseDescriptor.register("se_a_expt")

deepmd/pt_expt/utils/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,11 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
3+
from .exclude_mask import (
4+
AtomExcludeMask,
5+
PairExcludeMask,
6+
)
7+
8+
__all__ = [
9+
"AtomExcludeMask",
10+
"PairExcludeMask",
11+
]

deepmd/pt_expt/utils/env.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import importlib
3+
import logging
4+
import multiprocessing
5+
import os
6+
import sys
7+
8+
import numpy as np
9+
10+
from deepmd.common import (
11+
VALID_PRECISION,
12+
)
13+
from deepmd.env import (
14+
GLOBAL_ENER_FLOAT_PRECISION,
15+
GLOBAL_NP_FLOAT_PRECISION,
16+
get_default_nthreads,
17+
set_default_nthreads,
18+
)
19+
20+
log = logging.getLogger(__name__)
21+
torch = importlib.import_module("torch")
22+
23+
if sys.platform != "win32":
24+
try:
25+
multiprocessing.set_start_method("fork", force=True)
26+
log.debug("Successfully set multiprocessing start method to 'fork'.")
27+
except (RuntimeError, ValueError) as err:
28+
log.warning(f"Could not set multiprocessing start method: {err}")
29+
else:
30+
log.debug("Skipping fork start method on Windows (not supported).")
31+
32+
SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False)
33+
DP_DTYPE_PROMOTION_STRICT = os.environ.get("DP_DTYPE_PROMOTION_STRICT", "0") == "1"
34+
try:
35+
# only linux
36+
ncpus = len(os.sched_getaffinity(0))
37+
except AttributeError:
38+
ncpus = os.cpu_count()
39+
NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(4, ncpus)))
40+
if multiprocessing.get_start_method() != "fork":
41+
# spawn or forkserver does not support NUM_WORKERS > 0 for DataLoader
42+
log.warning(
43+
"NUM_WORKERS > 0 is not supported with spawn or forkserver start method. "
44+
"Setting NUM_WORKERS to 0."
45+
)
46+
NUM_WORKERS = 0
47+
48+
# Make sure DDP uses correct device if applicable
49+
LOCAL_RANK = os.environ.get("LOCAL_RANK")
50+
LOCAL_RANK = int(0 if LOCAL_RANK is None else LOCAL_RANK)
51+
52+
if os.environ.get("DEVICE") == "cpu" or torch.cuda.is_available() is False:
53+
DEVICE = torch.device("cpu")
54+
else:
55+
DEVICE = torch.device(f"cuda:{LOCAL_RANK}")
56+
57+
JIT = False
58+
CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory
59+
ENERGY_BIAS_TRAINABLE = True
60+
CUSTOM_OP_USE_JIT = False
61+
62+
PRECISION_DICT = {
63+
"float16": torch.float16,
64+
"float32": torch.float32,
65+
"float64": torch.float64,
66+
"half": torch.float16,
67+
"single": torch.float32,
68+
"double": torch.float64,
69+
"int32": torch.int32,
70+
"int64": torch.int64,
71+
"bfloat16": torch.bfloat16,
72+
"bool": torch.bool,
73+
}
74+
GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name]
75+
GLOBAL_PT_ENER_FLOAT_PRECISION = PRECISION_DICT[
76+
np.dtype(GLOBAL_ENER_FLOAT_PRECISION).name
77+
]
78+
PRECISION_DICT["default"] = GLOBAL_PT_FLOAT_PRECISION
79+
assert VALID_PRECISION.issubset(PRECISION_DICT.keys())
80+
# cannot automatically generated
81+
RESERVED_PRECISION_DICT = {
82+
torch.float16: "float16",
83+
torch.float32: "float32",
84+
torch.float64: "float64",
85+
torch.int32: "int32",
86+
torch.int64: "int64",
87+
torch.bfloat16: "bfloat16",
88+
torch.bool: "bool",
89+
}
90+
assert set(PRECISION_DICT.values()) == set(RESERVED_PRECISION_DICT.keys())
91+
DEFAULT_PRECISION = "float64"
92+
93+
# throw warnings if threads not set
94+
set_default_nthreads()
95+
inter_nthreads, intra_nthreads = get_default_nthreads()
96+
if inter_nthreads > 0: # the behavior of 0 is not documented
97+
torch.set_num_interop_threads(inter_nthreads)
98+
if intra_nthreads > 0:
99+
torch.set_num_threads(intra_nthreads)
100+
101+
__all__ = [
102+
"CACHE_PER_SYS",
103+
"CUSTOM_OP_USE_JIT",
104+
"DEFAULT_PRECISION",
105+
"DEVICE",
106+
"ENERGY_BIAS_TRAINABLE",
107+
"GLOBAL_ENER_FLOAT_PRECISION",
108+
"GLOBAL_NP_FLOAT_PRECISION",
109+
"GLOBAL_PT_ENER_FLOAT_PRECISION",
110+
"GLOBAL_PT_FLOAT_PRECISION",
111+
"JIT",
112+
"LOCAL_RANK",
113+
"NUM_WORKERS",
114+
"PRECISION_DICT",
115+
"RESERVED_PRECISION_DICT",
116+
"SAMPLER_RECORD",
117+
]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import importlib
3+
from typing import (
4+
Any,
5+
)
6+
7+
from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP
8+
from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP
9+
from deepmd.pt_expt.utils import (
10+
env,
11+
)
12+
13+
torch = importlib.import_module("torch")
14+
15+
16+
class AtomExcludeMask(AtomExcludeMaskDP):
17+
def __setattr__(self, name: str, value: Any) -> None:
18+
if name == "type_mask":
19+
value = None if value is None else torch.as_tensor(value, device=env.DEVICE)
20+
return super().__setattr__(name, value)
21+
22+
23+
class PairExcludeMask(PairExcludeMaskDP):
24+
def __setattr__(self, name: str, value: Any) -> None:
25+
if name == "type_mask":
26+
value = None if value is None else torch.as_tensor(value, device=env.DEVICE)
27+
return super().__setattr__(name, value)

deepmd/pt_expt/utils/network.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import importlib
23
from typing import (
34
Any,
45
ClassVar,
56
Self,
67
)
78

89
import numpy as np
9-
import torch # noqa: TID253
1010

1111
from deepmd.dpmodel.common import (
1212
NativeOP,
@@ -19,10 +19,12 @@
1919
make_fitting_network,
2020
make_multilayer_network,
2121
)
22-
from deepmd.pt.utils import ( # noqa: TID253
22+
from deepmd.pt_expt.utils import (
2323
env,
2424
)
2525

26+
torch = importlib.import_module("torch")
27+
2628

2729
def _to_torch_array(value: Any) -> torch.Tensor | None:
2830
if value is None:

source/tests/pt_expt/model/test_se_e2_a.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import importlib
23
import itertools
34
import unittest
45

56
import numpy as np
6-
import torch # noqa: TID253
77

88
from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA
9-
from deepmd.pt.utils import ( # noqa: TID253
9+
from deepmd.pt_expt.descriptor.se_e2_a import (
10+
DescrptSeA,
11+
)
12+
from deepmd.pt_expt.utils import (
1013
env,
1114
)
12-
from deepmd.pt.utils.env import ( # noqa: TID253
15+
from deepmd.pt_expt.utils.env import (
1316
PRECISION_DICT,
1417
)
15-
from deepmd.pt.utils.exclude_mask import ( # noqa: TID253
18+
from deepmd.pt_expt.utils.exclude_mask import (
1619
PairExcludeMask,
1720
)
18-
from deepmd.pt_expt.descriptor.se_e2_a import (
19-
DescrptSeA,
20-
)
2121

2222
from ...pt.model.test_env_mat import (
2323
TestCaseSingleFrameWithNlist,
@@ -29,6 +29,8 @@
2929
GLOBAL_SEED,
3030
)
3131

32+
torch = importlib.import_module("torch")
33+
3234

3335
class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist):
3436
def setUp(self) -> None:

0 commit comments

Comments
 (0)