diff --git a/ci/environment-upstream-dev.yml b/ci/environment-upstream-dev.yml index 5e93d143..133f24f5 100644 --- a/ci/environment-upstream-dev.yml +++ b/ci/environment-upstream-dev.yml @@ -14,6 +14,7 @@ dependencies: - pyyaml>=5.3.1 - scipy - toolz + - zarr - pip: - git+https://github.com/pydata/xarray.git#egg=xarray - git+https://github.com/dask/dask.git#egg=dask diff --git a/ci/environment.yml b/ci/environment.yml index 1d3335ff..121e919a 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -18,3 +18,4 @@ dependencies: - xarray>=0.16.1 - xgcm - watermark + - zarr diff --git a/docs/source/api.rst b/docs/source/api.rst index f9df6e19..53effbe7 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -11,6 +11,7 @@ Grid ~~~~ .. autosummary:: + calc_dzu_dzt get_grid @@ -42,6 +43,7 @@ Utilities ~~~~~~~~~ .. autosummary:: + four_point_min lateral_fill @@ -54,6 +56,8 @@ xgcm utilities .. currentmodule:: pop_tools +.. autofunction:: calc_dzu_dzt + .. autofunction:: get_grid .. autofunction:: eos diff --git a/pop_tools/grid.py b/pop_tools/grid.py index a1c37aa4..93a06841 100644 --- a/pop_tools/grid.py +++ b/pop_tools/grid.py @@ -6,7 +6,7 @@ import pooch import xarray as xr import yaml -from numba import jit, prange +from numba import double, float_, guvectorize, int_, jit, prange try: from tqdm import tqdm @@ -463,3 +463,158 @@ def _compute_corners(ULAT, ULONG): corner_lon[0, :, 3] = corner_lon[1, :, 3] - (corner_lon[2, :, 3] - corner_lon[1, :, 3]) return corner_lat, corner_lon + + +@guvectorize( + [ + (int_[:, :], int_[:, :]), + (float_[:, :], float_[:, :]), + (double[:, :], double[:, :]), + ], + '(n,m)->(n,m)', + nopython=True, + cache=True, +) +def numba_4pt_min(var, out): + """ + gufunc to calculate minimum over + (i, j+1) ————— (i+1, j+1) + | | + | | + (i,j) ————— (i+1, j) + at every depth level. + + Expects and returns a 2d numpy array + """ + dim1, dim0 = var.shape + out[:] = 0 + + for j in prange(dim1 - 1): + for i in prange(dim0 - 1): + out[j, i] = np.min( + np.array([var[j, i], var[j + 1, i], var[j, i + 1], var[j + 1, i + 1]]) + ) + + +def four_point_min(array, dims=('nlat', 'nlon')): + """ + Utility function that calculates minimum at 4 surrounding points in 2D slices + along dimensions ``dims. + + Output at (i,j) is minimium over the following 4 points + (i, j+1) ————— (i+1, j+1) + | | + | | + (i,j) ————— (i+1, j) + + Parameters + ---------- + array: DataArray + A 2D or 3D DataArray + + dims: tuple or list + two element tuple or list of dimension names + + Returns + ------- + DataArray + """ + + import dask + + if len(dims) != 2: + raise ValueError(f'Expected 2 dimensions. Received {dims} instead.') + + array = array.transpose(..., *dims) + data = array.data + + # map_overlap does not support negative axes :/ + depth = {array.ndim - 2: (0, 1), array.ndim - 1: (0, 1)} + + if dask.is_dask_collection(data): + result = data.map_overlap(numba_4pt_min, depth=depth, boundary='none', meta=data._meta) + else: + result = numba_4pt_min(data) + + return array.copy(data=result) + + +def calc_dzu_dzt(grid): + """ + Calculates DZT and DZU from a dataset containing dz, KMT and DZBC + + .. warning:: + + This function does not do the right thing at the tripole grid seam. + + Parameters + ---------- + grid: Dataset + An xarray Dataset containing grid variables. This *must* contain partial bottom + cell information: KMT and DZBC. Datasets with dimensions renamed for xgcm are not + allowed. + + Returns + ------- + DZT, DZU: DataArray + + Notes + ----- + From Frank's zulip convo + https://zulip.cloud.ucar.edu/#narrow/stream/9-CGD-OCE/topic/pop-tools/near/2864 + + DZT[:,:,k] = dz[k] if k< KMT-1 # converting from Fortran to python indexing + DZT[i,j,KMT[i,j]-1] = DZBC[i,j] + DZU = min of 4 surrounding DZT + + """ + + if not isinstance(grid, xr.Dataset): + raise ValueError( + f'Expected xarray Dataset with grid variables. Received {type(grid).__name__} instead.' + ) + expected_vars = ['dz', 'KMT', 'DZBC'] + missing_vars = set(expected_vars) - set(grid.variables) + if missing_vars: + raise ValueError(f'Variables {missing_vars} are missing in the provided dataset.') + + dz = grid.dz + KMT = grid.KMT + DZBC = grid.DZBC + + dzunit = dz.attrs.get('units', None) + zunit = {'units': dzunit} if dzunit is not None else {} + + # build a 1D DataArray of z-index value + fortran_zindex = dz.copy(data=np.arange(1, grid.sizes['z_t'] + 1)) + + # set values at KMT to DZBC, else, use existing nominal dz + DZT = xr.where(fortran_zindex == KMT, DZBC, dz) + DZT.name = 'DZT' + DZT.attrs = { + 'standard_name': 'cell_thickness', + 'long_name': 'Thickness of T cells', + **zunit, + 'grid_loc': '3111', + } + + if 'nlon_t' in DZT.dims: + raise ValueError('datasets renamed for xgcm are not allowed.') + + # now make DZU + DZU = four_point_min(DZT) + KMU = four_point_min(KMT) + + # In Fortran-like code, DZU is computed using a WORK variable that has DZT values. + # Then only values above KMU are modified, so we replicate that here + # so that we can run tests and users can check against existing code + DZU = xr.where(fortran_zindex >= KMU, DZT, DZU) + DZU.name = 'DZU' + DZU.attrs = { + 'standard_name': 'cell_thickness', + 'long_name': 'Thickness of U cells', + **zunit, + 'grid_loc': '3221', + } + + return DZT, DZU diff --git a/pop_tools/xgcm_util.py b/pop_tools/xgcm_util.py index 650c23e4..e1e61bf6 100644 --- a/pop_tools/xgcm_util.py +++ b/pop_tools/xgcm_util.py @@ -60,6 +60,7 @@ def _label_coord_grid_locs(ds): 'DZU': '3221', 'DZT': '3111', 'HT': '2110', + 'DZBC': '2110', 'HU': '2220', 'HTE': '2210', 'HTN': '2120', diff --git a/tests/test_grid.py b/tests/test_grid.py index 6dd66ceb..c3b671d9 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -2,8 +2,11 @@ import pytest import xarray as xr +from xarray.testing import assert_equal import pop_tools +from pop_tools import DATASETS +from pop_tools.datasets import UnzipZarr from .util import ds_compare, is_ncar_host @@ -43,3 +46,42 @@ def test_get_grid_to_netcdf(): gridfile = f'{grid}_{format}.nc' ds.to_netcdf(gridfile, format=format) os.system(f'rm -f {gridfile}') + + +def test_four_point_min_kmu(): + zstore = DATASETS.fetch('comp-grid.tx9.1v3.20170718.zarr.zip', processor=UnzipZarr()) + ds = xr.open_zarr(zstore) + + # topmost row is wrong because we need to account for tripole seam + # rightmost nlon is wrong because it doesn't matter + expected = ds.KMU.isel(nlat=slice(-1), nlon=slice(-1)) + actual = pop_tools.grid.four_point_min(ds.KMT).isel(nlat=slice(-1), nlon=slice(-1)) + assert_equal(expected, actual) + + # make sure dask & numpy results check out + actual = pop_tools.grid.four_point_min(ds.KMT.compute()).isel(nlat=slice(-1), nlon=slice(-1)) + assert_equal(expected, actual) + + +def test_dzu_dzt(): + + zstore = DATASETS.fetch('comp-grid.tx9.1v3.20170718.zarr.zip', processor=UnzipZarr()) + # chunk size is 300 along nlat; make sure we cross at least + # one chunk boundary to test map_overlap + ds = xr.open_zarr(zstore).sel(nlat=slice(100, 350)) + + dzu, dzt = pop_tools.grid.calc_dzu_dzt(ds) + # northernmost row will be wrong since we are working on a subset + assert_equal(dzu.isel(nlat=slice(-1)), ds['DZU'].isel(nlat=slice(-1))) + assert_equal(dzt, ds['DZT']) + + _, xds = pop_tools.to_xgcm_grid_dataset(ds) + with pytest.raises(ValueError): + pop_tools.grid.calc_dzu_dzt(xds) + + expected_vars = ['dz', 'KMT', 'DZBC'] + for var in expected_vars: + dsc = ds.copy() + del dsc[var] + with pytest.raises(ValueError): + pop_tools.grid.calc_dzu_dzt(dsc)