diff --git a/ep/bench/test_internode.py b/ep/bench/test_internode.py index c1a8f3001..62754a50a 100644 --- a/ep/bench/test_internode.py +++ b/ep/bench/test_internode.py @@ -185,9 +185,11 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0 check_start = check_end + do_combine = False + for previous_mode in (False, True): for async_mode in (False, True): - for current_x in (x_pure_rand, x, x_e4m3): + for current_x in (x_pure_rand, x): for with_topk in (False, True): if local_rank == 0: print( @@ -225,6 +227,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): event, ) = buffer.dispatch(**dispatch_args) event.current_stream_wait() if async_mode else () + print("recv_x: ", recv_x) recv_x = ( per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) @@ -284,47 +287,51 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): check_data(recv_x, recv_gbl_rank_prefix_sum) # Test combine - bias_0 = torch.ones( - (num_tokens, hidden), dtype=torch.bfloat16, device="cuda" - ) - bias_1 = torch.randn( - (num_tokens, hidden), dtype=torch.bfloat16, device="cuda" - ) - combine_args = { - "x": recv_x, - "bias": (bias_0, bias_1), - "handle": handle, - "config": config, - "async_finish": async_mode, - } - if with_topk: - combine_args.update({"topk_weights": recv_topk_weights}) - if previous_mode: - combine_args.update({"previous_event": buffer.capture()}) - combined_x, combined_topk_weights, event = buffer.combine( - **combine_args - ) - event.current_stream_wait() if async_mode else () - check_x = ( - combined_x.float() - bias_0.float() - bias_1.float() - ) / is_token_in_rank.sum(dim=1).unsqueeze(1) - ref_x = x_pure_rand if current_x is x_pure_rand else x - assert calc_diff(check_x, ref_x) < 5e-6 - if with_topk: - check_topk_weights = ( - combined_topk_weights - if (current_x is x_pure_rand) - else ( - combined_topk_weights - / is_token_in_rank.sum(dim=1).unsqueeze(1) - ) + + if do_combine: + bias_0 = torch.ones( + (num_tokens, hidden), dtype=torch.bfloat16, device="cuda" + ) + bias_1 = torch.randn( + (num_tokens, hidden), dtype=torch.bfloat16, device="cuda" ) - ref_topk_weights = ( - topk_weights_pure_rand - if current_x is x_pure_rand - else topk_weights + combine_args = { + "x": recv_x, + "bias": (bias_0, bias_1), + "handle": handle, + "config": config, + "async_finish": async_mode, + } + if with_topk: + combine_args.update({"topk_weights": recv_topk_weights}) + if previous_mode: + combine_args.update({"previous_event": buffer.capture()}) + combined_x, combined_topk_weights, event = buffer.combine( + **combine_args ) - assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9 + event.current_stream_wait() if async_mode else () + check_x = ( + combined_x.float() - bias_0.float() - bias_1.float() + ) / is_token_in_rank.sum(dim=1).unsqueeze(1) + ref_x = x_pure_rand if current_x is x_pure_rand else x + assert calc_diff(check_x, ref_x) < 5e-6 + if with_topk: + check_topk_weights = ( + combined_topk_weights + if (current_x is x_pure_rand) + else ( + combined_topk_weights + / is_token_in_rank.sum(dim=1).unsqueeze(1) + ) + ) + ref_topk_weights = ( + topk_weights_pure_rand + if current_x is x_pure_rand + else topk_weights + ) + assert ( + calc_diff(check_topk_weights, ref_topk_weights) < 1e-9 + ) # For later tuning dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2 @@ -334,23 +341,93 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): if local_rank == 0: print("", flush=True) + # will stuck # Tune dispatch performance - best_dispatch_results = None - fp8_factor = (1 + 4 / 128) / 2 - for current_x in (x_e4m3, x): - best_time, best_results = 1e10, None - rdma_send_bytes = ( - (dispatch_bf16_rdma_send_bytes * fp8_factor) - if isinstance(current_x, tuple) - else dispatch_bf16_rdma_send_bytes - ) - nvl_recv_bytes = ( - (dispatch_bf16_nvl_recv_bytes * fp8_factor) - if isinstance(current_x, tuple) - else dispatch_bf16_nvl_recv_bytes + # best_dispatch_results = None + # fp8_factor = (1 + 4 / 128) / 2 + # for current_x in (x, ): + # best_time, best_results = 1e10, None + # rdma_send_bytes = ( + # (dispatch_bf16_rdma_send_bytes * fp8_factor) + # if isinstance(current_x, tuple) + # else dispatch_bf16_rdma_send_bytes + # ) + # nvl_recv_bytes = ( + # (dispatch_bf16_nvl_recv_bytes * fp8_factor) + # if isinstance(current_x, tuple) + # else dispatch_bf16_nvl_recv_bytes + # ) + # for nvl_chunk_size in range(4, 45, 4): + # for rdma_chunk_size in range(4, 33, 4): + # config = Config( + # num_sms, + # nvl_chunk_size, + # nvl_buffer_size, + # rdma_chunk_size, + # rdma_buffer_size, + # ) + # tune_args = {"x": current_x, "handle": handle, "config": config} + # t, notify_t = bench_kineto( + # lambda: buffer.dispatch(**tune_args), ("dispatch", "notify") + # ) + # if t < best_time: + # best_time, best_results = t, ( + # num_sms, + # nvl_chunk_size, + # rdma_chunk_size, + # notify_t, + # ) + # if local_rank == 0: + # print( + # f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}, transmit: {t * 1e6:.2f} us, notify: {notify_t * 1e6:.2f} us, BW: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ", + # flush=True, + # ) + # if local_rank == 0: + # print( + # f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, transmit: {best_time * 1e6:.2f} us, notify: {best_results[3] * 1e6:.2f} us, BW: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', + # flush=True, + # ) + # print("", flush=True) + + # if isinstance(current_x, tuple): + # # Gather FP8 the best config from rank 0 + # best_dispatch_results = torch.tensor( + # [best_results[0], best_results[1], best_results[2]], + # dtype=torch.int32, + # device="cuda", + # ) + # all_best_fp8_results_list = [ + # torch.zeros_like(best_dispatch_results) + # for _ in range(torch.distributed.get_world_size()) + # ] + # dist.all_gather( + # all_best_fp8_results_list, best_dispatch_results, group=group + # ) + # best_dispatch_results = all_best_fp8_results_list[0].tolist() + + if do_combine: + dispatch_config = Config( + best_dispatch_results[0], + best_dispatch_results[1], + nvl_buffer_size, + best_dispatch_results[2], + rdma_buffer_size, ) - for nvl_chunk_size in range(4, 45, 4): - for rdma_chunk_size in range(4, 33, 4): + + dispatch_args = { + "x": x, + "num_tokens_per_rank": num_tokens_per_rank, + "num_tokens_per_rdma_rank": num_tokens_per_rdma_rank, + "is_token_in_rank": is_token_in_rank, + "num_tokens_per_expert": num_tokens_per_expert, + "config": dispatch_config if dispatch_config is not None else config, + } + recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) + + # Tune combine performance + best_time, best_results = 1e10, None + for nvl_chunk_size in range(1, 8, 1): + for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4): config = Config( num_sms, nvl_chunk_size, @@ -358,97 +435,30 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): rdma_chunk_size, rdma_buffer_size, ) - tune_args = {"x": current_x, "handle": handle, "config": config} + tune_args = {"x": recv_x, "handle": handle, "config": config} t, notify_t = bench_kineto( - lambda: buffer.dispatch(**tune_args), ("dispatch", "notify") + lambda: buffer.combine(**tune_args), ("combine", "notify") ) - if t < best_time: - best_time, best_results = t, ( - num_sms, - nvl_chunk_size, - rdma_chunk_size, - notify_t, - ) if local_rank == 0: print( - f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}, transmit: {t * 1e6:.2f} us, notify: {notify_t * 1e6:.2f} us, BW: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ", + f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}, transmit: {t * 1e6:.2f} us, notify: {notify_t * 1e6:.2f} us, BW: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ", flush=True, ) + if t < best_time: + best_time, best_results = t, ( + num_sms, + nvl_chunk_size, + rdma_chunk_size, + notify_t, + ) + if local_rank == 0: print( - f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, transmit: {best_time * 1e6:.2f} us, notify: {best_results[3] * 1e6:.2f} us, BW: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', + f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, transmit: {best_time * 1e6:.2f} us, notify: {best_results[3] * 1e6:.2f} us, BW: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)", flush=True, ) print("", flush=True) - if isinstance(current_x, tuple): - # Gather FP8 the best config from rank 0 - best_dispatch_results = torch.tensor( - [best_results[0], best_results[1], best_results[2]], - dtype=torch.int32, - device="cuda", - ) - all_best_fp8_results_list = [ - torch.zeros_like(best_dispatch_results) - for _ in range(torch.distributed.get_world_size()) - ] - dist.all_gather( - all_best_fp8_results_list, best_dispatch_results, group=group - ) - best_dispatch_results = all_best_fp8_results_list[0].tolist() - dispatch_config = Config( - best_dispatch_results[0], - best_dispatch_results[1], - nvl_buffer_size, - best_dispatch_results[2], - rdma_buffer_size, - ) - - dispatch_args = { - "x": x, - "num_tokens_per_rank": num_tokens_per_rank, - "num_tokens_per_rdma_rank": num_tokens_per_rdma_rank, - "is_token_in_rank": is_token_in_rank, - "num_tokens_per_expert": num_tokens_per_expert, - "config": dispatch_config if dispatch_config is not None else config, - } - recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) - - # Tune combine performance - best_time, best_results = 1e10, None - for nvl_chunk_size in range(1, 8, 1): - for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4): - config = Config( - num_sms, - nvl_chunk_size, - nvl_buffer_size, - rdma_chunk_size, - rdma_buffer_size, - ) - tune_args = {"x": recv_x, "handle": handle, "config": config} - t, notify_t = bench_kineto( - lambda: buffer.combine(**tune_args), ("combine", "notify") - ) - if local_rank == 0: - print( - f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}, transmit: {t * 1e6:.2f} us, notify: {notify_t * 1e6:.2f} us, BW: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ", - flush=True, - ) - if t < best_time: - best_time, best_results = t, ( - num_sms, - nvl_chunk_size, - rdma_chunk_size, - notify_t, - ) - - if local_rank == 0: - print( - f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, transmit: {best_time * 1e6:.2f} us, notify: {best_results[3] * 1e6:.2f} us, BW: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)", - flush=True, - ) - print("", flush=True) - # noinspection PyUnboundLocalVariable,PyShadowingNames def test_loop( @@ -458,7 +468,7 @@ def test_loop( if args.test_ll_compatibility: ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 - num_sms = 24 + num_sms = 4 num_qps_per_rank = max( num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0, diff --git a/ep/include/ep_configs.cuh b/ep/include/ep_configs.cuh index d2c5f9010..bdac8336a 100644 --- a/ep/include/ep_configs.cuh +++ b/ep/include/ep_configs.cuh @@ -12,11 +12,19 @@ // #define ENABLE_FAST_DEBUG #ifndef ENABLE_FAST_DEBUG #define NUM_CPU_TIMEOUT_SECS 100 +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) +#define NUM_TIMEOUT_CYCLES 20000000000ull +#else #define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s +#endif #else #define NUM_CPU_TIMEOUT_SECS 10 +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) +#define NUM_TIMEOUT_CYCLES 2000000000ull +#else #define NUM_TIMEOUT_CYCLES 20000000000ull // 20G cycles ~= 10s #endif +#endif #define LOW_LATENCY_SEND_PHASE 1 #define LOW_LATENCY_RECV_PHASE 2 diff --git a/ep/include/ep_utils.cuh b/ep/include/ep_utils.cuh index 10c40d276..97caf19cf 100644 --- a/ep/include/ep_utils.cuh +++ b/ep/include/ep_utils.cuh @@ -347,6 +347,29 @@ __forceinline__ __device__ float fast_pow2(int x) { return *reinterpret_cast(&bits_x); } +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) +__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, + float& scale_inv, + bool round_scale) { + if (!isfinite(amax) || amax <= 0.0f) { + scale = 1.0f; + scale_inv = 1.0f; + return; + } + float t = amax * kFinfoAmaxInvE4M3; + if (round_scale) { + int e; + frexpf(t, &e); + scale_inv = ldexpf(1.0f, e); + scale = ldexpf(1.0f, -e); + } else { + scale_inv = t; + scale = kFinfoAmaxE4M3 / amax; + } + if (!isfinite(scale) || scale <= 0.0f) scale = 1.0f; + if (!isfinite(scale_inv) || scale_inv <= 0.0f) scale_inv = 1.0f; +} +#else __forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv, bool round_scale) { @@ -359,6 +382,7 @@ __forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, scale = kFinfoAmaxE4M3 / amax; } } +#endif // `ld.global.nc.L1::no_allocate` will be translated into // `LDG.E.NA.[width].CONSTANT` in SASS @@ -901,7 +925,7 @@ __device__ __forceinline__ void st_relaxed_sys_global(int const* ptr, int val) { __device__ __forceinline__ int ld_acquire_cta(int const* ptr) { int ret; #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) - HIP_ATOMIC_LOAD(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_WORKGROUP); + ret = HIP_ATOMIC_LOAD(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_WORKGROUP); #else asm volatile("ld.acquire.cta.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); #endif diff --git a/ep/setup.py b/ep/setup.py index 3413a89d6..ef2d0ae0a 100644 --- a/ep/setup.py +++ b/ep/setup.py @@ -61,6 +61,9 @@ cxx_flags.append("-DDISABLE_AGGRESSIVE_ATOMIC") nvcc_flags.append("-DDISABLE_AGGRESSIVE_ATOMIC") + cxx_flags.append("-DENABLE_FAST_DEBUG") + nvcc_flags.append("-DENABLE_FAST_DEBUG") + device_arch = os.getenv("TORCH_CUDA_ARCH_LIST", "gfx942") os.environ["PYTORCH_ROCM_ARCH"] = device_arch diff --git a/ep/src/internode.cu b/ep/src/internode.cu index 034b4e11d..43a864894 100644 --- a/ep/src/internode.cu +++ b/ep/src/internode.cu @@ -819,7 +819,7 @@ __global__ void __launch_bounds__( acquire_lock(rdma_send_channel_lock + lane_id); auto latest_tail = rdma_send_channel_tail[lane_id]; auto offset = rdma_tail_idx - latest_tail; - while (offset >= WARP_SIZE) { + while (offset >= 32) { release_lock(rdma_send_channel_lock + lane_id); acquire_lock(rdma_send_channel_lock + lane_id); latest_tail = rdma_send_channel_tail[lane_id]; @@ -830,8 +830,7 @@ __global__ void __launch_bounds__( // Add the bit and move the ones if possible auto window = rdma_send_channel_window[lane_id] | (1u << offset); if (offset == 0) { - auto num_empty_slots = - (~window) == 0 ? WARP_SIZE : __ffs(~window) - 1; + auto num_empty_slots = (~window) == 0 ? 32 : __ffs(~window) - 1; st_release_cta(rdma_send_channel_tail + lane_id, latest_tail + num_empty_slots); window >>= num_empty_slots; @@ -1114,9 +1113,10 @@ __global__ void __launch_bounds__( // Copy data #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) - UNROLLED_WARP_COPY( - 5, lane_id, hidden_int4, reinterpret_cast(dst_shifted), - reinterpret_cast(shifted), ld_nc_global, st_na_global); + UNROLLED_WARP_COPY(5, lane_id, num_bytes_per_token / sizeof(int4), + reinterpret_cast(dst_shifted), + reinterpret_cast(shifted), ld_nc_global, + st_na_global); #else if (lane_id == 0) { tma_load_1d(tma_buffer, shifted, tma_mbarrier, num_bytes_per_token, @@ -1240,6 +1240,7 @@ __global__ void __launch_bounds__( } } num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset); + auto num_tokens_to_recv_original = num_tokens_to_recv; // Save for combine usage if (lane_id < kNumRDMARanks and not kCachedMode) diff --git a/ep/src/internode_ll.cu b/ep/src/internode_ll.cu index 281f1f1fc..c77e26bea 100644 --- a/ep/src/internode_ll.cu +++ b/ep/src/internode_ll.cu @@ -18,6 +18,12 @@ constexpr int kNumMaxWarpGroups = 16; constexpr int kNumMaxWarpGroups = 32; #endif +#ifndef UCCL_DEBUG_NAN +#define UCCL_DEBUG_NAN 1 +#endif +#define DBG_ROOT (blockIdx.x == 0 && threadIdx.x == 0) +#define DBG_WARP0 (blockIdx.x == 0 && (threadIdx.x / 32) == 0) + template __launch_bounds__(kNumThreads, 1) __global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, @@ -84,15 +90,17 @@ __global__ __launch_bounds__(1024, 1) void dispatch( size_t const hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); size_t const hidden_int4 = hidden_bytes / sizeof(int4); + EP_DEVICE_ASSERT(hidden_int4 > 0); // Message package: hidden data, FP8 scales, index at source // NOTES: currently we have 3 reserved int fields for future use using vec_t = typename std::conditional::type; size_t const num_bytes_per_msg = - sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) + sizeof(int4) + (kUseFP8 ? (hidden_bytes + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); size_t const num_int4_per_msg = num_bytes_per_msg / sizeof(int4); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); + EP_DEVICE_ASSERT(num_int4_per_msg > 0); // Expert counts __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; @@ -127,6 +135,17 @@ __global__ __launch_bounds__(1024, 1) void dispatch( reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); auto const rdma_x_scales = reinterpret_cast( reinterpret_cast(rdma_x_vec) + hidden_bytes); +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + if constexpr (kUseFP8) { + // One warp cooperatively initializes the per-token scales buffer. + if (warp_id == 0) { + for (int s = lane_id; s < num_scales; s += WARP_SIZE) { + rdma_x_scales[s] = 1.0f; // inverse scale default + } + } + __syncwarp(); + } +#endif // Overlap top-k index read and source token index writes auto dst_expert_idx = @@ -166,8 +185,14 @@ __global__ __launch_bounds__(1024, 1) void dispatch( amax = warp_reduce_max<16>(amax); calculate_fp8_scales(amax, scale, scale_inv, round_scale); if (lane_id == 0 or lane_id == 16) -#endif rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; +#endif + if (blockIdx.x == 0 && threadIdx.x == 0) { + float s_chk = reinterpret_cast( + rdma_x_scales)[i * kNumElemsPerRead / 128]; + printf("[dispatch] token=%d lane=%d i=%d scale_inv(sample)=%e\n", + token_idx, lane_id, i, (double)s_chk); + } // Cast into send buffer vec_t int2_value; @@ -179,8 +204,44 @@ __global__ __launch_bounds__(1024, 1) void dispatch( fp32_values[j + 1] * scale}; fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); +#if UCCL_DEBUG_NAN + if (blockIdx.x == 0 && lane_id == 0 && i == 0 && j == 0) { + printf( + "[dispatch][FP8-PACK] token=%d i=%d j=%d a=%e b=%e " + "scale=%e inv=%e amax=%e\n", + token_idx, i, j, (double)fp32x2.x, (double)fp32x2.y, + (double)scale, (double)scale_inv, (double)amax); + unsigned short raw16 = + *reinterpret_cast(&fp8x2_values[0]); + printf("[dispatch][FP8-PACK] token=%d raw16(first)=0x%04hx\n", + token_idx, raw16); + } +#endif } +#if UCCL_DEBUG_NAN + if (blockIdx.x == 0 && lane_id == 0 && i == 0) { + unsigned long long raw8 = + *reinterpret_cast(&rdma_x_vec[0]); + printf( + "[dispatch][SEND] token=%d raw8(fp8 first 8B)=0x%016llx " + "scale=%e inv=%e amax=%e\n", + token_idx, raw8, (double)scale, (double)scale_inv, + (double)amax); + } +#endif rdma_x_vec[i] = int2_value; +#if UCCL_DEBUG_NAN + if (DBG_WARP0 && i == 0 && lane_id == 0) { + // dump first 8 FP8 bytes of the message payload we just wrote + unsigned long long raw8 = + *reinterpret_cast(&rdma_x_vec[0]); + printf( + "[dispatch][SEND] token=%d raw8(fp8 bytes)=0x%016llx scale=%e " + "inv=%e amax=%e\n", + token_idx, raw8, (double)scale, (double)scale_inv, + (double)amax); + } +#endif } else { // Reinterpret-cast is for C++14 compatibility rdma_x_vec[i] = *reinterpret_cast(&int4_value); @@ -356,7 +417,15 @@ LOW_LATENCY_DISPATCH_RECV: local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales; - +#if UCCL_DEBUG_NAN + if (DBG_ROOT) { + printf( + "[dispatch][RECV-SETUP] expert=%d src_rank=%d local_expert=%d " + "hidden_bytes=%zu num_scales=%d aligned=%d\n", + responsible_expert_idx, src_rank, local_expert_idx, + (size_t)hidden_bytes, num_scales, num_aligned_scales); + } +#endif // Shared between sub-warps in warp groups __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; @@ -425,6 +494,15 @@ LOW_LATENCY_DISPATCH_RECV: recv_range[src_rank] = pack2(num_recv_tokens, recv_token_begin_idx); // Add stats for diagnosis +#if UCCL_DEBUG_NAN + if (DBG_ROOT) { + printf( + "[dispatch][RECV-COUNT] src_rank=%d num_recv_tokens=%d begin=%d " + "(ipc=%d, ib=%d)\n", + src_rank, num_recv_tokens, recv_token_begin_idx, + num_recv_tokens_ipc, num_recv_tokens_internode); + } +#endif if (cumulative_local_expert_recv_stats != nullptr) atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens); @@ -453,9 +531,36 @@ LOW_LATENCY_DISPATCH_RECV: reinterpret_cast(src_src_idx) + sizeof(int4)); auto const dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; +#if UCCL_DEBUG_NAN + if (DBG_WARP0 && i == 0 && lane_id == 0) { + // peek first 16B (int4) of FP8 payload at source before copy + int4 peek_src = src_data[0]; + unsigned long long lo = + *reinterpret_cast(&peek_src); + unsigned long long hi = + *reinterpret_cast(((char*)&peek_src) + 8); + printf( + "[dispatch][RECV-BEFORE] token_off=%d src_raw16=0x%016llx " + "0x%016llx\n", + recv_token_begin_idx, lo, hi); + } +#endif UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); - +#if UCCL_DEBUG_NAN + if (DBG_WARP0 && i == 0 && lane_id == 0) { + // peek first 16B (int4) of packed_recv_x after copy + int4 peek_dst = dst_data[0]; + unsigned long long lo = + *reinterpret_cast(&peek_dst); + unsigned long long hi = + *reinterpret_cast(((char*)&peek_dst) + 8); + printf( + "[dispatch][RECV-AFTER] token_off=%d dst_raw16=0x%016llx " + "0x%016llx\n", + recv_token_begin_idx, lo, hi); + } +#endif // Copy scales if constexpr (kUseFP8) { // Equivalent CuTe layout: @@ -476,6 +581,13 @@ LOW_LATENCY_DISPATCH_RECV: ld_nc_global(src_scales + lane_id)); recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; +#if UCCL_DEBUG_NAN + if (DBG_WARP0 && token_idx == recv_token_begin_idx && pack_idx == 0 && + elem_idx == 0) { + float sc = (float)scale; + printf("[dispatch][RECV-SCALE] first_scale=%e\n", (double)sc); + } +#endif } if (lane_id + WARP_SIZE < num_scales) { auto const pack_idx = (lane_id + WARP_SIZE) / num_elems_per_pack; @@ -484,6 +596,14 @@ LOW_LATENCY_DISPATCH_RECV: ld_nc_global(src_scales + lane_id + WARP_SIZE)); recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; +#if UCCL_DEBUG_NAN + if (DBG_WARP0 && token_idx == recv_token_begin_idx && pack_idx == 0 && + elem_idx == 0) { + float sc2 = (float)scale; + printf("[dispatch][RECV-SCALE-2] first_scale_lane+W=%e\n", + (double)sc2); + } +#endif } } } diff --git a/ep/src/proxy.cpp b/ep/src/proxy.cpp index 20c544e85..311728b09 100644 --- a/ep/src/proxy.cpp +++ b/ep/src/proxy.cpp @@ -840,6 +840,12 @@ void Proxy::post_gpu_commands_mixed( 0) { return; } + + printf( + "[post_gpu_commands_mixed] thread %d: Posting %zu RDMA writes, %zu " + "atomics, %zu barriers, %zu quiets\n", + cfg_.thread_idx, rdma_wrs.size(), atomic_wrs.size(), barrier_cmds.size(), + quiet_cmds.size()); // Handle regular RDMA writes if (!rdma_wrs.empty()) { post_rdma_async_batched(ctx_, cfg_.gpu_buffer, rdma_wrs.size(), rdma_wrs, diff --git a/ep/src/rdma.cpp b/ep/src/rdma.cpp index 3e588f600..25c603dfe 100644 --- a/ep/src/rdma.cpp +++ b/ep/src/rdma.cpp @@ -920,6 +920,15 @@ static void post_rdma_async_batched_normal_mode( .GetImmData(); wrs[j].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; wrs[j].imm_data = htonl(imm); + printf( + "Posting AtomicsImm with imm=0x%08x, atomic_offset: %d, " + "atomic_val: %d\n", + imm, cmd.atomic_offset, cmd.atomic_val); + + AtomicsImm aimm(imm); + assert(aimm.GetValue() == cmd.atomic_val); + assert(aimm.GetOff() == cmd.atomic_offset); + } else if (j + 1 == kgroup) { // Put WriteImm only on the tail WR uint32_t imm = @@ -929,6 +938,7 @@ static void post_rdma_async_batched_normal_mode( .GetImmData(); wrs[j].opcode = IBV_WR_RDMA_WRITE_WITH_IMM; wrs[j].imm_data = htonl(imm); + printf("Posting WriteImm with imm=0x%08x\n", imm); } else { wrs[j].opcode = IBV_WR_RDMA_WRITE; } @@ -950,6 +960,8 @@ static void post_rdma_async_batched_normal_mode( { auto [it, inserted] = S.wr_id_to_wr_ids.try_emplace( batch_tail_wr, std::move(ring_wrids)); + printf("pushed tail wr_id %lu into map (map=%p)\n", batch_tail_wr, + (void*)&S.wr_id_to_wr_ids); if (!inserted) { fprintf(stderr, "thread_idx: %d, Error: tail wr_id %lu already exists " @@ -1245,7 +1257,7 @@ void local_process_completions(ProxyCtx& S, } S.wr_id_to_wr_ids.erase(it); } else { - printf("Error: ACK for unknown wr_id %lu\n", wrid); + printf("Error: Atomic ACK for unknown wr_id %lu\n", wrid); std::abort(); } #endif @@ -1284,7 +1296,7 @@ void local_process_completions(ProxyCtx& S, } S.wr_id_to_wr_ids.erase(it); } else { - printf("Error: ACK for unknown wr_id %lu\n", wr_done); + printf("Error: Write ACK for unknown wr_id %lu\n", wr_done); std::abort(); } #endif @@ -1422,6 +1434,8 @@ void remote_process_completions_normal_mode( std::unordered_map> per_tag; per_tag.reserve(8); + printf("Remote thread %d: processing %d completions\n", idx, ne); + for (int i = 0; i < ne; ++i) { ibv_wc const& cqe = wc[i]; if (cqe.status != IBV_WC_SUCCESS) { @@ -1444,8 +1458,11 @@ void remote_process_completions_normal_mode( if (value == kMaxSendAtomicValue) value = kLargeAtomicValue; if (!aimm.IsReorderable()) { + printf("Applying non-reorderable atomic at index %zu: +%d\n", index, + value); addr32->fetch_add(value, std::memory_order_release); } else { + printf("Applying reorderable atomic at index %zu: +%d\n", index, value); struct SeqBuf { uint8_t expected = 0; // next seq expected uint16_t present_mask = 0; // bitmask of buffered seqs @@ -2138,6 +2155,20 @@ static void post_atomic_operations_normal_mode( } std::abort(); } + uint64_t const batch_tail_wr = group_wrids.back(); + { + auto [it, inserted] = S.wr_id_to_wr_ids.try_emplace( + batch_tail_wr, std::move(group_wrids)); + if (!inserted) { + fprintf(stderr, + "thread_idx: %d, Error: tail wr_id %lu already exists " + "(map=%p, " + "size=%zu, dst_rank=%d)\n", + thread_idx, batch_tail_wr, (void*)&S.wr_id_to_wr_ids, + S.wr_id_to_wr_ids.size(), dst_rank); + std::abort(); + } + } #endif } } diff --git a/ep/src/uccl_ep.cc b/ep/src/uccl_ep.cc index afaa789b3..f965afb9b 100644 --- a/ep/src/uccl_ep.cc +++ b/ep/src/uccl_ep.cc @@ -682,6 +682,8 @@ class Buffer { // Launch data dispatch // NOTES: the buffer size checks are moved into the `.cu` file + + printf("Before dispatch internode\n"); uccl::internode::dispatch( recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr, @@ -733,6 +735,11 @@ class Buffer { // Switch back compute stream if (allocate_on_comm_stream) at::cuda::setCurrentCUDAStream(compute_stream); + // synchronize + CUDA_CHECK(cudaStreamSynchronize(comm_stream)); + // print error + CUDA_CHECK(cudaGetLastError()); + // Return values return {recv_x, recv_x_scales,