From cbefd605d8f4491c2b542ad383fa3a2d10a4d40d Mon Sep 17 00:00:00 2001 From: maggiezimon Date: Mon, 11 Mar 2024 19:39:17 +0000 Subject: [PATCH 1/6] Enable the continuous GeM calc. --- pysages/colvars/patterns.py | 72 +++++++++++++++++++++++++++++++++++-- pysages/utils/__init__.py | 2 +- 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/pysages/colvars/patterns.py b/pysages/colvars/patterns.py index 0d71e48e..267f38dd 100644 --- a/pysages/colvars/patterns.py +++ b/pysages/colvars/patterns.py @@ -11,7 +11,9 @@ from jaxopt import GradientDescent as minimize from pysages.colvars.core import CollectiveVariable -from pysages.utils import gaussian, quaternion_from_euler, quaternion_matrix +from pysages.utils import ( + gaussian, row_sum, quaternion_from_euler, + quaternion_matrix) def rotate_pattern_with_quaternions(rot_q, pattern): @@ -42,6 +44,9 @@ def __init__( centre_j_id, standard_deviation, mesh_size, + number_of_added_sites=0, + width_of_switch_func=None, + scale_for_radial_distance=None ): self.characteristic_distance = characteristic_distance @@ -58,6 +63,18 @@ def __init__( self.centre_j_coords = self.positions[self.centre_j_id] self.standard_deviation = standard_deviation self.mesh_size = mesh_size + # These settings are needed if continuous LoM is to be used + self.number_of_added_sites = number_of_added_sites + if self.number_of_added_sites > 0: + if width_of_switch_func is None: + self.width_of_switch_func = self.standard_deviation/2 + else: + self.width_of_switch_func = width_of_switch_func + + if scale_for_radial_distance is None: + self.scale_for_radial_distance = 0.9 + else: + self.scale_for_radial_distance = scale_for_radial_distance def comp_pair_distance_squared(self, pos1): displacement_fn, shift_fn = space.periodic(np.diag(self.simulation_box)) @@ -79,6 +96,13 @@ def _generate_neighborhood(self): ids_of_neighbors = np.argsort(distances)[: len(self.reference)] + if self.number_of_added_sites > 0: + ids_of_neighbors_2nd_shell = ids_of_neighbors[ + -self.number_of_added_sites:] + self.shell_distance = self.scale_for_radial_distance*np.mean( + distances[ids_of_neighbors_2nd_shell]) + self._neighborhood_distances = distances[ids_of_neighbors] + coordinates = mic_vectors[ids_of_neighbors] + self.centre_j_coords # Step 1: Translate to origin; coordinates = coordinates.at[:].set(coordinates - np.mean(coordinates, axis=0)) @@ -95,9 +119,33 @@ def _generate_neighborhood(self): self._neighbor_coords = np.array([n["coordinates"] for n in self._neighborhood]) self._orig_neighbor_coords = positions_of_all_nbrs[ids_of_neighbors] + def _switching_function(self, distance, width): + result = 0.5*lax.erfc( + (distance - self.shell_distance) / width) + return result + def compute_score(self, optim_reference): r = self._neighbor_coords - optim_reference - return np.prod(gaussian(1, self.standard_deviation, r)) + + if self.number_of_added_sites != 0: + width = self.width_of_switch_func + squared_dist = row_sum(r**2) + return np.exp( + -np.sum( + self._switching_function( + self._neighborhood_distances, + width)*squared_dist + ) / ( + 2*(self.standard_deviation**2)*np.sum( + self._switching_function( + self._neighborhood_distances, width)) + ) + ) + else: + return np.prod( + gaussian(1, + self.standard_deviation*np.sqrt( + len(self.reference)), r)) def rotate_reference(self, random_euler_point): # Perform rotation of the reference pattern; @@ -153,7 +201,7 @@ def return_close(_, n): close_sites, ) # Return the locations of settled nighbours in the neighborhood; - # Settlled site should have a unique neighbor + # Settled site should have a unique neighbor settled_neighbor_indices = np.where(np.sum(indices, axis=0) >= 1, 1, 0) return settled_neighbor_indices @@ -281,6 +329,9 @@ def calculate_lom(all_positions: np.array, neighborlist, simulation_box, params) i, params.standard_deviation, params.mesh_size, + params.number_of_added_sites, + params.width_of_switch_func, + params.scale_for_radial_distance ).driver_match( params.number_of_rotations, params.number_of_opt_it, @@ -339,6 +390,14 @@ class GeM(CollectiveVariable): fractional_coords: bool Set to True if NPT simulation is considered and the box size changes; use periodic_general for constructing the neighborlist. + number_of_added_sites: int + Specify additional sites to the main reference for the continuous + calculation (skip if the continuous LoM is not needed). + width_of_switch_func: float + Width of the switching function for the continuous score function. + scale_for_radial_distance: float + Scaling factor for the mean radial distance of added sites + used in the continuous score function calculation. Returns ------- calculate_lom: float @@ -357,6 +416,9 @@ def __init__( mesh_size, nbrs, fractional_coords, + number_of_added_sites=0, + width_of_switch_func=None, + scale_for_radial_distance=None ): super().__init__(indices, group_length=None) @@ -369,6 +431,10 @@ def __init__( self.mesh_size = mesh_size self.nbrs = nbrs self.fractional_coords = fractional_coords + # The parameters below are only used in the continuous version + self.number_of_added_sites = number_of_added_sites + self.width_of_switch_func = width_of_switch_func + self.scale_for_radial_distance = scale_for_radial_distance @property def function(self): diff --git a/pysages/utils/__init__.py b/pysages/utils/__init__.py index 00279a4c..81b04b0d 100644 --- a/pysages/utils/__init__.py +++ b/pysages/utils/__init__.py @@ -17,5 +17,5 @@ solve_pos_def, try_import, ) -from .core import ToCPU, copy, dispatch, eps, first_or_all, gaussian, identity +from .core import ToCPU, copy, dispatch, eps, first_or_all, gaussian, identity, row_sum from .transformations import quaternion_from_euler, quaternion_matrix From 192fbd8ae64fd6b667a1db65d78cb16b93905abf Mon Sep 17 00:00:00 2001 From: maggiezimon Date: Mon, 11 Mar 2024 19:55:59 +0000 Subject: [PATCH 2/6] Add missing whitespaces. --- pysages/colvars/patterns.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pysages/colvars/patterns.py b/pysages/colvars/patterns.py index 267f38dd..a67542e4 100644 --- a/pysages/colvars/patterns.py +++ b/pysages/colvars/patterns.py @@ -67,7 +67,7 @@ def __init__( self.number_of_added_sites = number_of_added_sites if self.number_of_added_sites > 0: if width_of_switch_func is None: - self.width_of_switch_func = self.standard_deviation/2 + self.width_of_switch_func = self.standard_deviation / 2 else: self.width_of_switch_func = width_of_switch_func @@ -99,7 +99,7 @@ def _generate_neighborhood(self): if self.number_of_added_sites > 0: ids_of_neighbors_2nd_shell = ids_of_neighbors[ -self.number_of_added_sites:] - self.shell_distance = self.scale_for_radial_distance*np.mean( + self.shell_distance = self.scale_for_radial_distance * np.mean( distances[ids_of_neighbors_2nd_shell]) self._neighborhood_distances = distances[ids_of_neighbors] @@ -120,7 +120,7 @@ def _generate_neighborhood(self): self._orig_neighbor_coords = positions_of_all_nbrs[ids_of_neighbors] def _switching_function(self, distance, width): - result = 0.5*lax.erfc( + result = 0.5 * lax.erfc( (distance - self.shell_distance) / width) return result @@ -131,12 +131,12 @@ def compute_score(self, optim_reference): width = self.width_of_switch_func squared_dist = row_sum(r**2) return np.exp( - -np.sum( - self._switching_function( - self._neighborhood_distances, - width)*squared_dist + - np.sum( + self._switching_function( + self._neighborhood_distances, + width) * squared_dist ) / ( - 2*(self.standard_deviation**2)*np.sum( + 2 * (self.standard_deviation ** 2) * np.sum( self._switching_function( self._neighborhood_distances, width)) ) @@ -144,7 +144,7 @@ def compute_score(self, optim_reference): else: return np.prod( gaussian(1, - self.standard_deviation*np.sqrt( + self.standard_deviation * np.sqrt( len(self.reference)), r)) def rotate_reference(self, random_euler_point): From 4ceafe301005d90ec44644b5771b0dc42b86e84c Mon Sep 17 00:00:00 2001 From: maggiezimon Date: Mon, 11 Mar 2024 20:02:51 +0000 Subject: [PATCH 3/6] Fix indentation. --- pysages/colvars/patterns.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pysages/colvars/patterns.py b/pysages/colvars/patterns.py index a67542e4..92ce46d0 100644 --- a/pysages/colvars/patterns.py +++ b/pysages/colvars/patterns.py @@ -136,9 +136,10 @@ def compute_score(self, optim_reference): self._neighborhood_distances, width) * squared_dist ) / ( - 2 * (self.standard_deviation ** 2) * np.sum( - self._switching_function( - self._neighborhood_distances, width)) + 2 * (self.standard_deviation ** 2) * np.sum( + self._switching_function( + self._neighborhood_distances, width) + ) ) ) else: From ee62d8d108dd4e40a7374c533b4b6bda600f11ca Mon Sep 17 00:00:00 2001 From: maggiezimon <74198137+maggiezimon@users.noreply.github.com> Date: Tue, 12 Mar 2024 11:32:47 +0000 Subject: [PATCH 4/6] Add more comments. A comment is provided explaining that the additional sites should be part of the reference. --- pysages/colvars/patterns.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/pysages/colvars/patterns.py b/pysages/colvars/patterns.py index 92ce46d0..87b088fe 100644 --- a/pysages/colvars/patterns.py +++ b/pysages/colvars/patterns.py @@ -27,7 +27,7 @@ def func_to_optimise(Q, modified_pattern, local_pattern): # Main class implementing the GeM CV class Pattern: """ - For determining nearest neighbors, + For determining the nearest neighbors, [JAX MD](https://jax-md.readthedocs.io/en/main/jax_md.partition.html) neighborlist library is utilized. This requires the user to define the indices of all the atoms in the system and a JAX MD @@ -350,14 +350,14 @@ class GeM(CollectiveVariable): an atomic or a molecular site is described in [Martelli2018](https://journals.aps.org/prb/abstract/10.1103/PhysRevB.97.064105). - Given a pattern, the algorithm is returning an average score (from 0 to 1), + Given a pattern, the algorithm returns an average score (from 0 to 1), denoting how closely the atomic neighbors resemble the reference. - For determining nearest neighbors, + For determining the nearest neighbors, [JAX MD](https://jax-md.readthedocs.io/en/main/jax_md.partition.html) neighborlist library is utilized. This requires the user to define the indices of all the atoms in the system and a JAX MD - neighbor list callable for updating the state. + neighbor list which is callable for updating the state. Matching a neighborhood to the pattern is an optimization process. Based on the number of initial rotations of the reference structure @@ -378,11 +378,11 @@ class GeM(CollectiveVariable): box: JaxArray Definition of the simulation box. number_of_rotations: integer - Number of initial rotated structures for the optimization study. - number_of_opt_it: iteger - Number of iterations for gradient descent. + A number of initial rotated structures for the optimization study. + number_of_opt_it: integer + A number of iterations for gradient descent. standard_deviation: float - Parameter that controls the spread of the Gaussian function. + A parameter that controls the spread of the Gaussian function. mesh_size: integer Defines the size of the angular grid from which we draw random Euler angles. @@ -392,8 +392,11 @@ class GeM(CollectiveVariable): Set to True if NPT simulation is considered and the box size changes; use periodic_general for constructing the neighborlist. number_of_added_sites: int - Specify additional sites to the main reference for the continuous - calculation (skip if the continuous LoM is not needed). + Specify the number of additional sites to the main reference for the continuous + calculation (skip if the continuous LoM is not needed). The additional atoms should + already be added to the reference (reference_positions). + In other words, the reference should have elements corresponding to the original reference and + additional coordinates representing the extra atoms. width_of_switch_func: float Width of the switching function for the continuous score function. scale_for_radial_distance: float From 9845020b5ea7fc62c1e16d60543491716fe033d6 Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Wed, 20 Mar 2024 13:41:32 -0500 Subject: [PATCH 5/6] Format with black and improve readability --- pysages/colvars/patterns.py | 103 ++++++++++++++++-------------------- 1 file changed, 46 insertions(+), 57 deletions(-) diff --git a/pysages/colvars/patterns.py b/pysages/colvars/patterns.py index 87b088fe..0b405458 100644 --- a/pysages/colvars/patterns.py +++ b/pysages/colvars/patterns.py @@ -12,8 +12,12 @@ from pysages.colvars.core import CollectiveVariable from pysages.utils import ( - gaussian, row_sum, quaternion_from_euler, - quaternion_matrix) + gaussian, + identity, + quaternion_from_euler, + quaternion_matrix, + row_sum, +) def rotate_pattern_with_quaternions(rot_q, pattern): @@ -46,9 +50,8 @@ def __init__( mesh_size, number_of_added_sites=0, width_of_switch_func=None, - scale_for_radial_distance=None + scale_for_radial_distance=None, ): - self.characteristic_distance = characteristic_distance self.reference = reference self.neighborlist = neighborlist @@ -76,15 +79,15 @@ def __init__( else: self.scale_for_radial_distance = scale_for_radial_distance + self._neighborhood = [] + def comp_pair_distance_squared(self, pos1): - displacement_fn, shift_fn = space.periodic(np.diag(self.simulation_box)) + displacement_fn, _ = space.periodic(np.diag(self.simulation_box)) mic_vector = displacement_fn(self.centre_j_coords, pos1) mic_norm = linalg.norm(mic_vector) return mic_norm, mic_vector def _generate_neighborhood(self): - self._neighborhood = [] - positions_of_all_nbrs = self.positions[self.neighborlist.idx[self.centre_j_id]] distances, mic_vectors = vmap(self.comp_pair_distance_squared)(positions_of_all_nbrs) # remove the same atom from the neighborhood @@ -96,11 +99,12 @@ def _generate_neighborhood(self): ids_of_neighbors = np.argsort(distances)[: len(self.reference)] - if self.number_of_added_sites > 0: - ids_of_neighbors_2nd_shell = ids_of_neighbors[ - -self.number_of_added_sites:] + n_added_sites = self.number_of_added_sites + if n_added_sites > 0: + ids_of_neighbors_2nd_shell = ids_of_neighbors[-n_added_sites:] self.shell_distance = self.scale_for_radial_distance * np.mean( - distances[ids_of_neighbors_2nd_shell]) + distances[ids_of_neighbors_2nd_shell] + ) self._neighborhood_distances = distances[ids_of_neighbors] coordinates = mic_vectors[ids_of_neighbors] + self.centre_j_coords @@ -120,33 +124,20 @@ def _generate_neighborhood(self): self._orig_neighbor_coords = positions_of_all_nbrs[ids_of_neighbors] def _switching_function(self, distance, width): - result = 0.5 * lax.erfc( - (distance - self.shell_distance) / width) + result = 0.5 * lax.erfc((distance - self.shell_distance) / width) return result def compute_score(self, optim_reference): r = self._neighbor_coords - optim_reference + std = self.standard_deviation if self.number_of_added_sites != 0: width = self.width_of_switch_func squared_dist = row_sum(r**2) - return np.exp( - - np.sum( - self._switching_function( - self._neighborhood_distances, - width) * squared_dist - ) / ( - 2 * (self.standard_deviation ** 2) * np.sum( - self._switching_function( - self._neighborhood_distances, width) - ) - ) - ) - else: - return np.prod( - gaussian(1, - self.standard_deviation * np.sqrt( - len(self.reference)), r)) + x = self._switching_function(self._neighborhood_distances, width) + return np.exp(-np.sum(x * squared_dist) / (2 * (std**2) * np.sum(x))) + + return np.prod(gaussian(1, std * np.sqrt(len(self.reference)), r)) def rotate_reference(self, random_euler_point): # Perform rotation of the reference pattern; @@ -196,7 +187,7 @@ def return_close(_, n): _, indices = lax.scan( lambda _, sites: ( None, - lax.cond(np.sum(sites) == 1, lambda s: s, lambda s: np.zeros_like(s), sites), + lax.cond(np.sum(sites) == 1, identity, np.zeros_like, sites), ), None, close_sites, @@ -207,11 +198,10 @@ def return_close(_, n): return settled_neighbor_indices def driver_match(self, number_of_rotations, number_of_opt_steps, num): - self._generate_neighborhood() - """Step2: Scale the reference so that the spread matches - with the current local pattern""" + # STEP 2: + # Scale the reference so that the spread matches with the current local pattern. local_distance = 0.0 reference_distance = 0.0 for n_index, neighbor in enumerate(self._neighborhood): @@ -220,17 +210,18 @@ def driver_match(self, number_of_rotations, number_of_opt_steps, num): self.reference *= np.sqrt(local_distance / reference_distance) - """Step3: mesh-loop -> Define angles in reduced Euler domain, - and for each rotate, resort and score the pattern - - The implementation below follows the article Martelli et al. 2018 - - - (a) Randomly with uniform probability pick a point in the Euler domain, - (b) Rotate the reference - (c) Resort the local pattern and assign the closest reference sites, - (d) Perform the optimisation step (conjugate gradient), - and (e) store the score with (f) the final settled status""" + # STEP 3: + # + # mesh-loop -> Define angles in reduced Euler domain, and for each rotate, + # resort and score the pattern. + # + # The implementation below follows the article Martelli et al. 2018 + # + # (a) Randomly with uniform probability pick a point in the Euler domain, + # (b) Rotate the reference + # (c) Resort the local pattern and assign the closest reference sites, + # (d) Perform the optimisation step (conjugate gradient), and + # (e) store the score with (f) the final settled status def get_all_scores(newkey, euler_point): # b. Rotate the reference pattern @@ -239,8 +230,7 @@ def get_all_scores(newkey, euler_point): # and assign ids to the closest reference sites newkey, newsubkey = random.split(random.PRNGKey(newkey)) reshuffled_reference, random_indices = self.resort(rotated_reference, newsubkey) - # d. Find the best rotation that aligns the settled sites - # in both patterns; + # d. Find the best rotation that aligns the settled sites in both patterns. # Here, ‘optimal’ or ‘best’ is in terms of least squares errors solver = minimize(fun=func_to_optimise, maxiter=number_of_opt_steps) # We are fixing the initial guess for the quaternions; @@ -266,7 +256,7 @@ def get_all_scores(newkey, euler_point): # a. Randomly pick a point in the Euler domain - key, subkey = random.split(random.PRNGKey(num)) + _, subkey = random.split(random.PRNGKey(num)) mesh_size = self.mesh_size grid_dimension = np.pi / mesh_size euler_angles = np.arange( @@ -305,14 +295,13 @@ def get_all_scores(newkey, euler_point): def calculate_lom(all_positions: np.array, neighborlist, simulation_box, params): - if params.fractional_coords: update_neighborlist = neighborlist.update(np.divide(all_positions, np.diag(simulation_box))) else: update_neighborlist = neighborlist.update(all_positions) - """Step1: Move the reference and - local patterns so that their centers coincide with the origin""" + # STEP 1: + # Move the reference and local patterns so that their centers coincide with the origin. reference_positions = params.reference_positions.at[:].set( params.reference_positions - np.mean(params.reference_positions, axis=0) @@ -332,7 +321,7 @@ def calculate_lom(all_positions: np.array, neighborlist, simulation_box, params) params.mesh_size, params.number_of_added_sites, params.width_of_switch_func, - params.scale_for_radial_distance + params.scale_for_radial_distance, ).driver_match( params.number_of_rotations, params.number_of_opt_it, @@ -393,10 +382,10 @@ class GeM(CollectiveVariable): changes; use periodic_general for constructing the neighborlist. number_of_added_sites: int Specify the number of additional sites to the main reference for the continuous - calculation (skip if the continuous LoM is not needed). The additional atoms should - already be added to the reference (reference_positions). - In other words, the reference should have elements corresponding to the original reference and - additional coordinates representing the extra atoms. + calculation (skip if the continuous LoM is not needed). The additional atoms should + already be added to the reference (reference_positions). + In other words, the reference should have elements corresponding to the + original reference and additional coordinates representing the extra atoms. width_of_switch_func: float Width of the switching function for the continuous score function. scale_for_radial_distance: float @@ -422,7 +411,7 @@ def __init__( fractional_coords, number_of_added_sites=0, width_of_switch_func=None, - scale_for_radial_distance=None + scale_for_radial_distance=None, ): super().__init__(indices, group_length=None) From 28a77c9deff019747ce3ddc6d3f924d71c3d542e Mon Sep 17 00:00:00 2001 From: maggiezimon <74198137+maggiezimon@users.noreply.github.com> Date: Fri, 29 Mar 2024 13:22:09 +0000 Subject: [PATCH 6/6] Pass positions directly patterns.py To enable grad, we need to make sure that the output depends on the input. The neighbor list dilutes that dependency. So we are now directly passing the positions. It is redundant, though. --- pysages/colvars/patterns.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pysages/colvars/patterns.py b/pysages/colvars/patterns.py index 0b405458..43f125ba 100644 --- a/pysages/colvars/patterns.py +++ b/pysages/colvars/patterns.py @@ -40,6 +40,7 @@ class Pattern: def __init__( self, + positions, simulation_box, fractional_coords, reference, @@ -59,10 +60,11 @@ def __init__( self.centre_j_id = centre_j_id # This is added to handle neighborlists with fractional coordinates # (needed for NPT simulations) - if fractional_coords: - self.positions = self.neighborlist.reference_position * np.diag(self.simulation_box) - else: - self.positions = self.neighborlist.reference_position + # if fractional_coords: + # self.positions = self.neighborlist.reference_position * np.diag(self.simulation_box) + # else: + # self.positions = self.neighborlist.reference_position + self.positions = positions self.centre_j_coords = self.positions[self.centre_j_id] self.standard_deviation = standard_deviation self.mesh_size = mesh_size @@ -311,6 +313,7 @@ def calculate_lom(all_positions: np.array, neighborlist, simulation_box, params) seed = np.int64(time.process_time() * 1e5) optimal_results = vmap( lambda i: Pattern( + all_positions, params.box, params.fractional_coords, reference_positions,