Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fme/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
DENSITY_OF_SEA_WATER_CM4 = 1035.0 # kg/m^3

FREEZING_TEMPERATURE_KELVIN = 273.15 # K
REFERENCE_SALINITY_PSU = 35.0 # g/kg

EARTH_RADIUS = 6371000.0 # m
8 changes: 8 additions & 0 deletions fme/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,19 @@ def __post_init__(self):
f"{self.mask.shape}."
)
self._dz = dz_from_idepth(self.idepth, self.mask, self.deptho)
if self.deptho is not None:
self._sea_floor_depth = self.deptho
else:
self._sea_floor_depth = self._dz.sum(dim=-1)

@property
def dz(self) -> torch.Tensor:
return self._dz

@property
def sea_floor_depth(self) -> torch.Tensor:
return self._sea_floor_depth

def __len__(self):
"""The number of vertical layer interfaces."""
return len(self.idepth)
Expand Down
118 changes: 118 additions & 0 deletions fme/core/corrector/ocean.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

from fme.core.atmosphere_data import AtmosphereData
from fme.core.constants import (
DENSITY_OF_SEA_WATER_CM4,
FREEZING_TEMPERATURE_KELVIN,
LATENT_HEAT_OF_VAPORIZATION,
REFERENCE_SALINITY_PSU,
SPECIFIC_HEAT_OF_SEA_WATER_CM4,
)
from fme.core.corrector.registry import CorrectorABC
Expand Down Expand Up @@ -108,13 +110,34 @@ class SurfaceEnergyFluxCorrectionConfig:
method: Literal["residual_prediction", "prescribed"]


@dataclasses.dataclass
class OceanSaltContentBudgetConfig:
"""Configuration for ocean salt content budget correction.

Parameters:
method: Method to use for salt content budget correction. The available
option is "constant_salinity", which enforces conservation of salt
content by adding a uniform concentration correction (PSU) to each
ocean layer, computed by dividing the global salt deficit by the
local column mass (reference density times sea floor depth).
constant_unaccounted_salt_flux: Area-weighted global mean
column-integrated salt flux in g/m**2/s to be added to the surface
boundary flux when conserving the salt content. This can be useful
for correcting errors in salt budget in target data.
"""

method: Literal["constant_salinity"]
constant_unaccounted_salt_flux: float = 0.0


@CorrectorSelector.register("ocean_corrector")
@dataclasses.dataclass
class OceanCorrectorConfig:
force_positive_names: list[str] = dataclasses.field(default_factory=list)
sea_ice_fraction_correction: SeaIceFractionConfig | None = None
surface_energy_flux_correction: SurfaceEnergyFluxCorrectionConfig | None = None
ocean_heat_content_correction: OceanHeatContentBudgetConfig | None = None
ocean_salt_content_correction: OceanSaltContentBudgetConfig | None = None

@classmethod
def from_state(cls, state: Mapping[str, Any]) -> "OceanCorrectorConfig":
Expand Down Expand Up @@ -160,6 +183,20 @@ def __init__(
self._gridded_operations = gridded_operations
self._vertical_coordinate = vertical_coordinate
self._timestep = timestep
self._global_mean_depth: torch.Tensor | None = None

@property
def global_mean_depth(self) -> torch.Tensor:
if self._vertical_coordinate is None:
raise RuntimeError(
"Global mean ocean depth cannot be computed without the ocean depth "
"coordinate."
)
if self._global_mean_depth is None:
self._global_mean_depth = self._gridded_operations.area_weighted_mean(
self._vertical_coordinate.sea_floor_depth, name="global_mean_depth"
)
return self._global_mean_depth

def __call__(
self,
Expand All @@ -171,6 +208,23 @@ def __call__(
gen_data = force_positive(gen_data, self._config.force_positive_names)
if self._config.sea_ice_fraction_correction is not None:
gen_data = self._config.sea_ice_fraction_correction(gen_data, input_data)
if self._config.ocean_salt_content_correction is not None:
if self._vertical_coordinate is None:
raise ValueError(
"Ocean salt content correction is turned on, but no vertical "
"coordinate is available."
)
gen_data = _force_conserve_ocean_salt_content(
input_data,
gen_data,
forcing_data,
self._gridded_operations.area_weighted_mean,
self._vertical_coordinate,
self._timestep.total_seconds(),
self.global_mean_depth,
self._config.ocean_salt_content_correction.method,
self._config.ocean_salt_content_correction.constant_unaccounted_salt_flux,
)
if self._config.surface_energy_flux_correction is not None:
gen_data = _correct_hfds(
input_data,
Expand Down Expand Up @@ -342,3 +396,67 @@ def _force_conserve_ocean_heat_content(
gen.data["sst"] - FREEZING_TEMPERATURE_KELVIN
) * heat_content_correction_ratio + FREEZING_TEMPERATURE_KELVIN
return gen.data


def _force_conserve_ocean_salt_content(
input_data: TensorMapping,
gen_data: TensorMapping,
forcing_data: TensorMapping,
area_weighted_mean: AreaWeightedMean,
vertical_coordinate: HasOceanDepthIntegral,
timestep_seconds: float,
global_mean_depth: torch.Tensor,
method: Literal["constant_salinity"] = "constant_salinity",
unaccounted_salt_flux: float = 0.0,
) -> TensorDict:
if method != "constant_salinity":
raise NotImplementedError(
f"Method {method!r} not implemented for ocean salt content conservation"
)
if "wfo" in gen_data and "wfo" in forcing_data:
raise ValueError(
"Water flux into sea water cannot be present in both gen_data and "
"forcing_data."
)
input = OceanData(input_data, vertical_coordinate)
gen = OceanData(gen_data, vertical_coordinate)
forcing = OceanData(forcing_data)
global_gen_salt_content = area_weighted_mean(
gen.ocean_salt_content,
keepdim=True,
name="ocean_salt_content",
)
global_input_salt_content = area_weighted_mean(
input.ocean_salt_content,
keepdim=True,
name="ocean_salt_content",
)
try:
wfo = gen.water_flux_into_sea_water
except KeyError:
wfo = input.water_flux_into_sea_water
try:
sfdsi = gen.downward_sea_ice_basal_salt_flux
except KeyError:
sfdsi = input.downward_sea_ice_basal_salt_flux
virtual_salt_flux = -REFERENCE_SALINITY_PSU * wfo * forcing.sea_surface_fraction
salt_flux = 1000.0 * sfdsi # kg/m2/s -> g/m2/s
total_surface_flux = virtual_salt_flux + salt_flux # g/m2/s
salt_flux_global_mean = area_weighted_mean(
total_surface_flux,
keepdim=True,
name="ocean_salt_content",
)
salt_deficit = (
global_input_salt_content
+ (salt_flux_global_mean + unaccounted_salt_flux) * timestep_seconds
- global_gen_salt_content
) # g/m2
salinity_correction = salt_deficit / (
DENSITY_OF_SEA_WATER_CM4 * global_mean_depth
) # g/kg
n_levels = gen.sea_water_salinity.shape[-1]
for k in range(n_levels):
name = f"so_{k}"
gen.data[name] = gen.data[name] + salinity_correction
return gen.data
109 changes: 101 additions & 8 deletions fme/core/corrector/test_ocean.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import torch

from fme import get_device
from fme.core.constants import DENSITY_OF_SEA_WATER_CM4, REFERENCE_SALINITY_PSU
from fme.core.coordinates import DepthCoordinate
from fme.core.corrector.ocean import (
OceanCorrector,
OceanCorrectorConfig,
OceanHeatContentBudgetConfig,
OceanSaltContentBudgetConfig,
SeaIceFractionConfig,
SurfaceEnergyFluxCorrectionConfig,
_compute_ocean_net_surface_energy_flux,
Expand All @@ -28,6 +30,9 @@


class _MockDepth:
def sea_floor_depth(self) -> torch.Tensor:
return torch.full_like(_MASK, 15)

def depth_integral(self, integrand: torch.Tensor) -> torch.Tensor:
idepth = torch.tensor([0, 5, 15], device=DEVICE)
thickness = idepth.diff(dim=-1)
Expand Down Expand Up @@ -316,22 +321,22 @@ def test_ocean_heat_content_correction(hfds_type):
)
timestep = datetime.timedelta(seconds=5 * 24 * 3600)
nsamples, nlat, nlon, nlevels = 4, 3, 3, 2
mask = torch.ones(nsamples, nlat, nlon, nlevels)
mask[:, 0, 0, 0] = 0.0
mask[:, 0, 0, 1] = 0.0
mask[:, 0, 1, 1] = 0.0
mask = torch.ones(nlat, nlon, nlevels)
mask[0, 0, 0] = 0.0
mask[0, 0, 1] = 0.0
mask[0, 1, 1] = 0.0
masks = {
"mask_0": mask[:, :, :, 0],
"mask_1": mask[:, :, :, 1],
"mask_2d": mask[:, :, :, 0],
"mask_0": mask[:, :, 0],
"mask_1": mask[:, :, 1],
"mask_2d": mask[:, :, 0],
}
mask_provider = MaskProvider(masks)
ops = LatLonOperations(torch.ones(size=[3, 3]), mask_provider)

idepth = torch.tensor([2.5, 10, 20])
depth_coordinate = DepthCoordinate(idepth, mask)

sea_surface_fraction = mask[:, :, :, 0]
sea_surface_fraction = mask[:, :, 0]

input_data_dict = {
"thetao_0": torch.ones(nsamples, nlat, nlon),
Expand Down Expand Up @@ -399,3 +404,91 @@ def test_ocean_heat_content_correction(hfds_type):
gen_data_corrected.ocean_heat_content,
equal_nan=True,
)


@pytest.mark.parametrize(
"wfo_type",
[
pytest.param("input", id="wfo_in_input"),
pytest.param("gen", id="wfo_in_gen"),
],
)
def test_ocean_salt_content_correction(wfo_type):
unaccounted_salt_flux = 0.1
config = OceanCorrectorConfig(
ocean_salt_content_correction=OceanSaltContentBudgetConfig(
method="constant_salinity",
constant_unaccounted_salt_flux=unaccounted_salt_flux,
)
)
timestep = datetime.timedelta(seconds=5 * 24 * 3600)
nsamples, nlat, nlon, nlevels = 4, 3, 3, 2
mask = torch.ones(nlat, nlon, nlevels)
mask[0, 0, 0] = 0.0
mask[0, 0, 1] = 0.0
mask[0, 1, 1] = 0.0
masks = {
"mask_0": mask[:, :, 0],
"mask_1": mask[:, :, 1],
"mask_2d": mask[:, :, 0],
}
mask_provider = MaskProvider(masks)
ops = LatLonOperations(torch.ones(size=[3, 3]), mask_provider)

idepth = torch.tensor([2.5, 10, 20])
depth_coordinate = DepthCoordinate(idepth, mask)
global_mean_depth = 16.25 # (7 * 17.5 m * 1 * 7.5 m) / 8

sea_surface_fraction = mask[:, :, 0]

wfo_value = torch.ones(nsamples, nlat, nlon) * 0.5
sfdsi_value = torch.ones(nsamples, nlat, nlon) * 0.001

input_data_dict = {
"so_0": torch.ones(nsamples, nlat, nlon),
"so_1": torch.ones(nsamples, nlat, nlon),
}
gen_data_dict = {
"so_0": torch.ones(nsamples, nlat, nlon) * 2,
"so_1": torch.ones(nsamples, nlat, nlon) * 2,
"sfdsi": sfdsi_value,
}
if wfo_type == "gen":
gen_data_dict["wfo"] = wfo_value
else:
input_data_dict["wfo"] = wfo_value
forcing_data_dict = {
"sea_surface_fraction": sea_surface_fraction,
}
input_data = OceanData(input_data_dict, depth_coordinate)
gen_data = OceanData(gen_data_dict, depth_coordinate)
corrector = OceanCorrector(config, ops, depth_coordinate, timestep)
gen_data_corrected_dict = corrector(
input_data_dict, gen_data_dict, forcing_data_dict
)

input_osc = input_data.ocean_salt_content.nanmean(dim=(-1, -2), keepdim=True)
gen_osc = gen_data.ocean_salt_content.nanmean(dim=(-1, -2), keepdim=True)
torch.testing.assert_close(gen_osc, input_osc * 2, equal_nan=True)

# Total surface flux combines virtual salt flux from freshwater and sfdsi.
# All non-masked ocean points have the same flux, so the area-weighted
# mean equals the pointwise value.
total_flux_mean = -REFERENCE_SALINITY_PSU * 0.5 + 1000.0 * 0.001
osc_change = (total_flux_mean + unaccounted_salt_flux) * timestep.total_seconds()
salt_deficit = input_osc + osc_change - gen_osc

correction_psu = salt_deficit / (DENSITY_OF_SEA_WATER_CM4 * global_mean_depth)

expected_gen_data_dict = {
key: value + correction_psu if key.startswith("so") else value
for key, value in gen_data_dict.items()
}

expected_gen_data = OceanData(expected_gen_data_dict, depth_coordinate)
gen_data_corrected = OceanData(gen_data_corrected_dict, depth_coordinate)
torch.testing.assert_close(
expected_gen_data.ocean_salt_content,
gen_data_corrected.ocean_salt_content,
equal_nan=True,
)
Loading
Loading