From 49d0698788db452524b3e9d3a0c7f427c08302b2 Mon Sep 17 00:00:00 2001 From: JulesLebert Date: Fri, 14 Apr 2023 15:16:55 +0100 Subject: [PATCH 1/4] allow to replace saturation with a window --- src/spikeinterface/preprocessing/clip.py | 130 ++++++++++++++++++++++- 1 file changed, 125 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index 909c0c9171..cc0ab9a475 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -1,5 +1,11 @@ import numpy as np +try: + from numba import njit, guvectorize, float64 + HAVE_NUMBA = True +except ModuleNotFoundError as err: + HAVE_NUMBA = False + from spikeinterface.core.core_tools import define_function_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment @@ -88,6 +94,7 @@ class BlankSaturationRecording(BasePreprocessor): def __init__(self, recording, abs_threshold=None, quantile_threshold=None, direction='upper', fill_value=None, + ms_before=0, ms_after=0, num_chunks_per_segment=50, chunk_size=500, seed=0): assert direction in ('upper', 'lower', 'both') @@ -130,36 +137,149 @@ def __init__(self, recording, abs_threshold=None, quantile_threshold=None, BasePreprocessor.__init__(self, recording) for parent_segment in recording._recording_segments: rec_segment = ClipRecordingSegment( - parent_segment, a_min, value_min, a_max, value_max) + parent_segment, a_min, value_min, a_max, value_max, + ms_before=ms_before, ms_after=ms_after + ) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, abs_threshold=abs_threshold, + self._kwargs = dict(recording=recording, abs_threshold=abs_threshold, ms_before=ms_before, ms_after=ms_after, quantile_threshold=quantile_threshold, direction=direction, fill_value=fill_value, num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size, seed=seed) class ClipRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment, a_min, value_min, a_max, value_max): + def __init__(self, parent_recording_segment, a_min, value_min, a_max, value_max, + ms_before=0, ms_after=0): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.a_min = a_min self.value_min = value_min self.a_max = a_max self.value_max = value_max + self.ms_before = ms_before + self.ms_after = ms_after + def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces( start_frame, end_frame, channel_indices) traces = traces.copy() + fs = self.parent_recording_segment.sampling_frequency + + frames_before = int(self.ms_before * fs // 1000) + frames_after = int(self.ms_after * fs // 1000) if self.a_min is not None: - traces[traces <= self.a_min] = self.value_min + traces = replace_slice_min(traces, self.a_min, frames_before, frames_after, self.value_min) + if self.a_max is not None: - traces[traces >= self.a_max] = self.value_max + traces = replace_slice_max(traces, self.a_max, frames_before, frames_after, self.value_max) return traces +def replace_slice_min(traces, a_min, frames_before, frames_after, value_min): + print('numba: ', HAVE_NUMBA) + if HAVE_NUMBA: + return _replace_slice_min_numba(traces, a_min, frames_before, frames_after, value_min) + else: + return _replace_slice_min_for_loop(traces, a_min, frames_before, frames_after, value_min) + +def replace_slice_max(traces, a_max, frames_before, frames_after, value_max): + print('numba: ', HAVE_NUMBA) + if HAVE_NUMBA: + return _replace_slice_max_numba(traces, a_max, frames_before, frames_after, value_max) + else: + return _replace_slice_max_for_loop(traces, a_max, frames_before, frames_after, value_max) + +# For loops +def _replace_slice_min_for_loop(traces, a_min, frames_before, frames_after, value_min): + min_indices, channels = np.where(traces <= a_min) + for index, chan in zip(min_indices, channels): + traces[max(0, index - frames_before):min(len(traces), index + frames_after + 1), chan] = value_min + return traces + +def _replace_slice_max_for_loop(traces, a_max, frames_before, frames_after, value_max): + max_indices, channels = np.where(traces >= a_max) + for index, chan in zip(max_indices, channels): + traces[max(0, index - frames_before):min(len(traces), index + frames_after + 1), chan] = value_max + return traces + +if HAVE_NUMBA: + # Numba + # @njit(cache=True) + # def _replace_slice_min_numba(traces, res, a_min, frames_before, frames_after, value_min): + # m, n = traces.shape + # for i in range(m): + # for j in range(n): + # if traces[i, j] <= a_min: + # res[max(0, i - frames_before):min(m, i + frames_after + 1), j] = value_min + # return res + + # @njit(cache=True) + # def _replace_slice_max_numba(traces, res, a_max, frames_before, frames_after, value_max): + # m, n = traces.shape + # for i in range(m): + # for j in range(n): + # if traces[i, j] >= a_max: + # res[max(0, i - frames_before):min(m, i + frames_after + 1), j] = value_max + # return res + + @njit(cache=True) + def _replace_slice_max_numba(traces, a_max, frames_before, frames_after, value_max): + m, n = traces.shape + to_clear = np.zeros(m, dtype=np.bool_) + for j in range(n): + to_clear[:] = False + for i in range(m): + if traces[i, j] >= a_max: + to_clear[ + max(0, i - frames_before) : min(m, i + frames_after + 1) + ] = True + for i in range(m): + if to_clear[i]: + traces[i, j] = value_max + return traces + + @njit(cache=True) + def _replace_slice_min_numba(traces, a_min, frames_before, frames_after, value_min): + m, n = traces.shape + to_clear = np.zeros(m, dtype=np.bool_) + for j in range(n): + to_clear[:] = False + for i in range(m): + if traces[i, j] <= a_min: + to_clear[ + max(0, i - frames_before) : min(m, i + frames_after + 1) + ] = True + for i in range(m): + if to_clear[i]: + traces[i, j] = value_min + return traces + # @guvectorize( + # [(float64[:], float64, float64, float64, float64, float64[:])], + # "(n),(),(),(),()->(n)", cache=True, nopython=True + # ) + # def _replace_slice_max_numba(traces, a_max, frames_before, frames_after, value_max, res): + # m = traces.shape[0] + # res[:] = traces + # for i in range(m): + # if traces[i] >= a_max: + # res[max(0, i - frames_before) : min(m, i + frames_after + 1)] = value_max + # return + + # @guvectorize( + # [(float64[:], float64, float64, float64, float64, float64[:])], + # "(n),(),(),(),()->(n)", cache=True, nopython=True + # ) + # def _replace_slice_min_numba(traces, a_min, frames_before, frames_after, value_min, res): + # m = traces.shape[0] + # res[:] = traces + # for i in range(m): + # if traces[i] <= a_min: + # res[max(0, i - frames_before) : min(m, i + frames_after + 1)] = value_min + # return + clip = define_function_from_class(source_class=ClipRecording, name="clip") blank_staturation = define_function_from_class(source_class=BlankSaturationRecording, name="blank_staturation") From 12e8e919921b53a1a1340521a0ef98436499a12d Mon Sep 17 00:00:00 2001 From: JulesLebert Date: Fri, 14 Apr 2023 15:25:49 +0100 Subject: [PATCH 2/4] clean code --- src/spikeinterface/preprocessing/clip.py | 50 +++--------------------- 1 file changed, 5 insertions(+), 45 deletions(-) diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index cc0ab9a475..7699c116ed 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -1,7 +1,7 @@ import numpy as np try: - from numba import njit, guvectorize, float64 + from numba import njit HAVE_NUMBA = True except ModuleNotFoundError as err: HAVE_NUMBA = False @@ -77,6 +77,10 @@ class BlankSaturationRecording(BasePreprocessor): fill_value: float or None The value to write instead of the saturating signal. If None, then the value is automatically computed as the median signal value + ms_before: float (default 0) + Time (ms) to replace before the saturation signal + ms_after: float (default 0) + Time (ms) to replace after the saturation signal num_chunks_per_segment: int (default 50) The number of chunks per segments to consider to estimate the threshold/fill_values chunk_size: int (default 500) @@ -179,14 +183,12 @@ def get_traces(self, start_frame, end_frame, channel_indices): return traces def replace_slice_min(traces, a_min, frames_before, frames_after, value_min): - print('numba: ', HAVE_NUMBA) if HAVE_NUMBA: return _replace_slice_min_numba(traces, a_min, frames_before, frames_after, value_min) else: return _replace_slice_min_for_loop(traces, a_min, frames_before, frames_after, value_min) def replace_slice_max(traces, a_max, frames_before, frames_after, value_max): - print('numba: ', HAVE_NUMBA) if HAVE_NUMBA: return _replace_slice_max_numba(traces, a_max, frames_before, frames_after, value_max) else: @@ -207,24 +209,6 @@ def _replace_slice_max_for_loop(traces, a_max, frames_before, frames_after, valu if HAVE_NUMBA: # Numba - # @njit(cache=True) - # def _replace_slice_min_numba(traces, res, a_min, frames_before, frames_after, value_min): - # m, n = traces.shape - # for i in range(m): - # for j in range(n): - # if traces[i, j] <= a_min: - # res[max(0, i - frames_before):min(m, i + frames_after + 1), j] = value_min - # return res - - # @njit(cache=True) - # def _replace_slice_max_numba(traces, res, a_max, frames_before, frames_after, value_max): - # m, n = traces.shape - # for i in range(m): - # for j in range(n): - # if traces[i, j] >= a_max: - # res[max(0, i - frames_before):min(m, i + frames_after + 1), j] = value_max - # return res - @njit(cache=True) def _replace_slice_max_numba(traces, a_max, frames_before, frames_after, value_max): m, n = traces.shape @@ -256,30 +240,6 @@ def _replace_slice_min_numba(traces, a_min, frames_before, frames_after, value_m if to_clear[i]: traces[i, j] = value_min return traces - - # @guvectorize( - # [(float64[:], float64, float64, float64, float64, float64[:])], - # "(n),(),(),(),()->(n)", cache=True, nopython=True - # ) - # def _replace_slice_max_numba(traces, a_max, frames_before, frames_after, value_max, res): - # m = traces.shape[0] - # res[:] = traces - # for i in range(m): - # if traces[i] >= a_max: - # res[max(0, i - frames_before) : min(m, i + frames_after + 1)] = value_max - # return - - # @guvectorize( - # [(float64[:], float64, float64, float64, float64, float64[:])], - # "(n),(),(),(),()->(n)", cache=True, nopython=True - # ) - # def _replace_slice_min_numba(traces, a_min, frames_before, frames_after, value_min, res): - # m = traces.shape[0] - # res[:] = traces - # for i in range(m): - # if traces[i] <= a_min: - # res[max(0, i - frames_before) : min(m, i + frames_after + 1)] = value_min - # return clip = define_function_from_class(source_class=ClipRecording, name="clip") blank_staturation = define_function_from_class(source_class=BlankSaturationRecording, name="blank_staturation") From 2852da99f2ae34c81d8f9879dd1ed1de9d40aabf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 May 2023 10:32:10 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/__init__.py | 1 + src/spikeinterface/preprocessing/clip.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/__init__.py b/src/spikeinterface/__init__.py index 5f2086ede7..99325d7d65 100644 --- a/src/spikeinterface/__init__.py +++ b/src/spikeinterface/__init__.py @@ -9,6 +9,7 @@ from .core import * import warnings + warnings.filterwarnings("ignore", message="distutils Version classes are deprecated") warnings.filterwarnings("ignore", message="the imp module is deprecated") diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index 6ac4d16451..f13f988bdc 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -205,7 +205,7 @@ def replace_slice_max(traces, a_max, frames_before, frames_after, value_max): return _replace_slice_max_numba(traces, a_max, frames_before, frames_after, value_max) else: return _replace_slice_max_for_loop(traces, a_max, frames_before, frames_after, value_max) - + # For loops def _replace_slice_min_for_loop(traces, a_min, frames_before, frames_after, value_min): min_indices, channels = np.where(traces <= a_min) @@ -252,6 +252,6 @@ def _replace_slice_min_numba(traces, a_min, frames_before, frames_after, value_m if to_clear[i]: traces[i, j] = value_min return traces - + clip = define_function_from_class(source_class=ClipRecording, name="clip") blank_staturation = define_function_from_class(source_class=BlankSaturationRecording, name="blank_staturation") From dc8d49dbe0f964c30285e103da63987322be6b17 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 23 May 2023 12:38:15 +0200 Subject: [PATCH 4/4] Update src/spikeinterface/preprocessing/clip.py --- src/spikeinterface/preprocessing/clip.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/clip.py b/src/spikeinterface/preprocessing/clip.py index f13f988bdc..e48d60ecc8 100644 --- a/src/spikeinterface/preprocessing/clip.py +++ b/src/spikeinterface/preprocessing/clip.py @@ -101,7 +101,6 @@ def __init__(self, recording, abs_threshold=None, quantile_threshold=None, num_chunks_per_segment=50, chunk_size=500, seed=0): - name = "blank_staturation" def __init__( self,