Skip to content

Commit 7984900

Browse files
committed
change source of torch_admp from github to pypi in pyproject.toml; minor code improvement
1 parent 3d085c9 commit 7984900

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

deepmd/pt/modifier/dipole_charge.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def __init__(
6363
self._model_charge_map = model_charge_map
6464
self._sys_charge_map = sys_charge_map
6565

66+
# Validate that model_charge_map and sel_type have matching lengths
67+
if len(model_charge_map) != len(sel_type):
68+
raise ValueError(
69+
f"model_charge_map length ({len(model_charge_map)}) must match "
70+
f"sel_type length ({len(sel_type)})"
71+
)
72+
6673
# init ewald recp
6774
self.ewald_h = ewald_h
6875
self.ewald_beta = ewald_beta
@@ -344,8 +351,6 @@ def make_mask(
344351
sel_type = sel_type.to(torch.long)
345352
atype = atype.to(torch.long)
346353

347-
# Create mask using broadcasting
348-
mask = torch.zeros_like(atype, dtype=torch.bool)
349-
for t in sel_type:
350-
mask = mask | (atype == t)
354+
# Create mask using broadcasting for JIT compatibility
355+
mask = (atype.unsqueeze(-1) == sel_type).any(dim=-1)
351356
return mask

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,11 @@ pin_pytorch_cpu = [
164164
# macos x86 has been deprecated
165165
"torch>=2.8,<2.10; platform_machine!='x86_64' or platform_system != 'Darwin'",
166166
"torch; platform_machine=='x86_64' and platform_system == 'Darwin'",
167-
"torch_admp @ git+https://github.com/chiahsinchu/torch-admp.git@v1.1.1",
167+
"torch-admp==1.1.1",
168168
]
169169
pin_pytorch_gpu = [
170170
"torch>=2.7,<2.10",
171-
"torch_admp @ git+https://github.com/chiahsinchu/torch-admp.git@v1.1.1",
171+
"torch-admp==1.1.1",
172172
]
173173
pin_jax = [
174174
"jax==0.5.0;python_version>='3.10'",

0 commit comments

Comments
 (0)