Skip to content

Commit fa03351

Browse files
author
Han Wang
committed
simplify init method of se_e2_a descriptor. fig bug in consistent UT
1 parent fb9598a commit fa03351

File tree

2 files changed

+1
-26
lines changed

2 files changed

+1
-26
lines changed

deepmd/pt_expt/descriptor/se_e2_a.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class DescrptSeA(DescrptSeADP, torch.nn.Module):
2727
def __init__(self, *args: Any, **kwargs: Any) -> None:
2828
torch.nn.Module.__init__(self)
2929
DescrptSeADP.__init__(self, *args, **kwargs)
30-
self._convert_state()
3130

3231
def __setattr__(self, name: str, value: Any) -> None:
3332
if name in {"davg", "dstd"} and "_buffers" in self.__dict__:
@@ -53,30 +52,6 @@ def __setattr__(self, name: str, value: Any) -> None:
5352
return super().__setattr__(name, value)
5453
return super().__setattr__(name, value)
5554

56-
def _convert_state(self) -> None:
57-
if self.davg is not None:
58-
davg = torch.as_tensor(self.davg, device=env.DEVICE)
59-
if "davg" in self._buffers:
60-
self._buffers["davg"] = davg
61-
else:
62-
if hasattr(self, "davg"):
63-
delattr(self, "davg")
64-
self.register_buffer("davg", davg)
65-
if self.dstd is not None:
66-
dstd = torch.as_tensor(self.dstd, device=env.DEVICE)
67-
if "dstd" in self._buffers:
68-
self._buffers["dstd"] = dstd
69-
else:
70-
if hasattr(self, "dstd"):
71-
delattr(self, "dstd")
72-
self.register_buffer("dstd", dstd)
73-
if self.embeddings is not None:
74-
self.embeddings = NetworkCollection.deserialize(self.embeddings.serialize())
75-
if self.emask is not None:
76-
self.emask = PairExcludeMask(
77-
self.ntypes, exclude_types=list(self.emask.get_exclude_types())
78-
)
79-
8055
def forward(
8156
self,
8257
nlist: torch.Tensor,

source/tests/consistent/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class CommonTest(ABC):
9292
"""Native DP model class."""
9393
pt_class: ClassVar[type | None]
9494
"""PyTorch model class."""
95-
pt_expt_class: ClassVar[type | None]
95+
pt_expt_class: ClassVar[type | None] = None
9696
"""PyTorch exportable model class."""
9797
jax_class: ClassVar[type | None]
9898
"""JAX model class."""

0 commit comments

Comments
 (0)