Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions fme/ace/inference/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
HybridSigmaPressureCoordinate,
LatLonCoordinates,
)
from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig, EnergyBudgetConfig
from fme.core.dataset.data_typing import VariableMetadata
from fme.core.dataset.xarray import XarrayDataConfig
from fme.core.dataset_info import DatasetInfo
Expand Down Expand Up @@ -78,6 +79,7 @@ def save_plus_one_stepper(
ocean=None,
multi_call: MultiCallConfig | None = None,
derived_forcings: DerivedForcingsConfig | None = None,
total_energy_budget_correction: EnergyBudgetConfig | None = None,
):
if multi_call is None:
all_names = list(set(in_names).union(out_names))
Expand Down Expand Up @@ -121,6 +123,9 @@ def save_plus_one_stepper(
stds={name: std for name in all_names},
),
),
corrector=AtmosphereCorrectorConfig(
total_energy_budget_correction=total_energy_budget_correction,
),
ocean=ocean,
),
),
Expand Down
2 changes: 2 additions & 0 deletions fme/ace/stepper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
TrainOutput,
TrainStepper,
TrainStepperConfig,
apply_stepper_override,
apply_stepper_override_to_config,
load_stepper,
load_stepper_config,
process_prediction_generator_list,
Expand Down
60 changes: 58 additions & 2 deletions fme/ace/stepper/single_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
SerializableVerticalCoordinate,
VerticalCoordinate,
)
from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig
from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig, EnergyBudgetConfig
from fme.core.dataset.data_typing import VariableMetadata
from fme.core.dataset.schedule import IntSchedule
from fme.core.dataset.utils import encode_timestep
Expand Down Expand Up @@ -741,6 +741,12 @@ def replace_derived_forcings(self, derived_forcings: DerivedForcingsConfig):
self.derived_forcings.validate_replacement(derived_forcings)
self.derived_forcings = derived_forcings

def replace_total_energy_budget_correction(
self, value: EnergyBudgetConfig | None
) -> None:
"""Replace total energy budget correction (e.g. turn off during evaluation)."""
self.step.replace_total_energy_budget_correction(value)

@classmethod
def from_state(cls, state) -> "StepperConfig":
state = cls.remove_deprecated_keys(state)
Expand Down Expand Up @@ -962,6 +968,22 @@ def replace_derived_forcings(self, derived_forcings: DerivedForcingsConfig):
self._config.replace_derived_forcings(derived_forcings)
self.forcing_deriver = derived_forcings.build(self._dataset_info)

def replace_total_energy_budget_correction(
self, value: EnergyBudgetConfig | None
) -> None:
"""
Replace the total energy budget correction (e.g. turn off during evaluation).

Args:
value: The new total energy budget correction config or None to disable.
"""
self._config.replace_total_energy_budget_correction(value)
new_stepper: Stepper = self._config.get_stepper(
dataset_info=self._dataset_info,
)
new_stepper._step_obj.load_state(self._step_obj.get_state())
self._step_obj = new_stepper._step_obj

def get_base_weights(self) -> Weights | None:
"""
Get the base weights of the stepper.
Expand Down Expand Up @@ -1752,12 +1774,15 @@ class StepperOverrideConfig:
producing a serialized stepper.
prescribed_prognostic_names: List of prognostic variable names to overwrite
from forcing at each step during inference.
total_energy_budget_correction: Total energy budget correction config to use
for the atmosphere corrector. Use ``None`` to turn off during evaluation.
"""

ocean: Literal["keep"] | OceanConfig | None = "keep"
multi_call: Literal["keep"] | MultiCallConfig | None = "keep"
derived_forcings: Literal["keep"] | DerivedForcingsConfig = "keep"
prescribed_prognostic_names: Literal["keep"] | list[str] = "keep"
total_energy_budget_correction: Literal["keep"] | EnergyBudgetConfig | None = "keep"


def load_stepper_config(
Expand Down Expand Up @@ -1799,7 +1824,13 @@ def load_stepper(
checkpoint_path, map_location=get_device(), weights_only=False
)
stepper = Stepper.from_state(checkpoint["stepper"])
apply_stepper_override(stepper, override_config)
return stepper


def apply_stepper_override(
stepper: Stepper, override_config: StepperOverrideConfig
) -> None:
if override_config.ocean != "keep":
logging.info(
"Overriding training ocean configuration with a new ocean configuration."
Expand Down Expand Up @@ -1828,4 +1859,29 @@ def load_stepper(
stepper.replace_prescribed_prognostic_names(
override_config.prescribed_prognostic_names
)
return stepper

if override_config.total_energy_budget_correction != "keep":
logging.info(
"Overriding total_energy_budget_correction with %s.",
override_config.total_energy_budget_correction,
)
stepper.replace_total_energy_budget_correction(
override_config.total_energy_budget_correction
)


def apply_stepper_override_to_config(
stepper_config: StepperConfig, override_config: StepperOverrideConfig
) -> None:
if override_config.ocean != "keep":
stepper_config.replace_ocean(override_config.ocean)
if override_config.derived_forcings != "keep":
stepper_config.replace_derived_forcings(override_config.derived_forcings)
if override_config.prescribed_prognostic_names != "keep":
stepper_config.replace_prescribed_prognostic_names(
override_config.prescribed_prognostic_names
)
if override_config.total_energy_budget_correction != "keep":
stepper_config.replace_total_energy_budget_correction(
override_config.total_energy_budget_correction
)
58 changes: 58 additions & 0 deletions fme/ace/stepper/test_single_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
LatLonCoordinates,
VerticalCoordinate,
)
from fme.core.corrector.atmosphere import EnergyBudgetConfig
from fme.core.dataset_info import DatasetInfo, MissingDatasetInfo
from fme.core.device import get_device
from fme.core.generics.optimization import OptimizationABC
Expand Down Expand Up @@ -98,6 +99,8 @@
INSOLATION_CONFIG = InsolationConfig(INSOLATION_NAME, SOLAR_CONSTANT_AS_VALUE)
DERIVED_FORCINGS_CONFIG = DerivedForcingsConfig(insolation=INSOLATION_CONFIG)
EMPTY_DERIVED_FORCINGS_CONFIG = DerivedForcingsConfig()
EMPTY_ENERGY_BUDGET_CONFIG = None
DEFAULT_ENERGY_BUDGET_CONFIG = EnergyBudgetConfig(method="constant_temperature")


def get_data(names: Iterable[str], n_samples, n_time, epoch: int = 0) -> SphericalData:
Expand Down Expand Up @@ -1418,67 +1421,99 @@ def test_stepper_from_state_using_resnorm_has_correct_normalizer():
None,
None,
EMPTY_DERIVED_FORCINGS_CONFIG,
EMPTY_ENERGY_BUDGET_CONFIG,
OCEAN_CONFIG,
"keep",
"keep",
"keep",
OCEAN_CONFIG,
None,
EMPTY_DERIVED_FORCINGS_CONFIG,
EMPTY_ENERGY_BUDGET_CONFIG,
),
"persist-ocean": (
OCEAN_CONFIG,
None,
EMPTY_DERIVED_FORCINGS_CONFIG,
EMPTY_ENERGY_BUDGET_CONFIG,
"keep",
"keep",
"keep",
"keep",
OCEAN_CONFIG,
None,
EMPTY_DERIVED_FORCINGS_CONFIG,
EMPTY_ENERGY_BUDGET_CONFIG,
),
"override-multi-call": (
None,
None,
EMPTY_DERIVED_FORCINGS_CONFIG,
EMPTY_ENERGY_BUDGET_CONFIG,
"keep",
MULTI_CALL_CONFIG,
"keep",
"keep",
None,
MULTI_CALL_CONFIG,
EMPTY_DERIVED_FORCINGS_CONFIG,
EMPTY_ENERGY_BUDGET_CONFIG,
),
"persist-multi-call": (
None,
MULTI_CALL_CONFIG,
EMPTY_DERIVED_FORCINGS_CONFIG,
EMPTY_ENERGY_BUDGET_CONFIG,
"keep",
"keep",
"keep",
"keep",
None,
MULTI_CALL_CONFIG,
EMPTY_DERIVED_FORCINGS_CONFIG,
EMPTY_ENERGY_BUDGET_CONFIG,
),
"override-all": (
None,
None,
EMPTY_DERIVED_FORCINGS_CONFIG,
EMPTY_ENERGY_BUDGET_CONFIG,
OCEAN_CONFIG,
MULTI_CALL_CONFIG,
DERIVED_FORCINGS_CONFIG,
DEFAULT_ENERGY_BUDGET_CONFIG,
OCEAN_CONFIG,
MULTI_CALL_CONFIG,
DERIVED_FORCINGS_CONFIG,
DEFAULT_ENERGY_BUDGET_CONFIG,
),
"persist-all": (
OCEAN_CONFIG,
MULTI_CALL_CONFIG,
DERIVED_FORCINGS_CONFIG,
DEFAULT_ENERGY_BUDGET_CONFIG,
"keep",
"keep",
"keep",
"keep",
OCEAN_CONFIG,
MULTI_CALL_CONFIG,
DERIVED_FORCINGS_CONFIG,
DEFAULT_ENERGY_BUDGET_CONFIG,
),
"override-energy-budget-correction": (
OCEAN_CONFIG,
MULTI_CALL_CONFIG,
DERIVED_FORCINGS_CONFIG,
DEFAULT_ENERGY_BUDGET_CONFIG,
"keep",
"keep",
"keep",
EMPTY_ENERGY_BUDGET_CONFIG,
OCEAN_CONFIG,
MULTI_CALL_CONFIG,
DERIVED_FORCINGS_CONFIG,
EMPTY_ENERGY_BUDGET_CONFIG,
),
}

Expand All @@ -1488,12 +1523,15 @@ def test_stepper_from_state_using_resnorm_has_correct_normalizer():
"serialized_ocean_config",
"serialized_multi_call_config",
"serialized_derived_forcings_config",
"serialized_total_energy_budget_correction_config",
"overriding_ocean_config",
"overriding_multi_call_config",
"overriding_derived_forcings_config",
"overriding_total_energy_budget_correction_config",
"expected_ocean_config",
"expected_multi_call_config",
"expected_derived_forcings_config",
"expected_total_energy_budget_correction_config",
),
list(LOAD_STEPPER_TESTS.values()),
ids=list(LOAD_STEPPER_TESTS.keys()),
Expand All @@ -1503,12 +1541,19 @@ def test_load_stepper_and_load_stepper_config(
serialized_ocean_config: OceanConfig | None,
serialized_multi_call_config: MultiCallConfig | None,
serialized_derived_forcings_config: DerivedForcingsConfig,
serialized_total_energy_budget_correction_config: EnergyBudgetConfig | None,
overriding_ocean_config: Literal["keep"] | OceanConfig | None,
overriding_multi_call_config: Literal["keep"] | MultiCallConfig | None,
overriding_derived_forcings_config: Literal["keep"] | DerivedForcingsConfig,
overriding_total_energy_budget_correction_config: Literal["keep"]
| EnergyBudgetConfig
| None,
expected_ocean_config: OceanConfig | None,
expected_multi_call_config: MultiCallConfig | None,
expected_derived_forcings_config: DerivedForcingsConfig,
expected_total_energy_budget_correction_config: Literal["keep"]
| EnergyBudgetConfig
| None,
very_fast_only: bool,
):
if very_fast_only:
Expand Down Expand Up @@ -1545,6 +1590,7 @@ def test_load_stepper_and_load_stepper_config(
ocean=serialized_ocean_config,
multi_call=serialized_multi_call_config,
derived_forcings=serialized_derived_forcings_config,
total_energy_budget_correction=serialized_total_energy_budget_correction_config,
)

# First check that load_stepper_config and load_stepper functions load
Expand All @@ -1565,6 +1611,7 @@ def test_load_stepper_and_load_stepper_config(
ocean=overriding_ocean_config,
multi_call=overriding_multi_call_config,
derived_forcings=overriding_derived_forcings_config,
total_energy_budget_correction=overriding_total_energy_budget_correction_config,
)

stepper_config = load_stepper_config(stepper_path, stepper_override)
Expand All @@ -1576,6 +1623,10 @@ def test_load_stepper_and_load_stepper_config(
validate_stepper_ocean(stepper, expected_ocean_config)
validate_stepper_multi_call(stepper, expected_multi_call_config)
assert stepper.config.derived_forcings == expected_derived_forcings_config
assert (
_get_corrector_total_energy_budget_correction(stepper)
== expected_total_energy_budget_correction_config
)
assert isinstance(stepper.forcing_deriver, ForcingDeriver)


Expand All @@ -1589,6 +1640,13 @@ def _get_inner_single_module_config(stepper: Stepper):
return stepper._step_obj.config


def _get_corrector_total_energy_budget_correction(stepper: Stepper):
"""Get total_energy_budget_correction from the stepper's inner corrector config."""
config = _get_inner_single_module_config(stepper)
corrector = config.corrector
return corrector.total_energy_budget_correction


def validate_stepper_prescribed_prognostic_names(
stepper: Stepper, expected: list[str]
) -> None:
Expand Down
6 changes: 6 additions & 0 deletions fme/core/step/multi_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch import nn

from fme.core.corrector.atmosphere import EnergyBudgetConfig
from fme.core.dataset_info import DatasetInfo
from fme.core.normalizer import StandardNormalizer
from fme.core.ocean import OceanConfig
Expand Down Expand Up @@ -197,6 +198,11 @@ def get_ocean(self) -> OceanConfig | None:
def replace_prescribed_prognostic_names(self, names: list[str]) -> None:
self.wrapped_step.replace_prescribed_prognostic_names(names)

def replace_total_energy_budget_correction(
self, value: EnergyBudgetConfig | None
) -> None:
self.wrapped_step.replace_total_energy_budget_correction(value)

def replace_multi_call(self, multi_call: MultiCallConfig | None):
self.config = multi_call

Expand Down
19 changes: 18 additions & 1 deletion fme/core/step/single_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from torch import nn

from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig
from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig, EnergyBudgetConfig
from fme.core.corrector.registry import CorrectorABC
from fme.core.dataset.utils import encode_timestep
from fme.core.dataset_info import DatasetInfo
Expand Down Expand Up @@ -194,6 +194,23 @@ def replace_prescribed_prognostic_names(self, names: list[str]) -> None:
)
self.prescribed_prognostic_names = names

def replace_total_energy_budget_correction(
self, value: EnergyBudgetConfig | None
) -> None:
"""Replace total energy budget correction."""
if isinstance(self.corrector, AtmosphereCorrectorConfig):
self.corrector = dataclasses.replace(
self.corrector, total_energy_budget_correction=value
)
else:
new_config = dict(self.corrector.config)
new_config["total_energy_budget_correction"] = (
None if value is None else dataclasses.asdict(value)
)
self.corrector = CorrectorSelector(
type=self.corrector.type, config=new_config
)

@classmethod
def _remove_deprecated_keys(cls, state: dict[str, Any]) -> dict[str, Any]:
state_copy = state.copy()
Expand Down
Loading
Loading