@@ -208,7 +208,9 @@ def _get_data(self):
208208compute_template_similarity = ComputeTemplateSimilarity .function_factory ()
209209
210210
211- def _compute_similarity_matrix_numpy (templates_array , other_templates_array , num_shifts , mask , method ):
211+ def _compute_similarity_matrix_numpy (
212+ templates_array , other_templates_array , num_shifts , method , sparsity_mask , other_sparsity_mask , support = "union"
213+ ):
212214
213215 num_templates = templates_array .shape [0 ]
214216 num_samples = templates_array .shape [1 ]
@@ -232,15 +234,16 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num
232234 tgt_sliced_templates = other_templates_array [:, num_shifts + shift : num_samples - num_shifts + shift ]
233235 for i in range (num_templates ):
234236 src_template = src_sliced_templates [i ]
235- overlapping_templates = np .flatnonzero (np .sum (mask [i ], 1 ))
237+ local_mask = get_overlapping_mask_for_one_template (i , sparsity_mask , other_sparsity_mask , support = support )
238+ overlapping_templates = np .flatnonzero (np .sum (local_mask , 1 ))
236239 tgt_templates = tgt_sliced_templates [overlapping_templates ]
237240 for gcount , j in enumerate (overlapping_templates ):
238241 # symmetric values are handled later
239242 if same_array and j < i :
240243 # no need exhaustive looping when same template
241244 continue
242- src = src_template [:, mask [ i , j ]].reshape (1 , - 1 )
243- tgt = (tgt_templates [gcount ][:, mask [ i , j ]]).reshape (1 , - 1 )
245+ src = src_template [:, local_mask [ j ]].reshape (1 , - 1 )
246+ tgt = (tgt_templates [gcount ][:, local_mask [ j ]]).reshape (1 , - 1 )
244247
245248 if method == "l1" :
246249 norm_i = np .sum (np .abs (src ))
@@ -273,9 +276,12 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num
273276 import numba
274277
275278 @numba .jit (nopython = True , parallel = True , fastmath = True , nogil = True )
276- def _compute_similarity_matrix_numba (templates_array , other_templates_array , num_shifts , mask , method ):
279+ def _compute_similarity_matrix_numba (
280+ templates_array , other_templates_array , num_shifts , method , sparsity_mask , other_sparsity_mask , support = "union"
281+ ):
277282 num_templates = templates_array .shape [0 ]
278283 num_samples = templates_array .shape [1 ]
284+ num_channels = templates_array .shape [2 ]
279285 other_num_templates = other_templates_array .shape [0 ]
280286
281287 num_shifts_both_sides = 2 * num_shifts + 1
@@ -284,7 +290,6 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num
284290
285291 # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
286292 # So the matrix can be computed only for negative lags and be transposed
287-
288293 if same_array :
289294 # optimisation when array are the same because of symetry in shift
290295 shift_loop = list (range (- num_shifts , 1 ))
@@ -304,7 +309,23 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num
304309 tgt_sliced_templates = other_templates_array [:, num_shifts + shift : num_samples - num_shifts + shift ]
305310 for i in numba .prange (num_templates ):
306311 src_template = src_sliced_templates [i ]
307- overlapping_templates = np .flatnonzero (np .sum (mask [i ], 1 ))
312+
313+ ## Ideally we would like to use this but numba does not support well function with numpy and boolean arrays
314+ ## So we inline the function here
315+ # local_mask = get_overlapping_mask_for_one_template(i, sparsity, other_sparsity, support=support)
316+
317+ if support == "intersection" :
318+ local_mask = np .logical_and (
319+ sparsity_mask [i , :], other_sparsity_mask
320+ ) # shape (other_num_templates, num_channels)
321+ elif support == "union" :
322+ local_mask = np .logical_or (
323+ sparsity_mask [i , :], other_sparsity_mask
324+ ) # shape (other_num_templates, num_channels)
325+ elif support == "dense" :
326+ local_mask = np .ones ((other_num_templates , num_channels ), dtype = np .bool_ )
327+
328+ overlapping_templates = np .flatnonzero (np .sum (local_mask , 1 ))
308329 tgt_templates = tgt_sliced_templates [overlapping_templates ]
309330 for gcount in range (len (overlapping_templates )):
310331
@@ -313,8 +334,8 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num
313334 if same_array and j < i :
314335 # no need exhaustive looping when same template
315336 continue
316- src = src_template [:, mask [ i , j ]].flatten ()
317- tgt = (tgt_templates [gcount ][:, mask [ i , j ]]).flatten ()
337+ src = src_template [:, local_mask [ j ]].flatten ()
338+ tgt = (tgt_templates [gcount ][:, local_mask [ j ]]).flatten ()
318339
319340 norm_i = 0
320341 norm_j = 0
@@ -360,6 +381,17 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num
360381 _compute_similarity_matrix = _compute_similarity_matrix_numpy
361382
362383
384+ def get_overlapping_mask_for_one_template (template_index , sparsity , other_sparsity , support = "union" ) -> np .ndarray :
385+
386+ if support == "intersection" :
387+ mask = np .logical_and (sparsity [template_index , :], other_sparsity ) # shape (other_num_templates, num_channels)
388+ elif support == "union" :
389+ mask = np .logical_or (sparsity [template_index , :], other_sparsity ) # shape (other_num_templates, num_channels)
390+ elif support == "dense" :
391+ mask = np .ones (other_sparsity .shape , dtype = bool )
392+ return mask
393+
394+
363395def compute_similarity_with_templates_array (
364396 templates_array , other_templates_array , method , support = "union" , num_shifts = 0 , sparsity = None , other_sparsity = None
365397):
@@ -369,6 +401,8 @@ def compute_similarity_with_templates_array(
369401
370402 all_metrics = ["cosine" , "l1" , "l2" ]
371403
404+ assert support in ["dense" , "union" , "intersection" ], "support should be either dense, union or intersection"
405+
372406 if method not in all_metrics :
373407 raise ValueError (f"compute_template_similarity (method { method } ) not exists" )
374408
@@ -378,29 +412,25 @@ def compute_similarity_with_templates_array(
378412 assert (
379413 templates_array .shape [2 ] == other_templates_array .shape [2 ]
380414 ), "The number of channels in the templates should be the same for both arrays"
381- num_templates = templates_array .shape [0 ]
415+ # num_templates = templates_array.shape[0]
382416 num_samples = templates_array .shape [1 ]
383- num_channels = templates_array .shape [2 ]
384- other_num_templates = other_templates_array .shape [0 ]
385-
386- mask = np .ones ((num_templates , other_num_templates , num_channels ), dtype = bool )
417+ # num_channels = templates_array.shape[2]
418+ # other_num_templates = other_templates_array.shape[0]
387419
388- if sparsity is not None and other_sparsity is not None :
389-
390- # make the input more flexible with either The object or the array mask
420+ if sparsity is not None :
391421 sparsity_mask = sparsity .mask if isinstance (sparsity , ChannelSparsity ) else sparsity
392- other_sparsity_mask = other_sparsity .mask if isinstance (other_sparsity , ChannelSparsity ) else other_sparsity
422+ else :
423+ sparsity_mask = np .ones ((templates_array .shape [0 ], templates_array .shape [2 ]), dtype = bool )
393424
394- if support == "intersection" :
395- mask = np .logical_and (sparsity_mask [:, np .newaxis , :], other_sparsity_mask [np .newaxis , :, :])
396- elif support == "union" :
397- mask = np .logical_and (sparsity_mask [:, np .newaxis , :], other_sparsity_mask [np .newaxis , :, :])
398- units_overlaps = np .sum (mask , axis = 2 ) > 0
399- mask = np .logical_or (sparsity_mask [:, np .newaxis , :], other_sparsity_mask [np .newaxis , :, :])
400- mask [~ units_overlaps ] = False
425+ if other_sparsity is not None :
426+ other_sparsity_mask = other_sparsity .mask if isinstance (other_sparsity , ChannelSparsity ) else other_sparsity
427+ else :
428+ other_sparsity_mask = np .ones ((other_templates_array .shape [0 ], other_templates_array .shape [2 ]), dtype = bool )
401429
402430 assert num_shifts < num_samples , "max_lag is too large"
403- distances = _compute_similarity_matrix (templates_array , other_templates_array , num_shifts , mask , method )
431+ distances = _compute_similarity_matrix (
432+ templates_array , other_templates_array , num_shifts , method , sparsity_mask , other_sparsity_mask , support = support
433+ )
404434
405435 distances = np .min (distances , axis = 0 )
406436 similarity = 1 - distances
0 commit comments