diff --git a/pop_tools/grid.py b/pop_tools/grid.py index b6cdd749..f06e77bf 100644 --- a/pop_tools/grid.py +++ b/pop_tools/grid.py @@ -132,6 +132,10 @@ def get_grid(grid_name, scrip=False): assert kmt_flat.max() <= len(z_t), 'Max KMT > length z_t' KMT = kmt_flat.reshape(grid_attrs['lateral_dims']).astype(np.int32) + # derive KMU + KMU = np.empty_like(KMT) + KMU = _generate_KMU(KMT, KMU) + # read REGION_MASK region_mask_fname = INPUTDATA.fetch(grid_attrs['region_mask_fname'], downloader=downloader) region_mask_flat = np.fromfile(region_mask_fname, dtype='>i4', count=-1) @@ -140,6 +144,15 @@ def get_grid(grid_name, scrip=False): ), f'unexpected dims in region_mask file: {grid_attrs["region_mask_fname"]}' REGION_MASK = region_mask_flat.reshape(grid_attrs['lateral_dims']).astype(np.int32) + # derive depth of columns of ocean T-points and U-points + KMT_reidx = KMT - 1 + KMT_reidx[KMT_reidx == -1] = 0 + HT = z_w[KMT_reidx] + + KMU_reidx = KMU - 1 + KMU_reidx[KMU_reidx == -1] = 0 + HU = z_w[KMU_reidx] + # output dataset dso = xr.Dataset() if scrip: @@ -267,6 +280,35 @@ def get_grid(grid_name, scrip=False): }, ) + dso['KMU'] = xr.DataArray( + KMU, + dims=('nlat', 'nlon'), + attrs={ + 'long_name': 'k Index of Deepest Grid Cell on U Grid', + 'coordinates': 'ULONG ULAT', + }, + ) + + dso['HT'] = xr.DataArray( + HT, + dims=('nlat', 'nlon'), + attrs={ + 'units': 'cm', + 'long_name': 'depth of ocean column on T grid', + 'coordinates': 'TLONG TLAT', + }, + ) + + dso['HU'] = xr.DataArray( + HU, + dims=('nlat', 'nlon'), + attrs={ + 'units': 'cm', + 'long_name': 'depth of ocean column on U grid', + 'coordinates': 'ULONG ULAT', + }, + ) + dso['REGION_MASK'] = xr.DataArray( REGION_MASK, dims=('nlat', 'nlon'), @@ -320,6 +362,15 @@ def get_grid(grid_name, scrip=False): return dso +@jit(nopython=True, parallel=True) +def _generate_KMU(KMT, KMU): + """Computes KMU from KMT.""" + for i in prange(KMT.shape[0]): + for j in prange(KMT.shape[1]): + KMU[i, j] = min(KMT[i, j], KMT[i - 1, j], KMT[i, j - 1], KMT[i - 1, j - 1]) + return KMU + + @jit(nopython=True, parallel=True) def _compute_TLAT_TLONG(ULAT, ULONG, TLAT, TLONG, nlat, nlon): """Compute TLAT and TLONG from ULAT, ULONG""" diff --git a/tests/test_grid.py b/tests/test_grid.py index 4b56b41d..6d642a89 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -1,5 +1,6 @@ import os +import pytest import xarray as xr import pop_tools @@ -26,3 +27,12 @@ def test_get_grid_scrip(): ds_test = pop_tools.get_grid('POP_gx3v7', scrip=True) ds_ref = xr.open_dataset(DATASETS.fetch('POP_gx3v7.nc')) assert ds_compare(ds_test, ds_ref, assertion='allclose', rtol=1e-14, atol=1e-14) + + +@pytest.mark.parametrize('grid', pop_tools.grid_defs.keys()) +def test_HT_HU_KMU_in_grid(grid): + print(grid) + ds = pop_tools.get_grid(grid) + assert 'HT' in ds, 'Missing variable HT' + assert 'HU' in ds, 'Missing variable HU' + assert 'KMU' in ds, 'Missing variable KMU'