|
2 | 2 |
|
3 | 3 | from typing import Optional |
4 | 4 | from pathlib import Path |
| 5 | +import warnings |
5 | 6 |
|
6 | 7 | import numpy as np |
7 | 8 |
|
8 | | -from spikeinterface.core import BaseSorting, BaseSortingSegment, read_python |
| 9 | +from spikeinterface.core import ( |
| 10 | + BaseSorting, |
| 11 | + BaseSortingSegment, |
| 12 | + read_python, |
| 13 | + generate_ground_truth_recording, |
| 14 | + ChannelSparsity, |
| 15 | + ComputeTemplates, |
| 16 | + create_sorting_analyzer, |
| 17 | + SortingAnalyzer, |
| 18 | +) |
9 | 19 | from spikeinterface.core.core_tools import define_function_from_class |
10 | 20 |
|
| 21 | +from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations |
| 22 | +from probeinterface import read_prb, Probe |
| 23 | + |
11 | 24 |
|
12 | 25 | class BasePhyKilosortSortingExtractor(BaseSorting): |
13 | 26 | """Base SortingExtractor for Phy and Kilosort output folder. |
@@ -302,3 +315,177 @@ def __init__(self, folder_path: Path | str, keep_good_only: bool = False, remove |
302 | 315 |
|
303 | 316 | read_phy = define_function_from_class(source_class=PhySortingExtractor, name="read_phy") |
304 | 317 | read_kilosort = define_function_from_class(source_class=KiloSortSortingExtractor, name="read_kilosort") |
| 318 | + |
| 319 | + |
| 320 | +def read_kilosort_as_analyzer(folder_path, unwhiten=True) -> SortingAnalyzer: |
| 321 | + """ |
| 322 | + Load Kilosort output into a SortingAnalyzer. Output from Kilosort version 4.1 and |
| 323 | + above are supported. The function may work on older versions of Kilosort output, |
| 324 | + but these are not carefully tested. Please check your output carefully. |
| 325 | +
|
| 326 | + Parameters |
| 327 | + ---------- |
| 328 | + folder_path : str or Path |
| 329 | + Path to the output Phy folder (containing the params.py). |
| 330 | + unwhiten : bool, default: True |
| 331 | + Unwhiten the templates computed by kilosort. |
| 332 | +
|
| 333 | + Returns |
| 334 | + ------- |
| 335 | + sorting_analyzer : SortingAnalyzer |
| 336 | + A SortingAnalyzer object. |
| 337 | + """ |
| 338 | + |
| 339 | + phy_path = Path(folder_path) |
| 340 | + |
| 341 | + sorting = read_phy(phy_path) |
| 342 | + sampling_frequency = sorting.sampling_frequency |
| 343 | + |
| 344 | + # kilosort occasionally contains a few spikes just beyond the recording end point, which can lead |
| 345 | + # to errors later. To avoid this, we pad the recording with an extra second of blank time. |
| 346 | + duration = sorting._sorting_segments[0]._all_spikes[-1] / sampling_frequency + 1 |
| 347 | + |
| 348 | + if (phy_path / "probe.prb").is_file(): |
| 349 | + probegroup = read_prb(phy_path / "probe.prb") |
| 350 | + if len(probegroup.probes) > 0: |
| 351 | + warnings.warn("Found more than one probe. Selecting the first probe in ProbeGroup.") |
| 352 | + probe = probegroup.probes[0] |
| 353 | + elif (phy_path / "channel_positions.npy").is_file(): |
| 354 | + probe = Probe(si_units="um") |
| 355 | + channel_positions = np.load(phy_path / "channel_positions.npy") |
| 356 | + probe.set_contacts(channel_positions) |
| 357 | + probe.set_device_channel_indices(range(probe.get_contact_count())) |
| 358 | + else: |
| 359 | + AssertionError(f"Cannot read probe layout from folder {phy_path}.") |
| 360 | + |
| 361 | + # to make the initial analyzer, we'll use a fake recording and set it to None later |
| 362 | + recording, _ = generate_ground_truth_recording( |
| 363 | + probe=probe, |
| 364 | + sampling_frequency=sampling_frequency, |
| 365 | + durations=[duration], |
| 366 | + num_units=1, |
| 367 | + seed=1205, |
| 368 | + ) |
| 369 | + |
| 370 | + sparsity = _make_sparsity_from_templates(sorting, recording, phy_path) |
| 371 | + |
| 372 | + sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True, sparsity=sparsity) |
| 373 | + |
| 374 | + # first compute random spikes. These do nothing, but are needed for si-gui to run |
| 375 | + sorting_analyzer.compute("random_spikes") |
| 376 | + |
| 377 | + _make_templates(sorting_analyzer, phy_path, sparsity.mask, sampling_frequency, unwhiten=unwhiten) |
| 378 | + _make_locations(sorting_analyzer, phy_path) |
| 379 | + |
| 380 | + sorting_analyzer._recording = None |
| 381 | + return sorting_analyzer |
| 382 | + |
| 383 | + |
| 384 | +def _make_locations(sorting_analyzer, kilosort_output_path): |
| 385 | + """Constructs a `spike_locations` extension from the amplitudes numpy array |
| 386 | + in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" |
| 387 | + |
| 388 | + locations_extension = ComputeSpikeLocations(sorting_analyzer) |
| 389 | + |
| 390 | + spike_locations_path = kilosort_output_path / "spike_positions.npy" |
| 391 | + if spike_locations_path.is_file(): |
| 392 | + locs_np = np.load(spike_locations_path) |
| 393 | + else: |
| 394 | + return |
| 395 | + |
| 396 | + # Check that the spike locations vector is the same size as the spike vector |
| 397 | + num_spikes = len(sorting_analyzer.sorting.to_spike_vector()) |
| 398 | + num_spike_locs = len(locs_np) |
| 399 | + if num_spikes != num_spike_locs: |
| 400 | + warnings.warn( |
| 401 | + "The number of spikes does not match the number of spike locations in `spike_positions.npy`. Skipping spike locations." |
| 402 | + ) |
| 403 | + return |
| 404 | + |
| 405 | + num_dims = len(locs_np[0]) |
| 406 | + column_names = ["x", "y", "z"][:num_dims] |
| 407 | + dtype = [(name, locs_np.dtype) for name in column_names] |
| 408 | + |
| 409 | + structured_array = np.zeros(len(locs_np), dtype=dtype) |
| 410 | + for coordinate_index, column_name in enumerate(column_names): |
| 411 | + structured_array[column_name] = locs_np[:, coordinate_index] |
| 412 | + |
| 413 | + locations_extension.data = {"spike_locations": structured_array} |
| 414 | + locations_extension.params = {} |
| 415 | + locations_extension.run_info = {"run_completed": True} |
| 416 | + |
| 417 | + sorting_analyzer.extensions["spike_locations"] = locations_extension |
| 418 | + |
| 419 | + |
| 420 | +def _make_sparsity_from_templates(sorting, recording, kilosort_output_path): |
| 421 | + """Constructs the `ChannelSparsity` of from kilosort output, by seeing if the |
| 422 | + templates output is zero or not on all channels.""" |
| 423 | + |
| 424 | + templates = np.load(kilosort_output_path / "templates.npy") |
| 425 | + |
| 426 | + unit_ids = sorting.unit_ids |
| 427 | + channel_ids = recording.channel_ids |
| 428 | + |
| 429 | + # The raw templates have dense dimensions (num chan)x(num samples)x(num units) |
| 430 | + # but are zero on many channels, which implicitly defines the sparsity |
| 431 | + mask = np.sum(np.abs(templates), axis=1) != 0 |
| 432 | + return ChannelSparsity(mask, unit_ids=unit_ids, channel_ids=channel_ids) |
| 433 | + |
| 434 | + |
| 435 | +def _make_templates(sorting_analyzer, kilosort_output_path, mask, sampling_frequency, unwhiten=True): |
| 436 | + """Constructs a `templates` extension from the amplitudes numpy array |
| 437 | + in `kilosort_output_path`, and attaches the extension to the `sorting_analyzer`.""" |
| 438 | + |
| 439 | + template_extension = ComputeTemplates(sorting_analyzer) |
| 440 | + |
| 441 | + whitened_templates = np.load(kilosort_output_path / "templates.npy") |
| 442 | + wh_inv = np.load(kilosort_output_path / "whitening_mat_inv.npy") |
| 443 | + new_templates = _compute_unwhitened_templates(whitened_templates, wh_inv) if unwhiten else whitened_templates |
| 444 | + |
| 445 | + template_extension.data = {"average": new_templates} |
| 446 | + |
| 447 | + ops_path = kilosort_output_path / "ops.npy" |
| 448 | + if ops_path.is_file(): |
| 449 | + ops = np.load(ops_path, allow_pickle=True) |
| 450 | + |
| 451 | + number_samples_before_template_peak = ops.item(0)["nt0min"] |
| 452 | + total_template_samples = ops.item(0)["nt"] |
| 453 | + |
| 454 | + number_samples_after_template_peak = total_template_samples - number_samples_before_template_peak |
| 455 | + |
| 456 | + ms_before = number_samples_before_template_peak / (sampling_frequency // 1000) |
| 457 | + ms_after = number_samples_after_template_peak / (sampling_frequency // 1000) |
| 458 | + |
| 459 | + # Used for kilosort 2, 2.5 and 3 |
| 460 | + else: |
| 461 | + |
| 462 | + warnings.warn("Can't extract `ms_before` and `ms_after` from Kilosort output. Guessing a sensible value.") |
| 463 | + |
| 464 | + samples_in_templates = np.shape(new_templates)[1] |
| 465 | + template_extent_ms = (samples_in_templates + 1) / (sampling_frequency // 1000) |
| 466 | + ms_before = template_extent_ms / 3 |
| 467 | + ms_after = 2 * template_extent_ms / 3 |
| 468 | + |
| 469 | + params = { |
| 470 | + "operators": ["average"], |
| 471 | + "ms_before": ms_before, |
| 472 | + "ms_after": ms_after, |
| 473 | + "peak_sign": "both", |
| 474 | + } |
| 475 | + |
| 476 | + template_extension.params = params |
| 477 | + template_extension.run_info = {"run_completed": True} |
| 478 | + |
| 479 | + sorting_analyzer.extensions["templates"] = template_extension |
| 480 | + |
| 481 | + |
| 482 | +def _compute_unwhitened_templates(whitened_templates, wh_inv): |
| 483 | + """Constructs unwhitened templates from whitened_templates, by |
| 484 | + applying an inverse whitening matrix.""" |
| 485 | + |
| 486 | + # templates have dimension (num units) x (num samples) x (num channels) |
| 487 | + # whitening inverse has dimension (num units) x (num channels) |
| 488 | + # to undo whitening, we need do matrix multiplication on the channel index |
| 489 | + unwhitened_templates = np.einsum("ij,klj->kli", wh_inv, whitened_templates) |
| 490 | + |
| 491 | + return unwhitened_templates |
0 commit comments