diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 17a55adf5f..349e09eb2f 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -288,9 +288,11 @@ def test_cluster_spikes_arguments(self): self._check_arguments(cluster_spikes, expected_arguments) def test_save_sorting_arguments(self): - expected_arguments = ["ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars"] - - expected_arguments.append("save_preprocessed_copy") + expected_arguments = [ + "ops", "results_dir", "st", "clu", "tF", "Wall", "imin", "tic0", "save_extra_vars", "save_preprocessed_copy" + ] + if parse(kilosort.__version__) >= parse("4.0.39"): + expected_arguments.append("skip_dat_path") self._check_arguments(save_sorting, expected_arguments) diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index 088c703e64..9481bb74f2 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -11,6 +11,7 @@ on: env: KACHERY_API_KEY: ${{ secrets.KACHERY_API_KEY }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} concurrency: # Cancel previous workflows on the same pull request group: ${{ github.workflow }}-${{ github.ref }} diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index 9d56be8498..03f47bbef1 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -7,6 +7,7 @@ on: env: KACHERY_API_KEY: ${{ secrets.KACHERY_API_KEY }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} jobs: full-tests-with-codecov: diff --git a/.github/workflows/test_kilosort4.yml b/.github/workflows/test_kilosort4.yml index 42e6140917..10ec9532fe 100644 --- a/.github/workflows/test_kilosort4.yml +++ b/.github/workflows/test_kilosort4.yml @@ -7,6 +7,7 @@ on: pull_request: paths: - '**/kilosort4.py' + - '**/test_kilosort4_ci.py' jobs: versions: diff --git a/doc/api.rst b/doc/api.rst index bc685eee2a..b97df80ec1 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -76,7 +76,7 @@ Low-level .. autoclass:: ChunkRecordingExecutor -Back-compatibility with ``WaveformExtractor`` (version < 0.101.0) +Back-compatibility with ``WaveformExtractor`` (version > 0.100.0) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: spikeinterface.core @@ -179,6 +179,8 @@ spikeinterface.preprocessing .. autofunction:: correct_motion .. autofunction:: get_motion_presets .. autofunction:: get_motion_parameters_preset + .. autofunction:: load_motion_info + .. autofunction:: save_motion_info .. autofunction:: depth_order .. autofunction:: detect_bad_channels .. autofunction:: directional_derivative diff --git a/doc/get_started/import.rst b/doc/get_started/import.rst index be3f7d5afb..30e841bdd8 100644 --- a/doc/get_started/import.rst +++ b/doc/get_started/import.rst @@ -73,7 +73,7 @@ For example: .. code-block:: python from spikeinterface.preprocessing import bandpass_filter, common_reference - from spikeinterface.core import extract_waveforms + from spikeinterface.core import create_sorting_analyzer from spikeinterface.extractors import read_binary As mentioned this approach only imports exactly what you plan on using so it is the most minimalist. It does require diff --git a/doc/get_started/install_sorters.rst b/doc/get_started/install_sorters.rst index 61abc0f129..82bef42c14 100644 --- a/doc/get_started/install_sorters.rst +++ b/doc/get_started/install_sorters.rst @@ -12,7 +12,7 @@ and in many cases the easiest way to run them is to do so via Docker or Singular **This is the approach we recommend for all users.** To run containerized sorters see our documentation here: :ref:`containerizedsorters`. -There are some cases where users will need to install the spike sorting algorithms in their own environment. If you +There are some cases where you will need to install the spike sorting algorithms on your own computer. If you are on a system where it is infeasible to run Docker or Singularity containers, or if you are actively developing the spike sorting software, you will likely need to install each spike sorter yourself. @@ -24,7 +24,7 @@ opencl (Tridesclous) to use hardware acceleration (GPU). Here is a list of the implemented wrappers and some instructions to install them on your local machine. Installation instructions are given for an **Ubuntu** platform. Please check the documentation of the different spike sorters to retrieve installation instructions for other operating systems. -We use **pip** to install packages, but **conda** should also work in many cases. +We use **pip** to install packages, but **conda** or **uv** should also work in many cases. Some novel spike sorting algorithms are implemented directly in SpikeInterface using the :py:mod:`spikeinterface.sortingcomponents` module. Checkout the :ref:`get_started/install_sorters:SpikeInterface-based spike sorters` section of this page @@ -140,10 +140,12 @@ Kilosort4 * Python, requires CUDA for GPU acceleration (highly recommended) * Url: https://github.com/MouseLand/Kilosort -* Authors: Marius Pachitariu, Shashwat Sridhar, Carsen Stringer +* Authors: Marius Pachitariu, Shashwat Sridhar, Carsen Stringer, Jacob Pennington * Installation:: - pip install kilosort==4.0 torch + pip install kilosort + pip uninstall torch + pip install torch --index-url https://download.pytorch.org/whl/cu118 * For more installation instruction refer to https://github.com/MouseLand/Kilosort @@ -240,7 +242,7 @@ Waveclus * Also supports Snippets (waveform cutouts) objects (:py:class:`~spikeinterface.core.BaseSnippets`) * Url: https://github.com/csn-le/wave_clus/wiki * Authors: Fernando Chaure, Hernan Rey and Rodrigo Quian Quiroga -* Installation needs Matlab:: +* Installation requires Matlab:: git clone https://github.com/csn-le/wave_clus/ # provide installation path by setting the WAVECLUS_PATH environment variable @@ -270,7 +272,7 @@ with SpikeInterface. SpykingCircus2 ^^^^^^^^^^^^^^ -This is a upgraded version of SpykingCircus, natively written in SpikeInterface. +This is an upgraded version of SpykingCircus, natively written in SpikeInterface. The main differences are located in the clustering (now using on-the-fly features and less prone to finding noise clusters), and in the template-matching procedure, which is now a fully orthogonal matching pursuit, working not only at peak times but at all times, recovering more spikes close to noise thresholds. @@ -289,7 +291,7 @@ Tridesclous2 ^^^^^^^^^^^^ This is an upgraded version of Tridesclous, natively written in SpikeInterface. -#Same add his notes. + * Python * Requires: HDBSCAN and Numba @@ -314,7 +316,7 @@ Klusta (LEGACY) * Authors: Cyrille Rossant, Shabnam Kadir, Dan Goodman, Max Hunter, Kenneth Harris * Installation:: - pip install Cython h5py tqdm + pip install cython h5py tqdm pip install click klusta klustakwik2 * See also: https://github.com/kwikteam/phy diff --git a/doc/get_started/installation.rst b/doc/get_started/installation.rst index 182ce67b94..d432503098 100644 --- a/doc/get_started/installation.rst +++ b/doc/get_started/installation.rst @@ -81,14 +81,16 @@ Requirements * numpy * probeinterface - * neo>=0.9.0 - * joblib + * neo * threadpoolctl * tqdm + * zarr + * pydantic + * numcodecs + * packaging Sub-modules have more dependencies, so you should also install: - * zarr * h5py * scipy * pandas @@ -98,8 +100,13 @@ Sub-modules have more dependencies, so you should also install: * matplotlib * numba * distinctipy + * skops + * huggingface_hub * cuda-python (for non-macOS users) +For developers we offer a :code:`[dev]` option which installs testing, documentation, and linting packages necessary +for testing and building the docs. + All external spike sorters can be either run inside containers (Docker or Singularity - see :ref:`containerizedsorters`) or must be installed independently (see :ref:`get_started/install_sorters:Installing Spike Sorters`). diff --git a/doc/get_started/quickstart.rst b/doc/get_started/quickstart.rst index bfa4335ac1..1d532c9387 100644 --- a/doc/get_started/quickstart.rst +++ b/doc/get_started/quickstart.rst @@ -336,7 +336,7 @@ Alternatively we can pass a full dictionary containing the parameters: # parameters set by params dictionary sorting_TDC_2 = ss.run_sorter( - sorter_name="tridesclous", recording=recording_preprocessed, output_folder="tdc_output2", **other_params + sorter_name="tridesclous", recording=recording_preprocessed, folder="tdc_output2", **other_params ) print(sorting_TDC_2) diff --git a/doc/how_to/analyze_neuropixels.rst b/doc/how_to/analyze_neuropixels.rst index 1fe741ea48..04a9736b80 100644 --- a/doc/how_to/analyze_neuropixels.rst +++ b/doc/how_to/analyze_neuropixels.rst @@ -567,7 +567,7 @@ In this example: # run kilosort2.5 without drift correction params_kilosort2_5 = {'do_correction': False} - sorting = si.run_sorter('kilosort2_5', rec, output_folder=base_folder / 'kilosort2.5_output', + sorting = si.run_sorter('kilosort2_5', rec, folder=base_folder / 'kilosort2.5_output', docker_image=True, verbose=True, **params_kilosort2_5) .. code:: ipython3 diff --git a/doc/how_to/handle_drift.rst b/doc/how_to/handle_drift.rst index 86c146fd02..f1d4dc0a0e 100755 --- a/doc/how_to/handle_drift.rst +++ b/doc/how_to/handle_drift.rst @@ -1322,8 +1322,6 @@ A preset is a nested dict that contains theses methods/parameters. Run motion correction with one function! ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Correcting for drift is easy! You just need to run a single function. We -will try this function with some presets. Here we also save the motion correction results into a folder to be able to load them later. diff --git a/doc/how_to/process_by_channel_group.rst b/doc/how_to/process_by_channel_group.rst index 334f83b247..9ff8215aba 100644 --- a/doc/how_to/process_by_channel_group.rst +++ b/doc/how_to/process_by_channel_group.rst @@ -99,7 +99,8 @@ to any preprocessing function. referenced_recording = spre.common_reference(filtered_recording) good_channels_recording = spre.detect_and_remove_bad_channels(filtered_recording) -We can then aggregate the recordings back together using the ``aggregate_channels`` function +We can then aggregate the recordings back together using the ``aggregate_channels`` function. +Note that we do not need to do this to sort the data (see :ref:`sorting-by-channel-group`). .. code-block:: python @@ -134,6 +135,7 @@ back together under the hood). In general, it is not recommended to apply :py:func:`~aggregate_channels` more than once. This will slow down :py:func:`~get_traces` calls and may result in unpredictable behaviour. +.. _sorting-by-channel-group: Sorting a Recording by Channel Group ------------------------------------ @@ -141,16 +143,39 @@ Sorting a Recording by Channel Group We can also sort a recording for each channel group separately. It is not necessary to preprocess a recording by channel group in order to sort by channel group. -There are two ways to sort a recording by channel group. First, we can split the preprocessed -recording (or, if it was already split during preprocessing as above, skip the :py:func:`~aggregate_channels` step -directly use the :py:func:`~split_recording_dict`). +There are two ways to sort a recording by channel group. First, we can simply pass the output from +our preprocessing-by-group method above. Second, for more control, we can loop over the recordings +ourselves. -**Option 1: Manual splitting** +**Option 1 : Automatic splitting** -In this example, similar to above we loop over all preprocessed recordings that +Simply pass the split recording to the `run_sorter` function, as if it was a non-split recording. +This will return a dict of sortings, with the keys corresponding to the groups. + +.. code-block:: python + + split_recording = raw_recording.split_by("group") + # is a dict of recordings + + # do preprocessing if needed + pp_recording = spre.bandpass_filter(split_recording) + + dict_of_sortings = run_sorter( + sorter_name='kilosort2', + recording=pp_recording, + working_folder='working_path' + ) + + +**Option 2: Manual splitting** + +In this example, we loop over all preprocessed recordings that are grouped by channel, and apply the sorting separately. We store the sorting objects in a dictionary for later use. +You might do this if you want extra control e.g. to apply bespoke steps +to different groups. + .. code-block:: python split_preprocessed_recording = preprocessed_recording.split_by("group") @@ -160,19 +185,6 @@ sorting objects in a dictionary for later use. sorting = run_sorter( sorter_name='kilosort2', recording=split_preprocessed_recording, - output_folder=f"folder_KS2_group{group}" + folder=f"folder_KS2_group{group}" ) sortings[group] = sorting - -**Option 2 : Automatic splitting** - -Alternatively, SpikeInterface provides a convenience function to sort the recording by property: - -.. code-block:: python - - aggregate_sorting = run_sorter_by_property( - sorter_name='kilosort2', - recording=preprocessed_recording, - grouping_property='group', - working_folder='working_path' - ) diff --git a/doc/how_to/read_various_formats.rst b/doc/how_to/read_various_formats.rst new file mode 100644 index 0000000000..1e8ee0d5bc --- /dev/null +++ b/doc/how_to/read_various_formats.rst @@ -0,0 +1,173 @@ +.. code:: ipython3 + + %matplotlib inline + +Read various format into SpikeInterface +======================================= + +SpikeInterface can read various formats of “recording” (traces) and +“sorting” (spike train) data. + +Internally, to read different formats, SpikeInterface either uses: \* a +wrapper to ``neo ``\ \_ +rawio classes \* or a direct implementation + +Note that: + +- file formats contain a “recording”, a “sorting”, or “both” +- file formats can be file-based (NWB, …) or folder based (SpikeGLX, + OpenEphys, …) + +In this example we demonstrate how to read different file formats into +SI + +.. code:: ipython3 + + import matplotlib.pyplot as plt + + import spikeinterface.core as si + import spikeinterface.extractors as se + +Let’s download some datasets in different formats from the +``ephy_testing_data ``\ \_ +repo: + +- MEArec: a simulator format which is hdf5-based. It contains both a + “recording” and a “sorting” in the same file. +- Spike2: file from spike2 devices. It contains “recording” information + only. + +.. code:: ipython3 + + + spike2_file_path = si.download_dataset(remote_path="spike2/130322-1LY.smr") + print(spike2_file_path) + + mearec_folder_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") + print(mearec_folder_path) + + +.. parsed-literal:: + + Downloading data from 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data/raw/master/mearec/mearec_test_10s.h5' to file '/Users/christopherhalcrow/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5'. + + +.. parsed-literal:: + + modified: spike2/130322-1LY.smr (file) + 1 annex'd file (15.8 MB recorded total size) + /Users/christopherhalcrow/spikeinterface_datasets/ephy_testing_data/spike2/130322-1LY.smr + 1 annex'd file (59.4 MB recorded total size) + nothing to save, working tree clean + + +.. parsed-literal:: + + 100%|█████████████████████████████████████| 62.3M/62.3M [00:00<00:00, 76.7GB/s] + + +.. parsed-literal:: + + /Users/christopherhalcrow/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5 + + +Now that we have downloaded the files, let’s load them into SI. + +The :py:func:``~spikeinterface.extractors.read_spike2`` function returns +one object, a :py:class:``~spikeinterface.core.BaseRecording``. + +Note that internally this file contains 2 data streams (‘0’ and ‘1’), so +we need to specify which one we want to retrieve (‘0’ in our case). the +stream information can be retrieved by using the +:py:func:``~spikeinterface.extractors.get_neo_streams`` function. + +.. code:: ipython3 + + stream_names, stream_ids = se.get_neo_streams("spike2", spike2_file_path) + print(stream_names) + print(stream_ids) + stream_id = stream_ids[0] + print("stream_id", stream_id) + + recording = se.read_spike2(spike2_file_path, stream_id="0") + print(recording) + print(type(recording)) + print(isinstance(recording, si.BaseRecording)) + + +.. parsed-literal:: + + ['Signal stream 0', 'Signal stream 1'] + ['0', '1'] + stream_id 0 + Spike2RecordingExtractor: 1 channels - 20833.333333 Hz - 1 segments - 4,126,365 samples + 198.07s (3.30 minutes) - int16 dtype - 7.87 MiB + file_path: /Users/christopherhalcrow/spikeinterface_datasets/ephy_testing_data/spike2/130322-1LY.smr + + True + + +The +:py:func::literal:`~spikeinterface.extractors.read_spike2`\` function is equivalent to instantiating a :py:class:`\ ~spikeinterface.extractors.Spike2RecordingExtractor\` +object: + +.. code:: ipython3 + + recording = se.read_spike2(spike2_file_path, stream_id="0") + print(recording) + + +.. parsed-literal:: + + Spike2RecordingExtractor: 1 channels - 20833.333333 Hz - 1 segments - 4,126,365 samples + 198.07s (3.30 minutes) - int16 dtype - 7.87 MiB + file_path: /Users/christopherhalcrow/spikeinterface_datasets/ephy_testing_data/spike2/130322-1LY.smr + + +The :py:func:``~spikeinterface.extractors.read_mearec`` function returns +two objects, a :py:class:``~spikeinterface.core.BaseRecording`` and a +:py:class:``~spikeinterface.core.BaseSorting``: + +.. code:: ipython3 + + recording, sorting = se.read_mearec(mearec_folder_path) + print(recording) + print(type(recording)) + print() + print(sorting) + print(type(sorting)) + + + +.. parsed-literal:: + + MEArecRecordingExtractor: 32 channels - 32.0kHz - 1 segments - 320,000 samples - 10.00s + float32 dtype - 39.06 MiB + file_path: /Users/christopherhalcrow/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5 + + + MEArecSortingExtractor: 10 units - 1 segments - 32.0kHz + file_path: /Users/christopherhalcrow/spikeinterface_datasets/ephy_testing_data/mearec/mearec_test_10s.h5 + + + +SI objects (:py:class:``~spikeinterface.core.BaseRecording`` and +:py:class:``~spikeinterface.core.BaseSorting``) can be plotted quickly +with the :py:mod:``spikeinterface.widgets`` submodule: + +.. code:: ipython3 + + import spikeinterface.widgets as sw + + w_ts = sw.plot_traces(recording, time_range=(0, 5)) + w_rs = sw.plot_rasters(sorting, time_range=(0, 5)) + + plt.show() + + + +.. image:: read_various_formats_files/read_various_formats_12_0.png + + + +.. image:: read_various_formats_files/read_various_formats_12_1.png diff --git a/doc/how_to/read_various_formats_files/read_various_formats_12_0.png b/doc/how_to/read_various_formats_files/read_various_formats_12_0.png new file mode 100644 index 0000000000..58ea70c4fe Binary files /dev/null and b/doc/how_to/read_various_formats_files/read_various_formats_12_0.png differ diff --git a/doc/how_to/read_various_formats_files/read_various_formats_12_1.png b/doc/how_to/read_various_formats_files/read_various_formats_12_1.png new file mode 100644 index 0000000000..789e2438a0 Binary files /dev/null and b/doc/how_to/read_various_formats_files/read_various_formats_12_1.png differ diff --git a/doc/modules/comparison.rst b/doc/modules/comparison.rst index eb8b33edd0..957c974d16 100644 --- a/doc/modules/comparison.rst +++ b/doc/modules/comparison.rst @@ -276,10 +276,8 @@ The :py:func:`~spikeinterface.comparison.compare_two_sorters()` returns the comp import spikeinterface.comparisons as sc import spikinterface.widgets as sw - # First, let's download a simulated dataset - local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') - recording, sorting = se.read_mearec(local_path) - + # First, let's generate a simulated dataset + recording, sorting = si.generate_ground_truth_recording() # Then run two spike sorters and compare their outputs. sorting_HS = ss.run_sorter(sorter_name='herdingspikes', recording=recording) sorting_TDC = ss.run_sorter(sorter_name='tridesclous', recording=recording) @@ -332,9 +330,8 @@ Comparison of multiple sorters uses the following procedure: .. code-block:: python - # Download a simulated dataset - local_path = si.download_dataset(remote_path='mearec/mearec_test_10s.h5') - recording, sorting = se.read_mearec(local_path) + # Generate a simulated dataset + recording, sorting = si.generate_ground_truth_recording() # Then run 3 spike sorters and compare their outputs. sorting_MS4 = ss.run_sorter(sorter_name='mountainsort4', recording=recording) diff --git a/doc/modules/sorters.rst b/doc/modules/sorters.rst index d8a4708236..fd2423a26e 100644 --- a/doc/modules/sorters.rst +++ b/doc/modules/sorters.rst @@ -55,15 +55,15 @@ to easily run spike sorters: from spikeinterface.sorters import run_sorter # run Tridesclous - sorting_TDC = run_sorter(sorter_name="tridesclous", recording=recording, output_folder="/folder_TDC") + sorting_TDC = run_sorter(sorter_name="tridesclous", recording=recording, folder="/folder_TDC") # run Kilosort2.5 - sorting_KS2_5 = run_sorter(sorter_name="kilosort2_5", recording=recording, output_folder="/folder_KS2_5") + sorting_KS2_5 = run_sorter(sorter_name="kilosort2_5", recording=recording, folder="/folder_KS2_5") # run IronClust - sorting_IC = run_sorter(sorter_name="ironclust", recording=recording, output_folder="/folder_IC") + sorting_IC = run_sorter(sorter_name="ironclust", recording=recording, folder="/folder_IC") # run pyKilosort - sorting_pyKS = run_sorter(sorter_name="pykilosort", recording=recording, output_folder="/folder_pyKS") + sorting_pyKS = run_sorter(sorter_name="pykilosort", recording=recording, folder="/folder_pyKS") # run SpykingCircus - sorting_SC = run_sorter(sorter_name="spykingcircus", recording=recording, output_folder="/folder_SC") + sorting_SC = run_sorter(sorter_name="spykingcircus", recording=recording, folder="/folder_SC") Then the output, which is a :py:class:`~spikeinterface.core.BaseSorting` object, can be easily @@ -87,10 +87,10 @@ Spike-sorter-specific parameters can be controlled directly from the .. code-block:: python - sorting_TDC = run_sorter(sorter_name='tridesclous', recording=recording, output_folder="/folder_TDC", + sorting_TDC = run_sorter(sorter_name='tridesclous', recording=recording, folder="/folder_TDC", detect_threshold=8.) - sorting_KS2_5 = run_sorter(sorter_name="kilosort2_5", recording=recording, output_folder="/folder_KS2_5" + sorting_KS2_5 = run_sorter(sorter_name="kilosort2_5", recording=recording, folder="/folder_KS2_5" do_correction=False, preclust_threshold=6, freq_min=200.) @@ -193,7 +193,7 @@ The following code creates a test recording and runs a containerized spike sorte sorting = ss.run_sorter(sorter_name='kilosort3', recording=test_recording, - output_folder="kilosort3", + folder="kilosort3", singularity_image=True) print(sorting) @@ -208,7 +208,7 @@ To run in Docker instead of Singularity, use ``docker_image=True``. .. code-block:: python sorting = run_sorter(sorter_name='kilosort3', recording=test_recording, - output_folder="/tmp/kilosort3", docker_image=True) + folder="/tmp/kilosort3", docker_image=True) To use a specific image, set either ``docker_image`` or ``singularity_image`` to a string, e.g. ``singularity_image="spikeinterface/kilosort3-compiled-base:0.1.0"``. @@ -217,7 +217,7 @@ e.g. ``singularity_image="spikeinterface/kilosort3-compiled-base:0.1.0"``. sorting = run_sorter(sorter_name="kilosort3", recording=test_recording, - output_folder="kilosort3", + folder="kilosort3", singularity_image="spikeinterface/kilosort3-compiled-base:0.1.0") @@ -301,10 +301,10 @@ an :code:`engine` that supports parallel processing (such as :code:`joblib` or : another_recording = ... job_list = [ - {'sorter_name': 'tridesclous', 'recording': recording, 'output_folder': 'folder1','detect_threshold': 5.}, - {'sorter_name': 'tridesclous', 'recording': another_recording, 'output_folder': 'folder2', 'detect_threshold': 5.}, - {'sorter_name': 'herdingspikes', 'recording': recording, 'output_folder': 'folder3', 'clustering_bandwidth': 8., 'docker_image': True}, - {'sorter_name': 'herdingspikes', 'recording': another_recording, 'output_folder': 'folder4', 'clustering_bandwidth': 8., 'docker_image': True}, + {'sorter_name': 'tridesclous', 'recording': recording, 'folder': 'folder1','detect_threshold': 5.}, + {'sorter_name': 'tridesclous', 'recording': another_recording, 'folder': 'folder2', 'detect_threshold': 5.}, + {'sorter_name': 'herdingspikes', 'recording': recording, 'folder': 'folder3', 'clustering_bandwidth': 8., 'docker_image': True}, + {'sorter_name': 'herdingspikes', 'recording': another_recording, 'folder': 'folder4', 'clustering_bandwidth': 8., 'docker_image': True}, ] # run in loop @@ -339,8 +339,8 @@ Running spike sorting by group is indeed a very common need. A :py:class:`~spikeinterface.core.BaseRecording` object has the ability to split itself into a dictionary of sub-recordings given a certain property (see :py:meth:`~spikeinterface.core.BaseRecording.split_by`). So it is easy to loop over this dictionary and sequentially run spike sorting on these sub-recordings. -SpikeInterface also provides a high-level function to automate the process of splitting the -recording and then aggregating the results with the :py:func:`~spikeinterface.sorters.run_sorter_by_property` function. +The :py:func:`~spikeinterface.sorters.run_sorter` method can also accept the dictionary which is returned +by :py:meth:`~spikeinterface.core.BaseRecording.split_by` and will return a dictionary of sortings. In this example, we create a 16-channel recording with 4 tetrodes: @@ -368,7 +368,19 @@ In this example, we create a 16-channel recording with 4 tetrodes: # >>> [0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3] -**Option 1: Manual splitting** +**Option 1 : Automatic splitting** + +.. code-block:: python + + # here the result is a dict of sortings + dict_of_sortings = run_sorter( + sorter_name='kilosort2', + recording=recording_4_tetrodes, + working_folder='working_path' + ) + + +**Option 2: Manual splitting** .. code-block:: python @@ -380,18 +392,9 @@ In this example, we create a 16-channel recording with 4 tetrodes: # here the result is a dict of a sorting object sortings = {} for group, sub_recording in recordings.items(): - sorting = run_sorter(sorter_name='kilosort2', recording=recording, output_folder=f"folder_KS2_group{group}") + sorting = run_sorter(sorter_name='kilosort2', recording=recording, folder=f"folder_KS2_group{group}") sortings[group] = sorting -**Option 2 : Automatic splitting** - -.. code-block:: python - - # here the result is one sorting that aggregates all sub sorting objects - aggregate_sorting = run_sorter_by_property(sorter_name='kilosort2', recording=recording_4_tetrodes, - grouping_property='group', - working_folder='working_path') - Handling multi-segment recordings --------------------------------- @@ -466,6 +469,7 @@ Here is the list of external sorters accessible using the run_sorter wrapper: * **Klusta** :code:`run_sorter(sorter_name='klusta')` * **Mountainsort4** :code:`run_sorter(sorter_name='mountainsort4')` * **Mountainsort5** :code:`run_sorter(sorter_name='mountainsort5')` +* **RT-Sort** :code:`run_sorter(sorter_name='rt-sort')` * **SpyKING Circus** :code:`run_sorter(sorter_name='spykingcircus')` * **Tridesclous** :code:`run_sorter(sorter_name='tridesclous')` * **Wave clus** :code:`run_sorter(sorter_name='waveclus')` @@ -546,7 +550,7 @@ From the user's perspective, they behave exactly like the external sorters: .. code-block:: python - sorting = run_sorter(sorter_name="spykingcircus2", recording=recording, output_folder="/tmp/folder") + sorting = run_sorter(sorter_name="spykingcircus2", recording=recording, folder="/tmp/folder") Contributing diff --git a/doc/references.rst b/doc/references.rst index 6ba6efe6a1..b583df82b0 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -20,14 +20,17 @@ If you use one of the following preprocessing methods, please cite the appropria - :code:`common_reference` [Rolston]_ Motion Correction -^^^^^^^^^^^^^^^^^ +----------------- If you use the :code:`correct_motion` method in the preprocessing module, please cite [Garcia]_ as well as the references that correspond to the :code:`preset` you used: -- :code:`nonrigid_accurate` [Windolf]_ [Varol]_ -- :code:`nonrigid_fast_and_accurate` [Windolf]_ [Varol]_ [Pachitariu]_ +- :code:`nonrigid_accurate` [Windolf_a]_ [Varol]_ +- :code:`nonrigid_fast_and_accurate` [Windolf_a]_ [Varol]_ [Pachitariu]_ - :code:`rigid_fast` *no additional citation needed* - :code:`kilosort_like` [Pachitariu]_ +- :code:`dredge_ap` [Windolf_b]_ +- :code:`dredge_lfp` [Windolf_b]_ +- :code:`medicine` [Watters]_ Sorters Module -------------- @@ -40,6 +43,7 @@ please include the appropriate citation for the :code:`sorter_name` parameter yo - :code:`herdingspikes` [Muthmann]_ [Hilgen]_ - :code:`kilosort` [Pachitariu]_ - :code:`mountainsort` [Chung]_ +- :code:`rt-sort` [van_der_Molen]_ - :code:`spykingcircus` [Yger]_ - :code:`wavclus` [Chaure]_ - :code:`yass` [Lee]_ @@ -47,7 +51,14 @@ please include the appropriate citation for the :code:`sorter_name` parameter yo Postprocessing Module --------------------- -If you use the :code:`acgs_3d` extensions, (i.e. :code:`postprocessing.compute_acgs_3d`, :code:`postprocessing.ComputeACG3D`) please cite [Beau]_ +If you use the :code:`postprocessing module`, i.e. you use the :code:`analyzer.compute()` include the citations for the following +methods: + + - :code:`acgs_3d` [Beau]_ + - :code:`unit_locations` or :code:`spike_locations` with :code:`monopolar_triangulation` based on work from [Boussard]_ + - :code:`unit_locations` or :code:`spike_locations` with :code:`grid_convolution` based on work from [Pachitariu]_ + - :code:`template_metrics` [Jia]_ + Qualitymetrics Module --------------------- @@ -74,15 +85,20 @@ important for your research: - :code:`nearest_neighbor` or :code:`nn_isolation` or :code:`nn_noise_overlap` [Chung]_ [Siegle]_ - :code:`silhouette` [Rousseeuw]_ [Hruschka]_ + Curation Module --------------- If you use the :code:`get_potential_auto_merge` method from the curation module, please cite [Llobet]_ +If you use :code:`auto_label_units` or :code:`train_model`, please cite [Jain]_ + References ---------- .. [Beau] `A deep learning strategy to identify cell types across species from high-density extracellular recordings. 2025. `_ +.. [Boussard] `Three-dimensional spike localization and imporved motion correction for Neuropixels recordings. 2021 `_ + .. [Buccino] `SpikeInterface, a unified framework for spike sorting. 2020. `_ .. [Buzsáki] `The Log-Dynamic Brain: How Skewed Distributions Affect Network Operations. 2014. `_ @@ -107,7 +123,11 @@ References .. [IBL] `Spike sorting pipeline for the International Brain Laboratory. 2022. `_ -.. [Jackson] Quantitative assessment of extracellular multichannel recording quality using measures of cluster separation. Society of Neuroscience Abstract. 2005. +.. [Jackson] `Quantitative assessment of extracellular multichannel recording quality using measures of cluster separation. Society of Neuroscience Abstract. 2005. `_ + +.. [Jain] `UnitRefine: A Community Toolbox for Automated Spike Sorting Curation. 2025 `_ + +.. [Jia] `High-density extracellular probes reveal dendritic backpropagation and facilitate neuron classification. 2019 `_ .. [Lee] `YASS: Yet another spike sorter. 2017. `_ @@ -119,7 +139,7 @@ References .. [Niediek] `Reliable Analysis of Single-Unit Recordings from the Human Brain under Noisy Conditions: Tracking Neurons over Hours. 2016. `_ -.. [npyx] `NeuroPyxels: loading, processing and plotting Neuropixels data in python. 2021. _` +.. [npyx] `NeuroPyxels: loading, processing and plotting Neuropixels data in python. 2021. `_ .. [Pachitariu] `Spike sorting with Kilosort4. 2024. `_ @@ -135,8 +155,14 @@ References .. [UMS] `UltraMegaSort2000 - Spike sorting and quality metrics for extracellular spike data. 2011. `_ +.. [van_der_Molen] `RT-Sort: An action potential propagation-based algorithm for real time spike detection and sorting with millisecond latencies. 2024. `_ + .. [Varol] `Decentralized Motion Inference and Registration of Neuropixel Data. 2021. `_ -.. [Windolf] `Robust Online Multiband Drift Estimation in Electrophysiology Data. 2022. `_ +.. [Watters] `MEDiCINe: Motion Correction for Neural Electrophysiology Recordings. 2025. `_ + +.. [Windolf_a] `Robust Online Multiband Drift Estimation in Electrophysiology Data. 2022. `_ + +.. [Windolf_b] `DREDge: robust motion correction for high-density extracellular recordings across species. 2023 `_ .. [Yger] `A spike sorting toolbox for up to thousands of electrodes validated with ground truth recordings in vitro and in vivo. 2018. `_ diff --git a/doc/tutorials_custom_index.rst b/doc/tutorials_custom_index.rst index 05da8fa7dd..f146f4a4d0 100755 --- a/doc/tutorials_custom_index.rst +++ b/doc/tutorials_custom_index.rst @@ -101,9 +101,8 @@ The :py:mod:`spikeinterface.extractors` module is designed to load and save reco :gutter: 2 .. grid-item-card:: Read various formats - :link-type: ref - :link: sphx_glr_tutorials_extractors_plot_1_read_various_formats.py - :img-top: /tutorials/extractors/images/thumb/sphx_glr_plot_1_read_various_formats_thumb.png + :link: how_to/read_various_formats.html + :img-top: how_to/read_various_formats_files/read_various_formats_12_0.png :img-alt: Read various formats :class-card: gallery-card :text-align: center diff --git a/docs_rtd.yml b/docs_rtd.yml deleted file mode 100644 index c4e1fb378c..0000000000 --- a/docs_rtd.yml +++ /dev/null @@ -1,9 +0,0 @@ -channels: - - conda-forge - - defaults -dependencies: - - python=3.10 - - pip - - datalad - - pip: - - -e .[docs] diff --git a/examples/how_to/handle_drift.py b/examples/how_to/handle_drift.py index 02dd12f8a0..a38734991a 100755 --- a/examples/how_to/handle_drift.py +++ b/examples/how_to/handle_drift.py @@ -109,9 +109,6 @@ def preprocess_chain(rec): # ### Run motion correction with one function! # -# Correcting for drift is easy! You just need to run a single function. -# We will try this function with some presets. -# # Here we also save the motion correction results into a folder to be able to load them later. # lets try theses presets diff --git a/examples/tutorials/extractors/plot_1_read_various_formats.py b/examples/how_to/read_various_formats.py similarity index 100% rename from examples/tutorials/extractors/plot_1_read_various_formats.py rename to examples/how_to/read_various_formats.py diff --git a/examples/tutorials/core/plot_1_recording_extractor.py b/examples/tutorials/core/plot_1_recording_extractor.py index e7d773e9e6..39520f2195 100644 --- a/examples/tutorials/core/plot_1_recording_extractor.py +++ b/examples/tutorials/core/plot_1_recording_extractor.py @@ -122,7 +122,7 @@ ############################################################################## # You can also get a recording with a subset of channels (i.e. a channel slice): -recording4 = recording3.channel_slice(channel_ids=["a", "c", "e"]) +recording4 = recording3.select_channels(channel_ids=["a", "c", "e"]) print(recording4) print(recording4.get_channel_ids()) diff --git a/examples/tutorials/core/plot_4_sorting_analyzer.py b/examples/tutorials/core/plot_4_sorting_analyzer.py index 3b49e35fff..1d0849716e 100644 --- a/examples/tutorials/core/plot_4_sorting_analyzer.py +++ b/examples/tutorials/core/plot_4_sorting_analyzer.py @@ -27,23 +27,12 @@ import matplotlib.pyplot as plt -from spikeinterface import download_dataset -from spikeinterface import create_sorting_analyzer, load_sorting_analyzer -import spikeinterface.extractors as se +from spikeinterface import create_sorting_analyzer, load_sorting_analyzer, generate_ground_truth_recording ############################################################################## -# First let's use the repo https://gin.g-node.org/NeuralEnsemble/ephy_testing_data -# to download a MEArec dataset. It is a simulated dataset that contains "ground truth" -# sorting information: +# First let's generate a simulated recording and sorting -repo = "https://gin.g-node.org/NeuralEnsemble/ephy_testing_data" -remote_path = "mearec/mearec_test_10s.h5" -local_path = download_dataset(repo=repo, remote_path=remote_path, local_folder=None) - -############################################################################## -# Let's now instantiate the recording and sorting objects: - -recording, sorting = se.read_mearec(local_path) +recording, sorting = generate_ground_truth_recording() print(recording) print(sorting) diff --git a/examples/tutorials/qualitymetrics/plot_3_quality_metrics.py b/examples/tutorials/qualitymetrics/plot_3_quality_metrics.py index bfa6880cb0..fe71368845 100644 --- a/examples/tutorials/qualitymetrics/plot_3_quality_metrics.py +++ b/examples/tutorials/qualitymetrics/plot_3_quality_metrics.py @@ -8,22 +8,17 @@ """ import spikeinterface.core as si -import spikeinterface.extractors as se -from spikeinterface.postprocessing import compute_principal_components from spikeinterface.qualitymetrics import ( compute_snrs, compute_firing_rates, compute_isi_violations, - calculate_pc_metrics, compute_quality_metrics, ) ############################################################################## -# First, let's download a simulated dataset -# from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' +# First, let's generate a simulated recording and sorting -local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") -recording, sorting = se.read_mearec(local_path) +recording, sorting = si.generate_ground_truth_recording() print(recording) print(sorting) @@ -70,7 +65,7 @@ ############################################################################## # Some metrics are based on the principal component scores, so the exwtension -# need to be computed before. For instance: +# must be computed before. For instance: analyzer.compute("principal_components", n_components=3, mode="by_channel_global", whiten=True) diff --git a/examples/tutorials/qualitymetrics/plot_4_curation.py b/examples/tutorials/qualitymetrics/plot_4_curation.py index 6a9253c093..328ebf8f2b 100644 --- a/examples/tutorials/qualitymetrics/plot_4_curation.py +++ b/examples/tutorials/qualitymetrics/plot_4_curation.py @@ -11,20 +11,15 @@ # Import the modules and/or functions necessary from spikeinterface import spikeinterface.core as si -import spikeinterface.extractors as se -from spikeinterface.postprocessing import compute_principal_components from spikeinterface.qualitymetrics import compute_quality_metrics ############################################################################## -# Let's download a simulated dataset -# from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' -# -# Let's imagine that the ground-truth sorting is in fact the output of a sorter. +# Let's generate a simulated dataset, and imagine that the ground-truth +# sorting is in fact the output of a sorter. -local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") -recording, sorting = se.read_mearec(file_path=local_path) +recording, sorting = si.generate_ground_truth_recording() print(recording) print(sorting) diff --git a/examples/tutorials/widgets/plot_3_waveforms_gallery.py b/examples/tutorials/widgets/plot_3_waveforms_gallery.py index 2845dcc62c..d2f4345d14 100644 --- a/examples/tutorials/widgets/plot_3_waveforms_gallery.py +++ b/examples/tutorials/widgets/plot_3_waveforms_gallery.py @@ -9,15 +9,12 @@ import spikeinterface as si import spikeinterface.extractors as se -import spikeinterface.postprocessing as spost import spikeinterface.widgets as sw ############################################################################## -# First, let's download a simulated dataset -# from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' +# First, let's generate a simulated dataset -local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") -recording, sorting = se.read_mearec(local_path) +recording, sorting = si.generate_ground_truth_recording() print(recording) print(sorting) diff --git a/examples/tutorials/widgets/plot_4_peaks_gallery.py b/examples/tutorials/widgets/plot_4_peaks_gallery.py index cce04ae5a0..e6e5f6cd56 100644 --- a/examples/tutorials/widgets/plot_4_peaks_gallery.py +++ b/examples/tutorials/widgets/plot_4_peaks_gallery.py @@ -14,22 +14,19 @@ import spikeinterface.full as si ############################################################################## -# First, let's download a simulated dataset -# from the repo 'https://gin.g-node.org/NeuralEnsemble/ephy_testing_data' - -local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") -rec, sorting = si.read_mearec(local_path) +# First, let's generate a simulated dataset +recording, sorting = si.generate_ground_truth_recording() ############################################################################## # Let's filter and detect peaks on it from spikeinterface.sortingcomponents.peak_detection import detect_peaks -rec_filtred = si.bandpass_filter(recording=rec, freq_min=300.0, freq_max=6000.0, margin_ms=5.0) -print(rec_filtred) +rec_filtered = si.bandpass_filter(recording=recording, freq_min=300.0, freq_max=6000.0, margin_ms=5.0) +print(rec_filtered) peaks = detect_peaks( - recording=rec_filtred, + recording=rec_filtered, method="locally_exclusive", peak_sign="neg", detect_threshold=6, @@ -53,14 +50,14 @@ # This "peaks" vector can be used in several widgets, for instance # plot_peak_activity() -si.plot_peak_activity(recording=rec_filtred, peaks=peaks) +si.plot_peak_activity(recording=rec_filtered, peaks=peaks) plt.show() ############################################################################## -# can be also animated with bin_duration_s=1. - -si.plot_peak_activity(recording=rec_filtred, peaks=peaks, bin_duration_s=1.0) +# can be also animated with bin_duration_s=1. The animation only works if you +# run this code locally +si.plot_peak_activity(recording=rec_filtered, peaks=peaks, bin_duration_s=1.0) plt.show() diff --git a/installation_tips/README.md b/installation_tips/README.md index 1257cb6786..04c93db2ae 100644 --- a/installation_tips/README.md +++ b/installation_tips/README.md @@ -1,70 +1,105 @@ ## Installation tips -If you are not (yet) an expert in Python installations (conda vs pip, mananging environements, etc.), -here we propose a simple recipe to install `spikeinterface` and several sorters inside an anaconda -environment for windows/mac users. - -This environment will install: - * spikeinterface full option +If you are not (yet) an expert in Python installations, the first major hurdle is choosing the installation procedure. + +Some key concepts you need to know before starting: + * Python itself can be distributed and installed many, many ways. + * Python itself does not contain many features for scientific computing, so you need to install "packages". For example + numpy, scipy, matplotlib, spikeinterface, ... These are all examples of Python packages that aid in scientific computation. + * All of these packages have their own dependencies which requires figuring out which versions of the dependencies work for + the combination of packages you as the user want to use. + * Packages can be distributed and installed in several ways (pip, conda, uv, mamba, ...) and luckily these methods of installation + typically take care of solving the dependencies for you! + * Installing many packages at once is challenging (because of their dependency graphs) so you need to do it in an "isolated environment" to not destroy any previous installation. You need to see an "environment" as a sub-installation in a dedicated folder. + +Choosing the installer + an environment manager + a package installer is a nightmare for beginners. + +The main options are: + * use "uv", a new, fast and simple package manager. We recommend this for beginners on every operating system. + * use "anaconda" (or its flavors-mamba, miniconda), which does everything. Used to be very popular but theses days it is becoming + harder to use because it is slow by default and has relatively strict licensing on the default channel (not always free anymore). + You need to play with "community channels" to make it free again, which is complicated for beginners. + This way is better for users in organizations that have specific licensing agrees with anaconda already in place. + * use Python from the system or Python.org + venv + pip: good and simple idea for Linux users, but does require familiarity with + the Python ecosystem (so good for intermediate users). + +Here we propose a step by step recipe for beginners based on [**"uv"**](https://github.com/astral-sh/uv). +We used to recommend installing with anaconda. It will be kept here for a while but we do not recommend it anymore. + + +This recipe will install: + * spikeinterface `full` option * spikeinterface-gui - * phy - * tridesclous + * kilosort4 -Kilosort, Ironclust and HDSort are MATLAB based and need to be installed from source. +into our uv venv environment. -### Quick installation -Steps: +### Quick installation using "uv" (recommended) -1. Download anaconda individual edition [here](https://www.anaconda.com/download) -2. Run the installer. Check the box “Add anaconda3 to my Path environment variable”. It makes life easier for beginners. -3. Download with right click + save the file corresponding to your OS, and put it in "Documents" folder - * [`full_spikeinterface_environment_windows.yml`](https://raw.githubusercontent.com/SpikeInterface/spikeinterface/main/installation_tips/full_spikeinterface_environment_windows.yml) - * [`full_spikeinterface_environment_mac.yml`](https://raw.githubusercontent.com/SpikeInterface/spikeinterface/main/installation_tips/full_spikeinterface_environment_mac.yml) -4. Then open the "Anaconda Command Prompt" (if Windows, search in your applications) or the Terminal (for Mac users) -5. If not in the "Documents" folder type `cd Documents` -6. Then run this depending on your OS: - * `conda env create --file full_spikeinterface_environment_windows.yml` - * `conda env create --file full_spikeinterface_environment_mac.yml` +1. On macOS and Linux. Open a terminal and do + `curl -LsSf https://astral.sh/uv/install.sh | sh` +2. On Windows. Open an instance of the Powershell (Windows has many options this is the recommended one from uv) + `powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"` +3. Exit the session and log in again. +4. Download with right click and save this file in your "Documents" folder: + * [`beginner_requirements_stable.txt`](https://raw.githubusercontent.com/SpikeInterface/spikeinterface/main/installation_tips/beginner_requirements_stable.txt) for stable release +5. Open terminal or powershell and run: +6. `uv venv si_env --python 3.12` +7. Activate your virtual environment by running: + - For Mac/Linux: `source si_env/bin/activate` (you should see `(si_env)` in your terminal) + - For Windows: `si_env\Scripts\activate` +8. Run `uv pip install -r Documents/beginner_requirements_stable.txt` -Done! Before running a spikeinterface script you will need to "select" this "environment" with `conda activate si_env`. +## Installing before release (from source) -Note for **linux** users : this conda recipe should work but we recommend strongly to use **pip + virtualenv**. +Some tools in the spikeinteface ecosystem are getting regular bug fixes (spikeinterface, spikeinterface-gui, probeinterface, neo). +We are making releases 2 to 4 times a year. In between releases if you want to install from source you can use the `beginner_requirements_rolling.txt` file to create the environment instead of the `beginner_requirements_stable.txt` file. This will install the packages of the ecosystem from source. +This is a good way to test if a patch fixes your issue. ### Check the installation - If you want to test the spikeinterface install you can: 1. Download with right click + save the file [`check_your_install.py`](https://raw.githubusercontent.com/SpikeInterface/spikeinterface/main/installation_tips/check_your_install.py) and put it into the "Documents" folder - -2. Open the Anaconda Command Prompt (Windows) or Terminal (Mac) -3. If not in your "Documents" folder type `cd Documents` -4. Run this: - ``` - conda activate si_env - python check_your_install.py - ``` -5. If a windows user to clean-up you will also need to right click + save [`cleanup_for_windows.py`](https://raw.githubusercontent.com/SpikeInterface/spikeinterface/main/installation_tips/cleanup_for_windows.py) -Then transfer `cleanup_for_windows.py` into your "Documents" folder. Finally run : +2. Open the CMD Prompt (Windows)[^1] or Terminal (Mac/Linux) +3. Activate your si_env : `source si_env/bin/activate` (Max/Linux), `si_env\Scripts\activate` (Windows) +4. Go to your "Documents" folder with `cd Documents` or the place where you downloaded the `check_your_install.py` +5. Run `python check_your_install.py` +6. If you are a Windows user, you should also right click + save [`cleanup_for_windows.py`](https://raw.githubusercontent.com/SpikeInterface/spikeinterface/main/installation_tips/cleanup_for_windows.py). Then transfer `cleanup_for_windows.py` into your "Documents" folder and finally run: ``` python cleanup_for_windows.py ``` -This script tests the following: +This script tests the following steps: * importing spikeinterface - * running tridesclous - * running spyking-circus (not on mac) - * running herdingspikes (not on windows) + * running tridesclous2 + * running kilosort4 * opening the spikeinterface-gui - * exporting to Phy -## Installing before release +### Legacy installation using Anaconda (not recommended anymore) + +Steps: + +1. Download Anaconda individual edition [here](https://www.anaconda.com/download) +2. Run the installer. Check the box “Add anaconda3 to my Path environment variable”. It makes life easier for beginners. +3. Download with right click + save the environment YAML file ([`beginner_conda_env_stable.yml`](https://raw.githubusercontent.com/SpikeInterface/spikeinterface/main/installation_tips/beginner_conda_env_stable.yml)) and put it in "Documents" folder +4. Then open the "Anaconda Command Prompt" (if Windows, search in your applications) or the Terminal (for Mac users) +5. If not in the "Documents" folder type `cd Documents` +6. Run this command to create the environment: + ```bash + conda env create --file beginner_conda_env_stable.yml + ``` + +Done! Before running a spikeinterface script you will need to "select" this "environment" with `conda activate si_env`. + +Note for **Linux** users: this conda recipe should work but we recommend strongly to use **pip + virtualenv**. + + -Some tools in the spikeinteface ecosystem are getting regular bug fixes (spikeinterface, spikeinterface-gui, probeinterface, python-neo, sortingview). -We are making releases 2 to 4 times a year. In between releases if you want to install from source you can use the `full_spikeinterface_environment_rolling_updates.yml` file to create the environment. This will install the packages of the ecosystem from source. -This is a good way to test if patch fix your issue. +[^1]: Although uv installation instructions are for the Powershell, our sorter scripts are for the CMD Prompt. After the initial installation with Powershell, any session that will have sorting requires the CMD Prompt. If you do not +plan to spike sort in a session either shell could be used. diff --git a/installation_tips/full_spikeinterface_environment_linux_dandi.yml b/installation_tips/beginner_conda_env_stable.yml similarity index 84% rename from installation_tips/full_spikeinterface_environment_linux_dandi.yml rename to installation_tips/beginner_conda_env_stable.yml index 5f276b0d20..cd2f36c1bf 100755 --- a/installation_tips/full_spikeinterface_environment_linux_dandi.yml +++ b/installation_tips/beginner_conda_env_stable.yml @@ -3,7 +3,7 @@ channels: - conda-forge - defaults dependencies: - - python=3.11 + - python=3.12 - pip - numpy - scipy @@ -13,7 +13,7 @@ dependencies: - h5py - pandas - xarray - - zarr + - zarr<3 - scikit-learn - hdbscan - networkx @@ -30,8 +30,6 @@ dependencies: - libxcb - pip: - ephyviewer - - MEArec - spikeinterface[full,widgets] - spikeinterface-gui - - tridesclous - # - phy==2.0b5 + - kilosort>4.0.30 diff --git a/installation_tips/beginner_requirements_rolling.txt b/installation_tips/beginner_requirements_rolling.txt new file mode 100644 index 0000000000..b4deca3a50 --- /dev/null +++ b/installation_tips/beginner_requirements_rolling.txt @@ -0,0 +1,10 @@ +https://github.com/NeuralEnsemble/python-neo/archive/master.zip +https://github.com/SpikeInterface/probeinterface/archive/main.zip +https://github.com/SpikeInterface/spikeinterface/archive/main.zip[full,widgets] +https://github.com/SpikeInterface/spikeinterface-gui/archive/main.zip +jupyterlab +PySide6<6.8 +hdbscan +pyqtgraph +ephyviewer +kilosort>4.0.30 diff --git a/installation_tips/beginner_requirements_stable.txt b/installation_tips/beginner_requirements_stable.txt new file mode 100644 index 0000000000..faa121f267 --- /dev/null +++ b/installation_tips/beginner_requirements_stable.txt @@ -0,0 +1,8 @@ +spikeinterface[full,widgets] +jupyterlab +PySide6<6.8 +hdbscan +pyqtgraph +ephyviewer +spikeinterface-gui +kilosort>4.0.30 diff --git a/installation_tips/check_your_install.py b/installation_tips/check_your_install.py index f3f80961e8..e8437743b8 100644 --- a/installation_tips/check_your_install.py +++ b/installation_tips/check_your_install.py @@ -1,96 +1,92 @@ from pathlib import Path import platform -import os import shutil import argparse +import warnings + +warnings.filterwarnings("ignore") + + +job_kwargs = dict(n_jobs=-1, progress_bar=False, chunk_duration="1s") + def check_import_si(): import spikeinterface as si + def check_import_si_full(): import spikeinterface.full as si def _create_recording(): import spikeinterface.full as si - rec, sorting = si.toy_example(num_segments=1, duration=200, seed=1, num_channels=16, num_columns=2) - rec.save(folder='./toy_example_recording') + + rec, _ = si.generate_ground_truth_recording( + durations=[200.0], sampling_frequency=30_000.0, num_channels=16, num_units=10, seed=2205 + ) + rec.save(folder="./toy_example_recording", verbose=False, **job_kwargs) def _run_one_sorter_and_analyzer(sorter_name): - job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") import spikeinterface.full as si - recording = si.load_extractor('./toy_example_recording') - sorting = si.run_sorter(sorter_name, recording, output_folder=f'./sorter_with_{sorter_name}', verbose=False) - sorting_analyzer = si.create_sorting_analyzer(sorting, recording, - format="binary_folder", folder=f"./analyzer_with_{sorter_name}", - **job_kwargs) + si.set_global_job_kwargs(**job_kwargs) + + recording = si.load("./toy_example_recording") + sorting = si.run_sorter(sorter_name, recording, folder=f"./sorter_with_{sorter_name}", verbose=False) + + sorting_analyzer = si.create_sorting_analyzer( + sorting, recording, format="binary_folder", folder=f"./analyzer_with_{sorter_name}" + ) sorting_analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=500) - sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("waveforms") sorting_analyzer.compute("templates") sorting_analyzer.compute("noise_levels") sorting_analyzer.compute("unit_locations", method="monopolar_triangulation") - sorting_analyzer.compute("correlograms", window_ms=100, bin_ms=5.) - sorting_analyzer.compute("principal_components", n_components=3, mode='by_channel_global', whiten=True, **job_kwargs) + sorting_analyzer.compute("correlograms", window_ms=100, bin_ms=5.0) + sorting_analyzer.compute("principal_components", n_components=3, mode="by_channel_global", whiten=True) sorting_analyzer.compute("quality_metrics", metric_names=["snr", "firing_rate"]) -def run_tridesclous(): - _run_one_sorter_and_analyzer('tridesclous') - def run_tridesclous2(): - _run_one_sorter_and_analyzer('tridesclous2') + _run_one_sorter_and_analyzer("tridesclous2") +def run_kilosort4(): + _run_one_sorter_and_analyzer("kilosort4") + def open_sigui(): import spikeinterface.full as si - import spikeinterface_gui - app = spikeinterface_gui.mkQApp() - - sorter_name = "tridesclous2" - folder = f"./analyzer_with_{sorter_name}" - analyzer = si.load_sorting_analyzer(folder) + from spikeinterface_gui import run_mainwindow - win = spikeinterface_gui.MainWindow(analyzer) - win.show() - app.exec_() - -def export_to_phy(): - import spikeinterface.full as si sorter_name = "tridesclous2" folder = f"./analyzer_with_{sorter_name}" analyzer = si.load_sorting_analyzer(folder) - phy_folder = "./phy_example" - si.export_to_phy(analyzer, output_folder=phy_folder, verbose=False) - - -def open_phy(): - os.system("phy template-gui ./phy_example/params.py") + win = run_mainwindow(analyzer, start_app=True) def _clean(): # clean folders = [ "./toy_example_recording", - "./sorter_with_tridesclous", - "./analyzer_with_tridesclous", "./sorter_with_tridesclous2", "./analyzer_with_tridesclous2", - "./phy_example" + "./sorter_with_kilosort4", + "./analyzer_with_kilosort4", ] for folder in folders: if Path(folder).exists(): shutil.rmtree(folder) + parser = argparse.ArgumentParser() # add ci flag so that gui will not be used in ci # end user can ignore -parser.add_argument('--ci', action='store_false') +parser.add_argument("--ci", action="store_false") -if __name__ == '__main__': +if __name__ == "__main__": args = parser.parse_args() @@ -98,34 +94,22 @@ def _clean(): _create_recording() steps = [ - ('Import spikeinterface', check_import_si), - ('Import spikeinterface.full', check_import_si_full), - ('Run tridesclous', run_tridesclous), - ('Run tridesclous2', run_tridesclous2), - ] + ("Import spikeinterface", check_import_si), + ("Import spikeinterface.full", check_import_si_full), + ("Run tridesclous2", run_tridesclous2), + ("Run kilosort4", run_kilosort4), + ] # backwards logic because default is True for end-user if args.ci: - steps.append(('Open spikeinterface-gui', open_sigui)) - - steps.append(('Export to phy', export_to_phy)), - # phy is removed from the env because it force a pip install PyQt5 - # which break the conda env - # ('Open phy', open_phy), - - # if platform.system() == "Windows": - # pass - # elif platform.system() == "Darwin": - # pass - # else: - # pass + steps.append(("Open spikeinterface-gui", open_sigui)) for label, func in steps: try: func() - done = '...OK' + done = "...OK" except Exception as err: - done = f'...Fail, Error: {err}' + done = f"...Fail, Error: {err}" print(label, done) if platform.system() == "Windows": diff --git a/installation_tips/cleanup_for_windows.py b/installation_tips/cleanup_for_windows.py index 2c334f2df2..8d803ae421 100644 --- a/installation_tips/cleanup_for_windows.py +++ b/installation_tips/cleanup_for_windows.py @@ -5,15 +5,17 @@ def _clean(): # clean folders = [ - 'toy_example_recording', - "tridesclous_output", "tridesclous_waveforms", - "spykingcircus_output", "spykingcircus_waveforms", - "phy_example" + "./toy_example_recording", + "./sorter_with_tridesclous2", + "./analyzer_with_tridesclous2", + "./sorter_with_kilosort4", + "./analyzer_with_kilosort4", ] for folder in folders: if Path(folder).exists(): shutil.rmtree(folder) -if __name__ == '__main__': - _clean() +if __name__ == "__main__": + + _clean() diff --git a/installation_tips/full_spikeinterface_environment_mac.yml b/installation_tips/full_spikeinterface_environment_mac.yml deleted file mode 100755 index 2522fda78d..0000000000 --- a/installation_tips/full_spikeinterface_environment_mac.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: si_env -channels: - - conda-forge - - defaults -dependencies: - - python=3.11 - - pip - - numpy - - scipy - - joblib - - tqdm - - matplotlib - - h5py - - pandas - - xarray - - zarr - - scikit-learn - - hdbscan - - networkx - - pybind11 - - loky - - numba - - jupyter - - pyqt=5 - - pyqtgraph - - ipywidgets - - ipympl - - pip: - - ephyviewer - - MEArec - - spikeinterface[full,widgets] - - spikeinterface-gui - - tridesclous - # - phy==2.0b5 diff --git a/installation_tips/full_spikeinterface_environment_rolling_updates.yml b/installation_tips/full_spikeinterface_environment_rolling_updates.yml deleted file mode 100644 index b4479aa20f..0000000000 --- a/installation_tips/full_spikeinterface_environment_rolling_updates.yml +++ /dev/null @@ -1,35 +0,0 @@ -name: si_env_rolling -channels: - - conda-forge - - defaults -dependencies: - - python=3.11 - - pip - - numpy - - scipy - - joblib - - tqdm - - matplotlib - - h5py - - pandas - - xarray - - hdbscan - - scikit-learn - - networkx - - pybind11 - - loky - - numba - - jupyter - - pyqt=5 - - pyqtgraph - - ipywidgets - - ipympl - - pip: - - ephyviewer - - docker - - https://github.com/SpikeInterface/MEArec/archive/main.zip - - https://github.com/NeuralEnsemble/python-neo/archive/master.zip - - https://github.com/SpikeInterface/probeinterface/archive/main.zip - - https://github.com/SpikeInterface/spikeinterface/archive/main.zip - - https://github.com/SpikeInterface/spikeinterface-gui/archive/main.zip - - https://github.com/magland/sortingview/archive/main.zip diff --git a/installation_tips/full_spikeinterface_environment_windows.yml b/installation_tips/full_spikeinterface_environment_windows.yml deleted file mode 100755 index 2522fda78d..0000000000 --- a/installation_tips/full_spikeinterface_environment_windows.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: si_env -channels: - - conda-forge - - defaults -dependencies: - - python=3.11 - - pip - - numpy - - scipy - - joblib - - tqdm - - matplotlib - - h5py - - pandas - - xarray - - zarr - - scikit-learn - - hdbscan - - networkx - - pybind11 - - loky - - numba - - jupyter - - pyqt=5 - - pyqtgraph - - ipywidgets - - ipympl - - pip: - - ephyviewer - - MEArec - - spikeinterface[full,widgets] - - spikeinterface-gui - - tridesclous - # - phy==2.0b5 diff --git a/pyproject.toml b/pyproject.toml index 6554423890..9caf5a4784 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,13 +69,13 @@ extractors = [ "sonpy;python_version<'3.10'", "lxml", # lxml for neuroscope "scipy", - "ibllib==3.3.1", # streaming IBL + "ibllib>=3.4.1;python_version>='3.10'", # streaming IBL "pymatreader>=0.0.32", # For cell explorer matlab files "zugbruecke>=0.2; sys_platform!='win32'", # For plexon2 ] streaming_extractors = [ - "ibllib==3.3.1", # streaming IBL + "ibllib>=3.4.1;python_version>='3.10'", # streaming IBL # Following dependencies are for streaming with nwb files "pynwb>=2.6.0", "fsspec", @@ -101,7 +101,7 @@ full = [ "distinctipy", "matplotlib>=3.6", # matplotlib.colormaps "cuda-python; platform_system != 'Darwin'", - "numba", + "numba>=0.59", "skops", "huggingface_hub" ] @@ -144,7 +144,7 @@ test_extractors = [ ] test_preprocessing = [ - "ibllib==3.3.1", # for IBL + "ibllib>=3.4.1;python_version>='3.10'", # streaming IBL "torch", ] @@ -156,7 +156,7 @@ test = [ "psutil", # preprocessing - "ibllib==3.3.1", # for IBL + "ibllib>=3.4.1;python_version>='3.10'", # streaming templates "s3fs", diff --git a/readthedocs.yml b/readthedocs.yml index c6c44d83a0..486ace7d53 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -1,14 +1,18 @@ version: 2 -sphinx: - # Path to your Sphinx configuration file. - configuration: doc/conf.py - +# Specify os and python version build: - os: ubuntu-22.04 + os: "ubuntu-24.04" tools: - python: "mambaforge-4.10" + python: "3.10" + commands: + - asdf plugin add uv + - asdf install uv latest + - asdf global uv latest + - uv venv $READTHEDOCS_VIRTUALENV_PATH + - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH uv --preview pip install .[docs] + - python -m sphinx -T -b html -d doc/_build/doctrees -D language=en doc $READTHEDOCS_OUTPUT/html - -conda: - environment: docs_rtd.yml +sphinx: + # Path to your Sphinx configuration file. + configuration: doc/conf.py diff --git a/setup.py b/setup.py deleted file mode 100644 index 0f4aa42fb2..0000000000 --- a/setup.py +++ /dev/null @@ -1,7 +0,0 @@ -import setuptools -import warnings - -warnings.warn("Using `python setup.py` is legacy! See https://spikeinterface.readthedocs.io/en/latest/installation.html for installation") - -if __name__ == "__main__": - setuptools.setup() diff --git a/src/spikeinterface/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/benchmark/benchmark_motion_interpolation.py index ab72a1f9bd..c0969ed32a 100644 --- a/src/spikeinterface/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/benchmark/benchmark_motion_interpolation.py @@ -53,7 +53,7 @@ def run(self, **job_kwargs): sorting = run_sorter( sorter_name, recording, - output_folder=self.sorter_folder, + folder=self.sorter_folder, **sorter_params, delete_output_folder=False, ) diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index 9370bcc0bb..b397c7ef74 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -374,55 +374,6 @@ def plot_agreement_matrix(study, ordered=True, case_keys=None, axs=None): return fig -def plot_performances(study, mode="ordered", performance_names=("accuracy", "precision", "recall"), case_keys=None): - """ - Plot performances over case for a study. - - Parameters - ---------- - study : BenchmarkStudy - A study object. - mode : "ordered" | "snr" | "swarm", default: "ordered" - Which plot mode to use: - - * "ordered": plot performance metrics vs unit indices ordered by decreasing accuracy - * "snr": plot performance metrics vs snr - * "swarm": plot performance metrics as a swarm plot (see seaborn.swarmplot for details) - performance_names : list or tuple, default: ("accuracy", "precision", "recall") - Which performances to plot ("accuracy", "precision", "recall") - case_keys : list or None - A selection of cases to plot, if None, then all. - - Returns - ------- - fig : matplotlib.figure.Figure - The resulting figure containing the plots - """ - if mode == "snr": - warnings.warn( - "Use study.plot_performances_vs_snr() instead", - DeprecationWarning, - stacklevel=2, - ) - return plot_performances_vs_snr(study, case_keys=case_keys, performance_names=performance_names) - elif mode == "ordered": - warnings.warn( - "Use study.plot_performances_ordered() instead", - DeprecationWarning, - stacklevel=2, - ) - return plot_performances_ordered(study, case_keys=case_keys, performance_names=performance_names) - elif mode == "swarm": - warnings.warn( - "Use study.plot_performances_swarm() instead", - DeprecationWarning, - stacklevel=2, - ) - return plot_performances_swarm(study, case_keys=case_keys, performance_names=performance_names) - else: - raise ValueError("plot_performances() : wrong mode ") - - def plot_performances_vs_snr( study, case_keys=None, diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 388e9e4b6f..2b7180117b 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -190,63 +190,6 @@ def get_agreement_sorting(self, minimum_agreement_count=1, minimum_agreement_cou ) return sorting - def save_to_folder(self, save_folder): - warnings.warn( - "save_to_folder() is deprecated. " - "You should save and load the multi sorting comparison object using pickle." - "\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb'))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))", - DeprecationWarning, - stacklevel=2, - ) - for sorting in self.object_list: - assert sorting.check_serializability( - "json" - ), "MultiSortingComparison.save_to_folder() needs json serializable sortings" - - save_folder = Path(save_folder) - save_folder.mkdir(parents=True, exist_ok=True) - filename = str(save_folder / "multicomparison.gpickle") - with open(filename, "wb") as f: - pickle.dump(self.graph, f, pickle.HIGHEST_PROTOCOL) - kwargs = { - "delta_time": float(self.delta_time), - "match_score": float(self.match_score), - "chance_score": float(self.chance_score), - } - with (save_folder / "kwargs.json").open("w") as f: - json.dump(kwargs, f) - sortings = {} - for name, sorting in zip(self.name_list, self.object_list): - sortings[name] = sorting.to_dict(recursive=True, relative_to=save_folder) - with (save_folder / "sortings.json").open("w") as f: - json.dump(sortings, f) - - @staticmethod - def load_from_folder(folder_path): - warnings.warn( - "load_from_folder() is deprecated. " - "You should save and load the multi sorting comparison object using pickle." - "\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb'))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))", - DeprecationWarning, - stacklevel=2, - ) - folder_path = Path(folder_path) - with (folder_path / "kwargs.json").open() as f: - kwargs = json.load(f) - with (folder_path / "sortings.json").open() as f: - dict_sortings = json.load(f) - name_list = list(dict_sortings.keys()) - sorting_list = [load(v, base_folder=folder_path) for v in dict_sortings.values()] - mcmp = MultiSortingComparison(sorting_list=sorting_list, name_list=list(name_list), do_matching=False, **kwargs) - filename = str(folder_path / "multicomparison.gpickle") - with open(filename, "rb") as f: - mcmp.graph = pickle.load(f) - # do step 3 and 4 - mcmp._clean_graph() - mcmp._do_agreement() - mcmp._populate_spiketrains() - return mcmp - class AgreementSortingExtractor(BaseSorting): def __init__( diff --git a/src/spikeinterface/comparison/tests/test_multisortingcomparison.py b/src/spikeinterface/comparison/tests/test_multisortingcomparison.py index dc769f8d59..0b26083bd9 100644 --- a/src/spikeinterface/comparison/tests/test_multisortingcomparison.py +++ b/src/spikeinterface/comparison/tests/test_multisortingcomparison.py @@ -76,10 +76,6 @@ def test_compare_multiple_sorters(setup_module): agreement_2 = msc.get_agreement_sorting(minimum_agreement_count=2, minimum_agreement_count_only=True) assert np.all([agreement_2.get_unit_property(u, "agreement_number")] == 2 for u in agreement_2.get_unit_ids()) - msc.save_to_folder(multicomparison_folder) - - msc = MultiSortingComparison.load_from_folder(multicomparison_folder) - def test_compare_multi_segment(): num_segments = 3 diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index fb2e173b3e..fdcfc73c27 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -109,7 +109,12 @@ get_chunk_with_margin, order_channels_by_depth, ) -from .sorting_tools import spike_vector_to_spike_trains, random_spikes_selection, apply_merges_to_sorting +from .sorting_tools import ( + spike_vector_to_spike_trains, + random_spikes_selection, + apply_merges_to_sorting, + apply_splits_to_sorting, +) from .waveform_tools import extract_waveforms_to_buffers, estimate_templates, estimate_templates_with_accumulator from .snippets_tools import snippets_from_sorting diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index e237342f5b..4a045c255f 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -9,6 +9,7 @@ * ComputeNoiseLevels which is very convenient to have """ +import warnings import numpy as np from .sortinganalyzer import AnalyzerExtension, register_result_extension @@ -92,6 +93,11 @@ def _merge_extension_data( new_data["random_spikes_indices"] = np.flatnonzero(selected_mask[keep_mask]) return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + new_data = dict() + new_data["random_spikes_indices"] = self.data["random_spikes_indices"].copy() + return new_data + def _get_data(self): return self.data["random_spikes_indices"] @@ -243,8 +249,6 @@ def _select_extension_data(self, unit_ids): def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs ): - new_data = dict() - waveforms = self.data["waveforms"] some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes() if keep_mask is not None: @@ -275,6 +279,11 @@ def _merge_extension_data( return dict(waveforms=waveforms) + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # splitting only affects random spikes, not waveforms + new_data = dict(waveforms=self.data["waveforms"].copy()) + return new_data + def get_waveforms_one_unit(self, unit_id, force_dense: bool = False): """ Returns the waveforms of a unit id. @@ -554,6 +563,49 @@ def _merge_extension_data( return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + if not new_sorting_analyzer.has_extension("waveforms"): + warnings.warn( + "Splitting templates without the 'waveforms' extension will simply copy the template of the unit that " + "was split to the new split units. This is not recommended and may lead to incorrect results. It is " + "recommended to compute the 'waveforms' extension before splitting, or to use 'hard' splitting mode.", + ) + new_data = dict() + for operator, arr in self.data.items(): + # we first copy the unsplit units + new_array = np.zeros((len(new_sorting_analyzer.unit_ids), arr.shape[1], arr.shape[2]), dtype=arr.dtype) + new_analyzer_unit_ids = list(new_sorting_analyzer.unit_ids) + unsplit_unit_ids = [unit_id for unit_id in self.sorting_analyzer.unit_ids if unit_id not in split_units] + new_indices = np.array([new_analyzer_unit_ids.index(unit_id) for unit_id in unsplit_unit_ids]) + old_indices = self.sorting_analyzer.sorting.ids_to_indices(unsplit_unit_ids) + new_array[new_indices, ...] = arr[old_indices, ...] + + for split_unit_id, new_splits in zip(split_units, new_unit_ids): + if new_sorting_analyzer.has_extension("waveforms"): + for new_unit_id in new_splits: + split_unit_index = new_sorting_analyzer.sorting.id_to_index(new_unit_id) + wfs = new_sorting_analyzer.get_extension("waveforms").get_waveforms_one_unit( + new_unit_id, force_dense=True + ) + + if operator == "average": + arr = np.average(wfs, axis=0) + elif operator == "std": + arr = np.std(wfs, axis=0) + elif operator == "median": + arr = np.median(wfs, axis=0) + elif "percentile" in operator: + _, percentile = operator.splot("_") + arr = np.percentile(wfs, float(percentile), axis=0) + new_array[split_unit_index, ...] = arr + else: + split_unit_index = self.sorting_analyzer.sorting.id_to_index(split_unit_id) + old_template = arr[split_unit_index, ...] + new_indices = new_sorting_analyzer.sorting.ids_to_indices(new_splits) + new_array[new_indices, ...] = np.tile(old_template, (len(new_splits), 1, 1)) + new_data[operator] = new_array + return new_data + def _get_data(self, operator="average", percentile=None, outputs="numpy"): if operator != "percentile": key = operator @@ -727,6 +779,10 @@ def _merge_extension_data( # this does not depend on units return self.data.copy() + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # this does not depend on units + return self.data.copy() + def _run(self, verbose=False, **job_kwargs): self.data["noise_levels"] = get_noise_levels( self.sorting_analyzer.recording, diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index b32893981f..94a722eb15 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -2,6 +2,7 @@ import warnings from pathlib import Path +from typing import Optional import numpy as np from probeinterface import read_probeinterface, write_probeinterface @@ -369,21 +370,6 @@ def get_traces( traces = traces.astype("float32", copy=False) * gains + offsets return traces - def has_scaled_traces(self) -> bool: - """Checks if the recording has scaled traces - - Returns - ------- - bool - True if the recording has scaled traces, False otherwise - """ - warnings.warn( - "`has_scaled_traces` is deprecated and will be removed in 0.103.0. Use has_scaleable_traces() instead", - category=DeprecationWarning, - stacklevel=2, - ) - return self.has_scaled() - def get_time_info(self, segment_index=None) -> dict: """ Retrieves the timing attributes for a given segment index. As with @@ -467,7 +453,7 @@ def get_end_time(self, segment_index=None) -> float: rs = self._recording_segments[segment_index] return rs.get_end_time() - def has_time_vector(self, segment_index=None): + def has_time_vector(self, segment_index: Optional[int] = None): """Check if the segment of the recording has a time vector. Parameters @@ -725,17 +711,6 @@ def rename_channels(self, new_channel_ids: list | np.array | tuple) -> "BaseReco return ChannelSliceRecording(self, renamed_channel_ids=new_channel_ids) - def _channel_slice(self, channel_ids, renamed_channel_ids=None): - from .channelslice import ChannelSliceRecording - - warnings.warn( - "Recording.channel_slice will be removed in version 0.103, use `select_channels` or `rename_channels` instead.", - DeprecationWarning, - stacklevel=2, - ) - sub_recording = ChannelSliceRecording(self, channel_ids, renamed_channel_ids=renamed_channel_ids) - return sub_recording - def _remove_channels(self, remove_channel_ids): from .channelslice import ChannelSliceRecording @@ -879,7 +854,6 @@ def binary_compatible_with( file_paths_length=None, file_offset=None, file_suffix=None, - file_paths_lenght=None, ): """ Check is the recording is binary compatible with some constrain on @@ -891,14 +865,6 @@ def binary_compatible_with( * file_suffix """ - # spelling typo need to fix - if file_paths_lenght is not None: - warnings.warn( - "`file_paths_lenght` is deprecated and will be removed in 0.103.0 please use `file_paths_length`" - ) - if file_paths_length is None: - file_paths_length = file_paths_lenght - if not self.is_binary_compatible(): return False @@ -1017,7 +983,7 @@ def time_to_sample_index(self, time_s): sample_index = time_s * self.sampling_frequency else: sample_index = (time_s - self.t_start) * self.sampling_frequency - sample_index = np.round(sample_index).astype(int) + sample_index = np.round(sample_index).astype(np.int64) else: sample_index = np.searchsorted(self.time_vector, time_s, side="right") - 1 diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index ea1f9c4542..6be1766dbc 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -51,14 +51,6 @@ def has_scaleable_traces(self) -> bool: else: return True - def has_scaled(self): - warn( - "`has_scaled` has been deprecated and will be removed in 0.103.0. Please use `has_scaleable_traces()`", - category=DeprecationWarning, - stacklevel=2, - ) - return self.has_scaleable_traces() - def has_probe(self) -> bool: return "contact_vector" in self.get_property_keys() @@ -69,9 +61,6 @@ def is_filtered(self): # the is_filtered is handle with annotation return self._annotations.get("is_filtered", False) - def _channel_slice(self, channel_ids, renamed_channel_ids=None): - raise NotImplementedError - def set_probe(self, probe, group_mode="by_probe", in_place=False): """ Attach a list of Probe object to a recording. @@ -234,21 +223,6 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False return sub_recording - def set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False): - - warning_msg = ( - "`set_probes` is now a private function and the public function will be " - "removed in 0.103.0. Please use `set_probe` or `set_probegroup` instead" - ) - - warn(warning_msg, category=DeprecationWarning, stacklevel=2) - - sub_recording = self._set_probes( - probe_or_probegroup=probe_or_probegroup, group_mode=group_mode, in_place=in_place - ) - - return sub_recording - def get_probe(self): probes = self.get_probes() assert len(probes) == 1, "there are several probe use .get_probes() or get_probegroup()" @@ -441,25 +415,6 @@ def planarize(self, axes: str = "xy"): return recording2d - # utils - def channel_slice(self, channel_ids, renamed_channel_ids=None): - """ - Returns a new object with sliced channels. - - Parameters - ---------- - channel_ids : np.array or list - The list of channels to keep - renamed_channel_ids : np.array or list, default: None - A list of renamed channels - - Returns - ------- - BaseRecordingSnippets - The object with sliced channels - """ - return self._channel_slice(channel_ids, renamed_channel_ids=renamed_channel_ids) - def select_channels(self, channel_ids): """ Returns a new object with sliced channels. diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index 872f3fa8e1..e20fe09e11 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -79,14 +79,6 @@ def is_aligned(self): def get_num_segments(self): return len(self._snippets_segments) - def has_scaled_snippets(self): - warn( - "`has_scaled_snippets` is deprecated and will be removed in version 0.103.0. Please use `has_scaleable_traces()` instead", - category=DeprecationWarning, - stacklevel=2, - ) - return self.has_scaleable_traces() - def get_frames(self, indices=None, segment_index: Union[int, None] = None): segment_index = self._check_segment_index(segment_index) spts = self._snippets_segments[segment_index] diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index c092de3387..8a1fa9cf1b 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Optional, Union +from typing import Optional import numpy as np @@ -131,12 +131,52 @@ def get_total_duration(self) -> float: def get_unit_spike_train( self, unit_id: str | int, - segment_index: Union[int, None] = None, - start_frame: Union[int, None] = None, - end_frame: Union[int, None] = None, + segment_index: Optional[int] = None, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, return_times: bool = False, use_cache: bool = True, - ): + ) -> np.ndarray: + """ + Get spike train for a unit. + + Parameters + ---------- + unit_id : str or int + The unit id to retrieve spike train for + segment_index : int or None, default: None + The segment index to retrieve spike train from. + For multi-segment objects, it is required + start_frame : int or None, default: None + The start frame for spike train extraction + end_frame : int or None, default: None + The end frame for spike train extraction + return_times : bool, default: False + If True, returns spike times in seconds instead of frames + use_cache : bool, default: True + If True, uses cached spike trains when available + + Returns + ------- + spike_train : np.ndarray + Spike frames (or times if return_times=True) + """ + + if return_times: + start_time = ( + self.sample_index_to_time(start_frame, segment_index=segment_index) if start_frame is not None else None + ) + end_time = ( + self.sample_index_to_time(end_frame, segment_index=segment_index) if end_frame is not None else None + ) + + return self.get_unit_spike_train_in_seconds( + unit_id=unit_id, + segment_index=segment_index, + start_time=start_time, + end_time=end_time, + ) + segment_index = self._check_segment_index(segment_index) if use_cache: if segment_index not in self._cached_spike_trains: @@ -161,22 +201,101 @@ def get_unit_spike_train( unit_id=unit_id, start_frame=start_frame, end_frame=end_frame ).astype("int64") - if return_times: - if self.has_recording(): - times = self.get_times(segment_index=segment_index) - return times[spike_frames] - else: - segment = self._sorting_segments[segment_index] - t_start = segment._t_start if segment._t_start is not None else 0 - spike_times = spike_frames / self.get_sampling_frequency() - return t_start + spike_times - else: - return spike_frames + return spike_frames + + def get_unit_spike_train_in_seconds( + self, + unit_id: str | int, + segment_index: Optional[int] = None, + start_time: Optional[float] = None, + end_time: Optional[float] = None, + ) -> np.ndarray: + """ + Get spike train for a unit in seconds. + + This method uses a three-tier approach to get spike times: + 1. If the sorting has a recording, use the recording's time conversion + 2. If the segment implements get_unit_spike_train_in_seconds(), use that directly These are the native timestamps of the format + 3. Fall back to standard frame-to-time conversion + + This approach avoids double conversion for extractors that already store + spike times in seconds (e.g., NWB format) and ensures consistent timing + when a recording is associated with the sorting. + + Parameters + ---------- + unit_id : str or int + The unit id to retrieve spike train for + segment_index : int or None, default: None + The segment index to retrieve spike train from. + For multi-segment objects, it is required + start_time : float or None, default: None + The start time in seconds for spike train extraction + end_time : float or None, default: None + The end time in seconds for spike train extraction + + Returns + ------- + spike_times : np.ndarray + Spike times in seconds + """ + segment_index = self._check_segment_index(segment_index) + segment = self._sorting_segments[segment_index] - def register_recording(self, recording, check_spike_frames=True): + # If sorting has a registered recording, get the frames and get the times from the recording + # Note that this take into account the segment start time of the recording + if self.has_recording(): + + # Get all the spike times and then slice them + start_frame = None + end_frame = None + spike_frames = self.get_unit_spike_train( + unit_id=unit_id, + segment_index=segment_index, + start_frame=start_frame, + end_frame=end_frame, + return_times=False, + use_cache=True, + ) + + spike_times = self.sample_index_to_time(spike_frames, segment_index=segment_index) + + # Filter to return only the spikes within the specified time range + if start_time is not None: + spike_times = spike_times[spike_times >= start_time] + if end_time is not None: + spike_times = spike_times[spike_times <= end_time] + + return spike_times + + # Use the native spiking times if available + # Some instances might implement a method themselves to access spike times directly without having to convert + # (e.g. NWB extractors) + if hasattr(segment, "get_unit_spike_train_in_seconds"): + return segment.get_unit_spike_train_in_seconds(unit_id=unit_id, start_time=start_time, end_time=end_time) + + # If no recording attached and all back to frame-based conversion + # Get spike train in frames and convert to times using traditional method + start_frame = self.time_to_sample_index(start_time, segment_index=segment_index) if start_time else None + end_frame = self.time_to_sample_index(end_time, segment_index=segment_index) if end_time else None + + spike_frames = self.get_unit_spike_train( + unit_id=unit_id, + segment_index=segment_index, + start_frame=start_frame, + end_frame=end_frame, + return_times=False, + use_cache=True, + ) + + t_start = segment._t_start if segment._t_start is not None else 0 + spike_times = spike_frames / self.get_sampling_frequency() + return t_start + spike_times + + def register_recording(self, recording, check_spike_frames: bool = True): """ Register a recording to the sorting. If the sorting and recording both contain - time information, the recording’s time information will be used. + time information, the recording's time information will be used. Parameters ---------- @@ -215,7 +334,7 @@ def set_sorting_info(self, recording_dict, params_dict, log_dict): def has_recording(self) -> bool: return self._recording is not None - def has_time_vector(self, segment_index=None) -> bool: + def has_time_vector(self, segment_index: Optional[int] = None) -> bool: """ Check if the segment of the registered recording has a time vector. """ @@ -520,6 +639,20 @@ def time_to_sample_index(self, time, segment_index=0): return sample_index + def sample_index_to_time( + self, sample_index: int | np.ndarray, segment_index: Optional[int] = None + ) -> float | np.ndarray: + """ + Transform sample index into time in seconds + """ + segment_index = self._check_segment_index(segment_index) + if self.has_recording(): + return self._recording.sample_index_to_time(sample_index, segment_index=segment_index) + else: + segment = self._sorting_segments[segment_index] + t_start = segment._t_start if segment._t_start is not None else 0 + return (sample_index / self.get_sampling_frequency()) + t_start + def precompute_spike_trains(self, from_spike_vector=None): """ Pre-computes and caches all spike trains for this sorting diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 5a688e4869..7640168cb7 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -423,12 +423,20 @@ def check_paths_relative(input_dict, relative_folder) -> bool: relative_folder = Path(relative_folder).resolve().absolute() not_possible = [] for p in path_list: - p = Path(p) # check path is not an URL if "http" in str(p): not_possible.append(p) continue + # check path is not a remote path, see + # https://github.com/SpikeInterface/spikeinterface/issues/4045 + if is_path_remote(p): + not_possible.append(p) + continue + + # convert to Path + p = Path(p) + # If windows path check have same drive if isinstance(p, WindowsPath) and isinstance(relative_folder, WindowsPath): # check that on same drive diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 38a08c0fab..fbca2b0aa6 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -29,13 +29,21 @@ - chunk_duration : str or float or None Chunk duration in s if float or with units if str (e.g. "1s", "500ms") * n_jobs : int | float - Number of jobs to use. With -1 the number of jobs is the same as number of cores. - Using a float between 0 and 1 will use that fraction of the total cores. + Number of workers that will be requested during multiprocessing. Note that + the OS determines how this is distributed, but for convenience one can use + * -1 the number of workers is the same as number of cores (from os.cpu_count()) + * float between 0 and 1 uses fraction of total cores (from os.cpu_count()) * progress_bar : bool If True, a progress bar is printed * mp_context : "fork" | "spawn" | None, default: None Context for multiprocessing. It can be None, "fork" or "spawn". Note that "fork" is only safely available on LINUX systems + * pool_engine : "process" | "thread", default: "process" + Whether to use a ProcessPoolExecutor or ThreadPoolExecutor for multiprocessing + * max_threads_per_worker : int | None, default: 1 + Sets the limit for the number of thread per process using threadpoolctl module + Only applies in an n_jobs>1 context + If None, then no limits are applied. """ diff --git a/src/spikeinterface/core/loading.py b/src/spikeinterface/core/loading.py index d5845f033e..09e82e1026 100644 --- a/src/spikeinterface/core/loading.py +++ b/src/spikeinterface/core/loading.py @@ -130,7 +130,7 @@ def load( def load_extractor(file_or_folder_or_dict, base_folder=None) -> "BaseExtractor": warnings.warn( - "load_extractor() is deprecated and will be removed in the future. Please use load() instead.", + "load_extractor() is deprecated and will be removed in version 0.104.0. Please use load() instead.", DeprecationWarning, stacklevel=2, ) @@ -196,6 +196,12 @@ def _guess_object_from_local_folder(folder): with open(folder / "spikeinterface_info.json", "r") as f: spikeinterface_info = json.load(f) return _guess_object_from_dict(spikeinterface_info) + elif ( + (folder / "sorter_output").is_dir() + and (folder / "spikeinterface_params.json").is_file() + and (folder / "spikeinterface_log.json").is_file() + ): + return "SorterFolder" elif (folder / "waveforms").is_dir(): # before the SortingAnlazer, it was WaveformExtractor (v<0.101) return "WaveformExtractor" @@ -212,13 +218,20 @@ def _guess_object_from_local_folder(folder): return "Recording|Sorting" -def _load_object_from_folder(folder, object_type, **kwargs): +def _load_object_from_folder(folder, object_type: str, **kwargs): + if object_type == "SortingAnalyzer": from .sortinganalyzer import load_sorting_analyzer analyzer = load_sorting_analyzer(folder, **kwargs) return analyzer + elif object_type == "SorterFolder": + from spikeinterface.sorters import read_sorter_folder + + sorting = read_sorter_folder(folder) + return sorting + elif object_type == "Motion": from spikeinterface.core.motion import Motion @@ -244,6 +257,16 @@ def _load_object_from_folder(folder, object_type, **kwargs): si_file = f return BaseExtractor.load(si_file, base_folder=folder) + elif object_type.startswith("Group"): + + sub_object_type = object_type.split("[")[1].split("]")[0] + with open(folder / "spikeinterface_info.json", "r") as f: + spikeinterface_info = json.load(f) + group_keys = spikeinterface_info.get("dict_keys") + + group_of_objects = {key: _load_object_from_folder(folder / str(key), sub_object_type) for key in group_keys} + return group_of_objects + def _guess_object_from_zarr(zarr_folder): # here it can be a zarr folder for Recording|Sorting|SortingAnalyzer|Template diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index d5a899b9e5..c74c76d010 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -1,11 +1,13 @@ from __future__ import annotations +import warnings import importlib.util import numpy as np -from .basesorting import BaseSorting -from .numpyextractors import NumpySorting +from spikeinterface.core.base import BaseExtractor +from spikeinterface.core.basesorting import BaseSorting +from spikeinterface.core.numpyextractors import NumpySorting numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -226,6 +228,7 @@ def random_spikes_selection( return random_spikes_indices +### MERGING ZONE ### def apply_merges_to_sorting( sorting: BaseSorting, merge_unit_groups: list[list[int | str]] | list[tuple[int | str]], @@ -319,12 +322,69 @@ def apply_merges_to_sorting( keep_mask[group_indices[inds + 1]] = False spikes = spikes[keep_mask] - sorting = NumpySorting(spikes, sorting.sampling_frequency, all_unit_ids) + merge_sorting = NumpySorting(spikes, sorting.sampling_frequency, all_unit_ids) + set_properties_after_merging(merge_sorting, sorting, merge_unit_groups, new_unit_ids=new_unit_ids) if return_extra: - return sorting, keep_mask, new_unit_ids + return merge_sorting, keep_mask, new_unit_ids else: - return sorting + return merge_sorting + + +def set_properties_after_merging( + sorting_post_merge: BaseSorting, + sorting_pre_merge: BaseSorting, + merge_unit_groups: list[list[int | str]], + new_unit_ids: list[int | str], +): + """ + Add properties to the merge sorting object after merging units. + The properties of the merged units are propagated only if they are the same + for all units in the merge group. + + Parameters + ---------- + sorting_post_merge : BaseSorting + The Sorting object after merging units. + sorting_pre_merge : BaseSorting + The Sorting object before merging units. + merge_unit_groups : list + The groups of unit ids that were merged. + new_unit_ids : list + A list of new unit_ids for each merge. + """ + prop_keys = sorting_pre_merge.get_property_keys() + pre_unit_ids = sorting_pre_merge.unit_ids + post_unit_ids = sorting_post_merge.unit_ids + + kept_unit_ids = post_unit_ids[np.isin(post_unit_ids, pre_unit_ids)] + keep_pre_inds = sorting_pre_merge.ids_to_indices(kept_unit_ids) + keep_post_inds = sorting_post_merge.ids_to_indices(kept_unit_ids) + + for key in prop_keys: + parent_values = sorting_pre_merge.get_property(key) + + # propagate keep values + shape = (len(sorting_post_merge.unit_ids),) + parent_values.shape[1:] + new_values = np.empty(shape=shape, dtype=parent_values.dtype) + new_values[keep_post_inds] = parent_values[keep_pre_inds] + for new_id, merge_group in zip(new_unit_ids, merge_unit_groups): + merged_indices = sorting_pre_merge.ids_to_indices(merge_group) + merge_values = parent_values[merged_indices] + same_property_values = np.all([np.array_equal(m, merge_values[0]) for m in merge_values[1:]]) + new_index = sorting_post_merge.id_to_index(new_id) + if same_property_values: + # and new values only if they are all similar + new_values[new_index] = merge_values[0] + else: + default_missing_values = BaseExtractor.default_missing_property_values + new_values[new_index] = default_missing_values[parent_values.dtype.kind] + sorting_post_merge.set_property(key, new_values) + + # set is_merged property + is_merged = np.ones(len(sorting_post_merge.unit_ids), dtype=bool) + is_merged[keep_post_inds] = False + sorting_post_merge.set_property("is_merged", is_merged) def _get_ids_after_merging(old_unit_ids, merge_unit_groups, new_unit_ids): @@ -375,7 +435,7 @@ def _get_ids_after_merging(old_unit_ids, merge_unit_groups, new_unit_ids): def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ids=None, new_id_strategy="append"): """ - Function to generate new units ids during a merging procedure. If new_units_ids + Function to generate new units ids during a merging procedure. If `new_units_ids` are provided, it will return these unit ids, checking that they have the the same length as `merge_unit_groups`. @@ -440,3 +500,301 @@ def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ raise ValueError("wrong new_id_strategy") return new_unit_ids + + +### SPLITTING ZONE ### +def apply_splits_to_sorting( + sorting: BaseSorting, + unit_splits: dict[int | str, list[list[int | str]]], + new_unit_ids: list[list[int | str]] | None = None, + return_extra: bool = False, + new_id_strategy: str = "append", +): + """ + Apply a the splits to a sorting object. + + This function is not lazy and creates a new NumpySorting with a compact spike_vector as fast as possible. + The `unit_splits` should be a dict with the unit ids as keys and a list of lists of spike indices as values. + For each split, the list of spike indices should contain the indices of the spikes to be assigned to each split and + it should be complete (i.e. the sum of the lengths of the sublists must equal the number of spikes in the unit). + If `new_unit_ids` is not None, it will use these new unit ids for the split units. + If `new_unit_ids` is None, it will generate new unit ids according to `new_id_strategy`. + + Parameters + ---------- + sorting : BaseSorting + The Sorting object to apply splits. + unit_splits : dict + A dictionary with the split unit id as key and a list of lists of spike indices for each split. + The split indices for each unit MUST be a list of lists, where each sublist (at least two) contains the + indices of the spikes to be assigned to the each split. The sum of the lengths of the sublists must equal + the number of spikes in the unit. + new_unit_ids : list | None, default: None + List of new unit_ids for each split. If given, it needs to have the same length as `unit_splits`. + and each element must have the same length as the corresponding list of split indices. + If None, new ids will be generated. + return_extra : bool, default: False + If True, also return the new_unit_ids. + new_id_strategy : "append" | "split", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + + * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + * "split" : new_unit_ids will be the created as {split_unit_id]-{split_number} + (e.g. when splitting unit "13" in 2: "13-0" / "13-1"). + Only works if unit_ids are str otherwise switch to "append" + + Returns + ------- + sorting : NumpySorting + The newly create sorting with the split units. + """ + check_unit_splits_consistency(unit_splits, sorting) + spikes = sorting.to_spike_vector().copy() + + # here we assume that unit_splits split_indices are already full. + # this is true when running via apply_curation + + new_unit_ids = generate_unit_ids_for_split( + sorting.unit_ids, unit_splits, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy + ) + all_unit_ids = _get_ids_after_splitting(sorting.unit_ids, unit_splits, new_unit_ids) + all_unit_ids = list(all_unit_ids) + + num_seg = sorting.get_num_segments() + seg_lims = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2)) + segment_slices = [(seg_lims[i], seg_lims[i + 1]) for i in range(num_seg)] + + # using this function vaoid to use the mask approach and simplify a lot the algo + spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices] + spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True) + + for unit_id in sorting.unit_ids: + if unit_id in unit_splits: + split_indices = unit_splits[unit_id] + new_split_ids = new_unit_ids[list(unit_splits.keys()).index(unit_id)] + + for split, new_unit_id in zip(split_indices, new_split_ids): + new_unit_index = all_unit_ids.index(new_unit_id) + # split_indices are a concatenation across segments with absolute indices + # so we need to concatenate the spike indices across segments + spike_indices_unit = np.concatenate( + [spike_indices[segment_index][unit_id] for segment_index in range(num_seg)] + ) + spikes["unit_index"][spike_indices_unit[split]] = new_unit_index + else: + new_unit_index = all_unit_ids.index(unit_id) + for segment_index in range(num_seg): + spike_inds = spike_indices[segment_index][unit_id] + spikes["unit_index"][spike_inds] = new_unit_index + split_sorting = NumpySorting(spikes, sorting.sampling_frequency, all_unit_ids) + set_properties_after_splits( + split_sorting, + sorting, + list(unit_splits.keys()), + new_unit_ids=new_unit_ids, + ) + + if return_extra: + return split_sorting, new_unit_ids + else: + return split_sorting + + +def set_properties_after_splits( + sorting_post_split: BaseSorting, + sorting_pre_split: BaseSorting, + split_unit_ids: list[int | str], + new_unit_ids: list[list[int | str]], +): + """ + Add properties to the split sorting object after splitting units. + The properties of the split units are propagated to the new split units. + + Parameters + ---------- + sorting_post_split : BaseSorting + The Sorting object after splitting units. + sorting_pre_split : BaseSorting + The Sorting object before splitting units. + split_unit_ids : list + The unit ids that were split. + new_unit_ids : list + A list of new unit_ids for each split. + """ + prop_keys = sorting_pre_split.get_property_keys() + pre_unit_ids = sorting_pre_split.unit_ids + post_unit_ids = sorting_post_split.unit_ids + + kept_unit_ids = post_unit_ids[np.isin(post_unit_ids, pre_unit_ids)] + keep_pre_inds = sorting_pre_split.ids_to_indices(kept_unit_ids) + keep_post_inds = sorting_post_split.ids_to_indices(kept_unit_ids) + + for key in prop_keys: + parent_values = sorting_pre_split.get_property(key) + + # propagate keep values + shape = (len(sorting_post_split.unit_ids),) + parent_values.shape[1:] + new_values = np.empty(shape=shape, dtype=parent_values.dtype) + new_values[keep_post_inds] = parent_values[keep_pre_inds] + for split_unit, new_split_ids in zip(split_unit_ids, new_unit_ids): + split_index = sorting_pre_split.id_to_index(split_unit) + split_value = parent_values[split_index] + # propagate the split value to all new unit ids + new_unit_indices = sorting_post_split.ids_to_indices(new_split_ids) + new_values[new_unit_indices] = split_value + sorting_post_split.set_property(key, new_values) + + # set is_merged property + is_split = np.ones(len(sorting_post_split.unit_ids), dtype=bool) + is_split[keep_post_inds] = False + sorting_post_split.set_property("is_split", is_split) + + +def generate_unit_ids_for_split(old_unit_ids, unit_splits, new_unit_ids=None, new_id_strategy="append"): + """ + Function to generate new units ids during a splitting procedure. If `new_units_ids` + are provided, it will return these unit ids, checking that they are consistent with + `unit_splits`. + + Parameters + ---------- + old_unit_ids : np.array + The old unit_ids. + unit_splits : dict + + new_unit_ids : list | None, default: None + Optional new unit_ids for split units. If given, it needs to have the same length as `merge_unit_groups`. + If None, new ids will be generated. + new_id_strategy : "append" | "split", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + + * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + * "split" : new_unit_ids will be the created as {split_unit_id]-{split_number} + (e.g. when splitting unit "13" in 2: "13-0" / "13-1"). + Only works if unit_ids are str otherwise switch to "append" + + Returns + ------- + new_unit_ids : list of lists + The new units_ids associated with the merges. + """ + assert new_id_strategy in ["append", "split"], "new_id_strategy should be 'append' or 'split'" + old_unit_ids = np.asarray(old_unit_ids) + + if new_unit_ids is not None: + for split_unit, new_split_ids in zip(unit_splits.values(), new_unit_ids): + # then only doing a consistency check + assert len(split_unit) == len(new_split_ids), "new_unit_ids should have the same len as unit_splits.values" + # new_unit_ids can also be part of old_unit_ids only inside the same group: + assert all( + new_split_id not in old_unit_ids for new_split_id in new_split_ids + ), "new_unit_ids already exists but outside the split groups" + else: + dtype = old_unit_ids.dtype + if np.issubdtype(dtype, np.integer) and new_id_strategy == "split": + warnings.warn("new_id_strategy 'split' is not compatible with integer unit_ids. Switching to 'append'.") + new_id_strategy = "append" + + new_unit_ids = [] + current_unit_ids = old_unit_ids.copy() + for unit_to_split, split_indices in unit_splits.items(): + num_splits = len(split_indices) + # select new_unit_ids greater that the max id, event greater than the numerical str ids + if new_id_strategy == "append": + if np.issubdtype(dtype, np.character): + # dtype str + if all(p.isdigit() for p in current_unit_ids): + # All str are digit : we can generate a max + m = max(int(p) for p in current_unit_ids) + 1 + new_units_for_split = [str(m + i) for i in range(num_splits)] + else: + # we cannot automatically find new names + new_units_for_split = [f"{unit_to_split}-split{i}" for i in range(num_splits)] + else: + # dtype int + new_units_for_split = list(max(current_unit_ids) + 1 + np.arange(num_splits, dtype=dtype)) + # we append the new split unit ids to continue to increment the max id + current_unit_ids = np.concatenate([current_unit_ids, new_units_for_split]) + elif new_id_strategy == "split": + # we made sure that dtype is not integer + new_units_for_split = [f"{unit_to_split}-{i}" for i in np.arange(len(split_indices))] + new_unit_ids.append(new_units_for_split) + + return new_unit_ids + + +def check_unit_splits_consistency(unit_splits, sorting): + """ + Function to check the consistency of unit_splits indices with the sorting object. + It checks that the split indices for each unit are a list of lists, where each sublist (at least two) + contains the indices of the spikes to be assigned to each split. The sum of the lengths + of the sublists must equal the number of spikes in the unit. + + Parameters + ---------- + unit_splits : dict + A dictionary with the split unit id as key and a list of numpy arrays or lists of spike indices for each split. + sorting : BaseSorting + The sorting object containing spike information. + + Raises + ------ + ValueError + If the unit_splits are not in the expected format or if the total number of spikes in the splits does not match + the number of spikes in the unit. + """ + num_spikes = sorting.count_num_spikes_per_unit() + for unit_id, split_indices in unit_splits.items(): + if not isinstance(split_indices, (list, np.ndarray)): + raise ValueError(f"unit_splits[{unit_id}] should be a list or numpy array, got {type(split_indices)}") + if not all(isinstance(indices, (list, np.ndarray)) for indices in split_indices): + raise ValueError(f"unit_splits[{unit_id}] should be a list of lists or numpy arrays") + if len(split_indices) < 2: + raise ValueError(f"unit_splits[{unit_id}] should have at least two splits") + total_spikes_in_split = sum(len(indices) for indices in split_indices) + if total_spikes_in_split != num_spikes[unit_id]: + raise ValueError( + f"Total spikes in unit {unit_id} split ({total_spikes_in_split}) does not match the number of spikes in the unit ({num_spikes[unit_id]})" + ) + + +def _get_ids_after_splitting(old_unit_ids, split_units, new_unit_ids): + """ + Function to get the list of unique unit_ids after some splits, with given new_units_ids would + be provided. + + Every new unit_id will be added at the end if not already present. + + Parameters + ---------- + old_unit_ids : np.array + The old unit_ids. + split_units : dict + A dict of split units. Each element needs to have at least two elements (two units to split). + new_unit_ids : list | None + A new unit_ids for split units. If given, it needs to have the same length as `split_units` values. + + Returns + ------- + + all_unit_ids : The unit ids in the split sorting + The units_ids that will be present after splits + + """ + old_unit_ids = np.asarray(old_unit_ids) + dtype = old_unit_ids.dtype + if dtype.kind == "U": + # the new dtype can be longer + dtype = "U" + + assert len(new_unit_ids) == len(split_units), "new_unit_ids should have the same len as merge_unit_groups" + for new_unit_in_split, unit_to_split in zip(new_unit_ids, split_units.keys()): + assert len(new_unit_in_split) == len( + split_units[unit_to_split] + ), "new_unit_ids should have the same len as split_units values" + + all_unit_ids = list(old_unit_ids.copy()) + for split_unit, split_new_units in zip(split_units, new_unit_ids): + all_unit_ids.remove(split_unit) + all_unit_ids.extend(split_new_units) + return np.array(all_unit_ids, dtype=dtype) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index d1a0ef7be4..f8d11dd157 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -31,7 +31,13 @@ is_path_remote, clean_zarr_folder_name, ) -from .sorting_tools import generate_unit_ids_for_merge_group, _get_ids_after_merging +from .sorting_tools import ( + generate_unit_ids_for_merge_group, + generate_unit_ids_for_split, + check_unit_splits_consistency, + _get_ids_after_merging, + _get_ids_after_splitting, +) from .job_tools import split_job_kwargs from .numpyextractors import NumpySorting from .sparsity import ChannelSparsity, estimate_sparsity @@ -888,7 +894,7 @@ def are_units_mergeable( else: return mergeable - def _save_or_select_or_merge( + def _save_or_select_or_merge_or_split( self, format="binary_folder", folder=None, @@ -897,9 +903,12 @@ def _save_or_select_or_merge( censor_ms=None, merging_mode="soft", sparsity_overlap=0.75, - verbose=False, - new_unit_ids=None, + merge_new_unit_ids=None, + split_units=None, + splitting_mode="soft", + split_new_unit_ids=None, backend_options=None, + verbose=False, **job_kwargs, ) -> "SortingAnalyzer": """ @@ -921,12 +930,19 @@ def _save_or_select_or_merge( When merging units, any spikes violating this refractory period will be discarded. merging_mode : "soft" | "hard", default: "soft" How merges are performed. In the "soft" mode, merges will be approximated, with no smart merging - of the extension data. + of the extension data. In the "hard" mode, the extensions for merged units will be recomputed. sparsity_overlap : float, default 0.75 The percentage of overlap that units should share in order to accept merges. If this criteria is not achieved, soft merging will not be performed. - new_unit_ids : list or None, default: None + merge_new_unit_ids : list or None, default: None The new unit ids for merged units. Required if `merge_unit_groups` is not None. + split_units : dict or None, default: None + A dictionary with the keys being the unit ids to split and the values being the split indices. + splitting_mode : "soft" | "hard", default: "soft" + How splits are performed. In the "soft" mode, splits will be approximated, with no smart splitting. + If `splitting_mode` is "hard", the extensons for split units willbe recomputed. + split_new_unit_ids : list or None, default: None + The new unit ids for split units. Required if `split_units` is not None. verbose : bool, default: False If True, output is verbose. backend_options : dict | None, default: None @@ -949,36 +965,64 @@ def _save_or_select_or_merge( else: recording = None - if self.sparsity is not None and unit_ids is None and merge_unit_groups is None: - sparsity = self.sparsity - elif self.sparsity is not None and unit_ids is not None and merge_unit_groups is None: - sparsity_mask = self.sparsity.mask[np.isin(self.unit_ids, unit_ids), :] - sparsity = ChannelSparsity(sparsity_mask, unit_ids, self.channel_ids) - elif self.sparsity is not None and merge_unit_groups is not None: - all_unit_ids = unit_ids - sparsity_mask = np.zeros((len(all_unit_ids), self.sparsity.mask.shape[1]), dtype=bool) - mergeable, masks = self.are_units_mergeable( - merge_unit_groups, - sparsity_overlap=sparsity_overlap, - return_masks=True, - ) + has_removed = unit_ids is not None + has_merges = merge_unit_groups is not None + has_splits = split_units is not None + assert not has_merges if has_splits else True, "Cannot merge and split at the same time" + + if self.sparsity is not None: + if not has_removed and not has_merges and not has_splits: + # no changes in units + sparsity = self.sparsity + elif has_removed and not has_merges and not has_splits: + # remove units + sparsity_mask = self.sparsity.mask[np.isin(self.unit_ids, unit_ids), :] + sparsity = ChannelSparsity(sparsity_mask, unit_ids, self.channel_ids) + elif has_merges: + # merge units + all_unit_ids = unit_ids + sparsity_mask = np.zeros((len(all_unit_ids), self.sparsity.mask.shape[1]), dtype=bool) + mergeable, masks = self.are_units_mergeable( + merge_unit_groups, + sparsity_overlap=sparsity_overlap, + merging_mode=merging_mode, + return_masks=True, + ) - for unit_index, unit_id in enumerate(all_unit_ids): - if unit_id in new_unit_ids: - merge_unit_group = tuple(merge_unit_groups[new_unit_ids.index(unit_id)]) - if not mergeable[merge_unit_group]: - raise Exception( - f"The sparsity of {merge_unit_group} do not overlap enough for a soft merge using " - f"a sparsity threshold of {sparsity_overlap}. You can either lower the threshold or use " - "a hard merge." - ) + for unit_index, unit_id in enumerate(all_unit_ids): + if unit_id in merge_new_unit_ids: + merge_unit_group = tuple(merge_unit_groups[merge_new_unit_ids.index(unit_id)]) + if not mergeable[merge_unit_group]: + raise Exception( + f"The sparsity of {merge_unit_group} do not overlap enough for a soft merge using " + f"a sparsity threshold of {sparsity_overlap}. You can either lower the threshold or use " + "a hard merge." + ) + else: + sparsity_mask[unit_index] = masks[merge_unit_group] else: - sparsity_mask[unit_index] = masks[merge_unit_group] - else: - # This means that the unit is already in the previous sorting - index = self.sorting.id_to_index(unit_id) - sparsity_mask[unit_index] = self.sparsity.mask[index] - sparsity = ChannelSparsity(sparsity_mask, list(all_unit_ids), self.channel_ids) + # This means that the unit is already in the previous sorting + index = self.sorting.id_to_index(unit_id) + sparsity_mask[unit_index] = self.sparsity.mask[index] + sparsity = ChannelSparsity(sparsity_mask, list(all_unit_ids), self.channel_ids) + elif has_splits: + # split units + all_unit_ids = unit_ids + original_unit_ids = self.unit_ids + sparsity_mask = np.zeros((len(all_unit_ids), self.sparsity.mask.shape[1]), dtype=bool) + for unit_index, unit_id in enumerate(all_unit_ids): + if unit_id not in original_unit_ids: + # then it is a new unit + # we assign the original sparsity + for split_unit, new_unit_ids in zip(split_units, split_new_unit_ids): + if unit_id in new_unit_ids: + original_unit_index = self.sorting.id_to_index(split_unit) + sparsity_mask[unit_index] = self.sparsity.mask[original_unit_index] + break + else: + original_unit_index = self.sorting.id_to_index(unit_id) + sparsity_mask[unit_index] = self.sparsity.mask[original_unit_index] + sparsity = ChannelSparsity(sparsity_mask, list(all_unit_ids), self.channel_ids) else: sparsity = None @@ -988,25 +1032,33 @@ def _save_or_select_or_merge( # if the original sorting object is not available anymore (kilosort folder deleted, ....), take the copy sorting_provenance = self.sorting - if merge_unit_groups is None: + if merge_unit_groups is None and split_units is None: # when only some unit_ids then the sorting must be sliced # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! sorting_provenance = sorting_provenance.select_units(unit_ids) - else: + elif merge_unit_groups is not None: + assert split_units is None, "split_units must be None when merge_unit_groups is None" from spikeinterface.core.sorting_tools import apply_merges_to_sorting sorting_provenance, keep_mask, _ = apply_merges_to_sorting( sorting=sorting_provenance, merge_unit_groups=merge_unit_groups, - new_unit_ids=new_unit_ids, + new_unit_ids=merge_new_unit_ids, censor_ms=censor_ms, return_extra=True, ) if censor_ms is None: # in this case having keep_mask None is faster instead of having a vector of ones keep_mask = None - # TODO: sam/pierre would create a curation field / curation.json with the applied merges. - # What do you think? + elif split_units is not None: + assert merge_unit_groups is None, "merge_unit_groups must be None when split_units is not None" + from spikeinterface.core.sorting_tools import apply_splits_to_sorting + + sorting_provenance = apply_splits_to_sorting( + sorting=sorting_provenance, + unit_splits=split_units, + new_unit_ids=split_new_unit_ids, + ) backend_options = {} if backend_options is None else backend_options @@ -1055,26 +1107,34 @@ def _save_or_select_or_merge( recompute_dict = {} for extension_name, extension in sorted_extensions.items(): - if merge_unit_groups is None: + if merge_unit_groups is None and split_units is None: # copy full or select new_sorting_analyzer.extensions[extension_name] = extension.copy( new_sorting_analyzer, unit_ids=unit_ids ) - else: + elif merge_unit_groups is not None: # merge if merging_mode == "soft": new_sorting_analyzer.extensions[extension_name] = extension.merge( new_sorting_analyzer, merge_unit_groups=merge_unit_groups, - new_unit_ids=new_unit_ids, + new_unit_ids=merge_new_unit_ids, keep_mask=keep_mask, verbose=verbose, **job_kwargs, ) elif merging_mode == "hard": recompute_dict[extension_name] = extension.params + else: + # split + if splitting_mode == "soft": + new_sorting_analyzer.extensions[extension_name] = extension.split( + new_sorting_analyzer, split_units=split_units, new_unit_ids=split_new_unit_ids, verbose=verbose + ) + elif splitting_mode == "hard": + recompute_dict[extension_name] = extension.params - if merge_unit_groups is not None and merging_mode == "hard" and len(recompute_dict) > 0: + if len(recompute_dict) > 0: new_sorting_analyzer.compute_several_extensions(recompute_dict, save=True, verbose=verbose, **job_kwargs) return new_sorting_analyzer @@ -1102,7 +1162,7 @@ def save_as(self, format="memory", folder=None, backend_options=None) -> "Sortin """ if format == "zarr": folder = clean_zarr_folder_name(folder) - return self._save_or_select_or_merge(format=format, folder=folder, backend_options=backend_options) + return self._save_or_select_or_merge_or_split(format=format, folder=folder, backend_options=backend_options) def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyzer": """ @@ -1129,7 +1189,7 @@ def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyz # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! if format == "zarr": folder = clean_zarr_folder_name(folder) - return self._save_or_select_or_merge(format=format, folder=folder, unit_ids=unit_ids) + return self._save_or_select_or_merge_or_split(format=format, folder=folder, unit_ids=unit_ids) def remove_units(self, remove_unit_ids, format="memory", folder=None) -> "SortingAnalyzer": """ @@ -1157,7 +1217,7 @@ def remove_units(self, remove_unit_ids, format="memory", folder=None) -> "Sortin unit_ids = self.unit_ids[~np.isin(self.unit_ids, remove_unit_ids)] if format == "zarr": folder = clean_zarr_folder_name(folder) - return self._save_or_select_or_merge(format=format, folder=folder, unit_ids=unit_ids) + return self._save_or_select_or_merge_or_split(format=format, folder=folder, unit_ids=unit_ids) def merge_units( self, @@ -1222,28 +1282,18 @@ def merge_units( assert merging_mode in ["soft", "hard"], "Merging mode should be either soft or hard" if len(merge_unit_groups) == 0: - # TODO I think we should raise an error or at least make a copy and not return itself - if return_new_unit_ids: - return self, [] - else: - return self + raise ValueError("Merging requires at least one group of units to merge") for units in merge_unit_groups: - # TODO more checks like one units is only in one group if len(units) < 2: raise ValueError("Merging requires at least two units to merge") - # TODO : no this function did not exists before - if not isinstance(merge_unit_groups[0], (list, tuple)): - # keep backward compatibility : the previous behavior was only one merge - merge_unit_groups = [merge_unit_groups] - new_unit_ids = generate_unit_ids_for_merge_group( self.unit_ids, merge_unit_groups, new_unit_ids, new_id_strategy ) all_unit_ids = _get_ids_after_merging(self.unit_ids, merge_unit_groups, new_unit_ids=new_unit_ids) - new_analyzer = self._save_or_select_or_merge( + new_analyzer = self._save_or_select_or_merge_or_split( format=format, folder=folder, merge_unit_groups=merge_unit_groups, @@ -1252,7 +1302,79 @@ def merge_units( merging_mode=merging_mode, sparsity_overlap=sparsity_overlap, verbose=verbose, - new_unit_ids=new_unit_ids, + merge_new_unit_ids=new_unit_ids, + **job_kwargs, + ) + if return_new_unit_ids: + return new_analyzer, new_unit_ids + else: + return new_analyzer + + def split_units( + self, + split_units: dict[list[str | int], list[int] | list[list[int]]], + new_unit_ids: list[list[int | str]] | None = None, + new_id_strategy: str = "append", + return_new_unit_ids: bool = False, + format: str = "memory", + folder: Path | str | None = None, + verbose: bool = False, + **job_kwargs, + ) -> "SortingAnalyzer | tuple[SortingAnalyzer, list[int | str]]": + """ + This method is equivalent to `save_as()` but with a list of splits that have to be achieved. + Split units by creating a new SortingAnalyzer object with the appropriate splits + + Extensions are also updated to display the split `unit_ids`. + + Parameters + ---------- + split_units : dict + A dictionary with the keys being the unit ids to split and the values being the split indices. + The split indices for each unit MUST be a list of lists, where each sublist (at least two) contains the + indices of the spikes to be assigned to the each split. The sum of the lengths of the sublists must equal + the number of spikes in the unit. + new_unit_ids : None | list, default: None + A new unit_ids for split units. If given, it needs to have the same length as `merge_unit_groups`. If None, + merged units will have the first unit_id of every lists of merges + new_id_strategy : "append" | "split", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + + * "append" : new_units_ids will be added at the end of max(sorting.unit_ids) + * "split" : new_unit_ids will be the original unit_id to split with -{subsplit} + return_new_unit_ids : bool, default False + Alse return new_unit_ids which are the ids of the new units. + folder : Path | None, default: None + The new folder where the analyzer with merged units is copied for `format` "binary_folder" or "zarr" + format : "memory" | "binary_folder" | "zarr", default: "memory" + The format of SortingAnalyzer + verbose : bool, default: False + Whether to display calculations (such as sparsity estimation) + + Returns + ------- + analyzer : SortingAnalyzer + The newly create `SortingAnalyzer` with the selected units + """ + + if format == "zarr": + folder = clean_zarr_folder_name(folder) + + if len(split_units) == 0: + raise ValueError("Splitting requires at least one unit to split") + + check_unit_splits_consistency(split_units, self.sorting) + + new_unit_ids = generate_unit_ids_for_split(self.unit_ids, split_units, new_unit_ids, new_id_strategy) + all_unit_ids = _get_ids_after_splitting(self.unit_ids, split_units, new_unit_ids=new_unit_ids) + + new_analyzer = self._save_or_select_or_merge_or_split( + format=format, + folder=folder, + split_units=split_units, + unit_ids=all_unit_ids, + verbose=verbose, + split_new_unit_ids=new_unit_ids, **job_kwargs, ) if return_new_unit_ids: @@ -1264,7 +1386,7 @@ def copy(self): """ Create a a copy of SortingAnalyzer with format "memory". """ - return self._save_or_select_or_merge(format="memory", folder=None) + return self._save_or_select_or_merge_or_split(format="memory", folder=None) def is_read_only(self) -> bool: if self.format == "memory": @@ -2069,6 +2191,10 @@ def _merge_extension_data( # must be implemented in subclass raise NotImplementedError + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # must be implemented in subclass + raise NotImplementedError + def _get_pipeline_nodes(self): # must be implemented in subclass only if use_nodepipeline=True raise NotImplementedError @@ -2219,6 +2345,7 @@ def load_data(self): ext_data_file.name == "params.json" or ext_data_file.name == "info.json" or ext_data_file.name == "run_info.json" + or str(ext_data_file.name).startswith("._") # ignore AppleDouble format files ): continue ext_data_name = ext_data_file.stem @@ -2304,6 +2431,23 @@ def merge( new_extension.save() return new_extension + def split( + self, + new_sorting_analyzer, + split_units, + new_unit_ids, + verbose=False, + **job_kwargs, + ): + new_extension = self.__class__(new_sorting_analyzer) + new_extension.params = self.params.copy() + new_extension.data = self._split_extension_data( + split_units, new_unit_ids, new_sorting_analyzer, verbose=verbose, **job_kwargs + ) + new_extension.run_info = copy(self.run_info) + new_extension.save() + return new_extension + def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): # NB: this call to _save_params() also resets the folder or zarr group diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 0e30760262..f760243fb5 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -21,8 +21,7 @@ * "snr" : threshold based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold (in units of noise levels) and the "amplitude_mode" argument to specify the mode to compute the amplitude of the templates. - * "amplitude" : threshold based on the amplitude values on every channels. Use the "threshold" argument - to specify the ptp threshold (in units of amplitude) and the "amplitude_mode" argument + * "amplitude" : threshold based on the amplitude values on every channels. Use the "amplitude_mode" argument to specify the mode to compute the amplitude of the templates. * "energy" : threshold based on the expected energy that should be present on the channels, given their noise levels. Use the "threshold" argument to specify the energy threshold @@ -30,7 +29,6 @@ * "by_property" : sparsity is given by a property of the recording and sorting (e.g. "group"). In this case the sparsity for each unit is given by the channels that have the same property value as the unit. Use the "by_property" argument to specify the property name. - * "ptp: : deprecated, use the 'snr' method with the 'peak_to_peak' amplitude mode instead. peak_sign : "neg" | "pos" | "both" Sign of the template to compute best channels. @@ -39,7 +37,7 @@ radius_um : float Radius in um for "radius" method. threshold : float - Threshold for "snr", "energy" (in units of noise levels) and "ptp" methods (in units of amplitude). + Threshold for "snr" and "energy" (in units of noise levels) (in units of amplitude). For the "snr" method, the template amplitude mode is controlled by the "amplitude_mode" argument. amplitude_mode : "extremum" | "at_index" | "peak_to_peak" Mode to compute the amplitude of the templates for the "snr", "amplitude", and "best_channels" methods. @@ -454,33 +452,6 @@ def from_snr( mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) - @classmethod - def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): - """ - Construct sparsity from a thresholds based on template peak-to-peak values. - Use the "threshold" argument to specify the peak-to-peak threshold. - - Parameters - ---------- - templates_or_sorting_analyzer : Templates | SortingAnalyzer - A Templates or a SortingAnalyzer object. - threshold : float - Threshold for "ptp" method (in units of amplitude). - - Returns - ------- - sparsity : ChannelSparsity - The estimated sparsity. - """ - warnings.warn( - "The 'ptp' method is deprecated and will be removed in version 0.103.0. " - "Please use the 'snr' method with the 'peak_to_peak' amplitude mode instead.", - DeprecationWarning, - ) - return cls.from_snr( - templates_or_sorting_analyzer, threshold, amplitude_mode="peak_to_peak", noise_levels=noise_levels - ) - @classmethod def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode="extremum", peak_sign="neg"): """ @@ -635,9 +606,7 @@ def create_dense(cls, sorting_analyzer): def compute_sparsity( templates_or_sorting_analyzer: "Templates | SortingAnalyzer", noise_levels: np.ndarray | None = None, - method: ( - "radius" | "best_channels" | "closest_channels" | "snr" | "amplitude" | "energy" | "by_property" | "ptp" - ) = "radius", + method: "radius" | "best_channels" | "closest_channels" | "snr" | "amplitude" | "energy" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", num_channels: int | None = 5, radius_um: float | None = 100.0, @@ -672,7 +641,13 @@ def compute_sparsity( # to keep backward compatibility templates_or_sorting_analyzer = templates_or_sorting_analyzer.sorting_analyzer - if method in ("best_channels", "closest_channels", "radius", "snr", "amplitude", "ptp"): + if method in ( + "best_channels", + "closest_channels", + "radius", + "snr", + "amplitude", + ): assert isinstance( templates_or_sorting_analyzer, (Templates, SortingAnalyzer) ), f"compute_sparsity(method='{method}') need Templates or SortingAnalyzer" @@ -715,14 +690,6 @@ def compute_sparsity( sparsity = ChannelSparsity.from_property( templates_or_sorting_analyzer.sorting, templates_or_sorting_analyzer.recording, by_property ) - elif method == "ptp": - # TODO: remove after deprecation - assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_ptp( - templates_or_sorting_analyzer, - threshold, - noise_levels=noise_levels, - ) else: raise ValueError(f"compute_sparsity() method={method} does not exists") @@ -738,7 +705,7 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "closest_channels" | "amplitude" | "snr" | "by_property" | "ptp" = "radius", + method: "radius" | "best_channels" | "closest_channels" | "amplitude" | "snr" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, @@ -787,9 +754,9 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels", "closest_channels", "snr", "amplitude", "by_property", "ptp"), ( + assert method in ("radius", "best_channels", "closest_channels", "snr", "amplitude", "by_property"), ( f"method={method} is not available for `estimate_sparsity()`. " - "Available methods are 'radius', 'best_channels', 'snr', 'amplitude', 'by_property', 'ptp' (deprecated)" + "Available methods are 'radius', 'best_channels', 'snr', 'amplitude', 'by_property'" ) if recording.get_probes() == 1: @@ -866,14 +833,6 @@ def estimate_sparsity( sparsity = ChannelSparsity.from_amplitude( templates, threshold, amplitude_mode=amplitude_mode, peak_sign=peak_sign ) - elif method == "ptp": - # TODO: remove after deprecation - assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" - assert noise_levels is not None, ( - "For the 'snr' method, 'noise_levels' needs to be given. You can use the " - "`get_noise_levels()` function to compute them." - ) - sparsity = ChannelSparsity.from_ptp(templates, threshold, noise_levels=noise_levels) else: raise ValueError(f"compute_sparsity() method={method} does not exists") else: diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index c5c4e9db63..ce248f00f6 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -250,7 +250,7 @@ def test_SortingAnalyzer_tmp_recording(dataset): assert not sorting_analyzer_saved.has_temporary_recording() assert isinstance(sorting_analyzer_saved.recording, type(recording)) - recording_sliced = recording.channel_slice(recording.channel_ids[:-1]) + recording_sliced = recording.select_channels(recording.channel_ids[:-1]) # wrong channels with pytest.raises(ValueError): @@ -378,6 +378,7 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): # unit 0, 2, ... should be removed assert np.all(~np.isin(data["result_two"], [0, 2])) + # test merges if format != "memory": if format == "zarr": folder = cache_folder / f"test_SortingAnalyzer_merge_soft_with_{format}.zarr" @@ -387,10 +388,14 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): shutil.rmtree(folder) else: folder = None - sorting_analyzer4 = sorting_analyzer.merge_units(merge_unit_groups=[[0, 1]], format=format, folder=folder) + sorting_analyzer4, new_unit_ids = sorting_analyzer.merge_units( + merge_unit_groups=[[0, 1]], format=format, folder=folder, return_new_unit_ids=True + ) assert 0 not in sorting_analyzer4.unit_ids assert 1 not in sorting_analyzer4.unit_ids assert len(sorting_analyzer4.unit_ids) == len(sorting_analyzer.unit_ids) - 1 + is_merged_values = sorting_analyzer4.sorting.get_property("is_merged") + assert is_merged_values[sorting_analyzer4.sorting.ids_to_indices(new_unit_ids)][0] if format != "memory": if format == "zarr": @@ -401,13 +406,50 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): shutil.rmtree(folder) else: folder = None - sorting_analyzer5 = sorting_analyzer.merge_units( - merge_unit_groups=[[0, 1]], new_unit_ids=[50], format=format, folder=folder, merging_mode="hard" + sorting_analyzer5, new_unit_ids = sorting_analyzer.merge_units( + merge_unit_groups=[[0, 1]], + new_unit_ids=[50], + format=format, + folder=folder, + merging_mode="hard", + return_new_unit_ids=True, ) assert 0 not in sorting_analyzer5.unit_ids assert 1 not in sorting_analyzer5.unit_ids assert len(sorting_analyzer5.unit_ids) == len(sorting_analyzer.unit_ids) - 1 assert 50 in sorting_analyzer5.unit_ids + is_merged_values = sorting_analyzer5.sorting.get_property("is_merged") + assert is_merged_values[sorting_analyzer5.sorting.id_to_index(50)] + + # test splitting + if format != "memory": + if format == "zarr": + folder = cache_folder / f"test_SortingAnalyzer_split_with_{format}.zarr" + else: + folder = cache_folder / f"test_SortingAnalyzer_split_with_{format}" + if folder.exists(): + shutil.rmtree(folder) + else: + folder = None + split_units = {} + num_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit() + units_to_split = sorting_analyzer.unit_ids[:2] + for unit in units_to_split: + for unit in units_to_split: + split_units[unit] = [ + np.arange(num_spikes[unit] // 2), + np.arange(num_spikes[unit] // 2, num_spikes[unit]), + ] + + sorting_analyzer6, split_new_unit_ids = sorting_analyzer.split_units( + split_units=split_units, format=format, folder=folder, return_new_unit_ids=True + ) + for unit_to_split in units_to_split: + assert unit_to_split not in sorting_analyzer6.unit_ids + assert len(sorting_analyzer6.unit_ids) == len(sorting_analyzer.unit_ids) + 2 + is_split_values = sorting_analyzer6.sorting.get_property("is_split") + for new_unit_ids in split_new_unit_ids: + assert all(is_split_values[sorting_analyzer6.sorting.ids_to_indices(new_unit_ids)]) # test compute with extension-specific params sorting_analyzer.compute(["dummy"], extension_params={"dummy": {"param1": 5.5}}) @@ -491,6 +533,14 @@ def _merge_extension_data( return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + new_data = dict() + new_data["result_one"] = self.data["result_one"] + spikes = new_sorting_analyzer.sorting.to_spike_vector() + new_data["result_two"] = spikes["unit_index"].copy() + new_data["result_three"] = np.zeros((len(new_sorting_analyzer.unit_ids), 2)) + return new_data + def _get_data(self): return self.data["result_one"] diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 6ed311b5d8..c865068e4a 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -268,24 +268,8 @@ def test_estimate_sparsity(): progress_bar=True, n_jobs=1, ) - # ptp: just run it print(noise_levels) - with pytest.warns(DeprecationWarning): - sparsity = estimate_sparsity( - sorting, - recording, - num_spikes_for_sparsity=50, - ms_before=1.0, - ms_after=2.0, - method="ptp", - threshold=5, - noise_levels=noise_levels, - chunk_duration="1s", - progress_bar=True, - n_jobs=1, - ) - def test_compute_sparsity(): recording, sorting = get_dataset() @@ -310,8 +294,6 @@ def test_compute_sparsity(): sparsity = compute_sparsity(sorting_analyzer, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="by_property", by_property="group") - with pytest.warns(DeprecationWarning): - sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) # using object Templates templates = sorting_analyzer.get_extension("templates").get_data(outputs="Templates") @@ -322,9 +304,6 @@ def test_compute_sparsity(): sparsity = compute_sparsity(templates, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") sparsity = compute_sparsity(templates, method="closest_channels", num_channels=2) - with pytest.warns(DeprecationWarning): - sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5) - if __name__ == "__main__": # test_ChannelSparsity() diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 2ae435ea84..f22939c33c 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -411,9 +411,10 @@ def _check_spike_times_are_correct(self, sorting, times_recording, segment_index spike_indexes = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) rec_times = times_recording.get_times(segment_index=segment_index) + times_in_recording = rec_times[spike_indexes] assert np.array_equal( spike_times, - rec_times[spike_indexes], + times_in_recording, ) def _get_sorting_with_recording_attached(self, recording_for_durations, recording_to_attach): diff --git a/src/spikeinterface/core/tests/test_unitsaggregationsorting.py b/src/spikeinterface/core/tests/test_unitsaggregationsorting.py index f60de5b62a..fadac094aa 100644 --- a/src/spikeinterface/core/tests/test_unitsaggregationsorting.py +++ b/src/spikeinterface/core/tests/test_unitsaggregationsorting.py @@ -149,5 +149,20 @@ def test_unit_aggregation_does_not_preserve_ids_not_the_same_type(): assert list(aggregated_sorting.get_unit_ids()) == ["0", "1", "2", "3", "4"] +def test_sampling_frequency_max_diff(): + """Test that the sampling frequency max diff is respected.""" + sorting1 = generate_sorting(sampling_frequency=30000, num_units=3) + sorting2 = generate_sorting(sampling_frequency=30000.01, num_units=3) + sorting3 = generate_sorting(sampling_frequency=30000.001, num_units=3) + + # Default is 0, so should not raise an error + with pytest.raises(ValueError): + aggregate_units([sorting1, sorting2, sorting3]) + + # This should not raise an warning + with pytest.warns(UserWarning): + aggregate_units([sorting1, sorting2, sorting3], sampling_frequency_max_diff=0.02) + + if __name__ == "__main__": - test_unitsaggregationsorting() + test_sampling_frequency_max_diff() diff --git a/src/spikeinterface/core/unitsaggregationsorting.py b/src/spikeinterface/core/unitsaggregationsorting.py index 8f4a2732c3..838660df46 100644 --- a/src/spikeinterface/core/unitsaggregationsorting.py +++ b/src/spikeinterface/core/unitsaggregationsorting.py @@ -1,11 +1,13 @@ from __future__ import annotations +import math import warnings import numpy as np -from .core_tools import define_function_from_class -from .base import BaseExtractor -from .basesorting import BaseSorting, BaseSortingSegment +from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.base import BaseExtractor +from spikeinterface.core.basesorting import BaseSorting, BaseSortingSegment +from spikeinterface.core.segmentutils import _check_sampling_frequencies class UnitsAggregationSorting(BaseSorting): @@ -18,6 +20,8 @@ class UnitsAggregationSorting(BaseSorting): List of BaseSorting objects to aggregate renamed_unit_ids: array-like If given, unit ids are renamed as provided. If None, unit ids are sequential integers. + sampling_frequency_max_diff : float, default: 0 + Maximum allowed difference of sampling frequencies across recordings Returns ------- @@ -25,7 +29,7 @@ class UnitsAggregationSorting(BaseSorting): The aggregated sorting object """ - def __init__(self, sorting_list, renamed_unit_ids=None): + def __init__(self, sorting_list, renamed_unit_ids=None, sampling_frequency_max_diff=0): unit_map = {} num_all_units = sum([sort.get_num_units() for sort in sorting_list]) @@ -59,13 +63,14 @@ def __init__(self, sorting_list, renamed_unit_ids=None): unit_map[unit_ids[u_id]] = {"sorting_id": s_i, "unit_id": unit_id} u_id += 1 - sampling_frequency = sorting_list[0].get_sampling_frequency() + sampling_frequencies = [sort.sampling_frequency for sort in sorting_list] num_segments = sorting_list[0].get_num_segments() - ok1 = all(sampling_frequency == sort.get_sampling_frequency() for sort in sorting_list) - ok2 = all(num_segments == sort.get_num_segments() for sort in sorting_list) - if not (ok1 and ok2): - raise ValueError("Sortings don't have the same sampling_frequency/num_segments") + _check_sampling_frequencies(sampling_frequencies, sampling_frequency_max_diff) + sampling_frequency = sampling_frequencies[0] + num_segments_ok = all(num_segments == sort.get_num_segments() for sort in sorting_list) + if not num_segments_ok: + raise ValueError("Sortings don't have the same num_segments") BaseSorting.__init__(self, sampling_frequency, unit_ids) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index d5697fabcb..70c1cf55e3 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -30,7 +30,7 @@ # extract_waveforms() and WaveformExtractor() have been replaced by the `SortingAnalyzer` since version 0.101.0. # You should use `spikeinterface.create_sorting_analyzer()` instead. # `spikeinterface.extract_waveforms()` is now mocking the old behavior for backwards compatibility only, -# and will be removed with version 0.103.0 +# and may potentially be removed in a future version. ####""" diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 8447728216..f9ace831d8 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -635,8 +635,9 @@ def get_potential_auto_merge( done by Aurelien Wyngaard and Victor Llobet. https://github.com/BarbourLab/lussac/blob/v1.0.0/postprocessing/merge_units.py """ + # deprecation moved to 0.105.0 for @zm711 warnings.warn( - "get_potential_auto_merge() is deprecated. Use compute_merge_unit_groups() instead", + "get_potential_auto_merge() is deprecated and will be removed in version 0.105.0. Use compute_merge_unit_groups() instead", DeprecationWarning, stacklevel=2, ) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index bcd6bd9a4b..baea4399ff 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -1,8 +1,9 @@ from __future__ import annotations import numpy as np +from itertools import chain -from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting +from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting, apply_splits_to_sorting from spikeinterface.curation.curation_model import CurationModel @@ -100,34 +101,45 @@ def curation_label_to_dataframe(curation_dict_or_model: dict | CurationModel): def apply_curation_labels( - sorting: BaseSorting, new_unit_ids: list[int, str], curation_dict_or_model: dict | CurationModel + sorting_or_analyzer: BaseSorting | SortingAnalyzer, curation_dict_or_model: dict | CurationModel ): """ - Apply manual labels after merges. + Apply manual labels after merges/splits. Rules: - * label for non merge is applied first + * label for non merged units is applied first * for merged group, when exclusive=True, if all have the same label then this label is applied * for merged group, when exclusive=False, if one unit has the label then the new one have also it + * for split units, the original label is applied to all split units """ if isinstance(curation_dict_or_model, dict): curation_model = CurationModel(**curation_dict_or_model) else: curation_model = curation_dict_or_model + if isinstance(sorting_or_analyzer, BaseSorting): + sorting = sorting_or_analyzer + else: + sorting = sorting_or_analyzer.sorting + # Please note that manual_labels is done on the unit_ids before the merge!!! manual_labels = curation_label_to_vectors(curation_model) - # apply on non merged + # apply on non merged / split + merge_new_unit_ids = [m.new_unit_id for m in curation_model.merges] + split_new_unit_ids = [m.new_unit_ids for m in curation_model.splits] + split_new_unit_ids = list(chain(*split_new_unit_ids)) + + merged_split_units = merge_new_unit_ids + split_new_unit_ids for key, values in manual_labels.items(): all_values = np.zeros(sorting.unit_ids.size, dtype=values.dtype) for unit_ind, unit_id in enumerate(sorting.unit_ids): - if unit_id not in new_unit_ids: + if unit_id not in merged_split_units: ind = list(curation_model.unit_ids).index(unit_id) all_values[unit_ind] = values[ind] sorting.set_property(key, all_values) - for new_unit_id, merge in zip(new_unit_ids, curation_model.merges): + for new_unit_id, merge in zip(merge_new_unit_ids, curation_model.merges): old_group_ids = merge.unit_ids for label_key, label_def in curation_model.label_definitions.items(): if label_def.exclusive: @@ -141,7 +153,6 @@ def apply_curation_labels( # all group has the same label or empty sorting.set_property(key, values=group_values[:1], ids=[new_unit_id]) else: - for key in label_def.label_options: group_values = [] for unit_id in old_group_ids: @@ -151,6 +162,23 @@ def apply_curation_labels( new_value = np.any(group_values) sorting.set_property(key, values=[new_value], ids=[new_unit_id]) + # splits + for split in curation_model.splits: + # propagate property of splut unit to new units + old_unit = split.unit_id + new_unit_ids = split.new_unit_ids + for label_key, label_def in curation_model.label_definitions.items(): + if label_def.exclusive: + ind = list(curation_model.unit_ids).index(old_unit) + value = manual_labels[label_key][ind] + if value != "": + sorting.set_property(label_key, values=[value] * len(new_unit_ids), ids=new_unit_ids) + else: + for key in label_def.label_options: + ind = list(curation_model.unit_ids).index(old_unit) + value = manual_labels[key][ind] + sorting.set_property(key, values=[value] * len(new_unit_ids), ids=new_unit_ids) + def apply_curation( sorting_or_analyzer: BaseSorting | SortingAnalyzer, @@ -168,7 +196,8 @@ def apply_curation( Steps are done in this order: 1. Apply removal using curation_dict["removed"] 2. Apply merges using curation_dict["merges"] - 3. Set labels using curation_dict["manual_labels"] + 3. Apply splits using curation_dict["splits"] + 4. Set labels using curation_dict["manual_labels"] A new Sorting or SortingAnalyzer (in memory) is returned. The user (an adult) has the responsability to save it somewhere (or not). @@ -202,11 +231,15 @@ def apply_curation( Returns ------- - sorting_or_analyzer : Sorting | SortingAnalyzer + curated_sorting_or_analyzer : Sorting | SortingAnalyzer The curated object. - - """ + assert isinstance( + sorting_or_analyzer, (BaseSorting, SortingAnalyzer) + ), f"`sorting_or_analyzer` must be a Sorting or a SortingAnalyzer, not an object of type {type(sorting_or_analyzer)}" + assert isinstance( + curation_dict_or_model, (dict, CurationModel) + ), f"`curation_dict_or_model` must be a dict or a CurationModel, not an object of type {type(curation_dict_or_model)}" if isinstance(curation_dict_or_model, dict): curation_model = CurationModel(**curation_dict_or_model) else: @@ -215,29 +248,29 @@ def apply_curation( if not np.array_equal(np.asarray(curation_model.unit_ids), sorting_or_analyzer.unit_ids): raise ValueError("unit_ids from the curation_dict do not match the one from Sorting or SortingAnalyzer") - if isinstance(sorting_or_analyzer, BaseSorting): - sorting = sorting_or_analyzer - sorting = sorting.remove_units(curation_model.removed) - if len(curation_model.merges) > 0: - sorting, _, new_unit_ids = apply_merges_to_sorting( - sorting, - merge_unit_groups=[m.unit_ids for m in curation_model.merges], + # 1. Remove units + if len(curation_model.removed) > 0: + curated_sorting_or_analyzer = sorting_or_analyzer.remove_units(curation_model.removed) + else: + curated_sorting_or_analyzer = sorting_or_analyzer + + # 2. Merge units + if len(curation_model.merges) > 0: + merge_unit_groups = [m.unit_ids for m in curation_model.merges] + merge_new_unit_ids = [m.new_unit_id for m in curation_model.merges if m.new_unit_id is not None] + if len(merge_new_unit_ids) == 0: + merge_new_unit_ids = None + if isinstance(sorting_or_analyzer, BaseSorting): + curated_sorting_or_analyzer, _, new_unit_ids = apply_merges_to_sorting( + curated_sorting_or_analyzer, + merge_unit_groups=merge_unit_groups, censor_ms=censor_ms, - return_extra=True, new_id_strategy=new_id_strategy, + return_extra=True, ) else: - new_unit_ids = [] - apply_curation_labels(sorting, new_unit_ids, curation_model) - return sorting - - elif isinstance(sorting_or_analyzer, SortingAnalyzer): - analyzer = sorting_or_analyzer - if len(curation_model.removed) > 0: - analyzer = analyzer.remove_units(curation_model.removed) - if len(curation_model.removed) > 0: - analyzer, new_unit_ids = analyzer.merge_units( - merge_unit_groups=[m.unit_ids for m in curation_model.merges], + curated_sorting_or_analyzer, new_unit_ids = curated_sorting_or_analyzer.merge_units( + merge_unit_groups=merge_unit_groups, censor_ms=censor_ms, merging_mode=merging_mode, sparsity_overlap=sparsity_overlap, @@ -247,11 +280,43 @@ def apply_curation( verbose=verbose, **job_kwargs, ) + for i, merge_unit_id in enumerate(new_unit_ids): + curation_model.merges[i].new_unit_id = merge_unit_id + + # 3. Split units + if len(curation_model.splits) > 0: + split_units = {} + for split in curation_model.splits: + sorting = ( + curated_sorting_or_analyzer + if isinstance(sorting_or_analyzer, BaseSorting) + else sorting_or_analyzer.sorting + ) + split_units[split.unit_id] = split.get_full_spike_indices(sorting) + split_new_unit_ids = [s.new_unit_ids for s in curation_model.splits if s.new_unit_ids is not None] + if len(split_new_unit_ids) == 0: + split_new_unit_ids = None + if isinstance(sorting_or_analyzer, BaseSorting): + curated_sorting_or_analyzer, new_unit_ids = apply_splits_to_sorting( + curated_sorting_or_analyzer, + split_units, + new_unit_ids=split_new_unit_ids, + new_id_strategy=new_id_strategy, + return_extra=True, + ) else: - new_unit_ids = [] - apply_curation_labels(analyzer.sorting, new_unit_ids, curation_model) - return analyzer - else: - raise TypeError( - f"`sorting_or_analyzer` must be a Sorting or a SortingAnalyzer, not an object of type {type(sorting_or_analyzer)}" - ) + curated_sorting_or_analyzer, new_unit_ids = curated_sorting_or_analyzer.split_units( + split_units, + new_id_strategy=new_id_strategy, + return_new_unit_ids=True, + new_unit_ids=split_new_unit_ids, + format="memory", + verbose=verbose, + ) + for i, split_unit_ids in enumerate(new_unit_ids): + curation_model.splits[i].new_unit_ids = split_unit_ids + + # 4. Apply labels + apply_curation_labels(curated_sorting_or_analyzer, curation_model) + + return curated_sorting_or_analyzer diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index cc5bc8e4c1..5eb025ace4 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -3,6 +3,8 @@ from itertools import chain, combinations import numpy as np +from spikeinterface import BaseSorting + class LabelDefinition(BaseModel): name: str = Field(description="Name of the label") @@ -31,12 +33,11 @@ class Split(BaseModel): "If labels, the split is defined by a list of labels for each spike (`labels`). " ), ) - indices: Optional[Union[List[int], List[List[int]]]] = Field( + indices: Optional[List[List[int]]] = Field( default=None, description=( - "List of indices for the split. If a list of indices, the unit is splt in 2 (provided indices/others). " - "If a list of lists, the unit is split in multiple groups (one for each list of indices), plus an optional " - "extra if the spike train has more spikes than the sum of the indices in the lists." + "List of indices for the split. The unit is split in multiple groups (one for each list of indices), " + "plus an optional extra if the spike train has more spikes than the sum of the indices in the lists." ), ) labels: Optional[List[int]] = Field(default=None, description="List of labels for the split") @@ -44,6 +45,35 @@ class Split(BaseModel): default=None, description="List of new unit IDs for each split" ) + def get_full_spike_indices(self, sorting: BaseSorting): + """ + Get the full indices of the spikes in the split for different split modes. + """ + num_spikes = sorting.count_num_spikes_per_unit()[self.unit_id] + if self.mode == "indices": + # check the sum of split_indices is equal to num_spikes + num_spikes_in_split = sum(len(indices) for indices in self.indices) + if num_spikes_in_split != num_spikes: + # add remaining spike indices + full_spike_indices = list(self.indices) + existing_indices = np.concatenate(self.indices) + remaining_indices = np.setdiff1d(np.arange(num_spikes), existing_indices) + full_spike_indices.append(remaining_indices) + else: + full_spike_indices = self.indices + elif self.mode == "labels": + assert len(self.labels) == num_spikes, ( + f"In 'labels' mode, the number of.labels ({len(self.labels)}) " + f"must match the number of spikes in the unit ({num_spikes})" + ) + # convert to spike indices + full_spike_indices = [] + for label in np.unique(self.labels): + label_indices = np.where(self.labels == label)[0] + full_spike_indices.append(label_indices) + + return full_spike_indices + class CurationModel(BaseModel): supported_versions: Tuple[Literal["1"], Literal["2"]] = Field( @@ -74,7 +104,6 @@ def add_label_definition_name(cls, label_definitions): @classmethod def check_manual_labels(cls, values): - unit_ids = list(values["unit_ids"]) manual_labels = values.get("manual_labels") if manual_labels is None: @@ -106,7 +135,6 @@ def check_manual_labels(cls, values): @classmethod def check_merges(cls, values): - unit_ids = list(values["unit_ids"]) merges = values.get("merges") if merges is None: @@ -229,7 +257,10 @@ def check_splits(cls, values): # Validate new unit IDs if split.new_unit_ids is not None: if split.mode == "indices": - if len(split.new_unit_ids) != len(split.indices): + if ( + len(split.new_unit_ids) != len(split.indices) + and len(split.new_unit_ids) != len(split.indices) + 1 + ): raise ValueError( f"Number of new unit IDs does not match number of splits for unit {split.unit_id}" ) diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 587673b91a..9707b0f082 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -127,10 +127,25 @@ # Test dictionary format for merges with string IDs curation_ids_str_dict = {**curation_ids_str, "merges": {"u50": ["u3", "u6"], "u51": ["u10", "u14", "u20"]}} +# This is a failure example with duplicated merge +duplicate_merge = curation_ids_int.copy() +duplicate_merge["merge_unit_groups"] = [[3, 6, 10], [10, 14, 20]] + # Test with splits curation_with_splits = { - **curation_ids_int, - "splits": [{"unit_id": 2, "mode": "indices", "indices": [[0, 1, 2], [3, 4, 5]], "new_unit_ids": [50, 51]}], + "format_version": "2", + "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], + "label_definitions": { + "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, + "putative_type": { + "label_options": ["excitatory", "inhibitory", "pyramidal", "mitral"], + "exclusive": False, + }, + }, + "manual_labels": [ + {"unit_id": 2, "quality": ["good"], "putative_type": ["excitatory", "pyramidal"]}, + ], + "splits": [{"unit_id": 2, "mode": "indices", "indices": [[0, 1, 2], [3, 4, 5]]}], } # Test dictionary format for splits @@ -243,15 +258,144 @@ def test_apply_curation(): assert sorting_curated.get_property("quality", ids=[2])[0] == "noise" assert sorting_curated.get_property("excitatory", ids=[2])[0] - # Test with splits - sorting_curated = apply_curation(sorting, curation_with_splits) - assert sorting_curated.get_property("quality", ids=[1])[0] == "good" - # Test analyzer analyzer_curated = apply_curation(analyzer, curation_ids_int) assert "quality" in analyzer_curated.sorting.get_property_keys() +def test_apply_curation_with_split(): + recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=9, seed=2205) + sorting = sorting.rename_units(np.array([1, 2, 3, 6, 10, 14, 20, 31, 42])) + analyzer = create_sorting_analyzer(sorting, recording, sparse=False) + + sorting_curated = apply_curation(sorting, curation_with_splits) + # the split indices are not complete, so an extra unit is added + assert len(sorting_curated.unit_ids) == len(sorting.unit_ids) + 2 + + assert 2 not in sorting_curated.unit_ids + split_unit_ids = [43, 44, 45] + for unit_id in split_unit_ids: + assert unit_id in sorting_curated.unit_ids + assert sorting_curated.get_property("quality", ids=[unit_id])[0] == "good" + assert sorting_curated.get_property("excitatory", ids=[unit_id])[0] + assert sorting_curated.get_property("pyramidal", ids=[unit_id])[0] + + analyzer_curated = apply_curation(analyzer, curation_with_splits) + assert len(analyzer_curated.sorting.unit_ids) == len(analyzer.sorting.unit_ids) + 2 + + assert 2 not in analyzer_curated.unit_ids + for unit_id in split_unit_ids: + assert unit_id in analyzer_curated.unit_ids + assert analyzer_curated.sorting.get_property("quality", ids=[unit_id])[0] == "good" + assert analyzer_curated.sorting.get_property("excitatory", ids=[unit_id])[0] + assert analyzer_curated.sorting.get_property("pyramidal", ids=[unit_id])[0] + + +def test_apply_curation_with_split_multi_segment(): + recording, sorting = generate_ground_truth_recording(durations=[10.0, 10.0], num_units=9, seed=2205) + sorting = sorting.rename_units(np.array([1, 2, 3, 6, 10, 14, 20, 31, 42])) + analyzer = create_sorting_analyzer(sorting, recording, sparse=False) + num_segments = sorting.get_num_segments() + + curation_with_splits_multi_segment = curation_with_splits.copy() + + # we make a split so that each subsplit will have all spikes from different segments + split_unit_id = curation_with_splits_multi_segment["splits"][0]["unit_id"] + sv = sorting.to_spike_vector() + unit_index = sorting.id_to_index(split_unit_id) + spikes_from_split_unit = sv[sv["unit_index"] == unit_index] + + split_indices = [] + cum_spikes = 0 + for segment_index in range(num_segments): + spikes_in_segment = spikes_from_split_unit[spikes_from_split_unit["segment_index"] == segment_index] + split_indices.append(np.arange(0, len(spikes_in_segment)) + cum_spikes) + cum_spikes += len(spikes_in_segment) + + curation_with_splits_multi_segment["splits"][0]["indices"] = split_indices + + sorting_curated = apply_curation(sorting, curation_with_splits_multi_segment) + + assert len(sorting_curated.unit_ids) == len(sorting.unit_ids) + 1 + assert 2 not in sorting_curated.unit_ids + assert 43 in sorting_curated.unit_ids + assert 44 in sorting_curated.unit_ids + + # check that spike trains are correctly split across segments + for seg_index in range(num_segments): + st_43 = sorting_curated.get_unit_spike_train(43, segment_index=seg_index) + st_44 = sorting_curated.get_unit_spike_train(44, segment_index=seg_index) + if seg_index == 0: + assert len(st_43) > 0 + assert len(st_44) == 0 + else: + assert len(st_43) == 0 + assert len(st_44) > 0 + + +def test_apply_curation_splits_with_mask(): + recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=9, seed=2205) + sorting = sorting.rename_units(np.array([1, 2, 3, 6, 10, 14, 20, 31, 42])) + analyzer = create_sorting_analyzer(sorting, recording, sparse=False) + + # Get number of spikes for unit 2 + num_spikes = sorting.count_num_spikes_per_unit()[2] + + # Create split labels that assign spikes to 3 different clusters + split_labels = np.zeros(num_spikes, dtype=int) + split_labels[: num_spikes // 3] = 0 # First third to cluster 0 + split_labels[num_spikes // 3 : 2 * num_spikes // 3] = 1 # Second third to cluster 1 + split_labels[2 * num_spikes // 3 :] = 2 # Last third to cluster 2 + + curation_with_mask_split = { + "format_version": "2", + "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], + "label_definitions": { + "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, + "putative_type": { + "label_options": ["excitatory", "inhibitory", "pyramidal", "mitral"], + "exclusive": False, + }, + }, + "manual_labels": [ + {"unit_id": 2, "quality": ["good"], "putative_type": ["excitatory", "pyramidal"]}, + ], + "splits": [ + { + "unit_id": 2, + "mode": "labels", + "labels": split_labels.tolist(), + "new_unit_ids": [43, 44, 45], + } + ], + } + + sorting_curated = apply_curation(sorting, curation_with_mask_split) + + # Check results + assert len(sorting_curated.unit_ids) == len(sorting.unit_ids) + 2 # Original units - 1 (split) + 3 (new) + assert 2 not in sorting_curated.unit_ids # Original unit should be removed + + # Check new split units + split_unit_ids = [43, 44, 45] + for unit_id in split_unit_ids: + assert unit_id in sorting_curated.unit_ids + # Check properties are propagated + assert sorting_curated.get_property("quality", ids=[unit_id])[0] == "good" + assert sorting_curated.get_property("excitatory", ids=[unit_id])[0] + assert sorting_curated.get_property("pyramidal", ids=[unit_id])[0] + + # Check analyzer + analyzer_curated = apply_curation(analyzer, curation_with_mask_split) + assert len(analyzer_curated.sorting.unit_ids) == len(analyzer.sorting.unit_ids) + 2 + + # Verify split sizes + spike_counts = analyzer_curated.sorting.count_num_spikes_per_unit() + assert spike_counts[43] == num_spikes // 3 # First third + assert spike_counts[44] == num_spikes // 3 # Second third + assert spike_counts[45] == num_spikes - 2 * (num_spikes // 3) # Remainder + + if __name__ == "__main__": test_curation_format_validation() test_to_from_json() @@ -259,3 +403,5 @@ def test_apply_curation(): test_curation_label_to_vectors() test_curation_label_to_dataframe() test_apply_curation() + test_apply_curation_with_split_multi_segment() + test_apply_curation_splits_with_mask() diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index 728d352973..c70c49e8f8 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -55,7 +55,7 @@ def __init__( raise ImportError(self.installation_mesg) if cbin_file is not None: warnings.warn( - "The `cbin_file` argument is deprecated, please use `cbin_file_path` instead", + "The `cbin_file` argument is deprecated and will be removed in version 0.104.0, please use `cbin_file_path` instead", DeprecationWarning, stacklevel=2, ) diff --git a/src/spikeinterface/extractors/extractor_classes.py b/src/spikeinterface/extractors/extractor_classes.py index 8d4aa8ed32..a975f2da9e 100644 --- a/src/spikeinterface/extractors/extractor_classes.py +++ b/src/spikeinterface/extractors/extractor_classes.py @@ -19,6 +19,7 @@ # sorting/recording/event from neo from .neoextractors import * +from .neoextractors import read_neuroscope # non-NEO objects implemented in neo folder # keep for reference Currently pulling from neoextractor __init__ @@ -84,7 +85,7 @@ # * A mapping from the original class to its wrapper string (because of __all__) # * A mapping from format to the class wrapper for convenience (exposed to users for ease of use) # -# To achieve these there goals we do the following: +# To achieve these three goals we do the following: # # 1) we line up each class with its wrapper that returns a snakecase version of the class (in some docs called # the "function" version, although this is just a wrapper of the underlying class) @@ -161,7 +162,7 @@ # (e.g. 'intan' , 'kilosort') and values being the appropriate Extractor class returned as its wrapper # (e.g. IntanRecordingExtractor, KiloSortSortingExtractor) # An important note is the the formats are returned after performing `.lower()` so a format like -# SpikeGLX will be a key of 'spikeglx' +# SpikeGLX will have a key of 'spikeglx' # for example if we wanted to create a recording from an intan file we could do the following: # >>> recording = se.recording_extractor_full_dict['intan'](file_path='path/to/data.rhd') @@ -198,5 +199,6 @@ "snippets_extractor_full_dict", "read_binary", # convenience function for binary formats "read_zarr", + "read_neuroscope", # convenience function for neuroscope ] ) diff --git a/src/spikeinterface/extractors/neoextractors/__init__.py b/src/spikeinterface/extractors/neoextractors/__init__.py index d781d8c5ae..a90e0954f3 100644 --- a/src/spikeinterface/extractors/neoextractors/__init__.py +++ b/src/spikeinterface/extractors/neoextractors/__init__.py @@ -1,4 +1,5 @@ from .alphaomega import AlphaOmegaRecordingExtractor, AlphaOmegaEventExtractor, read_alphaomega, read_alphaomega_event +from .axon import AxonRecordingExtractor, read_axon from .axona import AxonaRecordingExtractor, read_axona from .biocam import BiocamRecordingExtractor, read_biocam from .blackrock import BlackrockRecordingExtractor, BlackrockSortingExtractor, read_blackrock, read_blackrock_sorting @@ -44,6 +45,7 @@ neo_recording_extractors_dict = { AlphaOmegaRecordingExtractor: dict(wrapper_string="read_alphaomega", wrapper_class=read_alphaomega), + AxonRecordingExtractor: dict(wrapper_string="read_axon", wrapper_class=read_axon), AxonaRecordingExtractor: dict(wrapper_string="read_axona", wrapper_class=read_axona), BiocamRecordingExtractor: dict(wrapper_string="read_biocam", wrapper_class=read_biocam), BlackrockRecordingExtractor: dict(wrapper_string="read_blackrock", wrapper_class=read_blackrock), diff --git a/src/spikeinterface/extractors/neoextractors/axon.py b/src/spikeinterface/extractors/neoextractors/axon.py new file mode 100644 index 0000000000..b2c47b697f --- /dev/null +++ b/src/spikeinterface/extractors/neoextractors/axon.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from pathlib import Path + +from spikeinterface.core.core_tools import define_function_from_class + +from .neobaseextractor import NeoBaseRecordingExtractor + + +class AxonRecordingExtractor(NeoBaseRecordingExtractor): + """ + Class for reading Axon Binary Format (ABF) files. + + Based on :py:class:`neo.rawio.AxonRawIO` + + Supports both ABF1 (pClamp ≤9) and ABF2 (pClamp ≥10) formats. + Can read data from pCLAMP and AxoScope software. + + Parameters + ---------- + file_path : str or Path + The ABF file path to load the recordings from. + stream_id : str or None, default: None + If there are several streams, specify the stream id you want to load. + stream_name : str or None, default: None + If there are several streams, specify the stream name you want to load. + all_annotations : bool, default: False + Load exhaustively all annotations from neo. + use_names_as_ids : bool, default: False + Determines the format of the channel IDs used by the extractor. + + Examples + -------- + >>> from spikeinterface.extractors import read_axon + >>> recording = read_axon(file_path='path/to/file.abf') + """ + + NeoRawIOClass = "AxonRawIO" + + def __init__( + self, + file_path, + stream_id=None, + stream_name=None, + all_annotations: bool = False, + use_names_as_ids: bool = False, + ): + neo_kwargs = self.map_to_neo_kwargs(file_path) + NeoBaseRecordingExtractor.__init__( + self, + stream_id=stream_id, + stream_name=stream_name, + all_annotations=all_annotations, + use_names_as_ids=use_names_as_ids, + **neo_kwargs, + ) + + self._kwargs.update(dict(file_path=str(Path(file_path).absolute()))) + + @classmethod + def map_to_neo_kwargs(cls, file_path): + neo_kwargs = {"filename": str(file_path)} + return neo_kwargs + + +read_axon = define_function_from_class(source_class=AxonRecordingExtractor, name="read_axon") diff --git a/src/spikeinterface/extractors/neoextractors/neuroscope.py b/src/spikeinterface/extractors/neoextractors/neuroscope.py index 70a110eced..298a2d6109 100644 --- a/src/spikeinterface/extractors/neoextractors/neuroscope.py +++ b/src/spikeinterface/extractors/neoextractors/neuroscope.py @@ -3,6 +3,7 @@ import warnings from pathlib import Path from typing import Union, Optional +from xml.etree import ElementTree as Etree import numpy as np @@ -64,6 +65,7 @@ def __init__( if xml_file_path is not None: xml_file_path = str(Path(xml_file_path).absolute()) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), xml_file_path=xml_file_path)) + self.xml_file_path = xml_file_path if xml_file_path is not None else Path(file_path).with_suffix(".xml") @classmethod def map_to_neo_kwargs(cls, file_path, xml_file_path=None): @@ -78,6 +80,71 @@ def map_to_neo_kwargs(cls, file_path, xml_file_path=None): return neo_kwargs + def _parse_xml_file(self, xml_file_path): + """ + Comes from NeuroPhy package by Diba Lab + """ + tree = Etree.parse(xml_file_path) + myroot = tree.getroot() + + for sf in myroot.findall("acquisitionSystem"): + n_channels = int(sf.find("nChannels").text) + + channel_groups, skipped_channels, anatomycolors = [], [], {} + for x in myroot.findall("anatomicalDescription"): + for y in x.findall("channelGroups"): + for z in y.findall("group"): + chan_group = [] + for chan in z.findall("channel"): + if int(chan.attrib["skip"]) == 1: + skipped_channels.append(int(chan.text)) + + chan_group.append(int(chan.text)) + if chan_group: + channel_groups.append(np.array(chan_group)) + + for x in myroot.findall("neuroscope"): + for y in x.findall("channels"): + for i, z in enumerate(y.findall("channelColors")): + try: + channel_id = str(z.find("channel").text) + color = z.find("color").text + + except AttributeError: + channel_id = i + color = "#0080ff" + anatomycolors[channel_id] = color + + discarded_channels = [ch for ch in range(n_channels) if all(ch not in group for group in channel_groups)] + kept_channels = [ch for ch in range(n_channels) if ch not in skipped_channels and ch not in discarded_channels] + + return channel_groups, kept_channels, discarded_channels, anatomycolors + + def _set_neuroscope_groups(self): + """ + Set the group ids and colors based on the xml file. + These group ids are usually different brain/body anatomical areas, or shanks from multi-shank probes. + The group ids are set as a property of the recording extractor. + """ + n = self.get_num_channels() + group_ids = np.full(n, -1, dtype=int) # Initialize all positions to -1 + + channel_groups, kept_channels, discarded_channels, colors = self._parse_xml_file(self.xml_file_path) + for group_id, numbers in enumerate(channel_groups): + group_ids[numbers] = group_id # Assign group_id to the positions in `numbers` + self.set_property("neuroscope_group", group_ids) + discarded_ppty = np.full(n, False, dtype=bool) + discarded_ppty[discarded_channels] = True + self.set_property("discarded_channels", discarded_ppty) + self.set_property("colors", values=list(colors.values()), ids=list(colors.keys())) + + def prepare_neuroscope_for_ephyviewer(self): + """ + Prepare the recording extractor for ephyviewer by setting the group ids and colors. + This function is not called when the extractor is initialized, and the user must call it manually. + """ + self._set_neuroscope_groups() + class NeuroScopeSortingExtractor(BaseSorting): """ diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index 0db3cd4426..175551d0a1 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -80,8 +80,9 @@ def __init__( ignore_timestamps_errors: bool = None, ): if ignore_timestamps_errors is not None: + dep_msg = "OpenEphysLegacyRecordingExtractor: `ignore_timestamps_errors` is deprecated. It will be removed in version 0.104.0 and is currently ignored" warnings.warn( - "OpenEphysLegacyRecordingExtractor: ignore_timestamps_errors is deprecated and is ignored", + dep_msg, DeprecationWarning, stacklevel=2, ) @@ -161,8 +162,8 @@ def __init__( if load_sync_channel: warning_message = ( - "OpenEphysBinaryRecordingExtractor: load_sync_channel is deprecated and will" - "be removed in version 0.104, use the stream_name or stream_id to load the sync stream if needed" + "OpenEphysBinaryRecordingExtractor: `load_sync_channel` is deprecated and will" + "be removed in version 0.104, use the `stream_name` or `stream_id` to load the sync stream if needed" ) warnings.warn(warning_message, DeprecationWarning, stacklevel=2) @@ -353,20 +354,20 @@ def read_openephys(folder_path, **kwargs): load_sync_channel : bool, default: False If False (default) and a SYNC channel is present (e.g. Neuropixels), this is not loaded. If True, the SYNC channel is loaded and can be accessed in the analog signals. - For open ephsy binary format only + For open ephys binary format only load_sync_timestamps : bool, default: False If True, the synchronized_timestamps are loaded and set as times to the recording. If False (default), only the t_start and sampling rate are set, and timestamps are assumed to be uniform and linearly increasing. - For open ephsy binary format only + For open ephys binary format only experiment_names : str, list, or None, default: None If multiple experiments are available, this argument allows users to select one or more experiments. If None, all experiements are loaded as blocks. E.g. `experiment_names="experiment2"`, `experiment_names=["experiment1", "experiment2"]` - For open ephsy binary format only + For open ephys binary format only ignore_timestamps_errors : bool, default: False Ignore the discontinuous timestamps errors in neo - For open ephsy legacy format only + For open ephys legacy format only Returns diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index df3728fe6a..dd51187508 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -75,9 +75,12 @@ def __init__( self._kwargs.update(dict(folder_path=str(Path(folder_path).absolute()), load_sync_channel=load_sync_channel)) - stream_is_nidq_or_sync = "nidq" in self.stream_id or "SYNC" in self.stream_id - if stream_is_nidq_or_sync: - # Do not add probe information for the sync or nidq stream. Early return + stream_is_nidq = "nidq" in self.stream_id + stream_is_one_box = "obx" in self.stream_id + stream_is_sync = "SYNC" in self.stream_id + + if stream_is_nidq or stream_is_one_box or stream_is_sync: + # Do not add probe information for the one box, nidq or sync streams. Early return return None # Checks if the probe information is available and adds location, shanks and sample shift if available. diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 65125efbcc..01a6f9e335 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -53,16 +53,6 @@ def read_file_from_backend( else: raise RuntimeError(f"{file_path} is not a valid HDF5 file!") - elif stream_mode == "ros3": - import h5py - - assert file_path is not None, "file_path must be specified when using stream_mode='ros3'" - - drivers = h5py.registered_drivers() - assertion_msg = "ROS3 support not enbabled, use: install -c conda-forge h5py>=3.2 to enable streaming" - assert "ros3" in drivers, assertion_msg - open_file = h5py.File(name=file_path, mode="r", driver="ros3") - elif stream_mode == "remfile": import remfile import h5py @@ -452,8 +442,6 @@ class NwbRecordingExtractor(BaseRecording, _BaseNWBExtractor): file_path : str, Path, or None Path to the NWB file or an s3 URL. Use this parameter to specify the file location if not using the `file` parameter. - electrical_series_name : str or None, default: None - Deprecated, use `electrical_series_path` instead. electrical_series_path : str or None, default: None The name of the ElectricalSeries object within the NWB file. This parameter is crucial when the NWB file contains multiple ElectricalSeries objects. It helps in identifying @@ -521,7 +509,6 @@ class NwbRecordingExtractor(BaseRecording, _BaseNWBExtractor): def __init__( self, file_path: str | Path | None = None, # provide either this or file - electrical_series_name: str | None = None, # deprecated load_time_vector: bool = False, samples_for_rate_estimation: int = 1_000, stream_mode: Optional[Literal["fsspec", "remfile", "zarr"]] = None, @@ -535,30 +522,11 @@ def __init__( use_pynwb: bool = False, ): - if stream_mode == "ros3": - warnings.warn( - "The 'ros3' stream_mode is deprecated and will be removed in version 0.103.0. " - "Use 'fsspec' stream_mode instead.", - DeprecationWarning, - ) - if file_path is not None and file is not None: raise ValueError("Provide either file_path or file, not both") if file_path is None and file is None: raise ValueError("Provide either file_path or file") - if electrical_series_name is not None: - warning_msg = ( - "The `electrical_series_name` parameter is deprecated and will be removed in version 0.101.0.\n" - "Use `electrical_series_path` instead." - ) - if electrical_series_path is None: - warning_msg += f"\nSetting `electrical_series_path` to 'acquisition/{electrical_series_name}'." - electrical_series_path = f"acquisition/{electrical_series_name}" - else: - warning_msg += f"\nIgnoring `electrical_series_name` and using the provided `electrical_series_path`." - warnings.warn(warning_msg, DeprecationWarning, stacklevel=2) - self.file_path = file_path self.stream_mode = stream_mode self.stream_cache_path = stream_cache_path @@ -1062,13 +1030,6 @@ def __init__( use_pynwb: bool = False, ): - if stream_mode == "ros3": - warnings.warn( - "The 'ros3' stream_mode is deprecated and will be removed in version 0.103.0. " - "Use 'fsspec' stream_mode instead.", - DeprecationWarning, - ) - self.stream_mode = stream_mode self.stream_cache_path = stream_cache_path self.electrical_series_path = electrical_series_path @@ -1360,6 +1321,49 @@ def get_unit_spike_train( start_frame: Optional[int] = None, end_frame: Optional[int] = None, ) -> np.ndarray: + # Convert frame boundaries to time boundaries + start_time = None + end_time = None + + if start_frame is not None: + start_time = start_frame / self._sampling_frequency + self._t_start + + if end_frame is not None: + end_time = end_frame / self._sampling_frequency + self._t_start + + # Get spike times in seconds + spike_times = self.get_unit_spike_train_in_seconds(unit_id=unit_id, start_time=start_time, end_time=end_time) + + # Convert to frames + frames = np.round((spike_times - self._t_start) * self._sampling_frequency) + return frames.astype("int64", copy=False) + + def get_unit_spike_train_in_seconds( + self, + unit_id, + start_time: Optional[float] = None, + end_time: Optional[float] = None, + ) -> np.ndarray: + """Get the spike train times for a unit in seconds. + + This method returns spike times directly in seconds without conversion + to frames, avoiding double conversion for NWB files that already store + spike times as timestamps. + + Parameters + ---------- + unit_id + The unit id to retrieve spike train for + start_time : float, default: None + The start time in seconds for spike train extraction + end_time : float, default: None + The end time in seconds for spike train extraction + + Returns + ------- + spike_times : np.ndarray + Spike times in seconds + """ # Extract the spike times for the unit unit_index = self.parent_extractor.id_to_index(unit_id) if unit_index == 0: @@ -1369,19 +1373,17 @@ def get_unit_spike_train( end_index = self.spike_times_index_data[unit_index] spike_times = self.spike_times_data[start_index:end_index] - # Transform spike times to frames and subset - frames = np.round((spike_times - self._t_start) * self._sampling_frequency) - + # Filter by time range if specified start_index = 0 - if start_frame is not None: - start_index = np.searchsorted(frames, start_frame, side="left") + if start_time is not None: + start_index = np.searchsorted(spike_times, start_time, side="left") - if end_frame is not None: - end_index = np.searchsorted(frames, end_frame, side="left") + if end_time is not None: + end_index = np.searchsorted(spike_times, end_time, side="left") else: - end_index = frames.size + end_index = spike_times.size - return frames[start_index:end_index].astype("int64", copy=False) + return spike_times[start_index:end_index].astype("float64", copy=False) def _find_timeseries_from_backend(group, path="", result=None, backend="hdf5"): diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index 0eb43535aa..ecab70a5b3 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -40,6 +40,7 @@ Spike2RecordingExtractor, EDFRecordingExtractor, Plexon2RecordingExtractor, + AxonRecordingExtractor, ) from spikeinterface.extractors.extractor_classes import KiloSortSortingExtractor @@ -104,6 +105,7 @@ class SpikeGLXRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ("spikeglx/Noise4Sam_g0", {"stream_id": "imec0.lf"}), ("spikeglx/Noise4Sam_g0", {"stream_id": "nidq"}), ("spikeglx/Noise4Sam_g0", {"stream_id": "imec0.ap-SYNC"}), + ("spikeglx/onebox/run_with_only_adc/myRun_g0", {"stream_id": "obx0"}), ] @@ -218,7 +220,7 @@ class NeuroNexusRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = NeuroNexusRecordingExtractor downloads = ["neuronexus"] entities = [ - ("neuronexus/allego_1/allego_2__uid0701-13-04-49.xdat.json", {"stream_id": "0"}), + ("neuronexus/allego_1/allego_2__uid0701-13-04-49.xdat.json", {"stream_id": "ai-pri"}), ] @@ -294,6 +296,12 @@ class TdTRecordingTest(RecordingCommonTestSuite, unittest.TestCase): entities = [("tdt/aep_05", {"stream_id": "1"})] +class AxonRecordingTest(RecordingCommonTestSuite, unittest.TestCase): + ExtractorClass = AxonRecordingExtractor + downloads = ["axon"] + entities = ["axon/extracellular_data/four_electrodes/24606005_SampleData.abf"] + + class AxonaRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = AxonaRecordingExtractor downloads = ["axona"] diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors.py b/src/spikeinterface/extractors/tests/test_nwbextractors.py index 15d3e8fee9..c2422600e4 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors.py @@ -207,24 +207,6 @@ def test_nwb_extractor_channel_ids_retrieval(generate_nwbfile, use_pynwb): assert np.array_equal(extracted_channel_ids, expected_channel_ids) -@pytest.mark.parametrize("use_pynwb", [True, False]) -def test_electrical_series_name_backcompatibility(generate_nwbfile, use_pynwb): - """ - Test that the channel_ids are retrieved from the electrodes table ONLY from the corresponding - region of the electrical series - """ - path_to_nwbfile, nwbfile_with_ecephys_content = generate_nwbfile - electrical_series_name_list = ["ElectricalSeries1", "ElectricalSeries2"] - for electrical_series_name in electrical_series_name_list: - with pytest.deprecated_call(): - recording_extractor = NwbRecordingExtractor( - path_to_nwbfile, - electrical_series_name=electrical_series_name, - use_pynwb=use_pynwb, - ) - assert recording_extractor.electrical_series_path == f"acquisition/{electrical_series_name}" - - @pytest.mark.parametrize("use_pynwb", [True, False]) def test_nwb_extractor_property_retrieval(generate_nwbfile, use_pynwb): """ @@ -542,6 +524,107 @@ def test_sorting_extraction_start_time_from_series(tmp_path, use_pynwb): np.testing.assert_allclose(extracted_spike_times1, expected_spike_times1) +@pytest.mark.parametrize("use_pynwb", [True, False]) +def test_get_unit_spike_train_in_seconds(tmp_path, use_pynwb): + """Test that get_unit_spike_train_in_seconds returns accurate timestamps without double conversion.""" + from pynwb import NWBHDF5IO + from pynwb.testing.mock.file import mock_NWBFile + + nwbfile = mock_NWBFile() + + # Add units with known spike times + t_start = 5.0 + sampling_frequency = 1000.0 + spike_times_unit_a = np.array([5.1, 5.2, 5.3, 6.0, 6.5]) # Absolute times + spike_times_unit_b = np.array([5.05, 5.15, 5.25, 5.35, 6.1]) # Absolute times + + nwbfile.add_unit(spike_times=spike_times_unit_a) + nwbfile.add_unit(spike_times=spike_times_unit_b) + + file_path = tmp_path / "test.nwb" + with NWBHDF5IO(path=file_path, mode="w") as io: + io.write(nwbfile) + + sorting_extractor = NwbSortingExtractor( + file_path=file_path, + sampling_frequency=sampling_frequency, + t_start=t_start, + use_pynwb=use_pynwb, + ) + + # Test full spike trains + spike_times_a_direct = sorting_extractor.get_unit_spike_train_in_seconds(unit_id=0) + spike_times_a_legacy = sorting_extractor.get_unit_spike_train(unit_id=0, return_times=True) + + spike_times_b_direct = sorting_extractor.get_unit_spike_train_in_seconds(unit_id=1) + spike_times_b_legacy = sorting_extractor.get_unit_spike_train(unit_id=1, return_times=True) + + # Both methods should return exact timestamps since return_times now uses get_unit_spike_train_in_seconds + np.testing.assert_array_equal(spike_times_a_direct, spike_times_unit_a) + np.testing.assert_array_equal(spike_times_b_direct, spike_times_unit_b) + np.testing.assert_array_equal(spike_times_a_legacy, spike_times_unit_a) + np.testing.assert_array_equal(spike_times_b_legacy, spike_times_unit_b) + + # Test time filtering + start_time = 5.2 + end_time = 6.1 + + # Direct method with time filtering + spike_times_a_filtered = sorting_extractor.get_unit_spike_train_in_seconds( + unit_id=0, start_time=start_time, end_time=end_time + ) + spike_times_b_filtered = sorting_extractor.get_unit_spike_train_in_seconds( + unit_id=1, start_time=start_time, end_time=end_time + ) + + # Expected filtered results + expected_a = spike_times_unit_a[(spike_times_unit_a >= start_time) & (spike_times_unit_a < end_time)] + expected_b = spike_times_unit_b[(spike_times_unit_b >= start_time) & (spike_times_unit_b < end_time)] + + np.testing.assert_array_equal(spike_times_a_filtered, expected_a) + np.testing.assert_array_equal(spike_times_b_filtered, expected_b) + + # Test edge cases + # Start time filtering only + spike_times_from_start = sorting_extractor.get_unit_spike_train_in_seconds(unit_id=0, start_time=5.25) + expected_from_start = spike_times_unit_a[spike_times_unit_a >= 5.25] + np.testing.assert_array_equal(spike_times_from_start, expected_from_start) + + # End time filtering only + spike_times_to_end = sorting_extractor.get_unit_spike_train_in_seconds(unit_id=0, end_time=6.0) + expected_to_end = spike_times_unit_a[spike_times_unit_a < 6.0] + np.testing.assert_array_equal(spike_times_to_end, expected_to_end) + + # Test that direct method avoids frame conversion rounding errors + # by comparing exact values that would be lost in frame conversion + precise_times = np.array([5.1001, 5.1002, 5.1003]) + nwbfile_precise = mock_NWBFile() + nwbfile_precise.add_unit(spike_times=precise_times) + + file_path_precise = tmp_path / "test_precise.nwb" + with NWBHDF5IO(path=file_path_precise, mode="w") as io: + io.write(nwbfile_precise) + + sorting_precise = NwbSortingExtractor( + file_path=file_path_precise, + sampling_frequency=sampling_frequency, + t_start=t_start, + use_pynwb=use_pynwb, + ) + + # Direct method should preserve exact precision + direct_precise = sorting_precise.get_unit_spike_train_in_seconds(unit_id=0) + np.testing.assert_array_equal(direct_precise, precise_times) + + # Both methods should now preserve exact precision since return_times uses get_unit_spike_train_in_seconds + legacy_precise = sorting_precise.get_unit_spike_train(unit_id=0, return_times=True) + # Both methods should be exactly equal since return_times now avoids double conversion + np.testing.assert_array_equal(direct_precise, precise_times) + np.testing.assert_array_equal(legacy_precise, precise_times) + # Verify both methods return identical results + np.testing.assert_array_equal(direct_precise, legacy_precise) + + @pytest.mark.parametrize("use_pynwb", [True, False]) def test_multiple_unit_tables(tmp_path, use_pynwb): from pynwb.misc import Units diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors_streaming.py b/src/spikeinterface/extractors/tests/test_nwbextractors_streaming.py index 84ae3c03bf..404a598713 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors_streaming.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors_streaming.py @@ -73,7 +73,7 @@ def test_recording_s3_nwb_remfile(): assert full_traces.shape == (num_frames, num_chans) assert full_traces.dtype == dtype - if rec.has_scaled(): + if rec.has_scaleable_traces(): trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) assert trace_scaled.dtype == "float32" @@ -103,7 +103,7 @@ def test_recording_s3_nwb_remfile_file_like(tmp_path): assert full_traces.shape == (num_frames, num_chans) assert full_traces.dtype == dtype - if rec.has_scaled(): + if rec.has_scaleable_traces(): trace_scaled = rec.get_traces(segment_index=segment_index, return_scaled=True, end_frame=2) assert trace_scaled.dtype == "float32" diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 278151a930..ff926a998d 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -128,6 +128,9 @@ def _merge_extension_data( return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + return self.data.copy() + def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index dd05bc1bd3..fee5cb4c6f 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -166,9 +166,6 @@ def _merge_extension_data( if unit_involved_in_merge is False: old_to_new_unit_index_map[old_unit_index] = new_sorting_analyzer.sorting.id_to_index(old_unit) - need_to_append = False - delete_from = 1 - correlograms, new_bins = deepcopy(self.get_data()) for new_unit_id, merge_unit_group in zip(new_unit_ids, merge_unit_groups): @@ -199,6 +196,12 @@ def _merge_extension_data( new_data = dict(ccgs=new_correlograms, bins=new_bins) return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # TODO: for now we just copy + new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params) + new_data = dict(ccgs=new_ccgs, bins=new_bins) + return new_data + def _run(self, verbose=False): ccgs, bins = _compute_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params) self.data["ccgs"] = ccgs diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 6089d8eba7..529dbf7182 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -3,6 +3,7 @@ import importlib.util import numpy as np +from itertools import chain from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension @@ -81,6 +82,29 @@ def _merge_extension_data( new_extension_data = dict(isi_histograms=new_isi_hists, bins=new_bins) return new_extension_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + new_bins = self.data["bins"] + arr = self.data["isi_histograms"] + num_dims = arr.shape[1] + all_new_units = new_sorting_analyzer.unit_ids + new_isi_hists = np.zeros((len(all_new_units), num_dims), dtype=arr.dtype) + + # compute all new isi at once + new_unit_ids_f = list(chain(*new_unit_ids)) + new_sorting = new_sorting_analyzer.sorting.select_units(new_unit_ids_f) + only_new_hist, _ = _compute_isi_histograms(new_sorting, **self.params) + + for unit_ind, unit_id in enumerate(all_new_units): + if unit_id not in new_unit_ids_f: + keep_unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + new_isi_hists[unit_ind, :] = arr[keep_unit_index, :] + else: + new_unit_index = new_sorting.id_to_index(unit_id) + new_isi_hists[unit_ind, :] = only_new_hist[new_unit_index, :] + + new_extension_data = dict(isi_histograms=new_isi_hists, bins=new_bins) + return new_extension_data + def _run(self, verbose=False): isi_histograms, bins = _compute_isi_histograms(self.sorting_analyzer.sorting, **self.params) self.data["isi_histograms"] = isi_histograms diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 8eb343c209..f5c1a74848 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -150,6 +150,10 @@ def _merge_extension_data( new_data[k] = v return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # splitting only changes random spikes assignments + return self.data.copy() + def get_pca_model(self): """ Returns the scikit-learn PCA model objects. diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index cd20f48ffc..8e95a28dd7 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -72,6 +72,10 @@ def _merge_extension_data( return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # splitting only changes random spikes assignments + return self.data.copy() + def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording @@ -111,7 +115,7 @@ def _run(self, verbose=False, **job_kwargs): ) self.data["amplitudes"] = amps - def _get_data(self, outputs="numpy"): + def _get_data(self, outputs="numpy", concatenated=False): all_amplitudes = self.data["amplitudes"] if outputs == "numpy": return all_amplitudes @@ -125,6 +129,16 @@ def _get_data(self, outputs="numpy"): for unit_id in unit_ids: inds = spike_indices[segment_index][unit_id] amplitudes_by_units[segment_index][unit_id] = all_amplitudes[inds] + + if concatenated: + amplitudes_by_units_concatenated = { + unit_id: np.concatenate( + [amps_in_segment[unit_id] for amps_in_segment in amplitudes_by_units.values()] + ) + for unit_id in unit_ids + } + return amplitudes_by_units_concatenated + return amplitudes_by_units else: raise ValueError(f"Wrong .get_data(outputs={outputs}); possibilities are `numpy` or `by_unit`") diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 54c4a2c164..700abe2321 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -103,6 +103,10 @@ def _merge_extension_data( ### in a merged could be different. Should be discussed return dict(spike_locations=new_spike_locations) + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # splitting only changes random spikes assignments + return self.data.copy() + def _get_pipeline_nodes(self): from spikeinterface.sortingcomponents.peak_localization import get_localization_pipeline_nodes @@ -142,7 +146,7 @@ def _run(self, verbose=False, **job_kwargs): ) self.data["spike_locations"] = spike_locations - def _get_data(self, outputs="numpy"): + def _get_data(self, outputs="numpy", concatenated=False): all_spike_locations = self.data["spike_locations"] if outputs == "numpy": return all_spike_locations @@ -156,6 +160,16 @@ def _get_data(self, outputs="numpy"): for unit_id in unit_ids: inds = spike_indices[segment_index][unit_id] spike_locations_by_units[segment_index][unit_id] = all_spike_locations[inds] + + if concatenated: + locations_by_units_concatenated = { + unit_id: np.concatenate( + [locs_in_segment[unit_id] for locs_in_segment in spike_locations_by_units.values()] + ) + for unit_id in unit_ids + } + return locations_by_units_concatenated + return spike_locations_by_units else: raise ValueError(f"Wrong .get_data(outputs={outputs})") diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 66e315d121..7de9db9649 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -8,6 +8,7 @@ import numpy as np import warnings +from itertools import chain from copy import deepcopy from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension @@ -193,6 +194,26 @@ def _merge_extension_data( new_data = dict(metrics=metrics) return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + import pandas as pd + + metric_names = self.params["metric_names"] + old_metrics = self.data["metrics"] + + all_unit_ids = new_sorting_analyzer.unit_ids + new_unit_ids_f = list(chain(*new_unit_ids)) + not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids_f)] + + metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) + + metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] + metrics.loc[new_unit_ids_f, :] = self._compute_metrics( + new_sorting_analyzer, new_unit_ids_f, verbose, metric_names, **job_kwargs + ) + + new_data = dict(metrics=metrics) + return new_data + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): """ Compute template metrics. diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 4f032dc784..2771d382b0 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -2,6 +2,7 @@ import numpy as np import warnings +from itertools import chain import importlib.util from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension @@ -129,6 +130,58 @@ def _merge_extension_data( return dict(similarity=similarity) + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + num_shifts = int(self.params["max_lag_ms"] * self.sorting_analyzer.sampling_frequency / 1000) + all_templates_array = get_dense_templates_array( + new_sorting_analyzer, return_scaled=self.sorting_analyzer.return_scaled + ) + + new_unit_ids_f = list(chain(*new_unit_ids)) + keep = np.isin(new_sorting_analyzer.unit_ids, new_unit_ids_f) + new_templates_array = all_templates_array[keep, :, :] + if new_sorting_analyzer.sparsity is None: + new_sparsity = None + else: + new_sparsity = ChannelSparsity( + new_sorting_analyzer.sparsity.mask[keep, :], new_unit_ids_f, new_sorting_analyzer.channel_ids + ) + + new_similarity = compute_similarity_with_templates_array( + new_templates_array, + all_templates_array, + method=self.params["method"], + num_shifts=num_shifts, + support=self.params["support"], + sparsity=new_sparsity, + other_sparsity=new_sorting_analyzer.sparsity, + ) + + old_similarity = self.data["similarity"] + + all_new_unit_ids = new_sorting_analyzer.unit_ids + n = all_new_unit_ids.size + similarity = np.zeros((n, n), dtype=old_similarity.dtype) + + # copy old similarity + for unit_ind1, unit_id1 in enumerate(all_new_unit_ids): + if unit_id1 not in new_unit_ids_f: + old_ind1 = self.sorting_analyzer.sorting.id_to_index(unit_id1) + for unit_ind2, unit_id2 in enumerate(all_new_unit_ids): + if unit_id2 not in new_unit_ids_f: + old_ind2 = self.sorting_analyzer.sorting.id_to_index(unit_id2) + s = self.data["similarity"][old_ind1, old_ind2] + similarity[unit_ind1, unit_ind2] = s + similarity[unit_ind2, unit_ind1] = s + + # insert new similarity both way + for unit_ind, unit_id in enumerate(all_new_unit_ids): + if unit_id in new_unit_ids_f: + new_index = list(new_unit_ids_f).index(unit_id) + similarity[unit_ind, :] = new_similarity[new_index, :] + similarity[:, unit_ind] = new_similarity[new_index, :] + + return dict(similarity=similarity) + def _run(self, verbose=False): num_shifts = int(self.params["max_lag_ms"] * self.sorting_analyzer.sampling_frequency / 1000) templates_array = get_dense_templates_array( diff --git a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py index be0070d94a..87ff3cdeb7 100644 --- a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py +++ b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py @@ -11,8 +11,59 @@ ) from spikeinterface.core.generate import inject_some_split_units +# even if this is in postprocessing, we make an extension for quality metrics +extension_dict = { + "noise_levels": dict(), + "random_spikes": dict(), + "waveforms": dict(), + "templates": dict(), + "principal_components": dict(), + "spike_amplitudes": dict(), + "template_similarity": dict(), + "correlograms": dict(), + "isi_histograms": dict(), + "amplitude_scalings": dict(handle_collisions=False), # otherwise hard mode could fail due to dropped spikes + "spike_locations": dict(method="center_of_mass"), # trick to avoid UserWarning + "unit_locations": dict(), + "template_metrics": dict(), + "quality_metrics": dict(metric_names=["firing_rate", "isi_violation", "snr"]), +} +extension_data_type = { + "noise_levels": None, + "templates": "unit", + "isi_histograms": "unit", + "unit_locations": "unit", + "spike_amplitudes": "spike", + "amplitude_scalings": "spike", + "spike_locations": "spike", + "quality_metrics": "pandas", + "template_metrics": "pandas", + "correlograms": "matrix", + "template_similarity": "matrix", + "principal_components": "random", + "waveforms": "random", + "random_spikes": "random_spikes", +} +data_with_miltiple_returns = ["isi_histograms", "correlograms"] +# due to incremental PCA, hard computation could result in different results for PCA +# the model is differents always +random_computation = ["principal_components"] +# for some extensions (templates, amplitude_scalings), since the templates slightly change for merges/splits +# we allow a relative tolerance +# (amplitud_scalings are the moste sensitive!) +extensions_with_rel_tolerance_merge = { + "amplitude_scalings": 1e-1, + "templates": 1e-3, + "template_similarity": 1e-3, + "unit_locations": 1e-3, + "template_metrics": 1e-3, + "quality_metrics": 1e-3, +} +extensions_with_rel_tolerance_splits = {"amplitude_scalings": 1e-1} -def get_dataset(): + +def get_dataset_to_merge(): + # generate a dataset with some split units to minimize merge errors recording, sorting = generate_ground_truth_recording( durations=[30.0], sampling_frequency=16000.0, @@ -20,6 +71,7 @@ def get_dataset(): num_units=10, generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=2.0, maximum_z=15.0, minimum_distance=20), seed=2205, ) @@ -36,70 +88,65 @@ def get_dataset(): sort_by_amp = np.argsort(list(get_template_extremum_amplitude(analyzer_raw).values()))[::-1] split_ids = sorting.unit_ids[sort_by_amp][:3] - sorting_with_splits, other_ids = inject_some_split_units( + sorting_with_splits, split_unit_ids = inject_some_split_units( sorting, num_split=3, split_ids=split_ids, output_ids=True, seed=0 ) - return recording, sorting_with_splits, other_ids + return recording, sorting_with_splits, split_unit_ids + + +def get_dataset_to_split(): + # generate a dataset and return large unit to split to minimize split errors + recording, sorting = generate_ground_truth_recording( + durations=[30.0], + sampling_frequency=16000.0, + num_channels=10, + num_units=10, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + seed=2205, + ) + + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + + # since templates are going to be averaged and this might be a problem for amplitude scaling + # we select the 3 units with the largest templates to split + analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) + analyzer_raw.compute(["random_spikes", "templates"]) + # select 3 largest templates to split + sort_by_amp = np.argsort(list(get_template_extremum_amplitude(analyzer_raw).values()))[::-1] + large_units = sorting.unit_ids[sort_by_amp][:2] + + return recording, sorting, large_units @pytest.fixture(scope="module") -def dataset(): - return get_dataset() +def dataset_to_merge(): + return get_dataset_to_merge() + + +@pytest.fixture(scope="module") +def dataset_to_split(): + return get_dataset_to_split() @pytest.mark.parametrize("sparse", [False, True]) -def test_SortingAnalyzer_merge_all_extensions(dataset, sparse): +def test_SortingAnalyzer_merge_all_extensions(dataset_to_merge, sparse): set_global_job_kwargs(n_jobs=1) - recording, sorting, other_ids = dataset + recording, sorting, other_ids = dataset_to_merge sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=sparse) + extension_dict_merge = extension_dict.copy() # we apply the merges according to the artificial splits merges = [list(v) for v in other_ids.values()] split_unit_ids = np.ravel(merges) unmerged_unit_ids = sorting_analyzer.unit_ids[~np.isin(sorting_analyzer.unit_ids, split_unit_ids)] - # even if this is in postprocessing, we make an extension for quality metrics - extension_dict = { - "noise_levels": dict(), - "random_spikes": dict(), - "waveforms": dict(), - "templates": dict(), - "principal_components": dict(), - "spike_amplitudes": dict(), - "template_similarity": dict(), - "correlograms": dict(), - "isi_histograms": dict(), - "amplitude_scalings": dict(handle_collisions=False), # otherwise hard mode could fail due to dropped spikes - "spike_locations": dict(method="center_of_mass"), # trick to avoid UserWarning - "unit_locations": dict(), - "template_metrics": dict(), - "quality_metrics": dict(metric_names=["firing_rate", "isi_violation", "snr"]), - } - extension_data_type = { - "noise_levels": None, - "templates": "unit", - "isi_histograms": "unit", - "unit_locations": "unit", - "spike_amplitudes": "spike", - "amplitude_scalings": "spike", - "spike_locations": "spike", - "quality_metrics": "pandas", - "template_metrics": "pandas", - "correlograms": "matrix", - "template_similarity": "matrix", - "principal_components": "random", - "waveforms": "random", - "random_spikes": "random_spikes", - } - data_with_miltiple_returns = ["isi_histograms", "correlograms"] - - # due to incremental PCA, hard computation could result in different results for PCA - # the model is differents always - random_computation = ["principal_components"] - - sorting_analyzer.compute(extension_dict, n_jobs=1) + sorting_analyzer.compute(extension_dict_merge, n_jobs=1) # TODO: still some UserWarnings for n_jobs, where from? t0 = time.perf_counter() @@ -155,14 +202,95 @@ def test_SortingAnalyzer_merge_all_extensions(dataset, sparse): ) if ext not in random_computation: + if ext in extensions_with_rel_tolerance_merge: + rtol = extensions_with_rel_tolerance_merge[ext] + else: + rtol = 0 if extension_data_type[ext] == "pandas": data_hard_merged = data_hard_merged.dropna().to_numpy().astype("float") data_soft_merged = data_soft_merged.dropna().to_numpy().astype("float") if data_hard_merged.dtype.fields is None: - assert np.allclose(data_hard_merged, data_soft_merged, rtol=0.1) + if not np.allclose(data_hard_merged, data_soft_merged, rtol=rtol): + max_error = np.max(np.abs(data_hard_merged - data_soft_merged)) + raise Exception(f"Failed for {ext} - max error {max_error}") else: for f in data_hard_merged.dtype.fields: - assert np.allclose(data_hard_merged[f], data_soft_merged[f], rtol=0.1) + if not np.allclose(data_hard_merged[f], data_soft_merged[f], rtol=rtol): + max_error = np.max(np.abs(data_hard_merged[f] - data_soft_merged[f])) + raise Exception(f"Failed for {ext} - field {f} - max error {max_error}") + + +@pytest.mark.parametrize("sparse", [False, True]) +def test_SortingAnalyzer_split_all_extensions(dataset_to_split, sparse): + set_global_job_kwargs(n_jobs=1) + + recording, sorting, units_to_split = dataset_to_split + + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=sparse) + extension_dict_split = extension_dict.copy() + sorting_analyzer.compute(extension_dict, n_jobs=1) + + # we randomly apply splits (at half of spiketrain) + num_spikes = sorting.count_num_spikes_per_unit() + + unsplit_unit_ids = sorting_analyzer.unit_ids[~np.isin(sorting_analyzer.unit_ids, units_to_split)] + splits = {} + for unit in units_to_split: + splits[unit] = [np.arange(num_spikes[unit] // 2), np.arange(num_spikes[unit] // 2, num_spikes[unit])] + + analyzer_split, split_unit_ids = sorting_analyzer.split_units(split_units=splits, return_new_unit_ids=True) + split_unit_ids = list(np.concatenate(split_unit_ids)) + + # also do a full recopute + analyzer_hard = create_sorting_analyzer(analyzer_split.sorting, recording, format="memory", sparse=sparse) + # we propagate random spikes to avoid random spikes to be recomputed + extension_dict_ = extension_dict_split.copy() + extension_dict_.pop("random_spikes") + analyzer_hard.extensions["random_spikes"] = analyzer_split.extensions["random_spikes"] + analyzer_hard.compute(extension_dict_, n_jobs=1) + + for ext in extension_dict: + # 1. check that data are exactly the same for unchanged units between original/split + data_original = sorting_analyzer.get_extension(ext).get_data() + data_split = analyzer_split.get_extension(ext).get_data() + data_recompute = analyzer_hard.get_extension(ext).get_data() + if ext in data_with_miltiple_returns: + data_original = data_original[0] + data_split = data_split[0] + data_recompute = data_recompute[0] + data_original_unsplit = get_extension_data_for_units( + sorting_analyzer, data_original, unsplit_unit_ids, extension_data_type[ext] + ) + data_split_unsplit = get_extension_data_for_units( + analyzer_split, data_split, unsplit_unit_ids, extension_data_type[ext] + ) + + np.testing.assert_array_equal(data_original_unsplit, data_split_unsplit) + + # 2. check that split data are the same for extension split and recompute + data_split_soft = get_extension_data_for_units( + analyzer_split, data_split, split_unit_ids, extension_data_type[ext] + ) + data_split_hard = get_extension_data_for_units( + analyzer_hard, data_recompute, split_unit_ids, extension_data_type[ext] + ) + if ext not in random_computation: + if ext in extensions_with_rel_tolerance_splits: + rtol = extensions_with_rel_tolerance_splits[ext] + else: + rtol = 0 + if extension_data_type[ext] == "pandas": + data_split_soft = data_split_soft.dropna().to_numpy().astype("float") + data_split_hard = data_split_hard.dropna().to_numpy().astype("float") + if data_split_hard.dtype.fields is None: + if not np.allclose(data_split_hard, data_split_soft, rtol=rtol): + max_error = np.max(np.abs(data_split_hard - data_split_soft)) + raise Exception(f"Failed for {ext} - max error {max_error}") + else: + for f in data_split_hard.dtype.fields: + if not np.allclose(data_split_hard[f], data_split_soft[f], rtol=rtol): + max_error = np.max(np.abs(data_split_hard[f] - data_split_soft[f])) + raise Exception(f"Failed for {ext} - field {f} - max error {max_error}") def get_extension_data_for_units(sorting_analyzer, data, unit_ids, ext_data_type): @@ -191,5 +319,5 @@ def get_extension_data_for_units(sorting_analyzer, data, unit_ids, ext_data_type if __name__ == "__main__": - dataset = get_dataset() + dataset = get_dataset_to_merge() test_SortingAnalyzer_merge_all_extensions(dataset, False) diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 53bebd52c5..ea297f7b6c 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +from itertools import chain from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from .localization_tools import _unit_location_methods @@ -87,6 +88,30 @@ def _merge_extension_data( return dict(unit_locations=unit_location) + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + old_unit_locations = self.data["unit_locations"] + num_dims = old_unit_locations.shape[1] + + method = self.params.get("method") + method_kwargs = self.params.copy() + method_kwargs.pop("method") + func = _unit_location_methods[method] + new_unit_ids_f = list(chain(*new_unit_ids)) + new_unit_locations = func(new_sorting_analyzer, unit_ids=new_unit_ids_f, **method_kwargs) + assert new_unit_locations.shape[0] == len(new_unit_ids_f) + + all_new_unit_ids = new_sorting_analyzer.unit_ids + unit_location = np.zeros((len(all_new_unit_ids), num_dims), dtype=old_unit_locations.dtype) + for unit_index, unit_id in enumerate(all_new_unit_ids): + if unit_id not in new_unit_ids_f: + old_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + unit_location[unit_index] = old_unit_locations[old_index] + else: + new_index = list(new_unit_ids_f).index(unit_id) + unit_location[unit_index] = new_unit_locations[new_index] + + return dict(unit_locations=unit_location) + def _run(self, verbose=False): method = self.params.get("method") method_kwargs = self.params.copy() diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 9fcb8fdbf3..2243412663 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -568,6 +568,19 @@ def correct_motion( def save_motion_info(motion_info, folder, overwrite=False): + """ + Saves motion info + + Parameters + ---------- + motion_info : dict + The returned motion_info from running `compute_motion` + folder : str | Path + The path for saving the `motion_info` + overwrite : bool, default: False + Whether to overwrite the folder location when saving motion info + + """ folder = Path(folder) if folder.is_dir(): if not overwrite: @@ -590,6 +603,20 @@ def save_motion_info(motion_info, folder, overwrite=False): def load_motion_info(folder): + """ + Loads a motion info dict from folder + + Parameters + ---------- + folder : str | Path + The folder containing the motion info to load + + Notes + ----- + Loads both current Motion implementation as well as the + legacy Motion format + + """ from spikeinterface.core.motion import Motion folder = Path(folder) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 3a3aab0bf4..a2a04a656a 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -805,17 +805,12 @@ def compute_amplitude_cv_metrics( def _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign): # used by compute_amplitude_cutoffs and compute_amplitude_medians - amplitudes_by_units = {} - if sorting_analyzer.has_extension("spike_amplitudes"): - spikes = sorting_analyzer.sorting.to_spike_vector() - ext = sorting_analyzer.get_extension("spike_amplitudes") - all_amplitudes = ext.get_data() - for unit_id in unit_ids: - unit_index = sorting_analyzer.sorting.id_to_index(unit_id) - spike_mask = spikes["unit_index"] == unit_index - amplitudes_by_units[unit_id] = all_amplitudes[spike_mask] + + if (spike_amplitudes_extension := sorting_analyzer.get_extension("spike_amplitudes")) is not None: + return spike_amplitudes_extension.get_data(outputs="by_unit", concatenated=True) elif sorting_analyzer.has_extension("waveforms"): + amplitudes_by_units = {} waveforms_ext = sorting_analyzer.get_extension("waveforms") before = waveforms_ext.nbefore extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index f4e36b24c0..9b8618f990 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -229,28 +229,6 @@ def compute_pc_metrics( return pc_metrics -def calculate_pc_metrics( - sorting_analyzer, metric_names=None, metric_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False -): - warnings.warn( - "The `calculate_pc_metrics` function is deprecated and will be removed in 0.103.0. Please use compute_pc_metrics instead", - category=DeprecationWarning, - stacklevel=2, - ) - - pc_metrics = compute_pc_metrics( - sorting_analyzer, - metric_names=metric_names, - metric_params=metric_params, - unit_ids=unit_ids, - seed=seed, - n_jobs=n_jobs, - progress_bar=progress_bar, - ) - - return pc_metrics - - ################################################################# # Code from spikemetrics diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 134849e70f..055fefc78c 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -3,6 +3,7 @@ from __future__ import annotations import warnings +from itertools import chain from copy import deepcopy import numpy as np @@ -158,6 +159,33 @@ def _merge_extension_data( new_data = dict(metrics=metrics) return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + import pandas as pd + + metric_names = self.params["metric_names"] + old_metrics = self.data["metrics"] + + all_unit_ids = new_sorting_analyzer.unit_ids + new_unit_ids_f = list(chain(*new_unit_ids)) + not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids_f)] + + # this creates a new metrics dictionary, but the dtype for everything will be + # object. So we will need to fix this later after computing metrics + metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) + metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] + metrics.loc[new_unit_ids_f, :] = self._compute_metrics( + new_sorting_analyzer, new_unit_ids_f, verbose, metric_names, **job_kwargs + ) + + # we need to fix the dtypes after we compute everything because we have nans + # we can iterate through the columns and convert them back to the dtype + # of the original quality dataframe. + for column in old_metrics.columns: + metrics[column] = metrics[column].astype(old_metrics[column].dtype) + + new_data = dict(metrics=metrics) + return new_data + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): """ Compute quality metrics. diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 23b781eb9d..f7411f6376 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -22,7 +22,6 @@ from .pca_metrics import ( compute_pc_metrics, - calculate_pc_metrics, # remove after 0.103.0 mahalanobis_metrics, lda_metrics, nearest_neighbors_metrics, diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 287439a4f7..1491b9eac1 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -4,7 +4,7 @@ from spikeinterface.qualitymetrics import compute_pc_metrics, get_quality_pca_metric_list -def test_calculate_pc_metrics(small_sorting_analyzer): +def test_compute_pc_metrics(small_sorting_analyzer): import pandas as pd sorting_analyzer = small_sorting_analyzer diff --git a/src/spikeinterface/sorters/external/rt_sort.py b/src/spikeinterface/sorters/external/rt_sort.py new file mode 100644 index 0000000000..23e1ca5ffb --- /dev/null +++ b/src/spikeinterface/sorters/external/rt_sort.py @@ -0,0 +1,192 @@ +import importlib.util +import os +import numpy as np + +from ..basesorter import BaseSorter + +from spikeinterface.extractors import NumpySorting # TODO: Create separate sorting extractor for RT-Sort + + +class RTSortSorter(BaseSorter): + """RTSort sorter object""" + + sorter_name = "rt-sort" + + _default_params = { + "detection_model": "neuropixels", + "recording_window_ms": None, + "stringent_thresh": 0.175, + "loose_thresh": 0.075, + "inference_scaling_numerator": 15.4, + "ms_before": 0.5, + "ms_after": 0.5, + "pre_median_ms": 50, + "inner_radius": 50, + "outer_radius": 100, + "min_elecs_for_array_noise_n": 100, + "min_elecs_for_array_noise_f": 0.1, + "min_elecs_for_seq_noise_n": 50, + "min_elecs_for_seq_noise_f": 0.05, + "min_activity_root_cocs": 2, + "min_activity_hz": 0.05, + "max_n_components_latency": 4, + "min_coc_n": 10, + "min_coc_p": 10, + "min_extend_comp_p": 50, + "elec_patience": 6, + "split_coc_clusters_amps": True, + "min_amp_dist_p": 0.1, + "max_n_components_amp": 4, + "min_loose_elec_prob": 0.03, + "min_inner_loose_detections": 3, + "min_loose_detections_n": 4, + "min_loose_detections_r_spikes": 1 / 3, + "min_loose_detections_r_sequences": 1 / 3, + "max_latency_diff_spikes": 2.5, + "max_latency_diff_sequences": 2.5, + "clip_latency_diff_factor": 2, + "max_amp_median_diff_spikes": 0.45, + "max_amp_median_diff_sequences": 0.45, + "clip_amp_median_diff_factor": 2, + "max_root_amp_median_std_spikes": 2.5, + "max_root_amp_median_std_sequences": 2.5, + "repeated_detection_overlap_time": 0.2, + "min_seq_spikes_n": 10, + "min_seq_spikes_hz": 0.05, + "relocate_root_min_amp": 0.8, + "relocate_root_max_latency": -2, + "device": "cuda", + "num_processes": None, + "ignore_warnings": True, + "debug": False, + } + + _params_description = { + "detection_model": "`mea` or `neuropixels` to use the mea- or neuropixels-trained detection models. Or, path to saved detection model (see https://braingeneers.github.io/braindance/docs/RT-sort/usage/training-models) (`str`, `neuropixels`)", + "recording_window_ms": "A tuple `(start_ms, end_ms)` defining the portion of the recording (in milliseconds) to process. (`tuple`, `None`)", + "stringent_thresh": "The stringent threshold for spike detection. (`float`, `0.175`)", + "loose_thresh": "The loose threshold for spike detection. (`float`, `0.075`)", + "inference_scaling_numerator": "Scaling factor for inference. (`float`, `15.4`)", + "ms_before": "Time (in milliseconds) to consider before each detected spike for sequence formation. (`float`, `0.5`)", + "ms_after": "Time (in milliseconds) to consider after each detected spike for sequence formation. (`float`, `0.5`)", + "pre_median_ms": "Duration (in milliseconds) to compute the median for normalization. (`float`, `50`)", + "inner_radius": "Inner radius (in micrometers). (`float`, `50`)", + "outer_radius": "Outer radius (in micrometers). (`float`, `100`)", + "min_elecs_for_array_noise_n": "Minimum number of electrodes for array-wide noise filtering. (`int`, `100`)", + "min_elecs_for_array_noise_f": "Minimum fraction of electrodes for array-wide noise filtering. (`float`, `0.1`)", + "min_elecs_for_seq_noise_n": "Minimum number of electrodes for sequence-wide noise filtering. (`int`, `50`)", + "min_elecs_for_seq_noise_f": "Minimum fraction of electrodes for sequence-wide noise filtering. (`float`, `0.05`)", + "min_activity_root_cocs": "Minimum number of stringent spike detections on inner electrodes within the maximum propagation window that cause a stringent spike detection on a root electrode to be counted as a stringent codetection. (`int`, `2`)", + "min_activity_hz": "Minimum activity rate of root detections (in Hz) for an electrode to be used as a root electrode. (`float`, `0.05`)", + "max_n_components_latency": "Maximum number of latency components for Gaussian mixture model used for splitting latency distribution. (`int`, `4`)", + "min_coc_n": "After splitting a cluster of codetections, a cluster is discarded if it does not have at least min_coc_n codetections. (`int`, `10`)", + "min_coc_p": "After splitting a cluster of codetections, a cluster is discarded if it does not have at least (min_coc_p * the total number of codetections before splitting) codetections. (`int`, `10`)", + "min_extend_comp_p": "The required percentage of codetections before splitting that is preserved after the split in order for the inner electrodes of the current splitting electrode to be added to the total list of electrodes used to further split the cluster. (`int`, `50`)", + "elec_patience": "Number of electrodes considered for splitting that do not lead to a split before terminating the splitting process. (`int`, `6`)", + "split_coc_clusters_amps": "Whether to split clusters based on amplitude. (`bool`, `True`)", + "min_amp_dist_p": "The minimum Hartigan's dip test p-value for a distribution to be considered unimodal. (`float`, `0.1`)", + "max_n_components_amp": "Maximum number of components for Gaussian mixture model used for splitting amplitude distribution. (`int`, `4`)", + "min_loose_elec_prob": "Minimum average detection score (smaller values are set to 0) in decimal form (ranging from 0 to 1). (`float`, `0.03`)", + "min_inner_loose_detections": "Minimum inner loose electrode detections for assigning spikes / overlaps for merging. (`int`, `3`)", + "min_loose_detections_n": "Minimum loose electrode detections for assigning spikes / overlaps for merging. (`int`, `4`)", + "min_loose_detections_r_spikes": "Minimum ratio of loose electrode detections for assigning spikes. (`float`, `1/3`)", + "min_loose_detections_r_sequences": "Minimum ratio of loose electrode detections overlaps for merging. (`float`, `1/3`)", + "max_latency_diff_spikes": "Maximum allowed weighted latency difference for spike assignment. (`float`, `2.5`)", + "max_latency_diff_sequences": "Maximum allowed weighted latency difference for sequence merging. (`float`, `2.5`)", + "clip_latency_diff_factor": "Latency clip = clip_latency_diff_factor * max_latency_diff. (`float`, `2`)", + "max_amp_median_diff_spikes": "Maximum allowed weighted percent amplitude difference for spike assignment. (`float`, `0.45`)", + "max_amp_median_diff_sequences": "Maximum allowed weighted percent amplitude difference for sequence merging. (`float`, `0.45`)", + "clip_amp_median_diff_factor": "Amplitude clip = clip_amp_median_diff_factor * max_amp_median_diff. (`float`, `2`)", + "max_root_amp_median_std_spikes": "Maximum allowed root amplitude standard deviation for spike assignment. (`float`, `2.5`)", + "max_root_amp_median_std_sequences": "Maximum allowed root amplitude standard deviation for sequence merging. (`float`, `2.5`)", + "repeated_detection_overlap_time": "Time window (in seconds) for overlapping repeated detections. (`float`, `0.2`)", + "min_seq_spikes_n": "Minimum number of spikes required for a valid sequence. (`int`, `10`)", + "min_seq_spikes_hz": "Minimum spike rate for a valid sequence. (`float`, `0.05`)", + "relocate_root_min_amp": "Minimum amplitude ratio for relocating a root electrode before first merging. (`float`, `0.8`)", + "relocate_root_max_latency": "Maximum latency for relocating a root electrode before first merging. (`float`, `-2`)", + "device": "The device for PyTorch operations ('cuda' or 'cpu'). (`str`, `cuda`)", + "num_processes": "Number of processes to use for parallelization. (`int`, `None`)", + "ignore_warnings": "Whether to suppress warnings during execution. (`bool`, `True`)", + "debug": "Whether to enable debugging features such as saving intermediate steps. (`bool`, `False`)", + } + + sorter_description = """RT-Sort is a real-time spike sorting algorithm that enables the sorted detection of action potentials within 7.5ms±1.5ms (mean±STD) after the waveform trough while the recording remains ongoing. + It utilizes unique propagation patterns of action potentials along axons detected as high-fidelity sequential activations on adjacent electrodes, together with a convolutional neural network-based spike detection algorithm. + This implementation in SpikeInterface only implements RT-Sort's offline sorting. + For more information see https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0312438""" + + installation_mesg = f"""\nTo use RTSort run:\n + >>> pip install git+https://github.com/braingeneers/braindance#egg=braindance[rt-sort] + + Additionally, install PyTorch (https://pytorch.org/get-started/locally/) with any version of CUDA as the compute platform. + If running on a Linux machine, install Torch-TensorRT (https://pytorch.org/TensorRT/getting_started/installation.html) for faster computations. + + More information on RTSort at: https://github.com/braingeneers/braindance + """ + + handle_multi_segment = False + + @classmethod + def get_sorter_version(cls): + import braindance + + return braindance.__version__ + + @classmethod + def is_installed(cls): + libraries = ["braindance", "torch" if os.name == "nt" else "torch_tensorrt", "diptest", "pynvml", "sklearn"] + + HAVE_RTSORT = True + for lib in libraries: + if importlib.util.find_spec(lib) is None: + HAVE_RTSORT = False + break + + return HAVE_RTSORT + + @classmethod + def _check_params(cls, recording, output_folder, params): + return params + + @classmethod + def _check_apply_filter_in_params(cls, params): + return False + + @classmethod + def _setup_recording(cls, recording, sorter_output_folder, params, verbose): + # nothing to copy inside the folder : RTSort uses spikeinterface natively + pass + + @classmethod + def _run_from_folder(cls, sorter_output_folder, params, verbose): + from braindance.core.spikesorter.rt_sort import detect_sequences + from braindance.core.spikedetector.model import ModelSpikeSorter + + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) + + params = params.copy() + params["recording"] = recording + rt_sort_inter = sorter_output_folder / "rt_sort_inter" + params["inter_path"] = rt_sort_inter + params["verbose"] = verbose + + if params["detection_model"] == "mea": + params["detection_model"] = ModelSpikeSorter.load_mea() + elif params["detection_model"] == "neuropixels": + params["detection_model"] = ModelSpikeSorter.load_neuropixels() + else: + params["detection_model"] = ModelSpikeSorter.load(params["detection_model"]) + + rt_sort = detect_sequences(**params, delete_inter=False, return_spikes=False) + np_sorting = rt_sort.sort_offline( + rt_sort_inter / "scaled_traces.npy", + verbose=verbose, + recording_window_ms=params.get("recording_window_ms", None), + return_spikeinterface_sorter=True, + ) # type: NumpySorting + rt_sort.save(sorter_output_folder / "rt_sort.pickle") + np_sorting.save(folder=sorter_output_folder / "rt_sorting") + + @classmethod + def _get_result_from_folder(cls, sorter_output_folder): + return NumpySorting.load_from_folder(sorter_output_folder / "rt_sorting") diff --git a/src/spikeinterface/sorters/external/tests/test_rtsort.py b/src/spikeinterface/sorters/external/tests/test_rtsort.py new file mode 100644 index 0000000000..1eb28e7591 --- /dev/null +++ b/src/spikeinterface/sorters/external/tests/test_rtsort.py @@ -0,0 +1,16 @@ +import unittest +import pytest + +from spikeinterface.sorters import RTSortSorter +from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite + + +@pytest.mark.skipif(not RTSortSorter.is_installed(), reason="rt-sort not installed") +class RTSortSorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): + SorterClass = RTSortSorter + + +if __name__ == "__main__": + test = RTSortSorterCommonTestSuite() + test.setUp() + test.test_with_run() diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index 137ff98cdb..122c5d4908 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -233,13 +233,11 @@ def run_sorter_by_property( recording, grouping_property, folder, - mode_if_folder_exists=None, engine="loop", engine_kwargs=None, verbose=False, docker_image=None, singularity_image=None, - working_folder: None = None, **sorter_params, ): """ @@ -261,10 +259,6 @@ def run_sorter_by_property( Property to split by before sorting folder : str | Path The working directory. - mode_if_folder_exists : bool or None, default: None - Must be None. This is deprecated. - If not None then a warning is raise. - Will be removed in next release. engine : "loop" | "joblib" | "dask" | "slurm", default: "loop" Which engine to use to run sorter. engine_kwargs : dict @@ -294,20 +288,6 @@ def run_sorter_by_property( engine_kwargs={"n_jobs": 4}) """ - if mode_if_folder_exists is not None: - warnings.warn( - "run_sorter_by_property(): mode_if_folder_exists is not used anymore and will be removed in 0.102.0", - DeprecationWarning, - stacklevel=2, - ) - - if working_folder is not None: - warnings.warn( - "`working_folder` is deprecated and will be removed in 0.103. Please use folder instead", - category=DeprecationWarning, - stacklevel=2, - ) - folder = working_folder working_folder = Path(folder).absolute() diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index bd5d9b3529..2799222e28 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -70,7 +70,7 @@ ---------- sorter_name : str The sorter name - recording : RecordingExtractor + recording : RecordingExtractor | dict of RecordingExtractor The recording extractor to be spike sorted folder : str or Path Path to output folder @@ -95,21 +95,15 @@ If True, the output Sorting is returned as a Sorting delete_container_files : bool, default: True If True, the container temporary files are deleted after the sorting is done - output_folder : None, default: None - Do not use. Deprecated output function to be removed in 0.103. **sorter_params : keyword args Spike sorter specific arguments (they can be retrieved with `get_default_sorter_params(sorter_name_or_class)`) - Returns - ------- - BaseSorting | None - The spike sorted data (it `with_output` is True) or None (if `with_output` is False) """ def run_sorter( sorter_name: str, - recording: BaseRecording, + recording: BaseRecording | dict, folder: Optional[str] = None, remove_existing_folder: bool = False, delete_output_folder: bool = False, @@ -119,26 +113,21 @@ def run_sorter( singularity_image: Optional[Union[bool, str]] = False, delete_container_files: bool = True, with_output: bool = True, - output_folder: None = None, **sorter_params, ): """ Generic function to run a sorter via function approach. - {} + Returns + ------- + BaseSorting | dict of BaseSorting | None + The spike sorted data (it `with_output` is True) or None (if `with_output` is False) Examples -------- >>> sorting = run_sorter("tridesclous", recording) """ - if output_folder is not None and folder is None: - deprecation_msg = ( - "`output_folder` is deprecated and will be removed in version 0.103.0 Please use folder instead" - ) - folder = output_folder - warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) - common_kwargs = dict( sorter_name=sorter_name, recording=recording, @@ -151,6 +140,21 @@ def run_sorter( **sorter_params, ) + if isinstance(recording, dict): + + all_kwargs = common_kwargs + all_kwargs.update( + dict( + docker_image=docker_image, + singularity_image=singularity_image, + delete_container_files=delete_container_files, + ) + ) + all_kwargs.pop("recording") + + dict_of_sorters = _run_sorter_by_dict(dict_of_recordings=recording, **all_kwargs) + return dict_of_sorters + if docker_image or singularity_image: common_kwargs.update(dict(delete_container_files=delete_container_files)) if docker_image: @@ -201,6 +205,46 @@ def run_sorter( run_sorter.__doc__ = run_sorter.__doc__.format(_common_param_doc) +def _run_sorter_by_dict(dict_of_recordings: dict, folder: str | Path | None = None, **run_sorter_params): + """ + Applies `run_sorter` to each recording in a dict of recordings and saves + the results. + {} + Returns + ------- + dict + Dictionary of `BaseSorting`s, with the same keys as the input dict of `BaseRecording`s. + """ + + sorter_name = run_sorter_params.get("sorter_name") + remove_existing_folder = run_sorter_params.get("remove_existing_folder") + + if folder is None: + folder = Path(sorter_name + "_output") + + folder = Path(folder) + folder.mkdir(exist_ok=remove_existing_folder) + + sorter_dict = {} + for group_key, recording in dict_of_recordings.items(): + sorter_dict[group_key] = run_sorter(recording=recording, folder=folder / f"{group_key}", **run_sorter_params) + + info_file = folder / "spikeinterface_info.json" + info = dict( + version=spikeinterface.__version__, + dev_mode=spikeinterface.DEV_MODE, + object="Group[SorterFolder]", + dict_keys=list(dict_of_recordings.keys()), + ) + with open(info_file, mode="w") as f: + json.dump(check_json(info), f, indent=4) + + return sorter_dict + + +_run_sorter_by_dict.__doc__ = _run_sorter_by_dict.__doc__.format(_common_param_doc) + + def run_sorter_local( sorter_name, recording, @@ -210,7 +254,6 @@ def run_sorter_local( verbose=False, raise_error=True, with_output=True, - output_folder=None, **sorter_params, ): """ @@ -235,20 +278,11 @@ def run_sorter_local( If False, the process continues and the error is logged in the log file with_output : bool, default: True If True, the output Sorting is returned as a Sorting - output_folder : None, default: None - Do not use. Deprecated output function to be removed in 0.103. **sorter_params : keyword args """ if isinstance(recording, list): raise Exception("If you want to run several sorters/recordings use run_sorter_jobs(...)") - if output_folder is not None and folder is None: - deprecation_msg = ( - "`output_folder` is deprecated and will be removed in version 0.103.0 Please use folder instead" - ) - folder = output_folder - warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) - SorterClass = sorter_dict[sorter_name] # only classmethod call not instance (stateless at instance level but state is in folder) @@ -294,7 +328,6 @@ def run_sorter_container( installation_mode="auto", spikeinterface_version=None, spikeinterface_folder_source=None, - output_folder: None = None, **sorter_params, ): """ @@ -309,8 +342,6 @@ def run_sorter_container( The container mode : "docker" or "singularity" container_image : str, default: None The container image name and tag. If None, the default container image is used - output_folder : str, default: None - Path to output folder remove_existing_folder : bool, default: True If True and output_folder exists yet then delete delete_output_folder : bool, default: False @@ -345,13 +376,6 @@ def run_sorter_container( """ assert installation_mode in ("auto", "pypi", "github", "folder", "dev", "no-install") - - if output_folder is not None and folder is None: - deprecation_msg = ( - "`output_folder` is deprecated and will be removed in version 0.103.0 Please use folder instead" - ) - folder = output_folder - warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) spikeinterface_version = spikeinterface_version or si_version if extra_requirements is None: @@ -427,7 +451,7 @@ def run_sorter_container( # run in container output_folder = '{output_folder_unix}' sorting = run_sorter_local( - '{sorter_name}', recording, output_folder=output_folder, + '{sorter_name}', recording, folder=output_folder, remove_existing_folder={remove_existing_folder}, delete_output_folder=False, verbose={verbose}, raise_error={raise_error}, with_output=True, **sorter_params ) diff --git a/src/spikeinterface/sorters/sorterlist.py b/src/spikeinterface/sorters/sorterlist.py index e1a6816133..f5291f4ea4 100644 --- a/src/spikeinterface/sorters/sorterlist.py +++ b/src/spikeinterface/sorters/sorterlist.py @@ -13,6 +13,7 @@ from .external.klusta import KlustaSorter from .external.mountainsort4 import Mountainsort4Sorter from .external.mountainsort5 import Mountainsort5Sorter +from .external.rt_sort import RTSortSorter from .external.spyking_circus import SpykingcircusSorter from .external.tridesclous import TridesclousSorter from .external.waveclus import WaveClusSorter @@ -39,6 +40,7 @@ KlustaSorter, Mountainsort4Sorter, Mountainsort5Sorter, + RTSortSorter, SpykingcircusSorter, TridesclousSorter, WaveClusSorter, diff --git a/src/spikeinterface/sorters/tests/test_runsorter.py b/src/spikeinterface/sorters/tests/test_runsorter.py index b995520f26..7b88de0266 100644 --- a/src/spikeinterface/sorters/tests/test_runsorter.py +++ b/src/spikeinterface/sorters/tests/test_runsorter.py @@ -4,8 +4,10 @@ from pathlib import Path import shutil from packaging.version import parse +import json +import numpy as np -from spikeinterface import generate_ground_truth_recording +from spikeinterface import generate_ground_truth_recording, load from spikeinterface.sorters import run_sorter ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) @@ -34,7 +36,7 @@ def test_run_sorter_local(generate_recording, create_cache_folder): sorting = run_sorter( "tridesclous2", recording, - output_folder=cache_folder / "sorting_tdc_local", + folder=cache_folder / "sorting_tdc_local", remove_existing_folder=True, delete_output_folder=False, verbose=True, @@ -45,6 +47,54 @@ def test_run_sorter_local(generate_recording, create_cache_folder): print(sorting) +def test_run_sorter_dict(generate_recording, create_cache_folder): + recording = generate_recording + cache_folder = create_cache_folder + + recording = recording.time_slice(start_time=0, end_time=3) + + recording.set_property(key="split_property", values=[4, 4, "g", "g", 4, 4, 4, "g"]) + dict_of_recordings = recording.split_by("split_property") + + sorter_params = {"detection": {"detect_threshold": 4.9}} + + folder = cache_folder / "sorting_tdc_local_dict" + + dict_of_sortings = run_sorter( + "simple", + dict_of_recordings, + folder=folder, + remove_existing_folder=True, + delete_output_folder=False, + verbose=True, + raise_error=True, + **sorter_params, + ) + + assert set(list(dict_of_sortings.keys())) == set(["g", "4"]) + assert (folder / "g").is_dir() + assert (folder / "4").is_dir() + + assert dict_of_sortings["g"]._recording.get_num_channels() == 3 + assert dict_of_sortings["4"]._recording.get_num_channels() == 5 + + info_filepath = folder / "spikeinterface_info.json" + assert info_filepath.is_file() + + with open(info_filepath) as f: + spikeinterface_info = json.load(f) + + si_info_keys = spikeinterface_info.keys() + for key in ["version", "dev_mode", "object"]: + assert key in si_info_keys + + loaded_sortings = load(folder) + assert loaded_sortings.keys() == dict_of_sortings.keys() + for key, sorting in loaded_sortings.items(): + assert np.all(sorting.unit_ids == dict_of_sortings[key].unit_ids) + assert np.all(sorting.to_spike_vector() == dict_of_sortings[key].to_spike_vector()) + + @pytest.mark.skipif(ON_GITHUB, reason="Docker tests don't run on github: test locally") def test_run_sorter_docker(generate_recording, create_cache_folder): recording = generate_recording @@ -61,7 +111,7 @@ def test_run_sorter_docker(generate_recording, create_cache_folder): sorting = run_sorter( "tridesclous", recording, - output_folder=output_folder, + folder=output_folder, remove_existing_folder=True, delete_output_folder=False, verbose=True, @@ -96,7 +146,7 @@ def test_run_sorter_singularity(generate_recording, create_cache_folder): sorting = run_sorter( "tridesclous", recording, - output_folder=output_folder, + folder=output_folder, remove_existing_folder=True, delete_output_folder=False, verbose=True, diff --git a/src/spikeinterface/sortingcomponents/matching/base.py b/src/spikeinterface/sortingcomponents/matching/base.py index 0e60a9e864..2dc93b52aa 100644 --- a/src/spikeinterface/sortingcomponents/matching/base.py +++ b/src/spikeinterface/sortingcomponents/matching/base.py @@ -43,6 +43,14 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): def compute_matching(self, traces, start_frame, end_frame, segment_index): raise NotImplementedError + def clean(self): + """ + Clean the matching output. + This is called at the end of the matching process. + """ + # can be overwritten if needed + pass + def get_extra_outputs(self): # can be overwritten if need to ouput some variables with a dict return None diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index faf73465ff..6bbcb77136 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -137,6 +137,9 @@ class CircusOMPSVDPeeler(BaseTemplateMatching): The engine to use for the convolutions torch_device : string in ["cpu", "cuda", None]. Default "cpu" Controls torch device if the torch engine is selected + shared_memory : bool, default True + If True, the overlaps are stored in shared memory, which is more efficient when + using numerous cores ----- """ @@ -167,6 +170,7 @@ def __init__( vicinity=2, precomputed=None, engine="numpy", + shared_memory=True, torch_device="cpu", ): @@ -176,6 +180,7 @@ def __init__( self.num_samples = templates.num_samples self.nbefore = templates.nbefore self.nafter = templates.nafter + self.shared_memory = shared_memory self.sampling_frequency = recording.get_sampling_frequency() self.vicinity = vicinity * self.num_samples assert engine in ["numpy", "torch", "auto"], "engine should be numpy, torch or auto" @@ -208,6 +213,18 @@ def __init__( assert precomputed[key] is not None, "If templates are provided, %d should also be there" % key setattr(self, key, precomputed[key]) + if self.shared_memory: + self.max_overlaps = max([len(o) for o in self.overlaps]) + num_samples = len(self.overlaps[0][0]) + from spikeinterface.core.core_tools import make_shared_array + + arr, shm = make_shared_array((self.num_templates, self.max_overlaps, num_samples), dtype=np.float32) + for i in range(self.num_templates): + n_overlaps = len(self.unit_overlaps_indices[i]) + arr[i, :n_overlaps] = self.overlaps[i] + self.overlaps = arr + self.shm = shm + self.ignore_inds = np.array(ignore_inds) self.unit_overlaps_tables = {} @@ -299,7 +316,10 @@ def _push_to_torch(self): def get_extra_outputs(self): output = {} for key in self._more_output_keys: - output[key] = getattr(self, key) + if key == "overlaps" and self.shared_memory: + output[key] = self.overlaps.copy() + else: + output[key] = getattr(self, key) return output def get_trace_margin(self): @@ -409,7 +429,12 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): myline = neighbor_window + delta_t[idx] myindices = selection[0, idx] - local_overlaps = self.overlaps[best_cluster_ind] + if self.shared_memory: + n_overlaps = len(self.unit_overlaps_indices[best_cluster_ind]) + local_overlaps = self.overlaps[best_cluster_ind, :n_overlaps] + else: + local_overlaps = self.overlaps[best_cluster_ind] + overlapping_templates = self.unit_overlaps_indices[best_cluster_ind] table = self.unit_overlaps_tables[best_cluster_ind] @@ -487,7 +512,11 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): for i in modified: tmp_best, tmp_peak = sub_selection[:, i] diff_amp = diff_amplitudes[i] * self.norms[tmp_best] - local_overlaps = self.overlaps[tmp_best] + if self.shared_memory: + n_overlaps = len(self.unit_overlaps_indices[tmp_best]) + local_overlaps = self.overlaps[tmp_best, :n_overlaps] + else: + local_overlaps = self.overlaps[tmp_best] overlapping_templates = self.units_overlaps[tmp_best] tmp = tmp_peak - neighbor_window idx = [max(0, tmp), min(num_peaks, tmp_peak + self.num_samples)] @@ -532,6 +561,19 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): return spikes + def clean(self): + if self.shared_memory and self.shm is not None: + self.overlaps = None + self.shm.close() + self.shm.unlink() + self.shm = None + + def __del__(self): + if self.shared_memory and self.shm is not None: + self.overlaps = None + self.shm.close() + self.shm = None + class CircusPeeler(BaseTemplateMatching): """ diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index 5895dc433a..6d88396159 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -58,8 +58,13 @@ def find_spikes_from_templates( gather_mode="memory", squeeze_output=True, ) + if extra_outputs: outputs = node0.get_extra_outputs() + + node0.clean() + + if extra_outputs: return spikes, outputs else: return spikes diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 59e171fe52..a0c926cafa 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -54,6 +54,9 @@ class WobbleParameters: The engine to use for the convolutions torch_device : string in ["cpu", "cuda", None]. Default "cpu" Controls torch device if the torch engine is selected + shared_memory : bool, default True + If True, the overlaps are stored in shared memory, which is more efficient when + using numerous cores Notes ----- @@ -77,6 +80,7 @@ class WobbleParameters: scale_amplitudes: bool = False engine: str = "numpy" torch_device: str = "cpu" + shared_memory: bool = True def __post_init__(self): assert self.amplitude_variance >= 0, "amplitude_variance must be a non-negative scalar" @@ -361,6 +365,7 @@ def __init__( parameters={}, engine="numpy", torch_device="cpu", + shared_memory=True, ): BaseTemplateMatching.__init__(self, recording, templates, return_output=True, parents=None) @@ -399,6 +404,24 @@ def __init__( compressed_templates, params.jitter_factor, params.approx_rank, template_meta.jittered_indices, sparsity ) + self.shared_memory = shared_memory + + if self.shared_memory: + self.max_overlaps = max([len(o) for o in pairwise_convolution]) + num_samples = len(pairwise_convolution[0][0]) + num_templates = len(templates_array) + num_jittered = num_templates * params.jitter_factor + from spikeinterface.core.core_tools import make_shared_array + + arr, shm = make_shared_array((num_jittered, self.max_overlaps, num_samples), dtype=np.float32) + for jittered_index in range(num_jittered): + units_are_overlapping = sparsity.unit_overlap[jittered_index, :] + overlapping_units = np.where(units_are_overlapping)[0] + n_overlaps = len(overlapping_units) + arr[jittered_index, :n_overlaps] = pairwise_convolution[jittered_index] + pairwise_convolution = [arr] + self.shm = shm + norm_squared = compute_template_norm(sparsity.visible_channels, templates_array) spatial = np.moveaxis(spatial, [0, 1, 2], [1, 0, 2]) @@ -424,6 +447,19 @@ def __init__( # self.margin = int(buffer_ms*1e-3 * recording.sampling_frequency) self.margin = 300 # To ensure equivalence with spike-psvae version of the algorithm + def clean(self): + if self.shared_memory and self.shm is not None: + self.template_meta = None + self.shm.close() + self.shm.unlink() + self.shm = None + + def __del__(self): + if self.shared_memory and self.shm is not None: + self.template_meta = None + self.shm.close() + self.shm = None + def _push_to_torch(self): if self.engine == "torch": temporal, singular, spatial, temporal_jittered = self.template_data.compressed_templates @@ -644,7 +680,12 @@ def subtract_spike_train( id_scaling = scalings[id_mask] overlapping_templates = sparsity.unit_overlap[jittered_index] # Note: pairwise_conv only has overlapping template convolutions already - pconv = template_data.pairwise_convolution[jittered_index] + if params.shared_memory: + overlapping_units = np.where(overlapping_templates)[0] + n_overlaps = len(overlapping_units) + pconv = template_data.pairwise_convolution[0][jittered_index, :n_overlaps] + else: + pconv = template_data.pairwise_convolution[jittered_index] # TODO: If optimizing for speed -- check this loop for spike_start_index, spike_scaling in zip(id_spiketrain, id_scaling): spike_stop_index = spike_start_index + convolution_resolution_len diff --git a/src/spikeinterface/sortingcomponents/peak_pipeline.py b/src/spikeinterface/sortingcomponents/peak_pipeline.py deleted file mode 100644 index d9ae76f283..0000000000 --- a/src/spikeinterface/sortingcomponents/peak_pipeline.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -import copy - -from spikeinterface.core.node_pipeline import PeakRetriever, run_node_pipeline - - -def run_peak_pipeline( - recording, - peaks, - nodes, - job_kwargs, - job_name="peak_pipeline", - gather_mode="memory", - squeeze_output=True, - folder=None, - names=None, -): - # TODO remove this soon - import warnings - - warnings.warn("run_peak_pipeline() is deprecated use run_node_pipeline() instead", DeprecationWarning, stacklevel=2) - - node0 = PeakRetriever(recording, peaks) - # because nodes are modified inplace (insert parent) they need to copy incase - # the same pipeline is run several times - nodes = copy.deepcopy(nodes) - - for node in nodes: - if node.parents is None: - node.parents = [node0] - else: - node.parents = [node0] + node.parents - all_nodes = [node0] + nodes - outs = run_node_pipeline( - recording, - all_nodes, - job_kwargs, - job_name=job_name, - gather_mode=gather_mode, - squeeze_output=squeeze_output, - folder=folder, - names=names, - ) - return outs diff --git a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py b/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py deleted file mode 100644 index 79a9603b8d..0000000000 --- a/src/spikeinterface/sortingcomponents/tests/test_waveforms/test_waveform_thresholder.py +++ /dev/null @@ -1,109 +0,0 @@ -import pytest -import numpy as np -import operator - - -from spikeinterface.sortingcomponents.waveforms.waveform_thresholder import WaveformThresholder -from spikeinterface.core.node_pipeline import ExtractDenseWaveforms -from spikeinterface.sortingcomponents.peak_pipeline import run_peak_pipeline - - -@pytest.fixture(scope="module") -def extract_dense_waveforms_node(generated_recording): - # Parameters - ms_before = 1.0 - ms_after = 1.0 - - # Node initialization - return ExtractDenseWaveforms( - recording=generated_recording, ms_before=ms_before, ms_after=ms_after, return_output=True - ) - - -def test_waveform_thresholder_ptp( - extract_dense_waveforms_node, generated_recording, detected_peaks, chunk_executor_kwargs -): - recording = generated_recording - peaks = detected_peaks - - tresholded_waveforms_ptp = WaveformThresholder( - recording=recording, parents=[extract_dense_waveforms_node], feature="ptp", threshold=3, return_output=True - ) - noise_levels = tresholded_waveforms_ptp.noise_levels - - pipeline_nodes = [extract_dense_waveforms_node, tresholded_waveforms_ptp] - # Extract projected waveforms and compare - waveforms, tresholded_waveforms = run_peak_pipeline( - recording, peaks, nodes=pipeline_nodes, job_kwargs=chunk_executor_kwargs - ) - - data = np.ptp(tresholded_waveforms, axis=1) / noise_levels - assert np.all(data[data != 0] > 3) - - -def test_waveform_thresholder_mean( - extract_dense_waveforms_node, generated_recording, detected_peaks, chunk_executor_kwargs -): - recording = generated_recording - peaks = detected_peaks - - tresholded_waveforms_mean = WaveformThresholder( - recording=recording, parents=[extract_dense_waveforms_node], feature="mean", threshold=0, return_output=True - ) - - pipeline_nodes = [extract_dense_waveforms_node, tresholded_waveforms_mean] - # Extract projected waveforms and compare - waveforms, tresholded_waveforms = run_peak_pipeline( - recording, peaks, nodes=pipeline_nodes, job_kwargs=chunk_executor_kwargs - ) - - assert np.all(tresholded_waveforms.mean(axis=1) >= 0) - - -def test_waveform_thresholder_energy( - extract_dense_waveforms_node, generated_recording, detected_peaks, chunk_executor_kwargs -): - recording = generated_recording - peaks = detected_peaks - - tresholded_waveforms_energy = WaveformThresholder( - recording=recording, parents=[extract_dense_waveforms_node], feature="energy", threshold=3, return_output=True - ) - noise_levels = tresholded_waveforms_energy.noise_levels - - pipeline_nodes = [extract_dense_waveforms_node, tresholded_waveforms_energy] - # Extract projected waveforms and compare - waveforms, tresholded_waveforms = run_peak_pipeline( - recording, peaks, nodes=pipeline_nodes, job_kwargs=chunk_executor_kwargs - ) - - data = np.linalg.norm(tresholded_waveforms, axis=1) / noise_levels - assert np.all(data[data != 0] > 3) - - -def test_waveform_thresholder_operator( - extract_dense_waveforms_node, generated_recording, detected_peaks, chunk_executor_kwargs -): - recording = generated_recording - peaks = detected_peaks - - import operator - - tresholded_waveforms_peak = WaveformThresholder( - recording=recording, - parents=[extract_dense_waveforms_node], - feature="peak_voltage", - threshold=5, - operator=operator.ge, - return_output=True, - ) - noise_levels = tresholded_waveforms_peak.noise_levels - - pipeline_nodes = [extract_dense_waveforms_node, tresholded_waveforms_peak] - # Extract projected waveforms and compare - waveforms, tresholded_waveforms = run_peak_pipeline( - recording, peaks, nodes=pipeline_nodes, job_kwargs=chunk_executor_kwargs - ) - - data = tresholded_waveforms[:, extract_dense_waveforms_node.nbefore, :] / noise_levels - assert np.all(data[data != 0] <= 5) diff --git a/src/spikeinterface/widgets/amplitudes.py b/src/spikeinterface/widgets/amplitudes.py index 197fefbab2..c3aeb221ab 100644 --- a/src/spikeinterface/widgets/amplitudes.py +++ b/src/spikeinterface/widgets/amplitudes.py @@ -5,7 +5,7 @@ from .rasters import BaseRasterWidget from .base import BaseWidget, to_attr -from .utils import get_some_colors +from .utils import get_some_colors, validate_segment_indices, get_segment_durations from spikeinterface.core.sortinganalyzer import SortingAnalyzer @@ -25,8 +25,9 @@ class AmplitudesWidget(BaseRasterWidget): unit_colors : dict | None, default: None Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted by matplotlib. If None, default colors are chosen using the `get_some_colors` function. - segment_index : int or None, default: None - The segment index (or None if mono-segment) + segment_indices : list of int or None, default: None + Segment index or indices to plot. If None and there are multiple segments, defaults to 0. + If list, spike trains and amplitudes are concatenated across the specified segments. max_spikes_per_unit : int or None, default: None Number of max spikes per unit to display. Use None for all spikes y_lim : tuple or None, default: None @@ -51,7 +52,7 @@ def __init__( sorting_analyzer: SortingAnalyzer, unit_ids=None, unit_colors=None, - segment_index=None, + segment_indices=None, max_spikes_per_unit=None, y_lim=None, scatter_decimate=1, @@ -59,62 +60,96 @@ def __init__( plot_histograms=False, bins=None, plot_legend=True, + segment_index=None, backend=None, **backend_kwargs, ): + import warnings + + # Handle deprecation of segment_index parameter + if segment_index is not None: + warnings.warn( + "The 'segment_index' parameter is deprecated and will be removed in a future version. " + "Use 'segment_indices' instead.", + DeprecationWarning, + stacklevel=2, + ) + if segment_indices is None: + if isinstance(segment_index, int): + segment_indices = [segment_index] + else: + segment_indices = segment_index sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) - sorting = sorting_analyzer.sorting self.check_extensions(sorting_analyzer, "spike_amplitudes") + # Get amplitudes by segment amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(outputs="by_unit") if unit_ids is None: unit_ids = sorting.unit_ids - if sorting.get_num_segments() > 1: - if segment_index is None: - warn("More than one segment available! Using `segment_index = 0`.") - segment_index = 0 - else: - segment_index = 0 + # Handle segment_index input + segment_indices = validate_segment_indices(segment_indices, sorting) + + # Check for SortingView backend + is_sortingview = backend == "sortingview" + + # For SortingView, ensure we're only using a single segment + if is_sortingview and len(segment_indices) > 1: + warn("SortingView backend currently supports only single segment. Using first segment.") + segment_indices = [segment_indices[0]] - amplitudes_segment = amplitudes[segment_index] - total_duration = sorting_analyzer.get_num_samples(segment_index) / sorting_analyzer.sampling_frequency + # Create multi-segment data structure (dict of dicts) + spiketrains_by_segment = {} + amplitudes_by_segment = {} - all_spiketrains = { - unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True) - for unit_id in sorting.unit_ids - } + for idx in segment_indices: + amplitudes_segment = amplitudes[idx] - all_amplitudes = amplitudes_segment + # Initialize for this segment + spiketrains_by_segment[idx] = {} + amplitudes_by_segment[idx] = {} + + for unit_id in unit_ids: + # Get spike times for this unit in this segment + spike_times = sorting.get_unit_spike_train(unit_id, segment_index=idx, return_times=True) + amps = amplitudes_segment[unit_id] + + # Store data in dict of dicts format + spiketrains_by_segment[idx][unit_id] = spike_times + amplitudes_by_segment[idx][unit_id] = amps + + # Apply max_spikes_per_unit limit if specified if max_spikes_per_unit is not None: - spiketrains_to_plot = dict() - amplitudes_to_plot = dict() - for unit, st in all_spiketrains.items(): - amps = all_amplitudes[unit] - if len(st) > max_spikes_per_unit: - random_idxs = np.random.choice(len(st), size=max_spikes_per_unit, replace=False) - spiketrains_to_plot[unit] = st[random_idxs] - amplitudes_to_plot[unit] = amps[random_idxs] - else: - spiketrains_to_plot[unit] = st - amplitudes_to_plot[unit] = amps - else: - spiketrains_to_plot = all_spiketrains - amplitudes_to_plot = all_amplitudes + for idx in segment_indices: + for unit_id in unit_ids: + st = spiketrains_by_segment[idx][unit_id] + amps = amplitudes_by_segment[idx][unit_id] + if len(st) > max_spikes_per_unit: + # Scale down the number of spikes proportionally per segment + # to ensure we have max_spikes_per_unit total after concatenation + segment_count = len(segment_indices) + segment_max = max(1, max_spikes_per_unit // segment_count) + + if len(st) > segment_max: + random_idxs = np.random.choice(len(st), size=segment_max, replace=False) + spiketrains_by_segment[idx][unit_id] = st[random_idxs] + amplitudes_by_segment[idx][unit_id] = amps[random_idxs] if plot_histograms and bins is None: bins = 100 + # Calculate durations for all segments for x-axis limits + durations = get_segment_durations(sorting, segment_indices) + + # Build the plot data with the full dict of dicts structure plot_data = dict( - spike_train_data=spiketrains_to_plot, - y_axis_data=amplitudes_to_plot, unit_colors=unit_colors, plot_histograms=plot_histograms, bins=bins, - total_duration=total_duration, + durations=durations, unit_ids=unit_ids, hide_unit_selector=hide_unit_selector, plot_legend=plot_legend, @@ -123,6 +158,17 @@ def __init__( scatter_decimate=scatter_decimate, ) + # If using SortingView, extract just the first segment's data as flat dicts + if is_sortingview: + first_segment = segment_indices[0] + plot_data["spike_train_data"] = spiketrains_by_segment[first_segment] + plot_data["y_axis_data"] = amplitudes_by_segment[first_segment] + else: + # Otherwise use the full dict of dicts structure with all segments + plot_data["spike_train_data"] = spiketrains_by_segment + plot_data["y_axis_data"] = amplitudes_by_segment + plot_data["segment_indices"] = segment_indices + BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) def plot_sortingview(self, data_plot, **backend_kwargs): @@ -143,7 +189,10 @@ def plot_sortingview(self, data_plot, **backend_kwargs): ] self.view = vv.SpikeAmplitudes( - start_time_sec=0, end_time_sec=dp.total_duration, plots=sa_items, hide_unit_selector=dp.hide_unit_selector + start_time_sec=0, + end_time_sec=np.sum(dp.durations), + plots=sa_items, + hide_unit_selector=dp.hide_unit_selector, ) self.url = handle_display_and_url(self, self.view, **backend_kwargs) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index f28560fcd6..3c08217300 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -117,14 +117,14 @@ class DriftRasterMapWidget(BaseRasterWidget): "spike_locations" extension computed. direction : "x" or "y", default: "y" The direction to display. "y" is the depth direction. - segment_index : int, default: None - The segment index to display. recording : RecordingExtractor | None, default: None The recording extractor object (only used to get "real" times). - segment_index : int, default: 0 - The segment index to display. sampling_frequency : float, default: None The sampling frequency (needed if recording is None). + segment_indices : list of int or None, default: None + The segment index or indices to display. If None and there's only one segment, it's used. + If None and there are multiple segments, you must specify which to use. + If a list of indices is provided, peaks and locations are concatenated across the segments. depth_lim : tuple or None, default: None The min and max depth to display, if None (min and max of the recording). scatter_decimate : int, default: None @@ -149,7 +149,7 @@ def __init__( direction: str = "y", recording: BaseRecording | None = None, sampling_frequency: float | None = None, - segment_index: int | None = None, + segment_indices: list[int] | None = None, depth_lim: tuple[float, float] | None = None, color_amplitude: bool = True, scatter_decimate: int | None = None, @@ -157,10 +157,30 @@ def __init__( color: str = "Gray", clim: tuple[float, float] | None = None, alpha: float = 1, + segment_index: int | list[int] | None = None, backend: str | None = None, **backend_kwargs, ): + import warnings + from matplotlib.pyplot import colormaps + from matplotlib.colors import Normalize + + # Handle deprecation of segment_index parameter + if segment_index is not None: + warnings.warn( + "The 'segment_index' parameter is deprecated and will be removed in a future version. " + "Use 'segment_indices' instead.", + DeprecationWarning, + stacklevel=2, + ) + if segment_indices is None: + if isinstance(segment_index, int): + segment_indices = [segment_index] + else: + segment_indices = segment_index + assert peaks is not None or sorting_analyzer is not None + if peaks is not None: assert peak_locations is not None if recording is None: @@ -168,6 +188,7 @@ def __init__( else: sampling_frequency = recording.sampling_frequency peak_amplitudes = peaks["amplitude"] + if sorting_analyzer is not None: if sorting_analyzer.has_recording(): recording = sorting_analyzer.recording @@ -190,29 +211,56 @@ def __init__( else: peak_amplitudes = None - if segment_index is None: - assert ( - len(np.unique(peaks["segment_index"])) == 1 - ), "segment_index must be specified if there are multiple segments" - segment_index = 0 - else: - peak_mask = peaks["segment_index"] == segment_index - peaks = peaks[peak_mask] - peak_locations = peak_locations[peak_mask] - if peak_amplitudes is not None: - peak_amplitudes = peak_amplitudes[peak_mask] - - from matplotlib.pyplot import colormaps + unique_segments = np.unique(peaks["segment_index"]) - if color_amplitude: - amps = peak_amplitudes + if segment_indices is None: + if len(unique_segments) == 1: + segment_indices = [int(unique_segments[0])] + else: + raise ValueError("segment_indices must be specified if there are multiple segments") + + if not isinstance(segment_indices, list): + raise ValueError("segment_indices must be a list of ints") + + # Validate all segment indices exist in the data + for idx in segment_indices: + if idx not in unique_segments: + raise ValueError(f"segment_index {idx} not found in peaks data") + + # Filter data for the selected segments + # Note: For simplicity, we'll filter all data first, then construct dict of dicts + segment_mask = np.isin(peaks["segment_index"], segment_indices) + filtered_peaks = peaks[segment_mask] + filtered_locations = peak_locations[segment_mask] + if peak_amplitudes is not None: + filtered_amplitudes = peak_amplitudes[segment_mask] + + # Create dict of dicts structure for the base class + spike_train_data = {} + y_axis_data = {} + + # Process each segment separately + for seg_idx in segment_indices: + segment_mask = filtered_peaks["segment_index"] == seg_idx + segment_peaks = filtered_peaks[segment_mask] + segment_locations = filtered_locations[segment_mask] + + # Convert peak times to seconds + spike_times = segment_peaks["sample_index"] / sampling_frequency + + # Store in dict of dicts format (using 0 as the "unit" id) + spike_train_data[seg_idx] = {0: spike_times} + y_axis_data[seg_idx] = {0: segment_locations[direction]} + + if color_amplitude and peak_amplitudes is not None: + amps = filtered_amplitudes amps_abs = np.abs(amps) q_95 = np.quantile(amps_abs, 0.95) - cmap = colormaps[cmap] + cmap_obj = colormaps[cmap] if clim is None: amps = amps_abs amps /= q_95 - c = cmap(amps) + c = cmap_obj(amps) else: from matplotlib.colors import Normalize @@ -226,18 +274,30 @@ def __init__( else: color_kwargs = dict(color=color, c=None, alpha=alpha) - # convert data into format that `BaseRasterWidget` can take it in - spike_train_data = {0: peaks["sample_index"] / sampling_frequency} - y_axis_data = {0: peak_locations[direction]} + # Calculate segment durations for x-axis limits + if recording is not None: + durations = [recording.get_duration(seg_idx) for seg_idx in segment_indices] + else: + # Find boundaries between segments using searchsorted + segment_boundaries = [ + np.searchsorted(filtered_peaks["segment_index"], [seg_idx, seg_idx + 1]) for seg_idx in segment_indices + ] + + # Calculate durations from max sample in each segment + durations = [ + (filtered_peaks["sample_index"][end - 1] + 1) / sampling_frequency for (_, end) in segment_boundaries + ] plot_data = dict( spike_train_data=spike_train_data, y_axis_data=y_axis_data, + segment_indices=segment_indices, y_lim=depth_lim, color_kwargs=color_kwargs, scatter_decimate=scatter_decimate, title="Peak depth", y_label="Depth [um]", + durations=durations, ) BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) @@ -370,10 +430,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp.recording, ) - commpon_drift_map_kwargs = dict( + common_drift_map_kwargs = dict( direction=dp.motion.direction, recording=dp.recording, - segment_index=dp.segment_index, + segment_indices=[dp.segment_index], depth_lim=dp.depth_lim, scatter_decimate=dp.scatter_decimate, color_amplitude=dp.color_amplitude, @@ -390,7 +450,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp.peak_locations, ax=ax0, immediate_plot=True, - **commpon_drift_map_kwargs, + **common_drift_map_kwargs, ) _ = DriftRasterMapWidget( @@ -398,7 +458,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): corrected_location, ax=ax1, immediate_plot=True, - **commpon_drift_map_kwargs, + **common_drift_map_kwargs, ) ax2.plot(temporal_bins_s, displacement, alpha=0.2, color="black") diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 5c6f867f2c..757401d77c 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -4,8 +4,8 @@ from warnings import warn from spikeinterface.core import SortingAnalyzer, BaseSorting -from .base import BaseWidget, to_attr -from .utils import get_some_colors +from .base import BaseWidget, to_attr, default_backend_kwargs +from .utils import get_some_colors, validate_segment_indices, get_segment_durations class BaseRasterWidget(BaseWidget): @@ -16,14 +16,19 @@ class BaseRasterWidget(BaseWidget): Parameters ---------- - spike_train_data : dict - A dict of spike trains, indexed by the unit_id - y_axis_data : dict - A dict of the y-axis data, indexed by the unit_id + spike_train_data : dict of dicts + A dict of dicts where the structure is spike_train_data[segment_index][unit_id]. + y_axis_data : dict of dicts + A dict of dicts where the structure is y_axis_data[segment_index][unit_id]. + For backwards compatibility, a flat dict indexed by unit_id will be internally + converted to a dict of dicts with segment 0. unit_ids : array-like | None, default: None List of unit_ids to plot - total_duration : int | None, default: None - Duration of spike_train_data in seconds. + segment_indices : list | None, default: None + For multi-segment data, specifies which segment(s) to plot. If None, uses all available segments. + For single-segment data, this parameter is ignored. + durations : list | None, default: None + List of durations per segment of spike_train_data in seconds. plot_histograms : bool, default: False Plot histogram of y-axis data in another subplot bins : int | None, default: None @@ -49,6 +54,8 @@ class BaseRasterWidget(BaseWidget): Ticks on y-axis, passed to `set_yticks`. If None, default ticks are used. hide_unit_selector : bool, default: False For sortingview backend, if True the unit selector is not displayed + segment_boundary_kwargs : dict | None, default: None + Additional arguments for the segment boundary lines, passed to `matplotlib.axvline` backend : str | None, default None Which plotting backend to use e.g. 'matplotlib', 'ipywidgets'. If None, uses default from `get_default_plotter_backend`. @@ -59,7 +66,8 @@ def __init__( spike_train_data: dict, y_axis_data: dict, unit_ids: list | None = None, - total_duration: int | None = None, + segment_indices: list | None = None, + durations: list | None = None, plot_histograms: bool = False, bins: int | None = None, scatter_decimate: int = 1, @@ -72,13 +80,72 @@ def __init__( y_label: str | None = None, y_ticks: bool = False, hide_unit_selector: bool = True, + segment_boundary_kwargs: dict | None = None, backend: str | None = None, **backend_kwargs, ): + # Set default segment boundary kwargs if not provided + if segment_boundary_kwargs is None: + segment_boundary_kwargs = {"color": "gray", "linestyle": "--", "alpha": 0.7} + + # Process the data + available_segments = list(spike_train_data.keys()) + available_segments.sort() # Ensure consistent ordering + + # Determine which segments to use + if segment_indices is None: + # Use all segments by default + segments_to_use = available_segments + elif isinstance(segment_indices, list): + # Multiple segments specified + for idx in segment_indices: + if idx not in available_segments: + raise ValueError(f"segment_index {idx} not found in avialable segments {available_segments}") + segments_to_use = segment_indices + else: + raise ValueError("segment_index must be `list` or `None`") + + # Get all unit IDs present in any segment if not specified + if unit_ids is None: + all_units = set() + for seg_idx in segments_to_use: + all_units.update(spike_train_data[seg_idx].keys()) + unit_ids = list(all_units) + + # Calculate cumulative durations for segment boundaries + segment_boundaries = np.cumsum(durations) + cumulative_durations = np.concatenate([[0], segment_boundaries]) + + # Concatenate data across segments with proper time offsets + concatenated_spike_trains = {unit_id: np.array([]) for unit_id in unit_ids} + concatenated_y_axis = {unit_id: np.array([]) for unit_id in unit_ids} + + for offset, spike_train_segment, y_axis_segment in zip( + cumulative_durations, + [spike_train_data[idx] for idx in segments_to_use], + [y_axis_data[idx] for idx in segments_to_use], + ): + # Process each unit in the current segment + for unit_id, spike_times in spike_train_segment.items(): + if unit_id not in unit_ids: + continue + + # Get y-axis values for this unit + y_values = y_axis_segment[unit_id] + + # Apply offset to spike times + adjusted_times = spike_times + offset + + # Add to concatenated data + concatenated_spike_trains[unit_id] = np.concatenate( + [concatenated_spike_trains[unit_id], adjusted_times] + ) + concatenated_y_axis[unit_id] = np.concatenate([concatenated_y_axis[unit_id], y_values]) + plot_data = dict( - spike_train_data=spike_train_data, - y_axis_data=y_axis_data, + spike_train_data=concatenated_spike_trains, + y_axis_data=concatenated_y_axis, unit_ids=unit_ids, plot_histograms=plot_histograms, y_lim=y_lim, @@ -88,11 +155,13 @@ def __init__( unit_colors=unit_colors, y_label=y_label, title=title, - total_duration=total_duration, + durations=durations, plot_legend=plot_legend, bins=bins, y_ticks=y_ticks, hide_unit_selector=hide_unit_selector, + segment_boundaries=segment_boundaries, + segment_boundary_kwargs=segment_boundary_kwargs, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -135,6 +204,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): y_axis_data = dp.y_axis_data for unit_id in unit_ids: + if unit_id not in spike_train_data: + continue # Skip this unit if not in data unit_spike_train = spike_train_data[unit_id][:: dp.scatter_decimate] unit_y_data = y_axis_data[unit_id][:: dp.scatter_decimate] @@ -156,6 +227,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): count, bins = np.histogram(unit_y_data, bins=bins) ax_hist.plot(count, bins[:-1], color=unit_colors[unit_id], alpha=0.8) + # Add segment boundary lines if provided + if getattr(dp, "segment_boundaries", None) is not None: + for boundary in dp.segment_boundaries: + scatter_ax.axvline(boundary, **dp.segment_boundary_kwargs) + if dp.plot_histograms: ax_hist = self.axes.flatten()[1] ax_hist.set_ylim(scatter_ax.get_ylim()) @@ -172,7 +248,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): scatter_ax.set_ylim(*dp.y_lim) x_lim = dp.x_lim if x_lim is None: - x_lim = [0, dp.total_duration] + x_lim = [0, np.sum(dp.durations)] scatter_ax.set_xlim(x_lim) if dp.y_ticks: @@ -298,16 +374,33 @@ class RasterWidget(BaseRasterWidget): def __init__( self, sorting_analyzer_or_sorting: SortingAnalyzer | BaseSorting | None = None, - segment_index: int | None = None, + segment_indices: int | None = None, unit_ids: list | None = None, time_range: list | None = None, color="k", backend: str | None = None, sorting: BaseSorting | None = None, sorting_analyzer: SortingAnalyzer | None = None, + segment_index: int | None = None, **backend_kwargs, ): + import warnings + + # Handle deprecation of segment_index parameter + if segment_index is not None: + warnings.warn( + "The 'segment_index' parameter is deprecated and will be removed in a future version. " + "Use 'segment_indices' instead.", + DeprecationWarning, + stacklevel=2, + ) + if segment_indices is None: + if isinstance(segment_index, int): + segment_indices = [segment_index] + else: + segment_indices = segment_index + if sorting is not None: # When removed, make `sorting_analyzer_or_sorting` a required argument rather than None. deprecation_msg = "`sorting` argument is deprecated and will be removed in version 0.105.0. Please use `sorting_analyzer_or_sorting` instead" @@ -320,30 +413,40 @@ def __init__( sorting = self.ensure_sorting(sorting_analyzer_or_sorting) - if sorting.get_num_segments() > 1: - if segment_index is None: - warn("More than one segment available! Using `segment_index = 0`.") - segment_index = 0 - else: - segment_index = 0 + segment_indices = validate_segment_indices(segment_indices, sorting) if unit_ids is None: unit_ids = sorting.unit_ids - all_spiketrains = { - unit_id: sorting.get_unit_spike_train(unit_id, segment_index=segment_index, return_times=True) - for unit_id in unit_ids - } + # Create dict of dicts structure + spike_train_data = {} + y_axis_data = {} - if time_range is not None: - assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" + # Create a lookup dictionary for unit indices + unit_indices_map = {unit_id: i for i, unit_id in enumerate(unit_ids)} + + # Estimate segment duration from max spike time in each segment + durations = get_segment_durations(sorting, segment_indices) + + # Extract spike data for all segments and units at once + spike_train_data = {seg_idx: {} for seg_idx in segment_indices} + y_axis_data = {seg_idx: {} for seg_idx in segment_indices} + + for seg_idx in segment_indices: for unit_id in unit_ids: - unit_st = all_spiketrains[unit_id] - all_spiketrains[unit_id] = unit_st[(time_range[0] < unit_st) & (unit_st < time_range[1])] + # Get spikes for this segment and unit + spike_times = ( + sorting.get_unit_spike_train(unit_id=unit_id, segment_index=seg_idx) / sorting.sampling_frequency + ) + + # Store data + spike_train_data[seg_idx][unit_id] = spike_times + y_axis_data[seg_idx][unit_id] = unit_indices_map[unit_id] * np.ones(len(spike_times)) - raster_locations = { - unit_id: unit_index * np.ones(len(all_spiketrains[unit_id])) for unit_index, unit_id in enumerate(unit_ids) - } + # Apply time range filtering if specified + if time_range is not None: + assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" + # Let BaseRasterWidget handle the filtering unit_indices = list(range(len(unit_ids))) @@ -354,14 +457,16 @@ def __init__( y_ticks = {"ticks": unit_indices, "labels": unit_ids} plot_data = dict( - spike_train_data=all_spiketrains, - y_axis_data=raster_locations, + spike_train_data=spike_train_data, + y_axis_data=y_axis_data, + segment_indices=segment_indices, x_lim=time_range, y_label="Unit id", unit_ids=unit_ids, unit_colors=unit_colors, plot_histograms=None, y_ticks=y_ticks, + durations=durations, ) BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs) diff --git a/src/spikeinterface/widgets/sorting_summary.py b/src/spikeinterface/widgets/sorting_summary.py index 67322398fb..b021f98563 100644 --- a/src/spikeinterface/widgets/sorting_summary.py +++ b/src/spikeinterface/widgets/sorting_summary.py @@ -88,7 +88,7 @@ def __init__( if unit_table_properties is not None: warnings.warn( - "plot_sorting_summary() : unit_table_properties is deprecated, use displayed_unit_properties instead", + "plot_sorting_summary() : `unit_table_properties` is deprecated and will be removed in version 0.104.0, use `displayed_unit_properties` instead", category=DeprecationWarning, stacklevel=2, ) diff --git a/src/spikeinterface/widgets/tests/test_widgets_utils.py b/src/spikeinterface/widgets/tests/test_widgets_utils.py index 2131969c2c..3adf31c189 100644 --- a/src/spikeinterface/widgets/tests/test_widgets_utils.py +++ b/src/spikeinterface/widgets/tests/test_widgets_utils.py @@ -1,4 +1,7 @@ -from spikeinterface.widgets.utils import get_some_colors +import pytest + +from spikeinterface import generate_sorting +from spikeinterface.widgets.utils import get_some_colors, validate_segment_indices, get_segment_durations def test_get_some_colors(): @@ -19,5 +22,76 @@ def test_get_some_colors(): # print(colors) +def test_validate_segment_indices(): + # Setup + sorting_single = generate_sorting(durations=[5]) # 1 segment + sorting_multiple = generate_sorting(durations=[5, 10, 15, 20, 25]) # 5 segments + + # Test None with single segment + assert validate_segment_indices(None, sorting_single) == [0] + + # Test None with multiple segments + with pytest.warns(UserWarning): + assert validate_segment_indices(None, sorting_multiple) == [0] + + # Test valid indices + assert validate_segment_indices([0], sorting_single) == [0] + assert validate_segment_indices([0, 1, 4], sorting_multiple) == [0, 1, 4] + + # Test invalid type + with pytest.raises(TypeError): + validate_segment_indices(0, sorting_multiple) + + # Test invalid index type + with pytest.raises(ValueError): + validate_segment_indices([0, "1"], sorting_multiple) + + # Test out of range + with pytest.raises(ValueError): + validate_segment_indices([5], sorting_multiple) + + +def test_get_segment_durations(): + from spikeinterface import generate_sorting + + # Test with a normal multi-segment sorting + durations = [5.0, 10.0, 15.0] + + # Create sorting with high fr to ensure spikes near the end segments + sorting = generate_sorting( + durations=durations, + firing_rates=15.0, + ) + + segment_indices = list(range(sorting.get_num_segments())) + + # Calculate durations + calculated_durations = get_segment_durations(sorting, segment_indices) + + # Check results + assert len(calculated_durations) == len(durations) + # Durations should be approximately correct + for calculated_duration, expected_duration in zip(calculated_durations, durations): + # Duration should be <= expected (spikes can't be after the end) + assert calculated_duration <= expected_duration + # And reasonably close + tolerance = max(0.1 * expected_duration, 0.1) + assert expected_duration - calculated_duration < tolerance + + # Test with single-segment sorting + sorting_single = generate_sorting( + durations=[7.0], + firing_rates=15.0, + ) + + single_duration = get_segment_durations(sorting_single, [0])[0] + + # Test that the calculated duration is reasonable + assert single_duration <= 7.0 + assert 7.0 - single_duration < 0.7 # Within 10% + + if __name__ == "__main__": test_get_some_colors() + test_validate_segment_indices() + test_get_segment_durations() diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index a1ac9d4af9..75fb74cfae 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -3,6 +3,8 @@ from warnings import warn import numpy as np +from spikeinterface.core import BaseSorting + def get_some_colors( keys, @@ -349,3 +351,73 @@ def make_units_table_from_analyzer( ) return units_table + + +def validate_segment_indices(segment_indices: list[int] | None, sorting: BaseSorting): + """ + Validate a list of segment indices for a sorting object. + + Parameters + ---------- + segment_indices : list of int + The segment index or indices to validate. + sorting : BaseSorting + The sorting object to validate against. + + Returns + ------- + list of int + A list of valid segment indices. + + Raises + ------ + ValueError + If the segment indices are not valid. + """ + num_segments = sorting.get_num_segments() + + # Handle segment_indices input + if segment_indices is None: + if num_segments > 1: + warn("Segment indices not specified. Using first available segment only.") + return [0] + + # Convert segment_index to list for consistent processing + if not isinstance(segment_indices, list): + raise ValueError( + "segment_indices must be a list of ints - available segments are: " + list(range(num_segments)) + ) + + # Validate segment indices + for idx in segment_indices: + if not isinstance(idx, int): + raise ValueError(f"Each segment index must be an integer, got {type(idx)}") + if idx < 0 or idx >= num_segments: + raise ValueError(f"segment_index {idx} out of range (0 to {num_segments - 1})") + + return segment_indices + + +def get_segment_durations(sorting: BaseSorting, segment_indices: list[int]) -> list[float]: + """ + Calculate the duration of each segment in a sorting object. + + Parameters + ---------- + sorting : BaseSorting + The sorting object containing spike data + + Returns + ------- + list[float] + List of segment durations in seconds + """ + spikes = sorting.to_spike_vector() + + segment_boundaries = [ + np.searchsorted(spikes["segment_index"], [seg_idx, seg_idx + 1]) for seg_idx in segment_indices + ] + + durations = [(spikes["sample_index"][end - 1] + 1) / sorting.sampling_frequency for (_, end) in segment_boundaries] + + return durations