diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index 13b6e886f0..b075704628 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -30,6 +30,7 @@ from spikeinterface.core.testing import check_sortings_equal from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter from probeinterface.io import write_prb +from spikeinterface.extractors import read_kilosort_as_analyzer import kilosort from kilosort.parameters import DEFAULT_SETTINGS @@ -405,12 +406,15 @@ def test_kilosort4_main(self, recording_and_paths, default_kilosort_sorting, tmp ops = ops.tolist() # strangely this makes a dict assert ops[param_key] == param_value - # Finally, check out test parameters actually change the output of - # KS4, ensuring our tests are actually doing something (exxcept for some params). + # Check our test parameters actually change the output of + # KS4, ensuring our tests are actually doing something (except for some params). if param_key not in PARAMETERS_NOT_AFFECTING_RESULTS: with pytest.raises(AssertionError): check_sortings_equal(default_kilosort_sorting, sorting_si) + # Check that the kilosort -> analyzer tool doesn't error + analyzer = read_kilosort_as_analyzer(kilosort_output_dir) + def test_clear_cache(self,recording_and_paths, tmp_path): """ Test clear_cache parameter in kilosort4.run_kilosort diff --git a/doc/how_to/import_kilosort_data.rst b/doc/how_to/import_kilosort_data.rst new file mode 100644 index 0000000000..dad522334a --- /dev/null +++ b/doc/how_to/import_kilosort_data.rst @@ -0,0 +1,81 @@ +Import Kilosort4 output +======================= + +If you have sorted your data with `Kilosort4 `__, your sorter output is saved in format which was +designed to be compatible with `phy `__. SpikeInterface provides a function which can be used to +transform this output into a ``SortingAnalyzer``. This is helpful if you'd like to compute some more properties of your sorting +(e.g. quality and template metrics), or if you'd like to visualize your output using `spikeinterface-gui `__. + +To create an analyzer from a Kilosort4 output folder, simply run + +.. code:: + + from spikeinterface.extractors import read_kilosort_as_analyzer + sorting_analyzer = read_kilosort_as_analyzer('path/to/output') + +The ``'path/to/output'`` should point to the Kilosort4 output folder. If you ran Kilosort4 natively, this is wherever you asked Kilosort4 to +save your output. If you ran Kilosort4 using SpikeInterface, this is in the ``sorter_output`` folder inside the ``output_folder`` created +when you ran ``run_sorter``. + +Note: the function ``read_kilosort_as_analyzer`` might work on older versions of Kilosort such as Kilosort2 and Kilosort3. +However, we do not guarantee that the results are correct. + +The ``analyzer`` object contains as much information as it can grab from the Kilosort4 output. If everything works, it should contain +information about the ``templates``, ``spike_locations`` and ``spike_amplitudes``. These are stored as ``extensions`` of the ``SortingAnalyzer``. +You can compute extra information about the sorting using the ``compute`` method. For example, + +.. code:: + + sorting_analyzer.compute({ + "unit_locations": {}, + "correlograms": {}, + "template_similarity": {}, + "isi_histograms": {}, + "template_metrics": {"include_multi_channel_metrics": True}, + "quality_metrics": {}, + }) + +widgets.html#available-plotting-functions + +Learn more about the ``SortingAnalyzer`` and its ``extensions`` `here `__. + +If you'd like to store the information you've computed, you can save the analyzer: + +.. code:: + + sorting_analyzer.save_as( + format="binary_folder", + folder="my_kilosort_analyzer" + ) + +You now have a fully functional ``SortingAnalyzer`` - congrats! You can now use `spikeinterface-gui `__. to view the results +interactively, or start manually labelling your units to `create an automated curation model `__. + +Note that if you have access to the raw recording, you can attach it to the analyzer, and re-compute extensions from the raw data. E.g. + +.. code:: + + from spikeinterface.extractors import read_kilosort_as_analyzer + import spikeinterface.extractors as se + import spikeinterface.extractors as spre + + recording = se.read_openephys('path/to/recording') + + preprocessed_recording = spre.bandpass_filter(spre.common_reference(recording)) + + sorting_analyzer = read_kilosort_as_analyzer('path/to/output') + sorting_analyzer.set_temporary_recording(preprocessed_recording) + + sorting_analyzer.compute({ + "spike_locations": {}, + "spike_amplitudes": {}, + "unit_locations": {}, + "correlograms": {}, + "template_similarity": {}, + "isi_histograms": {}, + "template_metrics": {"include_multi_channel_metrics": True}, + "quality_metrics": {}, + }) + + +This will take longer since you are dealing with the raw recording, but you do have a lot of control over how to compute the extensions. diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index 2d490c207d..db02da5cef 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -22,3 +22,4 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to. benchmark_with_hybrid_recordings auto_curation_training auto_curation_prediction + import_kilosort_data diff --git a/src/spikeinterface/extractors/__init__.py b/src/spikeinterface/extractors/__init__.py index 604dee68f9..216c668e0e 100644 --- a/src/spikeinterface/extractors/__init__.py +++ b/src/spikeinterface/extractors/__init__.py @@ -3,10 +3,9 @@ from .toy_example import toy_example as toy_example from .bids import read_bids as read_bids - from .neuropixels_utils import get_neuropixels_channel_groups, get_neuropixels_sample_shifts - from .neoextractors import get_neo_num_blocks, get_neo_streams +from .phykilosortextractors import read_kilosort_as_analyzer from warnings import warn diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 46a8e4cecb..68b16074fb 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -2,12 +2,25 @@ from typing import Optional from pathlib import Path +import warnings import numpy as np -from spikeinterface.core import BaseSorting, BaseSortingSegment, read_python +from spikeinterface.core import ( + BaseSorting, + BaseSortingSegment, + read_python, + generate_ground_truth_recording, + ChannelSparsity, + ComputeTemplates, + create_sorting_analyzer, + SortingAnalyzer, +) from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations +from probeinterface import read_prb, Probe + class BasePhyKilosortSortingExtractor(BaseSorting): """Base SortingExtractor for Phy and Kilosort output folder. @@ -302,3 +315,177 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove read_phy = define_function_from_class(source_class=PhySortingExtractor, name="read_phy") read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort") + + +def read_kilosort_as_analyzer(folder_path, unwhiten=True) -> SortingAnalyzer: + """ + Load Kilosort output into a SortingAnalyzer. Output from Kilosort version 4.1 and + above are supported. The function may work on older versions of Kilosort output, + but these are not carefully tested. Please check your output carefully. + + Parameters + ---------- + folder_path : str or Path + Path to the output Phy folder (containing the params.py). + unwhiten : bool, default: True + Unwhiten the templates computed by kilosort. + + Returns + ------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + """ + + phy_path = Path(folder_path) + + sorting = read_phy(phy_path) + sampling_frequency = sorting.sampling_frequency + + # kilosort occasionally contains a few spikes just beyond the recording end point, which can lead + # to errors later. To avoid this, we pad the recording with an extra second of blank time. + duration = sorting._sorting_segments[0]._all_spikes[-1] / sampling_frequency + 1 + + if (phy_path / "probe.prb").is_file(): + probegroup = read_prb(phy_path / "probe.prb") + if len(probegroup.probes) > 0: + warnings.warn("Found more than one probe. Selecting the first probe in ProbeGroup.") + probe = probegroup.probes[0] + elif (phy_path / "channel_positions.npy").is_file(): + probe = Probe(si_units="um") + channel_positions = np.load(phy_path / "channel_positions.npy") + probe.set_contacts(channel_positions) + probe.set_device_channel_indices(range(probe.get_contact_count())) + else: + AssertionError(f"Cannot read probe layout from folder {phy_path}.") + + # to make the initial analyzer, we'll use a fake recording and set it to None later + recording, _ = generate_ground_truth_recording( + probe=probe, + sampling_frequency=sampling_frequency, + durations=[duration], + num_units=1, + seed=1205, + ) + + sparsity = _make_sparsity_from_templates(sorting, recording, phy_path) + + sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True, sparsity=sparsity) + + # first compute random spikes. These do nothing, but are needed for si-gui to run + sorting_analyzer.compute("random_spikes") + + _make_templates(sorting_analyzer, phy_path, sparsity.mask, sampling_frequency, unwhiten=unwhiten) + _make_locations(sorting_analyzer, phy_path) + + sorting_analyzer._recording = None + return sorting_analyzer + + +def _make_locations(sorting_analyzer, kilosort_output_path): + """Constructs a `spike_locations` extension from the amplitudes numpy array + in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" + + locations_extension = ComputeSpikeLocations(sorting_analyzer) + + spike_locations_path = kilosort_output_path / "spike_positions.npy" + if spike_locations_path.is_file(): + locs_np = np.load(spike_locations_path) + else: + return + + # Check that the spike locations vector is the same size as the spike vector + num_spikes = len(sorting_analyzer.sorting.to_spike_vector()) + num_spike_locs = len(locs_np) + if num_spikes != num_spike_locs: + warnings.warn( + "The number of spikes does not match the number of spike locations in `spike_positions.npy`. Skipping spike locations." + ) + return + + num_dims = len(locs_np[0]) + column_names = ["x", "y", "z"][:num_dims] + dtype = [(name, locs_np.dtype) for name in column_names] + + structured_array = np.zeros(len(locs_np), dtype=dtype) + for coordinate_index, column_name in enumerate(column_names): + structured_array[column_name] = locs_np[:, coordinate_index] + + locations_extension.data = {"spike_locations": structured_array} + locations_extension.params = {} + locations_extension.run_info = {"run_completed": True} + + sorting_analyzer.extensions["spike_locations"] = locations_extension + + +def _make_sparsity_from_templates(sorting, recording, kilosort_output_path): + """Constructs the `ChannelSparsity` of from kilosort output, by seeing if the + templates output is zero or not on all channels.""" + + templates = np.load(kilosort_output_path / "templates.npy") + + unit_ids = sorting.unit_ids + channel_ids = recording.channel_ids + + # The raw templates have dense dimensions (num chan)x(num samples)x(num units) + # but are zero on many channels, which implicitly defines the sparsity + mask = np.sum(np.abs(templates), axis=1) != 0 + return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids) + + +def _make_templates(sorting_analyzer, kilosort_output_path, mask, sampling_frequency, unwhiten=True): + """Constructs a `templates` extension from the amplitudes numpy array + in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" + + template_extension = ComputeTemplates(sorting_analyzer) + + whitened_templates = np.load(kilosort_output_path / "templates.npy") + wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") + new_templates = _compute_unwhitened_templates(whitened_templates, wh_inv) if unwhiten else whitened_templates + + template_extension.data = {"average": new_templates} + + ops_path = kilosort_output_path / "ops.npy" + if ops_path.is_file(): + ops = np.load(ops_path, allow_pickle=True) + + number_samples_before_template_peak = ops.item(0)["nt0min"] + total_template_samples = ops.item(0)["nt"] + + number_samples_after_template_peak = total_template_samples - number_samples_before_template_peak + + ms_before = number_samples_before_template_peak / (sampling_frequency // 1000) + ms_after = number_samples_after_template_peak / (sampling_frequency // 1000) + + # Used for kilosort 2, 2.5 and 3 + else: + + warnings.warn("Can't extract `ms_before` and `ms_after` from Kilosort output. Guessing a sensible value.") + + samples_in_templates = np.shape(new_templates)[1] + template_extent_ms = (samples_in_templates + 1) / (sampling_frequency // 1000) + ms_before = template_extent_ms / 3 + ms_after = 2 * template_extent_ms / 3 + + params = { + "operators": ["average"], + "ms_before": ms_before, + "ms_after": ms_after, + "peak_sign": "both", + } + + template_extension.params = params + template_extension.run_info = {"run_completed": True} + + sorting_analyzer.extensions["templates"] = template_extension + + +def _compute_unwhitened_templates(whitened_templates, wh_inv): + """Constructs unwhitened templates from whitened_templates, by + applying an inverse whitening matrix.""" + + # templates have dimension (num units) x (num samples) x (num channels) + # whitening inverse has dimension (num units) x (num channels) + # to undo whitening, we need do matrix multiplication on the channel index + unwhitened_templates = np.einsum("ij,klj->kli", wh_inv, whitened_templates) + + return unwhitened_templates