diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index bbb6e5b9f3..7bb6546664 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -161,6 +161,16 @@ def _repr_html_(self, display_name=True): html_repr = html_header + html_segments + html_channel_ids + html_extra return html_repr + def __add__(self, other): + from .operatorrecordings import AddRecordings + + return AddRecordings(self, other) + + def __sub__(self, other): + from .operatorrecordings import SubtractRecordings + + return SubtractRecordings(self, other) + def get_num_segments(self) -> int: """ Returns the number of segments. diff --git a/src/spikeinterface/core/operatorrecordings.py b/src/spikeinterface/core/operatorrecordings.py new file mode 100644 index 0000000000..6ffb7d9fa3 --- /dev/null +++ b/src/spikeinterface/core/operatorrecordings.py @@ -0,0 +1,82 @@ +from spikeinterface.core import BaseRecording, BaseRecordingSegment +from spikeinterface.core.recording_tools import get_rec_attributes, do_recording_attributes_match + + +class BaseOperatorRecording(BaseRecording): + """Base class for operator recordings.""" + + def __init__(self, recording1, recording2, operator: str): + assert operator in ["add", "subtract"], "Operator must be 'add' or 'subtract'" + assert all( + isinstance(rec, BaseRecording) for rec in [recording1, recording2] + ), "'recordings' must be a list of RecordingExtractor" + + rec_attrs2 = get_rec_attributes(recording2) + assert do_recording_attributes_match( + recording1, rec_attrs2 + ), "Both recordings must have the same sampling frequency and channel ids" + assert self.are_times_kwargs_compatible( + recording1, recording2 + ), "Both recordings must have the same time parameters" + + channel_ids = recording1.channel_ids + sampling_frequency = recording1.sampling_frequency + dtype = recording1.get_dtype() + + BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype) + + for segment1, segment2 in zip(recording1._recording_segments, recording2._recording_segments): + add_segment = OperatorRecordingSegment(segment1, segment2, operator) + self.add_recording_segment(add_segment) + + self._kwargs = dict(recording1=recording1, recording2=recording2, operator=operator) + + def are_times_kwargs_compatible(self, recording1, recording2) -> bool: + import numpy as np + + for segment_index in range(recording1.get_num_segments()): + time_kwargs1 = recording1._recording_segments[segment_index].get_times_kwargs() + time_kwargs2 = recording2._recording_segments[segment_index].get_times_kwargs() + for key in time_kwargs1.keys(): + val1 = time_kwargs1[key] + val2 = time_kwargs2[key] + if (val1 is None and val2 is not None) or (val1 is not None and val2 is None): + return False + if isinstance(val1, np.ndarray) and isinstance(val2, np.ndarray): + if not np.array_equal(val1, val2): + return False + else: + if val1 != val2: + return False + return True + + +class OperatorRecordingSegment(BaseRecordingSegment): + def __init__(self, segment1, segment2, operator: str): + BaseRecordingSegment.__init__(self, **segment1.get_times_kwargs()) + self.segment1 = segment1 + self.segment2 = segment2 + self.operator = operator + + def get_num_samples(self): + return self.segment1.get_num_samples() + + def get_traces(self, start_frame, end_frame, channel_indices): + traces1 = self.segment1.get_traces(start_frame, end_frame, channel_indices) + traces2 = self.segment2.get_traces(start_frame, end_frame, channel_indices) + if self.operator == "add": + return traces1 + traces2 + elif self.operator == "subtract": + return traces1 - traces2 + else: + raise ValueError(f"Unknown operator: {self.operator}") + + +class AddRecordings(BaseOperatorRecording): + def __init__(self, recording1, recording2): + BaseOperatorRecording.__init__(self, recording1, recording2, operator="add") + + +class SubtractRecordings(BaseOperatorRecording): + def __init__(self, recording1, recording2): + BaseOperatorRecording.__init__(self, recording1, recording2, operator="subtract") diff --git a/src/spikeinterface/core/tests/test_operator_recordings.py b/src/spikeinterface/core/tests/test_operator_recordings.py new file mode 100644 index 0000000000..d8be5e9886 --- /dev/null +++ b/src/spikeinterface/core/tests/test_operator_recordings.py @@ -0,0 +1,50 @@ +import pytest +import numpy as np + +import spikeinterface as si + + +@pytest.fixture +def recording(): + recording = si.generate_recording(durations=[10, 20], num_channels=4, sampling_frequency=20000) + return recording + + +def test_sum_recordings(recording): + rec_sum = recording + recording + for seg_index in range(rec_sum.get_num_segments()): + traces_orig = recording.get_traces(segment_index=seg_index) + traces_sum = rec_sum.get_traces(segment_index=seg_index) + np.testing.assert_array_equal(traces_sum, traces_orig * 2) + + +def test_subtract_recordings(recording): + rec_sub = recording - recording + for seg_index in range(rec_sub.get_num_segments()): + traces_sub = rec_sub.get_traces(segment_index=seg_index) + np.testing.assert_array_equal(traces_sub, np.zeros_like(traces_sub)) + + +def test_operator_combo(recording): + rec_combo = recording - recording + recording - recording + recording + for seg_index in range(rec_combo.get_num_segments()): + traces_orig = recording.get_traces(segment_index=seg_index) + traces_combo = rec_combo.get_traces(segment_index=seg_index) + np.testing.assert_array_equal(traces_combo, traces_orig) + + +def test_errors(recording): + recording2 = si.generate_recording(durations=[10, 20], num_channels=4, sampling_frequency=10000) + with pytest.raises(AssertionError): + _ = recording + recording2 + with pytest.raises(AssertionError): + _ = recording - recording2 + + recording_times = recording.clone() + for segment_index in range(recording_times.get_num_segments()): + recording_times.set_times( + recording_times.get_times(segment_index=segment_index) + (segment_index + 1) * 5, + segment_index=segment_index, + ) + with pytest.raises(AssertionError): + _ = recording + recording_times