Skip to content

Commit 45032c6

Browse files
committed
Fix tests
1 parent 89dffa1 commit 45032c6

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

src/spikeinterface/core/sorting_tools.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -940,10 +940,9 @@ def remap_unit_indices_in_vector(vector, all_old_unit_ids, all_new_unit_ids, kee
940940
continue
941941
if old_unit_id in all_new_unit_ids:
942942
new_unit_index = all_new_unit_ids.index(old_unit_id)
943-
if new_unit_index[new_unit_index]:
944-
mapping[old_unit_ind] = new_unit_index
945-
keep[old_unit_ind] = True
946-
keep_mask = vector["unit_index"][keep]
943+
mapping[old_unit_ind] = new_unit_index
944+
keep[old_unit_ind] = True
945+
keep_mask = keep[vector["unit_index"]]
947946
new_vector = vector[keep_mask]
948947
new_vector["unit_index"] = mapping[new_vector["unit_index"]]
949948

src/spikeinterface/postprocessing/valid_unit_periods.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,16 +418,16 @@ def _compute_valid_periods(self, sorting_analyzer, unit_ids=None, **job_kwargs):
418418
all_fps[np.isnan(all_fps)] = 1.0
419419
all_fns[np.isnan(all_fns)] = 1.0
420420

421-
good_period_mask = (all_fps < self.params["fp_threshold"]) & (all_fns < self.params["fn_threshold"])
422-
good_periods = all_periods[good_period_mask]
421+
valid_period_mask = (all_fps < self.params["fp_threshold"]) & (all_fns < self.params["fn_threshold"])
422+
valid_unit_periods = all_periods[valid_period_mask]
423423

424424
# Combine with user-defined periods if provided
425425
if self.params["method"] == "combined":
426426
user_defined_periods = self.user_defined_periods
427427
valid_unit_periods = np.concatenate((valid_unit_periods, user_defined_periods), axis=0)
428428

429429
# Sort good periods on segment_index, unit_index, start_sample_index
430-
valid_unit_periods, _ = self._sort_periods(good_periods)
430+
valid_unit_periods, _ = self._sort_periods(valid_unit_periods)
431431
valid_unit_periods = merge_overlapping_periods_across_units_and_segments(valid_unit_periods)
432432

433433
# Remove good periods that are too short

0 commit comments

Comments
 (0)