diff --git a/cmake/oomph_ucx.cmake b/cmake/oomph_ucx.cmake index ad6155a2..a7ed4355 100644 --- a/cmake/oomph_ucx.cmake +++ b/cmake/oomph_ucx.cmake @@ -26,6 +26,11 @@ if (OOMPH_WITH_UCX) target_compile_definitions(oomph_ucx PRIVATE OOMPH_UCX_USE_SPIN_LOCK) endif() + set(OOMPH_UCX_USE_MULTIPLE_ENDPOINTS ON CACHE BOOL "use one shared recv endpoint, and one send endpoint per thread") + if (OOMPH_UCX_USE_MULTIPLE_ENDPOINTS) + target_compile_definitions(oomph_ucx PRIVATE OOMPH_UCX_USE_MULTIPLE_ENDPOINTS) + endif() + install(TARGETS oomph_ucx EXPORT oomph-targets LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} diff --git a/src/ucx/communicator.hpp b/src/ucx/communicator.hpp index c6e03712..12521d6f 100644 --- a/src/ucx/communicator.hpp +++ b/src/ucx/communicator.hpp @@ -61,12 +61,14 @@ class communicator_impl : public communicator_base ~communicator_impl() { +#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS // schedule all endpoints for closing for (auto& kvp : m_send_worker->m_endpoint_cache) { m_send_worker->m_endpoint_handles.push_back(kvp.second.close()); m_send_worker->m_endpoint_handles.back().progress(); } +#endif } auto& get_heap() noexcept { return m_context->get_heap(); } @@ -74,6 +76,8 @@ class communicator_impl : public communicator_base void progress() { while (ucp_worker_progress(m_send_worker->get())) {} + +#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS if (m_thread_safe) { #ifdef OOMPH_UCX_USE_SPIN_LOCK @@ -90,6 +94,7 @@ class communicator_impl : public communicator_base { while (ucp_worker_progress(m_recv_worker->get())) {} } +#endif // work through ready recv callbacks, which were pushed to the queue by other threads // (including this thread) if (m_thread_safe) @@ -158,8 +163,10 @@ class communicator_impl : public communicator_base // pointer to store callback in case of early completion request_data::cb_ptr_t cb_ptr = nullptr; { +#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS // locked region if (m_thread_safe) m_mutex.lock(); +#endif ucs_status_ptr_t ret; { @@ -203,7 +210,9 @@ class communicator_impl : public communicator_base throw std::runtime_error("oomph: ucx error - recv operation failed"); } +#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS if (m_thread_safe) m_mutex.unlock(); +#endif } // check for early completion if (cb_ptr) @@ -283,16 +292,24 @@ class communicator_impl : public communicator_base auto& req_data = request_data::get(req.m_data->m_data); { // locked region +#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS if (m_thread_safe) m_mutex.lock(); +#endif ucp_request_cancel(m_recv_worker->get(), req_data.m_ucx_ptr); +#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS if (m_thread_safe) m_mutex.unlock(); +#endif } // The ucx callback will still be executed after the cancel. However, the status argument // will indicate whether the cancel was successful. // Progress the receive worker in order to execute the ucx callback +#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS if (m_thread_safe) m_mutex.lock(); +#endif while (ucp_worker_progress(m_recv_worker->get())) {} +#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS if (m_thread_safe) m_mutex.unlock(); +#endif // check whether the cancelled callback was enqueued by consuming all queued cancelled // callbacks and putting them in a temporary vector bool found = false; diff --git a/src/ucx/context.hpp b/src/ucx/context.hpp index f89d91c6..9b715d5c 100644 --- a/src/ucx/context.hpp +++ b/src/ucx/context.hpp @@ -126,7 +126,14 @@ class context_impl : public context_base // make shared worker // use single-threaded UCX mode, as per developer advice // https://github.com/openucx/ucx/issues/4609 +#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS m_worker.reset(new worker_type{get(), m_db, UCS_THREAD_MODE_SINGLE}); +#else + if(this->m_thread_safe) + m_worker.reset(new worker_type{get(), m_db, UCS_THREAD_MODE_MULTI}); + else + m_worker.reset(new worker_type{get(), m_db, UCS_THREAD_MODE_SINGLE}); +#endif // intialize database m_db.init(m_worker->address()); diff --git a/src/ucx/src.cpp b/src/ucx/src.cpp index 05fd1823..b2043a6f 100644 --- a/src/ucx/src.cpp +++ b/src/ucx/src.cpp @@ -19,6 +19,7 @@ namespace oomph communicator_impl* context_impl::get_communicator() { +#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS auto send_worker = std::make_unique(get(), m_db, (m_thread_safe ? UCS_THREAD_MODE_SERIALIZED : UCS_THREAD_MODE_SINGLE)); auto send_worker_ptr = send_worker.get(); @@ -33,6 +34,10 @@ context_impl::get_communicator() } auto comm = new communicator_impl{this, m_thread_safe, m_worker.get(), send_worker_ptr, m_mutex}; +#else + auto comm = + new communicator_impl{this, m_thread_safe, m_worker.get(), m_worker.get(), m_mutex}; +#endif m_comms_set.insert(comm); return comm; } @@ -46,6 +51,7 @@ context_impl::~context_impl() double elapsed = 0.0; static constexpr double t_timeout = 1000; +#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS // close endpoints while also progressing the receive worker std::vector handles; for (auto& w_ptr : m_workers) @@ -75,6 +81,7 @@ context_impl::~context_impl() // free all requests for the unclosed endpoints for (auto& h : handles) ucp_request_free(h.m_status); } +#endif // issue another non-blocking barrier while progressing the receive worker in order to flush all // remaining (remote) endpoints which are connected to this receive worker @@ -89,7 +96,9 @@ context_impl::~context_impl() } // receive worker should not have connected to any endpoint +#ifdef OOMPH_UCX_USE_MULTIPLE_ENDPOINTS assert(m_worker->m_endpoint_cache.size() == 0); +#endif // another MPI barrier to be sure MPI_Barrier(m_mpi_comm);