Skip to content

Commit 6ec59ef

Browse files
Copilotnjzjzpre-commit-ci[bot]
authored
style(tf): add type annotations (#4945)
Addressed review feedback to replace remaining `Any` type annotations with specific types across training, fitting, and model modules. ## Changes ### Learning Rate Types - `EnerFitting.get_loss()` and `StandardModel.get_loss()`: `lr: Any` → `lr: LearningRateExp` ### Trainer Method Signatures - `build()`: `data: dict` → `data: DeepmdDataSystem | None` - `_build_loss()`: `-> Any` → `-> tuple[None, None] | tuple[tf.Tensor, dict[str, tf.Tensor]]` - `_build_network()`: `data: dict` → `data: DeepmdDataSystem` - `get_global_step()`: `-> Any` → `-> int` ### Example ```python # Before def get_loss(self, loss: dict, lr: Any) -> Loss: ... # After def get_loss(self, loss: dict, lr: LearningRateExp) -> Loss: ... ``` The `_build_loss()` return type now accurately reflects that it returns either `(None, None)` when `stop_batch == 0`, or `(tf.Tensor, dict[str, tf.Tensor])` from the loss build. <!-- START COPILOT CODING AGENT TIPS --> --- ✨ Let Copilot coding agent [set things up for you](https://github.com/deepmodeling/deepmd-kit/issues/new?title=✨+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot) — coding agent works faster and does higher quality work when set up for your repo. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8787b45 commit 6ec59ef

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+1364
-784
lines changed

deepmd/tf/cluster/local.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
__all__ = ["get_gpus", "get_resource"]
1515

1616

17-
def get_gpus():
17+
def get_gpus() -> list[int] | None:
1818
"""Get available IDs of GPU cards at local.
1919
These IDs are valid when used as the TensorFlow device ID.
2020

deepmd/tf/common.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Union,
1515
)
1616

17+
import numpy as np
1718
import tensorflow
1819
from packaging.version import (
1920
Version,
@@ -118,7 +119,7 @@ def gelu_tf(x: tf.Tensor) -> tf.Tensor:
118119
https://arxiv.org/abs/1606.08415
119120
"""
120121

121-
def gelu_wrapper(x):
122+
def gelu_wrapper(x: tf.Tensor) -> tf.Tensor:
122123
try:
123124
return tensorflow.nn.gelu(x, approximate=True)
124125
except AttributeError:
@@ -146,16 +147,14 @@ def silu(x: tf.Tensor) -> tf.Tensor:
146147
return x * tf.sigmoid(x)
147148

148149

149-
def get_silut(activation_function: str = "silut"):
150-
import numpy as np
151-
152-
def sigmoid(x):
150+
def get_silut(activation_function: str = "silut") -> Callable[[tf.Tensor], tf.Tensor]:
151+
def sigmoid(x: float | np.ndarray) -> float | np.ndarray:
153152
return 1 / (1 + np.exp(-x))
154153

155-
def silu(x):
154+
def silu(x: float | np.ndarray) -> float | np.ndarray:
156155
return x * sigmoid(x)
157156

158-
def silu_grad(x):
157+
def silu_grad(x: float | np.ndarray) -> float | np.ndarray:
159158
sig = sigmoid(x)
160159
return sig + x * sig * (1 - sig)
161160

@@ -233,7 +232,7 @@ def get_activation_func(
233232
return ACTIVATION_FN_DICT[activation_fn.lower()]
234233

235234

236-
def get_precision(precision: "_PRECISION") -> Any:
235+
def get_precision(precision: "_PRECISION") -> tf.DType:
237236
"""Convert str to TF DType constant.
238237
239238
Parameters
@@ -319,7 +318,7 @@ def cast_precision(func: Callable) -> Callable:
319318
"""
320319

321320
@wraps(func)
322-
def wrapper(self, *args, **kwargs):
321+
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
323322
# only convert tensors
324323
returned_tensor = func(
325324
self,

deepmd/tf/descriptor/descriptor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@
33
abstractmethod,
44
)
55
from typing import (
6+
TYPE_CHECKING,
67
Any,
78
)
89

910
import numpy as np
1011

12+
if TYPE_CHECKING:
13+
from typing_extensions import (
14+
Self,
15+
)
16+
1117
from deepmd.common import (
1218
j_get_type,
1319
)
@@ -48,7 +54,7 @@ class Descriptor(PluginVariant, make_plugin_registry("descriptor")):
4854
that can be called by other classes.
4955
"""
5056

51-
def __new__(cls, *args, **kwargs):
57+
def __new__(cls, *args: Any, **kwargs: Any) -> "Self":
5258
if cls is Descriptor:
5359
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
5460
return super().__new__(cls)
@@ -132,7 +138,7 @@ def compute_input_stats(
132138
natoms_vec: list[np.ndarray],
133139
mesh: list[np.ndarray],
134140
input_dict: dict[str, list[np.ndarray]],
135-
**kwargs,
141+
**kwargs: Any,
136142
) -> None:
137143
"""Compute the statisitcs (avg and std) of the training data. The input will be
138144
normalized by the statistics.

deepmd/tf/descriptor/hybrid.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
list: list[Descriptor | dict[str, Any]],
4747
ntypes: int | None = None,
4848
spin: Spin | None = None,
49-
**kwargs,
49+
**kwargs: Any,
5050
) -> None:
5151
"""Constructor."""
5252
# warning: list is conflict with built-in list
@@ -142,7 +142,7 @@ def compute_input_stats(
142142
input_dict: dict,
143143
mixed_type: bool = False,
144144
real_natoms_vec: list | None = None,
145-
**kwargs,
145+
**kwargs: Any,
146146
) -> None:
147147
"""Compute the statisitcs (avg and std) of the training data. The input will be normalized by the statistics.
148148
@@ -182,7 +182,7 @@ def compute_input_stats(
182182
**kwargs,
183183
)
184184

185-
def merge_input_stats(self, stat_dict) -> None:
185+
def merge_input_stats(self, stat_dict: dict[str, float]) -> None:
186186
"""Merge the statisitcs computed from compute_input_stats to obtain the self.davg and self.dstd.
187187
188188
Parameters

deepmd/tf/descriptor/loc_frame.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Any,
4+
)
25

36
import numpy as np
47

@@ -58,7 +61,7 @@ def __init__(
5861
sel_a: list[int],
5962
sel_r: list[int],
6063
axis_rule: list[int],
61-
**kwargs,
64+
**kwargs: Any,
6265
) -> None:
6366
"""Constructor."""
6467
self.sel_a = sel_a
@@ -159,7 +162,7 @@ def compute_input_stats(
159162
natoms_vec: list,
160163
mesh: list,
161164
input_dict: dict,
162-
**kwargs,
165+
**kwargs: Any,
163166
) -> None:
164167
"""Compute the statisitcs (avg and std) of the training data. The input will be normalized by the statistics.
165168
@@ -372,8 +375,13 @@ def prod_force_virial(
372375
return force, virial, atom_virial
373376

374377
def _compute_dstats_sys_nonsmth(
375-
self, data_coord, data_box, data_atype, natoms_vec, mesh
376-
):
378+
self,
379+
data_coord: np.ndarray,
380+
data_box: np.ndarray,
381+
data_atype: np.ndarray,
382+
natoms_vec: np.ndarray,
383+
mesh: np.ndarray,
384+
) -> tuple[Any, Any]:
377385
dd_all = run_sess(
378386
self.sub_sess,
379387
self.stat_descrpt,
@@ -405,7 +413,7 @@ def _compute_dstats_sys_nonsmth(
405413
sysv2.append(sumv2)
406414
return sysv, sysv2, sysn
407415

408-
def _compute_std(self, sumv2, sumv, sumn):
416+
def _compute_std(self, sumv2: float, sumv: float, sumn: float) -> float:
409417
return np.sqrt(sumv2 / sumn - np.multiply(sumv / sumn, sumv / sumn))
410418

411419
def init_variables(

deepmd/tf/descriptor/se_a.py

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22

3+
from collections.abc import (
4+
Callable,
5+
)
6+
from typing import (
7+
Any,
8+
)
9+
310
import numpy as np
411

512
from deepmd.dpmodel.utils.env_mat import (
@@ -180,7 +187,7 @@ def __init__(
180187
tebd_input_mode: str = "concat",
181188
type_map: list[str] | None = None, # to be compat with input
182189
env_protection: float = 0.0, # not implement!!
183-
**kwargs,
190+
**kwargs: Any,
184191
) -> None:
185192
"""Constructor."""
186193
if rcut < rcut_smth:
@@ -346,7 +353,7 @@ def compute_input_stats(
346353
natoms_vec: list,
347354
mesh: list,
348355
input_dict: dict,
349-
**kwargs,
356+
**kwargs: Any,
350357
) -> None:
351358
"""Compute the statisitcs (avg and std) of the training data. The input will be normalized by the statistics.
352359
@@ -393,7 +400,7 @@ def compute_input_stats(
393400
}
394401
self.merge_input_stats(stat_dict)
395402

396-
def merge_input_stats(self, stat_dict) -> None:
403+
def merge_input_stats(self, stat_dict: dict[str, Any]) -> None:
397404
"""Merge the statisitcs computed from compute_input_stats to obtain the self.davg and self.dstd.
398405
399406
Parameters
@@ -747,8 +754,15 @@ def prod_force_virial(
747754
return force, virial, atom_virial
748755

749756
def _pass_filter(
750-
self, inputs, atype, natoms, input_dict, reuse=None, suffix="", trainable=True
751-
):
757+
self,
758+
inputs: tf.Tensor,
759+
atype: tf.Tensor,
760+
natoms: tf.Tensor,
761+
input_dict: dict,
762+
reuse: bool | None = None,
763+
suffix: str = "",
764+
trainable: bool = True,
765+
) -> tuple[tf.Tensor, tf.Tensor]:
752766
if input_dict is not None:
753767
type_embedding = input_dict.get("type_embedding", None)
754768
if type_embedding is not None:
@@ -834,8 +848,13 @@ def _pass_filter(
834848
return output, output_qmat
835849

836850
def _compute_dstats_sys_smth(
837-
self, data_coord, data_box, data_atype, natoms_vec, mesh
838-
):
851+
self,
852+
data_coord: np.ndarray,
853+
data_box: np.ndarray,
854+
data_atype: np.ndarray,
855+
natoms_vec: np.ndarray,
856+
mesh: np.ndarray,
857+
) -> tuple[list[float], list[float], list[float], list[float], list[int]]:
839858
dd_all = run_sess(
840859
self.sub_sess,
841860
self.stat_descrpt,
@@ -876,7 +895,7 @@ def _compute_dstats_sys_smth(
876895
sysa2.append(suma2)
877896
return sysr, sysr2, sysa, sysa2, sysn
878897

879-
def _compute_std(self, sumv2, sumv, sumn):
898+
def _compute_std(self, sumv2: float, sumv: float, sumn: int) -> float:
880899
if sumn == 0:
881900
return 1.0 / self.rcut_r
882901
val = np.sqrt(sumv2 / sumn - np.multiply(sumv / sumn, sumv / sumn))
@@ -886,11 +905,11 @@ def _compute_std(self, sumv2, sumv, sumn):
886905

887906
def _concat_type_embedding(
888907
self,
889-
xyz_scatter,
890-
nframes,
891-
natoms,
892-
type_embedding,
893-
):
908+
xyz_scatter: tf.Tensor,
909+
nframes: tf.Tensor,
910+
natoms: tf.Tensor,
911+
type_embedding: tf.Tensor,
912+
) -> tf.Tensor:
894913
"""Concatenate `type_embedding` of neighbors and `xyz_scatter`.
895914
If not self.type_one_side, concatenate `type_embedding` of center atoms as well.
896915
@@ -939,21 +958,21 @@ def _concat_type_embedding(
939958

940959
def _filter_lower(
941960
self,
942-
type_i,
943-
type_input,
944-
start_index,
945-
incrs_index,
946-
inputs,
947-
nframes,
948-
natoms,
949-
type_embedding=None,
950-
is_exclude=False,
951-
activation_fn=None,
952-
bavg=0.0,
953-
stddev=1.0,
954-
trainable=True,
955-
suffix="",
956-
):
961+
type_i: int,
962+
type_input: int,
963+
start_index: int,
964+
incrs_index: int,
965+
inputs: tf.Tensor,
966+
nframes: tf.Tensor,
967+
natoms: tf.Tensor,
968+
type_embedding: tf.Tensor | None = None,
969+
is_exclude: bool = False,
970+
activation_fn: Callable[[tf.Tensor], tf.Tensor] | None = None,
971+
bavg: float = 0.0,
972+
stddev: float = 1.0,
973+
trainable: bool = True,
974+
suffix: str = "",
975+
) -> tf.Tensor:
957976
"""Input env matrix, returns R.G."""
958977
outputs_size = [1, *self.filter_neuron]
959978
# cut-out inputs
@@ -1159,17 +1178,17 @@ def _filter_lower(
11591178
@cast_precision
11601179
def _filter(
11611180
self,
1162-
inputs,
1163-
type_input,
1164-
natoms,
1165-
type_embedding=None,
1166-
activation_fn=tf.nn.tanh,
1167-
stddev=1.0,
1168-
bavg=0.0,
1169-
name="linear",
1170-
reuse=None,
1171-
trainable=True,
1172-
):
1181+
inputs: tf.Tensor,
1182+
type_input: int,
1183+
natoms: tf.Tensor,
1184+
type_embedding: tf.Tensor | None = None,
1185+
activation_fn: Callable[[tf.Tensor], tf.Tensor] | None = tf.nn.tanh,
1186+
stddev: float = 1.0,
1187+
bavg: float = 0.0,
1188+
name: str = "linear",
1189+
reuse: bool | None = None,
1190+
trainable: bool = True,
1191+
) -> tf.Tensor:
11731192
nframes = tf.shape(tf.reshape(inputs, [-1, natoms[0], self.ndescrpt]))[0]
11741193
# natom x (nei x 4)
11751194
shape = inputs.get_shape().as_list()
@@ -1363,7 +1382,7 @@ def explicit_ntypes(self) -> bool:
13631382
return False
13641383

13651384
@classmethod
1366-
def deserialize(cls, data: dict, suffix: str = ""):
1385+
def deserialize(cls, data: dict, suffix: str = "") -> "DescrptSeA":
13671386
"""Deserialize the model.
13681387
13691388
Parameters

0 commit comments

Comments
 (0)