-
Notifications
You must be signed in to change notification settings - Fork 32
First Pass Zonal Average #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
e2a6d50
c09eadf
ff4bc68
6fcd961
3e6b24c
8a0bd6d
4732c90
15488ca
f950017
acf689c
db5c0d2
886790c
d4bb79c
31698c1
4fceca6
2a9a0a9
6e14014
b6623f1
284997e
8a831a3
9726975
eba88b3
424c05f
69e4193
2d017cc
68a2a10
bbeeecb
4ff8a80
d415190
8d54ba7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,245 @@ | ||
| import os | ||
| import warnings | ||
|
|
||
| import numpy as np | ||
| import xarray as xr | ||
|
|
||
| try: | ||
| import xesmf as xe | ||
| except ImportError: | ||
| message = 'Zonal averaging requires xesmf package.\n\n' | ||
| 'Please conda install as follows:\n\n' | ||
| ' conda install -c conda-forge xesmf>=0.4.0' | ||
|
|
||
| raise ImportError(message) | ||
|
|
||
| from tqdm import tqdm | ||
|
|
||
| from .. import get_grid, region_mask_3d | ||
|
|
||
|
|
||
| def _generate_dest_grid(dy=None, dx=None, method_gen_grid='regular_lat_lon'): | ||
| """ | ||
| Generates the destination grid | ||
|
|
||
| Parameters | ||
| ---------- | ||
| dy: float | ||
| Horizontal grid spacing in y-direction (latitudinal) | ||
|
|
||
| dy: float | ||
| Horizontal grid spcaing in x-direction (longitudinal) | ||
| """ | ||
|
|
||
| # Use regular lat/lon with regular spacing | ||
| if method_gen_grid == 'regular_lat_lon': | ||
| if dy is None: | ||
| dy = 0.25 | ||
|
|
||
| if dx is None: | ||
| dx = dy | ||
|
|
||
| # Able to add other options at a later point | ||
| else: | ||
| raise ValueError(f'Input method_gen_grid: {method_gen_grid} is not supported.') | ||
|
|
||
| # Use xESMF to generate the destination grid | ||
| return xe.util.grid_global(dx, dy) | ||
|
|
||
|
|
||
| def _get_default_filename(src_grid, dst_grid, method): | ||
|
|
||
| # Get the source grid shape | ||
| src_shape = src_grid.lat.shape | ||
|
|
||
| # Get the destination grid shape | ||
| dst_shape = dst_grid.lat.shape | ||
|
|
||
| filename = f'{method}_{src_shape[0]}x{src_shape[1]}_{dst_shape[0]}x{dst_shape[1]}.nc' | ||
|
|
||
| return filename | ||
|
|
||
|
|
||
| def _convert_to_xesmf(data_ds, grid_ds): | ||
| """ | ||
| Format xarray datasets to be read in easily to xESMF | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data_ds : `xarray.Dataset` | ||
| Dataset which includes fields to regrid | ||
|
|
||
| grid_ds : `xarray.Dataset` | ||
| Dataset including the POP grid | ||
|
|
||
| Returns | ||
| ------- | ||
|
|
||
| out_ds : `xarray.Dataset` | ||
| Clipped dataset including fields to regrid with grid | ||
|
|
||
| """ | ||
|
|
||
| # Merge datasets into single dataset | ||
| data_ds = xr.merge( | ||
| [grid_ds.reset_coords(), data_ds.reset_coords()], compat='override', join='right' | ||
| ).rename({'TLAT': 'lat', 'TLONG': 'lon'}) | ||
|
|
||
| # Inlcude only points that will have surrounding corners | ||
| data_ds = data_ds.isel({'nlon': data_ds.nlon[1:], 'nlat': data_ds.nlat[1:]}) | ||
|
|
||
| # Use ulat and ulong values as grid corners, rename variables to match xESMF syntax | ||
| grid_corners = grid_ds[['ULAT', 'ULONG']].rename( | ||
|
||
| {'nlat': 'nlat_b', 'nlon': 'nlon_b', 'ULAT': 'lat_b', 'ULONG': 'lon_b'} | ||
| ) | ||
|
|
||
| # Merge datasets with data and grid corner information | ||
| out_ds = xr.merge([data_ds, grid_corners]) | ||
|
|
||
| return out_ds | ||
|
|
||
|
|
||
| def _generate_weights(src_grid, dst_grid, method, weight_file=None): | ||
| """ | ||
| Generate regridding weights by calling xESMF | ||
| """ | ||
|
|
||
| # Allow user to input weights file, if there is not one, use default check | ||
| if weight_file is None: | ||
| weight_file = _get_default_filename(src_grid, dst_grid, method) | ||
|
|
||
| # Check to see if the weights file already exists - if not, generate weights | ||
| if not os.path.exists(weight_file): | ||
| regridder = xe.Regridder(src_grid, dst_grid, method) | ||
| print(f'Saving weights file: {os.path.abspath(weight_file)}') | ||
| regridder.to_netcdf(weight_file) | ||
|
|
||
| else: | ||
| regridder = xe.Regridder(src_grid, dst_grid, method, weights=weight_file) | ||
|
|
||
| return regridder | ||
|
|
||
|
|
||
| class Regridder: | ||
| def __init__( | ||
| self, | ||
| grid_name=None, | ||
| grid=None, | ||
| dx=None, | ||
| dy=None, | ||
| mask=True, | ||
| regrid_method='conservative', | ||
| method_gen_grid='regular_lat_lon', | ||
| ): | ||
| """ | ||
| A regridding class which uses xESMF and Xarray tools to both regrid and | ||
| calculate a zonal averge. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| grid_name: string | ||
| POP grid name (ex. 'POP_gx1v6') | ||
|
|
||
| grid: `xarray.Dataset` | ||
| User defined grid containing metadata typically found in POP grid | ||
|
|
||
| dx: float | ||
| Horizontal grid spacing in x-direction for output grid in degrees | ||
|
|
||
| dy: float | ||
| Horizontal grid spacing in y-direction for output grid in degrees | ||
|
|
||
| mask: `xarray.Dataarray` | ||
| Flag whether to use mask, can also be user defined region mask | ||
|
|
||
| regrid_method: string | ||
| Regridding method to be used within xESMF (default is conservative) | ||
|
|
||
| method_gen_grid: string | ||
| Method used to generate the output grid - default is a regular lat/lon grid | ||
| """ | ||
| if grid_name is not None: | ||
| self.grid_name = grid_name | ||
|
|
||
| # Use pop-tools to retrieve the grid | ||
| self.grid = get_grid(grid_name) | ||
|
|
||
| elif grid is not None: | ||
| self.grid = grid | ||
|
|
||
| else: | ||
| raise ValueError('Failed to input grid name or grid dataset') | ||
|
|
||
| # Set the dx/dy parameters for generating the grid | ||
| self.dx = dx | ||
| self.dy = dy | ||
|
|
||
| # Set the regridding method | ||
| self.regrid_method = regrid_method | ||
|
|
||
| # Set the grid generation method | ||
| self.method_gen_grid = method_gen_grid | ||
|
|
||
| # Use the region 3d mask provided in pop-tools | ||
|
|
||
| if mask: | ||
| self.mask = region_mask_3d(grid_name, mask_name='default') | ||
| self.mask.name = 'region_mask' | ||
|
|
||
| else: | ||
|
||
| return ValueError('Failed to specify whether to use mask') | ||
|
|
||
| # Setup method for regridding a dataarray | ||
| def _regrid_dataarray(self, da_in, regrid_mask=False, regrid_method=None): | ||
|
|
||
| src_grid = _convert_to_xesmf(da_in, self.grid) | ||
| dst_grid = _generate_dest_grid(self.dy, self.dx, self.method_gen_grid) | ||
|
|
||
| # If the user does not specify a regridding method, use default conservative | ||
| if regrid_method is None: | ||
| regridder = _generate_weights(src_grid, dst_grid, self.regrid_method) | ||
|
|
||
| else: | ||
| regridder = _generate_weights(src_grid, dst_grid, regrid_method) | ||
|
|
||
| # Regrid the input data array, assigning the original attributes | ||
| da_out = regridder(src_grid[da_in.name]) | ||
| da_out.attrs = da_in.attrs | ||
|
|
||
| return da_out | ||
|
|
||
| def regrid(self, obj, **kwargs): | ||
| """generic interface for regridding DataArray or Dataset""" | ||
| if isinstance(obj, xr.Dataset): | ||
| var_list = list([]) | ||
| for var in obj: # only data variables | ||
|
|
||
| # Make sure the variable has the correct dimensions, is not a coordinate, and is not a velocity | ||
| if ('nlat' in obj[var].dims and 'nlon' in obj[var].dims) and ( | ||
| 'ULONG' not in obj[var].cf.coords['longitude'].name | ||
| and 'ULAT' not in obj[var].cf.coords['latitude'].name | ||
| ): | ||
| var_list.append(var) | ||
| return obj[var_list].map(self._regrid_dataarray, keep_attrs=True, **kwargs) | ||
| elif isinstance(obj, xr.DataArray): | ||
| return self._regrid_dataarray(obj, **kwargs) | ||
| raise TypeError('input data must be xarray DataArray or xarray Dataset!') | ||
|
|
||
| def zonal_average(self, obj, vertical_average=False, **kwargs): | ||
|
|
||
| data = self.regrid(obj, **kwargs) | ||
| mask = self.regrid(self.mask, regrid_method='nearest_s2d', **kwargs) | ||
|
|
||
| # Attach a name to the mask | ||
| mask.name = self.mask.name | ||
|
|
||
| # Replace zeros with nans and group into regions | ||
| out = mask.where(mask > 0) * data | ||
|
|
||
| # Check to see if a weighted vertical average is needed | ||
| if vertical_average: | ||
|
|
||
| # Calculate the vertical weighte average based on depth of the layer | ||
| out = out.weighted(out.z_t.diff(dim='z_t')).mean(dim='z_t') | ||
|
|
||
| return out | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work for datasets returned by
to_xgcm_grid_and_datasetwhich renames tonlon_t, nlon_u, nlat_t, nlat_u.I think the solution here is to use cf_xarray
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have a suggestion for adding a new X/Y coordinate? I noticed that the only cf axis within the POP grids is
timeThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ugh yeah I forgot. We need to make
nlat,nlondimension coordinates for this to work. That allows us to assign attributes likeaxis: "X"and similarly for
nlat_*,nlon_*into_xgcm_grid_dataset. This too seems like a followup PR>