diff --git a/libs/core/algorithms/include/hpx/parallel/algorithms/nth_element.hpp b/libs/core/algorithms/include/hpx/parallel/algorithms/nth_element.hpp index 428a19e3884..a8c31f63dcb 100644 --- a/libs/core/algorithms/include/hpx/parallel/algorithms/nth_element.hpp +++ b/libs/core/algorithms/include/hpx/parallel/algorithms/nth_element.hpp @@ -139,11 +139,14 @@ namespace hpx { #include #include #include +#include #include #include +#include #include #include #include +#include #include #include @@ -180,6 +183,10 @@ namespace hpx::parallel { constexpr void nth_element_seq(RandomIt first, RandomIt nth, RandomIt end, std::uint32_t level, Compare&& comp, Proj&& proj) { + using wrapped_comp_type = + hpx::parallel::util::compare_projected, + std::decay_t>; + constexpr std::uint32_t nmin_sort = 24; auto nelem = end - first; @@ -187,8 +194,8 @@ namespace hpx::parallel { if (nth == first) { RandomIt it = detail::min_element().call( - hpx::execution::seq, first, end, HPX_FORWARD(Compare, comp), - HPX_FORWARD(Proj, proj)); + hpx::execution::seq, first, end, + wrapped_comp_type(comp, proj), hpx::identity_v); if (it != first) { @@ -206,13 +213,13 @@ namespace hpx::parallel { } if (level == 0) { - std::make_heap(first, end, comp); - std::sort_heap(first, nth, comp); + std::make_heap(first, end, wrapped_comp_type(comp, proj)); + std::sort_heap(first, nth, wrapped_comp_type(comp, proj)); return; } // Filter the range and check which part contains the nth element - RandomIt c_last = filter(first, end, comp); + RandomIt c_last = filter(first, end, wrapped_comp_type(comp, proj)); if (c_last == nth) return; @@ -263,9 +270,6 @@ namespace hpx::parallel { parallel(ExPolicy&& policy, RandomIt first, RandomIt nth, Sent last, Pred&& pred, Proj&& proj) { - using value_type = - typename std::iterator_traits::value_type; - RandomIt partition_iter, return_last; if (first == last) @@ -288,17 +292,21 @@ namespace hpx::parallel { while (first != last_iter) { - detail::pivot9(first, last_iter, pred); + detail::pivot9(first, last_iter, + hpx::parallel::util::compare_projected< + std::decay_t, std::decay_t>( + pred, proj)); partition_iter = hpx::parallel::detail::partition().call( policy(hpx::execution::non_task), first + 1, last_iter, - [val = HPX_INVOKE(proj, *first), &pred]( - value_type const& elem) { - return HPX_INVOKE(pred, elem, val); + [val = HPX_INVOKE(proj, *first), &pred, &proj]( + auto const& elem) { + return HPX_INVOKE( + pred, HPX_INVOKE(proj, elem), val); }, - proj); + hpx::identity_v); --partition_iter; diff --git a/libs/core/algorithms/tests/unit/container_algorithms/nth_element_range.cpp b/libs/core/algorithms/tests/unit/container_algorithms/nth_element_range.cpp index 29dc4132fba..a0683ee2f84 100644 --- a/libs/core/algorithms/tests/unit/container_algorithms/nth_element_range.cpp +++ b/libs/core/algorithms/tests/unit/container_algorithms/nth_element_range.cpp @@ -23,6 +23,11 @@ #include "test_utils.hpp" +struct S +{ + std::size_t val; +}; + //////////////////////////////////////////////////////////////////////////// #define SIZE 10007 @@ -198,6 +203,65 @@ void test_nth_element_async(ExPolicy policy, IteratorTag) } } +template +void test_nth_element_projection(ExPolicy policy) +{ + static_assert(hpx::is_execution_policy::value, + "hpx::is_execution_policy::value"); + + std::vector c(SIZE); + for (std::size_t i = 0; i < SIZE; ++i) + c[i].val = SIZE - i; + + auto rand_index = std::rand() % SIZE; + auto nth = std::begin(c) + rand_index; + + hpx::ranges::nth_element(policy, c, nth, std::less{}, &S::val); + + for (int k = 0; k < rand_index; k++) + { + HPX_TEST(c[k].val <= c[rand_index].val); + } + + for (int k = rand_index + 1; k < SIZE; k++) + { + HPX_TEST(c[k].val >= c[rand_index].val); + } +} + +template +void test_nth_element_sent_projection(ExPolicy policy, IteratorTag) +{ + static_assert(hpx::is_execution_policy::value, + "hpx::is_execution_policy::value"); + + using base_iterator = std::vector::iterator; + using iterator = test::test_iterator; + using sentinel = test::sentinel_from_iterator; + + std::vector c(SIZE); + for (std::size_t i = 0; i < SIZE; ++i) + c[i].val = SIZE - i; + + auto rand_index = std::rand() % SIZE; + + auto result = hpx::ranges::nth_element(policy, iterator(std::begin(c)), + iterator(std::begin(c) + rand_index), + sentinel(iterator(std::end(c) - 1)), std::less{}, &S::val); + + HPX_TEST(result == iterator(std::end(c) - 1)); + + for (int k = 0; k < rand_index; k++) + { + HPX_TEST(c[k].val <= c[rand_index].val); + } + + for (int k = rand_index + 1; k < SIZE - 1; k++) + { + HPX_TEST(c[k].val >= c[rand_index].val); + } +} + template void test_nth_element() { @@ -215,6 +279,14 @@ void test_nth_element() test_nth_element_sent(seq, IteratorTag()); test_nth_element_sent(par, IteratorTag()); test_nth_element_sent(par_unseq, IteratorTag()); + + test_nth_element_projection(seq); + test_nth_element_projection(par); + test_nth_element_projection(par_unseq); + + test_nth_element_sent_projection(seq, IteratorTag()); + test_nth_element_sent_projection(par, IteratorTag()); + test_nth_element_sent_projection(par_unseq, IteratorTag()); } void nth_element_test()