diff --git a/ep/bench/test_internode.py b/ep/bench/test_internode.py index 3ea106ce..222604ef 100644 --- a/ep/bench/test_internode.py +++ b/ep/bench/test_internode.py @@ -502,7 +502,7 @@ def test_loop( assert num_local_ranks == 8 and num_ranks > 8 - for seed in range(int(1e9)): + for seed in range(650, int(1e9)): if local_rank == 0: print(f"Testing with seed {seed} ...", flush=True) torch.manual_seed(rank + seed) diff --git a/ep/bench/utils.py b/ep/bench/utils.py index 732c2d15..26c49321 100644 --- a/ep/bench/utils.py +++ b/ep/bench/utils.py @@ -584,6 +584,18 @@ def initialize_uccl( ep.register_proxies(local_rank, proxies) + # Set atomic buffer pointer for all proxies BEFORE starting them + # This ensures the atomic buffer info is included in connection info exchange + # Note: Only thread 0's proxy allocates the atomic buffer in its constructor + if not is_intranode and len(proxies) > 0: + # Get atomic buffer pointer from thread 0 proxy (only thread 0 allocates it) + # This must be done before start_dual() so the atomic buffer info is included + # in the connection info exchange during init_common() + atomic_buffer_ptr = proxies[0].get_atomic_buffer_ptr() + if atomic_buffer_ptr: + for proxy in proxies: + proxy.set_atomic_buffer_ptr(atomic_buffer_ptr) + dist.barrier(group) if not is_intranode: for proxy in proxies: diff --git a/ep/include/barrier_local.hpp b/ep/include/barrier_local.hpp index 2fb409b7..fb0256fb 100644 --- a/ep/include/barrier_local.hpp +++ b/ep/include/barrier_local.hpp @@ -10,6 +10,6 @@ struct LocalBarrier { std::atomic arrive_seq[UCCL_MAX_LOCAL_RANKS]; std::atomic release_seq[UCCL_MAX_LOCAL_RANKS]; std::atomic seq; - std::atomic full_mask; // unchanged; still used for size/info - std::atomic arrived_mask; // optional: keep only for debug prints + std::atomic full_mask; + std::atomic arrived_mask; }; \ No newline at end of file diff --git a/ep/include/common.hpp b/ep/include/common.hpp index 25c6ee08..18533612 100644 --- a/ep/include/common.hpp +++ b/ep/include/common.hpp @@ -12,6 +12,7 @@ #include #include +// #define SOFTWARE_ORDERING #define MAX_IB_DEVS 32 // #define MEASURE_PER_OP_LATENCY // #define MEASURE_PER_VERB_LATENCY @@ -61,6 +62,20 @@ extern bool use_ll_sl; #define kBarrierWrTag 0xbaba000000000000ULL #define kBarrierMask 0x0000FFFFFFFFFFFFULL #define kPrintCycleInterval 100000000000ULL +#define kRingIdxBits 10 +#define kRingIdxMask ((1ULL << kRingIdxBits) - 1ULL) +#define kRingIdxShift kRingIdxBits + +// WR ID helpers for ring-indexed work requests: +// - lower kRingIdxBits bits: ring index +// - upper bits: sequence / command index +inline uint64_t make_ring_wr_id(uint64_t seq, uint64_t ring_idx) { + return (seq << kRingIdxBits) | (ring_idx & kRingIdxMask); +} +inline uint64_t ring_wr_seq(uint64_t wrid) { return wrid >> kRingIdxBits; } +inline uint32_t ring_wr_idx(uint64_t wrid) { + return static_cast(wrid & kRingIdxMask); +} #define MAX_RETRIES 100 #define RETRY_DELAY_MS 50 #define QKEY 0x11111111u diff --git a/ep/include/ep_utils.cuh b/ep/include/ep_utils.cuh index 0c0e548b..a9829365 100644 --- a/ep/include/ep_utils.cuh +++ b/ep/include/ep_utils.cuh @@ -821,7 +821,7 @@ __forceinline__ __device__ int atomic_exch_cta_release(int* addr, int x) { return ret; } -template +template __forceinline__ __device__ void barrier_block(int** barrier_signal_ptrs, int rank) { auto thread_id = static_cast(threadIdx.x); @@ -849,10 +849,17 @@ __forceinline__ __device__ void barrier_block(int** barrier_signal_ptrs, if (__all_sync(WARP_MASK, value <= 0)) break; if (clock64() - start_time > NUM_TIMEOUT_CYCLES and thread_id < kNumRanks) { - printf( - "DeepEP timeout check failed: rank = %d, thread = %d, value = " - "%d)\n", - rank, thread_id, value); + if (label == 0) { + printf( + "DeepEP timeout check failed: rank = %d, thread = %d, value = " + "%d)\n", + rank, thread_id, value); + } else { + printf( + "DeepEP timeout check failed: rank = %d, thread = %d, value = %d, " + "label = %d)\n", + rank, thread_id, value, label); + } trap(); } } @@ -919,8 +926,7 @@ __device__ __forceinline__ int ld_acquire_cta(int const* ptr) { __forceinline__ __device__ void acquire_lock(int* mutex) { // To make later memory operations valid, we must use `acquire` for memory // semantics - while (atomic_cas_cta_acquire(mutex, 0, 1) != 0) - ; + while (atomic_cas_cta_acquire(mutex, 0, 1) != 0); } __forceinline__ __device__ void release_lock(int* mutex) { diff --git a/ep/include/proxy.hpp b/ep/include/proxy.hpp index 8e74aaf7..e14ef213 100644 --- a/ep/include/proxy.hpp +++ b/ep/include/proxy.hpp @@ -94,7 +94,7 @@ class Proxy { void init_remote(); void notify_gpu_completion(uint64_t& my_tail); - void post_gpu_command(uint64_t& my_tail, size_t& seen); + void post_gpu_command(uint64_t& my_tail, uint64_t& seen); void post_gpu_commands_mixed(std::vector const& wrs_to_post, std::vector const& cmds_to_post); void post_barrier_msg(int dst_rank, bool ack, uint64_t seq); diff --git a/ep/include/proxy_ctx.hpp b/ep/include/proxy_ctx.hpp index f87f70ce..0b5595e5 100644 --- a/ep/include/proxy_ctx.hpp +++ b/ep/include/proxy_ctx.hpp @@ -59,6 +59,12 @@ struct ProxyCtx { uint32_t remote_rkey = 0; uint32_t rkey = 0; + // Atomic buffer (separate from main RDMA buffer) + ibv_mr* atomic_buffer_mr = nullptr; // MR for local atomic_buffer_ptr + uintptr_t remote_atomic_buffer_addr = 0; // Remote atomic_buffer_ptr address + uint64_t remote_atomic_buffer_len = 0; // Remote atomic_buffer_ptr length + uint32_t remote_atomic_buffer_rkey = 0; // Remote atomic_buffer_ptr rkey + // Buffer offset within rdma_buffer for address translation uintptr_t dispatch_recv_data_offset = 0; // offset of dispatch_rdma_recv_data_buffer from rdma_buffer base @@ -102,7 +108,8 @@ struct ProxyCtx { // Async-barrier state (single inflight assumed) bool barrier_inflight = false; - uint64_t barrier_seq = 0; + // For debuigging only. + uint64_t barrier_seq = 1; int barrier_wr = -1; bool quiet_inflight = false; @@ -124,7 +131,7 @@ struct ProxyCtx { int local_rank = -1; // convenience mirror of cfg_.local_rank int thread_idx = -1; // thread index used in shm name - std::unordered_map next_seq_per_index; + std::unordered_map next_seq_per_index; inline uint64_t seq_key(int dst_rank, size_t index) { // assumes dst_rank fits 32 bits; if index > 32 bits, prefer Pair Hash below return (static_cast(static_cast(dst_rank)) << 32) ^ diff --git a/ep/include/rdma.hpp b/ep/include/rdma.hpp index 76ddc062..6f229d2a 100644 --- a/ep/include/rdma.hpp +++ b/ep/include/rdma.hpp @@ -26,6 +26,11 @@ struct RDMAConnectionInfo { uint16_t lid; // Local ID uint8_t gid[16]; // Global ID for RoCE (optional) + // Atomic buffer info (separate from main GPU buffer) + uint32_t atomic_buffer_rkey = 0; // Atomic buffer memory region key + uintptr_t atomic_buffer_addr = 0; // Atomic buffer address + uint64_t atomic_buffer_len = 0; // Atomic buffer length + // #ifdef EFA uint32_t num_rings; uint32_t data_qp_num[kChannelPerProxy]; @@ -283,13 +288,14 @@ struct BarrierImm { // [28:8]=SEQ (21 bits), [7:0]=SRC_RANK static constexpr uint32_t kCtrlBit = 1u << 30; static constexpr uint32_t kAckBit = 1u << 29; + static constexpr uint32_t kSeqMask = 0x1FFFFFu; static inline bool IsAck(uint32_t imm) { return (imm & kAckBit) != 0u; } static inline uint32_t Pack(bool ack, uint32_t seq, uint8_t src_rank) { return kCtrlBit | (ack ? kAckBit : 0u) | - ((seq & 0x1FFFFFu) << 8) // 21 bits for seq + ((seq & kSeqMask) << 8) // 21 bits for seq | uint32_t(src_rank); } - static inline uint32_t Seq(uint32_t imm) { return (imm >> 8) & 0x1FFFFFu; } + static inline uint32_t Seq(uint32_t imm) { return (imm >> 8) & kSeqMask; } static inline uint8_t Rank(uint32_t imm) { return imm & 0xFFu; } explicit BarrierImm(uint32_t imm = 0) : value(imm) {} bool GetIsAck() const { return IsAck(value); } @@ -329,7 +335,8 @@ void remote_process_completions( int my_rank, int num_nodes, bool use_normal_mode = false); void create_per_thread_qp(ProxyCtx& S, void* gpu_buffer, size_t size, RDMAConnectionInfo* local_info, int rank, - size_t num_rings, bool use_normal_mode); + size_t num_rings, bool use_normal_mode, + void* atomic_buffer_ptr = nullptr); ibv_cq* create_per_thread_cq(ProxyCtx& S); void remote_poll_completions(ProxyCtx& S, int idx, CopyRingBuffer& g_ring, std::vector& ctx_by_tag, diff --git a/ep/include/uccl_ibgda.cuh b/ep/include/uccl_ibgda.cuh index 41b48d1c..a44eb824 100644 --- a/ep/include/uccl_ibgda.cuh +++ b/ep/include/uccl_ibgda.cuh @@ -320,6 +320,7 @@ __device__ static __forceinline__ void nvshmemi_ibgda_quiet( } } #endif + // All proxy threads of the GPU will post the quiet command. break; } @@ -369,6 +370,8 @@ __forceinline__ __device__ void nvshmem_sync_with_same_gpu_idx( } } #endif + // Only one thread of the GPU will post the barrier command. + // This is because as long as proxy thread reaches the barrier, we are sure that other GPUs managed by the thread have reached the barrier. break; } diff --git a/ep/src/internode.cu b/ep/src/internode.cu index 2388bd38..74b6a3ed 100644 --- a/ep/src/internode.cu +++ b/ep/src/internode.cu @@ -135,7 +135,7 @@ __global__ void notify_dispatch( uccl::nvshmem_sync_with_same_gpu_idx(d2h_channel_addrs, num_d2h_channel_addrs, nvl_rank); } - barrier_block(barrier_signal_ptrs, nvl_rank); + barrier_block(barrier_signal_ptrs, nvl_rank); // Send numbers of tokens per rank/expert to RDMA ranks auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); @@ -290,7 +290,7 @@ __global__ void notify_dispatch( nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i]; } - barrier_block(barrier_signal_ptrs, nvl_rank); + barrier_block(barrier_signal_ptrs, nvl_rank); // Reduce the number of tokens per rank/expert EP_DEVICE_ASSERT(num_nvl_experts <= num_threads); @@ -323,7 +323,7 @@ __global__ void notify_dispatch( if (thread_id == WARP_SIZE) uccl::nvshmem_sync_with_same_gpu_idx(d2h_channel_addrs, num_d2h_channel_addrs, nvl_rank); - barrier_block(barrier_signal_ptrs, nvl_rank); + barrier_block(barrier_signal_ptrs, nvl_rank); } else { // Calculate meta data int dst_rdma_rank = sm_id - 1; @@ -1025,10 +1025,21 @@ __global__ void __launch_bounds__( num_bytes_per_msg, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, // NOTE(MaoZiming): use channel_id for rb. - lane_id, 0, d2h_channel_addrs, num_d2h_channel_addrs, false, -1, + lane_id, 0, d2h_channel_addrs, num_d2h_channel_addrs, false, + // NOTE(MaoZiming): for AMD GPUs, we directly send a subsequent RDMA + // to update the tail. For other GPUs and EFA NICs, we use the + // CPU-emulated atomics, allow us to piggyback the atomic operation + // with the RDMA send. +#if (defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)) && !defined(SOFTWARE_ORDERING) + + -1, +#else + -1, reinterpret_cast(rdma_channel_tail.buffer(rdma_rank)) - reinterpret_cast(original_atomic_buffer_ptr), - num_tokens_to_issue); + num_tokens_to_issue +#endif + ); } else { // Lighter fence for local RDMA rank memory_fence(); @@ -1046,7 +1057,13 @@ __global__ void __launch_bounds__( translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, // NOTE(MaoZiming): use channel_id for rb. dst_rdma_rank == rdma_rank, d2h_channel_addrs, - num_d2h_channel_addrs, false, -1, true); + num_d2h_channel_addrs, false, -1, +#if (defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)) && !defined(SOFTWARE_ORDERING) + false +#else + true +#endif + ); } __syncwarp(); } @@ -1152,7 +1169,7 @@ __global__ void __launch_bounds__( if (__shfl_sync(WARP_MASK, num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) { if (lane_id == src_rdma_rank) -#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) +#if (defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)) && !defined(SOFTWARE_ORDERING) cached_rdma_channel_tail = static_cast(__atomic_load_n( rdma_channel_tail.buffer(src_rdma_rank), __ATOMIC_SEQ_CST)); #else @@ -1178,7 +1195,7 @@ __global__ void __launch_bounds__( trap(); } } -#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) +#if (defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)) && defined(SOFTWARE_ORDERING) memory_fence(); #endif auto src_rdma_head = @@ -1194,7 +1211,15 @@ __global__ void __launch_bounds__( int seen_bits = ld_nc_global(reinterpret_cast( shifted + hidden_bytes + scale_bytes)) .is_token_in_nvl_rank_bits; - if (seen_bits == 0) trap(); + if (seen_bits == 0) { + printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, " + "RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, " + "tail: %d, expected: %d\n", + channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, + cached_rdma_channel_head, cached_rdma_channel_tail, + num_tokens_to_recv_from_rdma); + trap(); + } lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0; bool is_in_dst_nvl_rank = (seen_bits >> dst_nvl_rank) & 1; if (lane_id == src_rdma_rank) { @@ -1249,13 +1274,8 @@ __global__ void __launch_bounds__( // Move tail index __syncwarp(); if (lane_id == 0) -#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) - __atomic_store_n(nvl_channel_tail.buffer(), cached_nvl_channel_tail, - __ATOMIC_RELEASE); -#else st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail); -#endif } // Retired __syncwarp(); @@ -1589,7 +1609,7 @@ __global__ void cached_notify( num_d2h_channel_addrs, nvl_rank, 3); // Barrier for NVL - barrier_block(barrier_signal_ptrs, nvl_rank); + barrier_block(barrier_signal_ptrs, nvl_rank); // Clean RDMA buffer auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); @@ -1626,7 +1646,7 @@ __global__ void cached_notify( uccl::nvshmem_sync_with_same_gpu_idx(d2h_channel_addrs, num_d2h_channel_addrs, nvl_rank); - barrier_block(barrier_signal_ptrs, nvl_rank); + barrier_block(barrier_signal_ptrs, nvl_rank); } else if (sm_id == 1) { if (is_cached_dispatch) return; @@ -2611,10 +2631,14 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1) nvl_rank), channel_id, // NOTE(MaoZiming): use channel_id for rb. lane_id, 0, d2h_channel_addrs, num_d2h_channel_addrs, false, -1, +#if (defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)) && !defined(SOFTWARE_ORDERING) +#else reinterpret_cast( rdma_channel_tail.buffer(rdma_rank)) - reinterpret_cast(original_atomic_buffer_ptr), - num_chunked_tokens); + num_chunked_tokens +#endif + ); } else { memory_fence(); } @@ -2630,7 +2654,13 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1) nvl_rank), channel_id, // NOTE(MaoZiming): use warp_id for rb. dst_rdma_rank == rdma_rank, d2h_channel_addrs, - num_d2h_channel_addrs, false, -1, true); + num_d2h_channel_addrs, false, -1, + #if (defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)) && !defined(SOFTWARE_ORDERING) + false +#else + true +#endif + ); } } } diff --git a/ep/src/proxy.cpp b/ep/src/proxy.cpp index 7248f030..5a2c6128 100644 --- a/ep/src/proxy.cpp +++ b/ep/src/proxy.cpp @@ -45,7 +45,7 @@ LocalBarrier* map_local_barrier_shm(std::string const& name, bool* out_owner) { perror("shm_open(existing)"); return nullptr; } - struct stat st {}; + struct stat st{}; int tries = 1000; while (tries-- > 0) { if (fstat(fd, &st) == 0 && static_cast(st.st_size) >= kSize) @@ -176,6 +176,32 @@ void Proxy::init_common() { cfg_.thread_idx, cfg_.local_rank); pin_thread_to_numa_wrapper(); if (!ctx_.cq) ctx_.cq = create_per_thread_cq(ctx_); + + // Register atomic_buffer_ptr as a separate RDMA memory region if it was set + // This must be done after PD is initialized by per_thread_rdma_init + if (atomic_buffer_ptr_ && !ctx_.atomic_buffer_mr) { + ctx_.atomic_buffer_mr = + ibv_reg_mr(ctx_.pd, atomic_buffer_ptr_, kAtomicBufferSize, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | +#ifdef EFA + IBV_ACCESS_REMOTE_READ +#else + IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC +#endif + ); + + if (!ctx_.atomic_buffer_mr) { + perror("Failed to register atomic_buffer_ptr MR"); + std::abort(); + } + + fprintf(stderr, + "[Proxy] Registered atomic_buffer_ptr MR: addr=0x%llx, len=%zu, " + "rkey=0x%x\n", + (unsigned long long)ctx_.atomic_buffer_mr->addr, + (size_t)ctx_.atomic_buffer_mr->length, ctx_.atomic_buffer_mr->rkey); + } + if (ctxs_for_all_ranks_.empty()) { fprintf(stderr, "Error: peers metadata not set before init_common (peers_.size() " @@ -195,6 +221,14 @@ void Proxy::init_common() { ctx_.atomic_old_values_buf = reinterpret_cast(static_cast(cfg_.gpu_buffer) + cfg_.total_size - atomic_buf_size); + // Check alignment - only abort if not aligned (8-byte alignment required for + // 64-bit atomics) + if ((reinterpret_cast(ctx_.atomic_old_values_buf) & 0x7) != 0) { + fprintf(stderr, "Atomic buffer not 8-byte aligned: 0x%llx\n", + (unsigned long long)reinterpret_cast( + ctx_.atomic_old_values_buf)); + std::abort(); + } int num_ranks = ctxs_for_all_ranks_.size(); local_infos_.assign(num_ranks, RDMAConnectionInfo{}); @@ -218,6 +252,8 @@ void Proxy::init_common() { // NOTE(MaoZiming): each context can share the same cq, pd, mr. // but the qp must be different. c.cq = ctx_.cq; + // Share the atomic buffer MR with peer contexts + c.atomic_buffer_mr = ctx_.atomic_buffer_mr; if (peer == my_rank) continue; // Skip rdma connection for intra-node. @@ -226,7 +262,7 @@ void Proxy::init_common() { continue; create_per_thread_qp(c, cfg_.gpu_buffer, cfg_.total_size, &local_infos_[peer], my_rank, cfg_.d2h_queues.size(), - cfg_.use_normal_mode); + cfg_.use_normal_mode, atomic_buffer_ptr_); modify_qp_to_init(c); } @@ -291,12 +327,24 @@ void Proxy::init_common() { c.remote_addr = remote_infos_[peer].addr; c.remote_rkey = remote_infos_[peer].rkey; c.remote_len = remote_infos_[peer].len; - if (FILE* f = fopen("/tmp/uccl_debug.txt", "a")) { + + // Set remote atomic buffer info from exchanged connection info + c.remote_atomic_buffer_addr = remote_infos_[peer].atomic_buffer_addr; + c.remote_atomic_buffer_len = remote_infos_[peer].atomic_buffer_len; + c.remote_atomic_buffer_rkey = remote_infos_[peer].atomic_buffer_rkey; + + if (c.remote_atomic_buffer_addr == 0) { fprintf( - f, - "[PROXY_INIT] me=%d peer=%d: remote_addr=0x%lx local_buffer=0x%lx\n", - my_rank, peer, c.remote_addr, (uintptr_t)cfg_.gpu_buffer); - fclose(f); + stderr, + "[Proxy] WARNING: Remote atomic buffer not registered for peer %d " + "(local atomic_buffer_ptr=%p, local atomic_buffer_mr=%p)\n", + peer, atomic_buffer_ptr_, (void*)ctx_.atomic_buffer_mr); + } else { + fprintf(stderr, + "[Proxy] Remote atomic buffer info for peer %d: addr=0x%llx, " + "len=%zu, rkey=0x%x\n", + peer, (unsigned long long)c.remote_atomic_buffer_addr, + (size_t)c.remote_atomic_buffer_len, c.remote_atomic_buffer_rkey); } } usleep(50 * 1000); @@ -375,7 +423,7 @@ void Proxy::init_remote() { void Proxy::run_sender() { printf("CPU sender thread %d started\n", cfg_.thread_idx); init_sender(); - size_t seen = 0; + uint64_t seen = 0; uint64_t my_tail = 0; while (ctx_.progress_run.load(std::memory_order_acquire)) { local_poll_completions(ctx_, acked_wrs_, cfg_.thread_idx, ctx_by_tag_); @@ -419,7 +467,7 @@ void Proxy::run_dual() { #endif } uint64_t my_tail = 0; - size_t seen = 0; + uint64_t seen = 0; std::set pending_atomic_updates; while (ctx_.progress_run.load(std::memory_order_acquire)) { poll_cq_dual(ctx_, acked_wrs_, cfg_.thread_idx, ring, ctx_by_tag_, @@ -456,7 +504,7 @@ void Proxy::run_dual() { void Proxy::notify_gpu_completion(uint64_t& my_tail) { if (acked_wrs_.empty()) return; - // Mark all acked command slots in each ring's bitmask + // Mark all acked command slots in each ring's bitmask #ifdef USE_MSCCLPP_FIFO_BACKEND // FIFO path: pop in order using the pending deque and the completion set. for (size_t rb_idx = 0; rb_idx < cfg_.d2h_queues.size(); ++rb_idx) { @@ -484,8 +532,8 @@ void Proxy::notify_gpu_completion(uint64_t& my_tail) { } #else for (auto wr_id : acked_wrs_) { - size_t const rb_idx = (wr_id >> 32) & 0xFFFFFFFF; - size_t const cmd_idx = wr_id & 0xFFFFFFFF; + size_t const rb_idx = ring_wr_idx(wr_id); + size_t const cmd_idx = ring_wr_seq(wr_id); if (rb_idx >= cfg_.d2h_queues.size()) { fprintf(stderr, "Invalid rb_idx %zu in acked_wrs_\n", rb_idx); @@ -506,7 +554,7 @@ void Proxy::notify_gpu_completion(uint64_t& my_tail) { #endif } -void Proxy::post_gpu_command(uint64_t& my_tail, size_t& seen) { +void Proxy::post_gpu_command(uint64_t& my_tail, uint64_t& seen) { // Multi-ring buffer processing: collect commands from all ring buffers // Process each ring buffer (similar to test_multi_ring_throughput.cu) for (size_t rb_idx = 0; rb_idx < cfg_.d2h_queues.size(); rb_idx++) { @@ -538,8 +586,8 @@ void Proxy::post_gpu_command(uint64_t& my_tail, size_t& seen) { break; } - uint64_t unique_wr_id = (static_cast(rb_idx) << 32) | - (fifo_seq_[rb_idx]++ & 0xFFFFFFFFULL); + uint64_t seq = fifo_seq_[rb_idx]++; + uint64_t unique_wr_id = make_ring_wr_id(seq, rb_idx); wrs_to_post.push_back(unique_wr_id); cmds_to_post.push_back(cmd); fifo_pending_[rb_idx].push_back(unique_wr_id); @@ -612,9 +660,7 @@ void Proxy::post_gpu_command(uint64_t& my_tail, size_t& seen) { std::abort(); } } - - // Use a unique ID combining ring buffer index and command index - uint64_t unique_wr_id = (rb_idx << 32) | i; + uint64_t unique_wr_id = make_ring_wr_id(i, rb_idx); wrs_to_post.push_back(unique_wr_id); cmds_to_post.push_back(cmd_entry); #ifdef MEASURE_PER_VERB_LATENCY @@ -860,24 +906,27 @@ void Proxy::post_gpu_commands_mixed( atomic_cmds.clear(); } - if (!barrier_cmds.empty()) { -#ifdef USE_MSCCLPP_FIFO_BACKEND - assert(barrier_wrs.size() == 1 && ctx_.barrier_wr == -1); -#endif - send_barrier(barrier_wrs[0]); - barrier_wrs.clear(); - barrier_cmds.clear(); - } - + /* NOTE: quiet before barrier. */ if (!quiet_cmds.empty()) { #ifdef USE_MSCCLPP_FIFO_BACKEND assert(quiet_wrs.size() == 1 && ctx_.quiet_wr == -1); + assert(barrier_wrs.empty() && ctx_.barrier_wr == -1); #endif ctx_.quiet_wr = quiet_wrs[0]; quiet(quiet_wrs, quiet_cmds); quiet_wrs.clear(); quiet_cmds.clear(); } + + if (!barrier_cmds.empty()) { +#ifdef USE_MSCCLPP_FIFO_BACKEND + assert(barrier_wrs.size() == 1 && ctx_.barrier_wr == -1); + assert(quiet_wrs.empty() && ctx_.quiet_wr == -1); +#endif + send_barrier(barrier_wrs[0]); + barrier_wrs.clear(); + barrier_cmds.clear(); + } } void Proxy::quiet_cq() { @@ -917,17 +966,19 @@ void Proxy::quiet_cq() { if (outstanding_batches() == 0 && empty_iters >= kConsecutiveEmptyToExit) { break; } - auto now = clock::now(); - if (now - last_log > std::chrono::milliseconds(1000)) { + // auto now = clock::now(); + if (clock::now() - last_log > std::chrono::milliseconds(1000)) { fprintf(stderr, "[quiet] polling... outstanding=%zu\n", outstanding_batches()); - last_log = now; + // last_log = now; } } } void Proxy::quiet(std::vector wrs, std::vector cmds) { assert(cmds.size() == 1 && "quiet size must be 1"); + assert(wrs.size() == 1 && "wrs size must be 1"); + // printf("rank: %d, thread: %d, quiet\n", cfg_.rank, cfg_.thread_idx); quiet_cq(); acked_wrs_.insert(wrs[0]); } @@ -1035,6 +1086,8 @@ void Proxy::post_barrier_msg(int dst_rank, bool ack, uint64_t seq) { fprintf(stderr, "barrier_msg: bad ctx for dst=%d\n", dst_rank); std::abort(); } + // assert seq smaller than 22 bits + assert(seq <= 0x1FFFFFu); uint32_t imm = BarrierImm::Pack(ack, (uint32_t)seq, (uint8_t)cfg_.rank); #ifdef EFA auto* qpx = (struct ibv_qp_ex*)ctx->qp; @@ -1082,13 +1135,14 @@ void Proxy::post_barrier_msg(int dst_rank, bool ack, uint64_t seq) { } void Proxy::send_barrier(uint64_t wr) { + // printf("rank: %d, thread: %d, send_barrier\n", cfg_.rank, cfg_.thread_idx); #ifndef USE_MSCCLPP_FIFO_BACKEND assert(!ctx_.barrier_inflight && "only one barrier at a time"); ctx_.barrier_inflight = true; #endif assert(ctx_.barrier_wr == -1 && "barrier_wr should be 0"); ctx_.barrier_wr = wr; - ctx_.barrier_seq = ctx_.barrier_seq + 1; + ctx_.barrier_seq = (ctx_.barrier_seq + 1) & BarrierImm::kSeqMask; if (cfg_.rank == ctx_.node_leader_rank) { if (ctx_.barrier_arrived.size() != static_cast(cfg_.num_nodes)) { @@ -1166,7 +1220,9 @@ void Proxy::barrier_check() { } // When global release comes back (CQ handler should set these): - if (ctx_.barrier_released && ctx_.barrier_release_seq == seq) { + // NOTE: BarrierImm is 21 bits, so we must mask the local seq. + if (ctx_.barrier_released && + ctx_.barrier_release_seq == seq) { // Reset local mask for next barrier and consume the global release ctx_.barrier_released = false; @@ -1190,8 +1246,8 @@ void Proxy::barrier_check() { if (cfg_.rank == ctx_.node_leader_rank) { bool all_local_arrived = true; for (int lr = 0; lr < ctx_.num_local_ranks; ++lr) { - uint32_t seen = lb->arrive_seq[lr].load(std::memory_order_acquire); - if ((uint32_t)seen != seq) { + uint64_t seen = lb->arrive_seq[lr].load(std::memory_order_acquire); + if (seen != seq) { all_local_arrived = false; break; } diff --git a/ep/src/rdma.cpp b/ep/src/rdma.cpp index 78a66c22..69ffbe2f 100644 --- a/ep/src/rdma.cpp +++ b/ep/src/rdma.cpp @@ -96,6 +96,57 @@ void send_connection_info_as_client(int my_rank, int peer, char const* peer_ip, close(sockfd); } +// Helper function to get the root complex PCI bus identifier from a device path +// Returns something like "pci0000:00" or "pci0000:10" to match the terminal output format +static std::string get_root_complex_id(std::filesystem::path const& dev_path) { + static const std::regex pci_bus_re(R"(pci0000:[0-9a-fA-F]+)", std::regex_constants::icase); + + try { + // Canonicalize the path to get the full sysfs path + // e.g., /sys/devices/pci0000:00/0000:00:02.0/0000:02:00.0/... + std::filesystem::path canonical_path; + try { + canonical_path = std::filesystem::canonical(dev_path); + } catch (...) { + // If canonical fails, try as-is + canonical_path = dev_path; + } + + // Convert to string and search for pci0000:XX pattern + std::string path_str = canonical_path.string(); + std::smatch match; + if (std::regex_search(path_str, match, pci_bus_re)) { + std::string result = match.str(); + // Normalize to lowercase for consistent comparison + std::transform(result.begin(), result.end(), result.begin(), ::tolower); + return result; + } + + // Fallback: walk up the path looking for a component that matches + std::filesystem::path p = canonical_path; + while (p != p.root_path() && p != p.parent_path()) { + std::string component = p.filename().string(); + if (std::regex_match(component, pci_bus_re)) { + // Normalize to lowercase for consistent comparison + std::transform(component.begin(), component.end(), component.begin(), ::tolower); + return component; + } + p = p.parent_path(); + } + } catch (...) { + // If everything fails, try the original path + std::string path_str = dev_path.string(); + std::smatch match; + if (std::regex_search(path_str, match, pci_bus_re)) { + std::string result = match.str(); + std::transform(result.begin(), result.end(), result.begin(), ::tolower); + return result; + } + } + + return ""; // Not found +} + void per_thread_rdma_init(ProxyCtx& S, void* gpu_buf, size_t bytes, int rank, int thread_idx, int local_rank) { if (S.context) return; // already initialized @@ -115,9 +166,19 @@ void per_thread_rdma_init(ProxyCtx& S, void* gpu_buf, size_t bytes, int rank, auto ib_nics = uccl::get_rdma_nics(); // Get GPU pcie path auto gpu_device_path = gpu_cards[gpu_idx]; + // Get GPU root complex PCI bus identifier (e.g., "pci0000:00") + std::string gpu_root_complex = get_root_complex_id(gpu_device_path); + fprintf(stderr, "[DEBUG] GPU %d path: %s, root complex: %s\n", + gpu_idx, gpu_device_path.c_str(), + gpu_root_complex.empty() ? "(empty)" : gpu_root_complex.c_str()); + // Find the RDMA NIC that is closest to the GPU. std::vector> dist; dist.reserve(ib_nics.size()); + + // Separate NICs by root complex: same root complex vs different + std::vector> same_root_complex_dist; + std::vector> different_root_complex_dist; // Conforming to UCCL_IB_HCA filter. char* ib_hca = getenv("UCCL_IB_HCA"); @@ -136,23 +197,49 @@ void per_thread_rdma_init(ProxyCtx& S, void* gpu_buf, size_t bytes, int rank, continue; } uint32_t d = uccl::safe_pcie_distance(gpu_device_path, nic.second); + std::string nic_root_complex = get_root_complex_id(nic.second); + + // Prioritize NICs on the same root complex + if (!gpu_root_complex.empty() && !nic_root_complex.empty() && + gpu_root_complex == nic_root_complex) { + same_root_complex_dist.emplace_back(nic.first, d); + fprintf(stderr, "[DEBUG] NIC %s matches GPU root complex %s (distance %u)\n", + nic.first.c_str(), gpu_root_complex.c_str(), d); + } else { + different_root_complex_dist.emplace_back(nic.first, d); + if (!nic_root_complex.empty()) { + fprintf(stderr, "[DEBUG] NIC %s root complex %s != GPU root complex %s (distance %u)\n", + nic.first.c_str(), nic_root_complex.c_str(), + gpu_root_complex.empty() ? "(empty)" : gpu_root_complex.c_str(), d); + } + } dist.emplace_back(nic.first, d); } - // Find the NIC with the minimum distance. + // Find the NIC with the minimum distance, prioritizing same root complex. if (dist.empty()) { fprintf(stderr, "[WARN] no NIC found, defaulting to empty\n"); selected_nic_name.clear(); } else { - // Find the minimum distance + // Use same root complex NICs if available, otherwise fall back to all NICs + std::vector>* dist_to_use = &dist; + if (!same_root_complex_dist.empty()) { + dist_to_use = &same_root_complex_dist; + fprintf(stderr, "[DEBUG] Using %zu same-root-complex NICs (out of %zu total)\n", + same_root_complex_dist.size(), dist.size()); + } else { + fprintf(stderr, "[DEBUG] No same-root-complex NICs found, using all %zu NICs\n", + dist.size()); + } + auto min_it = std::min_element( - dist.begin(), dist.end(), + dist_to_use->begin(), dist_to_use->end(), [](auto const& a, auto const& b) { return a.second < b.second; }); auto min_d = min_it->second; // Collect all NICs with equal minimum distance std::vector candidates; - for (auto& p : dist) { + for (auto& p : *dist_to_use) { #ifdef EFA if (p.second == min_d && strncmp(p.first.c_str(), "rdmap", 5) == 0) candidates.push_back(p.first); @@ -164,7 +251,7 @@ void per_thread_rdma_init(ProxyCtx& S, void* gpu_buf, size_t bytes, int rank, if (candidates.empty()) { fprintf(stderr, "[WARN] no candidate NIC found, defaulting to first\n"); - selected_nic_name = dist.front().first; + selected_nic_name = dist_to_use->front().first; } else { // Spread GPUs across equal-distance NICs: use local GPU index modulo // For example, pass in `local_rank` or derive gpu_index from device path @@ -226,15 +313,14 @@ void per_thread_rdma_init(ProxyCtx& S, void* gpu_buf, size_t bytes, int rank, exit(1); } uint64_t iova = (uintptr_t)gpu_buf; -#ifndef EFA +#if !defined(EFA) && !defined(SOFTWARE_ORDERING) S.mr = ibv_reg_mr_iova2(S.pd, gpu_buf, bytes, iova, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | - IBV_ACCESS_REMOTE_ATOMIC | - IBV_ACCESS_RELAXED_ORDERING); + IBV_ACCESS_REMOTE_ATOMIC); #else S.mr = ibv_reg_mr_iova2(S.pd, gpu_buf, bytes, iova, - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | - IBV_ACCESS_RELAXED_ORDERING); + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_RELAXED_ORDERING); #endif if (!S.mr) { @@ -352,7 +438,8 @@ struct ibv_qp* create_srd_qp_ex(ProxyCtx& S) { void create_per_thread_qp(ProxyCtx& S, void* gpu_buffer, size_t size, RDMAConnectionInfo* local_info, int rank, - size_t num_rings, bool use_normal_mode) { + size_t num_rings, bool use_normal_mode, + void* atomic_buffer_ptr) { if (S.qp) return; // Already initialized for this thread if (S.ack_qp) return; if (S.recv_ack_qp) return; @@ -421,6 +508,7 @@ void create_per_thread_qp(ProxyCtx& S, void* gpu_buffer, size_t size, exit(1); } ncclIbGetGidIndex(S.context, 1, &port_attr, &S.gid_index); + S.gid_index = 3; local_info->qp_num = S.qp->qp_num; local_info->ack_qp_num = S.ack_qp->qp_num; local_info->recv_ack_qp_num = S.recv_ack_qp->qp_num; @@ -430,6 +518,30 @@ void create_per_thread_qp(ProxyCtx& S, void* gpu_buffer, size_t size, local_info->len = size; local_info->psn = 0; local_info->ack_psn = 0; + + // Populate atomic buffer info if available + // Use S.atomic_buffer_mr if it exists (even if atomic_buffer_ptr is nullptr + // for this thread) This ensures all threads exchange the same atomic buffer + // info + if (S.atomic_buffer_mr) { +#ifdef EFA + assert(false && "This path should not happen for EFA"); +#endif + local_info->atomic_buffer_rkey = S.atomic_buffer_mr->rkey; + local_info->atomic_buffer_addr = + reinterpret_cast(S.atomic_buffer_mr->addr); + local_info->atomic_buffer_len = S.atomic_buffer_mr->length; + fprintf(stderr, + "[create_per_thread_qp] Populated atomic buffer info: addr=0x%llx, " + "len=%zu, rkey=0x%x\n", + (unsigned long long)local_info->atomic_buffer_addr, + (size_t)local_info->atomic_buffer_len, + local_info->atomic_buffer_rkey); + } else { + // TODO(MaoZiming): Only for non-EFA case. + assert(false && "Atomic buffer is not registered"); + } + fill_local_gid(S, local_info); } @@ -486,7 +598,7 @@ struct ibv_ah* create_ah(ProxyCtx& S, uint8_t* remote_gid) { struct ibv_ah_attr ah_attr = {}; ah_attr.is_global = 1; // Enable Global Routing Header (GRH) ah_attr.port_num = 1; - ah_attr.grh.sgid_index = 0; // Local GID index + ah_attr.grh.sgid_index = 3; // Local GID index memcpy(&ah_attr.grh.dgid, remote_gid, 16); ah_attr.grh.flow_label = 0; ah_attr.grh.hop_limit = 255; @@ -547,13 +659,13 @@ void modify_qp_to_rtr(ProxyCtx& S, RDMAConnectionInfo* remote, attr.path_mtu = port_attr.active_mtu; attr.dest_qp_num = remote->qp_num; attr.rq_psn = remote->psn; - attr.max_dest_rd_atomic = 1; + attr.max_dest_rd_atomic = 32; attr.min_rnr_timer = 12; if (is_roce) { attr.ah_attr.is_global = 1; attr.ah_attr.port_num = 1; - attr.ah_attr.sl = 135; + attr.ah_attr.sl = 0; attr.ah_attr.src_path_bits = 0; attr.ah_attr.grh.traffic_class = 3; attr.ah_attr.grh.hop_limit = 64; @@ -638,12 +750,12 @@ void modify_qp_to_rts(ProxyCtx& S, RDMAConnectionInfo* local_info) { attr.retry_cnt = 7; attr.rnr_retry = 7; attr.sq_psn = local_info->psn; - attr.max_rd_atomic = 1; - attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_ATOMIC; + attr.max_rd_atomic = 32; + // attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_ATOMIC; int flags = IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC | - IBV_QP_ACCESS_FLAGS; + 0; if (ibv_modify_qp(S.qp, &attr, flags)) { perror("Failed to modify QP to RTS"); @@ -748,8 +860,7 @@ static void post_rdma_async_batched_normal_mode( ring_to_indices.reserve(k); for (size_t j = 0; j < k; ++j) { size_t i = wr_ids[j]; - size_t ring_idx = - static_cast((wrs_to_post[i] >> 32) & 0xFFFFFFFFu); + size_t ring_idx = static_cast(ring_wr_idx(wrs_to_post[i])); ring_to_indices[ring_idx].push_back(i); } @@ -849,6 +960,126 @@ static void post_rdma_async_batched_normal_mode( dst_rank, strerror(ret), ret); std::abort(); } +#elif defined(SOFTWARE_ORDERING) + { + size_t const local_ring_count = ctx->data_qps_by_channel.size(); + struct ibv_qp* qp = + local_ring_count + ? ctx->data_qps_by_channel[ring_idx_raw % local_ring_count] + : ctx->ack_qp; + + size_t const kgroup = idxs.size(); + std::vector sges(kgroup); + std::vector wrs(kgroup); + std::vector ring_wrids; + ring_wrids.reserve(kgroup); + + for (size_t j = 0; j < kgroup; ++j) { + size_t i = idxs[j]; + auto const& cmd = cmds_to_post[i]; + ring_wrids.push_back(wrs_to_post[i]); + + // Remote address bounds check + uint64_t remote_addr = + ctx->remote_addr + (cmd.req_rptr ? cmd.req_rptr : 0); + uint64_t remote_end = ctx->remote_addr + ctx->remote_len; + + if (remote_addr < ctx->remote_addr || + remote_addr + cmd.bytes > remote_end) { + fprintf( + stderr, + "[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, " + "size=%zu), cmd.req_rptr: 0x%llx\n", + (unsigned long long)remote_addr, cmd.bytes, + (unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len, + (unsigned long long)cmd.req_rptr); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", + cudaGetErrorString(err)); + } + std::abort(); + } + + // Local SGE + uintptr_t laddr = + cmd.req_lptr + reinterpret_cast(ctx->mr->addr); + sges[j] = { + .addr = laddr, + .length = static_cast(cmd.bytes), + .lkey = ctx->mr->lkey, + }; + + // Build WR + std::memset(&wrs[j], 0, sizeof(wrs[j])); + wrs[j].wr_id = wrs_to_post[i]; + wrs[j].sg_list = &sges[j]; + wrs[j].num_sge = 1; + wrs[j].wr.rdma.remote_addr = remote_addr; + wrs[j].wr.rdma.rkey = ctx->remote_rkey; + wrs[j].opcode = IBV_WR_RDMA_WRITE; // default + wrs[j].send_flags = (j + 1 == kgroup) ? IBV_SEND_SIGNALED : 0; + wrs[j].next = (j + 1 < kgroup) ? &wrs[j + 1] : nullptr; + + if (cmd.atomic_offset > 0 && cmd.atomic_val > 0) { + int v = static_cast(cmd.atomic_val); + if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { + fprintf(stderr, "atomic value=%d won't fit in 15 bits\n", v); + std::abort(); + } + size_t index = static_cast(cmd.atomic_offset / sizeof(int)); + // Initialize missing entries lazily + auto key = ctx->seq_key(dst_rank, index); + if (ctx->next_seq_per_index.find(key) == + ctx->next_seq_per_index.end()) + ctx->next_seq_per_index[key] = 0; + + uint8_t seq = ctx->next_seq_per_index[key]; + ctx->next_seq_per_index[key] = + (seq + 1) % kReorderingBufferSize; // 4-bit wrap (0–15) + uint32_t imm = + AtomicsImm::PackAtomicWithSeq(v, cmd.atomic_offset, seq, true) + .GetImmData(); + AtomicsImm aimm(imm); + assert(aimm.GetSeq() == seq); + + wrs[j].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wrs[j].imm_data = htonl(imm); + + assert(aimm.GetValue() == cmd.atomic_val); + assert(aimm.GetOff() == cmd.atomic_offset); + } else { + wrs[j].opcode = IBV_WR_RDMA_WRITE; + } + } + + // Post the chain + ibv_send_wr* bad = nullptr; + int ret = ibv_post_send(qp, &wrs[0], &bad); + if (ret) { + fprintf(stderr, "ibv_post_send failed (dst=%d): %s (ret=%d)\n", + dst_rank, strerror(ret), ret); + if (bad) + fprintf(stderr, "Bad WR at %p (wr_id=%lu)\n", (void*)bad, + bad->wr_id); + std::abort(); + } + + // Track wr_id mappings for SOFTWARE_ORDERING + size_t const last = kgroup - 1; + uint64_t const batch_tail_wr = ring_wrids[last]; + { + auto [it, inserted] = S.wr_id_to_wr_ids.try_emplace( + batch_tail_wr, std::move(ring_wrids)); + if (!inserted) { + fprintf(stderr, + "thread_idx: %d, Error: tail wr_id %lu already exists " + "(map=%p)\n", + thread_idx, batch_tail_wr, (void*)&S.wr_id_to_wr_ids); + std::abort(); + } + } + } #else { size_t const local_ring_count = ctx->data_qps_by_channel.size(); @@ -1284,9 +1515,18 @@ void local_process_completions(ProxyCtx& S, break; case IBV_WC_FETCH_ADD: { uint64_t wrid = wc[i].wr_id; - printf("Local thread %d: atomic completed (wr_id=0x%lx)\n", thread_idx, - wrid); - assert(false && "Atomic not expected on local proxy"); + auto it = S.wr_id_to_wr_ids.find(wrid); + if (it != S.wr_id_to_wr_ids.end()) { + for (uint64_t sub_wr : it->second) { + acked_wrs.insert(sub_wr); + } + S.wr_id_to_wr_ids.erase(it); + } else { + fprintf(stderr, + "[Atomic] No batch found for wr_id=0x%lx, treating as single " + "(map_size=%zu)\n", + wrid, S.wr_id_to_wr_ids.size()); + } } break; default: break; @@ -1423,7 +1663,6 @@ void remote_process_completions_normal_mode( int value = aimm.GetValue(); uint32_t offset = aimm.GetOff(); size_t index = offset / sizeof(int); - auto* addr32 = reinterpret_cast*>(atomic_buffer_ptr) + index; @@ -1432,6 +1671,10 @@ void remote_process_completions_normal_mode( if (!aimm.IsReorderable()) { addr32->fetch_add(value, std::memory_order_release); } else { +#if !defined(EFA) && !defined(SOFTWARE_ORDERING) + assert(false && + "Reorderable atomic operations should not be triggered"); +#endif struct SeqBuf { uint8_t expected = 0; // next seq expected uint16_t present_mask = 0; // bitmask of buffered seqs @@ -1972,13 +2215,11 @@ static void post_atomic_operations_normal_mode( ProxyCtx* ctx = ctxs[dst_rank].get(); size_t const k = wr_ids.size(); - // Group by ring index (upper 32 bits in wrs_to_post) std::unordered_map> ring_to_indices; ring_to_indices.reserve(k); for (size_t ii = 0; ii < k; ++ii) { size_t global_i = wr_ids[ii]; - size_t ring_idx = - static_cast((wrs_to_post[global_i] >> 32) & 0xFFFFFFFFu); + size_t ring_idx = static_cast(ring_wr_idx(wrs_to_post[global_i])); ring_to_indices[ring_idx].push_back(global_i); } @@ -2065,7 +2306,7 @@ static void post_atomic_operations_normal_mode( group_wrids.push_back(wr_id); int v = static_cast(cmd.value); - if (v > kLargeAtomicValue) v = kMaxSendAtomicValue; // saturate for imm + if (v > kLargeAtomicValue) v = kMaxSendAtomicValue; if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { fprintf(stderr, "value=%d (cmd.value=%lu) won't fit in 15 bits for imm; " @@ -2277,6 +2518,215 @@ static void post_atomic_operations_fast_mode( } } +// Native RDMA implementation (non-EFA) +static void post_atomic_operations_native_rdma( + ProxyCtx& S, std::vector const& wrs_to_post, + std::vector const& cmds_to_post, + std::vector>& ctxs, int my_rank, int thread_idx, + std::unordered_set& acked_wrs) { + if (cmds_to_post.size() > ProxyCtx::kMaxAtomicOps) { + fprintf(stderr, "Too many atomic operations: %zu > %zu\n", + cmds_to_post.size(), ProxyCtx::kMaxAtomicOps); + std::abort(); + } + + std::unordered_map> dst_rank_wr_ids; + dst_rank_wr_ids.reserve(cmds_to_post.size()); + for (size_t i = 0; i < wrs_to_post.size(); ++i) { + int dst = static_cast(cmds_to_post[i].dst_rank); + if (dst == my_rank) { + fprintf(stderr, "Posting atomic to itself\n"); + std::abort(); + } + dst_rank_wr_ids[dst].push_back(i); + } + + for (auto& [dst_rank, wr_ids] : dst_rank_wr_ids) { + if (wr_ids.empty()) continue; + + ProxyCtx* ctx = ctxs[dst_rank].get(); + size_t const k = wr_ids.size(); + + std::unordered_map> ring_to_indices; + ring_to_indices.reserve(k); + for (size_t ii = 0; ii < k; ++ii) { + size_t global_i = wr_ids[ii]; + size_t ring_idx = static_cast(ring_wr_idx(wrs_to_post[global_i])); + ring_to_indices[ring_idx].push_back(global_i); + } + + for (auto& [ring_idx_raw, idxs] : ring_to_indices) { + size_t const local_ring_count = ctx->data_qps_by_channel.size(); + struct ibv_qp* qp = + local_ring_count + ? ctx->data_qps_by_channel[ring_idx_raw % local_ring_count] + : ctx->ack_qp; + + size_t const k = idxs.size(); + std::vector sge(k); + std::vector wr(k); + std::vector group_wrids; + group_wrids.reserve(k); + + // Verify atomic buffer is properly aligned (8-byte for 64-bit atomics) + uintptr_t atomic_buf_addr = + reinterpret_cast(S.atomic_old_values_buf); + if ((atomic_buf_addr & 0x7) != 0) { + fprintf( + stderr, + "[Native RDMA] atomic_old_values_buf not 8-byte aligned: 0x%llx\n", + (unsigned long long)atomic_buf_addr); + std::abort(); + } + + // Verify atomic buffer is within registered memory region + uintptr_t mr_addr = reinterpret_cast(S.mr->addr); + uintptr_t mr_end = mr_addr + S.mr->length; + uintptr_t atomic_buf_end = atomic_buf_addr + k * sizeof(uint64_t); + if (atomic_buf_addr < mr_addr || atomic_buf_end > mr_end) { + fprintf(stderr, + "[Native RDMA] atomic buffer out of bounds: buf=0x%llx-0x%llx, " + "mr=0x%llx-0x%llx\n", + (unsigned long long)atomic_buf_addr, + (unsigned long long)atomic_buf_end, (unsigned long long)mr_addr, + (unsigned long long)mr_end); + std::abort(); + } + + uint64_t* atomic_old_values_64 = + reinterpret_cast(S.atomic_old_values_buf); + + for (size_t t = 0; t < k; ++t) { + size_t i = idxs[t]; + auto const& cmd = cmds_to_post[i]; + uint64_t const wr_id = wrs_to_post[i]; + group_wrids.push_back(wr_id); + + int v = static_cast(cmd.value); + if (v >= kLargeAtomicValue) { + printf("Large atomic value: %d\n", v); + // v = kMaxSendAtomicValue; + } + + // Convert 32-bit signed int to 64-bit for RDMA atomics + // IBV_WR_ATOMIC_FETCH_AND_ADD requires 64-bit operations + int64_t v64 = static_cast(static_cast(v)); + + // Calculate remote address - must be 8-byte aligned for 64-bit RDMA + // atomics cmd.req_rptr is an offset relative to atomic_base_addr (local + // atomic_buffer_ptr) Use the remote atomic buffer address if available, + // otherwise fall back to remote_addr + if ((cmd.req_rptr & 0x7) != 0) { + fprintf(stderr, "[Native RDMA] req_rptr not 8-byte aligned: 0x%x\n", + cmd.req_rptr); + std::abort(); + } + uint64_t remote_atomic_addr; + if (ctx->remote_atomic_buffer_addr != 0) { + // Use registered atomic buffer + remote_atomic_addr = ctx->remote_atomic_buffer_addr + cmd.req_rptr; + } else { + assert(false && "Atomic buffer is not registered"); + } + + // Verify final address alignment (should always be true if req_rptr is + // aligned) + assert((remote_atomic_addr & 0x7) == 0 && + "Remote atomic address must be 8-byte aligned"); + + // Verify remote address is within bounds of the atomic buffer + if (ctx->remote_atomic_buffer_addr == 0) { + fprintf(stderr, + "[Native RDMA] Remote atomic buffer not registered\n"); + std::abort(); + } + if (remote_atomic_addr < ctx->remote_atomic_buffer_addr || + remote_atomic_addr + sizeof(uint64_t) > + ctx->remote_atomic_buffer_addr + + ctx->remote_atomic_buffer_len) { + fprintf(stderr, + "[Native RDMA] Remote atomic address out of bounds: 0x%llx " + "(base=0x%llx, len=%zu)\n", + (unsigned long long)remote_atomic_addr, + (unsigned long long)ctx->remote_atomic_buffer_addr, + (size_t)ctx->remote_atomic_buffer_len); + std::abort(); + } + + // Local SGE: point to local buffer where old value will be stored + // (64-bit) Use local context S's memory region, not destination ctx + uintptr_t local_addr = + reinterpret_cast(&atomic_old_values_64[t]); + // Double-check address is within bounds + if (local_addr < mr_addr || local_addr + sizeof(uint64_t) > mr_end) { + fprintf(stderr, + "[Native RDMA] Local atomic address out of bounds: 0x%llx\n", + (unsigned long long)local_addr); + std::abort(); + } + + sge[t].addr = local_addr; + sge[t].length = sizeof(uint64_t); + sge[t].lkey = S.mr->lkey; + + std::memset(&wr[t], 0, sizeof(wr[t])); + wr[t].wr_id = wr_id; + wr[t].opcode = IBV_WR_ATOMIC_FETCH_AND_ADD; + wr[t].send_flags = (t + 1 == k) ? IBV_SEND_SIGNALED : 0; + wr[t].sg_list = &sge[t]; + wr[t].num_sge = 1; + wr[t].wr.atomic.remote_addr = remote_atomic_addr; + // Use remote atomic buffer rkey if available, otherwise use main buffer + // rkey + assert(ctx->remote_atomic_buffer_rkey != 0); + wr[t].wr.atomic.rkey = ctx->remote_atomic_buffer_rkey; + wr[t].wr.atomic.compare_add = v64; // 64-bit value to add + wr[t].next = (t + 1 < k) ? &wr[t + 1] : nullptr; + } + + ibv_send_wr* bad = nullptr; + int ret = ibv_post_send(qp, &wr[0], &bad); + if (ret) { + fprintf(stderr, "[Native RDMA] ibv_post_send(atomic) failed: %d (%s)\n", + ret, strerror(ret)); + if (bad) { + fprintf(stderr, " bad wr_id=0x%llx opcode=%u\n", + (unsigned long long)bad->wr_id, bad->opcode); + } + std::abort(); + } + // The completion will have wr_id = wr[k-1].wr_id (the last signaled WR) + // Store this exact value to ensure lookup matches + uint64_t const batch_tail_wr = wr[k - 1].wr_id; + // Verify this matches group_wrids.back() (should always be true) + if (batch_tail_wr != group_wrids.back()) { + fprintf( + stderr, + "[Native RDMA] ERROR: batch_tail_wr (0x%lx) != group_wrids.back() " + "(0x%lx)\n", + batch_tail_wr, group_wrids.back()); + std::abort(); + } + { + auto [it, inserted] = S.wr_id_to_wr_ids.try_emplace( + batch_tail_wr, std::move(group_wrids)); + + // printf("[Native RDMA] batch_tail_wr: 0x%lx, map_size: %zu, dst_rank: + // %d\n", batch_tail_wr, it->second.size(), dst_rank); + if (!inserted) { + fprintf(stderr, + "thread_idx: %d, Error: tail wr_id %lu already exists " + "(map=%p, " + "size=%zu, dst_rank=%d)\n", + thread_idx, batch_tail_wr, (void*)&S.wr_id_to_wr_ids, + S.wr_id_to_wr_ids.size(), dst_rank); + std::abort(); + } + } + } + } +} + // Wrapper that selects implementation based on use_normal_mode void post_atomic_operations(ProxyCtx& S, std::vector const& wrs_to_post, @@ -2286,8 +2736,13 @@ void post_atomic_operations(ProxyCtx& S, std::unordered_set& acked_wrs, bool use_normal_mode) { if (use_normal_mode) { +#if (defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)) && !defined(SOFTWARE_ORDERING) + post_atomic_operations_native_rdma(S, wrs_to_post, cmds_to_post, ctxs, + my_rank, thread_idx, acked_wrs); +#else post_atomic_operations_normal_mode(S, wrs_to_post, cmds_to_post, ctxs, my_rank, thread_idx, acked_wrs); +#endif } else { post_atomic_operations_fast_mode(S, wrs_to_post, cmds_to_post, ctxs, my_rank, thread_idx, acked_wrs); diff --git a/ep/src/uccl_proxy.cpp b/ep/src/uccl_proxy.cpp index 53bbd00f..6380d126 100644 --- a/ep/src/uccl_proxy.cpp +++ b/ep/src/uccl_proxy.cpp @@ -57,10 +57,12 @@ UcclProxy::UcclProxy(int thread_idx, uintptr_t gpu_buffer_addr, if (thread_idx == 0) { #ifdef USE_GRACE_HOPPER cudaMallocManaged(&atomic_buffer_ptr_, kAtomicBufferSize); -#elif defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) +#elif defined(SOFTWARE_ORDERING) cudaHostAlloc(&atomic_buffer_ptr_, kAtomicBufferSize, - cudaHostAllocMapped | cudaHostAllocWriteCombined | - hipHostMallocUncached); + cudaHostAllocMapped | cudaHostAllocWriteCombined); +#elif defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + hipExtMallocWithFlags(&atomic_buffer_ptr_, kAtomicBufferSize, + hipDeviceMallocUncached); #else cudaHostAlloc(&atomic_buffer_ptr_, kAtomicBufferSize, cudaHostAllocMapped | cudaHostAllocWriteCombined);