Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions src/spikeinterface/extractors/cbin_ibl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import probeinterface

from spikeinterface.core import BaseRecording, BaseRecordingSegment
from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts
from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts_from_probe
from spikeinterface.core.core_tools import define_function_from_class


Expand Down Expand Up @@ -45,21 +45,13 @@ class CompressedBinaryIblExtractor(BaseRecording):
installation_mesg = "To use the CompressedBinaryIblExtractor, install mtscomp: \n\n pip install mtscomp\n\n"

def __init__(
self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file_path=None, cbin_file=None
self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file_path=None
):
from neo.rawio.spikeglxrawio import read_meta_file

try:
import mtscomp
except ImportError:
raise ImportError(self.installation_mesg)
if cbin_file is not None:
warnings.warn(
"The `cbin_file` argument is deprecated and will be removed in version 0.104.0, please use `cbin_file_path` instead",
DeprecationWarning,
stacklevel=2,
)
cbin_file_path = cbin_file
if cbin_file_path is None:
folder_path = Path(folder_path)
# check bands
Expand Down Expand Up @@ -124,8 +116,7 @@ def __init__(
num_channels_per_adc = 16
else: # NP1.0
num_channels_per_adc = 12

sample_shifts = get_neuropixels_sample_shifts(self.get_num_channels(), num_channels_per_adc)
sample_shifts = get_neuropixels_sample_shifts_from_probe(self.get_num_channels(), num_channels_per_adc)
self.set_property("inter_sample_shift", sample_shifts)

self._kwargs = {
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class FilterRecording(BasePreprocessor):
def __init__(
self,
recording,
band=[300.0, 6000.0],
band=(300.0, 6000.0),
btype="bandpass",
filter_order=5,
ftype="butter",
Expand Down Expand Up @@ -370,7 +370,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms="auto", dtype=None, **f
def causal_filter(
recording,
direction="forward",
band=[300.0, 6000.0],
band=(300.0, 6000.0),
btype="bandpass",
filter_order=5,
ftype="butter",
Expand Down
18 changes: 12 additions & 6 deletions src/spikeinterface/preprocessing/highpass_spatial_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment, BaseRecording
from .filter import fix_dtype
from spikeinterface.core import order_channels_by_depth, get_chunk_with_margin
from spikeinterface.core.core_tools import define_function_handling_dict_from_class
Expand Down Expand Up @@ -66,14 +66,15 @@ class HighpassSpatialFilterRecording(BasePreprocessor):

def __init__(
self,
recording,
recording: BaseRecording,
n_channel_pad=60,
n_channel_taper=0,
direction="y",
apply_agc=True,
agc_window_length_s=0.1,
highpass_butter_order=3,
highpass_butter_wn=0.01,
epsilon=None,
dtype=None,
):
BasePreprocessor.__init__(self, recording)
Expand Down Expand Up @@ -133,6 +134,7 @@ def __init__(
order_f,
order_r,
dtype=dtype,
epsilon=epsilon,
)
self.add_recording_segment(rec_segment)

Expand Down Expand Up @@ -161,6 +163,7 @@ def __init__(
order_f,
order_r,
dtype,
epsilon,
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.parent_recording_segment = parent_recording_segment
Expand Down Expand Up @@ -207,6 +210,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
traces = traces.copy()

# apply AGC and keep the gains
traces = traces.astype(np.float32)
if self.window is not None:
traces, agc_gains = agc(traces, window=self.window)
else:
Expand Down Expand Up @@ -255,7 +259,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
# -----------------------------------------------------------------------------------------------


def agc(traces, window, epsilon=1e-8):
def agc(traces, window, epsilon=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a problem here because the espilon will change for every chunk no ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would vote to force espilon here because it is an helper function and compute it the init of the class.
no ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but this is a second order issue, we just need to stabilize the inverse of the gain.

To mitigate, I have exposed the epsilon as a parameter in the interface. It would be useful to have a BaseRecording.estimate_rms() method to fix a single epsilon for the full processing, but this is beyond the scope of this PR.

If everyone is on-board for this, I can create an issue and summarize this discussion

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me. We hav an get_noise_levels(recording, method="rms") that we could run in the init if rms levels are not estimated yet. But let's discuss specifics on the new issue

"""
Automatic gain control
w_agc, gain = agc(w, window_length=.5, si=.002, epsilon=1e-8)
Expand All @@ -268,13 +272,15 @@ def agc(traces, window, epsilon=1e-8):
"""
import scipy.signal

gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0)
# default value for epsilon is relative to the rms, loosely matching the IBL 1e-8 for an input in Volts
if epsilon is None:
epsilon = np.std(traces - np.mean(traces)) * 0.003

gain += (np.sum(gain, axis=0) * epsilon / traces.shape[0])[np.newaxis, :]
gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0)

dead_channels = np.sum(gain, axis=0) == 0

traces[:, ~dead_channels] = traces[:, ~dead_channels] / gain[:, ~dead_channels]
traces[:, ~dead_channels] = traces[:, ~dead_channels] / np.maximum(epsilon, gain[:, ~dead_channels])

return traces, gain

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from copy import deepcopy

import spikeinterface as si
import spikeinterface.core as si
import spikeinterface.preprocessing as spre
import spikeinterface.extractors as se
from spikeinterface.core import generate_recording
Expand All @@ -24,7 +24,7 @@


@pytest.mark.skipif(
importlib.util.find_spec("neurodsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB,
importlib.util.find_spec("ibldsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB,
reason="Only local. Requires ibl-neuropixel install",
)
@pytest.mark.parametrize("lagc", [False, 1, 300])
Expand All @@ -51,32 +51,28 @@ def test_highpass_spatial_filter_real_data(lagc):
use DEBUG = true to visualise.

"""
import spikeglx
import neurodsp.voltage as voltage
import ibldsp.voltage
import neuropixel

options = dict(lagc=lagc, ntr_pad=25, ntr_tap=50, butter_kwargs=None)
print(options)

ibl_data, si_recording = get_ibl_si_data()

si_filtered, _ = run_si_highpass_filter(si_recording, **options)
local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0")
si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap")
si_recording = spre.astype(si_recording, "float")
recording_ps = spre.phase_shift(si_recording)
recording_hp = spre.highpass_filter(recording_ps, freq_min=300, filter_order=3)
recording_hps = spre.highpass_spatial_filter(recording_hp)
raw = si_recording.get_traces().astype(np.float32).T * neuropixel.S2V_AP
si_filtered = recording_hps.get_traces().astype(np.float32).T * neuropixel.S2V_AP

ibl_filtered = run_ibl_highpass_filter(ibl_data.copy(), **options)
destripe = ibldsp.voltage.destripe(raw, fs=30_000, neuropixel_version=1)

if DEBUG:
fig, axs = plt.subplots(ncols=4)
axs[0].imshow(si_recording.get_traces(return_in_uV=True))
axs[0].set_title("SI Raw")
axs[1].imshow(ibl_data.T)
axs[1].set_title("IBL Raw")
axs[2].imshow(si_filtered)
axs[2].set_title("SI Filtered ")
axs[3].imshow(ibl_filtered)
axs[3].set_title("IBL Filtered")
from viewephys.gui import viewephys

eqc = {}
eqc["si_filtered"] = viewephys(si_filtered, fs=30_000, title="si_filtered")
eqc["ibl_filtered"] = viewephys(destripe, fs=30_000, title="ibl_filtered")

assert np.allclose(
si_filtered, ibl_filtered * 1e6, atol=1e-01, rtol=0
) # the differences are entired due to scaling on data load.
np.testing.assert_allclose(si_filtered[12:120, 300:800], destripe[12:120, 300:800], atol=1e-05, rtol=0)


@pytest.mark.parametrize("ntr_pad", [None, 0, 31])
Expand Down Expand Up @@ -140,24 +136,6 @@ def test_dtype_stability(dtype):
# ----------------------------------------------------------------------------------------------------------------------


def get_ibl_si_data():
"""
Set fixture to session to ensure origional data is not changed.
"""
import spikeglx

local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0")
ibl_recording = spikeglx.Reader(
local_path / "Noise4Sam_g0_imec0" / "Noise4Sam_g0_t0.imec0.ap.bin", ignore_warnings=True
)
ibl_data = ibl_recording.read(slice(None), slice(None), sync=False)[:, :-1].T # cut sync channel

si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap")
si_recording = spre.astype(si_recording, dtype="float32")

return ibl_data, si_recording


def process_args_for_si(si_recording, lagc):
""""""
if isinstance(lagc, bool) and not lagc:
Expand Down Expand Up @@ -215,9 +193,10 @@ def run_si_highpass_filter(si_recording, ntr_pad, ntr_tap, lagc, butter_kwargs,


def run_ibl_highpass_filter(ibl_data, ntr_pad, ntr_tap, lagc, butter_kwargs):
butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc)
import ibldsp.voltage

ibl_filtered = voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T
butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc)
ibl_filtered = ibldsp.voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T

return ibl_filtered

Expand Down