diff --git a/ep/src/rdma.cpp b/ep/src/rdma.cpp index 07c7453ea..462156d06 100644 --- a/ep/src/rdma.cpp +++ b/ep/src/rdma.cpp @@ -377,7 +377,6 @@ void create_per_thread_qp(ProxyCtx& S, void* gpu_buffer, size_t size, } #endif -#ifdef USE_NORMAL_MODE const size_t rings_to_create = std::min(num_rings, (size_t)kChannelPerProxy); S.data_qps_by_channel.resize(rings_to_create); for (size_t r = 0; r < rings_to_create; ++r) { @@ -400,7 +399,6 @@ void create_per_thread_qp(ProxyCtx& S, void* gpu_buffer, size_t size, for (uint32_t r = local_info->num_rings; r < kChannelPerProxy; ++r) { local_info->data_qp_num[r] = 0; } -#endif // Query port struct ibv_port_attr port_attr; @@ -974,201 +972,245 @@ void post_rdma_async_batched(ProxyCtx& S, void* buf, size_t num_wrs, std::abort(); } size_t const k = wr_ids.size(); + + std::unordered_map> ring_to_indices; + ring_to_indices.reserve(k); + for (size_t ii = 0; ii < k; ++ii) { + size_t global_i = wr_ids[ii]; + size_t ring_idx = + static_cast((wrs_to_post[global_i] >> 32) & 0xFFFFFFFFu); + ring_to_indices[ring_idx].push_back(global_i); + } + + for (auto& [ring_idx_raw, idxs] : ring_to_indices) { + size_t const local_ring_count = ctx->data_qps_by_channel.size(); #ifdef EFA - struct ibv_qp_ex* qpx = (struct ibv_qp_ex*)ctx->qp; - ibv_wr_start(qpx); + struct ibv_qp_ex* qpx = + (struct ibv_qp_ex*)(local_ring_count + ? ctx->data_qps_by_channel[ring_idx_raw % + local_ring_count] + : ctx->qp); + + size_t const remote_ring_count = ctx->dst_data_qpn_by_ring.size(); + uint32_t const dst_qpn = + remote_ring_count + ? ctx->dst_data_qpn_by_ring[ring_idx_raw % remote_ring_count] + : ctx->dst_qpn; + + ibv_wr_start(qpx); #ifdef USE_RECEIVER_BARRIER - std::unordered_map> dst_expert_wr_ids; - for (size_t j = 0; j < k; ++j) { - size_t i = wr_ids[j]; - int expert_idx = cmds_to_post[i].expert_idx; - dst_expert_wr_ids[expert_idx].push_back(i); - } + std::unordered_map> dst_expert_wr_ids; + for (size_t j = 0; j < idxs.size(); ++j) { + size_t i = idxs[j]; + int expert_idx = cmds_to_post[i].expert_idx; + dst_expert_wr_ids[expert_idx].push_back(i); + } #endif #ifdef USE_RECEIVER_BARRIER - for (auto& [expert_idx, expert_wr_ids] : dst_expert_wr_ids) { - size_t expert_k = expert_wr_ids.size(); - for (size_t j = 0; j < expert_k; ++j) { - size_t i = expert_wr_ids[j]; + for (auto& [expert_idx, expert_wr_ids] : dst_expert_wr_ids) { + size_t expert_k = expert_wr_ids.size(); + for (size_t j = 0; j < expert_k; ++j) { + size_t i = expert_wr_ids[j]; #else - for (size_t j = 0; j < k; ++j) { - size_t i = wr_ids[j]; + for (size_t j = 0; j < idxs.size(); ++j) { + size_t i = idxs[j]; #endif - auto const& cmd = cmds_to_post[i]; + auto const& cmd = cmds_to_post[i]; #ifdef USE_RECEIVER_BARRIER - expert_wr_ids[j] = wrs_to_post[i]; + expert_wr_ids[j] = wrs_to_post[i]; #else - wr_ids[j] = wrs_to_post[i]; + idxs[j] = wrs_to_post[i]; #endif - qpx->wr_id = wrs_to_post[i]; - qpx->comp_mask = 0; - qpx->wr_flags = IBV_SEND_SIGNALED; + qpx->wr_id = wrs_to_post[i]; + qpx->comp_mask = 0; + qpx->wr_flags = IBV_SEND_SIGNALED; - uint64_t remote_addr = - ctx->remote_addr + (cmd.req_rptr ? cmd.req_rptr : 0); - uint64_t remote_end = ctx->remote_addr + ctx->remote_len; + uint64_t remote_addr = + ctx->remote_addr + (cmd.req_rptr ? cmd.req_rptr : 0); + uint64_t remote_end = ctx->remote_addr + ctx->remote_len; - if (remote_addr < ctx->remote_addr || - remote_addr + cmd.bytes > remote_end) { - fprintf(stderr, - "[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, " - "size=%zu), cmd.req_rptr: 0x%llx\n", - (unsigned long long)remote_addr, cmd.bytes, - (unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len, - (unsigned long long)cmd.req_rptr); - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", - cudaGetErrorString(err)); + if (remote_addr < ctx->remote_addr || + remote_addr + cmd.bytes > remote_end) { + fprintf( + stderr, + "[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, " + "size=%zu), cmd.req_rptr: 0x%llx\n", + (unsigned long long)remote_addr, cmd.bytes, + (unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len, + (unsigned long long)cmd.req_rptr); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", + cudaGetErrorString(err)); + std::abort(); + } std::abort(); } - std::abort(); - } #ifdef USE_SENDER_BARRIER - S.wr_id_to_write_struct[qpx->wr_id] = {cmd.expert_idx, dst_rank, - get_is_combine(cmd.cmd_type), - get_low_latency(cmd.cmd_type)}; + S.wr_id_to_write_struct[qpx->wr_id] = {cmd.expert_idx, dst_rank, + get_is_combine(cmd.cmd_type), + get_low_latency(cmd.cmd_type)}; #endif #ifdef USE_RECEIVER_BARRIER - uint32_t imm = WriteImm::Pack(get_is_combine(cmd.cmd_type), - get_low_latency(cmd.cmd_type), - cmd.expert_idx, 1, my_rank) - .GetImmData(); - ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); + uint32_t imm = WriteImm::Pack(get_is_combine(cmd.cmd_type), + get_low_latency(cmd.cmd_type), + cmd.expert_idx, 1, my_rank) + .GetImmData(); + ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); #else - if (cmd.atomic_offset > 0 && cmd.atomic_val > 0) { - int v = static_cast(cmd.atomic_val); - if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { - fprintf(stderr, "[EFA] atomic value=%d won't fit in 15 bits\n", v); - std::abort(); + if (cmd.atomic_offset > 0 && cmd.atomic_val > 0) { + int v = static_cast(cmd.atomic_val); + if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { + fprintf(stderr, "[EFA] atomic value=%d won't fit in 15 bits\n", v); + std::abort(); + } + uint32_t imm = + AtomicsImm::Pack(true, false, cmd.atomic_val, cmd.atomic_offset, + get_low_latency(cmd.cmd_type)) + .GetImmData(); + ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); + } else if (j + 1 == idxs.size()) { + uint32_t imm = WriteImm::Pack(get_is_combine(cmd.cmd_type), + get_low_latency(cmd.cmd_type), + cmd.expert_idx, idxs.size(), my_rank) + .GetImmData(); + ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); + } else { + ibv_wr_rdma_write(qpx, ctx->remote_rkey, remote_addr); } - uint32_t imm = - AtomicsImm::Pack(true, false, cmd.atomic_val, cmd.atomic_offset, - get_low_latency(cmd.cmd_type)) - .GetImmData(); - ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); - } else if (j + 1 == k) { - uint32_t imm = WriteImm::Pack(get_is_combine(cmd.cmd_type), - get_low_latency(cmd.cmd_type), - cmd.expert_idx, k, my_rank) - .GetImmData(); - ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); - } else { - ibv_wr_rdma_write(qpx, ctx->remote_rkey, remote_addr); - } #endif - uintptr_t laddr = - cmd.req_lptr + reinterpret_cast(ctx->mr->addr); - ibv_wr_set_ud_addr(qpx, ctx->dst_ah, ctx->dst_qpn, QKEY); - ibv_wr_set_sge(qpx, ctx->mr->lkey, laddr, - static_cast(cmd.bytes)); - } + uintptr_t laddr = + cmd.req_lptr + reinterpret_cast(ctx->mr->addr); + ibv_wr_set_ud_addr(qpx, ctx->dst_ah, dst_qpn, QKEY); + ibv_wr_set_sge(qpx, ctx->mr->lkey, laddr, + static_cast(cmd.bytes)); + } #ifdef USE_RECEIVER_BARRIER - uint64_t const expert_tail_wr = expert_wr_ids.back(); + uint64_t const expert_tail_wr = expert_wr_ids.back(); + { + auto [it, inserted] = S.wr_id_to_wr_ids.try_emplace( + expert_tail_wr, std::move(expert_wr_ids)); + if (!inserted) { + fprintf(stderr, + "thread_idx: %d, Error: tail wr_id %lu already exists " + "(map=%p)\n", + thread_idx, expert_tail_wr, (void*)&S.wr_id_to_wr_ids); + std::abort(); + } + } + } +#else + std::vector ring_wrids; + ring_wrids.reserve(idxs.size()); + for (size_t j = 0; j < idxs.size(); ++j) { + ring_wrids.push_back(idxs[j]); + } + uint64_t const tail_wr = ring_wrids.back(); { - auto [it, inserted] = S.wr_id_to_wr_ids.try_emplace( - expert_tail_wr, std::move(expert_wr_ids)); + auto [it, inserted] = + S.wr_id_to_wr_ids.try_emplace(tail_wr, std::move(ring_wrids)); if (!inserted) { fprintf(stderr, "thread_idx: %d, Error: tail wr_id %lu already exists " "(map=%p)\n", - thread_idx, expert_tail_wr, (void*)&S.wr_id_to_wr_ids); + thread_idx, tail_wr, (void*)&S.wr_id_to_wr_ids); std::abort(); } } - } -#else - uint64_t const tail_wr = wr_ids.back(); - { - auto [it, inserted] = - S.wr_id_to_wr_ids.try_emplace(tail_wr, std::move(wr_ids)); - if (!inserted) { - fprintf(stderr, - "thread_idx: %d, Error: tail wr_id %lu already exists " - "(map=%p)\n", - thread_idx, tail_wr, (void*)&S.wr_id_to_wr_ids); - std::abort(); - } - } #endif - int ret = ibv_wr_complete(qpx); - if (ret) { - fprintf(stderr, "ibv_wr_complete failed (dst=%d): %s (ret=%d)\n", - dst_rank, strerror(ret), ret); - std::abort(); - } + int ret = ibv_wr_complete(qpx); + if (ret) { + fprintf(stderr, "ibv_wr_complete failed (dst=%d): %s (ret=%d)\n", + dst_rank, strerror(ret), ret); + std::abort(); + } #else - std::vector sges(k); - std::vector wrs(k); - for (size_t j = 0; j < k; ++j) { - size_t i = wr_ids[j]; - auto const& cmd = cmds_to_post[i]; - wr_ids[j] = wrs_to_post[i]; - sges[j].addr = - cmd.req_lptr + reinterpret_cast(ctx->mr->addr); - sges[j].length = static_cast(cmd.bytes); - sges[j].lkey = ctx->mr->lkey; - std::memset(&wrs[j], 0, sizeof(wrs[j])); - wrs[j].sg_list = &sges[j]; - wrs[j].num_sge = 1; - wrs[j].wr_id = wr_ids[j]; - - wrs[j].wr.rdma.remote_addr = ctx->remote_addr + cmd.req_rptr; - - uint64_t remote_end = ctx->remote_addr + ctx->remote_len; - if (wrs[j].wr.rdma.remote_addr < ctx->remote_addr || - wrs[j].wr.rdma.remote_addr + cmd.bytes > remote_end) { - fprintf(stderr, + { + size_t const local_ring_count = ctx->data_qps_by_channel.size(); + struct ibv_qp* qp = + local_ring_count + ? ctx->data_qps_by_channel[ring_idx_raw % local_ring_count] + : ctx->qp; + + size_t const kgroup = idxs.size(); + std::vector sges(kgroup); + std::vector wrs(kgroup); + std::vector ring_wrids; + ring_wrids.reserve(kgroup); + + for (size_t j = 0; j < kgroup; ++j) { + size_t i = idxs[j]; + auto const& cmd = cmds_to_post[i]; + ring_wrids.push_back(wrs_to_post[i]); + sges[j].addr = + cmd.req_lptr + reinterpret_cast(ctx->mr->addr); + sges[j].length = static_cast(cmd.bytes); + sges[j].lkey = ctx->mr->lkey; + std::memset(&wrs[j], 0, sizeof(wrs[j])); + wrs[j].sg_list = &sges[j]; + wrs[j].num_sge = 1; + wrs[j].wr_id = ring_wrids[j]; + + wrs[j].wr.rdma.remote_addr = ctx->remote_addr + cmd.req_rptr; + + uint64_t remote_end = ctx->remote_addr + ctx->remote_len; + if (wrs[j].wr.rdma.remote_addr < ctx->remote_addr || + wrs[j].wr.rdma.remote_addr + cmd.bytes > remote_end) { + fprintf( + stderr, "[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, " "size=%zu), cmd.req_rptr: 0x%llx\n", (unsigned long long)wrs[j].wr.rdma.remote_addr, cmd.bytes, (unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len, (unsigned long long)cmd.req_rptr); - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", - cudaGetErrorString(err)); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", + cudaGetErrorString(err)); + std::abort(); + } + std::abort(); + } + + wrs[j].wr.rdma.rkey = ctx->remote_rkey; + wrs[j].opcode = IBV_WR_RDMA_WRITE; + wrs[j].send_flags = 0; + wrs[j].next = (j + 1 < kgroup) ? &wrs[j + 1] : nullptr; + } + size_t const last = kgroup - 1; + uint64_t const batch_tail_wr = ring_wrids[last]; + wrs[last].send_flags |= IBV_SEND_SIGNALED; + wrs[last].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wrs[last].imm_data = htonl(static_cast(batch_tail_wr)); + ibv_send_wr* bad = nullptr; + int ret = ibv_post_send(qp, &wrs[0], &bad); + if (ret) { + fprintf(stderr, "ibv_post_send failed (dst=%d): %s (ret=%d)\n", + dst_rank, strerror(ret), ret); + if (bad) + fprintf(stderr, "Bad WR at %p (wr_id=%lu)\n", (void*)bad, + bad->wr_id); std::abort(); } - std::abort(); - } - - wrs[j].wr.rdma.rkey = ctx->remote_rkey; - wrs[j].opcode = IBV_WR_RDMA_WRITE; - wrs[j].send_flags = 0; - wrs[j].next = (j + 1 < k) ? &wrs[j + 1] : nullptr; - } - size_t const last = k - 1; - uint64_t const batch_tail_wr = wr_ids[last]; - wrs[last].send_flags |= IBV_SEND_SIGNALED; - wrs[last].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wrs[last].imm_data = htonl(static_cast(batch_tail_wr)); - ibv_send_wr* bad = nullptr; - int ret = ibv_post_send(ctx->qp, &wrs[0], &bad); - if (ret) { - fprintf(stderr, "ibv_post_send failed (dst=%d): %s (ret=%d)\n", - dst_rank, strerror(ret), ret); - if (bad) - fprintf(stderr, "Bad WR at %p (wr_id=%lu)\n", (void*)bad, bad->wr_id); - std::abort(); - } - { - auto [it, inserted] = - S.wr_id_to_wr_ids.try_emplace(batch_tail_wr, std::move(wr_ids)); - if (!inserted) { - fprintf(stderr, - "thread_idx: %d, Error: tail wr_id %lu already exists " - "(map=%p)\n", - thread_idx, batch_tail_wr, (void*)&S.wr_id_to_wr_ids); - std::abort(); + { + auto [it, inserted] = S.wr_id_to_wr_ids.try_emplace( + batch_tail_wr, std::move(ring_wrids)); + if (!inserted) { + fprintf(stderr, + "thread_idx: %d, Error: tail wr_id %lu already exists " + "(map=%p)\n", + thread_idx, batch_tail_wr, (void*)&S.wr_id_to_wr_ids); + std::abort(); + } + } } - } #endif + } } } #endif @@ -1993,110 +2035,139 @@ void post_atomic_operations(ProxyCtx& S, ProxyCtx* ctx = ctxs[dst_rank].get(); size_t const k = wr_ids.size(); + + std::unordered_map> ring_to_indices; + ring_to_indices.reserve(k); + for (size_t ii = 0; ii < k; ++ii) { + size_t global_i = wr_ids[ii]; + size_t ring_idx = + static_cast((wrs_to_post[global_i] >> 32) & 0xFFFFFFFFu); + ring_to_indices[ring_idx].push_back(global_i); + } + + for (auto& [ring_idx_raw, idxs] : ring_to_indices) { + size_t const local_ring_count = ctx->data_qps_by_channel.size(); #ifdef EFA - struct ibv_qp_ex* qpx = (struct ibv_qp_ex*)ctx->qp; - ibv_wr_start(qpx); - for (size_t i = 0; i < k; ++i) { - auto const& cmd = cmds_to_post[wr_ids[i]]; - auto wr_id = wrs_to_post[wr_ids[i]]; - wr_ids[i] = wr_id; - - int v = static_cast(cmd.value); - if (v == kLargeAtomicValue) v = kMaxSendAtomicValue; - if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { - fprintf(stderr, - "[EFA] value=%d (cmd.value: %lu) won't fit in 15 bits; " - "use an inline payload scheme instead.\n", - v, (unsigned long)cmd.value); - std::abort(); + struct ibv_qp_ex* qpx = + (struct ibv_qp_ex*)(local_ring_count + ? ctx->data_qps_by_channel[ring_idx_raw % + local_ring_count] + : ctx->qp); + + size_t const remote_ring_count = ctx->dst_data_qpn_by_ring.size(); + uint32_t const dst_qpn = + remote_ring_count + ? ctx->dst_data_qpn_by_ring[ring_idx_raw % remote_ring_count] + : ctx->dst_qpn; + + ibv_wr_start(qpx); + + std::vector ring_wrids; + ring_wrids.reserve(idxs.size()); + + for (size_t j = 0; j < idxs.size(); ++j) { + size_t i = idxs[j]; + auto const& cmd = cmds_to_post[i]; + auto wr_id = wrs_to_post[i]; + ring_wrids.push_back(wr_id); + + int v = static_cast(cmd.value); + if (v == kLargeAtomicValue) v = kMaxSendAtomicValue; + if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { + fprintf(stderr, + "[EFA] value=%d (cmd.value: %lu) won't fit in 15 bits; " + "use an inline payload scheme instead.\n", + v, (unsigned long)cmd.value); + std::abort(); + } + uint32_t offset = static_cast(cmd.req_rptr); + int low_latency_buffer_idx = get_low_latency(cmd.cmd_type); + if (low_latency_buffer_idx < 0 || low_latency_buffer_idx > 1) { + fprintf(stderr, "Invalid low_latency_buffer_idx: %d\n", + low_latency_buffer_idx); + std::abort(); + } + uint32_t imm = AtomicsImm::Pack(true, get_is_combine(cmd.cmd_type), v, + offset, low_latency_buffer_idx) + .GetImmData(); + + qpx->wr_id = kAtomicWrTag | (wr_id & kAtomicMask); + qpx->comp_mask = 0; + qpx->wr_flags = IBV_SEND_SIGNALED; + ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, ctx->remote_addr, + htonl(imm)); + + ibv_wr_set_ud_addr(qpx, ctx->dst_ah, dst_qpn, QKEY); + ibv_wr_set_sge(qpx, ctx->mr->lkey, (uintptr_t)ctx->mr->addr, 0); } - uint32_t offset = static_cast(cmd.req_rptr); - int low_latency_buffer_idx = get_low_latency(cmd.cmd_type); - if (low_latency_buffer_idx < 0 || low_latency_buffer_idx > 1) { - fprintf(stderr, "Invalid low_latency_buffer_idx: %d\n", - low_latency_buffer_idx); + int ret = ibv_wr_complete(qpx); + if (ret) { + fprintf(stderr, "[EFA] post_send failed: %s (ret=%d)\n", strerror(ret), + ret); std::abort(); } - uint32_t imm = AtomicsImm::Pack(true, get_is_combine(cmd.cmd_type), v, - offset, low_latency_buffer_idx) - .GetImmData(); - - qpx->wr_id = kAtomicWrTag | (wr_id & kAtomicMask); - qpx->comp_mask = 0; - qpx->wr_flags = IBV_SEND_SIGNALED; - ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, ctx->remote_addr, - htonl(imm)); - - ibv_wr_set_ud_addr(qpx, ctx->dst_ah, ctx->dst_qpn, QKEY); - ibv_wr_set_sge(qpx, ctx->mr->lkey, (uintptr_t)ctx->mr->addr, 0); - } - int ret = ibv_wr_complete(qpx); - if (ret) { - fprintf(stderr, "[EFA] post_send failed: %s (ret=%d)\n", strerror(ret), - ret); - std::abort(); - } #else - std::vector sge(k); - std::vector wr(k); - - for (size_t i = 0; i < k; ++i) { - auto const& cmd = cmds_to_post[wr_ids[i]]; - uint64_t const wrid = wrs_to_post[wr_ids[i]]; - wr_ids[i] = wrid; - - int v = static_cast(cmd.value); - if (v == kLargeAtomicValue) v = kMaxSendAtomicValue; - if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { - fprintf(stderr, "value=%d won't fit in 15 bits\n", v); - std::abort(); - } - uint32_t const off16 = static_cast(cmd.req_rptr) & 0xFFFFu; - int low_latency_buffer_idx = get_low_latency(cmd.cmd_type); - uint32_t const imm = AtomicsImm::Pack(true, get_is_combine(cmd.cmd_type), - v, off16, low_latency_buffer_idx) - .GetImmData(); - sge[i].addr = reinterpret_cast(ctx->mr->addr); - sge[i].length = 0; - sge[i].lkey = ctx->mr->lkey; - - std::memset(&wr[i], 0, sizeof(wr[i])); - wr[i].wr_id = kAtomicWrTag | (wrid & kAtomicMask); - wr[i].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wr[i].send_flags = (i + 1 == k) ? IBV_SEND_SIGNALED : 0; - wr[i].imm_data = htonl(imm); - wr[i].sg_list = &sge[i]; - wr[i].num_sge = 1; - wr[i].wr.rdma.remote_addr = ctx->remote_addr; - wr[i].wr.rdma.rkey = ctx->remote_rkey; - wr[i].next = (i + 1 < k) ? &wr[i + 1] : nullptr; - } - { - ibv_send_wr* bad = nullptr; - int ret = ibv_post_send(ctx->qp, &wr[0], &bad); - if (ret) { - fprintf(stderr, "ibv_post_send(atomic) failed: %d (%s)\n", ret, - strerror(ret)); - if (bad) - fprintf(stderr, " bad wr_id=0x%llx\n", - (unsigned long long)bad->wr_id); - std::abort(); + { + size_t const local_ring_count = ctx->data_qps_by_channel.size(); + struct ibv_qp* qp = + local_ring_count + ? ctx->data_qps_by_channel[ring_idx_raw % local_ring_count] + : ctx->qp; + + size_t const kgroup = idxs.size(); + std::vector sge(kgroup); + std::vector wr(kgroup); + + std::vector ring_wrids; + ring_wrids.reserve(kgroup); + + for (size_t j = 0; j < kgroup; ++j) { + size_t i = idxs[j]; + auto const& cmd = cmds_to_post[i]; + uint64_t const wrid = wrs_to_post[i]; + ring_wrids.push_back(wrid); + + int v = static_cast(cmd.value); + if (v == kLargeAtomicValue) v = kMaxSendAtomicValue; + if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { + fprintf(stderr, "value=%d won't fit in 15 bits\n", v); + std::abort(); + } + uint32_t const off16 = static_cast(cmd.req_rptr) & 0xFFFFu; + int low_latency_buffer_idx = get_low_latency(cmd.cmd_type); + uint32_t const imm = + AtomicsImm::Pack(true, get_is_combine(cmd.cmd_type), v, off16, + low_latency_buffer_idx) + .GetImmData(); + sge[j].addr = reinterpret_cast(ctx->mr->addr); + sge[j].length = 0; + sge[j].lkey = ctx->mr->lkey; + + std::memset(&wr[j], 0, sizeof(wr[j])); + wr[j].wr_id = kAtomicWrTag | (wrid & kAtomicMask); + wr[j].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wr[j].send_flags = (j + 1 == kgroup) ? IBV_SEND_SIGNALED : 0; + wr[j].imm_data = htonl(imm); + wr[j].sg_list = &sge[j]; + wr[j].num_sge = 1; + wr[j].wr.rdma.remote_addr = ctx->remote_addr; + wr[j].wr.rdma.rkey = ctx->remote_rkey; + wr[j].next = (j + 1 < kgroup) ? &wr[j + 1] : nullptr; + } + { + ibv_send_wr* bad = nullptr; + int ret = ibv_post_send(qp, &wr[0], &bad); + if (ret) { + fprintf(stderr, "ibv_post_send(atomic) failed: %d (%s)\n", ret, + strerror(ret)); + if (bad) + fprintf(stderr, " bad wr_id=0x%llx\n", + (unsigned long long)bad->wr_id); + std::abort(); + } + } } - } #endif - uint64_t const batch_tail_wr = wr_ids.back(); - { - auto [it, inserted] = - S.wr_id_to_wr_ids.try_emplace(batch_tail_wr, std::move(wr_ids)); - if (!inserted) { - fprintf(stderr, - "thread_idx: %d, Error: tail wr_id %lu already exists " - "(map=%p, " - "size=%zu, dst_rank=%d)\n", - thread_idx, batch_tail_wr, (void*)&S.wr_id_to_wr_ids, - S.wr_id_to_wr_ids.size(), dst_rank); - std::abort(); - } } } }