Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/spikeinterface/extractors/cbin_ibl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def __init__(
):
from neo.rawio.spikeglxrawio import read_meta_file

if Path(folder_path).is_file():
folder_path = Path(folder_path).parent
try:
import mtscomp
except ImportError:
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class FilterRecording(BasePreprocessor):
def __init__(
self,
recording,
band=[300.0, 6000.0],
band=(300.0, 6000.0),
btype="bandpass",
filter_order=5,
ftype="butter",
Expand Down Expand Up @@ -370,7 +370,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms="auto", dtype=None, **f
def causal_filter(
recording,
direction="forward",
band=[300.0, 6000.0],
band=(300.0, 6000.0),
btype="bandpass",
filter_order=5,
ftype="butter",
Expand Down
11 changes: 7 additions & 4 deletions src/spikeinterface/preprocessing/highpass_spatial_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
traces = traces.copy()

# apply AGC and keep the gains
traces = traces.astype(np.float32)
if self.window is not None:
traces, agc_gains = agc(traces, window=self.window)
else:
Expand Down Expand Up @@ -255,7 +256,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
# -----------------------------------------------------------------------------------------------


def agc(traces, window, epsilon=1e-8):
def agc(traces, window, epsilon=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a problem here because the espilon will change for every chunk no ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would vote to force espilon here because it is an helper function and compute it the init of the class.
no ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but this is a second order issue, we just need to stabilize the inverse of the gain.

To mitigate, I have exposed the epsilon as a parameter in the interface. It would be useful to have a BaseRecording.estimate_rms() method to fix a single epsilon for the full processing, but this is beyond the scope of this PR.

If everyone is on-board for this, I can create an issue and summarize this discussion

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me. We hav an get_noise_levels(recording, method="rms") that we could run in the init if rms levels are not estimated yet. But let's discuss specifics on the new issue

"""
Automatic gain control
w_agc, gain = agc(w, window_length=.5, si=.002, epsilon=1e-8)
Expand All @@ -268,13 +269,15 @@ def agc(traces, window, epsilon=1e-8):
"""
import scipy.signal

gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0)
# default value for epsilon is relative to the rms, loosely matching the IBL 1e-8 for an input in Volts
if epsilon is None:
epsilon = np.std(traces - np.mean(traces)) * 0.003

gain += (np.sum(gain, axis=0) * epsilon / traces.shape[0])[np.newaxis, :]
gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0)

dead_channels = np.sum(gain, axis=0) == 0

traces[:, ~dead_channels] = traces[:, ~dead_channels] / gain[:, ~dead_channels]
traces[:, ~dead_channels] = traces[:, ~dead_channels] / np.maximum(epsilon, gain[:, ~dead_channels])

return traces, gain

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from copy import deepcopy

import spikeinterface as si
import spikeinterface.full as si
import spikeinterface.preprocessing as spre
import spikeinterface.extractors as se
from spikeinterface.core import generate_recording
Expand All @@ -24,7 +24,7 @@


@pytest.mark.skipif(
importlib.util.find_spec("neurodsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB,
importlib.util.find_spec("ibldsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB,
reason="Only local. Requires ibl-neuropixel install",
)
@pytest.mark.parametrize("lagc", [False, 1, 300])
Expand All @@ -51,32 +51,28 @@ def test_highpass_spatial_filter_real_data(lagc):
use DEBUG = true to visualise.

"""
import spikeglx
import neurodsp.voltage as voltage
import ibldsp.voltage
import neuropixel

options = dict(lagc=lagc, ntr_pad=25, ntr_tap=50, butter_kwargs=None)
print(options)

ibl_data, si_recording = get_ibl_si_data()

si_filtered, _ = run_si_highpass_filter(si_recording, **options)
local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0")
si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap")
si_recording = spre.astype(si_recording, "float")
recording_ps = si.phase_shift(si_recording)
recording_hp = si.highpass_filter(recording_ps, freq_min=300, filter_order=3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
recording_hp = si.highpass_filter(recording_ps, freq_min=300, filter_order=3)
recording_hp = spre.highpass_filter(recording_ps, freq_min=300, filter_order=3)

recording_hps = si.highpass_spatial_filter(recording_hp)
raw = si_recording.get_traces().astype(np.float32).T * neuropixel.S2V_AP
si_filtered = recording_hps.get_traces().astype(np.float32).T * neuropixel.S2V_AP

ibl_filtered = run_ibl_highpass_filter(ibl_data.copy(), **options)
destripe = ibldsp.voltage.destripe(raw, fs=30_000, neuropixel_version=1)

if DEBUG:
fig, axs = plt.subplots(ncols=4)
axs[0].imshow(si_recording.get_traces(return_in_uV=True))
axs[0].set_title("SI Raw")
axs[1].imshow(ibl_data.T)
axs[1].set_title("IBL Raw")
axs[2].imshow(si_filtered)
axs[2].set_title("SI Filtered ")
axs[3].imshow(ibl_filtered)
axs[3].set_title("IBL Filtered")
from viewephys.gui import viewephys

eqc = {}
eqc["si_filtered"] = viewephys(si_filtered, fs=30_000, title="si_filtered")
eqc["ibl_filtered"] = viewephys(destripe, fs=30_000, title="ibl_filtered")

assert np.allclose(
si_filtered, ibl_filtered * 1e6, atol=1e-01, rtol=0
) # the differences are entired due to scaling on data load.
np.testing.assert_allclose(si_filtered[12:120, 300:800], destripe[12:120, 300:800], atol=1e-05, rtol=0)


@pytest.mark.parametrize("ntr_pad", [None, 0, 31])
Expand Down Expand Up @@ -140,24 +136,6 @@ def test_dtype_stability(dtype):
# ----------------------------------------------------------------------------------------------------------------------


def get_ibl_si_data():
"""
Set fixture to session to ensure origional data is not changed.
"""
import spikeglx

local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0")
ibl_recording = spikeglx.Reader(
local_path / "Noise4Sam_g0_imec0" / "Noise4Sam_g0_t0.imec0.ap.bin", ignore_warnings=True
)
ibl_data = ibl_recording.read(slice(None), slice(None), sync=False)[:, :-1].T # cut sync channel

si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap")
si_recording = spre.astype(si_recording, dtype="float32")

return ibl_data, si_recording


def process_args_for_si(si_recording, lagc):
""""""
if isinstance(lagc, bool) and not lagc:
Expand Down Expand Up @@ -215,9 +193,10 @@ def run_si_highpass_filter(si_recording, ntr_pad, ntr_tap, lagc, butter_kwargs,


def run_ibl_highpass_filter(ibl_data, ntr_pad, ntr_tap, lagc, butter_kwargs):
butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc)
import ibldsp.voltage

ibl_filtered = voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T
butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc)
ibl_filtered = ibldsp.voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T

return ibl_filtered

Expand Down