Skip to content

Commit 54f4f35

Browse files
Stabilze forcefield dependencies (#1452)
* opt in forcefields * minimal forcefield deps, use chgnet package for this + docs * other test data
1 parent a8bc650 commit 54f4f35

6 files changed

Lines changed: 58 additions & 26 deletions

File tree

.github/workflows/testing.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ jobs:
6262
micromamba install -n a2 -c conda-forge packmol --yes
6363
6464
- name: Install dependencies
65-
run: | # TODO: remove setuptools pin once `mattersim` removes `pkg_resources` dependency
65+
run: |
6666
micromamba activate a2
67-
python -m pip install --upgrade pip 'setuptools<81'
67+
python -m pip install --upgrade pip
6868
mkdir -p ~/.abinit/pseudos
6969
cp -r tests/test_data/abinit/pseudos/ONCVPSP-PBE-SR-PDv0.4 ~/.abinit/pseudos
7070
uv pip install .[strict,strict-forcefields-${{ matrix.dep-group }},abinit,approxneb,aims] --group tests
@@ -207,8 +207,8 @@ jobs:
207207
run: micromamba run -n a2 pip install uv
208208

209209
- name: Install conda dependencies
210-
run: | # TODO: migrate openff tests to use non smirnoff99frosst forcefields - requires old setuptools / pkg_resources
211-
micromamba install -n a2 -c conda-forge enumlib packmol bader openbabel openff-toolkit==0.16.2 openff-interchange==0.3.22 'setuptools<81' --yes
210+
run: | # TODO: migrate openff tests to use non smirnoff99frosst forcefields as recommended by devs
211+
micromamba install -n a2 -c conda-forge enumlib packmol bader openbabel openff-toolkit==0.16.2 openff-interchange==0.3.22 --yes
212212
213213
- name: Install dependencies
214214
run: |
@@ -376,8 +376,8 @@ jobs:
376376
cache-dependency-path: pyproject.toml
377377

378378
- name: Install dependencies
379-
run: | # TODO: remove setuptools pin once `mattersim` removes `pkg_resources` dependency
380-
python -m pip install --upgrade pip 'setuptools<81'
379+
run: |
380+
python -m pip install --upgrade pip
381381
pip install .[strict,strict-forcefields-generic] --group docs
382382
383383
- name: Build

docs/user/codes/forcefields.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44

55
`atomate2` includes an interface to a few common machine learning interatomic potentials (MLIPs), also known variously as machine learning forcefields (MLFFs), or foundation potentials (FPs) for universal variants.
66

7+
***As of `atomate2==0.1.1`, most forcefield packages are opt-in only. You must install those forcefields which you plan to use.***
8+
Running `pip install 'atomate2[forcefields]'` will install the `chgnet` package to permit you to try the forcefield classes.
9+
You can then select other forcefields you want to use.
10+
11+
We have made this choice both to avoid the appearance of favoritism (both `chgnet` and `atomate2` are Materials Project-supported projects), and to avoid dependency conflicts between MLFF packages.
12+
If you need a sense of which forcefields are compatible, you can use the [pyproject.toml](https://github.com/materialsproject/atomate2/blob/a8bc6505e439503a114f5346aec916aafae7f27b/pyproject.toml#L90) to see which versions are grouped together for testing.
13+
714
Most of `Maker` classes using the forcefields inherit from `atomate2.forcefields.utils.ForceFieldMixin` to specify which forcefield to use.
815
The `ForceFieldMixin` mixin provides the following configurable parameters:
916

pyproject.toml

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,10 @@ defects = [
5454
"pymatgen-analysis-defects>=2024.5.11",
5555
"python-ulid>=2.7",
5656
]
57+
5758
forcefields = [
5859
"ase>=3.26.0",
59-
"calorine>=3.0",
60-
"mace-torch>=0.3.3",
61-
"matgl>=1.2.1",
62-
"torchdata<=0.7.1", # TODO: remove when issue fixed
63-
"quippy-ase>=0.9.14",
64-
"mattersim>=1.0.1",
65-
"sevenn>=0.9.3",
66-
"deepmd-kit>=2.1.4",
67-
"upet>=0.2.1",
60+
"chgnet>=0.4.2",
6861
]
6962
torchsim = [
7063
"torch-sim-atomistic==0.5.0; python_version >= '3.12'"
@@ -98,10 +91,12 @@ strict-openff = [
9891
strict-forcefields-generic = [
9992
"calorine==3.3; python_version >= '3.12'",
10093
"calorine==3.1; python_version < '3.12'",
94+
"chgnet==0.4.2",
10195
"quippy-ase==0.10.2",
10296
"sevenn==0.12.1",
10397
"deepmd-kit==3.1.3",
104-
"tensorflow-cpu==2.21.0",
98+
"tensorflow-cpu==2.21.0; sys_platform == 'linux'",
99+
"tensorflow==2.21.0; sys_platform == 'darwin' or sys_platform == 'win32'",
105100
"mattersim==1.2.1",
106101
"ase<3.28.0", # TODO: remove, required for mattersim because of ase.constraints import
107102
"wandb>=0.24.0", # required for mattersim
@@ -113,7 +108,7 @@ strict-forcefields-torch-limited = [
113108
# That enforces a simultaneous pin on torch / torchdata
114109
# Linux users can acces newer versions of dgl / torch / torchdata via conda.
115110
# Mac / Windows users will need to install from source
116-
"dgl==2.2.1; sys_platform == 'darwin' or sys_platform == 'win32'",
111+
"dgl==2.2.0; sys_platform == 'darwin' or sys_platform == 'win32'",
117112
"dgl<=2.4.0; sys_platform == 'linux'",
118113
"torch==2.2.2; sys_platform == 'darwin' or sys_platform == 'win32'",
119114
"torch==2.2.0; sys_platform == 'linux'",

src/atomate2/forcefields/utils.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from functools import cached_property
1010
from importlib import import_module
1111
from importlib.metadata import PackageNotFoundError, version
12+
from importlib.util import find_spec
1213
from pathlib import Path
1314
from typing import TYPE_CHECKING
1415

@@ -290,6 +291,15 @@ def ase_calculator(
290291
calculator = getattr(import_module(_mod), _cls, None)(**kwargs)
291292

292293
case MLFF.CHGNet | MLFF.M3GNet | MLFF.MATPES_R2SCAN | MLFF.MATPES_PBE:
294+
if calculator_name == MLFF.CHGNet:
295+
# Legacy interface to `chgnet` package
296+
try:
297+
from chgnet.model.dynamics import CHGNetCalculator
298+
299+
return CHGNetCalculator(**kwargs)
300+
except ImportError:
301+
pass
302+
293303
import matgl
294304

295305
match calculator_name:
@@ -299,13 +309,12 @@ def ase_calculator(
299309
case MLFF.CHGNet:
300310
path = kwargs.get("path", "CHGNet-MPtrj-2023.12.1-2.7M-PES")
301311
matgl.config.BACKEND = "DGL"
312+
302313
warnings.warn(
303314
"The CHGNet functionality in atomate2 has been migrated "
304315
"from the `chgnet` package to `matgl` to ensure continuing "
305316
"support. If you want to use the `chgnet` package, "
306-
"`pip install chgnet` and then specify "
307-
'`calculator_meta = {"@module": "chgnet.model.dynamics", '
308-
'"@callable": "CHGNetCalculator"}`',
317+
"`pip install chgnet`",
309318
stacklevel=2,
310319
)
311320
case MLFF.MATPES_R2SCAN | MLFF.MATPES_PBE:
@@ -444,7 +453,13 @@ def _get_pkg_name(calculator_meta: MLFF | dict[str, Any]) -> str | None:
444453
match calculator_meta:
445454
case MLFF.Allegro | MLFF.Nequip:
446455
ff_pkg = "nequip"
447-
case MLFF.CHGNet | MLFF.M3GNet | MLFF.MATPES_PBE | MLFF.MATPES_R2SCAN:
456+
case MLFF.CHGNet:
457+
# Check if CHGNet is installed
458+
try:
459+
ff_pkg = next(pkg for pkg in ("chgnet", "matgl") if find_spec(pkg))
460+
except StopIteration:
461+
ff_pkg = None
462+
case MLFF.M3GNet | MLFF.MATPES_PBE | MLFF.MATPES_R2SCAN:
448463
ff_pkg = "matgl"
449464
case MLFF.DeepMD:
450465
ff_pkg = "deepmd-kit"

tests/forcefields/flows/test_phonon.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from importlib.util import find_spec
23
from pathlib import Path
34
from tempfile import TemporaryDirectory
45

@@ -135,6 +136,8 @@ def test_phonon_wf_force_field(
135136
}
136137
)
137138

139+
is_matgl_chgnet = find_spec("matgl") is not None
140+
138141
phonon_kwargs = dict(
139142
use_symmetrized_structure="conventional",
140143
create_thermal_displacements=False,
@@ -181,7 +184,9 @@ def test_phonon_wf_force_field(
181184

182185
assert_allclose(
183186
ph_bs_dos_doc.free_energies,
184-
[4440.74345, 4172.361432, 2910.000404, 720.739896, -2194.234779],
187+
[4440.74345, 4172.361432, 2910.000404, 720.739896, -2194.234779]
188+
if is_matgl_chgnet
189+
else [5271.300306, 5162.674841, 4353.717375, 2698.616337, 343.125174],
185190
atol=1000,
186191
)
187192

@@ -215,7 +220,9 @@ def test_phonon_wf_force_field(
215220
assert ph_bs_dos_doc.phonopy_settings.kpoint_density_dos == 7_000
216221
assert_allclose(
217222
ph_bs_dos_doc.entropies,
218-
[0.0, 7.374244, 17.612124, 25.802735, 32.209433],
223+
[0.0, 7.374244, 17.612124, 25.802735, 32.209433]
224+
if is_matgl_chgnet
225+
else [0.0, 3.733666, 12.536534, 20.344558, 26.627292],
219226
atol=2,
220227
)
221228
assert_allclose(

tests/forcefields/test_jobs.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import importlib
22
from contextlib import nullcontext
33
from importlib.metadata import version as get_imported_version
4+
from importlib.util import find_spec
45
from pathlib import Path
56

67
import numpy as np
@@ -38,7 +39,10 @@ def test_maker_initialization(mlff):
3839

3940

4041
@pytest.mark.skipif(
41-
dgl is None or not mlff_is_installed("CHGNet"),
42+
# test to see if CHGNet is installed, or that matgl is installed without dgl
43+
# Note that this should be the only test for interface with
44+
# the legacy `chgnet` package
45+
not mlff_is_installed("CHGNet") or (mlff_is_installed("M3GNet") and dgl is None),
4246
reason="CHGNet requires DGL which is not installed",
4347
)
4448
def test_chgnet_static_maker(si_structure):
@@ -48,17 +52,21 @@ def test_chgnet_static_maker(si_structure):
4852
ionic_step_data=("structure", "energy"),
4953
).make(si_structure)
5054

55+
pkg_name = "matgl" if find_spec("matgl") else "chgnet"
56+
5157
# run the flow or job and ensure that it finished running successfully
5258
responses = run_locally(job, ensure_success=True)
5359

5460
# validate job outputs
5561
output1 = responses[job.uuid][1].output
5662
assert isinstance(output1, ForceFieldTaskDocument)
57-
assert output1.output.energy == approx(-10.7907495, rel=1e-4)
63+
assert output1.output.energy == approx(
64+
-10.7907495 if pkg_name == "matgl" else -10.6275053, rel=1e-4
65+
)
5866
assert output1.output.ionic_steps[-1].magmoms is None
5967
assert output1.output.n_steps == 1
6068

61-
assert output1.forcefield_version == get_imported_version("matgl")
69+
assert output1.forcefield_version == get_imported_version(pkg_name)
6270

6371

6472
@pytest.mark.skipif(

0 commit comments

Comments
 (0)