diff --git a/cpp/include/rapidsmpf/communicator/ucxx.hpp b/cpp/include/rapidsmpf/communicator/ucxx.hpp index aa06911a7..8322d2256 100644 --- a/cpp/include/rapidsmpf/communicator/ucxx.hpp +++ b/cpp/include/rapidsmpf/communicator/ucxx.hpp @@ -169,6 +169,70 @@ class UCXX final : public Communicator { Rank rank, Tag tag, std::unique_ptr recv_buffer ) override; + /** + * Send a message with a callback. + * + * @param msg The message to send. + * @param rank The rank to send the message to. + * @param tag The tag to send the message with. + * @param send_cb The callback to call when the message is sent. The send buffer will + * be delivered into the callback at the end of send operation. + * + * @return A future representing the send operation. + * @note The returned future would not contain any data. Therefore, can not be used in + * `get_gpu_data` method. Use the callback to get the data. + * + * @throws std::runtime_error if the send operation fails. + */ + std::unique_ptr send_with_cb( + std::unique_ptr msg, + Rank rank, + Tag tag, + std::function)> send_cb + ); + + /** + * Receive a message with a callback. + * + * @param rank The rank to receive the message from. + * @param tag The tag to receive the message with. + * @param recv_buffer The buffer to receive the message into. + * @param recv_cb The callback to call when the message is received. The receive + * buffer will be delivered into the callback at the end of receive operation. + * + * @return A future representing the receive operation. + * @note The returned future would not contain any data. Therefore, can not be used in + * `get_gpu_data` method. Use the callback to get the data. + * + * @throws std::runtime_error if the receive operation fails. + */ + std::unique_ptr recv_with_cb( + Rank rank, + Tag tag, + std::unique_ptr recv_buffer, + std::function)> recv_cb + ); + + /** + * Receive a message with a callback. + * + * @param tag The tag to receive the message with. + * @param recv_cb The callback to call when the message is received. The receive + * buffer will be delivered into the callback at the end of receive operation. + * @param br The buffer resource to allocate the receive buffer. + * + * @return A future representing the receive operation. + * @note The returned future would not contain any data. Therefore, can not be used in + * `get_gpu_data` method. Use the callback to get the data. + * + * @throws std::runtime_error if the receive operation fails. + */ + std::unique_ptr recv_any_with_cb( + Tag tag, + std::function, Rank)> recv_cb, + BufferResource* br + ); + /** * @copydoc Communicator::recv_any * diff --git a/cpp/src/communicator/ucxx.cpp b/cpp/src/communicator/ucxx.cpp index e883f458e..e8d04de49 100644 --- a/cpp/src/communicator/ucxx.cpp +++ b/cpp/src/communicator/ucxx.cpp @@ -1288,6 +1288,98 @@ std::shared_ptr UCXX::split() { return std::make_shared(std::move(initialized_rank), options_); } +namespace { + +/// @brief Temporary data structure to hold the message and deliver it to the callback. +struct CallbackData { + std::unique_ptr msg; +}; + +} // namespace + +std::unique_ptr UCXX::send_with_cb( + std::unique_ptr msg, + Rank rank, + Tag tag, + std::function)> send_cb +) { + if (!msg->is_ready()) { + logger().warn("msg is not ready. This is irrecoverable, terminating."); + std::terminate(); + } + auto data_ptr = msg->data(); + auto data_size = msg->size; + auto req = get_endpoint(rank)->tagSend( + data_ptr, + data_size, + tag_with_rank(shared_resources_->rank(), tag), + false, + [cb = std::move(send_cb)](ucs_status_t status, std::shared_ptr data) { + RAPIDSMPF_EXPECTS(status == UCS_OK, "UCXX send failed", std::runtime_error); + cb(std::move(static_pointer_cast(data)->msg)); + }, + std::make_shared(std::move(msg)) + ); + return std::make_unique(req, nullptr); +} + +std::unique_ptr UCXX::recv_with_cb( + Rank rank, + Tag tag, + std::unique_ptr recv_buffer, + std::function)> recv_cb +) { + if (!recv_buffer->is_ready()) { + logger().warn("recv_buffer is not ready. This is irrecoverable, terminating."); + std::terminate(); + } + auto data_ptr = recv_buffer->data(); + auto data_size = recv_buffer->size; + auto req = get_endpoint(rank)->tagRecv( + data_ptr, + data_size, + tag_with_rank(rank, tag), + ::ucxx::TagMaskFull, + false, + [cb = std::move(recv_cb)](ucs_status_t status, std::shared_ptr data) { + RAPIDSMPF_EXPECTS(status == UCS_OK, "UCXX recv failed", std::runtime_error); + cb(std::move(static_pointer_cast(data)->msg)); + }, + std::make_shared(std::move(recv_buffer)) + ); + return std::make_unique(req, nullptr); +} + +std::unique_ptr UCXX::recv_any_with_cb( + Tag tag, + std::function, Rank)> recv_cb, + BufferResource* br +) { + auto [msg_available, info] = shared_resources_->get_worker()->tagProbe( + ::ucxx::Tag(static_cast(tag)), UserTagMask + ); + + if (!msg_available) { + return nullptr; + } + + auto sender_rank = static_cast(info.senderTag >> 32); + + // Create a buffer to receive the message + auto msg = std::make_unique>(info.length); + + // Receive the message + return recv_with_cb( + sender_rank, + tag, + br->move(std::move(msg)), + [cb = std::move(recv_cb), sender_rank](std::unique_ptr msg) { + cb(std::move(msg), sender_rank); + } + ); +} + + } // namespace ucxx } // namespace rapidsmpf diff --git a/cpp/tests/test_communicator.cpp b/cpp/tests/test_communicator.cpp index 25da4382b..31b9a203b 100644 --- a/cpp/tests/test_communicator.cpp +++ b/cpp/tests/test_communicator.cpp @@ -5,6 +5,8 @@ #include #include +#include +#include #include #include @@ -13,6 +15,8 @@ #include #include #include +#include +#include #include "environment.hpp" #include "utils.hpp" @@ -33,7 +37,9 @@ class BaseCommunicatorTest : public ::testing::Test { GlobalEnvironment->barrier(); } - virtual rapidsmpf::MemoryType memory_type() = 0; + virtual rapidsmpf::MemoryType memory_type() { + return rapidsmpf::MemoryType::DEVICE; + } rapidsmpf::Communicator* comm; std::unique_ptr mr; @@ -88,3 +94,225 @@ TEST_P(BasicCommunicatorTest, SendToSelf) { stream.synchronize(); EXPECT_EQ(send_data_h, *recv_data_h); } + +using namespace rapidsmpf; + +TEST_F(BaseCommunicatorTest, UcxxTagSendRecvCb) { + if (GlobalEnvironment->type() != TestEnvironmentType::UCXX) { + GTEST_SKIP() << "UCXX only"; + } + + if (comm->nranks() < 2) { + GTEST_SKIP() << "Need at least 2 ranks"; + } + auto ucx_comm = static_cast(comm); + Tag const ready_for_data_tag{0, 1}; + Tag const metadata_tag{0, 2}; + Tag const gpu_data_tag{0, 3}; + + constexpr size_t nelems{8}; + constexpr shuffler::detail::ChunkID chunk_id{100}; + constexpr shuffler::PartID part_id{100}; + + // Create dummy metadata and data + auto metadata = iota_vector(nelems, 100); // Start from 100 + auto data = iota_vector(nelems, 200); // Start from 200 + + // Create PackedData using the helper function + auto packed_data = create_packed_data(metadata, data, stream, br.get()); + + auto chunk = shuffler::detail::Chunk::from_packed_data( + chunk_id, part_id, std::move(packed_data) + ); + + std::vector> futures; + + if (comm->rank() == 0) { + // send metadata to rank 1 + auto serialized_metadata = chunk.serialize(); + + // send metadata to rank 1 + futures.emplace_back(ucx_comm->send( + std::move(serialized_metadata), Rank(1), metadata_tag, br.get() + )); + + // recive ready for data from rank 1 + auto ready_for_data = std::make_unique>( + shuffler::detail::ReadyForDataMessage::byte_size + ); + + futures.emplace_back(ucx_comm->recv_with_cb( + Rank(1), + ready_for_data_tag, + br->move(std::move(ready_for_data)), + [&](std::unique_ptr buf) { + auto const& host_buf = br->move_to_host_vector(std::move(buf)); + EXPECT_EQ( + host_buf->size(), shuffler::detail::ReadyForDataMessage::byte_size + ); + + shuffler::detail::ChunkID cid; + std::memcpy(&cid, host_buf->data(), sizeof(cid)); + EXPECT_EQ(cid, chunk_id); + + auto data_buf = chunk.release_data_buffer(); + data_buf->wait_for_ready(); + futures.emplace_back( + ucx_comm->send(std::move(data_buf), Rank(1), gpu_data_tag) + ); + } + )); + } else if (comm->rank() == 1) { + std::unique_ptr> recv_buf; + Rank sender_rank; + + // wait for metadata from rank 0 + while (!recv_buf) { + std::tie(recv_buf, sender_rank) = ucx_comm->recv_any(metadata_tag); + } + EXPECT_EQ(sender_rank, 0); + auto chunk = shuffler::detail::Chunk::deserialize(*recv_buf); + EXPECT_EQ(chunk.chunk_id(), chunk_id); + + // allocate data buffer + auto [reservation, ob] = + br->reserve(MemoryType::DEVICE, chunk.concat_data_size(), false); + auto data_buf = br->allocate(chunk.concat_data_size(), stream, reservation); + data_buf->wait_for_ready(); + + // post recv for data from rank 0 + futures.emplace_back(ucx_comm->recv_with_cb( + sender_rank, + gpu_data_tag, + std::move(data_buf), + [data_size = chunk.concat_data_size()](std::unique_ptr buf) { + EXPECT_EQ(data_size, buf->size); + } + )); + + // send ready for data to rank 0 + auto ready_for_data = std::make_unique>( + shuffler::detail::ReadyForDataMessage::byte_size + ); + std::memcpy(ready_for_data->data(), &chunk_id, sizeof(chunk_id)); + futures.emplace_back(ucx_comm->send( + br->move(std::move(ready_for_data)), Rank(0), ready_for_data_tag + )); + } // else do nothing + + while (!futures.empty()) { + std::ignore = ucx_comm->test_some(futures); + } +} + +TEST_F(BaseCommunicatorTest, UcxxTagSendRecvCb2) { + if (GlobalEnvironment->type() != TestEnvironmentType::UCXX) { + GTEST_SKIP() << "UCXX only"; + } + + if (comm->nranks() < 2) { + GTEST_SKIP() << "Need at least 2 ranks"; + } + auto ucx_comm = static_cast(comm); + Tag const ready_for_data_tag{0, 1}; + Tag const metadata_tag{0, 2}; + Tag const gpu_data_tag{0, 3}; + + constexpr size_t nelems{8}; + constexpr shuffler::detail::ChunkID chunk_id{100}; + constexpr shuffler::PartID part_id{100}; + + // Create dummy metadata and data + auto metadata = iota_vector(nelems, 100); // Start from 100 + auto data = iota_vector(nelems, 200); // Start from 200 + + // Create PackedData using the helper function + auto packed_data = create_packed_data(metadata, data, stream, br.get()); + + auto chunk = shuffler::detail::Chunk::from_packed_data( + chunk_id, part_id, std::move(packed_data) + ); + + std::vector> futures; + + if (comm->rank() == 0) { + // send metadata to rank 1 + auto serialized_metadata = chunk.serialize(); + + // send metadata to rank 1 + futures.emplace_back(ucx_comm->send( + std::move(serialized_metadata), Rank(1), metadata_tag, br.get() + )); + + // recive ready for data from rank 1 + auto ready_for_data = std::make_unique>( + shuffler::detail::ReadyForDataMessage::byte_size + ); + + futures.emplace_back(ucx_comm->recv_with_cb( + Rank(1), + ready_for_data_tag, + br->move(std::move(ready_for_data)), + [&](std::unique_ptr buf) { + auto const& host_buf = br->move_to_host_vector(std::move(buf)); + EXPECT_EQ( + host_buf->size(), shuffler::detail::ReadyForDataMessage::byte_size + ); + + shuffler::detail::ChunkID cid; + std::memcpy(&cid, host_buf->data(), sizeof(cid)); + EXPECT_EQ(cid, chunk_id); + + auto data_buf = chunk.release_data_buffer(); + data_buf->wait_for_ready(); + futures.emplace_back( + ucx_comm->send(std::move(data_buf), Rank(1), gpu_data_tag) + ); + } + )); + } else if (comm->rank() == 1) { + auto recv_any_cb = [&](std::unique_ptr buf, Rank sender_rank) { + auto const& recv_buf = br->move_to_host_vector(std::move(buf)); + EXPECT_EQ(sender_rank, 0); + auto chunk = shuffler::detail::Chunk::deserialize(*recv_buf); + EXPECT_EQ(chunk.chunk_id(), chunk_id); + + // allocate data buffer + auto [reservation, ob] = + br->reserve(MemoryType::DEVICE, chunk.concat_data_size(), false); + auto data_buf = br->allocate(chunk.concat_data_size(), stream, reservation); + data_buf->wait_for_ready(); + + // post recv for data from rank 0 + futures.emplace_back(ucx_comm->recv_with_cb( + sender_rank, + gpu_data_tag, + std::move(data_buf), + [data_size = chunk.concat_data_size()](std::unique_ptr buf) { + EXPECT_EQ(data_size, buf->size); + } + )); + + // send ready for data to rank 0 + auto ready_for_data = std::make_unique>( + shuffler::detail::ReadyForDataMessage::byte_size + ); + std::memcpy(ready_for_data->data(), &chunk_id, sizeof(chunk_id)); + futures.emplace_back(ucx_comm->send( + br->move(std::move(ready_for_data)), Rank(0), ready_for_data_tag + )); + }; + + while (true) { + auto fut = ucx_comm->recv_any_with_cb(metadata_tag, recv_any_cb, br.get()); + if (fut) { + futures.emplace_back(std::move(fut)); + break; + } + } + } // else do nothing + + while (!futures.empty()) { + std::ignore = ucx_comm->test_some(futures); + } +}