Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions cpp/include/rapidsmpf/communicator/ucxx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,70 @@ class UCXX final : public Communicator {
Rank rank, Tag tag, std::unique_ptr<Buffer> recv_buffer
) override;

/**
* Send a message with a callback.
*
* @param msg The message to send.
* @param rank The rank to send the message to.
* @param tag The tag to send the message with.
* @param send_cb The callback to call when the message is sent. The send buffer will
* be delivered into the callback at the end of send operation.
*
* @return A future representing the send operation.
* @note The returned future would not contain any data. Therefore, can not be used in
* `get_gpu_data` method. Use the callback to get the data.
*
* @throws std::runtime_error if the send operation fails.
*/
std::unique_ptr<Communicator::Future> send_with_cb(
std::unique_ptr<Buffer> msg,
Rank rank,
Tag tag,
std::function<void(std::unique_ptr<Buffer>)> send_cb
);

/**
* Receive a message with a callback.
*
* @param rank The rank to receive the message from.
* @param tag The tag to receive the message with.
* @param recv_buffer The buffer to receive the message into.
* @param recv_cb The callback to call when the message is received. The receive
* buffer will be delivered into the callback at the end of receive operation.
*
* @return A future representing the receive operation.
* @note The returned future would not contain any data. Therefore, can not be used in
* `get_gpu_data` method. Use the callback to get the data.
*
* @throws std::runtime_error if the receive operation fails.
*/
std::unique_ptr<Communicator::Future> recv_with_cb(
Rank rank,
Tag tag,
std::unique_ptr<Buffer> recv_buffer,
std::function<void(std::unique_ptr<Buffer>)> recv_cb
);

/**
* Receive a message with a callback.
*
* @param tag The tag to receive the message with.
* @param recv_cb The callback to call when the message is received. The receive
* buffer will be delivered into the callback at the end of receive operation.
* @param br The buffer resource to allocate the receive buffer.
*
* @return A future representing the receive operation.
* @note The returned future would not contain any data. Therefore, can not be used in
* `get_gpu_data` method. Use the callback to get the data.
*
* @throws std::runtime_error if the receive operation fails.
*/
std::unique_ptr<Communicator::Future> recv_any_with_cb(
Tag tag,
std::function<void(std::unique_ptr<Buffer>, Rank)> recv_cb,
BufferResource* br
);

/**
* @copydoc Communicator::recv_any
*
Expand Down
92 changes: 92 additions & 0 deletions cpp/src/communicator/ucxx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,98 @@ std::shared_ptr<UCXX> UCXX::split() {
return std::make_shared<UCXX>(std::move(initialized_rank), options_);
}

namespace {

/// @brief Temporary data structure to hold the message and deliver it to the callback.
struct CallbackData {
std::unique_ptr<Buffer> msg;
};

} // namespace

std::unique_ptr<Communicator::Future> UCXX::send_with_cb(
std::unique_ptr<Buffer> msg,
Rank rank,
Tag tag,
std::function<void(std::unique_ptr<Buffer>)> send_cb
) {
if (!msg->is_ready()) {
logger().warn("msg is not ready. This is irrecoverable, terminating.");
std::terminate();
}
auto data_ptr = msg->data();
auto data_size = msg->size;
auto req = get_endpoint(rank)->tagSend(
data_ptr,
data_size,
tag_with_rank(shared_resources_->rank(), tag),
false,
[cb = std::move(send_cb)](ucs_status_t status, std::shared_ptr<void> data) {
RAPIDSMPF_EXPECTS(status == UCS_OK, "UCXX send failed", std::runtime_error);
cb(std::move(static_pointer_cast<CallbackData>(data)->msg));
},
std::make_shared<CallbackData>(std::move(msg))
);
return std::make_unique<Future>(req, nullptr);
}

std::unique_ptr<Communicator::Future> UCXX::recv_with_cb(
Rank rank,
Tag tag,
std::unique_ptr<Buffer> recv_buffer,
std::function<void(std::unique_ptr<Buffer>)> recv_cb
) {
if (!recv_buffer->is_ready()) {
logger().warn("recv_buffer is not ready. This is irrecoverable, terminating.");
std::terminate();
}
auto data_ptr = recv_buffer->data();
auto data_size = recv_buffer->size;
auto req = get_endpoint(rank)->tagRecv(
data_ptr,
data_size,
tag_with_rank(rank, tag),
::ucxx::TagMaskFull,
false,
[cb = std::move(recv_cb)](ucs_status_t status, std::shared_ptr<void> data) {
RAPIDSMPF_EXPECTS(status == UCS_OK, "UCXX recv failed", std::runtime_error);
cb(std::move(static_pointer_cast<CallbackData>(data)->msg));
},
std::make_shared<CallbackData>(std::move(recv_buffer))
);
return std::make_unique<Future>(req, nullptr);
}

std::unique_ptr<Communicator::Future> UCXX::recv_any_with_cb(
Tag tag,
std::function<void(std::unique_ptr<Buffer>, Rank)> recv_cb,
BufferResource* br
) {
auto [msg_available, info] = shared_resources_->get_worker()->tagProbe(
::ucxx::Tag(static_cast<int>(tag)), UserTagMask
);
Comment on lines +1358 to +1360
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we would use the new implementation from rapidsai/ucxx#458, thus we won't need to go through the receive queue twice, but we need someone to review that PR.


if (!msg_available) {
return nullptr;
}

auto sender_rank = static_cast<Rank>(info.senderTag >> 32);

// Create a buffer to receive the message
auto msg = std::make_unique<std::vector<uint8_t>>(info.length);

// Receive the message
return recv_with_cb(
sender_rank,
tag,
br->move(std::move(msg)),
[cb = std::move(recv_cb), sender_rank](std::unique_ptr<Buffer> msg) {
cb(std::move(msg), sender_rank);
}
);
}


} // namespace ucxx

} // namespace rapidsmpf
Loading
Loading