diff --git a/fme/ace/inference/inference.py b/fme/ace/inference/inference.py index 5f06a39ed..9af46d5ef 100644 --- a/fme/ace/inference/inference.py +++ b/fme/ace/inference/inference.py @@ -305,7 +305,7 @@ def run_inference_from_config(config: InferenceConfig): logging.info("Loading initial condition data") initial_condition = get_initial_condition( config.initial_condition.get_dataset(), - stepper_config.prognostic_names, + stepper_config.get_prognostic_names(), labels=config.labels, n_ensemble=config.n_ensemble_per_ic, ) diff --git a/fme/ace/step/fcn3.py b/fme/ace/step/fcn3.py index 2c32d8ca5..c97507554 100644 --- a/fme/ace/step/fcn3.py +++ b/fme/ace/step/fcn3.py @@ -19,7 +19,6 @@ from fme.core.distributed import Distributed from fme.core.normalizer import NetworkAndLossNormalizationConfig, StandardNormalizer from fme.core.ocean import Ocean, OceanConfig -from fme.core.optimization import NullOptimization from fme.core.packer import Packer from fme.core.registry import CorrectorSelector from fme.core.step.args import StepArgs @@ -229,7 +228,8 @@ def get_loss_normalizer( extra_residual_scaled_names = [] return self.normalization.get_loss_normalizer( names=self._normalize_names + extra_names, - residual_scaled_names=self.prognostic_names + extra_residual_scaled_names, + residual_scaled_names=self.get_prognostic_names() + + extra_residual_scaled_names, ) @classmethod @@ -386,7 +386,6 @@ def __init__( self.module = dist.wrap_module(module) self._img_shape = dataset_info.img_shape self._config = config - self._no_optimization = NullOptimization() self._timestep = timestep @@ -483,7 +482,7 @@ def network_call(input_norm: TensorDict) -> TensorDict: corrector=self._corrector, ocean=self.ocean, residual_prediction=self._config.residual_prediction, - prognostic_names=self.prognostic_names, + prognostic_names=self.get_prognostic_names(), prescribed_prognostic_names=self._config.prescribed_prognostic_names, ) diff --git a/fme/ace/stepper/single_module.py b/fme/ace/stepper/single_module.py index afb2ea5fd..1533482d0 100644 --- a/fme/ace/stepper/single_module.py +++ b/fme/ace/stepper/single_module.py @@ -525,7 +525,7 @@ def get_evaluation_window_data_requirements( def get_prognostic_state_data_requirements(self) -> PrognosticStateDataRequirements: return PrognosticStateDataRequirements( - names=self.prognostic_names, + names=self.get_prognostic_names(), n_timesteps=self.n_ic_timesteps, ) @@ -672,10 +672,9 @@ def next_step_forcing_names(self) -> list[str]: """ return self.step.get_next_step_forcing_names() - @property - def prognostic_names(self) -> list[str]: + def get_prognostic_names(self) -> list[str]: """Names of variables which both inputs and outputs.""" - return self.step.prognostic_names + return self.step.get_prognostic_names() @property def output_names(self) -> list[str]: @@ -971,9 +970,8 @@ def get_base_weights(self) -> Weights | None: """ return self._parameter_initializer.base_weights - @property - def prognostic_names(self) -> list[str]: - return self._step_obj.prognostic_names + def get_prognostic_names(self) -> list[str]: + return self._step_obj.get_prognostic_names() @property def out_names(self) -> list[str]: @@ -1174,7 +1172,9 @@ def predict( ) .remove_initial_condition(self.n_ic_timesteps) ) - prognostic_state = data.get_end(self.prognostic_names, self.n_ic_timesteps) + prognostic_state = data.get_end( + self.get_prognostic_names(), self.n_ic_timesteps + ) data = BatchData.new_on_device( data=data.data, time=data.time, @@ -1494,7 +1494,7 @@ def __init__( self._epoch: int | None = None # to keep track of cached values - self._prognostic_names = self._stepper.prognostic_names + self._prognostic_names = self._stepper.get_prognostic_names() self._derive_func = self._stepper.derive_func self._loss_obj = self._stepper.build_loss(config.loss) diff --git a/fme/core/generics/test_looper.py b/fme/core/generics/test_looper.py index 017e69f09..5f37bb86e 100644 --- a/fme/core/generics/test_looper.py +++ b/fme/core/generics/test_looper.py @@ -173,7 +173,7 @@ def test_looper(): data={n: spherical_data.data[n][:, :1] for n in spherical_data.data}, time=time[:, 0:1], ).get_start( - prognostic_names=stepper.prognostic_names, + prognostic_names=stepper.get_prognostic_names(), n_ic_timesteps=1, ) loader = MockLoader(shape, forcing_names, 3, time=time) @@ -197,7 +197,7 @@ def test_looper_paired(): data={n: spherical_data.data[n][:, :1] for n in spherical_data.data}, time=time[:, 0:1], ).get_start( - prognostic_names=stepper.prognostic_names, + prognostic_names=stepper.get_prognostic_names(), n_ic_timesteps=1, ) loader = MockLoader(shape, forcing_names, 3, time=time) @@ -235,7 +235,7 @@ def test_looper_paired_with_derived_variables(): data={n: spherical_data.data[n][:, :1] for n in spherical_data.data}, time=time[:, 0:1], ).get_start( - prognostic_names=stepper.prognostic_names, + prognostic_names=stepper.get_prognostic_names(), n_ic_timesteps=1, ) loader = MockLoader(shape, forcing_names, 2, time=time) @@ -258,7 +258,7 @@ def test_looper_paired_with_target_data(): data={n: spherical_data.data[n][:, :1] for n in spherical_data.data}, time=time[:, 0:1], ).get_start( - prognostic_names=stepper.prognostic_names, + prognostic_names=stepper.get_prognostic_names(), n_ic_timesteps=1, ) loader = MockLoader(shape, all_names, 2, time=time) @@ -284,7 +284,7 @@ def test_looper_paired_with_target_data_and_derived_variables(): data={n: spherical_data.data[n][:, :1] for n in spherical_data.data}, time=time[:, 0:1], ).get_start( - prognostic_names=stepper.prognostic_names, + prognostic_names=stepper.get_prognostic_names(), n_ic_timesteps=1, ) loader = MockLoader(shape, all_names, 2, time=time) diff --git a/fme/core/registry/__init__.py b/fme/core/registry/__init__.py index d832a4545..0ecdbbab2 100644 --- a/fme/core/registry/__init__.py +++ b/fme/core/registry/__init__.py @@ -1,3 +1,4 @@ from .corrector import CorrectorSelector from .module import ModuleSelector from .registry import Registry +from .separated_module import SeparatedModuleSelector diff --git a/fme/core/registry/separated_module.py b/fme/core/registry/separated_module.py new file mode 100644 index 000000000..6505c6d6a --- /dev/null +++ b/fme/core/registry/separated_module.py @@ -0,0 +1,204 @@ +import abc +import dataclasses +from collections.abc import Callable, Mapping +from typing import Any, ClassVar + +import dacite +import torch +from torch import nn + +from fme.core.dataset_info import DatasetInfo +from fme.core.labels import BatchLabels, LabelEncoding + +from .registry import Registry + +SeparatedModuleConfigType = type["SeparatedModuleConfig"] + + +@dataclasses.dataclass +class SeparatedModuleConfig(abc.ABC): + """ + Builds an nn.Module that takes separate forcing and prognostic tensors + as input, and returns separate prognostic and diagnostic tensors as output. + + The built nn.Module must have the forward signature:: + + forward( + forcing: Tensor, + prognostic: Tensor, + labels: Tensor | None = None, + ) -> tuple[Tensor, Tensor] + + where the return value is (prognostic_out, diagnostic_out). + """ + + @abc.abstractmethod + def build( + self, + n_forcing_channels: int, + n_prognostic_channels: int, + n_diagnostic_channels: int, + dataset_info: DatasetInfo, + ) -> nn.Module: + """ + Build a nn.Module with separated forcing/prognostic/diagnostic channels. + + Args: + n_forcing_channels: number of input-only (forcing) channels + n_prognostic_channels: number of input-output (prognostic) channels + n_diagnostic_channels: number of output-only (diagnostic) channels + dataset_info: Information about the dataset, including img_shape, + horizontal coordinates, vertical coordinate, etc. + + Returns: + An nn.Module whose forward method takes + (forcing, prognostic, labels=None) and returns + (prognostic_out, diagnostic_out) tensors. + """ + ... + + @classmethod + def from_state(cls, state: Mapping[str, Any]) -> "SeparatedModuleConfig": + return dacite.from_dict( + data_class=cls, data=state, config=dacite.Config(strict=True) + ) + + +class SeparatedModule: + """ + Wrapper around an nn.Module with separated channel interface. + + The wrapped module takes (forcing, prognostic) tensors and returns + (prognostic_out, diagnostic_out) tensors. This wrapper handles + optional label encoding for conditional models. + """ + + def __init__(self, module: nn.Module, label_encoding: LabelEncoding | None): + self._module = module + self._label_encoding = label_encoding + + def __call__( + self, + forcing: torch.Tensor, + prognostic: torch.Tensor, + labels: BatchLabels | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if labels is not None and self._label_encoding is None: + raise ValueError( + "labels were provided but the module has no label encoding" + ) + if labels is None and self._label_encoding is not None: + raise ValueError( + "labels were not provided but the module has a label encoding" + ) + if labels is not None and self._label_encoding is not None: + encoded_labels = labels.conform_to_encoding(self._label_encoding) + return self._module(forcing, prognostic, labels=encoded_labels.tensor) + else: + return self._module(forcing, prognostic, labels=None) + + @property + def torch_module(self) -> nn.Module: + return self._module + + def get_state(self) -> dict[str, Any]: + if self._label_encoding is not None: + label_encoder_state = self._label_encoding.get_state() + else: + label_encoder_state = None + return { + **self._module.state_dict(), + "label_encoding": label_encoder_state, + } + + def load_state(self, state: dict[str, Any]) -> None: + state = state.copy() + if state.get("label_encoding") is not None: + if self._label_encoding is None: + self._label_encoding = LabelEncoding.from_state( + state.pop("label_encoding") + ) + else: + self._label_encoding.conform_to_state(state.pop("label_encoding")) + state.pop("label_encoding", None) + self._module.load_state_dict(state) + + def wrap_module( + self, callable: Callable[[nn.Module], nn.Module] + ) -> "SeparatedModule": + return SeparatedModule(callable(self._module), self._label_encoding) + + def to(self, device: torch.device) -> "SeparatedModule": + return SeparatedModule(self._module.to(device), self._label_encoding) + + +@dataclasses.dataclass +class SeparatedModuleSelector: + """ + A dataclass for building SeparatedModuleConfig instances from config dicts. + + Mirrors ModuleSelector but for the separated channel interface. + + Parameters: + type: the type of the SeparatedModuleConfig + config: data for a SeparatedModuleConfig instance of the indicated type + """ + + type: str + config: Mapping[str, Any] + registry: ClassVar[Registry[SeparatedModuleConfig]] = Registry[ + SeparatedModuleConfig + ]() + + def __post_init__(self): + if not isinstance(self.registry, Registry): + raise ValueError( + "SeparatedModuleSelector.registry should not be set manually" + ) + self._instance = self.registry.get(self.type, self.config) + + @property + def module_config(self) -> SeparatedModuleConfig: + return self._instance + + @classmethod + def register( + cls, type_name: str + ) -> Callable[[SeparatedModuleConfigType], SeparatedModuleConfigType]: + return cls.registry.register(type_name) + + def build( + self, + n_forcing_channels: int, + n_prognostic_channels: int, + n_diagnostic_channels: int, + dataset_info: DatasetInfo, + ) -> SeparatedModule: + """ + Build a SeparatedModule with separated forcing/prognostic/diagnostic + channels. + + Args: + n_forcing_channels: number of input-only (forcing) channels + n_prognostic_channels: number of input-output (prognostic) channels + n_diagnostic_channels: number of output-only (diagnostic) channels + dataset_info: Information about the dataset, including img_shape. + + Returns: + a SeparatedModule object + """ + if len(dataset_info.all_labels) > 0: + label_encoding = LabelEncoding(sorted(list(dataset_info.all_labels))) + else: + label_encoding = None + module = self._instance.build( + n_forcing_channels=n_forcing_channels, + n_prognostic_channels=n_prognostic_channels, + n_diagnostic_channels=n_diagnostic_channels, + dataset_info=dataset_info, + ) + return SeparatedModule(module, label_encoding) + + @classmethod + def get_available_types(cls): + return cls.registry._types.keys() diff --git a/fme/core/registry/test_separated_module.py b/fme/core/registry/test_separated_module.py new file mode 100644 index 000000000..009b9ffbd --- /dev/null +++ b/fme/core/registry/test_separated_module.py @@ -0,0 +1,109 @@ +import pytest +import torch + +from fme.core.dataset_info import DatasetInfo +from fme.core.labels import BatchLabels, LabelEncoding +from fme.core.registry.separated_module import SeparatedModule, SeparatedModuleSelector +from fme.core.registry.testing import _SimpleSeparatedModule, register_test_types + +register_test_types() + + +# --- Tests for SeparatedModuleSelector --- + + +def test_register_and_build(): + selector = SeparatedModuleSelector(type="test_simple", config={}) + dataset_info = DatasetInfo(img_shape=(4, 8)) + module = selector.build( + n_forcing_channels=2, + n_prognostic_channels=3, + n_diagnostic_channels=1, + dataset_info=dataset_info, + ) + assert isinstance(module, SeparatedModule) + assert isinstance(module.torch_module, _SimpleSeparatedModule) + + +def test_separated_module_forward(): + inner = _SimpleSeparatedModule(n_forcing=2, n_prognostic=3, n_diagnostic=1) + module = SeparatedModule(inner, label_encoding=None) + + forcing = torch.randn(2, 2, 4, 8) + prognostic = torch.randn(2, 3, 4, 8) + prog_out, diag_out = module(forcing, prognostic) + + assert prog_out.shape == (2, 3, 4, 8) + assert diag_out.shape == (2, 1, 4, 8) + + +def test_separated_module_get_and_load_state(): + inner = _SimpleSeparatedModule(n_forcing=2, n_prognostic=3, n_diagnostic=1) + module = SeparatedModule(inner, label_encoding=None) + + state = module.get_state() + assert "label_encoding" in state + assert state["label_encoding"] is None + + inner2 = _SimpleSeparatedModule(n_forcing=2, n_prognostic=3, n_diagnostic=1) + module2 = SeparatedModule(inner2, label_encoding=None) + module2.load_state(state) + + for p1, p2 in zip(inner.parameters(), inner2.parameters()): + assert torch.equal(p1, p2) + + +def test_separated_module_wrap_module(): + inner = _SimpleSeparatedModule(n_forcing=2, n_prognostic=3, n_diagnostic=1) + module = SeparatedModule(inner, label_encoding=None) + + wrapped = module.wrap_module(lambda m: m) + assert isinstance(wrapped, SeparatedModule) + + forcing = torch.randn(1, 2, 4, 8) + prognostic = torch.randn(1, 3, 4, 8) + out1 = module(forcing, prognostic) + out2 = wrapped(forcing, prognostic) + assert torch.equal(out1[0], out2[0]) + assert torch.equal(out1[1], out2[1]) + + +def test_separated_module_accepts_none_labels(): + inner = _SimpleSeparatedModule(n_forcing=2, n_prognostic=3, n_diagnostic=1) + module = SeparatedModule(inner, label_encoding=None) + + forcing = torch.randn(1, 2, 4, 8) + prognostic = torch.randn(1, 3, 4, 8) + + # Should work with no labels + prog_out, diag_out = module(forcing, prognostic) + assert prog_out.shape == (1, 3, 4, 8) + assert diag_out.shape == (1, 1, 4, 8) + + # Should also work with explicit None + prog_out, diag_out = module(forcing, prognostic, labels=None) + assert prog_out.shape == (1, 3, 4, 8) + + +def test_separated_module_raises_if_labels_without_encoding(): + inner = _SimpleSeparatedModule(n_forcing=2, n_prognostic=3, n_diagnostic=1) + module = SeparatedModule(inner, label_encoding=None) + + forcing = torch.randn(1, 2, 4, 8) + prognostic = torch.randn(1, 3, 4, 8) + labels = BatchLabels(tensor=torch.tensor([[0.0]]), names=["task"]) + + with pytest.raises(ValueError, match="no label encoding"): + module(forcing, prognostic, labels=labels) + + +def test_separated_module_raises_if_encoding_without_labels(): + inner = _SimpleSeparatedModule(n_forcing=2, n_prognostic=3, n_diagnostic=1) + label_encoding = LabelEncoding(["a", "b"]) + module = SeparatedModule(inner, label_encoding=label_encoding) + + forcing = torch.randn(1, 2, 4, 8) + prognostic = torch.randn(1, 3, 4, 8) + + with pytest.raises(ValueError, match="labels were not provided"): + module(forcing, prognostic, labels=None) diff --git a/fme/core/registry/testing.py b/fme/core/registry/testing.py new file mode 100644 index 000000000..07c9f9fc5 --- /dev/null +++ b/fme/core/registry/testing.py @@ -0,0 +1,47 @@ +import dataclasses + +import torch +from torch import nn + +from .separated_module import SeparatedModuleConfig, SeparatedModuleSelector + + +class _SimpleSeparatedModule(nn.Module): + """A trivial module for testing the separated interface.""" + + def __init__(self, n_forcing: int, n_prognostic: int, n_diagnostic: int): + super().__init__() + n_in = n_forcing + n_prognostic + self.prog_linear = nn.Linear(n_in, n_prognostic, bias=False) + self.diag_linear = nn.Linear(n_in, n_diagnostic, bias=False) + + def forward( + self, + forcing: torch.Tensor, + prognostic: torch.Tensor, + labels: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + combined = torch.cat([forcing, prognostic], dim=-3) + b, c, h, w = combined.shape + flat = combined.permute(0, 2, 3, 1).reshape(-1, c) + prog_out = self.prog_linear(flat).reshape(b, h, w, -1).permute(0, 3, 1, 2) + diag_out = self.diag_linear(flat).reshape(b, h, w, -1).permute(0, 3, 1, 2) + return prog_out, diag_out + + +def register_test_types() -> None: + """Register test-only module types. Call from tests before use.""" + + @SeparatedModuleSelector.register("test_simple") + @dataclasses.dataclass + class _SimpleSeparatedBuilder(SeparatedModuleConfig): + def build( + self, + n_forcing_channels, + n_prognostic_channels, + n_diagnostic_channels, + dataset_info, + ): + return _SimpleSeparatedModule( + n_forcing_channels, n_prognostic_channels, n_diagnostic_channels + ) diff --git a/fme/core/step/__init__.py b/fme/core/step/__init__.py index 081b4a30a..4e7a1f02d 100644 --- a/fme/core/step/__init__.py +++ b/fme/core/step/__init__.py @@ -1,4 +1,5 @@ from .multi_call import MultiCallStep, MultiCallStepConfig from .radiation import SeparateRadiationStep, SeparateRadiationStepConfig +from .separated_module import SeparatedModuleStep, SeparatedModuleStepConfig from .single_module import SingleModuleStep, SingleModuleStepConfig from .step import StepABC, StepConfigABC, StepSelector diff --git a/fme/core/step/radiation.py b/fme/core/step/radiation.py index f670b96bc..e52ecc863 100644 --- a/fme/core/step/radiation.py +++ b/fme/core/step/radiation.py @@ -15,7 +15,6 @@ from fme.core.distributed import Distributed from fme.core.normalizer import NetworkAndLossNormalizationConfig, StandardNormalizer from fme.core.ocean import Ocean, OceanConfig -from fme.core.optimization import NullOptimization from fme.core.packer import Packer from fme.core.registry import CorrectorSelector, ModuleSelector from fme.core.step.args import StepArgs @@ -128,7 +127,8 @@ def get_loss_normalizer( extra_residual_scaled_names = [] return self.normalization.get_loss_normalizer( names=self._normalize_names + extra_names, - residual_scaled_names=self.prognostic_names + extra_residual_scaled_names, + residual_scaled_names=self.get_prognostic_names() + + extra_residual_scaled_names, ) @classmethod @@ -288,7 +288,6 @@ def __init__( self.radiation_module = radiation_module.to(get_device()) self._img_shape = dataset_info.img_shape self._config = config - self._no_optimization = NullOptimization() init_weights(self.modules) dist = Distributed.get_instance() @@ -395,7 +394,7 @@ def network_calls(input_norm: TensorDict) -> TensorDict: corrector=self._corrector, ocean=self.ocean, residual_prediction=self._config.residual_prediction, - prognostic_names=self.prognostic_names, + prognostic_names=self.get_prognostic_names(), ) def get_regularizer_loss(self) -> torch.Tensor: diff --git a/fme/core/step/separated_module.py b/fme/core/step/separated_module.py new file mode 100644 index 000000000..2dcb1e4f8 --- /dev/null +++ b/fme/core/step/separated_module.py @@ -0,0 +1,384 @@ +import dataclasses +import logging +from collections.abc import Callable +from typing import Any + +import dacite +import torch +from torch import nn + +from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig +from fme.core.corrector.registry import CorrectorABC +from fme.core.dataset_info import DatasetInfo +from fme.core.device import get_device +from fme.core.distributed import Distributed +from fme.core.normalizer import NetworkAndLossNormalizationConfig, StandardNormalizer +from fme.core.ocean import Ocean, OceanConfig +from fme.core.packer import Packer +from fme.core.registry import CorrectorSelector, SeparatedModuleSelector +from fme.core.step.args import StepArgs +from fme.core.step.secondary_decoder import ( + NoSecondaryDecoder, + SecondaryDecoder, + SecondaryDecoderConfig, +) +from fme.core.step.single_module import step_with_adjustments +from fme.core.step.step import StepABC, StepConfigABC, StepSelector +from fme.core.typing_ import TensorDict, TensorMapping + + +@StepSelector.register("separated_module") +@dataclasses.dataclass +class SeparatedModuleStepConfig(StepConfigABC): + """ + Configuration for a step using a module with separated channel interface. + + Unlike SingleModuleStepConfig which uses in_names/out_names and derives + forcing/prognostic/diagnostic from set operations, this config specifies + the three channel categories explicitly. + + Parameters: + builder: The separated module builder. + forcing_names: Names of input-only (forcing) variables. + prognostic_names: Names of input-output (prognostic) variables. + diagnostic_names: Names of output-only (diagnostic) variables. + normalization: The normalization configuration. + secondary_decoder: Configuration for the secondary decoder that computes + additional diagnostic variables from outputs. + ocean: The ocean configuration. + corrector: The corrector configuration. + next_step_forcing_names: Names of forcing variables for the next timestep. + prescribed_prognostic_names: Prognostic variable names to overwrite from + forcing data at each step (e.g. for inference with observed values). + residual_prediction: Whether to use residual prediction. + """ + + builder: SeparatedModuleSelector + prognostic_names: list[str] + normalization: NetworkAndLossNormalizationConfig + forcing_names: list[str] = dataclasses.field(default_factory=list) + diagnostic_names: list[str] = dataclasses.field(default_factory=list) + secondary_decoder: SecondaryDecoderConfig | None = None + ocean: OceanConfig | None = None + corrector: AtmosphereCorrectorConfig | CorrectorSelector = dataclasses.field( + default_factory=lambda: AtmosphereCorrectorConfig() + ) + next_step_forcing_names: list[str] = dataclasses.field(default_factory=list) + prescribed_prognostic_names: list[str] = dataclasses.field(default_factory=list) + residual_prediction: bool = False + + def __post_init__(self): + if len(self.prognostic_names) == 0: + raise ValueError("prognostic_names must not be empty") + all_names = self.forcing_names + self.prognostic_names + self.diagnostic_names + if len(all_names) != len(set(all_names)): + seen: dict[str, str] = {} + for name_list, label in ( + (self.forcing_names, "forcing_names"), + (self.prognostic_names, "prognostic_names"), + (self.diagnostic_names, "diagnostic_names"), + ): + for name in name_list: + if name in seen: + raise ValueError( + f"Name '{name}' appears in both " + f"{seen[name]} and {label}." + ) + seen[name] = label + for name in self.prescribed_prognostic_names: + if name not in self.prognostic_names: + raise ValueError( + f"prescribed_prognostic_name '{name}' must be in " + f"prognostic_names: {self.prognostic_names}" + ) + for name in self.next_step_forcing_names: + if name not in self.forcing_names: + raise ValueError( + f"next_step_forcing_name '{name}' not in " + f"forcing_names: {self.forcing_names}" + ) + if self.secondary_decoder is not None: + for name in self.secondary_decoder.secondary_diagnostic_names: + if name in self.forcing_names: + raise ValueError( + f"secondary_diagnostic_name is a forcing variable: '{name}'" + ) + if name in self.prognostic_names: + raise ValueError( + f"secondary_diagnostic_name is a prognostic variable: " + f"'{name}'" + ) + if name in self.diagnostic_names: + raise ValueError( + f"secondary_diagnostic_name is a diagnostic variable: " + f"'{name}'" + ) + + @property + def n_ic_timesteps(self) -> int: + return 1 + + def get_state(self): + return dataclasses.asdict(self) + + def get_loss_normalizer( + self, + extra_names: list[str] | None = None, + extra_residual_scaled_names: list[str] | None = None, + ) -> StandardNormalizer: + if extra_names is None: + extra_names = [] + if extra_residual_scaled_names is None: + extra_residual_scaled_names = [] + return self.normalization.get_loss_normalizer( + names=self._normalize_names + extra_names, + residual_scaled_names=(self.prognostic_names + extra_residual_scaled_names), + ) + + @classmethod + def from_state(cls, state) -> "SeparatedModuleStepConfig": + return dacite.from_dict( + data_class=cls, data=state, config=dacite.Config(strict=True) + ) + + @property + def _normalize_names(self): + """Names of variables which require normalization.""" + return list(set(self.forcing_names).union(self.output_names)) + + @property + def input_names(self) -> list[str]: + names = self.forcing_names + self.prognostic_names + if self.ocean is not None: + names = list(set(names).union(self.ocean.forcing_names)) + return names + + def get_next_step_forcing_names(self) -> list[str]: + return self.next_step_forcing_names + + @property + def output_names(self) -> list[str]: + names = self.prognostic_names + self.diagnostic_names + if self.secondary_decoder is not None: + names = list( + set(names).union(self.secondary_decoder.secondary_diagnostic_names) + ) + return names + + @property + def next_step_input_names(self) -> list[str]: + input_only_names = set(self.input_names).difference(self.output_names) + result = set(input_only_names) + if self.ocean is not None: + result = result.union(self.ocean.forcing_names) + result = result.union(self.prescribed_prognostic_names) + return list(result) + + @property + def loss_names(self) -> list[str]: + return self.output_names + + def replace_ocean(self, ocean: OceanConfig | None): + self.ocean = ocean + + def get_ocean(self) -> OceanConfig | None: + return self.ocean + + def replace_prescribed_prognostic_names(self, names: list[str]) -> None: + for name in names: + if name not in self.prognostic_names: + raise ValueError( + f"prescribed_prognostic_name '{name}' must be in " + f"prognostic_names: {self.prognostic_names}" + ) + self.prescribed_prognostic_names = names + + def get_step( + self, + dataset_info: DatasetInfo, + init_weights: Callable[[list[nn.Module]], None], + ) -> "SeparatedModuleStep": + logging.info("Initializing separated module stepper from provided config") + corrector = self.corrector.get_corrector(dataset_info) + normalizer = self.normalization.get_network_normalizer(self._normalize_names) + return SeparatedModuleStep( + config=self, + dataset_info=dataset_info, + corrector=corrector, + normalizer=normalizer, + init_weights=init_weights, + ) + + def load(self): + self.normalization.load() + + +class SeparatedModuleStep(StepABC): + """ + Step class for a module with separated forcing/prognostic/diagnostic + channel interface. + """ + + CHANNEL_DIM = -3 + + def __init__( + self, + config: SeparatedModuleStepConfig, + dataset_info: DatasetInfo, + corrector: CorrectorABC, + normalizer: StandardNormalizer, + init_weights: Callable[[list[nn.Module]], None], + ): + super().__init__() + n_forcing = len(config.forcing_names) + n_prognostic = len(config.prognostic_names) + n_diagnostic = len(config.diagnostic_names) + + self.forcing_packer = Packer(config.forcing_names) + self.prognostic_packer = Packer(config.prognostic_names) + self.diagnostic_packer = Packer(config.diagnostic_names) + + self._normalizer = normalizer + if config.ocean is not None: + self.ocean: Ocean | None = config.ocean.build( + config.forcing_names + config.prognostic_names, + config.prognostic_names + config.diagnostic_names, + dataset_info.timestep, + ) + else: + self.ocean = None + + module = config.builder.build( + n_forcing_channels=n_forcing, + n_prognostic_channels=n_prognostic, + n_diagnostic_channels=n_diagnostic, + dataset_info=dataset_info, + ) + self.module = module.to(get_device()) + + dist = Distributed.get_instance() + + if config.secondary_decoder is not None: + self.secondary_decoder: SecondaryDecoder | NoSecondaryDecoder = ( + config.secondary_decoder.build( + n_in_channels=n_prognostic + n_diagnostic, + ).to(get_device()) + ) + else: + self.secondary_decoder = NoSecondaryDecoder() + + init_weights(self.modules) + self._img_shape = dataset_info.img_shape + self._config = config + + self.module = self.module.wrap_module(dist.wrap_module) + self.secondary_decoder = self.secondary_decoder.wrap_module(dist.wrap_module) + self._timestep = dataset_info.timestep + + self._corrector = corrector + + @property + def config(self) -> SeparatedModuleStepConfig: + return self._config + + @property + def normalizer(self) -> StandardNormalizer: + return self._normalizer + + @property + def surface_temperature_name(self) -> str | None: + if self._config.ocean is not None: + return self._config.ocean.surface_temperature_name + return None + + @property + def ocean_fraction_name(self) -> str | None: + if self._config.ocean is not None: + return self._config.ocean.ocean_fraction_name + return None + + def prescribe_sst( + self, + mask_data: TensorMapping, + gen_data: TensorMapping, + target_data: TensorMapping, + ) -> TensorDict: + if self.ocean is None: + raise RuntimeError( + "The Ocean interface is missing but required to prescribe " + "sea surface temperature." + ) + return self.ocean.prescriber(mask_data, gen_data, target_data) + + @property + def modules(self) -> nn.ModuleList: + modules = [self.module.torch_module] + modules.extend(self.secondary_decoder.torch_modules) + return nn.ModuleList(modules) + + def step( + self, + args: StepArgs, + wrapper: Callable[[nn.Module], nn.Module] = lambda x: x, + ) -> TensorDict: + def network_call(input_norm: TensorDict) -> TensorDict: + prognostic_in = self.prognostic_packer.pack( + input_norm, axis=self.CHANNEL_DIM + ) + + if len(self.forcing_packer.names) > 0: + forcing = self.forcing_packer.pack(input_norm, axis=self.CHANNEL_DIM) + else: + forcing = torch.zeros( + *prognostic_in.shape[:-3], + 0, + *prognostic_in.shape[-2:], + dtype=prognostic_in.dtype, + device=prognostic_in.device, + ) + + prog_out, diag_out = self.module.wrap_module(wrapper)( + forcing, prognostic_in, labels=args.labels + ) + + output_dict: TensorDict = {} + output_dict.update( + self.prognostic_packer.unpack(prog_out, axis=self.CHANNEL_DIM) + ) + if len(self.diagnostic_packer.names) > 0: + output_dict.update( + self.diagnostic_packer.unpack(diag_out, axis=self.CHANNEL_DIM) + ) + + # Secondary decoder gets concatenated prog+diag output + combined_out = torch.cat([prog_out, diag_out], dim=self.CHANNEL_DIM) + secondary_output_dict = self.secondary_decoder.wrap_module(wrapper)( + combined_out.detach() + ) + output_dict.update(secondary_output_dict) + return output_dict + + return step_with_adjustments( + input=args.input, + next_step_input_data=args.next_step_input_data, + network_calls=network_call, + normalizer=self.normalizer, + corrector=self._corrector, + ocean=self.ocean, + residual_prediction=self._config.residual_prediction, + prognostic_names=self._config.prognostic_names, + prescribed_prognostic_names=self._config.prescribed_prognostic_names, + ) + + def get_regularizer_loss(self): + return torch.tensor(0.0) + + def get_state(self): + return { + "module": self.module.get_state(), + "secondary_decoder": self.secondary_decoder.get_module_state(), + } + + def load_state(self, state: dict[str, Any]) -> None: + self.module.load_state(state["module"]) + self.secondary_decoder.load_module_state(state["secondary_decoder"]) diff --git a/fme/core/step/single_module.py b/fme/core/step/single_module.py index 9eb94ee79..0744c896c 100644 --- a/fme/core/step/single_module.py +++ b/fme/core/step/single_module.py @@ -17,7 +17,6 @@ from fme.core.distributed import Distributed from fme.core.normalizer import NetworkAndLossNormalizationConfig, StandardNormalizer from fme.core.ocean import Ocean, OceanConfig -from fme.core.optimization import NullOptimization from fme.core.packer import Packer from fme.core.registry import CorrectorSelector, ModuleSelector from fme.core.step.args import StepArgs @@ -114,7 +113,8 @@ def get_loss_normalizer( extra_residual_scaled_names = [] return self.normalization.get_loss_normalizer( names=self._normalize_names + extra_names, - residual_scaled_names=self.prognostic_names + extra_residual_scaled_names, + residual_scaled_names=self.get_prognostic_names() + + extra_residual_scaled_names, ) @classmethod @@ -279,7 +279,6 @@ def __init__( init_weights(self.modules) self._img_shape = dataset_info.img_shape self._config = config - self._no_optimization = NullOptimization() self.module = self.module.wrap_module(dist.wrap_module) self.secondary_decoder = self.secondary_decoder.wrap_module(dist.wrap_module) @@ -369,7 +368,7 @@ def network_call(input_norm: TensorDict) -> TensorDict: corrector=self._corrector, ocean=self.ocean, residual_prediction=self._config.residual_prediction, - prognostic_names=self.prognostic_names, + prognostic_names=self.get_prognostic_names(), prescribed_prognostic_names=self._config.prescribed_prognostic_names, ) diff --git a/fme/core/step/step.py b/fme/core/step/step.py index 6c339834f..1eababa74 100644 --- a/fme/core/step/step.py +++ b/fme/core/step/step.py @@ -64,9 +64,8 @@ def next_step_input_names(self) -> list[str]: """ pass - @property @final - def prognostic_names(self) -> list[str]: + def get_prognostic_names(self) -> list[str]: return list(set(self.input_names).intersection(self.output_names)) @property @@ -251,10 +250,9 @@ def input_names(self) -> list[str]: def output_names(self) -> list[str]: return self.config.output_names - @property @final - def prognostic_names(self) -> list[str]: - return self.config.prognostic_names + def get_prognostic_names(self) -> list[str]: + return self.config.get_prognostic_names() @property @final diff --git a/fme/core/step/test_separated_module.py b/fme/core/step/test_separated_module.py new file mode 100644 index 000000000..407a2433f --- /dev/null +++ b/fme/core/step/test_separated_module.py @@ -0,0 +1,351 @@ +import dataclasses +from datetime import timedelta + +import pytest +import torch + +import fme +from fme.core.coordinates import HybridSigmaPressureCoordinate, LatLonCoordinates +from fme.core.dataset_info import DatasetInfo +from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig +from fme.core.registry import SeparatedModuleSelector +from fme.core.registry.testing import register_test_types +from fme.core.step.args import StepArgs +from fme.core.step.separated_module import SeparatedModuleStepConfig +from fme.core.step.step import StepSelector + +register_test_types() + +IMG_SHAPE = (16, 32) +TIMESTEP = timedelta(hours=6) + + +def _get_dataset_info(): + device = fme.get_device() + return DatasetInfo( + horizontal_coordinates=LatLonCoordinates( + lat=torch.zeros(IMG_SHAPE[0], device=device), + lon=torch.zeros(IMG_SHAPE[1], device=device), + ), + vertical_coordinate=HybridSigmaPressureCoordinate( + ak=torch.arange(7, device=device), + bk=torch.arange(7, device=device), + ), + timestep=TIMESTEP, + ) + + +def _get_normalization(names): + return NetworkAndLossNormalizationConfig( + network=NormalizationConfig( + means={name: 0.0 for name in names}, + stds={name: 1.0 for name in names}, + ), + ) + + +def _get_tensor_dict(names, n_samples=2): + device = fme.get_device() + return {name: torch.rand(n_samples, *IMG_SHAPE, device=device) for name in names} + + +class TestSeparatedModuleStepConfig: + def test_duplicate_names_raises(self): + normalization = _get_normalization(["a", "b"]) + with pytest.raises(ValueError, match="appears in both"): + SeparatedModuleStepConfig( + builder=SeparatedModuleSelector( + type="test_simple", + config={}, + ), + forcing_names=["a"], + prognostic_names=["a"], + diagnostic_names=["b"], + normalization=normalization, + ) + + def test_prescribed_prognostic_not_in_prognostic_raises(self): + normalization = _get_normalization(["f", "p", "d"]) + with pytest.raises(ValueError, match="prescribed_prognostic_name"): + SeparatedModuleStepConfig( + builder=SeparatedModuleSelector( + type="test_simple", + config={}, + ), + forcing_names=["f"], + prognostic_names=["p"], + diagnostic_names=["d"], + normalization=normalization, + prescribed_prognostic_names=["d"], + ) + + def test_next_step_forcing_not_in_forcing_raises(self): + normalization = _get_normalization(["f", "p", "d"]) + with pytest.raises(ValueError, match="next_step_forcing_name"): + SeparatedModuleStepConfig( + builder=SeparatedModuleSelector( + type="test_simple", + config={}, + ), + forcing_names=["f"], + prognostic_names=["p"], + diagnostic_names=["d"], + normalization=normalization, + next_step_forcing_names=["p"], + ) + + def test_empty_prognostic_names_raises(self): + normalization = _get_normalization(["f", "d"]) + with pytest.raises(ValueError, match="prognostic_names must not be empty"): + SeparatedModuleStepConfig( + builder=SeparatedModuleSelector( + type="test_simple", + config={}, + ), + forcing_names=["f"], + prognostic_names=[], + diagnostic_names=["d"], + normalization=normalization, + ) + + def test_input_output_names(self): + normalization = _get_normalization(["f1", "f2", "p1", "p2", "d1"]) + config = SeparatedModuleStepConfig( + builder=SeparatedModuleSelector( + type="test_simple", + config={}, + ), + forcing_names=["f1", "f2"], + prognostic_names=["p1", "p2"], + diagnostic_names=["d1"], + normalization=normalization, + ) + assert set(config.input_names) == {"f1", "f2", "p1", "p2"} + assert set(config.output_names) == {"p1", "p2", "d1"} + assert set(config.get_prognostic_names()) == {"p1", "p2"} + + def test_from_state_roundtrip(self): + normalization = _get_normalization(["f", "p", "d"]) + config = SeparatedModuleStepConfig( + builder=SeparatedModuleSelector( + type="test_simple", + config={}, + ), + forcing_names=["f"], + prognostic_names=["p"], + diagnostic_names=["d"], + normalization=normalization, + ) + state = config.get_state() + config2 = SeparatedModuleStepConfig.from_state(state) + assert config2.forcing_names == config.forcing_names + assert config2.prognostic_names == config.prognostic_names + assert config2.diagnostic_names == config.diagnostic_names + + +class TestSeparatedModuleStep: + def test_step_produces_output(self): + forcing_names = ["f1", "f2"] + prognostic_names = ["p1", "p2"] + diagnostic_names = ["d1"] + all_names = forcing_names + prognostic_names + diagnostic_names + normalization = _get_normalization(all_names) + + selector = StepSelector( + type="separated_module", + config=dataclasses.asdict( + SeparatedModuleStepConfig( + builder=SeparatedModuleSelector( + type="test_simple", + config={}, + ), + forcing_names=forcing_names, + prognostic_names=prognostic_names, + diagnostic_names=diagnostic_names, + normalization=normalization, + ), + ), + ) + + step = selector.get_step(_get_dataset_info()) + input_data = _get_tensor_dict(step.input_names) + next_step_data = _get_tensor_dict(step.next_step_input_names) + + output = step.step( + StepArgs( + input=input_data, + next_step_input_data=next_step_data, + labels=None, + ) + ) + + for name in prognostic_names + diagnostic_names: + assert name in output + assert output[name].shape == (2, *IMG_SHAPE) + + def test_get_state_and_load_state(self): + forcing_names = ["f1"] + prognostic_names = ["p1"] + diagnostic_names = ["d1"] + all_names = forcing_names + prognostic_names + diagnostic_names + normalization = _get_normalization(all_names) + + config = SeparatedModuleStepConfig( + builder=SeparatedModuleSelector( + type="test_simple", + config={}, + ), + forcing_names=forcing_names, + prognostic_names=prognostic_names, + diagnostic_names=diagnostic_names, + normalization=normalization, + ) + + dataset_info = _get_dataset_info() + step1 = config.get_step(dataset_info, lambda _: None) + state = step1.get_state() + + step2 = config.get_step(dataset_info, lambda _: None) + step2.load_state(state) + + # Verify weights match using the public .modules API + for m1, m2 in zip(step1.modules, step2.modules): + for p1, p2 in zip(m1.parameters(), m2.parameters()): + assert torch.equal(p1, p2) + + def test_residual_prediction(self): + forcing_names = ["f1"] + prognostic_names = ["p1"] + diagnostic_names = ["d1"] + all_names = forcing_names + prognostic_names + diagnostic_names + normalization = _get_normalization(all_names) + + config = SeparatedModuleStepConfig( + builder=SeparatedModuleSelector( + type="test_simple", + config={}, + ), + forcing_names=forcing_names, + prognostic_names=prognostic_names, + diagnostic_names=diagnostic_names, + normalization=normalization, + residual_prediction=True, + ) + + dataset_info = _get_dataset_info() + step = config.get_step(dataset_info, lambda _: None) + input_data = _get_tensor_dict(step.input_names) + next_step_data = _get_tensor_dict(step.next_step_input_names) + + output = step.step( + StepArgs( + input=input_data, + next_step_input_data=next_step_data, + labels=None, + ) + ) + + assert "p1" in output + assert "d1" in output + + def test_step_no_forcings(self): + prognostic_names = ["p1", "p2"] + diagnostic_names = ["d1"] + all_names = prognostic_names + diagnostic_names + normalization = _get_normalization(all_names) + + config = SeparatedModuleStepConfig( + builder=SeparatedModuleSelector( + type="test_simple", + config={}, + ), + forcing_names=[], + prognostic_names=prognostic_names, + diagnostic_names=diagnostic_names, + normalization=normalization, + ) + + dataset_info = _get_dataset_info() + step = config.get_step(dataset_info, lambda _: None) + input_data = _get_tensor_dict(step.input_names) + next_step_data = _get_tensor_dict(step.next_step_input_names) + + output = step.step( + StepArgs( + input=input_data, + next_step_input_data=next_step_data, + labels=None, + ) + ) + + for name in prognostic_names + diagnostic_names: + assert name in output + assert output[name].shape == (2, *IMG_SHAPE) + + def test_step_no_diagnostics(self): + forcing_names = ["f1"] + prognostic_names = ["p1", "p2"] + all_names = forcing_names + prognostic_names + normalization = _get_normalization(all_names) + + config = SeparatedModuleStepConfig( + builder=SeparatedModuleSelector( + type="test_simple", + config={}, + ), + forcing_names=forcing_names, + prognostic_names=prognostic_names, + diagnostic_names=[], + normalization=normalization, + ) + + dataset_info = _get_dataset_info() + step = config.get_step(dataset_info, lambda _: None) + input_data = _get_tensor_dict(step.input_names) + next_step_data = _get_tensor_dict(step.next_step_input_names) + + output = step.step( + StepArgs( + input=input_data, + next_step_input_data=next_step_data, + labels=None, + ) + ) + + for name in prognostic_names: + assert name in output + assert output[name].shape == (2, *IMG_SHAPE) + assert len(output) == len(prognostic_names) + + def test_step_no_forcings_no_diagnostics(self): + prognostic_names = ["p1", "p2"] + normalization = _get_normalization(prognostic_names) + + config = SeparatedModuleStepConfig( + builder=SeparatedModuleSelector( + type="test_simple", + config={}, + ), + forcing_names=[], + prognostic_names=prognostic_names, + diagnostic_names=[], + normalization=normalization, + ) + + dataset_info = _get_dataset_info() + step = config.get_step(dataset_info, lambda _: None) + input_data = _get_tensor_dict(step.input_names) + next_step_data = _get_tensor_dict(step.next_step_input_names) + + output = step.step( + StepArgs( + input=input_data, + next_step_input_data=next_step_data, + labels=None, + ) + ) + + for name in prognostic_names: + assert name in output + assert output[name].shape == (2, *IMG_SHAPE) + assert len(output) == len(prognostic_names) diff --git a/fme/core/step/test_step.py b/fme/core/step/test_step.py index 5d98df1cf..eefb5ca08 100644 --- a/fme/core/step/test_step.py +++ b/fme/core/step/test_step.py @@ -22,16 +22,20 @@ from fme.core.distributed.non_distributed import DummyWrapper from fme.core.labels import BatchLabels from fme.core.normalizer import NetworkAndLossNormalizationConfig, NormalizationConfig -from fme.core.registry import ModuleSelector +from fme.core.registry import ModuleSelector, SeparatedModuleSelector +from fme.core.registry.testing import register_test_types from fme.core.step.args import StepArgs from fme.core.step.multi_call import MultiCallConfig, MultiCallStepConfig from fme.core.step.secondary_decoder import SecondaryDecoderConfig +from fme.core.step.separated_module import SeparatedModuleStepConfig from fme.core.step.single_module import SingleModuleStepConfig from fme.core.step.step import StepABC, StepSelector from fme.core.typing_ import TensorDict from .radiation import SeparateRadiationStepConfig +register_test_types() + DEFAULT_IMG_SHAPE = (45, 90) @@ -374,6 +378,64 @@ def get_fcn3_selector( ) +def get_separated_module_selector( + dir: pathlib.Path | None = None, +) -> StepSelector: + normalization = get_network_and_loss_normalization_config( + names=[ + "forcing_shared", + "forcing_rad", + "prog_main", + "diagnostic_rad", + ], + dir=dir, + ) + return StepSelector( + type="separated_module", + config=dataclasses.asdict( + SeparatedModuleStepConfig( + builder=SeparatedModuleSelector( + type="test_simple", + config={}, + ), + forcing_names=["forcing_shared", "forcing_rad"], + prognostic_names=["prog_main"], + diagnostic_names=["diagnostic_rad"], + normalization=normalization, + ), + ), + ) + + +def get_separated_module_with_prognostics_selector( + dir: pathlib.Path | None = None, +) -> StepSelector: + normalization = get_network_and_loss_normalization_config( + names=[ + "forcing_a", + "prog_a", + "prog_b", + "diag_a", + ], + dir=dir, + ) + return StepSelector( + type="separated_module", + config=dataclasses.asdict( + SeparatedModuleStepConfig( + builder=SeparatedModuleSelector( + type="test_simple", + config={}, + ), + forcing_names=["forcing_a"], + prognostic_names=["prog_a", "prog_b"], + diagnostic_names=["diag_a"], + normalization=normalization, + ), + ), + ) + + def get_multi_call_selector( dir: pathlib.Path | None = None, ) -> StepSelector: @@ -403,6 +465,8 @@ def get_multi_call_selector( get_separate_radiation_selector, get_single_module_selector, get_single_module_noise_conditioned_selector, + get_separated_module_selector, + get_separated_module_with_prognostics_selector, get_multi_call_selector, ] diff --git a/fme/coupled/inference/inference.py b/fme/coupled/inference/inference.py index e1aef9694..aa343d45d 100644 --- a/fme/coupled/inference/inference.py +++ b/fme/coupled/inference/inference.py @@ -229,8 +229,8 @@ def run_inference_from_config(config: InferenceConfig): ) logging.info("Loading initial condition data") initial_condition = config.initial_condition.get_initial_condition( - ocean_prognostic_names=stepper_config.ocean.stepper.prognostic_names, - atmosphere_prognostic_names=stepper_config.atmosphere.stepper.prognostic_names, + ocean_prognostic_names=stepper_config.ocean.stepper.get_prognostic_names(), + atmosphere_prognostic_names=stepper_config.atmosphere.stepper.get_prognostic_names(), n_ensemble_per_ic=config.n_ensemble_per_ic, ) stepper = config.load_stepper() diff --git a/fme/coupled/stepper.py b/fme/coupled/stepper.py index 1fa9f3efc..63143a5bc 100644 --- a/fme/coupled/stepper.py +++ b/fme/coupled/stepper.py @@ -483,7 +483,7 @@ def _validate_component_configs(self): # validate ocean_fraction_prediction if self.ocean_fraction_prediction is not None: self.ocean_fraction_prediction.validate_ocean_prognostic_names( - self.ocean.stepper.prognostic_names, + self.ocean.stepper.get_prognostic_names(), ) self.ocean_fraction_prediction.validate_atmosphere_forcing_names( self.atmosphere.stepper.input_only_names @@ -1261,11 +1261,11 @@ def predict_paired( ), CoupledPrognosticState( ocean_data=gen_data.ocean_data.get_end( - self.ocean.prognostic_names, + self.ocean.get_prognostic_names(), self.n_ic_timesteps, ), atmosphere_data=gen_data.atmosphere_data.get_end( - self.atmosphere.prognostic_names, + self.atmosphere.get_prognostic_names(), self.atmosphere.n_ic_timesteps, ), ), @@ -1284,11 +1284,11 @@ def predict( ), CoupledPrognosticState( ocean_data=gen_data.ocean_data.get_end( - self.ocean.prognostic_names, + self.ocean.get_prognostic_names(), self.n_ic_timesteps, ), atmosphere_data=gen_data.atmosphere_data.get_end( - self.atmosphere.prognostic_names, + self.atmosphere.get_prognostic_names(), self.atmosphere.n_ic_timesteps, ), ), @@ -1513,10 +1513,10 @@ def train_on_batch( # get initial condition prognostic variables input_data = CoupledPrognosticState( atmosphere_data=data.atmosphere_data.get_start( - self.atmosphere.prognostic_names, self.n_ic_timesteps + self.atmosphere.get_prognostic_names(), self.n_ic_timesteps ), ocean_data=data.ocean_data.get_start( - self.ocean.prognostic_names, self.n_ic_timesteps + self.ocean.get_prognostic_names(), self.n_ic_timesteps ), ) diff --git a/fme/coupled/test_stepper.py b/fme/coupled/test_stepper.py index 6e00684f6..a575a9011 100644 --- a/fme/coupled/test_stepper.py +++ b/fme/coupled/test_stepper.py @@ -1330,8 +1330,8 @@ def forward(self, x): ) data = coupled_data.data - atmos_prognostic_names = coupler.atmosphere.prognostic_names - ocean_prognostic_names = coupler.ocean.prognostic_names + atmos_prognostic_names = coupler.atmosphere.get_prognostic_names() + ocean_prognostic_names = coupler.ocean.get_prognostic_names() atmos_prognostic = data.atmosphere_data.get_start( atmos_prognostic_names, n_ic_timesteps=1 ) @@ -1457,8 +1457,8 @@ def test_predict_paired_with_derived_variables(): ) data = coupled_data.data - atmos_prognostic_names = coupler.atmosphere.prognostic_names - ocean_prognostic_names = coupler.ocean.prognostic_names + atmos_prognostic_names = coupler.atmosphere.get_prognostic_names() + ocean_prognostic_names = coupler.ocean.get_prognostic_names() atmos_prognostic = data.atmosphere_data.get_start( atmos_prognostic_names, n_ic_timesteps=1 )