diff --git a/pyproject.toml b/pyproject.toml index adcc5448f5..59866a7e73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ ase = ["ase>=3.23.0"] # tblite py3.12 support tracked in https://github.com/tblite/tblite/issues/198 ase-ext = ["tblite>=0.3.0; python_version < '3.12'"] openmm = [ - "mdanalysis>=2.7.0", + "mdanalysis>=2.8.0", "openmm-mdanalysis-reporter>=0.1.0", "openmm>=8.1.0", ] @@ -153,7 +153,7 @@ ignore_missing_imports = true no_strict_optional = true [tool.pytest.ini_options] -addopts = "-p no:warnings --import-mode=importlib --cov-config=pyproject.toml" +addopts = "-p no:warnings --import-mode=importlib --cov-config=pyproject.toml -m 'not openmm_mace'" filterwarnings = [ "ignore:.*POTCAR.*:UserWarning", "ignore:.*input structure.*:UserWarning", @@ -161,6 +161,9 @@ filterwarnings = [ "ignore:.*magmom.*:UserWarning", "ignore::DeprecationWarning", ] +markers = [ + "openmm_mace: tests marked openmm_mace are skipped by default because they are very slow (unskip with pytest -m openmm_mace)", +] [tool.coverage.run] include = ["src/*"] diff --git a/src/atomate2/openmm/jobs/mace.py b/src/atomate2/openmm/jobs/mace.py new file mode 100644 index 0000000000..c0020697e3 --- /dev/null +++ b/src/atomate2/openmm/jobs/mace.py @@ -0,0 +1,89 @@ +"""Run MACE on randomly packed benchmarking structures.""" + +import io +import json +from pathlib import Path + +import numpy as np +import openmm +import openmm.unit as omm_unit +from emmet.core.openmm import OpenMMInterchange, OpenMMTaskDocument +from emmet.core.vasp.task_valid import TaskState +from jobflow import Response +from mace.calculators.foundations_models import download_mace_mp_checkpoint +from monty.json import MontyEncoder +from openmm import Context, XmlSerializer +from openmm.app.pdbfile import PDBFile +from pymatgen.core import Structure + +from atomate2.openmm.jobs.base import openmm_job +from atomate2.openmm.mace_utils import MacePotential +from atomate2.openmm.utils import structure_to_topology + + +@openmm_job +def generate_mace_interchange( + structure: Structure, + ff_path: str | Path | None = None, + tags: list[str] | None = None, +) -> Response: + """Generate an OpenMMInterchange object with the MACE force-field. + + Parameters + ---------- + structure : Structure + The structure to simulate. + ff_path : str | Path, optional + The path to the MACE force-field. Must be accessible where the job is run. + Defaults to None. + tags : list[str], optional + Tags to add to the task document. Defaults to None. + + Returns + ------- + Response + The response containing the OpenMMTaskDocument. + """ + if not ff_path: + ff_path = Path(download_mace_mp_checkpoint()) + + potential = MacePotential(model_path=ff_path) + + topology = structure_to_topology(structure) + topology.setPeriodicBoxVectors(structure.lattice.matrix / 10) + system = potential.create_system(topology) + integrator = openmm.LangevinIntegrator( + 300 * omm_unit.kelvin, 10.0 / omm_unit.picoseconds, 1.0 * omm_unit.femtosecond + ) + context = Context(system, integrator) + context.setPositions(structure.cart_coords / 10) + state = context.getState(getPositions=True) + with io.StringIO() as buffer: + PDBFile.writeFile(topology, np.zeros(shape=(len(structure), 3)), file=buffer) + buffer.seek(0) + pdb = buffer.read() + + interchange = OpenMMInterchange( + system=XmlSerializer.serialize(system), + state=XmlSerializer.serialize(state), + topology=pdb, + ) + + interchange_json = interchange.model_dump_json() + + dir_name = Path.cwd() + + task_doc = OpenMMTaskDocument( + dir_name=str(dir_name), + state=TaskState.SUCCESS, + interchange=interchange_json, + structure=structure, + force_field=Path(ff_path).stem, + tags=tags, + ) + + # write out task_doc json to output dir + with open(dir_name / "taskdoc.json", "w") as file: + json.dump(task_doc.model_dump(), file, cls=MontyEncoder) + + return Response(output=task_doc) diff --git a/src/atomate2/openmm/jobs/random_structure.json b/src/atomate2/openmm/jobs/random_structure.json new file mode 100644 index 0000000000..b8852df7fb --- /dev/null +++ b/src/atomate2/openmm/jobs/random_structure.json @@ -0,0 +1 @@ +{"@module": "pymatgen.core.structure", "@class": "Structure", "charge": 0, "lattice": {"matrix": [[8.945722384641158, 0.0, 0.0], [0.0, 8.945722384641158, 0.0], [0.0, 0.0, 8.945722384641158]], "pbc": [true, true, true], "a": 8.945722384641158, "b": 8.945722384641158, "c": 8.945722384641158, "alpha": 90.0, "beta": 90.0, "gamma": 90.0, "volume": 715.8899231699995}, "properties": {}, "sites": [{"species": [{"element": "Al", "occu": 1}], "abc": [0.64659998950236, 0.38190599407216497, 0.8952527985610246], "properties": {}, "label": "Al", "xyz": [5.784304, 3.4164249999999994, 8.008683]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8834840452425946, 0.31815720157899446, 0.8839606976376343], "properties": {}, "label": "Al", "xyz": [7.903403, 2.846146, 7.907666999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.3468634355708847, 0.757030534686289, 0.817740668161081], "properties": {}, "label": "Al", "xyz": [3.1029439999999995, 6.772185, 7.315281]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.11223308267689708, 0.8869685039212494, 0.1123682310693931], "properties": {}, "label": "Al", "xyz": [1.004006, 7.934574, 1.005215]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.16107788035921722, 0.3912898086363331, 0.6561033025211006], "properties": {}, "label": "Al", "xyz": [1.4409579999999997, 3.50037, 5.869317999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.68638146099213, 0.5142043093017935, 0.491245626797559], "properties": {}, "label": "Al", "xyz": [6.140178, 4.5999289999999995, 4.394547]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.5147756438212701, 0.11389152895569865, 0.3368813462369576], "properties": {}, "label": "Al", "xyz": [4.60504, 1.018842, 3.0136469999999997]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.5836308992735897, 0.3558866308511883, 0.1110339620761486], "properties": {}, "label": "Al", "xyz": [5.221, 3.183663, 0.9932789999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.7667138219911492, 0.8885189656283793, 0.19472759438531093], "properties": {}, "label": "Al", "xyz": [6.858809, 7.948444, 1.7419789999999997]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.4653749379868551, 0.11650641001216443, 0.675629730068182], "properties": {}, "label": "Al", "xyz": [4.163115, 1.042234, 6.043995999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8434003063803603, 0.6355647711314559, 0.3462858410762388], "properties": {}, "label": "Al", "xyz": [7.544825, 5.685586, 3.097777]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.10461547539266058, 0.268665167178269, 0.45120213063284215], "properties": {}, "label": "Al", "xyz": [0.935861, 2.4034039999999997, 4.036329]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.10656545765792173, 0.48920040348150656, 0.34230941542043336], "properties": {}, "label": "Al", "xyz": [0.953305, 4.376251, 3.062205]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.16209926237926373, 0.532217276066407, 0.10737713050979408], "properties": {}, "label": "Al", "xyz": [1.450095, 4.761068, 0.9605659999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.10987653738143902, 0.6517177427738681, 0.531342779892299], "properties": {}, "label": "Al", "xyz": [0.9829249999999999, 5.830086, 4.753245]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.6220034292091662, 0.11033888126180591, 0.8866595294325322], "properties": {}, "label": "Al", "xyz": [5.56427, 0.987061, 7.93181]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.837565227025604, 0.33900448388681453, 0.11178527088174481], "properties": {}, "label": "Al", "xyz": [7.492625999999999, 3.0326399999999993, 1.0]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.12504557506955002, 0.1168739599828236, 0.6426506158821079], "properties": {}, "label": "Al", "xyz": [1.118623, 1.045522, 5.748974]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.6789277309150062, 0.2765485998366623, 0.41951491882234837], "properties": {}, "label": "Al", "xyz": [6.073499, 2.473927, 3.752864]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8902798072153085, 0.6174102842139062, 0.1067077603077542], "properties": {}, "label": "Al", "xyz": [7.964196, 5.523181, 0.9545779999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.5625750256503042, 0.780675019827383, 0.11049593956739476], "properties": {}, "label": "Al", "xyz": [5.03264, 6.983702, 0.988466]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.5418058812468304, 0.5988211761631681, 0.8458476213157743], "properties": {}, "label": "Al", "xyz": [4.846844999999999, 5.356887999999999, 7.566718]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.6963827773926489, 0.10453331321856249, 0.5926859533560945], "properties": {}, "label": "Al", "xyz": [6.229647, 0.9351259999999999, 5.302004]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8508692392543241, 0.7985831320079093, 0.8003875698504825], "properties": {}, "label": "Al", "xyz": [7.61164, 7.143902999999999, 7.160045]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.419600322769302, 0.24957928538373256, 0.8982807261933984], "properties": {}, "label": "Al", "xyz": [3.7536279999999995, 2.232667, 8.03577]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.3342419842061687, 0.18765795849893677, 0.4822187426033163], "properties": {}, "label": "Al", "xyz": [2.9900359999999995, 1.678736, 4.313795]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.7144616974670827, 0.8272436458240091, 0.5737339903160743], "properties": {}, "label": "Al", "xyz": [6.391376, 7.400292, 5.132464999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.6282204788345254, 0.852304226776633, 0.8892708333603299], "properties": {}, "label": "Al", "xyz": [5.619886, 7.624477, 7.95517]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8881081547578888, 0.40276344883963533, 0.4077231377381166], "properties": {}, "label": "Al", "xyz": [7.944769, 3.6030099999999994, 3.647378]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.3606916089347692, 0.1239122959933509, 0.10891205406427132], "properties": {}, "label": "Al", "xyz": [3.2266469999999994, 1.108485, 0.974297]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.13044402115624212, 0.5731628793671395, 0.8127437547673961], "properties": {}, "label": "Al", "xyz": [1.1669159999999998, 5.127356, 7.27058]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8207171745688493, 0.11312177558040695, 0.22061840454477363], "properties": {}, "label": "Al", "xyz": [7.341908, 1.011956, 1.973591]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.3506572040940694, 0.8885147177880858, 0.17155763771910992], "properties": {}, "label": "Al", "xyz": [3.136882, 7.948406, 1.5347069999999998]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.5031197936264306, 0.8860720978340487, 0.685552461423271], "properties": {}, "label": "Al", "xyz": [4.500769999999999, 7.926555, 6.132761999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.38536317714472473, 0.5908122086455747, 0.19223835997331626], "properties": {}, "label": "Al", "xyz": [3.447352, 5.285242, 1.719711]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.6037762818655432, 0.11123148664979664, 0.1103842660817839], "properties": {}, "label": "Al", "xyz": [5.401215, 0.9950459999999999, 0.9874669999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.5199556614886593, 0.6769940692992937, 0.5641595818650529], "properties": {}, "label": "Al", "xyz": [4.651379, 6.056201, 5.046815]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.4035626017411689, 0.3583061112541528, 0.6792217261774252], "properties": {}, "label": "Al", "xyz": [3.6101589999999995, 3.205307, 6.076128999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.12481473848517921, 0.8587373573306065, 0.8211087583727481], "properties": {}, "label": "Al", "xyz": [1.116558, 7.682026, 7.345411]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.14827310115025513, 0.3429417846978114, 0.8948110231705], "properties": {}, "label": "Al", "xyz": [1.3264099999999999, 3.067862, 8.004731]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.6645119023820963, 0.5724788652946141, 0.19265699581276838], "properties": {}, "label": "Al", "xyz": [5.944538999999999, 5.121237, 1.7234559999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8900891015431842, 0.15954295680474, 0.4518175085490461], "properties": {}, "label": "Al", "xyz": [7.96249, 1.427227, 4.041834]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.88777995320258, 0.868444119542352, 0.4076669097468631], "properties": {}, "label": "Al", "xyz": [7.941832999999999, 7.76886, 3.6468749999999996]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.6454857139222108, 0.33587248416724985, 0.6551271935357612], "properties": {}, "label": "Al", "xyz": [5.774336, 3.004622, 5.860586]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.6042424264451202, 0.7678240733015467, 0.3522360592600032], "properties": {}, "label": "Al", "xyz": [5.405385, 6.868740999999999, 3.1510059999999998]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.21920409729757778, 0.10988827483488159, 0.8689416757830466], "properties": {}, "label": "Al", "xyz": [1.960939, 0.9830299999999998, 7.773311]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8854965154742787, 0.6515903075650629, 0.5878952837871864], "properties": {}, "label": "Al", "xyz": [7.921406, 5.828945999999999, 5.259147999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.17644310119772572, 0.8866629947759297, 0.5117170870246747], "properties": {}, "label": "Al", "xyz": [1.578411, 7.931841, 4.577679]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8902248088620346, 0.3336060378001225, 0.6406499948891374], "properties": {}, "label": "Al", "xyz": [7.963703999999999, 2.984347, 5.731077]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.3383039255941987, 0.3644458054720618, 0.11053484084166161], "properties": {}, "label": "Al", "xyz": [3.026373, 3.260231, 0.988814]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.109196882934478, 0.2875773346618552, 0.15432973891189894], "properties": {}, "label": "Al", "xyz": [0.976845, 2.5725869999999995, 1.380591]}, {"species": [{"element": "Fe", "occu": 1}], "abc": [0.5150752283472332, 0.43701378512617667, 0.33287015536190795], "properties": {}, "label": "Fe", "xyz": [4.60772, 3.909404, 2.977764]}, {"species": [{"element": "Fe", "occu": 1}], "abc": [0.3231218090293945, 0.6687096628742476, 0.4165378534882257], "properties": {}, "label": "Fe", "xyz": [2.890558, 5.982090999999999, 3.7262319999999995]}, {"species": [{"element": "Fe", "occu": 1}], "abc": [0.8868078684869924, 0.11344483501325518, 0.7480456817538979], "properties": {}, "label": "Fe", "xyz": [7.933137, 1.014846, 6.691809]}, {"species": [{"element": "Ni", "occu": 1}], "abc": [0.4182014418894879, 0.8897721903002345, 0.4561953551125879], "properties": {}, "label": "Ni", "xyz": [3.741114, 7.959655, 4.080997]}, {"species": [{"element": "Ni", "occu": 1}], "abc": [0.3127311445303946, 0.5842903205865211, 0.6508213366866673], "properties": {}, "label": "Ni", "xyz": [2.797606, 5.226899, 5.822067]}, {"species": [{"element": "Ni", "occu": 1}], "abc": [0.11719713120094273, 0.7311663294212961, 0.3001219895454755], "properties": {}, "label": "Ni", "xyz": [1.048413, 6.540811, 2.6848079999999994]}, {"species": [{"element": "Ni", "occu": 1}], "abc": [0.19971142890230265, 0.10653147493557369, 0.29398352496554625], "properties": {}, "label": "Ni", "xyz": [1.786563, 0.953001, 2.629895]}, {"species": [{"element": "Ni", "occu": 1}], "abc": [0.77572270875678, 0.5601726483937846, 0.7856451047560562], "properties": {}, "label": "Ni", "xyz": [6.9394, 5.011149, 7.028162999999999]}, {"species": [{"element": "Ni", "occu": 1}], "abc": [0.30994746771489734, 0.42839927679621675, 0.4650978223003392], "properties": {}, "label": "Ni", "xyz": [2.7727039999999996, 3.8323409999999996, 4.160636]}]} diff --git a/src/atomate2/openmm/mace_utils.py b/src/atomate2/openmm/mace_utils.py new file mode 100644 index 0000000000..08a79cb690 --- /dev/null +++ b/src/atomate2/openmm/mace_utils.py @@ -0,0 +1,583 @@ +"""Supports easy instantiation of OpenMM Systems with the Mace force field. + +This code is based off of the openmm-ml package. In particular, +it borrows from the MLPotential class written by Peter Eastman and the MACEForce +class written by Harry Moore. The nnpops_nl function +is also from openmm-ml and was written by Harry Moore. + +The original code is licensed as below + +Portions copyright (c) 2021 Stanford University and the Authors. +Authors: Peter Eastman + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE +USE OR OTHER DEALINGS IN THE SOFTWARE. +""" + +import logging +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import openmm +import openmm.app +import torch +from e3nn.util import jit +from mace.tools import atomic_numbers_to_indices, to_one_hot, utils + +try: + from NNPOps.neighbors import getNeighborPairs +except ImportError: + + def getNeighborPairs(*args, **kwargs) -> None: # noqa: N802, ARG001 + """Raise ImportError if NNPOps is not installed.""" + raise ImportError( + "NNPOps is not installed. Please install it from conda-forge." + ) + + +try: + from openmmtorch import TorchForce +except ImportError: + + def TorchForce(*args, **kwargs) -> None: # noqa: N802, ARG001 + """Raise ImportError if openmmtorch is not installed.""" + raise ImportError( + "openmmtorch is not installed. Please install it from conda-forge." + ) + + +class MaceForce(torch.nn.Module): + """Computes the energy of a system using a MACE model. + + Attributes + ---------- + model (torch.nn.Module): The MACE model. + device (str): The device (CPU or GPU) on which computations are performed. + nl (Callable): The neighbor list function used for atom interactions. + periodic (bool): Whether to use periodic boundary conditions. + default_dtype (torch.dtype): The default data type for tensor operations. + r_max (float): The maximum cutoff radius for atomic interactions. + z_table (utils.AtomicNumberTable): Table for converting between + atomic numbers and indices. + """ + + def __init__( + self, + model: torch.nn.Module, + atomic_numbers: list[int], + device: torch.device | None, + nl: Callable, + *, + periodic: bool = True, + dtype: torch.dtype = torch.float32, + ) -> None: + """Initialize the MaceForce object. + + Args: + model (torch.nn.Module): The MACE neural network model. + atomic_numbers (list[int]): List of atomic numbers for the system. + device (str | None): The device to run computations on ('cuda', 'cpu', + or None for auto-detection). + nl (Callable): The neighbor list function to use. + periodic (bool, optional): Whether to use periodic boundary conditions. + Defaults to True. + dtype (torch.dtype, optional): The data type for tensor operations. + Defaults to torch.float32. + """ + super().__init__() + self.device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + self.nl = nl + self.periodic = periodic + self.default_dtype = dtype + + torch.set_default_dtype(self.default_dtype) + + logging.info( + f"Running MACEForce on device: {self.device} " + f"with dtype: {self.default_dtype} " + f"and neighbour list: {nl}" + ) + # conversion constants + self.nm_to_A = 10.0 + self.eV_to_kj = 96.48533288 + + self.model = model.to(dtype=self.default_dtype, device=self.device) + self.model.eval() + + # set model properties + self.r_max = self.model.r_max + self.z_table = utils.AtomicNumberTable( + [int(z) for z in self.model.atomic_numbers] + ) + self.model.atomic_numbers = torch.tensor( + self.model.atomic_numbers.clone(), device=self.device + ) + + # compile model + self.model = jit.compile(self.model) + + # setup system + n_atoms = len(atomic_numbers) + self.ptr = torch.tensor([0, n_atoms], dtype=torch.long, device=self.device) + self.batch = torch.zeros(n_atoms, dtype=torch.long, device=self.device) + + # one hot encoding of atomic number + self.node_attrs = to_one_hot( + torch.tensor( + atomic_numbers_to_indices(atomic_numbers, z_table=self.z_table), + dtype=torch.long, + device=self.device, + ).unsqueeze(-1), + num_classes=len(self.z_table), + ) + + if periodic: + self.pbc = torch.tensor([True, True, True], device=self.device) + else: + self.pbc = torch.tensor([False, False, False], device=self.device) + + def forward( + self, positions: torch.Tensor, boxvectors: torch.Tensor | None = None + ) -> torch.Tensor: + """Compute the energy of the system given atomic positions and box vectors. + + This method calculates the neighbor list, prepares the input for the MACE + model, and returns the computed energy. + + Args: + positions (torch.Tensor): Atomic positions in nanometers. + boxvectors (torch.Tensor | None, optional): Box vectors for + periodic systems. Defaults to None. + + Returns + ------- + torch.Tensor: The computed energy of the system in kJ/mol. + """ + positions = positions.to(device=self.device, dtype=self.default_dtype) + positions = positions * self.nm_to_A + + if boxvectors is not None: + cell = ( + boxvectors.to(device=self.device, dtype=self.default_dtype) + * self.nm_to_A + ) + else: + # TODO: it's not clear what the best fallback should be + # cell = torch.eye(3, device=self.device) + cell = torch.zeros((3, 3), device=self.device) + + # calculate neighbor list + mapping, shifts_idx = self.nl(positions, cell, self.periodic, self.r_max) + edge_index = torch.stack((mapping[0], mapping[1])) + shifts = torch.mm(shifts_idx, cell) + + # get model output + out = self.model( + dict( + ptr=self.ptr, + node_attrs=self.node_attrs, + batch=self.batch, + pbc=self.pbc, + cell=cell, + positions=positions, + edge_index=edge_index, + unit_shifts=shifts_idx, + shifts=shifts, + ), + compute_force=False, + ) + + energy = out["interaction_energy"] + if energy is None: + energy = torch.tensor(0.0, device=self.device) + + # return energy tensor + return energy * self.eV_to_kj + + +class MacePotential: + """A potential function class for molecular simulations using MACE models. + + Attributes + ---------- + model (torch.nn.Module | None): The MACE model, if provided directly. + model_path (str | Path | None): Path to the MACE model file, if the + model is to be loaded from disk. + """ + + def __init__( + self, model: torch.nn.Module | None = None, model_path: str | Path | None = None + ) -> None: + """Initialize a MacePotential object. + + Exactly one of 'model' or 'model_path' must be provided. + + Args: + model (torch.nn.Module | None, optional): The MACE model. Defaults to None. + model_path (str | Path | None, optional): Path to the MACE model file. + Defaults to None. + + Raises + ------ + ValueError: If neither model nor model_path is provided, + or if both are provided. + + """ + if (model_path is None) == (model is None): + raise ValueError( + "Exactly one of 'model_paths' or 'models' must be provided" + ) + self.model = model + self.model_path = model_path + + def create_system(self, topology: openmm.app.Topology, **kwargs) -> openmm.System: + """Create a System for running a simulation with this potential function. + + Parameters + ---------- + topology : openmm.app.Topology + The Topology for which to create a System + **kwargs : dict + Additional keyword arguments for customizing the potential function + + Returns + ------- + openmm.System + A newly created System object that uses this potential function to model + the Topology + """ + system = openmm.System() + if topology.getPeriodicBoxVectors() is not None: + system.setDefaultPeriodicBoxVectors(*topology.getPeriodicBoxVectors()) + for atom in topology.atoms(): + if atom.element is None: + system.addParticle(0) + else: + system.addParticle(atom.element.mass) + self.add_forces(topology, system, **kwargs) + return system + + def add_forces( + self, + topology: openmm.app.Topology, + system: openmm.System, + nl: Callable | None = None, + device: torch.device | None = None, + dtype: torch.dtype = torch.float32, + ) -> None: + """Add MACE forces to an existing OpenMM System. + + This method creates and adds a TorchForce to the provided System, which computes + interactions using the MACE potential. + + Args: + topology (openmm.app.Topology): The system topology. + system (openmm.System): The OpenMM System to which forces will be added. + nl (Callable | None, optional): The neighbor list method to use. + If None, an appropriate method will be chosen based on system size. + Defaults to None. + device (str | None, optional): The device to use for computations + ('cuda', 'cpu', or None for auto-detection). Defaults to None. + dtype (str, optional): The data type to use for computations. + Defaults to "float32". + """ + periodic = ( + topology.getPeriodicBoxVectors() is not None + ) or system.usesPeriodicBoundaryConditions() + + atoms = list(topology.atoms()) + atomic_numbers = [atom.element.atomic_number for atom in atoms] + + # get length of shortest box vector + box_vectors = topology.getPeriodicBoxVectors() + min_length = np.min(np.linalg.norm(box_vectors, axis=1)) + + # nnpops is both faster and O(n) but can't be used on small systems + if nl is None: + mace_cutoff = 5 + nl = nnpops_nl if min_length > (2 * mace_cutoff) else wrapping_nl + + # serialize the MACEForce into a module + maceforce = MaceForce( + self.model or torch.load(self.model_path), + atomic_numbers, + device=device, + nl=nl, + periodic=periodic, + dtype=dtype, + ) + module = torch.jit.script(maceforce) + + # Create the TorchForce and add it to the System. + force = TorchForce(module) + force.setForceGroup(0) + force.setUsesPeriodicBoundaryConditions(periodic) + system.addForce(force) + + +def nnpops_nl( + positions: torch.Tensor, + cell: torch.Tensor, + pbc: bool, + cutoff: float, + sorti: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """Run a neighbor list computation using NNPOps. + + It outputs neighbors and shifts in the same format as ASE + https://wiki.fysik.dtu.dk/ase/ase/neighborlist.html#ase.neighborlist.primitive_neighbor_list + + neighbors, shifts = nnpops_nl(..) + is equivalent to + + [i, j], S = primitive_neighbor_list( quantities="ijS", ...) + + Parameters + ---------- + positions : torch.Tensor + Atom positions, shape (num_atoms, 3) + cell : torch.Tensor + Unit cell, shape (3, 3) + pbc : bool + Whether to use periodic boundary conditions + cutoff : float + Cutoff distance for neighbors + sorti : bool, optional + Whether to sort the neighbor list by the first index. + Defaults to False. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - neighbors (torch.Tensor): Neighbor list, shape (2, num_neighbors) + - shifts (torch.Tensor): Shift vectors, shape (num_neighbors, 3) + """ + device = positions.device + neighbors, deltas, _, _ = getNeighborPairs( + positions, + cutoff=cutoff, + max_num_pairs=-1, + box_vectors=cell if pbc else None, + check_errors=False, + ) + + neighbors = neighbors.to(dtype=torch.long) + + # remove empty neighbors + mask = neighbors[0] > -1 + neighbors = neighbors[:, mask] + deltas = deltas[mask, :] + + # compute shifts TODO: pass deltas and distance directly to model + # From ASE docs: + # wrapped_delta = pos[i] - pos[j] - shift.cell + # => shift = ((pos[i]-pos[j]) - wrapped_delta).cell^-1 + if pbc: + shifts = torch.mm( + (positions[neighbors[0]] - positions[neighbors[1]]) - deltas, + torch.linalg.inv(cell), + ) + else: + shifts = torch.zeros(deltas.shape, device=device) + + # we have ij + neighbors = torch.hstack((neighbors, torch.stack((neighbors[1], neighbors[0])))) + shifts = torch.vstack((shifts, -shifts)) + + if sorti: + idx = torch.argsort(neighbors[0]) + neighbors = neighbors[:, idx] + shifts = shifts[idx, :] + + return neighbors, shifts + + +@torch.jit.script +def wrapping_nl( + positions: torch.Tensor, + cell: torch.Tensor, + pbc: bool, + cutoff: float, + sorti: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """Neighbor list including self-interactions across periodic boundaries. + + Parameters + ---------- + positions : torch.Tensor + Atom positions, shape (num_atoms, 3) + cell : torch.Tensor + Unit cell, shape (3, 3) + pbc : bool + Whether to use periodic boundary conditions + cutoff : float + Cutoff distance for neighbors + sorti : bool, optional + Whether to sort the neighbor list by the first index. + Defaults to False. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + A tuple containing: + - neighbors (torch.Tensor): Neighbor list, shape (2, num_neighbors) + - shifts (torch.Tensor): Shift vectors, shape (num_neighbors, 3) + """ + num_atoms = positions.shape[0] + device = positions.device + dtype = positions.dtype + + # Get all unique pairs including self-pairs (i <= j) + uij = torch.triu_indices(num_atoms, num_atoms, offset=0, device=device) + i_indices = uij[0] + j_indices = uij[1] + + if pbc: + # Compute displacement vectors between atom pairs + deltas = positions[i_indices] - positions[j_indices] + + # Compute inverse cell matrix + inv_cell = torch.linalg.inv(cell) + + # Compute fractional coordinates of displacement vectors + frac_deltas = torch.matmul(deltas, inv_cell) + + # Determine the maximum number of shifts needed along each axis + cell_lengths = torch.linalg.norm(cell, dim=0) + n_max = torch.ceil(cutoff / cell_lengths).to(torch.int32) + + # Extract scalar values from n_max + n_max0 = int(n_max[0]) + n_max1 = int(n_max[1]) + n_max2 = int(n_max[2]) + + # Generate shift ranges + shift_range_x = torch.arange(-n_max0, n_max0 + 1, device=device, dtype=dtype) + shift_range_y = torch.arange(-n_max1, n_max1 + 1, device=device, dtype=dtype) + shift_range_z = torch.arange(-n_max2, n_max2 + 1, device=device, dtype=dtype) + + # Generate all combinations of shifts within the range [-n_max, n_max] + shift_x, shift_y, shift_z = torch.meshgrid( + shift_range_x, shift_range_y, shift_range_z, indexing="ij" + ) + + shifts_list = torch.stack( + (shift_x.reshape(-1), shift_y.reshape(-1), shift_z.reshape(-1)), dim=1 + ) + + # Total number of shifts + num_shifts = shifts_list.shape[0] + + # Expand atom pairs and shifts + num_pairs = i_indices.shape[0] + i_indices_expanded = i_indices.repeat_interleave(num_shifts) + j_indices_expanded = j_indices.repeat_interleave(num_shifts) + shifts_expanded = shifts_list.repeat(num_pairs, 1) + + # Expand fractional displacements + frac_deltas_expanded = frac_deltas.repeat_interleave(num_shifts, dim=0) + + # Apply shifts to fractional displacements + shifted_frac_deltas = frac_deltas_expanded - shifts_expanded + + # Convert back to Cartesian coordinates + shifted_deltas = torch.matmul(shifted_frac_deltas, cell) + + # Compute distances + distances = torch.linalg.norm(shifted_deltas, dim=1) + + # Apply cutoff filter + within_cutoff = distances <= cutoff + + # Exclude self-pairs where shift is zero (no periodic boundary crossing) + shift_zero = (shifts_expanded == 0).all(dim=1) + i_eq_j = i_indices_expanded == j_indices_expanded + exclude_self_zero_shift = i_eq_j & shift_zero + within_cutoff = within_cutoff & (~exclude_self_zero_shift) + + num_within_cutoff = int(within_cutoff.sum()) + + i_indices_final = i_indices_expanded[within_cutoff] + j_indices_final = j_indices_expanded[within_cutoff] + shifts_final = shifts_expanded[within_cutoff] + + # Generate neighbor pairs and shifts + neighbors = torch.stack((i_indices_final, j_indices_final), dim=0) + shifts = shifts_final + + # Add symmetric pairs (j, i) and negate shifts, + # but avoid duplicates for self-pairs + i_neq_j = i_indices_final != j_indices_final + neighbors_sym = torch.stack( + (j_indices_final[i_neq_j], i_indices_final[i_neq_j]), dim=0 + ) + shifts_sym = -shifts_final[i_neq_j] + + neighbors = torch.cat((neighbors, neighbors_sym), dim=1) + shifts = torch.cat((shifts, shifts_sym), dim=0) + + if sorti: + idx = torch.argsort(neighbors[0]) + neighbors = neighbors[:, idx] + shifts = shifts[idx, :] + + return neighbors, shifts + + # Non-periodic case + deltas = positions[i_indices] - positions[j_indices] + distances = torch.linalg.norm(deltas, dim=1) + + # Apply cutoff filter + within_cutoff = distances <= cutoff + + # Exclude self-pairs where distance is zero + i_eq_j = i_indices == j_indices + exclude_self_zero_distance = i_eq_j & (distances == 0) + within_cutoff = within_cutoff & (~exclude_self_zero_distance) + + num_within_cutoff = int(within_cutoff.sum()) + + i_indices_final = i_indices[within_cutoff] + j_indices_final = j_indices[within_cutoff] + + shifts_final = torch.zeros((num_within_cutoff, 3), device=device, dtype=dtype) + + # Generate neighbor pairs and shifts + neighbors = torch.stack((i_indices_final, j_indices_final), dim=0) + shifts = shifts_final + + # Add symmetric pairs (j, i) and shifts (only if i != j) + i_neq_j = i_indices_final != j_indices_final + neighbors_sym = torch.stack( + (j_indices_final[i_neq_j], i_indices_final[i_neq_j]), dim=0 + ) + shifts_sym = shifts_final[i_neq_j] # shifts are zero + + neighbors = torch.cat((neighbors, neighbors_sym), dim=1) + shifts = torch.cat((shifts, shifts_sym), dim=0) + + if sorti: + idx = torch.argsort(neighbors[0]) + neighbors = neighbors[:, idx] + shifts = shifts[idx, :] + + return neighbors, shifts diff --git a/src/atomate2/openmm/utils.py b/src/atomate2/openmm/utils.py index ce6fb4055d..65a303c901 100644 --- a/src/atomate2/openmm/utils.py +++ b/src/atomate2/openmm/utils.py @@ -11,10 +11,13 @@ from typing import TYPE_CHECKING import numpy as np +import openmm import openmm.unit as omm_unit from emmet.core.openmm import OpenMMInterchange from openmm import LangevinMiddleIntegrator, State, XmlSerializer -from openmm.app import PDBFile, Simulation +from openmm.app import PDBFile, Simulation, Topology +from openmm.unit import angstrom +from pymatgen.core import Structure from pymatgen.core.trajectory import Trajectory if TYPE_CHECKING: @@ -174,6 +177,55 @@ def openff_to_openmm_interchange( ) +def structure_to_topology(structure: Structure) -> Topology: + """Convert pymatgen structure to openmm topology. + + Parameters + ---------- + structure : Structure + The pymatgen structure to convert. + + Returns + ------- + openmm.app.Topology + The converted OpenMM topology. + """ + top = Topology() + chain = top.addChain() + for i, site in enumerate(structure): + res = top.addResidue(f"r{i}", chain) + element = openmm.app.element.Element.getBySymbol(site.species_string) + top.addAtom(f"{element}{i}", element, res) + return top + + +def interchange_to_structure(interchange: OpenMMInterchange) -> Structure: + """Convert an OpenMMInterchange object to a pymatgen Structure. + + Parameters + ---------- + interchange : OpenMMInterchange + The OpenMMInterchange object to convert. + + Returns + ------- + Structure + The converted pymatgen Structure. + """ + with io.StringIO(interchange.topology) as buffer: + pdb = PDBFile(buffer) + topology = pdb.getTopology() + + state = XmlSerializer.deserialize(interchange.state) + + return Structure( + lattice=state.getPeriodicBoxVectors(asNumpy=True).value_in_unit(angstrom), + species=[atom.element.symbol for atom in topology.atoms()], + coords=state.getPositions(asNumpy=True).value_in_unit(angstrom), + coords_are_cartesian=True, + ) + + class PymatgenTrajectoryReporter: """Reporter that creates a pymatgen Trajectory from an OpenMM simulation. diff --git a/tests/openmm_md/conftest.py b/tests/openmm_md/conftest.py index d8d69e2c2f..69c8c423e1 100644 --- a/tests/openmm_md/conftest.py +++ b/tests/openmm_md/conftest.py @@ -1,6 +1,18 @@ +import gzip +import json +import shutil +from pathlib import Path + import pytest -from emmet.core.openmm import OpenMMInterchange -from jobflow import run_locally +from emmet.core.openmm import OpenMMInterchange, OpenMMTaskDocument +from jobflow import Flow, JobStore, run_locally +from maggma.stores import MemoryStore +from monty.json import MontyDecoder +from pymatgen.core import Composition, Structure + +from atomate2.common.jobs.mpmorph import get_random_packed_structure +from atomate2.forcefields.utils import revert_default_dtype +from atomate2.openmm.jobs.core import NVTMaker @pytest.fixture @@ -70,6 +82,80 @@ def interchange(openmm_data): return OpenMMInterchange.model_validate_json(file.read()) -@pytest.fixture -def output_dir(test_dir): - return test_dir / "classical_md" / "output_dir" +@pytest.fixture(scope="session") +def random_structure(test_dir) -> Structure: + test_files = test_dir / "openmm" / "mlff_test_files" + test_files.mkdir(parents=True, exist_ok=True) + struct_file = test_files / "random_structure.json" + + # disable this flag to speed up local testing + regenerate_test_data = False + if regenerate_test_data: + struct_file.unlink(missing_ok=True) + composition = Composition("Al85Ni10Fe5") + + n_atoms = 60 + struct = get_random_packed_structure( + composition=composition, + target_atoms=n_atoms, + packmol_seed=1, + ) + struct.to_file(str(struct_file)) + return Structure.from_file(struct_file) + + +@pytest.mark.openmm_slow +@pytest.fixture(scope="session") +def task_doc(random_structure: Structure, test_dir: Path) -> OpenMMInterchange: + from atomate2.openmm.jobs.mace import generate_mace_interchange + + output_dir = test_dir / "openmm" / "mlff_test_files" + output_dir.mkdir(parents=True, exist_ok=True) + + json_path = output_dir / "taskdoc.json" + gz_path = output_dir / "taskdoc.json.gz" + + # disable this flag to speed up local testing + regenerate_test_data = False + if regenerate_test_data: + (output_dir / "taskdoc.json").unlink(missing_ok=True) + + with revert_default_dtype(): + generate_job = generate_mace_interchange( + random_structure, + ) + nvt_job = NVTMaker( + n_steps=2, traj_interval=1, state_interval=1, save_structure=True + ).make( + generate_job.output.interchange, prev_dir=generate_job.output.dir_name + ) + + job_store = JobStore( + MemoryStore(), additional_stores={"data": MemoryStore()} + ) + + run_locally( + Flow([generate_job, nvt_job]), + store=job_store, + ensure_success=True, + root_dir=output_dir, + ) + + # Compress the generated JSON file + with json_path.open("rb") as f_in, gzip.open(gz_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + json_path.unlink() # Remove the uncompressed file + + # Read from the compressed file + with gzip.open(gz_path, "rt") as f: + task_doc_dict = json.load(f, cls=MontyDecoder) + + # task_doc_dict = json.load((output_dir / "taskdoc.json").open(), cls=MontyDecoder) + + return OpenMMTaskDocument.model_validate(task_doc_dict) + + +@pytest.mark.openmm_slow +@pytest.fixture(scope="session") +def mace_interchange(task_doc: OpenMMTaskDocument) -> OpenMMInterchange: + return OpenMMInterchange.model_validate_json(task_doc.interchange) diff --git a/tests/openmm_md/jobs/test_mace.py b/tests/openmm_md/jobs/test_mace.py new file mode 100644 index 0000000000..f855d79d1b --- /dev/null +++ b/tests/openmm_md/jobs/test_mace.py @@ -0,0 +1,103 @@ +from collections.abc import Callable +from pathlib import Path + +import openmm +import openmm.unit as omm_unit +import pytest +from emmet.core.openmm.tasks import OpenMMInterchange, OpenMMTaskDocument +from jobflow import JobStore, run_locally +from numpy.testing import assert_allclose +from pymatgen.core import Structure + +from atomate2.forcefields.utils import revert_default_dtype +from atomate2.openmm.flows import OpenMMFlowMaker +from atomate2.openmm.jobs import EnergyMinimizationMaker, NPTMaker, NVTMaker + + +@pytest.mark.openmm_mace +def test_generate_openmm_interchange( + task_doc: OpenMMTaskDocument, random_structure: Structure +) -> None: + assert_allclose( + task_doc.structure.frac_coords, random_structure.frac_coords, atol=0.01 + ) + interchange = OpenMMInterchange.model_validate_json(task_doc.interchange) + integrator = openmm.LangevinIntegrator( + 300 * omm_unit.kelvin, 10.0 / omm_unit.picoseconds, 1.0 * omm_unit.femtosecond + ) + platform = openmm.Platform.getPlatformByName("CPU") + + sim = interchange.to_openmm_simulation(integrator, platform) + assert isinstance(sim, openmm.app.Simulation) + + +@pytest.mark.openmm_mace +def test_nvt_maker(task_doc: OpenMMInterchange) -> None: + # the task document in the fixture is generated by an nvt maker + + # Test length of state attributes in calculation output + calc_output = task_doc.calcs_reversed[0].output + assert len(calc_output.steps_reported) == 2 + + # Test that the state interval is respected + assert calc_output.steps_reported == list(range(1, 3)) + + +@pytest.mark.openmm_mace +def test_npt_maker(interchange: OpenMMInterchange, run_job: Callable) -> None: + # this is validated upstream in atomate2, we are ensuring it works with mace + maker = NPTMaker(n_steps=2, state_interval=1, pressure=0) + base_job = maker.make(interchange) + with revert_default_dtype(): + run_job(base_job) + + +@pytest.mark.openmm_mace +def test_energy_minimization_maker( + interchange: OpenMMInterchange, run_job: Callable +) -> None: + # this is validated upstream in atomate2, we are ensuring it works with mace + maker = EnergyMinimizationMaker(max_iterations=1) + base_job = maker.make(interchange) + with revert_default_dtype(): + run_job(base_job) + + +@pytest.mark.openmm_mace +def test_flow_maker( + interchange: OpenMMInterchange, job_store: JobStore, tmp_path: Path +): + # this is validated upstream in atomate2, we are ensuring it works with mace + production_maker = OpenMMFlowMaker( + name="test_production", + tags=["test"], + makers=[ + NVTMaker( + n_steps=1, + state_interval=1, + traj_interval=1, + temperature=5000, + ), + NVTMaker(n_steps=1), + ], + ) + + # Run the ProductionMaker flow + production_flow = production_maker.make(interchange) + # task_doc = run_job(production_flow) + + with revert_default_dtype(): + response_dict = run_locally( + production_flow, store=job_store, ensure_success=True, root_dir=tmp_path + ) + task_doc = list(response_dict.values())[-1][1].output + + # Check the output task document + assert isinstance(task_doc, OpenMMTaskDocument) + assert task_doc.state == "successful" + assert len(task_doc.calcs_reversed) == 2 + assert task_doc.calcs_reversed[-1].task_name == "nvt simulation" + assert task_doc.calcs_reversed[0].task_name == "nvt simulation" + assert task_doc.tags == ["test"] + assert len(task_doc.job_uuids) == 2 + assert task_doc.job_uuids[0] is not None diff --git a/tests/openmm_md/test_mace_utils.py b/tests/openmm_md/test_mace_utils.py new file mode 100644 index 0000000000..fcf8799f04 --- /dev/null +++ b/tests/openmm_md/test_mace_utils.py @@ -0,0 +1,333 @@ +from pathlib import Path + +import numpy as np +import pytest +import torch +from mace.calculators.foundations_models import download_mace_mp_checkpoint +from pymatgen.core import Structure + +from atomate2.openmm.mace_utils import MacePotential, nnpops_nl, wrapping_nl +from atomate2.openmm.utils import structure_to_topology + + +@pytest.mark.openmm_mace +def test_mace_potential(random_structure: Structure): + ff_path = Path(download_mace_mp_checkpoint()) + + potential = MacePotential(model_path=ff_path) + + topology = structure_to_topology(random_structure) + topology.setPeriodicBoxVectors(random_structure.lattice.matrix / 10) + system = potential.create_system(topology) + + assert system.getNumParticles() == len(random_structure) + assert len(system.getForces()) == 1 + + +@pytest.fixture(scope="module") +def large_box() -> tuple[torch.Tensor, torch.Tensor, float, bool]: + """Fixture for a large orthorhombic box and random positions.""" + num_atoms = 50 + cell_lengths = torch.tensor([10.0, 10.0, 10.0]) + cell = torch.diag(cell_lengths) + positions = torch.rand((num_atoms, 3)) * cell_lengths + cutoff = 4.0 # Less than half of smallest box length (5.0) + pbc = True + return positions, cell, cutoff, pbc + + +@pytest.fixture(scope="module") +def small_box() -> tuple[torch.Tensor, torch.Tensor, float, bool]: + """Fixture for a small orthorhombic box and random positions.""" + num_atoms = 50 + cell_lengths = torch.tensor([5.0, 5.0, 5.0]) + cell = torch.diag(cell_lengths) + positions = torch.rand((num_atoms, 3)) * cell_lengths + cutoff = 4.9 # Greater than half of smallest box length (2.5) + pbc = True + return positions, cell, cutoff, pbc + + +@pytest.fixture(scope="module") +def triclinic_cell() -> tuple[torch.Tensor, torch.Tensor, float, bool]: + """Fixture for a triclinic cell and random positions.""" + num_atoms = 50 + a = 5.0 + b = 5.0 + c = 5.0 + alpha = 90 + beta = 90 + gamma = 90 # Non-orthogonal angle + + # Convert cell parameters to cell vectors + alpha_rad = np.radians(alpha) + beta_rad = np.radians(beta) + gamma_rad = np.radians(gamma) + + cell = torch.zeros((3, 3)) + cell[0, 0] = a + cell[1, 0] = b * np.cos(gamma_rad) + cell[1, 1] = b * np.sin(gamma_rad) + cell[2, 0] = c * np.cos(beta_rad) + cell[2, 1] = ( + c + * (np.cos(alpha_rad) - np.cos(beta_rad) * np.cos(gamma_rad)) + / np.sin(gamma_rad) + ) + cell[2, 2] = c * np.sqrt( + 1 + - np.cos(beta_rad) ** 2 + - ( + (np.cos(alpha_rad) - np.cos(beta_rad) * np.cos(gamma_rad)) + / np.sin(gamma_rad) + ) + ** 2 + ) + + positions = torch.rand((num_atoms, 3)) @ cell + cutoff = 3.0 + pbc = True + return positions, cell, cutoff, pbc + + +@pytest.mark.openmm_mace +def test_nl_agreement_in_large_box( + large_box: tuple[torch.Tensor, torch.Tensor, float, bool], +) -> None: + """Test that nnpops_nl and wrapped_nl produce the same results in a large box.""" + positions, cell, cutoff, pbc = large_box + + # Run both functions + neighbors_simple, shifts_simple = nnpops_nl(positions, cell, pbc, cutoff) + neighbors_wrapped, shifts_wrapped = wrapping_nl(positions, cell, pbc, cutoff) + + # convert neighbor lists so they can be easily compared + neighbors_simple_set = {tuple(pair) for pair in neighbors_simple.t().tolist()} + neighbors_wrapped_self_set = { + tuple(pair) for pair in neighbors_wrapped.t().tolist() + } + assert neighbors_simple_set == neighbors_wrapped_self_set + + # convert shift lists so they can be easily compared + shifts_simple_set = {tuple(pair) for pair in shifts_simple.tolist()} + shifts_wrapped_self_set = {tuple(pair) for pair in shifts_wrapped.tolist()} + assert shifts_simple_set == shifts_wrapped_self_set + + +@pytest.mark.openmm_mace +def test_nl_approximately_correct_in_small_box( + small_box: tuple[torch.Tensor, torch.Tensor, float, bool], +) -> None: + """Test that wrapped_nl works in a small box with large cutoff.""" + positions, cell, cutoff, pbc = small_box + + neighbors_wrapped, shifts_wrapped = wrapping_nl(positions, cell, pbc, cutoff) + + # Check that the function runs and returns expected types + assert neighbors_wrapped.shape[0] == 2, "Neighbors should have shape [2, N]" + assert shifts_wrapped.shape[0] == neighbors_wrapped.shape[1], ( + "Shifts should match number of neighbor pairs" + ) + assert 500 < neighbors_wrapped.shape[1] < 50_000, ( + "Shifts should be a reasonable size" + ) + + +@pytest.mark.openmm_mace +def test_nl_approximately_correct_in_triclinic_cell( + triclinic_cell: tuple[torch.Tensor, torch.Tensor, float, bool], +) -> None: + """Test that wrapped_nl works with a triclinic cell.""" + positions, cell, cutoff, pbc = triclinic_cell + + neighbors_wrapped, shifts_wrapped = wrapping_nl(positions, cell, pbc, cutoff) + + # Check that the function runs and returns expected types + assert neighbors_wrapped.shape[0] == 2, "Neighbors should have shape [2, N]" + assert shifts_wrapped.shape[0] == neighbors_wrapped.shape[1], ( + "Shifts should match number of neighbor pairs" + ) + assert 200 < neighbors_wrapped.shape[1] < 20_000, ( + "Shifts should be a reasonable size" + ) + + +@pytest.mark.openmm_mace +def test_exact_pairs_between_four_atoms_on_line() -> None: + """Test wrapped_nl with deterministically placed atoms and known cutoff.""" + + # Define cell parameters + cell_length = 10.0 + cell = torch.diag(torch.tensor([cell_length, cell_length, cell_length])) + pbc = True + cutoff = 2.5 + + # Define atom positions + positions = torch.tensor( + [ + [1.0, 1.0, 1.0], # Atom 0 + [1.0, 1.0, 3.0], # Atom 1 + [1.0, 1.0, 7.0], # Atom 2 + [1.0, 1.0, 9.0], # Atom 3 + ] + ) + + # Expected neighbor pairs (after symmetrization) + expected_pairs = { + (0, 1), + (1, 0), + (0, 3), + (3, 0), + (2, 3), + (3, 2), + } + + # Run wrapped_nl + neighbors_wrapped, shifts_wrapped = wrapping_nl(positions, cell, pbc, cutoff) + + # Extract neighbor pairs + neighbor_pairs = ( + neighbors_wrapped.t().tolist() + ) # Transpose and convert to list of pairs + + # Convert neighbor pairs to set of tuples for comparison + neighbor_pairs_set = {tuple(pair) for pair in neighbor_pairs} + + # Assert that the neighbor pairs match the expected pairs + assert neighbor_pairs_set == expected_pairs + + # Assert that the number of neighbor pairs is as expected + expected_num_pairs = len(expected_pairs) + actual_num_pairs = neighbors_wrapped.shape[1] + assert actual_num_pairs == expected_num_pairs, ( + f"Expected {expected_num_pairs} neighbor pairs, got {actual_num_pairs}." + ) + + +@pytest.mark.openmm_mace +@pytest.mark.parametrize( + "cutoff, n_pairs", + [ + (1.1, 2), + (1.6, 6), + (2.6, 10), + (3.1, 12), + ], +) +def test_n_neighbors_between_three_atoms_on_line(cutoff: float, n_pairs: int) -> None: + """Test wrapped_nl with deterministically placed atoms and known cutoff.""" + # Define cell parameters + cell_length = 4.0 + cell = torch.diag(torch.tensor([100, 100, cell_length])) + pbc = True + + # Define atom positions + positions = torch.tensor( + [ + [1.0, 1.0, 0.5], + [1.0, 1.0, 2.0], + [1.0, 1.0, 3.5], + ] # Atom 0 # Atom 1 # Atom 2 + ) + + # Run wrapped_nl + neighbors_wrapped, shifts_wrapped = wrapping_nl(positions, cell, pbc, cutoff) + + # Assert that the number of neighbor pairs is as expected + expected_num_pairs = n_pairs + actual_num_pairs = neighbors_wrapped.shape[1] + assert actual_num_pairs == expected_num_pairs, ( + f"Expected {expected_num_pairs} neighbor pairs, got {actual_num_pairs}." + ) + + +@pytest.mark.openmm_mace +@pytest.mark.parametrize( + "cutoff, n_pairs", + [ + (1.1, 2), + (2.9, 4), + (3.9, 8), + (4.1, 10), + ], +) +def test_n_neighbors_between_two_atoms_on_line(cutoff: float, n_pairs: int) -> None: + """Test wrapped_nl with deterministically placed atoms and known cutoff.""" + # Define cell parameters + cell_length = 3.0 + cell = torch.diag(torch.tensor([100, 100, cell_length])) + pbc = True + + # Define atom positions + positions = torch.tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 2.0]]) # Atom 0 # Atom 1 + + # Run wrapped_nl + neighbors_wrapped, shifts_wrapped = wrapping_nl(positions, cell, pbc, cutoff) + + # Assert that the number of neighbor pairs is as expected + expected_num_pairs = n_pairs + actual_num_pairs = neighbors_wrapped.shape[1] + assert actual_num_pairs == expected_num_pairs, ( + f"Expected {expected_num_pairs} neighbor pairs, got {actual_num_pairs}." + ) + + +@pytest.mark.openmm_mace +@pytest.mark.parametrize( + "cutoff, n_pairs", + [ + (1.1, 2), + (2.1, 4), + (3.1, 6), + ], +) +def test_n_neighbors_with_one_atom_on_line(cutoff: float, n_pairs: int) -> None: + """Test wrapped_nl with deterministically placed atoms and known cutoff.""" + # Define cell parameters + cell_length = 1.0 + cell = torch.diag(torch.tensor([100, 100, cell_length])) + pbc = True + + # Define atom positions + positions = torch.tensor([[1.0, 1.0, 0.5]]) # Atom 0 + + # Run wrapped_nl + neighbors_wrapped, shifts_wrapped = wrapping_nl(positions, cell, pbc, cutoff) + + # Assert that the number of neighbor pairs is as expected + expected_num_pairs = n_pairs + actual_num_pairs = neighbors_wrapped.shape[1] + assert actual_num_pairs == expected_num_pairs, ( + f"Expected {expected_num_pairs} neighbor pairs, got {actual_num_pairs}." + ) + + +@pytest.mark.openmm_mace +@pytest.mark.parametrize( + "cutoff, n_pairs", + [ + (1.1, 4), + (1.42, 8), + (2.01, 12), + (2.4, 20), + ], +) +def test_n_neighbors_with_one_atom_on_grid(cutoff: float, n_pairs: int) -> None: + """Test wrapped_nl with deterministically placed atoms and known cutoff.""" + # Define cell parameters + cell_length = 1.0 + cell = torch.diag(torch.tensor([100, cell_length, cell_length])) + pbc = True + + # Define atom positions + positions = torch.tensor([[1.0, 1.0, 0.5]]) # Atom 0 + + # Run wrapped_nl + neighbors_wrapped, shifts_wrapped = wrapping_nl(positions, cell, pbc, cutoff) + + # Assert that the number of neighbor pairs is as expected + expected_num_pairs = n_pairs + actual_num_pairs = neighbors_wrapped.shape[1] + assert actual_num_pairs == expected_num_pairs, ( + f"Expected {expected_num_pairs} neighbor pairs, got {actual_num_pairs}." + ) diff --git a/tests/openmm_md/test_utils.py b/tests/openmm_md/test_utils.py index 43da835feb..aa38f4bbfd 100644 --- a/tests/openmm_md/test_utils.py +++ b/tests/openmm_md/test_utils.py @@ -1,17 +1,21 @@ from pathlib import Path +import numpy as np import pytest from emmet.core.openmm import OpenMMInterchange +from pymatgen.core import Structure from atomate2.openmm.jobs.base import BaseOpenMMMaker from atomate2.openmm.utils import ( PymatgenTrajectoryReporter, download_opls_xml, increment_name, + interchange_to_structure, + structure_to_topology, ) -@pytest.mark.skip("annoying test") +@pytest.mark.skip("Unreliable test, needs browser to run successfully.") def test_download_xml(tmp_path: Path) -> None: pytest.importorskip("selenium") @@ -92,3 +96,20 @@ def test_trajectory_reporter(interchange: OpenMMInterchange, tmp_path: Path) -> # check that file was written assert (tmp_path / "trajectory.json").exists() + + +@pytest.mark.openmm_slow +def test_structure_to_topology(random_structure: Structure) -> None: + topology = structure_to_topology(random_structure) + assert topology is not None, "Topology should not be None." + num_atoms_in_topology = sum(1 for _ in topology.atoms()) + assert num_atoms_in_topology == len(random_structure), ( + "Number of atoms in topology should match structure." + ) + + +@pytest.mark.openmm_slow +def test_interchange_to_structure(interchange: OpenMMInterchange) -> None: + structure = interchange_to_structure(interchange) + assert len(structure) == 1170 + assert 4 < np.max(structure.cart_coords) < 16 diff --git a/tests/test_data/openmm/mlff_test_files/random_structure.json b/tests/test_data/openmm/mlff_test_files/random_structure.json new file mode 100644 index 0000000000..b8852df7fb --- /dev/null +++ b/tests/test_data/openmm/mlff_test_files/random_structure.json @@ -0,0 +1 @@ +{"@module": "pymatgen.core.structure", "@class": "Structure", "charge": 0, "lattice": {"matrix": [[8.945722384641158, 0.0, 0.0], [0.0, 8.945722384641158, 0.0], [0.0, 0.0, 8.945722384641158]], "pbc": [true, true, true], "a": 8.945722384641158, "b": 8.945722384641158, "c": 8.945722384641158, "alpha": 90.0, "beta": 90.0, "gamma": 90.0, "volume": 715.8899231699995}, "properties": {}, "sites": [{"species": [{"element": "Al", "occu": 1}], "abc": [0.64659998950236, 0.38190599407216497, 0.8952527985610246], "properties": {}, "label": "Al", "xyz": [5.784304, 3.4164249999999994, 8.008683]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8834840452425946, 0.31815720157899446, 0.8839606976376343], "properties": {}, "label": "Al", "xyz": [7.903403, 2.846146, 7.907666999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.3468634355708847, 0.757030534686289, 0.817740668161081], "properties": {}, "label": "Al", "xyz": [3.1029439999999995, 6.772185, 7.315281]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.11223308267689708, 0.8869685039212494, 0.1123682310693931], "properties": {}, "label": "Al", "xyz": [1.004006, 7.934574, 1.005215]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.16107788035921722, 0.3912898086363331, 0.6561033025211006], "properties": {}, "label": "Al", "xyz": [1.4409579999999997, 3.50037, 5.869317999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.68638146099213, 0.5142043093017935, 0.491245626797559], "properties": {}, "label": "Al", "xyz": [6.140178, 4.5999289999999995, 4.394547]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.5147756438212701, 0.11389152895569865, 0.3368813462369576], "properties": {}, "label": "Al", "xyz": [4.60504, 1.018842, 3.0136469999999997]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.5836308992735897, 0.3558866308511883, 0.1110339620761486], "properties": {}, "label": "Al", "xyz": [5.221, 3.183663, 0.9932789999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.7667138219911492, 0.8885189656283793, 0.19472759438531093], "properties": {}, "label": "Al", "xyz": [6.858809, 7.948444, 1.7419789999999997]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.4653749379868551, 0.11650641001216443, 0.675629730068182], "properties": {}, "label": "Al", "xyz": [4.163115, 1.042234, 6.043995999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8434003063803603, 0.6355647711314559, 0.3462858410762388], "properties": {}, "label": "Al", "xyz": [7.544825, 5.685586, 3.097777]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.10461547539266058, 0.268665167178269, 0.45120213063284215], "properties": {}, "label": "Al", "xyz": [0.935861, 2.4034039999999997, 4.036329]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.10656545765792173, 0.48920040348150656, 0.34230941542043336], "properties": {}, "label": "Al", "xyz": [0.953305, 4.376251, 3.062205]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.16209926237926373, 0.532217276066407, 0.10737713050979408], "properties": {}, "label": "Al", "xyz": [1.450095, 4.761068, 0.9605659999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.10987653738143902, 0.6517177427738681, 0.531342779892299], "properties": {}, "label": "Al", "xyz": [0.9829249999999999, 5.830086, 4.753245]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.6220034292091662, 0.11033888126180591, 0.8866595294325322], "properties": {}, "label": "Al", "xyz": [5.56427, 0.987061, 7.93181]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.837565227025604, 0.33900448388681453, 0.11178527088174481], "properties": {}, "label": "Al", "xyz": [7.492625999999999, 3.0326399999999993, 1.0]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.12504557506955002, 0.1168739599828236, 0.6426506158821079], "properties": {}, "label": "Al", "xyz": [1.118623, 1.045522, 5.748974]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.6789277309150062, 0.2765485998366623, 0.41951491882234837], "properties": {}, "label": "Al", "xyz": [6.073499, 2.473927, 3.752864]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8902798072153085, 0.6174102842139062, 0.1067077603077542], "properties": {}, "label": "Al", "xyz": [7.964196, 5.523181, 0.9545779999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.5625750256503042, 0.780675019827383, 0.11049593956739476], "properties": {}, "label": "Al", "xyz": [5.03264, 6.983702, 0.988466]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.5418058812468304, 0.5988211761631681, 0.8458476213157743], "properties": {}, "label": "Al", "xyz": [4.846844999999999, 5.356887999999999, 7.566718]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.6963827773926489, 0.10453331321856249, 0.5926859533560945], "properties": {}, "label": "Al", "xyz": [6.229647, 0.9351259999999999, 5.302004]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8508692392543241, 0.7985831320079093, 0.8003875698504825], "properties": {}, "label": "Al", "xyz": [7.61164, 7.143902999999999, 7.160045]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.419600322769302, 0.24957928538373256, 0.8982807261933984], "properties": {}, "label": "Al", "xyz": [3.7536279999999995, 2.232667, 8.03577]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.3342419842061687, 0.18765795849893677, 0.4822187426033163], "properties": {}, "label": "Al", "xyz": [2.9900359999999995, 1.678736, 4.313795]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.7144616974670827, 0.8272436458240091, 0.5737339903160743], "properties": {}, "label": "Al", "xyz": [6.391376, 7.400292, 5.132464999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.6282204788345254, 0.852304226776633, 0.8892708333603299], "properties": {}, "label": "Al", "xyz": [5.619886, 7.624477, 7.95517]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8881081547578888, 0.40276344883963533, 0.4077231377381166], "properties": {}, "label": "Al", "xyz": [7.944769, 3.6030099999999994, 3.647378]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.3606916089347692, 0.1239122959933509, 0.10891205406427132], "properties": {}, "label": "Al", "xyz": [3.2266469999999994, 1.108485, 0.974297]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.13044402115624212, 0.5731628793671395, 0.8127437547673961], "properties": {}, "label": "Al", "xyz": [1.1669159999999998, 5.127356, 7.27058]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8207171745688493, 0.11312177558040695, 0.22061840454477363], "properties": {}, "label": "Al", "xyz": [7.341908, 1.011956, 1.973591]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.3506572040940694, 0.8885147177880858, 0.17155763771910992], "properties": {}, "label": "Al", "xyz": [3.136882, 7.948406, 1.5347069999999998]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.5031197936264306, 0.8860720978340487, 0.685552461423271], "properties": {}, "label": "Al", "xyz": [4.500769999999999, 7.926555, 6.132761999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.38536317714472473, 0.5908122086455747, 0.19223835997331626], "properties": {}, "label": "Al", "xyz": [3.447352, 5.285242, 1.719711]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.6037762818655432, 0.11123148664979664, 0.1103842660817839], "properties": {}, "label": "Al", "xyz": [5.401215, 0.9950459999999999, 0.9874669999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.5199556614886593, 0.6769940692992937, 0.5641595818650529], "properties": {}, "label": "Al", "xyz": [4.651379, 6.056201, 5.046815]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.4035626017411689, 0.3583061112541528, 0.6792217261774252], "properties": {}, "label": "Al", "xyz": [3.6101589999999995, 3.205307, 6.076128999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.12481473848517921, 0.8587373573306065, 0.8211087583727481], "properties": {}, "label": "Al", "xyz": [1.116558, 7.682026, 7.345411]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.14827310115025513, 0.3429417846978114, 0.8948110231705], "properties": {}, "label": "Al", "xyz": [1.3264099999999999, 3.067862, 8.004731]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.6645119023820963, 0.5724788652946141, 0.19265699581276838], "properties": {}, "label": "Al", "xyz": [5.944538999999999, 5.121237, 1.7234559999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8900891015431842, 0.15954295680474, 0.4518175085490461], "properties": {}, "label": "Al", "xyz": [7.96249, 1.427227, 4.041834]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.88777995320258, 0.868444119542352, 0.4076669097468631], "properties": {}, "label": "Al", "xyz": [7.941832999999999, 7.76886, 3.6468749999999996]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.6454857139222108, 0.33587248416724985, 0.6551271935357612], "properties": {}, "label": "Al", "xyz": [5.774336, 3.004622, 5.860586]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.6042424264451202, 0.7678240733015467, 0.3522360592600032], "properties": {}, "label": "Al", "xyz": [5.405385, 6.868740999999999, 3.1510059999999998]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.21920409729757778, 0.10988827483488159, 0.8689416757830466], "properties": {}, "label": "Al", "xyz": [1.960939, 0.9830299999999998, 7.773311]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8854965154742787, 0.6515903075650629, 0.5878952837871864], "properties": {}, "label": "Al", "xyz": [7.921406, 5.828945999999999, 5.259147999999999]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.17644310119772572, 0.8866629947759297, 0.5117170870246747], "properties": {}, "label": "Al", "xyz": [1.578411, 7.931841, 4.577679]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.8902248088620346, 0.3336060378001225, 0.6406499948891374], "properties": {}, "label": "Al", "xyz": [7.963703999999999, 2.984347, 5.731077]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.3383039255941987, 0.3644458054720618, 0.11053484084166161], "properties": {}, "label": "Al", "xyz": [3.026373, 3.260231, 0.988814]}, {"species": [{"element": "Al", "occu": 1}], "abc": [0.109196882934478, 0.2875773346618552, 0.15432973891189894], "properties": {}, "label": "Al", "xyz": [0.976845, 2.5725869999999995, 1.380591]}, {"species": [{"element": "Fe", "occu": 1}], "abc": [0.5150752283472332, 0.43701378512617667, 0.33287015536190795], "properties": {}, "label": "Fe", "xyz": [4.60772, 3.909404, 2.977764]}, {"species": [{"element": "Fe", "occu": 1}], "abc": [0.3231218090293945, 0.6687096628742476, 0.4165378534882257], "properties": {}, "label": "Fe", "xyz": [2.890558, 5.982090999999999, 3.7262319999999995]}, {"species": [{"element": "Fe", "occu": 1}], "abc": [0.8868078684869924, 0.11344483501325518, 0.7480456817538979], "properties": {}, "label": "Fe", "xyz": [7.933137, 1.014846, 6.691809]}, {"species": [{"element": "Ni", "occu": 1}], "abc": [0.4182014418894879, 0.8897721903002345, 0.4561953551125879], "properties": {}, "label": "Ni", "xyz": [3.741114, 7.959655, 4.080997]}, {"species": [{"element": "Ni", "occu": 1}], "abc": [0.3127311445303946, 0.5842903205865211, 0.6508213366866673], "properties": {}, "label": "Ni", "xyz": [2.797606, 5.226899, 5.822067]}, {"species": [{"element": "Ni", "occu": 1}], "abc": [0.11719713120094273, 0.7311663294212961, 0.3001219895454755], "properties": {}, "label": "Ni", "xyz": [1.048413, 6.540811, 2.6848079999999994]}, {"species": [{"element": "Ni", "occu": 1}], "abc": [0.19971142890230265, 0.10653147493557369, 0.29398352496554625], "properties": {}, "label": "Ni", "xyz": [1.786563, 0.953001, 2.629895]}, {"species": [{"element": "Ni", "occu": 1}], "abc": [0.77572270875678, 0.5601726483937846, 0.7856451047560562], "properties": {}, "label": "Ni", "xyz": [6.9394, 5.011149, 7.028162999999999]}, {"species": [{"element": "Ni", "occu": 1}], "abc": [0.30994746771489734, 0.42839927679621675, 0.4650978223003392], "properties": {}, "label": "Ni", "xyz": [2.7727039999999996, 3.8323409999999996, 4.160636]}]} diff --git a/tests/test_data/openmm/mlff_test_files/taskdoc.json.gz b/tests/test_data/openmm/mlff_test_files/taskdoc.json.gz new file mode 100644 index 0000000000..115f3289a9 Binary files /dev/null and b/tests/test_data/openmm/mlff_test_files/taskdoc.json.gz differ