diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index c70c49e8f8..a4d2816c8e 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -7,7 +7,7 @@ import probeinterface from spikeinterface.core import BaseRecording, BaseRecordingSegment -from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts +from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts_from_probe from spikeinterface.core.core_tools import define_function_from_class @@ -44,22 +44,13 @@ class CompressedBinaryIblExtractor(BaseRecording): installation_mesg = "To use the CompressedBinaryIblExtractor, install mtscomp: \n\n pip install mtscomp\n\n" - def __init__( - self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file_path=None, cbin_file=None - ): + def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file_path=None): from neo.rawio.spikeglxrawio import read_meta_file try: import mtscomp except ImportError: raise ImportError(self.installation_mesg) - if cbin_file is not None: - warnings.warn( - "The `cbin_file` argument is deprecated and will be removed in version 0.104.0, please use `cbin_file_path` instead", - DeprecationWarning, - stacklevel=2, - ) - cbin_file_path = cbin_file if cbin_file_path is None: folder_path = Path(folder_path) # check bands @@ -124,8 +115,7 @@ def __init__( num_channels_per_adc = 16 else: # NP1.0 num_channels_per_adc = 12 - - sample_shifts = get_neuropixels_sample_shifts(self.get_num_channels(), num_channels_per_adc) + sample_shifts = get_neuropixels_sample_shifts_from_probe(self.get_num_channels(), num_channels_per_adc) self.set_property("inter_sample_shift", sample_shifts) self._kwargs = { diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 732b310123..f9d337fe8b 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -79,7 +79,7 @@ class FilterRecording(BasePreprocessor): def __init__( self, recording, - band=[300.0, 6000.0], + band=(300.0, 6000.0), btype="bandpass", filter_order=5, ftype="butter", @@ -370,7 +370,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms="auto", dtype=None, **f def causal_filter( recording, direction="forward", - band=[300.0, 6000.0], + band=(300.0, 6000.0), btype="bandpass", filter_order=5, ftype="butter", diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index 9228f5de12..212ff2ddec 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -2,7 +2,7 @@ import numpy as np -from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment +from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment, BaseRecording from .filter import fix_dtype from spikeinterface.core import order_channels_by_depth, get_chunk_with_margin from spikeinterface.core.core_tools import define_function_handling_dict_from_class @@ -66,7 +66,7 @@ class HighpassSpatialFilterRecording(BasePreprocessor): def __init__( self, - recording, + recording: BaseRecording, n_channel_pad=60, n_channel_taper=0, direction="y", @@ -74,6 +74,7 @@ def __init__( agc_window_length_s=0.1, highpass_butter_order=3, highpass_butter_wn=0.01, + epsilon=None, dtype=None, ): BasePreprocessor.__init__(self, recording) @@ -133,6 +134,7 @@ def __init__( order_f, order_r, dtype=dtype, + epsilon=epsilon, ) self.add_recording_segment(rec_segment) @@ -161,6 +163,7 @@ def __init__( order_f, order_r, dtype, + epsilon, ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.parent_recording_segment = parent_recording_segment @@ -207,6 +210,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = traces.copy() # apply AGC and keep the gains + traces = traces.astype(np.float32) if self.window is not None: traces, agc_gains = agc(traces, window=self.window) else: @@ -255,7 +259,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): # ----------------------------------------------------------------------------------------------- -def agc(traces, window, epsilon=1e-8): +def agc(traces, window, epsilon=None): """ Automatic gain control w_agc, gain = agc(w, window_length=.5, si=.002, epsilon=1e-8) @@ -268,13 +272,15 @@ def agc(traces, window, epsilon=1e-8): """ import scipy.signal - gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0) + # default value for epsilon is relative to the rms, loosely matching the IBL 1e-8 for an input in Volts + if epsilon is None: + epsilon = np.std(traces - np.mean(traces)) * 0.003 - gain += (np.sum(gain, axis=0) * epsilon / traces.shape[0])[np.newaxis, :] + gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0) dead_channels = np.sum(gain, axis=0) == 0 - traces[:, ~dead_channels] = traces[:, ~dead_channels] / gain[:, ~dead_channels] + traces[:, ~dead_channels] = traces[:, ~dead_channels] / np.maximum(epsilon, gain[:, ~dead_channels]) return traces, gain diff --git a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py index 8ff2aea547..4aa014bbeb 100644 --- a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py @@ -3,7 +3,7 @@ import numpy as np from copy import deepcopy -import spikeinterface as si +import spikeinterface.core as si import spikeinterface.preprocessing as spre import spikeinterface.extractors as se from spikeinterface.core import generate_recording @@ -24,7 +24,7 @@ @pytest.mark.skipif( - importlib.util.find_spec("neurodsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB, + importlib.util.find_spec("ibldsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB, reason="Only local. Requires ibl-neuropixel install", ) @pytest.mark.parametrize("lagc", [False, 1, 300]) @@ -51,32 +51,28 @@ def test_highpass_spatial_filter_real_data(lagc): use DEBUG = true to visualise. """ - import spikeglx - import neurodsp.voltage as voltage + import ibldsp.voltage + import neuropixel - options = dict(lagc=lagc, ntr_pad=25, ntr_tap=50, butter_kwargs=None) - print(options) - - ibl_data, si_recording = get_ibl_si_data() - - si_filtered, _ = run_si_highpass_filter(si_recording, **options) + local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0") + si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap") + si_recording = spre.astype(si_recording, "float") + recording_ps = spre.phase_shift(si_recording) + recording_hp = spre.highpass_filter(recording_ps, freq_min=300, filter_order=3) + recording_hps = spre.highpass_spatial_filter(recording_hp) + raw = si_recording.get_traces().astype(np.float32).T * neuropixel.S2V_AP + si_filtered = recording_hps.get_traces().astype(np.float32).T * neuropixel.S2V_AP - ibl_filtered = run_ibl_highpass_filter(ibl_data.copy(), **options) + destripe = ibldsp.voltage.destripe(raw, fs=30_000, neuropixel_version=1) if DEBUG: - fig, axs = plt.subplots(ncols=4) - axs[0].imshow(si_recording.get_traces(return_in_uV=True)) - axs[0].set_title("SI Raw") - axs[1].imshow(ibl_data.T) - axs[1].set_title("IBL Raw") - axs[2].imshow(si_filtered) - axs[2].set_title("SI Filtered ") - axs[3].imshow(ibl_filtered) - axs[3].set_title("IBL Filtered") + from viewephys.gui import viewephys + + eqc = {} + eqc["si_filtered"] = viewephys(si_filtered, fs=30_000, title="si_filtered") + eqc["ibl_filtered"] = viewephys(destripe, fs=30_000, title="ibl_filtered") - assert np.allclose( - si_filtered, ibl_filtered * 1e6, atol=1e-01, rtol=0 - ) # the differences are entired due to scaling on data load. + np.testing.assert_allclose(si_filtered[12:120, 300:800], destripe[12:120, 300:800], atol=1e-05, rtol=0) @pytest.mark.parametrize("ntr_pad", [None, 0, 31]) @@ -140,24 +136,6 @@ def test_dtype_stability(dtype): # ---------------------------------------------------------------------------------------------------------------------- -def get_ibl_si_data(): - """ - Set fixture to session to ensure origional data is not changed. - """ - import spikeglx - - local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0") - ibl_recording = spikeglx.Reader( - local_path / "Noise4Sam_g0_imec0" / "Noise4Sam_g0_t0.imec0.ap.bin", ignore_warnings=True - ) - ibl_data = ibl_recording.read(slice(None), slice(None), sync=False)[:, :-1].T # cut sync channel - - si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap") - si_recording = spre.astype(si_recording, dtype="float32") - - return ibl_data, si_recording - - def process_args_for_si(si_recording, lagc): """""" if isinstance(lagc, bool) and not lagc: @@ -215,9 +193,10 @@ def run_si_highpass_filter(si_recording, ntr_pad, ntr_tap, lagc, butter_kwargs, def run_ibl_highpass_filter(ibl_data, ntr_pad, ntr_tap, lagc, butter_kwargs): - butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc) + import ibldsp.voltage - ibl_filtered = voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T + butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc) + ibl_filtered = ibldsp.voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T return ibl_filtered