diff --git a/fme/core/constants.py b/fme/core/constants.py index 944545908..a0f5e9b2e 100644 --- a/fme/core/constants.py +++ b/fme/core/constants.py @@ -16,5 +16,6 @@ DENSITY_OF_SEA_WATER_CM4 = 1035.0 # kg/m^3 FREEZING_TEMPERATURE_KELVIN = 273.15 # K +REFERENCE_SALINITY_PSU = 35.0 # g/kg EARTH_RADIUS = 6371000.0 # m diff --git a/fme/core/coordinates.py b/fme/core/coordinates.py index 68c2fbb29..208e0c0d6 100644 --- a/fme/core/coordinates.py +++ b/fme/core/coordinates.py @@ -383,11 +383,19 @@ def __post_init__(self): f"{self.mask.shape}." ) self._dz = dz_from_idepth(self.idepth, self.mask, self.deptho) + if self.deptho is not None: + self._sea_floor_depth = self.deptho + else: + self._sea_floor_depth = self._dz.sum(dim=-1) @property def dz(self) -> torch.Tensor: return self._dz + @property + def sea_floor_depth(self) -> torch.Tensor: + return self._sea_floor_depth + def __len__(self): """The number of vertical layer interfaces.""" return len(self.idepth) diff --git a/fme/core/corrector/ocean.py b/fme/core/corrector/ocean.py index eecee02f6..dd191839d 100644 --- a/fme/core/corrector/ocean.py +++ b/fme/core/corrector/ocean.py @@ -8,8 +8,10 @@ from fme.core.atmosphere_data import AtmosphereData from fme.core.constants import ( + DENSITY_OF_SEA_WATER_CM4, FREEZING_TEMPERATURE_KELVIN, LATENT_HEAT_OF_VAPORIZATION, + REFERENCE_SALINITY_PSU, SPECIFIC_HEAT_OF_SEA_WATER_CM4, ) from fme.core.corrector.registry import CorrectorABC @@ -108,6 +110,26 @@ class SurfaceEnergyFluxCorrectionConfig: method: Literal["residual_prediction", "prescribed"] +@dataclasses.dataclass +class OceanSaltContentBudgetConfig: + """Configuration for ocean salt content budget correction. + + Parameters: + method: Method to use for salt content budget correction. The available + option is "constant_salinity", which enforces conservation of salt + content by adding a uniform concentration correction (PSU) to each + ocean layer, computed by dividing the global salt deficit by the + local column mass (reference density times sea floor depth). + constant_unaccounted_salt_flux: Area-weighted global mean + column-integrated salt flux in g/m**2/s to be added to the surface + boundary flux when conserving the salt content. This can be useful + for correcting errors in salt budget in target data. + """ + + method: Literal["constant_salinity"] + constant_unaccounted_salt_flux: float = 0.0 + + @CorrectorSelector.register("ocean_corrector") @dataclasses.dataclass class OceanCorrectorConfig: @@ -115,6 +137,7 @@ class OceanCorrectorConfig: sea_ice_fraction_correction: SeaIceFractionConfig | None = None surface_energy_flux_correction: SurfaceEnergyFluxCorrectionConfig | None = None ocean_heat_content_correction: OceanHeatContentBudgetConfig | None = None + ocean_salt_content_correction: OceanSaltContentBudgetConfig | None = None @classmethod def from_state(cls, state: Mapping[str, Any]) -> "OceanCorrectorConfig": @@ -160,6 +183,20 @@ def __init__( self._gridded_operations = gridded_operations self._vertical_coordinate = vertical_coordinate self._timestep = timestep + self._global_mean_depth: torch.Tensor | None = None + + @property + def global_mean_depth(self) -> torch.Tensor: + if self._vertical_coordinate is None: + raise RuntimeError( + "Global mean ocean depth cannot be computed without the ocean depth " + "coordinate." + ) + if self._global_mean_depth is None: + self._global_mean_depth = self._gridded_operations.area_weighted_mean( + self._vertical_coordinate.sea_floor_depth, name="global_mean_depth" + ) + return self._global_mean_depth def __call__( self, @@ -171,6 +208,23 @@ def __call__( gen_data = force_positive(gen_data, self._config.force_positive_names) if self._config.sea_ice_fraction_correction is not None: gen_data = self._config.sea_ice_fraction_correction(gen_data, input_data) + if self._config.ocean_salt_content_correction is not None: + if self._vertical_coordinate is None: + raise ValueError( + "Ocean salt content correction is turned on, but no vertical " + "coordinate is available." + ) + gen_data = _force_conserve_ocean_salt_content( + input_data, + gen_data, + forcing_data, + self._gridded_operations.area_weighted_mean, + self._vertical_coordinate, + self._timestep.total_seconds(), + self.global_mean_depth, + self._config.ocean_salt_content_correction.method, + self._config.ocean_salt_content_correction.constant_unaccounted_salt_flux, + ) if self._config.surface_energy_flux_correction is not None: gen_data = _correct_hfds( input_data, @@ -342,3 +396,67 @@ def _force_conserve_ocean_heat_content( gen.data["sst"] - FREEZING_TEMPERATURE_KELVIN ) * heat_content_correction_ratio + FREEZING_TEMPERATURE_KELVIN return gen.data + + +def _force_conserve_ocean_salt_content( + input_data: TensorMapping, + gen_data: TensorMapping, + forcing_data: TensorMapping, + area_weighted_mean: AreaWeightedMean, + vertical_coordinate: HasOceanDepthIntegral, + timestep_seconds: float, + global_mean_depth: torch.Tensor, + method: Literal["constant_salinity"] = "constant_salinity", + unaccounted_salt_flux: float = 0.0, +) -> TensorDict: + if method != "constant_salinity": + raise NotImplementedError( + f"Method {method!r} not implemented for ocean salt content conservation" + ) + if "wfo" in gen_data and "wfo" in forcing_data: + raise ValueError( + "Water flux into sea water cannot be present in both gen_data and " + "forcing_data." + ) + input = OceanData(input_data, vertical_coordinate) + gen = OceanData(gen_data, vertical_coordinate) + forcing = OceanData(forcing_data) + global_gen_salt_content = area_weighted_mean( + gen.ocean_salt_content, + keepdim=True, + name="ocean_salt_content", + ) + global_input_salt_content = area_weighted_mean( + input.ocean_salt_content, + keepdim=True, + name="ocean_salt_content", + ) + try: + wfo = gen.water_flux_into_sea_water + except KeyError: + wfo = input.water_flux_into_sea_water + try: + sfdsi = gen.downward_sea_ice_basal_salt_flux + except KeyError: + sfdsi = input.downward_sea_ice_basal_salt_flux + virtual_salt_flux = -REFERENCE_SALINITY_PSU * wfo * forcing.sea_surface_fraction + salt_flux = 1000.0 * sfdsi # kg/m2/s -> g/m2/s + total_surface_flux = virtual_salt_flux + salt_flux # g/m2/s + salt_flux_global_mean = area_weighted_mean( + total_surface_flux, + keepdim=True, + name="ocean_salt_content", + ) + salt_deficit = ( + global_input_salt_content + + (salt_flux_global_mean + unaccounted_salt_flux) * timestep_seconds + - global_gen_salt_content + ) # g/m2 + salinity_correction = salt_deficit / ( + DENSITY_OF_SEA_WATER_CM4 * global_mean_depth + ) # g/kg + n_levels = gen.sea_water_salinity.shape[-1] + for k in range(n_levels): + name = f"so_{k}" + gen.data[name] = gen.data[name] + salinity_correction + return gen.data diff --git a/fme/core/corrector/test_ocean.py b/fme/core/corrector/test_ocean.py index 0794a72da..db997c6d2 100644 --- a/fme/core/corrector/test_ocean.py +++ b/fme/core/corrector/test_ocean.py @@ -4,11 +4,13 @@ import torch from fme import get_device +from fme.core.constants import DENSITY_OF_SEA_WATER_CM4, REFERENCE_SALINITY_PSU from fme.core.coordinates import DepthCoordinate from fme.core.corrector.ocean import ( OceanCorrector, OceanCorrectorConfig, OceanHeatContentBudgetConfig, + OceanSaltContentBudgetConfig, SeaIceFractionConfig, SurfaceEnergyFluxCorrectionConfig, _compute_ocean_net_surface_energy_flux, @@ -28,6 +30,9 @@ class _MockDepth: + def sea_floor_depth(self) -> torch.Tensor: + return torch.full_like(_MASK, 15) + def depth_integral(self, integrand: torch.Tensor) -> torch.Tensor: idepth = torch.tensor([0, 5, 15], device=DEVICE) thickness = idepth.diff(dim=-1) @@ -316,14 +321,14 @@ def test_ocean_heat_content_correction(hfds_type): ) timestep = datetime.timedelta(seconds=5 * 24 * 3600) nsamples, nlat, nlon, nlevels = 4, 3, 3, 2 - mask = torch.ones(nsamples, nlat, nlon, nlevels) - mask[:, 0, 0, 0] = 0.0 - mask[:, 0, 0, 1] = 0.0 - mask[:, 0, 1, 1] = 0.0 + mask = torch.ones(nlat, nlon, nlevels) + mask[0, 0, 0] = 0.0 + mask[0, 0, 1] = 0.0 + mask[0, 1, 1] = 0.0 masks = { - "mask_0": mask[:, :, :, 0], - "mask_1": mask[:, :, :, 1], - "mask_2d": mask[:, :, :, 0], + "mask_0": mask[:, :, 0], + "mask_1": mask[:, :, 1], + "mask_2d": mask[:, :, 0], } mask_provider = MaskProvider(masks) ops = LatLonOperations(torch.ones(size=[3, 3]), mask_provider) @@ -331,7 +336,7 @@ def test_ocean_heat_content_correction(hfds_type): idepth = torch.tensor([2.5, 10, 20]) depth_coordinate = DepthCoordinate(idepth, mask) - sea_surface_fraction = mask[:, :, :, 0] + sea_surface_fraction = mask[:, :, 0] input_data_dict = { "thetao_0": torch.ones(nsamples, nlat, nlon), @@ -399,3 +404,91 @@ def test_ocean_heat_content_correction(hfds_type): gen_data_corrected.ocean_heat_content, equal_nan=True, ) + + +@pytest.mark.parametrize( + "wfo_type", + [ + pytest.param("input", id="wfo_in_input"), + pytest.param("gen", id="wfo_in_gen"), + ], +) +def test_ocean_salt_content_correction(wfo_type): + unaccounted_salt_flux = 0.1 + config = OceanCorrectorConfig( + ocean_salt_content_correction=OceanSaltContentBudgetConfig( + method="constant_salinity", + constant_unaccounted_salt_flux=unaccounted_salt_flux, + ) + ) + timestep = datetime.timedelta(seconds=5 * 24 * 3600) + nsamples, nlat, nlon, nlevels = 4, 3, 3, 2 + mask = torch.ones(nlat, nlon, nlevels) + mask[0, 0, 0] = 0.0 + mask[0, 0, 1] = 0.0 + mask[0, 1, 1] = 0.0 + masks = { + "mask_0": mask[:, :, 0], + "mask_1": mask[:, :, 1], + "mask_2d": mask[:, :, 0], + } + mask_provider = MaskProvider(masks) + ops = LatLonOperations(torch.ones(size=[3, 3]), mask_provider) + + idepth = torch.tensor([2.5, 10, 20]) + depth_coordinate = DepthCoordinate(idepth, mask) + global_mean_depth = 16.25 # (7 * 17.5 m * 1 * 7.5 m) / 8 + + sea_surface_fraction = mask[:, :, 0] + + wfo_value = torch.ones(nsamples, nlat, nlon) * 0.5 + sfdsi_value = torch.ones(nsamples, nlat, nlon) * 0.001 + + input_data_dict = { + "so_0": torch.ones(nsamples, nlat, nlon), + "so_1": torch.ones(nsamples, nlat, nlon), + } + gen_data_dict = { + "so_0": torch.ones(nsamples, nlat, nlon) * 2, + "so_1": torch.ones(nsamples, nlat, nlon) * 2, + "sfdsi": sfdsi_value, + } + if wfo_type == "gen": + gen_data_dict["wfo"] = wfo_value + else: + input_data_dict["wfo"] = wfo_value + forcing_data_dict = { + "sea_surface_fraction": sea_surface_fraction, + } + input_data = OceanData(input_data_dict, depth_coordinate) + gen_data = OceanData(gen_data_dict, depth_coordinate) + corrector = OceanCorrector(config, ops, depth_coordinate, timestep) + gen_data_corrected_dict = corrector( + input_data_dict, gen_data_dict, forcing_data_dict + ) + + input_osc = input_data.ocean_salt_content.nanmean(dim=(-1, -2), keepdim=True) + gen_osc = gen_data.ocean_salt_content.nanmean(dim=(-1, -2), keepdim=True) + torch.testing.assert_close(gen_osc, input_osc * 2, equal_nan=True) + + # Total surface flux combines virtual salt flux from freshwater and sfdsi. + # All non-masked ocean points have the same flux, so the area-weighted + # mean equals the pointwise value. + total_flux_mean = -REFERENCE_SALINITY_PSU * 0.5 + 1000.0 * 0.001 + osc_change = (total_flux_mean + unaccounted_salt_flux) * timestep.total_seconds() + salt_deficit = input_osc + osc_change - gen_osc + + correction_psu = salt_deficit / (DENSITY_OF_SEA_WATER_CM4 * global_mean_depth) + + expected_gen_data_dict = { + key: value + correction_psu if key.startswith("so") else value + for key, value in gen_data_dict.items() + } + + expected_gen_data = OceanData(expected_gen_data_dict, depth_coordinate) + gen_data_corrected = OceanData(gen_data_corrected_dict, depth_coordinate) + torch.testing.assert_close( + expected_gen_data.ocean_salt_content, + gen_data_corrected.ocean_salt_content, + equal_nan=True, + ) diff --git a/fme/core/ocean_data.py b/fme/core/ocean_data.py index a299de5fe..02c944907 100644 --- a/fme/core/ocean_data.py +++ b/fme/core/ocean_data.py @@ -5,7 +5,11 @@ import torch -from fme.core.constants import DENSITY_OF_SEA_WATER_CM4, SPECIFIC_HEAT_OF_SEA_WATER_CM4 +from fme.core.constants import ( + DENSITY_OF_SEA_WATER_CM4, + REFERENCE_SALINITY_PSU, + SPECIFIC_HEAT_OF_SEA_WATER_CM4, +) from fme.core.stacker import Stacker from fme.core.typing_ import TensorDict, TensorMapping @@ -25,12 +29,17 @@ "net_downward_surface_heat_flux": ["hfds"], "net_downward_surface_heat_flux_total_area": ["hfds_total_area"], "geothermal_heat_flux": ["hfgeou"], + "water_flux_into_sea_water": ["wfo"], + "downward_sea_ice_basal_salt_flux": ["sfdsi"], "sea_surface_fraction": ["sea_surface_fraction"], } ) class HasOceanDepthIntegral(Protocol): + @property + def sea_floor_depth(self) -> torch.Tensor: ... + def depth_integral( self, integrand: torch.Tensor, @@ -148,6 +157,28 @@ def ocean_heat_content(self) -> torch.Tensor: * DENSITY_OF_SEA_WATER_CM4 ) + @property + def ocean_salt_content(self) -> torch.Tensor: + """Returns column-integrated ocean salt content in g/m2.""" + if self._depth_coordinate is None: + raise ValueError( + "Depth coordinate must be provided to compute column-integrated " + "ocean salt content." + ) + return self._depth_coordinate.depth_integral( + self.sea_water_salinity * DENSITY_OF_SEA_WATER_CM4 + ) + + @property + def water_flux_into_sea_water(self) -> torch.Tensor: + """Returns water flux into sea water in kg/m2/s.""" + return self._get("water_flux_into_sea_water") + + @property + def downward_sea_ice_basal_salt_flux(self) -> torch.Tensor: + """Returns the downward sea ice basal salt flux in kg/m2/s.""" + return self._get("downward_sea_ice_basal_salt_flux") + @property def sea_surface_fraction(self) -> torch.Tensor: """Returns the sea surface fraction.""" @@ -189,6 +220,19 @@ def net_energy_flux_into_ocean(self) -> torch.Tensor: self.net_downward_surface_heat_flux + self.geothermal_heat_flux ) * self.sea_surface_fraction + @property + def net_virtual_salt_flux_into_ocean(self) -> torch.Tensor: + """Virtual salt flux into the ocean column (g/m2/s). + + Positive wfo (freshwater in) dilutes salt, giving a negative salt flux. + Uses a fixed reference salinity for diagnostic purposes. + """ + return ( + -REFERENCE_SALINITY_PSU + * self.water_flux_into_sea_water + * (self.sea_surface_fraction) + ) + @property def sea_ice_fraction(self) -> torch.Tensor: """Returns the sea ice fraction.""" diff --git a/fme/core/ocean_derived_variables.py b/fme/core/ocean_derived_variables.py index b6d61ff9a..ebe0d40b1 100644 --- a/fme/core/ocean_derived_variables.py +++ b/fme/core/ocean_derived_variables.py @@ -167,6 +167,45 @@ def implied_tendency_of_ocean_heat_content_due_to_advection( return implied_column_heating +@register(VariableMetadata("g/m**2", "Column-integrated ocean salt content")) +def ocean_salt_content( + data: OceanData, + timestep: datetime.timedelta, +) -> torch.Tensor: + return data.ocean_salt_content + + +@register( + VariableMetadata("g/m**2/s", "Tendency of column-integrated ocean salt content") +) +def ocean_salt_content_tendency( + data: OceanData, + timestep: datetime.timedelta, +) -> torch.Tensor: + osc = data.ocean_salt_content + osc_tendency = torch.zeros_like(osc) + osc_tendency[:, 1:] = torch.diff(osc, n=1, dim=1) / timestep.total_seconds() + return osc_tendency + + +@register( + VariableMetadata( + "g/m**2/s", + "Implied advective tendency of ocean salt content assuming closed budget", + ) +) +def implied_tendency_of_ocean_salt_content_due_to_advection( + data: OceanData, + timestep: datetime.timedelta, +) -> torch.Tensor: + """Implied tendency of ocean salt content due to advection. + This is computed as a residual from the column salt budget. + """ + column_salt_tendency = ocean_salt_content_tendency(data, timestep) + flux_through_surface = data.net_virtual_salt_flux_into_ocean + return column_salt_tendency - flux_through_surface + + @register( VariableMetadata( "W/m**2", diff --git a/fme/core/test_ocean_data.py b/fme/core/test_ocean_data.py index a94226ec9..196bd7fdc 100644 --- a/fme/core/test_ocean_data.py +++ b/fme/core/test_ocean_data.py @@ -51,6 +51,50 @@ def test_column_integrated_ocean_heat_content(has_depth_coordinate: bool): _ = ocean_data.ocean_heat_content +@pytest.mark.parametrize("has_depth_coordinate", [True, False]) +def test_column_integrated_ocean_salt_content(has_depth_coordinate: bool): + """Test column-integrated ocean salt content.""" + n_samples, n_time_steps, nlat, nlon, nlevels = 2, 2, 2, 2, 2 + shape_2d = (n_samples, n_time_steps, nlat, nlon) + + data = { + "so_0": torch.ones(n_samples, n_time_steps, nlat, nlon), + "so_1": torch.ones(n_samples, n_time_steps, nlat, nlon), + } + + if has_depth_coordinate: + idepth = torch.tensor([2.5, 10, 20]) + lev_thickness = idepth.diff(dim=-1) + mask = torch.ones(n_samples, n_time_steps, nlat, nlon, nlevels) + mask[:, :, 0, 0, 0] = 0.0 + mask[:, :, 0, 0, 1] = 0.0 + mask[:, :, 0, 1, 1] = 0.0 + + expected_osc = torch.tensor( + DENSITY_OF_SEA_WATER_CM4 + * n_samples + * n_time_steps + * ( + nlat * nlon * lev_thickness.sum() + - lev_thickness[0] + - 2 * lev_thickness[1] + ) + ) + depth_coordinate = DepthCoordinate(idepth, mask) + ocean_data = OceanData(data, depth_coordinate) + assert ocean_data.ocean_salt_content.shape == shape_2d + assert torch.allclose( + ocean_data.ocean_salt_content.nansum(), + expected_osc, + atol=1e-10, + equal_nan=True, + ) + else: + ocean_data = OceanData(data) + with pytest.raises(ValueError, match="Depth coordinate must be provided"): + _ = ocean_data.ocean_salt_content + + def test_get_3d_fields(): """Test getting 3D fields (fields with vertical levels).""" n_samples, n_time_steps, nlat, nlon, nlevels = 2, 3, 4, 8, 2