- 
                Notifications
    You must be signed in to change notification settings 
- Fork 52
Description
numpy provides an indirect way to compute the indices of the smallest (or largest) values of an array using: numpy.argpartition.
There is also a proposal to provide a higher level API, namely (arg)topk in numpy:
This PR relies on numpy.argpartition internally but it can probably later be optimized to avoid allocating a result array of the size of the input array when k is small.
Here is a quick review of some available implementations in related libraries:
- torch.topk (no such thing as torch.argpartition)- returns a tuple of values and indices
 
- jax.lax.top_k
- returns a tuple of values and indices
- apparently it is quite slow on GPU
 
- dask.array.topk
- returns only the values, I did not find a way to get the indices :(
 
- cupy.argpartition but internally computes a full cupy.argsortwhich makes it very inefficient for large arrays and smallk: O(nlog(n)) instead of O(n).
Motivation: (arg)topk is needed by popular baseline data-science workloads (e.g. k-nearest neighbors classification in scikit-learn) and is surprisingly non trivial to implement efficiently. For instance on GPUs, the fastest implementations are based on some kind of partial radix sort while CPU implementations would use more traditional partial sorting algorithms (as implemented in std:partial_sort or std::nth_element).