diff --git a/src/skyborn/calc/__init__.py b/src/skyborn/calc/__init__.py index 2ba3804..4cc6f6f 100644 --- a/src/skyborn/calc/__init__.py +++ b/src/skyborn/calc/__init__.py @@ -49,5 +49,8 @@ trend_analysis, ) +# Import Standardized Precipitation Index calculation +from .spi import spi, spi_xarray, standardized_precipitation_index + # Import tropopause calculation (requires compiled extensions) from .troposphere import trop_wmo, trop_wmo_profile diff --git a/src/skyborn/calc/spi/__init__.py b/src/skyborn/calc/spi/__init__.py new file mode 100644 index 0000000..19d99c2 --- /dev/null +++ b/src/skyborn/calc/spi/__init__.py @@ -0,0 +1,11 @@ +""" +Standardized Precipitation Index (SPI) calculation module. + +This module provides efficient, vectorized calculation of the Standardized +Precipitation Index for multi-dimensional climate datasets. +""" + +from .core import standardized_precipitation_index, spi +from .xarray import spi_xarray + +__all__ = ["standardized_precipitation_index", "spi", "spi_xarray"] \ No newline at end of file diff --git a/src/skyborn/calc/spi/core.py b/src/skyborn/calc/spi/core.py new file mode 100644 index 0000000..4160e30 --- /dev/null +++ b/src/skyborn/calc/spi/core.py @@ -0,0 +1,296 @@ +""" +Core SPI calculation functions. + +This module provides the mathematical implementation of the Standardized +Precipitation Index using vectorized operations for efficient computation +on multi-dimensional datasets. +""" + +import numpy as np +from scipy import stats +from typing import Union, Optional, Tuple +import warnings + + +def _gamma_fit_vectorized(data: np.ndarray, axis: int = 0) -> Tuple[np.ndarray, np.ndarray]: + """ + Fit gamma distribution parameters to precipitation data along specified axis. + + Parameters + ---------- + data : np.ndarray + Precipitation data with shape (..., time, ...) + axis : int, default 0 + Axis along which to fit the distribution (typically time axis) + + Returns + ------- + alpha : np.ndarray + Shape parameter of gamma distribution + beta : np.ndarray + Scale parameter of gamma distribution + """ + # Move time axis to the end for easier processing + data_moved = np.moveaxis(data, axis, -1) + original_shape = data_moved.shape + + # Reshape to (spatial_points, time) + data_reshaped = data_moved.reshape(-1, original_shape[-1]) + + # Initialize output arrays + n_points = data_reshaped.shape[0] + alpha = np.full(n_points, np.nan) + beta = np.full(n_points, np.nan) + + # Fit gamma distribution for each spatial point + for i in range(n_points): + series = data_reshaped[i, :] + + # Remove zeros and NaNs for gamma fitting + valid_data = series[~np.isnan(series) & (series > 0)] + + if len(valid_data) < 10: # Need sufficient data for fitting + continue + + try: + # Fit gamma distribution using method of moments as initial guess + mean_val = np.mean(valid_data) + var_val = np.var(valid_data) + + if var_val > 0 and mean_val > 0: + # Method of moments estimates + alpha_est = mean_val**2 / var_val + beta_est = var_val / mean_val + + # Use scipy's gamma fit with good initial guess + alpha_fit, _, beta_fit = stats.gamma.fit(valid_data, fa=alpha_est, scale=beta_est) + + alpha[i] = alpha_fit + beta[i] = beta_fit + + except (RuntimeError, ValueError): + # If fitting fails, use method of moments + try: + mean_val = np.mean(valid_data) + var_val = np.var(valid_data) + if var_val > 0 and mean_val > 0: + alpha[i] = mean_val**2 / var_val + beta[i] = var_val / mean_val + except: + continue + + # Reshape back to original spatial dimensions + spatial_shape = original_shape[:-1] + alpha = alpha.reshape(spatial_shape) + beta = beta.reshape(spatial_shape) + + return alpha, beta + + +def _calculate_spi_values(precip: np.ndarray, alpha: np.ndarray, beta: np.ndarray, + axis: int = 0) -> np.ndarray: + """ + Calculate SPI values using fitted gamma parameters. + + Parameters + ---------- + precip : np.ndarray + Precipitation data + alpha : np.ndarray + Shape parameter of gamma distribution + beta : np.ndarray + Scale parameter of gamma distribution + axis : int, default 0 + Time axis + + Returns + ------- + spi : np.ndarray + Standardized Precipitation Index values + """ + # Move time axis to the end + precip_moved = np.moveaxis(precip, axis, -1) + original_shape = precip_moved.shape + + # Reshape for vectorized operations + precip_reshaped = precip_moved.reshape(-1, original_shape[-1]) + alpha_flat = alpha.flatten() + beta_flat = beta.flatten() + + # Initialize output + spi_reshaped = np.full_like(precip_reshaped, np.nan) + + for i in range(precip_reshaped.shape[0]): + if np.isnan(alpha_flat[i]) or np.isnan(beta_flat[i]): + continue + + series = precip_reshaped[i, :] + + # Calculate cumulative probability using gamma distribution + # Handle zeros separately + prob = np.full_like(series, np.nan) + + # For zero precipitation values + zero_mask = (series == 0) & ~np.isnan(series) + nonzero_mask = (series > 0) & ~np.isnan(series) + + if np.any(nonzero_mask): + # Probability for non-zero values + prob[nonzero_mask] = stats.gamma.cdf(series[nonzero_mask], + a=alpha_flat[i], scale=beta_flat[i]) + + if np.any(zero_mask): + # For zeros, use the probability of zero precipitation + # This is often estimated as the fraction of zero values in the dataset + zero_count = np.sum(series == 0) + total_count = np.sum(~np.isnan(series)) + if total_count > 0: + prob_zero = zero_count / total_count + prob[zero_mask] = prob_zero + + # Convert to standard normal distribution + # Ensure probabilities are in valid range (0, 1) + prob = np.clip(prob, 1e-6, 1-1e-6) + spi_reshaped[i, :] = stats.norm.ppf(prob) + + # Reshape back to original shape + spi = spi_reshaped.reshape(original_shape) + + # Move time axis back to original position + spi = np.moveaxis(spi, -1, axis) + + return spi + + +def _rolling_sum(data: np.ndarray, window: int, axis: int = 0) -> np.ndarray: + """ + Calculate rolling sum along specified axis. + + Parameters + ---------- + data : np.ndarray + Input data + window : int + Rolling window size + axis : int, default 0 + Axis along which to calculate rolling sum + + Returns + ------- + np.ndarray + Rolling sum values (same shape as input) + """ + if window == 1: + return data.copy() + + # Move target axis to the end + data_moved = np.moveaxis(data, axis, -1) + original_shape = data_moved.shape + + # Reshape to (spatial_points, time) + reshaped = data_moved.reshape(-1, original_shape[-1]) + + # Initialize result with NaN + result = np.full(reshaped.shape, np.nan, dtype=np.float64) + + for i in range(reshaped.shape[0]): + series = reshaped[i, :].astype(np.float64) # Ensure float type + + # Calculate rolling sum for each position + for j in range(len(series)): + start_idx = max(0, j - window + 1) + end_idx = j + 1 + window_data = series[start_idx:end_idx] + + # Only calculate sum if we have complete window and all values are valid + if len(window_data) == window and np.all(np.isfinite(window_data)): + result[i, j] = np.sum(window_data) + + # Reshape back to original shape and move axis back + result = result.reshape(original_shape) + result = np.moveaxis(result, -1, axis) + + return result + + +def standardized_precipitation_index( + precipitation: np.ndarray, + time_scale: int = 3, + axis: int = 0, + distribution: str = 'gamma' +) -> np.ndarray: + """ + Calculate the Standardized Precipitation Index (SPI). + + The SPI is a widely used index to characterize meteorological drought + on a range of time scales. This implementation provides efficient + calculation for multi-dimensional datasets. + + Parameters + ---------- + precipitation : np.ndarray + Precipitation data. Can be multi-dimensional with time along any axis. + time_scale : int, default 3 + Time scale for SPI calculation in months (1, 3, 6, 12, etc.) + axis : int, default 0 + Axis along which time varies + distribution : str, default 'gamma' + Distribution to fit to precipitation data. Currently only 'gamma' is supported. + + Returns + ------- + np.ndarray + Standardized Precipitation Index values with same shape as input + + Notes + ----- + The SPI calculation involves: + 1. Aggregating precipitation over the specified time scale + 2. Fitting a probability distribution (gamma) to the aggregated data + 3. Transforming to standard normal distribution + + SPI values interpretation: + - SPI ≥ 2.0: Extremely wet + - 1.5 ≤ SPI < 2.0: Very wet + - 1.0 ≤ SPI < 1.5: Moderately wet + - -1.0 < SPI < 1.0: Near normal + - -1.5 < SPI ≤ -1.0: Moderately dry + - -2.0 < SPI ≤ -1.5: Severely dry + - SPI ≤ -2.0: Extremely dry + + Examples + -------- + >>> import numpy as np + >>> from skyborn.calc.spi import standardized_precipitation_index + + # Generate sample precipitation data (time, lat, lon) + >>> precip = np.random.gamma(2, 2, size=(120, 10, 15)) # 10 years monthly data + >>> spi_3m = standardized_precipitation_index(precip, time_scale=3, axis=0) + >>> print(spi_3m.shape) + (120, 10, 15) + """ + if distribution != 'gamma': + raise ValueError("Currently only 'gamma' distribution is supported") + + if time_scale < 1: + raise ValueError("Time scale must be >= 1") + + # Step 1: Calculate rolling sum for the specified time scale + if time_scale > 1: + precip_aggregated = _rolling_sum(precipitation, time_scale, axis=axis) + else: + precip_aggregated = precipitation.copy() + + # Step 2: Fit gamma distribution parameters + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning) + alpha, beta = _gamma_fit_vectorized(precip_aggregated, axis=axis) + + # Step 3: Calculate SPI values + spi_values = _calculate_spi_values(precip_aggregated, alpha, beta, axis=axis) + + return spi_values + + +# Convenience alias +spi = standardized_precipitation_index \ No newline at end of file diff --git a/src/skyborn/calc/spi/xarray.py b/src/skyborn/calc/spi/xarray.py new file mode 100644 index 0000000..21abad2 --- /dev/null +++ b/src/skyborn/calc/spi/xarray.py @@ -0,0 +1,311 @@ +""" +Xarray interface for Standardized Precipitation Index calculation. + +This module provides xarray-compatible functions for calculating SPI +with automatic dimension handling and metadata preservation. +""" + +import numpy as np +import xarray as xr +from typing import Union, Optional, Dict, Any +import warnings + +from .core import standardized_precipitation_index + + +def spi_xarray( + precipitation: xr.DataArray, + time_scale: int = 3, + time_dim: Optional[str] = None, + distribution: str = 'gamma', + **kwargs +) -> xr.DataArray: + """ + Calculate Standardized Precipitation Index for xarray DataArrays. + + This function provides a convenient interface for calculating SPI on + xarray DataArrays with automatic dimension detection and metadata + preservation. + + Parameters + ---------- + precipitation : xr.DataArray + Precipitation data as xarray DataArray. Should have a time dimension. + time_scale : int, default 3 + Time scale for SPI calculation in time units (1, 3, 6, 12, etc.) + time_dim : str, optional + Name of the time dimension. If None, will attempt to detect automatically. + distribution : str, default 'gamma' + Distribution to fit to precipitation data. Currently only 'gamma' is supported. + **kwargs + Additional keyword arguments passed to core SPI function + + Returns + ------- + xr.DataArray + SPI values with same coordinates and dimensions as input, + with updated attributes describing the calculation + + Raises + ------ + ValueError + If time dimension cannot be found or identified + + Examples + -------- + >>> import xarray as xr + >>> import numpy as np + >>> from skyborn.calc.spi import spi_xarray + + # Create sample precipitation data + >>> time = pd.date_range('2000-01-01', periods=120, freq='M') + >>> lat = np.linspace(-30, 30, 10) + >>> lon = np.linspace(0, 350, 15) + >>> precip_data = np.random.gamma(2, 2, size=(120, 10, 15)) + >>> precip = xr.DataArray( + ... precip_data, + ... coords={'time': time, 'lat': lat, 'lon': lon}, + ... dims=['time', 'lat', 'lon'], + ... attrs={'units': 'mm', 'long_name': 'precipitation'} + ... ) + + # Calculate 3-month SPI + >>> spi_3m = spi_xarray(precip, time_scale=3) + >>> print(spi_3m.attrs['long_name']) + '3-month Standardized Precipitation Index' + + # Calculate 12-month SPI + >>> spi_12m = spi_xarray(precip, time_scale=12) + """ + + # Validate input + if not isinstance(precipitation, xr.DataArray): + raise TypeError("precipitation must be an xarray DataArray") + + # Detect time dimension + if time_dim is None: + time_dim = _detect_time_dimension(precipitation) + + if time_dim not in precipitation.dims: + raise ValueError(f"Time dimension '{time_dim}' not found in DataArray dimensions: {precipitation.dims}") + + # Get the axis number for the time dimension + time_axis = precipitation.get_axis_num(time_dim) + + # Extract data and calculate SPI + precip_data = precipitation.values + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning) + spi_data = standardized_precipitation_index( + precip_data, + time_scale=time_scale, + axis=time_axis, + distribution=distribution, + **kwargs + ) + + # Create result DataArray with same coordinates and dimensions + spi_result = xr.DataArray( + spi_data, + coords=precipitation.coords, + dims=precipitation.dims, + name=f'spi_{time_scale}m' + ) + + # Update attributes + attrs = _create_spi_attributes(precipitation.attrs, time_scale, distribution) + spi_result.attrs.update(attrs) + + return spi_result + + +def spi_dataset( + dataset: xr.Dataset, + precip_var: str, + time_scales: Union[int, list] = [1, 3, 6, 12], + time_dim: Optional[str] = None, + distribution: str = 'gamma', + **kwargs +) -> xr.Dataset: + """ + Calculate SPI for multiple time scales and return as Dataset. + + Parameters + ---------- + dataset : xr.Dataset + Dataset containing precipitation data + precip_var : str + Name of precipitation variable in dataset + time_scales : int or list of int, default [1, 3, 6, 12] + Time scale(s) for SPI calculation + time_dim : str, optional + Name of time dimension. If None, will attempt to detect automatically. + distribution : str, default 'gamma' + Distribution to fit to precipitation data + **kwargs + Additional keyword arguments passed to SPI calculation + + Returns + ------- + xr.Dataset + Dataset containing SPI variables for each time scale + + Examples + -------- + >>> # Calculate multiple SPI time scales + >>> spi_ds = spi_dataset(dataset, 'precipitation', time_scales=[3, 6, 12]) + >>> print(list(spi_ds.data_vars)) + ['spi_3m', 'spi_6m', 'spi_12m'] + """ + + if precip_var not in dataset: + raise ValueError(f"Precipitation variable '{precip_var}' not found in dataset") + + precipitation = dataset[precip_var] + + # Ensure time_scales is a list + if isinstance(time_scales, int): + time_scales = [time_scales] + + # Calculate SPI for each time scale + spi_vars = {} + for ts in time_scales: + spi_result = spi_xarray( + precipitation, + time_scale=ts, + time_dim=time_dim, + distribution=distribution, + **kwargs + ) + spi_vars[f'spi_{ts}m'] = spi_result + + # Create new dataset with SPI variables + result_ds = xr.Dataset(spi_vars) + + # Copy coordinates from original dataset + for coord_name, coord_data in dataset.coords.items(): + if coord_name not in result_ds.coords: + result_ds.coords[coord_name] = coord_data + + # Update global attributes + result_ds.attrs.update(dataset.attrs) + result_ds.attrs['spi_calculation'] = f'SPI calculated for time scales: {time_scales}' + result_ds.attrs['spi_distribution'] = distribution + + return result_ds + + +def _detect_time_dimension(da: xr.DataArray) -> str: + """ + Attempt to automatically detect time dimension in DataArray. + + Parameters + ---------- + da : xr.DataArray + Input DataArray + + Returns + ------- + str + Name of detected time dimension + + Raises + ------ + ValueError + If time dimension cannot be detected + """ + + # Common time dimension names + time_names = ['time', 'Time', 'TIME', 't', 'date', 'Date'] + + # Check for exact matches first + for name in time_names: + if name in da.dims: + return name + + # Check coordinate types + for dim_name in da.dims: + if dim_name in da.coords: + coord = da.coords[dim_name] + + # Check if coordinate has datetime-like dtype + if np.issubdtype(coord.dtype, np.datetime64): + return dim_name + + # Check for time-related attributes + if hasattr(coord, 'attrs'): + long_name = coord.attrs.get('long_name', '').lower() + standard_name = coord.attrs.get('standard_name', '').lower() + + if any(keyword in long_name or keyword in standard_name + for keyword in ['time', 'date']): + return dim_name + + # If no clear time dimension found, use first dimension as fallback + if da.dims: + warnings.warn( + f"Could not detect time dimension. Using first dimension '{da.dims[0]}' as time axis. " + f"Specify time_dim parameter explicitly if this is incorrect.", + UserWarning + ) + return da.dims[0] + + raise ValueError("Cannot detect time dimension in DataArray") + + +def _create_spi_attributes( + original_attrs: Dict[str, Any], + time_scale: int, + distribution: str +) -> Dict[str, Any]: + """ + Create appropriate attributes for SPI DataArray. + + Parameters + ---------- + original_attrs : dict + Original attributes from precipitation data + time_scale : int + Time scale used for SPI calculation + distribution : str + Distribution used for fitting + + Returns + ------- + dict + Attributes for SPI DataArray + """ + + attrs = { + 'long_name': f'{time_scale}-month Standardized Precipitation Index', + 'standard_name': 'standardized_precipitation_index', + 'units': '1', # Dimensionless + 'description': ( + f'Standardized Precipitation Index calculated over {time_scale}-month time scale. ' + f'Based on {distribution} distribution fitting to precipitation data.' + ), + 'spi_time_scale': time_scale, + 'spi_distribution': distribution, + 'interpretation': ( + 'SPI >= 2.0: Extremely wet; 1.5 <= SPI < 2.0: Very wet; ' + '1.0 <= SPI < 1.5: Moderately wet; -1.0 < SPI < 1.0: Near normal; ' + '-1.5 < SPI <= -1.0: Moderately dry; -2.0 < SPI <= -1.5: Severely dry; ' + 'SPI <= -2.0: Extremely dry' + ), + 'references': ( + 'McKee, T.B., Doesken, N.J. and Kleist, J., 1993. The relationship of drought ' + 'frequency and duration to time scales. In Proceedings of the 8th Conference on ' + 'Applied Climatology (Vol. 17, No. 22, pp. 179-183).' + ) + } + + # Preserve some original attributes if relevant + if 'source' in original_attrs: + attrs['source'] = original_attrs['source'] + if 'history' in original_attrs: + attrs['history'] = original_attrs['history'] + f'; SPI calculated with time_scale={time_scale}' + else: + attrs['history'] = f'SPI calculated with time_scale={time_scale}' + + return attrs \ No newline at end of file diff --git a/tests/test_spi.py b/tests/test_spi.py new file mode 100644 index 0000000..9a966d6 --- /dev/null +++ b/tests/test_spi.py @@ -0,0 +1,410 @@ +""" +Tests for the Standardized Precipitation Index (SPI) module. + +This module tests both the core SPI calculation functions and the +xarray integration interface. +""" + +import numpy as np +import pytest +import sys +import os + +# Add src to path for testing +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +try: + import xarray as xr + HAS_XARRAY = True +except ImportError: + HAS_XARRAY = False + xr = None + +from numpy.testing import assert_array_almost_equal, assert_allclose +from scipy import stats + +from skyborn.calc.spi.core import ( + standardized_precipitation_index, + spi, + _gamma_fit_vectorized, + _calculate_spi_values, + _rolling_sum +) + +if HAS_XARRAY: + from skyborn.calc.spi.xarray import spi_xarray, spi_dataset + + +class TestSPICore: + """Test core SPI calculation functions.""" + + def test_rolling_sum_basic(self): + """Test basic rolling sum functionality.""" + data = np.array([1, 2, 3, 4, 5]) + + # 3-period rolling sum + result = _rolling_sum(data, window=3, axis=0) + expected = np.array([np.nan, np.nan, 6, 9, 12]) # [1+2+3, 2+3+4, 3+4+5] + + # Compare non-NaN values + valid_mask = ~np.isnan(expected) + assert_allclose(result[valid_mask], expected[valid_mask]) + + # Check NaN positions + assert np.all(np.isnan(result[~valid_mask])) + + def test_rolling_sum_2d(self): + """Test rolling sum with 2D data.""" + # Create 2D data (time, space) + data = np.array([ + [1, 10], + [2, 20], + [3, 30], + [4, 40] + ]) + + result = _rolling_sum(data, window=2, axis=0) + + # Expected: first row should be NaN, then rolling sums + expected = np.array([ + [np.nan, np.nan], + [3, 30], # [1+2, 10+20] + [5, 50], # [2+3, 20+30] + [7, 70] # [3+4, 30+40] + ]) + + # Test valid values + valid_mask = ~np.isnan(expected) + assert_allclose(result[valid_mask], expected[valid_mask]) + + def test_rolling_sum_window_1(self): + """Test rolling sum with window=1 (should return original data).""" + data = np.array([1, 2, 3, 4, 5]) + result = _rolling_sum(data, window=1, axis=0) + assert_array_almost_equal(result, data) + + def test_gamma_fit_vectorized(self): + """Test gamma distribution fitting.""" + # Create synthetic gamma-distributed data + np.random.seed(42) + true_alpha, true_scale = 2.0, 3.0 + + # Generate data for 2 spatial points + data = np.random.gamma(true_alpha, true_scale, size=(100, 2)) + + alpha, beta = _gamma_fit_vectorized(data, axis=0) + + # Check that fitted parameters are reasonable + assert alpha.shape == (2,) + assert beta.shape == (2,) + + # Parameters should be positive and finite + assert np.all(alpha > 0) + assert np.all(beta > 0) + assert np.all(np.isfinite(alpha)) + assert np.all(np.isfinite(beta)) + + # Fitted parameters should be reasonably close to true values + # (allowing for some variation due to random sampling) + assert np.all(np.abs(alpha - true_alpha) < 1.0) + assert np.all(np.abs(beta - true_scale) < 2.0) + + def test_gamma_fit_with_zeros(self): + """Test gamma fitting with zero precipitation values.""" + # Create data with some zeros + data = np.array([ + [0, 1, 2, 0, 3, 4, 0, 5], # Mix of zeros and positive values + [1, 2, 3, 4, 5, 6, 7, 8] # All positive values + ]).T + + alpha, beta = _gamma_fit_vectorized(data, axis=0) + + # Should still produce reasonable fits + assert alpha.shape == (2,) + assert beta.shape == (2,) + assert np.all(alpha > 0) + assert np.all(beta > 0) + assert np.all(np.isfinite(alpha)) + assert np.all(np.isfinite(beta)) + + def test_spi_calculation_basic(self): + """Test basic SPI calculation.""" + # Create synthetic precipitation data + np.random.seed(42) + # Generate gamma-distributed precipitation for 5 years (60 months) + precip = np.random.gamma(2, 2, size=(60,)) + + # Calculate SPI with 1-month time scale + spi_values = standardized_precipitation_index(precip, time_scale=1, axis=0) + + # Check basic properties + assert spi_values.shape == precip.shape + assert np.isfinite(spi_values).sum() > 40 # Most values should be finite + + # SPI should have approximately mean=0, std=1 for sufficiently long series + valid_spi = spi_values[np.isfinite(spi_values)] + if len(valid_spi) > 30: + assert abs(np.mean(valid_spi)) < 0.5 # Mean should be close to 0 + assert abs(np.std(valid_spi) - 1.0) < 0.5 # Std should be close to 1 + + def test_spi_multidimensional(self): + """Test SPI calculation with multi-dimensional data.""" + # Create 3D precipitation data (time, lat, lon) + np.random.seed(42) + precip = np.random.gamma(2, 2, size=(48, 5, 4)) # 4 years, 5x4 grid + + # Calculate 3-month SPI + spi_values = standardized_precipitation_index(precip, time_scale=3, axis=0) + + # Check output shape + assert spi_values.shape == precip.shape + + # Check that we have reasonable number of finite values + # (some will be NaN due to rolling window at beginning) + n_finite = np.isfinite(spi_values).sum() + expected_min_finite = 40 * 5 * 4 # At least 40 time steps for each grid point + assert n_finite >= expected_min_finite + + def test_spi_different_time_scales(self): + """Test SPI calculation with different time scales.""" + np.random.seed(42) + precip = np.random.gamma(2, 2, size=(120,)) # 10 years monthly data + + # Test different time scales + for time_scale in [1, 3, 6, 12]: + spi_values = standardized_precipitation_index(precip, time_scale=time_scale, axis=0) + + assert spi_values.shape == precip.shape + + # Check that we lose some data at the beginning due to rolling window + if time_scale > 1: + # First (time_scale - 1) values should be NaN + assert np.all(np.isnan(spi_values[:time_scale-1])) + + # Later values should be mostly finite + later_values = spi_values[time_scale+10:] # Skip initial period + finite_ratio = np.isfinite(later_values).mean() + assert finite_ratio > 0.8 # At least 80% should be finite + + def test_spi_alias(self): + """Test that spi is an alias for standardized_precipitation_index.""" + np.random.seed(42) + precip = np.random.gamma(2, 2, size=(60,)) + + spi1 = standardized_precipitation_index(precip, time_scale=3) + spi2 = spi(precip, time_scale=3) + + assert_array_almost_equal(spi1, spi2) + + def test_invalid_inputs(self): + """Test handling of invalid inputs.""" + precip = np.random.gamma(2, 2, size=(60,)) + + # Invalid distribution + with pytest.raises(ValueError, match="only 'gamma' distribution is supported"): + standardized_precipitation_index(precip, distribution='normal') + + # Invalid time scale + with pytest.raises(ValueError, match="Time scale must be >= 1"): + standardized_precipitation_index(precip, time_scale=0) + + +@pytest.mark.skipif(not HAS_XARRAY, reason="xarray not available") +class TestSPIXarray: + """Test xarray interface for SPI calculations.""" + + def test_spi_xarray_basic(self): + """Test basic xarray SPI calculation.""" + # Create sample xarray DataArray + np.random.seed(42) + time = range(60) # 5 years monthly + lat = np.linspace(-30, 30, 5) + lon = np.linspace(0, 360, 4, endpoint=False) + + precip_data = np.random.gamma(2, 2, size=(60, 5, 4)) + precip = xr.DataArray( + precip_data, + coords={'time': time, 'lat': lat, 'lon': lon}, + dims=['time', 'lat', 'lon'], + attrs={'units': 'mm', 'long_name': 'precipitation'} + ) + + # Calculate SPI + spi_result = spi_xarray(precip, time_scale=3) + + # Check output properties + assert isinstance(spi_result, xr.DataArray) + assert spi_result.shape == precip.shape + assert spi_result.dims == precip.dims + + # Check coordinates are preserved + for coord_name in ['time', 'lat', 'lon']: + assert coord_name in spi_result.coords + assert_array_almost_equal(spi_result.coords[coord_name], precip.coords[coord_name]) + + # Check attributes + assert spi_result.attrs['units'] == '1' + assert 'Standardized Precipitation Index' in spi_result.attrs['long_name'] + assert spi_result.attrs['spi_time_scale'] == 3 + + def test_spi_xarray_time_dim_detection(self): + """Test automatic time dimension detection.""" + np.random.seed(42) + + # Test with different time dimension names + for time_name in ['time', 'Time', 't']: + coords = {time_name: range(36), 'space': range(5)} + dims = [time_name, 'space'] + + precip = xr.DataArray( + np.random.gamma(2, 2, size=(36, 5)), + coords=coords, + dims=dims + ) + + spi_result = spi_xarray(precip, time_scale=1) + assert spi_result.shape == precip.shape + + def test_spi_xarray_explicit_time_dim(self): + """Test explicit time dimension specification.""" + np.random.seed(42) + precip = xr.DataArray( + np.random.gamma(2, 2, size=(5, 36)), # space, time + coords={'space': range(5), 'month': range(36)}, + dims=['space', 'month'] + ) + + # Specify time dimension explicitly + spi_result = spi_xarray(precip, time_scale=3, time_dim='month') + assert spi_result.shape == precip.shape + + def test_spi_dataset(self): + """Test SPI calculation for entire dataset with multiple time scales.""" + np.random.seed(42) + time = range(60) + lat = np.linspace(-30, 30, 3) + lon = np.linspace(0, 360, 4, endpoint=False) + + # Create dataset with precipitation + precip_data = np.random.gamma(2, 2, size=(60, 3, 4)) + dataset = xr.Dataset({ + 'precipitation': xr.DataArray( + precip_data, + coords={'time': time, 'lat': lat, 'lon': lon}, + dims=['time', 'lat', 'lon'] + ) + }) + + # Calculate SPI for multiple time scales + spi_ds = spi_dataset(dataset, 'precipitation', time_scales=[1, 3, 6]) + + # Check output + assert isinstance(spi_ds, xr.Dataset) + assert 'spi_1m' in spi_ds.data_vars + assert 'spi_3m' in spi_ds.data_vars + assert 'spi_6m' in spi_ds.data_vars + + # Check shapes + for var_name in ['spi_1m', 'spi_3m', 'spi_6m']: + assert spi_ds[var_name].shape == (60, 3, 4) + + def test_invalid_precipitation_variable(self): + """Test error handling for invalid precipitation variable.""" + dataset = xr.Dataset({'temperature': xr.DataArray([1, 2, 3])}) + + with pytest.raises(ValueError, match="Precipitation variable 'precip' not found"): + spi_dataset(dataset, 'precip') + + def test_non_xarray_input(self): + """Test error handling for non-xarray input.""" + with pytest.raises(TypeError, match="precipitation must be an xarray DataArray"): + spi_xarray(np.array([1, 2, 3])) + + +class TestSPIIntegration: + """Integration tests for SPI calculations.""" + + def test_spi_known_values(self): + """Test SPI calculation with known values for validation.""" + # Create a simple test case with known drought/wet patterns + # Alternate between low and high precipitation + precip = np.array([0.1, 5.0, 0.1, 5.0, 0.1, 5.0] * 10) # 60 months + + spi_values = standardized_precipitation_index(precip, time_scale=1, axis=0) + + # Check that low precipitation periods have negative SPI + # and high precipitation periods have positive SPI + # (after sufficient data for distribution fitting) + + valid_indices = np.where(np.isfinite(spi_values))[0] + if len(valid_indices) > 20: # Need sufficient data + # Low precip indices (even positions in pattern) + low_precip_mask = np.array([(i % 6) in [0, 2, 4] for i in valid_indices]) + high_precip_mask = np.array([(i % 6) in [1, 3, 5] for i in valid_indices]) + + if np.any(low_precip_mask) and np.any(high_precip_mask): + low_spi = spi_values[valid_indices[low_precip_mask]] + high_spi = spi_values[valid_indices[high_precip_mask]] + + # Most low precipitation should have negative SPI + assert np.mean(low_spi < 0) > 0.6 + # Most high precipitation should have positive SPI + assert np.mean(high_spi > 0) > 0.6 + + def test_spi_consistency_across_scales(self): + """Test that SPI is consistent across different time scales.""" + np.random.seed(42) + precip = np.random.gamma(2, 2, size=(120,)) # 10 years + + spi_1m = standardized_precipitation_index(precip, time_scale=1) + spi_3m = standardized_precipitation_index(precip, time_scale=3) + spi_12m = standardized_precipitation_index(precip, time_scale=12) + + # Longer time scales should be smoother (less variance) + valid_1m = spi_1m[np.isfinite(spi_1m)] + valid_3m = spi_3m[np.isfinite(spi_3m)] + valid_12m = spi_12m[np.isfinite(spi_12m)] + + if len(valid_1m) > 30 and len(valid_3m) > 30 and len(valid_12m) > 30: + var_1m = np.var(valid_1m) + var_3m = np.var(valid_3m) + var_12m = np.var(valid_12m) + + # Longer time scales should generally have less variance + # (though this is not strictly guaranteed for all datasets) + assert var_12m <= var_1m * 1.5 # Allow some flexibility + + @pytest.mark.skipif(not HAS_XARRAY, reason="xarray not available") + def test_xarray_numpy_consistency(self): + """Test that xarray and numpy interfaces give same results.""" + np.random.seed(42) + precip_data = np.random.gamma(2, 2, size=(60, 3, 4)) + + # Calculate with numpy interface + spi_numpy = standardized_precipitation_index(precip_data, time_scale=3, axis=0) + + # Calculate with xarray interface + precip_xr = xr.DataArray( + precip_data, + coords={'time': range(60), 'lat': range(3), 'lon': range(4)}, + dims=['time', 'lat', 'lon'] + ) + spi_xarray_result = spi_xarray(precip_xr, time_scale=3) + + # Results should be very close + assert_allclose(spi_numpy, spi_xarray_result.values, rtol=1e-10) + + +if __name__ == "__main__": + # Run basic tests if executed directly + test_core = TestSPICore() + test_core.test_rolling_sum_basic() + test_core.test_spi_calculation_basic() + print("Basic SPI tests passed!") + + if HAS_XARRAY: + test_xr = TestSPIXarray() + test_xr.test_spi_xarray_basic() + print("Xarray SPI tests passed!") + else: + print("Xarray not available, skipping xarray tests") \ No newline at end of file