Skip to content

Commit 8f225c2

Browse files
committed
Move unused rescale_coordinates to attic
1 parent da4770b commit 8f225c2

File tree

3 files changed

+19
-16
lines changed

3 files changed

+19
-16
lines changed

attic/jax_tricks/rescale.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import jax
2+
from jax import numpy as np
3+
import numpy as onp
4+
5+
6+
def rescale_coordinates(conf, indices, box, scales):
7+
"""Note: scales unused"""
8+
9+
mol_sizes = np.expand_dims(onp.bincount(indices), axis=1)
10+
mol_centers = jax.ops.segment_sum(conf, indices) / mol_sizes
11+
12+
new_centers = mol_centers - box[2] * np.floor(np.expand_dims(mol_centers[..., 2], axis=-1) / box[2][2])
13+
new_centers -= box[1] * np.floor(np.expand_dims(new_centers[..., 1], axis=-1) / box[1][1])
14+
new_centers -= box[0] * np.floor(np.expand_dims(new_centers[..., 0], axis=-1) / box[0][0])
15+
16+
offset = new_centers - mol_centers
17+
18+
return conf + offset[indices]

attic/readme.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ Code, documentation, experiments we want to retain for reference, but that we're
22

33
### Contents
44
* `docs/` -- write-up of initial vision for `timemachine`, involving efficient backpropagation through MD trajectories
5+
* `jax_tricks` -- misc. Jax functions
56
* `thermo_deriv/` -- numerical experiments with "thermodynamic derivative" estimators, adjusting LJ parameters to match observables
67
* note: currently missing dependencies `thermo_deriv.lj_non_periodic.lennard_jones`, `thermo_deriv.lj.lennard_jones`.
78
* note: `langevin_coefficients` dependency has since changed -- some scripts rely on a version of `langevin_coefficients` prior to PR #459

timemachine/potentials/jax_utils.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
from typing import Tuple
2121

22-
import jax
2322
import jax.numpy as np
2423
import numpy as onp
2524
from jax import vmap
@@ -88,21 +87,6 @@ def convert_to_4d(x3, lamb, lambda_plane_idxs, lambda_offset_idxs, cutoff):
8887
return augment_dim(x3, w)
8988

9089

91-
def rescale_coordinates(conf, indices, box, scales):
92-
"""Note: scales unused"""
93-
94-
mol_sizes = np.expand_dims(onp.bincount(indices), axis=1)
95-
mol_centers = jax.ops.segment_sum(conf, indices) / mol_sizes
96-
97-
new_centers = mol_centers - box[2] * np.floor(np.expand_dims(mol_centers[..., 2], axis=-1) / box[2][2])
98-
new_centers -= box[1] * np.floor(np.expand_dims(new_centers[..., 1], axis=-1) / box[1][1])
99-
new_centers -= box[0] * np.floor(np.expand_dims(new_centers[..., 0], axis=-1) / box[0][0])
100-
101-
offset = new_centers - mol_centers
102-
103-
return conf + offset[indices]
104-
105-
10690
def delta_r(ri, rj, box=None):
10791
diff = ri - rj # this can be either N,N,3 or B,3
10892

0 commit comments

Comments
 (0)