@@ -27,7 +27,6 @@ class DescrptSeA(DescrptSeADP, torch.nn.Module):
2727 def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
2828 torch .nn .Module .__init__ (self )
2929 DescrptSeADP .__init__ (self , * args , ** kwargs )
30- self ._convert_state ()
3130
3231 def __setattr__ (self , name : str , value : Any ) -> None :
3332 if name in {"davg" , "dstd" } and "_buffers" in self .__dict__ :
@@ -53,30 +52,6 @@ def __setattr__(self, name: str, value: Any) -> None:
5352 return super ().__setattr__ (name , value )
5453 return super ().__setattr__ (name , value )
5554
56- def _convert_state (self ) -> None :
57- if self .davg is not None :
58- davg = torch .as_tensor (self .davg , device = env .DEVICE )
59- if "davg" in self ._buffers :
60- self ._buffers ["davg" ] = davg
61- else :
62- if hasattr (self , "davg" ):
63- delattr (self , "davg" )
64- self .register_buffer ("davg" , davg )
65- if self .dstd is not None :
66- dstd = torch .as_tensor (self .dstd , device = env .DEVICE )
67- if "dstd" in self ._buffers :
68- self ._buffers ["dstd" ] = dstd
69- else :
70- if hasattr (self , "dstd" ):
71- delattr (self , "dstd" )
72- self .register_buffer ("dstd" , dstd )
73- if self .embeddings is not None :
74- self .embeddings = NetworkCollection .deserialize (self .embeddings .serialize ())
75- if self .emask is not None :
76- self .emask = PairExcludeMask (
77- self .ntypes , exclude_types = list (self .emask .get_exclude_types ())
78- )
79-
8055 def forward (
8156 self ,
8257 nlist : torch .Tensor ,
0 commit comments