Skip to content

Commit 9efde62

Browse files
Patch for memory usage in template_similarity (#4152)
* WIP * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixing tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP * WIP * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 813d964 commit 9efde62

File tree

3 files changed

+73
-31
lines changed

3 files changed

+73
-31
lines changed

src/spikeinterface/postprocessing/template_similarity.py

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ def _get_data(self):
208208
compute_template_similarity = ComputeTemplateSimilarity.function_factory()
209209

210210

211-
def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, mask, method):
211+
def _compute_similarity_matrix_numpy(
212+
templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support="union"
213+
):
212214

213215
num_templates = templates_array.shape[0]
214216
num_samples = templates_array.shape[1]
@@ -232,15 +234,16 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num
232234
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
233235
for i in range(num_templates):
234236
src_template = src_sliced_templates[i]
235-
overlapping_templates = np.flatnonzero(np.sum(mask[i], 1))
237+
local_mask = get_overlapping_mask_for_one_template(i, sparsity_mask, other_sparsity_mask, support=support)
238+
overlapping_templates = np.flatnonzero(np.sum(local_mask, 1))
236239
tgt_templates = tgt_sliced_templates[overlapping_templates]
237240
for gcount, j in enumerate(overlapping_templates):
238241
# symmetric values are handled later
239242
if same_array and j < i:
240243
# no need exhaustive looping when same template
241244
continue
242-
src = src_template[:, mask[i, j]].reshape(1, -1)
243-
tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1)
245+
src = src_template[:, local_mask[j]].reshape(1, -1)
246+
tgt = (tgt_templates[gcount][:, local_mask[j]]).reshape(1, -1)
244247

245248
if method == "l1":
246249
norm_i = np.sum(np.abs(src))
@@ -273,9 +276,12 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num
273276
import numba
274277

275278
@numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True)
276-
def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, mask, method):
279+
def _compute_similarity_matrix_numba(
280+
templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support="union"
281+
):
277282
num_templates = templates_array.shape[0]
278283
num_samples = templates_array.shape[1]
284+
num_channels = templates_array.shape[2]
279285
other_num_templates = other_templates_array.shape[0]
280286

281287
num_shifts_both_sides = 2 * num_shifts + 1
@@ -284,7 +290,6 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num
284290

285291
# We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
286292
# So the matrix can be computed only for negative lags and be transposed
287-
288293
if same_array:
289294
# optimisation when array are the same because of symetry in shift
290295
shift_loop = list(range(-num_shifts, 1))
@@ -304,7 +309,23 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num
304309
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
305310
for i in numba.prange(num_templates):
306311
src_template = src_sliced_templates[i]
307-
overlapping_templates = np.flatnonzero(np.sum(mask[i], 1))
312+
313+
## Ideally we would like to use this but numba does not support well function with numpy and boolean arrays
314+
## So we inline the function here
315+
# local_mask = get_overlapping_mask_for_one_template(i, sparsity, other_sparsity, support=support)
316+
317+
if support == "intersection":
318+
local_mask = np.logical_and(
319+
sparsity_mask[i, :], other_sparsity_mask
320+
) # shape (other_num_templates, num_channels)
321+
elif support == "union":
322+
local_mask = np.logical_or(
323+
sparsity_mask[i, :], other_sparsity_mask
324+
) # shape (other_num_templates, num_channels)
325+
elif support == "dense":
326+
local_mask = np.ones((other_num_templates, num_channels), dtype=np.bool_)
327+
328+
overlapping_templates = np.flatnonzero(np.sum(local_mask, 1))
308329
tgt_templates = tgt_sliced_templates[overlapping_templates]
309330
for gcount in range(len(overlapping_templates)):
310331

@@ -313,8 +334,8 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num
313334
if same_array and j < i:
314335
# no need exhaustive looping when same template
315336
continue
316-
src = src_template[:, mask[i, j]].flatten()
317-
tgt = (tgt_templates[gcount][:, mask[i, j]]).flatten()
337+
src = src_template[:, local_mask[j]].flatten()
338+
tgt = (tgt_templates[gcount][:, local_mask[j]]).flatten()
318339

319340
norm_i = 0
320341
norm_j = 0
@@ -360,6 +381,17 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num
360381
_compute_similarity_matrix = _compute_similarity_matrix_numpy
361382

362383

384+
def get_overlapping_mask_for_one_template(template_index, sparsity, other_sparsity, support="union") -> np.ndarray:
385+
386+
if support == "intersection":
387+
mask = np.logical_and(sparsity[template_index, :], other_sparsity) # shape (other_num_templates, num_channels)
388+
elif support == "union":
389+
mask = np.logical_or(sparsity[template_index, :], other_sparsity) # shape (other_num_templates, num_channels)
390+
elif support == "dense":
391+
mask = np.ones(other_sparsity.shape, dtype=bool)
392+
return mask
393+
394+
363395
def compute_similarity_with_templates_array(
364396
templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None
365397
):
@@ -369,6 +401,8 @@ def compute_similarity_with_templates_array(
369401

370402
all_metrics = ["cosine", "l1", "l2"]
371403

404+
assert support in ["dense", "union", "intersection"], "support should be either dense, union or intersection"
405+
372406
if method not in all_metrics:
373407
raise ValueError(f"compute_template_similarity (method {method}) not exists")
374408

@@ -378,29 +412,25 @@ def compute_similarity_with_templates_array(
378412
assert (
379413
templates_array.shape[2] == other_templates_array.shape[2]
380414
), "The number of channels in the templates should be the same for both arrays"
381-
num_templates = templates_array.shape[0]
415+
# num_templates = templates_array.shape[0]
382416
num_samples = templates_array.shape[1]
383-
num_channels = templates_array.shape[2]
384-
other_num_templates = other_templates_array.shape[0]
385-
386-
mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool)
417+
# num_channels = templates_array.shape[2]
418+
# other_num_templates = other_templates_array.shape[0]
387419

388-
if sparsity is not None and other_sparsity is not None:
389-
390-
# make the input more flexible with either The object or the array mask
420+
if sparsity is not None:
391421
sparsity_mask = sparsity.mask if isinstance(sparsity, ChannelSparsity) else sparsity
392-
other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity
422+
else:
423+
sparsity_mask = np.ones((templates_array.shape[0], templates_array.shape[2]), dtype=bool)
393424

394-
if support == "intersection":
395-
mask = np.logical_and(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :])
396-
elif support == "union":
397-
mask = np.logical_and(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :])
398-
units_overlaps = np.sum(mask, axis=2) > 0
399-
mask = np.logical_or(sparsity_mask[:, np.newaxis, :], other_sparsity_mask[np.newaxis, :, :])
400-
mask[~units_overlaps] = False
425+
if other_sparsity is not None:
426+
other_sparsity_mask = other_sparsity.mask if isinstance(other_sparsity, ChannelSparsity) else other_sparsity
427+
else:
428+
other_sparsity_mask = np.ones((other_templates_array.shape[0], other_templates_array.shape[2]), dtype=bool)
401429

402430
assert num_shifts < num_samples, "max_lag is too large"
403-
distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method)
431+
distances = _compute_similarity_matrix(
432+
templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support=support
433+
)
404434

405435
distances = np.min(distances, axis=0)
406436
similarity = 1 - distances

src/spikeinterface/postprocessing/tests/test_template_similarity.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,23 @@ def test_equal_results_numba(params):
107107
rng = np.random.default_rng(seed=2205)
108108
templates_array = rng.random(size=(4, 20, 5), dtype=np.float32)
109109
other_templates_array = rng.random(size=(2, 20, 5), dtype=np.float32)
110-
mask = np.ones((4, 2, 5), dtype=bool)
111-
112-
result_numpy = _compute_similarity_matrix_numba(templates_array, other_templates_array, mask=mask, **params)
113-
result_numba = _compute_similarity_matrix_numpy(templates_array, other_templates_array, mask=mask, **params)
110+
sparsity_mask = np.ones((4, 5), dtype=bool)
111+
other_sparsity_mask = np.ones((2, 5), dtype=bool)
112+
113+
result_numpy = _compute_similarity_matrix_numba(
114+
templates_array,
115+
other_templates_array,
116+
sparsity_mask=sparsity_mask,
117+
other_sparsity_mask=other_sparsity_mask,
118+
**params,
119+
)
120+
result_numba = _compute_similarity_matrix_numpy(
121+
templates_array,
122+
other_templates_array,
123+
sparsity_mask=sparsity_mask,
124+
other_sparsity_mask=other_sparsity_mask,
125+
**params,
126+
)
114127

115128
assert np.allclose(result_numpy, result_numba, 1e-3)
116129

src/spikeinterface/sortingcomponents/clustering/merging_tools.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,6 @@ def merge_peak_labels_from_templates(
541541
assert len(unit_ids) == templates_array.shape[0]
542542

543543
from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array
544-
from scipy.sparse.csgraph import connected_components
545544

546545
similarity = compute_similarity_with_templates_array(
547546
templates_array,

0 commit comments

Comments
 (0)