diff --git a/build.sh b/build.sh index ae079763..ae9476d6 100755 --- a/build.sh +++ b/build.sh @@ -398,6 +398,7 @@ def initialize(): --exclude "libcudart.so.12" \ --exclude "libamdhip64.so.*" \ --exclude "libcuda.so.1" \ + --exclude "libefa.so.1" \ -w /io/${WHEEL_DIR} # Add backend tag to wheel filename using local version identifier diff --git a/ep/bench/vllm/README.md b/ep/bench/vllm/README.md new file mode 100644 index 00000000..8dd64980 --- /dev/null +++ b/ep/bench/vllm/README.md @@ -0,0 +1,122 @@ +# vLLM + UCCL-EP Multi-Node Expert Parallel Deployment Guide + +This guide provides example scripts and instructions for deploying vLLM with Expert Parallelism (EP) across multiple nodes on AWS p5en. + +## 🚀 Installation + +### 0. Install uv + +```bash +curl -LsSf https://astral.sh/uv/install.sh | sh +uv venv +source .venv/bin/activate +uv pip install numpy torch setuptools +``` + +### 1. Install vLLM with EP Support + +Follow the official guide: +```bash +# Install vLLM: latest version with timeout fix (https://github.com/vllm-project/vllm/pull/27444) +git clone https://github.com/vllm-project/vllm.git +cd vllm +# This may take 5-10 minutes. +uv pip install -e . +``` + +For detailed EP setup, refer to [vLLM Expert Parallel Deployment](https://docs.vllm.ai/en/stable/serving/expert_parallel_deployment.html) + +### 2. Install DeepGEMM Library + +DeepGEMM provides optimized kernels for MoE operations: + +```bash +# Clone and install DeepGEMM +git clone --recursive https://github.com/deepseek-ai/DeepGEMM.git +cd DeepGEMM +cat install.sh +# cuobjdump used by https://github.com/deepseek-ai/DeepGEMM/blob/9b680f428484625f4f35dc3617f134187c6bcd4a/csrc/jit/kernel_runtime.hpp#L44 +# If you could not find cuobjdump in your servers, install it by: +sudo apt install nvidia-cuda-toolkit -y +# If your server's cuobjdump is under /bin instead of $CUDA_HOME/bin, set soft link to make DeepGEMM happy: +sudo ln -s /bin/cuobjdump /usr/local/cuda/bin/cuobjdump +./install.sh +uv pip install dist/*.whl --force-reinstall +``` + +Refer to [DeepGEMM Installation Guide](https://github.com/deepseek-ai/DeepGEMM#installation), if hitting any issues. + +### 3. Install EP Kernels + +Refer to [../../deep_ep_wrapper/README.md](../../deep_ep_wrapper/README.md) to install UCCL-EP's drop-in replacement for DeepEP. + +Refer to vLLM's guide for the original DeepEP and pplx-kernels setup. + +### 4. (Optional) AWS EFA Setup + +For AWS instances with EFA, install AWS OFI-NCCL plugin, which is pre-installed on AWS Deep Learning AMIs + +## ⚙️ Configuration + +### Network Interface Detection + +Find your network interface and IP: + +```bash +# List all network interfaces +ip addr show + +# Common interface names: +# - eth0, eno1, enp0s3 (Ethernet) +# - enp74s0, ens5 (Custom/AWS EFA) +``` + +### Backend Selection + +vLLM provides three EP communication backends: + +| Backend | Use Case | Features | Best For | +|---------|----------|----------|----------| +| `pplx` | Single node | Chunked prefill support | Development, intra-node | +| `deepep_high_throughput` | Multi-node prefill | Grouped GEMM | High throughput, prefill-dominated | +| `deepep_low_latency` | Multi-node decode | CUDA graph support | Low latency, decode-dominated | +| `allgather_reducescatter` | Multi-node | NCCL-based | InfiniBand/EFA networks | + +### Environment Setup + +Edit the provided scripts (`launch_vllm_head.sh` and `launch_vllm_worker.sh`) to configure: + +1. **Network interfaces** - Set `GLOO_SOCKET_IFNAME`, `NCCL_SOCKET_IFNAME` +1. **Backend** - Choose appropriate `VLLM_ALL2ALL_BACKEND` +1. **Model storage** - Set `HF_HOME` to some folder with large storage +1. **DeepGEMM JIT cache** - Set `DG_JIT_CACHE_DIR` to some non-shared folder on each node + + +## 🚢 Deployment + +### Step 1: Start Node 0 (Primary) + +On the **first node** (primary node that handles API requests): + +```bash +bash launch_vllm_head.sh 10.4.147.22 13345 deepseek-ai/DeepSeek-V3-0324 deepep_high_throughput 2 1 8 1 +``` + +### Step 2: Start Node 1+ (Secondary) + +On **each additional node** (secondary nodes in headless mode): + +```bash +# Launch Node 1 (headless) +bash launch_vllm_worker.sh 10.4.147.22 13345 deepseek-ai/DeepSeek-V3-0324 deepep_high_throughput 2 1 8 1 +``` + +**Arguments:** +- `10.4.147.22` - IP address of **Node 0**, should be the IP of the `NCCL_SOCKET_IFNAME` +- `13345` - RPC port +- `deepseek-ai/DeepSeek-V3-0324` - Same model as Node 1 +- `allgather_reducescatter` - EP communication backend +- `2` - Total DP size +- `1` - Local DP size on this node +- `8` - Local TP size on this node +- `1` - For node 0, number of API servers; for others, starting rank (= sum of previous nodes' local DP) diff --git a/ep/bench/vllm/launch_vllm_head.sh b/ep/bench/vllm/launch_vllm_head.sh new file mode 100755 index 00000000..f8510a7b --- /dev/null +++ b/ep/bench/vllm/launch_vllm_head.sh @@ -0,0 +1,135 @@ +#!/bin/bash +# Node 0 (Primary) - Multi-node vLLM with Expert Parallel (EP) +# This node handles incoming requests +# +# Prerequisites: +# 1. Install vLLM with EP support: https://docs.vllm.ai/en/stable/serving/expert_parallel_deployment.html#architecture-overview +# 2. Install DeepGEMM: https://github.com/deepseek-ai/DeepGEMM#installation +# 3. Install EP kernels: Follow vLLM's EP installation guide +# 4. For AWS EFA: Install AWS OFI-NCCL plugin + +# Example: +# bash launch_vllm_head.sh 10.4.147.22 13345 deepseek-ai/DeepSeek-V3-0324 deepep_low_latency 2 1 8 1 + +set -e + +echo "🚀 Launching vLLM Node 0 (Primary) with Expert Parallel..." + +# Check if IP is provided +if [ -z "$1" ]; then + echo "❌ Error: Node IP address is required!" + echo "" + echo "Usage: $0 [RPC_PORT] [MODEL] [TOTAL_DP_SIZE] [LOCAL_DP_SIZE] [API_SERVERS]" + echo "" + echo "Example:" + echo " $0 10.1.107.86 13345 deepseek-ai/DeepSeek-V3-0324 16 8 8" + echo "" + echo "💡 To find your IP address, run: hostname -I" + exit 1 +fi + +# PyTorch library path (required for DeepGEMM) +export LD_LIBRARY_PATH=$(python3 -c "import torch; import os; print(os.path.join(torch.__path__[0], 'lib'))"):$LD_LIBRARY_PATH + +export VLLM_USE_DEEP_GEMM=1 + +# ============================================================================ +# NETWORK CONFIGURATION +# ============================================================================ + +# For InfiniBand/EFA clusters: Prevent initialization hangs +# This ensures torch distributed uses Ethernet for initial setup +# Find your network interface: ip addr show | grep -E 'eth|enp' +export GLOO_SOCKET_IFNAME=enp71s0 # Change to your primary network interface +export NCCL_SOCKET_IFNAME=enp71s0 # Uncomment if using NCCL +export TP_SOCKET_IFNAME=enp71s0 # Uncomment if using tensor parallel + +# ============================================================================ +# NCCL CONFIGURATION (Optional - for advanced users) +# ============================================================================ + +# AWS EFA NCCL plugin (uncomment if using AWS EFA): +export NCCL_NET_PLUGIN="/opt/amazon/ofi-nccl/lib/x86_64-linux-gnu/libnccl-net.so" + +# NCCL performance tuning (optional): +export NCCL_P2P_NET_CHUNKSIZE=524288 +export NCCL_BUFFSIZE=8388608 + +# NCCL debugging (for diagnosing connection issues): +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=INIT,NET + +# https://github.com/vllm-project/vllm/pull/27444 +export VLLM_ENGINE_READY_TIMEOUT_S=3600 +# Set to local non-shared disk like "/opt/dlami/nvme" +export DG_JIT_CACHE_DIR="/local_storage" + +# ============================================================================ +# ARGUMENTS PARSING +# ============================================================================ + +NODE1_IP="$1" # Node 0 IP address (REQUIRED) +RPC_PORT="${2:-13345}" # RPC communication port +MODEL="${3:-deepseek-ai/DeepSeek-V3-0324}" # Model to serve +BACKEND="${4:-allgather_reducescatter}" # Backend to use +TOTAL_DP_SIZE="${5:-16}" # Total DP size across all nodes +LOCAL_DP_SIZE="${6:-8}" # Local DP size on this node +LOCAL_TP_SIZE="${7:-1}" # Local TP size on this node +API_SERVERS="${8:-8}" # Number of API servers + +# Recommendations: +# - TOTAL_DP_SIZE = LOCAL_DP_SIZE * NUMBER_OF_NODES +# - LOCAL_DP_SIZE = Number of GPUs per node (typically 8 for 8xGPU nodes) +# - API_SERVERS = LOCAL_DP_SIZE (one server per local DP process) + +# ============================================================================ +# CONFIGURATION SUMMARY +# ============================================================================ + +echo "" +echo "╔═══════════════════════════════════════════════════════════════╗" +echo "║ vLLM Expert Parallel Configuration ║" +echo "╚═══════════════════════════════════════════════════════════════╝" +echo "" +echo "Backend Configuration:" +echo " • Backend: ${BACKEND}" +echo " • DeepGEMM: Enabled" +echo "" +echo "Node Configuration:" +echo " • Role: Primary (handles API requests)" +echo " • Model: ${MODEL}" +echo " • Node IP: ${NODE1_IP}" +echo " • RPC Port: ${RPC_PORT}" +echo "" +echo "Parallelism Configuration:" +echo " • Total Data Parallel Size: ${TOTAL_DP_SIZE} (across all nodes)" +echo " • Local Data Parallel Size: ${LOCAL_DP_SIZE} (this node)" +echo " • Local Tensor Parallel Size: ${LOCAL_TP_SIZE} (this node)" +echo " • API Servers: ${API_SERVERS}" +echo " • Expert Parallel: Enabled" +echo "" +echo "═══════════════════════════════════════════════════════════════" +echo "" + +# ============================================================================ +# LAUNCH vLLM SERVER +# ============================================================================ + +vllm serve "${MODEL}" \ + --enable-expert-parallel \ + --all2all-backend "${BACKEND}" \ + --tensor-parallel-size "${LOCAL_TP_SIZE}" \ + --data-parallel-size "${TOTAL_DP_SIZE}" \ + --data-parallel-size-local "${LOCAL_DP_SIZE}" \ + --data-parallel-address "${NODE1_IP}" \ + --data-parallel-rpc-port "${RPC_PORT}" \ + --gpu-memory-utilization 0.8 \ + --api-server-count="${API_SERVERS}" + +# Additional useful options (uncomment as needed): +# --max-model-len 8192 \ +# --gpu-memory-utilization 0.9 \ +# --dtype auto \ +# --enable-chunked-prefill \ +# --port 8000 \ + diff --git a/ep/bench/vllm/launch_vllm_worker.sh b/ep/bench/vllm/launch_vllm_worker.sh new file mode 100755 index 00000000..09c2437e --- /dev/null +++ b/ep/bench/vllm/launch_vllm_worker.sh @@ -0,0 +1,136 @@ +#!/bin/bash +# Node 1+ (Secondary) - Multi-node vLLM with Expert Parallel (EP) +# This node runs in headless mode (no API server) +# +# Prerequisites: Same as Node 0 +# 1. Install vLLM with EP support +# 2. Install DeepGEMM +# 3. Install EP kernels +# 4. For AWS EFA: Install AWS OFI-NCCL plugin +# +# IMPORTANT: All configuration must match Node 0! + +# Example: +# bash launch_vllm_worker.sh 10.4.147.22 13345 deepseek-ai/DeepSeek-V3-0324 allgather_reducescatter 2 1 8 1 + +set -e + +echo "🚀 Launching vLLM Secondary Node (Headless) with Expert Parallel..." + +# Check if primary node IP is provided +if [ -z "$1" ]; then + echo "❌ Error: Primary Node (Node 0) IP address is required!" + echo "" + echo "Usage: $0 [RPC_PORT] [MODEL] [TOTAL_DP_SIZE] [LOCAL_DP_SIZE] [START_RANK]" + echo "" + echo "Example:" + echo " $0 10.1.107.86 13345 deepseek-ai/DeepSeek-V3-0324 16 8 8" + echo "" + echo "⚠️ Note: Use Node 0's IP address, not this node's IP!" + echo "💡 To find Node 0's IP, run on Node 0: hostname -I" + exit 1 +fi + +# PyTorch library path (required for DeepGEMM) +export LD_LIBRARY_PATH=$(python3 -c "import torch; import os; print(os.path.join(torch.__path__[0], 'lib'))"):$LD_LIBRARY_PATH + +export VLLM_USE_DEEP_GEMM=1 + +# ============================================================================ +# NETWORK CONFIGURATION +# ============================================================================ + +# CRITICAL: Must match Node 0 exactly! +# For InfiniBand/EFA clusters +export GLOO_SOCKET_IFNAME=enp71s0 # Change to your primary network interface +export NCCL_SOCKET_IFNAME=enp71s0 # Uncomment if using NCCL +export TP_SOCKET_IFNAME=enp71s0 # Uncomment if using tensor parallel + +# ============================================================================ +# NCCL CONFIGURATION (Optional) +# ============================================================================ +# CRITICAL: Must match Node 0 exactly! + +# AWS EFA NCCL plugin (uncomment if using AWS EFA): +export NCCL_NET_PLUGIN="/opt/amazon/ofi-nccl/lib/x86_64-linux-gnu/libnccl-net.so" + +# NCCL performance tuning (optional): +export NCCL_P2P_NET_CHUNKSIZE=524288 +export NCCL_BUFFSIZE=8388608 + +# NCCL debugging (for diagnosing connection issues): +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=INIT,NET + +# https://github.com/vllm-project/vllm/pull/27444 +export VLLM_ENGINE_READY_TIMEOUT_S=3600 +# Set to local non-shared disk like "/opt/dlami/nvme" +export DG_JIT_CACHE_DIR="/local_storage" + +# ============================================================================ +# ARGUMENTS PARSING +# ============================================================================ + +NODE1_IP="$1" # Primary node IP (REQUIRED) +RPC_PORT="${2:-13345}" # Same RPC port as Node 0 +MODEL="${3:-deepseek-ai/DeepSeek-V3-0324}" # Same model as Node 0 +BACKEND="${4:-allgather_reducescatter}" # Backend to use +TOTAL_DP_SIZE="${5:-16}" # Same total DP as Node 0 +LOCAL_DP_SIZE="${6:-8}" # Local DP on this node +LOCAL_TP_SIZE="${7:-1}" # Local TP on this node +START_RANK="${8:-8}" # Starting rank offset + +# START_RANK calculation: +# - Node 1: LOCAL_DP_SIZE of Node 0 (e.g., 8) +# - Node 2: LOCAL_DP_SIZE of Node 0 + Node 1 (e.g., 16) +# - Node N: Sum of all previous nodes' LOCAL_DP_SIZE + +# ============================================================================ +# CONFIGURATION SUMMARY +# ============================================================================ + +echo "" +echo "╔═══════════════════════════════════════════════════════════════╗" +echo "║ vLLM Expert Parallel - Secondary Node Config ║" +echo "╚═══════════════════════════════════════════════════════════════╝" +echo "" +echo "Backend Configuration:" +echo " • Backend: ${BACKEND}" +echo " • DeepGEMM: Enabled" +echo "" +echo "Node Configuration:" +echo " • Role: Secondary (headless worker)" +echo " • Model: ${MODEL}" +echo " • Primary Node IP: ${NODE1_IP}" +echo " • RPC Port: ${RPC_PORT}" +echo "" +echo "Parallelism Configuration:" +echo " • Total Data Parallel Size: ${TOTAL_DP_SIZE} (across all nodes)" +echo " • Local Data Parallel Size: ${LOCAL_DP_SIZE} (this node)" +echo " • Local Tensor Parallel Size: ${LOCAL_TP_SIZE} (this node)" +echo " • Starting Rank: ${START_RANK}" +echo " • Expert Parallel: Enabled" +echo "" +echo "═══════════════════════════════════════════════════════════════" +echo "" + +# ============================================================================ +# LAUNCH vLLM SERVER (HEADLESS MODE) +# ============================================================================ + +vllm serve "${MODEL}" \ + --enable-expert-parallel \ + --all2all-backend "${BACKEND}" \ + --tensor-parallel-size "${LOCAL_TP_SIZE}" \ + --data-parallel-size "${TOTAL_DP_SIZE}" \ + --data-parallel-size-local "${LOCAL_DP_SIZE}" \ + --data-parallel-start-rank "${START_RANK}" \ + --data-parallel-address "${NODE1_IP}" \ + --data-parallel-rpc-port "${RPC_PORT}" \ + --gpu-memory-utilization 0.8 \ + --headless + +# Additional useful options (uncomment as needed, must match Node 0): +# --max-model-len 8192 \ +# --gpu-memory-utilization 0.9 \ +# --dtype auto \ diff --git a/ep/deep_ep_wrapper/README.md b/ep/deep_ep_wrapper/README.md index 0e2d6ac7..9ea84f5c 100644 --- a/ep/deep_ep_wrapper/README.md +++ b/ep/deep_ep_wrapper/README.md @@ -1,7 +1,14 @@ ## DeepEP Wrapper of UCCL-EP +First build and install UCCL-EP: +```bash +cd uccl +bash build.sh cuda ep +uv pip install wheelhouse-cuda/uccl-*.whl ``` + +Then install UCCL-EP's drop-in replacement for DeepEP: +```bash +cd ep/deep_ep_wrapper python setup.py install ``` - -pip install ../../wheelhouse-cuda/uccl-0.0.1.post4-py3-none-any.whl \ No newline at end of file diff --git a/ep/src/internode.cu b/ep/src/internode.cu index dd37ea1f..c7011a36 100644 --- a/ep/src/internode.cu +++ b/ep/src/internode.cu @@ -274,11 +274,16 @@ __global__ void notify_dispatch( i)[NUM_MAX_NVL_PEERS + num_rdma_experts]; recv_rdma_rank_prefix_sum[i] = sum; } - if (num_worst_tokens == 0) { - while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1) - ; - *moe_recv_rdma_counter_mapped = sum; - } + // NOTE(MaoZiming): if I wrap this code with if (num_worst_tokens == 0), + // it will somehow cause deadlock on vllm with deepep_high_throughput + // mode. I suspect it is because some compiler reordering, but I don't + // know why. num_worst_tokens = 0, but somehow wrapping it with the + // conditional will cause deadlock. Removing the ``if" is logically + // redundant but harmless. if (num_worst_tokens == 0) { + while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1) + ; + *moe_recv_rdma_counter_mapped = sum; + // } } // Send numbers of tokens per rank/expert to NVL ranks @@ -306,11 +311,11 @@ __global__ void notify_dispatch( sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank]; recv_gbl_rank_prefix_sum[i] = sum; } - if (num_worst_tokens == 0) { - while (ld_volatile_global(moe_recv_counter_mapped) != -1) - ; - *moe_recv_counter_mapped = sum; - } + // if (num_worst_tokens == 0) { + while (ld_volatile_global(moe_recv_counter_mapped) != -1) + ; + *moe_recv_counter_mapped = sum; + // } } if (thread_id < num_nvl_experts) { int sum = 0; @@ -318,12 +323,12 @@ __global__ void notify_dispatch( for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id]; sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; - if (num_worst_tokens == 0) { - while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != - -1) - ; - moe_recv_expert_counter_mapped[thread_id] = sum; - } + // if (num_worst_tokens == 0) { + while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != + -1) + ; + // } + moe_recv_expert_counter_mapped[thread_id] = sum; } // Finally barrier @@ -1610,7 +1615,7 @@ __global__ void cached_notify( // Barrier for RDMA if (thread_id == WARP_SIZE) uccl::nvshmem_sync_with_same_gpu_idx(d2h_channel_addrs, - num_d2h_channel_addrs, nvl_rank, 3); + num_d2h_channel_addrs, nvl_rank); // Barrier for NVL barrier_block(barrier_signal_ptrs, nvl_rank); diff --git a/ep/src/proxy.cpp b/ep/src/proxy.cpp index 7248f030..dec8deba 100644 --- a/ep/src/proxy.cpp +++ b/ep/src/proxy.cpp @@ -864,6 +864,7 @@ void Proxy::post_gpu_commands_mixed( #ifdef USE_MSCCLPP_FIFO_BACKEND assert(barrier_wrs.size() == 1 && ctx_.barrier_wr == -1); #endif + assert(quiet_wrs.empty() && "quiet_wrs should be empty"); send_barrier(barrier_wrs[0]); barrier_wrs.clear(); barrier_cmds.clear(); diff --git a/ep/src/uccl_ep.cc b/ep/src/uccl_ep.cc index 9ff983ea..81bf6787 100644 --- a/ep/src/uccl_ep.cc +++ b/ep/src/uccl_ep.cc @@ -611,7 +611,6 @@ class Buffer { num_ranks), num_nvl_bytes, low_latency_mode, d_handles, num_d2h_channel_addrs, atomic_buffer_ptr); - // Synchronize total received tokens and tokens per expert if (num_worst_tokens > 0) { num_recv_tokens = num_worst_tokens; @@ -880,7 +879,6 @@ class Buffer { config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, false, low_latency_mode, d_handles, num_d2h_channel_addrs, atomic_buffer_ptr); - // Assign bias pointers auto bias_opts = std::vector>({bias_0, bias_1});