diff --git a/.gitignore b/.gitignore index c24ecf89b..4a8c877da 100644 --- a/.gitignore +++ b/.gitignore @@ -100,4 +100,6 @@ ep/figs ep/deep_ep_wrapper/deep_ep.egg-info/ *.json -*result.jsonl \ No newline at end of file +*result.jsonl + +ep/deep_ep_wrapper/sglang_profiles* \ No newline at end of file diff --git a/ep/Makefile b/ep/Makefile index 93aa112a6..8853e45a6 100644 --- a/ep/Makefile +++ b/ep/Makefile @@ -60,8 +60,8 @@ NVCCFLAGS += $(EFA_CFLAGS) $(GH_CFLAGS) $(NORMAL_CFLAGS) LDFLAGS += $(EFA_LDFLAGS) INCLUDES += $(EFA_CFLAGS) $(GH_CFLAGS) $(NORMAL_CFLAGS) -SRC_CPP := src/proxy.cpp src/rdma.cpp src/common.cpp src/peer_copy_worker.cpp src/uccl_proxy.cpp src/uccl_bench.cpp src/peer_copy_manager.cpp src/fifo.cpp -SRC_CU := src/bench_kernel.cu src/peer_copy.cu src/internode_ll.cu src/internode.cu src/layout.cu src/intranode.cu src/ep_runtime.cu +SRC_CPP := src/proxy.cpp src/rdma.cpp src/common.cpp src/uccl_proxy.cpp src/uccl_bench.cpp src/fifo.cpp +SRC_CU := src/bench_kernel.cu src/internode_ll.cu src/internode.cu src/layout.cu src/intranode.cu src/ep_runtime.cu OBJ_CPP := $(SRC_CPP:.cpp=.o) OBJ_CU := $(SRC_CU:.cu=.o) diff --git a/ep/README.md b/ep/README.md index b6c3fca2f..74692b1e8 100644 --- a/ep/README.md +++ b/ep/README.md @@ -80,8 +80,8 @@ combined_x, event, hook = buffer.low_latency_combine( Initialization and tear down: ```python -proxies, workers = initialize_uccl(scratch, num_rdma_bytes, rank, num_ranks, group, args.num_experts) -destroy_uccl(proxies, workers) +proxies = initialize_uccl(scratch, num_rdma_bytes, rank, num_ranks, group, args.num_experts) +destroy_uccl(proxies) ``` ## Benchmark diff --git a/ep/bench/Makefile b/ep/bench/Makefile index 2c5e4d3d1..08f224823 100644 --- a/ep/bench/Makefile +++ b/ep/bench/Makefile @@ -197,10 +197,10 @@ HEADERS += $(wildcard include/*.h include/*.cuh include/*.hpp) ifeq ($(HAS_EFA),1) SRC_CPP := ../src/proxy.cpp ../src/rdma.cpp ../src/common.cpp - SRC_CPP += ../src/peer_copy_worker.cpp ../src/uccl_proxy.cpp - SRC_CPP += ../src/uccl_bench.cpp ../src/peer_copy_manager.cpp ../src/fifo.cpp + SRC_CPP += ../src/uccl_proxy.cpp + SRC_CPP += ../src/uccl_bench.cpp ../src/fifo.cpp - SRC_CU := ../src/bench_kernel.cu ../src/peer_copy.cu + SRC_CU := ../src/bench_kernel.cu SRC_CU += ../src/internode_ll.cu ../src/internode.cu ../src/layout.cu SRC_CU += ../src/intranode.cu ../src/ep_runtime.cu diff --git a/ep/bench/benchmark_rdma_rb.py b/ep/bench/benchmark_rdma_rb.py index 1659038f0..f17b7a041 100644 --- a/ep/bench/benchmark_rdma_rb.py +++ b/ep/bench/benchmark_rdma_rb.py @@ -146,15 +146,7 @@ def run_rank1_remote( mode="remote", peers_meta_list=peers_meta_list, ) - device_index = int(os.environ.get("LOCAL_RANK", "0")) - workers = ep.PeerCopyManager(src_device=device_index) - workers.start_for_proxies(proxies) - print("[rank 1] PeerCopyManager started.", flush=True) time.sleep(5) - try: - workers.stop() - except Exception: - pass try: for p in proxies: p.stop() diff --git a/ep/bench/buffer.py b/ep/bench/buffer.py index bf12b2d5b..77cd0c389 100644 --- a/ep/bench/buffer.py +++ b/ep/bench/buffer.py @@ -96,13 +96,13 @@ def __init__( self.scratch = ep.get_rdma_buffer(num_rdma_bytes, device_index) rdma_buffer_ptr = self.scratch.data_ptr() - self.proxies, self.workers = initialize_uccl( + self.proxies = initialize_uccl( rdma_buffer_ptr, num_rdma_bytes, group.rank(), dist.get_world_size(group), group, - use_normal_mode=not low_latency_mode, + use_throughput_mode=not low_latency_mode, is_intranode=is_intranode, ) check_nvlink_connections(group) @@ -183,7 +183,7 @@ def destroy(self): self.runtime.destroy() self.runtime = None - destroy_uccl(self.proxies, self.workers) + destroy_uccl(self.proxies) @staticmethod def is_sm90_compiled(): diff --git a/ep/bench/test_internode_simple.py b/ep/bench/test_internode_simple.py index 1aea8e925..b7f7fa913 100644 --- a/ep/bench/test_internode_simple.py +++ b/ep/bench/test_internode_simple.py @@ -53,7 +53,7 @@ def test_simple_internode(rank: int, num_ranks: int, group: dist.ProcessGroup): scratch = torch.empty( scratch_nbytes, dtype=torch.uint8, device=f"cuda:{device_index}" ) - proxies, workers = initialize_uccl(scratch, scratch_nbytes, rank, num_ranks, group) + proxies = initialize_uccl(scratch, scratch_nbytes, rank, num_ranks, group) try: buffer = Buffer( @@ -146,7 +146,7 @@ def test_simple_internode(rank: int, num_ranks: int, group: dist.ProcessGroup): dist.barrier() print("[simple-test] ✓ Buffer destroyed", flush=True) - destroy_uccl(proxies, workers) + destroy_uccl(proxies) dist.barrier() diff --git a/ep/bench/utils.py b/ep/bench/utils.py index ef103dfce..5af36e982 100644 --- a/ep/bench/utils.py +++ b/ep/bench/utils.py @@ -171,17 +171,6 @@ def get_cpu_proxies_meta(proxies, rank, scratch_ptr, scratch_bytes, num_ranks, g torch.cuda.set_device(device_index) dist.all_gather_object(all_meta, meta, group=group) rank2meta = {m["rank"]: m for m in all_meta} - - # Debug: print IP distribution - ip_counts = {} - for m in all_meta: - ip = m["ip"] - ip_counts[ip] = ip_counts.get(ip, 0) + 1 - if rank == 0: - print(f"[DEBUG] IP distribution across {num_ranks} ranks:", flush=True) - for ip, count in ip_counts.items(): - print(f"[DEBUG] {ip}: {count} ranks", flush=True) - return rank2meta @@ -502,14 +491,8 @@ def initialize_uccl( group, num_experts=0, is_intranode=False, - use_normal_mode=False, + use_throughput_mode=False, ): - try: - for shm_file in glob.glob("/dev/shm/uccl_barrier_*"): - os.remove(shm_file) - except Exception: - pass - # Try to get local_rank from environment or infer from current device if "LOCAL_RANK" in os.environ: local_rank = int(os.environ["LOCAL_RANK"]) @@ -558,7 +541,7 @@ def initialize_uccl( num_experts=num_experts, num_ranks=num_ranks, num_nodes=num_nodes, - use_normal_mode=use_normal_mode, + use_throughput_mode=use_throughput_mode, is_intranode=is_intranode, ) proxies.append(proxy) @@ -578,35 +561,17 @@ def initialize_uccl( if not is_intranode: for proxy in proxies: proxy.start_dual() - - workers = None - # if hasattr(ep, "PeerCopyManager"): - # try: - # workers = ep.PeerCopyManager(src_device=local_rank) - # workers.start_for_proxies(proxies) - # if rank == 0: - # print("✓ PeerCopyManager started", flush=True) - # except Exception as e: - # if rank == 0: - # print(f"PeerCopyManager unavailable: {e}", flush=True) - time.sleep(3) - return proxies, workers + return proxies -def destroy_uccl(proxies, workers): +def destroy_uccl(proxies): # Use current device or fallback to LOCAL_RANK if "LOCAL_RANK" in os.environ: device_index = int(os.environ["LOCAL_RANK"]) else: device_index = torch.cuda.current_device() - if workers is not None: - try: - workers.stop() - except Exception: - pass - try: for p in proxies: p.stop() @@ -616,11 +581,6 @@ def destroy_uccl(proxies, workers): ep.unregister_proxy(device_index) except Exception: pass - try: - for shm_file in glob.glob("/dev/shm/uccl_barrier_*"): - os.remove(shm_file) - except Exception: - pass def per_token_cast_to_fp8(x: torch.Tensor): diff --git a/ep/deep_ep_wrapper/scripts/sglang_nccl_44000_prefill.sh b/ep/deep_ep_wrapper/scripts/sglang_nccl_44000_prefill.sh index b79408571..961110f14 100755 --- a/ep/deep_ep_wrapper/scripts/sglang_nccl_44000_prefill.sh +++ b/ep/deep_ep_wrapper/scripts/sglang_nccl_44000_prefill.sh @@ -40,6 +40,7 @@ export NCCL_SOCKET_IFNAME="^lo,docker" export FI_PROVIDER=efa export FI_EFA_USE_DEVICE_RDMA=1 export SGLANG_ENABLE_JIT_DEEPGEMM=1 +export SGLANG_TORCH_PROFILER_DIR=/workspace/uccl/ep/deep_ep_wrapper/sglang_profiles_nccl # Parameters MODEL_PATH="deepseek-ai/DeepSeek-R1-0528" diff --git a/ep/deep_ep_wrapper/scripts/sglang_uep_46000_prefill.sh b/ep/deep_ep_wrapper/scripts/sglang_uep_46000_prefill.sh index 6107a7de8..5b570e1a6 100755 --- a/ep/deep_ep_wrapper/scripts/sglang_uep_46000_prefill.sh +++ b/ep/deep_ep_wrapper/scripts/sglang_uep_46000_prefill.sh @@ -40,6 +40,7 @@ export NCCL_SOCKET_IFNAME="^lo,docker" export FI_PROVIDER=efa export FI_EFA_USE_DEVICE_RDMA=1 export SGLANG_ENABLE_JIT_DEEPGEMM=1 +export SGLANG_TORCH_PROFILER_DIR=/workspace/uccl/ep/deep_ep_wrapper/sglang_profiles # Parameters MODEL_PATH="deepseek-ai/DeepSeek-R1-0528" diff --git a/ep/include/common.hpp b/ep/include/common.hpp index 25c6ee08a..72294d800 100644 --- a/ep/include/common.hpp +++ b/ep/include/common.hpp @@ -14,22 +14,11 @@ #define MAX_IB_DEVS 32 // #define MEASURE_PER_OP_LATENCY -// #define MEASURE_PER_VERB_LATENCY - -// Barrier type selection (can be overridden at compile time) -#ifndef USE_SENDER_BARRIER -#ifdef EFA -#define USE_RECEIVER_BARRIER -#endif -#endif - -#ifdef EFA #define EFA_QP_LOW_LATENCY_SERVICE_LEVEL 8 + extern bool use_ll_sl; -#endif #define USE_MSCCLPP_FIFO_BACKEND -// #define USE_SUBSET_BARRIER #define kAtomicBufferSize 81960 #define kQueueSize 2048 #define kQueueMask (kQueueSize - 1) @@ -41,9 +30,7 @@ extern bool use_ll_sl; #define kIterations 40000 #define kNumProxyThs 4 #define kTestNumGpuThPerBlock 1 -#define kObjectSize 7168 // 7 KB -// #define kObjectSize 10752 // 10.5 KB -// #define kObjectSize 14336 // 14 KB +#define kObjectSize 7168 // 7 KB #define kMaxOutstandingSends 2048 // = max_send_wr, max_recv_wr, cq_depth / 2 #define kMaxOutstandingRecvs 2048 #define kSenderAckQueueDepth 2048 diff --git a/ep/include/peer_copy.cuh b/ep/include/peer_copy.cuh deleted file mode 100644 index 01350e9a4..000000000 --- a/ep/include/peer_copy.cuh +++ /dev/null @@ -1,37 +0,0 @@ -// peer_copy.cuh -#pragma once - -#include "ring_buffer.cuh" -#include "util/gpu_rt.h" - -template -__host__ __device__ constexpr Z divUp(X x, Y y) { - return (x + y - 1) / y; -} - -gpuError_t launch_peer_bulk_copy(void* dst_ptr, int dst_dev, void* src_ptr, - int src_dev, size_t bytes, - gpuStream_t stream = 0); - -gpuError_t launch_peer_bulk_copy2(CopyTask const* host_tasks, int num_tasks, - gpuStream_t stream, int src_device, - CopyTask*& d_tasks); - -__global__ void peer_copy_kernel_vec_batched(CopyTask const* __restrict__ tasks, - int num_tasks, - int tasks_per_block); - -template // 16 B per transaction -__global__ void peer_copy_kernel_vec_pipelined( - CopyTask const* __restrict__ tasks, int num_tasks, int tasks_per_block); - -HostToDeviceNVlinkBuffer* initialize_ring_buffer_for_nvlink_forwarding( - gpuStream_t stream); - -bool post_copy_task(HostToDeviceNVlinkBuffer* rb, CopyTask const* host_tasks, - int num_tasks, gpuStream_t stream, int src_device, - CopyTask*& d_tasks); - -extern "C" void launch_read_and_set_sys(int* addr, int new_val, - cudaStream_t stream); \ No newline at end of file diff --git a/ep/include/peer_copy_manager.hpp b/ep/include/peer_copy_manager.hpp deleted file mode 100644 index dc2854b38..000000000 --- a/ep/include/peer_copy_manager.hpp +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include "peer_copy_worker.hpp" -#include -#include -class UcclProxy; - -class PeerCopyManager { - public: - explicit PeerCopyManager(int src_device = 0); - ~PeerCopyManager(); - - void start_for_proxies(std::vector const& proxies); - void stop(); - - private: - PeerCopyShared shared_; - std::vector ctxs_; - std::vector threads_; -}; diff --git a/ep/include/peer_copy_worker.hpp b/ep/include/peer_copy_worker.hpp deleted file mode 100644 index c3e8c3524..000000000 --- a/ep/include/peer_copy_worker.hpp +++ /dev/null @@ -1,33 +0,0 @@ -#pragma once -#include "common.hpp" -#include "ring_buffer.cuh" -#include -#include -#include - -// Shared across all peer-copy workers on a process -struct PeerCopyShared { - // Controls the worker loop - std::atomic run{true}; - - // Source GPU for receiving host-side staging to device - int src_device = 0; -}; - -struct PeerWorkerCtx { - // Counters / timings - uint64_t async_memcpy_count = 0; - uint64_t prev_completed_async_memcpy_count = 0; - uint64_t highest_issued_wr_id = 0; - - // Batch buffers - CopyTask tasks[RECEIVER_BATCH_SIZE]; - uint64_t task_wrs[RECEIVER_BATCH_SIZE]; - - // CUDA resources - gpuStream_t stream = nullptr; - CopyTask* d_tasks = nullptr; // device buffer for tasks -}; - -void peer_copy_worker(PeerCopyShared& shared, PeerWorkerCtx& ctx, - CopyRingBuffer& ring, int idx); \ No newline at end of file diff --git a/ep/include/proxy.hpp b/ep/include/proxy.hpp index 8e74aaf7c..e86472267 100644 --- a/ep/include/proxy.hpp +++ b/ep/include/proxy.hpp @@ -48,8 +48,7 @@ class Proxy { int num_experts = 0; int num_ranks = 0; int num_nodes = 0; - bool use_normal_mode = - false; // Runtime flag for normal mode (batching optimization) + bool use_throughput_mode = false; bool is_intranode = false; }; @@ -70,7 +69,6 @@ class Proxy { void run_sender(); void run_remote(); - void run_local(); void run_dual(); void pin_thread_to_cpu_wrapper(); void pin_thread_to_numa_wrapper(); diff --git a/ep/include/rdma.hpp b/ep/include/rdma.hpp index 76ddc0626..e72f8559d 100644 --- a/ep/include/rdma.hpp +++ b/ep/include/rdma.hpp @@ -25,11 +25,8 @@ struct RDMAConnectionInfo { uint64_t len; uint16_t lid; // Local ID uint8_t gid[16]; // Global ID for RoCE (optional) - - // #ifdef EFA uint32_t num_rings; uint32_t data_qp_num[kChannelPerProxy]; - // #endif }; struct PendingUpdate { @@ -314,7 +311,7 @@ void send_connection_info_as_client(int my_rank, int peer, char const* peer_ip, RDMAConnectionInfo* local); void modify_qp_to_rtr(ProxyCtx& S, RDMAConnectionInfo* remote, - bool use_normal_mode); + bool use_throughput_mode); void modify_qp_to_rts(ProxyCtx& S, RDMAConnectionInfo* local_info); @@ -326,10 +323,10 @@ void remote_process_completions( ProxyCtx& S, int idx, CopyRingBuffer& ring, int ne, ibv_wc* wc, std::vector& ctx_by_tag, void* atomic_buffer_ptr, int num_ranks, int num_experts, std::set& pending_atomic_updates, - int my_rank, int num_nodes, bool use_normal_mode = false); + int my_rank, int num_nodes, bool use_throughput_mode = false); void create_per_thread_qp(ProxyCtx& S, void* gpu_buffer, size_t size, RDMAConnectionInfo* local_info, int rank, - size_t num_rings, bool use_normal_mode); + size_t num_rings, bool use_throughput_mode); ibv_cq* create_per_thread_cq(ProxyCtx& S); void remote_poll_completions(ProxyCtx& S, int idx, CopyRingBuffer& g_ring, std::vector& ctx_by_tag, @@ -337,18 +334,17 @@ void remote_poll_completions(ProxyCtx& S, int idx, CopyRingBuffer& g_ring, int num_experts, std::set& pending_atomic_updates, int my_rank, int num_nodes, - bool use_normal_mode = false); + bool use_throughput_mode = false); void per_thread_rdma_init(ProxyCtx& S, void* gpu_buf, size_t bytes, int rank, int thread_idx, int local_rank); -void remote_send_ack(ProxyCtx* ctx, struct ibv_qp* ack_qp, uint64_t& wr_id, - ibv_mr* local_ack_mr, uint64_t* ack_buf, int worker_idx); void local_post_ack_buf(ProxyCtx& S, int depth); void remote_reg_ack_buf(ibv_pd* pd, uint64_t* ack_buf, ibv_mr*& ack_mr); void post_rdma_async_batched(ProxyCtx& S, void* buf, size_t num_wrs, std::vector const& wrs_to_post, std::vector const& cmds_to_post, std::vector>& ctxs, - int my_rank, int thread_idx, bool use_normal_mode); + int my_rank, int thread_idx, + bool use_throughput_mode); void local_process_completions(ProxyCtx& S, std::unordered_set& acked_wrs, int thread_idx, ibv_wc* wc, int ne, @@ -358,14 +354,14 @@ void poll_cq_dual(ProxyCtx& S, std::unordered_set& acked_wrs, std::vector& ctx_by_tag, void* atomic_buffer_ptr, int num_ranks, int num_experts, std::set& pending_atomic_updates, int my_rank, - int num_nodes, bool use_normal_mode = false); + int num_nodes, bool use_throughput_mode = false); void post_atomic_operations(ProxyCtx& S, std::vector const& wrs_to_post, std::vector const& cmds_to_post, std::vector>& ctxs, int my_rank, int thread_idx, std::unordered_set& acked_wrs, - bool use_normal_mode); + bool use_throughput_mode); void apply_pending_updates(ProxyCtx& ctx, std::set& pending_atomic_updates, void* atomic_buffer_ptr, int num_experts, diff --git a/ep/include/ring_buffer.cuh b/ep/include/ring_buffer.cuh index 05c53aafa..701a29ca7 100644 --- a/ep/include/ring_buffer.cuh +++ b/ep/include/ring_buffer.cuh @@ -80,7 +80,7 @@ struct TransferCmd { union { // Low-latency mode uint16_t expert_idx; - // Normal mode + // Throughput mode uint16_t atomic_offset; }; }; diff --git a/ep/include/uccl_ibgda.cuh b/ep/include/uccl_ibgda.cuh index 41b48d1cd..05e28180f 100644 --- a/ep/include/uccl_ibgda.cuh +++ b/ep/include/uccl_ibgda.cuh @@ -23,7 +23,7 @@ namespace uccl { // Note(MaoZiming, Yang): the expert_idx here is used to tell which ring buffer // to use. The total concurrent warps can be say 64 (= number of experts), while // the number of ring buffers is small (say 6). -template +template __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp( uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_rank, int expert_idx, int lane_id, int message_idx, @@ -47,7 +47,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp( auto* h = reinterpret_cast( static_cast(d2h_channel_addrs[d2h_channel_idx])); - if constexpr (use_normal_mode) { + if constexpr (use_throughput_mode) { low_latency_buffer_idx == -1 ? expert_idx = 0 : 0; low_latency_buffer_idx == -1 ? low_latency_buffer_idx = 0 : 0; } @@ -62,7 +62,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp( cmd.req_lptr = lptr_val; cmd.bytes = bytes_val; cmd.dst_rank = dst_rank; - if constexpr (use_normal_mode) { + if constexpr (use_throughput_mode) { cmd.atomic_offset = atomic_offset; cmd.atomic_val = atomic_val; } else { @@ -82,7 +82,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp( cur_tail = h->tail(); inflight = cur_head - cur_tail; if (inflight < - (use_normal_mode ? kMaxInflightNormal : kMaxInflightLowLatency)) { + (use_throughput_mode ? kMaxInflightNormal : kMaxInflightLowLatency)) { uint64_t slot = cur_head; TransferCmd cmd{}; // TODO(MaoZiming): Check fields here. @@ -99,7 +99,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp( trap(); } - if constexpr (use_normal_mode) { + if constexpr (use_throughput_mode) { if (atomic_offset >> 16) { printf( "[nvshmemi_ibgda_put_nbi_warp] atomic_offset too large: %llu\n", @@ -134,7 +134,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_put_nbi_warp( // TODO(MaoZiming): Fix. This should be a non-fetch add operation. This could be // implemented with CPU proxy. -template +template __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add( uint64_t rptr, uint64_t atomic_base_addr, int const& value, int dst_rank, int warp_id, bool is_local_copy = false, @@ -145,7 +145,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add( atomicAdd(reinterpret_cast(rptr), static_cast(value)); } else { - if constexpr (use_normal_mode) { + if constexpr (use_throughput_mode) { if (skip_remote) return; } rptr -= atomic_base_addr; @@ -160,7 +160,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add( static_cast(d2h_channel_addrs[d2h_channel_idx])); auto last_print = clock64(); - if constexpr (use_normal_mode) { + if constexpr (use_throughput_mode) { /* Normal mode */ low_latency_buffer_idx == -1 ? low_latency_buffer_idx = 0 : 0; } @@ -185,7 +185,7 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add( cur_tail = h->tail(); inflight = cur_head - cur_tail; if (inflight < - (use_normal_mode ? kMaxInflightNormal : kMaxInflightLowLatency)) { + (use_throughput_mode ? kMaxInflightNormal : kMaxInflightLowLatency)) { uint64_t slot = cur_head; TransferCmd cmd{}; cmd.cmd_type = diff --git a/ep/include/uccl_proxy.hpp b/ep/include/uccl_proxy.hpp index 265ab8d3b..b3d937dcb 100644 --- a/ep/include/uccl_proxy.hpp +++ b/ep/include/uccl_proxy.hpp @@ -10,21 +10,16 @@ #include #include -class PeerCopyManager; - class UcclProxy { - friend class PeerCopyManager; - public: UcclProxy(int thread_idx, uintptr_t gpu_buffer_addr, size_t total_size, int rank, int node_idx, int local_rank, int num_experts = 0, - int num_ranks = 0, int num_nodes = 0, bool use_normal_mode = false, - bool is_intranode = false); + int num_ranks = 0, int num_nodes = 0, + bool use_throughput_mode = false, bool is_intranode = false); ~UcclProxy(); void start_sender(); void start_remote(); - void start_local(); void start_dual(); void stop(); int get_listen_port() const { return proxy_->get_listen_port(); } @@ -81,7 +76,7 @@ class UcclProxy { } private: - enum class Mode { None, Sender, Remote, Local, Dual }; + enum class Mode { None, Sender, Remote, Dual }; void start(Mode m); std::unique_ptr proxy_; diff --git a/ep/src/common.cpp b/ep/src/common.cpp index 368ab03c3..1a3ae113c 100644 --- a/ep/src/common.cpp +++ b/ep/src/common.cpp @@ -7,9 +7,7 @@ #include std::once_flag peer_ok_flag[MAX_NUM_GPUS][MAX_NUM_GPUS]; -#ifdef EFA bool use_ll_sl = false; -#endif bool pin_thread_to_cpu(int cpu) { int num_cpus = sysconf(_SC_NPROCESSORS_ONLN); diff --git a/ep/src/internode.cu b/ep/src/internode.cu index 2388bd387..921a14b5d 100644 --- a/ep/src/internode.cu +++ b/ep/src/internode.cu @@ -194,7 +194,7 @@ __global__ void notify_dispatch( rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank)); uint64_t src_ptr = reinterpret_cast( rdma_recv_num_tokens_mixed.send_buffer(i)); - uccl::nvshmemi_ibgda_put_nbi_warp( + uccl::nvshmemi_ibgda_put_nbi_warp( dst_ptr - reinterpret_cast(original_rdma_buffer_ptr), src_ptr - reinterpret_cast(original_rdma_buffer_ptr), (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) * sizeof(int), @@ -689,7 +689,7 @@ __global__ void __launch_bounds__( if (dst_rdma_rank != rdma_rank) { // NOTE(MaoZiming): this tells the remote rank how many tokens each // local nvl_rank and expert are expected to receive. - uccl::nvshmemi_ibgda_put_nbi_warp( + uccl::nvshmemi_ibgda_put_nbi_warp( reinterpret_cast( rdma_channel_meta.recv_buffer(rdma_rank)) - reinterpret_cast(original_rdma_buffer_ptr), @@ -1019,7 +1019,7 @@ __global__ void __launch_bounds__( auto const src_ptr = reinterpret_cast( rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_token); - uccl::nvshmemi_ibgda_put_nbi_warp( + uccl::nvshmemi_ibgda_put_nbi_warp( dst_ptr - reinterpret_cast(original_rdma_buffer_ptr), src_ptr - reinterpret_cast(original_rdma_buffer_ptr), num_bytes_per_msg, @@ -1039,7 +1039,7 @@ __global__ void __launch_bounds__( if (lane_id == dst_rdma_rank) { last_issued_tail += num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue; - uccl::nvshmemi_ibgda_amo_nonfetch_add( + uccl::nvshmemi_ibgda_amo_nonfetch_add( reinterpret_cast(rdma_channel_tail.buffer(rdma_rank)), reinterpret_cast(original_atomic_buffer_ptr), num_tokens_to_issue, @@ -1292,7 +1292,7 @@ __global__ void __launch_bounds__( if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - uccl::nvshmemi_ibgda_amo_nonfetch_add( + uccl::nvshmemi_ibgda_amo_nonfetch_add( reinterpret_cast(rdma_channel_head.buffer(rdma_rank)), reinterpret_cast(original_atomic_buffer_ptr), min_head - last_head, @@ -2603,7 +2603,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1) auto const src_ptr = reinterpret_cast( rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_token); - uccl::nvshmemi_ibgda_put_nbi_warp( + uccl::nvshmemi_ibgda_put_nbi_warp( dst_ptr - reinterpret_cast(original_rdma_buffer_ptr), src_ptr - reinterpret_cast(original_rdma_buffer_ptr), num_bytes_per_msg, @@ -2622,7 +2622,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1) // Write new RDMA tail __syncwarp(); if (lane_id == 0) { - uccl::nvshmemi_ibgda_amo_nonfetch_add( + uccl::nvshmemi_ibgda_amo_nonfetch_add( reinterpret_cast(rdma_channel_tail.buffer(rdma_rank)), reinterpret_cast(original_atomic_buffer_ptr), num_chunked_tokens, @@ -2785,7 +2785,7 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1) if (min_head != std::numeric_limits::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { - uccl::nvshmemi_ibgda_amo_nonfetch_add( + uccl::nvshmemi_ibgda_amo_nonfetch_add( reinterpret_cast(rdma_channel_head.buffer(rdma_rank)), reinterpret_cast(original_atomic_buffer_ptr), min_head - last_rdma_head, diff --git a/ep/src/peer_copy.cu b/ep/src/peer_copy.cu deleted file mode 100644 index 508cea811..000000000 --- a/ep/src/peer_copy.cu +++ /dev/null @@ -1,439 +0,0 @@ -// peer_copy.cu -#include "common.hpp" -#include "peer_copy.cuh" -#include "ring_buffer.cuh" -#include "util/gpu_rt.h" -#include -#ifndef __HIP_PLATFORM_AMD__ -#include -#endif - -#define NVLINK_SM_PER_PROCESS 1 - -__global__ void peer_copy_kernel(char const* __restrict__ src, - char* __restrict__ dst, size_t num_bytes) { - size_t idx = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; - size_t total_threads = (gridDim.x * gridDim.y) * blockDim.x; - - for (size_t i = idx; i < num_bytes; i += total_threads) { - dst[i] = src[i]; - } -} - -gpuError_t launch_peer_bulk_copy(void* dst_ptr, int dst_dev, void* src_ptr, - int src_dev, size_t bytes, - gpuStream_t stream) { - constexpr int threads_per_block = 256; - size_t total_threads = (bytes + threads_per_block - 1) / threads_per_block; - dim3 blocks; - blocks.x = (total_threads > 65535) ? 65535 - : static_cast(total_threads); - blocks.y = (total_threads + 65534) / 65535; - - peer_copy_kernel<<>>( - static_cast(src_ptr), static_cast(dst_ptr), bytes); - - return gpuGetLastError(); -} - -__device__ inline void copy128(char const* __restrict__ src, - char* __restrict__ dst) { - *reinterpret_cast(dst) = *reinterpret_cast(src); -} - -__global__ void peer_copy_kernel_vec(CopyTask const* __restrict__ tasks, - int num_tasks) { - int const task_id = blockIdx.x; - if (task_id >= num_tasks) return; - - const CopyTask t = tasks[task_id]; - char const* __restrict__ src = static_cast(t.src_ptr); - char* __restrict__ dst = static_cast(t.dst_ptr); - size_t nbytes = t.bytes; - - size_t i = threadIdx.x * 16; - for (; i + 15 < nbytes; i += blockDim.x * 16) { - copy128(src + i, dst + i); - } - - if (threadIdx.x == 0) { - for (size_t j = (nbytes & ~size_t(15)); j < nbytes; ++j) dst[j] = src[j]; - } -} - -__global__ void peer_copy_kernel_vec_batched(CopyTask const* __restrict__ tasks, - int num_tasks, - int tasks_per_block) { - int block_task_start = blockIdx.x * tasks_per_block; - int tid = threadIdx.x; - - for (int i = 0; i < tasks_per_block; ++i) { - int task_id = block_task_start + i; - if (task_id >= num_tasks) return; - - const CopyTask t = tasks[task_id]; - char const* __restrict__ src = static_cast(t.src_ptr); - char* __restrict__ dst = static_cast(t.dst_ptr); - size_t nbytes = t.bytes; - - size_t offset = tid * 16; - for (; offset + 15 < nbytes; offset += blockDim.x * 16) { - copy128(src + offset, dst + offset); - } - - if (tid == 0) { - for (size_t j = (nbytes & ~size_t(15)); j < nbytes; ++j) { - dst[j] = src[j]; - } - } - - __syncthreads(); // avoid interleaved accesses when multiple tasks per - // block - } -} - -gpuError_t launch_peer_bulk_copy2(CopyTask const* host_tasks, int num_tasks, - gpuStream_t stream, int src_device, - CopyTask*& d_tasks) { - GPU_RT_CHECK(gpuMemcpyAsync(d_tasks, host_tasks, num_tasks * sizeof(CopyTask), - gpuMemcpyHostToDevice, stream)); - constexpr int threads_per_block = 256; - dim3 blocks(NVLINK_SM_PER_PROCESS); - if (false) { - peer_copy_kernel_vec<<>>(d_tasks, - num_tasks); - } else if (true) { - int tasks_per_block = num_tasks / NVLINK_SM_PER_PROCESS; - peer_copy_kernel_vec_batched<<>>( - d_tasks, num_tasks, tasks_per_block); - } else { - int tasks_per_block = num_tasks / NVLINK_SM_PER_PROCESS; - size_t shmem = threads_per_block * 2 /*PIPE_DEPTH*/ * sizeof(int4); - peer_copy_kernel_vec_pipelined<2, int4> - <<>>(d_tasks, num_tasks, - tasks_per_block); - } - return gpuGetLastError(); -} - -#ifndef __HIP_PLATFORM_AMD__ -template -__global__ void peer_copy_kernel_vec_pipelined( - CopyTask const* __restrict__ tasks, int num_tasks, int tasks_per_block) { - extern __shared__ uint8_t shmem_raw[]; - VecT* __restrict__ ring = reinterpret_cast(shmem_raw); - - int const nThreads = blockDim.x; - int const tid = threadIdx.x; - int const blockTask0 = blockIdx.x * tasks_per_block; - - for (int local = 0; local < tasks_per_block; ++local) { - int const task_id = blockTask0 + local; - if (task_id >= num_tasks) return; - - CopyTask t = tasks[task_id]; - char const* __restrict__ src = static_cast(t.src_ptr); - char* __restrict__ dst = static_cast(t.dst_ptr); - const size_t nbytes = t.bytes; - - const size_t nVec = nbytes / sizeof(VecT); - const size_t vecPerThread = divUp(nVec, nThreads); - const size_t myFirst = tid * vecPerThread; - const size_t myLast = min(myFirst + vecPerThread, nVec); - - /* Two-slot ring-buffer in shared memory, one slot per outstanding - transaction. Each thread owns one VecT element inside every slot. */ - - // wr: slot we will find next - // rd: slot index we will retire to global memory next - // issued: how many DMAs that are still inflight. - int wr = 0, rd = 0, issued = 0; - - for (size_t v = myFirst; v < myLast; ++v) { - /* stage 1: async L2→shmem fetch */ - void const* gptr = src + v * sizeof(VecT); - void* sptr = &ring[wr * nThreads + tid]; - __pipeline_memcpy_async(sptr, gptr, sizeof(VecT)); - __pipeline_commit(); - - ++issued; - wr = (wr + 1) % PIPE_DEPTH; - - /* stage 2: retire oldest when PIPE_DEPTH requests in flight */ - if (issued == PIPE_DEPTH) { - __pipeline_wait_prior(PIPE_DEPTH - 1); - size_t dstIdx = v - (PIPE_DEPTH - 1); - *reinterpret_cast(dst + dstIdx * sizeof(VecT)) = - ring[rd * nThreads + tid]; - rd = (rd + 1) % PIPE_DEPTH; - --issued; - } - } - - /* drain remaining inflight transactions */ - while (issued) { - --issued; - __pipeline_wait_prior(issued); - size_t dstIdx = myLast - issued; - *reinterpret_cast(dst + dstIdx * sizeof(VecT)) = - ring[rd * nThreads + tid]; - rd = (rd + 1) % PIPE_DEPTH; - } - - if (tid == 0) { - for (size_t j = nVec * sizeof(VecT); j < nbytes; ++j) dst[j] = src[j]; - } - - __syncthreads(); - } -} -#else -// Manual implementation of pipeline functionality for HIP -// Since HIP doesn't have __pipeline_* functions, we implement a simplified -// version -template -__global__ void peer_copy_kernel_vec_pipelined( - CopyTask const* __restrict__ tasks, int num_tasks, int tasks_per_block) { - extern __shared__ uint8_t shmem_raw[]; - VecT* __restrict__ ring = reinterpret_cast(shmem_raw); - - int const nThreads = blockDim.x; - int const tid = threadIdx.x; - int const blockTask0 = blockIdx.x * tasks_per_block; - - for (int local = 0; local < tasks_per_block; ++local) { - int const task_id = blockTask0 + local; - if (task_id >= num_tasks) return; - - CopyTask t = tasks[task_id]; - char const* __restrict__ src = static_cast(t.src_ptr); - char* __restrict__ dst = static_cast(t.dst_ptr); - const size_t nbytes = t.bytes; - - const size_t nVec = nbytes / sizeof(VecT); - const size_t vecPerThread = divUp(nVec, nThreads); - const size_t myFirst = tid * vecPerThread; - const size_t myLast = min(myFirst + vecPerThread, nVec); - - // Manual pipelining implementation for HIP - // Since we don't have __pipeline_* functions, we use a simple staging - // approach - int wr = 0, rd = 0, issued = 0; - - for (size_t v = myFirst; v < myLast; ++v) { - // Stage 1: Manual copy from global memory to shared memory - VecT const* gptr = reinterpret_cast(src + v * sizeof(VecT)); - VecT* sptr = &ring[wr * nThreads + tid]; - - // Manual copy instead of __pipeline_memcpy_async - *sptr = *gptr; - __threadfence_block(); // Ensure shared memory write is visible - - ++issued; - wr = (wr + 1) % PIPE_DEPTH; - - // Stage 2: retire oldest when PIPE_DEPTH requests in flight - if (issued == PIPE_DEPTH) { - __syncthreads(); // Manual synchronization instead of - // __pipeline_wait_prior - size_t dstIdx = v - (PIPE_DEPTH - 1); - *reinterpret_cast(dst + dstIdx * sizeof(VecT)) = - ring[rd * nThreads + tid]; - rd = (rd + 1) % PIPE_DEPTH; - --issued; - } - } - - // drain remaining inflight transactions - while (issued) { - --issued; - __syncthreads(); // Manual synchronization - size_t dstIdx = myLast - issued; - *reinterpret_cast(dst + dstIdx * sizeof(VecT)) = - ring[rd * nThreads + tid]; - rd = (rd + 1) % PIPE_DEPTH; - } - - if (tid == 0) { - for (size_t j = nVec * sizeof(VecT); j < nbytes; ++j) dst[j] = src[j]; - } - - __syncthreads(); - } -} -#endif - -__device__ __forceinline__ unsigned long long atomicSubULL( - unsigned long long* addr, unsigned long long val) { - return atomicAdd( - reinterpret_cast(addr), - static_cast(-static_cast(val))); -} - -__device__ __forceinline__ bool pop_global(HostToDeviceNVlinkBuffer* rb, - CopyTask& out) { - // Reserve a slot atomically - const uint64_t my_tail = - atomicAdd(reinterpret_cast(&rb->tail), 1ULL); - - __threadfence_system(); - - // Check if we raced past the current head - if (my_tail >= rb->head) { - // undo reservation - atomicSubULL(reinterpret_cast(&rb->tail), 1ULL); - return false; - } - - out = rb->get_entry(my_tail); - return true; -} - -__global__ void peer_copy_kernel_vec_many(HostToDeviceNVlinkBuffer* rb) { - unsigned const lane = threadIdx.x & 0x1F; // 0–31 - - while (true) { - CopyTask task; - bool have = false; - if (lane == 0) have = pop_global(rb, task); -#ifndef __HIP_PLATFORM_AMD__ - have = __shfl_sync(0xFFFFFFFF, have, 0); -#else - have = __shfl(have, 0); -#endif - if (!have) continue; - - unsigned long long src_ll = 0, dst_ll = 0; - if (lane == 0) { - src_ll = (unsigned long long)task.src_ptr; - dst_ll = (unsigned long long)task.dst_ptr; - } -#ifndef __HIP_PLATFORM_AMD__ - src_ll = __shfl_sync(0xFFFFFFFF, src_ll, 0); - dst_ll = __shfl_sync(0xFFFFFFFF, dst_ll, 0); - size_t nbytes = __shfl_sync(0xFFFFFFFF, task.bytes, 0); -#else - src_ll = __shfl(src_ll, 0); - dst_ll = __shfl(dst_ll, 0); - size_t nbytes = __shfl(task.bytes, 0); -#endif - - char const* __restrict__ src = (char const*)src_ll; - char* __restrict__ dst = (char*)dst_ll; - -#if defined(DEBUG) || !defined(NDEBUG) - if (((uintptr_t)src & 0xF) || ((uintptr_t)dst & 0xF)) { - // rare – but don’t crash the whole grid - if (lane == 0) - for (size_t i = 0; i < nbytes; ++i) dst[i] = src[i]; - continue; - } -#endif - - size_t offset = lane * 16; - for (; offset + 127 < nbytes; offset += 32 * 16) - copy128(src + offset, dst + offset); - - if (lane == 0) { - for (size_t i = (nbytes & ~size_t(127)); i < nbytes; ++i) dst[i] = src[i]; - } - } -} - -__global__ void peer_copy_kernel_vec_persistent(HostToDeviceNVlinkBuffer* rb) -// Only one thread polls task, doesn't work. -{ - __shared__ CopyTask sm_task; - - while (true) { - if (threadIdx.x == 0) { - if (!rb->pop(sm_task)) sm_task.bytes = 0; - } - __syncthreads(); - if (sm_task.bytes == 0) { - continue; - } - - char const* __restrict__ src = static_cast(sm_task.src_ptr); - char* __restrict__ dst = static_cast(sm_task.dst_ptr); - size_t nbytes = sm_task.bytes; - - size_t offset = threadIdx.x * 16; - for (; offset + 127 < nbytes; offset += blockDim.x * 16) - copy128(src + offset, dst + offset); - - if (threadIdx.x == 0) { - for (size_t i = (nbytes & ~size_t(127)); i < nbytes; ++i) dst[i] = src[i]; - } - __syncthreads(); - } -} - -HostToDeviceNVlinkBuffer* initialize_ring_buffer_for_nvlink_forwarding( - gpuStream_t stream) { - HostToDeviceNVlinkBuffer* rb; - gpuError_t err = - gpuHostAlloc(reinterpret_cast(&rb), - sizeof(HostToDeviceNVlinkBuffer), gpuHostAllocMapped); - if (err != gpuSuccess) { - fprintf(stderr, "Error allocating ring buffer for NVLink forwarding: %s\n", - gpuGetErrorString(err)); - std::abort(); - } - - new (rb) HostToDeviceNVlinkBuffer{}; - constexpr int threads_per_block = 256; - dim3 blocks(NVLINK_SM_PER_PROCESS); - peer_copy_kernel_vec_many<<>>(rb); - err = gpuGetLastError(); - if (err != gpuSuccess) { - fprintf(stderr, "Error launching kernel for NVLink forwarding: %s\n", - gpuGetErrorString(err)); - std::abort(); - } - return rb; -} - -bool post_copy_task(HostToDeviceNVlinkBuffer* rb, CopyTask const* host_tasks, - int num_tasks, gpuStream_t stream, int src_device, - CopyTask*& d_tasks) { - uint64_t cur_head = rb->head; - uint64_t cur_tail = rb->volatile_tail(); - - int free_slots = rb->capacity - (cur_head - cur_tail); - if (free_slots < num_tasks) { - // printf( - // "Not enough free slots in ring buffer: %d available, %d requested, " - // "rb->capacity: %d, cur_head: %lu, cur_tail: %lu\n", - // free_slots, num_tasks, rb->capacity, cur_head, cur_tail); - return false; - } - for (int i = 0; i < num_tasks; ++i) { - rb->set_buffer(cur_head + i, host_tasks[i]); - } - rb->commit_with_head(cur_head + num_tasks); - return true; -} - -__device__ __forceinline__ void st_release_sys_global(int* ptr, int val) { -#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) - __atomic_store_n(ptr, val, __ATOMIC_RELEASE); -#else - asm volatile("st.release.sys.global.s32 [%0], %1;" ::"l"(ptr), "r"(val) - : "memory"); -#endif -} - -__global__ void read_and_set_sys(int* addr, int new_val) { - if (threadIdx.x == 0 && blockIdx.x == 0) { - printf("addr: %p\n", addr); - int old = *addr; - printf("Before st_release_sys_global %p, new: %d\n", addr, old + new_val); - st_release_sys_global(addr, old + new_val); - } -} - -extern "C" void launch_read_and_set_sys(int* addr, int new_val, - cudaStream_t stream) { - read_and_set_sys<<<1, 1, 0, stream>>>(addr, new_val); -} \ No newline at end of file diff --git a/ep/src/peer_copy_manager.cpp b/ep/src/peer_copy_manager.cpp deleted file mode 100644 index 1510064bf..000000000 --- a/ep/src/peer_copy_manager.cpp +++ /dev/null @@ -1,30 +0,0 @@ -#include "peer_copy_manager.hpp" -#include "uccl_proxy.hpp" - -PeerCopyManager::PeerCopyManager(int src_device) { - shared_.src_device = src_device; - shared_.run.store(true, std::memory_order_release); -} -PeerCopyManager::~PeerCopyManager() { stop(); } - -void PeerCopyManager::start_for_proxies( - std::vector const& proxies) { - int const n = static_cast(proxies.size()); - if (n <= 0) return; - ctxs_.resize(n); - threads_.reserve(n); - for (int i = 0; i < n; ++i) { - threads_.emplace_back(peer_copy_worker, std::ref(shared_), - std::ref(ctxs_[i]), - std::ref(proxies[i]->proxy_->ring), i); - } -} - -void PeerCopyManager::stop() { - if (threads_.empty()) return; - shared_.run.store(false, std::memory_order_release); - for (auto& t : threads_) - if (t.joinable()) t.join(); - threads_.clear(); - ctxs_.clear(); -} diff --git a/ep/src/peer_copy_worker.cpp b/ep/src/peer_copy_worker.cpp deleted file mode 100644 index dae7cf730..000000000 --- a/ep/src/peer_copy_worker.cpp +++ /dev/null @@ -1,86 +0,0 @@ -#include "peer_copy_worker.hpp" -#include "common.hpp" -#include "peer_copy.cuh" -#include "proxy.hpp" -#include "rdma.hpp" -#include - -void sync_and_post(PeerWorkerCtx& ctx, CopyRingBuffer& ring, - gpuStream_t& stream, int idx) { - if (ctx.async_memcpy_count > ctx.prev_completed_async_memcpy_count) { - gpuError_t err = gpuStreamSynchronize(stream); - if (err != gpuSuccess) { - fprintf(stderr, "Kernel execution failed: %s\n", gpuGetErrorString(err)); - std::abort(); - } - remote_send_ack((ProxyCtx*)ring.ctx, ring.ack_qp, ctx.highest_issued_wr_id, - ring.ack_mr, ring.ack_buf, idx); - ctx.prev_completed_async_memcpy_count = ctx.async_memcpy_count; - } -} - -void peer_copy_worker(PeerCopyShared& shared, PeerWorkerCtx& ctx, - CopyRingBuffer& ring, int idx) { - pin_thread_to_cpu(idx + 1 + MAIN_THREAD_CPU_IDX); - // TODO(MaoZiming): improves pinning. - printf("Peer copy worker %d started on CPU core %d\n", idx + 1, - sched_getcpu()); - gpuStream_t stream; - GPU_RT_CHECK(gpuSetDevice(shared.src_device)); - GPU_RT_CHECK(gpuStreamCreate(&stream)); - CopyTask* d_tasks; - GPU_RT_CHECK( - gpuMallocAsync(&d_tasks, RECEIVER_BATCH_SIZE * sizeof(CopyTask), stream)); - - while (shared.run.load(std::memory_order_acquire)) { - CopyTask t; - int copy_batch_size = 0; - if (RECEIVER_BATCH_SIZE == 1) { - if (!ring.pop(t)) { - continue; - } - copy_batch_size = 1; - ctx.tasks[0] = t; - } else { - int n = ring.popN(ctx.tasks, RECEIVER_BATCH_SIZE); - if (n == 0) { - continue; - } - t = ctx.tasks[0]; - copy_batch_size = n; - } - for (int i = 0; i < copy_batch_size; ++i) { - maybe_enable_peer_access(shared.src_device, ctx.tasks[i].dst_dev); - ctx.task_wrs[i] = ctx.tasks[i].wr_id; - } - ctx.highest_issued_wr_id = - std::max(ctx.highest_issued_wr_id, ctx.task_wrs[copy_batch_size - 1]); - // NOTE(MaoZiming): peer_copy.cu has some kernels such as - // launch_peer_bulk_copy2 that might be good. - gpuError_t err; - if (t.dst_ptr) { - err = - gpuMemcpyPeerAsync(t.dst_ptr, t.dst_dev, t.src_ptr, shared.src_device, - t.bytes * copy_batch_size, stream); - } - std::string func_name = "gpuMemcpyPeerAsync"; - if (err != gpuSuccess) { - fprintf(stderr, - "%s failed (%s) wr_id=%llu\n" - " dst_ptr=%p dst_dev=%d\n" - " src_ptr=%p src_dev=%d\n" - " size=%zu (bytes)\n" - " stream=%p\n", - func_name.c_str(), gpuGetErrorString(err), - static_cast(t.wr_id), t.dst_ptr, t.dst_dev, - t.src_ptr, shared.src_device, - static_cast(t.bytes * copy_batch_size), (void*)stream); - std::abort(); - } - ctx.async_memcpy_count += copy_batch_size; - sync_and_post(ctx, ring, stream, idx); - } - GPU_RT_CHECK(gpuFreeAsync(d_tasks, stream)); - GPU_RT_CHECK(gpuStreamSynchronize(stream)); - GPU_RT_CHECK(gpuStreamDestroy(stream)); -} \ No newline at end of file diff --git a/ep/src/proxy.cpp b/ep/src/proxy.cpp index 7248f030c..5f8c01dd4 100644 --- a/ep/src/proxy.cpp +++ b/ep/src/proxy.cpp @@ -3,7 +3,7 @@ #include "d2h_queue_host.hpp" #include "ep_util.hpp" #include "util/util.h" -#include // for htonl, ntohl +#include #include #include #include @@ -13,71 +13,6 @@ #include #include -#ifndef USE_SUBSET_BARRIER -static std::string shm_name_for_barrier(std::string const& ip, int thread_idx) { - // Include UID to avoid cross-user collisions on /dev/shm which cause EACCES - // when a leftover object is owned by a different user. - uid_t uid = getuid(); - return "/uccl_barrier_" + ip + "_uid" + std::to_string(uid) + "_th" + - std::to_string(thread_idx); -} - -LocalBarrier* map_local_barrier_shm(std::string const& name, bool* out_owner) { - *out_owner = false; - size_t const kSize = sizeof(LocalBarrier); - mode_t const kMode = 0600; - int fd = shm_open(name.c_str(), O_RDWR | O_CREAT | O_EXCL, kMode); - if (fd >= 0) { - *out_owner = true; - if (ftruncate(fd, kSize) != 0) { - perror("ftruncate(LocalBarrier)"); - close(fd); - shm_unlink(name.c_str()); - return nullptr; - } - } else { - if (errno != EEXIST) { - perror("shm_open"); - return nullptr; - } - fd = shm_open(name.c_str(), O_RDWR, kMode); - if (fd < 0) { - perror("shm_open(existing)"); - return nullptr; - } - struct stat st {}; - int tries = 1000; - while (tries-- > 0) { - if (fstat(fd, &st) == 0 && static_cast(st.st_size) >= kSize) - break; - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - if (tries < 0) { - fprintf(stderr, - "map_local_barrier_shm: existing shm '%s' never sized to %zu\n", - name.c_str(), kSize); - close(fd); - return nullptr; - } - } - void* p = mmap(nullptr, kSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); - int saved_errno = errno; - close(fd); - if (p == MAP_FAILED) { - errno = saved_errno; - perror("mmap(LocalBarrier)"); - return nullptr; - } - return reinterpret_cast(p); -} - -void unmap_local_barrier_shm(std::string const& name, LocalBarrier* lb, - bool owner) { - if (lb) munmap(lb, sizeof(LocalBarrier)); - if (owner) shm_unlink(name.c_str()); -} -#endif - Proxy::Proxy(Config const& cfg) : cfg_(cfg) { // Initialize state tracking for each ring buffer listen_port_ = uccl::create_listen_socket(&listen_fd_); @@ -102,7 +37,6 @@ uint64_t Proxy::completed_wr() const { return completion_count_; } void Proxy::pin_thread_to_cpu_wrapper() { if (cfg_.pin_thread) { - // TODO(MaoZiming): improves pinning. pin_thread_to_cpu(cfg_.thread_idx + cfg_.local_rank * kNumProxyThs); int cpu = sched_getcpu(); if (cpu == -1) { @@ -121,11 +55,7 @@ void Proxy::pin_thread_to_numa_wrapper() { assert(ctx_.numa_node != -1); pin_thread_unique(ctx_.numa_node, cfg_.local_rank, cfg_.thread_idx, kNumProxyThs); - - // Get the actual CPU this thread is running on int cpu = sched_getcpu(); - - // Get the affinity mask (optional but useful) cpu_set_t cpuset; CPU_ZERO(&cpuset); pthread_getaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset); @@ -164,11 +94,27 @@ void Proxy::set_bench_d2h_channel_addrs(std::vector const& addrs) { for (auto addr : addrs) { d2hq::HostD2HHandle h{}; - d2hq::init_from_addr(h, addr); // unified initialization + d2hq::init_from_addr(h, addr); cfg_.d2h_queues.push_back(h); } } +static void maybe_modify_qp_to_init(ProxyCtx& c) { +#ifndef EFA + modify_qp_to_init(c); +#else + (void)c; +#endif +} + +static void maybe_post_receive_buffer_for_imm(ProxyCtx& ctx) { +#ifndef EFA + post_receive_buffer_for_imm(ctx); +#else + (void)ctx; +#endif +} + void Proxy::init_common() { int const my_rank = cfg_.rank; @@ -222,12 +168,13 @@ void Proxy::init_common() { if (peer == my_rank) continue; // Skip rdma connection for intra-node. if (peers_[peer].ip == peers_[my_rank].ip) continue; - if (cfg_.use_normal_mode && std::abs(peer - my_rank) % MAX_NUM_GPUS != 0) + if (cfg_.use_throughput_mode && + std::abs(peer - my_rank) % MAX_NUM_GPUS != 0) continue; create_per_thread_qp(c, cfg_.gpu_buffer, cfg_.total_size, &local_infos_[peer], my_rank, cfg_.d2h_queues.size(), - cfg_.use_normal_mode); - modify_qp_to_init(c); + cfg_.use_throughput_mode); + maybe_modify_qp_to_init(c); } usleep(50 * 1000); @@ -237,7 +184,7 @@ void Proxy::init_common() { for (int peer = 0; peer < num_ranks; ++peer) { // Skip rdma connection for intra-node. if (peer == my_rank || peers_[peer].ip == peers_[my_rank].ip || - (cfg_.use_normal_mode && + (cfg_.use_throughput_mode && std::abs(peer - my_rank) % MAX_NUM_GPUS != 0)) continue; int actual_peer; @@ -249,7 +196,8 @@ void Proxy::init_common() { // Then send our info to all peers for (int peer = 0; peer < num_ranks; ++peer) { if (peer == my_rank || peers_[peer].ip == peers_[my_rank].ip || - (cfg_.use_normal_mode && std::abs(peer - my_rank) % MAX_NUM_GPUS != 0)) + (cfg_.use_throughput_mode && + std::abs(peer - my_rank) % MAX_NUM_GPUS != 0)) continue; char const* peer_ip = peers_[peer].ip.c_str(); int const peer_listen_port = peers_[peer].listen_ports[cfg_.thread_idx]; @@ -263,7 +211,8 @@ void Proxy::init_common() { // Verify remote info correctness for (int peer = 0; peer < num_ranks; ++peer) { if (peer == my_rank || peers_[peer].ip == peers_[my_rank].ip || - (cfg_.use_normal_mode && std::abs(peer - my_rank) % MAX_NUM_GPUS != 0)) + (cfg_.use_throughput_mode && + std::abs(peer - my_rank) % MAX_NUM_GPUS != 0)) continue; if (remote_infos_[peer].addr != peers_[peer].ptr) { fprintf(stderr, @@ -280,12 +229,13 @@ void Proxy::init_common() { if (peer == my_rank) continue; // Skip rdma connection for intra-node. if (peers_[peer].ip == peers_[my_rank].ip) continue; - if (cfg_.use_normal_mode && std::abs(peer - my_rank) % MAX_NUM_GPUS != 0) + if (cfg_.use_throughput_mode && + std::abs(peer - my_rank) % MAX_NUM_GPUS != 0) continue; auto& c = *ctxs_for_all_ranks_[peer]; // qp is different from each rank. - modify_qp_to_rtr(c, &remote_infos_[peer], cfg_.use_normal_mode); + modify_qp_to_rtr(c, &remote_infos_[peer], cfg_.use_throughput_mode); modify_qp_to_rts(c, &local_infos_[peer]); c.remote_addr = remote_infos_[peer].addr; @@ -300,11 +250,7 @@ void Proxy::init_common() { } } usleep(50 * 1000); - if (cfg_.use_normal_mode) { - // if (cfg_.thread_idx != 0) { - // return; - // } - // Discover local ranks (same IP as me) + if (cfg_.use_throughput_mode) { std::string const my_ip = peers_[cfg_.rank].ip; std::vector local_ranks; local_ranks.reserve(ctxs_for_all_ranks_.size()); @@ -325,26 +271,6 @@ void Proxy::init_common() { ctx_.num_local_ranks, (int)UCCL_MAX_LOCAL_RANKS); std::abort(); } -#ifndef USE_SUBSET_BARRIER - std::string const shm_name = shm_name_for_barrier(my_ip, cfg_.thread_idx); - ctx_.lb = map_local_barrier_shm(shm_name, &ctx_.lb_owner); - if (!ctx_.lb) { - fprintf(stderr, "Failed to map local barrier shm: %s\n", - shm_name.c_str()); - std::abort(); - } - if (ctx_.lb_owner) { - ctx_.lb->full_mask = (ctx_.num_local_ranks >= 64) - ? ~0ULL - : ((1ULL << ctx_.num_local_ranks) - 1ULL); - for (int i = 0; i < ctx_.num_local_ranks; ++i) { - ctx_.lb->arrive_seq[i].store(0, std::memory_order_relaxed); - ctx_.lb->release_seq[i].store(0, std::memory_order_relaxed); - } - } else { - while (ctx_.lb->full_mask == 0ULL) cpu_relax(); - } -#endif } #ifdef USE_MSCCLPP_FIFO_BACKEND @@ -367,9 +293,7 @@ void Proxy::init_remote() { local_post_ack_buf(*ctx_ptr, kSenderAckQueueDepth); remote_reg_ack_buf(ctx_ptr->pd, ring.ack_buf, ring.ack_mr); ring.ack_qp = ctx_ptr->ack_qp; -#ifndef EFA - post_receive_buffer_for_imm(*ctx_ptr); -#endif + maybe_post_receive_buffer_for_imm(*ctx_ptr); } void Proxy::run_sender() { @@ -392,13 +316,11 @@ void Proxy::run_remote() { remote_poll_completions(ctx_, cfg_.thread_idx, ring, ctx_by_tag_, atomic_buffer_ptr_, cfg_.num_ranks, cfg_.num_experts, pending_atomic_updates, cfg_.rank, - cfg_.num_nodes, cfg_.use_normal_mode); -#ifdef USE_RECEIVER_BARRIER - if (!cfg_.use_normal_mode) { + cfg_.num_nodes, cfg_.use_throughput_mode); + if (!cfg_.use_throughput_mode) { apply_pending_updates(ctx_, pending_atomic_updates, atomic_buffer_ptr_, cfg_.num_experts, cfg_.num_ranks); } -#endif } } @@ -407,16 +329,15 @@ void Proxy::run_dual() { for (int peer = 0; peer < (int)ctxs_for_all_ranks_.size(); ++peer) { if (peer == cfg_.rank) continue; if (peers_[peer].ip == peers_[cfg_.rank].ip) continue; - if (cfg_.use_normal_mode && std::abs(peer - cfg_.rank) % MAX_NUM_GPUS != 0) + if (cfg_.use_throughput_mode && + std::abs(peer - cfg_.rank) % MAX_NUM_GPUS != 0) continue; auto& ctx_ptr = ctxs_for_all_ranks_[peer]; if (!ctx_ptr) continue; local_post_ack_buf(*ctx_ptr, kSenderAckQueueDepth); remote_reg_ack_buf(ctx_ptr->pd, ring.ack_buf, ring.ack_mr); ring.ack_qp = ctx_ptr->ack_qp; -#ifndef EFA - post_receive_buffer_for_imm(*ctx_ptr); -#endif + maybe_post_receive_buffer_for_imm(*ctx_ptr); } uint64_t my_tail = 0; size_t seen = 0; @@ -425,29 +346,15 @@ void Proxy::run_dual() { poll_cq_dual(ctx_, acked_wrs_, cfg_.thread_idx, ring, ctx_by_tag_, atomic_buffer_ptr_, cfg_.num_ranks, cfg_.num_experts, pending_atomic_updates, cfg_.rank, cfg_.num_nodes, - cfg_.use_normal_mode); + cfg_.use_throughput_mode); notify_gpu_completion(my_tail); post_gpu_command(my_tail, seen); -#ifdef USE_RECEIVER_BARRIER - if (!cfg_.use_normal_mode) { + if (!cfg_.use_throughput_mode) { apply_pending_updates(ctx_, pending_atomic_updates, atomic_buffer_ptr_, cfg_.num_experts, cfg_.num_ranks); } -#endif - -#ifdef USE_SENDER_BARRIER - if (!cfg_.use_normal_mode) { - auto postponed_wr_ids = postponed_wr_ids_; - auto postponed_atomics = postponed_atomics_; - postponed_wr_ids_.clear(); - postponed_atomics_.clear(); - assert(postponed_wr_ids.size() == postponed_atomics.size()); - assert(postponed_wr_ids_.size() == 0); - post_gpu_commands_mixed(postponed_wr_ids, postponed_atomics); - } -#endif - if (cfg_.use_normal_mode) { + if (cfg_.use_throughput_mode) { barrier_check(); } } @@ -512,22 +419,19 @@ void Proxy::post_gpu_command(uint64_t& my_tail, size_t& seen) { for (size_t rb_idx = 0; rb_idx < cfg_.d2h_queues.size(); rb_idx++) { d2hq::HostD2HHandle* h = &cfg_.d2h_queues[rb_idx]; #ifdef USE_MSCCLPP_FIFO_BACKEND - assert(h && "h is empty!\n"); - assert(h->fifo && "h->fifo is empty!\n"); // FIFO path: one trigger == one command. Do NOT pop yet. auto* fifo = h->fifo; if (!fifo) continue; // Available budget for this FIFO. size_t pending = fifo_pending_[rb_idx].size(); size_t kMaxInflight = - cfg_.use_normal_mode ? kMaxInflightNormal : kMaxInflightLowLatency; + cfg_.use_throughput_mode ? kMaxInflightNormal : kMaxInflightLowLatency; size_t budget = (kMaxInflight > pending) ? (kMaxInflight - pending) : 0; for (size_t take = 0; take < budget; ++take) { auto trig = fifo->poll(); if (trig.fst == 0) break; TransferCmd cmd = d2hq::decode_from_trigger(trig); - - /* For some reason, this is important for correctness */ + /* This is important for correctness */ /* It cannot be if (ctx_.barrier_inflight) */ if (get_base_cmd(cmd.cmd_type) == CmdType::BARRIER && ctx_.barrier_inflight) { @@ -617,10 +521,6 @@ void Proxy::post_gpu_command(uint64_t& my_tail, size_t& seen) { uint64_t unique_wr_id = (rb_idx << 32) | i; wrs_to_post.push_back(unique_wr_id); cmds_to_post.push_back(cmd_entry); -#ifdef MEASURE_PER_VERB_LATENCY - wr_id_to_start_time_[unique_wr_id] = - std::chrono::high_resolution_clock::now(); -#endif ring_seen = i + 1; } #endif @@ -645,109 +545,6 @@ void Proxy::post_gpu_command(uint64_t& my_tail, size_t& seen) { cmds_to_post.clear(); } -void Proxy::run_local() { - pin_thread_to_cpu_wrapper(); - printf("Local CPU thread %d started with %zu ring buffers\n", cfg_.thread_idx, - cfg_.d2h_queues.size()); - - if (cfg_.d2h_queues.empty()) { - printf("Error: No ring buffers available for local mode\n"); - return; - } - - int total_seen = 0; - while (true) { - if (!ctx_.progress_run.load(std::memory_order_acquire)) { - printf("Local thread %d stopping early at total_seen=%d\n", - cfg_.thread_idx, total_seen); - return; - } - - bool found_work = false; - - // Multi-ring buffer polling (consistent with other modes) - for (size_t rb_idx = 0; rb_idx < cfg_.d2h_queues.size(); rb_idx++) { - d2hq::HostD2HHandle* h = &cfg_.d2h_queues[rb_idx]; -#ifdef USE_MSCCLPP_FIFO_BACKEND - auto* fifo = h->fifo; - if (!fifo) continue; - auto trig = fifo->poll(); - if (trig.fst != 0) { - d2hq::decode_from_trigger(trig); - fifo->pop(); - total_seen++; - found_work = true; - } -#else - uint64_t& ring_tail = ring_tails_[rb_idx]; - - // Check for new work in this ring buffer - uint64_t cur_head = h->volatile_head(); - if (cur_head == ring_tail) { - continue; // No new work in this ring - } - - // Process commands from this ring buffer - while (ring_tail < cur_head) { - uint64_t const idx = ring_tail & kQueueMask; - CmdType cmd; - auto last_print = std::chrono::steady_clock::now(); - size_t spin_count = 0; - do { - cmd = h->volatile_load_cmd_type(idx); - cpu_relax(); - - auto now = std::chrono::steady_clock::now(); - if (now - last_print > std::chrono::seconds(10)) { - printf( - "Still waiting at thread %d, ring %zu, total_seen=%d, " - "spin_count=%zu, ring_tail=%lu, cmd: %d\n", - cfg_.thread_idx, rb_idx, total_seen, spin_count, ring_tail, - static_cast(cmd)); - last_print = now; - spin_count++; - } - - if (!ctx_.progress_run.load(std::memory_order_acquire)) { - printf("Local thread %d stopping early at total_seen=%d\n", - cfg_.thread_idx, total_seen); - return; - } - } while (cmd == CmdType::EMPTY); - -#ifdef DEBUG_PRINT - printf( - "Local thread %d, ring %zu, total_seen=%d head=%lu tail=%lu " - "consuming cmd=%llu\n", - cfg_.thread_idx, rb_idx, total_seen, h->head, ring_tail, - static_cast(cmd)); -#endif - - std::atomic_thread_fence(std::memory_order_acquire); - - // Mark command as processed - h->volatile_clear_cmd_type(idx); - ring_tail++; - h->cpu_volatile_store_tail(ring_tail); - total_seen++; - found_work = true; - - // Break to check other ring buffers and progress_run flag - break; - } -#endif - } - - // If no work found across all ring buffers, relax CPU - if (!found_work) { - cpu_relax(); - } - } - - printf("Local thread %d finished %d commands across %zu ring buffers\n", - cfg_.thread_idx, total_seen, cfg_.d2h_queues.size()); -} - void Proxy::post_gpu_commands_mixed( std::vector const& wrs_to_post, std::vector const& cmds_to_post) { @@ -758,62 +555,8 @@ void Proxy::post_gpu_commands_mixed( for (size_t i = 0; i < cmds_to_post.size(); ++i) { switch (get_base_cmd(cmds_to_post[i].cmd_type)) { case (CmdType::ATOMIC): { -#ifdef USE_SENDER_BARRIER - if (!cfg_.use_normal_mode) { - int value = cmds_to_post[i].value; - uint32_t offset = static_cast(cmds_to_post[i].req_rptr); - uint32_t new_offset = - offset - get_low_latency(cmds_to_post[i].cmd_type) * - align(cfg_.num_experts * sizeof(int), 128); - size_t new_index = new_offset / sizeof(int); - int expected_value; - int expert_idx; - - if (get_is_combine(cmds_to_post[i].cmd_type)) { - expert_idx = new_index; - expected_value = ctx_.combine_sent_counter.Get( - {get_low_latency(cmds_to_post[i].cmd_type), expert_idx, - cmds_to_post[i].dst_rank}); - } else { - expert_idx = new_index / cfg_.num_ranks; - expected_value = ctx_.dispatch_sent_counter.Get( - {get_low_latency(cmds_to_post[i].cmd_type), expert_idx, - cmds_to_post[i].dst_rank}); - value = -value - 1; - } - if (value != expected_value) { - postponed_atomics_.push_back(cmds_to_post[i]); - postponed_wr_ids_.push_back(wrs_to_post[i]); - assert(postponed_atomics_.size() == postponed_wr_ids_.size()); - continue; - } - } -#endif - atomic_wrs.push_back(wrs_to_post[i]); atomic_cmds.push_back(cmds_to_post[i]); - -#ifdef USE_SENDER_BARRIER - if (!cfg_.use_normal_mode) { - uint32_t offset = static_cast(cmds_to_post[i].req_rptr); - uint32_t new_offset = - offset - get_low_latency(cmds_to_post[i].cmd_type) * - align(cfg_.num_experts * sizeof(int), 128); - size_t new_index = new_offset / sizeof(int); - int expert_idx; - if (get_is_combine(cmds_to_post[i].cmd_type)) { - expert_idx = new_index; - ctx_.combine_sent_counter.Reset( - {get_low_latency(cmds_to_post[i].cmd_type), expert_idx, - cmds_to_post[i].dst_rank}); - } else { - expert_idx = new_index / cfg_.num_ranks; - ctx_.dispatch_sent_counter.Reset( - {get_low_latency(cmds_to_post[i].cmd_type), expert_idx, - cmds_to_post[i].dst_rank}); - } - } -#endif break; } case (CmdType::WRITE): { @@ -847,7 +590,7 @@ void Proxy::post_gpu_commands_mixed( if (!rdma_wrs.empty()) { post_rdma_async_batched(ctx_, cfg_.gpu_buffer, rdma_wrs.size(), rdma_wrs, rdma_cmds, ctxs_for_all_ranks_, cfg_.rank, - cfg_.thread_idx, cfg_.use_normal_mode); + cfg_.thread_idx, cfg_.use_throughput_mode); rdma_wrs.clear(); rdma_cmds.clear(); } @@ -855,24 +598,18 @@ void Proxy::post_gpu_commands_mixed( if (!atomic_wrs.empty()) { post_atomic_operations(ctx_, atomic_wrs, atomic_cmds, ctxs_for_all_ranks_, cfg_.rank, cfg_.thread_idx, acked_wrs_, - cfg_.use_normal_mode); + cfg_.use_throughput_mode); atomic_wrs.clear(); atomic_cmds.clear(); } if (!barrier_cmds.empty()) { -#ifdef USE_MSCCLPP_FIFO_BACKEND - assert(barrier_wrs.size() == 1 && ctx_.barrier_wr == -1); -#endif send_barrier(barrier_wrs[0]); barrier_wrs.clear(); barrier_cmds.clear(); } if (!quiet_cmds.empty()) { -#ifdef USE_MSCCLPP_FIFO_BACKEND - assert(quiet_wrs.size() == 1 && ctx_.quiet_wr == -1); -#endif ctx_.quiet_wr = quiet_wrs[0]; quiet(quiet_wrs, quiet_cmds); quiet_wrs.clear(); @@ -904,13 +641,11 @@ void Proxy::quiet_cq() { remote_process_completions( ctx_, cfg_.thread_idx, ring, ne, wc, ctx_by_tag_, atomic_buffer_ptr_, cfg_.num_ranks, cfg_.num_experts, pending_atomic_updates, cfg_.rank, - cfg_.num_nodes, cfg_.use_normal_mode); -#ifdef USE_RECEIVER_BARRIER - if (!cfg_.use_normal_mode) { + cfg_.num_nodes, cfg_.use_throughput_mode); + if (!cfg_.use_throughput_mode) { apply_pending_updates(ctx_, pending_atomic_updates, atomic_buffer_ptr_, cfg_.num_experts, cfg_.num_ranks); } -#endif } else { ++empty_iters; } @@ -1012,15 +747,6 @@ void Proxy::destroy(bool free_gpu_buffer) { ibv_close_device(ctx_.context); ctx_.context = nullptr; } -#ifndef USE_SUBSET_BARRIER - std::string const my_ip = - (cfg_.rank < (int)peers_.size()) ? peers_[cfg_.rank].ip : ""; - std::string const shm_name = shm_name_for_barrier(my_ip, cfg_.thread_idx); - unmap_local_barrier_shm(shm_name, ctx_.lb, ctx_.lb_owner); - ctx_.lb = nullptr; - ctx_.lb_owner = false; -#endif - acked_wrs_.clear(); wr_id_to_start_time_.clear(); ctxs_for_all_ranks_.clear(); @@ -1029,13 +755,7 @@ void Proxy::destroy(bool free_gpu_buffer) { remote_infos_.clear(); } -void Proxy::post_barrier_msg(int dst_rank, bool ack, uint64_t seq) { - ProxyCtx* ctx = ctxs_for_all_ranks_[dst_rank].get(); - if (!ctx || !ctx->qp || !ctx->mr) { - fprintf(stderr, "barrier_msg: bad ctx for dst=%d\n", dst_rank); - std::abort(); - } - uint32_t imm = BarrierImm::Pack(ack, (uint32_t)seq, (uint8_t)cfg_.rank); +static void post_barrier_msg_efa(ProxyCtx* ctx, uint32_t imm) { #ifdef EFA auto* qpx = (struct ibv_qp_ex*)ctx->qp; int barrier_seq = 0; @@ -1054,6 +774,13 @@ void Proxy::post_barrier_msg(int dst_rank, bool ack, uint64_t seq) { std::abort(); } #else + (void)ctx; + (void)imm; +#endif +} + +static void post_barrier_msg_non_efa(ProxyCtx* ctx, uint32_t imm) { +#ifndef EFA ibv_sge sge{}; sge.addr = (uintptr_t)ctx->mr->addr; sge.length = 0; @@ -1078,6 +805,23 @@ void Proxy::post_barrier_msg(int dst_rank, bool ack, uint64_t seq) { fprintf(stderr, " bad wr_id=%llu\n", (unsigned long long)bad->wr_id); std::abort(); } +#else + (void)ctx; + (void)imm; +#endif +} + +void Proxy::post_barrier_msg(int dst_rank, bool ack, uint64_t seq) { + ProxyCtx* ctx = ctxs_for_all_ranks_[dst_rank].get(); + if (!ctx || !ctx->qp || !ctx->mr) { + fprintf(stderr, "barrier_msg: bad ctx for dst=%d\n", dst_rank); + std::abort(); + } + uint32_t imm = BarrierImm::Pack(ack, (uint32_t)seq, (uint8_t)cfg_.rank); +#ifdef EFA + post_barrier_msg_efa(ctx, imm); +#else + post_barrier_msg_non_efa(ctx, imm); #endif } @@ -1096,29 +840,16 @@ void Proxy::send_barrier(uint64_t wr) { ctx_.barrier_arrival_count = 0; } } -#ifndef USE_SUBSET_BARRIER - auto* lb = ctx_.lb; - lb->seq.store(ctx_.barrier_seq, std::memory_order_release); - lb->arrive_seq[ctx_.local_rank].store((uint32_t)ctx_.barrier_seq, - std::memory_order_release); - - uint64_t bit = (1ULL << (uint64_t)ctx_.local_rank); - lb->arrived_mask.fetch_or(bit, std::memory_order_acq_rel); -#endif } -#ifdef USE_SUBSET_BARRIER void Proxy::barrier_check() { if (!ctx_.barrier_inflight) return; uint64_t const seq = ctx_.barrier_seq; - // Node leader aggregates local arrivals static thread_local uint64_t last_sent_seq = 0; if (last_sent_seq != seq) { last_sent_seq = seq; if (cfg_.node_idx == 0) { - // Global leader: mark self-arrival; remote arrivals will come via - // your existing CQ handler. if (ctx_.barrier_arrived.empty()) { ctx_.barrier_arrived.assign(ctxs_for_all_ranks_.size(), 0); ctx_.barrier_arrival_count = 0; @@ -1178,106 +909,3 @@ void Proxy::barrier_check() { #endif } } -#else - -void Proxy::barrier_check() { - if (!ctx_.barrier_inflight) return; - - auto* lb = ctx_.lb; - uint64_t const seq = ctx_.barrier_seq; - - // Node leader aggregates local arrivals - if (cfg_.rank == ctx_.node_leader_rank) { - bool all_local_arrived = true; - for (int lr = 0; lr < ctx_.num_local_ranks; ++lr) { - uint32_t seen = lb->arrive_seq[lr].load(std::memory_order_acquire); - if ((uint32_t)seen != seq) { - all_local_arrived = false; - break; - } - } - if (all_local_arrived) { - static thread_local uint64_t last_sent_seq = 0; - if (last_sent_seq != seq) { - last_sent_seq = seq; - if (cfg_.rank == 0) { - // Global leader: mark self-arrival; remote arrivals will come via - // your existing CQ handler. - if (ctx_.barrier_arrived.empty()) { - ctx_.barrier_arrived.assign(ctxs_for_all_ranks_.size(), 0); - ctx_.barrier_arrival_count = 0; - } - if (!ctx_.barrier_arrived[0]) { - ctx_.barrier_arrived[0] = 1; - ++ctx_.barrier_arrival_count; - } - } else { - post_barrier_msg(/*dst=*/0, /*ack=*/false, seq); - } - } - - if (cfg_.rank == 0) { - if (ctx_.barrier_arrival_count == cfg_.num_nodes) { - std::unordered_map leader_for_ip; - for (int r = 0; r < (int)peers_.size(); ++r) { - auto it = leader_for_ip.find(peers_[r].ip); - if (it == leader_for_ip.end() || r < it->second) { - assert(r % MAX_NUM_GPUS == 0); - leader_for_ip[peers_[r].ip] = r; - } - } - for (auto const& kv : leader_for_ip) { - std::string const& ip = kv.first; - int leader_r = kv.second; - if (ip == peers_[0].ip) continue; - post_barrier_msg(leader_r, true, seq); - } - - for (int lr = 0; lr < ctx_.num_local_ranks; ++lr) { - lb->release_seq[lr].store(seq, std::memory_order_release); - } - ctx_.barrier_arrived.clear(); - ctx_.barrier_arrival_count = 0; - - acked_wrs_.insert(ctx_.barrier_wr); -#ifndef USE_MSCCLPP_FIFO_BACKEND - ctx_.barrier_inflight = false; - ctx_.barrier_wr = -1; -#endif - return; - } - } - - // When global release comes back (CQ handler should set these): - if (ctx_.barrier_released && ctx_.barrier_release_seq == seq) { - // Fan-out to local ranks via shared memory - for (int lr = 0; lr < ctx_.num_local_ranks; ++lr) { - lb->release_seq[lr].store(seq, std::memory_order_release); - } - // Reset local mask for next barrier and consume the global release - ctx_.barrier_released = false; - - // Complete WR - acked_wrs_.insert(ctx_.barrier_wr); -#ifndef USE_MSCCLPP_FIFO_BACKEND - ctx_.barrier_inflight = false; - ctx_.barrier_wr = -1; -#endif - } - } - return; - } else { - assert(!ctx_.barrier_released && - "This can only be set by local leader thread."); - } - - // Followers: wait until leader sets our release_seq - if (lb->release_seq[ctx_.local_rank].load(std::memory_order_acquire) == seq) { - acked_wrs_.insert(ctx_.barrier_wr); -#ifndef USE_MSCCLPP_FIFO_BACKEND - ctx_.barrier_inflight = false; - ctx_.barrier_wr = -1; -#endif - } -} -#endif \ No newline at end of file diff --git a/ep/src/rdma.cpp b/ep/src/rdma.cpp index 78a66c220..a16f5ac14 100644 --- a/ep/src/rdma.cpp +++ b/ep/src/rdma.cpp @@ -1,7 +1,5 @@ #include "rdma.hpp" #include "common.hpp" -#include "peer_copy.cuh" -#include "peer_copy_worker.hpp" #include "proxy_ctx.hpp" #include "rdma_util.hpp" #include "util/gpu_rt.h" @@ -96,6 +94,61 @@ void send_connection_info_as_client(int my_rank, int peer, char const* peer_ip, close(sockfd); } +// Helper functions to extract EFA-specific logic +static bool should_include_nic_candidate( + std::pair const& p, uint32_t min_d) { +#ifdef EFA + return (p.second == min_d && strncmp(p.first.c_str(), "rdmap", 5) == 0); +#else + if (!uccl::is_iface_up(p.first)) return false; + return (p.second == min_d); +#endif +} + +static void select_nic_for_efa(std::vector const& candidates, + int thread_idx, int local_rank, + std::string& selected_nic_name) { +#ifdef EFA + // NOTE(MaoZiming): This is a temporary hack. + if (candidates.size() == 8) { + // On p5, there are 8 NICs with the same distance. + auto half = (local_rank % 2) * 4; + // GPU0 uses candidates[0/1/2/3], GPU1 uses candidates[4/5/6/7], etc. + selected_nic_name = candidates[thread_idx % 4 + half]; + use_ll_sl = true; + } else if (candidates.size() == 4) { + // On p5e/p5en, there are 4 NICs with the same distance. + // We hardcode the first half Proxies to use the first NIC, and the + // second half to use the second NIC. + auto half = (local_rank % 2) * 2; + // GPU0 uses candidates[0/1], GPU1 uses candidates[2/3], etc. + selected_nic_name = candidates[thread_idx % 2 + half]; + use_ll_sl = true; + } else { + // On p6-b200, there is 2 NICs with the same distance. + assert(candidates.size() == 2); + auto half = (local_rank % 2) * 1; + selected_nic_name = candidates[thread_idx % 1 + half]; + use_ll_sl = true; + } +#else + (void)candidates; + (void)thread_idx; + (void)local_rank; + (void)selected_nic_name; +#endif +} + +static uint64_t get_mr_access_flags() { +#ifdef EFA + return IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_RELAXED_ORDERING; +#else + return IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_REMOTE_ATOMIC | IBV_ACCESS_RELAXED_ORDERING; +#endif +} + void per_thread_rdma_init(ProxyCtx& S, void* gpu_buf, size_t bytes, int rank, int thread_idx, int local_rank) { if (S.context) return; // already initialized @@ -153,13 +206,9 @@ void per_thread_rdma_init(ProxyCtx& S, void* gpu_buf, size_t bytes, int rank, // Collect all NICs with equal minimum distance std::vector candidates; for (auto& p : dist) { -#ifdef EFA - if (p.second == min_d && strncmp(p.first.c_str(), "rdmap", 5) == 0) + if (should_include_nic_candidate(p, min_d)) { candidates.push_back(p.first); -#else - if (!uccl::is_iface_up(p.first)) continue; - if (p.second == min_d) candidates.push_back(p.first); -#endif + } } if (candidates.empty()) { @@ -169,30 +218,7 @@ void per_thread_rdma_init(ProxyCtx& S, void* gpu_buf, size_t bytes, int rank, // Spread GPUs across equal-distance NICs: use local GPU index modulo // For example, pass in `local_rank` or derive gpu_index from device path selected_nic_name = candidates[thread_idx % candidates.size()]; -#ifdef EFA - // NOTE(MaoZiming): This is a temporary hack. - if (candidates.size() == 8) { - // On p5, there are 8 NICs with the same distance. - auto half = (local_rank % 2) * 4; - // GPU0 uses candidates[0/1/2/3], GPU1 uses candidates[4/5/6/7], etc. - selected_nic_name = candidates[thread_idx % 4 + half]; - use_ll_sl = true; - } else if (candidates.size() == 4) { - // On p5e/p5en, there are 4 NICs with the same distance. - // We hardcode the first half Proxies to use the first NIC, and the - // second half to use the second NIC. - auto half = (local_rank % 2) * 2; - // GPU0 uses candidates[0/1], GPU1 uses candidates[2/3], etc. - selected_nic_name = candidates[thread_idx % 2 + half]; - use_ll_sl = true; - } else { - // On p6-b200, there is 2 NICs with the same distance. - assert(candidates.size() == 2); - auto half = (local_rank % 2) * 1; - selected_nic_name = candidates[thread_idx % 1 + half]; - use_ll_sl = true; - } -#endif + select_nic_for_efa(candidates, thread_idx, local_rank, selected_nic_name); } } @@ -226,16 +252,7 @@ void per_thread_rdma_init(ProxyCtx& S, void* gpu_buf, size_t bytes, int rank, exit(1); } uint64_t iova = (uintptr_t)gpu_buf; -#ifndef EFA - S.mr = ibv_reg_mr_iova2(S.pd, gpu_buf, bytes, iova, - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | - IBV_ACCESS_REMOTE_ATOMIC | - IBV_ACCESS_RELAXED_ORDERING); -#else - S.mr = ibv_reg_mr_iova2(S.pd, gpu_buf, bytes, iova, - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | - IBV_ACCESS_RELAXED_ORDERING); -#endif + S.mr = ibv_reg_mr_iova2(S.pd, gpu_buf, bytes, iova, get_mr_access_flags()); if (!S.mr) { perror("ibv_reg_mr failed"); @@ -250,8 +267,7 @@ void per_thread_rdma_init(ProxyCtx& S, void* gpu_buf, size_t bytes, int rank, S.rkey = S.mr->rkey; } -ibv_cq* create_per_thread_cq(ProxyCtx& S) { - int cq_depth = kMaxOutstandingSends * 2; +static ibv_cq* create_cq_efa(ProxyCtx& S, int cq_depth) { #ifdef EFA struct ibv_cq_init_attr_ex cq_ex_attr = {}; cq_ex_attr.cqe = cq_depth; @@ -266,11 +282,32 @@ ibv_cq* create_per_thread_cq(ProxyCtx& S) { // See `efa_create_cq_ex` in rdma-core. cq_ex_attr.wc_flags = IBV_WC_STANDARD_FLAGS; - S.cq = (struct ibv_cq*)ibv_create_cq_ex(S.context, &cq_ex_attr); + return (struct ibv_cq*)ibv_create_cq_ex(S.context, &cq_ex_attr); +#else + (void)S; + (void)cq_depth; + return nullptr; +#endif +} + +static ibv_cq* create_cq_non_efa(ProxyCtx& S, int cq_depth) { +#ifndef EFA + return ibv_create_cq(S.context, /* cqe */ cq_depth, + /* user_context */ nullptr, + /* channel */ nullptr, /* comp_vector */ 0); +#else + (void)S; + (void)cq_depth; + return nullptr; +#endif +} + +ibv_cq* create_per_thread_cq(ProxyCtx& S) { + int cq_depth = kMaxOutstandingSends * 2; +#ifdef EFA + S.cq = create_cq_efa(S, cq_depth); #else - S.cq = - ibv_create_cq(S.context, /* cqe */ cq_depth, /* user_context */ nullptr, - /* channel */ nullptr, /* comp_vector */ 0); + S.cq = create_cq_non_efa(S, cq_depth); #endif if (!S.cq) { perror("Failed to create CQ"); @@ -350,23 +387,24 @@ struct ibv_qp* create_srd_qp_ex(ProxyCtx& S) { } #endif -void create_per_thread_qp(ProxyCtx& S, void* gpu_buffer, size_t size, - RDMAConnectionInfo* local_info, int rank, - size_t num_rings, bool use_normal_mode) { - if (S.qp) return; // Already initialized for this thread - if (S.ack_qp) return; - if (S.recv_ack_qp) return; +static void create_qps_efa(ProxyCtx& S) { #ifdef EFA S.qp = create_srd_qp_ex(S); S.ack_qp = create_srd_qp_ex(S); S.recv_ack_qp = create_srd_qp_ex(S); #else + (void)S; +#endif +} + +static void create_qps_non_efa(ProxyCtx& S) { +#ifndef EFA struct ibv_qp_init_attr qp_init_attr = {}; qp_init_attr.send_cq = S.cq; qp_init_attr.recv_cq = S.cq; - qp_init_attr.qp_type = IBV_QPT_RC; // Reliable Connection - qp_init_attr.cap.max_send_wr = kMaxOutstandingSends; // max outstanding sends - qp_init_attr.cap.max_recv_wr = kMaxOutstandingSends; // max outstanding recvs + qp_init_attr.qp_type = IBV_QPT_RC; + qp_init_attr.cap.max_send_wr = kMaxOutstandingSends; + qp_init_attr.cap.max_recv_wr = kMaxOutstandingSends; qp_init_attr.cap.max_send_sge = 1; qp_init_attr.cap.max_recv_sge = 1; qp_init_attr.sq_sig_all = 0; @@ -386,9 +424,24 @@ void create_per_thread_qp(ProxyCtx& S, void* gpu_buffer, size_t size, perror("Failed to create Receive Ack QP"); exit(1); } +#else + (void)S; +#endif +} + +void create_per_thread_qp(ProxyCtx& S, void* gpu_buffer, size_t size, + RDMAConnectionInfo* local_info, int rank, + size_t num_rings, bool use_throughput_mode) { + if (S.qp) return; + if (S.ack_qp) return; + if (S.recv_ack_qp) return; +#ifdef EFA + create_qps_efa(S); +#else + create_qps_non_efa(S); #endif - if (use_normal_mode) { + if (use_throughput_mode) { size_t const rings_to_create = std::min(num_rings, (size_t)kChannelPerProxy); S.data_qps_by_channel.resize(rings_to_create); @@ -396,6 +449,15 @@ void create_per_thread_qp(ProxyCtx& S, void* gpu_buffer, size_t size, #ifdef EFA S.data_qps_by_channel[r] = create_srd_qp_ex(S); #else + struct ibv_qp_init_attr qp_init_attr = {}; + qp_init_attr.send_cq = S.cq; + qp_init_attr.recv_cq = S.cq; + qp_init_attr.qp_type = IBV_QPT_RC; + qp_init_attr.cap.max_send_wr = kMaxOutstandingSends; + qp_init_attr.cap.max_recv_wr = kMaxOutstandingSends; + qp_init_attr.cap.max_send_sge = 1; + qp_init_attr.cap.max_recv_sge = 1; + qp_init_attr.sq_sig_all = 0; S.data_qps_by_channel[r] = ibv_create_qp(S.pd, &qp_init_attr); #endif if (!S.data_qps_by_channel[r]) { @@ -434,9 +496,6 @@ void create_per_thread_qp(ProxyCtx& S, void* gpu_buffer, size_t size, } void modify_qp_to_init(ProxyCtx& S) { -#ifdef EFA - return; -#endif struct ibv_qp_attr attr; memset(&attr, 0, sizeof(attr)); @@ -501,14 +560,12 @@ struct ibv_ah* create_ah(ProxyCtx& S, uint8_t* remote_gid) { } void modify_qp_to_rtr(ProxyCtx& S, RDMAConnectionInfo* remote, - bool use_normal_mode) { + bool use_throughput_mode) { #ifdef EFA S.dst_qpn = remote->qp_num; S.dst_ack_qpn = remote->recv_ack_qp_num; S.dst_ah = create_ah(S, remote->gid); -#endif - - if (use_normal_mode) { + if (use_throughput_mode) { S.dst_data_qpn_by_ring.clear(); uint32_t const remote_rings = std::min(remote->num_rings, (uint32_t)kChannelPerProxy); @@ -517,10 +574,17 @@ void modify_qp_to_rtr(ProxyCtx& S, RDMAConnectionInfo* remote, S.dst_data_qpn_by_ring.push_back(remote->data_qp_num[r]); } } - -#ifdef EFA return; #endif + if (use_throughput_mode) { + S.dst_data_qpn_by_ring.clear(); + uint32_t const remote_rings = + std::min(remote->num_rings, (uint32_t)kChannelPerProxy); + S.dst_data_qpn_by_ring.reserve(remote_rings); + for (uint32_t r = 0; r < remote_rings; ++r) { + S.dst_data_qpn_by_ring.push_back(remote->data_qp_num[r]); + } + } int is_roce = 0; @@ -629,6 +693,8 @@ void modify_qp_to_rtr(ProxyCtx& S, RDMAConnectionInfo* remote, void modify_qp_to_rts(ProxyCtx& S, RDMAConnectionInfo* local_info) { #ifdef EFA + (void)S; + (void)local_info; return; #endif struct ibv_qp_attr attr; @@ -703,8 +769,211 @@ void post_receive_buffer_for_imm(ProxyCtx& S) { } } -// Normal mode implementation -static void post_rdma_async_batched_normal_mode( +static void efa_handle_ring_idx_throughput_mode( + ProxyCtx* ctx, int dst_rank, int my_rank, size_t ring_idx_raw, + std::vector const& idxs, std::vector const& wrs_to_post, + std::vector const& cmds_to_post) { + const size_t local_ring_count = ctx->data_qps_by_channel.size(); + struct ibv_qp_ex* qpx = + (struct ibv_qp_ex*)(local_ring_count + ? ctx->data_qps_by_channel[ring_idx_raw % + local_ring_count] + : ctx->ack_qp); + + size_t const remote_ring_count = ctx->dst_data_qpn_by_ring.size(); + uint32_t const dst_qpn = + remote_ring_count + ? ctx->dst_data_qpn_by_ring[ring_idx_raw % remote_ring_count] + : ctx->dst_qpn; + + ibv_wr_start(qpx); + // No receiver barrier: build a single chain for this ring group + std::vector ring_wrids; + ring_wrids.reserve(idxs.size()); + + for (size_t j = 0; j < idxs.size(); ++j) { + size_t i = idxs[j]; + auto const& cmd = cmds_to_post[i]; + + qpx->wr_id = wrs_to_post[i]; + qpx->comp_mask = 0; + qpx->wr_flags = IBV_SEND_SIGNALED; + + uint64_t remote_addr = ctx->remote_addr + (cmd.req_rptr ? cmd.req_rptr : 0); + uint64_t remote_end = ctx->remote_addr + ctx->remote_len; + + if (remote_addr < ctx->remote_addr || + remote_addr + cmd.bytes > remote_end) { + fprintf(stderr, + "[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, " + "size=%zu), cmd.req_rptr: 0x%llx\n", + (unsigned long long)remote_addr, cmd.bytes, + (unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len, + (unsigned long long)cmd.req_rptr); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", + cudaGetErrorString(err)); + } + std::abort(); + } + // Optionally send an inline "atomic" via imm, else use imm only on tail + if (cmd.atomic_offset > 0 && cmd.atomic_val > 0) { + int v = static_cast(cmd.atomic_val); + if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { + fprintf(stderr, "[EFA] atomic value=%d won't fit in 15 bits\n", v); + std::abort(); + } + size_t index = static_cast(cmd.atomic_offset / sizeof(int)); + // Initialize missing entries lazily + auto key = ctx->seq_key(dst_rank, index); + if (ctx->next_seq_per_index.find(key) == ctx->next_seq_per_index.end()) + ctx->next_seq_per_index[key] = 0; + + uint8_t seq = ctx->next_seq_per_index[key]; + ctx->next_seq_per_index[key] = + (seq + 1) % kReorderingBufferSize; // 4-bit wrap (0–15) + uint32_t imm = + AtomicsImm::PackAtomicWithSeq(v, cmd.atomic_offset, seq, true) + .GetImmData(); + AtomicsImm aimm(imm); + assert(aimm.GetSeq() == seq); + + ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); + } else if (j + 1 == idxs.size()) { + uint32_t imm = + WriteImm::Pack(get_is_combine(cmd.cmd_type), + get_low_latency(cmd.cmd_type), cmd.expert_idx, + (uint32_t)idxs.size(), my_rank) + .GetImmData(); + ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); + } else { + ibv_wr_rdma_write(qpx, ctx->remote_rkey, remote_addr); + } + + uintptr_t laddr = cmd.req_lptr + reinterpret_cast(ctx->mr->addr); + ibv_wr_set_ud_addr(qpx, ctx->dst_ah, dst_qpn, QKEY); + ibv_wr_set_sge(qpx, ctx->mr->lkey, laddr, static_cast(cmd.bytes)); + + ring_wrids.push_back(wrs_to_post[i]); + } + int ret = ibv_wr_complete(qpx); + if (ret) { + fprintf(stderr, "ibv_wr_complete failed (dst=%d): %s (ret=%d)\n", dst_rank, + strerror(ret), ret); + std::abort(); + } +} + +static void non_efa_handle_ring_idx_throughput_mode( + ProxyCtx& S, ProxyCtx* ctx, int dst_rank, int my_rank, int thread_idx, + size_t ring_idx_raw, std::vector const& idxs, + std::vector const& wrs_to_post, + std::vector const& cmds_to_post) { + { + size_t const local_ring_count = ctx->data_qps_by_channel.size(); + struct ibv_qp* qp = + local_ring_count + ? ctx->data_qps_by_channel[ring_idx_raw % local_ring_count] + : ctx->ack_qp; + + size_t const kgroup = idxs.size(); + std::vector sges(kgroup); + std::vector wrs(kgroup); + std::vector ring_wrids; + ring_wrids.reserve(kgroup); + + for (size_t j = 0; j < kgroup; ++j) { + size_t i = idxs[j]; + auto const& cmd = cmds_to_post[i]; + ring_wrids.push_back(wrs_to_post[i]); + uint64_t remote_addr = + ctx->remote_addr + (cmd.req_rptr ? cmd.req_rptr : 0); + uint64_t remote_end = ctx->remote_addr + ctx->remote_len; + if (remote_addr < ctx->remote_addr || + remote_addr + cmd.bytes > remote_end) { + fprintf(stderr, + "[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, " + "size=%zu), cmd.req_rptr: 0x%llx\n", + (unsigned long long)remote_addr, cmd.bytes, + (unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len, + (unsigned long long)cmd.req_rptr); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", + cudaGetErrorString(err)); + } + std::abort(); + } + + // Local SGE + uintptr_t laddr = + cmd.req_lptr + reinterpret_cast(ctx->mr->addr); + sges[j] = { + .addr = laddr, + .length = static_cast(cmd.bytes), + .lkey = ctx->mr->lkey, + }; + + // Build WR + std::memset(&wrs[j], 0, sizeof(wrs[j])); + wrs[j].wr_id = wrs_to_post[i]; + wrs[j].sg_list = &sges[j]; + wrs[j].num_sge = 1; + wrs[j].wr.rdma.remote_addr = remote_addr; + wrs[j].wr.rdma.rkey = ctx->remote_rkey; + wrs[j].opcode = IBV_WR_RDMA_WRITE; // default + wrs[j].send_flags = (j + 1 == kgroup) ? IBV_SEND_SIGNALED : 0; + wrs[j].next = (j + 1 < kgroup) ? &wrs[j + 1] : nullptr; + + if (cmd.atomic_offset > 0 && cmd.atomic_val > 0) { + int v = static_cast(cmd.atomic_val); + if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { + fprintf(stderr, "atomic value=%d won't fit in 15 bits\n", v); + std::abort(); + } + uint32_t imm = + AtomicsImm::Pack(true, false, cmd.atomic_val, cmd.atomic_offset, + get_low_latency(cmd.cmd_type)) + .GetImmData(); + wrs[j].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wrs[j].imm_data = htonl(imm); + + AtomicsImm aimm(imm); + assert(aimm.GetValue() == cmd.atomic_val); + assert(aimm.GetOff() == cmd.atomic_offset); + } else { + wrs[j].opcode = IBV_WR_RDMA_WRITE; + } + } + + // Post the chain + ibv_send_wr* bad = nullptr; + int ret = ibv_post_send(qp, &wrs[0], &bad); + if (ret) { + fprintf(stderr, "ibv_post_send failed (dst=%d): %s (ret=%d)\n", dst_rank, + strerror(ret), ret); + if (bad) + fprintf(stderr, "Bad WR at %p (wr_id=%lu)\n", (void*)bad, bad->wr_id); + std::abort(); + } + size_t const last = kgroup - 1; + uint64_t const batch_tail_wr = ring_wrids[last]; + { + auto [it, inserted] = + S.wr_id_to_wr_ids.try_emplace(batch_tail_wr, std::move(ring_wrids)); + if (!inserted) { + fprintf(stderr, + "thread_idx: %d, Error: tail wr_id %lu already exists " + "(map=%p)\n", + thread_idx, batch_tail_wr, (void*)&S.wr_id_to_wr_ids); + std::abort(); + } + } + } +} + +static void post_rdma_async_batched_throughput_mode( ProxyCtx& S, void* buf, size_t num_wrs, std::vector const& wrs_to_post, std::vector const& cmds_to_post, @@ -719,15 +988,11 @@ static void post_rdma_async_batched_normal_mode( std::unordered_map> dst_rank_wr_ids; for (size_t i = 0; i < num_wrs; ++i) { if (cmds_to_post[i].dst_rank == static_cast(my_rank)) { - // NOTE(MaoZiming): this should not happen. - printf("Posting rdma to itself\n"); std::abort(); continue; } else if (std::abs((int)cmds_to_post[i].dst_rank - (int)my_rank) % MAX_NUM_GPUS != 0) { - // NOTE(MaoZiming): this should not happen. - printf("Posting rdma to a different rank\n"); std::abort(); continue; } else { @@ -755,214 +1020,166 @@ static void post_rdma_async_batched_normal_mode( for (auto& [ring_idx_raw, idxs] : ring_to_indices) { #ifdef EFA - const size_t local_ring_count = ctx->data_qps_by_channel.size(); - struct ibv_qp_ex* qpx = - (struct ibv_qp_ex*)(local_ring_count - ? ctx->data_qps_by_channel[ring_idx_raw % - local_ring_count] - : ctx->ack_qp); - - size_t const remote_ring_count = ctx->dst_data_qpn_by_ring.size(); - uint32_t const dst_qpn = - remote_ring_count - ? ctx->dst_data_qpn_by_ring[ring_idx_raw % remote_ring_count] - : ctx->dst_qpn; - - ibv_wr_start(qpx); - // No receiver barrier: build a single chain for this ring group - std::vector ring_wrids; - ring_wrids.reserve(idxs.size()); - - for (size_t j = 0; j < idxs.size(); ++j) { - size_t i = idxs[j]; - auto const& cmd = cmds_to_post[i]; - - qpx->wr_id = wrs_to_post[i]; - qpx->comp_mask = 0; - qpx->wr_flags = IBV_SEND_SIGNALED; - - uint64_t remote_addr = - ctx->remote_addr + (cmd.req_rptr ? cmd.req_rptr : 0); - uint64_t remote_end = ctx->remote_addr + ctx->remote_len; - - if (remote_addr < ctx->remote_addr || - remote_addr + cmd.bytes > remote_end) { - fprintf(stderr, - "[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, " - "size=%zu), cmd.req_rptr: 0x%llx\n", - (unsigned long long)remote_addr, cmd.bytes, - (unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len, - (unsigned long long)cmd.req_rptr); - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", - cudaGetErrorString(err)); - } - std::abort(); - } - // Optionally send an inline "atomic" via imm, else use imm only on tail - if (cmd.atomic_offset > 0 && cmd.atomic_val > 0) { - int v = static_cast(cmd.atomic_val); - if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { - fprintf(stderr, "[EFA] atomic value=%d won't fit in 15 bits\n", v); - std::abort(); - } - size_t index = static_cast(cmd.atomic_offset / sizeof(int)); - // Initialize missing entries lazily - auto key = ctx->seq_key(dst_rank, index); - if (ctx->next_seq_per_index.find(key) == - ctx->next_seq_per_index.end()) - ctx->next_seq_per_index[key] = 0; - - uint8_t seq = ctx->next_seq_per_index[key]; - ctx->next_seq_per_index[key] = - (seq + 1) % kReorderingBufferSize; // 4-bit wrap (0–15) - uint32_t imm = - AtomicsImm::PackAtomicWithSeq(v, cmd.atomic_offset, seq, true) - .GetImmData(); - AtomicsImm aimm(imm); - assert(aimm.GetSeq() == seq); - - ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); - } else if (j + 1 == idxs.size()) { - uint32_t imm = - WriteImm::Pack(get_is_combine(cmd.cmd_type), - get_low_latency(cmd.cmd_type), cmd.expert_idx, - (uint32_t)idxs.size(), my_rank) - .GetImmData(); - ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); - } else { - ibv_wr_rdma_write(qpx, ctx->remote_rkey, remote_addr); - } + efa_handle_ring_idx_throughput_mode(ctx, dst_rank, my_rank, ring_idx_raw, + idxs, wrs_to_post, cmds_to_post); +#else + non_efa_handle_ring_idx_throughput_mode(S, ctx, dst_rank, my_rank, + thread_idx, ring_idx_raw, idxs, + wrs_to_post, cmds_to_post); +#endif + } + } +} - uintptr_t laddr = - cmd.req_lptr + reinterpret_cast(ctx->mr->addr); - ibv_wr_set_ud_addr(qpx, ctx->dst_ah, dst_qpn, QKEY); - ibv_wr_set_sge(qpx, ctx->mr->lkey, laddr, - static_cast(cmd.bytes)); +static void efa_handle_dst_rank_latency_mode( + ProxyCtx& S, ProxyCtx* ctx, int dst_rank, int my_rank, int thread_idx, + size_t k, std::vector& wr_ids, + std::vector const& wrs_to_post, + std::vector const& cmds_to_post) { + struct ibv_qp_ex* qpx = (struct ibv_qp_ex*)ctx->qp; + ibv_wr_start(qpx); - ring_wrids.push_back(wrs_to_post[i]); - } - int ret = ibv_wr_complete(qpx); - if (ret) { - fprintf(stderr, "ibv_wr_complete failed (dst=%d): %s (ret=%d)\n", - dst_rank, strerror(ret), ret); - std::abort(); - } -#else - { - size_t const local_ring_count = ctx->data_qps_by_channel.size(); - struct ibv_qp* qp = - local_ring_count - ? ctx->data_qps_by_channel[ring_idx_raw % local_ring_count] - : ctx->ack_qp; - - size_t const kgroup = idxs.size(); - std::vector sges(kgroup); - std::vector wrs(kgroup); - std::vector ring_wrids; - ring_wrids.reserve(kgroup); - - for (size_t j = 0; j < kgroup; ++j) { - size_t i = idxs[j]; - auto const& cmd = cmds_to_post[i]; - ring_wrids.push_back(wrs_to_post[i]); - - // Remote address bounds check - uint64_t remote_addr = - ctx->remote_addr + (cmd.req_rptr ? cmd.req_rptr : 0); - uint64_t remote_end = ctx->remote_addr + ctx->remote_len; - - if (remote_addr < ctx->remote_addr || - remote_addr + cmd.bytes > remote_end) { - fprintf( - stderr, + std::unordered_map> dst_expert_wr_ids; + for (size_t j = 0; j < k; ++j) { + size_t i = wr_ids[j]; + int expert_idx = cmds_to_post[i].expert_idx; + dst_expert_wr_ids[expert_idx].push_back(i); + } + + for (auto& [expert_idx, expert_wr_ids] : dst_expert_wr_ids) { + size_t expert_k = expert_wr_ids.size(); + for (size_t j = 0; j < expert_k; ++j) { + size_t i = expert_wr_ids[j]; + auto const& cmd = cmds_to_post[i]; + expert_wr_ids[j] = wrs_to_post[i]; + qpx->wr_id = wrs_to_post[i]; + qpx->comp_mask = 0; + qpx->wr_flags = IBV_SEND_SIGNALED; + + uint64_t remote_addr = + ctx->remote_addr + (cmd.req_rptr ? cmd.req_rptr : 0); + uint64_t remote_end = ctx->remote_addr + ctx->remote_len; + + if (remote_addr < ctx->remote_addr || + remote_addr + cmd.bytes > remote_end) { + fprintf(stderr, "[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, " "size=%zu), cmd.req_rptr: 0x%llx\n", (unsigned long long)remote_addr, cmd.bytes, (unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len, (unsigned long long)cmd.req_rptr); - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", - cudaGetErrorString(err)); - } - std::abort(); - } - - // Local SGE - uintptr_t laddr = - cmd.req_lptr + reinterpret_cast(ctx->mr->addr); - sges[j] = { - .addr = laddr, - .length = static_cast(cmd.bytes), - .lkey = ctx->mr->lkey, - }; - - // Build WR - std::memset(&wrs[j], 0, sizeof(wrs[j])); - wrs[j].wr_id = wrs_to_post[i]; - wrs[j].sg_list = &sges[j]; - wrs[j].num_sge = 1; - wrs[j].wr.rdma.remote_addr = remote_addr; - wrs[j].wr.rdma.rkey = ctx->remote_rkey; - wrs[j].opcode = IBV_WR_RDMA_WRITE; // default - wrs[j].send_flags = (j + 1 == kgroup) ? IBV_SEND_SIGNALED : 0; - wrs[j].next = (j + 1 < kgroup) ? &wrs[j + 1] : nullptr; - - if (cmd.atomic_offset > 0 && cmd.atomic_val > 0) { - int v = static_cast(cmd.atomic_val); - if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { - fprintf(stderr, "atomic value=%d won't fit in 15 bits\n", v); - std::abort(); - } - uint32_t imm = - AtomicsImm::Pack(true, false, cmd.atomic_val, cmd.atomic_offset, - get_low_latency(cmd.cmd_type)) - .GetImmData(); - wrs[j].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wrs[j].imm_data = htonl(imm); - - AtomicsImm aimm(imm); - assert(aimm.GetValue() == cmd.atomic_val); - assert(aimm.GetOff() == cmd.atomic_offset); - } else { - wrs[j].opcode = IBV_WR_RDMA_WRITE; - } - } - - // Post the chain - ibv_send_wr* bad = nullptr; - int ret = ibv_post_send(qp, &wrs[0], &bad); - if (ret) { - fprintf(stderr, "ibv_post_send failed (dst=%d): %s (ret=%d)\n", - dst_rank, strerror(ret), ret); - if (bad) - fprintf(stderr, "Bad WR at %p (wr_id=%lu)\n", (void*)bad, - bad->wr_id); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", + cudaGetErrorString(err)); std::abort(); } - size_t const last = kgroup - 1; - uint64_t const batch_tail_wr = ring_wrids[last]; - { - auto [it, inserted] = S.wr_id_to_wr_ids.try_emplace( - batch_tail_wr, std::move(ring_wrids)); - if (!inserted) { - fprintf(stderr, - "thread_idx: %d, Error: tail wr_id %lu already exists " - "(map=%p)\n", - thread_idx, batch_tail_wr, (void*)&S.wr_id_to_wr_ids); - std::abort(); - } - } + std::abort(); } -#endif + uint32_t imm = WriteImm::Pack(get_is_combine(cmd.cmd_type), + get_low_latency(cmd.cmd_type), + cmd.expert_idx, 1, my_rank) + .GetImmData(); + ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); + uintptr_t laddr = + cmd.req_lptr + reinterpret_cast(ctx->mr->addr); + ibv_wr_set_ud_addr(qpx, ctx->dst_ah, ctx->dst_qpn, QKEY); + ibv_wr_set_sge(qpx, ctx->mr->lkey, laddr, + static_cast(cmd.bytes)); + } + uint64_t const expert_tail_wr = expert_wr_ids.back(); + { + auto [it, inserted] = S.wr_id_to_wr_ids.try_emplace( + expert_tail_wr, std::move(expert_wr_ids)); + if (!inserted) { + fprintf(stderr, + "thread_idx: %d, Error: tail wr_id %lu already exists " + "(map=%p)\n", + thread_idx, expert_tail_wr, (void*)&S.wr_id_to_wr_ids); + std::abort(); + } + } + } + + int ret = ibv_wr_complete(qpx); + if (ret) { + fprintf(stderr, "ibv_wr_complete failed (dst=%d): %s (ret=%d)\n", dst_rank, + strerror(ret), ret); + std::abort(); + } +} + +static void non_efa_handle_dst_rank_latency_mode( + ProxyCtx& S, ProxyCtx* ctx, int dst_rank, int thread_idx, size_t k, + std::vector& wr_ids, std::vector const& wrs_to_post, + std::vector const& cmds_to_post) { + std::vector sges(k); + std::vector wrs(k); + for (size_t j = 0; j < k; ++j) { + size_t i = wr_ids[j]; + auto const& cmd = cmds_to_post[i]; + wr_ids[j] = wrs_to_post[i]; + sges[j].addr = cmd.req_lptr + reinterpret_cast(ctx->mr->addr); + sges[j].length = static_cast(cmd.bytes); + sges[j].lkey = ctx->mr->lkey; + std::memset(&wrs[j], 0, sizeof(wrs[j])); + wrs[j].sg_list = &sges[j]; + wrs[j].num_sge = 1; + wrs[j].wr_id = wr_ids[j]; + + wrs[j].wr.rdma.remote_addr = ctx->remote_addr + cmd.req_rptr; + + uint64_t remote_end = ctx->remote_addr + ctx->remote_len; + if (wrs[j].wr.rdma.remote_addr < ctx->remote_addr || + wrs[j].wr.rdma.remote_addr + cmd.bytes > remote_end) { + fprintf(stderr, + "[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, " + "size=%zu), cmd.req_rptr: 0x%llx\n", + (unsigned long long)wrs[j].wr.rdma.remote_addr, cmd.bytes, + (unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len, + (unsigned long long)cmd.req_rptr); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", + cudaGetErrorString(err)); + std::abort(); + } + std::abort(); + } + + wrs[j].wr.rdma.rkey = ctx->remote_rkey; + wrs[j].opcode = IBV_WR_RDMA_WRITE; + wrs[j].send_flags = 0; + wrs[j].next = (j + 1 < k) ? &wrs[j + 1] : nullptr; + } + size_t const last = k - 1; + uint64_t const batch_tail_wr = wr_ids[last]; + wrs[last].send_flags |= IBV_SEND_SIGNALED; + wrs[last].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wrs[last].imm_data = htonl(static_cast(batch_tail_wr)); + ibv_send_wr* bad = nullptr; + int ret = ibv_post_send(ctx->qp, &wrs[0], &bad); + if (ret) { + fprintf(stderr, "ibv_post_send failed (dst=%d): %s (ret=%d)\n", dst_rank, + strerror(ret), ret); + if (bad) + fprintf(stderr, "Bad WR at %p (wr_id=%lu)\n", (void*)bad, bad->wr_id); + std::abort(); + } + { + auto [it, inserted] = + S.wr_id_to_wr_ids.try_emplace(batch_tail_wr, std::move(wr_ids)); + if (!inserted) { + fprintf(stderr, + "thread_idx: %d, Error: tail wr_id %lu already exists " + "(map=%p)\n", + thread_idx, batch_tail_wr, (void*)&S.wr_id_to_wr_ids); + std::abort(); } } } -// Fast mode implementation -static void post_rdma_async_batched_fast_mode( +static void post_rdma_async_batched_latency_mode( ProxyCtx& S, void* buf, size_t num_wrs, std::vector const& wrs_to_post, std::vector const& cmds_to_post, @@ -979,10 +1196,7 @@ static void post_rdma_async_batched_fast_mode( std::unordered_map> dst_rank_wr_ids; for (size_t i = 0; i < num_wrs; ++i) { if (cmds_to_post[i].dst_rank == static_cast(my_rank)) { - // NOTE(MaoZiming): this should not happen. - printf("Posting rdma to itself\n"); std::abort(); - continue; } else { dst_rank_wr_ids[cmds_to_post[i].dst_rank].push_back(i); } @@ -998,205 +1212,64 @@ static void post_rdma_async_batched_fast_mode( } size_t const k = wr_ids.size(); #ifdef EFA - struct ibv_qp_ex* qpx = (struct ibv_qp_ex*)ctx->qp; - ibv_wr_start(qpx); - -#ifdef USE_RECEIVER_BARRIER - std::unordered_map> dst_expert_wr_ids; - for (size_t j = 0; j < k; ++j) { - size_t i = wr_ids[j]; - int expert_idx = cmds_to_post[i].expert_idx; - dst_expert_wr_ids[expert_idx].push_back(i); - } -#endif - -#ifdef USE_RECEIVER_BARRIER - for (auto& [expert_idx, expert_wr_ids] : dst_expert_wr_ids) { - size_t expert_k = expert_wr_ids.size(); - for (size_t j = 0; j < expert_k; ++j) { - size_t i = expert_wr_ids[j]; + efa_handle_dst_rank_latency_mode(S, ctx, dst_rank, my_rank, thread_idx, k, + wr_ids, wrs_to_post, cmds_to_post); #else - for (size_t j = 0; j < k; ++j) { - size_t i = wr_ids[j]; + non_efa_handle_dst_rank_latency_mode(S, ctx, dst_rank, thread_idx, k, + wr_ids, wrs_to_post, cmds_to_post); #endif + } +} + +void post_rdma_async_batched(ProxyCtx& S, void* buf, size_t num_wrs, + std::vector const& wrs_to_post, + std::vector const& cmds_to_post, + std::vector>& ctxs, + int my_rank, int thread_idx, + bool use_throughput_mode) { + if (use_throughput_mode) { + post_rdma_async_batched_throughput_mode( + S, buf, num_wrs, wrs_to_post, cmds_to_post, ctxs, my_rank, thread_idx); + } else { + post_rdma_async_batched_latency_mode( + S, buf, num_wrs, wrs_to_post, cmds_to_post, ctxs, my_rank, thread_idx); + } +} - auto const& cmd = cmds_to_post[i]; -#ifdef USE_RECEIVER_BARRIER - expert_wr_ids[j] = wrs_to_post[i]; +static void handle_atomic_completion_ack( + ProxyCtx& S, uint64_t wrid, std::unordered_set& acked_wrs) { +#ifdef EFA + acked_wrs.insert(wrid); #else - wr_ids[j] = wrs_to_post[i]; + auto it = S.wr_id_to_wr_ids.find(wrid); + if (it != S.wr_id_to_wr_ids.end()) { + for (uint64_t sub_wr : it->second) { + acked_wrs.insert(sub_wr); + } + S.wr_id_to_wr_ids.erase(it); + } else { + printf("Error: Atomic ACK for unknown wr_id %lu\n", wrid); + std::abort(); + } #endif - qpx->wr_id = wrs_to_post[i]; - qpx->comp_mask = 0; - qpx->wr_flags = IBV_SEND_SIGNALED; +} - uint64_t remote_addr = - ctx->remote_addr + (cmd.req_rptr ? cmd.req_rptr : 0); - uint64_t remote_end = ctx->remote_addr + ctx->remote_len; - - if (remote_addr < ctx->remote_addr || - remote_addr + cmd.bytes > remote_end) { - fprintf(stderr, - "[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, " - "size=%zu), cmd.req_rptr: 0x%llx\n", - (unsigned long long)remote_addr, cmd.bytes, - (unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len, - (unsigned long long)cmd.req_rptr); - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", - cudaGetErrorString(err)); - std::abort(); - } - std::abort(); - } -#ifdef USE_SENDER_BARRIER - S.wr_id_to_write_struct[qpx->wr_id] = {cmd.expert_idx, dst_rank, - get_is_combine(cmd.cmd_type), - get_low_latency(cmd.cmd_type)}; -#endif -#ifdef USE_RECEIVER_BARRIER - uint32_t imm = WriteImm::Pack(get_is_combine(cmd.cmd_type), - get_low_latency(cmd.cmd_type), - cmd.expert_idx, 1, my_rank) - .GetImmData(); - ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); +static void handle_write_completion_ack( + ProxyCtx& S, uint64_t wr_done, std::unordered_set& acked_wrs) { +#ifdef EFA + acked_wrs.insert(wr_done); #else - if (j + 1 == k) { - uint32_t imm = WriteImm::Pack(get_is_combine(cmd.cmd_type), - get_low_latency(cmd.cmd_type), - cmd.expert_idx, k, my_rank) - .GetImmData(); - ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, remote_addr, htonl(imm)); - } else { - ibv_wr_rdma_write(qpx, ctx->remote_rkey, remote_addr); - } -#endif - uintptr_t laddr = - cmd.req_lptr + reinterpret_cast(ctx->mr->addr); - ibv_wr_set_ud_addr(qpx, ctx->dst_ah, ctx->dst_qpn, QKEY); - ibv_wr_set_sge(qpx, ctx->mr->lkey, laddr, - static_cast(cmd.bytes)); - } - -#ifdef USE_RECEIVER_BARRIER - uint64_t const expert_tail_wr = expert_wr_ids.back(); - { - auto [it, inserted] = S.wr_id_to_wr_ids.try_emplace( - expert_tail_wr, std::move(expert_wr_ids)); - if (!inserted) { - fprintf(stderr, - "thread_idx: %d, Error: tail wr_id %lu already exists " - "(map=%p)\n", - thread_idx, expert_tail_wr, (void*)&S.wr_id_to_wr_ids); - std::abort(); - } - } + auto it = S.wr_id_to_wr_ids.find(wr_done); + if (it != S.wr_id_to_wr_ids.end()) { + for (uint64_t sub_wr : it->second) { + acked_wrs.insert(sub_wr); } -#else - uint64_t const tail_wr = wr_ids.back(); - { - auto [it, inserted] = - S.wr_id_to_wr_ids.try_emplace(tail_wr, std::move(wr_ids)); - if (!inserted) { - fprintf(stderr, - "thread_idx: %d, Error: tail wr_id %lu already exists " - "(map=%p)\n", - thread_idx, tail_wr, (void*)&S.wr_id_to_wr_ids); - std::abort(); - } - } -#endif - - int ret = ibv_wr_complete(qpx); - if (ret) { - fprintf(stderr, "ibv_wr_complete failed (dst=%d): %s (ret=%d)\n", - dst_rank, strerror(ret), ret); - std::abort(); - } -#else - std::vector sges(k); - std::vector wrs(k); - for (size_t j = 0; j < k; ++j) { - size_t i = wr_ids[j]; - auto const& cmd = cmds_to_post[i]; - wr_ids[j] = wrs_to_post[i]; - sges[j].addr = cmd.req_lptr + reinterpret_cast(ctx->mr->addr); - sges[j].length = static_cast(cmd.bytes); - sges[j].lkey = ctx->mr->lkey; - std::memset(&wrs[j], 0, sizeof(wrs[j])); - wrs[j].sg_list = &sges[j]; - wrs[j].num_sge = 1; - wrs[j].wr_id = wr_ids[j]; - - wrs[j].wr.rdma.remote_addr = ctx->remote_addr + cmd.req_rptr; - - uint64_t remote_end = ctx->remote_addr + ctx->remote_len; - if (wrs[j].wr.rdma.remote_addr < ctx->remote_addr || - wrs[j].wr.rdma.remote_addr + cmd.bytes > remote_end) { - fprintf(stderr, - "[ERROR] Remote write OOB: addr=0x%llx len=%u (base=0x%llx, " - "size=%zu), cmd.req_rptr: 0x%llx\n", - (unsigned long long)wrs[j].wr.rdma.remote_addr, cmd.bytes, - (unsigned long long)ctx->remote_addr, (size_t)ctx->remote_len, - (unsigned long long)cmd.req_rptr); - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - fprintf(stderr, "cudaDeviceSynchronize failed: %s\n", - cudaGetErrorString(err)); - std::abort(); - } - std::abort(); - } - - wrs[j].wr.rdma.rkey = ctx->remote_rkey; - wrs[j].opcode = IBV_WR_RDMA_WRITE; - wrs[j].send_flags = 0; - wrs[j].next = (j + 1 < k) ? &wrs[j + 1] : nullptr; - } - size_t const last = k - 1; - uint64_t const batch_tail_wr = wr_ids[last]; - wrs[last].send_flags |= IBV_SEND_SIGNALED; - wrs[last].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wrs[last].imm_data = htonl(static_cast(batch_tail_wr)); - ibv_send_wr* bad = nullptr; - int ret = ibv_post_send(ctx->qp, &wrs[0], &bad); - if (ret) { - fprintf(stderr, "ibv_post_send failed (dst=%d): %s (ret=%d)\n", dst_rank, - strerror(ret), ret); - if (bad) - fprintf(stderr, "Bad WR at %p (wr_id=%lu)\n", (void*)bad, bad->wr_id); - std::abort(); - } - { - auto [it, inserted] = - S.wr_id_to_wr_ids.try_emplace(batch_tail_wr, std::move(wr_ids)); - if (!inserted) { - fprintf(stderr, - "thread_idx: %d, Error: tail wr_id %lu already exists " - "(map=%p)\n", - thread_idx, batch_tail_wr, (void*)&S.wr_id_to_wr_ids); - std::abort(); - } - } -#endif - } -} - -// Wrapper that selects implementation based on use_normal_mode -void post_rdma_async_batched(ProxyCtx& S, void* buf, size_t num_wrs, - std::vector const& wrs_to_post, - std::vector const& cmds_to_post, - std::vector>& ctxs, - int my_rank, int thread_idx, - bool use_normal_mode) { - if (use_normal_mode) { - post_rdma_async_batched_normal_mode( - S, buf, num_wrs, wrs_to_post, cmds_to_post, ctxs, my_rank, thread_idx); + S.wr_id_to_wr_ids.erase(it); } else { - post_rdma_async_batched_fast_mode(S, buf, num_wrs, wrs_to_post, - cmds_to_post, ctxs, my_rank, thread_idx); + printf("Error: Write ACK for unknown wr_id %lu\n", wr_done); + std::abort(); } +#endif } void local_process_completions(ProxyCtx& S, @@ -1221,59 +1294,15 @@ void local_process_completions(ProxyCtx& S, uint64_t wrid = wc[i].wr_id; if ((wrid & kAtomicWrTag) == kAtomicWrTag) { wrid &= kAtomicMask; -#ifdef EFA - acked_wrs.insert(wrid); -#else - auto it = S.wr_id_to_wr_ids.find(wrid); - if (it != S.wr_id_to_wr_ids.end()) { - for (uint64_t sub_wr : it->second) { - acked_wrs.insert(sub_wr); - } - S.wr_id_to_wr_ids.erase(it); - } else { - printf("Error: Atomic ACK for unknown wr_id %lu\n", wrid); - std::abort(); - } -#endif + handle_atomic_completion_ack(S, wrid, acked_wrs); break; } if ((wrid & kBarrierWrTag) == kBarrierWrTag) { break; } -#ifdef USE_SENDER_BARRIER - { - auto it = S.wr_id_to_write_struct.find(wrid); - if (it != S.wr_id_to_write_struct.end()) { - WriteStruct const& ws = it->second; - S.wr_id_to_write_struct.erase(it); - if (ws.is_combine) { - S.combine_sent_counter.Add( - {ws.low_latency_buffer_idx, ws.expert_idx, ws.dst_rank}, 1); - } else { - S.dispatch_sent_counter.Add( - {ws.low_latency_buffer_idx, ws.expert_idx, ws.dst_rank}, 1); - } - } else { - assert(false && "wr_id not found in write_struct map"); - } - } -#endif { uint64_t const wr_done = wc[i].wr_id; -#ifdef EFA - acked_wrs.insert(wr_done); -#else - auto it = S.wr_id_to_wr_ids.find(wr_done); - if (it != S.wr_id_to_wr_ids.end()) { - for (uint64_t sub_wr : it->second) { - acked_wrs.insert(sub_wr); - } - S.wr_id_to_wr_ids.erase(it); - } else { - printf("Error: Write ACK for unknown wr_id %lu\n", wr_done); - std::abort(); - } -#endif + handle_write_completion_ack(S, wr_done, acked_wrs); } } break; case IBV_WC_RECV: @@ -1294,7 +1323,7 @@ void local_process_completions(ProxyCtx& S, } } -int poll_cq_once(ibv_cq* cq, ibv_wc* wc, int max_cqes) { +static int poll_cq_once_efa(ibv_cq* cq, ibv_wc* wc, int max_cqes) { #ifdef EFA auto cqx = reinterpret_cast(cq); ibv_poll_cq_attr attr{.comp_mask = 0}; @@ -1314,7 +1343,29 @@ int poll_cq_once(ibv_cq* cq, ibv_wc* wc, int max_cqes) { ibv_end_poll(cqx); return n; #else + (void)cq; + (void)wc; + (void)max_cqes; + return 0; +#endif +} + +static int poll_cq_once_non_efa(ibv_cq* cq, ibv_wc* wc, int max_cqes) { +#ifndef EFA return ibv_poll_cq(cq, max_cqes, wc); +#else + (void)cq; + (void)wc; + (void)max_cqes; + return 0; +#endif +} + +int poll_cq_once(ibv_cq* cq, ibv_wc* wc, int max_cqes) { +#ifdef EFA + return poll_cq_once_efa(cq, wc, max_cqes); +#else + return poll_cq_once_non_efa(cq, wc, max_cqes); #endif } @@ -1330,7 +1381,6 @@ void local_poll_completions(ProxyCtx& S, } }; if (S.cq) poll_one(S.cq); - // for (auto* cq : S.extra_cqs) poll_one(cq); } void poll_cq_dual(ProxyCtx& S, std::unordered_set& acked_wrs, @@ -1338,7 +1388,7 @@ void poll_cq_dual(ProxyCtx& S, std::unordered_set& acked_wrs, std::vector& ctx_by_tag, void* atomic_buffer_ptr, int num_ranks, int num_experts, std::set& pending_atomic_updates, int my_rank, - int num_nodes, bool use_normal_mode) { + int num_nodes, bool use_throughput_mode) { ibv_wc wc[kMaxOutstandingSends]; auto poll_one = [&](ibv_cq* cq) { int ne = poll_cq_once(cq, wc, kMaxOutstandingSends); @@ -1347,11 +1397,10 @@ void poll_cq_dual(ProxyCtx& S, std::unordered_set& acked_wrs, remote_process_completions(S, thread_idx, g_ring, ne, wc, ctx_by_tag, atomic_buffer_ptr, num_ranks, num_experts, pending_atomic_updates, my_rank, num_nodes, - use_normal_mode); + use_throughput_mode); } }; if (S.cq) poll_one(S.cq); - // for (auto* cq : S.extra_cqs) poll_one(cq); } void apply_pending_updates(ProxyCtx& ctx, @@ -1399,7 +1448,50 @@ ibv_qp* qp_from_qpnum(ProxyCtx& S, uint32_t qpnum) { return nullptr; } -void remote_process_completions_normal_mode( +static void replenish_recv_buffer_if_needed( + ibv_wc const& cqe, std::vector& ctx_by_tag) { +#ifndef EFA + if (cqe.opcode == IBV_WC_RECV_RDMA_WITH_IMM) { + uint32_t const tag = wr_tag(cqe.wr_id); + if (tag >= ctx_by_tag.size() || ctx_by_tag[tag] == nullptr) { + fprintf(stderr, "Invalid tag or uninitialized context for tag=%u\n", tag); + std::abort(); + } + ProxyCtx& S = *ctx_by_tag[tag]; + ibv_qp* qp = qp_from_qpnum(S, cqe.qp_num); + if (!qp) { + fprintf(stderr, "No matching QP for qp_num=0x%x (tag=%u)\n", cqe.qp_num, + tag); + std::abort(); + } + ibv_sge sge = { + .addr = reinterpret_cast(&S.ack_recv_buf[0]), + .length = sizeof(uint64_t), + .lkey = S.ack_recv_mr->lkey, + }; + ibv_recv_wr rwr{}; + S.pool_index = (S.pool_index + 1) % (kRemoteBufferSize / kObjectSize - 1); + rwr.wr_id = make_wr_id(wr_tag(cqe.wr_id), S.pool_index); + rwr.sg_list = &sge; + rwr.num_sge = 1; + ibv_recv_wr* bad = nullptr; + int ret = ibv_post_recv(qp, &rwr, &bad); + if (ret) { + fprintf(stderr, + "ibv_post_recv (imm replenish) failed on qp=0x%x: %s (%d)\n", + qp->qp_num, strerror(ret), ret); + if (bad) + fprintf(stderr, " bad wr_id=%llu\n", (unsigned long long)bad->wr_id); + std::abort(); + } + } +#else + (void)cqe; + (void)ctx_by_tag; +#endif +} + +void remote_process_completions_throughput_mode( ProxyCtx& S, int idx, CopyRingBuffer& g_ring, int ne, ibv_wc* wc, std::vector& ctx_by_tag, void* atomic_buffer_ptr, int num_ranks, int num_experts, std::set& pending_atomic_updates, @@ -1493,11 +1585,7 @@ void remote_process_completions_normal_mode( uint16_t src = bimm.GetRank(); // First node. // TODO(MaoZiming): pass node_idx instead. -#ifdef USE_SUBSET_BARRIER if (my_rank < MAX_NUM_GPUS) { -#else - if (my_rank == 0) { -#endif if (!is_ack) { if (S.barrier_arrived.empty()) { assert(S.barrier_arrival_count == 0 && @@ -1532,47 +1620,11 @@ void remote_process_completions_normal_mode( std::abort(); } -#ifndef EFA - if (cqe.opcode == IBV_WC_RECV_RDMA_WITH_IMM) { - uint32_t const tag = wr_tag(cqe.wr_id); - if (tag >= ctx_by_tag.size() || ctx_by_tag[tag] == nullptr) { - fprintf(stderr, "Invalid tag or uninitialized context for tag=%u\n", - tag); - std::abort(); - } - ProxyCtx& S = *ctx_by_tag[tag]; - ibv_qp* qp = qp_from_qpnum(S, cqe.qp_num); - if (!qp) { - fprintf(stderr, "No matching QP for qp_num=0x%x (tag=%u)\n", cqe.qp_num, - tag); - std::abort(); - } - ibv_sge sge = { - .addr = reinterpret_cast(&S.ack_recv_buf[0]), - .length = sizeof(uint64_t), - .lkey = S.ack_recv_mr->lkey, - }; - ibv_recv_wr rwr{}; - S.pool_index = (S.pool_index + 1) % (kRemoteBufferSize / kObjectSize - 1); - rwr.wr_id = make_wr_id(wr_tag(cqe.wr_id), S.pool_index); - rwr.sg_list = &sge; - rwr.num_sge = 1; - ibv_recv_wr* bad = nullptr; - int ret = ibv_post_recv(qp, &rwr, &bad); - if (ret) { - fprintf(stderr, - "ibv_post_recv (imm replenish) failed on qp=0x%x: %s (%d)\n", - qp->qp_num, strerror(ret), ret); - if (bad) - fprintf(stderr, " bad wr_id=%llu\n", (unsigned long long)bad->wr_id); - std::abort(); - } - } -#endif + replenish_recv_buffer_if_needed(cqe, ctx_by_tag); } } -void remote_process_completions_fast_mode( +void remote_process_completions_latency_mode( ProxyCtx& S, int idx, CopyRingBuffer& g_ring, int ne, ibv_wc* wc, std::vector& ctx_by_tag, void* atomic_buffer_ptr, int num_ranks, int num_experts, std::set& pending_atomic_updates, @@ -1596,7 +1648,6 @@ void remote_process_completions_fast_mode( int value = aimm.GetValue(); uint32_t offset = aimm.GetOff(); size_t index = offset / sizeof(int); -#ifdef USE_RECEIVER_BARRIER // ep_config.hpp bool is_combine = aimm.IsCombine(); int low_latency_buffer_idx = aimm.GetBufferIdx(); @@ -1661,43 +1712,6 @@ void remote_process_completions_fast_mode( low_latency_buffer_idx, expert_idx, is_combine, src_rank}); } -#else - auto* addr32 = - reinterpret_cast*>(atomic_buffer_ptr) + index; -#ifdef USE_SENDER_BARRIER - if (aimm.IsCombine()) value = 1; -#ifndef EFA - const uint32_t tag = wr_tag(cqe.wr_id); - ProxyCtx& S_atomic = *ctx_by_tag[tag]; - ibv_sge sge = { - .addr = reinterpret_cast(S_atomic.mr->addr), - .length = 1, - .lkey = S_atomic.mr->lkey, - }; - ibv_recv_wr rwr = {}; - S.pool_index = (S.pool_index + 1) % (kRemoteBufferSize / kObjectSize - 1); - rwr.wr_id = make_wr_id(wr_tag(cqe.wr_id), S.pool_index); - rwr.sg_list = &sge; - rwr.num_sge = 1; - ibv_recv_wr* bad = nullptr; - if (ibv_post_recv(S_atomic.qp, &rwr, &bad)) { - perror("ibv_post_recv (atomics replenish)"); - std::abort(); - } - continue; -#endif -#endif - if (value == kMaxSendAtomicValue) value = kLargeAtomicValue; - bool is_combine = aimm.IsCombine(); - assert(!is_combine || value >= 0); - if (is_combine) { - assert(value >= 0 && "Combine atomic value should be non-negative"); - } else { - assert(value <= -1 && "Dispatch atomic value should be <= -1"); - } - if (is_combine) value = 1; - addr32->fetch_add(value, std::memory_order_release); -#endif } else if (cqe.opcode == IBV_WC_RECV_RDMA_WITH_IMM && ImmType::IsBarrier(ntohl(cqe.imm_data))) { BarrierImm bimm(ntohl(cqe.imm_data)); @@ -1734,7 +1748,6 @@ void remote_process_completions_fast_mode( } } else if (cqe.opcode == IBV_WC_RECV_RDMA_WITH_IMM && ImmType::IsWrite(ntohl(cqe.imm_data))) { -#ifdef USE_RECEIVER_BARRIER uint32_t imm = ntohl(cqe.imm_data); WriteImm wimm(imm); bool is_combine = wimm.IsCombine(); @@ -1752,48 +1765,11 @@ void remote_process_completions_fast_mode( expert_idx < (src_rank + 1) * (num_experts / num_ranks)); S.combine_token_counter.Add({buffer_idx, expert_idx}, k); } -#endif } else if (cqe.opcode == IBV_WC_RECV_RDMA_WITH_IMM) { fprintf(stderr, "Unexpected CQE opcode: %d\n", cqe.opcode); std::abort(); } -#ifndef EFA - if (cqe.opcode == IBV_WC_RECV_RDMA_WITH_IMM) { - uint32_t const tag = wr_tag(cqe.wr_id); - if (tag >= ctx_by_tag.size() || ctx_by_tag[tag] == nullptr) { - fprintf(stderr, "Invalid tag or uninitialized context for tag=%u\n", - tag); - std::abort(); - } - ProxyCtx& S = *ctx_by_tag[tag]; - ibv_qp* qp = qp_from_qpnum(S, cqe.qp_num); - if (!qp) { - fprintf(stderr, "No matching QP for qp_num=0x%x (tag=%u)\n", cqe.qp_num, - tag); - std::abort(); - } - ibv_sge sge = { - .addr = reinterpret_cast(&S.ack_recv_buf[0]), - .length = sizeof(uint64_t), - .lkey = S.ack_recv_mr->lkey, - }; - ibv_recv_wr rwr{}; - S.pool_index = (S.pool_index + 1) % (kRemoteBufferSize / kObjectSize - 1); - rwr.wr_id = make_wr_id(wr_tag(cqe.wr_id), S.pool_index); - rwr.sg_list = &sge; - rwr.num_sge = 1; - ibv_recv_wr* bad = nullptr; - int ret = ibv_post_recv(qp, &rwr, &bad); - if (ret) { - fprintf(stderr, - "ibv_post_recv (imm replenish) failed on qp=0x%x: %s (%d)\n", - qp->qp_num, strerror(ret), ret); - if (bad) - fprintf(stderr, " bad wr_id=%llu\n", (unsigned long long)bad->wr_id); - std::abort(); - } - } -#endif + replenish_recv_buffer_if_needed(cqe, ctx_by_tag); } } @@ -1801,13 +1777,13 @@ void remote_process_completions( ProxyCtx& S, int idx, CopyRingBuffer& g_ring, int ne, ibv_wc* wc, std::vector& ctx_by_tag, void* atomic_buffer_ptr, int num_ranks, int num_experts, std::set& pending_atomic_updates, - int my_rank, int num_nodes, bool use_normal_mode) { - if (use_normal_mode) { - remote_process_completions_normal_mode( + int my_rank, int num_nodes, bool use_throughput_mode) { + if (use_throughput_mode) { + remote_process_completions_throughput_mode( S, idx, g_ring, ne, wc, ctx_by_tag, atomic_buffer_ptr, num_ranks, num_experts, pending_atomic_updates, my_rank, num_nodes); } else { - remote_process_completions_fast_mode( + remote_process_completions_latency_mode( S, idx, g_ring, ne, wc, ctx_by_tag, atomic_buffer_ptr, num_ranks, num_experts, pending_atomic_updates, my_rank, num_nodes); } @@ -1818,7 +1794,8 @@ void remote_poll_completions(ProxyCtx& S, int idx, CopyRingBuffer& g_ring, void* atomic_buffer_ptr, int num_ranks, int num_experts, std::set& pending_atomic_updates, - int my_rank, int num_nodes, bool use_normal_mode) { + int my_rank, int num_nodes, + bool use_throughput_mode) { ibv_wc wc[kMaxOutstandingRecvs]; auto poll_one = [&](ibv_cq* cq) { int ne = poll_cq_once(cq, wc, kMaxOutstandingRecvs); @@ -1826,17 +1803,16 @@ void remote_poll_completions(ProxyCtx& S, int idx, CopyRingBuffer& g_ring, remote_process_completions(S, idx, g_ring, ne, wc, ctx_by_tag, atomic_buffer_ptr, num_ranks, num_experts, pending_atomic_updates, my_rank, num_nodes, - use_normal_mode); + use_throughput_mode); } }; if (S.cq) poll_one(S.cq); - // for (auto* cq : S.extra_cqs) poll_one(cq); } void remote_reg_ack_buf(ibv_pd* pd, uint64_t* ack_buf, ibv_mr*& ack_mr) { if (ack_mr) return; ack_mr = ibv_reg_mr(pd, ack_buf, sizeof(uint64_t) * RECEIVER_BATCH_SIZE, - IBV_ACCESS_LOCAL_WRITE); // host-only + IBV_ACCESS_LOCAL_WRITE); if (!ack_mr) { perror("ibv_reg_mr(ack_buf)"); @@ -1844,73 +1820,6 @@ void remote_reg_ack_buf(ibv_pd* pd, uint64_t* ack_buf, ibv_mr*& ack_mr) { } } -void remote_send_ack(ProxyCtx* ctx, struct ibv_qp* ack_qp, uint64_t& wr_id, - ibv_mr* local_ack_mr, uint64_t* ack_buf, int worker_idx) { - assert(false && "ACK is disabled"); - if (!ack_qp || !local_ack_mr) { - if (!ack_qp) { - fprintf(stderr, "QP not initialised\n"); - std::abort(); - } - if (!local_ack_mr) { - fprintf(stderr, "ACK MR not initialised\n"); - std::abort(); - } - fprintf(stderr, "ACK resources not initialised\n"); - std::abort(); - } - - *reinterpret_cast(ack_buf) = wr_id; - ibv_sge sge = { - .addr = reinterpret_cast(ack_buf), - .length = sizeof(uint64_t), - .lkey = local_ack_mr->lkey, - }; - -#ifdef EFA - auto qpx = (struct ibv_qp_ex*)ack_qp; - ibv_wr_start(qpx); - - qpx->wr_flags = IBV_SEND_SIGNALED; - qpx->wr_id = wr_id; - - ibv_wr_send_imm(qpx, htonl(static_cast(wr_id))); - ibv_wr_set_ud_addr(qpx, ctx->dst_ah, ctx->dst_ack_qpn, QKEY); - ibv_wr_set_sge(qpx, sge.lkey, sge.addr, sge.length); - - auto ret = ibv_wr_complete(qpx); - if (ret) { - fprintf(stderr, "ibv_wr_complete(SEND_WITH_IMM) failed: %d (%s)\n", ret, - strerror(ret)); - std::abort(); - } - -#else - ibv_send_wr wr = {}; - ibv_send_wr* bad = nullptr; - wr.wr_id = wr_id; - wr.sg_list = &sge; - wr.num_sge = 1; - wr.opcode = IBV_WR_SEND_WITH_IMM; - wr.send_flags = IBV_SEND_SIGNALED; // generate a CQE - wr.imm_data = htonl(static_cast(wr_id)); - - int ret = ibv_post_send(ack_qp, &wr, &bad); - - if (ret) { // ret is already an errno value - fprintf(stderr, "ibv_post_send(SEND_WITH_IMM) failed: %d (%s)\n", ret, - strerror(ret)); // strerror(ret) gives the text - if (bad) { - fprintf(stderr, - " first bad WR: wr_id=%llu opcode=%u addr=0x%llx lkey=0x%x\n", - (unsigned long long)bad->wr_id, bad->opcode, - (unsigned long long)bad->sg_list[0].addr, bad->sg_list[0].lkey); - } - std::abort(); - } -#endif -} - void local_post_ack_buf(ProxyCtx& S, int depth) { if (!S.pd || !S.recv_ack_qp) { fprintf(stderr, @@ -1944,8 +1853,181 @@ void local_post_ack_buf(ProxyCtx& S, int depth) { } } -// Normal mode implementation -static void post_atomic_operations_normal_mode( +static void efa_atomic_handle_ring_idx_throughput_mode( + ProxyCtx* ctx, int dst_rank, int my_rank, size_t ring_idx_raw, + std::vector const& idxs, std::vector const& wrs_to_post, + std::vector const& cmds_to_post) { +#ifdef EFA + size_t const local_ring_count = ctx->data_qps_by_channel.size(); + struct ibv_qp_ex* qpx = + (struct ibv_qp_ex*)(local_ring_count + ? ctx->data_qps_by_channel[ring_idx_raw % + local_ring_count] + : ctx->ack_qp); + size_t const remote_ring_count = ctx->dst_data_qpn_by_ring.size(); + uint32_t const dst_qpn = + remote_ring_count + ? ctx->dst_data_qpn_by_ring[ring_idx_raw % remote_ring_count] + : ctx->dst_qpn; + ibv_wr_start(qpx); + std::vector group_wrids; + group_wrids.reserve(idxs.size()); + + for (size_t t = 0; t < idxs.size(); ++t) { + size_t i = idxs[t]; + auto const& cmd = cmds_to_post[i]; + auto wr_id = wrs_to_post[i]; + group_wrids.push_back(wr_id); + + int v = static_cast(cmd.value); + if (v > kLargeAtomicValue) { + v = kMaxSendAtomicValue; + } + if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { + fprintf(stderr, + "[EFA] value=%d (cmd.value: %lu) won't fit in 15 bits; " + "use an inline payload scheme instead.\n", + v, (unsigned long)cmd.value); + std::abort(); + } + + uint32_t offset = static_cast(cmd.req_rptr); + int low_latency_buffer_idx = get_low_latency(cmd.cmd_type); + if (low_latency_buffer_idx < 0 || low_latency_buffer_idx > 1) { + fprintf(stderr, "Invalid low_latency_buffer_idx: %d\n", + low_latency_buffer_idx); + std::abort(); + } + + uint32_t imm = AtomicsImm::PackAtomic(v, offset).GetImmData(); + + qpx->wr_id = kAtomicWrTag | (wr_id & kAtomicMask); + qpx->comp_mask = 0; + qpx->wr_flags = IBV_SEND_SIGNALED; + + ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, ctx->remote_addr, htonl(imm)); + ibv_wr_set_ud_addr(qpx, ctx->dst_ah, dst_qpn, QKEY); + ibv_wr_set_sge(qpx, ctx->mr->lkey, (uintptr_t)ctx->mr->addr, 0); + } + + int ret = ibv_wr_complete(qpx); + if (ret) { + fprintf(stderr, "[EFA] post_send failed: %s (ret=%d)\n", strerror(ret), + ret); + std::abort(); + } +#else + (void)ctx; + (void)dst_rank; + (void)my_rank; + (void)ring_idx_raw; + (void)idxs; + (void)wrs_to_post; + (void)cmds_to_post; +#endif +} + +static void non_efa_atomic_handle_ring_idx_throughput_mode( + ProxyCtx& S, ProxyCtx* ctx, int dst_rank, int my_rank, size_t ring_idx_raw, + std::vector const& idxs, std::vector const& wrs_to_post, + std::vector const& cmds_to_post, int thread_idx) { +#ifndef EFA + size_t const local_ring_count = ctx->data_qps_by_channel.size(); + struct ibv_qp* qp = + local_ring_count + ? ctx->data_qps_by_channel[ring_idx_raw % local_ring_count] + : ctx->ack_qp; + + size_t const k = idxs.size(); + std::vector sge(k); + std::vector wr(k); + std::vector group_wrids; + group_wrids.reserve(k); + + for (size_t t = 0; t < k; ++t) { + size_t i = idxs[t]; + auto const& cmd = cmds_to_post[i]; + uint64_t const wr_id = wrs_to_post[i]; + group_wrids.push_back(wr_id); + + int v = static_cast(cmd.value); + if (v > kLargeAtomicValue) v = kMaxSendAtomicValue; + if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { + fprintf(stderr, + "value=%d (cmd.value=%lu) won't fit in 15 bits for imm; " + "use a different scheme.\n", + v, (unsigned long)cmd.value); + std::abort(); + } + + uint32_t off16 = static_cast(cmd.req_rptr) & 0xFFFFu; + int low_latency_buffer_idx = get_low_latency(cmd.cmd_type); + if (low_latency_buffer_idx < 0 || low_latency_buffer_idx > 1) { + fprintf(stderr, "Invalid low_latency_buffer_idx: %d\n", + low_latency_buffer_idx); + std::abort(); + } + uint32_t imm = AtomicsImm::Pack( + /*is_atomic*/ true, + /*is_combine*/ get_is_combine(cmd.cmd_type), v, + /*offset*/ off16, low_latency_buffer_idx) + .GetImmData(); + + sge[t].addr = reinterpret_cast(ctx->mr->addr); + sge[t].length = 0; + sge[t].lkey = ctx->mr->lkey; + + std::memset(&wr[t], 0, sizeof(wr[t])); + wr[t].wr_id = kAtomicWrTag | (wr_id & kAtomicMask); + wr[t].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wr[t].send_flags = (t + 1 == k) ? IBV_SEND_SIGNALED : 0; + wr[t].imm_data = htonl(imm); + wr[t].sg_list = &sge[t]; + wr[t].num_sge = 1; + wr[t].wr.rdma.remote_addr = ctx->remote_addr; + wr[t].wr.rdma.rkey = ctx->remote_rkey; + wr[t].next = (t + 1 < k) ? &wr[t + 1] : nullptr; + } + + ibv_send_wr* bad = nullptr; + int ret = ibv_post_send(qp, &wr[0], &bad); + if (ret) { + fprintf(stderr, "[RC] post_send(atomic imm) failed: %s (ret=%d)\n", + strerror(ret), ret); + if (bad) { + fprintf(stderr, " bad wr_id=0x%llx opcode=%u\n", + (unsigned long long)bad->wr_id, bad->opcode); + } + std::abort(); + } + uint64_t const batch_tail_wr = group_wrids.back(); + { + auto [it, inserted] = + S.wr_id_to_wr_ids.try_emplace(batch_tail_wr, std::move(group_wrids)); + if (!inserted) { + fprintf(stderr, + "thread_idx: %d, Error: tail wr_id %lu already exists " + "(map=%p, " + "size=%zu, dst_rank=%d)\n", + thread_idx, batch_tail_wr, (void*)&S.wr_id_to_wr_ids, + S.wr_id_to_wr_ids.size(), dst_rank); + std::abort(); + } + } +#else + (void)S; + (void)ctx; + (void)dst_rank; + (void)my_rank; + (void)ring_idx_raw; + (void)idxs; + (void)wrs_to_post; + (void)cmds_to_post; + (void)thread_idx; +#endif +} + +static void post_atomic_operations_throughput_mode( ProxyCtx& S, std::vector const& wrs_to_post, std::vector const& cmds_to_post, std::vector>& ctxs, int my_rank, int thread_idx, @@ -1972,7 +2054,6 @@ static void post_atomic_operations_normal_mode( ProxyCtx* ctx = ctxs[dst_rank].get(); size_t const k = wr_ids.size(); - // Group by ring index (upper 32 bits in wrs_to_post) std::unordered_map> ring_to_indices; ring_to_indices.reserve(k); for (size_t ii = 0; ii < k; ++ii) { @@ -1983,161 +2064,139 @@ static void post_atomic_operations_normal_mode( } for (auto& [ring_idx_raw, idxs] : ring_to_indices) { - size_t const local_ring_count = ctx->data_qps_by_channel.size(); #ifdef EFA - struct ibv_qp_ex* qpx = - (struct ibv_qp_ex*)(local_ring_count - ? ctx->data_qps_by_channel[ring_idx_raw % - local_ring_count] - : ctx->ack_qp); - size_t const remote_ring_count = ctx->dst_data_qpn_by_ring.size(); - uint32_t const dst_qpn = - remote_ring_count - ? ctx->dst_data_qpn_by_ring[ring_idx_raw % remote_ring_count] - : ctx->dst_qpn; - ibv_wr_start(qpx); - - // Build the chain - std::vector group_wrids; - group_wrids.reserve(idxs.size()); - - for (size_t t = 0; t < idxs.size(); ++t) { - size_t i = idxs[t]; - auto const& cmd = cmds_to_post[i]; - auto wr_id = wrs_to_post[i]; - group_wrids.push_back(wr_id); - - int v = static_cast(cmd.value); - if (v > kLargeAtomicValue) { - // Sender-side saturation to fit imm payload - v = kMaxSendAtomicValue; - } - if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { - fprintf(stderr, - "[EFA] value=%d (cmd.value: %lu) won't fit in 15 bits; " - "use an inline payload scheme instead.\n", - v, (unsigned long)cmd.value); - std::abort(); - } - - uint32_t offset = static_cast(cmd.req_rptr); - int low_latency_buffer_idx = get_low_latency(cmd.cmd_type); - if (low_latency_buffer_idx < 0 || low_latency_buffer_idx > 1) { - fprintf(stderr, "Invalid low_latency_buffer_idx: %d\n", - low_latency_buffer_idx); - std::abort(); - } - - uint32_t imm = AtomicsImm::PackAtomic(v, offset).GetImmData(); + efa_atomic_handle_ring_idx_throughput_mode(ctx, dst_rank, my_rank, + ring_idx_raw, idxs, + wrs_to_post, cmds_to_post); +#else + non_efa_atomic_handle_ring_idx_throughput_mode( + S, ctx, dst_rank, my_rank, ring_idx_raw, idxs, wrs_to_post, + cmds_to_post, thread_idx); +#endif + } + } +} - qpx->wr_id = kAtomicWrTag | (wr_id & kAtomicMask); - qpx->comp_mask = 0; - qpx->wr_flags = IBV_SEND_SIGNALED; +static void efa_atomic_handle_latency_mode( + ProxyCtx* ctx, size_t k, std::vector& wr_ids, + std::vector const& wrs_to_post, + std::vector const& cmds_to_post) { +#ifdef EFA + struct ibv_qp_ex* qpx = (struct ibv_qp_ex*)ctx->qp; + ibv_wr_start(qpx); + for (size_t i = 0; i < k; ++i) { + auto const& cmd = cmds_to_post[wr_ids[i]]; + auto wr_id = wrs_to_post[wr_ids[i]]; + wr_ids[i] = wr_id; + + int v = static_cast(cmd.value); + if (v == kLargeAtomicValue) v = kMaxSendAtomicValue; + if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { + fprintf(stderr, + "[EFA] value=%d (cmd.value: %lu) won't fit in 15 bits; " + "use an inline payload scheme instead.\n", + v, (unsigned long)cmd.value); + std::abort(); + } + uint32_t offset = static_cast(cmd.req_rptr); + int low_latency_buffer_idx = get_low_latency(cmd.cmd_type); + if (low_latency_buffer_idx < 0 || low_latency_buffer_idx > 1) { + fprintf(stderr, "Invalid low_latency_buffer_idx: %d\n", + low_latency_buffer_idx); + std::abort(); + } + uint32_t imm = AtomicsImm::Pack(true, get_is_combine(cmd.cmd_type), v, + offset, low_latency_buffer_idx) + .GetImmData(); - ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, ctx->remote_addr, - htonl(imm)); - ibv_wr_set_ud_addr(qpx, ctx->dst_ah, dst_qpn, QKEY); - ibv_wr_set_sge(qpx, ctx->mr->lkey, (uintptr_t)ctx->mr->addr, 0); - } + qpx->wr_id = kAtomicWrTag | (wr_id & kAtomicMask); + qpx->comp_mask = 0; + qpx->wr_flags = IBV_SEND_SIGNALED; + ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, ctx->remote_addr, htonl(imm)); - int ret = ibv_wr_complete(qpx); - if (ret) { - fprintf(stderr, "[EFA] post_send failed: %s (ret=%d)\n", strerror(ret), - ret); - std::abort(); - } + ibv_wr_set_ud_addr(qpx, ctx->dst_ah, ctx->dst_qpn, QKEY); + ibv_wr_set_sge(qpx, ctx->mr->lkey, (uintptr_t)ctx->mr->addr, 0); + } + int ret = ibv_wr_complete(qpx); + if (ret) { + fprintf(stderr, "[EFA] post_send failed: %s (ret=%d)\n", strerror(ret), + ret); + std::abort(); + } #else - struct ibv_qp* qp = - local_ring_count - ? ctx->data_qps_by_channel[ring_idx_raw % local_ring_count] - : ctx->ack_qp; - - size_t const k = idxs.size(); - std::vector sge(k); - std::vector wr(k); - std::vector group_wrids; - group_wrids.reserve(k); - - for (size_t t = 0; t < k; ++t) { - size_t i = idxs[t]; - auto const& cmd = cmds_to_post[i]; - uint64_t const wr_id = wrs_to_post[i]; - group_wrids.push_back(wr_id); - - int v = static_cast(cmd.value); - if (v > kLargeAtomicValue) v = kMaxSendAtomicValue; // saturate for imm - if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { - fprintf(stderr, - "value=%d (cmd.value=%lu) won't fit in 15 bits for imm; " - "use a different scheme.\n", - v, (unsigned long)cmd.value); - std::abort(); - } - - // If your AtomicsImm for non-EFA expects 16-bit offsets, keep the - // mask: - uint32_t off16 = static_cast(cmd.req_rptr) & 0xFFFFu; - int low_latency_buffer_idx = get_low_latency(cmd.cmd_type); - if (low_latency_buffer_idx < 0 || low_latency_buffer_idx > 1) { - fprintf(stderr, "Invalid low_latency_buffer_idx: %d\n", - low_latency_buffer_idx); - std::abort(); - } - uint32_t imm = AtomicsImm::Pack( - /*is_atomic*/ true, - /*is_combine*/ get_is_combine(cmd.cmd_type), v, - /*offset*/ off16, low_latency_buffer_idx) - .GetImmData(); - - // Zero-length write-with-imm on RC QP - sge[t].addr = reinterpret_cast(ctx->mr->addr); - sge[t].length = 0; - sge[t].lkey = ctx->mr->lkey; - - std::memset(&wr[t], 0, sizeof(wr[t])); - wr[t].wr_id = kAtomicWrTag | (wr_id & kAtomicMask); - wr[t].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wr[t].send_flags = (t + 1 == k) ? IBV_SEND_SIGNALED : 0; - wr[t].imm_data = htonl(imm); - wr[t].sg_list = &sge[t]; - wr[t].num_sge = 1; - wr[t].wr.rdma.remote_addr = ctx->remote_addr; - wr[t].wr.rdma.rkey = ctx->remote_rkey; - wr[t].next = (t + 1 < k) ? &wr[t + 1] : nullptr; - } - - ibv_send_wr* bad = nullptr; - int ret = ibv_post_send(qp, &wr[0], &bad); - if (ret) { - fprintf(stderr, "[RC] post_send(atomic imm) failed: %s (ret=%d)\n", - strerror(ret), ret); - if (bad) { - fprintf(stderr, " bad wr_id=0x%llx opcode=%u\n", - (unsigned long long)bad->wr_id, bad->opcode); - } - std::abort(); - } - uint64_t const batch_tail_wr = group_wrids.back(); - { - auto [it, inserted] = S.wr_id_to_wr_ids.try_emplace( - batch_tail_wr, std::move(group_wrids)); - if (!inserted) { - fprintf(stderr, - "thread_idx: %d, Error: tail wr_id %lu already exists " - "(map=%p, " - "size=%zu, dst_rank=%d)\n", - thread_idx, batch_tail_wr, (void*)&S.wr_id_to_wr_ids, - S.wr_id_to_wr_ids.size(), dst_rank); - std::abort(); - } - } + (void)ctx; + (void)k; + (void)wr_ids; + (void)wrs_to_post; + (void)cmds_to_post; #endif +} + +static void non_efa_atomic_handle_latency_mode( + ProxyCtx* ctx, size_t k, std::vector& wr_ids, + std::vector const& wrs_to_post, + std::vector const& cmds_to_post) { +#ifndef EFA + std::vector sge(k); + std::vector wr(k); + + for (size_t i = 0; i < k; ++i) { + auto const& cmd = cmds_to_post[wr_ids[i]]; + uint64_t const wrid = wrs_to_post[wr_ids[i]]; + wr_ids[i] = wrid; + + int v = static_cast(cmd.value); + if (v == kLargeAtomicValue) v = kMaxSendAtomicValue; + if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { + fprintf(stderr, "value=%d won't fit in 15 bits\n", v); + std::abort(); + } + uint32_t const off16 = static_cast(cmd.req_rptr) & 0xFFFFu; + int low_latency_buffer_idx = get_low_latency(cmd.cmd_type); + if (low_latency_buffer_idx < 0 || low_latency_buffer_idx > 1) { + fprintf(stderr, "Invalid low_latency_buffer_idx: %d\n", + low_latency_buffer_idx); + std::abort(); + } + uint32_t const imm = AtomicsImm::Pack(true, get_is_combine(cmd.cmd_type), v, + off16, low_latency_buffer_idx) + .GetImmData(); + sge[i].addr = reinterpret_cast(ctx->mr->addr); + sge[i].length = 0; + sge[i].lkey = ctx->mr->lkey; + + std::memset(&wr[i], 0, sizeof(wr[i])); + wr[i].wr_id = kAtomicWrTag | (wrid & kAtomicMask); + wr[i].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wr[i].send_flags = (i + 1 == k) ? IBV_SEND_SIGNALED : 0; + wr[i].imm_data = htonl(imm); + wr[i].sg_list = &sge[i]; + wr[i].num_sge = 1; + wr[i].wr.rdma.remote_addr = ctx->remote_addr; + wr[i].wr.rdma.rkey = ctx->remote_rkey; + wr[i].next = (i + 1 < k) ? &wr[i + 1] : nullptr; + } + { + ibv_send_wr* bad = nullptr; + int ret = ibv_post_send(ctx->qp, &wr[0], &bad); + if (ret) { + fprintf(stderr, "ibv_post_send(atomic) failed: %d (%s)\n", ret, + strerror(ret)); + if (bad) + fprintf(stderr, " bad wr_id=0x%llx\n", (unsigned long long)bad->wr_id); + std::abort(); } } +#else + (void)ctx; + (void)k; + (void)wr_ids; + (void)wrs_to_post; + (void)cmds_to_post; +#endif } -// Fast mode implementation -static void post_atomic_operations_fast_mode( +static void post_atomic_operations_latency_mode( ProxyCtx& S, std::vector const& wrs_to_post, std::vector const& cmds_to_post, std::vector>& ctxs, int my_rank, int thread_idx, @@ -2165,100 +2224,10 @@ static void post_atomic_operations_fast_mode( ProxyCtx* ctx = ctxs[dst_rank].get(); size_t const k = wr_ids.size(); #ifdef EFA - struct ibv_qp_ex* qpx = (struct ibv_qp_ex*)ctx->qp; - ibv_wr_start(qpx); - for (size_t i = 0; i < k; ++i) { - auto const& cmd = cmds_to_post[wr_ids[i]]; - auto wr_id = wrs_to_post[wr_ids[i]]; - wr_ids[i] = wr_id; - - int v = static_cast(cmd.value); - if (v == kLargeAtomicValue) v = kMaxSendAtomicValue; - if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { - fprintf(stderr, - "[EFA] value=%d (cmd.value: %lu) won't fit in 15 bits; " - "use an inline payload scheme instead.\n", - v, (unsigned long)cmd.value); - std::abort(); - } - uint32_t offset = static_cast(cmd.req_rptr); - int low_latency_buffer_idx = get_low_latency(cmd.cmd_type); - if (low_latency_buffer_idx < 0 || low_latency_buffer_idx > 1) { - fprintf(stderr, "Invalid low_latency_buffer_idx: %d\n", - low_latency_buffer_idx); - std::abort(); - } - uint32_t imm = AtomicsImm::Pack(true, get_is_combine(cmd.cmd_type), v, - offset, low_latency_buffer_idx) - .GetImmData(); - - qpx->wr_id = kAtomicWrTag | (wr_id & kAtomicMask); - qpx->comp_mask = 0; - qpx->wr_flags = IBV_SEND_SIGNALED; - ibv_wr_rdma_write_imm(qpx, ctx->remote_rkey, ctx->remote_addr, - htonl(imm)); - - ibv_wr_set_ud_addr(qpx, ctx->dst_ah, ctx->dst_qpn, QKEY); - ibv_wr_set_sge(qpx, ctx->mr->lkey, (uintptr_t)ctx->mr->addr, 0); - } - int ret = ibv_wr_complete(qpx); - if (ret) { - fprintf(stderr, "[EFA] post_send failed: %s (ret=%d)\n", strerror(ret), - ret); - std::abort(); - } + efa_atomic_handle_latency_mode(ctx, k, wr_ids, wrs_to_post, cmds_to_post); #else - std::vector sge(k); - std::vector wr(k); - - for (size_t i = 0; i < k; ++i) { - auto const& cmd = cmds_to_post[wr_ids[i]]; - uint64_t const wrid = wrs_to_post[wr_ids[i]]; - wr_ids[i] = wrid; - - int v = static_cast(cmd.value); - if (v == kLargeAtomicValue) v = kMaxSendAtomicValue; - if (v < -kMaxSendAtomicValue || v > kMaxSendAtomicValue) { - fprintf(stderr, "value=%d won't fit in 15 bits\n", v); - std::abort(); - } - uint32_t const off16 = static_cast(cmd.req_rptr) & 0xFFFFu; - int low_latency_buffer_idx = get_low_latency(cmd.cmd_type); - if (low_latency_buffer_idx < 0 || low_latency_buffer_idx > 1) { - fprintf(stderr, "Invalid low_latency_buffer_idx: %d\n", - low_latency_buffer_idx); - std::abort(); - } - uint32_t const imm = AtomicsImm::Pack(true, get_is_combine(cmd.cmd_type), - v, off16, low_latency_buffer_idx) - .GetImmData(); - sge[i].addr = reinterpret_cast(ctx->mr->addr); - sge[i].length = 0; - sge[i].lkey = ctx->mr->lkey; - - std::memset(&wr[i], 0, sizeof(wr[i])); - wr[i].wr_id = kAtomicWrTag | (wrid & kAtomicMask); - wr[i].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - wr[i].send_flags = (i + 1 == k) ? IBV_SEND_SIGNALED : 0; - wr[i].imm_data = htonl(imm); - wr[i].sg_list = &sge[i]; - wr[i].num_sge = 1; - wr[i].wr.rdma.remote_addr = ctx->remote_addr; - wr[i].wr.rdma.rkey = ctx->remote_rkey; - wr[i].next = (i + 1 < k) ? &wr[i + 1] : nullptr; - } - { - ibv_send_wr* bad = nullptr; - int ret = ibv_post_send(ctx->qp, &wr[0], &bad); - if (ret) { - fprintf(stderr, "ibv_post_send(atomic) failed: %d (%s)\n", ret, - strerror(ret)); - if (bad) - fprintf(stderr, " bad wr_id=0x%llx\n", - (unsigned long long)bad->wr_id); - std::abort(); - } - } + non_efa_atomic_handle_latency_mode(ctx, k, wr_ids, wrs_to_post, + cmds_to_post); #endif uint64_t const batch_tail_wr = wr_ids.back(); { @@ -2277,19 +2246,18 @@ static void post_atomic_operations_fast_mode( } } -// Wrapper that selects implementation based on use_normal_mode void post_atomic_operations(ProxyCtx& S, std::vector const& wrs_to_post, std::vector const& cmds_to_post, std::vector>& ctxs, int my_rank, int thread_idx, std::unordered_set& acked_wrs, - bool use_normal_mode) { - if (use_normal_mode) { - post_atomic_operations_normal_mode(S, wrs_to_post, cmds_to_post, ctxs, - my_rank, thread_idx, acked_wrs); + bool use_throughput_mode) { + if (use_throughput_mode) { + post_atomic_operations_throughput_mode(S, wrs_to_post, cmds_to_post, ctxs, + my_rank, thread_idx, acked_wrs); } else { - post_atomic_operations_fast_mode(S, wrs_to_post, cmds_to_post, ctxs, - my_rank, thread_idx, acked_wrs); + post_atomic_operations_latency_mode(S, wrs_to_post, cmds_to_post, ctxs, + my_rank, thread_idx, acked_wrs); } } diff --git a/ep/src/uccl_ep.cc b/ep/src/uccl_ep.cc index 5ba63839e..7c1c31916 100644 --- a/ep/src/uccl_ep.cc +++ b/ep/src/uccl_ep.cc @@ -12,7 +12,6 @@ #include "internode_ll.cuh" #include "intranode.cuh" #include "layout.hpp" -#include "peer_copy_manager.hpp" #include "ring_buffer.cuh" #include "uccl_bench.hpp" #include "uccl_proxy.hpp" @@ -1993,7 +1992,6 @@ PYBIND11_MODULE(ep, m) { for (auto& proxy : proxies) { vec.push_back(std::move(proxy)); } - printf("Registered proxies for device %d\n", device_index); }, py::arg("device_index"), py::arg("proxies")); m.def( @@ -2170,10 +2168,10 @@ PYBIND11_MODULE(ep, m) { py::arg("total_size"), py::arg("rank") = 0, py::arg("node_idx") = -1, py::arg("local_rank") = 0, py::arg("num_experts") = -1, py::arg("num_ranks") = -1, py::arg("num_nodes") = 0, - py::arg("use_normal_mode") = false, py::arg("is_intranode") = false) + py::arg("use_throughput_mode") = false, + py::arg("is_intranode") = false) .def("start_sender", &UcclProxy::start_sender) .def("start_remote", &UcclProxy::start_remote) - .def("start_local", &UcclProxy::start_local) .def("start_dual", &UcclProxy::start_dual) .def("stop", &UcclProxy::stop) .def("get_listen_port", &UcclProxy::get_listen_port) @@ -2291,16 +2289,6 @@ PYBIND11_MODULE(ep, m) { .def("print_summary", &Bench::print_summary) .def("print_summary_last", &Bench::print_summary_last) .def("last_elapsed_ms", &Bench::last_elapsed_ms); - py::class_(m, "PeerCopyManager") - .def(py::init(), py::arg("src_device") = 0) - .def("start_for_proxies", - [](PeerCopyManager& mgr, py::iterable proxy_list) { - std::vector vec; - for (py::handle h : proxy_list) - vec.push_back(h.cast()); - mgr.start_for_proxies(vec); - }) - .def("stop", &PeerCopyManager::stop); // MSCCLPP Fifo class - must be registered before BenchFifo which uses it py::class_(m, "Fifo").def(py::init(), diff --git a/ep/src/uccl_proxy.cpp b/ep/src/uccl_proxy.cpp index 53bbd00fa..5eab1f229 100644 --- a/ep/src/uccl_proxy.cpp +++ b/ep/src/uccl_proxy.cpp @@ -13,13 +13,11 @@ UcclProxy::UcclProxy(int thread_idx, uintptr_t gpu_buffer_addr, size_t total_size, int rank, int node_idx, int local_rank, int num_experts, int num_ranks, int num_nodes, - bool use_normal_mode, bool is_intranode) + bool use_throughput_mode, bool is_intranode) : thread_{}, mode_{Mode::None}, running_{false}, is_intranode_{is_intranode} { - // EP 8 of internode_ll also need atomic_buffer_ptr - Proxy::Config cfg{}; thread_idx_ = thread_idx; gpu_buffer_addr_ = reinterpret_cast(gpu_buffer_addr); @@ -48,7 +46,7 @@ UcclProxy::UcclProxy(int thread_idx, uintptr_t gpu_buffer_addr, cfg.num_experts = num_experts; cfg.num_ranks = num_ranks; cfg.num_nodes = num_nodes; - cfg.use_normal_mode = use_normal_mode; + cfg.use_throughput_mode = use_throughput_mode; cfg.is_intranode = is_intranode; proxy_ = std::make_unique(cfg); local_rank_ = local_rank; @@ -101,7 +99,6 @@ void UcclProxy::set_peers_meta(std::vector const& peers) { void UcclProxy::start_sender() { start(Mode::Sender); } void UcclProxy::start_remote() { start(Mode::Remote); } -void UcclProxy::start_local() { start(Mode::Local); } void UcclProxy::start_dual() { start(Mode::Dual); } void UcclProxy::stop() { @@ -127,7 +124,6 @@ void UcclProxy::start(Mode m) { thread_ = std::thread([this]() { if (is_intranode_) { std::printf("UcclProxy: no peer IP set, running in local mode\n"); - proxy_->run_local(); return; } switch (mode_) { @@ -137,9 +133,6 @@ void UcclProxy::start(Mode m) { case Mode::Remote: proxy_->run_remote(); break; - case Mode::Local: - proxy_->run_local(); - break; case Mode::Dual: proxy_->run_dual(); break; @@ -242,20 +235,6 @@ void FifoProxy::run_sender() { // Process completed work requests (similar to notify_gpu_completion) while (fifo_tail_acked < fifo_head_seen && proxy_->acked_wrs_.count(fifo_tail_acked) > 0) { -#ifdef MEASURE_PER_VERB_LATENCY - // Track latency - auto it = proxy_->wr_id_to_start_time_.find(fifo_tail_acked); - if (it != proxy_->wr_id_to_start_time_.end()) { - auto duration = std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - it->second); - if (proxy_->completion_count_ > kWarmupOps) { - proxy_->wr_time_total_us_ += duration.count(); - } - proxy_->completion_count_++; - proxy_->wr_id_to_start_time_.erase(it); - } -#endif - // Remove from tracking sets proxy_->acked_wrs_.erase(fifo_tail_acked); @@ -292,13 +271,6 @@ void FifoProxy::run_sender() { // Post immediately (no batching) std::vector wrs_to_post{fifo_head_seen}; std::vector cmds_to_post{cmd}; - -#ifdef MEASURE_PER_VERB_LATENCY - // Record timestamp for latency measurement (like original proxy) - proxy_->wr_id_to_start_time_[fifo_head_seen] = - std::chrono::high_resolution_clock::now(); -#endif - proxy_->post_gpu_commands_mixed(wrs_to_post, cmds_to_post); fifo_head_seen++; } @@ -310,19 +282,6 @@ void FifoProxy::run_sender() { while (fifo_tail_acked < fifo_head_seen && proxy_->acked_wrs_.count(fifo_tail_acked) > 0) { -#ifdef MEASURE_PER_VERB_LATENCY - auto it = proxy_->wr_id_to_start_time_.find(fifo_tail_acked); - if (it != proxy_->wr_id_to_start_time_.end()) { - auto duration = std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - it->second); - if (proxy_->completion_count_ > kWarmupOps) { - proxy_->wr_time_total_us_ += duration.count(); - } - proxy_->completion_count_++; - proxy_->wr_id_to_start_time_.erase(it); - } -#endif - proxy_->acked_wrs_.erase(fifo_tail_acked); fifo_->pop(); fifo_tail_acked++; diff --git a/include/util/net.h b/include/util/net.h index fcabfb1ed..ba69c2db7 100644 --- a/include/util/net.h +++ b/include/util/net.h @@ -1,5 +1,6 @@ #pragma once +#include "util.h" #include #include #include diff --git a/include/util/util.h b/include/util/util.h index 172b99166..07517cecc 100644 --- a/include/util/util.h +++ b/include/util/util.h @@ -762,7 +762,7 @@ static inline int get_dev_index(char const* dev_name) { return ret; } -static inline std::string get_dev_ip(char const* dev_name) { +inline std::string get_dev_ip(char const* dev_name) { struct ifaddrs* ifAddrStruct = NULL; struct ifaddrs* ifa = NULL; void* tmpAddrPtr = NULL;