Skip to content

Commit 5897cfe

Browse files
committed
output ForceFieldStructureTaskDocument or ForceFieldMoleculeTaskDocument based on the input type of mol_or_struct.
change name ForceFieldTaskDocument => ForceFieldStructureTaskDocument output ForceFieldStructureTaskDocument or ForceFieldMoleculeTaskDocument based on type of mol_or_struct update ForceFieldTaskDocument => ForceFieldStructureTaskDocument in the tests import Union from typing include Union in forcefield/md.py take the suggestions from the formatter take ruff's suggestions try again with ruff format ruff format again try again ruff ruff again fix the mypy error Take inputs of both Molecule and Structure update docstring add molecule test for forcefield
1 parent b8ff9f4 commit 5897cfe

File tree

7 files changed

+200
-79
lines changed

7 files changed

+200
-79
lines changed

src/atomate2/ase/schemas.py

-21
Original file line numberDiff line numberDiff line change
@@ -233,27 +233,6 @@ class AseStructureTaskDoc(StructureMetadata):
233233

234234
tags: Optional[list[str]] = Field(None, description="List of tags for the task.")
235235

236-
@classmethod
237-
def from_ase_task_doc(
238-
cls, ase_task_doc: AseTaskDoc, **task_document_kwargs
239-
) -> AseStructureTaskDoc:
240-
"""Create an AseStructureTaskDoc for a task that has ASE-compatible outputs.
241-
242-
Parameters
243-
----------
244-
ase_task_doc : AseTaskDoc
245-
Task doc for the calculation
246-
task_document_kwargs : dict
247-
Additional keyword args passed to :obj:`.AseStructureTaskDoc()`.
248-
"""
249-
task_document_kwargs.update(
250-
{k: getattr(ase_task_doc, k) for k in _task_doc_translation_keys},
251-
structure=ase_task_doc.mol_or_struct,
252-
)
253-
return cls.from_structure(
254-
meta_structure=ase_task_doc.mol_or_struct, **task_document_kwargs
255-
)
256-
257236

258237
class AseMoleculeTaskDoc(MoleculeMetadata):
259238
"""Document containing information on molecule manipulation using ASE."""

src/atomate2/forcefields/jobs.py

+30-27
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,19 @@
1515

1616
from atomate2.ase.jobs import AseRelaxMaker
1717
from atomate2.forcefields import MLFF, _get_formatted_ff_name
18-
from atomate2.forcefields.schemas import ForceFieldTaskDocument
18+
from atomate2.forcefields.schemas import (
19+
ForceFieldMoleculeTaskDocument,
20+
ForceFieldStructureTaskDocument,
21+
ForceFieldTaskDocument,
22+
)
1923
from atomate2.forcefields.utils import ase_calculator, revert_default_dtype
2024

2125
if TYPE_CHECKING:
2226
from collections.abc import Callable
2327
from pathlib import Path
2428

2529
from ase.calculators.calculator import Calculator
26-
from pymatgen.core.structure import Structure
30+
from pymatgen.core.structure import Molecule, Structure
2731

2832
logger = logging.getLogger(__name__)
2933

@@ -48,7 +52,8 @@ def forcefield_job(method: Callable) -> job:
4852
This is a thin wrapper around :obj:`~jobflow.core.job.Job` that configures common
4953
settings for all forcefield jobs. For example, it ensures that large data objects
5054
(currently only trajectories) are all stored in the atomate2 data store.
51-
It also configures the output schema to be a ForceFieldTaskDocument :obj:`.TaskDoc`.
55+
It also configures the output schema to be a
56+
ForceFieldStructureTaskDocument :obj:`.TaskDoc`.
5257
5358
Any makers that return forcefield jobs (not flows) should decorate the
5459
``make`` method with @forcefield_job. For example:
@@ -72,9 +77,7 @@ def make(structure):
7277
callable
7378
A decorated version of the make function that will generate forcefield jobs.
7479
"""
75-
return job(
76-
method, data=_FORCEFIELD_DATA_OBJECTS, output_schema=ForceFieldTaskDocument
77-
)
80+
return job(method, data=_FORCEFIELD_DATA_OBJECTS)
7881

7982

8083
@dataclass
@@ -118,7 +121,7 @@ class ForceFieldRelaxMaker(AseRelaxMaker):
118121
tags : list[str] or None
119122
A list of tags for the task.
120123
task_document_kwargs : dict (deprecated)
121-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
124+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
122125
"""
123126

124127
name: str = "Force field relax"
@@ -146,15 +149,15 @@ def __post_init__(self) -> None:
146149

147150
@forcefield_job
148151
def make(
149-
self, structure: Structure, prev_dir: str | Path | None = None
150-
) -> ForceFieldTaskDocument:
152+
self, structure: Molecule | Structure, prev_dir: str | Path | None = None
153+
) -> ForceFieldStructureTaskDocument | ForceFieldMoleculeTaskDocument:
151154
"""
152155
Perform a relaxation of a structure using a force field.
153156
154157
Parameters
155158
----------
156-
structure: .Structure
157-
pymatgen structure.
159+
structure: .Structure or Molecule
160+
pymatgen structure or molecule.
158161
prev_dir : str or Path or None
159162
A previous calculation directory to copy output files from. Unused, just
160163
added to match the method signature of other makers.
@@ -170,7 +173,7 @@ def make(
170173
stacklevel=1,
171174
)
172175

173-
return ForceFieldTaskDocument.from_ase_compatible_result(
176+
return ForceFieldTaskDocument.from_ase_compatible_result_forcefield(
174177
str(self.force_field_name), # make mypy happy
175178
ase_result,
176179
self.steps,
@@ -212,7 +215,7 @@ class ForceFieldStaticMaker(ForceFieldRelaxMaker):
212215
calculator_kwargs : dict
213216
Keyword arguments that will get passed to the ASE calculator.
214217
task_document_kwargs : dict (deprecated)
215-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
218+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
216219
"""
217220

218221
name: str = "Force field static"
@@ -255,7 +258,7 @@ class CHGNetRelaxMaker(ForceFieldRelaxMaker):
255258
calculator_kwargs : dict
256259
Keyword arguments that will get passed to the ASE calculator.
257260
task_document_kwargs : dict (deprecated)
258-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
261+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
259262
"""
260263

261264
name: str = f"{MLFF.CHGNet} relax"
@@ -291,7 +294,7 @@ class CHGNetStaticMaker(ForceFieldStaticMaker):
291294
calculator_kwargs : dict
292295
Keyword arguments that will get passed to the ASE calculator.
293296
task_document_kwargs : dict (deprecated)
294-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
297+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
295298
"""
296299

297300
name: str = f"{MLFF.CHGNet} static"
@@ -334,7 +337,7 @@ class M3GNetRelaxMaker(ForceFieldRelaxMaker):
334337
calculator_kwargs : dict
335338
Keyword arguments that will get passed to the ASE calculator.
336339
task_document_kwargs : dict (deprecated)
337-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
340+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
338341
"""
339342

340343
name: str = f"{MLFF.M3GNet} relax"
@@ -372,7 +375,7 @@ class M3GNetStaticMaker(ForceFieldStaticMaker):
372375
calculator_kwargs : dict
373376
Keyword arguments that will get passed to the ASE calculator.
374377
task_document_kwargs : dict (deprecated)
375-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
378+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
376379
"""
377380

378381
name: str = f"{MLFF.M3GNet} static"
@@ -415,7 +418,7 @@ class NEPRelaxMaker(ForceFieldRelaxMaker):
415418
calculator_kwargs : dict
416419
Keyword arguments that will get passed to the ASE calculator.
417420
task_document_kwargs : dict (deprecated)
418-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
421+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
419422
"""
420423

421424
name: str = f"{MLFF.NEP} relax"
@@ -451,7 +454,7 @@ class NEPStaticMaker(ForceFieldStaticMaker):
451454
calculator_kwargs : dict
452455
Keyword arguments that will get passed to the ASE calculator.
453456
task_document_kwargs : dict (deprecated)
454-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
457+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
455458
"""
456459

457460
name: str = f"{MLFF.NEP} static"
@@ -494,7 +497,7 @@ class NequipRelaxMaker(ForceFieldRelaxMaker):
494497
calculator_kwargs : dict
495498
Keyword arguments that will get passed to the ASE calculator.
496499
task_document_kwargs : dict (deprecated)
497-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
500+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
498501
"""
499502

500503
name: str = f"{MLFF.Nequip} relax"
@@ -529,7 +532,7 @@ class NequipStaticMaker(ForceFieldStaticMaker):
529532
calculator_kwargs : dict
530533
Keyword arguments that will get passed to the ASE calculator.
531534
task_document_kwargs : dict (deprecated)
532-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
535+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
533536
"""
534537

535538
name: str = f"{MLFF.Nequip} static"
@@ -576,7 +579,7 @@ class MACERelaxMaker(ForceFieldRelaxMaker):
576579
trained for Matbench Discovery on the MPtrj dataset available at
577580
https://figshare.com/articles/dataset/22715158.
578581
task_document_kwargs : dict (deprecated)
579-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
582+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
580583
"""
581584

582585
name: str = f"{MLFF.MACE_MP_0} relax"
@@ -616,7 +619,7 @@ class MACEStaticMaker(ForceFieldStaticMaker):
616619
trained for Matbench Discovery on the MPtrj dataset available at
617620
https://figshare.com/articles/dataset/22715158.
618621
task_document_kwargs : dict (deprecated)
619-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
622+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
620623
"""
621624

622625
name: str = f"{MLFF.MACE_MP_0} static"
@@ -665,7 +668,7 @@ class SevenNetRelaxMaker(ForceFieldRelaxMaker):
665668
trained for Matbench Discovery on the MPtrj dataset available at
666669
https://figshare.com/articles/dataset/22715158.
667670
task_document_kwargs : dict (deprecated)
668-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
671+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
669672
"""
670673

671674
name: str = f"{MLFF.SevenNet} relax"
@@ -707,7 +710,7 @@ class SevenNetStaticMaker(ForceFieldStaticMaker):
707710
trained for Matbench Discovery on the MPtrj dataset available at
708711
https://figshare.com/articles/dataset/22715158.
709712
task_document_kwargs : dict (deprecated)
710-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
713+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
711714
"""
712715

713716
name: str = f"{MLFF.SevenNet} static"
@@ -747,7 +750,7 @@ class GAPRelaxMaker(ForceFieldRelaxMaker):
747750
calculator_kwargs : dict
748751
Keyword arguments that will get passed to the ASE calculator.
749752
task_document_kwargs : dict (deprecated)
750-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
753+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
751754
"""
752755

753756
name: str = f"{MLFF.GAP} relax"
@@ -783,7 +786,7 @@ class GAPStaticMaker(ForceFieldStaticMaker):
783786
calculator_kwargs : dict
784787
Keyword arguments that will get passed to the ASE calculator.
785788
task_document_kwargs : dict (deprecated)
786-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
789+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
787790
"""
788791

789792
name: str = f"{MLFF.GAP} static"

src/atomate2/forcefields/md.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,18 @@
1515
_DEFAULT_CALCULATOR_KWARGS,
1616
_FORCEFIELD_DATA_OBJECTS,
1717
)
18-
from atomate2.forcefields.schemas import ForceFieldTaskDocument
18+
from atomate2.forcefields.schemas import (
19+
ForceFieldMoleculeTaskDocument,
20+
ForceFieldStructureTaskDocument,
21+
ForceFieldTaskDocument,
22+
)
1923
from atomate2.forcefields.utils import ase_calculator, revert_default_dtype
2024

2125
if TYPE_CHECKING:
2226
from pathlib import Path
2327

2428
from ase.calculators.calculator import Calculator
25-
from pymatgen.core.structure import Structure
29+
from pymatgen.core.structure import Molecule, Structure
2630

2731

2832
@dataclass
@@ -126,19 +130,18 @@ def __post_init__(self) -> None:
126130

127131
@job(
128132
data=[*_FORCEFIELD_DATA_OBJECTS, "ionic_steps"],
129-
output_schema=ForceFieldTaskDocument,
130133
)
131134
def make(
132135
self,
133-
structure: Structure,
136+
structure: Molecule | Structure,
134137
prev_dir: str | Path | None = None,
135-
) -> ForceFieldTaskDocument:
138+
) -> ForceFieldStructureTaskDocument | ForceFieldMoleculeTaskDocument:
136139
"""
137140
Perform MD on a structure using forcefields and jobflow.
138141
139142
Parameters
140143
----------
141-
structure: .Structure
144+
structure: .Structure or Molecule
142145
pymatgen structure.
143146
prev_dir : str or Path or None
144147
A previous calculation directory to copy output files from. Unused, just
@@ -156,7 +159,7 @@ def make(
156159
stacklevel=1,
157160
)
158161

159-
return ForceFieldTaskDocument.from_ase_compatible_result(
162+
return ForceFieldTaskDocument.from_ase_compatible_result_forcefield(
160163
str(self.force_field_name), # make mypy happy
161164
md_result,
162165
relax_cell=(self.ensemble == MDEnsemble.npt),

0 commit comments

Comments
 (0)