diff --git a/espei/datasets.py b/espei/datasets.py index ffa27f98..0bf65eaf 100644 --- a/espei/datasets.py +++ b/espei/datasets.py @@ -1,259 +1,336 @@ -import fnmatch, json, os -from typing import Any, Dict, List - +from typing import Any, Literal, Union, TypeAlias, Self +import warnings +from pydantic import BaseModel, Field, model_validator, field_validator import numpy as np +import fnmatch +import json +import os from tinydb.storages import MemoryStorage from tinydb import where from espei.utils import PickleableTinyDB -# Create a type -Dataset = Dict[str, Any] +__all__ = [ + # Models + "Dataset", + "BroadcastSinglePhaseFixedConfigurationDataset", + "ActivityPropertyDataset", + "EquilibriumPropertyDataset", + "ZPFDataset", + + # Errors (when validating models) + "DatasetError", + + # User-facing API + "load_datasets", + "recursive_glob", + "apply_tags", + "to_Dataset", + + # Deprecated + "check_dataset", + "clean_dataset", +] + + +# Type aliases - used to clarify intent +# e.g. when we want a ComponentName rather than a str (even though that's what it is) +ComponentName: TypeAlias = str +PhaseName: TypeAlias = str +PhaseCompositionType: TypeAlias = Union[ + tuple[PhaseName, list[ComponentName], list[float | None]], # The usual definition ["LIQUID", ["B"], [0.5]] + tuple[PhaseName, list[ComponentName], list[float | None], bool] # Handle the disordered flag +] +PhaseRegionType: TypeAlias = list[PhaseCompositionType] + class DatasetError(Exception): """Exception raised when datasets are invalid.""" pass -def recursive_map(f, x): - """ - map, but over nested lists +# Used by BroadcastSinglePhaseFixedConfigurationDataset to define internal DOF +class Solver(BaseModel): + mode: Literal["manual"] = Field(default="manual") + sublattice_site_ratios: list[float] + # TODO: migrate to list[list[list[float]]] + sublattice_configurations: list[list[ComponentName | list[ComponentName]]] + sublattice_occupancies: list[list[float | list[float]]] | None = Field(default=None) - Parameters - ---------- - f : callable - Function to apply to x - x : list or value - Value passed to v - Returns - ------- - list or value - """ - if isinstance(x, list): - if [isinstance(xx, list) for xx in x]: - # we got a nested list - return [recursive_map(f, xx) for xx in x] - else: - # it's a list with some values inside - return list(map(f, x)) - else: - # not a list, probably just a singular value - return f(x) +# Activity dataset special case reference state +class ActivityDataReferenceState(BaseModel): + phases: list[PhaseName] = Field(min_length=1) + conditions: dict[str, float] -def check_dataset(dataset: Dataset): - """Ensure that the dataset is valid and consistent. +# More general reference states for equilibrium property datasets +class ReferenceStates(BaseModel): + phase: PhaseName + fixed_state_variables: dict[str, float] | None = Field(default=None, description="Fixed potentials for the reference state", examples=[{"T": 298.15, "P": 101325}]) - Currently supports the following validation checks: - * data shape is valid - * phases and components used match phases and components entered - * individual shapes of keys, such as ZPF, sublattice configs and site ratios - Planned validation checks: - * all required keys are present +class Dataset(BaseModel): + components: list[ComponentName] = Field(min_length=1) + phases: list[PhaseName] = Field(min_length=1) + conditions: dict[str, float | list[float]] + output: str + # TODO: weights - Note that this follows some of the implicit assumptions in ESPEI at the time - of writing, such that conditions are only P, T, configs for single phase and - essentially only T for ZPF data. + # Control + disabled: bool = Field(default=False) + tags: list[str] = Field(default_factory=list) - Parameters - ---------- - dataset : Dataset - Dictionary of the standard ESPEI dataset. + # Metadata + reference: str = Field(default="") + bibtex: str = Field(default="") + dataset_author: str = Field(default="") - Returns - ------- - None - Raises - ------ - DatasetError - If an error is found in the dataset - """ - is_equilibrium = 'solver' not in dataset.keys() and dataset['output'] != 'ZPF' - is_activity = dataset['output'].startswith('ACR') - is_zpf = dataset['output'] == 'ZPF' - is_single_phase = 'solver' in dataset.keys() - if not any((is_equilibrium, is_single_phase, is_zpf)): - raise DatasetError("Cannot determine type of dataset") - components = dataset['components'] - conditions = dataset['conditions'] - values = dataset['values'] - phases = dataset['phases'] - if is_single_phase: - solver = dataset['solver'] - sublattice_configurations = solver['sublattice_configurations'] - sublattice_site_ratios = solver['sublattice_site_ratios'] - sublattice_occupancies = solver.get('sublattice_occupancies', None) +class BroadcastSinglePhaseFixedConfigurationDataset(Dataset): + phases: list[PhaseName] = Field(min_length=1, max_length=1) + values: list[list[list[float]]] + solver: Solver + conditions: dict[str, float | list[float]] + excluded_model_contributions: list[str] = Field(default_factory=list) + + @model_validator(mode="after") + def validate_components_entered_match_components_used(self) -> Self: + components_entered = set(self.components) + components_used = set() + for config in self.solver.sublattice_configurations: + for subl in config: + if isinstance(subl, list): + components_used.update(set(subl)) + else: + components_used.add(subl) + # Don't count vacancies as a component here + components_difference = components_entered.symmetric_difference(components_used) - {"VA"} + if len(components_difference) != 0: + raise DatasetError(f'Components entered {components_entered} do not match components used {components_used} ({components_difference} different).') + return self + + @model_validator(mode="after") + def validate_condition_value_shape_agreement(self) -> Self: + values_shape = np.array(self.values).shape + num_configs = len(self.solver.sublattice_configurations) + num_temperature = np.atleast_1d(self.conditions["T"]).size + num_pressure = np.atleast_1d(self.conditions["P"]).size + conditions_shape = (num_pressure, num_temperature, num_configs) + if conditions_shape != values_shape: + raise DatasetError(f'Shape of conditions (P, T, configs): {conditions_shape} does not match the shape of the values {values_shape}.') + return self + + @model_validator(mode="after") + def validate_configuration_occupancy_shape_agreement(self) -> Self: + sublattice_configurations = self.solver.sublattice_configurations + sublattice_site_ratios = self.solver.sublattice_site_ratios + sublattice_occupancies = self.solver.sublattice_occupancies # check for mixing is_mixing = any([any([isinstance(subl, list) for subl in config]) for config in sublattice_configurations]) # pad the values of sublattice occupancies if there is no mixing + # just for the purposes of checking validity if sublattice_occupancies is None and not is_mixing: sublattice_occupancies = [None]*len(sublattice_configurations) elif sublattice_occupancies is None: - raise DatasetError('At least one sublattice in the following sublattice configurations is mixing, but the "sublattice_occupancies" key is empty: {}'.format(sublattice_configurations)) - if is_equilibrium: - conditions = dataset['conditions'] + raise DatasetError(f'At least one sublattice in the following sublattice configurations is mixing, but the "sublattice_occupancies" key is empty: {sublattice_configurations}') + + # check that the site ratios are valid as well as site occupancies, if applicable + nconfigs = len(sublattice_configurations) + noccupancies = len(sublattice_occupancies) + if nconfigs != noccupancies: + raise DatasetError(f'Number of sublattice configurations ({nconfigs}) does not match the number of sublattice occupancies ({noccupancies})') + for configuration, occupancy in zip(sublattice_configurations, sublattice_occupancies): + if len(configuration) != len(sublattice_site_ratios): + raise DatasetError(f'Sublattice configuration {configuration} and sublattice site ratio {sublattice_site_ratios} describe different numbers of sublattices ({len(configuration)} and {len(sublattice_site_ratios)}).') + if is_mixing: + configuration_shape = tuple(len(sl) if isinstance(sl, list) else 1 for sl in configuration) + occupancy_shape = tuple(len(sl) if isinstance(sl, list) else 1 for sl in occupancy) + if configuration_shape != occupancy_shape: + raise DatasetError(f'The shape of sublattice configuration {configuration} ({configuration_shape}) does not match the shape of occupancies {occupancy} ({occupancy_shape})') + # check that sublattice interactions are in sorted. Related to sorting in espei.core_utils.get_samples + for subl in configuration: + if isinstance(subl, (list, tuple)) and sorted(subl) != subl: + raise DatasetError(f'Sublattice {subl} in configuration {configuration} is must be sorted in alphabetic order ({sorted(subl)})') + return self + + +# TODO: refactor ActivityPropertyDataset to merge with EquilibriumPropertyDataset +# The validator functions are exactly duplicated in EquilibriumPropertyDataset +# The duplication simplifies the implementation since the activity special case is +# ultimately meant to be removed once activity is a PyCalphad Workspace property +class ActivityPropertyDataset(Dataset): + values: list[list[list[float]]] + reference_state: ActivityDataReferenceState + + @model_validator(mode="after") + def validate_condition_value_shape_agreement(self) -> Self: + conditions = self.conditions comp_conditions = {k: v for k, v in conditions.items() if k.startswith('X_')} - if is_activity: - ref_state = dataset['reference_state'] - elif is_equilibrium: - for el, vals in dataset.get('reference_states', {}).items(): - if 'phase' not in vals: - raise DatasetError(f'Reference state for element {el} must define the `phase` key with the reference phase name.') - - # check that the shape of conditions match the values - num_pressure = np.atleast_1d(conditions['P']).size - num_temperature = np.atleast_1d(conditions['T']).size - if is_equilibrium: - values_shape = np.array(values).shape + num_temperature = np.atleast_1d(self.conditions["T"]).size + num_pressure = np.atleast_1d(self.conditions["P"]).size # check each composition condition is the same shape - num_x_conds = [len(v) for _, v in comp_conditions.items()] + num_x_conds = [np.atleast_1d(vals).size for _, vals in comp_conditions.items()] if num_x_conds.count(num_x_conds[0]) != len(num_x_conds): - raise DatasetError('All compositions in conditions are not the same shape. Note that conditions cannot be broadcast. Composition conditions are {}'.format(comp_conditions)) + raise DatasetError(f'All compositions in conditions are not the same shape. Note that conditions cannot be broadcast. Composition conditions are {comp_conditions}') conditions_shape = (num_pressure, num_temperature, num_x_conds[0]) + values_shape = np.array(self.values).shape if conditions_shape != values_shape: - raise DatasetError('Shape of conditions (P, T, compositions): {} does not match the shape of the values {}.'.format(conditions_shape, values_shape)) - elif is_single_phase: - values_shape = np.array(values).shape - num_configs = len(dataset['solver']['sublattice_configurations']) - conditions_shape = (num_pressure, num_temperature, num_configs) + raise DatasetError(f'Shape of conditions (P, T, compositions): {conditions_shape} does not match the shape of the values {values_shape}.') + return self + + @model_validator(mode="after") + def validate_components_entered_match_components_used(self) -> Self: + conditions = self.conditions + comp_conditions = {ky: vl for ky, vl in conditions.items() if ky.startswith('X_')} + components_entered = set(self.components) + components_used = set() + components_used.update({c.split('_')[1] for c in comp_conditions.keys()}) + if not components_entered.issuperset(components_used): + raise DatasetError(f"Components were used as conditions that are not present in the specified components: {components_used - components_entered}.") + independent_components = components_entered - components_used - {'VA'} + if len(independent_components) != 1: + raise DatasetError(f"Degree of freedom error: expected 1 independent component, got {len(independent_components)} for entered components {components_entered} and {components_used} used in the conditions.") + return self + + +class EquilibriumPropertyDataset(Dataset): + values: list[list[list[float]]] + reference_states: dict[ComponentName, ReferenceStates] | None = Field(default=None) + + @model_validator(mode="after") + def validate_condition_value_shape_agreement(self) -> Self: + conditions = self.conditions + comp_conditions = {k: v for k, v in conditions.items() if k.startswith('X_')} + num_temperature = np.atleast_1d(self.conditions["T"]).size + num_pressure = np.atleast_1d(self.conditions["P"]).size + # check each composition condition is the same shape + num_x_conds = [np.atleast_1d(vals).size for _, vals in comp_conditions.items()] + if num_x_conds.count(num_x_conds[0]) != len(num_x_conds): + raise DatasetError(f'All compositions in conditions are not the same shape. Note that conditions cannot be broadcast. Composition conditions are {comp_conditions}') + conditions_shape = (num_pressure, num_temperature, num_x_conds[0]) + values_shape = np.array(self.values).shape if conditions_shape != values_shape: - raise DatasetError('Shape of conditions (P, T, configs): {} does not match the shape of the values {}.'.format(conditions_shape, values_shape)) - elif is_zpf: - values_shape = (len(values)) - conditions_shape = (num_temperature) + raise DatasetError(f'Shape of conditions (P, T, compositions): {conditions_shape} does not match the shape of the values {values_shape}.') + return self + + @model_validator(mode="after") + def validate_components_entered_match_components_used(self) -> Self: + conditions = self.conditions + comp_conditions = {ky: vl for ky, vl in conditions.items() if ky.startswith('X_')} + components_entered = set(self.components) + components_used = set() + components_used.update({c.split('_')[1] for c in comp_conditions.keys()}) + if not components_entered.issuperset(components_used): + raise DatasetError(f"Components were used as conditions that are not present in the specified components: {components_used - components_entered}.") + independent_components = components_entered - components_used - {'VA'} + if len(independent_components) != 1: + raise DatasetError(f"Degree of freedom error: expected 1 independent component, got {len(independent_components)} for entered components {components_entered} and {components_used} used in the conditions.") + return self + + @model_validator(mode="after") + def validate_reference_state_fully_specified_if_used(self) -> Self: + """If there is a reference state specified, the components in the reference state must match the dataset components""" + components_entered = set(self.components) - {"VA"} + if self.reference_states is not None: + reference_state_components = set(self.reference_states.keys()) - {"VA"} + if components_entered != reference_state_components: + raise DatasetError(f"If used, reference states in equilibrium property must define a reference state for all components in the calculation. Got {components_entered} entered components and {reference_state_components} in the reference states ({components_entered.symmetric_difference(reference_state_components)} non-matching).") + return self + + +class ZPFDataset(Dataset): + values: list[PhaseRegionType] + + @model_validator(mode="after") + def validate_condition_value_shape_agreement(self) -> Self: + values_shape = (len(self.values),) + num_temperature = np.atleast_1d(self.conditions["T"]).size + num_pressure = np.atleast_1d(self.conditions["P"]).size + if num_pressure != 1: + raise DatasetError("Non-scalar pressures are not currently supported") + conditions_shape = (num_temperature,) if conditions_shape != values_shape: - raise DatasetError('Shape of conditions (T): {} does not match the shape of the values {}.'.format(conditions_shape, values_shape)) + raise DatasetError("Shape of conditions (T): {} does not match the shape of the values {}.".format(conditions_shape, values_shape)) + return self - # check that all of the correct phases are present - if is_zpf: - phases_entered = set(phases) + @model_validator(mode="after") + def validate_phases_entered_match_phases_used(self) -> Self: + phases_entered = set(self.phases) phases_used = set() - for zpf in values: - for tieline in zpf: - phases_used.add(tieline[0]) + for phase_region in self.values: + for phase_composition in phase_region: + phases_used.add(phase_composition[0]) if len(phases_entered - phases_used) > 0: - raise DatasetError('Phases entered {} do not match phases used {}.'.format(phases_entered, phases_used)) - - # check that all of the components used match the components entered - components_entered = set(components) - components_used = set() - if is_single_phase: - for config in sublattice_configurations: - for sl in config: - if isinstance(sl, list): - components_used.update(set(sl)) - else: - components_used.add(sl) - comp_dof = 0 - elif is_equilibrium: - components_used.update({c.split('_')[1] for c in comp_conditions.keys()}) - # mass balance of components - comp_dof = len(comp_conditions.keys()) - elif is_zpf: - for zpf in values: - for tieline in zpf: - tieline_comps = set(tieline[1]) - components_used.update(tieline_comps) - if len(components_entered - tieline_comps - {'VA'}) != 1: - raise DatasetError('Degree of freedom error for entered components {} in tieline {} of ZPF {}'.format(components_entered, tieline, zpf)) - # handle special case of mass balance in ZPFs - comp_dof = 1 - if len(components_entered - components_used - {'VA'}) > comp_dof or len(components_used - components_entered) > 0: - raise DatasetError('Components entered {} do not match components used {}.'.format(components_entered, components_used)) - - # check that the ZPF values are formatted properly - if is_zpf: - for zpf in values: - for tieline in zpf: - phase = tieline[0] - component_list = tieline[1] - mole_fraction_list = tieline[2] + raise DatasetError("Phases entered {} do not match phases used {}.".format(phases_entered, phases_used)) + return self + + @model_validator(mode="after") + def validate_components_entered_match_components_used(self) -> Self: + components_entered = set(self.components) + for i, phase_region in enumerate(self.values): + for j, phase_compositions in enumerate(phase_region): + phase_composition_components = set(phase_compositions[1]) + if not components_entered.issuperset(phase_composition_components): + raise DatasetError("Components were used in phase region {} ({}) for phase composition {} ({}) that are not specified as components in the dataset ()", i,phase_region, j, phase_compositions, components_entered) + independent_components = components_entered - phase_composition_components - {'VA'} + if len(independent_components) != 1: + raise DatasetError('Degree of freedom error: expected 1 independent component, got {} for entered components {} and phase composition components {} in phase region {} ({}) for phase composition {} ({})'.format(len(independent_components), components_entered, phase_composition_components, i, phase_region, j, phase_compositions)) + return self + + @field_validator("values", mode="after") + @classmethod + def validate_phase_compositions(cls, values: list[PhaseRegionType]) -> list[PhaseRegionType]: + for i, phase_region in enumerate(values): + for j, phase_composition in enumerate(phase_region): + phase = phase_composition[0] + component_list = phase_composition[1] + mole_fraction_list = phase_composition[2] # check that the phase is a string, components a list of strings, # and the fractions are a list of float if not isinstance(phase, str): - raise DatasetError('The first element in the tieline {} for the ZPF point {} should be a string. Instead it is a {} of value {}'.format(tieline, zpf, type(phase), phase)) + raise DatasetError('The first element in phase composition {} ({}) for phase region {} ({}) should be a string. Instead it is a {} of value {}'.format(j, phase_composition, i, phase_region, type(phase), phase)) if not all([isinstance(comp, str) for comp in component_list]): - raise DatasetError('The second element in the tieline {} for the ZPF point {} should be a list of strings. Instead it is a {} of value {}'.format(tieline, zpf, type(component_list), component_list)) + raise DatasetError('The second element in phase composition {} ({}) for phase region {} ({}) should be a list of strings. Instead it is a {} of value {}'.format(j, phase_composition, i, phase_region, type(component_list), component_list)) if not all([(isinstance(mole_frac, (int, float)) or mole_frac is None) for mole_frac in mole_fraction_list]): - raise DatasetError('The last element in the tieline {} for the ZPF point {} should be a list of numbers. Instead it is a {} of value {}'.format(tieline, zpf, type(mole_fraction_list), mole_fraction_list)) + raise DatasetError('The last element in phase composition {} ({}) for phase region {} ({}) should be a list of numbers. Instead it is a {} of value {}'.format(j, phase_composition, i, phase_region, type(mole_fraction_list), mole_fraction_list)) # check that the shape of components list and mole fractions list is the same if len(component_list) != len(mole_fraction_list): - raise DatasetError('The length of the components list and mole fractions list in tieline {} for the ZPF point {} should be the same.'.format(tieline, zpf)) + raise DatasetError('The length of the components list and mole fractions list in phase composition {} ({}) for phase region {} ({}) should be the same.'.format(j, phase_composition, i, phase_region)) # check that all mole fractions are less than one mf_sum = np.nansum(np.array(mole_fraction_list, dtype=np.float64)) if any([mf is not None for mf in mole_fraction_list]) and mf_sum > 1.0: - raise DatasetError('Mole fractions for tieline {} for the ZPF point {} sum to greater than one.'.format(tieline, zpf)) + raise DatasetError('Mole fractions for phase composition {} ({}) for phase region {} ({}) sum to greater than one.'.format(j, phase_composition, i, phase_region)) + if any([(mf is not None) and (mf < 0.0) for mf in mole_fraction_list]): + raise DatasetError('Got unallowed negative mole fraction for phase composition {} ({}) for phase region {} ({}).'.format(j, phase_composition, i, phase_region)) + return values - # check that the site ratios are valid as well as site occupancies, if applicable - if is_single_phase: - nconfigs = len(sublattice_configurations) - noccupancies = len(sublattice_occupancies) - if nconfigs != noccupancies: - raise DatasetError('Number of sublattice configurations ({}) does not match the number of sublattice occupancies ({})'.format(nconfigs, noccupancies)) - for configuration, occupancy in zip(sublattice_configurations, sublattice_occupancies): - if len(configuration) != len(sublattice_site_ratios): - raise DatasetError('Sublattice configuration {} and sublattice site ratio {} describe different numbers of sublattices ({} and {}).'.format(configuration, sublattice_site_ratios, len(configuration), len(sublattice_site_ratios))) - if is_mixing: - configuration_shape = tuple(len(sl) if isinstance(sl, list) else 1 for sl in configuration) - occupancy_shape = tuple(len(sl) if isinstance(sl, list) else 1 for sl in occupancy) - if configuration_shape != occupancy_shape: - raise DatasetError('The shape of sublattice configuration {} ({}) does not match the shape of occupancies {} ({})'.format(configuration, configuration_shape, occupancy, occupancy_shape)) - # check that sublattice interactions are in sorted. Related to sorting in espei.core_utils.get_samples - for subl in configuration: - if isinstance(subl, (list, tuple)) and sorted(subl) != subl: - raise DatasetError('Sublattice {} in configuration {} is must be sorted in alphabetic order ({})'.format(subl, configuration, sorted(subl))) - -def clean_dataset(dataset: Dataset) -> Dataset: - """ - Clean an ESPEI dataset dictionary. +def to_Dataset(candidate: dict[str, Any]) -> Dataset: + """Return a validated Dataset object for a dataset dict. Raises if a validated dataset cannot be created. Parameters ---------- - dataset: Dataset - Dictionary of the standard ESPEI dataset. dataset : dic + candidate : dict[str, Any] + Dictionary describing an ESPEI dataset. Returns ------- Dataset - Modified dataset that has been cleaned - - Notes - ----- - Assumes a valid, checked dataset. Currently handles - * Converting expected numeric values to floats + Raises + ------ + DatasetError + If an error is found in the dataset """ - dataset["conditions"] = {k: recursive_map(float, v) for k, v in dataset["conditions"].items()} - - solver = dataset.get("solver") - if solver is not None: - solver["sublattice_site_ratios"] = recursive_map(float, solver["sublattice_site_ratios"]) - occupancies = solver.get("sublattice_occupancies") - if occupancies is not None: - solver["sublattice_occupancies"] = recursive_map(float, occupancies) - - if dataset["output"] == "ZPF": - values = dataset["values"] - new_values = [] - for tieline in values: - new_tieline = [] - for tieline_point in tieline: - if all([comp is None for comp in tieline_point[2]]): - # this is a null tieline point - new_tieline.append(tieline_point) - else: - new_tieline.append([tieline_point[0], tieline_point[1], recursive_map(float, tieline_point[2])]) - new_values.append(new_tieline) - dataset["values"] = new_values + if candidate["output"] == "ZPF": + return ZPFDataset.model_validate(candidate) + elif candidate['output'].startswith('ACR'): + return ActivityPropertyDataset.model_validate(candidate) + elif 'solver' in candidate.keys(): + return BroadcastSinglePhaseFixedConfigurationDataset.model_validate(candidate) else: - # values should be all numerical - dataset["values"] = recursive_map(float, dataset["values"]) - - return dataset + return EquilibriumPropertyDataset.model_validate(candidate) def apply_tags(datasets: PickleableTinyDB, tags): @@ -333,8 +410,7 @@ def load_datasets(dataset_filenames, include_disabled=False) -> PickleableTinyDB if not include_disabled and d.get('disabled', False): # The dataset is disabled and not included continue - check_dataset(d) - ds_database.insert(clean_dataset(d)) + ds_database.insert(to_Dataset(d).model_dump()) except ValueError as e: raise ValueError('JSON Error in {}: {}'.format(fname, e)) except DatasetError as e: @@ -364,3 +440,15 @@ def recursive_glob(start, pattern='*.json'): for filename in fnmatch.filter(filenames, pattern): matches.append(os.path.join(root, filename)) return sorted(matches) + + +def check_dataset(dataset: dict[str, Any]) -> dict[str, Any]: + """Ensure that the dataset is valid and consistent by round-tripping through pydantic.""" + warnings.warn("check_dataset is deprecated will be removed in ESPEI 0.11. Behavior has been migrated to the pydantic dataset implementations in espei.datasets.dataset_models. To get a Dataset object, use espei.datasets.to_Dataset.", DeprecationWarning) + return to_Dataset(dataset).model_dump() + + +def clean_dataset(dataset: dict[str, Any]) -> dict[str, Any]: + """Ensure that the dataset is valid and consistent by round-tripping through pydantic.""" + warnings.warn("clean_dataset is deprecated will be removed in ESPEI 0.11. Behavior has been migrated to the pydantic dataset implementations in espei.datasets.dataset_models. To get a Dataset object, use espei.datasets.to_Dataset.", DeprecationWarning) + return to_Dataset(dataset).model_dump() diff --git a/espei/error_functions/equilibrium_thermochemical_error.py b/espei/error_functions/equilibrium_thermochemical_error.py index bd96194c..6352811d 100644 --- a/espei/error_functions/equilibrium_thermochemical_error.py +++ b/espei/error_functions/equilibrium_thermochemical_error.py @@ -87,7 +87,7 @@ def build_eqpropdata(data: tinydb.database.Document, # Models are now modified in response to the data from this data # TODO: build a reference state MetaProperty with the reference state information, maybe just-in-time, below - if 'reference_states' in data: + if data.get("reference_states") is not None: property_output = output[:-1] if output.endswith('R') else output # unreferenced model property so we can tell shift_reference_state what to build. reference_states = [] for el, vals in data['reference_states'].items(): diff --git a/tests/test_core_utils.py b/tests/test_core_utils.py index 5c150f1e..9fd45dca 100644 --- a/tests/test_core_utils.py +++ b/tests/test_core_utils.py @@ -2,7 +2,6 @@ import tinydb from espei.core_utils import get_prop_data, filter_configurations, filter_temperatures, symmetry_filter, ravel_zpf_values -from espei.datasets import recursive_map from espei.sublattice_tools import recursive_tuplify from espei.utils import PickleableTinyDB, MemoryStorage from espei.error_functions.non_equilibrium_thermochemical_error import get_prop_samples @@ -55,19 +54,6 @@ def test_get_data_for_a_minimal_example(): assert desired_data['values'] == np.array([[[34720.0]]]) -def test_recursive_map(): - """Test that recursive map function works""" - - strings = [[["1.0"], ["5.5", "8.8"], ["10.7"]]] - floats = [[[1.0], [5.5, 8.8], [10.7]]] - - assert recursive_map(float, strings) == floats - assert recursive_map(str, floats) == strings - assert recursive_map(float, "1.234") == 1.234 - assert recursive_map(int, ["1", "2", "5"]) == [1, 2, 5] - assert recursive_map(float, ["1.0", ["0.5", "0.5"]]) == [1.0, [0.5, 0.5]] - - def test_get_prop_samples_ravels_correctly(): """get_prop_samples should ravel non-equilibrium thermochemical data correctly""" desired_data = [{ diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 65d914c2..437423bc 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,9 +1,11 @@ +from copy import deepcopy import pytest import numpy as np -from espei.datasets import DatasetError, check_dataset, clean_dataset, apply_tags +from espei.datasets import DatasetError, to_Dataset, apply_tags, BroadcastSinglePhaseFixedConfigurationDataset, ZPFDataset from .testing_data import CU_MG_EXP_ACTIVITY, CU_MG_DATASET_THERMOCHEMICAL_STRING_VALUES, CU_MG_DATASET_ZPF_STRING_VALUES, LI_SN_LIQUID_DATA, dataset_multi_valid_ternary from .fixtures import datasets_db +from pydantic import ValidationError dataset_single_valid = { "components": ["AL", "NI", "VA"], @@ -294,6 +296,21 @@ ], } +dataset_zpf_negative_mole_fraction = { + "components": ["AL", "NI", "VA"], + "phases": ["AL3NI2", "BCC_B2"], + "conditions": { + "P": 101325, + "T": [1348] + }, + "output": "ZPF", + "values": [ + [["AL3NI2", ["NI"], [-0.5]], ["BCC_B2", ["NI"], [None]]], # mole fraction is negative + ], +} + + + dataset_single_unsorted_interaction = { "components": ["AL", "NI", "VA"], "phases": ["BCC_B2"], @@ -349,91 +366,195 @@ def test_check_datasets_run_on_good_data(): """Passed valid datasets that should raise DatasetError.""" - check_dataset(dataset_single_valid) - check_dataset(dataset_multi_valid) - check_dataset(dataset_multi_valid_ternary) + to_Dataset(dataset_single_valid) + to_Dataset(dataset_multi_valid) + to_Dataset(dataset_multi_valid_ternary) def test_check_datasets_raises_on_misaligned_data(): """Passed datasets that have misaligned data and conditions should raise DatasetError.""" with pytest.raises(DatasetError): - check_dataset(dataset_single_misaligned) + to_Dataset(dataset_single_misaligned) with pytest.raises(DatasetError): - check_dataset(dataset_multi_misaligned) + to_Dataset(dataset_multi_misaligned) def test_check_datasets_raises_with_incorrect_zpf_phases(): """Passed datasets that have incorrect phases entered than used should raise.""" with pytest.raises(DatasetError): - check_dataset(dataset_multi_incorrect_phases) + to_Dataset(dataset_multi_incorrect_phases) def test_check_datasets_raises_with_incorrect_components(): """Passed datasets that have incorrect components entered vs. used should raise.""" with pytest.raises(DatasetError): - check_dataset(dataset_single_incorrect_components_overspecified) + to_Dataset(dataset_single_incorrect_components_overspecified) with pytest.raises(DatasetError): - check_dataset(dataset_single_incorrect_components_underspecified) + to_Dataset(dataset_single_incorrect_components_underspecified) + with pytest.raises(DatasetError): + to_Dataset(dataset_multi_incorrect_components_overspecified) + with pytest.raises(DatasetError): + to_Dataset(dataset_multi_incorrect_components_underspecified) + + # equilibrium datasets underspecified + ds_eq_underspecified = { + "components": ["NI"], + "phases": ["LIQUID"], + "conditions": { + "P": 101325, + "T": [1348, 1176, 977], + "X_NI": 0.5 + }, + "output": "HM", + "values": [[[-1000], [-900], [-800]]] + } with pytest.raises(DatasetError): - check_dataset(dataset_multi_incorrect_components_overspecified) + to_Dataset(ds_eq_underspecified) + + # equilibrium datasets overspecified + ds_eq_overspecified = { + "components": ["CU", "MG", "NI"], + "phases": ["LIQUID"], + "conditions": { + "P": 101325, + "T": [1348, 1176, 977], + "X_NI": 0.5 + }, + "output": "HM", + "values": [[[-1000], [-900], [-800]]] + } with pytest.raises(DatasetError): - check_dataset(dataset_multi_incorrect_components_underspecified) + to_Dataset(ds_eq_overspecified) def test_check_datasets_raises_with_malformed_zpf(): """Passed datasets that have malformed ZPF values should raise.""" + with pytest.raises((DatasetError, ValidationError)): + to_Dataset(dataset_multi_malformed_zpfs_components_not_list) with pytest.raises(DatasetError): - check_dataset(dataset_multi_malformed_zpfs_components_not_list) + to_Dataset(dataset_multi_malformed_zpfs_fractions_do_not_match_components) with pytest.raises(DatasetError): - check_dataset(dataset_multi_malformed_zpfs_fractions_do_not_match_components) - with pytest.raises(DatasetError): - check_dataset(dataset_multi_malformed_zpfs_components_do_not_match_fractions) + to_Dataset(dataset_multi_malformed_zpfs_components_do_not_match_fractions) def test_check_datasets_raises_with_malformed_sublattice_configurations(): """Passed datasets that have malformed ZPF values should raise.""" with pytest.raises(DatasetError): - check_dataset(dataset_single_malformed_site_occupancies) + to_Dataset(dataset_single_malformed_site_occupancies) + with pytest.raises(DatasetError): + to_Dataset(dataset_single_malformed_site_ratios) + + +def test_check_datasets_raises_with_equilibrium_conditions_and_values_shapes_mismatch(): + """Passed equilibrium datasets that have mismatched condition and values shapes should raise.""" + COND_VALS_SHAPE_GOOD = { + "components": ["CU", "MG"], + "phases": ["LIQUID"], + "conditions": {"P": [101325, 1e5], "T": [1400, 1500, 1600], "X_MG": [0.5, 0.6, 0.7, 0.8]}, + "reference_states": { + "CU": {"phase": "LIQUID"}, + "MG": {"phase": "LIQUID"} + }, + "output": "HMR", + "values": np.zeros((2, 3, 4)).tolist(), + "reference": "equilibrium thermochemical tests", + } + # Good shape should not raise + to_Dataset(COND_VALS_SHAPE_GOOD) + + COND_VALS_SHAPE_DISAGREEMENT_1_1_2 = { + "components": ["CU", "MG", "NI"], + "phases": ["LIQUID"], + "conditions": {"P": 101325, "T": [1400], "X_MG": [0.5, 0.6], "X_NI": [0.5, 0.6]}, + "reference_states": { + "CU": {"phase": "LIQUID"}, + "MG": {"phase": "LIQUID"}, + "NI": {"phase": "LIQUID"} + }, + "output": "HMR", + "values": [[[0]]], + "reference": "equilibrium thermochemical tests", + } + with pytest.raises(DatasetError): + to_Dataset(COND_VALS_SHAPE_DISAGREEMENT_1_1_2) + + COND_VALS_SHAPE_DISAGREEMENT_1_2_2 = { + "components": ["CU", "MG"], + "phases": ["LIQUID"], + "conditions": {"P": 101325, "T": [1400, 1500], "X_MG": [0.5, 0.6]}, + "reference_states": { + "CU": {"phase": "LIQUID"}, + "MG": {"phase": "LIQUID"} + }, + "output": "HMR", + "values": [[[0, 0]]], + "reference": "equilibrium thermochemical tests", + } + with pytest.raises(DatasetError): + to_Dataset(COND_VALS_SHAPE_DISAGREEMENT_1_2_2) + + # we don't broadcast over compositions, so composition conditions shapes need to match + MISMATCHED_COMPOSITION_CONDS = { + "components": ["CU", "MG", "NI"], + "phases": ["LIQUID"], + "conditions": {"P": 101325, "T": [1400], "X_MG": [0.5, 0.6], "X_NI": [0.5]}, + "reference_states": { + "CU": {"phase": "LIQUID"}, + "MG": {"phase": "LIQUID"}, + "NI": {"phase": "LIQUID"} + }, + "output": "HMR", + "values": [[[0, 0]]], + "reference": "equilibrium thermochemical tests", + } with pytest.raises(DatasetError): - check_dataset(dataset_single_malformed_site_ratios) + to_Dataset(MISMATCHED_COMPOSITION_CONDS) def test_check_datasets_works_on_activity_data(): """Passed activity datasets should work correctly.""" - check_dataset(CU_MG_EXP_ACTIVITY) + to_Dataset(CU_MG_EXP_ACTIVITY) def test_check_datasets_raises_with_zpf_fractions_greater_than_one(): """Passed datasets that have mole fractions greater than one should raise.""" with pytest.raises(DatasetError): - check_dataset(dataset_multi_mole_fractions_as_percents) + to_Dataset(dataset_multi_mole_fractions_as_percents) + + +def test_check_datasets_raises_with_negative_zpf_fractions(): + """Passed datasets that have negative mole fractions should raise.""" + with pytest.raises(DatasetError): + to_Dataset(dataset_zpf_negative_mole_fraction) def test_check_datasets_raises_with_unsorted_interactions(): """Passed datasets that have sublattice interactions not in sorted order should raise.""" with pytest.raises(DatasetError): - check_dataset(dataset_single_unsorted_interaction) + to_Dataset(dataset_single_unsorted_interaction) def test_datasets_convert_thermochemical_string_values_producing_correct_value(datasets_db): """Strings where floats are expected should give correct answers for thermochemical datasets""" - ds = clean_dataset(CU_MG_DATASET_THERMOCHEMICAL_STRING_VALUES) - assert np.issubdtype(np.array(ds['values']).dtype, np.number) - assert np.issubdtype(np.array(ds['conditions']['T']).dtype, np.number) - assert np.issubdtype(np.array(ds['conditions']['P']).dtype, np.number) + ds = to_Dataset(CU_MG_DATASET_THERMOCHEMICAL_STRING_VALUES) + assert isinstance(ds, BroadcastSinglePhaseFixedConfigurationDataset) + assert np.issubdtype(np.array(ds.values).dtype, np.number) + assert np.issubdtype(np.array(ds.conditions['T']).dtype, np.number) + assert np.issubdtype(np.array(ds.conditions['P']).dtype, np.number) def test_datasets_convert_zpf_string_values_producing_correct_value(datasets_db): """Strings where floats are expected should give correct answers for ZPF datasets""" - ds = clean_dataset(CU_MG_DATASET_ZPF_STRING_VALUES) - assert np.issubdtype(np.array([t[0][2] for t in ds['values']]).dtype, np.number) - assert np.issubdtype(np.array(ds['conditions']['T']).dtype, np.number) - assert np.issubdtype(np.array(ds['conditions']['P']).dtype, np.number) + ds = to_Dataset(CU_MG_DATASET_ZPF_STRING_VALUES) + assert isinstance(ds, ZPFDataset) + assert np.issubdtype(np.array([t[0][2] for t in ds.values]).dtype, np.number) + assert np.issubdtype(np.array(ds.conditions['T']).dtype, np.number) + assert np.issubdtype(np.array(ds.conditions['P']).dtype, np.number) def test_check_datasets_raises_if_configs_occupancies_not_aligned(datasets_db): """Checking datasets that don't have the same number/shape of configurations/occupancies should raise.""" with pytest.raises(DatasetError): - check_dataset(dataset_mismatched_configs_occupancies) + to_Dataset(dataset_mismatched_configs_occupancies) # Expected to fail, since the dataset checker cannot determine that species are used in the configurations and components should only contain pure elements. @@ -441,12 +562,12 @@ def test_check_datasets_raises_if_configs_occupancies_not_aligned(datasets_db): def test_non_equilibrium_thermo_data_with_species_passes_checker(): """Non-equilibrium thermochemical data that use species in the configurations should pass the dataset checker. """ - check_dataset(LI_SN_LIQUID_DATA) + to_Dataset(LI_SN_LIQUID_DATA) def test_applying_tags(datasets_db): """Test that applying tags updates the appropriate values""" - dataset = clean_dataset(CU_MG_DATASET_THERMOCHEMICAL_STRING_VALUES) + dataset = deepcopy(CU_MG_DATASET_THERMOCHEMICAL_STRING_VALUES) # overwrite tags for this test dataset["tags"] = ["testtag"] datasets_db.insert(dataset) diff --git a/tests/testing_data.py b/tests/testing_data.py index 73c9540e..dc528c19 100644 --- a/tests/testing_data.py +++ b/tests/testing_data.py @@ -584,7 +584,6 @@ "P": 101325, "T": [1337.97, 1262.238] }, - "broadcast_conditions": false, "output": "ZPF", "values": [ [["LIQUID", ["MG"], [0.0246992]], ["FCC_A1", ["MG"], [null]]], @@ -695,7 +694,6 @@ "P": "101325", "T": ["1337.97", "1262.238"] }, - "broadcast_conditions": false, "output": "ZPF", "values": [ [["LIQUID", ["MG"], ["0.0246992"]], ["FCC_A1", ["MG"], [null]]], @@ -713,7 +711,6 @@ "P": 101325, "T": [733.15] }, - "broadcast_conditions": false, "output": "ZPF", "values": [ [["__HYPERPLANE__", ["CU"], [0.05]], ["HCP_A3", ["CU"], [null]], ["CUMG2", ["CU"], [null]]] @@ -964,7 +961,6 @@ CR_NI_ZPF_DATA = { "components": ["CR", "NI", "VA"], "phases": ["BCC_A2", "FCC_A1"], - "broadcast_conditions": False, "conditions": { "T": [1073, 1173, 1273, 1373, 1548], "P": [101325.0] @@ -1478,7 +1474,6 @@ LI_SN_ZPF_DATA = { "components": ["LI", "SN"], "phases": ["LIQUID", "LI7SN2"], - "broadcast_conditions": False, "conditions": { "T": [1040], "P": [101325.0]