diff --git a/ep/bench/buffer.py b/ep/bench/buffer.py index bf12b2d5..b49a528f 100644 --- a/ep/bench/buffer.py +++ b/ep/bench/buffer.py @@ -679,9 +679,6 @@ def dispatch( # Internode if self.runtime.get_num_rdma_ranks() > 1: - assert ( - num_worst_tokens == 0 - ), "Internode dispatch does not support `num_worst_tokens > 0`" return self.internode_dispatch( x, handle, @@ -692,6 +689,7 @@ def dispatch( topk_idx, topk_weights, expert_alignment, + num_worst_tokens, config, previous_event, async_finish, @@ -881,6 +879,7 @@ def internode_dispatch( topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1, + num_worst_tokens: int = 0, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, @@ -934,6 +933,7 @@ def internode_dispatch( gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, expert_alignment, + num_worst_tokens, config, getattr(previous_event, "event", None), async_finish, @@ -986,6 +986,7 @@ def internode_dispatch( None, None, expert_alignment, + num_worst_tokens, config, getattr(previous_event, "event", None), async_finish, diff --git a/ep/bench/test_internode.py b/ep/bench/test_internode.py index 3ea106ce..b7b9577d 100644 --- a/ep/bench/test_internode.py +++ b/ep/bench/test_internode.py @@ -252,6 +252,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): ) if current_x is not x_pure_rand: check_data(recv_x, recv_gbl_rank_prefix_sum) + recv_topk_weights_clone = None if with_topk: # Check `topk_idx` assert ( @@ -265,6 +266,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): assert recv_topk_idx.eq(i).sum().item() == count # Check `topk_weights` + recv_topk_weights_clone = recv_topk_weights.clone() if current_x is not x_pure_rand: recv_topk_weights[recv_topk_idx.eq(-1)] = ( recv_topk_weights.amax(dim=1, keepdim=True).expand_as( @@ -273,6 +275,40 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): ) check_data(recv_topk_weights, recv_gbl_rank_prefix_sum) + # Test `num_worst_tokens != 0` + if with_topk: + num_worst_tokens = num_tokens * num_ranks + dispatch_args.update({"num_worst_tokens": num_worst_tokens}) + ( + recv_worst_x, + recv_worst_topk_idx, + recv_worst_topk_weights, + empty_list, + _, + event, + ) = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + recv_worst_x = ( + per_token_cast_back(*recv_worst_x) + if isinstance(recv_worst_x, tuple) + else recv_worst_x + ) + assert len(empty_list) == 0 + assert num_worst_tokens == recv_worst_x.size(0) + assert num_worst_tokens == recv_worst_topk_idx.size(0) + assert num_worst_tokens == recv_worst_topk_weights.size(0) + assert torch.equal(recv_x, recv_worst_x[: recv_x.size(0)]) + assert torch.equal( + recv_topk_idx, recv_worst_topk_idx[: recv_x.size(0)] + ) + assert torch.equal( + recv_topk_weights_clone, + recv_worst_topk_weights[: recv_x.size(0)], + ) + assert torch.all( + recv_worst_topk_idx[recv_x.size(0) :] == -1 + ).item() + # Test cached dispatch (must without top-k staffs) if not with_topk: dispatch_args = { diff --git a/ep/bench/utils.py b/ep/bench/utils.py index 732c2d15..72747ccc 100644 --- a/ep/bench/utils.py +++ b/ep/bench/utils.py @@ -89,53 +89,10 @@ def init_dist_under_torchrun(local_rank: int, num_local_ranks: int): ) -def _discover_local_ip(): - # Try to infer the IP that can reach MASTER_ADDR (works in most clusters) - import socket, os - - # Method 1: Use MASTER_ADDR if available (torchrun style) - if "MASTER_ADDR" in os.environ: - master = os.environ["MASTER_ADDR"] - port = int(os.environ.get("MASTER_PORT", "29500")) - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - s.connect((master, port)) - return s.getsockname()[0] - except: - pass - finally: - s.close() - - # Method 2: Use hostname resolution (works in AWS and most cloud environments) - hostname = socket.gethostname() - try: - # This usually returns the private IP in cloud environments - local_ip = socket.gethostbyname(hostname) - # Skip loopback addresses - if not local_ip.startswith("127."): - return local_ip - except: - pass - - # Method 3: Connect to a public DNS to determine outgoing interface - try: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - # Google DNS - this doesn't actually send packets - s.connect(("8.8.8.8", 80)) - local_ip = s.getsockname()[0] - s.close() - return local_ip - except: - pass - - # Last resort: return localhost - return "127.0.0.1" - - def _gather_peer_ips(group): # Gather local IP strings across ranks world = dist.get_world_size(group) - my_ip = _discover_local_ip() + my_ip = ep.get_oob_ip() ips = [None] * world dist.all_gather_object(ips, my_ip, group=group) return ips @@ -154,7 +111,7 @@ def get_peer_ip(rank: int, num_ranks: int, group: dist.ProcessGroup): def get_cpu_proxies_meta(proxies, rank, scratch_ptr, scratch_bytes, num_ranks, group): - my_ip = _discover_local_ip() + my_ip = ep.get_oob_ip() meta = { "rank": rank, "ptr": int(scratch_ptr), diff --git a/ep/include/internode.cuh b/ep/include/internode.cuh index 3df61ca4..a4f04d1a 100644 --- a/ep/include/internode.cuh +++ b/ep/include/internode.cuh @@ -31,8 +31,8 @@ void notify_dispatch( int const* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, int const* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, bool const* is_token_in_rank, int num_tokens, - int num_channels, int hidden_int4, int num_scales, int num_topk, - int expert_alignment, int* rdma_channel_prefix_matrix, + int num_worst_tokens, int num_channels, int hidden_int4, int num_scales, + int num_topk, int expert_alignment, int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, @@ -54,25 +54,23 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, uint64_t const* d2h_channel_addrs, int num_d2h_channel_addrs, void* atomic_buffer_ptr); -void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, - float* recv_topk_weights, void* recv_src_meta, void const* x, - float const* x_scales, int64_t const* topk_idx, - float const* topk_weights, int* send_rdma_head, - int* send_nvl_head, int* recv_rdma_channel_prefix_matrix, - int* recv_gbl_channel_prefix_matrix, - int const* rdma_channel_prefix_matrix, - int const* recv_rdma_rank_prefix_sum, - int const* gbl_channel_prefix_matrix, - int const* recv_gbl_rank_prefix_sum, bool const* is_token_in_rank, - int num_tokens, int hidden_int4, int num_scales, int num_topk, - int num_experts, int scale_token_stride, int scale_hidden_stride, - void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, - int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, - int num_max_nvl_chunked_send_tokens, - int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, - bool is_cached_dispatch, cudaStream_t stream, int num_channels, - bool low_latency_mode, uint64_t const* d2h_channel_addrs, - int num_d2h_channel_addrs, void* atomic_buffer_ptr); +void dispatch( + void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, + float* recv_topk_weights, void* recv_src_meta, void const* x, + float const* x_scales, int64_t const* topk_idx, float const* topk_weights, + int* send_rdma_head, int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, + int const* rdma_channel_prefix_matrix, int const* recv_rdma_rank_prefix_sum, + int const* gbl_channel_prefix_matrix, int const* recv_gbl_rank_prefix_sum, + bool const* is_token_in_rank, int num_tokens, int num_worst_tokens, + int hidden_int4, int num_scales, int num_topk, int num_experts, + int scale_token_stride, int scale_hidden_stride, void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, + bool is_cached_dispatch, cudaStream_t stream, int num_channels, + bool low_latency_mode, uint64_t const* d2h_channel_addrs, + int num_d2h_channel_addrs, void* atomic_buffer_ptr); void combine(cudaDataType_t type, void* combined_x, float* combined_topk_weights, diff --git a/ep/include/rdma.hpp b/ep/include/rdma.hpp index 76ddc062..6ab3c208 100644 --- a/ep/include/rdma.hpp +++ b/ep/include/rdma.hpp @@ -2,6 +2,11 @@ #define RDMA_HPP #include "common.hpp" #include "proxy_ctx.hpp" +// clang-format off +// prevent clang-format reordering net.h before util.h +#include "util/util.h" +#include "util/net.h" +// clang-format on #include "ring_buffer.cuh" #include "unistd.h" #include diff --git a/ep/src/internode.cu b/ep/src/internode.cu index 2388bd38..dd37ea1f 100644 --- a/ep/src/internode.cu +++ b/ep/src/internode.cu @@ -97,13 +97,14 @@ __global__ void notify_dispatch( int const* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, int const* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, bool const* is_token_in_rank, int num_tokens, - int num_channels, int expert_alignment, int const rdma_clean_offset, - int const rdma_num_int_clean, int const nvl_clean_offset, - int const nvl_num_int_clean, int* rdma_channel_prefix_matrix, - int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, - int* recv_gbl_rank_prefix_sum, void* rdma_buffer_ptr, void** buffer_ptrs, - int** barrier_signal_ptrs, int rank, uint64_t const* d2h_channel_addrs, - int num_d2h_channel_addrs, void* atomic_buffer_ptr) { + int num_worst_tokens, int num_channels, int expert_alignment, + int const rdma_clean_offset, int const rdma_num_int_clean, + int const nvl_clean_offset, int const nvl_num_int_clean, + int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, void** buffer_ptrs, int** barrier_signal_ptrs, + int rank, uint64_t const* d2h_channel_addrs, int num_d2h_channel_addrs, + void* atomic_buffer_ptr) { void* original_rdma_buffer_ptr = rdma_buffer_ptr; auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x), @@ -273,9 +274,11 @@ __global__ void notify_dispatch( i)[NUM_MAX_NVL_PEERS + num_rdma_experts]; recv_rdma_rank_prefix_sum[i] = sum; } - while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1) - ; - *moe_recv_rdma_counter_mapped = sum; + if (num_worst_tokens == 0) { + while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1) + ; + *moe_recv_rdma_counter_mapped = sum; + } } // Send numbers of tokens per rank/expert to NVL ranks @@ -303,9 +306,11 @@ __global__ void notify_dispatch( sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank]; recv_gbl_rank_prefix_sum[i] = sum; } - while (ld_volatile_global(moe_recv_counter_mapped) != -1) - ; - *moe_recv_counter_mapped = sum; + if (num_worst_tokens == 0) { + while (ld_volatile_global(moe_recv_counter_mapped) != -1) + ; + *moe_recv_counter_mapped = sum; + } } if (thread_id < num_nvl_experts) { int sum = 0; @@ -313,10 +318,12 @@ __global__ void notify_dispatch( for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id]; sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; - while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != - -1) - ; - moe_recv_expert_counter_mapped[thread_id] = sum; + if (num_worst_tokens == 0) { + while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != + -1) + ; + moe_recv_expert_counter_mapped[thread_id] = sum; + } } // Finally barrier @@ -394,8 +401,8 @@ void notify_dispatch( int const* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, int const* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, bool const* is_token_in_rank, int num_tokens, - int num_channels, int hidden_int4, int num_scales, int num_topk, - int expert_alignment, int* rdma_channel_prefix_matrix, + int num_worst_tokens, int num_channels, int hidden_int4, int num_scales, + int num_topk, int expert_alignment, int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, @@ -403,23 +410,24 @@ void notify_dispatch( cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool low_latency_mode, uint64_t const* d2h_channel_addrs, int num_d2h_channel_addrs, void* atomic_buffer_ptr) { -#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ - { \ - auto notify_dispatch_func = low_latency_mode \ - ? notify_dispatch \ - : notify_dispatch; \ - LAUNCH_KERNEL( \ - &cfg, notify_dispatch_func, num_tokens_per_rank, \ - moe_recv_counter_mapped, num_ranks, num_tokens_per_rdma_rank, \ - moe_recv_rdma_counter_mapped, num_tokens_per_expert, \ - moe_recv_expert_counter_mapped, num_experts, is_token_in_rank, \ - num_tokens, num_channels, expert_alignment, rdma_clean_meta.first, \ - rdma_clean_meta.second, nvl_clean_meta.first, nvl_clean_meta.second, \ - rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \ - gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, rdma_buffer_ptr, \ - buffer_ptrs, barrier_signal_ptrs, rank, d2h_channel_addrs, \ - num_d2h_channel_addrs, atomic_buffer_ptr); \ - } \ +#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ + { \ + auto notify_dispatch_func = low_latency_mode \ + ? notify_dispatch \ + : notify_dispatch; \ + LAUNCH_KERNEL(&cfg, notify_dispatch_func, num_tokens_per_rank, \ + moe_recv_counter_mapped, num_ranks, \ + num_tokens_per_rdma_rank, moe_recv_rdma_counter_mapped, \ + num_tokens_per_expert, moe_recv_expert_counter_mapped, \ + num_experts, is_token_in_rank, num_tokens, num_worst_tokens, \ + num_channels, expert_alignment, rdma_clean_meta.first, \ + rdma_clean_meta.second, nvl_clean_meta.first, \ + nvl_clean_meta.second, rdma_channel_prefix_matrix, \ + recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \ + recv_gbl_rank_prefix_sum, rdma_buffer_ptr, buffer_ptrs, \ + barrier_signal_ptrs, rank, d2h_channel_addrs, \ + num_d2h_channel_addrs, atomic_buffer_ptr); \ + } \ break #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) @@ -475,8 +483,9 @@ __global__ void __launch_bounds__( int const* recv_rdma_rank_prefix_sum, int const* gbl_channel_prefix_matrix, int const* recv_gbl_rank_prefix_sum, bool const* is_token_in_rank, - int num_tokens, int hidden_int4, int num_scales, int num_topk, - int num_experts, int scale_token_stride, int scale_hidden_stride, + int num_tokens, int num_worst_tokens, int hidden_int4, + int num_scales, int num_topk, int num_experts, + int scale_token_stride, int scale_hidden_stride, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, @@ -1476,27 +1485,42 @@ __global__ void __launch_bounds__( cached_channel_head_idx); } } + + // Clean unused `recv_topk_idx` as -1 + if (num_worst_tokens > 0) { + if (is_forwarder) return; + // get the actual number of num_recv_tokens on the current rank + int num_recv_tokens = recv_gbl_rank_prefix_sum[num_ranks - 1]; + // some ForwarderCoordinator threads exit early, so we only use + // non-forwarder in clean-up channel_id * num_threads is the offset of the + // current non-forwarder sms + auto const clean_start = + num_recv_tokens * num_topk + channel_id * num_threads; + auto const clean_end = num_worst_tokens * num_topk; + auto const clean_stride = num_channels * num_threads; +#pragma unroll + for (int i = clean_start + thread_id; i < clean_end; i += clean_stride) + recv_topk_idx[i] = -1; + } } -void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, - float* recv_topk_weights, void* recv_src_meta, void const* x, - float const* x_scales, int64_t const* topk_idx, - float const* topk_weights, int* send_rdma_head, - int* send_nvl_head, int* recv_rdma_channel_prefix_matrix, - int* recv_gbl_channel_prefix_matrix, - int const* rdma_channel_prefix_matrix, - int const* recv_rdma_rank_prefix_sum, - int const* gbl_channel_prefix_matrix, - int const* recv_gbl_rank_prefix_sum, bool const* is_token_in_rank, - int num_tokens, int hidden_int4, int num_scales, int num_topk, - int num_experts, int scale_token_stride, int scale_hidden_stride, - void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, - int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, - int num_max_nvl_chunked_send_tokens, - int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, - bool is_cached_dispatch, cudaStream_t stream, int num_channels, - bool low_latency_mode, uint64_t const* d2h_channel_addrs, - int num_d2h_channel_addrs, void* atomic_buffer_ptr) { +void dispatch( + void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, + float* recv_topk_weights, void* recv_src_meta, void const* x, + float const* x_scales, int64_t const* topk_idx, float const* topk_weights, + int* send_rdma_head, int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, + int const* rdma_channel_prefix_matrix, int const* recv_rdma_rank_prefix_sum, + int const* gbl_channel_prefix_matrix, int const* recv_gbl_rank_prefix_sum, + bool const* is_token_in_rank, int num_tokens, int num_worst_tokens, + int hidden_int4, int num_scales, int num_topk, int num_experts, + int scale_token_stride, int scale_hidden_stride, void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, + bool is_cached_dispatch, cudaStream_t stream, int num_channels, + bool low_latency_mode, uint64_t const* d2h_channel_addrs, + int num_d2h_channel_addrs, void* atomic_buffer_ptr) { constexpr int kNumDispatchRDMASenderWarps = 7; constexpr int kNumTMABytesPerWarp = 16384; constexpr int smem_size = kNumTMABytesPerWarp * NUM_MAX_NVL_PEERS; @@ -1535,9 +1559,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, send_rdma_head, send_nvl_head, recv_rdma_channel_prefix_matrix, \ recv_gbl_channel_prefix_matrix, rdma_channel_prefix_matrix, \ recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \ - recv_gbl_rank_prefix_sum, is_token_in_rank, num_tokens, hidden_int4, \ - num_scales, num_topk, num_experts, scale_token_stride, \ - scale_hidden_stride, rdma_buffer_ptr, \ + recv_gbl_rank_prefix_sum, is_token_in_rank, num_tokens, \ + num_worst_tokens, hidden_int4, num_scales, num_topk, num_experts, \ + scale_token_stride, scale_hidden_stride, rdma_buffer_ptr, \ num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \ buffer_ptrs, num_max_nvl_chunked_send_tokens, \ num_max_nvl_chunked_recv_tokens, rank, num_ranks, d2h_channel_addrs, \ diff --git a/ep/src/rdma.cpp b/ep/src/rdma.cpp index 78a66c22..096dddc2 100644 --- a/ep/src/rdma.cpp +++ b/ep/src/rdma.cpp @@ -5,9 +5,6 @@ #include "proxy_ctx.hpp" #include "rdma_util.hpp" #include "util/gpu_rt.h" -#include "util/util.h" -// net.h should be included after util.h -#include "util/net.h" #include #include #include diff --git a/ep/src/uccl_ep.cc b/ep/src/uccl_ep.cc index 5ba63839..9ff983ea 100644 --- a/ep/src/uccl_ep.cc +++ b/ep/src/uccl_ep.cc @@ -416,7 +416,7 @@ class Buffer { std::optional const& cached_recv_rdma_rank_prefix_sum, std::optional const& cached_gbl_channel_prefix_matrix, std::optional const& cached_recv_gbl_rank_prefix_sum, - int expert_alignment, uccl::Config const& config, + int expert_alignment, int num_worst_tokens, uccl::Config const& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { // In dispatch, CPU will busy-wait until GPU receive tensor size metadata @@ -598,8 +598,8 @@ class Buffer { num_ranks, num_tokens_per_rdma_rank->data_ptr(), moe_recv_rdma_counter_mapped, num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, - is_token_in_rank.data_ptr(), num_tokens, num_channels, - hidden_int4, num_scales, num_topk, expert_alignment, + is_token_in_rank.data_ptr(), num_tokens, num_worst_tokens, + num_channels, hidden_int4, num_scales, num_topk, expert_alignment, rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), @@ -613,35 +613,41 @@ class Buffer { atomic_buffer_ptr); // Synchronize total received tokens and tokens per expert - auto start_time = std::chrono::high_resolution_clock::now(); - while (true) { - // Read total count - num_recv_tokens = static_cast(*moe_recv_counter); - num_rdma_recv_tokens = static_cast(*moe_recv_rdma_counter); - - // Read per-expert count - bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0); - for (int i = 0; i < num_local_experts and ready; ++i) - ready &= moe_recv_expert_counter[i] >= 0; - - if (ready) break; - - // Timeout check - if (std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - start_time) - .count() > NUM_CPU_TIMEOUT_SECS) { - printf( - "Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: " - "%d\n", - rank, num_recv_tokens, num_rdma_recv_tokens); - for (int i = 0; i < num_local_experts; ++i) - printf("moe_recv_expert_counter[%d]: %d\n", i, - moe_recv_expert_counter[i]); - throw std::runtime_error("DeepEP error: timeout (dispatch CPU)"); + if (num_worst_tokens > 0) { + num_recv_tokens = num_worst_tokens; + num_rdma_recv_tokens = num_worst_tokens; + } else { + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + // Read total count + num_recv_tokens = static_cast(*moe_recv_counter); + num_rdma_recv_tokens = static_cast(*moe_recv_rdma_counter); + + // Read per-expert count + bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0); + for (int i = 0; i < num_local_experts and ready; ++i) + ready &= moe_recv_expert_counter[i] >= 0; + + if (ready) break; + + // Timeout check + if (std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - start_time) + .count() > NUM_CPU_TIMEOUT_SECS) { + printf( + "Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: " + "%d\n", + rank, num_recv_tokens, num_rdma_recv_tokens); + for (int i = 0; i < num_local_experts; ++i) + printf("moe_recv_expert_counter[%d]: %d\n", i, + moe_recv_expert_counter[i]); + throw std::runtime_error("DeepEP error: timeout (dispatch CPU)"); + } } + num_recv_tokens_per_expert_list = + std::vector(moe_recv_expert_counter, + moe_recv_expert_counter + num_local_experts); } - num_recv_tokens_per_expert_list = std::vector( - moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); } // Allocate new tensors @@ -705,9 +711,10 @@ class Buffer { recv_rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), - is_token_in_rank.data_ptr(), num_tokens, hidden_int4, num_scales, - num_topk, num_experts, scale_token_stride, scale_hidden_stride, - rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, + is_token_in_rank.data_ptr(), num_tokens, num_worst_tokens, + hidden_int4, num_scales, num_topk, num_experts, scale_token_stride, + scale_hidden_stride, rdma_buffer_ptr, + config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, rank, num_ranks, cached_mode, @@ -2024,6 +2031,8 @@ PYBIND11_MODULE(ep, m) { uccl::g_proxies_by_dev.clear(); }); + m.def("get_oob_ip", &uccl::get_oob_ip, "Get the OOB IP address"); + m.def("get_rdma_buffer", [](int64_t num_rdma_bytes, int device_index) { void* ptr; CUDA_CHECK(cudaSetDevice(device_index));