-
Notifications
You must be signed in to change notification settings - Fork 237
Add Kilosort output to SortingAnalyzer helper function #4202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
1f1c282
1162428
65d6b21
8105912
39aa298
10ed2c4
8fe91b7
0659891
0f1a2c5
369f3ee
14a432c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| Import Kilosort4 output | ||
| ======================= | ||
|
|
||
| If you have sorted your data with `Kilosort4 <https://github.com/MouseLand/Kilosort>`__, your sorter output is saved in format which was | ||
| designed to be compatible with `phy <https://github.com/cortex-lab/phy>`__. SpikeInterface provides a function which can be used to | ||
| transform this output into a ``SortingAnalyzer``. This is helpful if you'd like to compute some more properties of your sorting | ||
| (e.g. quality and template metrics), or if you'd like to visualize your output using `spikeinterface-gui <https://github.com/SpikeInterface/spikeinterface-gui/>`__. | ||
|
|
||
| To create an analyzer from a Kilosort4 output folder, simply run | ||
|
|
||
| .. code:: | ||
|
|
||
| from spikeinterface.extractors import kilosort_output_to_analyzer | ||
| sorting_analyzer = kilosort_output_to_analyzer('path/to/output') | ||
|
|
||
| The ``'path/to/output'`` should point to the Kilosort4 output folder. If you ran Kilosort4 natively, this is wherever you asked Kilosort4 to | ||
| save your output. If you ran Kilosort4 using SpikeInterface, this is in the ``sorter_output`` folder inside the ``output_folder`` created | ||
| when you ran ``run_sorter``. | ||
|
|
||
| The ``analyzer`` object contains as much information as it can grab from the Kilosort4 output. If everything works, it should contain | ||
| information about the ``templates``, ``spike_locations`` and ``spike_amplitudes``. These are stored as ``extensions`` of the ``SortingAnalyzer``. | ||
| You can compute extra information about the sorting using the ``compute`` method. For example, | ||
|
|
||
| .. code:: | ||
|
|
||
| sorting_analyzer.compute({ | ||
| "unit_locations": {}, | ||
| "correlograms": {}, | ||
| "template_similarity": {}, | ||
| "isi_histograms": {}, | ||
| "template_metrics": {include_multi_channel_metrics: True}, | ||
| "quality_metrics": {}, | ||
| }) | ||
|
|
||
| widgets.html#available-plotting-functions | ||
|
|
||
| Learn more about the ``SortingAnalyzer`` and its ``extensions`` `here <https://spikeinterface.readthedocs.io/en/stable/modules/postprocessing.html>`__. | ||
|
|
||
| If you'd like to store the information you've computed, you can save the analyzer: | ||
|
|
||
| .. code:: | ||
|
|
||
| sorting_analyzer.save_as( | ||
| format="binary_folder", | ||
| folder="my_kilosort_analyzer" | ||
| ) | ||
|
|
||
| You now have a fully functional ``SortingAnalyzer`` - congrats! You can now use `spikeinterface-gui <https://github.com/SpikeInterface/spikeinterface-gui/>`__. to view the results | ||
| interactively, or start manually labelling your units to `create an automated curation model <https://spikeinterface.readthedocs.io/en/stable/tutorials_custom_index.html#automated-curation-tutorials>`__. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,12 +2,25 @@ | |
|
|
||
| from typing import Optional | ||
| from pathlib import Path | ||
| import warnings | ||
|
|
||
| import numpy as np | ||
|
|
||
| from spikeinterface.core import BaseSorting, BaseSortingSegment, read_python | ||
| from spikeinterface.core import ( | ||
| BaseSorting, | ||
| BaseSortingSegment, | ||
| read_python, | ||
| generate_ground_truth_recording, | ||
| ChannelSparsity, | ||
| ComputeTemplates, | ||
| create_sorting_analyzer, | ||
| SortingAnalyzer, | ||
| ) | ||
| from spikeinterface.core.core_tools import define_function_from_class | ||
|
|
||
| from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations | ||
| from probeinterface import read_prb, Probe | ||
|
|
||
|
|
||
| class BasePhyKilosortSortingExtractor(BaseSorting): | ||
| """Base SortingExtractor for Phy and Kilosort output folder. | ||
|
|
@@ -302,3 +315,195 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove | |
|
|
||
| read_phy = define_function_from_class(source_class=PhySortingExtractor, name="read_phy") | ||
| read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort") | ||
|
|
||
|
|
||
| def kilosort_output_to_analyzer(folder_path, compute_extras=False, unwhiten=True) -> SortingAnalyzer: | ||
| """ | ||
| Load kilosort output into a SortingAnalyzer. | ||
| Output from kilosort version 4.1 and above are supported. | ||
|
||
|
|
||
| Parameters | ||
| ---------- | ||
| folder_path : str or Path | ||
| Path to the output Phy folder (containing the params.py). | ||
| unwhiten : bool, default: True | ||
| Unwhiten the templates computed by kilosort. | ||
|
|
||
| Returns | ||
| ------- | ||
| sorting_analyzer : SortingAnalyzer | ||
| A SortingAnalyzer object. | ||
| """ | ||
|
|
||
| phy_path = Path(folder_path) | ||
|
|
||
| guessed_kilosort_version = _guess_kilosort_version(phy_path) | ||
|
|
||
| sorting = read_phy(phy_path) | ||
| sampling_frequency = sorting.sampling_frequency | ||
|
|
||
| # kilosort occasionally contains a few spikes beyond the recording end point, which can lead | ||
| # to errors later. To avoid this, we pad the recording with an extra second of blank time. | ||
| duration = sorting._sorting_segments[0]._all_spikes[-1] / sampling_frequency + 1 | ||
chrishalcrow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if (phy_path / "probe.prb").is_file(): | ||
| probegroup = read_prb(phy_path / "probe.prb") | ||
| if len(probegroup.probes) > 0: | ||
| warnings.warn("Found more than one probe. Selecting the first probe in ProbeGroup.") | ||
| probe = probegroup.probes[0] | ||
chrishalcrow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| elif (phy_path / "channel_positions.npy").is_file(): | ||
| probe = Probe(si_units="um") | ||
| channel_positions = np.load(phy_path / "channel_positions.npy") | ||
| probe.set_contacts(channel_positions) | ||
| probe.set_device_channel_indices(range(probe.get_contact_count())) | ||
| else: | ||
| AssertionError(f"Cannot read probe layout from folder {phy_path}.") | ||
|
|
||
| # to make the initial analyzer, we'll use a fake recording and set it to None later | ||
| recording, _ = generate_ground_truth_recording( | ||
| probe=probe, sampling_frequency=sampling_frequency, durations=[duration] | ||
| ) | ||
chrishalcrow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| sparsity = _make_sparsity_from_templates(sorting, recording, phy_path) | ||
|
|
||
| sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True, sparsity=sparsity) | ||
|
|
||
| # first compute random spikes. These do nothing, but are needed for si-gui to run | ||
| sorting_analyzer.compute("random_spikes") | ||
chrishalcrow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| _make_templates(sorting_analyzer, phy_path, sparsity.mask, sampling_frequency, unwhiten=unwhiten) | ||
| _make_locations(sorting_analyzer, phy_path) | ||
| _make_amplitudes(sorting_analyzer, phy_path) | ||
|
|
||
| sorting_analyzer._recording = None | ||
| return sorting_analyzer | ||
|
|
||
|
|
||
| def _guess_kilosort_version(kilosort_path) -> tuple: | ||
| """ | ||
| Guesses the kilosort version based on the files which exist in folder `kilosort_path`. | ||
| If unknown, returns minimum guessed version. | ||
|
|
||
| Returns | ||
| ------- | ||
| version_number : tuple | ||
| Version number in the form (major, minor, patch) | ||
| """ | ||
|
|
||
| kilosort_log_file = Path(kilosort_path / "kilosort4.log") | ||
|
|
||
| if kilosort_log_file.is_file(): | ||
| return (4, 0, 33) | ||
|
||
| else: | ||
| return (2, 0, 0) | ||
|
|
||
|
|
||
| def _make_amplitudes(sorting_analyzer, kilosort_output_path): | ||
| """Constructs a `spike_amplitudes` extension from the amplitudes numpy array | ||
| in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" | ||
|
|
||
| amplitudes_extension = ComputeSpikeAmplitudes(sorting_analyzer) | ||
|
|
||
| amps_np = np.load(kilosort_output_path / "amplitudes.npy") | ||
|
|
||
| amplitudes_extension.data = {"amplitudes": amps_np} | ||
| amplitudes_extension.params = {"peak_sign": "neg"} | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| amplitudes_extension.run_info = {"run_completed": True} | ||
|
|
||
| sorting_analyzer.extensions["spike_amplitudes"] = amplitudes_extension | ||
|
|
||
|
|
||
| def _make_locations(sorting_analyzer, kilosort_output_path): | ||
| """Constructs a `spike_locations` extension from the amplitudes numpy array | ||
| in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" | ||
|
|
||
| locations_extension = ComputeSpikeLocations(sorting_analyzer) | ||
|
|
||
| locs_np = np.load(kilosort_output_path / "spike_positions.npy") | ||
|
|
||
| num_dims = len(locs_np[0]) | ||
| column_names = ["x", "y", "z"][:num_dims] | ||
| dtype = [(name, locs_np.dtype) for name in column_names] | ||
|
|
||
| structured_array = np.zeros(len(locs_np), dtype=dtype) | ||
| for coordinate_index, column_name in enumerate(column_names): | ||
| structured_array[column_name] = locs_np[:, coordinate_index] | ||
|
|
||
| locations_extension.data = {"spike_locations": structured_array} | ||
| locations_extension.params = {} | ||
| locations_extension.run_info = {"run_completed": True} | ||
chrishalcrow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| sorting_analyzer.extensions["spike_locations"] = locations_extension | ||
|
|
||
|
|
||
| def _make_sparsity_from_templates(sorting, recording, kilosort_output_path): | ||
| """Constructs the `ChannelSparsity` of from kilosort output, by seeing if the | ||
| templates output is zero or not on all channels.""" | ||
|
|
||
| templates = np.load(kilosort_output_path / "templates.npy") | ||
|
|
||
| unit_ids = sorting.unit_ids | ||
| channel_ids = recording.channel_ids | ||
|
|
||
| # The raw templates have dense dimensions (num chan)x(num units) | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # but are zero on many channels, which implicitly defines the sparsity | ||
| mask = np.sum(np.abs(templates), axis=1) != 0 | ||
| return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids) | ||
|
|
||
|
|
||
| def _make_templates(sorting_analyzer, kilosort_output_path, mask, sampling_frequency, unwhiten=True): | ||
| """Constructs a `templates` extension from the amplitudes numpy array | ||
| in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" | ||
|
|
||
| template_extension = ComputeTemplates(sorting_analyzer) | ||
|
|
||
| whitened_templates = np.load(kilosort_output_path / "templates.npy") | ||
| wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") | ||
| new_templates = _compute_unwhitened_templates(whitened_templates, wh_inv, mask) if unwhiten else whitened_templates | ||
|
|
||
| template_extension.data = {"average": new_templates} | ||
|
|
||
| ops_path = kilosort_output_path / "ops.npy" | ||
| if ops_path.is_file(): | ||
chrishalcrow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ops = np.load(ops_path, allow_pickle=True) | ||
|
|
||
| number_samples_before_template_peak = ops.item(0)["nt0min"] | ||
| total_template_samples = ops.item(0)["nt"] | ||
|
|
||
| number_samples_after_template_peak = total_template_samples - number_samples_before_template_peak | ||
|
|
||
| ms_before = number_samples_before_template_peak / (sampling_frequency // 1000) | ||
| ms_after = number_samples_after_template_peak / (sampling_frequency // 1000) | ||
|
|
||
| params = { | ||
| "operators": ["average"], | ||
| "ms_before": ms_before, | ||
| "ms_after": ms_after, | ||
| "peak_sign": "neg", | ||
| } | ||
|
|
||
| template_extension.params = params | ||
| template_extension.run_info = {"run_completed": True} | ||
|
|
||
| sorting_analyzer.extensions["templates"] = template_extension | ||
|
|
||
|
|
||
| def _compute_unwhitened_templates(whitened_templates, wh_inv, mask): | ||
| """Constructs unwhitened templates from whitened_templates, by | ||
| applying an inverse whitening matrix.""" | ||
|
|
||
| template_shape = np.shape(whitened_templates) | ||
| new_templates = np.zeros(template_shape) | ||
|
|
||
| sparsity_channel_ids = [np.arange(template_shape[-1])[unit_sparsity] for unit_sparsity in mask] | ||
|
|
||
| for unit_index, channel_indices in enumerate(sparsity_channel_ids): | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for channel_index_1 in channel_indices: | ||
| for channel_index_2 in channel_indices: | ||
| # templates have dimension unit_index x sample_index x channel_index | ||
| # to undo whitening, we need do matrix multiplication on the channel_index | ||
| new_templates[unit_index, :, channel_index_1] += ( | ||
| wh_inv[channel_index_1, channel_index_2] * whitened_templates[unit_index, :, channel_index_2] | ||
| ) | ||
chrishalcrow marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| return new_templates | ||
Uh oh!
There was an error while loading. Please reload this page.