Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 191 additions & 5 deletions src/xrsignal/spectral_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __csd_chunk(data, dim, **kwargs):
return np.abs(Pxy_x)


def welch(data, dim, dB=False, **kwargs):
def welch(data, dim, dB=False, nan=False, **kwargs):
'''
Estimate power spectral density using welch method
For now, an integer number of chunks in PSD dimension is required
Expand All @@ -183,12 +183,14 @@ def welch(data, dim, dB=False, **kwargs):
dimension to calculate PSD over
dB : bool
if True, return PSD in dB
nan : bool
if True, use welch_nan instead of welch. This throws out segments with NAN and still calculates the PSD
'''

if isinstance(data, xr.DataArray):
Sxx = __welch_da(data, dim, dB=dB, **kwargs)
Sxx = __welch_da(data, dim, dB=dB, nan=nan, **kwargs)
elif isinstance(data, xr.Dataset):
Sxx = data.map(__welch_da, dim=dim, dB=dB, **kwargs)
Sxx = data.map(__welch_da, dim=dim, dB=dB, nan=nan, **kwargs)
else:
raise Exception('data must be xr.DataArray or xr.Dataset')

Expand All @@ -207,6 +209,8 @@ def __welch_chunk(da, dim, **kwargs):
**kwargs
passed to scipy.signal.welch
'''
# unpack nan kwarg
nan = kwargs.pop('nan', False)

# Create new dimensions of PSD object
original_dims = list(da.dims)
Expand All @@ -217,14 +221,18 @@ def __welch_chunk(da, dim, **kwargs):
new_dims.append(dim)

# Estimate PSD and convert to xarray.DataArray
f, P = signal.welch(da.values, axis=psd_dim_idx, **kwargs)
if nan:
f, P = welch_nan(da.values, axis=psd_dim_idx, **kwargs)
else:
f, P = signal.welch(da.values, axis=psd_dim_idx, **kwargs)

P = np.expand_dims(P, -1)

Px = xr.DataArray(P, dims=new_dims, coords={f'{dim}_frequency': f})

return Px

def __welch_da(da, dim, dB=False, **kwargs):
def __welch_da(da, dim, dB=False, nan=False, **kwargs):
'''
Estimate power spectral density using welch method

Expand All @@ -238,6 +246,8 @@ def __welch_da(da, dim, dB=False, **kwargs):
dimension to calculate PSD over
dB : bool
if True, return PSD in dB
nan : bool
if True, use welch_nan instead of welch
'''

## Parse Kwargs
Expand Down Expand Up @@ -329,9 +339,185 @@ def __welch_da(da, dim, dB=False, **kwargs):
name=f'psd across {dim} dimension')

kwargs['dim'] = dim
kwargs['nan'] = nan

Pxx = xr.map_blocks(__welch_chunk, da, template=template, kwargs=kwargs)

if dB:
return 10*np.log10(Pxx)
else:
return Pxx

def welch_nan(x, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
detrend='constant', return_onesided=True, scaling='density',
axis=-1, average='mean'):
"""
Estimate power spectral density using Welch's method, ignoring NaN values.

This function behaves exactly like scipy.signal.welch but handles NaN values
by ignoring segments containing NaN when computing the average.

Parameters
----------
x : array_like
Time series of measurement values
fs : float, optional
Sampling frequency of the `x` time series. Defaults to 1.0.
window : str or tuple or array_like, optional
Desired window to use. Defaults to 'hann'.
nperseg : int, optional
Length of each segment. Defaults to 256.
noverlap : int, optional
Number of points to overlap between segments. Defaults to `nperseg // 2`.
nfft : int, optional
Length of the FFT used. If `None`, the default is `nperseg`.
detrend : str or function or `False`, optional
Specifies how to detrend each segment. If `detrend` is a string, it is
passed as the `type` argument to the `detrend` function. If it is a
function, it takes a segment and returns a detrended segment. If `False`,
no detrending is done. Defaults to 'constant'.
return_onesided : bool, optional
If `True`, return a one-sided spectrum for real data. If `False` return
a two-sided spectrum. If `x` is complex, the default is `False`.
scaling : { 'density', 'spectrum' }, optional
Selects between computing the power spectral density ('density') where
`Pxx` has units of V**2/Hz and computing the power spectrum ('spectrum')
where `Pxx` has units of V**2, if `x` is measured in V and fs is measured
in Hz. Defaults to 'density'.
axis : int, optional
Axis along which the periodogram is computed; the default is over the
last axis (i.e., axis=-1).
average : { 'mean', 'median' }, optional
Method to use when averaging periodograms. Defaults to 'mean'.

Returns
-------
f : ndarray
Array of sample frequencies.
Pxx : ndarray
Power spectral density or power spectrum of x.
"""
# Convert x to numpy array
x = np.asarray(x)

# Check if there are any NaN values
if not np.any(np.isnan(x)):
# If no NaN values, use the original welch function
return signal.welch(x, fs=fs, window=window, nperseg=nperseg,
noverlap=noverlap, nfft=nfft, detrend=detrend,
return_onesided=return_onesided, scaling=scaling,
axis=axis, average=average)

# Handle negative axis
if axis < 0:
axis = x.ndim + axis

# Set default parameters
if nperseg is None:
nperseg = min(256, x.shape[axis])

if noverlap is None:
noverlap = nperseg // 2

if nfft is None:
nfft = nperseg

# Get the window
if isinstance(window, str) or isinstance(window, tuple):
win = signal.get_window(window, nperseg)
else:
win = np.asarray(window)
if len(win) != nperseg:
raise ValueError('window must have length of nperseg')

# Determine if input is complex
is_complex = np.iscomplexobj(x)

# Calculate frequencies for return array
if return_onesided and not is_complex:
# Real input, one-sided frequency range
freqs = np.fft.rfftfreq(nfft, 1.0/fs)
else:
# Complex input or two-sided frequency range
if return_onesided and is_complex:
# For complex input with return_onesided=True, scipy.signal.welch
# would issue a warning and compute the full spectrum
import warnings
warnings.warn('return_onesided=True is ignored for complex input. '
'Computing two-sided spectrum.')
freqs = np.fft.fftfreq(nfft, 1.0/fs)

# Init periodogram array
segment_psds = []

# Calculate step size between segments
step = nperseg - noverlap

# For each slice of the input data, calculate a periodogram
ind = 0
while ind + nperseg <= x.shape[axis]:
# Extract segment
segment_slice = [slice(None)] * x.ndim
segment_slice[axis] = slice(ind, ind + nperseg)
segment = x[tuple(segment_slice)].copy()

# Check if segment contains NaN
if not np.any(np.isnan(segment)):
# Detrend if needed
if detrend != False:
segment = signal.detrend(segment, type=detrend, axis=axis)

# Apply window (broadcasting to the correct shape)
win_shape = [1] * segment.ndim
win_shape[axis] = len(win)
segment = segment * win.reshape(win_shape)

# Compute FFT
if return_onesided and not is_complex:
# Real input, one-sided FFT
fft_data = np.fft.rfft(segment, n=nfft, axis=axis)
else:
# Complex input or two-sided FFT
fft_data = np.fft.fft(segment, n=nfft, axis=axis)

# Compute PSD based on scaling
if scaling == 'density':
# Power spectral density
segment_psd = abs(fft_data)**2 / (fs * (win**2).sum())
else: # scaling == 'spectrum'
# Power spectrum
segment_psd = abs(fft_data)**2 / win.sum()**2

# Apply one-sided scaling for real data
if return_onesided and not is_complex:
# Multiply by 2 for one-sided (except at DC and Nyquist)
if nfft % 2 == 0: # Even nfft, Nyquist present
segment_psd[..., 1:-1] *= 2
else: # Odd nfft, Nyquist not present
segment_psd[..., 1:] *= 2

segment_psds.append(segment_psd)

# Move to next segment
ind += step

# If no valid segments found, return NaN array
if not segment_psds:
if return_onesided and not is_complex:
Pxx = np.full(freqs.shape, np.nan)
else:
Pxx = np.full(freqs.shape, np.nan)
return freqs, Pxx

# Stack periodograms for averaging
segment_psds = np.stack(segment_psds, axis=0)

# Average the periodograms
if average == 'mean':
Pxx = np.mean(segment_psds, axis=0)
elif average == 'median':
Pxx = np.median(segment_psds, axis=0)
else:
raise ValueError(f"Unknown average: {average}")

return freqs, Pxx
Loading