diff --git a/ep/include/ep_config.hpp b/ep/include/ep_config.hpp index b9fbdf51c..523fef3b0 100644 --- a/ep/include/ep_config.hpp +++ b/ep/include/ep_config.hpp @@ -195,7 +195,9 @@ struct LowLatencyLayout { // Send buffer size_t dispatch_send_buffer_bytes = - num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; + (static_cast(num_experts) + 1) * + static_cast(num_max_dispatch_tokens_per_rank) * + num_bytes_per_dispatch_msg; size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; diff --git a/ep/src/internode_ll.cu b/ep/src/internode_ll.cu index 281f1f1fc..3e149c7be 100644 --- a/ep/src/internode_ll.cu +++ b/ep/src/internode_ll.cu @@ -18,6 +18,8 @@ constexpr int kNumMaxWarpGroups = 16; constexpr int kNumMaxWarpGroups = 32; #endif +constexpr int kDispatchChunkSize = 8; + template __launch_bounds__(kNumThreads, 1) __global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, @@ -53,9 +55,10 @@ __global__ __launch_bounds__(1024, 1) void dispatch( int64_t* dispatch_wait_recv_cost_stats, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, void const* x, int64_t const* topk_idx, int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, - int* next_clean, int* next_clean_second, int num_next_clean_int, - int num_tokens, int num_max_dispatch_tokens_per_rank, int num_topk, - int num_experts, int rank, int num_ranks, int num_warp_groups, + int* chunk_fill_counters, int* next_clean, int* next_clean_second, + int num_next_clean_int, int num_tokens, + int num_max_dispatch_tokens_per_rank, int num_chunks_per_expert, + int num_topk, int num_experts, int rank, int num_ranks, int num_warp_groups, int num_warps_per_group, bool round_scale, int phases, uint64_t const* d2h_channel_addrs, int num_d2h_channel_addrs, int max_nvl_peers, int low_latency_buffer_idx, @@ -94,6 +97,12 @@ __global__ __launch_bounds__(1024, 1) void dispatch( size_t const num_int4_per_msg = num_bytes_per_msg / sizeof(int4); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); + auto const rdma_x_uint8 = static_cast(rdma_x); + auto const rdma_x_chunk_uint8 = + rdma_x_uint8 + + static_cast(num_max_dispatch_tokens_per_rank) * num_bytes_per_msg; + auto const rdma_x_chunk_int4 = reinterpret_cast(rdma_x_chunk_uint8); + // Expert counts __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; @@ -121,12 +130,13 @@ __global__ __launch_bounds__(1024, 1) void dispatch( for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { auto const x_int4 = static_cast(x) + token_idx * hidden_bf16_int4; - auto const rdma_x_src_idx = reinterpret_cast( - static_cast(rdma_x) + token_idx * num_bytes_per_msg); - auto const rdma_x_vec = reinterpret_cast( - reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); + auto const rdma_x_token_uint8 = + rdma_x_uint8 + token_idx * num_bytes_per_msg; + auto const rdma_x_src_idx = reinterpret_cast(rdma_x_token_uint8); + auto const rdma_x_vec = + reinterpret_cast(rdma_x_token_uint8 + sizeof(int4)); auto const rdma_x_scales = reinterpret_cast( - reinterpret_cast(rdma_x_vec) + hidden_bytes); + rdma_x_token_uint8 + sizeof(int4) + hidden_bytes); // Overlap top-k index read and source token index writes auto dst_expert_idx = @@ -188,7 +198,7 @@ __global__ __launch_bounds__(1024, 1) void dispatch( } sync_barrier_1(num_threads); - // Issue IBGDA sends + // Issue IBGDA sends in chunks if (dst_expert_idx >= 0) { int slot_idx = lane_id == 0 @@ -197,40 +207,85 @@ __global__ __launch_bounds__(1024, 1) void dispatch( slot_idx = __shfl_sync(WARP_MASK, slot_idx, 0); auto const dst_rank = dst_expert_idx / num_local_experts; auto const dst_expert_local_idx = dst_expert_idx % num_local_experts; - auto const src_ptr = reinterpret_cast(rdma_x_src_idx); - auto const dst_ptr = - reinterpret_cast(rdma_recv_x) + - dst_expert_local_idx * num_ranks * - num_max_dispatch_tokens_per_rank * num_bytes_per_msg + - rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + - slot_idx * num_bytes_per_msg; - auto const dst_p2p_ptr = - ipc_rdma_base_ptrs - ? uccl::get_ipc_p2p_ptr(dst_ptr, ipc_rdma_base_ptrs, rank, - dst_rank, max_nvl_peers, 0) - : 0; - if (dst_p2p_ptr == 0) { - __threadfence_system(); - uccl::nvshmemi_ibgda_put_nbi_warp( - dst_ptr - reinterpret_cast(rdma_buffer_ptr), - src_ptr - reinterpret_cast(rdma_buffer_ptr), - num_bytes_per_msg, dst_rank, - /*warp_id=*/dst_expert_local_idx, // NOTE(Yang): for selecting - // rb. - lane_id, slot_idx, d2h_channel_addrs, num_d2h_channel_addrs, - false, low_latency_buffer_idx); - } else { - // Intra-node: use direct memory copy via IPC - auto const* src_int4_ptr = reinterpret_cast(src_ptr); - auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); - UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, - src_int4_ptr, ld_nc_global, st_na_global); - } - // Increase counter after finishing + auto const token_msg_int4 = + reinterpret_cast(rdma_x_src_idx); + auto const chunk_slot_linear = + static_cast(dst_expert_idx) * + static_cast(num_max_dispatch_tokens_per_rank) + + static_cast(slot_idx); + auto* chunk_msg_int4 = + rdma_x_chunk_int4 + chunk_slot_linear * num_int4_per_msg; + UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, chunk_msg_int4, + token_msg_int4, ld_nc_global, st_na_global); __syncwarp(); - lane_id == 0 ? atomic_add_release_global( - atomic_finish_counter_per_expert + dst_expert_idx, 1) - : 0; + __threadfence_system(); + + int chunk_id = slot_idx / kDispatchChunkSize; + int chunk_index = dst_expert_idx * num_chunks_per_expert + chunk_id; + int prev_fill = + lane_id == 0 ? atomicAdd(chunk_fill_counters + chunk_index, 1) : 0; + prev_fill = __shfl_sync(WARP_MASK, prev_fill, 0); + bool chunk_ready = (prev_fill + 1) == kDispatchChunkSize; + + if (chunk_ready) { + int const chunk_base_slot = chunk_id * kDispatchChunkSize; + size_t const chunk_bytes = + static_cast(kDispatchChunkSize) * num_bytes_per_msg; + auto const chunk_src_ptr = + reinterpret_cast(rdma_x_chunk_uint8) + + (static_cast(dst_expert_idx) * + static_cast(num_max_dispatch_tokens_per_rank) + + static_cast(chunk_base_slot)) * + num_bytes_per_msg; + auto const chunk_dst_ptr = + reinterpret_cast(rdma_recv_x) + + static_cast(dst_expert_local_idx) * num_ranks * + static_cast(num_max_dispatch_tokens_per_rank) * + num_bytes_per_msg + + static_cast(rank) * + static_cast(num_max_dispatch_tokens_per_rank) * + num_bytes_per_msg + + static_cast(chunk_base_slot) * num_bytes_per_msg; + + uint64_t chunk_dst_p2p_ptr = 0; + if (ipc_rdma_base_ptrs && lane_id == 0) + chunk_dst_p2p_ptr = + uccl::get_ipc_p2p_ptr(chunk_dst_ptr, ipc_rdma_base_ptrs, rank, + dst_rank, max_nvl_peers, 0); + auto chunk_dst_p2p_lo = static_cast(chunk_dst_p2p_ptr); + auto chunk_dst_p2p_hi = + static_cast(chunk_dst_p2p_ptr >> 32); + chunk_dst_p2p_lo = __shfl_sync(WARP_MASK, chunk_dst_p2p_lo, 0); + chunk_dst_p2p_hi = __shfl_sync(WARP_MASK, chunk_dst_p2p_hi, 0); + chunk_dst_p2p_ptr = (static_cast(chunk_dst_p2p_hi) << 32) | + chunk_dst_p2p_lo; + + if (chunk_dst_p2p_ptr == 0) { + __threadfence_system(); + uccl::nvshmemi_ibgda_put_nbi_warp( + chunk_dst_ptr - reinterpret_cast(rdma_buffer_ptr), + chunk_src_ptr - reinterpret_cast(rdma_buffer_ptr), + chunk_bytes, dst_rank, + /*warp_id=*/dst_expert_local_idx, lane_id, chunk_base_slot, + d2h_channel_addrs, num_d2h_channel_addrs, false, + low_latency_buffer_idx); + } else { + auto const* chunk_src_int4 = + reinterpret_cast(chunk_src_ptr); + auto* chunk_dst_int4 = reinterpret_cast(chunk_dst_p2p_ptr); + UNROLLED_WARP_COPY( + 8, lane_id, num_int4_per_msg * kDispatchChunkSize, + chunk_dst_int4, chunk_src_int4, ld_nc_global, st_na_global); + } + + __syncwarp(); + if (lane_id == 0) { + st_release_sys_global(chunk_fill_counters + chunk_index, 0); + atomic_add_release_global( + atomic_finish_counter_per_expert + dst_expert_idx, + kDispatchChunkSize); + } + } } } } else if (warp_id == num_warps - 1) { @@ -269,11 +324,87 @@ __global__ __launch_bounds__(1024, 1) void dispatch( #pragma unroll for (int i = expert_begin_idx; i < expert_end_idx; ++i) { auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); + sum = __shfl_sync(WARP_MASK, sum, 0); if (lane_id == 0) { shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); } + + if (sum > 0) { + int remainder = sum % kDispatchChunkSize; + if (remainder != 0) { + int const chunk_id = sum / kDispatchChunkSize; + int const chunk_index = i * num_chunks_per_expert + chunk_id; + if (lane_id == 0) { + while (ld_acquire_global(atomic_counter_per_expert + i) < sum) + ; + while (ld_acquire_global(chunk_fill_counters + chunk_index) < + remainder) + ; + } + + auto const chunk_base_slot = chunk_id * kDispatchChunkSize; + size_t const chunk_bytes = + static_cast(remainder) * num_bytes_per_msg; + auto const chunk_src_ptr = + reinterpret_cast(rdma_x_chunk_uint8) + + (static_cast(i) * + static_cast(num_max_dispatch_tokens_per_rank) + + static_cast(chunk_base_slot)) * + num_bytes_per_msg; + auto const dst_rank = i / num_local_experts; + auto const dst_expert_local_idx = i % num_local_experts; + auto const chunk_dst_ptr = + reinterpret_cast(rdma_recv_x) + + static_cast(dst_expert_local_idx) * + static_cast(num_ranks) * + static_cast(num_max_dispatch_tokens_per_rank) * + num_bytes_per_msg + + static_cast(rank) * + static_cast(num_max_dispatch_tokens_per_rank) * + num_bytes_per_msg + + static_cast(chunk_base_slot) * num_bytes_per_msg; + + uint64_t chunk_dst_p2p_ptr = 0; + if (ipc_rdma_base_ptrs && lane_id == 0) + chunk_dst_p2p_ptr = + uccl::get_ipc_p2p_ptr(chunk_dst_ptr, ipc_rdma_base_ptrs, rank, + dst_rank, max_nvl_peers, 0); + auto chunk_dst_p2p_lo = static_cast(chunk_dst_p2p_ptr); + auto chunk_dst_p2p_hi = + static_cast(chunk_dst_p2p_ptr >> 32); + chunk_dst_p2p_lo = __shfl_sync(WARP_MASK, chunk_dst_p2p_lo, 0); + chunk_dst_p2p_hi = __shfl_sync(WARP_MASK, chunk_dst_p2p_hi, 0); + chunk_dst_p2p_ptr = (static_cast(chunk_dst_p2p_hi) << 32) | + chunk_dst_p2p_lo; + + if (chunk_dst_p2p_ptr == 0) { + __threadfence_system(); + uccl::nvshmemi_ibgda_put_nbi_warp( + chunk_dst_ptr - reinterpret_cast(rdma_buffer_ptr), + chunk_src_ptr - reinterpret_cast(rdma_buffer_ptr), + chunk_bytes, dst_rank, + /*warp_id=*/dst_expert_local_idx, lane_id, chunk_base_slot, + d2h_channel_addrs, num_d2h_channel_addrs, false, + low_latency_buffer_idx); + } else { + auto const* chunk_src_int4 = + reinterpret_cast(chunk_src_ptr); + auto* chunk_dst_int4 = reinterpret_cast(chunk_dst_p2p_ptr); + UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg * remainder, + chunk_dst_int4, chunk_src_int4, ld_nc_global, + st_na_global); + } + + __syncwarp(); + if (lane_id == 0) { + st_release_sys_global(chunk_fill_counters + chunk_index, 0); + atomic_add_release_global(atomic_finish_counter_per_expert + i, + remainder); + } + } + } } } __syncthreads(); @@ -521,31 +652,39 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, auto atomic_counter_per_expert = static_cast(workspace); auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; - EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); + int const num_chunks_per_expert = + ceil_div(num_max_dispatch_tokens_per_rank, kDispatchChunkSize); + auto chunk_fill_counters = atomic_finish_counter_per_expert + num_experts; + auto const required_workspace_ints = + static_cast(num_experts) * + static_cast(2 + num_chunks_per_expert); + EP_HOST_ASSERT(required_workspace_ints * sizeof(int) <= + static_cast(NUM_WORKSPACE_BYTES)); // FP8 checks if (use_ue8m0) EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`"); -#define DISPATCH_LAUNCH_CASE(hidden) \ - { \ - auto dispatch_func = dispatch