Skip to content

Commit 8cb2daf

Browse files
Add Kilosort output to SortingAnalyzer helper function (#4202)
Co-authored-by: Joe Ziminski <[email protected]>
1 parent 3db639c commit 8cb2daf

File tree

5 files changed

+277
-5
lines changed

5 files changed

+277
-5
lines changed

.github/scripts/test_kilosort4_ci.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from spikeinterface.core.testing import check_sortings_equal
3131
from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter
3232
from probeinterface.io import write_prb
33+
from spikeinterface.extractors import read_kilosort_as_analyzer
3334

3435
import kilosort
3536
from kilosort.parameters import DEFAULT_SETTINGS
@@ -405,12 +406,15 @@ def test_kilosort4_main(self, recording_and_paths, default_kilosort_sorting, tmp
405406
ops = ops.tolist() # strangely this makes a dict
406407
assert ops[param_key] == param_value
407408

408-
# Finally, check out test parameters actually change the output of
409-
# KS4, ensuring our tests are actually doing something (exxcept for some params).
409+
# Check our test parameters actually change the output of
410+
# KS4, ensuring our tests are actually doing something (except for some params).
410411
if param_key not in PARAMETERS_NOT_AFFECTING_RESULTS:
411412
with pytest.raises(AssertionError):
412413
check_sortings_equal(default_kilosort_sorting, sorting_si)
413414

415+
# Check that the kilosort -> analyzer tool doesn't error
416+
analyzer = read_kilosort_as_analyzer(kilosort_output_dir)
417+
414418
def test_clear_cache(self,recording_and_paths, tmp_path):
415419
"""
416420
Test clear_cache parameter in kilosort4.run_kilosort
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
Import Kilosort4 output
2+
=======================
3+
4+
If you have sorted your data with `Kilosort4 <https://github.com/MouseLand/Kilosort>`__, your sorter output is saved in format which was
5+
designed to be compatible with `phy <https://github.com/cortex-lab/phy>`__. SpikeInterface provides a function which can be used to
6+
transform this output into a ``SortingAnalyzer``. This is helpful if you'd like to compute some more properties of your sorting
7+
(e.g. quality and template metrics), or if you'd like to visualize your output using `spikeinterface-gui <https://github.com/SpikeInterface/spikeinterface-gui/>`__.
8+
9+
To create an analyzer from a Kilosort4 output folder, simply run
10+
11+
.. code::
12+
13+
from spikeinterface.extractors import read_kilosort_as_analyzer
14+
sorting_analyzer = read_kilosort_as_analyzer('path/to/output')
15+
16+
The ``'path/to/output'`` should point to the Kilosort4 output folder. If you ran Kilosort4 natively, this is wherever you asked Kilosort4 to
17+
save your output. If you ran Kilosort4 using SpikeInterface, this is in the ``sorter_output`` folder inside the ``output_folder`` created
18+
when you ran ``run_sorter``.
19+
20+
Note: the function ``read_kilosort_as_analyzer`` might work on older versions of Kilosort such as Kilosort2 and Kilosort3.
21+
However, we do not guarantee that the results are correct.
22+
23+
The ``analyzer`` object contains as much information as it can grab from the Kilosort4 output. If everything works, it should contain
24+
information about the ``templates``, ``spike_locations`` and ``spike_amplitudes``. These are stored as ``extensions`` of the ``SortingAnalyzer``.
25+
You can compute extra information about the sorting using the ``compute`` method. For example,
26+
27+
.. code::
28+
29+
sorting_analyzer.compute({
30+
"unit_locations": {},
31+
"correlograms": {},
32+
"template_similarity": {},
33+
"isi_histograms": {},
34+
"template_metrics": {"include_multi_channel_metrics": True},
35+
"quality_metrics": {},
36+
})
37+
38+
widgets.html#available-plotting-functions
39+
40+
Learn more about the ``SortingAnalyzer`` and its ``extensions`` `here <https://spikeinterface.readthedocs.io/en/stable/modules/postprocessing.html>`__.
41+
42+
If you'd like to store the information you've computed, you can save the analyzer:
43+
44+
.. code::
45+
46+
sorting_analyzer.save_as(
47+
format="binary_folder",
48+
folder="my_kilosort_analyzer"
49+
)
50+
51+
You now have a fully functional ``SortingAnalyzer`` - congrats! You can now use `spikeinterface-gui <https://github.com/SpikeInterface/spikeinterface-gui/>`__. to view the results
52+
interactively, or start manually labelling your units to `create an automated curation model <https://spikeinterface.readthedocs.io/en/stable/tutorials_custom_index.html#automated-curation-tutorials>`__.
53+
54+
Note that if you have access to the raw recording, you can attach it to the analyzer, and re-compute extensions from the raw data. E.g.
55+
56+
.. code::
57+
58+
from spikeinterface.extractors import read_kilosort_as_analyzer
59+
import spikeinterface.extractors as se
60+
import spikeinterface.extractors as spre
61+
62+
recording = se.read_openephys('path/to/recording')
63+
64+
preprocessed_recording = spre.bandpass_filter(spre.common_reference(recording))
65+
66+
sorting_analyzer = read_kilosort_as_analyzer('path/to/output')
67+
sorting_analyzer.set_temporary_recording(preprocessed_recording)
68+
69+
sorting_analyzer.compute({
70+
"spike_locations": {},
71+
"spike_amplitudes": {},
72+
"unit_locations": {},
73+
"correlograms": {},
74+
"template_similarity": {},
75+
"isi_histograms": {},
76+
"template_metrics": {"include_multi_channel_metrics": True},
77+
"quality_metrics": {},
78+
})
79+
80+
81+
This will take longer since you are dealing with the raw recording, but you do have a lot of control over how to compute the extensions.

doc/how_to/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to.
2222
benchmark_with_hybrid_recordings
2323
auto_curation_training
2424
auto_curation_prediction
25+
import_kilosort_data

src/spikeinterface/extractors/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
from .toy_example import toy_example as toy_example
44
from .bids import read_bids as read_bids
55

6-
76
from .neuropixels_utils import get_neuropixels_channel_groups, get_neuropixels_sample_shifts
8-
97
from .neoextractors import get_neo_num_blocks, get_neo_streams
8+
from .phykilosortextractors import read_kilosort_as_analyzer
109

1110
from warnings import warn
1211

src/spikeinterface/extractors/phykilosortextractors.py

Lines changed: 188 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,25 @@
22

33
from typing import Optional
44
from pathlib import Path
5+
import warnings
56

67
import numpy as np
78

8-
from spikeinterface.core import BaseSorting, BaseSortingSegment, read_python
9+
from spikeinterface.core import (
10+
BaseSorting,
11+
BaseSortingSegment,
12+
read_python,
13+
generate_ground_truth_recording,
14+
ChannelSparsity,
15+
ComputeTemplates,
16+
create_sorting_analyzer,
17+
SortingAnalyzer,
18+
)
919
from spikeinterface.core.core_tools import define_function_from_class
1020

21+
from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations
22+
from probeinterface import read_prb, Probe
23+
1124

1225
class BasePhyKilosortSortingExtractor(BaseSorting):
1326
"""Base SortingExtractor for Phy and Kilosort output folder.
@@ -302,3 +315,177 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove
302315

303316
read_phy = define_function_from_class(source_class=PhySortingExtractor, name="read_phy")
304317
read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort")
318+
319+
320+
def read_kilosort_as_analyzer(folder_path, unwhiten=True) -> SortingAnalyzer:
321+
"""
322+
Load Kilosort output into a SortingAnalyzer. Output from Kilosort version 4.1 and
323+
above are supported. The function may work on older versions of Kilosort output,
324+
but these are not carefully tested. Please check your output carefully.
325+
326+
Parameters
327+
----------
328+
folder_path : str or Path
329+
Path to the output Phy folder (containing the params.py).
330+
unwhiten : bool, default: True
331+
Unwhiten the templates computed by kilosort.
332+
333+
Returns
334+
-------
335+
sorting_analyzer : SortingAnalyzer
336+
A SortingAnalyzer object.
337+
"""
338+
339+
phy_path = Path(folder_path)
340+
341+
sorting = read_phy(phy_path)
342+
sampling_frequency = sorting.sampling_frequency
343+
344+
# kilosort occasionally contains a few spikes just beyond the recording end point, which can lead
345+
# to errors later. To avoid this, we pad the recording with an extra second of blank time.
346+
duration = sorting._sorting_segments[0]._all_spikes[-1] / sampling_frequency + 1
347+
348+
if (phy_path / "probe.prb").is_file():
349+
probegroup = read_prb(phy_path / "probe.prb")
350+
if len(probegroup.probes) > 0:
351+
warnings.warn("Found more than one probe. Selecting the first probe in ProbeGroup.")
352+
probe = probegroup.probes[0]
353+
elif (phy_path / "channel_positions.npy").is_file():
354+
probe = Probe(si_units="um")
355+
channel_positions = np.load(phy_path / "channel_positions.npy")
356+
probe.set_contacts(channel_positions)
357+
probe.set_device_channel_indices(range(probe.get_contact_count()))
358+
else:
359+
AssertionError(f"Cannot read probe layout from folder {phy_path}.")
360+
361+
# to make the initial analyzer, we'll use a fake recording and set it to None later
362+
recording, _ = generate_ground_truth_recording(
363+
probe=probe,
364+
sampling_frequency=sampling_frequency,
365+
durations=[duration],
366+
num_units=1,
367+
seed=1205,
368+
)
369+
370+
sparsity = _make_sparsity_from_templates(sorting, recording, phy_path)
371+
372+
sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True, sparsity=sparsity)
373+
374+
# first compute random spikes. These do nothing, but are needed for si-gui to run
375+
sorting_analyzer.compute("random_spikes")
376+
377+
_make_templates(sorting_analyzer, phy_path, sparsity.mask, sampling_frequency, unwhiten=unwhiten)
378+
_make_locations(sorting_analyzer, phy_path)
379+
380+
sorting_analyzer._recording = None
381+
return sorting_analyzer
382+
383+
384+
def _make_locations(sorting_analyzer, kilosort_output_path):
385+
"""Constructs a `spike_locations` extension from the amplitudes numpy array
386+
in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`."""
387+
388+
locations_extension = ComputeSpikeLocations(sorting_analyzer)
389+
390+
spike_locations_path = kilosort_output_path / "spike_positions.npy"
391+
if spike_locations_path.is_file():
392+
locs_np = np.load(spike_locations_path)
393+
else:
394+
return
395+
396+
# Check that the spike locations vector is the same size as the spike vector
397+
num_spikes = len(sorting_analyzer.sorting.to_spike_vector())
398+
num_spike_locs = len(locs_np)
399+
if num_spikes != num_spike_locs:
400+
warnings.warn(
401+
"The number of spikes does not match the number of spike locations in `spike_positions.npy`. Skipping spike locations."
402+
)
403+
return
404+
405+
num_dims = len(locs_np[0])
406+
column_names = ["x", "y", "z"][:num_dims]
407+
dtype = [(name, locs_np.dtype) for name in column_names]
408+
409+
structured_array = np.zeros(len(locs_np), dtype=dtype)
410+
for coordinate_index, column_name in enumerate(column_names):
411+
structured_array[column_name] = locs_np[:, coordinate_index]
412+
413+
locations_extension.data = {"spike_locations": structured_array}
414+
locations_extension.params = {}
415+
locations_extension.run_info = {"run_completed": True}
416+
417+
sorting_analyzer.extensions["spike_locations"] = locations_extension
418+
419+
420+
def _make_sparsity_from_templates(sorting, recording, kilosort_output_path):
421+
"""Constructs the `ChannelSparsity` of from kilosort output, by seeing if the
422+
templates output is zero or not on all channels."""
423+
424+
templates = np.load(kilosort_output_path / "templates.npy")
425+
426+
unit_ids = sorting.unit_ids
427+
channel_ids = recording.channel_ids
428+
429+
# The raw templates have dense dimensions (num chan)x(num samples)x(num units)
430+
# but are zero on many channels, which implicitly defines the sparsity
431+
mask = np.sum(np.abs(templates), axis=1) != 0
432+
return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids)
433+
434+
435+
def _make_templates(sorting_analyzer, kilosort_output_path, mask, sampling_frequency, unwhiten=True):
436+
"""Constructs a `templates` extension from the amplitudes numpy array
437+
in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`."""
438+
439+
template_extension = ComputeTemplates(sorting_analyzer)
440+
441+
whitened_templates = np.load(kilosort_output_path / "templates.npy")
442+
wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy")
443+
new_templates = _compute_unwhitened_templates(whitened_templates, wh_inv) if unwhiten else whitened_templates
444+
445+
template_extension.data = {"average": new_templates}
446+
447+
ops_path = kilosort_output_path / "ops.npy"
448+
if ops_path.is_file():
449+
ops = np.load(ops_path, allow_pickle=True)
450+
451+
number_samples_before_template_peak = ops.item(0)["nt0min"]
452+
total_template_samples = ops.item(0)["nt"]
453+
454+
number_samples_after_template_peak = total_template_samples - number_samples_before_template_peak
455+
456+
ms_before = number_samples_before_template_peak / (sampling_frequency // 1000)
457+
ms_after = number_samples_after_template_peak / (sampling_frequency // 1000)
458+
459+
# Used for kilosort 2, 2.5 and 3
460+
else:
461+
462+
warnings.warn("Can't extract `ms_before` and `ms_after` from Kilosort output. Guessing a sensible value.")
463+
464+
samples_in_templates = np.shape(new_templates)[1]
465+
template_extent_ms = (samples_in_templates + 1) / (sampling_frequency // 1000)
466+
ms_before = template_extent_ms / 3
467+
ms_after = 2 * template_extent_ms / 3
468+
469+
params = {
470+
"operators": ["average"],
471+
"ms_before": ms_before,
472+
"ms_after": ms_after,
473+
"peak_sign": "both",
474+
}
475+
476+
template_extension.params = params
477+
template_extension.run_info = {"run_completed": True}
478+
479+
sorting_analyzer.extensions["templates"] = template_extension
480+
481+
482+
def _compute_unwhitened_templates(whitened_templates, wh_inv):
483+
"""Constructs unwhitened templates from whitened_templates, by
484+
applying an inverse whitening matrix."""
485+
486+
# templates have dimension (num units) x (num samples) x (num channels)
487+
# whitening inverse has dimension (num units) x (num channels)
488+
# to undo whitening, we need do matrix multiplication on the channel index
489+
unwhitened_templates = np.einsum("ij,klj->kli", wh_inv, whitened_templates)
490+
491+
return unwhitened_templates

0 commit comments

Comments
 (0)