diff --git a/ep/include/proxy_ctx.hpp b/ep/include/proxy_ctx.hpp index f87f70ce0..744cd4662 100644 --- a/ep/include/proxy_ctx.hpp +++ b/ep/include/proxy_ctx.hpp @@ -53,6 +53,14 @@ struct ProxyCtx { uint32_t dst_ack_qpn; struct ibv_ah* dst_ah = nullptr; + // Connectionless SRD support: multiple AHs and QPNs for different remote NICs + std::vector dst_ah_per_nic; + std::vector dst_qpn_per_nic; + std::vector dst_ack_qpn_per_nic; + std::vector remote_addr_per_nic; + std::vector remote_rkey_per_nic; + std::vector remote_len_per_nic; + // Remote memory uintptr_t remote_addr = 0; // Base address of remote rdma_buffer uint64_t remote_len = 0; diff --git a/ep/include/rdma.hpp b/ep/include/rdma.hpp index 92176e1d6..acb24288c 100644 --- a/ep/include/rdma.hpp +++ b/ep/include/rdma.hpp @@ -20,7 +20,7 @@ struct RDMAConnectionInfo { uint32_t ack_qp_num; uint32_t recv_ack_qp_num; uint32_t ack_psn; - uint32_t rkey; // Memory region key + uint32_t rkey; // Memory region keyf uintptr_t addr; // Buffer address uint64_t len; uint16_t lid; // Local ID @@ -29,7 +29,11 @@ struct RDMAConnectionInfo { // #ifdef EFA uint32_t num_rings; uint32_t data_qp_num[kChannelPerProxy]; - // #endif + + uint32_t num_nics; + uint8_t gid_per_nic[MAX_NUM_GPUS][16]; + uint32_t qp_num_per_nic[MAX_NUM_GPUS]; + uint32_t ack_qp_num_per_nic[MAX_NUM_GPUS]; }; struct PendingUpdate { @@ -301,6 +305,9 @@ void modify_qp_to_rtr(ProxyCtx& S, RDMAConnectionInfo* remote, void modify_qp_to_rts(ProxyCtx& S, RDMAConnectionInfo* local_info); void modify_qp_to_init(ProxyCtx& S); + +struct ibv_ah* create_ah(ProxyCtx& S, uint8_t* remote_gid); + void local_poll_completions(ProxyCtx& S, std::unordered_set& acked_wrs, int thread_idx, std::vector& ctx_by_tag); diff --git a/ep/src/proxy.cpp b/ep/src/proxy.cpp index 405dc73ac..fe3f684a8 100644 --- a/ep/src/proxy.cpp +++ b/ep/src/proxy.cpp @@ -279,6 +279,7 @@ void Proxy::init_common() { } } usleep(50 * 1000); + if (cfg_.use_normal_mode) { // if (cfg_.thread_idx != 0) { // return; diff --git a/ep/src/rdma.cpp b/ep/src/rdma.cpp index 5d040b139..bc82f632a 100644 --- a/ep/src/rdma.cpp +++ b/ep/src/rdma.cpp @@ -420,6 +420,12 @@ void create_per_thread_qp(ProxyCtx& S, void* gpu_buffer, size_t size, local_info->psn = 0; local_info->ack_psn = 0; fill_local_gid(S, local_info); + + local_info->num_nics = 0; + memset(local_info->gid_per_nic, 0, sizeof(local_info->gid_per_nic)); + memset(local_info->qp_num_per_nic, 0, sizeof(local_info->qp_num_per_nic)); + memset(local_info->ack_qp_num_per_nic, 0, + sizeof(local_info->ack_qp_num_per_nic)); } void modify_qp_to_init(ProxyCtx& S) { @@ -495,6 +501,18 @@ void modify_qp_to_rtr(ProxyCtx& S, RDMAConnectionInfo* remote, S.dst_qpn = remote->qp_num; S.dst_ack_qpn = remote->recv_ack_qp_num; S.dst_ah = create_ah(S, remote->gid); + + if (!use_normal_mode && remote->num_nics > 0) { + S.dst_ah_per_nic.resize(remote->num_nics); + S.dst_qpn_per_nic.resize(remote->num_nics); + S.dst_ack_qpn_per_nic.resize(remote->num_nics); + + for (uint32_t nic_idx = 0; nic_idx < remote->num_nics; ++nic_idx) { + S.dst_ah_per_nic[nic_idx] = create_ah(S, remote->gid_per_nic[nic_idx]); + S.dst_qpn_per_nic[nic_idx] = remote->qp_num_per_nic[nic_idx]; + S.dst_ack_qpn_per_nic[nic_idx] = remote->ack_qp_num_per_nic[nic_idx]; + } + } #endif if (use_normal_mode) { @@ -1022,17 +1040,33 @@ static void post_rdma_async_batched_fast_mode( qpx->comp_mask = 0; qpx->wr_flags = IBV_SEND_SIGNALED; + struct ibv_ah* selected_ah = ctx->dst_ah; + uint32_t selected_qpn = ctx->dst_qpn; + uintptr_t selected_remote_addr = ctx->remote_addr; + uint32_t selected_remote_rkey = ctx->remote_rkey; + uint64_t selected_remote_len = ctx->remote_len; + + if (!ctx->dst_ah_per_nic.empty()) { + size_t nic_idx = wrs_to_post[i] % ctx->dst_ah_per_nic.size(); + selected_ah = ctx->dst_ah_per_nic[nic_idx]; + selected_qpn = ctx->dst_qpn_per_nic[nic_idx]; + selected_remote_addr = ctx->remote_addr_per_nic[nic_idx]; + selected_remote_rkey = ctx->remote_rkey_per_nic[nic_idx]; + selected_remote_len = ctx->remote_len_per_nic[nic_idx]; + } + uint64_t remote_addr = - ctx->remote_addr + (cmd.req_rptr ? cmd.req_rptr : 0); - uint64_t remote_end = ctx->remote_addr + ctx->remote_len; + selected_remote_addr + (cmd.req_rptr ? cmd.req_rptr : 0); + uint64_t remote_end = selected_remote_addr + selected_remote_len; - if (remote_addr < ctx->remote_addr || + if (remote_addr < selected_remote_addr || remote_addr + cmd.bytes > remote_end) { fprintf(stderr, "[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, " "size=%zu), cmd.req_rptr: 0x%llx\n", (unsigned long long)remote_addr, cmd.bytes, - (unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len, + (unsigned long long)selected_remote_addr, + (size_t)selected_remote_len, (unsigned long long)cmd.req_rptr); cudaError_t err = cudaDeviceSynchronize(); if (err != cudaSuccess) { @@ -1052,7 +1086,8 @@ static void post_rdma_async_batched_fast_mode( get_low_latency(cmd.cmd_type), cmd.expert_idx, 1, my_rank) .GetImmData(); - ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); + ibv_wr_rdma_write_imm(qpx, selected_remote_rkey, remote_addr, + htonl(imm)); #else if (cmd.atomic_offset > 0 && cmd.atomic_val > 0) { int v = static_cast(cmd.atomic_val); @@ -1064,20 +1099,23 @@ static void post_rdma_async_batched_fast_mode( AtomicsImm::Pack(true, false, cmd.atomic_val, cmd.atomic_offset, get_low_latency(cmd.cmd_type)) .GetImmData(); - ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); + ibv_wr_rdma_write_imm(qpx, selected_remote_rkey, remote_addr, + htonl(imm)); } else if (j + 1 == k) { uint32_t imm = WriteImm::Pack(get_is_combine(cmd.cmd_type), get_low_latency(cmd.cmd_type), cmd.expert_idx, k, my_rank) .GetImmData(); - ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); + ibv_wr_rdma_write_imm(qpx, selected_remote_rkey, remote_addr, + htonl(imm)); } else { - ibv_wr_rdma_write(qpx, ctx->remote_rkey, remote_addr); + ibv_wr_rdma_write(qpx, selected_remote_rkey, remote_addr); } #endif uintptr_t laddr = cmd.req_lptr + reinterpret_cast(ctx->mr->addr); - ibv_wr_set_ud_addr(qpx, ctx->dst_ah, ctx->dst_qpn, QKEY); + + ibv_wr_set_ud_addr(qpx, selected_ah, selected_qpn, QKEY); ibv_wr_set_sge(qpx, ctx->mr->lkey, laddr, static_cast(cmd.bytes)); }