Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ def _repr_html_(self, display_name=True):
html_repr = html_header + html_segments + html_channel_ids + html_extra
return html_repr

def __add__(self, other):
from .operatorrecordings import AddRecordings

return AddRecordings(self, other)

def __sub__(self, other):
from .operatorrecordings import SubtractRecordings

return SubtractRecordings(self, other)

def get_num_segments(self) -> int:
"""
Returns the number of segments.
Expand Down
60 changes: 60 additions & 0 deletions src/spikeinterface/core/operatorrecordings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from spikeinterface.core import BaseRecording, BaseRecordingSegment
from spikeinterface.core.recording_tools import get_rec_attributes, do_recording_attributes_match


class BaseOperatorRecording(BaseRecording):
"""Base class for operator recordings."""

def __init__(self, recording1, recording2, operator: str):
assert operator in ["add", "subtract"], "Operator must be 'add' or 'subtract'"
assert all(
isinstance(rec, BaseRecording) for rec in [recording1, recording2]
), "'recordings' must be a list of RecordingExtractor"

rec_attrs2 = get_rec_attributes(recording2)
assert do_recording_attributes_match(
recording1, rec_attrs2
), "Both recordings must have the same sampling frequency and channel ids"

channel_ids = recording1.channel_ids
sampling_frequency = recording1.sampling_frequency
dtype = recording1.get_dtype()

BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype)

for segment1, segment2 in zip(recording1._recording_segments, recording2._recording_segments):
add_segment = OperatorRecordingSegment(segment1, segment2, operator)
self.add_recording_segment(add_segment)

self._kwargs = dict(recording1=recording1, recording2=recording2, operator=operator)


class OperatorRecordingSegment(BaseRecordingSegment):
def __init__(self, segment1, segment2, operator: str):
BaseRecordingSegment.__init__(self, **segment1.get_times_kwargs())
self.segment1 = segment1
self.segment2 = segment2
self.operator = operator

def get_num_samples(self):
return self.segment1.get_num_samples()

def get_traces(self, start_frame, end_frame, channel_indices):
traces1 = self.segment1.get_traces(start_frame, end_frame, channel_indices)
traces2 = self.segment2.get_traces(start_frame, end_frame, channel_indices)
if self.operator == "add":
return traces1 + traces2
elif self.operator == "subtract":
return traces1 - traces2
else:
raise ValueError(f"Unknown operator: {self.operator}")


class AddRecordings(BaseOperatorRecording):
def __init__(self, recording1, recording2):
BaseOperatorRecording.__init__(self, recording1, recording2, operator="add")


class SubtractRecordings(BaseOperatorRecording):
def __init__(self, recording1, recording2):
BaseOperatorRecording.__init__(self, recording1, recording2, operator="subtract")
33 changes: 33 additions & 0 deletions src/spikeinterface/core/tests/test_operator_recordings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
import numpy as np

import spikeinterface as si


@pytest.fixture
def recording():
recording = si.generate_recording(durations=[10, 20], num_channels=4, sampling_frequency=20000)
return recording


def test_sum_recordings(recording):
rec_sum = recording + recording
for seg_index in range(rec_sum.get_num_segments()):
traces_orig = recording.get_traces(segment_index=seg_index)
traces_sum = rec_sum.get_traces(segment_index=seg_index)
np.testing.assert_array_equal(traces_sum, traces_orig * 2)


def test_subtract_recordings(recording):
rec_sub = recording - recording
for seg_index in range(rec_sub.get_num_segments()):
traces_sub = rec_sub.get_traces(segment_index=seg_index)
np.testing.assert_array_equal(traces_sub, np.zeros_like(traces_sub))


def test_operator_combo(recording):
rec_combo = recording - recording + recording - recording + recording
for seg_index in range(rec_combo.get_num_segments()):
traces_orig = recording.get_traces(segment_index=seg_index)
traces_combo = rec_combo.get_traces(segment_index=seg_index)
np.testing.assert_array_equal(traces_combo, traces_orig)