Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/spikeinterface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .core import *

import warnings

warnings.filterwarnings("ignore", message="distutils Version classes are deprecated")
warnings.filterwarnings("ignore", message="the imp module is deprecated")

Expand Down
110 changes: 96 additions & 14 deletions src/spikeinterface/preprocessing/clip.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import numpy as np

try:
from numba import njit
HAVE_NUMBA = True
except ModuleNotFoundError as err:
HAVE_NUMBA = False

from spikeinterface.core.core_tools import define_function_from_class
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment

Expand Down Expand Up @@ -70,6 +76,10 @@ class BlankSaturationRecording(BasePreprocessor):
fill_value: float or None
The value to write instead of the saturating signal. If None, then the value is
automatically computed as the median signal value
ms_before: float (default 0)
Time (ms) to replace before the saturation signal
ms_after: float (default 0)
Time (ms) to replace after the saturation signal
num_chunks_per_segment: int (default 50)
The number of chunks per segments to consider to estimate the threshold/fill_values
chunk_size: int (default 500)
Expand All @@ -83,6 +93,13 @@ class BlankSaturationRecording(BasePreprocessor):
The filtered traces recording extractor object

"""
name = 'blank_staturation'

def __init__(self, recording, abs_threshold=None, quantile_threshold=None,
direction='upper', fill_value=None,
ms_before=0, ms_after=0,
num_chunks_per_segment=50, chunk_size=500, seed=0):


name = "blank_staturation"

Expand Down Expand Up @@ -135,41 +152,106 @@ def __init__(

BasePreprocessor.__init__(self, recording)
for parent_segment in recording._recording_segments:
rec_segment = ClipRecordingSegment(parent_segment, a_min, value_min, a_max, value_max)
rec_segment = ClipRecordingSegment(
parent_segment, a_min, value_min, a_max, value_max,
ms_before=ms_before, ms_after=ms_after
)
self.add_recording_segment(rec_segment)

self._kwargs = dict(
recording=recording,
abs_threshold=abs_threshold,
quantile_threshold=quantile_threshold,
direction=direction,
fill_value=fill_value,
num_chunks_per_segment=num_chunks_per_segment,
chunk_size=chunk_size,
seed=seed,
)
self._kwargs = dict(recording=recording, abs_threshold=abs_threshold, ms_before=ms_before, ms_after=ms_after,
quantile_threshold=quantile_threshold, direction=direction, fill_value=fill_value,
num_chunks_per_segment=num_chunks_per_segment, chunk_size=chunk_size,
seed=seed)



class ClipRecordingSegment(BasePreprocessorSegment):
def __init__(self, parent_recording_segment, a_min, value_min, a_max, value_max):
def __init__(self, parent_recording_segment, a_min, value_min, a_max, value_max,
ms_before=0, ms_after=0):
BasePreprocessorSegment.__init__(self, parent_recording_segment)

self.a_min = a_min
self.value_min = value_min
self.a_max = a_max
self.value_max = value_max
self.ms_before = ms_before
self.ms_after = ms_after


def get_traces(self, start_frame, end_frame, channel_indices):
traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices)
traces = traces.copy()
fs = self.parent_recording_segment.sampling_frequency

frames_before = int(self.ms_before * fs // 1000)
frames_after = int(self.ms_after * fs // 1000)

if self.a_min is not None:
traces[traces <= self.a_min] = self.value_min
traces = replace_slice_min(traces, self.a_min, frames_before, frames_after, self.value_min)

if self.a_max is not None:
traces[traces >= self.a_max] = self.value_max
traces = replace_slice_max(traces, self.a_max, frames_before, frames_after, self.value_max)

return traces

def replace_slice_min(traces, a_min, frames_before, frames_after, value_min):
if HAVE_NUMBA:
return _replace_slice_min_numba(traces, a_min, frames_before, frames_after, value_min)
else:
return _replace_slice_min_for_loop(traces, a_min, frames_before, frames_after, value_min)

def replace_slice_max(traces, a_max, frames_before, frames_after, value_max):
if HAVE_NUMBA:
return _replace_slice_max_numba(traces, a_max, frames_before, frames_after, value_max)
else:
return _replace_slice_max_for_loop(traces, a_max, frames_before, frames_after, value_max)

# For loops
def _replace_slice_min_for_loop(traces, a_min, frames_before, frames_after, value_min):
min_indices, channels = np.where(traces <= a_min)
for index, chan in zip(min_indices, channels):
traces[max(0, index - frames_before):min(len(traces), index + frames_after + 1), chan] = value_min
return traces

def _replace_slice_max_for_loop(traces, a_max, frames_before, frames_after, value_max):
max_indices, channels = np.where(traces >= a_max)
for index, chan in zip(max_indices, channels):
traces[max(0, index - frames_before):min(len(traces), index + frames_after + 1), chan] = value_max
return traces

if HAVE_NUMBA:
# Numba
@njit(cache=True)
def _replace_slice_max_numba(traces, a_max, frames_before, frames_after, value_max):
m, n = traces.shape
to_clear = np.zeros(m, dtype=np.bool_)
for j in range(n):
to_clear[:] = False
for i in range(m):
if traces[i, j] >= a_max:
to_clear[
max(0, i - frames_before) : min(m, i + frames_after + 1)
] = True
for i in range(m):
if to_clear[i]:
traces[i, j] = value_max
return traces

@njit(cache=True)
def _replace_slice_min_numba(traces, a_min, frames_before, frames_after, value_min):
m, n = traces.shape
to_clear = np.zeros(m, dtype=np.bool_)
for j in range(n):
to_clear[:] = False
for i in range(m):
if traces[i, j] <= a_min:
to_clear[
max(0, i - frames_before) : min(m, i + frames_after + 1)
] = True
for i in range(m):
if to_clear[i]:
traces[i, j] = value_min
return traces

clip = define_function_from_class(source_class=ClipRecording, name="clip")
blank_staturation = define_function_from_class(source_class=BlankSaturationRecording, name="blank_staturation")