-
Notifications
You must be signed in to change notification settings - Fork 241
Fix bug spatial filter #4175 #4286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
212b93c
b03d044
ca25b61
dea1b5f
1a673c4
f208481
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -207,6 +207,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): | |
| traces = traces.copy() | ||
|
|
||
| # apply AGC and keep the gains | ||
| traces = traces.astype(np.float32) | ||
| if self.window is not None: | ||
| traces, agc_gains = agc(traces, window=self.window) | ||
| else: | ||
|
|
@@ -255,7 +256,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): | |
| # ----------------------------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def agc(traces, window, epsilon=1e-8): | ||
| def agc(traces, window, epsilon=None): | ||
|
||
| """ | ||
| Automatic gain control | ||
| w_agc, gain = agc(w, window_length=.5, si=.002, epsilon=1e-8) | ||
|
|
@@ -268,13 +269,15 @@ def agc(traces, window, epsilon=1e-8): | |
| """ | ||
| import scipy.signal | ||
|
|
||
| gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0) | ||
| # default value for epsilon is relative to the rms, loosely matching the IBL 1e-8 for an input in Volts | ||
| if epsilon is None: | ||
| epsilon = np.std(traces - np.mean(traces)) * 0.003 | ||
alejoe91 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| gain += (np.sum(gain, axis=0) * epsilon / traces.shape[0])[np.newaxis, :] | ||
| gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0) | ||
|
|
||
| dead_channels = np.sum(gain, axis=0) == 0 | ||
|
|
||
| traces[:, ~dead_channels] = traces[:, ~dead_channels] / gain[:, ~dead_channels] | ||
| traces[:, ~dead_channels] = traces[:, ~dead_channels] / np.maximum(epsilon, gain[:, ~dead_channels]) | ||
|
|
||
| return traces, gain | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |||||
| import numpy as np | ||||||
| from copy import deepcopy | ||||||
|
|
||||||
| import spikeinterface as si | ||||||
| import spikeinterface.full as si | ||||||
oliche marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| import spikeinterface.preprocessing as spre | ||||||
| import spikeinterface.extractors as se | ||||||
| from spikeinterface.core import generate_recording | ||||||
|
|
@@ -24,7 +24,7 @@ | |||||
|
|
||||||
|
|
||||||
| @pytest.mark.skipif( | ||||||
| importlib.util.find_spec("neurodsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB, | ||||||
| importlib.util.find_spec("ibldsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB, | ||||||
| reason="Only local. Requires ibl-neuropixel install", | ||||||
| ) | ||||||
| @pytest.mark.parametrize("lagc", [False, 1, 300]) | ||||||
|
|
@@ -51,32 +51,28 @@ def test_highpass_spatial_filter_real_data(lagc): | |||||
| use DEBUG = true to visualise. | ||||||
|
|
||||||
| """ | ||||||
| import spikeglx | ||||||
| import neurodsp.voltage as voltage | ||||||
| import ibldsp.voltage | ||||||
| import neuropixel | ||||||
|
|
||||||
| options = dict(lagc=lagc, ntr_pad=25, ntr_tap=50, butter_kwargs=None) | ||||||
| print(options) | ||||||
|
|
||||||
| ibl_data, si_recording = get_ibl_si_data() | ||||||
|
|
||||||
| si_filtered, _ = run_si_highpass_filter(si_recording, **options) | ||||||
| local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0") | ||||||
| si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap") | ||||||
| si_recording = spre.astype(si_recording, "float") | ||||||
| recording_ps = si.phase_shift(si_recording) | ||||||
oliche marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| recording_hp = si.highpass_filter(recording_ps, freq_min=300, filter_order=3) | ||||||
|
||||||
| recording_hp = si.highpass_filter(recording_ps, freq_min=300, filter_order=3) | |
| recording_hp = spre.highpass_filter(recording_ps, freq_min=300, filter_order=3) |
Uh oh!
There was an error while loading. Please reload this page.