Skip to content
49 changes: 49 additions & 0 deletions doc/how_to/import_kilosort_data.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
Import Kilosort4 output
=======================

If you have sorted your data with `Kilosort4 <https://github.com/MouseLand/Kilosort>`__, your sorter output is saved in format which was
designed to be compatible with `phy <https://github.com/cortex-lab/phy>`__. SpikeInterface provides a function which can be used to
transform this output into a ``SortingAnalyzer``. This is helpful if you'd like to compute some more properties of your sorting
(e.g. quality and template metrics), or if you'd like to visualize your output using `spikeinterface-gui <https://github.com/SpikeInterface/spikeinterface-gui/>`__.

To create an analyzer from a Kilosort4 output folder, simply run

.. code::

from spikeinterface.extractors import kilosort_output_to_analyzer
sorting_analyzer = kilosort_output_to_analyzer('path/to/output')

The ``'path/to/output'`` should point to the Kilosort4 output folder. If you ran Kilosort4 natively, this is wherever you asked Kilosort4 to
save your output. If you ran Kilosort4 using SpikeInterface, this is in the ``sorter_output`` folder inside the ``output_folder`` created
when you ran ``run_sorter``.

The ``analyzer`` object contains as much information as it can grab from the Kilosort4 output. If everything works, it should contain
information about the ``templates``, ``spike_locations`` and ``spike_amplitudes``. These are stored as ``extensions`` of the ``SortingAnalyzer``.
You can compute extra information about the sorting using the ``compute`` method. For example,

.. code::

sorting_analyzer.compute({
"unit_locations": {},
"correlograms": {},
"template_similarity": {},
"isi_histograms": {},
"template_metrics": {include_multi_channel_metrics: True},
"quality_metrics": {},
})

widgets.html#available-plotting-functions

Learn more about the ``SortingAnalyzer`` and its ``extensions`` `here <https://spikeinterface.readthedocs.io/en/stable/modules/postprocessing.html>`__.

If you'd like to store the information you've computed, you can save the analyzer:

.. code::

sorting_analyzer.save_as(
format="binary_folder",
folder="my_kilosort_analyzer"
)

You now have a fully functional ``SortingAnalyzer`` - congrats! You can now use `spikeinterface-gui <https://github.com/SpikeInterface/spikeinterface-gui/>`__. to view the results
interactively, or start manually labelling your units to `create an automated curation model <https://spikeinterface.readthedocs.io/en/stable/tutorials_custom_index.html#automated-curation-tutorials>`__.
1 change: 1 addition & 0 deletions doc/how_to/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to.
benchmark_with_hybrid_recordings
auto_curation_training
auto_curation_prediction
import_kilosort_data
207 changes: 206 additions & 1 deletion src/spikeinterface/extractors/phykilosortextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,25 @@

from typing import Optional
from pathlib import Path
import warnings

import numpy as np

from spikeinterface.core import BaseSorting, BaseSortingSegment, read_python
from spikeinterface.core import (
BaseSorting,
BaseSortingSegment,
read_python,
generate_ground_truth_recording,
ChannelSparsity,
ComputeTemplates,
create_sorting_analyzer,
SortingAnalyzer,
)
from spikeinterface.core.core_tools import define_function_from_class

from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations
from probeinterface import read_prb, Probe


class BasePhyKilosortSortingExtractor(BaseSorting):
"""Base SortingExtractor for Phy and Kilosort output folder.
Expand Down Expand Up @@ -302,3 +315,195 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove

read_phy = define_function_from_class(source_class=PhySortingExtractor, name="read_phy")
read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort")


def kilosort_output_to_analyzer(folder_path, compute_extras=False, unwhiten=True) -> SortingAnalyzer:
"""
Load kilosort output into a SortingAnalyzer.
Output from kilosort version 4.1 and above are supported.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to check the version from the output? If not we should ask them to add on the KS repo as this would be a useful general addition. But maybe we could check that the kilosortX.log is not < 4? (IIRC that the logs are formatted in this way)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have asked on KiloSort

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you've run directly from kilosort, you'll have the version in kilosort4.log. So we could check there, and if it's not there, we have a guess...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great that sounds good, if this tool is only for kilosort4 then just checking for the existing of that log file should do (unless it's extended to other versions)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, the log files only appeared at v4.0.33. Thinking of other ways to check...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know of any files which KS2/2.5 defo don't have in their output, that KS4 does?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, I've added a _guess_kilosort_version function, to isolate this logic. Let's compare some outputs and see if we can make it reasonable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like everything works fine for all versions (i.e. the probes, templates and spike locations are saved in the same way for all versions). So in the end we don't need to change behavior depending on version. We would have to if we got pcas working, but that's not gonna happen in this PR.


Parameters
----------
folder_path : str or Path
Path to the output Phy folder (containing the params.py).
unwhiten : bool, default: True
Unwhiten the templates computed by kilosort.

Returns
-------
sorting_analyzer : SortingAnalyzer
A SortingAnalyzer object.
"""

phy_path = Path(folder_path)

guessed_kilosort_version = _guess_kilosort_version(phy_path)

sorting = read_phy(phy_path)
sampling_frequency = sorting.sampling_frequency

# kilosort occasionally contains a few spikes beyond the recording end point, which can lead
# to errors later. To avoid this, we pad the recording with an extra second of blank time.
duration = sorting._sorting_segments[0]._all_spikes[-1] / sampling_frequency + 1

if (phy_path / "probe.prb").is_file():
probegroup = read_prb(phy_path / "probe.prb")
if len(probegroup.probes) > 0:
warnings.warn("Found more than one probe. Selecting the first probe in ProbeGroup.")
probe = probegroup.probes[0]
elif (phy_path / "channel_positions.npy").is_file():
probe = Probe(si_units="um")
channel_positions = np.load(phy_path / "channel_positions.npy")
probe.set_contacts(channel_positions)
probe.set_device_channel_indices(range(probe.get_contact_count()))
else:
AssertionError(f"Cannot read probe layout from folder {phy_path}.")

# to make the initial analyzer, we'll use a fake recording and set it to None later
recording, _ = generate_ground_truth_recording(
probe=probe, sampling_frequency=sampling_frequency, durations=[duration]
)

sparsity = _make_sparsity_from_templates(sorting, recording, phy_path)

sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True, sparsity=sparsity)

# first compute random spikes. These do nothing, but are needed for si-gui to run
sorting_analyzer.compute("random_spikes")

_make_templates(sorting_analyzer, phy_path, sparsity.mask, sampling_frequency, unwhiten=unwhiten)
_make_locations(sorting_analyzer, phy_path)
_make_amplitudes(sorting_analyzer, phy_path)

sorting_analyzer._recording = None
return sorting_analyzer


def _guess_kilosort_version(kilosort_path) -> tuple:
"""
Guesses the kilosort version based on the files which exist in folder `kilosort_path`.
If unknown, returns minimum guessed version.

Returns
-------
version_number : tuple
Version number in the form (major, minor, patch)
"""

kilosort_log_file = Path(kilosort_path / "kilosort4.log")

if kilosort_log_file.is_file():
return (4, 0, 33)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you want to parse the file here no ?
Also we check that the kilosort folder was lauch by spikeinterface and use the spikeinterface log to gete the version no ?

Copy link
Member Author

@chrishalcrow chrishalcrow Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the idea is that we use whatever we can to figure out the kilosort version.

However, it seems like everything works fine for all versions (i.e. the probes, templates and spike locations are saved in the same way for all versions). So in the end we don't need to change behavior depending on version.

else:
return (2, 0, 0)


def _make_amplitudes(sorting_analyzer, kilosort_output_path):
"""Constructs a `spike_amplitudes` extension from the amplitudes numpy array
in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`."""

amplitudes_extension = ComputeSpikeAmplitudes(sorting_analyzer)

amps_np = np.load(kilosort_output_path / "amplitudes.npy")

amplitudes_extension.data = {"amplitudes": amps_np}
amplitudes_extension.params = {"peak_sign": "neg"}
amplitudes_extension.run_info = {"run_completed": True}

sorting_analyzer.extensions["spike_amplitudes"] = amplitudes_extension


def _make_locations(sorting_analyzer, kilosort_output_path):
"""Constructs a `spike_locations` extension from the amplitudes numpy array
in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`."""

locations_extension = ComputeSpikeLocations(sorting_analyzer)

locs_np = np.load(kilosort_output_path / "spike_positions.npy")

num_dims = len(locs_np[0])
column_names = ["x", "y", "z"][:num_dims]
dtype = [(name, locs_np.dtype) for name in column_names]

structured_array = np.zeros(len(locs_np), dtype=dtype)
for coordinate_index, column_name in enumerate(column_names):
structured_array[column_name] = locs_np[:, coordinate_index]

locations_extension.data = {"spike_locations": structured_array}
locations_extension.params = {}
locations_extension.run_info = {"run_completed": True}

sorting_analyzer.extensions["spike_locations"] = locations_extension


def _make_sparsity_from_templates(sorting, recording, kilosort_output_path):
"""Constructs the `ChannelSparsity` of from kilosort output, by seeing if the
templates output is zero or not on all channels."""

templates = np.load(kilosort_output_path / "templates.npy")

unit_ids = sorting.unit_ids
channel_ids = recording.channel_ids

# The raw templates have dense dimensions (num chan)x(num units)
# but are zero on many channels, which implicitly defines the sparsity
mask = np.sum(np.abs(templates), axis=1) != 0
return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids)


def _make_templates(sorting_analyzer, kilosort_output_path, mask, sampling_frequency, unwhiten=True):
"""Constructs a `templates` extension from the amplitudes numpy array
in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`."""

template_extension = ComputeTemplates(sorting_analyzer)

whitened_templates = np.load(kilosort_output_path / "templates.npy")
wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy")
new_templates = _compute_unwhitened_templates(whitened_templates, wh_inv, mask) if unwhiten else whitened_templates

template_extension.data = {"average": new_templates}

ops_path = kilosort_output_path / "ops.npy"
if ops_path.is_file():
ops = np.load(ops_path, allow_pickle=True)

number_samples_before_template_peak = ops.item(0)["nt0min"]
total_template_samples = ops.item(0)["nt"]

number_samples_after_template_peak = total_template_samples - number_samples_before_template_peak

ms_before = number_samples_before_template_peak / (sampling_frequency // 1000)
ms_after = number_samples_after_template_peak / (sampling_frequency // 1000)

params = {
"operators": ["average"],
"ms_before": ms_before,
"ms_after": ms_after,
"peak_sign": "neg",
}

template_extension.params = params
template_extension.run_info = {"run_completed": True}

sorting_analyzer.extensions["templates"] = template_extension


def _compute_unwhitened_templates(whitened_templates, wh_inv, mask):
"""Constructs unwhitened templates from whitened_templates, by
applying an inverse whitening matrix."""

template_shape = np.shape(whitened_templates)
new_templates = np.zeros(template_shape)

sparsity_channel_ids = [np.arange(template_shape[-1])[unit_sparsity] for unit_sparsity in mask]

for unit_index, channel_indices in enumerate(sparsity_channel_ids):
for channel_index_1 in channel_indices:
for channel_index_2 in channel_indices:
# templates have dimension unit_index x sample_index x channel_index
# to undo whitening, we need do matrix multiplication on the channel_index
new_templates[unit_index, :, channel_index_1] += (
wh_inv[channel_index_1, channel_index_2] * whitened_templates[unit_index, :, channel_index_2]
)

return new_templates