|
15 | 15 | #include "libspu/kernel/hlo/sort.h" |
16 | 16 |
|
17 | 17 | #include <algorithm> |
| 18 | +#include <limits> |
18 | 19 | #include <random> |
19 | 20 | #include <xtensor/xadapt.hpp> |
20 | 21 | #include <xtensor/xsort.hpp> |
@@ -461,6 +462,244 @@ TEST_P(SimpleSortTest, MixVisibilityKey) { |
461 | 462 | }); |
462 | 463 | } |
463 | 464 |
|
| 465 | +// Helper template to run unsigned type sort test |
| 466 | +template <typename T> |
| 467 | +void RunUnsignedSortTest(SPUContext *ctx) { |
| 468 | + xt::xarray<T> k1 = {7, 6, 5, 4, 1, 3, 2}; |
| 469 | + xt::xarray<T> sorted_k1 = {1, 2, 3, 4, 5, 6, 7}; |
| 470 | + xt::xarray<float> payload = {1, 2, 3, 6, 7, 6, 5}; |
| 471 | + xt::xarray<float> sorted_payload = {7, 5, 6, 6, 3, 2, 1}; |
| 472 | + |
| 473 | + Value k1_v = test::makeValue(ctx, k1, VIS_SECRET); |
| 474 | + Value payload_v = test::makeValue(ctx, payload, VIS_SECRET); |
| 475 | + |
| 476 | + std::vector<spu::Value> rets = |
| 477 | + SimpleSort(ctx, {k1_v, payload_v}, 0, hal::SortDirection::Ascending, 1); |
| 478 | + |
| 479 | + EXPECT_EQ(rets.size(), 2); |
| 480 | + |
| 481 | + auto sorted_k1_hat = hal::dump_public_as<T>(ctx, hal::reveal(ctx, rets[0])); |
| 482 | + auto sorted_payload_hat = |
| 483 | + hal::dump_public_as<float>(ctx, hal::reveal(ctx, rets[1])); |
| 484 | + |
| 485 | + EXPECT_TRUE(xt::allclose(sorted_k1, sorted_k1_hat, 0.01, 0.001)) |
| 486 | + << "sort failed: " << sorted_k1 << std::endl |
| 487 | + << sorted_k1_hat << std::endl; |
| 488 | + |
| 489 | + EXPECT_TRUE(xt::allclose(sorted_payload, sorted_payload_hat, 0.01, 0.001)) |
| 490 | + << "payload failed: " << sorted_payload << std::endl |
| 491 | + << sorted_payload_hat << std::endl; |
| 492 | +} |
| 493 | + |
| 494 | +TEST_P(SimpleSortTest, UnsignedTypeSort) { |
| 495 | + size_t npc = std::get<0>(GetParam()); |
| 496 | + FieldType field = std::get<1>(GetParam()); |
| 497 | + ProtocolKind prot = std::get<2>(GetParam()); |
| 498 | + RuntimeConfig::SortMethod method = std::get<3>(GetParam()); |
| 499 | + |
| 500 | + mpc::utils::simulate(npc, |
| 501 | + [&](const std::shared_ptr<yacl::link::Context> &lctx) { |
| 502 | + RuntimeConfig cfg; |
| 503 | + cfg.protocol = prot; |
| 504 | + cfg.field = field; |
| 505 | + cfg.enable_action_trace = false; |
| 506 | + cfg.sort_method = method; |
| 507 | + SPUContext ctx = test::makeSPUContext(cfg, lctx); |
| 508 | + |
| 509 | + RunUnsignedSortTest<uint8_t>(&ctx); |
| 510 | + RunUnsignedSortTest<uint16_t>(&ctx); |
| 511 | + RunUnsignedSortTest<uint32_t>(&ctx); |
| 512 | + if (field >= FieldType::FM64) { |
| 513 | + RunUnsignedSortTest<uint64_t>(&ctx); |
| 514 | + } |
| 515 | + }); |
| 516 | +} |
| 517 | + |
| 518 | +// Helper template to test signed interpretation of unsigned-range values |
| 519 | +// When values like 255 (for int8_t) are interpreted as signed, they become -1 |
| 520 | +// So sorting should treat them as negative numbers |
| 521 | +template <typename SignedT, typename UnsignedT> |
| 522 | +void RunSignedInterpretationSortTest(SPUContext *ctx) { |
| 523 | + // Use max value of unsigned type which becomes -1 when interpreted as signed |
| 524 | + constexpr UnsignedT max_val = std::numeric_limits<UnsignedT>::max(); |
| 525 | + // Key: {0, max_val} where max_val is interpreted as -1 in signed |
| 526 | + // For ascending sort with signed interpretation: -1 < 0, so max_val comes |
| 527 | + // first |
| 528 | + xt::xarray<SignedT> k1 = {0, static_cast<SignedT>(max_val)}; |
| 529 | + // Expected: max_val (-1) < 0, so sorted order is {max_val, 0} |
| 530 | + xt::xarray<SignedT> sorted_k1 = {static_cast<SignedT>(max_val), 0}; |
| 531 | + xt::xarray<float> payload = {1.0, 2.0}; |
| 532 | + xt::xarray<float> sorted_payload = {2.0, 1.0}; |
| 533 | + |
| 534 | + Value k1_v = test::makeValue(ctx, k1, VIS_SECRET); |
| 535 | + Value payload_v = test::makeValue(ctx, payload, VIS_SECRET); |
| 536 | + |
| 537 | + std::vector<spu::Value> rets = |
| 538 | + SimpleSort(ctx, {k1_v, payload_v}, 0, hal::SortDirection::Ascending, 1); |
| 539 | + |
| 540 | + EXPECT_EQ(rets.size(), 2); |
| 541 | + |
| 542 | + auto sorted_k1_hat = |
| 543 | + hal::dump_public_as<SignedT>(ctx, hal::reveal(ctx, rets[0])); |
| 544 | + auto sorted_payload_hat = |
| 545 | + hal::dump_public_as<float>(ctx, hal::reveal(ctx, rets[1])); |
| 546 | + |
| 547 | + EXPECT_TRUE(xt::allclose(sorted_k1, sorted_k1_hat, 0.01, 0.001)) |
| 548 | + << "sort failed: expected " << sorted_k1 << ", got " << sorted_k1_hat |
| 549 | + << std::endl; |
| 550 | + |
| 551 | + EXPECT_TRUE(xt::allclose(sorted_payload, sorted_payload_hat, 0.01, 0.001)) |
| 552 | + << "payload failed: expected " << sorted_payload << ", got " |
| 553 | + << sorted_payload_hat << std::endl; |
| 554 | +} |
| 555 | + |
| 556 | +// IMPORTANT: the user should ensure that the data has the correct signed or |
| 557 | +// unsigned type. Incorrect type interpretation will result in incorrect sort |
| 558 | +// order (for example, treating signed values as unsigned may place negative |
| 559 | +// numbers at the end instead of the beginning). |
| 560 | +TEST_P(SimpleSortTest, SignedInterpretationSort) { |
| 561 | + size_t npc = std::get<0>(GetParam()); |
| 562 | + FieldType field = std::get<1>(GetParam()); |
| 563 | + ProtocolKind prot = std::get<2>(GetParam()); |
| 564 | + RuntimeConfig::SortMethod method = std::get<3>(GetParam()); |
| 565 | + |
| 566 | + mpc::utils::simulate( |
| 567 | + npc, [&](const std::shared_ptr<yacl::link::Context> &lctx) { |
| 568 | + RuntimeConfig cfg; |
| 569 | + cfg.protocol = prot; |
| 570 | + cfg.field = field; |
| 571 | + cfg.enable_action_trace = false; |
| 572 | + cfg.sort_method = method; |
| 573 | + SPUContext ctx = test::makeSPUContext(cfg, lctx); |
| 574 | + |
| 575 | + // Test: data is uint8_t range but treated as int8_t |
| 576 | + // 255 (uint8_t) -> -1 (int8_t), so -1 < 0 |
| 577 | + RunSignedInterpretationSortTest<int8_t, uint8_t>(&ctx); |
| 578 | + |
| 579 | + // Test: data is uint16_t range but treated as int16_t |
| 580 | + // 65535 (uint16_t) -> -1 (int16_t), so -1 < 0 |
| 581 | + RunSignedInterpretationSortTest<int16_t, uint16_t>(&ctx); |
| 582 | + |
| 583 | + // Test: data is uint32_t range but treated as int32_t |
| 584 | + // 4294967295 (uint32_t) -> -1 (int32_t), so -1 < 0 |
| 585 | + RunSignedInterpretationSortTest<int32_t, uint32_t>(&ctx); |
| 586 | + |
| 587 | + if (field >= FieldType::FM64) { |
| 588 | + // Test: data is uint64_t range but treated as int64_t |
| 589 | + RunSignedInterpretationSortTest<int64_t, uint64_t>(&ctx); |
| 590 | + } |
| 591 | + }); |
| 592 | +} |
| 593 | + |
| 594 | +TEST_P(SimpleSortTest, BoolKeyWithPayloads) { |
| 595 | + size_t npc = std::get<0>(GetParam()); |
| 596 | + FieldType field = std::get<1>(GetParam()); |
| 597 | + ProtocolKind prot = std::get<2>(GetParam()); |
| 598 | + RuntimeConfig::SortMethod method = std::get<3>(GetParam()); |
| 599 | + |
| 600 | + mpc::utils::simulate( |
| 601 | + npc, [&](const std::shared_ptr<yacl::link::Context> &lctx) { |
| 602 | + RuntimeConfig cfg; |
| 603 | + cfg.protocol = prot; |
| 604 | + cfg.field = field; |
| 605 | + cfg.enable_action_trace = false; |
| 606 | + cfg.sort_method = method; |
| 607 | + |
| 608 | + SPUContext ctx = test::makeSPUContext(cfg, lctx); |
| 609 | + |
| 610 | + // Bool key with two payloads |
| 611 | + xt::xarray<bool> k1 = {true, false, true, false, true}; |
| 612 | + xt::xarray<float> p1 = {1.0, 2.0, 3.0, 4.0, 5.0}; |
| 613 | + xt::xarray<int32_t> p2 = {10, 20, 30, 40, 50}; |
| 614 | + |
| 615 | + // Expected sorted keys |
| 616 | + xt::xarray<bool> sorted_k1_desc = {true, true, true, false, false}; |
| 617 | + xt::xarray<bool> sorted_k1_asc = {false, false, true, true, true}; |
| 618 | + |
| 619 | + // Expected payloads (sorted within each group since sort is unstable) |
| 620 | + // Descending: true keys first {1,3,5}, then false keys {2,4} |
| 621 | + xt::xarray<float> sorted_p1_desc = {1.0, 3.0, 5.0, 2.0, 4.0}; |
| 622 | + xt::xarray<int32_t> sorted_p2_desc = {10, 30, 50, 20, 40}; |
| 623 | + // Ascending: false keys first {2,4}, then true keys {1,3,5} |
| 624 | + xt::xarray<float> sorted_p1_asc = {2.0, 4.0, 1.0, 3.0, 5.0}; |
| 625 | + xt::xarray<int32_t> sorted_p2_asc = {20, 40, 10, 30, 50}; |
| 626 | + |
| 627 | + Value k1_v = test::makeValue(&ctx, k1, VIS_SECRET); |
| 628 | + Value p1_v = test::makeValue(&ctx, p1, VIS_SECRET); |
| 629 | + Value p2_v = test::makeValue(&ctx, p2, VIS_SECRET); |
| 630 | + |
| 631 | + // Test descending sort (true before false) |
| 632 | + { |
| 633 | + std::vector<spu::Value> rets = SimpleSort( |
| 634 | + &ctx, {k1_v, p1_v, p2_v}, 0, hal::SortDirection::Descending, 1); |
| 635 | + |
| 636 | + EXPECT_EQ(rets.size(), 3); |
| 637 | + |
| 638 | + auto sorted_k1_hat = |
| 639 | + hal::dump_public_as<bool>(&ctx, hal::reveal(&ctx, rets[0])); |
| 640 | + auto sorted_p1_hat = |
| 641 | + hal::dump_public_as<float>(&ctx, hal::reveal(&ctx, rets[1])); |
| 642 | + auto sorted_p2_hat = |
| 643 | + hal::dump_public_as<int32_t>(&ctx, hal::reveal(&ctx, rets[2])); |
| 644 | + |
| 645 | + // Check bool key is sorted correctly |
| 646 | + EXPECT_TRUE(xt::allclose(sorted_k1_desc, sorted_k1_hat, 0.01, 0.001)) |
| 647 | + << "Bool descending sort failed: " << sorted_k1_desc << std::endl |
| 648 | + << sorted_k1_hat << std::endl; |
| 649 | + |
| 650 | + // Sort each part and compare (since sort is unstable within same key) |
| 651 | + auto p1_hat_sorted = xt::concatenate( |
| 652 | + xt::xtuple(xt::sort(xt::view(sorted_p1_hat, xt::range(0, 3))), |
| 653 | + xt::sort(xt::view(sorted_p1_hat, xt::range(3, 5))))); |
| 654 | + auto p2_hat_sorted = xt::concatenate( |
| 655 | + xt::xtuple(xt::sort(xt::view(sorted_p2_hat, xt::range(0, 3))), |
| 656 | + xt::sort(xt::view(sorted_p2_hat, xt::range(3, 5))))); |
| 657 | + |
| 658 | + EXPECT_TRUE(xt::allclose(sorted_p1_desc, p1_hat_sorted, 0.01, 0.001)) |
| 659 | + << "Descending p1 failed: " << sorted_p1_desc << std::endl |
| 660 | + << p1_hat_sorted << std::endl; |
| 661 | + EXPECT_TRUE(xt::allclose(sorted_p2_desc, p2_hat_sorted, 0.01, 0.001)) |
| 662 | + << "Descending p2 failed: " << sorted_p2_desc << std::endl |
| 663 | + << p2_hat_sorted << std::endl; |
| 664 | + } |
| 665 | + |
| 666 | + // Test ascending sort (false before true) |
| 667 | + { |
| 668 | + std::vector<spu::Value> rets = SimpleSort( |
| 669 | + &ctx, {k1_v, p1_v, p2_v}, 0, hal::SortDirection::Ascending, 1); |
| 670 | + |
| 671 | + EXPECT_EQ(rets.size(), 3); |
| 672 | + |
| 673 | + auto sorted_k1_hat = |
| 674 | + hal::dump_public_as<bool>(&ctx, hal::reveal(&ctx, rets[0])); |
| 675 | + auto sorted_p1_hat = |
| 676 | + hal::dump_public_as<float>(&ctx, hal::reveal(&ctx, rets[1])); |
| 677 | + auto sorted_p2_hat = |
| 678 | + hal::dump_public_as<int32_t>(&ctx, hal::reveal(&ctx, rets[2])); |
| 679 | + |
| 680 | + // Check bool key is sorted correctly |
| 681 | + EXPECT_TRUE(xt::allclose(sorted_k1_asc, sorted_k1_hat, 0.01, 0.001)) |
| 682 | + << "Bool ascending sort failed: " << sorted_k1_asc << std::endl |
| 683 | + << sorted_k1_hat << std::endl; |
| 684 | + |
| 685 | + // Sort each part and compare (since sort is unstable within same key) |
| 686 | + auto p1_hat_sorted = xt::concatenate( |
| 687 | + xt::xtuple(xt::sort(xt::view(sorted_p1_hat, xt::range(0, 2))), |
| 688 | + xt::sort(xt::view(sorted_p1_hat, xt::range(2, 5))))); |
| 689 | + auto p2_hat_sorted = xt::concatenate( |
| 690 | + xt::xtuple(xt::sort(xt::view(sorted_p2_hat, xt::range(0, 2))), |
| 691 | + xt::sort(xt::view(sorted_p2_hat, xt::range(2, 5))))); |
| 692 | + |
| 693 | + EXPECT_TRUE(xt::allclose(sorted_p1_asc, p1_hat_sorted, 0.01, 0.001)) |
| 694 | + << "Ascending p1 failed: " << sorted_p1_asc << std::endl |
| 695 | + << p1_hat_sorted << std::endl; |
| 696 | + EXPECT_TRUE(xt::allclose(sorted_p2_asc, p2_hat_sorted, 0.01, 0.001)) |
| 697 | + << "Ascending p2 failed: " << sorted_p2_asc << std::endl |
| 698 | + << p2_hat_sorted << std::endl; |
| 699 | + } |
| 700 | + }); |
| 701 | +} |
| 702 | + |
464 | 703 | INSTANTIATE_TEST_SUITE_P( |
465 | 704 | SimpleSort2PCTestInstances, SimpleSortTest, |
466 | 705 | testing::Combine(testing::Values(2), |
|
0 commit comments