Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmake/oomph_ucx.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ if (OOMPH_WITH_UCX)
target_compile_definitions(oomph_ucx PRIVATE OOMPH_UCX_USE_SPIN_LOCK)
endif()

set(OOMPH_UCX_USE_MULTIPLE_ENDPOINTS ON CACHE BOOL "use one shared recv endpoint, and one send endpoint per thread")
if (OOMPH_UCX_USE_MULTIPLE_ENDPOINTS)
target_compile_definitions(oomph_ucx PRIVATE OOMPH_UCX_USE_MULTIPLE_ENDPOINTS)
endif()

install(TARGETS oomph_ucx
EXPORT oomph-targets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
Expand Down
17 changes: 17 additions & 0 deletions src/ucx/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,23 @@ class communicator_impl : public communicator_base<communicator_impl>

~communicator_impl()
{
#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS
// schedule all endpoints for closing
for (auto& kvp : m_send_worker->m_endpoint_cache)
{
m_send_worker->m_endpoint_handles.push_back(kvp.second.close());
m_send_worker->m_endpoint_handles.back().progress();
}
#endif
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the single-endpoint code path is missing here

}

auto& get_heap() noexcept { return m_context->get_heap(); }

void progress()
{
while (ucp_worker_progress(m_send_worker->get())) {}

#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS
if (m_thread_safe)
{
#ifdef OOMPH_UCX_USE_SPIN_LOCK
Expand All @@ -90,6 +94,7 @@ class communicator_impl : public communicator_base<communicator_impl>
{
while (ucp_worker_progress(m_recv_worker->get())) {}
}
#endif
// work through ready recv callbacks, which were pushed to the queue by other threads
// (including this thread)
if (m_thread_safe)
Expand Down Expand Up @@ -158,8 +163,10 @@ class communicator_impl : public communicator_base<communicator_impl>
// pointer to store callback in case of early completion
request_data::cb_ptr_t cb_ptr = nullptr;
{
#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS
// locked region
if (m_thread_safe) m_mutex.lock();
#endif

ucs_status_ptr_t ret;
{
Expand Down Expand Up @@ -203,7 +210,9 @@ class communicator_impl : public communicator_base<communicator_impl>
throw std::runtime_error("oomph: ucx error - recv operation failed");
}

#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS
if (m_thread_safe) m_mutex.unlock();
#endif
}
// check for early completion
if (cb_ptr)
Expand Down Expand Up @@ -283,16 +292,24 @@ class communicator_impl : public communicator_base<communicator_impl>
auto& req_data = request_data::get(req.m_data->m_data);
{
// locked region
#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS
if (m_thread_safe) m_mutex.lock();
#endif
ucp_request_cancel(m_recv_worker->get(), req_data.m_ucx_ptr);
#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS
if (m_thread_safe) m_mutex.unlock();
#endif
}
// The ucx callback will still be executed after the cancel. However, the status argument
// will indicate whether the cancel was successful.
// Progress the receive worker in order to execute the ucx callback
#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS
if (m_thread_safe) m_mutex.lock();
#endif
while (ucp_worker_progress(m_recv_worker->get())) {}
#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS
if (m_thread_safe) m_mutex.unlock();
#endif
// check whether the cancelled callback was enqueued by consuming all queued cancelled
// callbacks and putting them in a temporary vector
bool found = false;
Expand Down
7 changes: 7 additions & 0 deletions src/ucx/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,14 @@ class context_impl : public context_base
// make shared worker
// use single-threaded UCX mode, as per developer advice
// https://github.com/openucx/ucx/issues/4609
#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS
m_worker.reset(new worker_type{get(), m_db, UCS_THREAD_MODE_SINGLE});
#else
if(this->m_thread_safe)
m_worker.reset(new worker_type{get(), m_db, UCS_THREAD_MODE_MULTI});
else
m_worker.reset(new worker_type{get(), m_db, UCS_THREAD_MODE_SINGLE});
#endif

// intialize database
m_db.init(m_worker->address());
Expand Down
9 changes: 9 additions & 0 deletions src/ucx/src.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace oomph
communicator_impl*
context_impl::get_communicator()
{
#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS
auto send_worker = std::make_unique<worker_type>(get(), m_db,
(m_thread_safe ? UCS_THREAD_MODE_SERIALIZED : UCS_THREAD_MODE_SINGLE));
auto send_worker_ptr = send_worker.get();
Expand All @@ -33,6 +34,10 @@ context_impl::get_communicator()
}
auto comm =
new communicator_impl{this, m_thread_safe, m_worker.get(), send_worker_ptr, m_mutex};
#else
auto comm =
new communicator_impl{this, m_thread_safe, m_worker.get(), m_worker.get(), m_mutex};
#endif
m_comms_set.insert(comm);
return comm;
}
Expand All @@ -46,6 +51,7 @@ context_impl::~context_impl()
double elapsed = 0.0;
static constexpr double t_timeout = 1000;

#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS
// close endpoints while also progressing the receive worker
std::vector<endpoint_t::close_handle> handles;
for (auto& w_ptr : m_workers)
Expand Down Expand Up @@ -75,6 +81,7 @@ context_impl::~context_impl()
// free all requests for the unclosed endpoints
for (auto& h : handles) ucp_request_free(h.m_status);
}
#endif

// issue another non-blocking barrier while progressing the receive worker in order to flush all
// remaining (remote) endpoints which are connected to this receive worker
Expand All @@ -89,7 +96,9 @@ context_impl::~context_impl()
}

// receive worker should not have connected to any endpoint
#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS
assert(m_worker->m_endpoint_cache.size() == 0);
#endif

// another MPI barrier to be sure
MPI_Barrier(m_mpi_comm);
Expand Down