Skip to content

Commit 9465d71

Browse files
authored
test(consistent): use ReLU to speedup UTs (#5203)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Updated test data configurations across descriptor and fitting test suites to include standardized activation function parameters, ensuring consistent test coverage. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 5c2ca51 commit 9465d71

File tree

14 files changed

+16
-3
lines changed

14 files changed

+16
-3
lines changed

source/tests/consistent/descriptor/test_dpa1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def data(self) -> dict:
127127
"type_map": ["O", "H"] if use_econf_tebd else None,
128128
"seed": 1145141919810,
129129
"trainable": False,
130+
"activation_function": "relu",
130131
}
131132

132133
def is_meaningless_zero_attention_layer_tests(

source/tests/consistent/descriptor/test_dpa2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def data(self) -> dict:
132132
"tebd_dim": 4,
133133
"tebd_input_mode": repinit_tebd_input_mode,
134134
"set_davg_zero": repinit_set_davg_zero,
135-
"activation_function": "tanh",
135+
"activation_function": "relu",
136136
"type_one_side": repinit_type_one_side,
137137
"use_three_body": repinit_use_three_body,
138138
"three_body_sel": 8,
@@ -163,7 +163,7 @@ def data(self) -> dict:
163163
"attn2_hidden": 10,
164164
"attn2_nhead": 2,
165165
"attn2_has_gate": repformer_attn2_has_gate,
166-
"activation_function": "tanh",
166+
"activation_function": "relu",
167167
"update_style": repformer_update_style,
168168
"update_residual": 0.001,
169169
"update_residual_init": repformer_update_residual_init,

source/tests/consistent/descriptor/test_dpa3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def data(self) -> dict:
125125
}
126126
),
127127
# kwargs for descriptor
128-
"activation_function": "silu",
128+
"activation_function": "relu",
129129
"precision": precision,
130130
"exclude_types": exclude_types,
131131
"env_protection": 0.0,

source/tests/consistent/descriptor/test_hybrid.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def data(self) -> dict:
6161
"type_one_side": True,
6262
"precision": "float64",
6363
"seed": 20240229,
64+
"activation_function": "relu",
6465
},
6566
{
6667
"type": "se_e2_a",
@@ -73,6 +74,7 @@ def data(self) -> dict:
7374
"type_one_side": True,
7475
"precision": "float64",
7576
"seed": 20240229,
77+
"activation_function": "relu",
7678
},
7779
]
7880
}

source/tests/consistent/descriptor/test_se_atten_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def data(self) -> dict:
124124
"use_tebd_bias": use_tebd_bias,
125125
"type_map": ["O", "H"] if use_econf_tebd else None,
126126
"seed": 1145141919810,
127+
"activation_function": "relu",
127128
}
128129

129130
def is_meaningless_zero_attention_layer_tests(

source/tests/consistent/descriptor/test_se_e2_a.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def data(self) -> dict:
9999
"env_protection": env_protection,
100100
"precision": precision,
101101
"seed": 1145141919810,
102+
"activation_function": "relu",
102103
}
103104

104105
@property

source/tests/consistent/descriptor/test_se_r.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def data(self) -> dict:
7777
"exclude_types": excluded_types,
7878
"precision": precision,
7979
"seed": 1145141919810,
80+
"activation_function": "relu",
8081
}
8182

8283
@property

source/tests/consistent/descriptor/test_se_t.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def data(self) -> dict:
6969
"env_protection": env_protection,
7070
"precision": precision,
7171
"seed": 1145141919810,
72+
"activation_function": "relu",
7273
}
7374

7475
@property

source/tests/consistent/descriptor/test_se_t_tebd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def data(self) -> dict:
9898
"use_tebd_bias": use_tebd_bias,
9999
"type_map": ["O", "H"] if use_econf_tebd else None,
100100
"seed": 1145141919810,
101+
"activation_function": "relu",
101102
}
102103

103104
@property

source/tests/consistent/fitting/test_dipole.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def data(self) -> dict:
7878
"precision": precision,
7979
"sel_type": sel_type,
8080
"seed": 20240217,
81+
"activation_function": "relu",
8182
}
8283
return data
8384

0 commit comments

Comments
 (0)