diff --git a/ep/bench/test_internode.py b/ep/bench/test_internode.py index b20c86fb4..4a732b136 100644 --- a/ep/bench/test_internode.py +++ b/ep/bench/test_internode.py @@ -418,7 +418,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): # Tune combine performance best_time, best_results = 1e10, None - for nvl_chunk_size in range(1, 8, 1): + for nvl_chunk_size in range(1, 20, 2): for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4): config = Config( num_sms, diff --git a/ep/include/common.hpp b/ep/include/common.hpp index 367c59cba..098808230 100644 --- a/ep/include/common.hpp +++ b/ep/include/common.hpp @@ -34,7 +34,7 @@ // imm for reordering buffer sequence tracking. #ifdef USE_MSCCLPP_FIFO_BACKEND #ifdef USE_NORMAL_MODE -#define kMaxInflight 8 +#define kMaxInflight 16 // Increased for better combine throughput #else #define kMaxInflight 32 #endif diff --git a/ep/include/proxy_ctx.hpp b/ep/include/proxy_ctx.hpp index f87f70ce0..977a63a27 100644 --- a/ep/include/proxy_ctx.hpp +++ b/ep/include/proxy_ctx.hpp @@ -2,10 +2,16 @@ #include "barrier_local.hpp" #include "util/gpu_rt.h" #include +#include #include #include #include #include +#ifdef USE_NORMAL_MODE +#include +#endif + +struct TransferCmd; // forward declaration template class TokenCounter { @@ -23,6 +29,51 @@ class TokenCounter { MapType counter_; }; +// Fast array-based token counter for combine operations +// Assumes max 4 buffers, 512 experts +class FastCombineTokenCounter { + public: + static constexpr size_t kMaxBuffers = 8; // Increased to be safe + static constexpr size_t kMaxExperts = 512; + + void Add(int buffer_idx, int expert_idx, size_t k) { + if (buffer_idx >= 0 && buffer_idx < kMaxBuffers && expert_idx >= 0 && + expert_idx < kMaxExperts) { + counters_[buffer_idx][expert_idx] += k; + } else { + // Fallback to map for out-of-range + fallback_[{buffer_idx, expert_idx}] += k; + } + } + + size_t Get(int buffer_idx, int expert_idx) const { + if (buffer_idx >= 0 && buffer_idx < kMaxBuffers && expert_idx >= 0 && + expert_idx < kMaxExperts) { + return counters_[buffer_idx][expert_idx]; + } + auto it = fallback_.find({buffer_idx, expert_idx}); + return (it == fallback_.end()) ? 0 : it->second; + } + + void Reset(int buffer_idx, int expert_idx) { + if (buffer_idx >= 0 && buffer_idx < kMaxBuffers && expert_idx >= 0 && + expert_idx < kMaxExperts) { + counters_[buffer_idx][expert_idx] = 0; + } else { + fallback_[{buffer_idx, expert_idx}] = 0; + } + } + + void Clear() { + memset(counters_, 0, sizeof(counters_)); + fallback_.clear(); + } + + private: + size_t counters_[kMaxBuffers][kMaxExperts] = {}; + mutable std::map, size_t> fallback_; // For out-of-range +}; + using DispatchTokenKey = std::tuple; using CombineTokenKey = std::pair; using NormalTokenKey = std::pair; @@ -91,7 +142,7 @@ struct ProxyCtx { uint32_t tag = 0; TokenCounter dispatch_token_counter; - TokenCounter combine_token_counter; + FastCombineTokenCounter combine_token_counter; // Optimized for fast lookups TokenCounter normal_token_counter; /* low_latency_buffer_idx, expert_idx, dst_rank */ @@ -130,4 +181,44 @@ struct ProxyCtx { return (static_cast(static_cast(dst_rank)) << 32) ^ static_cast(static_cast(index)); } + +#ifdef USE_NORMAL_MODE + // Batching state for delayed transmission + struct BatchState { + std::vector wrs; + std::vector cmds; + std::chrono::steady_clock::time_point first_cmd_time; + bool has_pending = false; + }; + std::unordered_map pending_batches; // per dst_rank + static constexpr size_t kMaxBatchSize = 64; // Sweet spot for EFA UD mode + static constexpr int64_t kMaxBatchDelayUs = + 10; // Not used (size-only batching) + + // Pre-allocated buffers to avoid allocation in hot path + std::unordered_map> reusable_dst_rank_wr_ids; + std::unordered_map> reusable_ring_to_indices; + std::vector reusable_ring_wrids; + std::vector reusable_sges; + std::vector reusable_wrs; + + // Cache for EFA UD addressing - avoid repeated ibv_wr_set_ud_addr + struct UDAddrCache { + ibv_ah* ah = nullptr; + uint32_t qpn = 0; + uint32_t qkey = QKEY; + }; + std::unordered_map ud_addr_cache; // ring_idx -> cache + + // Pre-allocated array for sequence numbers (replacing map) + static constexpr size_t kSeqArraySize = + 16384; // Support up to 16K unique (dst_rank, index) pairs + std::array, kSeqArraySize> seq_array{}; + + // Hash function for sequence array indexing + inline size_t seq_hash(int dst_rank, size_t index) const { + // Simple hash combining dst_rank and index + return (static_cast(dst_rank) * 4096 + index) % kSeqArraySize; + } +#endif }; diff --git a/ep/include/rdma.hpp b/ep/include/rdma.hpp index cbba199ef..ef04c9622 100644 --- a/ep/include/rdma.hpp +++ b/ep/include/rdma.hpp @@ -331,6 +331,14 @@ void post_rdma_async_batched(ProxyCtx& S, void* buf, size_t num_wrs, std::vector const& cmds_to_post, std::vector>& ctxs, int my_rank, int thread_idx); +#ifdef USE_NORMAL_MODE +void flush_pending_batch_for_dst(ProxyCtx& S, int dst_rank, void* buf, + std::vector>& ctxs, + int my_rank, int thread_idx); +void flush_all_pending_batches(ProxyCtx& S, void* buf, + std::vector>& ctxs, + int my_rank, int thread_idx); +#endif void local_process_completions(ProxyCtx& S, std::unordered_set& acked_wrs, int thread_idx, ibv_wc* wc, int ne, diff --git a/ep/src/proxy.cpp b/ep/src/proxy.cpp index 22b88c071..83cc09632 100644 --- a/ep/src/proxy.cpp +++ b/ep/src/proxy.cpp @@ -393,6 +393,9 @@ void Proxy::run_dual() { uint64_t my_tail = 0; size_t seen = 0; std::set pending_atomic_updates; +#ifdef USE_NORMAL_MODE + auto last_flush_check = std::chrono::steady_clock::now(); +#endif while (ctx_.progress_run.load(std::memory_order_acquire)) { poll_cq_dual(ctx_, acked_wrs_, cfg_.thread_idx, ring, ctx_by_tag_, atomic_buffer_ptr_, cfg_.num_ranks, cfg_.num_experts, @@ -416,6 +419,10 @@ void Proxy::run_dual() { #ifdef USE_NORMAL_MODE barrier_check(); + + // Note: Periodic flush removed - using size-only batching + // (kMaxBatchSize=64) for maximum throughput. Time-based flushing overhead + // was degrading performance. #endif } } diff --git a/ep/src/rdma.cpp b/ep/src/rdma.cpp index 07c7453ea..14aa605db 100644 --- a/ep/src/rdma.cpp +++ b/ep/src/rdma.cpp @@ -689,82 +689,184 @@ void post_receive_buffer_for_imm(ProxyCtx& S) { } #ifdef USE_NORMAL_MODE -void post_rdma_async_batched(ProxyCtx& S, void* buf, size_t num_wrs, - std::vector const& wrs_to_post, - std::vector const& cmds_to_post, - std::vector>& ctxs, - int my_rank, int thread_idx) { - if (num_wrs == 0) return; - if (wrs_to_post.size() != num_wrs || cmds_to_post.size() != num_wrs) { - fprintf(stderr, "Size mismatch (num_wrs=%zu, wr_ids=%zu, cmds=%zu)\n", - num_wrs, wrs_to_post.size(), cmds_to_post.size()); +// Forward declaration +void flush_pending_batch_for_dst(ProxyCtx& S, int dst_rank, void* buf, + std::vector>& ctxs, + int my_rank, int thread_idx); + +void flush_all_pending_batches(ProxyCtx& S, void* buf, + std::vector>& ctxs, + int my_rank, int thread_idx) { + for (auto& [dst_rank, batch_state] : S.pending_batches) { + if (batch_state.has_pending && !batch_state.wrs.empty()) { + flush_pending_batch_for_dst(S, dst_rank, buf, ctxs, my_rank, thread_idx); + } + } +} + +void flush_pending_batch_for_dst(ProxyCtx& S, int dst_rank, void* buf, + std::vector>& ctxs, + int my_rank, int thread_idx) { + auto& batch = S.pending_batches[dst_rank]; + if (!batch.has_pending || batch.wrs.empty()) { + return; + } + + ProxyCtx* ctx = ctxs[dst_rank].get(); + if (!ctx || !ctx->qp || !ctx->mr) { + fprintf(stderr, "Destination ctx missing fields for dst=%d\n", dst_rank); std::abort(); } - std::unordered_map> dst_rank_wr_ids; - for (size_t i = 0; i < num_wrs; ++i) { - if (cmds_to_post[i].dst_rank == static_cast(my_rank)) { - // NOTE(MaoZiming): this should not happen. - printf("Posting rdma to itself\n"); - std::abort(); - continue; - } else if (std::abs((int)cmds_to_post[i].dst_rank - (int)my_rank) % - MAX_NUM_GPUS != - 0) { - // NOTE(MaoZiming): this should not happen. - printf("Posting rdma to a different rank\n"); - std::abort(); - continue; - } else { - dst_rank_wr_ids[cmds_to_post[i].dst_rank].push_back(i); - } + size_t const num_wrs = batch.wrs.size(); + + // Group by ring index - use pre-allocated buffer (avoid repeated allocation) + auto& ring_to_indices = S.reusable_ring_to_indices; + // Fast clear: just mark vectors as empty without deallocating + for (auto& [key, vec] : ring_to_indices) { + vec.clear(); + } + // Only clear map if it's getting too large (avoid rehashing overhead) + if (ring_to_indices.size() > 16) { + ring_to_indices.clear(); } - for (auto& [dst_rank, wr_ids] : dst_rank_wr_ids) { - if (wr_ids.empty()) continue; + for (size_t j = 0; j < num_wrs; ++j) { + size_t ring_idx = static_cast((batch.wrs[j] >> 32) & 0xFFFFFFFFu); + ring_to_indices[ring_idx].push_back(j); + } - ProxyCtx* ctx = ctxs[dst_rank].get(); - if (!ctx || !ctx->qp || !ctx->mr) { - fprintf(stderr, "Destination ctx missing fields for dst=%d\n", dst_rank); - std::abort(); - } - size_t const k = wr_ids.size(); - std::unordered_map> ring_to_indices; - ring_to_indices.reserve(k); - for (size_t j = 0; j < k; ++j) { - size_t i = wr_ids[j]; - size_t ring_idx = - static_cast((wrs_to_post[i] >> 32) & 0xFFFFFFFFu); - ring_to_indices[ring_idx].push_back(i); + for (auto& [ring_idx_raw, idxs] : ring_to_indices) { +#ifdef EFA + const size_t local_ring_count = ctx->data_qps_by_channel.size(); + struct ibv_qp_ex* qpx = + (struct ibv_qp_ex*)(local_ring_count + ? ctx->data_qps_by_channel[ring_idx_raw % + local_ring_count] + : ctx->ack_qp); + + size_t const remote_ring_count = ctx->dst_data_qpn_by_ring.size(); + uint32_t const dst_qpn = + remote_ring_count + ? ctx->dst_data_qpn_by_ring[ring_idx_raw % remote_ring_count] + : ctx->dst_qpn; + + ibv_wr_start(qpx); + // Use pre-allocated buffer + auto& ring_wrids = S.reusable_ring_wrids; + ring_wrids.clear(); + ring_wrids.reserve(idxs.size()); + + // Cache UD address once per ring batch + auto& ud_cache = S.ud_addr_cache[ring_idx_raw]; + if (ud_cache.ah != ctx->dst_ah || ud_cache.qpn != dst_qpn) { + ud_cache.ah = ctx->dst_ah; + ud_cache.qpn = dst_qpn; } - for (auto& [ring_idx_raw, idxs] : ring_to_indices) { -#ifdef EFA - const size_t local_ring_count = ctx->data_qps_by_channel.size(); - struct ibv_qp_ex* qpx = - (struct ibv_qp_ex*)(local_ring_count - ? ctx->data_qps_by_channel[ring_idx_raw % - local_ring_count] - : ctx->ack_qp); + // Set UD address ONCE for entire batch (persists across WRs on same QP) + // This eliminates 64 redundant calls per batch + ibv_wr_set_ud_addr(qpx, ud_cache.ah, ud_cache.qpn, ud_cache.qkey); - size_t const remote_ring_count = ctx->dst_data_qpn_by_ring.size(); - uint32_t const dst_qpn = - remote_ring_count - ? ctx->dst_data_qpn_by_ring[ring_idx_raw % remote_ring_count] - : ctx->dst_qpn; + for (size_t j = 0; j < idxs.size(); ++j) { + size_t i = idxs[j]; + auto const& cmd = batch.cmds[i]; - ibv_wr_start(qpx); - // No receiver barrier: build a single chain for this ring group - std::vector ring_wrids; - ring_wrids.reserve(idxs.size()); + qpx->wr_id = batch.wrs[i]; + qpx->comp_mask = 0; + qpx->wr_flags = IBV_SEND_SIGNALED; - for (size_t j = 0; j < idxs.size(); ++j) { - size_t i = idxs[j]; - auto const& cmd = cmds_to_post[i]; + uint64_t remote_addr = + ctx->remote_addr + (cmd.req_rptr ? cmd.req_rptr : 0); + uint64_t remote_end = ctx->remote_addr + ctx->remote_len; - qpx->wr_id = wrs_to_post[i]; - qpx->comp_mask = 0; - qpx->wr_flags = IBV_SEND_SIGNALED; + if (remote_addr < ctx->remote_addr || + remote_addr + cmd.bytes > remote_end) { + fprintf(stderr, + "[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, " + "size=%zu), cmd.req_rptr: 0x%llx\n", + (unsigned long long)remote_addr, cmd.bytes, + (unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len, + (unsigned long long)cmd.req_rptr); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", + cudaGetErrorString(err)); + } + std::abort(); + } + + // Optionally send an inline "atomic" via imm, else use imm only on tail + if (cmd.atomic_offset > 0 && cmd.atomic_val > 0) { + int v = static_cast(cmd.atomic_val); + if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { + fprintf(stderr, "[EFA] atomic value=%d won't fit in 15 bits\n", v); + std::abort(); + } + size_t index = static_cast(cmd.atomic_offset / sizeof(int)); + + // Use pre-allocated array instead of map + size_t seq_idx = S.seq_hash(dst_rank, index); + uint8_t seq = + S.seq_array[seq_idx].fetch_add(1, std::memory_order_relaxed) % + kReorderingBufferSize; + + uint32_t imm = + AtomicsImm::PackAtomicWithSeq(v, cmd.atomic_offset, seq, true) + .GetImmData(); + AtomicsImm aimm(imm); + assert(aimm.GetSeq() == seq); + + ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); + } else if (j + 1 == idxs.size()) { + uint32_t imm = + WriteImm::Pack(get_is_combine(cmd.cmd_type), + get_low_latency(cmd.cmd_type), cmd.expert_idx, + (uint32_t)idxs.size(), my_rank) + .GetImmData(); + ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); + } else { + ibv_wr_rdma_write(qpx, ctx->remote_rkey, remote_addr); + } + + uintptr_t laddr = + cmd.req_lptr + reinterpret_cast(ctx->mr->addr); + // Set SGE for this WR + ibv_wr_set_sge(qpx, ctx->mr->lkey, laddr, + static_cast(cmd.bytes)); + + ring_wrids.push_back(batch.wrs[i]); + } + int ret = ibv_wr_complete(qpx); + if (ret) { + fprintf(stderr, "ibv_wr_complete failed (dst=%d): %s (ret=%d)\n", + dst_rank, strerror(ret), ret); + std::abort(); + } +#else + { + size_t const local_ring_count = ctx->data_qps_by_channel.size(); + struct ibv_qp* qp = + local_ring_count + ? ctx->data_qps_by_channel[ring_idx_raw % local_ring_count] + : ctx->ack_qp; + + size_t const kgroup = idxs.size(); + // Use pre-allocated buffers + auto& sges = S.reusable_sges; + auto& wrs = S.reusable_wrs; + auto& ring_wrids = S.reusable_ring_wrids; + sges.clear(); + wrs.clear(); + ring_wrids.clear(); + sges.resize(kgroup); + wrs.resize(kgroup); + ring_wrids.reserve(kgroup); + + for (size_t j = 0; j < kgroup; ++j) { + size_t i = idxs[j]; + auto const& cmd = batch.cmds[i]; + ring_wrids.push_back(batch.wrs[i]); uint64_t remote_addr = ctx->remote_addr + (cmd.req_rptr ? cmd.req_rptr : 0); @@ -785,154 +887,127 @@ void post_rdma_async_batched(ProxyCtx& S, void* buf, size_t num_wrs, } std::abort(); } - // Optionally send an inline "atomic" via imm, else use imm only on tail + + uintptr_t laddr = + cmd.req_lptr + reinterpret_cast(ctx->mr->addr); + sges[j] = { + .addr = laddr, + .length = static_cast(cmd.bytes), + .lkey = ctx->mr->lkey, + }; + + std::memset(&wrs[j], 0, sizeof(wrs[j])); + wrs[j].wr_id = batch.wrs[i]; + wrs[j].sg_list = &sges[j]; + wrs[j].num_sge = 1; + wrs[j].wr.rdma.remote_addr = remote_addr; + wrs[j].wr.rdma.rkey = ctx->remote_rkey; + wrs[j].opcode = IBV_WR_RDMA_WRITE; + wrs[j].send_flags = (j + 1 == kgroup) ? IBV_SEND_SIGNALED : 0; + wrs[j].next = (j + 1 < kgroup) ? &wrs[j + 1] : nullptr; + if (cmd.atomic_offset > 0 && cmd.atomic_val > 0) { int v = static_cast(cmd.atomic_val); if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { - fprintf(stderr, "[EFA] atomic value=%d won't fit in 15 bits\n", v); + fprintf(stderr, "atomic value=%d won't fit in 15 bits\n", v); std::abort(); } - size_t index = static_cast(cmd.atomic_offset / sizeof(int)); - // Initialize missing entries lazily - auto key = ctx->seq_key(dst_rank, index); - if (ctx->next_seq_per_index.find(key) == - ctx->next_seq_per_index.end()) - ctx->next_seq_per_index[key] = 0; - - uint8_t seq = ctx->next_seq_per_index[key]; - ctx->next_seq_per_index[key] = - (seq + 1) % kReorderingBufferSize; // 4-bit wrap (0–15) uint32_t imm = - AtomicsImm::PackAtomicWithSeq(v, cmd.atomic_offset, seq, true) + AtomicsImm::Pack(true, false, cmd.atomic_val, cmd.atomic_offset, + get_low_latency(cmd.cmd_type)) .GetImmData(); - AtomicsImm aimm(imm); - assert(aimm.GetSeq() == seq); - - ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); - } else if (j + 1 == idxs.size()) { + wrs[j].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wrs[j].imm_data = htonl(imm); + } else if (j + 1 == kgroup) { uint32_t imm = WriteImm::Pack(get_is_combine(cmd.cmd_type), get_low_latency(cmd.cmd_type), cmd.expert_idx, - (uint32_t)idxs.size(), my_rank) + static_cast(kgroup), my_rank) .GetImmData(); - ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); + wrs[j].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wrs[j].imm_data = htonl(imm); } else { - ibv_wr_rdma_write(qpx, ctx->remote_rkey, remote_addr); + wrs[j].opcode = IBV_WR_RDMA_WRITE; } - - uintptr_t laddr = - cmd.req_lptr + reinterpret_cast(ctx->mr->addr); - ibv_wr_set_ud_addr(qpx, ctx->dst_ah, dst_qpn, QKEY); - ibv_wr_set_sge(qpx, ctx->mr->lkey, laddr, - static_cast(cmd.bytes)); - - ring_wrids.push_back(wrs_to_post[i]); } - int ret = ibv_wr_complete(qpx); + + ibv_send_wr* bad = nullptr; + int ret = ibv_post_send(qp, &wrs[0], &bad); if (ret) { - fprintf(stderr, "ibv_wr_complete failed (dst=%d): %s (ret=%d)\n", + fprintf(stderr, "ibv_post_send failed (dst=%d): %s (ret=%d)\n", dst_rank, strerror(ret), ret); + if (bad) + fprintf(stderr, "Bad WR at %p (wr_id=%lu)\n", (void*)bad, bad->wr_id); std::abort(); } -#else - { - size_t const local_ring_count = ctx->data_qps_by_channel.size(); - struct ibv_qp* qp = - local_ring_count - ? ctx->data_qps_by_channel[ring_idx_raw % local_ring_count] - : ctx->ack_qp; - - size_t const kgroup = idxs.size(); - std::vector sges(kgroup); - std::vector wrs(kgroup); - std::vector ring_wrids; - ring_wrids.reserve(kgroup); - - for (size_t j = 0; j < kgroup; ++j) { - size_t i = idxs[j]; - auto const& cmd = cmds_to_post[i]; - ring_wrids.push_back(wrs_to_post[i]); - - // Remote address bounds check - uint64_t remote_addr = - ctx->remote_addr + (cmd.req_rptr ? cmd.req_rptr : 0); - uint64_t remote_end = ctx->remote_addr + ctx->remote_len; - - if (remote_addr < ctx->remote_addr || - remote_addr + cmd.bytes > remote_end) { - fprintf( - stderr, - "[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, " - "size=%zu), cmd.req_rptr: 0x%llx\n", - (unsigned long long)remote_addr, cmd.bytes, - (unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len, - (unsigned long long)cmd.req_rptr); - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", - cudaGetErrorString(err)); - } - std::abort(); - } + } +#endif + } - // Local SGE - uintptr_t laddr = - cmd.req_lptr + reinterpret_cast(ctx->mr->addr); - sges[j] = { - .addr = laddr, - .length = static_cast(cmd.bytes), - .lkey = ctx->mr->lkey, - }; - - // Build WR - std::memset(&wrs[j], 0, sizeof(wrs[j])); - wrs[j].wr_id = wrs_to_post[i]; - wrs[j].sg_list = &sges[j]; - wrs[j].num_sge = 1; - wrs[j].wr.rdma.remote_addr = remote_addr; - wrs[j].wr.rdma.rkey = ctx->remote_rkey; - wrs[j].opcode = IBV_WR_RDMA_WRITE; // default - wrs[j].send_flags = (j + 1 == kgroup) ? IBV_SEND_SIGNALED : 0; - wrs[j].next = (j + 1 < kgroup) ? &wrs[j + 1] : nullptr; - - if (cmd.atomic_offset > 0 && cmd.atomic_val > 0) { - int v = static_cast(cmd.atomic_val); - if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { - fprintf(stderr, "atomic value=%d won't fit in 15 bits\n", v); - std::abort(); - } - uint32_t imm = - AtomicsImm::Pack(true, false, cmd.atomic_val, cmd.atomic_offset, - get_low_latency(cmd.cmd_type)) - .GetImmData(); - wrs[j].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wrs[j].imm_data = htonl(imm); - } else if (j + 1 == kgroup) { - // Put WriteImm only on the tail WR - uint32_t imm = - WriteImm::Pack(get_is_combine(cmd.cmd_type), - get_low_latency(cmd.cmd_type), cmd.expert_idx, - static_cast(kgroup), my_rank) - .GetImmData(); - wrs[j].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wrs[j].imm_data = htonl(imm); - } else { - wrs[j].opcode = IBV_WR_RDMA_WRITE; - } - } + // Clear the batch + batch.wrs.clear(); + batch.cmds.clear(); + batch.has_pending = false; +} - // Post the chain - ibv_send_wr* bad = nullptr; - int ret = ibv_post_send(qp, &wrs[0], &bad); - if (ret) { - fprintf(stderr, "ibv_post_send failed (dst=%d): %s (ret=%d)\n", - dst_rank, strerror(ret), ret); - if (bad) - fprintf(stderr, "Bad WR at %p (wr_id=%lu)\n", (void*)bad, - bad->wr_id); - std::abort(); - } -#endif +void post_rdma_async_batched(ProxyCtx& S, void* buf, size_t num_wrs, + std::vector const& wrs_to_post, + std::vector const& cmds_to_post, + std::vector>& ctxs, + int my_rank, int thread_idx) { + if (num_wrs == 0) return; + if (wrs_to_post.size() != num_wrs || cmds_to_post.size() != num_wrs) { + fprintf(stderr, "Size mismatch (num_wrs=%zu, wr_ids=%zu, cmds=%zu)\n", + num_wrs, wrs_to_post.size(), cmds_to_post.size()); + std::abort(); + } + + // Group by destination rank - use pre-allocated buffer + auto& dst_rank_wr_ids = S.reusable_dst_rank_wr_ids; + // Clear existing entries and their vectors + for (auto& [key, vec] : dst_rank_wr_ids) { + vec.clear(); + } + dst_rank_wr_ids.clear(); + for (size_t i = 0; i < num_wrs; ++i) { + if (cmds_to_post[i].dst_rank == static_cast(my_rank)) { + printf("Posting rdma to itself\n"); + std::abort(); + continue; + } else if (std::abs((int)cmds_to_post[i].dst_rank - (int)my_rank) % + MAX_NUM_GPUS != + 0) { + printf("Posting rdma to a different rank\n"); + std::abort(); + continue; + } else { + dst_rank_wr_ids[cmds_to_post[i].dst_rank].push_back(i); + } + } + + // Add to pending batches and flush if needed + auto now = std::chrono::steady_clock::now(); + + for (auto& [dst_rank, wr_ids] : dst_rank_wr_ids) { + if (wr_ids.empty()) continue; + + auto& batch = S.pending_batches[dst_rank]; + + // Initialize batch if empty + if (!batch.has_pending) { + batch.first_cmd_time = now; + batch.has_pending = true; + } + + // Add commands to batch + for (size_t idx : wr_ids) { + batch.wrs.push_back(wrs_to_post[idx]); + batch.cmds.push_back(cmds_to_post[idx]); + } + + // Flush only when batch is full (remove time-based check for performance) + if (batch.wrs.size() >= ProxyCtx::kMaxBatchSize) { + flush_pending_batch_for_dst(S, dst_rank, buf, ctxs, my_rank, thread_idx); } } } @@ -1103,71 +1178,70 @@ void post_rdma_async_batched(ProxyCtx& S, void* buf, size_t num_wrs, std::abort(); } #else - std::vector sges(k); - std::vector wrs(k); - for (size_t j = 0; j < k; ++j) { - size_t i = wr_ids[j]; - auto const& cmd = cmds_to_post[i]; - wr_ids[j] = wrs_to_post[i]; - sges[j].addr = - cmd.req_lptr + reinterpret_cast(ctx->mr->addr); - sges[j].length = static_cast(cmd.bytes); - sges[j].lkey = ctx->mr->lkey; - std::memset(&wrs[j], 0, sizeof(wrs[j])); - wrs[j].sg_list = &sges[j]; - wrs[j].num_sge = 1; - wrs[j].wr_id = wr_ids[j]; - - wrs[j].wr.rdma.remote_addr = ctx->remote_addr + cmd.req_rptr; - - uint64_t remote_end = ctx->remote_addr + ctx->remote_len; - if (wrs[j].wr.rdma.remote_addr < ctx->remote_addr || - wrs[j].wr.rdma.remote_addr + cmd.bytes > remote_end) { - fprintf(stderr, - "[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, " - "size=%zu), cmd.req_rptr: 0x%llx\n", - (unsigned long long)wrs[j].wr.rdma.remote_addr, cmd.bytes, - (unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len, - (unsigned long long)cmd.req_rptr); - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", - cudaGetErrorString(err)); - std::abort(); - } + std::vector sges(k); + std::vector wrs(k); + for (size_t j = 0; j < k; ++j) { + size_t i = wr_ids[j]; + auto const& cmd = cmds_to_post[i]; + wr_ids[j] = wrs_to_post[i]; + sges[j].addr = cmd.req_lptr + reinterpret_cast(ctx->mr->addr); + sges[j].length = static_cast(cmd.bytes); + sges[j].lkey = ctx->mr->lkey; + std::memset(&wrs[j], 0, sizeof(wrs[j])); + wrs[j].sg_list = &sges[j]; + wrs[j].num_sge = 1; + wrs[j].wr_id = wr_ids[j]; + + wrs[j].wr.rdma.remote_addr = ctx->remote_addr + cmd.req_rptr; + + uint64_t remote_end = ctx->remote_addr + ctx->remote_len; + if (wrs[j].wr.rdma.remote_addr < ctx->remote_addr || + wrs[j].wr.rdma.remote_addr + cmd.bytes > remote_end) { + fprintf(stderr, + "[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, " + "size=%zu), cmd.req_rptr: 0x%llx\n", + (unsigned long long)wrs[j].wr.rdma.remote_addr, cmd.bytes, + (unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len, + (unsigned long long)cmd.req_rptr); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", + cudaGetErrorString(err)); std::abort(); } - - wrs[j].wr.rdma.rkey = ctx->remote_rkey; - wrs[j].opcode = IBV_WR_RDMA_WRITE; - wrs[j].send_flags = 0; - wrs[j].next = (j + 1 < k) ? &wrs[j + 1] : nullptr; - } - size_t const last = k - 1; - uint64_t const batch_tail_wr = wr_ids[last]; - wrs[last].send_flags |= IBV_SEND_SIGNALED; - wrs[last].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wrs[last].imm_data = htonl(static_cast(batch_tail_wr)); - ibv_send_wr* bad = nullptr; - int ret = ibv_post_send(ctx->qp, &wrs[0], &bad); - if (ret) { - fprintf(stderr, "ibv_post_send failed (dst=%d): %s (ret=%d)\n", - dst_rank, strerror(ret), ret); - if (bad) - fprintf(stderr, "Bad WR at %p (wr_id=%lu)\n", (void*)bad, bad->wr_id); std::abort(); } - { - auto [it, inserted] = - S.wr_id_to_wr_ids.try_emplace(batch_tail_wr, std::move(wr_ids)); - if (!inserted) { - fprintf(stderr, - "thread_idx: %d, Error: tail wr_id %lu already exists " - "(map=%p)\n", - thread_idx, batch_tail_wr, (void*)&S.wr_id_to_wr_ids); - std::abort(); - } + + wrs[j].wr.rdma.rkey = ctx->remote_rkey; + wrs[j].opcode = IBV_WR_RDMA_WRITE; + wrs[j].send_flags = 0; + wrs[j].next = (j + 1 < k) ? &wrs[j + 1] : nullptr; + } + size_t const last = k - 1; + uint64_t const batch_tail_wr = wr_ids[last]; + wrs[last].send_flags |= IBV_SEND_SIGNALED; + wrs[last].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wrs[last].imm_data = htonl(static_cast(batch_tail_wr)); + ibv_send_wr* bad = nullptr; + int ret = ibv_post_send(ctx->qp, &wrs[0], &bad); + if (ret) { + fprintf(stderr, "ibv_post_send failed (dst=%d): %s (ret=%d)\n", dst_rank, + strerror(ret), ret); + if (bad) + fprintf(stderr, "Bad WR at %p (wr_id=%lu)\n", (void*)bad, bad->wr_id); + std::abort(); + } + { + auto [it, inserted] = + S.wr_id_to_wr_ids.try_emplace(batch_tail_wr, std::move(wr_ids)); + if (!inserted) { + fprintf(stderr, + "thread_idx: %d, Error: tail wr_id %lu already exists " + "(map=%p)\n", + thread_idx, batch_tail_wr, (void*)&S.wr_id_to_wr_ids); + std::abort(); } + } #endif } } @@ -1346,11 +1420,11 @@ void apply_pending_updates(ProxyCtx& ctx, } } else { int combine_num_tokens = ctx.combine_token_counter.Get( - {upd.low_latency_buffer_idx, upd.expert_idx}); + upd.low_latency_buffer_idx, upd.expert_idx); if (value == combine_num_tokens) { is_atomic_ready = true; - ctx.combine_token_counter.Reset( - {upd.low_latency_buffer_idx, upd.expert_idx}); + ctx.combine_token_counter.Reset(upd.low_latency_buffer_idx, + upd.expert_idx); } } if (is_atomic_ready) { @@ -1438,7 +1512,7 @@ void remote_process_completions(ProxyCtx& S, int idx, CopyRingBuffer& g_ring, std::abort(); } int combine_num_tokens = - S.combine_token_counter.Get({low_latency_buffer_idx, expert_idx}); + S.combine_token_counter.Get(low_latency_buffer_idx, expert_idx); if (value == combine_num_tokens) { is_atomic_ready = true; } @@ -1450,7 +1524,7 @@ void remote_process_completions(ProxyCtx& S, int idx, CopyRingBuffer& g_ring, value, combine_num_tokens, expert_idx); } if (is_atomic_ready) { - S.combine_token_counter.Reset({low_latency_buffer_idx, expert_idx}); + S.combine_token_counter.Reset(low_latency_buffer_idx, expert_idx); } } auto* addr32 = @@ -1608,7 +1682,7 @@ void remote_process_completions(ProxyCtx& S, int idx, CopyRingBuffer& g_ring, /* expert_idx here is the global expert index of the sender. */ assert(expert_idx >= src_rank * (num_experts / num_ranks) && expert_idx < (src_rank + 1) * (num_experts / num_ranks)); - S.combine_token_counter.Add({buffer_idx, expert_idx}, k); + S.combine_token_counter.Add(buffer_idx, expert_idx, k); } #endif } else if (cqe.opcode == IBV_WC_RECV_RDMA_WITH_IMM) { @@ -1889,77 +1963,76 @@ void post_atomic_operations(ProxyCtx& S, std::abort(); } #else - struct ibv_qp* qp = - local_ring_count - ? ctx->data_qps_by_channel[ring_idx_raw % local_ring_count] - : ctx->ack_qp; - - size_t const k = idxs.size(); - std::vector sge(k); - std::vector wr(k); - std::vector group_wrids; - group_wrids.reserve(k); - - for (size_t t = 0; t < k; ++t) { - size_t i = idxs[t]; - auto const& cmd = cmds_to_post[i]; - uint64_t const wr_id = wrs_to_post[i]; - group_wrids.push_back(wr_id); - - int v = static_cast(cmd.value); - if (v > kLargeAtomicValue) - v = kMaxSendAtomicValue; // saturate for imm - if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { - fprintf(stderr, - "value=%d (cmd.value=%lu) won't fit in 15 bits for imm; " - "use a different scheme.\n", - v, (unsigned long)cmd.value); - std::abort(); - } + struct ibv_qp* qp = + local_ring_count + ? ctx->data_qps_by_channel[ring_idx_raw % local_ring_count] + : ctx->ack_qp; + + size_t const k = idxs.size(); + std::vector sge(k); + std::vector wr(k); + std::vector group_wrids; + group_wrids.reserve(k); - // If your AtomicsImm for non-EFA expects 16-bit offsets, keep the - // mask: - uint32_t off16 = static_cast(cmd.req_rptr) & 0xFFFFu; - int low_latency_buffer_idx = get_low_latency(cmd.cmd_type); - if (low_latency_buffer_idx < 0 || low_latency_buffer_idx > 1) { - fprintf(stderr, "Invalid low_latency_buffer_idx: %d\n", - low_latency_buffer_idx); - std::abort(); - } - uint32_t imm = AtomicsImm::Pack( - /*is_atomic*/ true, - /*is_combine*/ get_is_combine(cmd.cmd_type), v, - /*offset*/ off16, low_latency_buffer_idx) - .GetImmData(); - - // Zero-length write-with-imm on RC QP - sge[t].addr = reinterpret_cast(ctx->mr->addr); - sge[t].length = 0; - sge[t].lkey = ctx->mr->lkey; - - std::memset(&wr[t], 0, sizeof(wr[t])); - wr[t].wr_id = kAtomicWrTag | (wr_id & kAtomicMask); - wr[t].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wr[t].send_flags = (t + 1 == k) ? IBV_SEND_SIGNALED : 0; - wr[t].imm_data = htonl(imm); - wr[t].sg_list = &sge[t]; - wr[t].num_sge = 1; - wr[t].wr.rdma.remote_addr = ctx->remote_addr; - wr[t].wr.rdma.rkey = ctx->remote_rkey; - wr[t].next = (t + 1 < k) ? &wr[t + 1] : nullptr; + for (size_t t = 0; t < k; ++t) { + size_t i = idxs[t]; + auto const& cmd = cmds_to_post[i]; + uint64_t const wr_id = wrs_to_post[i]; + group_wrids.push_back(wr_id); + + int v = static_cast(cmd.value); + if (v > kLargeAtomicValue) v = kMaxSendAtomicValue; // saturate for imm + if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { + fprintf(stderr, + "value=%d (cmd.value=%lu) won't fit in 15 bits for imm; " + "use a different scheme.\n", + v, (unsigned long)cmd.value); + std::abort(); } - ibv_send_wr* bad = nullptr; - int ret = ibv_post_send(qp, &wr[0], &bad); - if (ret) { - fprintf(stderr, "[RC] post_send(atomic imm) failed: %s (ret=%d)\n", - strerror(ret), ret); - if (bad) { - fprintf(stderr, " bad wr_id=0x%llx opcode=%u\n", - (unsigned long long)bad->wr_id, bad->opcode); - } + // If your AtomicsImm for non-EFA expects 16-bit offsets, keep the + // mask: + uint32_t off16 = static_cast(cmd.req_rptr) & 0xFFFFu; + int low_latency_buffer_idx = get_low_latency(cmd.cmd_type); + if (low_latency_buffer_idx < 0 || low_latency_buffer_idx > 1) { + fprintf(stderr, "Invalid low_latency_buffer_idx: %d\n", + low_latency_buffer_idx); std::abort(); } + uint32_t imm = AtomicsImm::Pack( + /*is_atomic*/ true, + /*is_combine*/ get_is_combine(cmd.cmd_type), v, + /*offset*/ off16, low_latency_buffer_idx) + .GetImmData(); + + // Zero-length write-with-imm on RC QP + sge[t].addr = reinterpret_cast(ctx->mr->addr); + sge[t].length = 0; + sge[t].lkey = ctx->mr->lkey; + + std::memset(&wr[t], 0, sizeof(wr[t])); + wr[t].wr_id = kAtomicWrTag | (wr_id & kAtomicMask); + wr[t].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wr[t].send_flags = (t + 1 == k) ? IBV_SEND_SIGNALED : 0; + wr[t].imm_data = htonl(imm); + wr[t].sg_list = &sge[t]; + wr[t].num_sge = 1; + wr[t].wr.rdma.remote_addr = ctx->remote_addr; + wr[t].wr.rdma.rkey = ctx->remote_rkey; + wr[t].next = (t + 1 < k) ? &wr[t + 1] : nullptr; + } + + ibv_send_wr* bad = nullptr; + int ret = ibv_post_send(qp, &wr[0], &bad); + if (ret) { + fprintf(stderr, "[RC] post_send(atomic imm) failed: %s (ret=%d)\n", + strerror(ret), ret); + if (bad) { + fprintf(stderr, " bad wr_id=0x%llx opcode=%u\n", + (unsigned long long)bad->wr_id, bad->opcode); + } + std::abort(); + } #endif } }