Skip to content

Commit

Permalink
Make CompetingPhasesAnalyzer object MSONable and test
Browse files Browse the repository at this point in the history
  • Loading branch information
kavanase committed Jan 20, 2025
1 parent d5c412e commit d57500e
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 8 deletions.
64 changes: 57 additions & 7 deletions doped/chemical_potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from labellines import labelLines
from matplotlib import colors
from matplotlib.ticker import AutoMinorLocator
from monty.json import MSONable
from monty.serialization import loadfn
from pymatgen.analysis.chempot_diagram import ChemicalPotentialDiagram
from pymatgen.analysis.phase_diagram import PDEntry, PhaseDiagram
Expand Down Expand Up @@ -1835,8 +1836,7 @@ def get_doped_chempots_from_entries(
return _round_floats(chempots, 4)


class CompetingPhasesAnalyzer:
# TODO: Make MSONable!
class CompetingPhasesAnalyzer(MSONable):
def __init__(
self,
composition: Union[str, Composition],
Expand Down Expand Up @@ -1946,10 +1946,13 @@ def __init__(
f"got type {type(entries)} instead!"
)

self.vasprun_paths: list[str] = []
self.parsed_folders: list[str] = []

if isinstance(entries, (str, PathLike)) or isinstance(entries[0], (str, PathLike)):
self._from_vaspruns(path=entries, subfolder=subfolder, verbose=verbose, processes=processes)
else:
self._from_entries(self.entries)
self._from_entries(entries)

def _from_entries(self, entries: list[Union[ComputedEntry, ComputedStructureEntry]]):
r"""
Expand Down Expand Up @@ -1996,7 +1999,8 @@ def _from_entries(self, entries: list[Union[ComputedEntry, ComputedStructureEntr
)
self.elements += self.extrinsic_elements

# TODO: Warn if any missing elemental phases and remove any entries with them in composition
# TODO: Warn if any missing elemental phases and remove any entries with them in composition,
# and remove from element lists?
# set(Composition(d["Formula"]).elements).issubset(self.composition.elements)
# or (
# extrinsic_elements
Expand Down Expand Up @@ -2135,9 +2139,7 @@ def _from_vaspruns(
# subfolders are found) - see how this is done in DefectsParser in analysis.py
# TODO: Add check for matching INCAR and POTCARs from these calcs - can use code/functions from
# analysis.py for this
self.vasprun_paths = []
skipped_folders = []
self.parsed_folders = []

if isinstance(path, list): # if path is just a list of all competing phases
for p in path:
Expand Down Expand Up @@ -2299,6 +2301,54 @@ def _estimate_uncompressed_vasprun_size(vasprun_path: PathLike) -> float:

return self._from_entries(self.entries)

def as_dict(self) -> dict:
"""
Returns:
JSON-serializable dict representation of ``CompetingPhasesAnalyzer``.
"""
return {
"@module": self.__class__.__module__,
"@class": self.__class__.__name__,
"composition": self.composition.as_dict(),
"entries": self.entries,
"unstable_host": self.unstable_host,
"bulk_entry": self.bulk_entry,
"parsed_folders": self.parsed_folders,
"vasprun_paths": self.vasprun_paths,
}

@classmethod
def from_dict(cls, d: dict) -> "CompetingPhasesAnalyzer":
"""
Reconstitute a ``CompetingPhasesAnalyzer`` object from a dict
representation created using ``as_dict()``.
Args:
d (dict): dict representation of ``CompetingPhasesAnalyzer``.
Returns:
``CompetingPhasesAnalyzer`` object
"""
entries = d["entries"]

def get_entry(entry_or_dict):
if isinstance(entry_or_dict, dict):
try:
return ComputedStructureEntry.from_dict(entry_or_dict)
except Exception:
return ComputedEntry.from_dict(entry_or_dict)
return entry_or_dict

cpa = cls(
composition=Composition.from_dict(d["composition"]),
entries=[get_entry(entry) for entry in entries],
)
cpa.unstable_host = d.get("unstable_host", cpa.unstable_host)
cpa.bulk_entry = get_entry(d.get("bulk_entry", cpa.bulk_entry))
cpa.parsed_folders = d.get("parsed_folders", cpa.parsed_folders)
cpa.vasprun_paths = d.get("vasprun_paths", cpa.vasprun_paths)
return cpa

def get_formation_energy_df(
self,
prune_polymorphs: bool = False,
Expand Down Expand Up @@ -2949,7 +2999,7 @@ def __repr__(self):
"""
formula = self.composition.get_reduced_formula_and_factor(iupac_ordering=True)[0]
properties, methods = _doped_obj_properties_methods(self)
joined_entry_list = "\n".join([entry.data["doped_name"] for entry in self.entries])
joined_entry_list = "\n".join([entry.data.get("doped_name", "N/A") for entry in self.entries])
return (
f"doped CompetingPhasesAnalyzer for bulk composition {formula} with {len(self.entries)} "
f"entries (in self.entries):\n{joined_entry_list}\n\n"
Expand Down
40 changes: 39 additions & 1 deletion tests/test_chemical_potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
import pandas as pd
import pytest
from monty.serialization import loadfn
from monty.serialization import dumpfn, loadfn
from pymatgen.core.composition import Composition
from pymatgen.core.structure import Structure
from test_analysis import if_present_rm
Expand Down Expand Up @@ -1104,6 +1104,44 @@ def test_repr(self):
assert "Available attributes:" in repr(la_cpa)
assert "Available methods:" in repr(la_cpa)

def _compare_cpas(self, cpa_a, cpa_b):
for attr in [
"entries",
"chempots",
"extrinsic_elements",
"elements",
"vasprun_paths",
"parsed_folders",
"unstable_host",
"bulk_entry",
"composition",
"phase_diagram",
"chempots_df",
]:
print(f"Checking {attr}")
if attr == "chempots_df":
assert cpa_a.chempots_df.equals(cpa_b.chempots_df)
elif attr == "phase_diagram":
assert cpa_a.phase_diagram.entries == cpa_b.phase_diagram.entries
else:
assert getattr(cpa_a, attr) == getattr(cpa_b, attr)

def _general_cpa_check(self, cpa):
cpa_dict = cpa.as_dict()
cpa_from_dict = chemical_potentials.CompetingPhasesAnalyzer.from_dict(cpa_dict)
self._compare_cpas(cpa, cpa_from_dict)

dumpfn(cpa_dict, "cpa.json")
reloaded_cpa = loadfn("cpa.json")
self._compare_cpas(cpa, reloaded_cpa)

def test_general_cpa_reloading(self):
cpa = chemical_potentials.CompetingPhasesAnalyzer(self.stable_system, self.zro2_path)
self._general_cpa_check(cpa)

la_cpa = chemical_potentials.CompetingPhasesAnalyzer(self.stable_system, self.la_zro2_path)
self._general_cpa_check(la_cpa)


class TestChemicalPotentialGrid(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit d57500e

Please sign in to comment.