diff --git a/build.sh b/build.sh index 17d4cf02d..4707ba25e 100755 --- a/build.sh +++ b/build.sh @@ -105,7 +105,8 @@ build_efa() { # EFA requires a custom NCCL. cd thirdparty/nccl-sg - make src.build -j$(nproc) NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80" + # make src.build -j$(nproc) NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80" + make src.build -j$(nproc) NVCC_GENCODE="-gencode=arch=compute_90,code=sm_90" cd ../.. echo "[container] Copying EFA .so to uccl/lib/" diff --git a/collective/efa/run_p5en.sh b/collective/efa/run_p5en.sh index b4416ed85..aa04e460c 100755 --- a/collective/efa/run_p5en.sh +++ b/collective/efa/run_p5en.sh @@ -10,12 +10,14 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # Disable NVLink. NV_LINK_DISABLE=0 MULTI_GROUP=0 -NIC=10.1.0.0/16 +# NIC=10.1.0.0/16 +NIC=172.31.0.0/16 # Processes/Ranks/GPUs per node. PROCS_PER_NODE=8 TEST=${1:-srd} -NUM_PROCS=${2:-32} +# NUM_PROCS=${2:-32} +NUM_PROCS=${2:-16} PROG_NAME=${3:-0} # all_gather_perf all_reduce_perf alltoall_perf broadcast_perf gather_perf @@ -32,7 +34,8 @@ else exit 1 fi -CHANNELS=32 +# CHANNELS=32 +CHANNELS=16 CHANNELS_NET_PEER=1 # UCCL optimal parameters. Yang: for allreduce with nvlink, we need to use larger buffer to catch up with NCCL with larger buffers, and avoid outliers. @@ -55,7 +58,7 @@ if [ "$TEST" = "srd" ]; then >"nccl_test_outputs/output_rank_$rank.log" done - LIBNCCL_PATH="${UCCL_HOME}/thirdparty/nccl/build/lib/libnccl.so" + LIBNCCL_PATH="${UCCL_HOME}/thirdparty/nccl-sg/build/lib/libnccl.so" PLUGIN_PATH="/opt/amazon/ofi-nccl/lib/x86_64-linux-gnu/libnccl-net.so" mpirun --bind-to none -np ${NUM_PROCS} -N ${PROCS_PER_NODE} --hostfile $NODEFILE --map-by ppr:8:node \ @@ -75,7 +78,9 @@ if [ "$TEST" = "srd" ]; then -x NCCL_NCHANNELS_PER_NET_PEER=${CHANNELS_NET_PEER} \ -x NCCL_P2P_NET_CHUNKSIZE=${CHUNK_SIZE} \ -x NCCL_BUFFSIZE=${BUFFSIZE} \ - ${UCCL_HOME}/thirdparty/nccl-tests/build/${PROG_NAME} \ + -x UCCL_EFA_DEVICES=rdmap110s0,rdmap112s0,rdmap135s0,rdmap137s0,rdmap160s0,rdmap162s0,rdmap85s0,rdmap87s0,rdmap111s0,rdmap113s0,rdmap136s0,rdmap138s0,rdmap161s0,rdmap163s0,rdmap86s0,rdmap88s0 \ + -x UCCL_ENA_DEVICES=enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0 \ + /usr/local/cuda-12.9/efa/test-cuda-12.9/${PROG_NAME} \ -b 1K -e 1G -f 2 -c 1 -w 5 -n 10 -t 1 -g 1 \ 2>&1 | while read -r line; do if [[ "$line" =~ ^\[[0-9]+,([0-9]+)\](.+) ]]; then @@ -129,7 +134,9 @@ elif [ "$TEST" = "ud" ]; then -x NCCL_TOPO_FILE=${UCCL_HOME}/collective/efa/p4d-24xl-topo.xml \ -x NCCL_PXN_DISABLE=1 \ -x UCCL_ENGINE_QUIET=1 \ - ${UCCL_HOME}/thirdparty/nccl-tests/build/${PROG_NAME} \ + -x UCCL_EFA_DEVICES=rdmap110s0,rdmap112s0,rdmap135s0,rdmap137s0,rdmap160s0,rdmap162s0,rdmap85s0,rdmap87s0,rdmap111s0,rdmap113s0,rdmap136s0,rdmap138s0,rdmap161s0,rdmap163s0,rdmap86s0,rdmap88s0 \ + -x UCCL_ENA_DEVICES=enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0,enp71s0 \ + /usr/local/cuda-12.9/efa/test-cuda-12.9/${PROG_NAME} \ -b 1K -e 1G -f 2 -c 1 -w 5 -n 10 -t 1 -g 1 \ 2>&1 | while read -r line; do if [[ "$line" =~ ^\[[0-9]+,([0-9]+)\](.+) ]]; then diff --git a/docker/Dockerfile.cuda b/docker/Dockerfile.cuda index 6ff363942..29b6d8b0d 100644 --- a/docker/Dockerfile.cuda +++ b/docker/Dockerfile.cuda @@ -1,6 +1,9 @@ -ARG BASE_IMAGE=nvidia/cuda:12.3.2-devel-ubuntu22.04 +# ARG BASE_IMAGE=nvidia/cuda:12.3.2-devel-ubuntu22.04 +# FROM ${BASE_IMAGE} +# ARG PY_VER=3.13 +ARG BASE_IMAGE=nvidia/cuda:12.9.0-devel-ubuntu22.04 FROM ${BASE_IMAGE} -ARG PY_VER=3.13 +ARG PY_VER=3.12 # Non-interactive APT ENV DEBIAN_FRONTEND=noninteractive diff --git a/docker/Dockerfile.efa b/docker/Dockerfile.efa index 09abad4d2..afffa3a68 100644 --- a/docker/Dockerfile.efa +++ b/docker/Dockerfile.efa @@ -1,6 +1,9 @@ -ARG BASE_IMAGE=nvidia/cuda:12.3.2-devel-ubuntu22.04 +# ARG BASE_IMAGE=nvidia/cuda:12.3.2-devel-ubuntu22.04 +# FROM ${BASE_IMAGE} +# ARG PY_VER=3.13 +ARG BASE_IMAGE=nvidia/cuda:12.9.0-devel-ubuntu22.04 FROM ${BASE_IMAGE} -ARG PY_VER=3.13 +ARG PY_VER=3.12 # Non-interactive APT ENV DEBIAN_FRONTEND=noninteractive @@ -44,7 +47,8 @@ RUN ln -s /usr/lib/x86_64-linux-gnu/libevent_core-2.1.so.7 /usr/lib/x86_64-linux ln -s /usr/lib/x86_64-linux-gnu/libhwloc.so.15 /usr/lib/x86_64-linux-gnu/libhwloc15.so # Install EFA installer (without kernel driver) -ARG EFA_VER=1.42.0 +# ARG EFA_VER=1.42.0 +ARG EFA_VER=1.43.2 RUN curl -O https://efa-installer.amazonaws.com/aws-efa-installer-${EFA_VER}.tar.gz && \ tar -xf aws-efa-installer-${EFA_VER}.tar.gz && \ cd aws-efa-installer && \ diff --git a/ep/deep_ep_wrapper/README.md b/ep/deep_ep_wrapper/README.md new file mode 100644 index 000000000..1d01bee56 --- /dev/null +++ b/ep/deep_ep_wrapper/README.md @@ -0,0 +1,8 @@ +## DeepEP Wrapper of UCCL-EP + +``` +cp ../bench/buffer.py ./ # Change `utils` to `deep_ep.utils` +cp ../bench/utils.py ./ + +python setup.py install +``` \ No newline at end of file diff --git a/ep/deep_ep_wrapper/deep_ep/__init__.py b/ep/deep_ep_wrapper/deep_ep/__init__.py new file mode 100644 index 000000000..e44d6ed3a --- /dev/null +++ b/ep/deep_ep_wrapper/deep_ep/__init__.py @@ -0,0 +1,15 @@ +from uccl.ep import Config, EventHandle + +from .utils import EventOverlap, check_nvlink_connections, initialize_uccl, destroy_uccl +from .buffer import Buffer +import torch.distributed as dist + +__all__ = [ + 'Config', + 'EventHandle', + 'Buffer', + 'EventOverlap', + 'check_nvlink_connections', + 'initialize_uccl', + 'destroy_uccl', +] diff --git a/ep/deep_ep_wrapper/deep_ep/buffer.py b/ep/deep_ep_wrapper/deep_ep/buffer.py new file mode 100644 index 000000000..f0843fb63 --- /dev/null +++ b/ep/deep_ep_wrapper/deep_ep/buffer.py @@ -0,0 +1,1065 @@ +import os +import torch +import torch.distributed as dist +from typing import Callable, Tuple, Optional, Union, List + +try: + from uccl import ep +except ImportError as exc: + import sys + + sys.stderr.write("Failed to import uccl.ep\n") + raise + +from uccl.ep import EventHandle, Config +from deep_ep.utils import EventOverlap, check_nvlink_connections, initialize_uccl, destroy_uccl + + +class Buffer: + """ + The core expert-parallel (EP) communication buffers for Mixture of Experts (MoE) model, which supports: + - high-throughput intranode all-to-all (dispatch and combine, using NVLink) + - high-throughput internode all-to-all (dispatch and combine, using RDMA and NVLink) + - low-latency all-to-all (dispatch and combine, using RDMA) + + Attributes: + num_sms: the SMs used in high-throughput kernels. + rank: the local rank number. + group_size: the number of ranks in the group. + group: the communication group. + num_nvl_bytes: the buffer size for intranode NVLink communication. + num_rdma_bytes: the buffer size for internode (also for intranode with low-latency mode) RDMA communication. + runtime: the C++ runtime. + """ + + # TODO(MaoZiming): Reduce SMs. UCCL Proxy should reduce the usage of SMs. + num_sms: int = 20 + + def __init__( + self, + group: dist.ProcessGroup, + num_nvl_bytes: int = 0, + num_rdma_bytes: int = 0, + low_latency_mode: bool = False, + num_qps_per_rank: int = 24, + allow_nvlink_for_low_latency_mode: bool = True, + allow_mnnvl: bool = False, + explicitly_destroy: bool = False, + ) -> None: + """ + Initialize the communication buffer. + + Arguments: + group: the communication group. + num_nvl_bytes: the buffer size for intranode NVLink communication. + num_rdma_bytes: the buffer size for internode (also for intranode with low-latency mode) RDMA communication. + low_latency_mode: whether to enable low-latency mode. + num_qps_per_rank: the number of QPs for RDMA, the low-latency mode requires that this number equals + to the number of local experts. + allow_nvlink_for_low_latency_mode: whether allow NVLink traffic for low-latency mode, you should notice + this is somehow incompatible with the hook-based overlapping. + Warning: PCIe connections may lead to errors due to memory ordering issues, + please make sure all connections are via NVLink. + allow_mnnvl: whether to allow MNNVL + explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources; + otherwise, the resources will be released by the destructor. + Note: Releasing resources in the destructor may cause Python's exception handling process to hang. + """ + device_index = torch.cuda.current_device() + self.scratch = torch.zeros( + num_rdma_bytes, dtype=torch.uint8, device=f"cuda:{device_index}" + ) + rdma_buffer_ptr = self.scratch.data_ptr() + self.proxies, self.workers = initialize_uccl( + rdma_buffer_ptr, + num_rdma_bytes, + group.rank(), + dist.get_world_size(group), + group, + use_normal_mode=not low_latency_mode + ) + check_nvlink_connections(group) + + # Initialize the CPP runtime + self.rank = group.rank() + self.group_size = group.size() + self.group = group + self.num_nvl_bytes = num_nvl_bytes + self.num_rdma_bytes = num_rdma_bytes + self.low_latency_mode = low_latency_mode + self.explicitly_destroy = explicitly_destroy + + if "LOCAL_WORLD_SIZE" in os.environ: + local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + else: + local_world_size = torch.cuda.device_count() + + print('rank:', self.rank, 'group_size:', self.group_size, 'group:', self.group, 'num_nvl_bytes:', self.num_nvl_bytes, 'num_rdma_bytes:', self.num_rdma_bytes, 'low_latency_mode:', self.low_latency_mode, 'explicitly_destroy:', self.explicitly_destroy, 'device_index:', device_index, 'proxies:', self.proxies, 'workers:', self.workers, 'local_world_size:', local_world_size, 'rdma_buffer_ptr:', rdma_buffer_ptr) + + self.runtime = ep.Buffer( + self.rank, + self.group_size, + num_nvl_bytes, + num_rdma_bytes, + low_latency_mode, + explicitly_destroy, + local_world_size, # int(os.environ.get("LOCAL_WORLD_SIZE", -1)), + ) + if num_rdma_bytes: + self.runtime.set_rdma_buffer_raw(rdma_buffer_ptr) + + # Synchronize device IDs + device_ids = [ + None, + ] * self.group_size + local_device_id = self.runtime.get_local_device_id() + # print("Before all_gather_object device_ids", local_device_id, flush=True) + dist.all_gather_object(device_ids, local_device_id, group) + # Synchronize IPC handles + ipc_handles = [ + None, + ] * self.group_size + local_ipc_handle = self.runtime.get_local_ipc_handle() + # print("Before all_gather_object ipc_handles", local_ipc_handle, flush=True) + dist.all_gather_object(ipc_handles, local_ipc_handle, group) + + rdma_ipc_handles = [None] * self.group_size + local_rdma_ipc_handle = ( + self.runtime.get_local_rdma_ipc_handle() + if self.num_rdma_bytes > 0 + else None + ) + dist.all_gather_object(rdma_ipc_handles, local_rdma_ipc_handle, group) + root_unique_id = None + # Make CPP runtime available + self.runtime.sync( + device_ids, + ipc_handles, + root_unique_id, + rdma_ipc_handles, + ) + assert self.runtime.is_available() + self.connect_atomic_buffer(self.proxies[0]) + + for proxy in self.proxies: + proxy.set_atomic_buffer_ptr(self.proxies[0].get_atomic_buffer_ptr()) + + def reset_rdma_buffer(self): + """ + Reset the RDMA buffer, this is useful when you want to reuse the RDMA buffer for another run. + + """ + self.runtime.reset_rdma_buffer() + + def connect_atomic_buffer(self, proxy: "ep.UcclProxy"): + ep.connect_atomic_buffer(proxy, self.runtime) + + def destroy(self): + """ + Destroy the cpp runtime and release resources. + + """ + + assert self.explicitly_destroy, "`explicitly_destroy` flag must be set" + + self.runtime.destroy() + self.runtime = None + destroy_uccl(self.proxies, self.workers) + + @staticmethod + def is_sm90_compiled(): + return ep.is_sm90_compiled() + + @staticmethod + def set_num_sms(new_num_sms: int) -> None: + """ + Set the number of SMs to use in high-throughput kernels. + + Arguments: + new_num_sms: the new number to be set. + """ + + assert new_num_sms % 2 == 0, "The SM count must be even" + Buffer.num_sms = new_num_sms + + @staticmethod + def capture() -> EventOverlap: + """ + Capture a CUDA event on the current stream, i.e. `torch.cuda.current_stream()`. + + Returns: + event: the captured event. + """ + return EventOverlap(EventHandle()) + + # noinspection PyTypeChecker + def low_latency_dispatch( + self, + x: torch.Tensor, + topk_idx: torch.Tensor, + num_max_dispatch_tokens_per_rank: int, + num_experts: int, + cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, + dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None, + use_fp8: bool = True, + round_scale: bool = False, + use_ue8m0: bool = False, + async_finish: bool = False, + return_recv_hook: bool = False, + ) -> Tuple[ + Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable + ]: + """ + A low-latency implementation for dispatching with IBGDA. + This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA + (specifically, IBGDA must be enabled). + Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2 + low-latency kernels' result tensors at a single moment. + + Arguments: + x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are + supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`. + topk_idx: `torch.Tensor` with `torch.int64`, shaped as `[num_tokens, num_topk]`, only several top-k shapes + are supported. `-1` indices (not selecting any expert) are supported. + num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. + num_experts: the number of all experts. + cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape + `[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance + monitoring. + dispatch_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics, + which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`. + This is useful for detecting and pre-cisely localizing slow anomalies. + use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors. + round_scale: whether round the scaling factors into power of 2. + use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`). + async_finish: the current stream will not wait for the communication kernels to be finished if set. + return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, + but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. + If you do not set this flag, the kernel will ensure the data's arrival. + + Returns: + recv_x: a tensor or tuple with received tokens for each expert. + With `use_fp8=True`: the first element is a `torch.Tensor` shaped as + `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`. + The second tensor is the corresponding scales for the first element with shape + `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`, + if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as + `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`. + Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility. + With `use_fp8=False`, the result would be a tensor shaped as + `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`. + Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are, + as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced). + recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each + expert receives. As mentioned before, not all tokens are valid in `recv_x`. + handle: the communication handle to be used in the `low_latency_combine` function. + event: the event after executing the kernel (valid only if `async_finish` is set). + hook: the receiving hook function (valid only if `return_recv_hook` is set). + """ + for proxy in self.proxies: + proxy.calculate_and_set_dispatch_recv_data_offset( + num_tokens=x.shape[0], + hidden=x.shape[1], + num_experts=num_experts, + ) + ( + packed_recv_x, + packed_recv_x_scales, + packed_recv_count, + packed_recv_src_info, + packed_recv_layout_range, + event, + hook, + ) = self.runtime.low_latency_dispatch( + x, + topk_idx, + cumulative_local_expert_recv_stats, + dispatch_wait_recv_cost_stats, + num_max_dispatch_tokens_per_rank, + num_experts, + use_fp8, + round_scale, + use_ue8m0, + async_finish, + return_recv_hook, + ) + handle = ( + packed_recv_src_info, + packed_recv_layout_range, + num_max_dispatch_tokens_per_rank, + x.size(1), + num_experts, + ) + tensors_to_record = ( + x, + topk_idx, + packed_recv_x, + packed_recv_x_scales, + packed_recv_count, + packed_recv_src_info, + packed_recv_layout_range, + cumulative_local_expert_recv_stats, + ) + return ( + (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, + packed_recv_count, + handle, + EventOverlap(event, tensors_to_record if async_finish else None), + hook, + ) + + # noinspection PyTypeChecker + def low_latency_combine( + self, + x: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + handle: tuple, + use_logfmt: bool = False, + zero_copy: bool = False, + async_finish: bool = False, + return_recv_hook: bool = False, + out: Optional[torch.Tensor] = None, + combine_wait_recv_cost_stats: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, EventOverlap, Callable]: + """ + A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. + This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA + (specifically, IBGDA must be enabled). + Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2 + low-latency kernels' result tensors at a single moment. + + Arguments: + x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`, + the local calculated tokens to be sent to this original rank and reduced. + topk_idx: `[num_combined_tokens, num_topk]` with `torch.int64`, the expert indices selected by the dispatched + tokens. `-1` indices (not selecting any expert) are supported. Note that, `num_combined_tokens` equals + to the number of dispatched tokens. + topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched + tokens. The received tokens will be reduced with the weights in this tensor. + handle: the communication handle given by the `dispatch` function. + use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits). + zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative + with `get_next_low_latency_combine_buffer`. + async_finish: the current stream will not wait for the communication kernels to be finished if set. + return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, + but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. + If you do not set this flag, the kernel will ensure the data's arrival. + out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly. + combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics, + which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`. + This is useful for detecting and pre-cisely localizing slow anomalies. + + Returns: + combined_x: the reduced token tensor, with shape `[num_combined_tokens, hidden]` and type `torch.bfloat16`. + event: the event after executing the kernel (valid only if `async_finish` is set). + hook: the receiving hook function (valid only if `return_recv_hook` is set). + """ + ( + src_info, + layout_range, + num_max_dispatch_tokens_per_rank, + hidden, + num_experts, + ) = handle + combined_x, event, hook = self.runtime.low_latency_combine( + x, + topk_idx, + topk_weights, + src_info, + layout_range, + combine_wait_recv_cost_stats, + num_max_dispatch_tokens_per_rank, + num_experts, + use_logfmt, + zero_copy, + async_finish, + return_recv_hook, + out, + ) + tensors_to_record = ( + x, + topk_idx, + topk_weights, + src_info, + layout_range, + combined_x, + ) + return ( + combined_x, + EventOverlap(event, tensors_to_record if async_finish else None), + hook, + ) + + def get_next_low_latency_combine_buffer(self, handle: object): + """ + Get the raw registered RDMA buffer tensor for next low-latency combine, so that the next combine kernel can skip the copying. + + Arguments: + handle: the communication handle given by the `dispatch` function. + + Returns: + buffer: the raw RDMA low-latency buffer as a BF16 PyTorch tensor with shape + `[num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden]`, you should fill this buffer + by yourself. + """ + ( + src_info, + layout_range, + num_max_dispatch_tokens_per_rank, + hidden, + num_experts, + ) = handle + return self.runtime.get_next_low_latency_combine_buffer( + num_max_dispatch_tokens_per_rank, hidden, num_experts + ) + + @staticmethod + def get_low_latency_rdma_size_hint( + num_max_dispatch_tokens_per_rank: int, + hidden: int, + num_ranks: int, + num_experts: int, + ) -> int: + """ + Get a minimum size requirement for the RDMA buffer. The size calculation will be done with BF16. + + Arguments: + num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. + hidden: the hidden dimension of each token. + num_ranks: the number of EP group ranks. + num_experts: the number of all experts. + + Returns: + size: the RDMA buffer size recommended. + """ + return ep.get_low_latency_rdma_size_hint( + num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts + ) + + def get_comm_stream(self) -> torch.Stream: + """ + Get the communication stream. + + Returns: + stream: the communication stream. + """ + ts: torch.Stream = self.runtime.get_comm_stream() + return torch.cuda.Stream( + stream_id=ts.stream_id, + device_index=ts.device_index, + device_type=ts.device_type, + ) + + def get_local_buffer_tensor( + self, + dtype: torch.dtype, + size: Optional[torch.Size] = None, + offset: int = 0, + use_rdma_buffer: bool = False, + ) -> torch.Tensor: + """ + Get the raw buffer (slice supported) as a PyTorch tensor. + + Argument: + dtype: the data type (PyTorch `dtype`) for the tensor. + size: the slice size (by elements) to get from the buffer. + offset: the offset of the beginning element. + use_rdma_buffer: whether to return the RDMA buffer. + """ + tensor = self.runtime.get_local_buffer_tensor(dtype, offset, use_rdma_buffer) + if size is None: + return tensor + + assert tensor.numel() >= size.numel() + return tensor[: size.numel()].view(size) + + @staticmethod + def _unpack_bias(bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]): + bias_0, bias_1 = None, None + if isinstance(bias, torch.Tensor): + bias_0 = bias + elif isinstance(bias, tuple): + assert len(bias) == 2 + bias_0, bias_1 = bias + return bias_0, bias_1 + + @staticmethod + def get_dispatch_config(num_ranks: int) -> Config: + """ + Get a recommended dispatch config. + + Argument: + num_ranks: the number of ranks. + + Returns: + config: the recommended config. + """ + + # TODO: automatically tune + config_map = { + 2: Config(Buffer.num_sms, 24, 256, 6, 128), + 4: Config(Buffer.num_sms, 6, 256, 6, 128), + 8: Config(Buffer.num_sms, 6, 256, 6, 128), + 16: Config(Buffer.num_sms, 36, 288, 20, 128), + 24: Config(Buffer.num_sms, 8, 288, 32, 128), + 32: Config(Buffer.num_sms, 32, 288, 32, 128), + 64: Config(Buffer.num_sms, 20, 288, 28, 128), + 128: Config(Buffer.num_sms, 20, 560, 32, 128), + 144: Config(Buffer.num_sms, 32, 720, 12, 128), + 160: Config(Buffer.num_sms, 28, 720, 12, 128), + } + assert num_ranks in config_map, f"Unsupported number of EP ranks: {num_ranks}" + return config_map[num_ranks] + + @staticmethod + def get_combine_config(num_ranks: int) -> Config: + """ + Get a recommended combine config. + + Argument: + num_ranks: the number of ranks. + + Returns: + config: the recommended config. + """ + + # TODO: automatically tune + config_map = { + 2: Config(Buffer.num_sms, 10, 256, 6, 128), + 4: Config(Buffer.num_sms, 9, 256, 6, 128), + 8: Config(Buffer.num_sms, 4, 256, 6, 128), + 16: Config(Buffer.num_sms, 4, 288, 12, 128), + 24: Config(Buffer.num_sms, 1, 288, 8, 128), + 32: Config(Buffer.num_sms, 1, 288, 8, 128), + 64: Config(Buffer.num_sms, 1, 288, 20, 128), + 128: Config(Buffer.num_sms, 1, 560, 12, 128), + 144: Config(Buffer.num_sms, 2, 720, 8, 128), + 160: Config(Buffer.num_sms, 2, 720, 8, 128), + } + assert num_ranks in config_map, f"Unsupported number of EP ranks: {num_ranks}" + return config_map[num_ranks] + + # noinspection PyTypeChecker + def get_dispatch_layout( + self, + topk_idx: torch.Tensor, + num_experts: int, + previous_event: Optional[EventOverlap] = None, + async_finish: bool = False, + allocate_on_comm_stream: bool = False, + ) -> Tuple[ + torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, EventOverlap + ]: + """ + Calculate the layout required for later communication. + + Arguments: + topk_idx: `[num_tokens, num_topk]`, dtype must be `torch.int64`, the expert indices selected by each token, + `-1` means no selections. + num_experts: the number of experts. + previous_event: the event to wait before actually executing the kernel. + async_finish: the current stream will not wait for the communication kernels to be finished if set. + allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream. + + Returns: + num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank. + num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA + rank (with the same GPU index), return `None` for intranode settings. + num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert. + is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank. + event: the event after executing the kernel (valid only if `async_finish` is set). + """ + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + event, + ) = self.runtime.get_dispatch_layout( + topk_idx, + num_experts, + getattr(previous_event, "event", None), + async_finish, + allocate_on_comm_stream, + ) + return ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + EventOverlap(event), + ) + + # noinspection PyTypeChecker + def dispatch( + self, + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + handle: Optional[Tuple] = None, + num_tokens_per_rank: Optional[torch.Tensor] = None, + num_tokens_per_rdma_rank: Optional[torch.Tensor] = None, + is_token_in_rank: Optional[torch.Tensor] = None, + num_tokens_per_expert: Optional[torch.Tensor] = None, + topk_idx: Optional[torch.Tensor] = None, + topk_weights: Optional[torch.Tensor] = None, + expert_alignment: int = 1, + num_worst_tokens: int = 0, + config: Optional[Config] = None, + previous_event: Optional[EventOverlap] = None, + async_finish: bool = False, + allocate_on_comm_stream: bool = False, + ) -> Tuple[ + Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + List[int], + Tuple, + EventOverlap, + ]: + """ + Dispatch tokens to different ranks, both intranode and internode settings are supported. + Intranode kernels require all the ranks should be visible via NVLink. + Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU + index should be visible via RDMA. + + Arguments: + x: `torch.Tensor` or tuple of `torch.Tensor`, for the first type, the shape must be `[num_tokens, hidden]`, + and type must be `torch.bfloat16`; for the second type, the first element of the tuple must be shaped as + `[num_tokens, hidden]` with type `torch.float8_e4m3fn`, the second must be `[num_tokens, hidden // 128]` + (requiring divisible) with type `torch.float`. + handle: an optional communication handle, if set, the CPU will reuse the layout information to save some time. + num_tokens_per_rank: `[num_ranks]` with `torch.int`, the number of tokens to be sent to each rank. + num_tokens_per_rdma_rank: `[num_rdma_ranks]` with `torch.int`, the number of tokens to be sent to each RDMA + rank (with the same GPU index), return `None` for intranode settings. + is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank. + num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert. + topk_idx: `[num_tokens, num_topk]` with `torch.int64`, the expert indices selected by each token, + `-1` means no selections. + topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch. + expert_alignment: align the number of tokens received by each local expert to this variable. + num_worst_tokens: the worst number of tokens to receive, if specified, there will be no CPU sync, and it + will be CUDA-graph compatible. Please also notice that this flag is for intranode only. + config: the performance tuning config. + previous_event: the event to wait before actually executing the kernel. + async_finish: the current stream will not wait for the communication kernels to be finished if set. + allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream. + + Returns: + recv_x: received tokens, the same type and tuple as the input `x`, but the number of tokens equals to the + received token count. + recv_topk_idx: received expert indices. + recv_topk_weights: received expert weights. + num_recv_tokens_per_expert_list: Python list shaped `[num_local_experts]`, the received token count by + each local expert, aligned to the input `expert_alignment`. If `num_worst_tokens` is specified, the list + will be empty. + handle: the returned communication handle. + event: the event after executing the kernel (valid only if `async_finish` is set). + """ + # Default config + config = self.get_dispatch_config(self.group_size) if config is None else config + + # Internode + if self.runtime.get_num_rdma_ranks() > 1: + assert ( + num_worst_tokens == 0 + ), "Internode dispatch does not support `num_worst_tokens > 0`" + return self.internode_dispatch( + x, + handle, + num_tokens_per_rank, + num_tokens_per_rdma_rank, + is_token_in_rank, + num_tokens_per_expert, + topk_idx, + topk_weights, + expert_alignment, + config, + previous_event, + async_finish, + allocate_on_comm_stream, + ) + + # Launch the kernel with cached or non-cached mode + x, x_scales = x if isinstance(x, tuple) else (x, None) + if handle is not None: + assert topk_idx is None and topk_weights is None + ( + rank_prefix_matrix, + channel_prefix_matrix, + recv_channel_prefix_matrix, + recv_src_idx, + is_token_in_rank, + send_head, + ) = handle + num_recv_tokens = recv_src_idx.size(0) + recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = ( + self.runtime.intranode_dispatch( + x, + x_scales, + None, + None, + None, + is_token_in_rank, + None, + num_recv_tokens, + rank_prefix_matrix, + channel_prefix_matrix, + expert_alignment, + num_worst_tokens, + config, + getattr(previous_event, "event", None), + async_finish, + allocate_on_comm_stream, + ) + ) + return ( + (recv_x, recv_x_scales) if x_scales is not None else recv_x, + None, + None, + None, + None, + EventOverlap(event), + ) + else: + assert ( + num_tokens_per_rank is not None + and is_token_in_rank is not None + and num_tokens_per_expert is not None + ) + ( + recv_x, + recv_x_scales, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + rank_prefix_matrix, + channel_prefix_matrix, + recv_channel_prefix_matrix, + recv_src_idx, + send_head, + event, + ) = self.runtime.intranode_dispatch( + x, + x_scales, + topk_idx, + topk_weights, + num_tokens_per_rank, + is_token_in_rank, + num_tokens_per_expert, + 0, + None, + None, + expert_alignment, + num_worst_tokens, + config, + getattr(previous_event, "event", None), + async_finish, + allocate_on_comm_stream, + ) + handle = ( + rank_prefix_matrix, + channel_prefix_matrix, + recv_channel_prefix_matrix, + recv_src_idx, + is_token_in_rank, + send_head, + ) + return ( + (recv_x, recv_x_scales) if x_scales is not None else recv_x, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + handle, + EventOverlap(event), + ) + + # noinspection PyTypeChecker + def combine( + self, + x: torch.Tensor, + handle: Tuple, + topk_weights: Optional[torch.Tensor] = None, + bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, + config: Optional[Config] = None, + previous_event: Optional[EventOverlap] = None, + async_finish: bool = False, + allocate_on_comm_stream: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]: + """ + Combine (reduce) tokens (addition **without** weights) from different ranks, both intranode and internode + settings are supported. + Intranode kernels require all the ranks should be visible via NVLink. + Internode kernels require the ranks in a node should be visible via NVLink, while the ranks with the same GPU + index should be visible via RDMA. + + Arguments: + x: `[num_tokens, hidden]` with `torch.bfloat16`, the tokens to send for reducing to its original ranks. + handle: a must-set communication handle, you can obtain this from the dispatch function. + topk_weights: `[num_tokens, num_topk]` with `torch.float`, the tokens' top-k weights for reducing to its original ranks. + config: the performance tuning config. + previous_event: the event to wait before actually executing the kernel. + async_finish: the current stream will not wait for the communication kernels to be finished if set. + allocate_on_comm_stream: control whether all the allocated tensors' ownership to be on the communication stream. + + Returns: + recv_x: the reduced token from its dispatched ranks. + recv_topk_weights: the reduced top-k weights from its dispatch ranks. + event: the event after executing the kernel (valid only if `async_finish` is set). + """ + # Default config + config = self.get_combine_config(self.group_size) if config is None else config + + # Internode + if self.runtime.get_num_rdma_ranks() > 1: + return self.internode_combine( + x, + handle, + topk_weights, + bias, + config, + previous_event, + async_finish, + allocate_on_comm_stream, + ) + + # NOTES: the second `_` is for the sending side, so we should use the third one + ( + rank_prefix_matrix, + _, + channel_prefix_matrix, + src_idx, + is_recv_token_in_rank, + send_head, + ) = handle + bias_0, bias_1 = Buffer._unpack_bias(bias) + + # Launch the kernel + recv_x, recv_topk_weights, event = self.runtime.intranode_combine( + x, + topk_weights, + bias_0, + bias_1, + src_idx, + rank_prefix_matrix, + channel_prefix_matrix, + send_head, + config, + getattr(previous_event, "event", None), + async_finish, + allocate_on_comm_stream, + ) + return recv_x, recv_topk_weights, EventOverlap(event) + + # noinspection PyTypeChecker + def internode_dispatch( + self, + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + handle: Optional[Tuple] = None, + num_tokens_per_rank: Optional[torch.Tensor] = None, + num_tokens_per_rdma_rank: Optional[torch.Tensor] = None, + is_token_in_rank: Optional[torch.Tensor] = None, + num_tokens_per_expert: Optional[torch.Tensor] = None, + topk_idx: Optional[torch.Tensor] = None, + topk_weights: Optional[torch.Tensor] = None, + expert_alignment: int = 1, + config: Optional[Config] = None, + previous_event: Optional[EventOverlap] = None, + async_finish: bool = False, + allocate_on_comm_stream: bool = False, + ) -> Tuple[ + Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + List[int], + Tuple, + EventOverlap, + ]: + """ + Internode dispatch implementation, for more details, please refer to the `dispatch` docs. + Normally, you should not directly call this function. + """ + assert config is not None + + # Launch the kernel with cached or non-cached mode + x, x_scales = x if isinstance(x, tuple) else (x, None) + if handle is not None: + assert topk_idx is None and topk_weights is None + ( + is_token_in_rank, + rdma_channel_prefix_matrix, + gbl_channel_prefix_matrix, + recv_rdma_channel_prefix_matrix, + recv_rdma_rank_prefix_sum, + recv_gbl_channel_prefix_matrix, + recv_gbl_rank_prefix_sum, + recv_src_meta, + send_rdma_head, + send_nvl_head, + ) = handle + num_recv_tokens = recv_src_meta.size(0) + num_rdma_recv_tokens = send_nvl_head.size(0) + recv_x, recv_x_scales, _, _, _, _, _, _, _, _, _, _, _, _, event = ( + self.runtime.internode_dispatch( + x, + x_scales, + topk_idx, + topk_weights, + None, + None, + is_token_in_rank, + None, + num_recv_tokens, + num_rdma_recv_tokens, + rdma_channel_prefix_matrix, + recv_rdma_rank_prefix_sum, + gbl_channel_prefix_matrix, + recv_gbl_rank_prefix_sum, + expert_alignment, + config, + getattr(previous_event, "event", None), + async_finish, + allocate_on_comm_stream, + ) + ) + return ( + (recv_x, recv_x_scales) if x_scales is not None else recv_x, + None, + None, + None, + None, + EventOverlap(event), + ) + else: + assert ( + num_tokens_per_rank is not None + and is_token_in_rank is not None + and num_tokens_per_expert is not None + ) + ( + recv_x, + recv_x_scales, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + rdma_channel_prefix_matrix, + gbl_channel_prefix_matrix, + recv_rdma_channel_prefix_matrix, + recv_rdma_rank_prefix_sum, + recv_gbl_channel_prefix_matrix, + recv_gbl_rank_prefix_sum, + recv_src_meta, + send_rdma_head, + send_nvl_head, + event, + ) = self.runtime.internode_dispatch( + x, + x_scales, + topk_idx, + topk_weights, + num_tokens_per_rank, + num_tokens_per_rdma_rank, + is_token_in_rank, + num_tokens_per_expert, + 0, + 0, + None, + None, + None, + None, + expert_alignment, + config, + getattr(previous_event, "event", None), + async_finish, + allocate_on_comm_stream, + ) + handle = ( + is_token_in_rank, + rdma_channel_prefix_matrix, + gbl_channel_prefix_matrix, + recv_rdma_channel_prefix_matrix, + recv_rdma_rank_prefix_sum, + recv_gbl_channel_prefix_matrix, + recv_gbl_rank_prefix_sum, + recv_src_meta, + send_rdma_head, + send_nvl_head, + ) + return ( + (recv_x, recv_x_scales) if x_scales is not None else recv_x, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + handle, + EventOverlap(event), + ) + + # noinspection PyTypeChecker + def internode_combine( + self, + x: torch.Tensor, + handle: Union[tuple, list], + topk_weights: Optional[torch.Tensor] = None, + bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, + config: Optional[Config] = None, + previous_event: Optional[EventOverlap] = None, + async_finish: bool = False, + allocate_on_comm_stream: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], EventOverlap]: + """ + Internode combine implementation, for more details, please refer to the `combine` docs. + Normally, you should not directly call this function. + """ + assert config is not None + + # Unpack handle and bias + ( + is_combined_token_in_rank, + _, + _, + rdma_channel_prefix_matrix, + rdma_rank_prefix_sum, + gbl_channel_prefix_matrix, + gbl_rank_prefix_sum, + src_meta, + send_rdma_head, + send_nvl_head, + ) = handle + bias_0, bias_1 = Buffer._unpack_bias(bias) + + # Launch the kernel + combined_x, combined_topk_weights, event = self.runtime.internode_combine( + x, + topk_weights, + bias_0, + bias_1, + src_meta, + is_combined_token_in_rank, + rdma_channel_prefix_matrix, + rdma_rank_prefix_sum, + gbl_channel_prefix_matrix, + send_rdma_head, + send_nvl_head, + config, + getattr(previous_event, "event", None), + async_finish, + allocate_on_comm_stream, + ) + return combined_x, combined_topk_weights, EventOverlap(event) + + def clean_low_latency_buffer( + self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int + ) -> None: + """ + As low-latency kernels require part of the buffer to be zero-initialized, so it is vital to clean the buffer + if the buffer is dirty at some time. + For example, after running the normal dispatch/combine, you must run this function before executing any + low-latency kernel. + + Arguments: + num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. + hidden: the hidden dimension of each token. + num_experts: the number of all experts. + """ + self.runtime.clean_low_latency_buffer( + num_max_dispatch_tokens_per_rank, hidden, num_experts + ) diff --git a/ep/deep_ep_wrapper/deep_ep/test_internode.py b/ep/deep_ep_wrapper/deep_ep/test_internode.py new file mode 100644 index 000000000..632a9c8ca --- /dev/null +++ b/ep/deep_ep_wrapper/deep_ep/test_internode.py @@ -0,0 +1,549 @@ +""" +This is the same test_internode.py test in DeepEP's repo. + +Build: +export OMP_NUM_THREADS=6 +export MAKE_NORMAL_MODE=1 +make clean && make -j install + +On first node: +torchrun --nnodes=2 --nproc_per_node=8 --node_rank=0 \ + --master_addr=10.1.227.34 --master_port=12355 \ + bench/test_internode.py --num-tokens=4096 \ + --hidden=7168 --num-topk=8 --num-experts=256 --test-ll-compatibility + +On second node: +torchrun --nnodes=2 --nproc_per_node=8 --node_rank=1 \ + --master_addr=10.1.227.34 --master_port=12355 \ + bench/test_internode.py --num-tokens=4096 \ + --hidden=7168 --num-topk=8 --num-experts=256 --test-ll-compatibility + +This benchmark verifies: + * Dispatch and combine correctness for BF16/FP8 + * Top-k routing and per-expert token distribution + * Compatibility with cached dispatch and low-latency kernels + * Performance tuning for NVL and RDMA chunk sizes +""" + +import argparse +import os +import time +import torch +import torch.distributed as dist + +# noinspection PyUnresolvedReferences + +from utils import ( + init_dist, + bench, + bench_kineto, + calc_diff, + create_grouped_scores, + inplace_unique, + per_token_cast_to_fp8, + per_token_cast_back, + init_dist_under_torchrun, + detect_ib_hca, +) + +# # Test compatibility with low latency functions +# from buffer import Buffer + +# try: +# from uccl.ep import Config +# except ImportError as exc: +# import sys + +# sys.stderr.write("Failed to import uccl.ep\n") +# raise + +from deep_ep import Buffer, Config + +# noinspection PyShadowingNames +def test_main( + args: argparse.Namespace, + num_sms: int, + local_rank: int, + num_local_ranks: int, + num_ranks: int, + num_nodes: int, + rank: int, + buffer: Buffer, + group: dist.ProcessGroup, +): + # Settings + num_tokens, hidden = args.num_tokens, args.hidden + num_topk_groups, num_topk, num_experts = ( + args.num_topk_groups, + args.num_topk, + args.num_experts, + ) + + assert num_experts % num_ranks == 0 and num_local_ranks == 8 + if local_rank == 0: + print( + f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}", + flush=True, + ) + + # Random data + x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") * rank + x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda") + x_e4m3 = per_token_cast_to_fp8(x) + x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) + scores = ( + torch.randn((num_tokens, num_experts), dtype=torch.float32, device="cuda").abs() + + 1 + ) + group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1) + group_idx = torch.topk( + group_scores, k=num_topk_groups, dim=-1, sorted=False + ).indices + masked_scores = create_grouped_scores(scores, group_idx, num_nodes) + topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[ + 1 + ] + topk_weights = ( + torch.ones((num_tokens, num_topk), dtype=torch.float32, device="cuda") * rank + ) + topk_weights_pure_rand = torch.randn( + (num_tokens, num_topk), dtype=torch.float32, device="cuda" + ) + rank_idx = topk_idx // (num_experts // num_ranks) + rank_idx.masked_fill_(topk_idx == -1, -1) + inplace_unique(rank_idx, num_ranks) + rdma_rank_idx = rank_idx // num_local_ranks + rdma_rank_idx.masked_fill_(rank_idx == -1, -1) + inplace_unique(rdma_rank_idx, num_nodes) + + # RDMA dispatch counts + rdma_idx = topk_idx // (num_experts // num_nodes) + rdma_idx.masked_fill_(topk_idx == -1, -1) + inplace_unique(rdma_idx, num_nodes) + num_rdma_token_sent = rdma_idx.ne(-1).sum().item() + + # Expert meta + num_tokens_per_expert = torch.zeros((num_experts,), dtype=torch.int, device="cuda") + for i in range(num_experts): + num_tokens_per_expert[i] = (topk_idx == i).sum() + gbl_num_tokens_per_expert = num_tokens_per_expert.clone() + dist.all_reduce(gbl_num_tokens_per_expert, group=group) + + # Rank layout meta + num_tokens_per_rank = torch.empty((num_ranks,), dtype=torch.int, device="cuda") + num_tokens_per_rdma_rank = torch.empty((num_nodes,), dtype=torch.int, device="cuda") + token_idx_in_rank = torch.full( + (num_ranks, num_tokens), -1, dtype=torch.long, device="cuda" + ) + for i in range(num_ranks): + num_tokens_per_rank[i] = (rank_idx == i).sum() + token_sel = (rank_idx == i).max(dim=-1)[0] + count = token_sel.sum().item() + tokens = torch.sort(token_sel.to(torch.int), descending=True)[1] + tokens[:count] = torch.sort(tokens[:count])[0] + token_idx_in_rank[i][tokens[:count]] = torch.arange( + count, dtype=torch.long, device="cuda" + ) + for i in range(num_nodes): + num_tokens_per_rdma_rank[i] = (rdma_rank_idx == i).sum() + token_idx_in_rank = token_idx_in_rank.T.contiguous().to(torch.int) + is_token_in_rank = token_idx_in_rank >= 0 + gbl_num_tokens_per_rank = num_tokens_per_rank.clone() + dist.all_reduce(gbl_num_tokens_per_rank, group=group) + + ( + ref_num_tokens_per_rank, + ref_num_tokens_per_rdma_rank, + ref_num_tokens_per_expert, + ref_is_token_in_rank, + _, + ) = buffer.get_dispatch_layout(topk_idx, num_experts) + assert torch.allclose(ref_num_tokens_per_rank, num_tokens_per_rank) + assert torch.allclose(ref_num_tokens_per_rdma_rank, num_tokens_per_rdma_rank) + assert torch.allclose(ref_num_tokens_per_expert, num_tokens_per_expert) + assert torch.allclose(ref_is_token_in_rank, is_token_in_rank) + t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0] + if local_rank == 0: + print(f"[layout] Kernel performance: {t * 1000:.3f} ms", flush=True) + print("", flush=True) + group.barrier() + time.sleep(1) + + # Config + # This seems really high. + rdma_buffer_size, nvl_buffer_size = 512, (720 if num_ranks in (144, 160) else 512) + if num_ranks == 24: + nvl_buffer_size = 540 + config = Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size) + + # Test dispatch + # noinspection PyShadowingNames + def check_data(check_x, recv_gbl_rank_prefix_sum): + assert torch.allclose(check_x.amin(dim=1), check_x.amax(dim=1)) + check_start = 0 + for i in range(num_ranks): + check_end = recv_gbl_rank_prefix_sum[i].item() + assert (check_x[check_start:check_end, :].int() - i).sum().item() == 0 + check_start = check_end + + for previous_mode in (False, True): + for async_mode in (False, True): + for current_x in (x_pure_rand, x, x_e4m3): + for with_topk in (False, True): + if local_rank == 0: + print( + f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', + flush=True, + ) + dispatch_args = { + "x": current_x, + "num_tokens_per_rank": num_tokens_per_rank, + "num_tokens_per_rdma_rank": num_tokens_per_rdma_rank, + "is_token_in_rank": is_token_in_rank, + "num_tokens_per_expert": num_tokens_per_expert, + "config": config, + "async_finish": async_mode, + } + if with_topk: + dispatch_args.update( + { + "topk_idx": topk_idx, + "topk_weights": ( + topk_weights_pure_rand + if current_x is x_pure_rand + else topk_weights + ), + } + ) + if previous_mode: + dispatch_args.update({"previous_event": buffer.capture()}) + ( + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_num_tokens_per_expert_list, + handle, + event, + ) = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + recv_x = ( + per_token_cast_back(*recv_x) + if isinstance(recv_x, tuple) + else recv_x + ) + + # Checks + recv_gbl_rank_prefix_sum = handle[-4] + assert gbl_num_tokens_per_rank[rank].item() == recv_x.size( + 0 + ), f"{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}" + assert ( + gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() + == recv_num_tokens_per_expert_list + ) + if current_x is not x_pure_rand: + check_data(recv_x, recv_gbl_rank_prefix_sum) + if with_topk: + # Check `topk_idx` + assert ( + recv_topk_idx.eq(-1) + | ( + (recv_topk_idx >= 0) + & (recv_topk_idx < (num_experts // num_ranks)) + ) + ).sum().item() == recv_topk_idx.numel() + for i, count in enumerate(recv_num_tokens_per_expert_list): + assert recv_topk_idx.eq(i).sum().item() == count + + # Check `topk_weights` + if current_x is not x_pure_rand: + recv_topk_weights[recv_topk_idx.eq(-1)] = ( + recv_topk_weights.amax(dim=1, keepdim=True).expand_as( + recv_topk_weights + )[recv_topk_idx.eq(-1)] + ) + check_data(recv_topk_weights, recv_gbl_rank_prefix_sum) + + # Test cached dispatch (must without top-k staffs) + if not with_topk: + dispatch_args = { + "x": current_x, + "handle": handle, + "config": config, + "async_finish": async_mode, + } + if previous_mode: + dispatch_args.update({"previous_event": buffer.capture()}) + recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + recv_x = ( + per_token_cast_back(*recv_x) + if isinstance(recv_x, tuple) + else recv_x + ) + if current_x is not x_pure_rand: + check_data(recv_x, recv_gbl_rank_prefix_sum) + + # Test combine + bias_0 = torch.ones( + (num_tokens, hidden), dtype=torch.bfloat16, device="cuda" + ) + bias_1 = torch.randn( + (num_tokens, hidden), dtype=torch.bfloat16, device="cuda" + ) + combine_args = { + "x": recv_x, + "bias": (bias_0, bias_1), + "handle": handle, + "config": config, + "async_finish": async_mode, + } + if with_topk: + combine_args.update({"topk_weights": recv_topk_weights}) + if previous_mode: + combine_args.update({"previous_event": buffer.capture()}) + combined_x, combined_topk_weights, event = buffer.combine( + **combine_args + ) + event.current_stream_wait() if async_mode else () + check_x = ( + combined_x.float() - bias_0.float() - bias_1.float() + ) / is_token_in_rank.sum(dim=1).unsqueeze(1) + ref_x = x_pure_rand if current_x is x_pure_rand else x + assert calc_diff(check_x, ref_x) < 5e-6 + if with_topk: + check_topk_weights = ( + combined_topk_weights + if (current_x is x_pure_rand) + else ( + combined_topk_weights + / is_token_in_rank.sum(dim=1).unsqueeze(1) + ) + ) + ref_topk_weights = ( + topk_weights_pure_rand + if current_x is x_pure_rand + else topk_weights + ) + assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9 + + # For later tuning + dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2 + dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2 + combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes + combine_bf16_rdma_recv_bytes = dispatch_bf16_rdma_send_bytes + if local_rank == 0: + print("", flush=True) + + # Tune dispatch performance + best_dispatch_results = None + fp8_factor = (1 + 4 / 128) / 2 + for current_x in (x_e4m3, x): + best_time, best_results = 1e10, None + rdma_send_bytes = ( + (dispatch_bf16_rdma_send_bytes * fp8_factor) + if isinstance(current_x, tuple) + else dispatch_bf16_rdma_send_bytes + ) + nvl_recv_bytes = ( + (dispatch_bf16_nvl_recv_bytes * fp8_factor) + if isinstance(current_x, tuple) + else dispatch_bf16_nvl_recv_bytes + ) + for nvl_chunk_size in range(4, 45, 4): + for rdma_chunk_size in range(4, 33, 4): + config = Config( + num_sms, + nvl_chunk_size, + nvl_buffer_size, + rdma_chunk_size, + rdma_buffer_size, + ) + tune_args = {"x": current_x, "handle": handle, "config": config} + t, notify_t = bench_kineto( + lambda: buffer.dispatch(**tune_args), ("dispatch", "notify") + ) + if t < best_time: + best_time, best_results = t, ( + num_sms, + nvl_chunk_size, + rdma_chunk_size, + notify_t, + ) + if local_rank == 0: + print( + f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}, transmit: {t * 1e6:.2f} us, notify: {notify_t * 1e6:.2f} us, BW: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ", + flush=True, + ) + if local_rank == 0: + print( + f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, transmit: {best_time * 1e6:.2f} us, notify: {best_results[3] * 1e6:.2f} us, BW: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', + flush=True, + ) + print("", flush=True) + + if isinstance(current_x, tuple): + # Gather FP8 the best config from rank 0 + best_dispatch_results = torch.tensor( + [best_results[0], best_results[1], best_results[2]], + dtype=torch.int32, + device="cuda", + ) + all_best_fp8_results_list = [ + torch.zeros_like(best_dispatch_results) + for _ in range(torch.distributed.get_world_size()) + ] + dist.all_gather( + all_best_fp8_results_list, best_dispatch_results, group=group + ) + best_dispatch_results = all_best_fp8_results_list[0].tolist() + dispatch_config = Config( + best_dispatch_results[0], + best_dispatch_results[1], + nvl_buffer_size, + best_dispatch_results[2], + rdma_buffer_size, + ) + + dispatch_args = { + "x": x, + "num_tokens_per_rank": num_tokens_per_rank, + "num_tokens_per_rdma_rank": num_tokens_per_rdma_rank, + "is_token_in_rank": is_token_in_rank, + "num_tokens_per_expert": num_tokens_per_expert, + "config": dispatch_config if dispatch_config is not None else config, + } + recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) + + # Tune combine performance + best_time, best_results = 1e10, None + for nvl_chunk_size in range(1, 8, 1): + for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4): + config = Config( + num_sms, + nvl_chunk_size, + nvl_buffer_size, + rdma_chunk_size, + rdma_buffer_size, + ) + tune_args = {"x": recv_x, "handle": handle, "config": config} + t, notify_t = bench_kineto( + lambda: buffer.combine(**tune_args), ("combine", "notify") + ) + if local_rank == 0: + print( + f"[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}, transmit: {t * 1e6:.2f} us, notify: {notify_t * 1e6:.2f} us, BW: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ", + flush=True, + ) + if t < best_time: + best_time, best_results = t, ( + num_sms, + nvl_chunk_size, + rdma_chunk_size, + notify_t, + ) + + if local_rank == 0: + print( + f"[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, transmit: {best_time * 1e6:.2f} us, notify: {best_results[3] * 1e6:.2f} us, BW: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)", + flush=True, + ) + print("", flush=True) + + +# noinspection PyUnboundLocalVariable,PyShadowingNames +def test_loop( + local_rank: int, num_local_ranks: int, num_nodes: int, args: argparse.Namespace +): + rank, num_ranks, group = init_dist_under_torchrun(local_rank, num_local_ranks) + if args.test_ll_compatibility: + ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 + + num_sms = 24 + num_qps_per_rank = max( + num_sms, + ll_num_experts // num_ranks if args.test_ll_compatibility else 0, + ) + num_nvlink_bytes = int(2e9) + num_rdma_bytes = int(1e9) + + buffer = Buffer( + group, + num_nvlink_bytes, + num_rdma_bytes, + low_latency_mode=False, + num_qps_per_rank=num_qps_per_rank, + explicitly_destroy=True, + ) + + assert num_local_ranks == 8 and num_ranks > 8 + torch.manual_seed(rank) + + for i in (num_sms,): + test_main( + args, + i, + local_rank, + num_local_ranks, + num_ranks, + num_nodes, + rank, + buffer, + group, + ) + if local_rank == 0: + print("", flush=True) + + # Destroy the buffer runtime and communication group + buffer.destroy() + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test internode EP kernels") + parser.add_argument( + "--num-processes", + type=int, + default=8, + help="Number of processes to spawn (default: 8)", + ) + parser.add_argument( + "--num-tokens", type=int, default=4096, help="Number of tokens (default: 4096)" + ) + parser.add_argument( + "--hidden", type=int, default=7168, help="Hidden dimension size (default: 7168)" + ) + parser.add_argument( + "--num-topk-groups", + type=int, + default=None, + help="Number of top-k groups (default: `min(num_nodes, 4)`)", + ) + parser.add_argument( + "--num-topk", type=int, default=8, help="Number of top-k experts (default: 8)" + ) + parser.add_argument( + "--num-experts", type=int, default=256, help="Number of experts (default: 256" + ) + parser.add_argument( + "--test-ll-compatibility", + action="store_true", + help="whether to test compatibility with low-latency kernels", + ) + args = parser.parse_args() + world_size = int(os.environ["WORLD_SIZE"]) + local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + num_nodes = world_size // local_world_size + + # Set default `num_topk_groups` if not provided + if args.num_topk_groups is None: + args.num_topk_groups = min(num_nodes, 4) + + num_processes = args.num_processes + if num_processes != 8: + raise ValueError("Only --num-processes=8 is supported for this test.") + # NOTE: modified from deep_ep + local_rank = int(os.environ["LOCAL_RANK"]) + num_local_ranks = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) + test_loop(local_rank, num_local_ranks, num_nodes, args) diff --git a/ep/deep_ep_wrapper/deep_ep/utils.py b/ep/deep_ep_wrapper/deep_ep/utils.py new file mode 100644 index 000000000..ad0ccfcf7 --- /dev/null +++ b/ep/deep_ep_wrapper/deep_ep/utils.py @@ -0,0 +1,655 @@ +import inspect +from typing import Any, Optional, Tuple, Union +import os +import torch +import torch.distributed as dist +from typing import Optional +import glob +import sys +from uccl.ep import EventHandle +import tempfile +import json +from pathlib import Path +import time +import numpy as np + +# import deep_ep as ep +try: + from uccl import ep +except ImportError as exc: + import sys + + sys.stderr.write("Failed to import uccl.ep\n") + raise + +# import deep_ep as ep +try: + from uccl import ep +except ImportError as exc: + import sys + + sys.stderr.write("Failed to import uccl.ep\n") + raise + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double() + 1, y.double() + 1 + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return (1 - sim).item() + + +def hash_tensor(t: torch.Tensor): + return t.view(torch.int64).sum().item() + + +def init_dist(local_rank: int, num_local_ranks: int): + # Set device + device_index = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(device_index) + torch.set_default_device(f"cuda:{device_index}") + + # NOTES: you may rewrite this function with your own cluster settings + ip = os.getenv("MASTER_ADDR", "127.0.0.1") + port = int(os.getenv("MASTER_PORT", "8361")) + world_size = int(os.getenv("WORLD_SIZE", 1)) + node_rank = int(os.getenv("RANK", 0)) + + sig = inspect.signature(dist.init_process_group) + params = { + "backend": "nccl", + "init_method": f"tcp://{ip}:{port}", + "world_size": world_size, + "rank": node_rank, + } + print(params) + if "device_id" in sig.parameters: + # noinspection PyTypeChecker + params["device_id"] = torch.device(f"cuda:{local_rank}") + dist.init_process_group(**params) + torch.set_default_dtype(torch.bfloat16) + return ( + dist.get_rank(), + dist.get_world_size(), + dist.new_group(list(range(world_size))), + ) + + +def init_dist_under_torchrun(local_rank: int, num_local_ranks: int): + # torchrun already sets RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT + dist.init_process_group( + backend="nccl", device_id=torch.device(f"cuda:{local_rank}") + ) + + torch.set_default_dtype(torch.bfloat16) + torch.set_default_device(f"cuda:{local_rank}") + torch.cuda.set_device(local_rank) + + return ( + dist.get_rank(), + dist.get_world_size(), + dist.new_group(list(range(dist.get_world_size()))), + ) + + +def _discover_local_ip(): + # Try to infer the IP that can reach MASTER_ADDR (works in most clusters) + import socket, os + + # Method 1: Use MASTER_ADDR if available (torchrun style) + if "MASTER_ADDR" in os.environ: + master = os.environ["MASTER_ADDR"] + port = int(os.environ.get("MASTER_PORT", "29500")) + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect((master, port)) + return s.getsockname()[0] + except: + pass + finally: + s.close() + + # Method 2: Use hostname resolution (works in AWS and most cloud environments) + hostname = socket.gethostname() + try: + # This usually returns the private IP in cloud environments + local_ip = socket.gethostbyname(hostname) + # Skip loopback addresses + if not local_ip.startswith('127.'): + return local_ip + except: + pass + + # Method 3: Connect to a public DNS to determine outgoing interface + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + # Google DNS - this doesn't actually send packets + s.connect(("8.8.8.8", 80)) + local_ip = s.getsockname()[0] + s.close() + return local_ip + except: + pass + + # Last resort: return localhost + return "127.0.0.1" + + +def _gather_peer_ips(group): + # Gather local IP strings across ranks + world = dist.get_world_size(group) + my_ip = _discover_local_ip() + ips = [None] * world + dist.all_gather_object(ips, my_ip, group=group) + return ips + + +def get_peer_ip(rank: int, num_ranks: int, group: dist.ProcessGroup): + + if num_ranks == 1: + # single-process local test: okay to leave blank (or 127.0.0.1) + peer_ip = "" + else: + ips = _gather_peer_ips(group) + # simple ring: next rank is your peer + peer_ip = ips[(rank + 1) % num_ranks] + return peer_ip if peer_ip else "" + + +def get_cpu_proxies_meta(rank, scratch_ptr, scratch_bytes, num_ranks, group): + my_ip = _discover_local_ip() + meta = { + "rank": rank, + "ptr": int(scratch_ptr), + "nbytes": int(scratch_bytes), + "ip": my_ip, + } + all_meta = [None] * num_ranks + # Use current device or fallback to LOCAL_RANK or 0 + if "LOCAL_RANK" in os.environ: + device_index = int(os.environ["LOCAL_RANK"]) + else: + device_index = torch.cuda.current_device() + torch.cuda.set_device(device_index) + dist.all_gather_object(all_meta, meta, group=group) + rank2meta = {m["rank"]: m for m in all_meta} + + # Debug: print IP distribution + ip_counts = {} + for m in all_meta: + ip = m["ip"] + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + if rank == 0: + print(f"[DEBUG] IP distribution across {num_ranks} ranks:", flush=True) + for ip, count in ip_counts.items(): + print(f"[DEBUG] {ip}: {count} ranks", flush=True) + + return rank2meta + + +def check_nvlink_connections(group: dist.ProcessGroup): + """ + Check NVLink connection between every pair of GPUs. + + Arguments: + group: the communication group. + """ + # Check NVLink connection + # NOTES: some A100 PCIE GPUs only have pairwise NVLink connection, so that we can only use EP2 + # TODO: check all cases, all local-node GPUs in the group should be connected via NVLink + if "PCIE" in torch.cuda.get_device_name(): + assert group.size() <= 2, "PCIe GPUs only have pairwise NVLink connections" + + # noinspection PyUnresolvedReferences + import pynvml + + pynvml.nvmlInit() + + # noinspection PyTypeChecker + devices = ( + os.environ.get("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7") + .strip(",") + .split(",") + ) + physical_device_idx = int(devices[torch.cuda.current_device()]) + physical_device_indices = [ + 0, + ] * group.size() + dist.all_gather_object(physical_device_indices, physical_device_idx, group) + + # Check whether they are all connected via NVLink + # Reference: https://github.com/vllm-project/vllm/blob/b8e809a057765c574726a6077fd124db5077ce1f/vllm/platforms/cuda.py#L438 + handles = [ + pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_indices + ] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i >= j: + continue + status = pynvml.nvmlDeviceGetP2PStatus( + handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK + ) + assert ( + status == pynvml.NVML_P2P_STATUS_OK + ), f"GPU {physical_device_indices[i]} and GPU {physical_device_indices[j]} are not connected via NVLink" + + # Close NVML + pynvml.nvmlShutdown() + + +class EventOverlap: + """ + A wrapper class to manage CUDA events, also for better overlapping convenience. + + Attributes: + event: the CUDA event captured. + extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph. + """ + + def __init__( + self, + event: Optional[EventHandle] = None, + extra_tensors: Optional[Tuple[torch.Tensor]] = None, + ) -> None: + """ + Initialize the class. + + Arguments: + event: the CUDA event captured. + extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph. + """ + self.event = event + + # NOTES: we use extra tensors to achieve stream recording, otherwise, + # stream recording will be incompatible with CUDA graph. + self.extra_tensors = extra_tensors + + def current_stream_wait(self) -> None: + """ + The current stream `torch.cuda.current_stream()` waits for the event to be finished. + """ + assert self.event is not None + self.event.current_stream_wait() + + def __enter__(self) -> Any: + """ + Utility for overlapping and Python `with` syntax. + + You can overlap the kernels on the current stream with the following example: + ```python + event_overlap = event_after_all_to_all_kernels() + with event_overlap(): + do_something_on_current_stream() + # After exiting the `with` scope, the current stream with wait the event to be finished. + ``` + """ + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """ + Utility for overlapping and Python `with` syntax. + + Please follow the example in the `__enter__` function. + """ + if self.event is not None: + self.event.current_stream_wait() + + +def detect_ib_hca(): + devices = sorted(glob.glob("/sys/class/infiniband/*")) + if not devices: + raise RuntimeError("No devices found under /sys/class/infiniband") + + ib_devs = [ + os.path.basename(d) for d in devices if os.path.basename(d).startswith("mlx5") + ] + if not ib_devs: + return None + return ib_devs[0] + + +def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor): + if x_scales.dtype == torch.int: + x_scales = x_scales.view(dtype=torch.uint8).to(torch.int) << 23 + x_scales = x_scales.view(dtype=torch.float) + x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128) + x_scales = x_scales.view(x_fp8.size(0), -1, 1) + return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16) + + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + + +class suppress_stdout_stderr: + def __enter__(self): + self.outnull_file = open(os.devnull, "w") + self.errnull_file = open(os.devnull, "w") + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() + + +def bench(fn, num_warmups: int = 50, num_tests: int = 50, post_fn=None): + # Flush L2 cache with 256 MB data + torch.cuda.synchronize() + current_device = torch.cuda.current_device() + cache = torch.empty( + int(256e6 // 4), dtype=torch.int, device=f"cuda:{current_device}" + ) + + # Warmup + for _ in range(num_warmups): + fn() + + # Flush L2 + cache.zero_() + + # Testing + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_tests)] + for i in range(num_tests): + # Record + start_events[i].record() + fn() + end_events[i].record() + if post_fn is not None: + post_fn() + torch.cuda.synchronize() + + times = np.array( + [s.elapsed_time(e) / 1e3 for s, e in zip(start_events, end_events)] + )[1:] + return np.average(times), np.min(times), np.max(times) + + +def bench_kineto( + fn, + kernel_names: Union[str, tuple], + num_tests: int = 30, + suppress_kineto_output: bool = False, + trace_path: Optional[str] = None, + barrier_comm_profiling: bool = False, + num_kernels_per_period: int = 1, +): + # Profile + suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress + with suppress(): + schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule + ) as prof: + for i in range(2): + # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead + if barrier_comm_profiling: + current_device = torch.cuda.current_device() + lhs = torch.randn( + (8192, 8192), dtype=torch.float, device=f"cuda:{current_device}" + ) + rhs = torch.randn( + (8192, 8192), dtype=torch.float, device=f"cuda:{current_device}" + ) + lhs @ rhs + dist.all_reduce( + torch.ones( + 1, dtype=torch.float, device=f"cuda:{current_device}" + ) + ) + for _ in range(num_tests): + fn() + torch.cuda.synchronize() + dist.barrier() + prof.step() + + # Parse the profiling table + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tuple = isinstance(kernel_names, tuple) + prof_lines = ( + prof.key_averages() + .table(sort_by="cuda_time_total", max_name_column_width=100) + .split("\n") + ) + kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names + assert all([isinstance(name, str) for name in kernel_names]) + for name in kernel_names: + assert ( + sum([name in line for line in prof_lines]) == 1 + ), f"Errors of the kernel {name} in the profiling table" + + # Save chrome traces + if trace_path is not None: + prof.export_chrome_trace(trace_path) + + # Return average kernel durations + units = {"ms": 1e3, "us": 1e6} + kernel_durations = [] + for name in kernel_names: + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + for unit, scale in units.items(): + if unit in time_str: + kernel_durations.append( + float(time_str.replace(unit, "")) / scale + ) + break + break + + # Expand the kernels by periods + if num_kernels_per_period > 1: + with tempfile.NamedTemporaryFile(suffix=".json") as tmp: + prof.export_chrome_trace(tmp.name) + profile_data = json.loads(Path(tmp.name).read_text()) + + for i, kernel_name in enumerate(kernel_names): + events = [ + event + for event in profile_data["traceEvents"] + if f"::{kernel_name}" in event["name"] + ] + events = sorted(events, key=lambda event: event["ts"]) + durations = [event["dur"] / 1e6 for event in events] + assert len(durations) % num_kernels_per_period == 0 + num_kernel_patterns = len(durations) // num_kernels_per_period + kernel_durations[i] = [ + sum(durations[j::num_kernels_per_period]) / num_kernel_patterns + for j in range(num_kernels_per_period) + ] + + # Return execution durations + return kernel_durations if is_tuple else kernel_durations[0] + + +def initialize_uccl( + scratch_ptr, + scratch_nbytes, + rank, + num_ranks, + group, + num_experts=0, + is_intranode=False, + use_normal_mode=False +): + try: + for shm_file in glob.glob("/dev/shm/uccl_barrier_*"): + os.remove(shm_file) + except Exception: + pass + + # Try to get local_rank from environment or infer from current device + if "LOCAL_RANK" in os.environ: + local_rank = int(os.environ["LOCAL_RANK"]) + else: + # Fallback: use current CUDA device as local_rank + local_rank = torch.cuda.current_device() + + # Try to get nproc_per_node from environment + if "LOCAL_WORLD_SIZE" in os.environ: + nproc_per_node = int(os.environ["LOCAL_WORLD_SIZE"]) + else: + # Fallback: infer from is_intranode and num_ranks + if is_intranode: + # All ranks are on the same node + nproc_per_node = num_ranks + else: + # Assume uniform distribution across nodes + # If we have N GPUs, assume each node has same number of GPUs + num_gpus = torch.cuda.device_count() + nproc_per_node = num_gpus if num_gpus > 0 else 1 + + node_idx = rank // nproc_per_node if nproc_per_node > 0 else 0 + + # Only check WORLD_SIZE consistency if it's defined + if "WORLD_SIZE" in os.environ and nproc_per_node > 0: + world_size = int(os.environ.get("WORLD_SIZE")) + if world_size % nproc_per_node != 0: + raise ValueError("WORLD_SIZE must be divisible by LOCAL_WORLD_SIZE") + + proxies = [] + rank2meta = get_cpu_proxies_meta( + rank, scratch_ptr, scratch_nbytes, num_ranks, group + ) + peers_meta_list = [rank2meta[r] for r in range(num_ranks)] + peer_ip = rank2meta[(rank + 1) % num_ranks]["ip"] + + # Calculate num_nodes from num_ranks and nproc_per_node + if nproc_per_node > 0: + num_nodes = num_ranks // nproc_per_node + else: + num_nodes = num_ranks # Fallback: assume each rank is on a different node + + for i in range(ep.get_num_proxy_threads()): + proxy = ep.Proxy( + thread_idx=i, + gpu_buffer_addr=scratch_ptr, + total_size=scratch_nbytes, + rank=rank, + node_idx=node_idx, + local_rank=local_rank, + peer_ip="" if is_intranode else peer_ip, + num_experts=num_experts, + num_ranks=num_ranks, + num_nodes=num_nodes, + use_normal_mode=use_normal_mode, + ) + if not is_intranode: + proxy.set_peers_meta(peers_meta_list) + proxies.append(proxy) + ep.register_proxies(local_rank, proxies) + + dist.barrier(group) + if not is_intranode: + if rank == 0: + print(f"[UCCL] Starting dual mode for internode communication (num_nodes={num_nodes}, num_ranks={num_ranks})", flush=True) + for proxy in proxies: + proxy.start_dual() + if rank == 0: + print(f"[UCCL] Dual mode started, waiting for RDMA connections to establish...", flush=True) + + workers = None + # if hasattr(ep, "PeerCopyManager"): + # try: + # workers = ep.PeerCopyManager(src_device=local_rank) + # workers.start_for_proxies(proxies) + # if rank == 0: + # print("✓ PeerCopyManager started", flush=True) + # except Exception as e: + # if rank == 0: + # print(f"PeerCopyManager unavailable: {e}", flush=True) + + time.sleep(3) + return proxies, workers + + +def destroy_uccl(proxies, workers): + # Use current device or fallback to LOCAL_RANK + if "LOCAL_RANK" in os.environ: + device_index = int(os.environ["LOCAL_RANK"]) + else: + device_index = torch.cuda.current_device() + + if workers is not None: + try: + workers.stop() + except Exception: + pass + + try: + for p in proxies: + p.stop() + except Exception: + pass + try: + ep.unregister_proxy(device_index) + except Exception: + pass + try: + for shm_file in glob.glob("/dev/shm/uccl_barrier_*"): + os.remove(shm_file) + except Exception: + pass + + +def per_token_cast_to_fp8(x: torch.Tensor): + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( + m, n + ), (x_amax / 448.0).view(m, -1) + + +def create_grouped_scores( + scores: torch.Tensor, group_idx: torch.Tensor, num_groups: int +): + num_tokens, num_experts = scores.shape + scores = scores.view(num_tokens, num_groups, -1) + mask = torch.zeros((num_tokens, num_groups), dtype=torch.bool, device=scores.device) + mask = mask.scatter_(1, group_idx, True).unsqueeze(-1).expand_as(scores) + return (scores * mask).view(num_tokens, num_experts) + + +def inplace_unique(x: torch.Tensor, num_slots: int): + assert x.dim() == 2 + mask = x < 0 + x_padded = x.masked_fill(mask, num_slots) + bin_count = torch.zeros((x.size(0), num_slots + 1), dtype=x.dtype, device=x.device) + bin_count.scatter_add_(1, x_padded, torch.ones_like(x_padded)) + bin_count = bin_count[:, :num_slots] + sorted_bin_count, sorted_bin_idx = torch.sort(bin_count, dim=-1, descending=True) + sorted_bin_idx.masked_fill_(sorted_bin_count == 0, -1) + sorted_bin_idx = torch.sort(sorted_bin_idx, descending=True, dim=-1).values + x[:, :].fill_(-1) + valid_len = min(num_slots, x.size(1)) + x[:, :valid_len] = sorted_bin_idx[:, :valid_len] diff --git a/ep/deep_ep_wrapper/setup.py b/ep/deep_ep_wrapper/setup.py new file mode 100644 index 000000000..615a3e19c --- /dev/null +++ b/ep/deep_ep_wrapper/setup.py @@ -0,0 +1,13 @@ +from setuptools import setup, find_packages + +setup( + name='deep_ep', + version='0.1.0', + packages=find_packages(), + install_requires=[ + 'uccl', # 声明依赖 + ], + author='whn09', + description='A wrapper package for uccl.ep with additional functionality', + python_requires='>=3.6', +) \ No newline at end of file diff --git a/ep/install_deps.sh b/ep/install_deps.sh index a0e33a66d..c8acabe86 100755 --- a/ep/install_deps.sh +++ b/ep/install_deps.sh @@ -19,7 +19,8 @@ get_cuda_version() { # Install PyTorch with automatic CUDA version handling echo "Checking CUDA environment..." if check_cuda; then - sudo apt install -y nvtop libgoogle-glog-dev clang-format-14 python3-pip + sudo apt update + sudo apt install -y nvtop libgoogle-glog-dev clang-format-14 python3-pip libnuma-dev pip3 install pybind11 --upgrade pip3 install black diff --git a/scripts/node_ips/p5en.txt b/scripts/node_ips/p5en.txt index 2410fc7bf..72263f118 100644 --- a/scripts/node_ips/p5en.txt +++ b/scripts/node_ips/p5en.txt @@ -1,4 +1,2 @@ -ip-10-1-227-34 -ip-10-1-27-43 -ip-10-1-72-192 -ip-10-1-88-176 \ No newline at end of file +ip-172-31-15-58 +ip-172-31-6-215 \ No newline at end of file