Skip to content

Commit

Permalink
Count how many times each method is called
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz committed Jan 3, 2024
1 parent c8ba9d5 commit 76f3066
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 48 deletions.
8 changes: 6 additions & 2 deletions pysages/methods/abf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class ABFState(NamedTuple):
Wp_: JaxArray (CV shape)
Product of W matrix and momenta matrix for the previous step.
ncalls: int
Counts the number of times the method's update has been called.
"""

xi: JaxArray
Expand All @@ -66,6 +69,7 @@ class ABFState(NamedTuple):
force: JaxArray
Wp: JaxArray
Wp_: JaxArray
ncalls: int

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -174,7 +178,7 @@ def initialize():
force = np.zeros(dims)
Wp = np.zeros(dims)
Wp_ = np.zeros(dims)
return ABFState(xi, bias, hist, Fsum, force, Wp, Wp_)
return ABFState(xi, bias, hist, Fsum, force, Wp, Wp_, 0)

def update(state, data):
"""
Expand Down Expand Up @@ -213,7 +217,7 @@ def update(state, data):
force = estimate_force(xi, I_xi, Fsum, hist).reshape(dims)
bias = np.reshape(-Jxi.T @ force, state.bias.shape)

return ABFState(xi, bias, hist, Fsum, force, Wp, state.Wp)
return ABFState(xi, bias, hist, Fsum, force, Wp, state.Wp, state.ncalls + 1)

return snapshot, initialize, generalize(update, helpers)

Expand Down
16 changes: 8 additions & 8 deletions pysages/methods/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class ANNState(NamedTuple):
nn: NNDada
Bundle of the neural network parameters, and output scaling coefficients.
nstep: int
Count the number of times the method's update has been called.
ncalls: int
Counts the number of times the method's update has been called.
"""

xi: JaxArray
Expand All @@ -70,7 +70,7 @@ class ANNState(NamedTuple):
phi: JaxArray
prob: JaxArray
nn: NNData
nstep: int
ncalls: int

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -148,13 +148,13 @@ def initialize():
phi = np.zeros(shape)
prob = np.ones(shape)
nn = NNData(ps, np.array(0.0), np.array(1.0))
return ANNState(xi, bias, hist, phi, prob, nn, 1)
return ANNState(xi, bias, hist, phi, prob, nn, 0)

def update(state, data):
nstep = state.nstep
in_training_regime = nstep > train_freq
ncalls = state.ncalls + 1
in_training_regime = ncalls > train_freq
# We only train every `train_freq` timesteps
in_training_step = in_training_regime & (nstep % train_freq == 1)
in_training_step = in_training_regime & (ncalls % train_freq == 1)
hist, phi, prob, nn = learn_free_energy(state, in_training_step)
# Compute the collective variable and its jacobian
xi, Jxi = cv(data)
Expand All @@ -163,7 +163,7 @@ def update(state, data):
F = estimate_force(xi, I_xi, nn, in_training_regime)
bias = np.reshape(-Jxi.T @ F, state.bias.shape)
#
return ANNState(xi, bias, hist, phi, prob, nn, nstep + 1)
return ANNState(xi, bias, hist, phi, prob, nn, ncalls)

return snapshot, initialize, generalize(update, helpers)

Expand Down
18 changes: 8 additions & 10 deletions pysages/methods/cff.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ class CFFState(NamedTuple):
nn: NNDada
Bundle of the neural network parameters, and output scaling coefficients.
nstep: int
Count the number of times the method's update has been called.
ncalls: int
Counts the number of times the method's update has been called.
"""

xi: JaxArray
Expand All @@ -93,7 +93,7 @@ class CFFState(NamedTuple):
Wp_: JaxArray
nn: NNData
fnn: NNData
nstep: int
ncalls: int

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -209,13 +209,13 @@ def initialize():
nn = NNData(ps, np.array(0.0), np.array(1.0))
fnn = NNData(fps, np.zeros(dims), np.array(1.0))

return CFFState(xi, bias, hist, histp, prob, fe, Fsum, force, Wp, Wp_, nn, fnn, 1)
return CFFState(xi, bias, hist, histp, prob, fe, Fsum, force, Wp, Wp_, nn, fnn, 0)

def update(state, data):
# During the intial stage, when there are not enough collected samples, use ABF
nstep = state.nstep
in_training_regime = nstep > 1 * train_freq
in_training_step = in_training_regime & (nstep % train_freq == 1)
ncalls = state.ncalls + 1
in_training_regime = ncalls > train_freq
in_training_step = in_training_regime & (ncalls % train_freq == 1)
histp, fe, prob, nn, fnn = learn_free_energy(state, in_training_step)
# Compute the collective variable and its jacobian
xi, Jxi = cv(data)
Expand All @@ -232,9 +232,7 @@ def update(state, data):
force = estimate_force(PartialCFFState(xi, hist, Fsum, I_xi, fnn, in_training_regime))
bias = (-Jxi.T @ force).reshape(state.bias.shape)
#
return CFFState(
xi, bias, hist, histp, prob, fe, Fsum, force, Wp, state.Wp, nn, fnn, nstep + 1
)
return CFFState(xi, bias, hist, histp, prob, fe, Fsum, force, Wp, state.Wp, nn, fnn, ncalls)

return snapshot, initialize, generalize(update, helpers)

Expand Down
5 changes: 3 additions & 2 deletions pysages/methods/ffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
class FFSState(NamedTuple):
xi: JaxArray
bias: Optional[JaxArray]
ncalls: int

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -210,11 +211,11 @@ def _ffs(method, snapshot, helpers):
# initialize method
def initialize():
xi = cv(helpers.query(snapshot))
return FFSState(xi, None)
return FFSState(xi, None, 0)

def update(state, data):
xi = cv(data)
return FFSState(xi, None)
return FFSState(xi, None, state.ncalls + 1)

return snapshot, initialize, generalize(update, helpers)

Expand Down
16 changes: 8 additions & 8 deletions pysages/methods/funn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class FUNNState(NamedTuple):
nn: NNData
Bundle of the neural network parameters, and output scaling coefficients.
nstep: int
Count the number of times the method's update has been called.
ncalls: int
Counts the number of times the method's update has been called.
"""

xi: JaxArray
Expand All @@ -78,7 +78,7 @@ class FUNNState(NamedTuple):
Wp: JaxArray
Wp_: JaxArray
nn: NNData
nstep: int
ncalls: int

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -173,13 +173,13 @@ def initialize():
Wp = np.zeros(dims)
Wp_ = np.zeros(dims)
nn = NNData(ps, F, F)
return FUNNState(xi, bias, hist, Fsum, F, Wp, Wp_, nn, 1)
return FUNNState(xi, bias, hist, Fsum, F, Wp, Wp_, nn, 0)

def update(state, data):
# During the intial stage, when there are not enough collected samples, use ABF
nstep = state.nstep
in_training_regime = nstep > 2 * train_freq
in_training_step = in_training_regime & (nstep % train_freq == 1)
ncalls = state.ncalls + 1
in_training_regime = ncalls > 2 * train_freq
in_training_step = in_training_regime & (ncalls % train_freq == 1)
# NN training
nn = learn_free_energy_grad(state, in_training_step)
# Compute the collective variable and its jacobian
Expand All @@ -198,7 +198,7 @@ def update(state, data):
)
bias = (-Jxi.T @ F).reshape(state.bias.shape)
#
return FUNNState(xi, bias, hist, Fsum, F, Wp, state.Wp, nn, state.nstep + 1)
return FUNNState(xi, bias, hist, Fsum, F, Wp, state.Wp, nn, state.ncalls)

return snapshot, initialize, generalize(update, helpers)

Expand Down
5 changes: 3 additions & 2 deletions pysages/methods/harmonic_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class HarmonicBiasState(NamedTuple):

xi: JaxArray
bias: JaxArray
ncalls: int

def __repr__(self):
return repr("PySAGES" + type(self).__name__)
Expand Down Expand Up @@ -118,14 +119,14 @@ def _harmonic_bias(method, snapshot, helpers):
def initialize():
xi, _ = cv(helpers.query(snapshot))
bias = np.zeros((natoms, helpers.dimensionality()))
return HarmonicBiasState(xi, bias)
return HarmonicBiasState(xi, bias, 0)

def update(state, data):
xi, Jxi = cv(data)
forces = kspring @ (xi - center).flatten()
bias = -Jxi.T @ forces.flatten()
bias = bias.reshape(state.bias.shape)

return HarmonicBiasState(xi, bias)
return HarmonicBiasState(xi, bias, state.ncalls + 1)

return snapshot, initialize, generalize(update, helpers)
11 changes: 5 additions & 6 deletions pysages/methods/metad.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class MetadynamicsState(NamedTuple):
idx: int
Index of the next Gaussian to be deposited.
nstep: int
ncalls: int
Counts the number of times `method.update` has been called.
"""

Expand All @@ -64,7 +64,7 @@ class MetadynamicsState(NamedTuple):
grid_potential: Optional[JaxArray]
grid_gradient: Optional[JaxArray]
idx: int
nstep: int
ncalls: int

def __repr__(self):
return repr("PySAGES" + type(self).__name__)
Expand Down Expand Up @@ -189,8 +189,8 @@ def update(state, data):
xi, Jxi = cv(data)

# Deposit Gaussian depending on the stride
nstep = state.nstep
in_deposition_step = (nstep > 0) & (nstep % stride == 0)
ncalls = state.ncalls + 1
in_deposition_step = (ncalls > 1) & (ncalls % stride == 1)
partial_state = deposit_gaussian(xi, state, in_deposition_step)

# Evaluate gradient of biasing potential (or generalized force)
Expand All @@ -200,7 +200,7 @@ def update(state, data):
bias = -Jxi.T @ generalized_force.flatten()
bias = bias.reshape(state.bias.shape)

return MetadynamicsState(xi, bias, *partial_state[1:-1], nstep + 1)
return MetadynamicsState(xi, bias, *partial_state[1:-1], ncalls)

return snapshot, initialize, generalize(update, helpers, jit_compile=True)

Expand Down Expand Up @@ -290,7 +290,6 @@ def build_bias_grad_evaluator(method: Metadynamics):
periods = get_periods(method.cvs)
evaluate_bias_grad = jit(lambda pstate: grad(sum_of_gaussians)(*pstate[:4], periods))
else:

if restraints:

def ob_force(pstate): # out-of-bounds force
Expand Down
16 changes: 8 additions & 8 deletions pysages/methods/spectral_abf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class SpectralABFState(NamedTuple):
Object that holds the coefficients of the basis functions
approximation to the free energy.
nstep: int
Count the number of times the method's update has been called.
ncalls: int
Counts the number of times the method's update has been called.
"""

xi: JaxArray
Expand All @@ -78,7 +78,7 @@ class SpectralABFState(NamedTuple):
Wp: JaxArray
Wp_: JaxArray
fun: Fun
nstep: int
ncalls: int

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -168,13 +168,13 @@ def initialize():
Wp = np.zeros(dims)
Wp_ = np.zeros(dims)
fun = fit(Fsum)
return SpectralABFState(xi, bias, hist, Fsum, force, Wp, Wp_, fun, 1)
return SpectralABFState(xi, bias, hist, Fsum, force, Wp, Wp_, fun, 0)

def update(state, data):
# During the intial stage use ABF
nstep = state.nstep
in_fitting_regime = nstep > fit_threshold
in_fitting_step = in_fitting_regime & (nstep % fit_freq == 1)
ncalls = state.ncalls + 1
in_fitting_regime = ncalls > fit_threshold
in_fitting_step = in_fitting_regime & (ncalls % fit_freq == 1)
# Fit forces
fun = fit_forces(state, in_fitting_step)
# Compute the collective variable and its jacobian
Expand All @@ -194,7 +194,7 @@ def update(state, data):
)
bias = np.reshape(-Jxi.T @ force, state.bias.shape)
#
return SpectralABFState(xi, bias, hist, Fsum, force, Wp, state.Wp, fun, state.nstep + 1)
return SpectralABFState(xi, bias, hist, Fsum, force, Wp, state.Wp, fun, state.ncalls)

return snapshot, initialize, generalize(update, helpers)

Expand Down
8 changes: 6 additions & 2 deletions pysages/methods/unbiased.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ class UnbiasedState(NamedTuple):
bias: Optional[JaxArray]
Either None or an array with all entries equal to zero.
ncalls: int
Counts the number of times the method's update has been called.
"""

xi: JaxArray
bias: Optional[JaxArray]
ncalls: int

def __repr__(self):
return repr("PySAGES" + type(self).__name__)
Expand Down Expand Up @@ -62,10 +66,10 @@ def _unbias(method, snapshot, helpers):

def initialize():
xi = cv(helpers.query(snapshot))
return UnbiasedState(xi, None)
return UnbiasedState(xi, None, 0)

def update(state, data):
xi = cv(data)
return UnbiasedState(xi, None)
return UnbiasedState(xi, None, state.ncalls + 1)

return snapshot, initialize, generalize(update, helpers)

0 comments on commit 76f3066

Please sign in to comment.