From 5a823db65c7dfbea072f36b0ebeaf897ca0f5a1b Mon Sep 17 00:00:00 2001 From: lbluque Date: Mon, 29 Sep 2025 17:24:26 -0700 Subject: [PATCH 01/16] generalize hf reference download --- .../core/calculate/pretrained_mlip.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/fairchem/core/calculate/pretrained_mlip.py b/src/fairchem/core/calculate/pretrained_mlip.py index 89a1eef9aa..f19b8d0ff2 100644 --- a/src/fairchem/core/calculate/pretrained_mlip.py +++ b/src/fairchem/core/calculate/pretrained_mlip.py @@ -30,6 +30,7 @@ class HuggingFaceCheckpoint: subfolder: str | None = None # specify a hf repo subfolder revision: str | None = None # specify a version tag, branch, commit hash atom_refs: dict | None = None # specify an isolated atomic reference + mp_elemental_refs: dict | None = None # mp unary compount elemental reference @dataclass @@ -107,31 +108,37 @@ def get_predict_unit( revision=model_checkpoint.revision, cache_dir=cache_dir, ) - atom_refs = get_isolated_atomic_energies(model_name, cache_dir) + atom_refs = get_reference_energies(model_name, "atom_refs", cache_dir) return load_predict_unit( checkpoint_path, inference_settings, overrides, device, atom_refs ) -def get_isolated_atomic_energies(model_name: str, cache_dir: str = CACHE_DIR) -> dict: +def get_reference_energies( + model_name: str, + reference_type: Literal["atom_refs", "bulk_refs"], + cache_dir: str = CACHE_DIR, +) -> dict: """ Retrieves the isolated atomic energies for use with single atom systems into the CACHE_DIR Args: model_name: Name of the model to load from available pretrained models. + reference_type: Type of references file to download: atom_refs or bulk_refs. cache_dir: Path to folder where files will be stored. Default is "~/.cache/fairchem" Returns: - Atomic element reference data + Atomic or bulk phase element reference data Raises: KeyError: If the specified model_name is not found in available models. """ model_checkpoint = _MODEL_CKPTS.checkpoints[model_name] - atomic_refs_path = hf_hub_download( - filename=model_checkpoint.atom_refs["filename"], + file_data = getattr(model_checkpoint, reference_type) + refs_path = hf_hub_download( + filename=file_data["filename"], repo_id=model_checkpoint.repo_id, - subfolder=model_checkpoint.atom_refs["subfolder"], + subfolder=file_data["subfolder"], revision=model_checkpoint.revision, cache_dir=cache_dir, ) - return OmegaConf.load(atomic_refs_path) + return OmegaConf.load(refs_path) From d5a76ef7fcc7a064be6b22475af386a9bceb775b Mon Sep 17 00:00:00 2001 From: lbluque Date: Tue, 30 Sep 2025 19:15:34 +0000 Subject: [PATCH 02/16] initial implementation of formation energy calc --- src/fairchem/core/calculate/ase_calculator.py | 62 +++++++++++++++++++ .../core/calculate/pretrained_mlip.py | 2 +- 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/src/fairchem/core/calculate/ase_calculator.py b/src/fairchem/core/calculate/ase_calculator.py index f3c59899b5..ffdb78f2ab 100644 --- a/src/fairchem/core/calculate/ase_calculator.py +++ b/src/fairchem/core/calculate/ase_calculator.py @@ -10,6 +10,7 @@ import logging import os from functools import partial +from collections import Counter from typing import TYPE_CHECKING, Literal import numpy as np @@ -316,6 +317,67 @@ def _validate_charge_and_spin(self, atoms: Atoms) -> None: ) +class FormationEnergyCalculator(FAIRChemCalculator): + def __init__( + self, + predict_unit: MLIPPredictUnit, + task_name: UMATask | str | None = None, + seed: int | None = None, # deprecated + element_references: dict | None = None + ): + """ + Initialize the FormationEnergyCalculator. + + Args: + predict_unit (MLIPPredictUnit): A pretrained MLIPPredictUnit. + task_name (UMATask or str, optional): Name of the task to use if using a UMA checkpoint. + seed (int, optional): Deprecated. Random seed for reproducibility. + element_references (dict): Dictionary mapping element symbols to their reference bulk phase energies. + Using the default will take the appropriate values for the dataset corresponding to the task given, this + is likely what you want to use, always. + """ + super().__init__(predict_unit=predict_unit, task_name=task_name, seed=seed) + + if element_references is None: + # get them from HF + element_references = {} + + self._element_refs = element_references + + def calculate( + self, atoms: Atoms, properties: list[str], system_changes: list[str] + ) -> None: + """ + Perform the calculation for the given atomic structure and convert total energy to formation energy. + + Args: + atoms (Atoms): The atomic structure to calculate properties for. + properties (list[str]): The list of properties to calculate. + system_changes (list[str]): The list of changes in the system. + """ + # First get the total energy from the parent calculator + super().calculate(atoms, properties, system_changes) + + # If energy was calculated, convert it to formation energy + total_energy = self.results["energy"] + + atomic_numbers = atoms.get_atomic_numbers() + element_symbols = atoms.get_chemical_symbols() + element_counts = Counter(element_symbols) + + missing_elements = set(element_symbols) - set(self._element_refs.keys()) + if missing_elements: + raise ValueError(f"Missing reference energies for elements: {missing_elements}") + + total_ref_energy = sum( + self._element_refs[element] * count + for element, count in element_counts.items() + ) + + formation_energy = (total_energy - total_ref_energy) + self.results["energy"] = formation_energy + + class MixedPBCError(ValueError): """Specific exception example.""" diff --git a/src/fairchem/core/calculate/pretrained_mlip.py b/src/fairchem/core/calculate/pretrained_mlip.py index f19b8d0ff2..4232026c66 100644 --- a/src/fairchem/core/calculate/pretrained_mlip.py +++ b/src/fairchem/core/calculate/pretrained_mlip.py @@ -116,7 +116,7 @@ def get_predict_unit( def get_reference_energies( model_name: str, - reference_type: Literal["atom_refs", "bulk_refs"], + reference_type: Literal["atom_refs"] = "atom_refs", cache_dir: str = CACHE_DIR, ) -> dict: """ From 17e087118c171f7d00ef42b88259d551d9ca8dde Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 30 Oct 2025 21:29:00 +0000 Subject: [PATCH 03/16] use runtime re-binding instead of a new class --- src/fairchem/core/calculate/ase_calculator.py | 132 ++++++++++-------- 1 file changed, 77 insertions(+), 55 deletions(-) diff --git a/src/fairchem/core/calculate/ase_calculator.py b/src/fairchem/core/calculate/ase_calculator.py index ffdb78f2ab..1857feedbf 100644 --- a/src/fairchem/core/calculate/ase_calculator.py +++ b/src/fairchem/core/calculate/ase_calculator.py @@ -9,13 +9,14 @@ import logging import os -from functools import partial from collections import Counter +from functools import partial from typing import TYPE_CHECKING, Literal import numpy as np from ase.calculators.calculator import Calculator from ase.stress import full_3x3_to_voigt_6_stress +from monty.dev import requires from fairchem.core.calculate import pretrained_mlip from fairchem.core.datasets import data_list_collater @@ -30,6 +31,12 @@ UMATask, ) +try: + from fairchem.data.omat import data_omat_installed +except ImportError: + data_omat_installed = False + + if TYPE_CHECKING: from ase import Atoms @@ -317,65 +324,80 @@ def _validate_charge_and_spin(self, atoms: Atoms) -> None: ) -class FormationEnergyCalculator(FAIRChemCalculator): - def __init__( - self, - predict_unit: MLIPPredictUnit, - task_name: UMATask | str | None = None, - seed: int | None = None, # deprecated - element_references: dict | None = None - ): - """ - Initialize the FormationEnergyCalculator. +@requires( + data_omat_installed, + message="Formation energy functionality requires fairchem.data.omat to be installed.", +) +def _apply_mp_style_corrections(formation_energy: float, atoms: Atoms) -> float: + pass + + +def enable_formation_energy( + calculator: FAIRChemCalculator, + element_references: dict | None = None, + apply_corrections: bool | None = None, +) -> FAIRChemCalculator: + """ + Helper function to easily enable formation energy calculation on a FAIRChemCalculator instance. + + Args: + calculator (FAIRChemCalculator): The calculator instance to modify. + element_references (dict, optional): Optional dictionary of formation reference energies for each element. You likely do not want + to provide these and instead use the defaults for each UMA task. + apply_corrections (bool, optional): Whether to apply MP style corrections to the formation energies. + This is only relevant for the OMat task. Default is True if task is OMat. + + Returns: + FAIRChemCalculator: The same calculator instance but will return formation energies as the potential energy. + """ + if element_references is None: + # get these + element_references = {} + + if apply_corrections is True and calculator.task_name != UMATask.OMAT.value: + raise ValueError("MP style corrections can only be applied for the OMat task.") + + if apply_corrections is None and calculator.task_name == UMATask.OMAT.value: + apply_corrections = True + + original_calculate = calculator.calculate + + def formation_energy_calculate( + atoms: Atoms, properties: list[str], system_changes: list[str] + ) -> None: + original_calculate(atoms, properties, system_changes) - Args: - predict_unit (MLIPPredictUnit): A pretrained MLIPPredictUnit. - task_name (UMATask or str, optional): Name of the task to use if using a UMA checkpoint. - seed (int, optional): Deprecated. Random seed for reproducibility. - element_references (dict): Dictionary mapping element symbols to their reference bulk phase energies. - Using the default will take the appropriate values for the dataset corresponding to the task given, this - is likely what you want to use, always. - """ - super().__init__(predict_unit=predict_unit, task_name=task_name, seed=seed) + if "energy" in calculator.results: + total_energy = calculator.results["energy"] - if element_references is None: - # get them from HF - element_references = {} + element_symbols = atoms.get_chemical_symbols() + element_counts = Counter(element_symbols) - self._element_refs = element_references + missing_elements = set(element_symbols) - set(element_references.keys()) + if missing_elements: + raise ValueError( + f"Missing reference energies for elements: {missing_elements}" + ) - def calculate( - self, atoms: Atoms, properties: list[str], system_changes: list[str] - ) -> None: - """ - Perform the calculation for the given atomic structure and convert total energy to formation energy. + total_ref_energy = sum( + element_references[element] * count + for element, count in element_counts.items() + ) - Args: - atoms (Atoms): The atomic structure to calculate properties for. - properties (list[str]): The list of properties to calculate. - system_changes (list[str]): The list of changes in the system. - """ - # First get the total energy from the parent calculator - super().calculate(atoms, properties, system_changes) - - # If energy was calculated, convert it to formation energy - total_energy = self.results["energy"] - - atomic_numbers = atoms.get_atomic_numbers() - element_symbols = atoms.get_chemical_symbols() - element_counts = Counter(element_symbols) - - missing_elements = set(element_symbols) - set(self._element_refs.keys()) - if missing_elements: - raise ValueError(f"Missing reference energies for elements: {missing_elements}") - - total_ref_energy = sum( - self._element_refs[element] * count - for element, count in element_counts.items() - ) - - formation_energy = (total_energy - total_ref_energy) - self.results["energy"] = formation_energy + formation_energy = total_energy - total_ref_energy + + if apply_corrections: + formation_energy = _apply_mp_style_corrections(formation_energy, atoms) + + calculator.results["energy"] = formation_energy + + if "free_energy" in calculator.results: + calculator.results["free_energy"] = formation_energy + + # Replace the calculate method + calculator.calculate = formation_energy_calculate + + return calculator class MixedPBCError(ValueError): From b747b90adaeb3e5bb3f44c883b5559e0286fa026 Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 31 Oct 2025 00:11:42 +0000 Subject: [PATCH 04/16] add corrections for omat --- src/fairchem/core/calculate/ase_calculator.py | 28 +++++++------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/src/fairchem/core/calculate/ase_calculator.py b/src/fairchem/core/calculate/ase_calculator.py index 1857feedbf..a394c75c5c 100644 --- a/src/fairchem/core/calculate/ase_calculator.py +++ b/src/fairchem/core/calculate/ase_calculator.py @@ -16,7 +16,6 @@ import numpy as np from ase.calculators.calculator import Calculator from ase.stress import full_3x3_to_voigt_6_stress -from monty.dev import requires from fairchem.core.calculate import pretrained_mlip from fairchem.core.datasets import data_list_collater @@ -31,12 +30,6 @@ UMATask, ) -try: - from fairchem.data.omat import data_omat_installed -except ImportError: - data_omat_installed = False - - if TYPE_CHECKING: from ase import Atoms @@ -324,14 +317,6 @@ def _validate_charge_and_spin(self, atoms: Atoms) -> None: ) -@requires( - data_omat_installed, - message="Formation energy functionality requires fairchem.data.omat to be installed.", -) -def _apply_mp_style_corrections(formation_energy: float, atoms: Atoms) -> float: - pass - - def enable_formation_energy( calculator: FAIRChemCalculator, element_references: dict | None = None, @@ -351,8 +336,7 @@ def enable_formation_energy( FAIRChemCalculator: The same calculator instance but will return formation energies as the potential energy. """ if element_references is None: - # get these - element_references = {} + element_references = calculator.predictor.form_elem_refs[calculator.task_name] if apply_corrections is True and calculator.task_name != UMATask.OMAT.value: raise ValueError("MP style corrections can only be applied for the OMat task.") @@ -387,7 +371,15 @@ def formation_energy_calculate( formation_energy = total_energy - total_ref_energy if apply_corrections: - formation_energy = _apply_mp_style_corrections(formation_energy, atoms) + try: + from fairchem.data.omat.entries.compatibility import ( + apply_mp_style_corrections, + ) + except ImportError as err: + raise ImportError( + "fairchem.data.omat is required to apply MP style corrections. Please install it." + ) from err + formation_energy = apply_mp_style_corrections(formation_energy, atoms) calculator.results["energy"] = formation_energy From adbbf55fc8f55499cf61b5fac9d3da829a168b85 Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 31 Oct 2025 00:12:18 +0000 Subject: [PATCH 05/16] add formation element reference loading --- src/fairchem/core/calculate/pretrained_mlip.py | 13 +++++++++++-- src/fairchem/core/calculate/pretrained_models.json | 12 ++++++++++++ src/fairchem/core/units/mlip_unit/__init__.py | 3 +++ src/fairchem/core/units/mlip_unit/predict.py | 5 +++++ 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/fairchem/core/calculate/pretrained_mlip.py b/src/fairchem/core/calculate/pretrained_mlip.py index 4232026c66..5a58cb0b5a 100644 --- a/src/fairchem/core/calculate/pretrained_mlip.py +++ b/src/fairchem/core/calculate/pretrained_mlip.py @@ -109,14 +109,23 @@ def get_predict_unit( cache_dir=cache_dir, ) atom_refs = get_reference_energies(model_name, "atom_refs", cache_dir) + form_elem_refs = get_reference_energies(model_name, "form_elem_refs", cache_dir)[ + "refs" + ] + return load_predict_unit( - checkpoint_path, inference_settings, overrides, device, atom_refs + checkpoint_path, + inference_settings, + overrides, + device, + atom_refs, + form_elem_refs, ) def get_reference_energies( model_name: str, - reference_type: Literal["atom_refs"] = "atom_refs", + reference_type: Literal["atom_refs", "form_elem_refs"] = "atom_refs", cache_dir: str = CACHE_DIR, ) -> dict: """ diff --git a/src/fairchem/core/calculate/pretrained_models.json b/src/fairchem/core/calculate/pretrained_models.json index 29cfdd2e1e..a9793b6ea6 100644 --- a/src/fairchem/core/calculate/pretrained_models.json +++ b/src/fairchem/core/calculate/pretrained_models.json @@ -6,6 +6,10 @@ "atom_refs": { "subfolder": "references", "filename": "iso_atom_elem_refs.yaml" + }, + "form_elem_refs": { + "subfolder": "references", + "filename": "form_elem_refs.yaml" } }, "uma-s-1p1": { @@ -15,6 +19,10 @@ "atom_refs": { "subfolder": "references", "filename": "iso_atom_elem_refs.yaml" + }, + "form_elem_refs": { + "subfolder": "references", + "filename": "form_elem_refs.yaml" } }, "uma-m-1p1": { @@ -24,6 +32,10 @@ "atom_refs": { "subfolder": "references", "filename": "iso_atom_elem_refs.yaml" + }, + "form_elem_refs": { + "subfolder": "references", + "filename": "form_elem_refs.yaml" } }, "esen-md-direct-all-omol": { diff --git a/src/fairchem/core/units/mlip_unit/__init__.py b/src/fairchem/core/units/mlip_unit/__init__.py index d406bec7e5..17f48fa728 100644 --- a/src/fairchem/core/units/mlip_unit/__init__.py +++ b/src/fairchem/core/units/mlip_unit/__init__.py @@ -28,6 +28,7 @@ def load_predict_unit( overrides: dict | None = None, device: Literal["cuda", "cpu"] | None = None, atom_refs: dict | None = None, + form_elem_refs: dict | None = None, ) -> MLIPPredictUnit: """Load a MLIPPredictUnit from a checkpoint file. @@ -39,6 +40,7 @@ def load_predict_unit( overrides: Optional dictionary of settings to override default inference settings. device: Optional torch device to load the model onto. atom_refs: Optional dictionary of isolated atom reference energies. + form_elem_refs: Optional dictionary of element reference energies for formation energy calculations. Returns: A MLIPPredictUnit instance ready for inference @@ -57,4 +59,5 @@ def load_predict_unit( inference_settings=inference_settings, overrides=overrides, atom_refs=atom_refs, + form_elem_refs=form_elem_refs, ) diff --git a/src/fairchem/core/units/mlip_unit/predict.py b/src/fairchem/core/units/mlip_unit/predict.py index 77083b8d4a..3abcb6a06b 100644 --- a/src/fairchem/core/units/mlip_unit/predict.py +++ b/src/fairchem/core/units/mlip_unit/predict.py @@ -91,6 +91,7 @@ def __init__( inference_settings: InferenceSettings | None = None, seed: int = 41, atom_refs: dict | None = None, + form_elem_refs: dict | None = None, assert_on_nans: bool = False, ): super().__init__() @@ -103,6 +104,7 @@ def __init__( if atom_refs is not None else {} ) + self.form_elem_refs = form_elem_refs if form_elem_refs is not None else {} if inference_settings is None: inference_settings = InferenceSettings() @@ -329,6 +331,7 @@ def __init__( inference_settings: InferenceSettings | None = None, seed: int = 41, atom_refs: dict | None = None, + form_elem_refs: dict | None = None, assert_on_nans: bool = False, server_config: dict | None = None, client_config: dict | None = None, @@ -359,6 +362,7 @@ def __init__( inference_settings=inference_settings, seed=seed, atom_refs=atom_refs, + form_elem_refs=form_elem_refs, ) self._dataset_to_tasks = copy.deepcopy(_mlip_pred_unit.dataset_to_tasks) @@ -378,6 +382,7 @@ def __init__( "inference_settings": inference_settings, "seed": seed, "atom_refs": atom_refs, + "form_elem_refs": form_elem_refs, "assert_on_nans": assert_on_nans, } From 4ca5f6618c8aaffab74d7629d0c0a7531243fd75 Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 31 Oct 2025 21:54:13 +0000 Subject: [PATCH 06/16] fix HFCheckpoint attrs --- src/fairchem/core/calculate/pretrained_mlip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fairchem/core/calculate/pretrained_mlip.py b/src/fairchem/core/calculate/pretrained_mlip.py index 5a58cb0b5a..ba33a2fd00 100644 --- a/src/fairchem/core/calculate/pretrained_mlip.py +++ b/src/fairchem/core/calculate/pretrained_mlip.py @@ -30,7 +30,7 @@ class HuggingFaceCheckpoint: subfolder: str | None = None # specify a hf repo subfolder revision: str | None = None # specify a version tag, branch, commit hash atom_refs: dict | None = None # specify an isolated atomic reference - mp_elemental_refs: dict | None = None # mp unary compount elemental reference + form_elem_refs: dict | None = None # specify a form elemental reference @dataclass From 61a0b934391d5c632f1e2322421c54df9bc13a92 Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 31 Oct 2025 22:23:45 +0000 Subject: [PATCH 07/16] set_predict_formation_energy function --- src/fairchem/core/calculate/ase_calculator.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/fairchem/core/calculate/ase_calculator.py b/src/fairchem/core/calculate/ase_calculator.py index a394c75c5c..c52cfe5d6c 100644 --- a/src/fairchem/core/calculate/ase_calculator.py +++ b/src/fairchem/core/calculate/ase_calculator.py @@ -317,13 +317,13 @@ def _validate_charge_and_spin(self, atoms: Atoms) -> None: ) -def enable_formation_energy( +def set_predict_formation_energy( calculator: FAIRChemCalculator, element_references: dict | None = None, apply_corrections: bool | None = None, ) -> FAIRChemCalculator: """ - Helper function to easily enable formation energy calculation on a FAIRChemCalculator instance. + Adapt a calculator to predict formation energy. Args: calculator (FAIRChemCalculator): The calculator instance to modify. @@ -340,7 +340,6 @@ def enable_formation_energy( if apply_corrections is True and calculator.task_name != UMATask.OMAT.value: raise ValueError("MP style corrections can only be applied for the OMat task.") - if apply_corrections is None and calculator.task_name == UMATask.OMAT.value: apply_corrections = True From 612a62f5ad8c733f34f2f0198f7a4175e42203e6 Mon Sep 17 00:00:00 2001 From: lbluque Date: Tue, 4 Nov 2025 21:47:40 +0000 Subject: [PATCH 08/16] tests --- .../core/calculate/pretrained_mlip.py | 11 +- tests/core/calculate/test_ase_calculator.py | 160 ++++++++++++++++-- 2 files changed, 153 insertions(+), 18 deletions(-) diff --git a/src/fairchem/core/calculate/pretrained_mlip.py b/src/fairchem/core/calculate/pretrained_mlip.py index a3a86887c8..bb319e3dac 100644 --- a/src/fairchem/core/calculate/pretrained_mlip.py +++ b/src/fairchem/core/calculate/pretrained_mlip.py @@ -96,9 +96,14 @@ def get_predict_unit( """ checkpoint_path = pretrained_checkpoint_path_from_name(model_name) atom_refs = get_reference_energies(model_name, "atom_refs", cache_dir) - form_elem_refs = get_reference_energies(model_name, "form_elem_refs", cache_dir)[ - "refs" - ] + + if _MODEL_CKPTS.checkpoints[model_name].form_elem_refs is not None: + form_elem_refs = get_reference_energies( + model_name, "form_elem_refs", cache_dir + )["refs"] + else: + form_elem_refs = None + return load_predict_unit( checkpoint_path, inference_settings, diff --git a/tests/core/calculate/test_ase_calculator.py b/tests/core/calculate/test_ase_calculator.py index 26b2c18000..352ffa0bea 100644 --- a/tests/core/calculate/test_ase_calculator.py +++ b/tests/core/calculate/test_ase_calculator.py @@ -11,12 +11,12 @@ import tempfile from typing import TYPE_CHECKING +import ase.io import numpy as np import pytest import torch from ase import Atoms, units from ase.build import add_adsorbate, bulk, fcc111, molecule -from ase.io import read, write from ase.md.langevin import Langevin from ase.optimize import BFGS @@ -24,6 +24,7 @@ from fairchem.core.calculate.ase_calculator import ( AllZeroUnitCellError, MixedPBCError, + set_predict_formation_energy, ) from fairchem.core.units.mlip_unit.api.inference import InferenceSettings, UMATask @@ -32,6 +33,9 @@ from fairchem.core.calculate import pretrained_mlip +# mark all tests in this module as gpu tests +pytestmark = pytest.mark.gpu + @pytest.fixture(scope="module", params=pretrained_mlip.available_models) def mlip_predict_unit(request) -> MLIPPredictUnit: @@ -77,7 +81,15 @@ def bulk_atoms() -> Atoms: @pytest.fixture() def aperiodic_atoms() -> Atoms: - return molecule("H2O") + atoms = molecule("H2O") + atoms.info["charge"] = 0 + atoms.info["spin"] = 1 + return atoms + + +@pytest.fixture() +def custom_element_refs() -> dict: + return {"H": -0.5, "O": -1.0, "Fe": -2.0} @pytest.fixture() @@ -96,8 +108,8 @@ def periodic_h2o_from_extxyz(periodic_h2o_atoms) -> Atoms: periodic_h2o_atoms.info["charge"] = 0 # set as int here periodic_h2o_atoms.info["spin"] = 0 with tempfile.NamedTemporaryFile(suffix=".xyz") as f: - write(f.name, periodic_h2o_atoms, format="extxyz") - atoms = read(f.name, format="extxyz") # type: ignore + ase.io.write(f.name, periodic_h2o_atoms, format="extxyz") + atoms = ase.io.read(f.name, format="extxyz") # type: ignore return atoms # will be read as np.int64 @@ -148,7 +160,6 @@ def test_calculator_unknown_task_raises_error(): ) -@pytest.mark.gpu() def test_calculator_setup(all_calculators): for calc in all_calculators(): implemented_properties = ["energy", "forces"] @@ -166,7 +177,6 @@ def test_calculator_setup(all_calculators): ) -@pytest.mark.gpu() @pytest.mark.parametrize( "atoms_fixture", [ @@ -185,7 +195,6 @@ def test_energy_calculation(request, atoms_fixture, all_calculators): assert isinstance(energy, float) -@pytest.mark.gpu() def test_relaxation_final_energy(slab_atoms, mlip_predict_unit): datasets = list(mlip_predict_unit.dataset_to_tasks.keys()) calc = FAIRChemCalculator( @@ -203,7 +212,6 @@ def test_relaxation_final_energy(slab_atoms, mlip_predict_unit): assert isinstance(final_energy, float) -@pytest.mark.gpu() @pytest.mark.parametrize("inference_settings", ["default", "turbo"]) def test_calculator_configurations(inference_settings, slab_atoms): # turbo mode requires compilation and needs to reset here @@ -232,7 +240,6 @@ def test_calculator_configurations(inference_settings, slab_atoms): assert isinstance(stress, np.ndarray) -@pytest.mark.gpu() def test_large_bulk_system(large_bulk_atoms): """Test a bulk system with 1000 atoms using the small model.""" predict_unit = pretrained_mlip.get_predict_unit("uma-s-1", device="cuda") @@ -248,7 +255,6 @@ def test_large_bulk_system(large_bulk_atoms): assert isinstance(forces, np.ndarray) -@pytest.mark.gpu() @pytest.mark.parametrize( "pbc", [ @@ -277,7 +283,6 @@ def test_mixed_pbc_behavior(pbc, aperiodic_atoms, all_calculators): assert isinstance(energy, float) -@pytest.mark.gpu() def test_error_for_pbc_with_zero_cell(aperiodic_atoms, all_calculators): """Test error raised when pbc=True but atoms.cell is zero.""" aperiodic_atoms.pbc = True # Set PBC to True @@ -288,7 +293,6 @@ def test_error_for_pbc_with_zero_cell(aperiodic_atoms, all_calculators): aperiodic_atoms.get_potential_energy() -@pytest.mark.gpu() def test_omol_missing_spin_charge_logs_warning( periodic_h2o_atoms, omol_calculators, caplog ): @@ -304,7 +308,6 @@ def test_omol_missing_spin_charge_logs_warning( assert "spin multiplicity is not set in atoms.info" in caplog.text -@pytest.mark.gpu() def test_omol_energy_diff_for_charge_and_spin(aperiodic_atoms, omol_calculators): """Test that energy differs for H2O molecule with different charge and spin_multiplicity.""" @@ -364,7 +367,6 @@ def test_single_atom_system_errors(): atom.get_potential_energy() -@pytest.mark.gpu() @pytest.mark.skip( reason="the wigner matrices should be dependent on the RNG, but the energies" "are not actually different using the above seed setting code." @@ -418,7 +420,7 @@ def test_simple_md(): external_graph_gen=False, ) predictor = pretrained_mlip.get_predict_unit( - "uma-s-1p1", device="cpu", inference_settings=inference_settings + "uma-s-1p1", inference_settings=inference_settings ) calc = FAIRChemCalculator(predictor, task_name="omol") run_md_simulation(calc, steps=10) @@ -440,3 +442,131 @@ def test_parallel_md(checkpointing): calc = FAIRChemCalculator(predictor, task_name="omol") run_md_simulation(calc, steps=10) + + +def test_set_predict_formation_energy_basic_functionality(): + predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") + calc = FAIRChemCalculator(predict_unit, task_name="omol") + + original_calc = FAIRChemCalculator(predict_unit, task_name="omol") + + formation_calc = set_predict_formation_energy(calc) + + assert formation_calc is calc + assert calc.calculate != original_calc.calculate + + +def test_set_predict_formation_energy_with_custom_references( + aperiodic_atoms, custom_element_refs +): + predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") + calc = FAIRChemCalculator(predict_unit, task_name="omol") + + formation_calc = set_predict_formation_energy( + calc, element_references=custom_element_refs + ) + + aperiodic_atoms.calc = formation_calc + formation_energy = aperiodic_atoms.get_potential_energy() + + assert isinstance(formation_energy, float) + + +def test_set_predict_formation_energy_missing_element_raises_error(): + """Test that missing element references raise appropriate error.""" + water_molecule = molecule("H2O") + predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") + calc = FAIRChemCalculator(predict_unit, task_name="omol") + + incomplete_refs = {"H": -0.5} # Missing O reference + formation_calc = set_predict_formation_energy( + calc, element_references=incomplete_refs + ) + + water_molecule.calc = formation_calc + + with pytest.raises(ValueError, match="Missing reference energies for elements"): + water_molecule.get_potential_energy() + + +def test_set_predict_formation_energy_mp_corrections_omat_task(bulk_atoms): + predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") + calc = FAIRChemCalculator(predict_unit, task_name="omat") + + # Should apply corrections by default for omat + formation_calc = set_predict_formation_energy(calc, apply_corrections=None) + + bulk_atoms.calc = formation_calc + + # Should not raise error - corrections should be applied + try: + energy = bulk_atoms.get_potential_energy() + assert isinstance(energy, float) + except ImportError: + # If fairchem.data.omat is not available, should get ImportError + pytest.skip("fairchem.data.omat not available for MP corrections") + + +def test_set_predict_formation_energy_mp_corrections_non_omat_error(): + predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") + calc = FAIRChemCalculator(predict_unit, task_name="omol") + + # Should raise error when trying to apply corrections to non-omat task + with pytest.raises( + ValueError, match="MP style corrections can only be applied for the OMat task" + ): + set_predict_formation_energy(calc, apply_corrections=True) + + +def test_set_predict_formation_energy_no_corrections_omat_task(bulk_atoms): + """Test that MP corrections can be explicitly disabled for OMat task.""" + predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") + calc = FAIRChemCalculator(predict_unit, task_name="omat") + + formation_calc = set_predict_formation_energy(calc, apply_corrections=False) + + bulk_atoms.calc = formation_calc + + # Should work without corrections + energy = bulk_atoms.get_potential_energy() + assert isinstance(energy, float) + + +def test_set_predict_formation_energy_calculation_correctness(aperiodic_atoms): + """Test that formation energy calculation follows the correct formula.""" + predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") + + total_calc = FAIRChemCalculator(predict_unit, task_name="omol") + formation_calc_instance = FAIRChemCalculator(predict_unit, task_name="omol") + + test_refs = {"H": -0.5} + formation_calc = set_predict_formation_energy( + formation_calc_instance, element_references=test_refs + ) + + aperiodic_atoms = molecule("H2") + aperiodic_atoms.info["charge"] = 0 + aperiodic_atoms.info["spin"] = 1 + + aperiodic_atoms.calc = total_calc + total_energy = aperiodic_atoms.get_potential_energy() + + aperiodic_atoms.calc = formation_calc + formation_energy = aperiodic_atoms.get_potential_energy() + + expected_formation_energy = total_energy - (2 * test_refs["H"]) + + assert np.isclose(formation_energy, expected_formation_energy, atol=1e-6) + + +def test_set_predict_formation_energy_different_task_types(): + """Test formation energy calculation works for different task types.""" + predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") + + for task_name in ["omol", "omat", "oc20"]: + if task_name in predict_unit.dataset_to_tasks: + calc = FAIRChemCalculator(predict_unit, task_name=task_name) + + # Should not raise error when using default element references + formation_calc = set_predict_formation_energy(calc) + assert formation_calc is calc From 2687a146d98e17d04145f04c1daee03578926a7b Mon Sep 17 00:00:00 2001 From: lbluque Date: Tue, 4 Nov 2025 21:49:12 +0000 Subject: [PATCH 09/16] predict formation energies in inorganic bulk tutorial --- docs/_toc.yml | 2 +- .../examples_tutorials/bulk_stability.md | 77 ---------------- .../examples_tutorials/formation_energy.md | 92 +++++++++++++++++++ 3 files changed, 93 insertions(+), 78 deletions(-) delete mode 100644 docs/inorganic_materials/examples_tutorials/bulk_stability.md create mode 100644 docs/inorganic_materials/examples_tutorials/formation_energy.md diff --git a/docs/_toc.yml b/docs/_toc.yml index 217eb0fd41..94685c7fa2 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -72,7 +72,7 @@ parts: - file: inorganic_materials/models - file: inorganic_materials/examples_tutorials/summary sections: - - file: inorganic_materials/examples_tutorials/bulk_stability + - file: inorganic_materials/examples_tutorials/formation_energy - file: inorganic_materials/examples_tutorials/phonons - file: inorganic_materials/examples_tutorials/elastic - file: inorganic_materials/FAQ diff --git a/docs/inorganic_materials/examples_tutorials/bulk_stability.md b/docs/inorganic_materials/examples_tutorials/bulk_stability.md deleted file mode 100644 index ef803c4a0f..0000000000 --- a/docs/inorganic_materials/examples_tutorials/bulk_stability.md +++ /dev/null @@ -1,77 +0,0 @@ ---- -jupytext: - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.17.1 -kernelspec: - display_name: Python 3 (ipykernel) - language: python - name: python3 ---- - -Stability ------------------- - -We're going to start simple here - let's run a local relaxation (optimize the unit cell and positions) using a pre-trained EquiformerV2-31M-OMAT24-MP-sAlex checkpoint. This checkpoint has a few fun properties -1. It's a relatively small (31M) parameter model -2. It was pre-trained on the OMat24 dataset, and then fine-tuned on the MPtrj and Alexandria datasets, so it should emit energies and forces that are consistent with the MP GGA (PBE/PBE+U) level of theory - -````{admonition} Need to install fairchem-core or get UMA access or getting permissions/401 errors? -:class: dropdown - - -1. Install the necessary packages using pip, uv etc -```{code-cell} ipython3 -:tags: [skip-execution] - -! pip install fairchem-core fairchem-data-oc fairchem-applications-cattsunami -``` - -2. Get access to any necessary huggingface gated models - * Get and login to your Huggingface account - * Request access to https://huggingface.co/facebook/UMA - * Create a Huggingface token at https://huggingface.co/settings/tokens/ with the permission "Permissions: Read access to contents of all public gated repos you can access" - * Add the token as an environment variable using `huggingface-cli login` or by setting the HF_TOKEN environment variable. - -```{code-cell} ipython3 -:tags: [skip-execution] - -# Login using the huggingface-cli utility -! huggingface-cli login - -# alternatively, -import os -os.environ['HF_TOKEN'] = 'MY_TOKEN' -``` - -```` - -```{code-cell} ipython3 -from __future__ import annotations - -import pprint - -from ase.build import bulk -from ase.optimize import LBFGS -from quacc.recipes.mlp.core import relax_job - -# Make an Atoms object of a bulk Cu structure -atoms = bulk("Cu") - -# Run a structure relaxation -result = relax_job( - atoms, - method="fairchem", - name_or_path="uma-s-1p1", - task_name="omat", - opt_params={"fmax": 1e-3, "optimizer": LBFGS}, -) -``` - -```{code-cell} ipython3 -pprint.pprint(result) -``` - -Congratulations; you ran your first relaxation using an OMat24-trained checkpoint and `quacc`! diff --git a/docs/inorganic_materials/examples_tutorials/formation_energy.md b/docs/inorganic_materials/examples_tutorials/formation_energy.md new file mode 100644 index 0000000000..6eab9a495c --- /dev/null +++ b/docs/inorganic_materials/examples_tutorials/formation_energy.md @@ -0,0 +1,92 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.17.1 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +Formation energies +------------------ + +We're going to start simple here - let's run a local relaxation (optimize the unit cell and positions) using a pre-trained UMA model to compute formation energies for inorganic materials. + +Note predicting formation energy using models that models trained solely on OMat24 must use OMat24 compatible references and corrections for mixing PBE and PBE+U calculations. We use MP2020-style corrections fitted to OMat24 DFT calculations. For more information see the [documentation](https://docs.materialsproject.org/methodology/materials-methodology/thermodynamic-stability/thermodynamic-stability/anion-and-gga-gga+u-mixing) at the Materials Project. The necessary references can be found using the `fairchem.data.omat` package! + +````{admonition} Need to install fairchem-core or get UMA access or getting permissions/401 errors? +:class: dropdown + + +1. Install the necessary packages using pip, uv etc +```{code-cell} ipython3 +:tags: [skip-execution] + +! pip install fairchem-core fairchem-data-omat +``` + +2. Get access to any necessary huggingface gated models + * Get and login to your Huggingface account + * Request access to https://huggingface.co/facebook/UMA + * Create a Huggingface token at https://huggingface.co/settings/tokens/ with the permission "Permissions: Read access to contents of all public gated repos you can access" + * Add the token as an environment variable using `huggingface-cli login` or by setting the HF_TOKEN environment variable. + +```{code-cell} ipython3 +:tags: [skip-execution] + +# Login using the huggingface-cli utility +! huggingface-cli login + +# alternatively, +import os +os.environ['HF_TOKEN'] = 'MY_TOKEN' +``` +```` + +```{code-cell} ipython3 +from __future__ import annotations + +import pprint + +from ase.build import bulk +from ase.optimize import LBFGS +from quacc.recipes.mlp.core import relax_job + +from fairchem.core.calculate.ase_calculator import FAIRChemCalculator, set_predict_formation_energy + +# Make an Atoms object of a bulk MgO structure +atoms = bulk("MgO", "rocksalt", a=4.213) + +# Run a structure relaxation +result = relax_job( + atoms, + method="fairchem", + name_or_path="uma-s-1p1", + task_name="omat", + opt_params={"fmax": 1e-3, "optimizer": LBFGS}, +) + +# Get the realxed atoms! +atoms = result["atoms"] + +# Create an calculator using uma-s-1p1 +calculator = FAIRChemCalculator.from_model_checkpoint("uma-s-1p1", task_name="omat") + +# Adapt the calculator to automatically return MP-style corrected formation energies +# For the omat task, this defaults to apply MP2020 style corrections with OMat24 compatibility +calculator = set_predict_formation_energy(calculator, apply_corrections=True) + +# Predict the formation energy +atoms.calc = calculator +form_energy = atoms.get_potential_energy() +``` + +```{code-cell} ipython3 +pprint.pprint(f"Total energy: {result["results"]["energy"] eV\n Formation energy {form_energy} eV}) +``` + +Congratulations; you ran your first relaxation and predicted the formation energy of MgO using UMA and `quacc`! From 542dad23d32ffe6ca1c1f49967422f87bda9ee25 Mon Sep 17 00:00:00 2001 From: lbluque Date: Tue, 4 Nov 2025 21:55:36 +0000 Subject: [PATCH 10/16] remove test --- tests/core/calculate/test_ase_calculator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/calculate/test_ase_calculator.py b/tests/core/calculate/test_ase_calculator.py index 352ffa0bea..92e84076b1 100644 --- a/tests/core/calculate/test_ase_calculator.py +++ b/tests/core/calculate/test_ase_calculator.py @@ -207,7 +207,7 @@ def test_relaxation_final_energy(slab_atoms, mlip_predict_unit): assert isinstance(initial_energy, float) opt = BFGS(slab_atoms) - opt.run(fmax=0.05, steps=100) + opt.run(fmax=0.05, steps=10) final_energy = slab_atoms.get_potential_energy() assert isinstance(final_energy, float) From b6783d32edcfc7693c1eaff2c7dd637f0dbb850c Mon Sep 17 00:00:00 2001 From: lbluque Date: Wed, 5 Nov 2025 18:40:04 +0000 Subject: [PATCH 11/16] fix correct total energy --- src/fairchem/core/calculate/ase_calculator.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/fairchem/core/calculate/ase_calculator.py b/src/fairchem/core/calculate/ase_calculator.py index 06107df9d4..19041e0588 100644 --- a/src/fairchem/core/calculate/ase_calculator.py +++ b/src/fairchem/core/calculate/ase_calculator.py @@ -353,6 +353,17 @@ def formation_energy_calculate( if "energy" in calculator.results: total_energy = calculator.results["energy"] + if apply_corrections: + try: + from fairchem.data.omat.entries.compatibility import ( + apply_mp_style_corrections, + ) + except ImportError as err: + raise ImportError( + "fairchem.data.omat is required to apply MP style corrections. Please install it." + ) from err + total_energy = apply_mp_style_corrections(total_energy, atoms) + element_symbols = atoms.get_chemical_symbols() element_counts = Counter(element_symbols) @@ -369,17 +380,6 @@ def formation_energy_calculate( formation_energy = total_energy - total_ref_energy - if apply_corrections: - try: - from fairchem.data.omat.entries.compatibility import ( - apply_mp_style_corrections, - ) - except ImportError as err: - raise ImportError( - "fairchem.data.omat is required to apply MP style corrections. Please install it." - ) from err - formation_energy = apply_mp_style_corrections(formation_energy, atoms) - calculator.results["energy"] = formation_energy if "free_energy" in calculator.results: From 73a7e8cb40018e20531a5a3cea00201aefdf2f59 Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 7 Nov 2025 00:25:58 +0000 Subject: [PATCH 12/16] add more tests --- src/fairchem/core/calculate/ase_calculator.py | 11 +- tests/core/calculate/test_ase_calculator.py | 217 ++++++++---------- .../test_formation_energies_omat.json | 1 + 3 files changed, 106 insertions(+), 123 deletions(-) create mode 100644 tests/core/calculate/test_formation_energies_omat.json diff --git a/src/fairchem/core/calculate/ase_calculator.py b/src/fairchem/core/calculate/ase_calculator.py index 19041e0588..1be2a63a67 100644 --- a/src/fairchem/core/calculate/ase_calculator.py +++ b/src/fairchem/core/calculate/ase_calculator.py @@ -10,6 +10,7 @@ import logging import os from collections import Counter +from contextlib import contextmanager from functools import partial from typing import TYPE_CHECKING, Literal @@ -317,6 +318,7 @@ def _validate_charge_and_spin(self, atoms: Atoms) -> None: ) +@contextmanager def set_predict_formation_energy( calculator: FAIRChemCalculator, element_references: dict | None = None, @@ -385,10 +387,11 @@ def formation_energy_calculate( if "free_energy" in calculator.results: calculator.results["free_energy"] = formation_energy - # Replace the calculate method - calculator.calculate = formation_energy_calculate - - return calculator + try: + calculator.calculate = formation_energy_calculate + yield + finally: + calculator.calculate = original_calculate class MixedPBCError(ValueError): diff --git a/tests/core/calculate/test_ase_calculator.py b/tests/core/calculate/test_ase_calculator.py index 92e84076b1..4941a2d85d 100644 --- a/tests/core/calculate/test_ase_calculator.py +++ b/tests/core/calculate/test_ase_calculator.py @@ -7,7 +7,9 @@ from __future__ import annotations +import json import logging +import os import tempfile from typing import TYPE_CHECKING @@ -17,6 +19,7 @@ import torch from ase import Atoms, units from ase.build import add_adsorbate, bulk, fcc111, molecule +from ase.io.jsonio import decode from ase.md.langevin import Langevin from ase.optimize import BFGS @@ -37,6 +40,30 @@ pytestmark = pytest.mark.gpu +@pytest.fixture(scope="session") +def atoms_with_formation_energy(): + with open( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "test_formation_energies_omat.json", + ) + ) as f: + data = json.load(f) + + atoms_with_formation_energy = {} + for comp_str, entry in data.items(): + atoms = decode(entry["atoms"]) + formation_energy_per_atom = entry["formation_energy_per_atom"] + atoms_with_formation_energy[comp_str] = (atoms, formation_energy_per_atom) + + return atoms_with_formation_energy + + +@pytest.fixture(scope="module") +def single_mlip_predict_unit(): + return pretrained_mlip.get_predict_unit("uma-s-1p1") + + @pytest.fixture(scope="module", params=pretrained_mlip.available_models) def mlip_predict_unit(request) -> MLIPPredictUnit: return pretrained_mlip.get_predict_unit(request.param) @@ -213,21 +240,20 @@ def test_relaxation_final_energy(slab_atoms, mlip_predict_unit): @pytest.mark.parametrize("inference_settings", ["default", "turbo"]) -def test_calculator_configurations(inference_settings, slab_atoms): +def test_calculator_configurations( + inference_settings, slab_atoms, single_mlip_predict_unit +): # turbo mode requires compilation and needs to reset here if inference_settings == "turbo": torch.compiler.reset() - predict_unit = pretrained_mlip.get_predict_unit( - "uma-s-1", inference_settings=inference_settings - ) - datasets = list(predict_unit.dataset_to_tasks.keys()) + datasets = list(single_mlip_predict_unit.dataset_to_tasks.keys()) calc = FAIRChemCalculator( - predict_unit, + single_mlip_predict_unit, task_name=datasets[0], ) slab_atoms.calc = calc - assert predict_unit.model.module.otf_graph is True + assert single_mlip_predict_unit.model.module.otf_graph is True # Test energy calculation energy = slab_atoms.get_potential_energy() assert isinstance(energy, float) @@ -240,10 +266,9 @@ def test_calculator_configurations(inference_settings, slab_atoms): assert isinstance(stress, np.ndarray) -def test_large_bulk_system(large_bulk_atoms): +def test_large_bulk_system(large_bulk_atoms, single_mlip_predict_unit): """Test a bulk system with 1000 atoms using the small model.""" - predict_unit = pretrained_mlip.get_predict_unit("uma-s-1", device="cuda") - calc = FAIRChemCalculator(predict_unit, task_name="omat") + calc = FAIRChemCalculator(single_mlip_predict_unit, task_name="omat") large_bulk_atoms.calc = calc # Test energy calculation @@ -332,17 +357,15 @@ def test_omol_energy_diff_for_charge_and_spin(aperiodic_atoms, omol_calculators) ), "Energy values are not unique for different charge/spin combinations" -def test_single_atom_systems(): +def test_single_atom_systems(single_mlip_predict_unit): """Test a system with a single atom. Single atoms do not currently use the model.""" - predict_unit = pretrained_mlip.get_predict_unit("uma-s-1", device="cpu") - for at_num in range(1, 84): atom = Atoms([at_num], positions=[(0.0, 0.0, 0.0)]) atom.info["charge"] = 0 atom.info["spin"] = 3 for task_name in ("omat", "omol", "oc20"): - calc = FAIRChemCalculator(predict_unit, task_name=task_name) + calc = FAIRChemCalculator(single_mlip_predict_unit, task_name=task_name) atom.calc = calc # Test energy calculation energy = atom.get_potential_energy() @@ -353,10 +376,9 @@ def test_single_atom_systems(): assert (forces == 0.0).all() -def test_single_atom_system_errors(): +def test_single_atom_system_errors(single_mlip_predict_unit): """Test that a charged system with a single atom does not work.""" - predict_unit = pretrained_mlip.get_predict_unit("uma-s-1", device="cpu") - calc = FAIRChemCalculator(predict_unit, task_name="omol") + calc = FAIRChemCalculator(single_mlip_predict_unit, task_name="omol") atom = Atoms("C", positions=[(0.0, 0.0, 0.0)]) atom.calc = calc @@ -371,12 +393,12 @@ def test_single_atom_system_errors(): reason="the wigner matrices should be dependent on the RNG, but the energies" "are not actually different using the above seed setting code." ) -def test_random_seed_final_energy(): +def test_random_seed_final_energy(single_mlip_predict_unit): seeds = [100, 200, 300, 200] results_by_seed = {} calc = FAIRChemCalculator( - pretrained_mlip.get_predict_unit("uma-s-1"), + single_mlip_predict_unit, task_name="omat", ) @@ -419,10 +441,10 @@ def test_simple_md(): internal_graph_gen_version=2, external_graph_gen=False, ) - predictor = pretrained_mlip.get_predict_unit( + predict_unit = pretrained_mlip.get_predict_unit( "uma-s-1p1", inference_settings=inference_settings ) - calc = FAIRChemCalculator(predictor, task_name="omol") + calc = FAIRChemCalculator(predict_unit, task_name="omol") run_md_simulation(calc, steps=10) @@ -444,129 +466,86 @@ def test_parallel_md(checkpointing): run_md_simulation(calc, steps=10) -def test_set_predict_formation_energy_basic_functionality(): - predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") - calc = FAIRChemCalculator(predict_unit, task_name="omol") - - original_calc = FAIRChemCalculator(predict_unit, task_name="omol") - - formation_calc = set_predict_formation_energy(calc) - - assert formation_calc is calc - assert calc.calculate != original_calc.calculate - - -def test_set_predict_formation_energy_with_custom_references( - aperiodic_atoms, custom_element_refs +def test_set_predict_formation_energy_missing_element_raises_error( + single_mlip_predict_unit, ): - predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") - calc = FAIRChemCalculator(predict_unit, task_name="omol") - - formation_calc = set_predict_formation_energy( - calc, element_references=custom_element_refs - ) - - aperiodic_atoms.calc = formation_calc - formation_energy = aperiodic_atoms.get_potential_energy() - - assert isinstance(formation_energy, float) - - -def test_set_predict_formation_energy_missing_element_raises_error(): - """Test that missing element references raise appropriate error.""" water_molecule = molecule("H2O") - predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") - calc = FAIRChemCalculator(predict_unit, task_name="omol") - + calc = FAIRChemCalculator(single_mlip_predict_unit, task_name="omol") + water_molecule.calc = calc incomplete_refs = {"H": -0.5} # Missing O reference - formation_calc = set_predict_formation_energy( - calc, element_references=incomplete_refs - ) - - water_molecule.calc = formation_calc - - with pytest.raises(ValueError, match="Missing reference energies for elements"): + with pytest.raises( + ValueError, match="Missing reference energies for elements" + ), set_predict_formation_energy(calc, element_references=incomplete_refs): water_molecule.get_potential_energy() -def test_set_predict_formation_energy_mp_corrections_omat_task(bulk_atoms): - predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") - calc = FAIRChemCalculator(predict_unit, task_name="omat") +def test_set_predict_formation_energy_mp_corrections_omat_task( + bulk_atoms, single_mlip_predict_unit +): + calc = FAIRChemCalculator(single_mlip_predict_unit, task_name="omat") + bulk_atoms.calc = calc - # Should apply corrections by default for omat - formation_calc = set_predict_formation_energy(calc, apply_corrections=None) + try: + with set_predict_formation_energy(calc, apply_corrections=None): + corrected_energy = bulk_atoms.get_potential_energy() - bulk_atoms.calc = formation_calc + with set_predict_formation_energy(calc, apply_corrections=False): + energy = bulk_atoms.get_potential_energy() - # Should not raise error - corrections should be applied - try: - energy = bulk_atoms.get_potential_energy() assert isinstance(energy, float) + assert isinstance(corrected_energy, float) + assert energy != corrected_energy + except ImportError: # If fairchem.data.omat is not available, should get ImportError pytest.skip("fairchem.data.omat not available for MP corrections") -def test_set_predict_formation_energy_mp_corrections_non_omat_error(): - predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") - calc = FAIRChemCalculator(predict_unit, task_name="omol") - - # Should raise error when trying to apply corrections to non-omat task - with pytest.raises( - ValueError, match="MP style corrections can only be applied for the OMat task" - ): - set_predict_formation_energy(calc, apply_corrections=True) - - -def test_set_predict_formation_energy_no_corrections_omat_task(bulk_atoms): - """Test that MP corrections can be explicitly disabled for OMat task.""" - predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") - calc = FAIRChemCalculator(predict_unit, task_name="omat") - - formation_calc = set_predict_formation_energy(calc, apply_corrections=False) - - bulk_atoms.calc = formation_calc - - # Should work without corrections - energy = bulk_atoms.get_potential_energy() - assert isinstance(energy, float) - +def test_set_predict_formation_energy_calculation_correctness( + aperiodic_atoms, single_mlip_predict_unit +): + calc = FAIRChemCalculator(single_mlip_predict_unit, task_name="omol") -def test_set_predict_formation_energy_calculation_correctness(aperiodic_atoms): - """Test that formation energy calculation follows the correct formula.""" - predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") + atoms = molecule("H2") + atoms.info["charge"] = 0 + atoms.info["spin"] = 1 - total_calc = FAIRChemCalculator(predict_unit, task_name="omol") - formation_calc_instance = FAIRChemCalculator(predict_unit, task_name="omol") + atoms.calc = calc + total_energy = atoms.get_potential_energy() test_refs = {"H": -0.5} - formation_calc = set_predict_formation_energy( - formation_calc_instance, element_references=test_refs - ) - - aperiodic_atoms = molecule("H2") - aperiodic_atoms.info["charge"] = 0 - aperiodic_atoms.info["spin"] = 1 + with set_predict_formation_energy(calc, element_references=test_refs): + atoms.calc = calc + formation_energy = atoms.get_potential_energy() - aperiodic_atoms.calc = total_calc - total_energy = aperiodic_atoms.get_potential_energy() + expected_formation_energy = total_energy - (2 * test_refs["H"]) + assert np.isclose(formation_energy, expected_formation_energy, atol=1e-6) - aperiodic_atoms.calc = formation_calc - formation_energy = aperiodic_atoms.get_potential_energy() - expected_formation_energy = total_energy - (2 * test_refs["H"]) +def test_set_predict_formation_energy_different_task_types(single_mlip_predict_unit): + for task_name in ["omol", "omat", "oc20"]: + if task_name in single_mlip_predict_unit.dataset_to_tasks: + calc = FAIRChemCalculator(single_mlip_predict_unit, task_name=task_name) - assert np.isclose(formation_energy, expected_formation_energy, atol=1e-6) + # Should not raise error when using default element references + with set_predict_formation_energy(calc): + pass -def test_set_predict_formation_energy_different_task_types(): - """Test formation energy calculation works for different task types.""" +def test_formation_energy_predictions_against_known_values( + atoms_with_formation_energy, single_mlip_predict_unit +): + """Test that formation energy predictions match known values within tolerance.""" predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") + calc = FAIRChemCalculator(predict_unit, task_name="omat") - for task_name in ["omol", "omat", "oc20"]: - if task_name in predict_unit.dataset_to_tasks: - calc = FAIRChemCalculator(predict_unit, task_name=task_name) + for known_formation_energy, atoms in atoms_with_formation_energy.values(): + atoms.calc = calc + with set_predict_formation_energy(calc): + predicted_formation_energy = atoms.get_potential_energy() - # Should not raise error when using default element references - formation_calc = set_predict_formation_energy(calc) - assert formation_calc is calc + assert np.isclose( + predicted_formation_energy, + known_formation_energy, + atol=1e-3, + ) diff --git a/tests/core/calculate/test_formation_energies_omat.json b/tests/core/calculate/test_formation_energies_omat.json new file mode 100644 index 0000000000..7e8fb308a6 --- /dev/null +++ b/tests/core/calculate/test_formation_energies_omat.json @@ -0,0 +1 @@ +{"MnO": {"atoms": "{\"numbers\": {\"__ndarray__\": [[4], \"int64\", [25, 25, 8, 8]]}, \"positions\": {\"__ndarray__\": [[4, 3], \"float64\", [0.0, 0.0, 0.0, -0.00037429499999996896, 9.772000000007885e-05, 3.145181315, 1.815539013211297, 1.2841631749417177, 3.1450094255952568, 0.9072525067887002, -1.2837140149417199, 4.71762057440474]]}, \"initial_magmoms\": {\"__ndarray__\": [[4], \"float64\", [4.5440000000000005, 4.5440000000000005, 0.233, 0.233]]}, \"cell\": {\"__ndarray__\": [[3, 3], \"float64\", [2.72354011, 0.00025372000000000003, 1.57226737, 0.9079176600000001, 2.56797491, 1.57256073, -0.90866625, -2.56777947, 4.7178018999999995]]}, \"pbc\": {\"__ndarray__\": [[3], \"bool\", [true, true, true]]}, \"__ase_objtype__\": \"atoms\"}", "formation_energy_per_atom": -1.9547571762931035}, "Li2CO3": {"atoms": "{\"numbers\": {\"__ndarray__\": [[12], \"int64\", [3, 3, 3, 3, 6, 6, 8, 8, 8, 8, 8, 8]]}, \"positions\": {\"__ndarray__\": [[12, 3], \"float64\", [4.249693104350961, 3.102879190596196, 3.348833019989187, 1.760827394771604, 1.4715854444393222, 3.079278267148243, 2.033495435649039, 1.015350689403804, 6.117356870010813, 4.5223611452283965, 2.6466444355606775, 6.386911622851757, 1.9507237449478476, 3.8559150337509043, 6.252134256176085, 4.332464795052153, 0.26231484624909585, 3.214055633823915, 2.6507627167209726, 2.7878647853334093, 6.2521342510361535, 3.632425823279027, 1.330365094666591, 3.214055638963846, 0.7296972899654498, 0.8641712078738811, 4.66666277809244, 5.181195655458371, 3.781845662796685, 7.837605719505678, 5.553491250034551, 3.254058672126119, 4.79952711190756, 1.1019928845416291, 0.3363842172033146, 1.628584170494322]]}, \"initial_magmoms\": {\"__ndarray__\": [[12], \"float64\", [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0]]}, \"cell\": {\"__ndarray__\": [[3, 3], \"float64\", [4.52199783, -0.00674351, 1.69501633, 1.79807432, 4.14914966, 1.69501635, -0.036883610000000004, -0.02417627, 6.07615721]]}, \"pbc\": {\"__ndarray__\": [[3], \"bool\", [true, true, true]]}, \"__ase_objtype__\": \"atoms\"}", "formation_energy_per_atom": -2.2577517413888892}, "MgBr2": {"atoms": "{\"numbers\": {\"__ndarray__\": [[3], \"int64\", [12, 35, 35]]}, \"positions\": {\"__ndarray__\": [[3, 3], \"float64\", [0.0, 0.0, 0.0, -1.9287345801243893e-06, 2.22711224022223, 1.4946971674935556, 1.9287365087345798, 1.1135544497777699, 4.907225012506444]]}, \"initial_magmoms\": {\"__ndarray__\": [[3], \"float64\", [0.0, 0.0, 0.0]]}, \"cell\": {\"__ndarray__\": [[3, 3], \"float64\", [3.85746916, -0.0, 0.0, -1.92873458, 3.34066669, 0.0, -0.0, 0.0, 6.40192218]]}, \"pbc\": {\"__ndarray__\": [[3], \"bool\", [true, true, true]]}, \"__ase_objtype__\": \"atoms\"}", "formation_energy_per_atom": -1.84848771}, "FeCO3": {"atoms": "{\"numbers\": {\"__ndarray__\": [[10], \"int64\", [26, 26, 6, 6, 8, 8, 8, 8, 8, 8]]}, \"positions\": {\"__ndarray__\": [[10, 3], \"float64\", [2.9272933314567284, 1.9114790144963563, 6.616813576605867, 0.0, 0.0, 0.0, 4.390922157480482, 2.867211082849009, 9.925171327154509, 1.4636679050765888, 0.955754779637051, 3.3084577017913417, 1.3299271279875373, 2.904766857249935, 5.625724472606166, 3.9738741971329987, 1.8243146427277472, 4.768167537990191, 3.840306918279997, 3.773437205490659, 7.085826481475668, 2.0142832591629327, 0.049528732023281, 6.147791189578482, 4.524663049455393, 0.918199080264005, 7.607893198447984, 1.880707326115722, 1.9986512572640933, 8.465459615226468]]}, \"initial_magmoms\": {\"__ndarray__\": [[10], \"float64\", [3.664, 3.6630000000000003, 0.02, 0.02, 0.07100000000000001, 0.07100000000000001, 0.07100000000000001, 0.07100000000000001, 0.07100000000000001, 0.07100000000000001]]}, \"cell\": {\"__ndarray__\": [[3, 3], \"float64\", [4.24093271, -0.03750986, 3.77733752, 1.67110034, 3.8979897, 3.77733998, -0.05744293, -0.03751394, 5.67894585]]}, \"pbc\": {\"__ndarray__\": [[3], \"bool\", [true, true, true]]}, \"__ase_objtype__\": \"atoms\"}", "formation_energy_per_atom": -1.6023929089999982}} From 2b0a525ebb6f88b376696480ff40a527201fca29 Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 7 Nov 2025 19:42:41 +0000 Subject: [PATCH 13/16] fix tests --- tests/core/calculate/test_ase_calculator.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/core/calculate/test_ase_calculator.py b/tests/core/calculate/test_ase_calculator.py index 4941a2d85d..d53aaafe09 100644 --- a/tests/core/calculate/test_ase_calculator.py +++ b/tests/core/calculate/test_ase_calculator.py @@ -504,7 +504,9 @@ def test_set_predict_formation_energy_mp_corrections_omat_task( def test_set_predict_formation_energy_calculation_correctness( aperiodic_atoms, single_mlip_predict_unit ): - calc = FAIRChemCalculator(single_mlip_predict_unit, task_name="omol") + calc = FAIRChemCalculator.from_model_checkpoint( + "uma-s-1p1", task_name="omol" + ) # (single_mlip_predict_unit, task_name="omol") atoms = molecule("H2") atoms.info["charge"] = 0 @@ -513,13 +515,15 @@ def test_set_predict_formation_energy_calculation_correctness( atoms.calc = calc total_energy = atoms.get_potential_energy() + # reset cached atoms + calc.atoms = None + test_refs = {"H": -0.5} - with set_predict_formation_energy(calc, element_references=test_refs): - atoms.calc = calc + with set_predict_formation_energy(atoms.calc, element_references=test_refs): formation_energy = atoms.get_potential_energy() - expected_formation_energy = total_energy - (2 * test_refs["H"]) - assert np.isclose(formation_energy, expected_formation_energy, atol=1e-6) + expected_formation_energy = total_energy - (2 * test_refs["H"]) + assert np.isclose(formation_energy, expected_formation_energy, atol=1e-6) def test_set_predict_formation_energy_different_task_types(single_mlip_predict_unit): @@ -539,7 +543,7 @@ def test_formation_energy_predictions_against_known_values( predict_unit = pretrained_mlip.get_predict_unit("uma-s-1") calc = FAIRChemCalculator(predict_unit, task_name="omat") - for known_formation_energy, atoms in atoms_with_formation_energy.values(): + for atoms, known_formation_energy in atoms_with_formation_energy.values(): atoms.calc = calc with set_predict_formation_energy(calc): predicted_formation_energy = atoms.get_potential_energy() From 63aafbf643fd9a9cefe592dcae2df3297ed2862a Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 7 Nov 2025 20:25:55 +0000 Subject: [PATCH 14/16] reset calculator cash --- src/fairchem/core/calculate/ase_calculator.py | 5 +++++ tests/core/calculate/test_ase_calculator.py | 12 ++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/fairchem/core/calculate/ase_calculator.py b/src/fairchem/core/calculate/ase_calculator.py index 1be2a63a67..7f0a1ec1e8 100644 --- a/src/fairchem/core/calculate/ase_calculator.py +++ b/src/fairchem/core/calculate/ase_calculator.py @@ -323,6 +323,7 @@ def set_predict_formation_energy( calculator: FAIRChemCalculator, element_references: dict | None = None, apply_corrections: bool | None = None, + clear_calculator_cache: bool = True, ) -> FAIRChemCalculator: """ Adapt a calculator to predict formation energy. @@ -333,6 +334,7 @@ def set_predict_formation_energy( to provide these and instead use the defaults for each UMA task. apply_corrections (bool, optional): Whether to apply MP style corrections to the formation energies. This is only relevant for the OMat task. Default is True if task is OMat. + clear_calculator_cache (bool): Whether to clear the calculator cache before modifying the calculate method. Returns: FAIRChemCalculator: The same calculator instance but will return formation energies as the potential energy. @@ -347,6 +349,9 @@ def set_predict_formation_energy( original_calculate = calculator.calculate + if clear_calculator_cache is True: + calculator.atoms = None + def formation_energy_calculate( atoms: Atoms, properties: list[str], system_changes: list[str] ) -> None: diff --git a/tests/core/calculate/test_ase_calculator.py b/tests/core/calculate/test_ase_calculator.py index d53aaafe09..cfa6ce1d25 100644 --- a/tests/core/calculate/test_ase_calculator.py +++ b/tests/core/calculate/test_ase_calculator.py @@ -515,11 +515,10 @@ def test_set_predict_formation_energy_calculation_correctness( atoms.calc = calc total_energy = atoms.get_potential_energy() - # reset cached atoms - calc.atoms = None - test_refs = {"H": -0.5} - with set_predict_formation_energy(atoms.calc, element_references=test_refs): + with set_predict_formation_energy( + atoms.calc, element_references=test_refs, reset_calculator_cache=True + ): formation_energy = atoms.get_potential_energy() expected_formation_energy = total_energy - (2 * test_refs["H"]) @@ -545,11 +544,12 @@ def test_formation_energy_predictions_against_known_values( for atoms, known_formation_energy in atoms_with_formation_energy.values(): atoms.calc = calc + with set_predict_formation_energy(calc): - predicted_formation_energy = atoms.get_potential_energy() + predicted_formation_energy = atoms.get_potential_energy() / len(atoms) assert np.isclose( predicted_formation_energy, known_formation_energy, - atol=1e-3, + atol=0.3, # eV/atom tolerance ) From c47a9e9a0ff18a328b40d9ddfd0885ba8d72326e Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 7 Nov 2025 20:34:35 +0000 Subject: [PATCH 15/16] reset cache after calculation --- src/fairchem/core/calculate/ase_calculator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/fairchem/core/calculate/ase_calculator.py b/src/fairchem/core/calculate/ase_calculator.py index 7f0a1ec1e8..ba60fe80a0 100644 --- a/src/fairchem/core/calculate/ase_calculator.py +++ b/src/fairchem/core/calculate/ase_calculator.py @@ -397,6 +397,8 @@ def formation_energy_calculate( yield finally: calculator.calculate = original_calculate + if clear_calculator_cache is True: + calculator.atoms = None class MixedPBCError(ValueError): From 460acdc8389d018ea07f5844261b154e82bad816 Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 7 Nov 2025 20:39:45 +0000 Subject: [PATCH 16/16] add MP value to tutorial --- .../examples_tutorials/formation_energy.md | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/docs/inorganic_materials/examples_tutorials/formation_energy.md b/docs/inorganic_materials/examples_tutorials/formation_energy.md index 6eab9a495c..5f366fe6d5 100644 --- a/docs/inorganic_materials/examples_tutorials/formation_energy.md +++ b/docs/inorganic_materials/examples_tutorials/formation_energy.md @@ -67,7 +67,8 @@ result = relax_job( method="fairchem", name_or_path="uma-s-1p1", task_name="omat", - opt_params={"fmax": 1e-3, "optimizer": LBFGS}, + relax_cell=True, + opt_params={"fmax": 1e-3, "optimizer": FIRE}, ) # Get the realxed atoms! @@ -75,18 +76,19 @@ atoms = result["atoms"] # Create an calculator using uma-s-1p1 calculator = FAIRChemCalculator.from_model_checkpoint("uma-s-1p1", task_name="omat") +atoms.calc = calculator -# Adapt the calculator to automatically return MP-style corrected formation energies +# Adapt the calculation to automatically return MP-style corrected formation energies # For the omat task, this defaults to apply MP2020 style corrections with OMat24 compatibility -calculator = set_predict_formation_energy(calculator, apply_corrections=True) - -# Predict the formation energy -atoms.calc = calculator -form_energy = atoms.get_potential_energy() +with set_predict_formation_energy(atoms.calc, apply_corrections=True): + form_energy = atoms.get_potential_energy() ``` ```{code-cell} ipython3 pprint.pprint(f"Total energy: {result["results"]["energy"] eV\n Formation energy {form_energy} eV}) ``` +Compare the results to the value of [-3.038 eV/atom reported](https://next-gen.materialsproject.org/materials/mp-1265?chemsys=Mg-O#thermodynamic_stability) in the the Materials Project! +*Note that we expect differences due to the different DFT settings used to calculate the OMat24 training data.* + Congratulations; you ran your first relaxation and predicted the formation energy of MgO using UMA and `quacc`!