diff --git a/.gitignore b/.gitignore index c24ecf89..4a8c877d 100644 --- a/.gitignore +++ b/.gitignore @@ -100,4 +100,6 @@ ep/figs ep/deep_ep_wrapper/deep_ep.egg-info/ *.json -*result.jsonl \ No newline at end of file +*result.jsonl + +ep/deep_ep_wrapper/sglang_profiles* \ No newline at end of file diff --git a/ep/bench/buffer.py b/ep/bench/buffer.py index bf12b2d5..f5942151 100644 --- a/ep/bench/buffer.py +++ b/ep/bench/buffer.py @@ -123,6 +123,7 @@ def __init__( low_latency_mode, explicitly_destroy, int(os.environ.get("LOCAL_WORLD_SIZE", -1)), + Buffer.disable_ll_layered(), ) if num_rdma_bytes: self.runtime.set_rdma_buffer_raw(rdma_buffer_ptr) @@ -173,6 +174,13 @@ def reset_rdma_buffer(self): def connect_atomic_buffer(self, proxy: "ep.UcclProxy"): ep.connect_atomic_buffer(proxy, self.runtime) + @staticmethod + def disable_ll_layered() -> bool: + disable_ll_layered = False + if int(os.environ.get("DEEPEP_DISABLE_LL_DISPATCH_OPT", "0")) == 1: + disable_ll_layered = True + return disable_ll_layered + def destroy(self): """ Destroy the cpp runtime and release resources. @@ -453,7 +461,11 @@ def get_low_latency_rdma_size_hint( size: the RDMA buffer size recommended. """ return ep.get_low_latency_rdma_size_hint( - num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts + Buffer.disable_ll_layered(), + num_max_dispatch_tokens_per_rank, + hidden, + num_ranks, + num_experts, ) def get_comm_stream(self) -> torch.Stream: diff --git a/ep/include/ep_config.hpp b/ep/include/ep_config.hpp index b9fbdf51..18043ad4 100644 --- a/ep/include/ep_config.hpp +++ b/ep/include/ep_config.hpp @@ -174,10 +174,12 @@ struct LowLatencyLayout { count); } - LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, - int hidden, int num_ranks, int num_experts, + LowLatencyLayout(bool disable_ll_layered, void* rdma_buffer, + int num_max_dispatch_tokens_per_rank, int hidden, + int num_ranks, int num_experts, void* atomic_buffer_ptr = nullptr) { int const num_scales = hidden / 128; + int const num_nodes = num_ranks / NUM_MAX_NVL_PEERS; // Dispatch and combine layout: // - 2 symmetric odd/even send buffer @@ -188,9 +190,18 @@ struct LowLatencyLayout { // NOTES: you should add a control `int4` for combine messages if you want // to do data transformation EP_HOST_ASSERT(num_scales * sizeof(float) <= static_cast(hidden)); + size_t per_meta_data_size = sizeof(int4); + size_t per_token_size_unaligned = std::max(hidden * sizeof(nv_bfloat16), + hidden + num_scales * sizeof(float)) + + sizeof(int); // Flag at end of data + // Align to sizeof(int4) for efficient vectorized copies + size_t per_token_size = align(per_token_size_unaligned, sizeof(int4)); size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); + if (!disable_ll_layered) { + num_bytes_per_dispatch_msg = per_meta_data_size + per_token_size; + } size_t num_bytes_per_combine_msg = hidden * sizeof(nv_bfloat16); // Send buffer @@ -209,6 +220,12 @@ struct LowLatencyLayout { size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; + if (!disable_ll_layered) { + dispatch_recv_data_buffer_bytes = + num_experts * num_max_dispatch_tokens_per_rank * per_meta_data_size + + num_nodes * num_max_dispatch_tokens_per_rank * per_token_size; + } + size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; @@ -219,6 +236,12 @@ struct LowLatencyLayout { // Symmetric signaling buffers size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int); + if (!disable_ll_layered) { + dispatch_recv_count_buffer_bytes += NUM_MAX_NVL_PEERS * num_nodes * + num_max_dispatch_tokens_per_rank * + sizeof(int) + + NUM_MAX_NVL_PEERS * sizeof(int); + } size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); @@ -261,11 +284,13 @@ struct LowLatencyLayout { } }; -size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, +size_t get_low_latency_rdma_size_hint(bool dispatch_ll_dispatch_opt, + int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { - auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, - hidden, num_ranks, num_experts, nullptr) + auto num_bytes = LowLatencyLayout(dispatch_ll_dispatch_opt, nullptr, + num_max_dispatch_tokens_per_rank, hidden, + num_ranks, num_experts, nullptr) .total_bytes; return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * diff --git a/ep/include/internode_ll.cuh b/ep/include/internode_ll.cuh index 602e36d9..f02fe938 100644 --- a/ep/include/internode_ll.cuh +++ b/ep/include/internode_ll.cuh @@ -11,9 +11,10 @@ namespace internode_ll { void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, int* clean_1, int num_clean_int_1, cudaStream_t stream); // Dummy host launcher declaration -void dispatch(void* packed_recv_x, void* packed_recv_x_scales, - int* packed_recv_src_info, int64_t* packed_recv_layout_range, - int* packed_recv_count, int* cumulative_local_expert_recv_stats, +void dispatch(bool dispatch_ll_dispatch_opt, void* packed_recv_x, + void* packed_recv_x_scales, int* packed_recv_src_info, + int64_t* packed_recv_layout_range, int* packed_recv_count, + int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, void const* x, int64_t const* topk_idx, int* next_clean, int* next_clean_second, @@ -29,11 +30,12 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, void* atomic_buffer_ptr = nullptr, int* rdma_recv_count_internode = nullptr); -void combine(void* combined_x, void* rdma_recv_x, int* rdma_recv_flag, - void* rdma_send_x, void const* x, int64_t const* topk_idx, - float const* topk_weights, int const* src_info, - int64_t const* layout_range, int64_t* combine_wait_recv_cost_stats, - int* next_clean, int* next_clean_second, int num_next_clean_int, +void combine(bool dispatch_ll_dispatch_opt, void* combined_x, void* rdma_recv_x, + int* rdma_recv_flag, void* rdma_send_x, void const* x, + int64_t const* topk_idx, float const* topk_weights, + int const* src_info, int64_t const* layout_range, + int64_t* combine_wait_recv_cost_stats, int* next_clean, + int* next_clean_second, int num_next_clean_int, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, bool use_logfmt, diff --git a/ep/src/internode_ll.cu b/ep/src/internode_ll.cu index 5e2becc3..23a33a0c 100644 --- a/ep/src/internode_ll.cu +++ b/ep/src/internode_ll.cu @@ -5,9 +5,11 @@ #include "ep_utils.cuh" #include "internode_ll.cuh" #include "uccl_ibgda.cuh" +#include #include #include #include + namespace cg = cooperative_groups; namespace uccl { namespace internode_ll { @@ -47,9 +49,9 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, int* clean_1, template __global__ __launch_bounds__(1024, 1) void dispatch( - void* packed_recv_x, void* packed_recv_x_scales, int* packed_recv_src_info, - int64_t* packed_recv_layout_range, int* packed_recv_count, - int* cumulative_local_expert_recv_stats, + bool disable_ll_layered, void* packed_recv_x, void* packed_recv_x_scales, + int* packed_recv_src_info, int64_t* packed_recv_layout_range, + int* packed_recv_count, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, void const* x, int64_t const* topk_idx, int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, @@ -72,6 +74,24 @@ __global__ __launch_bounds__(1024, 1) void dispatch( auto const sub_warp_id = warp_id % num_warps_per_group; auto const responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; + auto const num_nvl_ranks = NUM_MAX_NVL_PEERS; + auto const num_nodes = num_ranks / num_nvl_ranks; + int* data_ready_counter = + reinterpret_cast(rdma_recv_count + num_experts); + int* next_clean_data_ready_counter = + reinterpret_cast(next_clean + num_experts); + auto* data_ready_send_buffer = + reinterpret_cast(data_ready_counter) + + num_nodes * num_max_dispatch_tokens_per_rank * num_nvl_ranks; + if (!disable_ll_layered) { + if (thread_id < num_nvl_ranks) { + st_na_global(reinterpret_cast(data_ready_send_buffer) + thread_id, + 2); // set to 2 + } + __syncthreads(); + EP_DEVICE_ASSERT(num_ranks % num_nvl_ranks == 0); + } + // May extract UE8M0 from the scales using scale_t = std::conditional_t; using packed_t = std::conditional_t; @@ -94,6 +114,24 @@ __global__ __launch_bounds__(1024, 1) void dispatch( size_t const num_int4_per_msg = num_bytes_per_msg / sizeof(int4); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); + size_t const num_bytes_per_meta = sizeof(int4); + size_t const num_bytes_per_data_unaligned = + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) + : (kHidden * sizeof(nv_bfloat16))) + + sizeof(int); // Flag at end of data + // Align to sizeof(int4) + size_t const num_bytes_per_data = + ((num_bytes_per_data_unaligned + sizeof(int4) - 1) / sizeof(int4)) * + sizeof(int4); + size_t const num_bytes_per_msg_new = num_bytes_per_meta + num_bytes_per_data; + EP_DEVICE_ASSERT(num_bytes_per_msg_new % sizeof(int4) == 0); + + void* rdma_recv_x_meta = rdma_recv_x; + void* rdma_recv_x_data = + (void*)(uint64_t(rdma_recv_x) + num_experts * + num_max_dispatch_tokens_per_rank * + num_bytes_per_meta); + // Expert counts __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; @@ -103,7 +141,16 @@ __global__ __launch_bounds__(1024, 1) void dispatch( #endif // Sending phase + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf( + "[internode_ll] Start of kernel: rank=%d, phases=0x%x, " + "has_send_phase=%d\n", + rank, phases, (phases & LOW_LATENCY_SEND_PHASE) != 0); + } if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV; + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf("[internode_ll] Entering SEND phase: rank=%d\n", rank); + } // There are 2 kinds of warps in this part: // 1. The first-kind warps for FP8 cast and sending top-k tokens @@ -121,8 +168,13 @@ __global__ __launch_bounds__(1024, 1) void dispatch( for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { auto const x_int4 = static_cast(x) + token_idx * hidden_bf16_int4; - auto const rdma_x_src_idx = reinterpret_cast( + auto rdma_x_src_idx = reinterpret_cast( static_cast(rdma_x) + token_idx * num_bytes_per_msg); + + if (!disable_ll_layered) { + rdma_x_src_idx = reinterpret_cast( + static_cast(rdma_x) + token_idx * num_bytes_per_msg_new); + } auto const rdma_x_vec = reinterpret_cast( reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); auto const rdma_x_scales = reinterpret_cast( @@ -190,47 +242,182 @@ __global__ __launch_bounds__(1024, 1) void dispatch( // Issue IBGDA sends if (dst_expert_idx >= 0) { + int send_node_id = + dst_expert_idx >= 0 + ? dst_expert_idx / num_local_experts / num_nvl_ranks + : -1; int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0; slot_idx = __shfl_sync(WARP_MASK, slot_idx, 0); + + /* dst_rank is the logical rank that owns the expert. The final + * destination where the expert lives. */ auto const dst_rank = dst_expert_idx / num_local_experts; auto const dst_expert_local_idx = dst_expert_idx % num_local_experts; - auto const src_ptr = reinterpret_cast(rdma_x_src_idx); - auto const dst_ptr = - reinterpret_cast(rdma_recv_x) + - dst_expert_local_idx * num_ranks * - num_max_dispatch_tokens_per_rank * num_bytes_per_msg + - rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + - slot_idx * num_bytes_per_msg; - auto const dst_p2p_ptr = - ipc_rdma_base_ptrs - ? uccl::get_ipc_p2p_ptr(dst_ptr, ipc_rdma_base_ptrs, rank, - dst_rank, max_nvl_peers, 0) - : 0; - if (dst_p2p_ptr == 0) { - __threadfence_system(); - uccl::nvshmemi_ibgda_put_nbi_warp( - dst_ptr - reinterpret_cast(rdma_buffer_ptr), - src_ptr - reinterpret_cast(rdma_buffer_ptr), - num_bytes_per_msg, dst_rank, - /*warp_id=*/dst_expert_local_idx, // NOTE(Yang): for selecting - // rb. - lane_id, slot_idx, d2h_channel_addrs, num_d2h_channel_addrs, - false, low_latency_buffer_idx); - } else { - // Intra-node: use direct memory copy via IPC - auto const* src_int4_ptr = reinterpret_cast(src_ptr); - auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); - UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, - src_int4_ptr, ld_nc_global, st_na_global); + /* real_write_dst_rank has the same NVLink rail position as the sender. + */ + auto real_write_dst_rank = + (dst_rank / num_nvl_ranks) * num_nvl_ranks + (rank % num_nvl_ranks); + auto real_dst_expert_id = + real_write_dst_rank * num_local_experts + dst_expert_local_idx; + + if (!disable_ll_layered) { + { + EP_DEVICE_ASSERT(num_topk <= 32); + int tmp_dst_expert_id = + (lane_id < num_topk) + ? static_cast( + __ldg(topk_idx + token_idx * num_topk + lane_id)) + : -1; + int tmp_dst_node_id = + (tmp_dst_expert_id >= 0) + ? (tmp_dst_expert_id / num_local_experts / num_nvl_ranks) + : -1; + + /* This check if the token is already sent to the same node */ +#pragma unroll + for (int i = 0; i < warp_id; ++i) { + int dst_node_id = __shfl_sync(0xffffffff, tmp_dst_node_id, i); + if (dst_node_id == send_node_id) { + send_node_id = -1; + break; + } + } + } + + if (send_node_id != -1) { + /* Send the token to the destination node (data + flag) */ + auto const src_ptr = + reinterpret_cast(rdma_x_src_idx) + num_bytes_per_meta; + auto const dst_ptr = reinterpret_cast(rdma_recv_x_data) + + (rank / num_nvl_ranks) * + num_max_dispatch_tokens_per_rank * + num_bytes_per_data + + token_idx * num_bytes_per_data; + + // Write flag value (2) at end of source data buffer with release + // semantics + // Flag is at the end of actual data (before padding) + size_t num_bytes_data_only = + num_bytes_per_data_unaligned - sizeof(int); + if (lane_id == 0) { + auto* flag_ptr = reinterpret_cast( + reinterpret_cast(src_ptr) + num_bytes_data_only); + st_release_sys_global(flag_ptr, 2); + } + __syncwarp(); + + auto const dst_p2p_ptr = + ipc_rdma_base_ptrs ? uccl::get_ipc_p2p_ptr( + dst_ptr, ipc_rdma_base_ptrs, rank, + real_write_dst_rank, max_nvl_peers, 0) + : 0; + + if (dst_p2p_ptr == 0) { + __threadfence_system(); + uccl::nvshmemi_ibgda_put_nbi_warp( + /*dst_off=*/dst_ptr - + reinterpret_cast(rdma_buffer_ptr), + /*src_off=*/src_ptr - + reinterpret_cast(rdma_buffer_ptr), + /*bytes=*/num_bytes_per_data, // Includes flag + /*dst_rank=*/real_write_dst_rank, + /*warp_id=*/dst_expert_local_idx, lane_id, + /*slot_idx=*/slot_idx, d2h_channel_addrs, + num_d2h_channel_addrs, + /*is_atomic=*/false, low_latency_buffer_idx); + } else { + auto const* src_int4_ptr = reinterpret_cast(src_ptr); + auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); + UNROLLED_WARP_COPY(7, lane_id, num_bytes_per_data / sizeof(int4), + dst_int4_ptr, src_int4_ptr, ld_nc_global, + st_na_global); + } + } + + { // send meta + auto const src_ptr = reinterpret_cast(rdma_x_src_idx); + + auto const dst_ptr = + reinterpret_cast(rdma_recv_x_meta) + + dst_expert_local_idx * num_ranks * + num_max_dispatch_tokens_per_rank * num_bytes_per_meta + + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_meta + + slot_idx * num_bytes_per_meta; + + auto const dst_p2p_ptr = + ipc_rdma_base_ptrs + ? uccl::get_ipc_p2p_ptr(dst_ptr, ipc_rdma_base_ptrs, rank, + dst_rank, max_nvl_peers, 0) + : 0; + + if (dst_p2p_ptr == 0) { + __threadfence_system(); + uccl::nvshmemi_ibgda_put_nbi_warp( + /*dst_off=*/dst_ptr - + reinterpret_cast(rdma_buffer_ptr), + /*src_off=*/src_ptr - + reinterpret_cast(rdma_buffer_ptr), + /*bytes=*/num_bytes_per_meta, + /*dst_rank=*/dst_rank, + /*warp_id=*/dst_expert_local_idx, lane_id, slot_idx, + d2h_channel_addrs, num_d2h_channel_addrs, + /*is_atomic=*/false, low_latency_buffer_idx); + } else { + auto const* src_int4_ptr = reinterpret_cast(src_ptr); + auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); + UNROLLED_WARP_COPY(1, lane_id, num_bytes_per_meta / sizeof(int4), + dst_int4_ptr, src_int4_ptr, ld_nc_global, + st_na_global); + } + } } + + if (disable_ll_layered) { + auto const src_ptr = reinterpret_cast(rdma_x_src_idx); + auto const dst_ptr = + reinterpret_cast(rdma_recv_x) + + dst_expert_local_idx * num_ranks * + num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + slot_idx * num_bytes_per_msg; + + auto const dst_p2p_ptr = + ipc_rdma_base_ptrs + ? uccl::get_ipc_p2p_ptr(dst_ptr, ipc_rdma_base_ptrs, rank, + dst_rank, max_nvl_peers, 0) + : 0; + + if (dst_p2p_ptr == 0) { + __threadfence_system(); + uccl::nvshmemi_ibgda_put_nbi_warp( + dst_ptr - reinterpret_cast(rdma_buffer_ptr), + src_ptr - reinterpret_cast(rdma_buffer_ptr), + num_bytes_per_msg, dst_rank, + /*warp_id=*/dst_expert_local_idx, lane_id, slot_idx, + d2h_channel_addrs, num_d2h_channel_addrs, false, + low_latency_buffer_idx); + } else { + auto const* src_int4_ptr = reinterpret_cast(src_ptr); + auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); + UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, + src_int4_ptr, ld_nc_global, st_na_global); + } + } + // Increase counter after finishing __syncwarp(); lane_id == 0 ? atomic_add_release_global( atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0; + if (!disable_ll_layered) { + lane_id == 0 + ? atomic_add_release_global( + atomic_finish_counter_per_expert + real_dst_expert_id, 1) + : 0; + } } } } else if (warp_id == num_warps - 1) { @@ -238,21 +425,24 @@ __global__ __launch_bounds__(1024, 1) void dispatch( EP_DEVICE_ASSERT(num_sms > 1); if (sm_id == 0) { // The first SM is also responsible for cleaning the next buffer + if (disable_ll_layered) { +// The first SM is also responsible for cleaning the next buffer #pragma unroll - for (int i = lane_id; i < num_next_clean_int; i += WARP_SIZE) { - next_clean[i] = 0; - next_clean_second[i] = 0; - } - // Notify before executing `int_p` - __syncwarp(); + for (int i = lane_id; i < num_next_clean_int; i += 32) + next_clean[i] = 0; + + // Notify before executing `int_p` + __syncwarp(); #pragma unroll - for (int i = lane_id; i < num_experts; i += WARP_SIZE) - atomic_add_release_global(atomic_finish_counter_per_expert + i, - FINISHED_SUM_TAG); + for (int i = lane_id; i < num_experts; i += 32) + atomic_add_release_global(atomic_finish_counter_per_expert + i, + FINISHED_SUM_TAG); + } } // This SM should be responsible for some destination experts, read // `topk_idx` for them int expert_count[kNumMaxWarpGroups] = {0}; + int waiting_flag[kNumMaxWarpGroups] = {0}; auto const expert_begin_idx = sm_id * num_warp_groups; auto const expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts); @@ -261,22 +451,98 @@ __global__ __launch_bounds__(1024, 1) void dispatch( #pragma unroll 8 for (int i = lane_id; i < num_tokens * num_topk; i += WARP_SIZE) { auto idx = static_cast(__ldg(topk_idx + i)); - if (idx >= expert_begin_idx and idx < expert_end_idx) + if (idx >= expert_begin_idx and idx < expert_end_idx) { expert_count[idx - expert_begin_idx]++; + if (!disable_ll_layered) { // only open ll dispatch opt, should do + if (idx < 0) continue; + auto const dst_rank = idx / num_local_experts; + auto const dst_expert_local_idx = idx % num_local_experts; + auto real_write_dst_rank = + dst_rank / num_nvl_ranks * num_nvl_ranks + rank % num_nvl_ranks; + auto real_dst_expert_id = + real_write_dst_rank * num_local_experts + dst_expert_local_idx; + if (real_dst_expert_id >= expert_begin_idx and + real_dst_expert_id < expert_end_idx) + waiting_flag[real_dst_expert_id - expert_begin_idx]++; + } + } } // Warp reduce #pragma unroll for (int i = expert_begin_idx; i < expert_end_idx; ++i) { auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); + auto waiting_flag_sum = 0; + if (!disable_ll_layered) { // only open ll dispatch opt, should do + waiting_flag_sum = warp_reduce_sum(waiting_flag[i - expert_begin_idx]); + } if (lane_id == 0) { shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; + int increment = FINISHED_SUM_TAG - waiting_flag_sum - sum; + printf( + "[internode_ll] Last warp: rank=%d, sm_id=%d, expert_idx=%d, " + "sum=%d, waiting_flag_sum=%d, increment=%d\n", + rank, sm_id, i, sum, waiting_flag_sum, increment); atomic_add_release_global(atomic_finish_counter_per_expert + i, - FINISHED_SUM_TAG - sum); + increment); } } } + if (!disable_ll_layered and + sm_id == num_sms - 1) { // only open ll dispatch opt, should do +// The first SM is also responsible for cleaning the next buffer +#pragma unroll + for (int i = thread_id; i < num_experts; + i += blockDim.x) // clean for combine + next_clean[i] = 0; +// clean data ready flag +#pragma unroll 8 + for (int i = thread_id; i < num_max_dispatch_tokens_per_rank * num_ranks; + i += blockDim.x) { + int token_idx = i / num_ranks; + int rank_id = i % num_ranks; + { + auto node_id = rank_id / num_nvl_ranks; + auto nvl_rank_id = rank_id % num_nvl_ranks; + auto* data_ready_flag_ptr = + reinterpret_cast(next_clean_data_ready_counter) + + node_id * num_max_dispatch_tokens_per_rank * num_nvl_ranks + + token_idx * num_nvl_ranks + rank % num_nvl_ranks; + EP_DEVICE_ASSERT(data_ready_flag_ptr - next_clean_data_ready_counter < + num_max_dispatch_tokens_per_rank * num_nodes * + num_nvl_ranks * sizeof(int)); + auto const data_ready_p2p_src_ptr = + ipc_rdma_base_ptrs + ? uccl::get_ipc_p2p_ptr( + /*dst_ptr=*/uint64_t(data_ready_flag_ptr), + /*ipc_rdma_base_ptrs=*/ipc_rdma_base_ptrs, + /*my_rank=*/rank, + /*peer_rank=*/(rank / num_nvl_ranks) * num_nvl_ranks + + nvl_rank_id, + /*max_nvl_peers=*/max_nvl_peers, + /*peer_base_idx=*/0) + : 0; + + reinterpret_cast(data_ready_p2p_src_ptr)[0] = 0; + } + } + __syncthreads(); + if (thread_id == 0) { + printf( + "[internode_ll] Last SM (sm_id=%d, rank=%d) adding FINISHED_SUM_TAG " + "to all experts\n", + sm_id, rank); + } +#pragma unroll + for (int i = thread_id; i < num_experts; i += blockDim.x) + atomic_add_release_global(atomic_finish_counter_per_expert + i, + FINISHED_SUM_TAG); + } + __syncthreads(); + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf("[internode_ll] After last SM sync: rank=%d\n", rank); + } // Issue count sends if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) { @@ -287,18 +553,40 @@ __global__ __launch_bounds__(1024, 1) void dispatch( shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups]; // Wait local sends issued and send expert counts - while (ld_acquire_global(atomic_finish_counter_per_expert + - responsible_expert_idx) != FINISHED_SUM_TAG * 2) + printf( + "[internode_ll] Before wait: rank=%d, responsible_expert_idx=%d, " + "num_tokens_sent=%d, FINISHED_SUM_TAG*2=%d\n", + rank, responsible_expert_idx, num_tokens_sent, FINISHED_SUM_TAG * 2); + int counter_value = ld_acquire_global(atomic_finish_counter_per_expert + + responsible_expert_idx); + printf( + "[internode_ll] Initial counter: rank=%d, responsible_expert_idx=%d, " + "counter_value=%d\n", + rank, responsible_expert_idx, counter_value); + // Wait until counter reaches at least FINISHED_SUM_TAG * 2 + // Use < instead of != because counter can occasionally exceed target due to + // race conditions or timing differences between per-token increments and + // last warp adjustments + while (counter_value < FINISHED_SUM_TAG * 2) { + counter_value = ld_acquire_global(atomic_finish_counter_per_expert + + responsible_expert_idx); #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) __builtin_amdgcn_s_sleep(1); #else ; #endif - + } + printf( + "[internode_ll] After wait: rank=%d, responsible_expert_idx=%d, " + "counter_value=%d\n", + rank, responsible_expert_idx, counter_value); auto dst_ptr = reinterpret_cast( rdma_recv_count + dst_expert_local_idx * num_ranks + rank); auto dst_ptr_internode = reinterpret_cast( rdma_recv_count_internode + dst_expert_local_idx * num_ranks + rank); + // Ensure all metadata writes are flushed before sending atomic to maintain + // ordering on EFA + __threadfence_system(); // Try to use IPC for intra-node atomic operations auto const dst_p2p_ptr = ipc_rdma_base_ptrs @@ -324,33 +612,73 @@ __global__ __launch_bounds__(1024, 1) void dispatch( // Clean `packed_recv_count` if (dst_rank == 0) packed_recv_count[dst_expert_local_idx] = 0; + printf( + "[internode_ll] Finished SEND phase: rank=%d, " + "responsible_expert_idx=%d\n", + rank, responsible_expert_idx); } __syncwarp(); // Receiving phase LOW_LATENCY_DISPATCH_RECV: + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf( + "[internode_ll] LOW_LATENCY_DISPATCH_RECV: rank=%d, phases=0x%x, " + "LOW_LATENCY_RECV_PHASE=0x%x, has_recv_phase=%d\n", + rank, phases, LOW_LATENCY_RECV_PHASE, + (phases & LOW_LATENCY_RECV_PHASE) != 0); + } if ((phases & LOW_LATENCY_RECV_PHASE) == 0) { return; } // For send-and-recv kernels, we need a grid sync for making // `packed_recv_count` visible + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf("[internode_ll] Before grid sync: phases=0x%x, has_send_phase=%d\n", + phases, (phases & LOW_LATENCY_SEND_PHASE) != 0); + } if (phases & LOW_LATENCY_SEND_PHASE) #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) amd::grid_sync(grid_sync_barrier_ptr, num_sms); #else cg::this_grid().sync(); #endif + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf("[internode_ll] After grid sync\n"); + } // Receiving and packing + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf( + "[internode_ll] Before responsible_expert_idx check: " + "responsible_expert_idx=%d, num_experts=%d\n", + responsible_expert_idx, num_experts); + } if (responsible_expert_idx < num_experts) { + if (sub_warp_id == 1 and lane_id == 0) { + printf( + "[internode_ll] Entering recv phase: responsible_expert_idx=%d, " + "num_experts=%d, sub_warp_id=%d, lane_id=%d\n", + responsible_expert_idx, num_experts, sub_warp_id, lane_id); + } auto const src_rank = responsible_expert_idx / num_local_experts; auto const local_expert_idx = responsible_expert_idx % num_local_experts; - auto const rdma_recv_x_uint8 = - static_cast(rdma_recv_x) + - local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * - num_bytes_per_msg + - src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; + uint8_t* rdma_recv_x_uint8 = nullptr; + if (disable_ll_layered) { + rdma_recv_x_uint8 = + static_cast(rdma_recv_x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * + num_bytes_per_msg + + src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; + } + if (!disable_ll_layered) { + rdma_recv_x_uint8 = + static_cast(rdma_recv_x_meta) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * + num_bytes_per_meta + + src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_meta; + } auto const recv_x_int4 = static_cast(packed_recv_x) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; @@ -379,8 +707,28 @@ LOW_LATENCY_DISPATCH_RECV: #else EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15); #endif + // Debug: print from any lane 0 to see if we reach here + if (sub_warp_id == 1 and lane_id == 0) { + printf( + "[internode_ll] DEBUG: sub_warp_id=1 reached before wait section " + "print, " + "responsible_expert_idx=%d, src_rank=%d, rank=%d, " + "local_expert_idx=%d\n", + responsible_expert_idx, src_rank, rank, local_expert_idx); + } + if (lane_id == 0) { + printf( + "[internode_ll] Reached wait section: sub_warp_id=%d, lane_id=%d, " + "src_rank=%d, rank=%d, local_expert_idx=%d\n", + sub_warp_id, lane_id, src_rank, rank, local_expert_idx); + } if (sub_warp_id == 1 and lane_id == 0) { auto start_time = clock64(); + printf( + "[internode_ll] Before IPC wait: src_rank=%d, rank=%d, " + "local_expert_idx=%d, max_nvl_peers=%d, sub_warp_id=%d, lane_id=%d\n", + src_rank, rank, local_expert_idx, max_nvl_peers, sub_warp_id, + lane_id); while ((src_rank / max_nvl_peers == rank / max_nvl_peers) && (num_recv_tokens_ipc = ld_acquire_sys_global( rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == @@ -390,7 +738,13 @@ LOW_LATENCY_DISPATCH_RECV: #else ; #endif + printf("[internode_ll] After IPC wait: num_recv_tokens_ipc=%d\n", + num_recv_tokens_ipc); + printf( + "[internode_ll] Before internode wait: src_rank=%d, rank=%d, " + "local_expert_idx=%d, max_nvl_peers=%d\n", + src_rank, rank, local_expert_idx, max_nvl_peers); while ((src_rank / max_nvl_peers != rank / max_nvl_peers) && (num_recv_tokens_internode = ld_acquire_sys_global( rdma_recv_count_internode + local_expert_idx * num_ranks + @@ -400,6 +754,9 @@ LOW_LATENCY_DISPATCH_RECV: #else ; #endif + printf( + "[internode_ll] After internode wait: num_recv_tokens_internode=%d\n", + num_recv_tokens_internode); if (src_rank / max_nvl_peers == rank / max_nvl_peers) { if (ld_acquire_sys_global(rdma_recv_count_internode + @@ -459,21 +816,89 @@ LOW_LATENCY_DISPATCH_RECV: #endif num_recv_tokens = shared_num_recv_tokens[warp_group_id]; recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; - + auto const real_read_src_rank = + src_rank % num_nvl_ranks + rank / num_nvl_ranks * num_nvl_ranks; // Copy tokens EP_DEVICE_ASSERT(num_scales <= 64); for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) { // Copy source info - auto const src_src_idx = - reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); - if (lane_id == 0) - recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); - __syncwarp(); + int4* src_data = nullptr; + if (!disable_ll_layered) { + auto const src_src_idx = + reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_meta); + int src_token_idx = 0; + if (lane_id == 0) { + src_token_idx = ld_nc_global(src_src_idx); + recv_src_info[recv_token_begin_idx + i] = + pack2(src_token_idx, src_rank); + } + src_token_idx = __shfl_sync(0xffffffff, src_token_idx, 0); + auto const src_ptr = reinterpret_cast(rdma_recv_x_data) + + (src_rank / num_nvl_ranks) * + num_max_dispatch_tokens_per_rank * + num_bytes_per_data + + src_token_idx * num_bytes_per_data; + // Flag is at end of actual data (before padding) + size_t num_bytes_data_only = num_bytes_per_data_unaligned - sizeof(int); + auto const data_ready_flag_ptr = + reinterpret_cast(src_ptr) + num_bytes_data_only; + auto const src_data_ready_flag_p2p_ptr = reinterpret_cast( + ipc_rdma_base_ptrs ? uccl::get_ipc_p2p_ptr( + /*dst_ptr=*/data_ready_flag_ptr, + /*ipc_rdma_base_ptrs=*/ipc_rdma_base_ptrs, + /*my_rank=*/rank, + /*peer_rank=*/real_read_src_rank, + /*max_nvl_peers=*/max_nvl_peers, + /*peer_base_idx=*/0) + : 0); + + if (lane_id == 0) { + int tmp = 0; + auto start_time = clock64(); + while (tmp != 2) { // wait for data to be ready + if (src_data_ready_flag_p2p_ptr != 0) { + tmp = ld_acquire_sys_global(src_data_ready_flag_p2p_ptr); + } else { + // RDMA case: read flag directly from data buffer address + // The flag is at the end of the data buffer written via RDMA + auto* flag_ptr = reinterpret_cast( + reinterpret_cast(rdma_recv_x_data) + + (src_rank / num_nvl_ranks) * + num_max_dispatch_tokens_per_rank * num_bytes_per_data + + src_token_idx * num_bytes_per_data + num_bytes_data_only); + tmp = ld_acquire_sys_global(flag_ptr); + } + if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) { + printf( + "DeepEP ll dispatch recv data timeout,src_rank:%d, dst_rank: " + "%d, real_read_src_rank:%d,src_token_idx:%d " + "dst RDMA lane: %d, num_recv_tokens: %d\n", + src_rank, rank, real_read_src_rank, src_token_idx, lane_id, + num_recv_tokens); + trap(); + } + } + } + __syncwarp(); + src_data = reinterpret_cast( + ipc_rdma_base_ptrs + ? uccl::get_ipc_p2p_ptr(src_ptr, ipc_rdma_base_ptrs, rank, + real_read_src_rank, max_nvl_peers, 0) + : 0); + } + if (disable_ll_layered) { + auto const src_src_idx = + reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); + if (lane_id == 0) + recv_src_info[recv_token_begin_idx + i] = + pack2(ld_nc_global(src_src_idx), src_rank); + __syncwarp(); + src_data = reinterpret_cast( + reinterpret_cast(src_src_idx) + sizeof(int4)); + } // Copy data // NOTES: only 2 load iterations for 7K hidden with 7 unrolls - auto const src_data = reinterpret_cast( - reinterpret_cast(src_src_idx) + sizeof(int4)); auto const dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, @@ -515,9 +940,10 @@ LOW_LATENCY_DISPATCH_RECV: } } -void dispatch(void* packed_recv_x, void* packed_recv_x_scales, - int* packed_recv_src_info, int64_t* packed_recv_layout_range, - int* packed_recv_count, int* cumulative_local_expert_recv_stats, +void dispatch(bool disable_ll_layered, void* packed_recv_x, + void* packed_recv_x_scales, int* packed_recv_src_info, + int64_t* packed_recv_layout_range, int* packed_recv_count, + int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, void const* x, int64_t const* topk_idx, int* next_clean, int* next_clean_second, @@ -558,17 +984,18 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, dispatch_func = dispatch