diff --git a/doc/modules/curation.rst b/doc/modules/curation.rst index 300c68a5e4..ef8fc54f2a 100644 --- a/doc/modules/curation.rst +++ b/doc/modules/curation.rst @@ -206,7 +206,7 @@ Curation format SpikeInterface internally supports a JSON-based manual curation format. When manual curation is necessary, modifying a dataset in place is a bad practice. Instead, to ensure the reproducibility of the spike sorting pipelines, we have introduced a simple and JSON-based manual curation format. -This format defines at the moment : merges + deletions + manual tags. +This format defines at the moment : manual labelling, removal, merging, splitting and removal of spikes from a unit. The simple file can be kept along side the output of a sorter and applied on the result to have a "clean" result. This format has two part: @@ -216,21 +216,26 @@ This format has two part: * "format_version" : format specification * "unit_ids" : the list of unit_ds * "label_definitions" : list of label categories and possible labels per category. - Every category can be *exclusive=True* onely one label or *exclusive=False* several labels possible + If a unit can only have one label, the category can be set to be *exclusive=True*. If several labels can be used at once, the category can be set to be *exclusive=False*. * **manual output** curation with the folowing keys: * "manual_labels" - * "merge_unit_groups" - * "removed_units" + * "merges" + * "removed" + * "splits" + * "discard_spikes" -Here is the description of the format with a simple example (the first part of the -format is the definition; the second part of the format is manual action): +The first three ("manual_labels", "merges" and "removed") act at the unit level. They label, merge or remove whole units. While the final two +("splits" and "discard_spikes") act at the spike level: we need to define which spikes from a unit are being split into a new unit, or which +spikes from a unit are to be discarded. Note that all spike indices are with respect to the original analyzer. + +Here is a simple example of the format: .. code-block:: json { - "format_version": "1", + "format_version": "3", "unit_ids": [ "u1", "u2", @@ -266,25 +271,31 @@ format is the definition; the second part of the format is manual action): "manual_labels": [ { "unit_id": "u1", - "quality": [ - "good" - ] + "labels": { + "quality": [ + "good" + ] + } }, { "unit_id": "u2", - "quality": [ - "noise" - ], - "putative_type": [ - "excitatory", - "pyramidal" - ] + "labels": { + "quality": [ + "noise" + ], + "putative_type": [ + "excitatory", + "pyramidal" + ] + } }, { "unit_id": "u3", - "putative_type": [ - "inhibitory" - ] + "labels": { + "putative_type": [ + "inhibitory" + ] + } } ], "merge_unit_groups": [ @@ -301,9 +312,48 @@ format is the definition; the second part of the format is manual action): "removed_units": [ "u31", "u42" + ], + "splits": [ + { + "unit_id": "u1", + "mode": "indices", + "indices": [ + [ + 10, + 20, + 30 + ] + ], + "new_unit_ids": [ + "u1-1", + "u1-2" + ] + } + ], + "discard_spikes": [ + { + "unit_id": "u10", + "indices": [ + 56, + 57, + 59, + 60 + ] + }, + { + "unit_id": "u14", + "indices": [ + 123, + 321 + ] + } ] } +Note that you cannot split and merge a unit at the same time. + +We do not expect users to create their own curation json files. Instead, our internal curation algorithms will output +results which can be easily transformed into the format. We also hope that external packages can use our format. The curation format can be loaded into a dictionary and directly applied to a ``BaseSorting`` or ``SortingAnalyzer`` object using the :py:func:`~spikeinterface.curation.apply_curation` function. diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 74ef52e258..b6c16a4728 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -604,8 +604,9 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, arr = np.std(wfs, axis=0) elif operator == "median": arr = np.median(wfs, axis=0) - elif "percentile" in operator: - _, percentile = operator.splot("_") + # old versions have spelling error in "percentile" + elif "percentile" in operator or "pencentile" in operator: + _, percentile = operator.split("_") arr = np.percentile(wfs, float(percentile), axis=0) new_array[split_unit_index, ...] = arr else: @@ -1126,6 +1127,7 @@ def _compute_metrics( column_names = list(metric.metric_columns.keys()) try: metric_params = self.params["metric_params"].get(metric_name, {}) + # TODO: deal with number_of_peaks and velocity_fits here res = metric.compute( sorting_analyzer, unit_ids=unit_ids, diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 90c7e18a99..c755125a60 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -517,6 +517,7 @@ def apply_splits_to_sorting( sorting: BaseSorting, unit_splits: dict[int | str, list[list[int | str]]], new_unit_ids: list[list[int | str]] | None = None, + discard_spikes_unit_ids: list[list[int | str]] | None = None, return_extra: bool = False, new_id_strategy: str = "append", ): @@ -543,6 +544,8 @@ def apply_splits_to_sorting( List of new unit_ids for each split. If given, it needs to have the same length as `unit_splits`. and each element must have the same length as the corresponding list of split indices. If None, new ids will be generated. + discard_spikes_unit_ids : list | None, default: None + List of units which contain spikes to discard. return_extra : bool, default: False If True, also return the new_unit_ids. new_id_strategy : "append" | "split", default: "append" @@ -565,7 +568,11 @@ def apply_splits_to_sorting( # this is true when running via apply_curation new_unit_ids = generate_unit_ids_for_split( - sorting.unit_ids, unit_splits, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy + sorting.unit_ids, + unit_splits, + new_unit_ids=new_unit_ids, + new_id_strategy=new_id_strategy, + discard_spikes_unit_ids=discard_spikes_unit_ids, ) all_unit_ids = _get_ids_after_splitting(sorting.unit_ids, unit_splits, new_unit_ids) all_unit_ids = list(all_unit_ids) @@ -660,7 +667,9 @@ def set_properties_after_splits( sorting_post_split.set_property("is_split", is_split) -def generate_unit_ids_for_split(old_unit_ids, unit_splits, new_unit_ids=None, new_id_strategy="append"): +def generate_unit_ids_for_split( + old_unit_ids, unit_splits, new_unit_ids=None, new_id_strategy="append", discard_spikes_unit_ids=None +): """ Function to generate new units ids during a splitting procedure. If `new_units_ids` are provided, it will return these unit ids, checking that they are consistent with @@ -682,6 +691,8 @@ def generate_unit_ids_for_split(old_unit_ids, unit_splits, new_unit_ids=None, ne * "split" : new_unit_ids will be the created as {split_unit_id]-{split_number} (e.g. when splitting unit "13" in 2: "13-0" / "13-1"). Only works if unit_ids are str otherwise switch to "append" + discard_spikes_unit_ids : list | None, default: None + List of units which contain spikes to discard. Returns ------- @@ -691,46 +702,92 @@ def generate_unit_ids_for_split(old_unit_ids, unit_splits, new_unit_ids=None, ne assert new_id_strategy in ["append", "split"], "new_id_strategy should be 'append' or 'split'" old_unit_ids = np.asarray(old_unit_ids) - if new_unit_ids is not None: - for split_unit, new_split_ids in zip(unit_splits.values(), new_unit_ids): - # then only doing a consistency check - assert len(split_unit) == len(new_split_ids), "new_unit_ids should have the same len as unit_splits.values" - # new_unit_ids can also be part of old_unit_ids only inside the same group: - assert all( - new_split_id not in old_unit_ids for new_split_id in new_split_ids - ), "new_unit_ids already exists but outside the split groups" - else: - dtype = old_unit_ids.dtype - if np.issubdtype(dtype, np.integer) and new_id_strategy == "split": - warnings.warn("new_id_strategy 'split' is not compatible with integer unit_ids. Switching to 'append'.") - new_id_strategy = "append" - - new_unit_ids = [] - current_unit_ids = old_unit_ids.copy() - for unit_to_split, split_indices in unit_splits.items(): - num_splits = len(split_indices) - # select new_unit_ids greater that the max id, event greater than the numerical str ids - if new_id_strategy == "append": - if np.issubdtype(dtype, np.character): - # dtype str - if all(p.isdigit() for p in current_unit_ids): - # All str are digit : we can generate a max - m = max(int(p) for p in current_unit_ids) + 1 - new_units_for_split = [str(m + i) for i in range(num_splits)] - else: - # we cannot automatically find new names - new_units_for_split = [f"{unit_to_split}-split{i}" for i in range(num_splits)] - else: - # dtype int - new_units_for_split = list(max(current_unit_ids) + 1 + np.arange(num_splits, dtype=dtype)) - # we append the new split unit ids to continue to increment the max id - current_unit_ids = np.concatenate([current_unit_ids, new_units_for_split]) - elif new_id_strategy == "split": - # we made sure that dtype is not integer + new_new_unit_ids = [] + if discard_spikes_unit_ids is None: + discard_spikes_unit_ids = [] + + if new_unit_ids is not None and len(new_unit_ids) > 0: + for split_unit, new_split_ids in zip(unit_splits.keys(), new_unit_ids): + if split_unit in discard_spikes_unit_ids: + # make the discard unit have the original unit_id, which we'll delete later + new_new_unit_ids.append([split_unit] + new_split_ids) + else: + new_new_unit_ids.append(new_split_ids) + + return new_new_unit_ids + + dtype = old_unit_ids.dtype + if np.issubdtype(dtype, np.integer) and new_id_strategy == "split": + warnings.warn("new_id_strategy 'split' is not compatible with integer unit_ids. Switching to 'append'.") + new_id_strategy = "append" + + # If unit_ids are list of numeric strings, we will convert to int, use "append" code with int unit_ids, then convert back to string + all_strings = False + if np.issubdtype(dtype, np.character): + if all(p.isdigit() for p in old_unit_ids): + all_strings = True + old_unit_ids = [int(unit_id) for unit_id in old_unit_ids] + else: + # we cannot automatically find new names + warnings.warn( + "new_id_strategy 'append' is not compatible with non-numeric string unit_ids. Switching to 'split'." + ) + new_id_strategy = "split" + + if new_id_strategy == "append": + next_max_unit_id = max(old_unit_ids) + 1 + highest_possible_unit_id = max(old_unit_ids) + for split_indices in unit_splits.values(): + highest_possible_unit_id += len(split_indices) + + for unit_to_split, split_indices in unit_splits.items(): + + num_splits = len(split_indices) + + # decide if unit is a simple discard, a simple split or a discard and split + just_discard = False + discard_and_split = False + if unit_to_split in discard_spikes_unit_ids: + if num_splits == 2: + just_discard = True + elif num_splits > 2: + discard_and_split = True + + # the new units made from the discard_spikes must be at the *front* of the new_units_for_split list. + if new_id_strategy == "append": + if just_discard: + # give the discard unit the highest possible id, leaving the cleaned unit's id untouched. + new_units_for_split = [highest_possible_unit_id, unit_to_split] + highest_possible_unit_id -= 1 + elif discard_and_split: + # give the discard unit the id `unit_to_split`, and the other units their expected append ids + new_units_for_split = [unit_to_split] + list(next_max_unit_id + np.arange(num_splits - 1)) + next_max_unit_id += len(new_units_for_split) - 1 + else: + # there is no discard unit + new_units_for_split = list(next_max_unit_id + np.arange(num_splits)) + next_max_unit_id += len(new_units_for_split) + + elif new_id_strategy == "split": + if just_discard: + new_units_for_split = [f"{unit_to_split}-{i}" for i in np.arange(len(split_indices) - 1)] + [ + unit_to_split + ] + elif discard_and_split: + new_units_for_split = [unit_to_split] + [ + f"{unit_to_split}-{i}" for i in np.arange(len(split_indices) - 1) + ] + else: new_units_for_split = [f"{unit_to_split}-{i}" for i in np.arange(len(split_indices))] - new_unit_ids.append(new_units_for_split) - return new_unit_ids + new_new_unit_ids.append(new_units_for_split) + + if all_strings: + new_new_unit_ids = [ + [str(unit_id) for unit_id in new_units_for_split] for new_units_for_split in new_new_unit_ids + ] + + return new_new_unit_ids def check_unit_splits_consistency(unit_splits, sorting): diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 1870c24e7a..d6fb931aa9 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1386,6 +1386,7 @@ def split_units( new_unit_ids: list[list[int | str]] | None = None, new_id_strategy: str = "append", return_new_unit_ids: bool = False, + discard_spikes_unit_ids=None, format: str = "memory", folder: Path | str | None = None, verbose: bool = False, @@ -1435,7 +1436,9 @@ def split_units( check_unit_splits_consistency(split_units, self.sorting) - new_unit_ids = generate_unit_ids_for_split(self.unit_ids, split_units, new_unit_ids, new_id_strategy) + new_unit_ids = generate_unit_ids_for_split( + self.unit_ids, split_units, new_unit_ids, new_id_strategy, discard_spikes_unit_ids + ) all_unit_ids = _get_ids_after_splitting(self.unit_ids, split_units, new_unit_ids=new_unit_ids) new_analyzer = self._save_or_select_or_merge_or_split( diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 840ef07353..a8fe2afc22 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -3,10 +3,10 @@ from pathlib import Path import json import numpy as np -from itertools import chain from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting, apply_splits_to_sorting from spikeinterface.curation.curation_model import CurationModel +from spikeinterface.core.sorting_tools import generate_unit_ids_for_split def validate_curation_dict(curation_dict: dict): @@ -153,9 +153,9 @@ def apply_curation( Steps are done in this order: 1. Apply labels using curation_dict["manual_labels"] - 2. Apply removal using curation_dict["removed"] - 3. Apply merges using curation_dict["merges"] - 4. Apply splits using curation_dict["splits"] + 2. Remove whole units using curation_dict["removed"] + 3. Apply splits using curation_dict["splits"] and remove spikes from units using curation_dict["discard_spikes"] + 4. Apply merges using curation_dict["merges"] A new Sorting or SortingAnalyzer (in memory) is returned. The user (an adult) has the responsability to save it somewhere (or not). @@ -218,7 +218,85 @@ def apply_curation( else: curated_sorting_or_analyzer = sorting_or_analyzer - # 3. Merge units + # 3. Split and discard_spikes from units + # Do this at the same time, otherwise have to do a lot of spike index shuffling. + # Strategy: put the discarded spikes in a new unit when splitting, then remove them at the end. + if len(curation_model.splits) > 0 or len(curation_model.discard_spikes) > 0: + + split_spikes_unit_ids = [] + split_new_unit_ids_from_user = None + if len(curation_model.splits) > 0: + split_spikes_unit_ids = [split.unit_id for split in curation_model.splits] + split_new_unit_ids_from_user = [s.new_unit_ids for s in curation_model.splits if s.new_unit_ids is not None] + + discard_spikes_unit_ids = [] + if len(curation_model.discard_spikes) > 0: + discard_spikes_unit_ids = [discard_spike.unit_id for discard_spike in curation_model.discard_spikes] + + split_units = {} + + sorting = ( + curated_sorting_or_analyzer if isinstance(sorting_or_analyzer, BaseSorting) else sorting_or_analyzer.sorting + ) + + num_spikes_per_unit = sorting.count_num_spikes_per_unit() + + for unit_id in curation_model.unit_ids: + + if unit_id in split_spikes_unit_ids: + split_spikes_list_index = np.where(np.array(split_spikes_unit_ids) == unit_id)[0][0] + split = curation_model.splits[split_spikes_list_index] + + split_units[unit_id] = split.get_full_spike_indices(sorting) + + # If the unit is not split but does contain spikes to discard, make an initial "split" + # unit containing the indices of the entire spike train. + elif unit_id in discard_spikes_unit_ids: + split_units[unit_id] = [np.arange(num_spikes_per_unit[unit_id])] + + # If there are spikes to discard, find these, and remove them from the units-to-split. + # Put the discarded spikes in their own unit, at the start of the list of split units. + if unit_id in discard_spikes_unit_ids: + + discard_spikes_list_index = np.where(np.array(discard_spikes_unit_ids) == unit_id)[0][0] + discard_spikes_indices = np.array(curation_model.discard_spikes[discard_spikes_list_index].indices) + + split_units_with_discard = [discard_spikes_indices] + for split_spike_train in split_units[unit_id]: + split_spike_train_cleaned = np.setdiff1d(split_spike_train, discard_spikes_indices) + split_units_with_discard.append(split_spike_train_cleaned) + + split_units[unit_id] = split_units_with_discard + + if isinstance(sorting_or_analyzer, BaseSorting): + curated_sorting_or_analyzer, new_ids = apply_splits_to_sorting( + curated_sorting_or_analyzer, + split_units, + new_id_strategy=new_id_strategy, + new_unit_ids=split_new_unit_ids_from_user, + discard_spikes_unit_ids=discard_spikes_unit_ids, + return_extra=True, + ) + else: + curated_sorting_or_analyzer, new_ids = curated_sorting_or_analyzer.split_units( + split_units, + new_id_strategy=new_id_strategy, + new_unit_ids=split_new_unit_ids_from_user, + discard_spikes_unit_ids=discard_spikes_unit_ids, + format="memory", + verbose=verbose, + return_new_unit_ids=True, + ) + + if len(discard_spikes_unit_ids) > 0: + ids_to_remove = [] + for new_id_set in new_ids: + if new_id_set[0] in discard_spikes_unit_ids or new_id_set[1] in discard_spikes_unit_ids: + ids_to_remove.append(new_id_set[0]) + + curated_sorting_or_analyzer = curated_sorting_or_analyzer.remove_units(ids_to_remove) + + # 4. Merge units if len(curation_model.merges) > 0: merge_unit_groups = [m.unit_ids for m in curation_model.merges] merge_new_unit_ids = [m.new_unit_id for m in curation_model.merges if m.new_unit_id is not None] @@ -246,37 +324,6 @@ def apply_curation( **job_kwargs, ) - # 4. Split units - if len(curation_model.splits) > 0: - split_units = {} - for split in curation_model.splits: - sorting = ( - curated_sorting_or_analyzer - if isinstance(sorting_or_analyzer, BaseSorting) - else sorting_or_analyzer.sorting - ) - split_units[split.unit_id] = split.get_full_spike_indices(sorting) - split_new_unit_ids = [s.new_unit_ids for s in curation_model.splits if s.new_unit_ids is not None] - if len(split_new_unit_ids) == 0: - split_new_unit_ids = None - if isinstance(sorting_or_analyzer, BaseSorting): - curated_sorting_or_analyzer, _ = apply_splits_to_sorting( - curated_sorting_or_analyzer, - split_units, - new_unit_ids=split_new_unit_ids, - new_id_strategy=new_id_strategy, - return_extra=True, - ) - else: - curated_sorting_or_analyzer, _ = curated_sorting_or_analyzer.split_units( - split_units, - new_id_strategy=new_id_strategy, - return_new_unit_ids=True, - new_unit_ids=split_new_unit_ids, - format="memory", - verbose=verbose, - ) - return curated_sorting_or_analyzer diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index ac89fde04a..de4bad25c5 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -75,9 +75,17 @@ def get_full_spike_indices(self, sorting: BaseSorting): return full_spike_indices +class DiscardSpikes(BaseModel): + unit_id: Union[int, str] = Field(description="ID of the unit") + indices: Optional[List[int]] = Field( + default=None, + description=("List of indices of the spikes to discard."), + ) + + class CurationModel(BaseModel): - supported_versions: Tuple[Literal["1"], Literal["2"]] = Field( - default=["1", "2"], description="Supported versions of the curation format" + supported_versions: Tuple[Literal["1"], Literal["2"], Literal["3"]] = Field( + default=["1", "2", "3"], description="Supported versions of the curation format" ) format_version: str = Field(description="Version of the curation format") unit_ids: List[Union[int, str]] = Field(description="List of unit IDs") @@ -88,6 +96,9 @@ class CurationModel(BaseModel): removed: Optional[List[Union[int, str]]] = Field(default=None, description="List of removed unit IDs") merges: Optional[List[Merge]] = Field(default=None, description="List of merges") splits: Optional[List[Split]] = Field(default=None, description="List of splits") + discard_spikes: Optional[List[DiscardSpikes]] = Field( + default=None, description="List of spikes to discard for each unit" + ) @field_validator("label_definitions", mode="before") def add_label_definition_name(cls, label_definitions): @@ -238,11 +249,11 @@ def check_splits(cls, values): for i, split in enumerate(splits): if isinstance(split, dict): split = dict(split) - if "indices" in split: + if split.get("indices") is not None: split["indices"] = [list(indices) for indices in split["indices"]] - if "labels" in split: + if split.get("labels") is not None: split["labels"] = list(split["labels"]) - if "new_unit_ids" in split: + if split.get("new_unit_ids") is not None: split["new_unit_ids"] = list(split["new_unit_ids"]) splits[i] = Split(**split) @@ -292,6 +303,39 @@ def check_splits(cls, values): values["splits"] = splits return values + @classmethod + def check_discard_spikes(cls, values): + """ + Checks and validates the discard_spikes in the curation model. + + * Checks if the unit_id exists in the unit_ids list. + * Checks that indices are defined and that there are no duplicate indices. + + """ + unit_ids = list(values["unit_ids"]) + discard_spikes = values.get("discard_spikes") + if discard_spikes is None: + values["discard_spikes"] = [] + return values + + # Convert items to DiscardSpikes objects + discard_spikes_objects = [DiscardSpikes(**discard_spike) for discard_spike in discard_spikes] + + # Validate discard spikes + for discard_spike in discard_spikes_objects: + # Check unit exists + if discard_spike.unit_id not in unit_ids: + raise ValueError(f"DiscardSpikes unit_id {discard_spike.unit_id} is not in the unit list") + if discard_spike.indices is None: + raise ValueError(f"DiscardSpikes unit {discard_spike.unit_id} has no indices defined") + # Check no duplicate indices + all_indices = discard_spike.indices + if len(all_indices) != len(set(all_indices)): + raise ValueError(f"DiscardSpikes unit {discard_spike.unit_id} has duplicate indices") + + values["discard_spikes"] = discard_spikes_objects + return values + @classmethod def check_removed(cls, values): """ @@ -314,13 +358,14 @@ def check_removed(cls, values): @classmethod def convert_old_format(cls, values): """ - Converts old curation formats (v0 and v1) to the current format (v2). + Converts old curation formats (v0, v1 and v2) to the current format (v3). v0 (sortingview) format is converted to v2 by extracting labels, merges, and unit IDs. v1 format is updated to v2 by renaming fields and ensuring the structure matches the v2 format. + v2 format is updated to v3 by updating the `format_version`. """ format_version = values.get("format_version", "0") if format_version == "0": - print("Conversion from format version v0 (sortingview) to v2") + print("Conversion from format version v0 (sortingview) to v3") if "mergeGroups" not in values.keys(): values["mergeGroups"] = [] merge_groups = values["mergeGroups"] @@ -349,7 +394,7 @@ def convert_old_format(cls, values): all_units = list(set(all_units)) values = { - "format_version": "2", + "format_version": "3", "unit_ids": values.get("unit_ids", all_units), "label_definitions": labels_def, "manual_labels": list(manual_labels), @@ -364,6 +409,11 @@ def convert_old_format(cls, values): removed_units = values.get("removed_units") if removed_units is not None: values["removed"] = list(removed_units) + values["supported_versions"] = ["1", "2", "3"] + elif values["format_version"] == "2": + # TODO: check this - should we update the format version when we load an old version? + values["format_version"] = "3" + values["supported_versions"] = ["1", "2", "3"] return values @model_validator(mode="before") @@ -375,6 +425,7 @@ def validate_fields(cls, values): values = cls.check_merges(values) values = cls.check_splits(values) values = cls.check_removed(values) + values = cls.check_discard_spikes(values) return values @model_validator(mode="after") @@ -387,6 +438,7 @@ def validate_curation_dict(self): labeled_unit_set = set([lbl.unit_id for lbl in self.manual_labels]) if self.manual_labels else set() merged_units_set = set(chain.from_iterable(merge.unit_ids for merge in self.merges)) if self.merges else set() split_units_set = set(split.unit_id for split in self.splits) if self.splits else set() + discard_spikes_units_set = set(split.unit_id for split in self.discard_spikes) if self.splits else set() removed_set = set(self.removed) if self.removed else set() unit_ids = self.unit_ids @@ -399,7 +451,8 @@ def validate_curation_dict(self): raise ValueError("Curation format: some split units are not in the unit list") if not removed_set.issubset(unit_set): raise ValueError("Curation format: some removed units are not in the unit list") - + if not discard_spikes_units_set.issubset(unit_set): + raise ValueError("Curation format: some discard_spikes units are not in the unit list") # Check for units being merged multiple times all_merging_groups = [set(merge.unit_ids) for merge in self.merges] if self.merges else [] for gp_1, gp_2 in combinations(all_merging_groups, 2): diff --git a/src/spikeinterface/curation/tests/test_curation_model.py b/src/spikeinterface/curation/tests/test_curation_model.py index 951f33a300..6bf8895260 100644 --- a/src/spikeinterface/curation/tests/test_curation_model.py +++ b/src/spikeinterface/curation/tests/test_curation_model.py @@ -13,7 +13,7 @@ def test_format_version(): # Invalid format version with pytest.raises(ValidationError): - CurationModel(format_version="3", unit_ids=[1, 2, 3]) + CurationModel(format_version="4", unit_ids=[1, 2, 3]) with pytest.raises(ValidationError): CurationModel(format_version="0.1", unit_ids=[1, 2, 3]) @@ -213,6 +213,34 @@ def test_split_units(): CurationModel(**invalid_new_ids) +# Test discard_spikes functionality +def test_discard_spikes(): + # Test indices mode with list format + valid_discard_spikes_indices = { + "format_version": "3", + "unit_ids": [1, 2, 3], + "discard_spikes": [ + { + "unit_id": 1, + "indices": [0, 1, 2], + } + ], + } + + model = CurationModel(**valid_discard_spikes_indices) + assert len(model.discard_spikes) == 1 + assert len(model.discard_spikes[0].indices) == 3 + + # Test invalid unit ID + invalid_unit_id = { + "format_version": "3", + "unit_ids": [1, 2, 3], + "discard_spikes": [{"unit_id": 4, "indices": [0, 1]}], # Non-existent unit + } + with pytest.raises(ValidationError): + CurationModel(**invalid_unit_id) + + # Test removed units def test_removed_units(): valid_remove = {"format_version": "2", "unit_ids": [1, 2, 3], "removed": [2]} @@ -252,7 +280,7 @@ def test_complete_model(): } model = CurationModel(**complete_model) - assert model.format_version == "2" + assert model.format_version == "3" assert len(model.unit_ids) == 5 assert len(model.label_definitions) == 2 assert len(model.manual_labels) == 1 @@ -275,7 +303,7 @@ def test_complete_model(): } model = CurationModel(**complete_model_dict) - assert model.format_version == "2" + assert model.format_version == "3" assert len(model.unit_ids) == 5 assert len(model.label_definitions) == 2 assert len(model.manual_labels) == 1 diff --git a/src/spikeinterface/curation/tests/test_discard_spikes.py b/src/spikeinterface/curation/tests/test_discard_spikes.py new file mode 100644 index 0000000000..0d1b044ccb --- /dev/null +++ b/src/spikeinterface/curation/tests/test_discard_spikes.py @@ -0,0 +1,413 @@ +import numpy as np +from spikeinterface.core import NumpySorting +from spikeinterface.curation import apply_curation +from numpy.random import default_rng + + +def test_discard_and_split(): + rng = default_rng() + spike_indices = [ + { + 0: np.sort(rng.choice(100, size=20, replace=False)), + 1: np.arange(17), + 2: np.arange(17) + 5, + 4: np.concatenate([np.arange(10), np.arange(20, 30)]), + 5: np.arange(9), + }, + {0: np.arange(15), 1: np.arange(17), 2: np.arange(40, 140), 4: np.arange(40, 140), 5: np.arange(40, 140)}, + ] + original_sort = NumpySorting.from_unit_dict(spike_indices, sampling_frequency=1000) # to have 1 sample=1ms + + discard_spikes = [4, 11, 13, 22, 30] + discard_spikes_segs = [[4, 11, 13], np.array([22, 30]) - len(spike_indices[0][0])] + + discard_spikes_curation = { + "format_version": "3", + "unit_ids": original_sort.unit_ids, + "discard_spikes": [ + { + "unit_id": 0, + "indices": discard_spikes, + } + ], + } + + curated_sort = apply_curation(original_sort, discard_spikes_curation) + + for segment_index in [0, 1]: + + original_spike_train = original_sort.get_unit_spike_train(unit_id=0, segment_index=segment_index) + curated_spike_train = curated_sort.get_unit_spike_train(unit_id=0, segment_index=segment_index) + + discard_spike_times = set(original_spike_train).difference(set(curated_spike_train)) + spike_times_of_discarded_spikes = original_spike_train[discard_spikes_segs[segment_index]] + + assert set(discard_spike_times) == set(spike_times_of_discarded_spikes) + + split_indices = [2, 4, 5, 10, 13, 18, 22, 31, 33] + split_spikes_segs_50 = [ + [2, 4, 5, 10, 13, 18], + [22 - len(spike_indices[0][0]), 31 - len(spike_indices[0][0]), 33 - len(spike_indices[0][0])], + ] + indices_in_51_seg_2 = np.array([20, 21, 23, 24, 25, 26, 27, 28, 29, 30, 32, 34]) + split_spikes_segs_51 = [ + [0, 1, 3, 6, 7, 8, 9, 11, 12, 14, 15, 16, 17, 19], + indices_in_51_seg_2 - len(spike_indices[0][0]), + ] + split_spikes_segs = [split_spikes_segs_50, split_spikes_segs_51] + + discard_spikes_curation_with_split = { + "format_version": "3", + "unit_ids": original_sort.unit_ids, + "discard_spikes": [ + { + "unit_id": 0, + "indices": discard_spikes, + } + ], + "splits": [{"unit_id": 0, "mode": "indices", "indices": [split_indices], "new_unit_ids": [50, 51]}], + } + + # 4 and 13 should be discarded + should_be_discarded_spikes = [[[4, 13], [22 - len(spike_indices[0][0])]], [[11], [30 - len(spike_indices[0][0])]]] + + curated_sort_with_split = apply_curation(original_sort, discard_spikes_curation_with_split) + + for new_unit_id, discarded_spikes_each_unit, split_spikes_seg in zip( + [50, 51], should_be_discarded_spikes, split_spikes_segs + ): + for segment_index in [0, 1]: + + original_spike_train = original_sort.get_unit_spike_train(unit_id=0, segment_index=segment_index) + curated_spike_train = curated_sort_with_split.get_unit_spike_train( + unit_id=new_unit_id, segment_index=segment_index + ) + + discard_spike_times = set(original_spike_train[split_spikes_seg[segment_index]]).difference( + set(curated_spike_train) + ) + + spike_times_of_discarded_spikes = original_spike_train[discarded_spikes_each_unit[segment_index]] + + assert set(discard_spike_times) == set(spike_times_of_discarded_spikes) + + # test with "append" unit_id strategy + + split_indices = [2, 4, 5, 10, 13, 18, 22, 31, 33] + split_spikes_segs_50 = [[2, 4, 5, 10, 13, 18], np.array([22, 31, 33]) - len(spike_indices[0][0])] + indices_in_51_seg_2 = np.array([20, 21, 23, 24, 25, 26, 27, 28, 29, 30, 32, 34]) + split_spikes_segs_51 = [ + [0, 1, 3, 6, 7, 8, 9, 11, 12, 14, 15, 16, 17, 19], + indices_in_51_seg_2 - len(spike_indices[0][0]), + ] + split_spikes_segs = [split_spikes_segs_50, split_spikes_segs_51] + + discard_spikes_curation_with_split_append = { + "format_version": "3", + "unit_ids": original_sort.unit_ids, + "discard_spikes": [ + { + "unit_id": 0, + "indices": discard_spikes, + } + ], + "splits": [{"unit_id": 0, "mode": "indices", "indices": [split_indices]}], + } + + # since we're using append, the new units will have id 6 and 7 + curated_sort_with_split_append = apply_curation(original_sort, discard_spikes_curation_with_split_append) + + print(f"{curated_sort_with_split_append.unit_ids=}", flush=True) + + for new_unit_id, discarded_spikes_each_unit, split_spikes_seg in zip( + [6, 7], should_be_discarded_spikes, split_spikes_segs + ): + for segment_index in [0, 1]: + + original_spike_train = original_sort.get_unit_spike_train(unit_id=0, segment_index=segment_index) + curated_spike_train = curated_sort_with_split_append.get_unit_spike_train( + unit_id=new_unit_id, segment_index=segment_index + ) + + discard_spike_times = set(original_spike_train[split_spikes_seg[segment_index]]).difference( + set(curated_spike_train) + ) + + spike_times_of_discarded_spikes = original_spike_train[discarded_spikes_each_unit[segment_index]] + + assert set(discard_spike_times) == set(spike_times_of_discarded_spikes) + + +def test_discard_and_split_string_ids(): + rng = default_rng() + spike_indices = [ + { + "0": np.sort(rng.choice(100, size=20, replace=False)), + "1": np.arange(17), + "2": np.arange(17) + 5, + "4": np.concatenate([np.arange(10), np.arange(20, 30)]), + "5": np.arange(9), + }, + { + "0": np.arange(15), + "1": np.arange(17), + "2": np.arange(40, 140), + "4": np.arange(40, 140), + "5": np.arange(40, 140), + }, + ] + original_sort = NumpySorting.from_unit_dict(spike_indices, sampling_frequency=1000) # to have 1 sample=1ms + + discard_spikes = [4, 11, 13, 22, 30] + discard_spikes_segs = [[4, 11, 13], np.array([22, 30]) - len(spike_indices[0]["0"])] + + discard_spikes_curation = { + "format_version": "3", + "unit_ids": original_sort.unit_ids, + "discard_spikes": [ + { + "unit_id": "0", + "indices": discard_spikes, + } + ], + } + + curated_sort = apply_curation(original_sort, discard_spikes_curation) + + for segment_index in [0, 1]: + + original_spike_train = original_sort.get_unit_spike_train(unit_id="0", segment_index=segment_index) + curated_spike_train = curated_sort.get_unit_spike_train(unit_id="0", segment_index=segment_index) + + discard_spike_times = set(original_spike_train).difference(set(curated_spike_train)) + spike_times_of_discarded_spikes = original_spike_train[discard_spikes_segs[segment_index]] + + assert set(discard_spike_times) == set(spike_times_of_discarded_spikes) + + split_indices = [2, 4, 5, 10, 13, 18, 22, 31, 33] + split_spikes_segs_50 = [[2, 4, 5, 10, 13, 18], np.array([22, 31, 33]) - len(spike_indices[0]["0"])] + indices_in_51_seg_2 = np.array([20, 21, 23, 24, 25, 26, 27, 28, 29, 30, 32, 34]) + split_spikes_segs_51 = [ + [0, 1, 3, 6, 7, 8, 9, 11, 12, 14, 15, 16, 17, 19], + indices_in_51_seg_2 - len(spike_indices[0]["0"]), + ] + split_spikes_segs = [split_spikes_segs_50, split_spikes_segs_51] + + discard_spikes_curation_with_split = { + "format_version": "3", + "unit_ids": original_sort.unit_ids, + "discard_spikes": [ + { + "unit_id": "0", + "indices": discard_spikes, + } + ], + "splits": [{"unit_id": "0", "mode": "indices", "indices": [split_indices], "new_unit_ids": ["50", "51"]}], + } + + # 4 and 13 should be discarded + should_be_discarded_spikes = [ + [[4, 13], [22 - len(spike_indices[0]["0"])]], + [[11], [30 - len(spike_indices[0]["0"])]], + ] + + curated_sort_with_split = apply_curation(original_sort, discard_spikes_curation_with_split) + + for new_unit_id, discarded_spikes_each_unit, split_spikes_seg in zip( + ["50", "51"], should_be_discarded_spikes, split_spikes_segs + ): + for segment_index in [0, 1]: + + original_spike_train = original_sort.get_unit_spike_train(unit_id="0", segment_index=segment_index) + curated_spike_train = curated_sort_with_split.get_unit_spike_train( + unit_id=new_unit_id, segment_index=segment_index + ) + + discard_spike_times = set(original_spike_train[split_spikes_seg[segment_index]]).difference( + set(curated_spike_train) + ) + + spike_times_of_discarded_spikes = original_spike_train[discarded_spikes_each_unit[segment_index]] + + assert set(discard_spike_times) == set(spike_times_of_discarded_spikes) + + # test with "append" unit_id strategy + + split_indices = [2, 4, 5, 10, 13, 18, 22, 31, 33] + split_spikes_segs_50 = [[2, 4, 5, 10, 13, 18], np.array([22, 31, 33]) - len(spike_indices[0]["0"])] + indices_in_51_seg_2 = np.array([20, 21, 23, 24, 25, 26, 27, 28, 29, 30, 32, 34]) + split_spikes_segs_51 = [ + [0, 1, 3, 6, 7, 8, 9, 11, 12, 14, 15, 16, 17, 19], + indices_in_51_seg_2 - len(spike_indices[0]["0"]), + ] + split_spikes_segs = [split_spikes_segs_50, split_spikes_segs_51] + + discard_spikes_curation_with_split_append = { + "format_version": "3", + "unit_ids": original_sort.unit_ids, + "discard_spikes": [ + { + "unit_id": "0", + "indices": discard_spikes, + } + ], + "splits": [{"unit_id": "0", "mode": "indices", "indices": [split_indices]}], + } + + # since we're using append, the new units will have id 6 and 7 + curated_sort_with_split_append = apply_curation(original_sort, discard_spikes_curation_with_split_append) + + for new_unit_id, discarded_spikes_each_unit, split_spikes_seg in zip( + ["6", "7"], should_be_discarded_spikes, split_spikes_segs + ): + for segment_index in [0, 1]: + + original_spike_train = original_sort.get_unit_spike_train(unit_id="0", segment_index=segment_index) + curated_spike_train = curated_sort_with_split_append.get_unit_spike_train( + unit_id=new_unit_id, segment_index=segment_index + ) + + discard_spike_times = set(original_spike_train[split_spikes_seg[segment_index]]).difference( + set(curated_spike_train) + ) + + spike_times_of_discarded_spikes = original_spike_train[discarded_spikes_each_unit[segment_index]] + + assert set(discard_spike_times) == set(spike_times_of_discarded_spikes) + + # since we're using append, the new units will have id 6 and 7 + curated_sort_with_split_split = apply_curation( + original_sort, discard_spikes_curation_with_split_append, new_id_strategy="split" + ) + + for new_unit_id, discarded_spikes_each_unit, split_spikes_seg in zip( + ["0-0", "0-1"], should_be_discarded_spikes, split_spikes_segs + ): + for segment_index in [0, 1]: + + original_spike_train = original_sort.get_unit_spike_train(unit_id="0", segment_index=segment_index) + curated_spike_train = curated_sort_with_split_split.get_unit_spike_train( + unit_id=new_unit_id, segment_index=segment_index + ) + + discard_spike_times = set(original_spike_train[split_spikes_seg[segment_index]]).difference( + set(curated_spike_train) + ) + + spike_times_of_discarded_spikes = original_spike_train[discarded_spikes_each_unit[segment_index]] + + assert set(discard_spike_times) == set(spike_times_of_discarded_spikes) + + +def test_discard_and_split_several_units(): + + spike_indices = [{unit_id: np.arange(10) for unit_id in [0, 1, 3, 6]}] + original_sort = NumpySorting.from_unit_dict(spike_indices, sampling_frequency=1000) # to have 1 sample=1ms + + discard_spikes = {0: [2, 4, 6, 8], 3: [1, 3, 5, 7, 9]} + remaining_spikes = { + 0: [0, 1, 3, 5, 7, 9], + 3: [0, 2, 4, 6, 8], + } + + split_spikes = {1: [5, 6, 8], 3: [2, 3, 4, 7, 8, 9]} + + expected_new_spikes = { + 0: [0, 1, 3, 5, 7, 9], + 6: np.arange(10), + # 1 should split into 7 and 8 + 7: [5, 6, 8], + 8: [0, 1, 2, 3, 4, 7, 9], + # 3 should split into 9 and 10, without discarded spikes + 9: [2, 4, 8], + 10: [0, 6], + } + + discard_spikes_curation = { + "format_version": "3", + "unit_ids": original_sort.unit_ids, + "splits": [ + {"unit_id": 1, "mode": "indices", "indices": [split_spikes[1]]}, + {"unit_id": 3, "mode": "indices", "indices": [split_spikes[3]]}, + ], + "discard_spikes": [ + { + "unit_id": 0, + "indices": discard_spikes[0], + }, + { + "unit_id": 3, + "indices": discard_spikes[3], + }, + ], + } + + curated_sort = apply_curation(original_sort, discard_spikes_curation) + + for expected_unit_id, spikes in expected_new_spikes.items(): + assert np.all(curated_sort.get_unit_spike_train(unit_id=expected_unit_id) == spikes) + + +def test_discard_and_split_several_units_string_ids(): + + spike_indices = [{unit_id: np.arange(10) for unit_id in ["0", "1", "3", "6"]}] + original_sort = NumpySorting.from_unit_dict(spike_indices, sampling_frequency=1000) # to have 1 sample=1ms + + discard_spikes = {0: [2, 4, 6, 8], 3: [1, 3, 5, 7, 9]} + remaining_spikes = { + 0: [0, 1, 3, 5, 7, 9], + 3: [0, 2, 4, 6, 8], + } + + split_spikes = {1: [5, 6, 8], 3: [2, 3, 4, 7, 8, 9]} + + expected_new_spikes = { + "0": [0, 1, 3, 5, 7, 9], + "6": np.arange(10), + # 1 should split into 7 and 8 + "7": [5, 6, 8], + "8": [0, 1, 2, 3, 4, 7, 9], + # 3 should split into 9 and 10, without discarded spikes + "9": [2, 4, 8], + "10": [0, 6], + } + + discard_spikes_curation = { + "format_version": "3", + "unit_ids": original_sort.unit_ids, + "splits": [ + {"unit_id": "1", "mode": "indices", "indices": [split_spikes[1]]}, + {"unit_id": "3", "mode": "indices", "indices": [split_spikes[3]]}, + ], + "discard_spikes": [ + { + "unit_id": "0", + "indices": discard_spikes[0], + }, + { + "unit_id": "3", + "indices": discard_spikes[3], + }, + ], + } + + curated_sort = apply_curation(original_sort, discard_spikes_curation) + + for expected_unit_id, spikes in expected_new_spikes.items(): + assert np.all(curated_sort.get_unit_spike_train(unit_id=expected_unit_id) == spikes) + + expected_new_spikes_split = { + "0": [0, 1, 3, 5, 7, 9], + "6": np.arange(10), + # 1 should split into "1-0" and "1-1" + "1-0": [5, 6, 8], + "1-1": [0, 1, 2, 3, 4, 7, 9], + # 3 should split into "3-0" and "3-1", without discarded spikes + "3-0": [2, 4, 8], + "3-1": [0, 6], + } + + curated_sort_split = apply_curation(original_sort, discard_spikes_curation, new_id_strategy="split") + + for expected_split_unit_id, spikes in expected_new_spikes_split.items(): + assert np.all(curated_sort_split.get_unit_spike_train(unit_id=expected_split_unit_id) == spikes) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index c6b07da52e..5baeb901f0 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -326,10 +326,13 @@ def compute_refrac_period_violations( res = namedtuple("rp_violations", ["rp_contamination", "rp_violations"]) + if unit_ids is None: + unit_ids = sorting_analyzer.unit_ids + if not HAVE_NUMBA: warnings.warn("Error: numba is not installed.") warnings.warn("compute_refrac_period_violations cannot run without numba.") - return None + return {unit_id: np.nan for unit_id in unit_ids} sorting = sorting_analyzer.sorting fs = sorting_analyzer.sampling_frequency @@ -338,9 +341,6 @@ def compute_refrac_period_violations( spikes = sorting.to_spike_vector(concatenated=False) - if unit_ids is None: - unit_ids = sorting_analyzer.unit_ids - num_spikes = compute_num_spikes(sorting_analyzer) t_c = int(round(censored_period_ms * fs * 1e-3)) diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index 83a9048a64..87d5a64a1e 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -244,7 +244,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids): # templates_multi is a list of 2D arrays of shape (n_times, n_channels) tmp_data["templates_multi"] = templates_multi tmp_data["channel_locations_multi"] = channel_locations_multi - tmp_data["depth_direction"] = self.params["depth_direction"] + tmp_data["depth_direction"] = self.params.get("depth_direction") return tmp_data