Skip to content

Commit a8bc650

Browse files
MLFF elastic workflow stability (#1449)
* default to fixed symmetry relaxation in elastic, throw warning to user * reintroduce strict pymatgen dep
1 parent ecaec88 commit a8bc650

3 files changed

Lines changed: 31 additions & 13 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ strict = [
134134
"atomate2[cclib, phonons, lobster, openmm, mp, defects, ase, ase-ext]",
135135
"numpy<3.0",
136136
"numba>=0.60.0", # needed to get numpy >2,<3 installed
137+
"pymatgen==2026.3.23",
137138
]
138139

139140
[project.scripts]

src/atomate2/forcefields/flows/elastic.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_DEFAULT_RELAX_KWARGS: dict[str, Any] = {
2222
"force_field_name": "CHGNet",
2323
"relax_kwargs": {"fmax": 0.00001},
24+
"fix_symmetry": True,
2425
}
2526

2627

@@ -106,6 +107,7 @@ def from_force_field_name(
106107
cls,
107108
force_field_name: str | MLFF | dict,
108109
calculator_kwargs: dict | None = None,
110+
relax_initial_structure: bool = True,
109111
**kwargs,
110112
) -> Self:
111113
"""
@@ -117,13 +119,24 @@ def from_force_field_name(
117119
The name of the force field.
118120
calculator_kwargs : dict or None (default)
119121
calculator_kwargs to pass to `ForceFieldRelaxMaker`.
122+
relax_initial_structure : bool = True (default)
123+
Whether to relax the structure before computing
124+
the elastic tensor.
120125
**kwargs
121126
Additional kwargs to pass to ElasticMaker.
122127
123128
Returns
124129
-------
125130
ElasticMaker
126131
"""
132+
warnings.warn(
133+
"Fixed symmetry relaxations are automatically enabled "
134+
"to improve elastic tensor stability. To disable this "
135+
"specify ForceFieldRelaxMaker objects explicitly. ",
136+
category=UserWarning,
137+
stacklevel=2,
138+
)
139+
127140
if (mlff_kwargs := kwargs.pop("mlff_kwargs", None)) is not None:
128141
warnings.warn(
129142
"`mlff_kwargs` has been marked for deprecation. "
@@ -148,18 +161,22 @@ def from_force_field_name(
148161
"force_field_name": force_field_name,
149162
"calculator_kwargs": calculator_kwargs or {},
150163
}
151-
bulk_relax_maker = ForceFieldRelaxMaker(
152-
relax_cell=True,
164+
165+
elastic_relax_maker = ForceFieldRelaxMaker(
166+
relax_cell=False,
153167
**default_kwargs,
154168
)
155-
kwargs.update(
156-
bulk_relax_maker=bulk_relax_maker,
157-
elastic_relax_maker=ForceFieldRelaxMaker(
158-
relax_cell=False,
159-
**default_kwargs,
160-
),
161-
)
169+
162170
return cls(
163-
name=f"{bulk_relax_maker.mlff.name} elastic",
171+
name=f"{elastic_relax_maker.mlff.name} elastic",
164172
**kwargs,
173+
bulk_relax_maker=(
174+
ForceFieldRelaxMaker(
175+
relax_cell=True,
176+
**default_kwargs,
177+
)
178+
if relax_initial_structure
179+
else None
180+
),
181+
elastic_relax_maker=elastic_relax_maker,
165182
)

tests/forcefields/flows/test_elastic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_elastic_wf_with_mace(
1616
si_prim = SpacegroupAnalyzer(si_structure).get_primitive_standard_structure()
1717
model_path = f"{test_dir}/forcefields/mace/MACE.model"
1818
common_kwds = {
19-
"force_field_name": "MACE",
19+
"force_field_name": "MACE-MP-0",
2020
"calculator_kwargs": {"model": model_path, "default_dtype": "float64"},
2121
"relax_kwargs": {"fmax": 0.00001},
2222
}
@@ -29,7 +29,7 @@ def test_elastic_wf_with_mace(
2929
ValueError, match="You have specified both `calculator_kwargs` and"
3030
):
3131
ElasticMaker.from_force_field_name(
32-
force_field_name="MACE",
32+
force_field_name="MACE-MP-0",
3333
mlff_kwargs=common_kwds,
3434
calculator_kwargs=common_kwds,
3535
)
@@ -38,7 +38,7 @@ def test_elastic_wf_with_mace(
3838
UserWarning, match="`mlff_kwargs` has been marked for deprecation."
3939
):
4040
maker = ElasticMaker.from_force_field_name(
41-
force_field_name="MACE",
41+
force_field_name="MACE-MP-0",
4242
mlff_kwargs=common_kwds,
4343
)
4444
assert all(

0 commit comments

Comments
 (0)