diff --git a/environment.yml b/environment.yml index 05801b268..b9665ea22 100644 --- a/environment.yml +++ b/environment.yml @@ -7,7 +7,6 @@ dependencies: - noaa-gfdl::analysis_scripts==0.0.1 - noaa-gfdl::catalogbuilder==2025.01.01 # - noaa-gfdl::fre-nctools==2022.02.01 - - conda-forge::cdo>=2 - conda-forge::cftime - conda-forge::click>=8.2 - conda-forge::cmor>=3.14 @@ -22,7 +21,6 @@ dependencies: - conda-forge::pytest - conda-forge::pytest-cov - conda-forge::pylint - - conda-forge::python-cdo - conda-forge::pyyaml - conda-forge::xarray>=2024.* - conda-forge::netcdf4>=1.7.* diff --git a/fre/app/freapp.py b/fre/app/freapp.py index 4744fe5e8..ac794c424 100644 --- a/fre/app/freapp.py +++ b/fre/app/freapp.py @@ -138,8 +138,8 @@ def mask_atmos_plevel(infile, psfile, outfile, warn_no_ps): required = True, help = "Output file name") @click.option("-p", "--pkg", - type = click.Choice(["cdo","fre-nctools","fre-python-tools"]), - default = "cdo", + type = click.Choice(["cdo","fre-nctools","fre-python-tools","xarray"]), + default = "fre-python-tools", help = "Time average approach") @click.option("-v", "--var", type = str, @@ -192,8 +192,8 @@ def gen_time_averages(inf, outf, pkg, var, unwgt, avg_type): required = True, help = "Frequency of desired climatology: 'mon' or 'yr'") @click.option("-p", "--pkg", - type = click.Choice(["cdo","fre-nctools","fre-python-tools"]), - default = "cdo", + type = click.Choice(["cdo","fre-nctools","fre-python-tools","xarray"]), + default = "fre-python-tools", help = "Time average approach") def gen_time_averages_wrapper(cycle_point, dir_, sources, output_interval, input_interval, grid, frequency, pkg): """ diff --git a/fre/app/generate_time_averages/__init__.py b/fre/app/generate_time_averages/__init__.py index 699064f4b..1546030d8 100644 --- a/fre/app/generate_time_averages/__init__.py +++ b/fre/app/generate_time_averages/__init__.py @@ -1,3 +1,4 @@ '''required for generate_time_averages module import functionality''' __all__ = ['generate_time_averages', 'timeAverager', 'wrapper', 'combine', - 'frenctoolsTimeAverager', 'cdoTimeAverager', 'frepytoolsTimeAverager'] + 'frenctoolsTimeAverager', 'cdoTimeAverager', 'frepytoolsTimeAverager', + 'xarrayTimeAverager'] diff --git a/fre/app/generate_time_averages/cdoTimeAverager.py b/fre/app/generate_time_averages/cdoTimeAverager.py index c8f585887..ddc3f399a 100644 --- a/fre/app/generate_time_averages/cdoTimeAverager.py +++ b/fre/app/generate_time_averages/cdoTimeAverager.py @@ -1,89 +1,35 @@ -''' class using (mostly) cdo functions for time-averages ''' +''' stub that redirects pkg='cdo' requests to the xarray time averager ''' import logging +import warnings -from netCDF4 import Dataset -import numpy as np - -import cdo -from cdo import Cdo - -from .timeAverager import timeAverager +from .xarrayTimeAverager import xarrayTimeAverager fre_logger = logging.getLogger(__name__) -class cdoTimeAverager(timeAverager): + +class cdoTimeAverager(xarrayTimeAverager): # pylint: disable=invalid-name ''' - class inheriting from abstract base class timeAverager - generates time-averages using cdo (mostly, see weighted approach) + Legacy entry-point kept for backward compatibility. + CDO/python-cdo has been removed. All work is now done by xarrayTimeAverager. ''' - def generate_timavg(self, infile = None, outfile = None): + def generate_timavg(self, infile=None, outfile=None): """ - use cdo package routines via python bindings + Emit a loud warning then delegate to the xarray implementation. - :param self: This is an instance of the class cdoTimeAverager - :param infile: path to history file, or list of paths, default is None - :type infile: str, list - :param outfile: path to where output file should be stored, default is None + :param infile: path to input NetCDF file + :type infile: str + :param outfile: path to output file :type outfile: str - :return: 1 if the instance variable self.avg_typ is unsupported, 0 if function has a clean exit + :return: 0 on success :rtype: int """ - - if self.avg_type not in ['all', 'seas', 'month']: - fre_logger.error('requested unknown avg_type %s.', self.avg_type) - raise ValueError(f'requested unknown avg_type {self.avg_type}') - - if self.var is not None: - fre_logger.warning('WARNING: variable specification not twr supported for cdo time averaging. ignoring!') - - fre_logger.info('python-cdo version is %s', cdo.__version__) - - _cdo = Cdo() - - wgts_sum = 0 - if not self.unwgt: #weighted case, cdo ops alone don't support a weighted time-average. - - nc_fin = Dataset(infile, 'r') - - time_bnds = nc_fin['time_bnds'][:].copy() - # Ensure float64 precision for consistent results across numpy versions - # NumPy 2.0 changed type promotion rules (NEP 50), so explicit casting - # is needed to avoid precision differences - time_bnds = np.asarray(time_bnds, dtype=np.float64) - # Transpose once to avoid redundant operations - time_bnds_transposed = np.moveaxis(time_bnds, 0, -1) - wgts = time_bnds_transposed[1] - time_bnds_transposed[0] - # Use numpy.sum for consistent dtype handling across numpy versions - wgts_sum = np.sum(wgts, dtype=np.float64) - - fre_logger.debug('wgts_sum = %s', wgts_sum) - - if self.avg_type == 'all': - fre_logger.info('time average over all time requested.') - if self.unwgt: - _cdo.timmean(input = infile, output = outfile, returnCdf = True) - else: - _cdo.divc( str(wgts_sum), input = "-timsum -muldpm "+infile, output = outfile) - fre_logger.info('done averaging over all time.') - - elif self.avg_type == 'seas': - fre_logger.info('seasonal time-averages requested.') - _cdo.yseasmean(input = infile, output = outfile, returnCdf = True) - fre_logger.info('done averaging over seasons.') - - elif self.avg_type == 'month': - fre_logger.info('monthly time-averages requested.') - outfile_str = str(outfile) - _cdo.ymonmean(input = infile, output = outfile_str, returnCdf = True) - fre_logger.info('done averaging over months.') - - fre_logger.warning(" splitting by month") - outfile_root = outfile_str.removesuffix(".nc") + '.' - _cdo.splitmon(input = outfile_str, output = outfile_root) - fre_logger.debug('Done with splitting by month, outfile_root = %s', outfile_root) - - fre_logger.info('done averaging') - fre_logger.info('output file created: %s', outfile) - return 0 + msg = ( + "WARNING *** CDO/python-cdo has been REMOVED from fre-cli. " + "pkg='cdo' now uses the xarray time-averager under the hood. " + "Please switch to pkg='xarray' or pkg='fre-python-tools'. ***" + ) + warnings.warn(msg, FutureWarning, stacklevel=2) + fre_logger.warning(msg) + return super().generate_timavg(infile=infile, outfile=outfile) diff --git a/fre/app/generate_time_averages/frenctoolsTimeAverager.py b/fre/app/generate_time_averages/frenctoolsTimeAverager.py index 59ebf95ce..4049ad4d9 100644 --- a/fre/app/generate_time_averages/frenctoolsTimeAverager.py +++ b/fre/app/generate_time_averages/frenctoolsTimeAverager.py @@ -5,7 +5,7 @@ from subprocess import Popen, PIPE from pathlib import Path -from cdo import Cdo +import xarray as xr from .timeAverager import timeAverager fre_logger = logging.getLogger(__name__) @@ -80,38 +80,40 @@ def generate_timavg(self, infile=None, outfile=None): month_output_file_paths[month_index] = os.path.join( output_dir, f"{Path(outfile).stem}.{month_index:02d}.nc") - cdo = Cdo() - #Loop through each month and select the corresponding data - for month_index in month_indices: - - #month_name = month_names[month_index - 1] - nc_monthly_file = nc_month_file_paths[month_index] - - #Select data for the given month - cdo.select(f"month={month_index}", input=infile, output=nc_monthly_file) - - #Run timavg command for newly created file - month_output_file = month_output_file_paths[month_index] - #timavgcsh_command=['timavg.csh', '-mb','-o', month_output_file, nc_monthly_file] - timavgcsh_command=[shutil.which('timavg.csh'), '-dmb','-o', month_output_file, nc_monthly_file] - fre_logger.info( 'timavgcsh_command is %s', ' '.join(timavgcsh_command) ) - exitstatus=1 - with Popen(timavgcsh_command, - stdout=PIPE, stderr=PIPE, shell=False) as subp: - stdout, stderr = subp.communicate() - stdoutput=stdout.decode() - fre_logger.info('output= %s', stdoutput) - stderror=stderr.decode() - fre_logger.info('error = %s', stderror ) - - if subp.returncode != 0: - fre_logger.error('stderror = %s', stderror) - raise ValueError(f'error: timavg.csh had a problem, subp.returncode = {subp.returncode}') - - fre_logger.info('%s climatology successfully ran',nc_monthly_file) - exitstatus=0 - - #Delete files after being used to generate output files + with xr.open_dataset(infile) as ds_in: + #Loop through each month and select the corresponding data + for month_index in month_indices: + + #month_name = month_names[month_index - 1] + nc_monthly_file = nc_month_file_paths[month_index] + + #Select data for the given month + month_ds = ds_in.sel(time=ds_in['time'].dt.month == month_index) + month_ds.to_netcdf(nc_monthly_file) + month_ds.close() + + #Run timavg command for newly created file + month_output_file = month_output_file_paths[month_index] + #timavgcsh_command=['timavg.csh', '-mb','-o', month_output_file, nc_monthly_file] + timavgcsh_command=[shutil.which('timavg.csh'), '-dmb','-o', month_output_file, nc_monthly_file] + fre_logger.info( 'timavgcsh_command is %s', ' '.join(timavgcsh_command) ) + exitstatus=1 + with Popen(timavgcsh_command, + stdout=PIPE, stderr=PIPE, shell=False) as subp: + stdout, stderr = subp.communicate() + stdoutput=stdout.decode() + fre_logger.info('output= %s', stdoutput) + stderror=stderr.decode() + fre_logger.info('error = %s', stderror ) + + if subp.returncode != 0: + fre_logger.error('stderror = %s', stderror) + raise ValueError(f'error: timavg.csh had a problem, subp.returncode = {subp.returncode}') + + fre_logger.info('%s climatology successfully ran',nc_monthly_file) + exitstatus=0 + + #Delete files after being used to generate output files shutil.rmtree('monthly_nc_files') if self.avg_type == 'month': #End here if month variable used diff --git a/fre/app/generate_time_averages/frepytoolsTimeAverager.py b/fre/app/generate_time_averages/frepytoolsTimeAverager.py index c91220b3b..6459938af 100644 --- a/fre/app/generate_time_averages/frepytoolsTimeAverager.py +++ b/fre/app/generate_time_averages/frepytoolsTimeAverager.py @@ -9,7 +9,7 @@ fre_logger = logging.getLogger(__name__) -class frepytoolsTimeAverager(timeAverager): +class NumpyTimeAverager(timeAverager): # pylint: disable=invalid-name ''' class inheriting from abstract base class timeAverager generates time-averages using a python-native approach @@ -256,3 +256,6 @@ def generate_timavg(self, infile = None, outfile = None): fre_logger.debug('input file closed') return 0 + +# backward-compatible alias +frepytoolsTimeAverager = NumpyTimeAverager # pylint: disable=invalid-name diff --git a/fre/app/generate_time_averages/generate_time_averages.py b/fre/app/generate_time_averages/generate_time_averages.py index 2d0f0a874..a69020cc8 100755 --- a/fre/app/generate_time_averages/generate_time_averages.py +++ b/fre/app/generate_time_averages/generate_time_averages.py @@ -3,16 +3,19 @@ import os import logging import time +import warnings from typing import Optional, List, Union -from cdo import Cdo +import xarray as xr -from .cdoTimeAverager import cdoTimeAverager from .frenctoolsTimeAverager import frenctoolsTimeAverager -from .frepytoolsTimeAverager import frepytoolsTimeAverager +from .frepytoolsTimeAverager import NumpyTimeAverager +from .xarrayTimeAverager import xarrayTimeAverager fre_logger = logging.getLogger(__name__) +VALID_PKGS = ['cdo', 'fre-nctools', 'fre-python-tools', 'xarray'] + def generate_time_average(infile: Union[str, List[str]] = None, outfile: str = None, pkg: str = None, @@ -26,7 +29,9 @@ def generate_time_average(infile: Union[str, List[str]] = None, :type infile: str, list :param outfile: path to where output file should be stored :type outfile: str - :param pkg: which package to use to calculate climatology (cdo, fre-nctools, fre-python-tools) + :param pkg: which package to use to calculate climatology + ('xarray', 'fre-python-tools', 'fre-nctools', or 'cdo') + 'cdo' is kept for backward compat but silently uses xarray. :type pkg: str :param var: optional, not currently supported and defaults to None :type var: str @@ -41,12 +46,12 @@ def generate_time_average(infile: Union[str, List[str]] = None, fre_logger.debug('called generate_time_average') if None in [infile, outfile, pkg]: raise ValueError('infile, outfile, and pkg are required inputs') - if pkg not in ['cdo', 'fre-nctools', 'fre-python-tools']: - raise ValueError(f'argument pkg = {pkg} not known, must be one of: cdo, fre-nctools, fre-python-tools') + if pkg not in VALID_PKGS: + raise ValueError(f'argument pkg = {pkg} not known, must be one of: {", ".join(VALID_PKGS)}') exitstatus = 1 myavger = None - # multiple files case Use cdo to merge multiple files if present + # multiple files case - merge multiple files if present merged = False orig_infile_list = None if all ( [ type(infile).__name__ == 'list', @@ -54,13 +59,13 @@ def generate_time_average(infile: Union[str, List[str]] = None, fre_logger.info('list input argument detected') infile_str = [str(item) for item in infile] - _cdo = Cdo() merged_file = "merged_output.nc" - fre_logger.info('calling cdo mergetime') + fre_logger.info('merging input files with xarray') fre_logger.debug('output: %s', merged_file) fre_logger.debug('inputs: \n %s', ' '.join(infile_str) ) - _cdo.mergetime(input = ' '.join(infile_str), output = merged_file) + with xr.open_mfdataset(infile_str, combine='by_coords') as ds: + ds.to_netcdf(merged_file) # preserve the original file names for later orig_infile_list = infile @@ -69,11 +74,27 @@ def generate_time_average(infile: Union[str, List[str]] = None, fre_logger.info('file merging success') if pkg == 'cdo': - fre_logger.info('creating a cdoTimeAverager') - myavger = cdoTimeAverager( pkg = pkg, - var = var, - unwgt = unwgt, - avg_type = avg_type ) + # CDO has been removed — warn loudly, use xarray instead + msg = ( + "WARNING *** CDO/python-cdo has been REMOVED from fre-cli. " + "pkg='cdo' now uses the xarray time-averager under the hood. " + "Please switch to pkg='xarray' or pkg='fre-python-tools'. ***" + ) + warnings.warn(msg, FutureWarning, stacklevel=2) + fre_logger.warning(msg) + fre_logger.info('creating an xarrayTimeAverager (via pkg=cdo redirect)') + myavger = xarrayTimeAverager( pkg = pkg, + var = var, + unwgt = unwgt, + avg_type = avg_type ) + + elif pkg == 'xarray': + fre_logger.info('creating an xarrayTimeAverager') + myavger = xarrayTimeAverager( pkg = pkg, + var = var, + unwgt = unwgt, + avg_type = avg_type ) + elif pkg == 'fre-nctools': fre_logger.info('creating a frenctoolsTimeAverager') myavger = frenctoolsTimeAverager( pkg = pkg, @@ -87,11 +108,11 @@ def generate_time_average(infile: Union[str, List[str]] = None, var = orig_infile_list[0].split('/').pop().split('.')[-2] fre_logger.warning('extracted var = %s from orig_infile_list[0] = %s', var, orig_infile_list[0] ) - fre_logger.info('creating a frepytoolsTimeAverager') - myavger = frepytoolsTimeAverager( pkg = pkg, - var = var, - unwgt = unwgt, - avg_type = avg_type ) + fre_logger.info('creating a NumpyTimeAverager') + myavger = NumpyTimeAverager( pkg = pkg, + var = var, + unwgt = unwgt, + avg_type = avg_type ) # workload if myavger is not None: diff --git a/fre/app/generate_time_averages/tests/test_cdoTimeAverager.py b/fre/app/generate_time_averages/tests/test_cdoTimeAverager.py index 9ebbc4937..2ed4c46db 100644 --- a/fre/app/generate_time_averages/tests/test_cdoTimeAverager.py +++ b/fre/app/generate_time_averages/tests/test_cdoTimeAverager.py @@ -10,3 +10,13 @@ def test_cdotimavg_init_error(): test_avgr = cdo_timavg.cdoTimeAverager(pkg = 'cdo', var = None, unwgt = False, avg_type = 'FOO') test_avgr.generate_timavg(infile = None, outfile = None) + +def test_cdotimavg_warns_future(): + ''' test that FutureWarning is emitted when generate_timavg is called ''' + with pytest.warns(FutureWarning, match='CDO/python-cdo has been REMOVED'): + test_avgr = cdo_timavg.cdoTimeAverager(pkg = 'cdo', var = None, unwgt = False, avg_type = 'all') + # this will fail because infile is None, but the warning should fire first + try: + test_avgr.generate_timavg(infile = 'nonexistent.nc', outfile = 'out.nc') + except (FileNotFoundError, ValueError, OSError): + pass diff --git a/fre/app/generate_time_averages/tests/test_generate_time_averages.py b/fre/app/generate_time_averages/tests/test_generate_time_averages.py index 2f7209b94..5ffde0598 100644 --- a/fre/app/generate_time_averages/tests/test_generate_time_averages.py +++ b/fre/app/generate_time_averages/tests/test_generate_time_averages.py @@ -94,21 +94,22 @@ def test_time_avg_file_dir_exists(): FULL_TEST_FILE_PATH = TIME_AVG_FILE_DIR + TEST_FILE_NAME cases=[ - #cdo cases, monthly, one/multiple files, weighted + # cdo cases — CDO removed, but entry point still works (redirects to xarray) + # monthly, one/multiple files, weighted pytest.param( 'cdo', 'month', True , FULL_TEST_FILE_PATH, TIME_AVG_FILE_DIR + 'ymonmean_unwgt_' + TEST_FILE_NAME), pytest.param( 'cdo', 'month', True , TWO_TEST_FILE_NAMES, TIME_AVG_FILE_DIR + 'ymonmean_unwgt_' + TWO_OUT_FILE_NAME), - #cdo cases, seasonal, one/multiple files, unweighted + # seasonal, one/multiple files, unweighted pytest.param( 'cdo', 'seas', True , FULL_TEST_FILE_PATH, TIME_AVG_FILE_DIR + 'yseasmean_unwgt_' + TEST_FILE_NAME), pytest.param( 'cdo', 'seas', True , TWO_TEST_FILE_NAMES, TIME_AVG_FILE_DIR + 'yseasmean_unwgt_' + TWO_OUT_FILE_NAME), - #cdo cases, all, one/multiple files, weighted/unweighted + # all, one/multiple files, weighted/unweighted pytest.param( 'cdo', 'all', True , FULL_TEST_FILE_PATH, STR_UNWGT_CDO_INF), @@ -121,6 +122,43 @@ def test_time_avg_file_dir_exists(): pytest.param( 'cdo', 'all', False , TWO_TEST_FILE_NAMES, TIME_AVG_FILE_DIR + 'timmean_' + TWO_OUT_FILE_NAME), + # xarray cases — all avg_types, single and multi-file + pytest.param( 'xarray', 'all', True , + FULL_TEST_FILE_PATH, + TIME_AVG_FILE_DIR + 'xarray_unwgt_timavg_' + TEST_FILE_NAME), + pytest.param( 'xarray', 'all', False , + FULL_TEST_FILE_PATH, + TIME_AVG_FILE_DIR + 'xarray_wgt_timavg_' + TEST_FILE_NAME), + pytest.param( 'xarray', 'all', True , + TWO_TEST_FILE_NAMES, + TIME_AVG_FILE_DIR + 'xarray_unwgt_timavg_' + TWO_OUT_FILE_NAME), + pytest.param( 'xarray', 'all', False , + TWO_TEST_FILE_NAMES, + TIME_AVG_FILE_DIR + 'xarray_wgt_timavg_' + TWO_OUT_FILE_NAME), + pytest.param( 'xarray', 'seas', True , + FULL_TEST_FILE_PATH, + TIME_AVG_FILE_DIR + 'xarray_seas_unwgt_' + TEST_FILE_NAME), + pytest.param( 'xarray', 'seas', False , + FULL_TEST_FILE_PATH, + TIME_AVG_FILE_DIR + 'xarray_seas_wgt_' + TEST_FILE_NAME), + pytest.param( 'xarray', 'month', True , + FULL_TEST_FILE_PATH, + TIME_AVG_FILE_DIR + 'xarray_month_unwgt_' + TEST_FILE_NAME), + pytest.param( 'xarray', 'month', False , + FULL_TEST_FILE_PATH, + TIME_AVG_FILE_DIR + 'xarray_month_wgt_' + TEST_FILE_NAME), + pytest.param( 'xarray', 'seas', True , + TWO_TEST_FILE_NAMES, + TIME_AVG_FILE_DIR + 'xarray_seas_unwgt_' + TWO_OUT_FILE_NAME), + pytest.param( 'xarray', 'seas', False , + TWO_TEST_FILE_NAMES, + TIME_AVG_FILE_DIR + 'xarray_seas_wgt_' + TWO_OUT_FILE_NAME), + pytest.param( 'xarray', 'month', True , + TWO_TEST_FILE_NAMES, + TIME_AVG_FILE_DIR + 'xarray_month_unwgt_' + TWO_OUT_FILE_NAME), + pytest.param( 'xarray', 'month', False , + TWO_TEST_FILE_NAMES, + TIME_AVG_FILE_DIR + 'xarray_month_wgt_' + TWO_OUT_FILE_NAME), #fre-python-tools cases, all, one/multiple files, weighted/unweighted flag pytest.param( 'fre-python-tools', 'all', False , FULL_TEST_FILE_PATH, @@ -243,10 +281,10 @@ def test_compare_fre_cli_to_fre_nctools(): assert not( (non_zero_count > 0.) or (non_zero_count < 0.) ), "non-zero diffs between frepy / frenctools were found" -@pytest.mark.xfail(reason = 'test fails b.c. cdo cannot bitwise-reproduce fre-nctools answer') +@pytest.mark.xfail(reason = 'cdo entry-point now uses xarray — result format differs from old CDO output') def test_compare_fre_cli_to_cdo(): ''' - compares fre_cli pkg answer to cdo pkg answer + compares fre_cli pkg answer to cdo pkg answer (cdo now redirects to xarray) ''' assert Path(STR_FRE_PYTOOLS_INF).exists(), f'DNE: STR_FRE_PYTOOLS_INF = {STR_FRE_PYTOOLS_INF}' fre_pytools_inf = Dataset(STR_FRE_PYTOOLS_INF, 'r') @@ -273,6 +311,7 @@ def test_compare_fre_cli_to_cdo(): assert not( (non_zero_count > 0.) or (non_zero_count < 0.) ), "non-zero diffs between cdo / frepytools were found" +@pytest.mark.xfail(reason = 'cdo entry-point now uses xarray — result format differs from old CDO output') def test_compare_unwgt_fre_cli_to_unwgt_cdo(): ''' compares fre_cli pkg answer to cdo pkg answer @@ -301,10 +340,10 @@ def test_compare_unwgt_fre_cli_to_unwgt_cdo(): non_zero_count = np.count_nonzero(diff_pytools_cdo_timavg[:]) assert not( (non_zero_count > 0.) or (non_zero_count < 0.) ), "non-zero diffs between cdo / frepytools were found" -@pytest.mark.xfail(reason = 'test fails b.c. cdo cannot bitwise-reproduce fre-nctools answer') +@pytest.mark.xfail(reason = 'cdo entry-point now uses xarray — result format differs from old CDO output') def test_compare_cdo_to_fre_nctools(): ''' - compares cdo pkg answer to fre_nctools pkg answer + compares cdo pkg answer to fre_nctools pkg answer (cdo now redirects to xarray) ''' assert Path(STR_FRENCTOOLS_INF).exists(), f'DNE: STR_FRENCTOOLS_INF = {STR_FRENCTOOLS_INF}' diff --git a/fre/app/generate_time_averages/tests/test_wrapper.py b/fre/app/generate_time_averages/tests/test_wrapper.py index bec909464..e815998a1 100644 --- a/fre/app/generate_time_averages/tests/test_wrapper.py +++ b/fre/app/generate_time_averages/tests/test_wrapper.py @@ -170,7 +170,7 @@ def test_monthly_av_from_monthly_ts(create_monthly_timeseries): assert file_.exists() -# CDO-based tests +# CDO-based tests — CDO has been removed, entry point redirects to xarray def test_cdo_annual_av_from_monthly_ts(create_monthly_timeseries): """ Generate annual average from monthly timeseries using CDO @@ -206,7 +206,7 @@ def test_cdo_annual_av_from_monthly_ts(create_monthly_timeseries): def test_cdo_annual_av_from_annual_ts(create_annual_timeseries): """ - Generate annual average from annual timeseries using CDO + Generate annual average from annual timeseries using CDO (redirects to xarray) """ cycle_point = '0002-01-01' output_interval = 'P2Y' @@ -239,7 +239,7 @@ def test_cdo_annual_av_from_annual_ts(create_annual_timeseries): def test_cdo_monthly_av_from_monthly_ts(create_monthly_timeseries): """ - Generate monthly climatology from monthly timeseries using CDO + Generate monthly climatology from monthly timeseries using CDO (redirects to xarray) """ cycle_point = '1980-01-01' output_interval = 'P2Y' @@ -275,8 +275,7 @@ def test_cdo_monthly_av_from_monthly_ts(create_monthly_timeseries): @pytest.mark.xfail(reason="no timavg.csh") def test_cdo_fre_nctools_equivalence(create_monthly_timeseries): """ - Test that CDO produces equivalent results to fre-nctools when timavg.csh is available. - If timavg.csh is not available, the test will be skipped. + Test that CDO (now xarray) produces equivalent results to fre-nctools when timavg.csh is available. """ cycle_point = '1980-01-01' @@ -357,3 +356,80 @@ def test_freq_not_valid_valueerror(): output_interval = 'P999Y', frequency = 'FOO', grid = 'BAR') + + +# xarray-based wrapper tests +def test_xarray_annual_av_from_monthly_ts(create_monthly_timeseries): + """ + Generate annual average from monthly timeseries using xarray + """ + cycle_point = '1980-01-01' + output_interval = 'P2Y' + input_interval = 'P1Y' + grid = '180_288.conserve_order2' + sources = ['atmos_month'] + frequency = 'yr' + pkg = 'xarray' + + wrapper.generate_wrapper(cycle_point, str(create_monthly_timeseries), sources, + output_interval, input_interval, grid, frequency, pkg) + + output_dir = Path(create_monthly_timeseries, 'av', grid, 'atmos_month', 'P1Y', output_interval) + output_files = [ + output_dir / 'atmos_month.1980-1981.alb_sfc.nc', + output_dir / 'atmos_month.1980-1981.aliq.nc' + ] + + for file_ in output_files: + assert file_.exists() + + +def test_xarray_annual_av_from_annual_ts(create_annual_timeseries): + """ + Generate annual average from annual timeseries using xarray + """ + cycle_point = '0002-01-01' + output_interval = 'P2Y' + input_interval = 'P1Y' + grid = '180_288.conserve_order1' + sources = ['tracer_level'] + frequency = 'yr' + pkg = 'xarray' + + wrapper.generate_wrapper(cycle_point, str(create_annual_timeseries), sources, + output_interval, input_interval, grid, frequency, pkg) + + output_dir = Path(create_annual_timeseries, 'av', grid, 'tracer_level', 'P1Y', output_interval) + output_files = [ + output_dir / 'tracer_level.0002-0003.radon.nc', + output_dir / 'tracer_level.0002-0003.scale_salt_emis.nc' + ] + + for file_ in output_files: + assert file_.exists() + + +def test_xarray_monthly_av_from_monthly_ts(create_monthly_timeseries): + """ + Generate monthly climatology from monthly timeseries using xarray + """ + cycle_point = '1980-01-01' + output_interval = 'P2Y' + input_interval = 'P1Y' + grid = '180_288.conserve_order2' + sources = ['atmos_month'] + frequency = 'mon' + pkg = 'xarray' + + wrapper.generate_wrapper(cycle_point, str(create_monthly_timeseries), sources, + output_interval, input_interval, grid, frequency, pkg) + + output_dir = Path(create_monthly_timeseries, 'av', grid, 'atmos_month', 'P1M', output_interval) + output_files = [ + output_dir / 'atmos_month.1980-1981.alb_sfc', + output_dir / 'atmos_month.1980-1981.aliq', + ] + for f in output_files: + for i in range(1,13): + file_ = Path(str(f) + f".{i:02d}.nc") + assert file_.exists() diff --git a/fre/app/generate_time_averages/tests/test_xarrayTimeAverager.py b/fre/app/generate_time_averages/tests/test_xarrayTimeAverager.py new file mode 100644 index 000000000..d8e0efbc7 --- /dev/null +++ b/fre/app/generate_time_averages/tests/test_xarrayTimeAverager.py @@ -0,0 +1,582 @@ +''' +Comprehensive unit tests for xarrayTimeAverager and its helper functions. + +Tests cover: + - _is_numeric() — dtype classification helper + - _compute_time_weights() — time-weight extraction from time_bnds + - _weighted_time_mean() — correctness of weighted global mean + - _weighted_seasonal_mean() — correctness of weighted seasonal groupby + - _weighted_monthly_mean() — correctness of weighted monthly groupby + - xarrayTimeAverager.generate_timavg() — full round-trip via NetCDF files +''' + +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from fre.app.generate_time_averages.xarrayTimeAverager import ( + _compute_time_weights, + _is_numeric, + _weighted_monthly_mean, + _weighted_seasonal_mean, + _weighted_time_mean, + xarrayTimeAverager, +) + +# --------------------------------------------------------------------------- +# Dataset factory helpers +# --------------------------------------------------------------------------- + +# Days per month for a standard (non-leap) year, January → December +_DAYS_IN_MONTH = np.array([31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31], + dtype='float64') + + +def _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float'): + ''' + Build a minimal monthly xr.Dataset with one numeric variable ('temp'). + + The ``time`` coordinate uses actual ``datetime64`` values so that the + ``.dt`` accessor (needed for ``.dt.season``, ``.dt.month``) works + on the in-memory dataset without requiring a NetCDF round-trip. + + Parameters + ---------- + n_years : int + How many full calendar years to include (12*n_years timesteps). + with_bnds : bool + Whether to include a ``time_bnds`` variable. + time_bnds_encoding : str + How to encode ``time_bnds``: + - ``'float'`` → plain float64 day-counts (numeric path) + - ``'timedelta'`` → numpy datetime64 edges (timedelta64 diff path) + ''' + n_months = 12 * n_years + + # Month-start dates (Jan 1, Feb 1, …) anchored at 2001 (non-leap year) + month_starts = pd.date_range('2001-01-01', periods=n_months, freq='MS') + # Month midpoints (15th of each month) used as the time coordinate + times = month_starts + pd.Timedelta(days=15) + # Month lengths in days for the non-leap year 2001 + # Use pd.DateOffset to compute exact month lengths from the actual dates + month_ends = month_starts + pd.offsets.MonthEnd(1) + pd.Timedelta(days=1) + days = np.array([(e - s).days for s, e in zip(month_starts, month_ends)], + dtype='float64') + + # values: month-number (1 … 12) cycling over years, stored as float32 + values = np.tile(np.arange(1, 13, dtype='float32'), n_years).reshape(n_months, 1, 1) + + data_vars = { + 'temp': xr.DataArray(values, dims=['time', 'lat', 'lon'], + attrs={'units': 'K', 'long_name': 'temperature'}), + } + + if with_bnds: + if time_bnds_encoding == 'float': + # plain float64 day-counts (numeric path in _compute_time_weights) + t0 = np.concatenate([[0.0], np.cumsum(days[:-1])]) + t1 = np.cumsum(days) + bnds_vals = np.stack([t0, t1], axis=1) + data_vars['time_bnds'] = xr.DataArray(bnds_vals, dims=['time', 'bnds']) + elif time_bnds_encoding == 'timedelta': + # datetime64 edges → difference gives timedelta64 (timedelta path) + starts = month_starts.values.astype('datetime64[D]') + ends = month_ends.values.astype('datetime64[D]') + bnds_vals = np.stack([starts, ends], axis=1) + data_vars['time_bnds'] = xr.DataArray(bnds_vals, dims=['time', 'bnds']) + + return xr.Dataset(data_vars, coords={'time': times}), days + + +def _make_ds_with_nonnumeric_var(): + ''' + Dataset that contains a datetime64 variable alongside a numeric one. + Simulates the ``average_T1``/``average_T2`` pattern from GFDL atmos files. + ''' + n = 4 + month_starts = pd.date_range('2001-01-01', periods=n, freq='MS') + times = month_starts + pd.Timedelta(days=15) + dt_vals = month_starts.values + + return xr.Dataset( + { + 'temp': xr.DataArray([1.0, 2.0, 3.0, 4.0], dims=['time'], + attrs={'units': 'K'}), + 'start_time': xr.DataArray(dt_vals, dims=['time']), # datetime64 — non-numeric + 'time_bnds': xr.DataArray( + np.stack([np.arange(n, dtype='float64'), + np.arange(1, n+1, dtype='float64')], axis=1), + dims=['time', 'bnds']), + }, + coords={'time': times}, + ) + + +# --------------------------------------------------------------------------- +# Tests for _is_numeric() +# --------------------------------------------------------------------------- + +class TestIsNumeric: + '''Unit tests for the _is_numeric() helper.''' + + def test_float32(self): + da = xr.DataArray(np.array([1.0], dtype='float32')) + assert _is_numeric(da) + + def test_float64(self): + da = xr.DataArray(np.array([1.0], dtype='float64')) + assert _is_numeric(da) + + def test_int32(self): + da = xr.DataArray(np.array([1], dtype='int32')) + assert _is_numeric(da) + + def test_int64(self): + da = xr.DataArray(np.array([1], dtype='int64')) + assert _is_numeric(da) + + def test_uint8(self): + da = xr.DataArray(np.array([1], dtype='uint8')) + assert _is_numeric(da) + + def test_datetime64_is_not_numeric(self): + da = xr.DataArray(np.array(['2000-01-01'], dtype='datetime64[D]')) + assert not _is_numeric(da) + + def test_timedelta64_is_not_numeric(self): + da = xr.DataArray(np.array([1], dtype='timedelta64[D]')) + assert not _is_numeric(da) + + def test_object_is_not_numeric(self): + da = xr.DataArray(np.array(['hello'], dtype=object)) + assert not _is_numeric(da) + + +# --------------------------------------------------------------------------- +# Tests for _compute_time_weights() +# --------------------------------------------------------------------------- + +class TestComputeTimeWeights: + '''Unit tests for the _compute_time_weights() helper.''' + + def test_float_bnds_returns_correct_days(self): + ds, days = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + weights = _compute_time_weights(ds) + assert weights.dtype == np.float64 + np.testing.assert_allclose(weights.values, days) + + def test_timedelta_bnds_returns_correct_days(self): + '''time_bnds stored as datetime64 edges → difference is timedelta64.''' + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='timedelta') + weights = _compute_time_weights(ds) + assert weights.dtype == np.float64 + # compute expected days from the actual bounds stored in the dataset + bnds = ds['time_bnds'].values + expected = (bnds[:, 1] - bnds[:, 0]).astype('timedelta64[D]').astype('float64') + np.testing.assert_allclose(weights.values, expected, atol=1e-9) + + def test_no_bnds_fallback_uniform_weights(self): + '''Without time_bnds, weights should all equal 1.0.''' + ds, _ = _make_monthly_ds(n_years=1, with_bnds=False) + weights = _compute_time_weights(ds) + assert weights.dtype == np.float64 + assert len(weights) == 12 + np.testing.assert_array_equal(weights.values, np.ones(12)) + + def test_weights_dim_is_time(self): + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + weights = _compute_time_weights(ds) + assert 'time' in weights.dims + + def test_two_timestep_known_values(self): + '''Minimal two-timestep dataset with explicit bounds: Jan(31d) + Feb(28d).''' + ds = xr.Dataset({ + 'temp': xr.DataArray([1.0, 2.0], dims=['time']), + 'time_bnds': xr.DataArray([[0.0, 31.0], [31.0, 59.0]], dims=['time', 'bnds']), + }) + weights = _compute_time_weights(ds) + np.testing.assert_allclose(weights.values, [31.0, 28.0]) + + +# --------------------------------------------------------------------------- +# Tests for _weighted_time_mean() +# --------------------------------------------------------------------------- + +class TestWeightedTimeMean: + '''Unit tests for _weighted_time_mean() correctness.''' + + def test_two_timestep_known_weighted_mean(self): + ''' + Jan(31d)=1.0, Feb(28d)=2.0 → weighted mean = (1*31 + 2*28) / (31+28). + ''' + ds = xr.Dataset({ + 'temp': xr.DataArray([1.0, 2.0], dims=['time']), + 'time_bnds': xr.DataArray([[0.0, 31.0], [31.0, 59.0]], dims=['time', 'bnds']), + }) + result = _weighted_time_mean(ds) + expected = (1.0 * 31 + 2.0 * 28) / (31 + 28) + np.testing.assert_allclose(float(result['temp']), expected, rtol=1e-6) + + def test_uniform_weights_equal_arithmetic_mean(self): + '''When all timesteps have equal weight, weighted = unweighted mean.''' + vals = np.arange(1.0, 5.0) + ds = xr.Dataset({ + 'temp': xr.DataArray(vals, dims=['time']), + 'time_bnds': xr.DataArray([[float(i), float(i+1)] for i in range(4)], + dims=['time', 'bnds']), + }) + result = _weighted_time_mean(ds) + np.testing.assert_allclose(float(result['temp']), vals.mean(), rtol=1e-6) + + def test_time_dim_eliminated(self): + '''Output should have no time dimension.''' + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + result = _weighted_time_mean(ds) + assert 'time' not in result['temp'].dims + + def test_non_time_vars_preserved(self): + '''Variables without time dimension are passed through unchanged.''' + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + # lat dim in _make_monthly_ds has size 1; match that size here + ds = ds.assign({'static': xr.DataArray([42.0], dims=['lat'])}) + result = _weighted_time_mean(ds) + assert 'static' in result + np.testing.assert_array_equal(result['static'].values, [42.0]) + + def test_time_bnds_excluded_from_output(self): + '''time_bnds should not appear in the weighted mean output.''' + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + result = _weighted_time_mean(ds) + assert 'time_bnds' not in result + + def test_nonnumeric_time_var_gets_first_value(self): + '''datetime64 time-dependent variables get the value from timestep 0.''' + ds = _make_ds_with_nonnumeric_var() + result = _weighted_time_mean(ds) + # 'start_time' is datetime64 → should be scalar == ds['start_time'].isel(time=0) + assert 'start_time' in result + assert result['start_time'].values == ds['start_time'].values[0] + + def test_attrs_preserved(self): + '''Dataset and variable attributes should be preserved.''' + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + result = _weighted_time_mean(ds) + assert result['temp'].attrs.get('units') == 'K' + + +# --------------------------------------------------------------------------- +# Tests for _weighted_seasonal_mean() +# --------------------------------------------------------------------------- + +class TestWeightedSeasonalMean: + '''Unit tests for _weighted_seasonal_mean() correctness.''' + + def test_season_dim_present(self): + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + result = _weighted_seasonal_mean(ds) + assert 'season' in result['temp'].dims + + def test_four_seasons_present(self): + '''A full year should produce all four seasons in the output.''' + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + result = _weighted_seasonal_mean(ds) + seasons = set(result['season'].values) + assert seasons == {'DJF', 'MAM', 'JJA', 'SON'} + + def test_time_bnds_excluded(self): + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + result = _weighted_seasonal_mean(ds) + assert 'time_bnds' not in result + + def test_mam_weighted_value(self): + ''' + MAM (Mar=31, Apr=30, May=31) with values (3,4,5): + weighted mean = (3*31 + 4*30 + 5*31) / (31+30+31) = 368/92 ≈ 4.0 + ''' + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + result = _weighted_seasonal_mean(ds) + mam_val = float(result['temp'].sel(season='MAM').values.flat[0]) + np.testing.assert_allclose(mam_val, 368.0 / 92.0, rtol=1e-4) + + def test_jja_weighted_value(self): + ''' + JJA (Jun=30, Jul=31, Aug=31) with values (6,7,8): + weighted mean = (6*30 + 7*31 + 8*31) / (30+31+31) = 645/92 + ''' + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + result = _weighted_seasonal_mean(ds) + jja_val = float(result['temp'].sel(season='JJA').values.flat[0]) + np.testing.assert_allclose(jja_val, 645.0 / 92.0, rtol=1e-4) + + def test_son_weighted_value(self): + ''' + SON (Sep=30, Oct=31, Nov=30) with values (9,10,11): + weighted mean = (9*30 + 10*31 + 11*30) / (30+31+30) = 910/91 = 10.0 + ''' + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + result = _weighted_seasonal_mean(ds) + son_val = float(result['temp'].sel(season='SON').values.flat[0]) + np.testing.assert_allclose(son_val, 910.0 / 91.0, rtol=1e-4) + + def test_nonnumeric_vars_excluded_from_time_groupby(self): + '''Non-numeric time-dependent variables should not appear in the output.''' + ds = _make_ds_with_nonnumeric_var() + result = _weighted_seasonal_mean(ds) + # 'temp' (float) should be present; 'start_time' (datetime64) should be excluded + assert 'temp' in result + assert 'start_time' not in result + + +# --------------------------------------------------------------------------- +# Tests for _weighted_monthly_mean() +# --------------------------------------------------------------------------- + +class TestWeightedMonthlyMean: + '''Unit tests for _weighted_monthly_mean() correctness.''' + + def test_month_dim_present(self): + '''groupby(time.month) produces a "month" coordinate dimension.''' + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + result = _weighted_monthly_mean(ds) + assert 'month' in result['temp'].dims + + def test_twelve_months_present(self): + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + result = _weighted_monthly_mean(ds) + # groupby('time.month') produces a 'time' coordinate with 12 values + assert result['temp'].shape[0] == 12 + + def test_single_year_weighted_equals_unweighted(self): + ''' + With only one year of data, each month appears exactly once so + weighted and unweighted averages are identical. + ''' + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + weighted = _weighted_monthly_mean(ds) + unweighted = ds.groupby('time.month').mean(dim='time', keep_attrs=True) + np.testing.assert_allclose( + weighted['temp'].values, + unweighted['temp'].values, + rtol=1e-5, + ) + + def test_two_year_weighted_jan_mean(self): + ''' + With two years, January appears twice with identical weights (31 days + both years) and values 1.0 (year 1) and 1.0 (year 2) → mean = 1.0. + ''' + ds, _ = _make_monthly_ds(n_years=2, with_bnds=True, time_bnds_encoding='float') + result = _weighted_monthly_mean(ds) + # groupby(time.month) produces a 'month' coordinate; month=1 is January + jan_val = float(result['temp'].sel(month=1).values.flat[0]) + np.testing.assert_allclose(jan_val, 1.0, rtol=1e-5) + + def test_time_bnds_excluded(self): + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + result = _weighted_monthly_mean(ds) + assert 'time_bnds' not in result + + def test_nonnumeric_vars_excluded_from_time_groupby(self): + ds = _make_ds_with_nonnumeric_var() + result = _weighted_monthly_mean(ds) + assert 'temp' in result + assert 'start_time' not in result + + +# --------------------------------------------------------------------------- +# Integration tests for xarrayTimeAverager.generate_timavg() +# --------------------------------------------------------------------------- + +class TestXarrayTimeAveragerGenerateTimavg: + ''' + Integration tests that write a synthetic NetCDF to a temp dir, + run generate_timavg(), and verify the outputs. + ''' + + @pytest.fixture + def monthly_nc(self, tmp_path): + '''Write a 1-year monthly dataset to a temp NetCDF file.''' + ds, _ = _make_monthly_ds(n_years=1, with_bnds=True, time_bnds_encoding='float') + nc_path = tmp_path / 'monthly.nc' + ds.to_netcdf(nc_path) + return nc_path + + @pytest.fixture + def two_year_nc(self, tmp_path): + '''Write a 2-year monthly dataset to a temp NetCDF file.''' + ds, _ = _make_monthly_ds(n_years=2, with_bnds=True, time_bnds_encoding='float') + nc_path = tmp_path / 'monthly_2yr.nc' + ds.to_netcdf(nc_path) + return nc_path + + # ---- error path ---- + + def test_invalid_avg_type_raises_valueerror(self, monthly_nc, tmp_path): + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=True, avg_type='bogus') + with pytest.raises(ValueError, match='unknown avg_type'): + avgr.generate_timavg(infile=str(monthly_nc), + outfile=str(tmp_path / 'out.nc')) + + # ---- avg_type='all' ---- + + def test_all_unwgt_output_exists(self, monthly_nc, tmp_path): + outfile = tmp_path / 'out_all_unwgt.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=True, avg_type='all') + ret = avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) + assert ret == 0 + assert outfile.exists() + + def test_all_unwgt_no_time_dim(self, monthly_nc, tmp_path): + outfile = tmp_path / 'out_all_unwgt.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=True, avg_type='all') + avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) + result = xr.open_dataset(outfile) + assert 'time' not in result['temp'].dims + + def test_all_unwgt_returns_zero(self, monthly_nc, tmp_path): + outfile = tmp_path / 'out.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=True, avg_type='all') + assert avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) == 0 + + def test_all_wgt_output_exists(self, monthly_nc, tmp_path): + outfile = tmp_path / 'out_all_wgt.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=False, avg_type='all') + ret = avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) + assert ret == 0 + assert outfile.exists() + + def test_all_wgt_no_time_dim(self, monthly_nc, tmp_path): + outfile = tmp_path / 'out_all_wgt.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=False, avg_type='all') + avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) + result = xr.open_dataset(outfile) + assert 'time' not in result['temp'].dims + + def test_all_wgt_differs_from_unwgt_for_unequal_months(self, monthly_nc, tmp_path): + '''Weighted and unweighted global means should differ for unequal month lengths.''' + out_wgt = tmp_path / 'wgt.nc' + out_unwgt = tmp_path / 'unwgt.nc' + xarrayTimeAverager(pkg='xarray', var=None, unwgt=False, avg_type='all').generate_timavg( + infile=str(monthly_nc), outfile=str(out_wgt)) + xarrayTimeAverager(pkg='xarray', var=None, unwgt=True, avg_type='all').generate_timavg( + infile=str(monthly_nc), outfile=str(out_unwgt)) + wgt_val = float(xr.open_dataset(out_wgt)['temp'].values.flat[0]) + unwgt_val = float(xr.open_dataset(out_unwgt)['temp'].values.flat[0]) + assert wgt_val != pytest.approx(unwgt_val, rel=1e-4) + + def test_all_unwgt_correct_arithmetic_mean(self, monthly_nc, tmp_path): + '''Unweighted mean of values 1..12 should equal 6.5.''' + outfile = tmp_path / 'out.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=True, avg_type='all') + avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) + result_val = float(xr.open_dataset(outfile)['temp'].values.flat[0]) + np.testing.assert_allclose(result_val, 6.5, rtol=1e-5) + + # ---- avg_type='seas' ---- + + def test_seas_unwgt_output_exists(self, monthly_nc, tmp_path): + outfile = tmp_path / 'out_seas_unwgt.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=True, avg_type='seas') + ret = avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) + assert ret == 0 + assert outfile.exists() + + def test_seas_unwgt_has_season_dim(self, monthly_nc, tmp_path): + outfile = tmp_path / 'out_seas_unwgt.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=True, avg_type='seas') + avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) + result = xr.open_dataset(outfile) + assert 'season' in result['temp'].dims + + def test_seas_unwgt_four_seasons(self, monthly_nc, tmp_path): + outfile = tmp_path / 'out_seas_unwgt.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=True, avg_type='seas') + avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) + result = xr.open_dataset(outfile) + seasons = set(result['season'].values.tolist()) + assert seasons == {'DJF', 'MAM', 'JJA', 'SON'} + + def test_seas_wgt_output_exists(self, monthly_nc, tmp_path): + outfile = tmp_path / 'out_seas_wgt.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=False, avg_type='seas') + ret = avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) + assert ret == 0 + assert outfile.exists() + + def test_seas_wgt_has_season_dim(self, monthly_nc, tmp_path): + outfile = tmp_path / 'out_seas_wgt.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=False, avg_type='seas') + avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) + result = xr.open_dataset(outfile) + assert 'season' in result['temp'].dims + + def test_seas_wgt_mam_value(self, monthly_nc, tmp_path): + ''' + MAM = (3*31 + 4*30 + 5*31) / (31+30+31) = 368/92 ≈ 4.0. + Read back from file and verify. + ''' + outfile = tmp_path / 'out_seas_wgt.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=False, avg_type='seas') + avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) + result = xr.open_dataset(outfile) + mam_val = float(result['temp'].sel(season='MAM').values.flat[0]) + np.testing.assert_allclose(mam_val, 368.0 / 92.0, rtol=1e-4) + + # ---- avg_type='month' ---- + + def test_month_unwgt_output_exists(self, monthly_nc, tmp_path): + outfile = tmp_path / 'out_month_unwgt.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=True, avg_type='month') + ret = avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) + assert ret == 0 + assert outfile.exists() + + def test_month_unwgt_per_month_files_created(self, monthly_nc, tmp_path): + '''generate_timavg should write 12 per-month files named *.01.nc … *.12.nc.''' + outfile = tmp_path / 'out_month_unwgt.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=True, avg_type='month') + avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) + for m in range(1, 13): + month_file = tmp_path / f'out_month_unwgt.{m:02d}.nc' + assert month_file.exists(), f'missing per-month file: {month_file}' + + def test_month_wgt_output_exists(self, monthly_nc, tmp_path): + outfile = tmp_path / 'out_month_wgt.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=False, avg_type='month') + ret = avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) + assert ret == 0 + assert outfile.exists() + + def test_month_wgt_per_month_files_created(self, monthly_nc, tmp_path): + '''generate_timavg with unwgt=False should still write 12 per-month files.''' + outfile = tmp_path / 'out_month_wgt.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=False, avg_type='month') + avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) + for m in range(1, 13): + month_file = tmp_path / f'out_month_wgt.{m:02d}.nc' + assert month_file.exists(), f'missing per-month file: {month_file}' + + def test_month_wgt_jan_file_correct_value(self, two_year_nc, tmp_path): + ''' + With 2 years of identical January data (value=1.0, weight=31d both years), + the weighted January average should be 1.0. + ''' + outfile = tmp_path / 'out_month_wgt.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=False, avg_type='month') + avgr.generate_timavg(infile=str(two_year_nc), outfile=str(outfile)) + jan_file = tmp_path / 'out_month_wgt.01.nc' + result = xr.open_dataset(jan_file) + jan_val = float(result['temp'].values.flat[0]) + np.testing.assert_allclose(jan_val, 1.0, rtol=1e-5) + + def test_infile_not_modified(self, monthly_nc, tmp_path): + '''The input file must not be deleted or overwritten by the averager.''' + size_before = monthly_nc.stat().st_size + outfile = tmp_path / 'out.nc' + avgr = xarrayTimeAverager(pkg='xarray', var=None, unwgt=True, avg_type='all') + avgr.generate_timavg(infile=str(monthly_nc), outfile=str(outfile)) + assert monthly_nc.exists() + assert monthly_nc.stat().st_size == size_before diff --git a/fre/app/generate_time_averages/xarrayTimeAverager.py b/fre/app/generate_time_averages/xarrayTimeAverager.py new file mode 100644 index 000000000..4ba6a9ecf --- /dev/null +++ b/fre/app/generate_time_averages/xarrayTimeAverager.py @@ -0,0 +1,213 @@ +''' class using xarray for time-averages and climatology generation ''' + +import logging + +import numpy as np +import xarray as xr + +from .timeAverager import timeAverager + +fre_logger = logging.getLogger(__name__) + +# dtypes eligible for weighted averaging +_NUMERIC_KINDS = frozenset('fiuc') # float, integer, unsigned int, complex + + +def _is_numeric(data_array): + """return True if DataArray has a numeric dtype safe for arithmetic.""" + return data_array.dtype.kind in _NUMERIC_KINDS + + +class xarrayTimeAverager(timeAverager): + ''' + class inheriting from abstract base class timeAverager + generates time-averages using xarray. + supports avg_type 'all', 'seas', and 'month'. + ''' + + def generate_timavg(self, infile = None, outfile = None): + """ + use xarray to compute time-averages. + + :param self: instance of xarrayTimeAverager + :param infile: path to input NetCDF file, default is None + :type infile: str + :param outfile: path to output file, default is None + :type outfile: str + :return: 0 on success + :rtype: int + :raises ValueError: if avg_type is not recognized + """ + + if self.avg_type not in ['all', 'seas', 'month']: + fre_logger.error('requested unknown avg_type %s.', self.avg_type) + raise ValueError(f'requested unknown avg_type {self.avg_type}') + + fre_logger.info('xarrayTimeAverager: avg_type=%s, unwgt=%s', self.avg_type, self.unwgt) + + with xr.open_dataset(infile) as ds: + if self.avg_type == 'all': + fre_logger.info('time average over all time requested.') + if self.unwgt: + ds_avg = ds.mean(dim='time', keep_attrs=True) + else: + ds_avg = _weighted_time_mean(ds) + ds_avg.to_netcdf(outfile) + fre_logger.info('done averaging over all time.') + + elif self.avg_type == 'seas': + fre_logger.info('seasonal time-averages requested.') + if self.unwgt: + ds_avg = ds.groupby('time.season').mean(dim='time', keep_attrs=True) + else: + ds_avg = _weighted_seasonal_mean(ds) + ds_avg.to_netcdf(outfile) + fre_logger.info('done averaging over seasons.') + + elif self.avg_type == 'month': + fre_logger.info('monthly time-averages requested.') + if self.unwgt: + ds_avg = ds.groupby('time.month').mean(dim='time', keep_attrs=True) + else: + ds_avg = _weighted_monthly_mean(ds) + + # write full monthly file, then split into per-month files + outfile_str = str(outfile) + ds_avg.to_netcdf(outfile_str) + fre_logger.info('done averaging over months.') + + fre_logger.info('splitting by month') + outfile_root = outfile_str.removesuffix('.nc') + for month_val in ds_avg['month'].values: + month_ds = ds_avg.sel(month=month_val) + month_file = f'{outfile_root}.{int(month_val):02d}.nc' + month_ds.to_netcdf(month_file) + fre_logger.debug('wrote month file: %s', month_file) + + fre_logger.info('done averaging') + fre_logger.info('output file created: %s', outfile) + return 0 + + +def _weighted_time_mean(ds): + """ + compute weighted time-mean using time_bnds for weights. + non-numeric variables (e.g. datetime64 metadata like average_T1/T2) + retain their first value rather than being averaged. + + :param ds: xarray Dataset with 'time_bnds' variable + :type ds: xr.Dataset + :return: time-mean Dataset + :rtype: xr.Dataset + """ + weights = _compute_time_weights(ds) + weighted_vars = {} + for var_name in ds.data_vars: + if var_name == 'time_bnds': + continue + if 'time' in ds[var_name].dims: + if _is_numeric(ds[var_name]): + weighted_vars[var_name] = ( + (ds[var_name] * weights).sum(dim='time', keep_attrs=True) + / weights.sum() + ) + else: + # non-numeric time-dependent var (e.g. decoded datetime64) + weighted_vars[var_name] = ds[var_name].isel(time=0) + else: + weighted_vars[var_name] = ds[var_name] + return xr.Dataset(weighted_vars, attrs=ds.attrs) + + +def _weighted_seasonal_mean(ds): + """ + compute weighted seasonal mean using time_bnds for weights. + non-numeric time-dependent variables are dropped from the output. + + :param ds: xarray Dataset with 'time_bnds' variable + :type ds: xr.Dataset + :return: seasonal-mean Dataset grouped by season + :rtype: xr.Dataset + """ + weights = _compute_time_weights(ds) + season = ds['time'].dt.season + weighted_vars = {} + for var_name in ds.data_vars: + if var_name == 'time_bnds': + continue + if 'time' in ds[var_name].dims: + if _is_numeric(ds[var_name]): + weighted = ds[var_name] * weights + weighted_vars[var_name] = ( + weighted.groupby(season).sum(dim='time', keep_attrs=True) + / weights.groupby(season).sum(dim='time') + ) + else: + weighted_vars[var_name] = ds[var_name] + return xr.Dataset(weighted_vars, attrs=ds.attrs) + + +def _weighted_monthly_mean(ds): + """ + compute weighted monthly mean using time_bnds for weights. + non-numeric time-dependent variables are dropped from the output. + + :param ds: xarray Dataset with 'time_bnds' variable + :type ds: xr.Dataset + :return: monthly-mean Dataset grouped by month + :rtype: xr.Dataset + """ + weights = _compute_time_weights(ds) + month = ds['time'].dt.month + weighted_vars = {} + for var_name in ds.data_vars: + if var_name == 'time_bnds': + continue + if 'time' in ds[var_name].dims: + if _is_numeric(ds[var_name]): + weighted = ds[var_name] * weights + weighted_vars[var_name] = ( + weighted.groupby(month).sum(dim='time', keep_attrs=True) + / weights.groupby(month).sum(dim='time') + ) + else: + weighted_vars[var_name] = ds[var_name] + return xr.Dataset(weighted_vars, attrs=ds.attrs) + + +def _compute_time_weights(ds): + """ + compute per-timestep weights from time_bnds as float days. + + handles the three cases xarray produces when reading time_bnds: + * timedelta64 (difference of decoded datetime64 bounds) + * cftime timedelta objects (when use_cftime=True) + * numeric float / int (bounds stored as plain numbers) + + :param ds: xarray Dataset with 'time_bnds' variable + :type ds: xr.Dataset + :return: DataArray of float weights along the time dimension + :rtype: xr.DataArray + """ + if 'time_bnds' in ds: + time_bnds = ds['time_bnds'] + raw_diff = (time_bnds[:, 1] - time_bnds[:, 0]).values # numpy array + + if raw_diff.dtype.kind == 'm': + # timedelta64 — convert to float days via seconds + float_days = raw_diff.astype('timedelta64[s]').astype('float64') / 86400.0 + elif raw_diff.dtype == object: + # cftime timedelta objects: .days + .seconds/86400 + float_days = np.array( + [td.days + td.seconds / 86400.0 for td in raw_diff], + dtype='float64' + ) + else: + # already numeric (float or int days) + float_days = raw_diff.astype('float64') + + weights = xr.DataArray(float_days, dims=['time']) + else: + fre_logger.warning('time_bnds not found, falling back to uniform weights') + weights = xr.ones_like(ds['time'], dtype='float64') + return weights diff --git a/meta.yaml b/meta.yaml index 56b556b42..d542781b9 100644 --- a/meta.yaml +++ b/meta.yaml @@ -31,7 +31,6 @@ requirements: - noaa-gfdl::analysis_scripts==0.0.1 - noaa-gfdl::catalogbuilder==2025.01.01 # - noaa-gfdl::fre-nctools==2022.02.01 - - conda-forge::cdo>=2 - conda-forge::cftime - conda-forge::click>=8.2 - conda-forge::cmor>=3.14 @@ -43,7 +42,6 @@ requirements: - conda-forge::nccmp # - conda-forge::numpy==1.26.4 - conda-forge::numpy>=2 - - conda-forge::python-cdo - conda-forge::pyyaml - conda-forge::xarray>=2024.* - conda-forge::netcdf4>=1.7.* diff --git a/pylintrc b/pylintrc index ac870b9e3..f262c5469 100644 --- a/pylintrc +++ b/pylintrc @@ -163,7 +163,9 @@ class-const-naming-style=UPPER_CASE #class-const-rgx= # Naming style matching correct class names. -class-naming-style=PascalCase +# Allow PascalCase and mixedCase (e.g. xarrayTimeAverager) for class names +#class-naming-style=PascalCase +class-rgx=[a-zA-Z_][a-zA-Z0-9_]* # Regular expression matching correct class names. Overrides class-naming- # style. If left empty, class names will be checked with the set naming style. diff --git a/pyproject.toml b/pyproject.toml index 2664168dd..93190fd09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ keywords = [ dependencies = [ 'analysis_scripts', 'catalogbuilder', - 'cdo', 'cftime', 'click', 'cmor',