diff --git a/fme/ace/inference/test_evaluator.py b/fme/ace/inference/test_evaluator.py index 1736d2a0a..cf4b9eca7 100644 --- a/fme/ace/inference/test_evaluator.py +++ b/fme/ace/inference/test_evaluator.py @@ -42,6 +42,7 @@ HybridSigmaPressureCoordinate, LatLonCoordinates, ) +from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig, EnergyBudgetConfig from fme.core.dataset.data_typing import VariableMetadata from fme.core.dataset.xarray import XarrayDataConfig from fme.core.dataset_info import DatasetInfo @@ -78,6 +79,7 @@ def save_plus_one_stepper( ocean=None, multi_call: MultiCallConfig | None = None, derived_forcings: DerivedForcingsConfig | None = None, + total_energy_budget_correction: EnergyBudgetConfig | None = None, ): if multi_call is None: all_names = list(set(in_names).union(out_names)) @@ -121,6 +123,9 @@ def save_plus_one_stepper( stds={name: std for name in all_names}, ), ), + corrector=AtmosphereCorrectorConfig( + total_energy_budget_correction=total_energy_budget_correction, + ), ocean=ocean, ), ), diff --git a/fme/ace/stepper/__init__.py b/fme/ace/stepper/__init__.py index f66686e0b..8c7509e5e 100644 --- a/fme/ace/stepper/__init__.py +++ b/fme/ace/stepper/__init__.py @@ -6,6 +6,8 @@ TrainOutput, TrainStepper, TrainStepperConfig, + apply_stepper_override, + apply_stepper_override_to_config, load_stepper, load_stepper_config, process_prediction_generator_list, diff --git a/fme/ace/stepper/single_module.py b/fme/ace/stepper/single_module.py index a991d2742..c2c2f7fd5 100644 --- a/fme/ace/stepper/single_module.py +++ b/fme/ace/stepper/single_module.py @@ -33,7 +33,7 @@ SerializableVerticalCoordinate, VerticalCoordinate, ) -from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig +from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig, EnergyBudgetConfig from fme.core.dataset.data_typing import VariableMetadata from fme.core.dataset.schedule import IntSchedule from fme.core.dataset.utils import encode_timestep @@ -741,6 +741,12 @@ def replace_derived_forcings(self, derived_forcings: DerivedForcingsConfig): self.derived_forcings.validate_replacement(derived_forcings) self.derived_forcings = derived_forcings + def replace_total_energy_budget_correction( + self, value: EnergyBudgetConfig | None + ) -> None: + """Replace total energy budget correction (e.g. turn off during evaluation).""" + self.step.replace_total_energy_budget_correction(value) + @classmethod def from_state(cls, state) -> "StepperConfig": state = cls.remove_deprecated_keys(state) @@ -962,6 +968,22 @@ def replace_derived_forcings(self, derived_forcings: DerivedForcingsConfig): self._config.replace_derived_forcings(derived_forcings) self.forcing_deriver = derived_forcings.build(self._dataset_info) + def replace_total_energy_budget_correction( + self, value: EnergyBudgetConfig | None + ) -> None: + """ + Replace the total energy budget correction (e.g. turn off during evaluation). + + Args: + value: The new total energy budget correction config or None to disable. + """ + self._config.replace_total_energy_budget_correction(value) + new_stepper: Stepper = self._config.get_stepper( + dataset_info=self._dataset_info, + ) + new_stepper._step_obj.load_state(self._step_obj.get_state()) + self._step_obj = new_stepper._step_obj + def get_base_weights(self) -> Weights | None: """ Get the base weights of the stepper. @@ -1752,12 +1774,15 @@ class StepperOverrideConfig: producing a serialized stepper. prescribed_prognostic_names: List of prognostic variable names to overwrite from forcing at each step during inference. + total_energy_budget_correction: Total energy budget correction config to use + for the atmosphere corrector. Use ``None`` to turn off during evaluation. """ ocean: Literal["keep"] | OceanConfig | None = "keep" multi_call: Literal["keep"] | MultiCallConfig | None = "keep" derived_forcings: Literal["keep"] | DerivedForcingsConfig = "keep" prescribed_prognostic_names: Literal["keep"] | list[str] = "keep" + total_energy_budget_correction: Literal["keep"] | EnergyBudgetConfig | None = "keep" def load_stepper_config( @@ -1799,7 +1824,13 @@ def load_stepper( checkpoint_path, map_location=get_device(), weights_only=False ) stepper = Stepper.from_state(checkpoint["stepper"]) + apply_stepper_override(stepper, override_config) + return stepper + +def apply_stepper_override( + stepper: Stepper, override_config: StepperOverrideConfig +) -> None: if override_config.ocean != "keep": logging.info( "Overriding training ocean configuration with a new ocean configuration." @@ -1828,4 +1859,29 @@ def load_stepper( stepper.replace_prescribed_prognostic_names( override_config.prescribed_prognostic_names ) - return stepper + + if override_config.total_energy_budget_correction != "keep": + logging.info( + "Overriding total_energy_budget_correction with %s.", + override_config.total_energy_budget_correction, + ) + stepper.replace_total_energy_budget_correction( + override_config.total_energy_budget_correction + ) + + +def apply_stepper_override_to_config( + stepper_config: StepperConfig, override_config: StepperOverrideConfig +) -> None: + if override_config.ocean != "keep": + stepper_config.replace_ocean(override_config.ocean) + if override_config.derived_forcings != "keep": + stepper_config.replace_derived_forcings(override_config.derived_forcings) + if override_config.prescribed_prognostic_names != "keep": + stepper_config.replace_prescribed_prognostic_names( + override_config.prescribed_prognostic_names + ) + if override_config.total_energy_budget_correction != "keep": + stepper_config.replace_total_energy_budget_correction( + override_config.total_energy_budget_correction + ) diff --git a/fme/ace/stepper/test_single_module.py b/fme/ace/stepper/test_single_module.py index d36463895..7905443ee 100644 --- a/fme/ace/stepper/test_single_module.py +++ b/fme/ace/stepper/test_single_module.py @@ -58,6 +58,7 @@ LatLonCoordinates, VerticalCoordinate, ) +from fme.core.corrector.atmosphere import EnergyBudgetConfig from fme.core.dataset_info import DatasetInfo, MissingDatasetInfo from fme.core.device import get_device from fme.core.generics.optimization import OptimizationABC @@ -98,6 +99,8 @@ INSOLATION_CONFIG = InsolationConfig(INSOLATION_NAME, SOLAR_CONSTANT_AS_VALUE) DERIVED_FORCINGS_CONFIG = DerivedForcingsConfig(insolation=INSOLATION_CONFIG) EMPTY_DERIVED_FORCINGS_CONFIG = DerivedForcingsConfig() +EMPTY_ENERGY_BUDGET_CONFIG = None +DEFAULT_ENERGY_BUDGET_CONFIG = EnergyBudgetConfig(method="constant_temperature") def get_data(names: Iterable[str], n_samples, n_time, epoch: int = 0) -> SphericalData: @@ -1418,67 +1421,99 @@ def test_stepper_from_state_using_resnorm_has_correct_normalizer(): None, None, EMPTY_DERIVED_FORCINGS_CONFIG, + EMPTY_ENERGY_BUDGET_CONFIG, OCEAN_CONFIG, "keep", "keep", + "keep", OCEAN_CONFIG, None, EMPTY_DERIVED_FORCINGS_CONFIG, + EMPTY_ENERGY_BUDGET_CONFIG, ), "persist-ocean": ( OCEAN_CONFIG, None, EMPTY_DERIVED_FORCINGS_CONFIG, + EMPTY_ENERGY_BUDGET_CONFIG, + "keep", "keep", "keep", "keep", OCEAN_CONFIG, None, EMPTY_DERIVED_FORCINGS_CONFIG, + EMPTY_ENERGY_BUDGET_CONFIG, ), "override-multi-call": ( None, None, EMPTY_DERIVED_FORCINGS_CONFIG, + EMPTY_ENERGY_BUDGET_CONFIG, "keep", MULTI_CALL_CONFIG, "keep", + "keep", None, MULTI_CALL_CONFIG, EMPTY_DERIVED_FORCINGS_CONFIG, + EMPTY_ENERGY_BUDGET_CONFIG, ), "persist-multi-call": ( None, MULTI_CALL_CONFIG, EMPTY_DERIVED_FORCINGS_CONFIG, + EMPTY_ENERGY_BUDGET_CONFIG, + "keep", "keep", "keep", "keep", None, MULTI_CALL_CONFIG, EMPTY_DERIVED_FORCINGS_CONFIG, + EMPTY_ENERGY_BUDGET_CONFIG, ), "override-all": ( None, None, EMPTY_DERIVED_FORCINGS_CONFIG, + EMPTY_ENERGY_BUDGET_CONFIG, OCEAN_CONFIG, MULTI_CALL_CONFIG, DERIVED_FORCINGS_CONFIG, + DEFAULT_ENERGY_BUDGET_CONFIG, OCEAN_CONFIG, MULTI_CALL_CONFIG, DERIVED_FORCINGS_CONFIG, + DEFAULT_ENERGY_BUDGET_CONFIG, ), "persist-all": ( OCEAN_CONFIG, MULTI_CALL_CONFIG, DERIVED_FORCINGS_CONFIG, + DEFAULT_ENERGY_BUDGET_CONFIG, + "keep", "keep", "keep", "keep", OCEAN_CONFIG, MULTI_CALL_CONFIG, DERIVED_FORCINGS_CONFIG, + DEFAULT_ENERGY_BUDGET_CONFIG, + ), + "override-energy-budget-correction": ( + OCEAN_CONFIG, + MULTI_CALL_CONFIG, + DERIVED_FORCINGS_CONFIG, + DEFAULT_ENERGY_BUDGET_CONFIG, + "keep", + "keep", + "keep", + EMPTY_ENERGY_BUDGET_CONFIG, + OCEAN_CONFIG, + MULTI_CALL_CONFIG, + DERIVED_FORCINGS_CONFIG, + EMPTY_ENERGY_BUDGET_CONFIG, ), } @@ -1488,12 +1523,15 @@ def test_stepper_from_state_using_resnorm_has_correct_normalizer(): "serialized_ocean_config", "serialized_multi_call_config", "serialized_derived_forcings_config", + "serialized_total_energy_budget_correction_config", "overriding_ocean_config", "overriding_multi_call_config", "overriding_derived_forcings_config", + "overriding_total_energy_budget_correction_config", "expected_ocean_config", "expected_multi_call_config", "expected_derived_forcings_config", + "expected_total_energy_budget_correction_config", ), list(LOAD_STEPPER_TESTS.values()), ids=list(LOAD_STEPPER_TESTS.keys()), @@ -1503,12 +1541,19 @@ def test_load_stepper_and_load_stepper_config( serialized_ocean_config: OceanConfig | None, serialized_multi_call_config: MultiCallConfig | None, serialized_derived_forcings_config: DerivedForcingsConfig, + serialized_total_energy_budget_correction_config: EnergyBudgetConfig | None, overriding_ocean_config: Literal["keep"] | OceanConfig | None, overriding_multi_call_config: Literal["keep"] | MultiCallConfig | None, overriding_derived_forcings_config: Literal["keep"] | DerivedForcingsConfig, + overriding_total_energy_budget_correction_config: Literal["keep"] + | EnergyBudgetConfig + | None, expected_ocean_config: OceanConfig | None, expected_multi_call_config: MultiCallConfig | None, expected_derived_forcings_config: DerivedForcingsConfig, + expected_total_energy_budget_correction_config: Literal["keep"] + | EnergyBudgetConfig + | None, very_fast_only: bool, ): if very_fast_only: @@ -1545,6 +1590,7 @@ def test_load_stepper_and_load_stepper_config( ocean=serialized_ocean_config, multi_call=serialized_multi_call_config, derived_forcings=serialized_derived_forcings_config, + total_energy_budget_correction=serialized_total_energy_budget_correction_config, ) # First check that load_stepper_config and load_stepper functions load @@ -1565,6 +1611,7 @@ def test_load_stepper_and_load_stepper_config( ocean=overriding_ocean_config, multi_call=overriding_multi_call_config, derived_forcings=overriding_derived_forcings_config, + total_energy_budget_correction=overriding_total_energy_budget_correction_config, ) stepper_config = load_stepper_config(stepper_path, stepper_override) @@ -1576,6 +1623,10 @@ def test_load_stepper_and_load_stepper_config( validate_stepper_ocean(stepper, expected_ocean_config) validate_stepper_multi_call(stepper, expected_multi_call_config) assert stepper.config.derived_forcings == expected_derived_forcings_config + assert ( + _get_corrector_total_energy_budget_correction(stepper) + == expected_total_energy_budget_correction_config + ) assert isinstance(stepper.forcing_deriver, ForcingDeriver) @@ -1589,6 +1640,13 @@ def _get_inner_single_module_config(stepper: Stepper): return stepper._step_obj.config +def _get_corrector_total_energy_budget_correction(stepper: Stepper): + """Get total_energy_budget_correction from the stepper's inner corrector config.""" + config = _get_inner_single_module_config(stepper) + corrector = config.corrector + return corrector.total_energy_budget_correction + + def validate_stepper_prescribed_prognostic_names( stepper: Stepper, expected: list[str] ) -> None: diff --git a/fme/core/step/multi_call.py b/fme/core/step/multi_call.py index ec3614d1b..fda6edb92 100644 --- a/fme/core/step/multi_call.py +++ b/fme/core/step/multi_call.py @@ -6,6 +6,7 @@ import torch from torch import nn +from fme.core.corrector.atmosphere import EnergyBudgetConfig from fme.core.dataset_info import DatasetInfo from fme.core.normalizer import StandardNormalizer from fme.core.ocean import OceanConfig @@ -197,6 +198,11 @@ def get_ocean(self) -> OceanConfig | None: def replace_prescribed_prognostic_names(self, names: list[str]) -> None: self.wrapped_step.replace_prescribed_prognostic_names(names) + def replace_total_energy_budget_correction( + self, value: EnergyBudgetConfig | None + ) -> None: + self.wrapped_step.replace_total_energy_budget_correction(value) + def replace_multi_call(self, multi_call: MultiCallConfig | None): self.config = multi_call diff --git a/fme/core/step/single_module.py b/fme/core/step/single_module.py index 592bfbcff..f23811ae3 100644 --- a/fme/core/step/single_module.py +++ b/fme/core/step/single_module.py @@ -8,7 +8,7 @@ import torch from torch import nn -from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig +from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig, EnergyBudgetConfig from fme.core.corrector.registry import CorrectorABC from fme.core.dataset.utils import encode_timestep from fme.core.dataset_info import DatasetInfo @@ -194,6 +194,23 @@ def replace_prescribed_prognostic_names(self, names: list[str]) -> None: ) self.prescribed_prognostic_names = names + def replace_total_energy_budget_correction( + self, value: EnergyBudgetConfig | None + ) -> None: + """Replace total energy budget correction.""" + if isinstance(self.corrector, AtmosphereCorrectorConfig): + self.corrector = dataclasses.replace( + self.corrector, total_energy_budget_correction=value + ) + else: + new_config = dict(self.corrector.config) + new_config["total_energy_budget_correction"] = ( + None if value is None else dataclasses.asdict(value) + ) + self.corrector = CorrectorSelector( + type=self.corrector.type, config=new_config + ) + @classmethod def _remove_deprecated_keys(cls, state: dict[str, Any]) -> dict[str, Any]: state_copy = state.copy() diff --git a/fme/core/step/step.py b/fme/core/step/step.py index 62cdf4064..bf1e20c0e 100644 --- a/fme/core/step/step.py +++ b/fme/core/step/step.py @@ -9,6 +9,7 @@ import torch from torch import nn +from fme.core.corrector.atmosphere import EnergyBudgetConfig from fme.core.dataset_info import DatasetInfo from fme.core.normalizer import StandardNormalizer from fme.core.ocean import OceanConfig @@ -112,6 +113,12 @@ def replace_prescribed_prognostic_names(self, names: list[str]) -> None: """Replace prescribed prognostic names (e.g. when loading from checkpoint).""" pass + def replace_total_energy_budget_correction( + self, value: EnergyBudgetConfig | None + ) -> None: + """Replace total energy budget correction (e.g. turn off during evaluation).""" + pass + @abc.abstractmethod def load(self): """ @@ -219,6 +226,12 @@ def replace_prescribed_prognostic_names(self, names: list[str]) -> None: self._step_config_instance.replace_prescribed_prognostic_names(names) self.config = dataclasses.asdict(self._step_config_instance) + def replace_total_energy_budget_correction( + self, value: EnergyBudgetConfig | None + ) -> None: + self._step_config_instance.replace_total_energy_budget_correction(value) + self.config = dataclasses.asdict(self._step_config_instance) + def load(self): self._step_config_instance.load() self.config = dataclasses.asdict(self._step_config_instance) diff --git a/fme/coupled/inference/evaluator.py b/fme/coupled/inference/evaluator.py index e05aae27f..3b82acfb4 100644 --- a/fme/coupled/inference/evaluator.py +++ b/fme/coupled/inference/evaluator.py @@ -7,6 +7,11 @@ import torch import fme +from fme.ace.stepper import ( + StepperOverrideConfig, + apply_stepper_override, + apply_stepper_override_to_config, +) from fme.ace.stepper import load_stepper as load_single_stepper from fme.ace.stepper import load_stepper_config as load_single_stepper_config from fme.core.cli import prepare_config, prepare_directory @@ -37,6 +42,21 @@ ) +@dataclasses.dataclass +class CoupledStepperOverrideConfig: + """ + Override config for each component when loading a coupled stepper from + separate ocean/atmosphere checkpoints. + + Parameters: + ocean: Override config for the ocean stepper (optional). + atmosphere: Override config for the atmosphere stepper (optional). + """ + + ocean: StepperOverrideConfig | None = None + atmosphere: StepperOverrideConfig | None = None + + @dataclasses.dataclass class StandaloneComponentConfig: """ @@ -69,30 +89,47 @@ class StandaloneComponentCheckpointsConfig: sst_name: str = "sst" ocean_fraction_prediction: CoupledOceanFractionConfig | None = None - def load_stepper_config(self) -> CoupledStepperConfig: + def load_stepper_config( + self, + stepper_override_ocean: StepperOverrideConfig | None = None, + stepper_override_atmosphere: StepperOverrideConfig | None = None, + ) -> CoupledStepperConfig: return CoupledStepperConfig( ocean=ComponentConfig( timedelta=self.ocean.timedelta, - stepper=load_single_stepper_config(self.ocean.path), + stepper=load_single_stepper_config( + self.ocean.path, stepper_override_ocean + ), ), atmosphere=ComponentConfig( timedelta=self.atmosphere.timedelta, - stepper=load_single_stepper_config(self.atmosphere.path), + stepper=load_single_stepper_config( + self.atmosphere.path, stepper_override_atmosphere + ), ), sst_name=self.sst_name, ocean_fraction_prediction=self.ocean_fraction_prediction, ) - def load_stepper(self) -> CoupledStepper: - ocean = load_single_stepper(self.ocean.path) - atmosphere = load_single_stepper(self.atmosphere.path) + def load_stepper( + self, + stepper_override_ocean: StepperOverrideConfig | None = None, + stepper_override_atmosphere: StepperOverrideConfig | None = None, + ) -> CoupledStepper: + ocean = load_single_stepper(self.ocean.path, stepper_override_ocean) + atmosphere = load_single_stepper( + self.atmosphere.path, stepper_override_atmosphere + ) dataset_info = CoupledDatasetInfo( ocean=ocean.training_dataset_info, atmosphere=atmosphere.training_dataset_info, ) return CoupledStepper( - config=self.load_stepper_config(), + config=self.load_stepper_config( + stepper_override_ocean=stepper_override_ocean, + stepper_override_atmosphere=stepper_override_atmosphere, + ), ocean=ocean, atmosphere=atmosphere, dataset_info=dataset_info, @@ -101,12 +138,16 @@ def load_stepper(self) -> CoupledStepper: def load_stepper_config( checkpoint_path: str | pathlib.Path | StandaloneComponentCheckpointsConfig, + stepper_override_ocean: StepperOverrideConfig | None = None, + stepper_override_atmosphere: StepperOverrideConfig | None = None, ) -> CoupledStepperConfig: """Load a coupled stepper configuration. Args: checkpoint_path: The path to the serialized CoupledStepper checkpoint, or a StandaloneComponentCheckpointsConfig. + stepper_override_ocean: Override config for the ocean stepper. + stepper_override_atmosphere: Override config for the atmosphere stepper. Returns: The CoupledStepperConfig from the serialized checkpoint or constructed from the @@ -120,25 +161,37 @@ def load_stepper_config( "Loading atmosphere model checkpoint from " f"{checkpoint_path.atmosphere.path}" ) - return checkpoint_path.load_stepper_config() + return checkpoint_path.load_stepper_config( + stepper_override_ocean=stepper_override_ocean, + stepper_override_atmosphere=stepper_override_atmosphere, + ) logging.info(f"Loading trained coupled model checkpoint from {checkpoint_path}") checkpoint = torch.load( checkpoint_path, map_location=fme.get_device(), weights_only=False ) config = CoupledStepperConfig.from_state(checkpoint["stepper"]["config"]) - + if stepper_override_ocean is not None: + apply_stepper_override_to_config(config.ocean.stepper, stepper_override_ocean) + if stepper_override_atmosphere is not None: + apply_stepper_override_to_config( + config.atmosphere.stepper, stepper_override_atmosphere + ) return config def load_stepper( checkpoint_path: str | pathlib.Path | StandaloneComponentCheckpointsConfig, + stepper_override_ocean: StepperOverrideConfig | None = None, + stepper_override_atmosphere: StepperOverrideConfig | None = None, ) -> CoupledStepper: """Load a coupled stepper. Args: checkpoint_path: The path to the serialized CoupledStepper checkpoint, or a StandaloneComponentCheckpointsConfig. + stepper_override_ocean: Override config for the ocean stepper. + stepper_override_atmosphere: Override config for the atmosphere stepper. Returns: The CoupledStepper serialized in the checkpoint or constructed from the @@ -152,9 +205,17 @@ def load_stepper( "Loading atmosphere model checkpoint from " f"{checkpoint_path.atmosphere.path}" ) - return checkpoint_path.load_stepper() + return checkpoint_path.load_stepper( + stepper_override_ocean=stepper_override_ocean, + stepper_override_atmosphere=stepper_override_atmosphere, + ) - return load_coupled_stepper(checkpoint_path) + stepper = load_coupled_stepper(checkpoint_path) + if stepper_override_ocean is not None: + apply_stepper_override(stepper.ocean, stepper_override_ocean) + if stepper_override_atmosphere is not None: + apply_stepper_override(stepper.atmosphere, stepper_override_atmosphere) + return stepper @dataclasses.dataclass @@ -177,6 +238,10 @@ class InferenceEvaluatorConfig: prediction_loader: Configuration for prediction data to evaluate. If given, model evaluation will not run, and instead predictions will be evaluated. Model checkpoint will still be used to determine inputs and outputs. + stepper_override: Override config for ocean and atmosphere steppers at + inference time (optional). Contains ``ocean`` and ``atmosphere`` + fields, each a StepperOverrideConfig or None. Applied when loading + from either a single coupled checkpoint path or separate checkpoints. """ experiment_dir: str @@ -192,6 +257,7 @@ class InferenceEvaluatorConfig: default_factory=lambda: InferenceEvaluatorAggregatorConfig() ) prediction_loader: InferenceDataLoaderConfig | None = None + stepper_override: CoupledStepperOverrideConfig | None = None def configure_logging(self, log_filename: str): config = dataclasses.asdict(self) @@ -200,10 +266,20 @@ def configure_logging(self, log_filename: str): ) def load_stepper(self) -> CoupledStepper: - return load_stepper(self.checkpoint_path) + override = self.stepper_override + return load_stepper( + self.checkpoint_path, + stepper_override_ocean=override.ocean if override else None, + stepper_override_atmosphere=override.atmosphere if override else None, + ) def load_stepper_config(self) -> CoupledStepperConfig: - return load_stepper_config(self.checkpoint_path) + override = self.stepper_override + return load_stepper_config( + self.checkpoint_path, + stepper_override_ocean=override.ocean if override else None, + stepper_override_atmosphere=override.atmosphere if override else None, + ) def get_data_writer( self, diff --git a/fme/coupled/inference/inference.py b/fme/coupled/inference/inference.py index a1d1f0181..6912a0aae 100644 --- a/fme/coupled/inference/inference.py +++ b/fme/coupled/inference/inference.py @@ -33,6 +33,7 @@ from fme.coupled.stepper import CoupledStepper, CoupledStepperConfig from .evaluator import ( + CoupledStepperOverrideConfig, StandaloneComponentCheckpointsConfig, load_stepper, load_stepper_config, @@ -122,7 +123,11 @@ class InferenceConfig: at a time, will load one more step for initial condition. data_writer: Configuration for data writers. aggregator: Configuration for inference aggregator. - n_ensemble_per_ic: Number of ensemble members per initial condition + n_ensemble_per_ic: Number of ensemble members per initial condition. + stepper_override: Override config for ocean and atmosphere steppers at + inference time (optional). Contains ``ocean`` and ``atmosphere`` + components, each a StepperOverrideConfig or None. Applied when loading + from either a single coupled checkpoint path or separate checkpoints. """ experiment_dir: str @@ -139,6 +144,7 @@ class InferenceConfig: default_factory=lambda: InferenceAggregatorConfig() ) n_ensemble_per_ic: int = 1 + stepper_override: CoupledStepperOverrideConfig | None = None def configure_logging(self, log_filename: str): config = dataclasses.asdict(self) @@ -147,10 +153,20 @@ def configure_logging(self, log_filename: str): ) def load_stepper(self) -> CoupledStepper: - return load_stepper(self.checkpoint_path) + override = self.stepper_override + return load_stepper( + self.checkpoint_path, + stepper_override_ocean=override.ocean if override else None, + stepper_override_atmosphere=override.atmosphere if override else None, + ) def load_stepper_config(self) -> CoupledStepperConfig: - return load_stepper_config(self.checkpoint_path) + override = self.stepper_override + return load_stepper_config( + self.checkpoint_path, + stepper_override_ocean=override.ocean if override else None, + stepper_override_atmosphere=override.atmosphere if override else None, + ) def get_data_writer( self, diff --git a/fme/coupled/inference/test_evaluator.py b/fme/coupled/inference/test_evaluator.py index 7b6ed770b..b819dc36e 100644 --- a/fme/coupled/inference/test_evaluator.py +++ b/fme/coupled/inference/test_evaluator.py @@ -3,6 +3,7 @@ import os import pathlib import shutil +from typing import cast import pytest import torch @@ -10,8 +11,11 @@ import yaml from fme.ace.inference.data_writer.main import DataWriterConfig +from fme.ace.stepper import Stepper, StepperOverrideConfig +from fme.core.corrector.atmosphere import EnergyBudgetConfig from fme.core.dataset.xarray import XarrayDataConfig from fme.core.logging_utils import LoggingConfig +from fme.core.step.single_module import SingleModuleStepConfig from fme.core.testing import mock_wandb from fme.coupled.data_loading.config import CoupledDatasetWithOptionalOceanConfig from fme.coupled.data_loading.inference import ( @@ -25,6 +29,7 @@ InferenceEvaluatorConfig, StandaloneComponentCheckpointsConfig, StandaloneComponentConfig, + load_stepper, main, ) from fme.coupled.stepper import CoupledStepperConfig @@ -49,6 +54,82 @@ def test_standalone_checkpoints_config_init_args(): ) +def _get_corrector_total_energy_budget_correction_from_stepper(stepper: Stepper): + from fme.core.step.multi_call import MultiCallStep + + if isinstance(stepper._step_obj, MultiCallStep): + config = stepper._step_obj._wrapped_step.config + else: + config = stepper._step_obj.config + config = cast(SingleModuleStepConfig, config) + corrector = config.corrector + if hasattr(corrector, "total_energy_budget_correction"): + return corrector.total_energy_budget_correction + return corrector.config.get("total_energy_budget_correction") + + +def test_stepper_override_total_energy_budget_correction_to_none( + tmp_path: pathlib.Path, very_fast_only: bool +): + """StepperOverrideConfig correctly overrides total_energy_budget_correction from + EnergyBudgetConfig to None (no energy corrector) for the atmosphere component.""" + if very_fast_only: + pytest.skip("Skipping non-fast tests") + ocean_in_names = ["o_prog", "sst", "mask_0", "a_diag"] + ocean_out_names = ["o_prog", "sst", "o_diag"] + atmos_in_names = ["a_prog", "surface_temperature", "ocean_fraction"] + atmos_out_names = ["a_prog", "surface_temperature", "a_diag"] + n_coupled_steps = 2 + n_initial_conditions = 1 + + stepper_data_dir = tmp_path / "stepper_data" + dataset_info, _ = _create_dataset_info_for_stepper( + ocean_in_names=ocean_in_names, + ocean_out_names=ocean_out_names, + atmos_in_names=atmos_in_names, + atmos_out_names=atmos_out_names, + n_coupled_steps=n_coupled_steps, + n_initial_conditions=n_initial_conditions, + data_dir=stepper_data_dir, + ) + # Save coupled stepper with atmosphere total_energy_budget_correction set + default_energy_config = EnergyBudgetConfig("constant_temperature", 0.0) + checkpoint_path = save_coupled_stepper( + tmp_path, + ocean_in_names=ocean_in_names, + ocean_out_names=ocean_out_names, + atmos_in_names=atmos_in_names, + atmos_out_names=atmos_out_names, + dataset_info=dataset_info, + atmosphere_total_energy_budget_correction=default_energy_config, + ) + assert isinstance(checkpoint_path, str) + + # Load without override: atmosphere should still have the energy corrector + stepper_no_override = load_stepper(checkpoint_path) + assert ( + _get_corrector_total_energy_budget_correction_from_stepper( + stepper_no_override.atmosphere + ) + is not None + ) + + # Load with override: atmosphere should have total_energy_budget_correction=None + stepper_with_override = load_stepper( + checkpoint_path, + stepper_override_ocean=None, + stepper_override_atmosphere=StepperOverrideConfig( + total_energy_budget_correction=None + ), + ) + assert ( + _get_corrector_total_energy_budget_correction_from_stepper( + stepper_with_override.atmosphere + ) + is None + ) + + def save_coupled_stepper( base_dir: pathlib.Path, ocean_in_names: list[str], @@ -62,6 +143,7 @@ def save_coupled_stepper( save_standalone_component_checkpoints: bool = False, ocean_timedelta: str = "2D", atmosphere_timedelta: str = "1D", + atmosphere_total_energy_budget_correction: EnergyBudgetConfig | None = None, ) -> str | StandaloneComponentCheckpointsConfig: config = get_stepper_config( ocean_in_names=ocean_in_names, @@ -73,6 +155,7 @@ def save_coupled_stepper( ocean_fraction_name=ocean_fraction_name, ocean_timedelta=ocean_timedelta, atmosphere_timedelta=atmosphere_timedelta, + atmosphere_total_energy_budget_correction=atmosphere_total_energy_budget_correction, ) if save_standalone_component_checkpoints: ocean_stepper = config.ocean.stepper.get_stepper(dataset_info.ocean) diff --git a/fme/coupled/test_stepper.py b/fme/coupled/test_stepper.py index 6f6571e61..2a59e190b 100644 --- a/fme/coupled/test_stepper.py +++ b/fme/coupled/test_stepper.py @@ -20,6 +20,7 @@ NullVerticalCoordinate, VerticalCoordinate, ) +from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig, EnergyBudgetConfig from fme.core.dataset_info import DatasetInfo from fme.core.loss import StepLossConfig from fme.core.mask_provider import MaskProvider @@ -902,6 +903,7 @@ def get_stepper_config( ocean_timedelta: str = OCEAN_TIMEDELTA, atmosphere_timedelta: str = ATMOS_TIMEDELTA, ocean_fraction_prediction: CoupledOceanFractionConfig | None = None, + atmosphere_total_energy_budget_correction: EnergyBudgetConfig | None = None, ): # CoupledStepper requires that both component datasets include prognostic # surface temperature variables and that the atmosphere data includes an @@ -923,6 +925,25 @@ def get_stepper_config( if ocean_builder is None: ocean_builder = ModuleSelector(type="prebuilt", config={"module": TimesTwo()}) + atmosphere_step_kwargs: dict = { + "builder": atmosphere_builder, + "in_names": atmosphere_in_names, + "out_names": atmosphere_out_names, + "normalization": NetworkAndLossNormalizationConfig( + network=NormalizationConfig( + means={name: 0.0 for name in atmos_norm_names}, + stds={name: 1.0 for name in atmos_norm_names}, + ), + ), + "ocean": OceanConfig( + surface_temperature_name=sfc_temp_name_in_atmosphere_data, + ocean_fraction_name=ocean_fraction_name, + ), + } + if atmosphere_total_energy_budget_correction is not None: + atmosphere_step_kwargs["corrector"] = AtmosphereCorrectorConfig( + total_energy_budget_correction=atmosphere_total_energy_budget_correction + ) config = CoupledStepperConfig( atmosphere=ComponentConfig( timedelta=atmosphere_timedelta, @@ -930,21 +951,7 @@ def get_stepper_config( step=StepSelector( type="single_module", config=dataclasses.asdict( - SingleModuleStepConfig( - builder=atmosphere_builder, - in_names=atmosphere_in_names, - out_names=atmosphere_out_names, - normalization=NetworkAndLossNormalizationConfig( - network=NormalizationConfig( - means={name: 0.0 for name in atmos_norm_names}, - stds={name: 1.0 for name in atmos_norm_names}, - ), - ), - ocean=OceanConfig( - surface_temperature_name=sfc_temp_name_in_atmosphere_data, - ocean_fraction_name=ocean_fraction_name, - ), - ), + SingleModuleStepConfig(**atmosphere_step_kwargs), ), ), ),