Skip to content

Commit 30712de

Browse files
made ntop always flexible (i.e., not only when ntop >= B.shape[1])
1 parent 5a12efb commit 30712de

9 files changed

+559
-503
lines changed

sparse_dot_topn/awesome_cossim_topn.py

+213-228
Large diffs are not rendered by default.

sparse_dot_topn/sparse_dot_topn.pyx

+266-249
Large diffs are not rendered by default.

sparse_dot_topn/sparse_dot_topn_parallel.cpp

+20-7
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,7 @@ void sparse_dot_topn_extd_parallel(
419419
void inner_sparse_dot_free(
420420
job_range_type job_range,
421421
int n_col_inner,
422+
int ntop_inner,
422423
double lower_bound_inner,
423424
int Ap_copy[],
424425
int Aj_copy[],
@@ -485,18 +486,29 @@ void inner_sparse_dot_free(
485486
}
486487

487488
int len = (int) (real_candidates->size() - sz);
489+
*n_minmax = (len > *n_minmax)? len : *n_minmax;
488490

489491
candidate* candidate_arr_begin = real_candidates->data() + sz;
490-
std::sort(
491-
candidate_arr_begin,
492-
candidate_arr_begin + len,
493-
candidate_cmp
494-
);
492+
if (len > ntop_inner){
493+
std::partial_sort(
494+
candidate_arr_begin,
495+
candidate_arr_begin + ntop_inner,
496+
candidate_arr_begin + len,
497+
candidate_cmp
498+
);
499+
len = ntop_inner;
500+
}
501+
else {
502+
std::sort(
503+
candidate_arr_begin,
504+
candidate_arr_begin + len,
505+
candidate_cmp
506+
);
507+
}
495508

496509
real_candidates->resize(sz + (size_t) len);
497510
*(row_sizes_ptr++) = len;
498511
(*total) += len;
499-
*n_minmax = (len > *n_minmax)? len : *n_minmax;
500512
}
501513
real_candidates->shrink_to_fit();
502514
}
@@ -510,6 +522,7 @@ void sparse_dot_free_parallel(
510522
int Bp[],
511523
int Bj[],
512524
double Bx[], //data of B
525+
int ntop,
513526
double lower_bound,
514527
int Cp[],
515528
std::vector<int>* vCj,
@@ -536,7 +549,7 @@ void sparse_dot_free_parallel(
536549
inner_sparse_dot_free,
537550
job_ranges[job_nr],
538551
n_col,
539-
lower_bound,
552+
ntop, lower_bound,
540553
Ap, Aj, Ax, Bp, Bj, Bx,
541554
&real_candidates[job_nr],
542555
&row_sizes[job_nr],

sparse_dot_topn/sparse_dot_topn_parallel.h

+1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ extern void sparse_dot_free_parallel(
6767
int Bp[],
6868
int Bj[],
6969
double Bx[], //data of B
70+
int ntop,
7071
double lower_bound,
7172
int Cp[],
7273
std::vector<int>* Cj,

sparse_dot_topn/sparse_dot_topn_source.cpp

+13-7
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,9 @@ void sparse_dot_topn_extd_source(
250250
C++ implementation of sparse_dot_free_source
251251
252252
This function will return a matrix C in CSR format, where
253-
C = [all results > lower_bound sorted for each row of A * B].
254-
It also returns the maximum number of elements per row of C.
253+
C = [sorted top n results > lower_bound for each row of A * B].
254+
The maximum number n_minmax of elements per row of C (assuming ntop = n_col)
255+
is also returned.
255256
256257
Input:
257258
n_row: number of rows of A matrix
@@ -260,7 +261,7 @@ void sparse_dot_topn_extd_source(
260261
Ap, Aj, Ax: CSR expression of A matrix
261262
Bp, Bj, Bx: CSR expression of B matrix
262263
263-
memory_bound: the maximum number of elements per row of C
264+
ntop: n top results
264265
lower_bound: a threshold that the element of A*B must greater than
265266
266267
Output by reference:
@@ -280,6 +281,7 @@ void sparse_dot_free_source(
280281
int Bp[],
281282
int Bj[],
282283
double Bx[], //data of B
284+
int ntop,
283285
double lower_bound,
284286
int Cp[],
285287
std::vector<int>* Cj,
@@ -342,18 +344,22 @@ void sparse_dot_free_source(
342344

343345
int len = (int)candidates.size();
344346
*n_minmax = (len > *n_minmax)? len : *n_minmax;
345-
std::sort(candidates.begin(), candidates.end(), candidate_cmp);
347+
348+
if (len > ntop){
349+
std::partial_sort(candidates.begin(), candidates.begin()+ntop, candidates.end(), candidate_cmp);
350+
len = ntop;
351+
} else {
352+
std::sort(candidates.begin(), candidates.end(), candidate_cmp);
353+
}
346354

347355
for(int a=0; a < len; a++){
348356
Cj->push_back(candidates[a].index);
349357
Cx->push_back(candidates[a].value);
350358
}
351359
candidates.clear();
352360

353-
Cp[i+1] = (int) (Cj->size());
361+
Cp[i+1] = Cj->size();
354362
}
355-
Cj->shrink_to_fit();
356-
Cx->shrink_to_fit();
357363
}
358364

359365
/*

sparse_dot_topn/sparse_dot_topn_source.h

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ extern void sparse_dot_free_source(
7070
int Bp[],
7171
int Bj[],
7272
double Bx[], //data of B
73+
int ntop,
7374
double lower_bound,
7475
int Cp[],
7576
std::vector<int>* Cj,

sparse_dot_topn/sparse_dot_topn_threaded.pyx

+3-1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ cdef extern from "sparse_dot_topn_parallel.h":
7575
int Bp[],
7676
int Bj[],
7777
double Bx[],
78+
int ntop,
7879
double lower_bound,
7980
int Cp[],
8081
vector[int]* Cj,
@@ -167,6 +168,7 @@ cpdef sparse_dot_free_threaded(
167168
np.ndarray[int, ndim=1] b_indptr,
168169
np.ndarray[int, ndim=1] b_indices,
169170
np.ndarray[double, ndim=1] b_data,
171+
int ntop,
170172
double lower_bound,
171173
np.ndarray[int, ndim=1] c_indptr,
172174
int n_jobs
@@ -185,7 +187,7 @@ cpdef sparse_dot_free_threaded(
185187
cdef vector[int] vCj;
186188
cdef vector[double] vCx;
187189

188-
sparse_dot_free_parallel(n_row, n_col, Ap, Aj, Ax, Bp, Bj, Bx, lower_bound, Cp, &vCj, &vCx, n_minmax, n_jobs)
190+
sparse_dot_free_parallel(n_row, n_col, Ap, Aj, Ax, Bp, Bj, Bx, ntop, lower_bound, Cp, &vCj, &vCx, n_minmax, n_jobs)
189191

190192
c_indices = np.asarray(ArrayWrapper_int(vCj)).squeeze(axis=0)
191193
c_data = np.asarray(ArrayWrapper_double(vCx)).squeeze(axis=0)

sparse_dot_topn/test/test_awesome_cossim_topn.py

+36-8
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,15 @@ def helper_awesome_cossim_topn_dense(
6262
use_threads=use_threads,
6363
n_jobs=n_jobs
6464
)
65-
awesome_result_top3 = \
66-
awesome_cossim_topn(a_csr, b_csr_t, NUM_CANDIDATES, 0.0, use_threads=use_threads, n_jobs=n_jobs)
65+
awesome_result_top3 = awesome_cossim_topn(
66+
a_csr,
67+
b_csr_t,
68+
NUM_CANDIDATES,
69+
0.0,
70+
mem_manager_is_C=mem_manager_is_C,
71+
use_threads=use_threads,
72+
n_jobs=n_jobs
73+
)
6774
awesome_result_top3 = [list(zip(row.indices, row.data)) if len(
6875
row.data) > 0 else None for row in awesome_result_top3] # make comparable, normally not needed
6976

@@ -76,8 +83,15 @@ def helper_awesome_cossim_topn_dense(
7683
use_threads=use_threads,
7784
n_jobs=n_jobs
7885
)
79-
pruned_awesome_result_top3 = \
80-
awesome_cossim_topn(a_csr, b_csr_t, NUM_CANDIDATES, PRUNE_THRESHOLD, use_threads=use_threads, n_jobs=n_jobs)
86+
pruned_awesome_result_top3 = awesome_cossim_topn(
87+
a_csr,
88+
b_csr_t,
89+
NUM_CANDIDATES,
90+
PRUNE_THRESHOLD,
91+
mem_manager_is_C=mem_manager_is_C,
92+
use_threads=use_threads,
93+
n_jobs=n_jobs
94+
)
8195
pruned_awesome_result_top3 = [list(zip(row.indices, row.data)) if len(
8296
row.data) > 0 else None for row in pruned_awesome_result_top3]
8397

@@ -131,8 +145,15 @@ def helper_awesome_cossim_topn_sparse(
131145
use_threads=use_threads,
132146
n_jobs=n_jobs
133147
)
134-
awesome_result_top3 = \
135-
awesome_cossim_topn(a_csr, b_csr_t, NUM_CANDIDATES, 0.0, use_threads=use_threads, n_jobs=n_jobs)
148+
awesome_result_top3 = awesome_cossim_topn(
149+
a_csr,
150+
b_csr_t,
151+
NUM_CANDIDATES,
152+
0.0,
153+
mem_manager_is_C=mem_manager_is_C,
154+
use_threads=use_threads,
155+
n_jobs=n_jobs
156+
)
136157
awesome_result_top3 = [list(zip(row.indices, row.data)) if len(
137158
row.data) > 0 else None for row in awesome_result_top3] # make comparable, normally not needed
138159

@@ -145,8 +166,15 @@ def helper_awesome_cossim_topn_sparse(
145166
use_threads=use_threads,
146167
n_jobs=n_jobs
147168
)
148-
pruned_awesome_result_top3 = \
149-
awesome_cossim_topn(a_csr, b_csr_t, NUM_CANDIDATES, PRUNE_THRESHOLD, use_threads=use_threads, n_jobs=n_jobs)
169+
pruned_awesome_result_top3 = awesome_cossim_topn(
170+
a_csr,
171+
b_csr_t,
172+
NUM_CANDIDATES,
173+
PRUNE_THRESHOLD,
174+
mem_manager_is_C=mem_manager_is_C,
175+
use_threads=use_threads,
176+
n_jobs=n_jobs
177+
)
150178
pruned_awesome_result_top3 = [list(zip(row.indices, row.data)) if len(
151179
row.data) > 0 else None for row in pruned_awesome_result_top3]
152180

string_grouper/string_grouper.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,13 @@ def __init__(self, master: pd.Series,
218218
self._duplicates: pd.Series = duplicates if duplicates is not None else None
219219
self._master_id: pd.Series = master_id if master_id is not None else None
220220
self._duplicates_id: pd.Series = duplicates_id if duplicates_id is not None else None
221+
221222
self._config: StringGrouperConfig = StringGrouperConfig(**kwargs)
222-
self._max_n_matches = len(self._master) if self._config.max_n_matches is None \
223-
else self._config.max_n_matches
223+
if self._config.max_n_matches is None:
224+
self._max_n_matches = len(self._master) if self._duplicates is None else len(self._duplicates)
225+
else:
226+
self._max_n_matches = self._config.max_n_matches
227+
224228
self._validate_group_rep_specs()
225229
self._validate_replace_na_and_drop()
226230
self.is_build = False # indicates if the grouper was fit or not
@@ -435,7 +439,6 @@ def _build_matches(self, master_matrix: csr_matrix, duplicate_matrix: csr_matrix
435439
optional_kwargs = dict()
436440
if self._config.number_of_processes > 1:
437441
optional_kwargs = {
438-
'ntop_is_flexible': self._config.max_n_matches is None,
439442
'return_best_topn': True,
440443
'use_threads': True,
441444
'n_jobs': self._config.number_of_processes

0 commit comments

Comments
 (0)