From 7e46478cc3977d04854e09d07a54347207b48323 Mon Sep 17 00:00:00 2001 From: James Duncan Date: Wed, 25 Feb 2026 22:45:51 -0800 Subject: [PATCH 1/3] Add OceanSaltContentBudgetConfig --- fme/core/constants.py | 1 + fme/core/corrector/ocean.py | 93 ++++++++++++++++++++++++++++- fme/core/corrector/test_ocean.py | 86 ++++++++++++++++++++++++++ fme/core/ocean_data.py | 37 +++++++++++- fme/core/ocean_derived_variables.py | 39 ++++++++++++ fme/core/test_ocean_data.py | 44 ++++++++++++++ 6 files changed, 298 insertions(+), 2 deletions(-) diff --git a/fme/core/constants.py b/fme/core/constants.py index 117ae813f..6248d4d1d 100644 --- a/fme/core/constants.py +++ b/fme/core/constants.py @@ -16,5 +16,6 @@ DENSITY_OF_WATER_CM4 = 1035.0 # kg/m^3 FREEZING_TEMPERATURE_KELVIN = 273.15 # K +REFERENCE_SALINITY_PSU = 35 # g/kg EARTH_RADIUS = 6371000.0 # m diff --git a/fme/core/corrector/ocean.py b/fme/core/corrector/ocean.py index ef9a915de..da60558a0 100644 --- a/fme/core/corrector/ocean.py +++ b/fme/core/corrector/ocean.py @@ -6,7 +6,7 @@ import dacite import torch -from fme.core.constants import FREEZING_TEMPERATURE_KELVIN +from fme.core.constants import FREEZING_TEMPERATURE_KELVIN, REFERENCE_SALINITY_PSU from fme.core.corrector.registry import CorrectorABC from fme.core.corrector.utils import force_positive from fme.core.gridded_ops import GriddedOperations @@ -79,12 +79,32 @@ class OceanHeatContentBudgetConfig: constant_unaccounted_heating: float = 0.0 +@dataclasses.dataclass +class OceanSaltContentBudgetConfig: + """Configuration for ocean salt content budget correction. + + Parameters: + method: Method to use for salt content budget correction. The available + option is "scaled_salinity", which enforces conservation of salt + content by scaling the predicted salinity by a vertically and + horizontally uniform correction factor. + constant_unaccounted_salt_flux: Area-weighted global mean + column-integrated salt flux in g/m**2/s to be added to the virtual + salt flux when conserving the salt content. This can be useful for + correcting errors in salt budget in target data. + """ + + method: Literal["scaled_salinity"] + constant_unaccounted_salt_flux: float = 0.0 + + @CorrectorSelector.register("ocean_corrector") @dataclasses.dataclass class OceanCorrectorConfig: force_positive_names: list[str] = dataclasses.field(default_factory=list) sea_ice_fraction_correction: SeaIceFractionConfig | None = None ocean_heat_content_correction: OceanHeatContentBudgetConfig | None = None + ocean_salt_content_correction: OceanSaltContentBudgetConfig | None = None @classmethod def from_state(cls, state: Mapping[str, Any]) -> "OceanCorrectorConfig": @@ -141,6 +161,22 @@ def __call__( gen_data = force_positive(gen_data, self._config.force_positive_names) if self._config.sea_ice_fraction_correction is not None: gen_data = self._config.sea_ice_fraction_correction(gen_data, input_data) + if self._config.ocean_salt_content_correction is not None: + if self._vertical_coordinate is None: + raise ValueError( + "Ocean salt content correction is turned on, but no vertical " + "coordinate is available." + ) + gen_data = _force_conserve_ocean_salt_content( + input_data, + gen_data, + forcing_data, + self._gridded_operations.area_weighted_mean, + self._vertical_coordinate, + self._timestep.total_seconds(), + self._config.ocean_salt_content_correction.method, + self._config.ocean_salt_content_correction.constant_unaccounted_salt_flux, + ) if self._config.ocean_heat_content_correction is not None: if self._vertical_coordinate is None: raise ValueError( @@ -240,3 +276,58 @@ def _force_conserve_ocean_heat_content( gen.data["sst"] - FREEZING_TEMPERATURE_KELVIN ) * heat_content_correction_ratio + FREEZING_TEMPERATURE_KELVIN return gen.data + + +def _force_conserve_ocean_salt_content( + input_data: TensorMapping, + gen_data: TensorMapping, + forcing_data: TensorMapping, + area_weighted_mean: AreaWeightedMean, + vertical_coordinate: HasOceanDepthIntegral, + timestep_seconds: float, + method: Literal["scaled_salinity"] = "scaled_salinity", + unaccounted_salt_flux: float = 0.0, +) -> TensorDict: + if method != "scaled_salinity": + raise NotImplementedError( + f"Method {method!r} not implemented for ocean salt content conservation" + ) + if "wfo" in gen_data and "wfo" in forcing_data: + raise ValueError( + "Water flux into sea water cannot be present in both gen_data and " + "forcing_data." + ) + input = OceanData(input_data, vertical_coordinate) + gen = OceanData(gen_data, vertical_coordinate) + forcing = OceanData(forcing_data) + global_gen_salt_content = area_weighted_mean( + gen.ocean_salt_content, + keepdim=True, + name="ocean_salt_content", + ) + global_input_salt_content = area_weighted_mean( + input.ocean_salt_content, + keepdim=True, + name="ocean_salt_content", + ) + try: + wfo = gen.water_flux_into_sea_water + except KeyError: + wfo = input.water_flux_into_sea_water + virtual_salt_flux = -REFERENCE_SALINITY_PSU * wfo * forcing.sea_surface_fraction + salt_flux_global_mean = area_weighted_mean( + virtual_salt_flux, + keepdim=True, + name="ocean_salt_content", + ) + expected_change_salt_content = ( + salt_flux_global_mean + unaccounted_salt_flux + ) * timestep_seconds + salt_content_correction_ratio = ( + global_input_salt_content + expected_change_salt_content + ) / global_gen_salt_content + n_levels = gen.sea_water_salinity.shape[-1] + for k in range(n_levels): + name = f"so_{k}" + gen.data[name] = gen.data[name] * salt_content_correction_ratio + return gen.data diff --git a/fme/core/corrector/test_ocean.py b/fme/core/corrector/test_ocean.py index bffb8ba73..15fdb069f 100644 --- a/fme/core/corrector/test_ocean.py +++ b/fme/core/corrector/test_ocean.py @@ -4,11 +4,13 @@ import torch from fme import get_device +from fme.core.constants import REFERENCE_SALINITY_PSU from fme.core.coordinates import DepthCoordinate from fme.core.corrector.ocean import ( OceanCorrector, OceanCorrectorConfig, OceanHeatContentBudgetConfig, + OceanSaltContentBudgetConfig, SeaIceFractionConfig, ) from fme.core.gridded_ops import LatLonOperations @@ -320,3 +322,87 @@ def test_ocean_heat_content_correction(hfds_type): gen_data_corrected.ocean_heat_content, equal_nan=True, ) + + +@pytest.mark.parametrize( + "wfo_type", + [ + pytest.param("input", id="wfo_in_input"), + pytest.param("gen", id="wfo_in_gen"), + ], +) +def test_ocean_salt_content_correction(wfo_type): + unaccounted_salt_flux = 0.1 + config = OceanCorrectorConfig( + ocean_salt_content_correction=OceanSaltContentBudgetConfig( + method="scaled_salinity", + constant_unaccounted_salt_flux=unaccounted_salt_flux, + ) + ) + timestep = datetime.timedelta(seconds=5 * 24 * 3600) + nsamples, nlat, nlon, nlevels = 4, 3, 3, 2 + mask = torch.ones(nsamples, nlat, nlon, nlevels) + mask[:, 0, 0, 0] = 0.0 + mask[:, 0, 0, 1] = 0.0 + mask[:, 0, 1, 1] = 0.0 + masks = { + "mask_0": mask[:, :, :, 0], + "mask_1": mask[:, :, :, 1], + "mask_2d": mask[:, :, :, 0], + } + mask_provider = MaskProvider(masks) + ops = LatLonOperations(torch.ones(size=[3, 3]), mask_provider) + + idepth = torch.tensor([2.5, 10, 20]) + depth_coordinate = DepthCoordinate(idepth, mask) + + sea_surface_fraction = mask[:, :, :, 0] + + wfo_value = torch.ones(nsamples, nlat, nlon) * 0.5 + + input_data_dict = { + "so_0": torch.ones(nsamples, nlat, nlon), + "so_1": torch.ones(nsamples, nlat, nlon), + } + gen_data_dict = { + "so_0": torch.ones(nsamples, nlat, nlon) * 2, + "so_1": torch.ones(nsamples, nlat, nlon) * 2, + } + if wfo_type == "gen": + gen_data_dict["wfo"] = wfo_value + else: + input_data_dict["wfo"] = wfo_value + forcing_data_dict = { + "sea_surface_fraction": sea_surface_fraction, + } + input_data = OceanData(input_data_dict, depth_coordinate) + gen_data = OceanData(gen_data_dict, depth_coordinate) + corrector = OceanCorrector(config, ops, depth_coordinate, timestep) + gen_data_corrected_dict = corrector( + input_data_dict, gen_data_dict, forcing_data_dict + ) + + input_osc = input_data.ocean_salt_content.nanmean(dim=(-1, -2), keepdim=True) + gen_osc = gen_data.ocean_salt_content.nanmean(dim=(-1, -2), keepdim=True) + torch.testing.assert_close(gen_osc, input_osc * 2, equal_nan=True) + + # -reference_salinity * wfo at all non-masked points, plus unaccounted flux; + # masked points are excluded by the area_weighted_mean so the raw wfo value + # (0.5 everywhere) directly determines the mean. + osc_change = ( + -REFERENCE_SALINITY_PSU * 0.5 + unaccounted_salt_flux + ) * timestep.total_seconds() + corrector_ratio = (input_osc + osc_change) / gen_osc + + expected_gen_data_dict = { + key: value * corrector_ratio if key.startswith("so") else value + for key, value in gen_data_dict.items() + } + + expected_gen_data = OceanData(expected_gen_data_dict, depth_coordinate) + gen_data_corrected = OceanData(gen_data_corrected_dict, depth_coordinate) + torch.testing.assert_close( + expected_gen_data.ocean_salt_content, + gen_data_corrected.ocean_salt_content, + equal_nan=True, + ) diff --git a/fme/core/ocean_data.py b/fme/core/ocean_data.py index c0060ea01..fa13e073e 100644 --- a/fme/core/ocean_data.py +++ b/fme/core/ocean_data.py @@ -5,7 +5,11 @@ import torch -from fme.core.constants import DENSITY_OF_WATER_CM4, SPECIFIC_HEAT_OF_WATER_CM4 +from fme.core.constants import ( + DENSITY_OF_WATER_CM4, + REFERENCE_SALINITY_PSU, + SPECIFIC_HEAT_OF_WATER_CM4, +) from fme.core.stacker import Stacker from fme.core.typing_ import TensorDict, TensorMapping @@ -25,6 +29,7 @@ "net_downward_surface_heat_flux": ["hfds"], "net_downward_surface_heat_flux_total_area": ["hfds_total_area"], "geothermal_heat_flux": ["hfgeou"], + "water_flux_into_sea_water": ["wfo"], "sea_surface_fraction": ["sea_surface_fraction"], } ) @@ -162,6 +167,23 @@ def ocean_heat_content(self) -> torch.Tensor: * DENSITY_OF_WATER_CM4 ) + @property + def ocean_salt_content(self) -> torch.Tensor: + """Returns column-integrated ocean salt content.""" + if self._depth_coordinate is None: + raise ValueError( + "Depth coordinate must be provided to compute column-integrated " + "ocean salt content." + ) + return self._depth_coordinate.depth_integral( + self.sea_water_salinity * DENSITY_OF_WATER_CM4 + ) + + @property + def water_flux_into_sea_water(self) -> torch.Tensor: + """Returns water flux into sea water (wfo).""" + return self._get("water_flux_into_sea_water") + @property def sea_surface_fraction(self) -> torch.Tensor: """Returns the sea surface fraction.""" @@ -203,6 +225,19 @@ def net_energy_flux_into_ocean(self) -> torch.Tensor: self.net_downward_surface_heat_flux + self.geothermal_heat_flux ) * self.sea_surface_fraction + @property + def net_virtual_salt_flux_into_ocean(self) -> torch.Tensor: + """Virtual salt flux into the ocean column (g/m2/s). + + Positive wfo (freshwater in) dilutes salt, giving a negative salt flux. + Uses a fixed reference salinity for diagnostic purposes. + """ + return ( + -REFERENCE_SALINITY_PSU + * self.water_flux_into_sea_water + * (self.sea_surface_fraction) + ) + @property def sea_ice_fraction(self) -> torch.Tensor: """Returns the sea ice fraction.""" diff --git a/fme/core/ocean_derived_variables.py b/fme/core/ocean_derived_variables.py index b6d61ff9a..ebe0d40b1 100644 --- a/fme/core/ocean_derived_variables.py +++ b/fme/core/ocean_derived_variables.py @@ -167,6 +167,45 @@ def implied_tendency_of_ocean_heat_content_due_to_advection( return implied_column_heating +@register(VariableMetadata("g/m**2", "Column-integrated ocean salt content")) +def ocean_salt_content( + data: OceanData, + timestep: datetime.timedelta, +) -> torch.Tensor: + return data.ocean_salt_content + + +@register( + VariableMetadata("g/m**2/s", "Tendency of column-integrated ocean salt content") +) +def ocean_salt_content_tendency( + data: OceanData, + timestep: datetime.timedelta, +) -> torch.Tensor: + osc = data.ocean_salt_content + osc_tendency = torch.zeros_like(osc) + osc_tendency[:, 1:] = torch.diff(osc, n=1, dim=1) / timestep.total_seconds() + return osc_tendency + + +@register( + VariableMetadata( + "g/m**2/s", + "Implied advective tendency of ocean salt content assuming closed budget", + ) +) +def implied_tendency_of_ocean_salt_content_due_to_advection( + data: OceanData, + timestep: datetime.timedelta, +) -> torch.Tensor: + """Implied tendency of ocean salt content due to advection. + This is computed as a residual from the column salt budget. + """ + column_salt_tendency = ocean_salt_content_tendency(data, timestep) + flux_through_surface = data.net_virtual_salt_flux_into_ocean + return column_salt_tendency - flux_through_surface + + @register( VariableMetadata( "W/m**2", diff --git a/fme/core/test_ocean_data.py b/fme/core/test_ocean_data.py index 3d22de437..0438731f8 100644 --- a/fme/core/test_ocean_data.py +++ b/fme/core/test_ocean_data.py @@ -51,6 +51,50 @@ def test_column_integrated_ocean_heat_content(has_depth_coordinate: bool): _ = ocean_data.ocean_heat_content +@pytest.mark.parametrize("has_depth_coordinate", [True, False]) +def test_column_integrated_ocean_salt_content(has_depth_coordinate: bool): + """Test column-integrated ocean salt content.""" + n_samples, n_time_steps, nlat, nlon, nlevels = 2, 2, 2, 2, 2 + shape_2d = (n_samples, n_time_steps, nlat, nlon) + + data = { + "so_0": torch.ones(n_samples, n_time_steps, nlat, nlon), + "so_1": torch.ones(n_samples, n_time_steps, nlat, nlon), + } + + if has_depth_coordinate: + idepth = torch.tensor([2.5, 10, 20]) + lev_thickness = idepth.diff(dim=-1) + mask = torch.ones(n_samples, n_time_steps, nlat, nlon, nlevels) + mask[:, :, 0, 0, 0] = 0.0 + mask[:, :, 0, 0, 1] = 0.0 + mask[:, :, 0, 1, 1] = 0.0 + + expected_osc = torch.tensor( + DENSITY_OF_WATER_CM4 + * n_samples + * n_time_steps + * ( + nlat * nlon * lev_thickness.sum() + - lev_thickness[0] + - 2 * lev_thickness[1] + ) + ) + depth_coordinate = DepthCoordinate(idepth, mask) + ocean_data = OceanData(data, depth_coordinate) + assert ocean_data.ocean_salt_content.shape == shape_2d + assert torch.allclose( + ocean_data.ocean_salt_content.nansum(), + expected_osc, + atol=1e-10, + equal_nan=True, + ) + else: + ocean_data = OceanData(data) + with pytest.raises(ValueError, match="Depth coordinate must be provided"): + _ = ocean_data.ocean_salt_content + + def test_get_3d_fields(): """Test getting 3D fields (fields with vertical levels).""" n_samples, n_time_steps, nlat, nlon, nlevels = 2, 3, 4, 8, 2 From b7b47301a0260252362600a7876206674ec54e73 Mon Sep 17 00:00:00 2001 From: James Duncan Date: Tue, 17 Mar 2026 13:23:41 -0700 Subject: [PATCH 2/3] Replace "scaled_salinity" with "constant_salinity" --- fme/core/coordinates.py | 8 ++++++ fme/core/corrector/ocean.py | 43 +++++++++++++++++++------------- fme/core/corrector/test_ocean.py | 27 ++++++++++++-------- fme/core/ocean_data.py | 9 +++++++ 4 files changed, 60 insertions(+), 27 deletions(-) diff --git a/fme/core/coordinates.py b/fme/core/coordinates.py index 68c2fbb29..208e0c0d6 100644 --- a/fme/core/coordinates.py +++ b/fme/core/coordinates.py @@ -383,11 +383,19 @@ def __post_init__(self): f"{self.mask.shape}." ) self._dz = dz_from_idepth(self.idepth, self.mask, self.deptho) + if self.deptho is not None: + self._sea_floor_depth = self.deptho + else: + self._sea_floor_depth = self._dz.sum(dim=-1) @property def dz(self) -> torch.Tensor: return self._dz + @property + def sea_floor_depth(self) -> torch.Tensor: + return self._sea_floor_depth + def __len__(self): """The number of vertical layer interfaces.""" return len(self.idepth) diff --git a/fme/core/corrector/ocean.py b/fme/core/corrector/ocean.py index f509185d3..39feb6e95 100644 --- a/fme/core/corrector/ocean.py +++ b/fme/core/corrector/ocean.py @@ -8,6 +8,7 @@ from fme.core.atmosphere_data import AtmosphereData from fme.core.constants import ( + DENSITY_OF_SEA_WATER_CM4, FREEZING_TEMPERATURE_KELVIN, LATENT_HEAT_OF_VAPORIZATION, REFERENCE_SALINITY_PSU, @@ -115,16 +116,17 @@ class OceanSaltContentBudgetConfig: Parameters: method: Method to use for salt content budget correction. The available - option is "scaled_salinity", which enforces conservation of salt - content by scaling the predicted salinity by a vertically and - horizontally uniform correction factor. + option is "constant_salinity", which enforces conservation of salt + content by adding a uniform concentration correction (PSU) to each + ocean layer, computed by dividing the global salt deficit by the + local column mass (reference density times sea floor depth). constant_unaccounted_salt_flux: Area-weighted global mean - column-integrated salt flux in g/m**2/s to be added to the virtual - salt flux when conserving the salt content. This can be useful for - correcting errors in salt budget in target data. + column-integrated salt flux in g/m**2/s to be added to the surface + boundary flux when conserving the salt content. This can be useful + for correcting errors in salt budget in target data. """ - method: Literal["scaled_salinity"] + method: Literal["constant_salinity"] constant_unaccounted_salt_flux: float = 0.0 @@ -388,10 +390,10 @@ def _force_conserve_ocean_salt_content( area_weighted_mean: AreaWeightedMean, vertical_coordinate: HasOceanDepthIntegral, timestep_seconds: float, - method: Literal["scaled_salinity"] = "scaled_salinity", + method: Literal["constant_salinity"] = "constant_salinity", unaccounted_salt_flux: float = 0.0, ) -> TensorDict: - if method != "scaled_salinity": + if method != "constant_salinity": raise NotImplementedError( f"Method {method!r} not implemented for ocean salt content conservation" ) @@ -417,20 +419,27 @@ def _force_conserve_ocean_salt_content( wfo = gen.water_flux_into_sea_water except KeyError: wfo = input.water_flux_into_sea_water + try: + sfdsi = gen.downward_sea_ice_basal_salt_flux + except KeyError: + sfdsi = input.downward_sea_ice_basal_salt_flux virtual_salt_flux = -REFERENCE_SALINITY_PSU * wfo * forcing.sea_surface_fraction + total_surface_flux = virtual_salt_flux + 1000.0 * sfdsi salt_flux_global_mean = area_weighted_mean( - virtual_salt_flux, + total_surface_flux, keepdim=True, name="ocean_salt_content", ) - expected_change_salt_content = ( - salt_flux_global_mean + unaccounted_salt_flux - ) * timestep_seconds - salt_content_correction_ratio = ( - global_input_salt_content + expected_change_salt_content - ) / global_gen_salt_content + salt_deficit = ( + global_input_salt_content + + (salt_flux_global_mean + unaccounted_salt_flux) * timestep_seconds + - global_gen_salt_content + ) + correction_psu = salt_deficit / ( + DENSITY_OF_SEA_WATER_CM4 * vertical_coordinate.sea_floor_depth + ) n_levels = gen.sea_water_salinity.shape[-1] for k in range(n_levels): name = f"so_{k}" - gen.data[name] = gen.data[name] * salt_content_correction_ratio + gen.data[name] = gen.data[name] + correction_psu return gen.data diff --git a/fme/core/corrector/test_ocean.py b/fme/core/corrector/test_ocean.py index da396d4a6..828f5e754 100644 --- a/fme/core/corrector/test_ocean.py +++ b/fme/core/corrector/test_ocean.py @@ -4,7 +4,7 @@ import torch from fme import get_device -from fme.core.constants import REFERENCE_SALINITY_PSU +from fme.core.constants import DENSITY_OF_SEA_WATER_CM4, REFERENCE_SALINITY_PSU from fme.core.coordinates import DepthCoordinate from fme.core.corrector.ocean import ( OceanCorrector, @@ -30,6 +30,9 @@ class _MockDepth: + def sea_floor_depth(self) -> torch.Tensor: + return torch.full_like(_MASK, 15) + def depth_integral(self, integrand: torch.Tensor) -> torch.Tensor: idepth = torch.tensor([0, 5, 15], device=DEVICE) thickness = idepth.diff(dim=-1) @@ -414,7 +417,7 @@ def test_ocean_salt_content_correction(wfo_type): unaccounted_salt_flux = 0.1 config = OceanCorrectorConfig( ocean_salt_content_correction=OceanSaltContentBudgetConfig( - method="scaled_salinity", + method="constant_salinity", constant_unaccounted_salt_flux=unaccounted_salt_flux, ) ) @@ -438,6 +441,7 @@ def test_ocean_salt_content_correction(wfo_type): sea_surface_fraction = mask[:, :, :, 0] wfo_value = torch.ones(nsamples, nlat, nlon) * 0.5 + sfdsi_value = torch.ones(nsamples, nlat, nlon) * 0.001 input_data_dict = { "so_0": torch.ones(nsamples, nlat, nlon), @@ -446,6 +450,7 @@ def test_ocean_salt_content_correction(wfo_type): gen_data_dict = { "so_0": torch.ones(nsamples, nlat, nlon) * 2, "so_1": torch.ones(nsamples, nlat, nlon) * 2, + "sfdsi": sfdsi_value, } if wfo_type == "gen": gen_data_dict["wfo"] = wfo_value @@ -465,16 +470,18 @@ def test_ocean_salt_content_correction(wfo_type): gen_osc = gen_data.ocean_salt_content.nanmean(dim=(-1, -2), keepdim=True) torch.testing.assert_close(gen_osc, input_osc * 2, equal_nan=True) - # -reference_salinity * wfo at all non-masked points, plus unaccounted flux; - # masked points are excluded by the area_weighted_mean so the raw wfo value - # (0.5 everywhere) directly determines the mean. - osc_change = ( - -REFERENCE_SALINITY_PSU * 0.5 + unaccounted_salt_flux - ) * timestep.total_seconds() - corrector_ratio = (input_osc + osc_change) / gen_osc + # Total surface flux combines virtual salt flux from freshwater and sfdsi. + # All non-masked ocean points have the same flux, so the area-weighted + # mean equals the pointwise value. + total_flux_mean = -REFERENCE_SALINITY_PSU * 0.5 + 1000.0 * 0.001 + osc_change = (total_flux_mean + unaccounted_salt_flux) * timestep.total_seconds() + salt_deficit = input_osc + osc_change - gen_osc + correction_psu = salt_deficit / ( + DENSITY_OF_SEA_WATER_CM4 * depth_coordinate.sea_floor_depth + ) expected_gen_data_dict = { - key: value * corrector_ratio if key.startswith("so") else value + key: value + correction_psu if key.startswith("so") else value for key, value in gen_data_dict.items() } diff --git a/fme/core/ocean_data.py b/fme/core/ocean_data.py index 2a271492c..602c9c88c 100644 --- a/fme/core/ocean_data.py +++ b/fme/core/ocean_data.py @@ -30,12 +30,16 @@ "net_downward_surface_heat_flux_total_area": ["hfds_total_area"], "geothermal_heat_flux": ["hfgeou"], "water_flux_into_sea_water": ["wfo"], + "downward_sea_ice_basal_salt_flux": ["sfdsi"], "sea_surface_fraction": ["sea_surface_fraction"], } ) class HasOceanDepthIntegral(Protocol): + @property + def sea_floor_depth(self) -> torch.Tensor: ... + def depth_integral( self, integrand: torch.Tensor, @@ -170,6 +174,11 @@ def water_flux_into_sea_water(self) -> torch.Tensor: """Returns water flux into sea water (wfo).""" return self._get("water_flux_into_sea_water") + @property + def downward_sea_ice_basal_salt_flux(self) -> torch.Tensor: + """Returns the downward sea ice basal salt flux in kg/m2/s.""" + return self._get("downward_sea_ice_basal_salt_flux") + @property def sea_surface_fraction(self) -> torch.Tensor: """Returns the sea surface fraction.""" From 78beea896d6b93f953b8f413a48b0e0aa84691ce Mon Sep 17 00:00:00 2001 From: James Duncan Date: Tue, 17 Mar 2026 16:21:10 -0700 Subject: [PATCH 3/3] Normalize by global mean depth --- fme/core/constants.py | 2 +- fme/core/corrector/ocean.py | 29 +++++++++++++++++++----- fme/core/corrector/test_ocean.py | 38 ++++++++++++++++---------------- fme/core/ocean_data.py | 4 ++-- 4 files changed, 45 insertions(+), 28 deletions(-) diff --git a/fme/core/constants.py b/fme/core/constants.py index 4efd095ac..a0f5e9b2e 100644 --- a/fme/core/constants.py +++ b/fme/core/constants.py @@ -16,6 +16,6 @@ DENSITY_OF_SEA_WATER_CM4 = 1035.0 # kg/m^3 FREEZING_TEMPERATURE_KELVIN = 273.15 # K -REFERENCE_SALINITY_PSU = 35 # g/kg +REFERENCE_SALINITY_PSU = 35.0 # g/kg EARTH_RADIUS = 6371000.0 # m diff --git a/fme/core/corrector/ocean.py b/fme/core/corrector/ocean.py index 39feb6e95..dd191839d 100644 --- a/fme/core/corrector/ocean.py +++ b/fme/core/corrector/ocean.py @@ -183,6 +183,20 @@ def __init__( self._gridded_operations = gridded_operations self._vertical_coordinate = vertical_coordinate self._timestep = timestep + self._global_mean_depth: torch.Tensor | None = None + + @property + def global_mean_depth(self) -> torch.Tensor: + if self._vertical_coordinate is None: + raise RuntimeError( + "Global mean ocean depth cannot be computed without the ocean depth " + "coordinate." + ) + if self._global_mean_depth is None: + self._global_mean_depth = self._gridded_operations.area_weighted_mean( + self._vertical_coordinate.sea_floor_depth, name="global_mean_depth" + ) + return self._global_mean_depth def __call__( self, @@ -207,6 +221,7 @@ def __call__( self._gridded_operations.area_weighted_mean, self._vertical_coordinate, self._timestep.total_seconds(), + self.global_mean_depth, self._config.ocean_salt_content_correction.method, self._config.ocean_salt_content_correction.constant_unaccounted_salt_flux, ) @@ -390,6 +405,7 @@ def _force_conserve_ocean_salt_content( area_weighted_mean: AreaWeightedMean, vertical_coordinate: HasOceanDepthIntegral, timestep_seconds: float, + global_mean_depth: torch.Tensor, method: Literal["constant_salinity"] = "constant_salinity", unaccounted_salt_flux: float = 0.0, ) -> TensorDict: @@ -424,7 +440,8 @@ def _force_conserve_ocean_salt_content( except KeyError: sfdsi = input.downward_sea_ice_basal_salt_flux virtual_salt_flux = -REFERENCE_SALINITY_PSU * wfo * forcing.sea_surface_fraction - total_surface_flux = virtual_salt_flux + 1000.0 * sfdsi + salt_flux = 1000.0 * sfdsi # kg/m2/s -> g/m2/s + total_surface_flux = virtual_salt_flux + salt_flux # g/m2/s salt_flux_global_mean = area_weighted_mean( total_surface_flux, keepdim=True, @@ -434,12 +451,12 @@ def _force_conserve_ocean_salt_content( global_input_salt_content + (salt_flux_global_mean + unaccounted_salt_flux) * timestep_seconds - global_gen_salt_content - ) - correction_psu = salt_deficit / ( - DENSITY_OF_SEA_WATER_CM4 * vertical_coordinate.sea_floor_depth - ) + ) # g/m2 + salinity_correction = salt_deficit / ( + DENSITY_OF_SEA_WATER_CM4 * global_mean_depth + ) # g/kg n_levels = gen.sea_water_salinity.shape[-1] for k in range(n_levels): name = f"so_{k}" - gen.data[name] = gen.data[name] + correction_psu + gen.data[name] = gen.data[name] + salinity_correction return gen.data diff --git a/fme/core/corrector/test_ocean.py b/fme/core/corrector/test_ocean.py index 828f5e754..db997c6d2 100644 --- a/fme/core/corrector/test_ocean.py +++ b/fme/core/corrector/test_ocean.py @@ -321,14 +321,14 @@ def test_ocean_heat_content_correction(hfds_type): ) timestep = datetime.timedelta(seconds=5 * 24 * 3600) nsamples, nlat, nlon, nlevels = 4, 3, 3, 2 - mask = torch.ones(nsamples, nlat, nlon, nlevels) - mask[:, 0, 0, 0] = 0.0 - mask[:, 0, 0, 1] = 0.0 - mask[:, 0, 1, 1] = 0.0 + mask = torch.ones(nlat, nlon, nlevels) + mask[0, 0, 0] = 0.0 + mask[0, 0, 1] = 0.0 + mask[0, 1, 1] = 0.0 masks = { - "mask_0": mask[:, :, :, 0], - "mask_1": mask[:, :, :, 1], - "mask_2d": mask[:, :, :, 0], + "mask_0": mask[:, :, 0], + "mask_1": mask[:, :, 1], + "mask_2d": mask[:, :, 0], } mask_provider = MaskProvider(masks) ops = LatLonOperations(torch.ones(size=[3, 3]), mask_provider) @@ -336,7 +336,7 @@ def test_ocean_heat_content_correction(hfds_type): idepth = torch.tensor([2.5, 10, 20]) depth_coordinate = DepthCoordinate(idepth, mask) - sea_surface_fraction = mask[:, :, :, 0] + sea_surface_fraction = mask[:, :, 0] input_data_dict = { "thetao_0": torch.ones(nsamples, nlat, nlon), @@ -423,22 +423,23 @@ def test_ocean_salt_content_correction(wfo_type): ) timestep = datetime.timedelta(seconds=5 * 24 * 3600) nsamples, nlat, nlon, nlevels = 4, 3, 3, 2 - mask = torch.ones(nsamples, nlat, nlon, nlevels) - mask[:, 0, 0, 0] = 0.0 - mask[:, 0, 0, 1] = 0.0 - mask[:, 0, 1, 1] = 0.0 + mask = torch.ones(nlat, nlon, nlevels) + mask[0, 0, 0] = 0.0 + mask[0, 0, 1] = 0.0 + mask[0, 1, 1] = 0.0 masks = { - "mask_0": mask[:, :, :, 0], - "mask_1": mask[:, :, :, 1], - "mask_2d": mask[:, :, :, 0], + "mask_0": mask[:, :, 0], + "mask_1": mask[:, :, 1], + "mask_2d": mask[:, :, 0], } mask_provider = MaskProvider(masks) ops = LatLonOperations(torch.ones(size=[3, 3]), mask_provider) idepth = torch.tensor([2.5, 10, 20]) depth_coordinate = DepthCoordinate(idepth, mask) + global_mean_depth = 16.25 # (7 * 17.5 m * 1 * 7.5 m) / 8 - sea_surface_fraction = mask[:, :, :, 0] + sea_surface_fraction = mask[:, :, 0] wfo_value = torch.ones(nsamples, nlat, nlon) * 0.5 sfdsi_value = torch.ones(nsamples, nlat, nlon) * 0.001 @@ -476,9 +477,8 @@ def test_ocean_salt_content_correction(wfo_type): total_flux_mean = -REFERENCE_SALINITY_PSU * 0.5 + 1000.0 * 0.001 osc_change = (total_flux_mean + unaccounted_salt_flux) * timestep.total_seconds() salt_deficit = input_osc + osc_change - gen_osc - correction_psu = salt_deficit / ( - DENSITY_OF_SEA_WATER_CM4 * depth_coordinate.sea_floor_depth - ) + + correction_psu = salt_deficit / (DENSITY_OF_SEA_WATER_CM4 * global_mean_depth) expected_gen_data_dict = { key: value + correction_psu if key.startswith("so") else value diff --git a/fme/core/ocean_data.py b/fme/core/ocean_data.py index 602c9c88c..02c944907 100644 --- a/fme/core/ocean_data.py +++ b/fme/core/ocean_data.py @@ -159,7 +159,7 @@ def ocean_heat_content(self) -> torch.Tensor: @property def ocean_salt_content(self) -> torch.Tensor: - """Returns column-integrated ocean salt content.""" + """Returns column-integrated ocean salt content in g/m2.""" if self._depth_coordinate is None: raise ValueError( "Depth coordinate must be provided to compute column-integrated " @@ -171,7 +171,7 @@ def ocean_salt_content(self) -> torch.Tensor: @property def water_flux_into_sea_water(self) -> torch.Tensor: - """Returns water flux into sea water (wfo).""" + """Returns water flux into sea water in kg/m2/s.""" return self._get("water_flux_into_sea_water") @property