File tree Expand file tree Collapse file tree 2 files changed +11
-6
lines changed
Expand file tree Collapse file tree 2 files changed +11
-6
lines changed Original file line number Diff line number Diff line change @@ -63,6 +63,13 @@ def __init__(
6363 self ._model_charge_map = model_charge_map
6464 self ._sys_charge_map = sys_charge_map
6565
66+ # Validate that model_charge_map and sel_type have matching lengths
67+ if len (model_charge_map ) != len (sel_type ):
68+ raise ValueError (
69+ f"model_charge_map length ({ len (model_charge_map )} ) must match "
70+ f"sel_type length ({ len (sel_type )} )"
71+ )
72+
6673 # init ewald recp
6774 self .ewald_h = ewald_h
6875 self .ewald_beta = ewald_beta
@@ -344,8 +351,6 @@ def make_mask(
344351 sel_type = sel_type .to (torch .long )
345352 atype = atype .to (torch .long )
346353
347- # Create mask using broadcasting
348- mask = torch .zeros_like (atype , dtype = torch .bool )
349- for t in sel_type :
350- mask = mask | (atype == t )
354+ # Create mask using broadcasting for JIT compatibility
355+ mask = (atype .unsqueeze (- 1 ) == sel_type ).any (dim = - 1 )
351356 return mask
Original file line number Diff line number Diff line change @@ -164,11 +164,11 @@ pin_pytorch_cpu = [
164164 # macos x86 has been deprecated
165165 " torch>=2.8,<2.10; platform_machine!='x86_64' or platform_system != 'Darwin'" ,
166166 " torch; platform_machine=='x86_64' and platform_system == 'Darwin'" ,
167- " torch_admp @ git+https://github.com/chiahsinchu/ torch-admp.git@v1 .1.1" ,
167+ " torch-admp==1 .1.1" ,
168168]
169169pin_pytorch_gpu = [
170170 " torch>=2.7,<2.10" ,
171- " torch_admp @ git+https://github.com/chiahsinchu/ torch-admp.git@v1 .1.1" ,
171+ " torch-admp==1 .1.1" ,
172172]
173173pin_jax = [
174174 " jax==0.5.0;python_version>='3.10'" ,
You can’t perform that action at this time.
0 commit comments