Skip to content

Commit

Permalink
Add stable_sort_by_key (uxlfoundation#1692)
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitriy Sobolev <[email protected]>
Co-authored-by: Alexey Kukanov <[email protected]>
Co-authored-by: MikeDvorskiy <[email protected]>
  • Loading branch information
3 people authored Aug 15, 2024
1 parent 120fce5 commit 91cf500
Show file tree
Hide file tree
Showing 15 changed files with 501 additions and 223 deletions.
24 changes: 17 additions & 7 deletions documentation/library_guide/parallel_api/additional_algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,29 @@ header. All algorithms are implemented in the ``oneapi::dpl`` namespace.
search sequence: [0, 2, 4, 7, 6]
result sequence: [1, 4, 8, 10, 10]

* ``sort_by_key``: performs a stable key-value sort. The algorithm sorts the sequence's keys according to
a comparioson operator. If no comparator is provided, then the elements are compared with ``operator<``.
The sequence's values are permutated according to the sorted sequence's keys. The prerequisite for correct
behavior is that the size for both keys sequence and values sequence shall be the same.
* ``sort_by_key``: performs a key-value sort.
The algorithm sorts a sequence of keys using a given comparison function object.
If it is not provided, the elements are compared with ``operator<``.
A sequence of values is simultaneously permuted according to the sorted order of keys.
There must be at least as many values as the keys, otherwise the behavior is undefined.

For example::

keys: [3, 5, 0, 4, 3, 0]
values: ['a', 'b', 'c', 'd', 'e', 'f']
output_keys: [0, 0, 3, 3, 4, 5]
output_values: ['c', 'f', 'a', 'e', 'd', 'b']

.. note::
``sort_by_key`` currently implements a stable sort for device execution policies,
but may implement an unstable sort in the future.
Use ``stable_sort_by_key`` if stability is essential.

* ``stable_sort_by_key``: performs a key-value sort similar to ``sort_by_key``,
but with the added guarantee of stability.

* ``transform_if``: performs a transform on the input sequence(s) elements and stores the result into the
corresponding position in the output sequence at each position for which the predicate applied to the
corresponding position in the output sequence at each position for which the predicate applied to the
element(s) evaluates to ``true``. If the predicate evaluates to ``false``, the transform is not applied for
the elements(s), and the output sequence's corresponding position is left unmodified. There are two overloads
of this function, one for a single input sequence with a unary transform and a unary predicate, and another
Expand Down Expand Up @@ -118,11 +128,11 @@ header. All algorithms are implemented in the ``oneapi::dpl`` namespace.
The first overload takes as input the number of bins, range minimum, and range maximum, then evenly
divides bins within that range. An input element ``a`` maps to a bin ``i`` such that
``i = floor((a - minimum) / ((maximum - minimum) / num_bins)))``.

The other overload defines ``m`` bins from a sorted sequence of ``m + 1`` user-provided boundaries
where an input element ``a`` maps to a bin ``i`` if and only if
``__boundary_first[i] <= a < __boundary_first[i + 1]``.

Input values which do not map to a defined bin are skipped silently. The algorithm counts the number of
input elements which map to each bin and outputs the result to a user-provided sequence of ``m`` output
bin counts. The user must provide sufficient output data to store each bin, and the type of the output
Expand Down
18 changes: 17 additions & 1 deletion include/oneapi/dpl/pstl/algorithm_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ __pattern_stable_sort(__parallel_tag<_IsVector>, _ExecutionPolicy&&, _RandomAcce
// sort_by_key
//------------------------------------------------------------------------

template <class _Tag, typename _ExecutionPolicy, typename _RandomAccessIterator1, typename _RandomAccessIterator2,
template <typename _Tag, typename _ExecutionPolicy, typename _RandomAccessIterator1, typename _RandomAccessIterator2,
typename _Compare>
void
__pattern_sort_by_key(_Tag, _ExecutionPolicy&&, _RandomAccessIterator1, _RandomAccessIterator1, _RandomAccessIterator2,
Expand All @@ -778,6 +778,22 @@ void
__pattern_sort_by_key(__parallel_tag<_IsVector>, _ExecutionPolicy&&, _RandomAccessIterator1, _RandomAccessIterator1,
_RandomAccessIterator2, _Compare);

//------------------------------------------------------------------------
// stable_sort_by_key
//------------------------------------------------------------------------

template <typename _Tag, typename _ExecutionPolicy, typename _RandomAccessIterator1, typename _RandomAccessIterator2,
typename _Compare>
void
__pattern_stable_sort_by_key(_Tag, _ExecutionPolicy&&, _RandomAccessIterator1, _RandomAccessIterator1,
_RandomAccessIterator2, _Compare) noexcept;

template <typename _IsVector, typename _ExecutionPolicy, typename _RandomAccessIterator1,
typename _RandomAccessIterator2, typename _Compare>
void
__pattern_stable_sort_by_key(__parallel_tag<_IsVector>, _ExecutionPolicy&&, _RandomAccessIterator1,
_RandomAccessIterator1, _RandomAccessIterator2, _Compare);

//------------------------------------------------------------------------
// partial_sort
//------------------------------------------------------------------------
Expand Down
56 changes: 43 additions & 13 deletions include/oneapi/dpl/pstl/algorithm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2465,11 +2465,9 @@ __pattern_sort_by_key(_Tag, _ExecutionPolicy&& __exec, _RandomAccessIterator1 __

auto __beg = oneapi::dpl::make_zip_iterator(__keys_first, __values_first);
auto __end = __beg + (__keys_last - __keys_first);
auto __cmp_f = [__comp](const auto& __a, const auto& __b) {
return __comp(::std::get<0>(__a), ::std::get<0>(__b));
};
auto __cmp_f = [__comp](const auto& __a, const auto& __b) { return __comp(std::get<0>(__a), std::get<0>(__b)); };

::std::sort(__beg, __end, __cmp_f);
std::sort(__beg, __end, __cmp_f);
}

template <typename _IsVector, typename _ExecutionPolicy, typename _RandomAccessIterator1,
Expand All @@ -2478,23 +2476,55 @@ void
__pattern_sort_by_key(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _RandomAccessIterator1 __keys_first,
_RandomAccessIterator1 __keys_last, _RandomAccessIterator2 __values_first, _Compare __comp)
{
static_assert(
::std::is_move_constructible_v<typename ::std::iterator_traits<_RandomAccessIterator1>::value_type> &&
::std::is_move_constructible_v<typename ::std::iterator_traits<_RandomAccessIterator2>::value_type>,
"The keys and values should be move constructible in case of parallel execution.");

auto __beg = oneapi::dpl::make_zip_iterator(__keys_first, __values_first);
auto __end = __beg + (__keys_last - __keys_first);
auto __cmp_f = [__comp](const auto& __a, const auto& __b) {
return __comp(::std::get<0>(__a), ::std::get<0>(__b));
};
auto __cmp_f = [__comp](const auto& __a, const auto& __b) { return __comp(std::get<0>(__a), std::get<0>(__b)); };

using __backend_tag = typename __parallel_tag<_IsVector>::__backend_tag;

__internal::__except_handler([&]() {
__par_backend::__parallel_stable_sort(
__backend_tag{}, ::std::forward<_ExecutionPolicy>(__exec), __beg, __end, __cmp_f,
[](auto __first, auto __last, auto __cmp) { ::std::sort(__first, __last, __cmp); }, __end - __beg);
[](auto __first, auto __last, auto __cmp) { std::sort(__first, __last, __cmp); }, __end - __beg);
});
}

//------------------------------------------------------------------------
// stable_sort_by_key
//------------------------------------------------------------------------

template <typename _Tag, typename _ExecutionPolicy, typename _RandomAccessIterator1, typename _RandomAccessIterator2,
typename _Compare>
void
__pattern_stable_sort_by_key(_Tag, _ExecutionPolicy&& __exec, _RandomAccessIterator1 __keys_first,
_RandomAccessIterator1 __keys_last, _RandomAccessIterator2 __values_first,
_Compare __comp) noexcept
{
static_assert(__is_serial_tag_v<_Tag> || __is_parallel_forward_tag_v<_Tag>);

auto __beg = oneapi::dpl::make_zip_iterator(__keys_first, __values_first);
auto __end = __beg + (__keys_last - __keys_first);
auto __cmp_f = [__comp](const auto& __a, const auto& __b) { return __comp(std::get<0>(__a), std::get<0>(__b)); };

std::stable_sort(__beg, __end, __cmp_f);
}

template <typename _IsVector, typename _ExecutionPolicy, typename _RandomAccessIterator1,
typename _RandomAccessIterator2, typename _Compare>
void
__pattern_stable_sort_by_key(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _RandomAccessIterator1 __keys_first,
_RandomAccessIterator1 __keys_last, _RandomAccessIterator2 __values_first, _Compare __comp)
{
auto __beg = oneapi::dpl::make_zip_iterator(__keys_first, __values_first);
auto __end = __beg + (__keys_last - __keys_first);
auto __cmp_f = [__comp](const auto& __a, const auto& __b) { return __comp(std::get<0>(__a), std::get<0>(__b)); };

using __backend_tag = typename __parallel_tag<_IsVector>::__backend_tag;

__internal::__except_handler([&]() {
__par_backend::__parallel_stable_sort(
__backend_tag{}, std::forward<_ExecutionPolicy>(__exec), __beg, __end, __cmp_f,
[](auto __first, auto __last, auto __cmp) { std::stable_sort(__first, __last, __cmp); }, __end - __beg);
});
}

Expand Down
13 changes: 13 additions & 0 deletions include/oneapi/dpl/pstl/glue_algorithm_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,19 @@ oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy>
sort_by_key(_ExecutionPolicy&& __exec, _RandomAccessIterator1 __keys_first, _RandomAccessIterator1 __keys_last,
_RandomAccessIterator2 __values_first);

// oneapi::dpl::stable_sort_by_key

template <typename _ExecutionPolicy, typename _RandomAccessIterator1, typename _RandomAccessIterator2,
typename _Compare>
oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy>
stable_sort_by_key(_ExecutionPolicy&& __exec, _RandomAccessIterator1 __keys_first, _RandomAccessIterator1 __keys_last,
_RandomAccessIterator2 __values_first, _Compare __comp);

template <typename _ExecutionPolicy, typename _RandomAccessIterator1, typename _RandomAccessIterator2>
oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy>
stable_sort_by_key(_ExecutionPolicy&& __exec, _RandomAccessIterator1 __keys_first, _RandomAccessIterator1 __keys_last,
_RandomAccessIterator2 __values_first);

// [mismatch]

template <class _ExecutionPolicy, class _ForwardIterator1, class _ForwardIterator2, class _BinaryPredicate>
Expand Down
23 changes: 23 additions & 0 deletions include/oneapi/dpl/pstl/glue_algorithm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,29 @@ sort_by_key(_ExecutionPolicy&& __exec, _RandomAccessIterator1 __keys_first, _Ran
oneapi::dpl::__internal::__pstl_less());
}

// oneapi::dpl::stable_sort_by_key

template <typename _ExecutionPolicy, typename _RandomAccessIterator1, typename _RandomAccessIterator2,
typename _Compare>
oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy>
stable_sort_by_key(_ExecutionPolicy&& __exec, _RandomAccessIterator1 __keys_first, _RandomAccessIterator1 __keys_last,
_RandomAccessIterator2 __values_first, _Compare __comp)
{
const auto __dispatch_tag = oneapi::dpl::__internal::__select_backend(__exec, __keys_first, __values_first);

oneapi::dpl::__internal::__pattern_stable_sort_by_key(__dispatch_tag, ::std::forward<_ExecutionPolicy>(__exec),
__keys_first, __keys_last, __values_first, __comp);
}

template <typename _ExecutionPolicy, typename _RandomAccessIterator1, typename _RandomAccessIterator2>
oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy>
stable_sort_by_key(_ExecutionPolicy&& __exec, _RandomAccessIterator1 __keys_first, _RandomAccessIterator1 __keys_last,
_RandomAccessIterator2 __values_first)
{
oneapi::dpl::stable_sort_by_key(::std::forward<_ExecutionPolicy>(__exec), __keys_first, __keys_last, __values_first,
oneapi::dpl::__internal::__pstl_less());
}

// [mismatch]

template <class _ExecutionPolicy, class _ForwardIterator1, class _ForwardIterator2, class _BinaryPredicate>
Expand Down
35 changes: 28 additions & 7 deletions include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h
Original file line number Diff line number Diff line change
Expand Up @@ -1270,21 +1270,42 @@ __pattern_stable_sort(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec
oneapi::dpl::identity{});
}

//------------------------------------------------------------------------
// sort_by_key
//------------------------------------------------------------------------

template <typename _BackendTag, typename _ExecutionPolicy, typename _Iterator1, typename _Iterator2, typename _Compare>
void
__pattern_sort_by_key(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _Iterator1 __keys_first,
_Iterator1 __keys_last, _Iterator2 __values_first, _Compare __comp)
__pattern_stable_sort_by_key(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _Iterator1 __keys_first,
_Iterator1 __keys_last, _Iterator2 __values_first, _Compare __comp)
{
static_assert(::std::is_move_constructible_v<typename ::std::iterator_traits<_Iterator1>::value_type> &&
::std::is_move_constructible_v<typename ::std::iterator_traits<_Iterator2>::value_type>,
static_assert(std::is_move_constructible_v<typename std::iterator_traits<_Iterator1>::value_type> &&
std::is_move_constructible_v<typename std::iterator_traits<_Iterator2>::value_type>,
"The keys and values should be move constructible in case of parallel execution.");

auto __beg = oneapi::dpl::make_zip_iterator(__keys_first, __values_first);
auto __end = __beg + (__keys_last - __keys_first);
__stable_sort_with_projection(__tag, ::std::forward<_ExecutionPolicy>(__exec), __beg, __end, __comp,
[](const auto& __a) { return ::std::get<0>(__a); });
__stable_sort_with_projection(__tag, std::forward<_ExecutionPolicy>(__exec), __beg, __end, __comp,
[](const auto& __a) { return std::get<0>(__a); });
}

//------------------------------------------------------------------------
// stable_sort_by_key
//------------------------------------------------------------------------

template <typename _BackendTag, typename _ExecutionPolicy, typename _Iterator1, typename _Iterator2, typename _Compare>
void
__pattern_sort_by_key(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _Iterator1 __keys_first,
_Iterator1 __keys_last, _Iterator2 __values_first, _Compare __comp)
{
__pattern_stable_sort_by_key(__tag, std::forward<_ExecutionPolicy>(__exec), __keys_first, __keys_last,
__values_first, __comp);
}

//------------------------------------------------------------------------
// stable_partition
//------------------------------------------------------------------------

template <typename _BackendTag, typename _ExecutionPolicy, typename _Iterator, typename _UnaryPredicate>
_Iterator
__pattern_stable_partition(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _Iterator __first,
Expand Down Expand Up @@ -1490,7 +1511,7 @@ __pattern_partial_sort_copy(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&&

// TODO: __pattern_walk2 is a blocking call here, so there is a synchronization between the patterns.
// But, when the input iterators are a kind of hetero iterator on top of sycl::buffer, SYCL
// runtime makes a dependency graph. In that case the call of __pattern_walk2 could be changed to
// runtime makes a dependency graph. In that case the call of __pattern_walk2 could be changed to
// be asynchronous for better performance.

// Use regular sort as partial_sort isn't required to be stable.
Expand Down
12 changes: 6 additions & 6 deletions test/kt/esimd_radix_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ test_all_view(sycl::queue q, std::size_t size, KernelParam param)
std::cout << "\ttest_all_view(" << size << ") : " << TypeInfo().name<T>() << std::endl;
#endif
std::vector<T> input(size);
generate_data(input.data(), size, 42);
TestUtils::generate_arithmetic_data(input.data(), size, 42);
std::vector<T> ref(input);
std::stable_sort(std::begin(ref), std::end(ref), Compare<T, IsAscending>{});
{
Expand All @@ -66,7 +66,7 @@ test_subrange_view(sycl::queue q, std::size_t size, KernelParam param)
<< std::endl;
#endif
std::vector<T> expected(size);
generate_data(expected.data(), size, 42);
TestUtils::generate_arithmetic_data(expected.data(), size, 42);

TestUtils::usm_data_transfer<sycl::usm::alloc::device, T> dt_input(q, expected.begin(), expected.end());

Expand All @@ -93,7 +93,7 @@ test_usm(sycl::queue q, std::size_t size, KernelParam param)
<< IsAscending << ">(" << size << ");" << std::endl;
#endif
std::vector<T> expected(size);
generate_data(expected.data(), size, 42);
TestUtils::generate_arithmetic_data(expected.data(), size, 42);

TestUtils::usm_data_transfer<_alloc_type, T> dt_input(q, expected.begin(), expected.end());

Expand All @@ -118,7 +118,7 @@ test_sycl_iterators(sycl::queue q, std::size_t size, KernelParam param)
std::cout << "\t\ttest_sycl_iterators<" << TypeInfo().name<T>() << ">(" << size << ");" << std::endl;
#endif
std::vector<T> input(size);
generate_data(input.data(), size, 42);
TestUtils::generate_arithmetic_data(input.data(), size, 42);
std::vector<T> ref(input);
std::stable_sort(std::begin(ref), std::end(ref), Compare<T, IsAscending>{});
{
Expand Down Expand Up @@ -149,7 +149,7 @@ test_sycl_buffer(sycl::queue q, std::size_t size, KernelParam param)
std::cout << "\t\ttest_sycl_buffer<" << TypeInfo().name<T>() << ">(" << size << ");" << std::endl;
#endif
std::vector<T> input(size);
generate_data(input.data(), size, 42);
TestUtils::generate_arithmetic_data(input.data(), size, 42);
std::vector<T> ref(input);
std::stable_sort(std::begin(ref), std::end(ref), Compare<T, IsAscending>{});
{
Expand All @@ -167,7 +167,7 @@ test_small_sizes(sycl::queue q, KernelParam param)
{
constexpr int size = 8;
std::vector<T> input(size);
generate_data(input.data(), size, 42);
TestUtils::generate_arithmetic_data(input.data(), size, 42);
std::vector<T> ref(input);

oneapi::dpl::experimental::kt::gpu::esimd::radix_sort<IsAscending, RadixBits>(q, oneapi::dpl::begin(input),
Expand Down
8 changes: 4 additions & 4 deletions test/kt/esimd_radix_sort_by_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ void test_sycl_buffer(sycl::queue q, std::size_t size, KernelParam param)
{
std::vector<KeyT> expected_keys(size);
std::vector<ValueT> expected_values(size);
generate_data(expected_keys.data(), size, 6);
generate_data(expected_values.data(), size, 7);
TestUtils::generate_arithmetic_data(expected_keys.data(), size, 6);
TestUtils::generate_arithmetic_data(expected_values.data(), size, 7);

std::vector<KeyT> actual_keys(expected_keys);
std::vector<ValueT> actual_values(expected_values);
Expand Down Expand Up @@ -62,8 +62,8 @@ void test_usm(sycl::queue q, std::size_t size, KernelParam param)
{
std::vector<KeyT> expected_keys(size);
std::vector<ValueT> expected_values(size);
generate_data(expected_keys.data(), size, 6);
generate_data(expected_values.data(), size, 7);
TestUtils::generate_arithmetic_data(expected_keys.data(), size, 6);
TestUtils::generate_arithmetic_data(expected_values.data(), size, 7);

TestUtils::usm_data_transfer<_alloc_type, KeyT> keys(q, expected_keys.begin(), expected_keys.end());
TestUtils::usm_data_transfer<_alloc_type, ValueT> values(q, expected_values.begin(), expected_values.end());
Expand Down
Loading

0 comments on commit 91cf500

Please sign in to comment.