Skip to content

Commit feb185a

Browse files
deadlywingCopilot
andauthored
radix sort support unsigned/signed type (#1367)
# Pull Request ## What problem does this PR solve? Issue Number: Fixed # ## Possible side effects? - Performance: - Backward compatibility: --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent b30ac5b commit feb185a

File tree

6 files changed

+273
-11
lines changed

6 files changed

+273
-11
lines changed

src/libspu/core/type_util.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ bool isInteger(DataType dtype) {
5252
}
5353
}
5454

55+
bool isUInteger(DataType dtype) {
56+
switch (dtype) {
57+
FOREACH_UINT_DTYPES(CASE)
58+
default:
59+
return false;
60+
}
61+
}
62+
5563
bool isFixedPoint(DataType dtype) {
5664
switch (dtype) {
5765
FOREACH_FXP_DTYPES(CASE)

src/libspu/core/type_util.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ std::ostream& operator<<(std::ostream& os, const Visibility& vtype);
5050
FN(DT_I64, I64, 64) \
5151
FN(DT_U64, U64, 64)
5252

53+
#define FOREACH_UINT_DTYPES(FN) \
54+
FN(DT_I1, I1, 1) \
55+
FN(DT_U8, U8, 8) \
56+
FN(DT_U16, U16, 16) \
57+
FN(DT_U32, U32, 32) \
58+
FN(DT_U64, U64, 64)
59+
5360
#define FOREACH_FXP_DTYPES(FN) \
5461
FN(DT_F16, F16, 16) \
5562
FN(DT_F32, F32, 32) \
@@ -60,6 +67,7 @@ std::ostream& operator<<(std::ostream& os, const Visibility& vtype);
6067
FOREACH_FXP_DTYPES(FN)
6168

6269
bool isInteger(DataType dtype);
70+
bool isUInteger(DataType dtype);
6371
bool isFixedPoint(DataType dtype);
6472
size_t getWidth(DataType dtype);
6573

src/libspu/core/value.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class Value final {
8080
// Get dtype.
8181
DataType dtype() const { return dtype_; }
8282
bool isInt() const { return isInteger(dtype()); }
83+
bool isUInt() const { return isUInteger(dtype()); }
8384
bool isFxp() const { return isFixedPoint(dtype()); }
8485
bool isComplex() const { return imag_.has_value(); }
8586

src/libspu/kernel/hal/permute.cc

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -984,25 +984,27 @@ std::vector<spu::Value> _gen_bv_vector(SPUContext *ctx,
984984
int64_t valid_bits) {
985985
std::vector<spu::Value> ret;
986986
const auto k1 = _constant(ctx, 1U, keys[0].shape());
987+
const bool is_descending = (direction == SortDirection::Descending);
988+
987989
// keys[0] is the most significant key
988990
for (size_t i = keys.size(); i > 0; --i) {
989991
const auto t = _bit_decompose(ctx, keys[i - 1], valid_bits);
992+
const bool is_unsigned = keys[i - 1].isUInt();
990993

991994
SPU_ENFORCE(t.size() > 0);
995+
996+
// Process non-sign bits (same logic for both signed and unsigned types)
997+
// Radix sort is a stable sorting algorithm for the ascending order, if
998+
// we flip the bit, then we can get the descending order for stable sort
992999
for (size_t j = 0; j < t.size() - 1; j++) {
993-
// Radix sort is a stable sorting algorithm for the ascending order, if
994-
// we flip the bit, then we can get the descending order for stable sort
995-
if (direction == SortDirection::Descending) {
996-
ret.emplace_back(_sub(ctx, k1, t[j]));
997-
} else {
998-
ret.emplace_back(t[j]);
999-
}
1000+
ret.emplace_back(is_descending ? _sub(ctx, k1, t[j]) : t[j]);
10001001
}
1001-
// The sign bit is opposite
1002-
if (direction == SortDirection::Descending) {
1003-
ret.emplace_back(t.back());
1002+
1003+
// For signed types, handle the sign bit (opposite logic)
1004+
if (!is_unsigned) {
1005+
ret.emplace_back(is_descending ? t.back() : _sub(ctx, k1, t.back()));
10041006
} else {
1005-
ret.emplace_back(_sub(ctx, k1, t.back()));
1007+
ret.emplace_back(is_descending ? _sub(ctx, k1, t.back()) : t.back());
10061008
}
10071009
}
10081010
return ret;

src/libspu/kernel/hal/permute.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ std::vector<spu::Value> sort1d(SPUContext *ctx,
5858
// - direction: sorting order
5959
// - num_keys: the number of operands to treat as keys (count from index 0)
6060
// - valid_bits: indicates the numeric range of keys for performance hint
61+
//
62+
// Important notes:
63+
// - for radix sort, the user should ensure that the data has the correct
64+
// signed or unsigned type.
6165
std::vector<spu::Value> simple_sort1d(SPUContext *ctx,
6266
absl::Span<spu::Value const> inputs,
6367
SortDirection direction, int64_t num_keys,

src/libspu/kernel/hlo/sort_test.cc

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "libspu/kernel/hlo/sort.h"
1616

1717
#include <algorithm>
18+
#include <limits>
1819
#include <random>
1920
#include <xtensor/xadapt.hpp>
2021
#include <xtensor/xsort.hpp>
@@ -461,6 +462,244 @@ TEST_P(SimpleSortTest, MixVisibilityKey) {
461462
});
462463
}
463464

465+
// Helper template to run unsigned type sort test
466+
template <typename T>
467+
void RunUnsignedSortTest(SPUContext *ctx) {
468+
xt::xarray<T> k1 = {7, 6, 5, 4, 1, 3, 2};
469+
xt::xarray<T> sorted_k1 = {1, 2, 3, 4, 5, 6, 7};
470+
xt::xarray<float> payload = {1, 2, 3, 6, 7, 6, 5};
471+
xt::xarray<float> sorted_payload = {7, 5, 6, 6, 3, 2, 1};
472+
473+
Value k1_v = test::makeValue(ctx, k1, VIS_SECRET);
474+
Value payload_v = test::makeValue(ctx, payload, VIS_SECRET);
475+
476+
std::vector<spu::Value> rets =
477+
SimpleSort(ctx, {k1_v, payload_v}, 0, hal::SortDirection::Ascending, 1);
478+
479+
EXPECT_EQ(rets.size(), 2);
480+
481+
auto sorted_k1_hat = hal::dump_public_as<T>(ctx, hal::reveal(ctx, rets[0]));
482+
auto sorted_payload_hat =
483+
hal::dump_public_as<float>(ctx, hal::reveal(ctx, rets[1]));
484+
485+
EXPECT_TRUE(xt::allclose(sorted_k1, sorted_k1_hat, 0.01, 0.001))
486+
<< "sort failed: " << sorted_k1 << std::endl
487+
<< sorted_k1_hat << std::endl;
488+
489+
EXPECT_TRUE(xt::allclose(sorted_payload, sorted_payload_hat, 0.01, 0.001))
490+
<< "payload failed: " << sorted_payload << std::endl
491+
<< sorted_payload_hat << std::endl;
492+
}
493+
494+
TEST_P(SimpleSortTest, UnsignedTypeSort) {
495+
size_t npc = std::get<0>(GetParam());
496+
FieldType field = std::get<1>(GetParam());
497+
ProtocolKind prot = std::get<2>(GetParam());
498+
RuntimeConfig::SortMethod method = std::get<3>(GetParam());
499+
500+
mpc::utils::simulate(npc,
501+
[&](const std::shared_ptr<yacl::link::Context> &lctx) {
502+
RuntimeConfig cfg;
503+
cfg.protocol = prot;
504+
cfg.field = field;
505+
cfg.enable_action_trace = false;
506+
cfg.sort_method = method;
507+
SPUContext ctx = test::makeSPUContext(cfg, lctx);
508+
509+
RunUnsignedSortTest<uint8_t>(&ctx);
510+
RunUnsignedSortTest<uint16_t>(&ctx);
511+
RunUnsignedSortTest<uint32_t>(&ctx);
512+
if (field >= FieldType::FM64) {
513+
RunUnsignedSortTest<uint64_t>(&ctx);
514+
}
515+
});
516+
}
517+
518+
// Helper template to test signed interpretation of unsigned-range values
519+
// When values like 255 (for int8_t) are interpreted as signed, they become -1
520+
// So sorting should treat them as negative numbers
521+
template <typename SignedT, typename UnsignedT>
522+
void RunSignedInterpretationSortTest(SPUContext *ctx) {
523+
// Use max value of unsigned type which becomes -1 when interpreted as signed
524+
constexpr UnsignedT max_val = std::numeric_limits<UnsignedT>::max();
525+
// Key: {0, max_val} where max_val is interpreted as -1 in signed
526+
// For ascending sort with signed interpretation: -1 < 0, so max_val comes
527+
// first
528+
xt::xarray<SignedT> k1 = {0, static_cast<SignedT>(max_val)};
529+
// Expected: max_val (-1) < 0, so sorted order is {max_val, 0}
530+
xt::xarray<SignedT> sorted_k1 = {static_cast<SignedT>(max_val), 0};
531+
xt::xarray<float> payload = {1.0, 2.0};
532+
xt::xarray<float> sorted_payload = {2.0, 1.0};
533+
534+
Value k1_v = test::makeValue(ctx, k1, VIS_SECRET);
535+
Value payload_v = test::makeValue(ctx, payload, VIS_SECRET);
536+
537+
std::vector<spu::Value> rets =
538+
SimpleSort(ctx, {k1_v, payload_v}, 0, hal::SortDirection::Ascending, 1);
539+
540+
EXPECT_EQ(rets.size(), 2);
541+
542+
auto sorted_k1_hat =
543+
hal::dump_public_as<SignedT>(ctx, hal::reveal(ctx, rets[0]));
544+
auto sorted_payload_hat =
545+
hal::dump_public_as<float>(ctx, hal::reveal(ctx, rets[1]));
546+
547+
EXPECT_TRUE(xt::allclose(sorted_k1, sorted_k1_hat, 0.01, 0.001))
548+
<< "sort failed: expected " << sorted_k1 << ", got " << sorted_k1_hat
549+
<< std::endl;
550+
551+
EXPECT_TRUE(xt::allclose(sorted_payload, sorted_payload_hat, 0.01, 0.001))
552+
<< "payload failed: expected " << sorted_payload << ", got "
553+
<< sorted_payload_hat << std::endl;
554+
}
555+
556+
// IMPORTANT: the user should ensure that the data has the correct signed or
557+
// unsigned type. Incorrect type interpretation will result in incorrect sort
558+
// order (for example, treating signed values as unsigned may place negative
559+
// numbers at the end instead of the beginning).
560+
TEST_P(SimpleSortTest, SignedInterpretationSort) {
561+
size_t npc = std::get<0>(GetParam());
562+
FieldType field = std::get<1>(GetParam());
563+
ProtocolKind prot = std::get<2>(GetParam());
564+
RuntimeConfig::SortMethod method = std::get<3>(GetParam());
565+
566+
mpc::utils::simulate(
567+
npc, [&](const std::shared_ptr<yacl::link::Context> &lctx) {
568+
RuntimeConfig cfg;
569+
cfg.protocol = prot;
570+
cfg.field = field;
571+
cfg.enable_action_trace = false;
572+
cfg.sort_method = method;
573+
SPUContext ctx = test::makeSPUContext(cfg, lctx);
574+
575+
// Test: data is uint8_t range but treated as int8_t
576+
// 255 (uint8_t) -> -1 (int8_t), so -1 < 0
577+
RunSignedInterpretationSortTest<int8_t, uint8_t>(&ctx);
578+
579+
// Test: data is uint16_t range but treated as int16_t
580+
// 65535 (uint16_t) -> -1 (int16_t), so -1 < 0
581+
RunSignedInterpretationSortTest<int16_t, uint16_t>(&ctx);
582+
583+
// Test: data is uint32_t range but treated as int32_t
584+
// 4294967295 (uint32_t) -> -1 (int32_t), so -1 < 0
585+
RunSignedInterpretationSortTest<int32_t, uint32_t>(&ctx);
586+
587+
if (field >= FieldType::FM64) {
588+
// Test: data is uint64_t range but treated as int64_t
589+
RunSignedInterpretationSortTest<int64_t, uint64_t>(&ctx);
590+
}
591+
});
592+
}
593+
594+
TEST_P(SimpleSortTest, BoolKeyWithPayloads) {
595+
size_t npc = std::get<0>(GetParam());
596+
FieldType field = std::get<1>(GetParam());
597+
ProtocolKind prot = std::get<2>(GetParam());
598+
RuntimeConfig::SortMethod method = std::get<3>(GetParam());
599+
600+
mpc::utils::simulate(
601+
npc, [&](const std::shared_ptr<yacl::link::Context> &lctx) {
602+
RuntimeConfig cfg;
603+
cfg.protocol = prot;
604+
cfg.field = field;
605+
cfg.enable_action_trace = false;
606+
cfg.sort_method = method;
607+
608+
SPUContext ctx = test::makeSPUContext(cfg, lctx);
609+
610+
// Bool key with two payloads
611+
xt::xarray<bool> k1 = {true, false, true, false, true};
612+
xt::xarray<float> p1 = {1.0, 2.0, 3.0, 4.0, 5.0};
613+
xt::xarray<int32_t> p2 = {10, 20, 30, 40, 50};
614+
615+
// Expected sorted keys
616+
xt::xarray<bool> sorted_k1_desc = {true, true, true, false, false};
617+
xt::xarray<bool> sorted_k1_asc = {false, false, true, true, true};
618+
619+
// Expected payloads (sorted within each group since sort is unstable)
620+
// Descending: true keys first {1,3,5}, then false keys {2,4}
621+
xt::xarray<float> sorted_p1_desc = {1.0, 3.0, 5.0, 2.0, 4.0};
622+
xt::xarray<int32_t> sorted_p2_desc = {10, 30, 50, 20, 40};
623+
// Ascending: false keys first {2,4}, then true keys {1,3,5}
624+
xt::xarray<float> sorted_p1_asc = {2.0, 4.0, 1.0, 3.0, 5.0};
625+
xt::xarray<int32_t> sorted_p2_asc = {20, 40, 10, 30, 50};
626+
627+
Value k1_v = test::makeValue(&ctx, k1, VIS_SECRET);
628+
Value p1_v = test::makeValue(&ctx, p1, VIS_SECRET);
629+
Value p2_v = test::makeValue(&ctx, p2, VIS_SECRET);
630+
631+
// Test descending sort (true before false)
632+
{
633+
std::vector<spu::Value> rets = SimpleSort(
634+
&ctx, {k1_v, p1_v, p2_v}, 0, hal::SortDirection::Descending, 1);
635+
636+
EXPECT_EQ(rets.size(), 3);
637+
638+
auto sorted_k1_hat =
639+
hal::dump_public_as<bool>(&ctx, hal::reveal(&ctx, rets[0]));
640+
auto sorted_p1_hat =
641+
hal::dump_public_as<float>(&ctx, hal::reveal(&ctx, rets[1]));
642+
auto sorted_p2_hat =
643+
hal::dump_public_as<int32_t>(&ctx, hal::reveal(&ctx, rets[2]));
644+
645+
// Check bool key is sorted correctly
646+
EXPECT_TRUE(xt::allclose(sorted_k1_desc, sorted_k1_hat, 0.01, 0.001))
647+
<< "Bool descending sort failed: " << sorted_k1_desc << std::endl
648+
<< sorted_k1_hat << std::endl;
649+
650+
// Sort each part and compare (since sort is unstable within same key)
651+
auto p1_hat_sorted = xt::concatenate(
652+
xt::xtuple(xt::sort(xt::view(sorted_p1_hat, xt::range(0, 3))),
653+
xt::sort(xt::view(sorted_p1_hat, xt::range(3, 5)))));
654+
auto p2_hat_sorted = xt::concatenate(
655+
xt::xtuple(xt::sort(xt::view(sorted_p2_hat, xt::range(0, 3))),
656+
xt::sort(xt::view(sorted_p2_hat, xt::range(3, 5)))));
657+
658+
EXPECT_TRUE(xt::allclose(sorted_p1_desc, p1_hat_sorted, 0.01, 0.001))
659+
<< "Descending p1 failed: " << sorted_p1_desc << std::endl
660+
<< p1_hat_sorted << std::endl;
661+
EXPECT_TRUE(xt::allclose(sorted_p2_desc, p2_hat_sorted, 0.01, 0.001))
662+
<< "Descending p2 failed: " << sorted_p2_desc << std::endl
663+
<< p2_hat_sorted << std::endl;
664+
}
665+
666+
// Test ascending sort (false before true)
667+
{
668+
std::vector<spu::Value> rets = SimpleSort(
669+
&ctx, {k1_v, p1_v, p2_v}, 0, hal::SortDirection::Ascending, 1);
670+
671+
EXPECT_EQ(rets.size(), 3);
672+
673+
auto sorted_k1_hat =
674+
hal::dump_public_as<bool>(&ctx, hal::reveal(&ctx, rets[0]));
675+
auto sorted_p1_hat =
676+
hal::dump_public_as<float>(&ctx, hal::reveal(&ctx, rets[1]));
677+
auto sorted_p2_hat =
678+
hal::dump_public_as<int32_t>(&ctx, hal::reveal(&ctx, rets[2]));
679+
680+
// Check bool key is sorted correctly
681+
EXPECT_TRUE(xt::allclose(sorted_k1_asc, sorted_k1_hat, 0.01, 0.001))
682+
<< "Bool ascending sort failed: " << sorted_k1_asc << std::endl
683+
<< sorted_k1_hat << std::endl;
684+
685+
// Sort each part and compare (since sort is unstable within same key)
686+
auto p1_hat_sorted = xt::concatenate(
687+
xt::xtuple(xt::sort(xt::view(sorted_p1_hat, xt::range(0, 2))),
688+
xt::sort(xt::view(sorted_p1_hat, xt::range(2, 5)))));
689+
auto p2_hat_sorted = xt::concatenate(
690+
xt::xtuple(xt::sort(xt::view(sorted_p2_hat, xt::range(0, 2))),
691+
xt::sort(xt::view(sorted_p2_hat, xt::range(2, 5)))));
692+
693+
EXPECT_TRUE(xt::allclose(sorted_p1_asc, p1_hat_sorted, 0.01, 0.001))
694+
<< "Ascending p1 failed: " << sorted_p1_asc << std::endl
695+
<< p1_hat_sorted << std::endl;
696+
EXPECT_TRUE(xt::allclose(sorted_p2_asc, p2_hat_sorted, 0.01, 0.001))
697+
<< "Ascending p2 failed: " << sorted_p2_asc << std::endl
698+
<< p2_hat_sorted << std::endl;
699+
}
700+
});
701+
}
702+
464703
INSTANTIATE_TEST_SUITE_P(
465704
SimpleSort2PCTestInstances, SimpleSortTest,
466705
testing::Combine(testing::Values(2),

0 commit comments

Comments
 (0)