diff --git a/build.sh b/build.sh index 335dfd985..248fea9a1 100755 --- a/build.sh +++ b/build.sh @@ -145,6 +145,13 @@ build_p2p() { else echo "[container] USE_TCPX=1, skipping copying p2p runtime files" fi + if [[ "$TARGET" == rocm* ]]; then + cd thirdparty/dietgpu + rm -rf build/ + python3 setup.py build + cd ../.. + cp thirdparty/dietgpu/build/**/*.so uccl/ + fi } build_ep() { @@ -271,9 +278,10 @@ echo "[2/3] Running build inside container..." # Auto-detect CUDA architecture for ep build DETECTED_GPU_ARCH="" -if [[ "$BUILD_TYPE" =~ (ep|all) ]];then +if [[ "$BUILD_TYPE" =~ (ep|all|p2p) ]];then if [[ "$TARGET" == cuda* ]] && command -v nvidia-smi &> /dev/null; then - DETECTED_GPU_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null | head -n1 | tr -d ' ') + DETECTED_GPU_ARCH="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null | head -n1 | tr -d ' ' || true)" + if [[ -n "$DETECTED_GPU_ARCH" ]]; then echo "Auto-detected CUDA compute capability: ${DETECTED_GPU_ARCH}" fi @@ -283,10 +291,23 @@ if [[ "$BUILD_TYPE" =~ (ep|all) ]];then echo "jq not found, installing via pip..." pip install jq fi - DETECTED_GPU_ARCH=$(amd-smi static -g 0 --asic --json | jq -r '.[].asic.target_graphics_version') - if [[ -n "$DETECTED_GPU_ARCH" ]]; then - echo "Auto-detected ROCm architecture: ${DETECTED_GPU_ARCH}" + DETECTED_GPU_ARCH="$( + PYTHONWARNINGS=ignore \ + amd-smi static -g 0 --asic --json 2>/dev/null \ + | jq -r ' + if .gpu_data and (.gpu_data | length > 0) then + .gpu_data[0].asic.target_graphics_version + else + empty + end + ' \ + || true + )" + if [[ -n "$DETECTED_GPU_ARCH" ]]; then + echo "Auto-detected ROCm architecture: ${DETECTED_GPU_ARCH}" fi + else + echo "[INFO] No compatible GPU detection tool found, skipping auto-detect" fi fi @@ -306,6 +327,7 @@ docker run --rm --user "$(id -u):$(id -g)" \ -e BUILD_TYPE="${BUILD_TYPE}" \ -e USE_TCPX="${USE_TCPX:-0}" \ -e USE_EFA="${USE_EFA:-0}" \ + -e USE_IB="${USE_IB:-0}" \ -e MAKE_NORMAL_MODE="${MAKE_NORMAL_MODE:-}" \ -e TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-}" \ -e FUNCTION_DEF="$(declare -f build_rccl_nccl_h build_ccl_rdma build_ccl_efa build_p2p build_ep build_eccl)" \ @@ -397,6 +419,7 @@ def initialize(): --exclude "libcudart.so.12" \ --exclude "libamdhip64.so.*" \ --exclude "libcuda.so.1" \ + --exclude "libefa.so.1" \ -w /io/${WHEEL_DIR} # Add backend tag to wheel filename using local version identifier diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 6114cf49a..af00471a3 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -14,7 +14,8 @@ RUN apt-get update && \ rdma-core libibverbs-dev libnuma-dev \ libgoogle-glog-dev libgflags-dev libgtest-dev libelf-dev \ pkg-config zlib1g-dev curl unzip \ - software-properties-common && \ + software-properties-common \ + hipcub && \ \ # ───── Add Python ${PY_VER} PPA & install Python ${PY_VER} + setuptools ───── add-apt-repository ppa:deadsnakes/ppa && \ @@ -40,7 +41,7 @@ RUN apt-get update && \ RUN python${PY_VER} -m pip install --no-cache-dir build auditwheel pybind11 RUN python${PY_VER} -m pip install --no-cache-dir --pre torch torchvision \ - --index-url https://download.pytorch.org/whl/nightly/rocm7.0 + --index-url https://download.pytorch.org/whl/nightly/rocm7.1 RUN python${PY_VER} -m pip install --no-cache-dir --upgrade setuptools diff --git a/ep/bench/buffer.py b/ep/bench/buffer.py index bf12b2d5b..b49a528f9 100644 --- a/ep/bench/buffer.py +++ b/ep/bench/buffer.py @@ -679,9 +679,6 @@ def dispatch( # Internode if self.runtime.get_num_rdma_ranks() > 1: - assert ( - num_worst_tokens == 0 - ), "Internode dispatch does not support `num_worst_tokens > 0`" return self.internode_dispatch( x, handle, @@ -692,6 +689,7 @@ def dispatch( topk_idx, topk_weights, expert_alignment, + num_worst_tokens, config, previous_event, async_finish, @@ -881,6 +879,7 @@ def internode_dispatch( topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1, + num_worst_tokens: int = 0, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, @@ -934,6 +933,7 @@ def internode_dispatch( gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, expert_alignment, + num_worst_tokens, config, getattr(previous_event, "event", None), async_finish, @@ -986,6 +986,7 @@ def internode_dispatch( None, None, expert_alignment, + num_worst_tokens, config, getattr(previous_event, "event", None), async_finish, diff --git a/ep/bench/run_ep.sh b/ep/bench/run_ep.sh index 024bdff7b..2db0346f3 100755 --- a/ep/bench/run_ep.sh +++ b/ep/bench/run_ep.sh @@ -12,10 +12,15 @@ if [ "$MODE" = "ll" ]; then --master_addr=$MAIN_IP --master_port=12355 \ test_low_latency.py --num-tokens=128 \ --hidden=7168 --num-topk=8 --num-experts=288 -else +elif [ "$MODE" = "ht" ]; then torchrun --nnodes=$NNODES --nproc_per_node=8 --node_rank=$RANK \ --master_addr=$MAIN_IP --master_port=12355 \ test_internode.py --num-tokens=4096 \ --hidden=7168 --num-topk=8 --num-experts=288 --test-ll-compatibility +else + torchrun --nnodes=$NNODES --nproc_per_node=8 --node_rank=$RANK \ + --master_addr=$MAIN_IP --master_port=12355 \ + test_internode.py --num-tokens=4096 \ + --hidden=7168 --num-topk=8 --num-experts=256 --pressure-test-mode=1 fi # --log-dir=logs --redirect=3 \ No newline at end of file diff --git a/ep/deep_ep_wrapper/scripts/GLM-4.6_nccl.sh b/ep/bench/sglang/GLM-4.6_nccl.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/GLM-4.6_nccl.sh rename to ep/bench/sglang/GLM-4.6_nccl.sh diff --git a/ep/deep_ep_wrapper/scripts/GLM-4.6_uep.sh b/ep/bench/sglang/GLM-4.6_uep.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/GLM-4.6_uep.sh rename to ep/bench/sglang/GLM-4.6_uep.sh diff --git a/ep/deep_ep_wrapper/scripts/Kimi_uep.sh b/ep/bench/sglang/Kimi_uep.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/Kimi_uep.sh rename to ep/bench/sglang/Kimi_uep.sh diff --git a/ep/deep_ep_wrapper/scripts/Qwen3-235B_nccl.sh b/ep/bench/sglang/Qwen3-235B_nccl.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/Qwen3-235B_nccl.sh rename to ep/bench/sglang/Qwen3-235B_nccl.sh diff --git a/ep/deep_ep_wrapper/scripts/Qwen3-235B_uep.sh b/ep/bench/sglang/Qwen3-235B_uep.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/Qwen3-235B_uep.sh rename to ep/bench/sglang/Qwen3-235B_uep.sh diff --git a/ep/deep_ep_wrapper/scripts/Qwen3-30B_nccl.sh b/ep/bench/sglang/Qwen3-30B_nccl.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/Qwen3-30B_nccl.sh rename to ep/bench/sglang/Qwen3-30B_nccl.sh diff --git a/ep/deep_ep_wrapper/scripts/Qwen3-30B_uep.sh b/ep/bench/sglang/Qwen3-30B_uep.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/Qwen3-30B_uep.sh rename to ep/bench/sglang/Qwen3-30B_uep.sh diff --git a/ep/deep_ep_wrapper/scripts/common_deepep_config.sh b/ep/bench/sglang/common_deepep_config.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/common_deepep_config.sh rename to ep/bench/sglang/common_deepep_config.sh diff --git a/ep/deep_ep_wrapper/scripts/common_env.sh b/ep/bench/sglang/common_env.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/common_env.sh rename to ep/bench/sglang/common_env.sh diff --git a/ep/deep_ep_wrapper/scripts/common_launch.sh b/ep/bench/sglang/common_launch.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/common_launch.sh rename to ep/bench/sglang/common_launch.sh diff --git a/ep/deep_ep_wrapper/scripts/deepseek_r1_bf16_nccl.sh b/ep/bench/sglang/deepseek_r1_bf16_nccl.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/deepseek_r1_bf16_nccl.sh rename to ep/bench/sglang/deepseek_r1_bf16_nccl.sh diff --git a/ep/deep_ep_wrapper/scripts/deepseek_r1_nccl.sh b/ep/bench/sglang/deepseek_r1_nccl.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/deepseek_r1_nccl.sh rename to ep/bench/sglang/deepseek_r1_nccl.sh diff --git a/ep/deep_ep_wrapper/scripts/deepseek_r1_uep_ll.sh b/ep/bench/sglang/deepseek_r1_uep_ll.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/deepseek_r1_uep_ll.sh rename to ep/bench/sglang/deepseek_r1_uep_ll.sh diff --git a/ep/deep_ep_wrapper/scripts/deepseek_v3_nccl.sh b/ep/bench/sglang/deepseek_v3_nccl.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/deepseek_v3_nccl.sh rename to ep/bench/sglang/deepseek_v3_nccl.sh diff --git a/ep/deep_ep_wrapper/scripts/launch_sglang_docker.sh b/ep/bench/sglang/launch_sglang_docker.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/launch_sglang_docker.sh rename to ep/bench/sglang/launch_sglang_docker.sh diff --git a/ep/deep_ep_wrapper/scripts/sglang_nccl_44000_prefill.sh b/ep/bench/sglang/sglang_nccl_44000_prefill.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/sglang_nccl_44000_prefill.sh rename to ep/bench/sglang/sglang_nccl_44000_prefill.sh diff --git a/ep/deep_ep_wrapper/scripts/sglang_uep_46000_prefill.sh b/ep/bench/sglang/sglang_uep_46000_prefill.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/sglang_uep_46000_prefill.sh rename to ep/bench/sglang/sglang_uep_46000_prefill.sh diff --git a/ep/deep_ep_wrapper/scripts/sglang_uep_75000_prefill.sh b/ep/bench/sglang/sglang_uep_75000_prefill.sh similarity index 100% rename from ep/deep_ep_wrapper/scripts/sglang_uep_75000_prefill.sh rename to ep/bench/sglang/sglang_uep_75000_prefill.sh diff --git a/ep/bench/test_internode.py b/ep/bench/test_internode.py index 3ea106ce7..3622470ba 100644 --- a/ep/bench/test_internode.py +++ b/ep/bench/test_internode.py @@ -252,6 +252,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): ) if current_x is not x_pure_rand: check_data(recv_x, recv_gbl_rank_prefix_sum) + recv_topk_weights_clone = None if with_topk: # Check `topk_idx` assert ( @@ -265,6 +266,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): assert recv_topk_idx.eq(i).sum().item() == count # Check `topk_weights` + recv_topk_weights_clone = recv_topk_weights.clone() if current_x is not x_pure_rand: recv_topk_weights[recv_topk_idx.eq(-1)] = ( recv_topk_weights.amax(dim=1, keepdim=True).expand_as( @@ -273,6 +275,40 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): ) check_data(recv_topk_weights, recv_gbl_rank_prefix_sum) + # Test `num_worst_tokens != 0` + if with_topk: + num_worst_tokens = num_tokens * num_ranks + dispatch_args.update({"num_worst_tokens": num_worst_tokens}) + ( + recv_worst_x, + recv_worst_topk_idx, + recv_worst_topk_weights, + empty_list, + _, + event, + ) = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + recv_worst_x = ( + per_token_cast_back(*recv_worst_x) + if isinstance(recv_worst_x, tuple) + else recv_worst_x + ) + assert len(empty_list) == 0 + assert num_worst_tokens == recv_worst_x.size(0) + assert num_worst_tokens == recv_worst_topk_idx.size(0) + assert num_worst_tokens == recv_worst_topk_weights.size(0) + assert torch.equal(recv_x, recv_worst_x[: recv_x.size(0)]) + assert torch.equal( + recv_topk_idx, recv_worst_topk_idx[: recv_x.size(0)] + ) + assert torch.equal( + recv_topk_weights_clone, + recv_worst_topk_weights[: recv_x.size(0)], + ) + assert torch.all( + recv_worst_topk_idx[recv_x.size(0) :] == -1 + ).item() + # Test cached dispatch (must without top-k staffs) if not with_topk: dispatch_args = { @@ -502,7 +538,7 @@ def test_loop( assert num_local_ranks == 8 and num_ranks > 8 - for seed in range(int(1e9)): + for seed in range(0, int(1e9)): if local_rank == 0: print(f"Testing with seed {seed} ...", flush=True) torch.manual_seed(rank + seed) diff --git a/ep/bench/utils.py b/ep/bench/utils.py index 26c493212..0826a49c6 100644 --- a/ep/bench/utils.py +++ b/ep/bench/utils.py @@ -89,53 +89,10 @@ def init_dist_under_torchrun(local_rank: int, num_local_ranks: int): ) -def _discover_local_ip(): - # Try to infer the IP that can reach MASTER_ADDR (works in most clusters) - import socket, os - - # Method 1: Use MASTER_ADDR if available (torchrun style) - if "MASTER_ADDR" in os.environ: - master = os.environ["MASTER_ADDR"] - port = int(os.environ.get("MASTER_PORT", "29500")) - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - s.connect((master, port)) - return s.getsockname()[0] - except: - pass - finally: - s.close() - - # Method 2: Use hostname resolution (works in AWS and most cloud environments) - hostname = socket.gethostname() - try: - # This usually returns the private IP in cloud environments - local_ip = socket.gethostbyname(hostname) - # Skip loopback addresses - if not local_ip.startswith("127."): - return local_ip - except: - pass - - # Method 3: Connect to a public DNS to determine outgoing interface - try: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - # Google DNS - this doesn't actually send packets - s.connect(("8.8.8.8", 80)) - local_ip = s.getsockname()[0] - s.close() - return local_ip - except: - pass - - # Last resort: return localhost - return "127.0.0.1" - - def _gather_peer_ips(group): # Gather local IP strings across ranks world = dist.get_world_size(group) - my_ip = _discover_local_ip() + my_ip = ep.get_oob_ip() ips = [None] * world dist.all_gather_object(ips, my_ip, group=group) return ips @@ -154,7 +111,7 @@ def get_peer_ip(rank: int, num_ranks: int, group: dist.ProcessGroup): def get_cpu_proxies_meta(proxies, rank, scratch_ptr, scratch_bytes, num_ranks, group): - my_ip = _discover_local_ip() + my_ip = ep.get_oob_ip() meta = { "rank": rank, "ptr": int(scratch_ptr), diff --git a/ep/bench/vllm/README.md b/ep/bench/vllm/README.md new file mode 100644 index 000000000..8dd649803 --- /dev/null +++ b/ep/bench/vllm/README.md @@ -0,0 +1,122 @@ +# vLLM + UCCL-EP Multi-Node Expert Parallel Deployment Guide + +This guide provides example scripts and instructions for deploying vLLM with Expert Parallelism (EP) across multiple nodes on AWS p5en. + +## 🚀 Installation + +### 0. Install uv + +```bash +curl -LsSf https://astral.sh/uv/install.sh | sh +uv venv +source .venv/bin/activate +uv pip install numpy torch setuptools +``` + +### 1. Install vLLM with EP Support + +Follow the official guide: +```bash +# Install vLLM: latest version with timeout fix (https://github.com/vllm-project/vllm/pull/27444) +git clone https://github.com/vllm-project/vllm.git +cd vllm +# This may take 5-10 minutes. +uv pip install -e . +``` + +For detailed EP setup, refer to [vLLM Expert Parallel Deployment](https://docs.vllm.ai/en/stable/serving/expert_parallel_deployment.html) + +### 2. Install DeepGEMM Library + +DeepGEMM provides optimized kernels for MoE operations: + +```bash +# Clone and install DeepGEMM +git clone --recursive https://github.com/deepseek-ai/DeepGEMM.git +cd DeepGEMM +cat install.sh +# cuobjdump used by https://github.com/deepseek-ai/DeepGEMM/blob/9b680f428484625f4f35dc3617f134187c6bcd4a/csrc/jit/kernel_runtime.hpp#L44 +# If you could not find cuobjdump in your servers, install it by: +sudo apt install nvidia-cuda-toolkit -y +# If your server's cuobjdump is under /bin instead of $CUDA_HOME/bin, set soft link to make DeepGEMM happy: +sudo ln -s /bin/cuobjdump /usr/local/cuda/bin/cuobjdump +./install.sh +uv pip install dist/*.whl --force-reinstall +``` + +Refer to [DeepGEMM Installation Guide](https://github.com/deepseek-ai/DeepGEMM#installation), if hitting any issues. + +### 3. Install EP Kernels + +Refer to [../../deep_ep_wrapper/README.md](../../deep_ep_wrapper/README.md) to install UCCL-EP's drop-in replacement for DeepEP. + +Refer to vLLM's guide for the original DeepEP and pplx-kernels setup. + +### 4. (Optional) AWS EFA Setup + +For AWS instances with EFA, install AWS OFI-NCCL plugin, which is pre-installed on AWS Deep Learning AMIs + +## ⚙️ Configuration + +### Network Interface Detection + +Find your network interface and IP: + +```bash +# List all network interfaces +ip addr show + +# Common interface names: +# - eth0, eno1, enp0s3 (Ethernet) +# - enp74s0, ens5 (Custom/AWS EFA) +``` + +### Backend Selection + +vLLM provides three EP communication backends: + +| Backend | Use Case | Features | Best For | +|---------|----------|----------|----------| +| `pplx` | Single node | Chunked prefill support | Development, intra-node | +| `deepep_high_throughput` | Multi-node prefill | Grouped GEMM | High throughput, prefill-dominated | +| `deepep_low_latency` | Multi-node decode | CUDA graph support | Low latency, decode-dominated | +| `allgather_reducescatter` | Multi-node | NCCL-based | InfiniBand/EFA networks | + +### Environment Setup + +Edit the provided scripts (`launch_vllm_head.sh` and `launch_vllm_worker.sh`) to configure: + +1. **Network interfaces** - Set `GLOO_SOCKET_IFNAME`, `NCCL_SOCKET_IFNAME` +1. **Backend** - Choose appropriate `VLLM_ALL2ALL_BACKEND` +1. **Model storage** - Set `HF_HOME` to some folder with large storage +1. **DeepGEMM JIT cache** - Set `DG_JIT_CACHE_DIR` to some non-shared folder on each node + + +## 🚢 Deployment + +### Step 1: Start Node 0 (Primary) + +On the **first node** (primary node that handles API requests): + +```bash +bash launch_vllm_head.sh 10.4.147.22 13345 deepseek-ai/DeepSeek-V3-0324 deepep_high_throughput 2 1 8 1 +``` + +### Step 2: Start Node 1+ (Secondary) + +On **each additional node** (secondary nodes in headless mode): + +```bash +# Launch Node 1 (headless) +bash launch_vllm_worker.sh 10.4.147.22 13345 deepseek-ai/DeepSeek-V3-0324 deepep_high_throughput 2 1 8 1 +``` + +**Arguments:** +- `10.4.147.22` - IP address of **Node 0**, should be the IP of the `NCCL_SOCKET_IFNAME` +- `13345` - RPC port +- `deepseek-ai/DeepSeek-V3-0324` - Same model as Node 1 +- `allgather_reducescatter` - EP communication backend +- `2` - Total DP size +- `1` - Local DP size on this node +- `8` - Local TP size on this node +- `1` - For node 0, number of API servers; for others, starting rank (= sum of previous nodes' local DP) diff --git a/ep/bench/vllm/launch_vllm_head.sh b/ep/bench/vllm/launch_vllm_head.sh new file mode 100755 index 000000000..f8510a7ba --- /dev/null +++ b/ep/bench/vllm/launch_vllm_head.sh @@ -0,0 +1,135 @@ +#!/bin/bash +# Node 0 (Primary) - Multi-node vLLM with Expert Parallel (EP) +# This node handles incoming requests +# +# Prerequisites: +# 1. Install vLLM with EP support: https://docs.vllm.ai/en/stable/serving/expert_parallel_deployment.html#architecture-overview +# 2. Install DeepGEMM: https://github.com/deepseek-ai/DeepGEMM#installation +# 3. Install EP kernels: Follow vLLM's EP installation guide +# 4. For AWS EFA: Install AWS OFI-NCCL plugin + +# Example: +# bash launch_vllm_head.sh 10.4.147.22 13345 deepseek-ai/DeepSeek-V3-0324 deepep_low_latency 2 1 8 1 + +set -e + +echo "🚀 Launching vLLM Node 0 (Primary) with Expert Parallel..." + +# Check if IP is provided +if [ -z "$1" ]; then + echo "❌ Error: Node IP address is required!" + echo "" + echo "Usage: $0 [RPC_PORT] [MODEL] [TOTAL_DP_SIZE] [LOCAL_DP_SIZE] [API_SERVERS]" + echo "" + echo "Example:" + echo " $0 10.1.107.86 13345 deepseek-ai/DeepSeek-V3-0324 16 8 8" + echo "" + echo "💡 To find your IP address, run: hostname -I" + exit 1 +fi + +# PyTorch library path (required for DeepGEMM) +export LD_LIBRARY_PATH=$(python3 -c "import torch; import os; print(os.path.join(torch.__path__[0], 'lib'))"):$LD_LIBRARY_PATH + +export VLLM_USE_DEEP_GEMM=1 + +# ============================================================================ +# NETWORK CONFIGURATION +# ============================================================================ + +# For InfiniBand/EFA clusters: Prevent initialization hangs +# This ensures torch distributed uses Ethernet for initial setup +# Find your network interface: ip addr show | grep -E 'eth|enp' +export GLOO_SOCKET_IFNAME=enp71s0 # Change to your primary network interface +export NCCL_SOCKET_IFNAME=enp71s0 # Uncomment if using NCCL +export TP_SOCKET_IFNAME=enp71s0 # Uncomment if using tensor parallel + +# ============================================================================ +# NCCL CONFIGURATION (Optional - for advanced users) +# ============================================================================ + +# AWS EFA NCCL plugin (uncomment if using AWS EFA): +export NCCL_NET_PLUGIN="/opt/amazon/ofi-nccl/lib/x86_64-linux-gnu/libnccl-net.so" + +# NCCL performance tuning (optional): +export NCCL_P2P_NET_CHUNKSIZE=524288 +export NCCL_BUFFSIZE=8388608 + +# NCCL debugging (for diagnosing connection issues): +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=INIT,NET + +# https://github.com/vllm-project/vllm/pull/27444 +export VLLM_ENGINE_READY_TIMEOUT_S=3600 +# Set to local non-shared disk like "/opt/dlami/nvme" +export DG_JIT_CACHE_DIR="/local_storage" + +# ============================================================================ +# ARGUMENTS PARSING +# ============================================================================ + +NODE1_IP="$1" # Node 0 IP address (REQUIRED) +RPC_PORT="${2:-13345}" # RPC communication port +MODEL="${3:-deepseek-ai/DeepSeek-V3-0324}" # Model to serve +BACKEND="${4:-allgather_reducescatter}" # Backend to use +TOTAL_DP_SIZE="${5:-16}" # Total DP size across all nodes +LOCAL_DP_SIZE="${6:-8}" # Local DP size on this node +LOCAL_TP_SIZE="${7:-1}" # Local TP size on this node +API_SERVERS="${8:-8}" # Number of API servers + +# Recommendations: +# - TOTAL_DP_SIZE = LOCAL_DP_SIZE * NUMBER_OF_NODES +# - LOCAL_DP_SIZE = Number of GPUs per node (typically 8 for 8xGPU nodes) +# - API_SERVERS = LOCAL_DP_SIZE (one server per local DP process) + +# ============================================================================ +# CONFIGURATION SUMMARY +# ============================================================================ + +echo "" +echo "╔═══════════════════════════════════════════════════════════════╗" +echo "║ vLLM Expert Parallel Configuration ║" +echo "╚═══════════════════════════════════════════════════════════════╝" +echo "" +echo "Backend Configuration:" +echo " • Backend: ${BACKEND}" +echo " • DeepGEMM: Enabled" +echo "" +echo "Node Configuration:" +echo " • Role: Primary (handles API requests)" +echo " • Model: ${MODEL}" +echo " • Node IP: ${NODE1_IP}" +echo " • RPC Port: ${RPC_PORT}" +echo "" +echo "Parallelism Configuration:" +echo " • Total Data Parallel Size: ${TOTAL_DP_SIZE} (across all nodes)" +echo " • Local Data Parallel Size: ${LOCAL_DP_SIZE} (this node)" +echo " • Local Tensor Parallel Size: ${LOCAL_TP_SIZE} (this node)" +echo " • API Servers: ${API_SERVERS}" +echo " • Expert Parallel: Enabled" +echo "" +echo "═══════════════════════════════════════════════════════════════" +echo "" + +# ============================================================================ +# LAUNCH vLLM SERVER +# ============================================================================ + +vllm serve "${MODEL}" \ + --enable-expert-parallel \ + --all2all-backend "${BACKEND}" \ + --tensor-parallel-size "${LOCAL_TP_SIZE}" \ + --data-parallel-size "${TOTAL_DP_SIZE}" \ + --data-parallel-size-local "${LOCAL_DP_SIZE}" \ + --data-parallel-address "${NODE1_IP}" \ + --data-parallel-rpc-port "${RPC_PORT}" \ + --gpu-memory-utilization 0.8 \ + --api-server-count="${API_SERVERS}" + +# Additional useful options (uncomment as needed): +# --max-model-len 8192 \ +# --gpu-memory-utilization 0.9 \ +# --dtype auto \ +# --enable-chunked-prefill \ +# --port 8000 \ + diff --git a/ep/bench/vllm/launch_vllm_worker.sh b/ep/bench/vllm/launch_vllm_worker.sh new file mode 100755 index 000000000..09c2437ef --- /dev/null +++ b/ep/bench/vllm/launch_vllm_worker.sh @@ -0,0 +1,136 @@ +#!/bin/bash +# Node 1+ (Secondary) - Multi-node vLLM with Expert Parallel (EP) +# This node runs in headless mode (no API server) +# +# Prerequisites: Same as Node 0 +# 1. Install vLLM with EP support +# 2. Install DeepGEMM +# 3. Install EP kernels +# 4. For AWS EFA: Install AWS OFI-NCCL plugin +# +# IMPORTANT: All configuration must match Node 0! + +# Example: +# bash launch_vllm_worker.sh 10.4.147.22 13345 deepseek-ai/DeepSeek-V3-0324 allgather_reducescatter 2 1 8 1 + +set -e + +echo "🚀 Launching vLLM Secondary Node (Headless) with Expert Parallel..." + +# Check if primary node IP is provided +if [ -z "$1" ]; then + echo "❌ Error: Primary Node (Node 0) IP address is required!" + echo "" + echo "Usage: $0 [RPC_PORT] [MODEL] [TOTAL_DP_SIZE] [LOCAL_DP_SIZE] [START_RANK]" + echo "" + echo "Example:" + echo " $0 10.1.107.86 13345 deepseek-ai/DeepSeek-V3-0324 16 8 8" + echo "" + echo "⚠️ Note: Use Node 0's IP address, not this node's IP!" + echo "💡 To find Node 0's IP, run on Node 0: hostname -I" + exit 1 +fi + +# PyTorch library path (required for DeepGEMM) +export LD_LIBRARY_PATH=$(python3 -c "import torch; import os; print(os.path.join(torch.__path__[0], 'lib'))"):$LD_LIBRARY_PATH + +export VLLM_USE_DEEP_GEMM=1 + +# ============================================================================ +# NETWORK CONFIGURATION +# ============================================================================ + +# CRITICAL: Must match Node 0 exactly! +# For InfiniBand/EFA clusters +export GLOO_SOCKET_IFNAME=enp71s0 # Change to your primary network interface +export NCCL_SOCKET_IFNAME=enp71s0 # Uncomment if using NCCL +export TP_SOCKET_IFNAME=enp71s0 # Uncomment if using tensor parallel + +# ============================================================================ +# NCCL CONFIGURATION (Optional) +# ============================================================================ +# CRITICAL: Must match Node 0 exactly! + +# AWS EFA NCCL plugin (uncomment if using AWS EFA): +export NCCL_NET_PLUGIN="/opt/amazon/ofi-nccl/lib/x86_64-linux-gnu/libnccl-net.so" + +# NCCL performance tuning (optional): +export NCCL_P2P_NET_CHUNKSIZE=524288 +export NCCL_BUFFSIZE=8388608 + +# NCCL debugging (for diagnosing connection issues): +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=INIT,NET + +# https://github.com/vllm-project/vllm/pull/27444 +export VLLM_ENGINE_READY_TIMEOUT_S=3600 +# Set to local non-shared disk like "/opt/dlami/nvme" +export DG_JIT_CACHE_DIR="/local_storage" + +# ============================================================================ +# ARGUMENTS PARSING +# ============================================================================ + +NODE1_IP="$1" # Primary node IP (REQUIRED) +RPC_PORT="${2:-13345}" # Same RPC port as Node 0 +MODEL="${3:-deepseek-ai/DeepSeek-V3-0324}" # Same model as Node 0 +BACKEND="${4:-allgather_reducescatter}" # Backend to use +TOTAL_DP_SIZE="${5:-16}" # Same total DP as Node 0 +LOCAL_DP_SIZE="${6:-8}" # Local DP on this node +LOCAL_TP_SIZE="${7:-1}" # Local TP on this node +START_RANK="${8:-8}" # Starting rank offset + +# START_RANK calculation: +# - Node 1: LOCAL_DP_SIZE of Node 0 (e.g., 8) +# - Node 2: LOCAL_DP_SIZE of Node 0 + Node 1 (e.g., 16) +# - Node N: Sum of all previous nodes' LOCAL_DP_SIZE + +# ============================================================================ +# CONFIGURATION SUMMARY +# ============================================================================ + +echo "" +echo "╔═══════════════════════════════════════════════════════════════╗" +echo "║ vLLM Expert Parallel - Secondary Node Config ║" +echo "╚═══════════════════════════════════════════════════════════════╝" +echo "" +echo "Backend Configuration:" +echo " • Backend: ${BACKEND}" +echo " • DeepGEMM: Enabled" +echo "" +echo "Node Configuration:" +echo " • Role: Secondary (headless worker)" +echo " • Model: ${MODEL}" +echo " • Primary Node IP: ${NODE1_IP}" +echo " • RPC Port: ${RPC_PORT}" +echo "" +echo "Parallelism Configuration:" +echo " • Total Data Parallel Size: ${TOTAL_DP_SIZE} (across all nodes)" +echo " • Local Data Parallel Size: ${LOCAL_DP_SIZE} (this node)" +echo " • Local Tensor Parallel Size: ${LOCAL_TP_SIZE} (this node)" +echo " • Starting Rank: ${START_RANK}" +echo " • Expert Parallel: Enabled" +echo "" +echo "═══════════════════════════════════════════════════════════════" +echo "" + +# ============================================================================ +# LAUNCH vLLM SERVER (HEADLESS MODE) +# ============================================================================ + +vllm serve "${MODEL}" \ + --enable-expert-parallel \ + --all2all-backend "${BACKEND}" \ + --tensor-parallel-size "${LOCAL_TP_SIZE}" \ + --data-parallel-size "${TOTAL_DP_SIZE}" \ + --data-parallel-size-local "${LOCAL_DP_SIZE}" \ + --data-parallel-start-rank "${START_RANK}" \ + --data-parallel-address "${NODE1_IP}" \ + --data-parallel-rpc-port "${RPC_PORT}" \ + --gpu-memory-utilization 0.8 \ + --headless + +# Additional useful options (uncomment as needed, must match Node 0): +# --max-model-len 8192 \ +# --gpu-memory-utilization 0.9 \ +# --dtype auto \ diff --git a/ep/deep_ep_wrapper/README.md b/ep/deep_ep_wrapper/README.md index 0e2d6ac7c..9ea84f5c8 100644 --- a/ep/deep_ep_wrapper/README.md +++ b/ep/deep_ep_wrapper/README.md @@ -1,7 +1,14 @@ ## DeepEP Wrapper of UCCL-EP +First build and install UCCL-EP: +```bash +cd uccl +bash build.sh cuda ep +uv pip install wheelhouse-cuda/uccl-*.whl ``` + +Then install UCCL-EP's drop-in replacement for DeepEP: +```bash +cd ep/deep_ep_wrapper python setup.py install ``` - -pip install ../../wheelhouse-cuda/uccl-0.0.1.post4-py3-none-any.whl \ No newline at end of file diff --git a/ep/include/ep_utils.cuh b/ep/include/ep_utils.cuh index 0c0e548bf..bad082d65 100644 --- a/ep/include/ep_utils.cuh +++ b/ep/include/ep_utils.cuh @@ -744,7 +744,8 @@ __device__ __forceinline__ void trap() { __device__ __forceinline__ int ld_volatile_global(int const* ptr) { int ret; #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) - ret = __atomic_load_n(const_cast(ptr), __ATOMIC_RELAXED); + ret = __hip_atomic_load(const_cast(ptr), __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); #else asm volatile("ld.volatile.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); #endif @@ -754,7 +755,8 @@ __device__ __forceinline__ int ld_volatile_global(int const* ptr) { __device__ __forceinline__ float ld_volatile_global(float const* ptr) { float ret; #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) - __atomic_load(const_cast(ptr), &ret, __ATOMIC_RELAXED); + ret = __hip_atomic_load(const_cast(ptr), __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); #else asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); #endif @@ -764,7 +766,8 @@ __device__ __forceinline__ float ld_volatile_global(float const* ptr) { __device__ __forceinline__ int64_t ld_volatile_global(int64_t const* ptr) { int64_t ret; #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) - ret = __atomic_load_n(const_cast(ptr), __ATOMIC_RELAXED); + ret = __hip_atomic_load(const_cast(ptr), __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); #else asm volatile("ld.volatile.global.s64 %0, [%1];" : "=l"(ret) : "l"(ptr)); #endif @@ -774,7 +777,8 @@ __device__ __forceinline__ int64_t ld_volatile_global(int64_t const* ptr) { __device__ __forceinline__ int64_t ld_volatile_global(uint64_t const* ptr) { int64_t ret; #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) - ret = __atomic_load_n(const_cast(ptr), __ATOMIC_RELAXED); + ret = __hip_atomic_load(const_cast(ptr), __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_SYSTEM); #else asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); #endif @@ -836,6 +840,7 @@ __forceinline__ __device__ void barrier_block(int** barrier_signal_ptrs, // Add self-ranks, sub other ranks if (thread_id < kNumRanks) { atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG); + memory_fence(); atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG); } EP_DEVICE_ASSERT(kNumRanks <= blockDim.x); diff --git a/ep/include/internode.cuh b/ep/include/internode.cuh index 3df61ca46..a4f04d1ad 100644 --- a/ep/include/internode.cuh +++ b/ep/include/internode.cuh @@ -31,8 +31,8 @@ void notify_dispatch( int const* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, int const* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, bool const* is_token_in_rank, int num_tokens, - int num_channels, int hidden_int4, int num_scales, int num_topk, - int expert_alignment, int* rdma_channel_prefix_matrix, + int num_worst_tokens, int num_channels, int hidden_int4, int num_scales, + int num_topk, int expert_alignment, int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, @@ -54,25 +54,23 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, uint64_t const* d2h_channel_addrs, int num_d2h_channel_addrs, void* atomic_buffer_ptr); -void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, - float* recv_topk_weights, void* recv_src_meta, void const* x, - float const* x_scales, int64_t const* topk_idx, - float const* topk_weights, int* send_rdma_head, - int* send_nvl_head, int* recv_rdma_channel_prefix_matrix, - int* recv_gbl_channel_prefix_matrix, - int const* rdma_channel_prefix_matrix, - int const* recv_rdma_rank_prefix_sum, - int const* gbl_channel_prefix_matrix, - int const* recv_gbl_rank_prefix_sum, bool const* is_token_in_rank, - int num_tokens, int hidden_int4, int num_scales, int num_topk, - int num_experts, int scale_token_stride, int scale_hidden_stride, - void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, - int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, - int num_max_nvl_chunked_send_tokens, - int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, - bool is_cached_dispatch, cudaStream_t stream, int num_channels, - bool low_latency_mode, uint64_t const* d2h_channel_addrs, - int num_d2h_channel_addrs, void* atomic_buffer_ptr); +void dispatch( + void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, + float* recv_topk_weights, void* recv_src_meta, void const* x, + float const* x_scales, int64_t const* topk_idx, float const* topk_weights, + int* send_rdma_head, int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, + int const* rdma_channel_prefix_matrix, int const* recv_rdma_rank_prefix_sum, + int const* gbl_channel_prefix_matrix, int const* recv_gbl_rank_prefix_sum, + bool const* is_token_in_rank, int num_tokens, int num_worst_tokens, + int hidden_int4, int num_scales, int num_topk, int num_experts, + int scale_token_stride, int scale_hidden_stride, void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, + bool is_cached_dispatch, cudaStream_t stream, int num_channels, + bool low_latency_mode, uint64_t const* d2h_channel_addrs, + int num_d2h_channel_addrs, void* atomic_buffer_ptr); void combine(cudaDataType_t type, void* combined_x, float* combined_topk_weights, diff --git a/ep/include/rdma.hpp b/ep/include/rdma.hpp index 6f229d2aa..14a4e6760 100644 --- a/ep/include/rdma.hpp +++ b/ep/include/rdma.hpp @@ -2,6 +2,11 @@ #define RDMA_HPP #include "common.hpp" #include "proxy_ctx.hpp" +// clang-format off +// prevent clang-format reordering net.h before util.h +#include "util/util.h" +#include "util/net.h" +// clang-format on #include "ring_buffer.cuh" #include "unistd.h" #include diff --git a/ep/install_deps.sh b/ep/install_deps.sh index 6883916b0..1a3ad05ed 100755 --- a/ep/install_deps.sh +++ b/ep/install_deps.sh @@ -50,7 +50,7 @@ if check_cuda; then elif check_rocm; then echo "Detected ROCM" # Install Pytorch using nightly - pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.0 + pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.1 else echo "No CUDA or ROCM detected" exit 1 diff --git a/ep/src/internode.cu b/ep/src/internode.cu index d17506056..0f7a141e4 100644 --- a/ep/src/internode.cu +++ b/ep/src/internode.cu @@ -97,13 +97,14 @@ __global__ void notify_dispatch( int const* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, int const* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, bool const* is_token_in_rank, int num_tokens, - int num_channels, int expert_alignment, int const rdma_clean_offset, - int const rdma_num_int_clean, int const nvl_clean_offset, - int const nvl_num_int_clean, int* rdma_channel_prefix_matrix, - int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, - int* recv_gbl_rank_prefix_sum, void* rdma_buffer_ptr, void** buffer_ptrs, - int** barrier_signal_ptrs, int rank, uint64_t const* d2h_channel_addrs, - int num_d2h_channel_addrs, void* atomic_buffer_ptr) { + int num_worst_tokens, int num_channels, int expert_alignment, + int const rdma_clean_offset, int const rdma_num_int_clean, + int const nvl_clean_offset, int const nvl_num_int_clean, + int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, void** buffer_ptrs, int** barrier_signal_ptrs, + int rank, uint64_t const* d2h_channel_addrs, int num_d2h_channel_addrs, + void* atomic_buffer_ptr) { void* original_rdma_buffer_ptr = rdma_buffer_ptr; auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x), @@ -273,9 +274,16 @@ __global__ void notify_dispatch( i)[NUM_MAX_NVL_PEERS + num_rdma_experts]; recv_rdma_rank_prefix_sum[i] = sum; } + // NOTE(MaoZiming): if I wrap this code with if (num_worst_tokens == 0), + // it will somehow cause deadlock on vllm with deepep_high_throughput + // mode. I suspect it is because some compiler reordering, but I don't + // know why. num_worst_tokens = 0, but somehow wrapping it with the + // conditional will cause deadlock. Removing the ``if" is logically + // redundant but harmless. if (num_worst_tokens == 0) { while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1) ; *moe_recv_rdma_counter_mapped = sum; + // } } // Send numbers of tokens per rank/expert to NVL ranks @@ -303,9 +311,11 @@ __global__ void notify_dispatch( sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank]; recv_gbl_rank_prefix_sum[i] = sum; } + // if (num_worst_tokens == 0) { while (ld_volatile_global(moe_recv_counter_mapped) != -1) ; *moe_recv_counter_mapped = sum; + // } } if (thread_id < num_nvl_experts) { int sum = 0; @@ -313,9 +323,11 @@ __global__ void notify_dispatch( for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id]; sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; + // if (num_worst_tokens == 0) { while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1) ; + // } moe_recv_expert_counter_mapped[thread_id] = sum; } @@ -394,8 +406,8 @@ void notify_dispatch( int const* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, int const* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, bool const* is_token_in_rank, int num_tokens, - int num_channels, int hidden_int4, int num_scales, int num_topk, - int expert_alignment, int* rdma_channel_prefix_matrix, + int num_worst_tokens, int num_channels, int hidden_int4, int num_scales, + int num_topk, int expert_alignment, int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, @@ -403,23 +415,24 @@ void notify_dispatch( cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, bool low_latency_mode, uint64_t const* d2h_channel_addrs, int num_d2h_channel_addrs, void* atomic_buffer_ptr) { -#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ - { \ - auto notify_dispatch_func = low_latency_mode \ - ? notify_dispatch \ - : notify_dispatch; \ - LAUNCH_KERNEL( \ - &cfg, notify_dispatch_func, num_tokens_per_rank, \ - moe_recv_counter_mapped, num_ranks, num_tokens_per_rdma_rank, \ - moe_recv_rdma_counter_mapped, num_tokens_per_expert, \ - moe_recv_expert_counter_mapped, num_experts, is_token_in_rank, \ - num_tokens, num_channels, expert_alignment, rdma_clean_meta.first, \ - rdma_clean_meta.second, nvl_clean_meta.first, nvl_clean_meta.second, \ - rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \ - gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, rdma_buffer_ptr, \ - buffer_ptrs, barrier_signal_ptrs, rank, d2h_channel_addrs, \ - num_d2h_channel_addrs, atomic_buffer_ptr); \ - } \ +#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ + { \ + auto notify_dispatch_func = low_latency_mode \ + ? notify_dispatch \ + : notify_dispatch; \ + LAUNCH_KERNEL(&cfg, notify_dispatch_func, num_tokens_per_rank, \ + moe_recv_counter_mapped, num_ranks, \ + num_tokens_per_rdma_rank, moe_recv_rdma_counter_mapped, \ + num_tokens_per_expert, moe_recv_expert_counter_mapped, \ + num_experts, is_token_in_rank, num_tokens, num_worst_tokens, \ + num_channels, expert_alignment, rdma_clean_meta.first, \ + rdma_clean_meta.second, nvl_clean_meta.first, \ + nvl_clean_meta.second, rdma_channel_prefix_matrix, \ + recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \ + recv_gbl_rank_prefix_sum, rdma_buffer_ptr, buffer_ptrs, \ + barrier_signal_ptrs, rank, d2h_channel_addrs, \ + num_d2h_channel_addrs, atomic_buffer_ptr); \ + } \ break #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) @@ -475,8 +488,9 @@ __global__ void __launch_bounds__( int const* recv_rdma_rank_prefix_sum, int const* gbl_channel_prefix_matrix, int const* recv_gbl_rank_prefix_sum, bool const* is_token_in_rank, - int num_tokens, int hidden_int4, int num_scales, int num_topk, - int num_experts, int scale_token_stride, int scale_hidden_stride, + int num_tokens, int num_worst_tokens, int hidden_int4, + int num_scales, int num_topk, int num_experts, + int scale_token_stride, int scale_hidden_stride, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, @@ -720,7 +734,7 @@ __global__ void __launch_bounds__( // Read RDMA rank existence uint64_t is_token_in_rank_uint64 = 0; if (lane_id < kNumRDMARanks) { - is_token_in_rank_uint64 = __ldg(reinterpret_cast( + is_token_in_rank_uint64 = *(reinterpret_cast( is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS)); } @@ -1194,6 +1208,9 @@ __global__ void __launch_bounds__( trap(); } } +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + memory_fence(); +#endif auto src_rdma_head = __shfl_sync(WARP_MASK, cached_rdma_channel_head, src_rdma_rank); auto src_rdma_tail = @@ -1484,27 +1501,42 @@ __global__ void __launch_bounds__( cached_channel_head_idx); } } + + // Clean unused `recv_topk_idx` as -1 + if (num_worst_tokens > 0) { + if (is_forwarder) return; + // get the actual number of num_recv_tokens on the current rank + int num_recv_tokens = recv_gbl_rank_prefix_sum[num_ranks - 1]; + // some ForwarderCoordinator threads exit early, so we only use + // non-forwarder in clean-up channel_id * num_threads is the offset of the + // current non-forwarder sms + auto const clean_start = + num_recv_tokens * num_topk + channel_id * num_threads; + auto const clean_end = num_worst_tokens * num_topk; + auto const clean_stride = num_channels * num_threads; +#pragma unroll + for (int i = clean_start + thread_id; i < clean_end; i += clean_stride) + recv_topk_idx[i] = -1; + } } -void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, - float* recv_topk_weights, void* recv_src_meta, void const* x, - float const* x_scales, int64_t const* topk_idx, - float const* topk_weights, int* send_rdma_head, - int* send_nvl_head, int* recv_rdma_channel_prefix_matrix, - int* recv_gbl_channel_prefix_matrix, - int const* rdma_channel_prefix_matrix, - int const* recv_rdma_rank_prefix_sum, - int const* gbl_channel_prefix_matrix, - int const* recv_gbl_rank_prefix_sum, bool const* is_token_in_rank, - int num_tokens, int hidden_int4, int num_scales, int num_topk, - int num_experts, int scale_token_stride, int scale_hidden_stride, - void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, - int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, - int num_max_nvl_chunked_send_tokens, - int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, - bool is_cached_dispatch, cudaStream_t stream, int num_channels, - bool low_latency_mode, uint64_t const* d2h_channel_addrs, - int num_d2h_channel_addrs, void* atomic_buffer_ptr) { +void dispatch( + void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, + float* recv_topk_weights, void* recv_src_meta, void const* x, + float const* x_scales, int64_t const* topk_idx, float const* topk_weights, + int* send_rdma_head, int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, + int const* rdma_channel_prefix_matrix, int const* recv_rdma_rank_prefix_sum, + int const* gbl_channel_prefix_matrix, int const* recv_gbl_rank_prefix_sum, + bool const* is_token_in_rank, int num_tokens, int num_worst_tokens, + int hidden_int4, int num_scales, int num_topk, int num_experts, + int scale_token_stride, int scale_hidden_stride, void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks, + bool is_cached_dispatch, cudaStream_t stream, int num_channels, + bool low_latency_mode, uint64_t const* d2h_channel_addrs, + int num_d2h_channel_addrs, void* atomic_buffer_ptr) { constexpr int kNumDispatchRDMASenderWarps = 7; constexpr int kNumTMABytesPerWarp = 16384; constexpr int smem_size = kNumTMABytesPerWarp * NUM_MAX_NVL_PEERS; @@ -1543,9 +1575,9 @@ void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, send_rdma_head, send_nvl_head, recv_rdma_channel_prefix_matrix, \ recv_gbl_channel_prefix_matrix, rdma_channel_prefix_matrix, \ recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \ - recv_gbl_rank_prefix_sum, is_token_in_rank, num_tokens, hidden_int4, \ - num_scales, num_topk, num_experts, scale_token_stride, \ - scale_hidden_stride, rdma_buffer_ptr, \ + recv_gbl_rank_prefix_sum, is_token_in_rank, num_tokens, \ + num_worst_tokens, hidden_int4, num_scales, num_topk, num_experts, \ + scale_token_stride, scale_hidden_stride, rdma_buffer_ptr, \ num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \ buffer_ptrs, num_max_nvl_chunked_send_tokens, \ num_max_nvl_chunked_recv_tokens, rank, num_ranks, d2h_channel_addrs, \ @@ -1594,7 +1626,7 @@ __global__ void cached_notify( // Barrier for RDMA if (thread_id == WARP_SIZE) uccl::nvshmem_sync_with_same_gpu_idx(d2h_channel_addrs, - num_d2h_channel_addrs, nvl_rank, 3); + num_d2h_channel_addrs, nvl_rank); // Barrier for NVL barrier_block(barrier_signal_ptrs, nvl_rank); diff --git a/ep/src/proxy.cpp b/ep/src/proxy.cpp index ee52ab5d2..e7dc0829f 100644 --- a/ep/src/proxy.cpp +++ b/ep/src/proxy.cpp @@ -45,7 +45,7 @@ LocalBarrier* map_local_barrier_shm(std::string const& name, bool* out_owner) { perror("shm_open(existing)"); return nullptr; } - struct stat st{}; + struct stat st {}; int tries = 1000; while (tries-- > 0) { if (fstat(fd, &st) == 0 && static_cast(st.st_size) >= kSize) @@ -186,9 +186,9 @@ void Proxy::init_common() { #ifdef EFA IBV_ACCESS_REMOTE_READ #else - IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC + IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC #endif - ); + ); if (!ctx_.atomic_buffer_mr) { perror("Failed to register atomic_buffer_ptr MR"); @@ -504,7 +504,7 @@ void Proxy::run_dual() { void Proxy::notify_gpu_completion(uint64_t& my_tail) { if (acked_wrs_.empty()) return; - // Mark all acked command slots in each ring's bitmask + // Mark all acked command slots in each ring's bitmask #ifdef USE_MSCCLPP_FIFO_BACKEND // FIFO path: pop in order using the pending deque and the completion set. for (size_t rb_idx = 0; rb_idx < cfg_.d2h_queues.size(); ++rb_idx) { @@ -912,6 +912,7 @@ void Proxy::post_gpu_commands_mixed( #ifdef USE_MSCCLPP_FIFO_BACKEND assert(barrier_wrs.size() == 1 && ctx_.barrier_wr == -1); #endif + assert(quiet_wrs.empty() && "quiet_wrs should be empty"); send_barrier(barrier_wrs[0]); barrier_wrs.clear(); barrier_cmds.clear(); @@ -1215,8 +1216,7 @@ void Proxy::barrier_check() { // When global release comes back (CQ handler should set these): // NOTE: BarrierImm is 21 bits, so we must mask the local seq. - if (ctx_.barrier_released && - ctx_.barrier_release_seq == seq) { + if (ctx_.barrier_released && ctx_.barrier_release_seq == seq) { // Reset local mask for next barrier and consume the global release ctx_.barrier_released = false; diff --git a/ep/src/rdma.cpp b/ep/src/rdma.cpp index 1f9ee21b1..b8b802c3b 100644 --- a/ep/src/rdma.cpp +++ b/ep/src/rdma.cpp @@ -5,9 +5,6 @@ #include "proxy_ctx.hpp" #include "rdma_util.hpp" #include "util/gpu_rt.h" -#include "util/util.h" -// net.h should be included after util.h -#include "util/net.h" #include #include #include @@ -232,8 +229,8 @@ void per_thread_rdma_init(ProxyCtx& S, void* gpu_buf, size_t bytes, int rank, IBV_ACCESS_REMOTE_ATOMIC); #else S.mr = ibv_reg_mr_iova2(S.pd, gpu_buf, bytes, iova, - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | - IBV_ACCESS_RELAXED_ORDERING); + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_RELAXED_ORDERING); #endif if (!S.mr) { @@ -554,6 +551,13 @@ void modify_qp_to_rtr(ProxyCtx& S, RDMAConnectionInfo* remote, exit(1); } + // Query device attributes to get max_dest_rd_atomic + struct ibv_device_attr dev_attr; + if (ibv_query_device(S.context, &dev_attr)) { + perror("Failed to query device attributes"); + exit(1); + } + if (port_attr.link_layer == IBV_LINK_LAYER_ETHERNET) { printf("RoCE detected (Ethernet)\n"); is_roce = 1; @@ -571,16 +575,16 @@ void modify_qp_to_rtr(ProxyCtx& S, RDMAConnectionInfo* remote, attr.path_mtu = port_attr.active_mtu; attr.dest_qp_num = remote->qp_num; attr.rq_psn = remote->psn; - attr.max_dest_rd_atomic = 1; + attr.max_dest_rd_atomic = dev_attr.max_qp_init_rd_atom; attr.min_rnr_timer = 12; if (is_roce) { attr.ah_attr.is_global = 1; attr.ah_attr.port_num = 1; - attr.ah_attr.sl = 135; + attr.ah_attr.sl = 1; attr.ah_attr.src_path_bits = 0; - attr.ah_attr.grh.traffic_class = 3; - attr.ah_attr.grh.hop_limit = 64; + attr.ah_attr.grh.traffic_class = 1; + attr.ah_attr.grh.hop_limit = 255; // Fill GID from remote_info memcpy(&attr.ah_attr.grh.dgid, remote->gid, 16); attr.ah_attr.grh.sgid_index = S.gid_index; @@ -655,14 +659,21 @@ void modify_qp_to_rts(ProxyCtx& S, RDMAConnectionInfo* local_info) { #ifdef EFA return; #endif + // Query device attributes to get max_rd_atomic + struct ibv_device_attr dev_attr; + if (ibv_query_device(S.context, &dev_attr)) { + perror("Failed to query device attributes"); + exit(1); + } + struct ibv_qp_attr attr; memset(&attr, 0, sizeof(attr)); attr.qp_state = IBV_QPS_RTS; - attr.timeout = 14; + attr.timeout = 20; attr.retry_cnt = 7; attr.rnr_retry = 7; attr.sq_psn = local_info->psn; - attr.max_rd_atomic = 1; + attr.max_rd_atomic = dev_attr.max_qp_rd_atom; attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_ATOMIC; int flags = IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | diff --git a/ep/src/uccl_ep.cc b/ep/src/uccl_ep.cc index 5ba63839e..c57314704 100644 --- a/ep/src/uccl_ep.cc +++ b/ep/src/uccl_ep.cc @@ -138,10 +138,6 @@ class Buffer { #endif } -#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) - // Note(huangzhen): It will make d_handles turn to nullptr in rocm7.0, - // so we don't prefetch d_handles. -#else // Prefetch so the device immediately sees initialized contents CUDA_CHECK(cudaMemPrefetchAsync( d_handle_objs, num_d2h_channel_addrs * sizeof(d2hq::D2HHandle), @@ -150,7 +146,6 @@ class Buffer { d_handles, num_d2h_channel_addrs * sizeof(uint64_t), device_index)); CUDA_CHECK(cudaDeviceSynchronize()); -#endif } // Allocate device memory for IPC base pointers CUDA_CHECK( @@ -416,7 +411,7 @@ class Buffer { std::optional const& cached_recv_rdma_rank_prefix_sum, std::optional const& cached_gbl_channel_prefix_matrix, std::optional const& cached_recv_gbl_rank_prefix_sum, - int expert_alignment, uccl::Config const& config, + int expert_alignment, int num_worst_tokens, uccl::Config const& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { // In dispatch, CPU will busy-wait until GPU receive tensor size metadata @@ -598,8 +593,8 @@ class Buffer { num_ranks, num_tokens_per_rdma_rank->data_ptr(), moe_recv_rdma_counter_mapped, num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, - is_token_in_rank.data_ptr(), num_tokens, num_channels, - hidden_int4, num_scales, num_topk, expert_alignment, + is_token_in_rank.data_ptr(), num_tokens, num_worst_tokens, + num_channels, hidden_int4, num_scales, num_topk, expert_alignment, rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), @@ -611,37 +606,42 @@ class Buffer { num_ranks), num_nvl_bytes, low_latency_mode, d_handles, num_d2h_channel_addrs, atomic_buffer_ptr); - // Synchronize total received tokens and tokens per expert - auto start_time = std::chrono::high_resolution_clock::now(); - while (true) { - // Read total count - num_recv_tokens = static_cast(*moe_recv_counter); - num_rdma_recv_tokens = static_cast(*moe_recv_rdma_counter); - - // Read per-expert count - bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0); - for (int i = 0; i < num_local_experts and ready; ++i) - ready &= moe_recv_expert_counter[i] >= 0; - - if (ready) break; - - // Timeout check - if (std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - start_time) - .count() > NUM_CPU_TIMEOUT_SECS) { - printf( - "Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: " - "%d\n", - rank, num_recv_tokens, num_rdma_recv_tokens); - for (int i = 0; i < num_local_experts; ++i) - printf("moe_recv_expert_counter[%d]: %d\n", i, - moe_recv_expert_counter[i]); - throw std::runtime_error("DeepEP error: timeout (dispatch CPU)"); + if (num_worst_tokens > 0) { + num_recv_tokens = num_worst_tokens; + num_rdma_recv_tokens = num_worst_tokens; + } else { + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + // Read total count + num_recv_tokens = static_cast(*moe_recv_counter); + num_rdma_recv_tokens = static_cast(*moe_recv_rdma_counter); + + // Read per-expert count + bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0); + for (int i = 0; i < num_local_experts and ready; ++i) + ready &= moe_recv_expert_counter[i] >= 0; + + if (ready) break; + + // Timeout check + if (std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - start_time) + .count() > NUM_CPU_TIMEOUT_SECS) { + printf( + "Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: " + "%d\n", + rank, num_recv_tokens, num_rdma_recv_tokens); + for (int i = 0; i < num_local_experts; ++i) + printf("moe_recv_expert_counter[%d]: %d\n", i, + moe_recv_expert_counter[i]); + throw std::runtime_error("DeepEP error: timeout (dispatch CPU)"); + } } + num_recv_tokens_per_expert_list = + std::vector(moe_recv_expert_counter, + moe_recv_expert_counter + num_local_experts); } - num_recv_tokens_per_expert_list = std::vector( - moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); } // Allocate new tensors @@ -705,9 +705,10 @@ class Buffer { recv_rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), - is_token_in_rank.data_ptr(), num_tokens, hidden_int4, num_scales, - num_topk, num_experts, scale_token_stride, scale_hidden_stride, - rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, + is_token_in_rank.data_ptr(), num_tokens, num_worst_tokens, + hidden_int4, num_scales, num_topk, num_experts, scale_token_stride, + scale_hidden_stride, rdma_buffer_ptr, + config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, rank, num_ranks, cached_mode, @@ -873,7 +874,6 @@ class Buffer { config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, false, low_latency_mode, d_handles, num_d2h_channel_addrs, atomic_buffer_ptr); - // Assign bias pointers auto bias_opts = std::vector>({bias_0, bias_1}); @@ -2024,6 +2024,8 @@ PYBIND11_MODULE(ep, m) { uccl::g_proxies_by_dev.clear(); }); + m.def("get_oob_ip", &uccl::get_oob_ip, "Get the OOB IP address"); + m.def("get_rdma_buffer", [](int64_t num_rdma_bytes, int device_index) { void* ptr; CUDA_CHECK(cudaSetDevice(device_index)); diff --git a/experimental/eccl/Makefile b/experimental/eccl/Makefile index 6635c6124..7881223f8 100644 --- a/experimental/eccl/Makefile +++ b/experimental/eccl/Makefile @@ -19,7 +19,7 @@ NVCCFLAGS := -g -O0 -std=c++17 \ LDFLAGS := -L$(CONDA_LIB_HOME) \ -L$(CUDA_PATH)/lib64 -Wl,-rpath,$(CUDA_PATH)/lib64 \ - -lcudart \ + -lcudart -lcuda \ -lgflags -lgtest -lz -lelf -libverbs -lnl-3 -lnl-route-3 -lpthread -lglog INCLUDES := -Iinclude -I$(CUDA_PATH)/include -I/usr/include -I../../include diff --git a/experimental/eccl/README.md b/experimental/eccl/README.md index 9f7d2ece8..545b8a61c 100644 --- a/experimental/eccl/README.md +++ b/experimental/eccl/README.md @@ -1,5 +1,10 @@ # ECCL +``` +sudo apt-get update +sudo apt-get install -y libelf-dev +``` + ## develpment node: on AMD @@ -28,4 +33,15 @@ cd experimental/eccl/src/device make clean -f Makefile && make -j$(nproc) -f Makefile CUDA_VISIBLE_DEVICES=5 ./test_persistent +``` + + +## test +``` +# test communicator +CUDA_VISIBLE_DEVICES=5 ./test_main --role=server +CUDA_VISIBLE_DEVICES=5 ./test_main --role=client + +# test others +CUDA_VISIBLE_DEVICES=5 ./test_main --role=server ``` \ No newline at end of file diff --git a/experimental/eccl/include/oob.h b/experimental/eccl/include/oob.h index 6098eba5a..6105f92df 100644 --- a/experimental/eccl/include/oob.h +++ b/experimental/eccl/include/oob.h @@ -1,5 +1,6 @@ #pragma once +#include "util/gpu_rt.h" #include #include @@ -24,6 +25,7 @@ #include #include +// --- Group Exchanger --- struct Exchangeable { virtual std::map to_map() const = 0; virtual void from_map(std::map const& kv) = 0; @@ -234,3 +236,111 @@ class SockExchanger : public Exchanger { std::mutex&, std::atomic&, size_t, std::function); }; + +// --- P2P Exchanger --- + +struct IpcCache { + gpuIpcMemHandle_t handle; + bool is_send; + void* direct_ptr; // ptr of remote, local get it by mapping from handle + uintptr_t offset; + size_t size; +}; + +// Only message type needed for now +static constexpr uint16_t kTypeIpcCache = 1; +static constexpr uint16_t kTypeAck = 2; + +// uds payload +#pragma pack(push, 1) +struct IpcCacheWire { + gpuIpcMemHandle_t handle; + uint8_t is_send; + uint64_t offset; + uint64_t size; + uint32_t remote_gpu_idx_; +}; +#pragma pack(pop) + +struct AckWire { + uint32_t status; // 0=fail, 1=ok, or extend + uint32_t reserved; // keep 8B aligned +}; + +class UdsExchanger { + public: + UdsExchanger(int self_rank); + ~UdsExchanger(); + + // Lazy start local server (bind/listen) once. + bool ensure_server_started(); + + // Client: connect to peer's UDS with retry until timeout. Idempotent. + bool connect_to(int peer_rank, int timeout_ms = 30000); + + // Server: accept connections until we receive one from peer_rank (or + // timeout). Idempotent: returns true immediately if already connected. + bool accept_from(int peer_rank, int timeout_ms = 30000); + + // Generic framed send + bool send(int peer_rank, uint16_t type, uint64_t seq, void const* payload, + uint32_t bytes); + + // Convenience: send IPC cache + bool send_ipc_cache(int peer_rank, uint64_t seq, IpcCacheWire const& cache); + bool recv_ipc_cache(int peer_rank, IpcCacheWire& out_cache, + uint64_t* out_seq = nullptr, int timeout_ms = 30000); + bool send_ack(int peer_rank, uint64_t seq, uint32_t status = 1); + bool recv_ack(int peer_rank, uint32_t* out_status = nullptr, + uint64_t* out_seq = nullptr, int timeout_ms = 30000, + uint64_t expected_seq = UINT64_MAX); + + int get_fd(int peer_rank) const; + void close_peer(int peer_rank); + + private: + struct Hello { + uint32_t magic; + int32_t from_rank; + int32_t to_rank; + uint32_t version; + }; + + struct MsgHdr { + uint32_t magic; + uint16_t version; + uint16_t type; + uint32_t bytes; + int32_t from_rank; + int32_t to_rank; + uint64_t seq; + }; + + private: + std::string path_for_rank(int rank); + bool connect_once(std::string const& peer_path, int& out_fd); + + bool send_all(int fd, char const* buf, size_t len); + bool recv_all(int fd, char* buf, size_t len); + + // Accept one connection with a poll-like timeout. + // Returns: fd>=0 on success, -1 on timeout, -2 on fatal error. + int accept_with_timeout(int timeout_ms); + + private: + int self_rank_; + + // server state + std::atomic running_{false}; + int listen_fd_{-1}; + std::string local_path_; + + // connections + mutable std::mutex mu_; + std::unordered_map rank_to_fd_; + std::unordered_map> rank_send_mu_; + std::unordered_map> rank_recv_mu_; + + // Ensure only one thread is doing accept() at a time. + mutable std::mutex accept_mu_; +}; \ No newline at end of file diff --git a/experimental/eccl/include/transport.h b/experimental/eccl/include/transport.h index 5b2464062..7b9b4e1ad 100644 --- a/experimental/eccl/include/transport.h +++ b/experimental/eccl/include/transport.h @@ -6,6 +6,8 @@ #include "util/gpu_rt.h" #include "util/jring.h" #include +#include +#include #include #include #include @@ -14,6 +16,8 @@ #include #include +enum class EndpointType { RDMA, IPC }; + class Communicator; class CQPoller; class EndpointBase { @@ -24,6 +28,8 @@ class EndpointBase { virtual bool recv_async(int from_rank, std::shared_ptr creq) = 0; std::atomic next_send_seq_{0}; std::atomic next_recv_seq_{0}; + + EndpointType type; }; class RDMAEndpoint : public EndpointBase { @@ -59,13 +65,9 @@ class RDMAEndpoint : public EndpointBase { Communicator* comm_; }; -struct IPCcontext { - gpuIpcMemHandle_t handle; - bool is_send; - void* direct_ptr; - uintptr_t offset; - size_t size; -}; +static constexpr size_t kTaskRingSize = 1024; +static constexpr size_t kIpcAlignment = 1ul << 20; +static constexpr size_t kIpcSizePerEngine = 1ul << 20; class IPCEndpoint : public EndpointBase { public: @@ -79,19 +81,32 @@ class IPCEndpoint : public EndpointBase { bool recv_async(int from_rank, std::shared_ptr creq) override; private: - std::shared_ptr ipc_context; + enum class IpcTaskType : uint8_t { SEND, RECV }; + struct IpcTask { + IpcTaskType type; + int peer_rank; + std::shared_ptr req; + uint64_t enqueue_ns; + uint32_t retry; + }; + + bool send_(int to_rank, std::shared_ptr creq); + bool recv_(int from_rank, std::shared_ptr creq); + + jring_t* task_ring_; + std::atomic stop_{false}; + std::thread proxy_thread_; + std::mutex cv_mu_; + std::condition_variable cv_; + std::atomic pending_{0}; // number of queued tasks + void proxy_thread_func(); + + std::vector ipc_streams_; // n_streams + std::shared_ptr config_; Communicator* comm_; }; -struct IpcCache { - gpuIpcMemHandle_t handle; - bool is_send; - void* direct_ptr; // for remote - uintptr_t offset; - size_t size; -}; - // one gpu with the best nic class Communicator { public: @@ -126,16 +141,14 @@ class Communicator { MR reg_mr(void* local_buf, size_t len); bool dereg_mr(void* local_buf); bool notify_mr(int remote_rank, MR& mr); - MR wait_mr_notify(int remote_rank); + bool wait_mr_notify(int remote_rank, MR& mr); MR get_local_mr(void* local_buf); MR get_local_mr(uint16_t mr_id); MR get_remote_mr(int remote_rank, uint16_t mr_id); - bool register_local_ipc_cache(void* local_buf); - bool register_remote_ipc_cache(int remote_rank, void* local_buf, + bool register_remote_ipc_cache(int remote_rank, gpuIpcMemHandle_t handle, IpcCache const& cache); - IpcCache get_local_ipc_cache(void* local_buf); - IpcCache get_remote_ipc_cache(int remote_rank, void* local_buf); + IpcCache get_remote_ipc_cache(int remote_rank, gpuIpcMemHandle_t handle); ibv_cq* get_cq_by_index(int index); @@ -149,8 +162,8 @@ class Communicator { mutable std::mutex meta_mu_; // ---------- GPU / NIC info -------- - int gpu_id_; // todo, this is true local_rank_ - int local_rank_; // todo, replace with rank_ + int local_rank_; // gpu_id_ + int global_rank_; int world_size_; bool support_rdma; bool support_rdma_roce; @@ -173,13 +186,27 @@ class Communicator { std::atomic next_mr_id{0}; // ---------- IPC resources --------- - std::unordered_map ptr_to_local_ipc_cache_; - std::unordered_map> - rank_ptr_to_ipc_cache_; + using HandleKey = std::array; + static inline HandleKey MakeHandleKey(gpuIpcMemHandle_t const& h) { + HandleKey k{}; + std::memcpy(k.data(), &h, k.size()); + return k; + } + struct HandleKeyHash { + size_t operator()(HandleKey const& k) const noexcept { + uint64_t hash = 1469598103934665603ull; + for (uint8_t b : k) { + hash ^= b; + hash *= 1099511628211ull; + } + return (size_t)hash; + } + }; + std::shared_ptr uds_; + using HandleCacheMap = std::unordered_map; + std::unordered_map rank_handle_to_ipc_cache_; mutable std::mutex local_ipc_cache_mu_; mutable std::mutex remote_ipc_cache_mu_; - // ipc_stream_[gpu_num], for ipc_send/recv - // uds_fd_ init with rank_, per communicator; connect with rank_ // ---------- Config & Redis -------- std::shared_ptr config_; diff --git a/experimental/eccl/src/communicator.cc b/experimental/eccl/src/communicator.cc index 9eb67bb18..66eb03571 100644 --- a/experimental/eccl/src/communicator.cc +++ b/experimental/eccl/src/communicator.cc @@ -31,12 +31,12 @@ std::string get_local_ip() { Communicator::Communicator(int gpu_id, int rank, int world_size, std::shared_ptr config) - : gpu_id_(gpu_id), - local_rank_(rank), + : local_rank_(gpu_id), + global_rank_(rank), world_size_(world_size), config_(config) { // Find best NIC for current gpu - auto [nic_id, nic_name] = find_best_rdma_for_gpu(gpu_id_); + auto [nic_id, nic_name] = find_best_rdma_for_gpu(gpu_id); std::cout << "[INFO] Using RDMA NIC " << nic_name << std::endl; if (nic_id != -1) { // Support RDMA struct ibv_device** dev_list = ibv_get_device_list(nullptr); @@ -65,8 +65,8 @@ Communicator::Communicator(int gpu_id, int rank, int world_size, } ibv_free_device_list(dev_list); - std::cout << "[INFO] Communicator " << local_rank_ << " initialized: GPU " - << gpu_id_ << " map to RDMA NIC " << nic_name << std::endl; + std::cout << "[INFO] Communicator " << global_rank_ << " initialized: GPU " + << gpu_id << " map to RDMA NIC " << nic_name << std::endl; support_rdma = true; // Init RAMD resource @@ -119,31 +119,33 @@ Communicator::Communicator(int gpu_id, int rank, int world_size, uccl::create_ring(sizeof(unsigned), 16); // change num later } else { // Does not support RDMA - // TODO: if we can't find any rdma nic, we still can do ipc comm on local + // If we can't find any rdma nic, we still can do ipc comm on local // host. support_rdma = false; } + uds_ = std::make_shared(global_rank_); + // Initialize communicator meta CommunicatorMeta local{}; local.host_id = generate_host_id(); local.is_ready = true; local.ip = get_local_ip(); - set_communicator_meta_with_rank(local_rank_, local); + set_communicator_meta_with_rank(global_rank_, local); // Initialize Redis client #ifdef USE_REDIS_OOB exchanger_client_ = std::make_shared(config_->exchanger_ip, config_->exchanger_port); #else - bool is_server = (local_rank_ == 0); + bool is_server = (global_rank_ == 0); if (!is_server && config_->exchanger_ip == "0.0.0.0") config_->exchanger_ip = "127.0.0.1"; std::cout << "[INFO] Using socket-based exchanger as " << (is_server ? "server" : "client") << " " << config_->exchanger_ip << std::endl; exchanger_client_ = std::make_shared( - (local_rank_ == 0), config_->exchanger_ip, config_->exchanger_port); + (global_rank_ == 0), config_->exchanger_ip, config_->exchanger_port); #endif if (!exchanger_client_->valid()) { fprintf(stderr, "[ERROR] Failed to connect to Exchanger\n"); @@ -151,7 +153,7 @@ Communicator::Communicator(int gpu_id, int rank, int world_size, } // Exchange communicator meta - std::string meta_key = "meta:" + std::to_string(local_rank_); + std::string meta_key = "meta:" + std::to_string(global_rank_); if (!exchanger_client_->publish(meta_key, local)) { fprintf(stderr, "[ERROR] Failed to publish local CommunicatorMeta \n"); } @@ -159,7 +161,7 @@ Communicator::Communicator(int gpu_id, int rank, int world_size, // Get all others meta CommunicatorMeta remote{}; for (int i = 0; i < world_size_; i++) { - if (i == local_rank_) continue; + if (i == global_rank_) continue; std::string key = "meta:" + std::to_string(i); if (exchanger_client_->wait_and_fetch(key, remote, -1)) { set_communicator_meta_with_rank(i, remote); @@ -167,7 +169,7 @@ Communicator::Communicator(int gpu_id, int rank, int world_size, fprintf(stderr, "[WARN] Timeout waiting for remote CommunicatorMeta \n"); } } - std::cout << "[INFO] Communicator " << local_rank_ + std::cout << "[INFO] Communicator " << global_rank_ << " initialized: rank_to_comm_meta_ success" << std::endl; } @@ -194,12 +196,19 @@ Communicator::~Communicator() { } // Deregister local memory regions + std::vector bufs; { std::lock_guard lk(local_mr_mu_); - for (auto& [ptr, mr] : ptr_to_local_ibv_mr_) { - dereg_mr(ptr); + bufs.reserve(ptr_to_local_ibv_mr_.size()); + for (auto& kv : ptr_to_local_ibv_mr_) { + bufs.push_back(kv.first); // ptr } - ptr_to_local_ibv_mr_.clear(); + } + for (auto* p : bufs) { + dereg_mr(p); + } + { + std::lock_guard lk(local_mr_mu_); mr_id_to_local_mr_.clear(); } @@ -208,7 +217,7 @@ Communicator::~Communicator() { for (auto& cq : cq_list_) { if (cq) { if (ibv_destroy_cq(cq)) { - std::cerr << "[WARN] Communicator " << local_rank_ + std::cerr << "[WARN] Communicator " << global_rank_ << " Failed to destroy CQ" << std::endl; } } @@ -219,7 +228,7 @@ Communicator::~Communicator() { // Deallocate PD if (pd_) { if (ibv_dealloc_pd(pd_)) { - std::cerr << "[WARN] Communicator " << local_rank_ + std::cerr << "[WARN] Communicator " << global_rank_ << " Failed to deallocate PD" << std::endl; } pd_ = nullptr; @@ -228,7 +237,7 @@ Communicator::~Communicator() { // Close device if (nic_ibv_ctx_) { if (ibv_close_device(nic_ibv_ctx_)) { - std::cerr << "[WARN] Communicator " << local_rank_ + std::cerr << "[WARN] Communicator " << global_rank_ << " Failed to close IBV device context" << std::endl; } nic_ibv_ctx_ = nullptr; @@ -241,13 +250,9 @@ Communicator::~Communicator() { } // Clear IPC caches - { - std::lock_guard lk(local_ipc_cache_mu_); - ptr_to_local_ipc_cache_.clear(); - } { std::lock_guard lk(remote_ipc_cache_mu_); - rank_ptr_to_ipc_cache_.clear(); + rank_handle_to_ipc_cache_.clear(); } // Free pending_req_id_to_deal_ buffer @@ -256,13 +261,13 @@ Communicator::~Communicator() { pending_req_id_to_deal_ = nullptr; } - std::cout << "[INFO] Communicator " << local_rank_ << " resources released" + std::cout << "[INFO] Communicator " << global_rank_ << " resources released" << std::endl; } bool Communicator::connect_to(int rank) { if (!check_ready()) { - std::cerr << "[WARN] Communicator " << local_rank_ + std::cerr << "[WARN] Communicator " << global_rank_ << " not ready, cannot connect to rank " << rank << std::endl; return false; } @@ -270,52 +275,60 @@ bool Communicator::connect_to(int rank) { auto [existing_ep, ok] = get_endpoint_by_rank(rank); if (ok && existing_ep) return true; // already - if (rank == local_rank_) { + if (rank == global_rank_) { return true; } if (rank < 0 || rank >= world_size_) { - std::cerr << "[ERROR] Communicator " << local_rank_ << " invalid rank " + std::cerr << "[ERROR] Communicator " << global_rank_ << " invalid rank " << rank << ", world_size=" << world_size_ << std::endl; return false; } auto meta = get_communicator_meta_by_rank(rank); - auto local_meta = get_communicator_meta_by_rank(local_rank_); + auto local_meta = get_communicator_meta_by_rank(global_rank_); if (!meta) { - std::cerr << "[ERROR] Communicator " << local_rank_ + std::cerr << "[ERROR] Communicator " << global_rank_ << " CommunicatorMeta not found for rank " << rank << std::endl; return false; } - // We use RDMA for test now. TODO: support IPC bool same_host = meta->host_id == local_meta->host_id; - same_host = false; // only RDMA now + // same_host = false; // force RDMA std::shared_ptr ep; bool ret = false; if (same_host) { - // std::cout << "[INFO] Communicator " << local_rank_ + // std::cout << "[INFO] Communicator " << global_rank_ // << " same host detected, using IPC endpoint" << std::endl; ep = std::make_shared(config_, this); - // ep->connect_to(rank); // optional if IPCEndpoint needs explicit connect - ret = true; + ret = ep->connect_to(rank); + if (ret) { + std::cout << "[INFO] Communicator " << global_rank_ + << " IPC connect_to succeeded to rank " << rank << std::endl; + } else { + std::cerr << "[ERROR] Communicator " << global_rank_ + << " IPC connect_to failed to rank " << rank << std::endl; + return false; + } + ep->type = EndpointType::IPC; } else { - // std::cout << "[INFO] Communicator " << local_rank_ + // std::cout << "[INFO] Communicator " << global_rank_ // << " different host detected, using RDMA endpoint" << // std::endl; ep = std::make_shared(config_, this); ret = ep->connect_to(rank); if (ret) { - // std::cout << "[INFO] Communicator " << local_rank_ - // << " RDMA connect_to succeeded to rank " << rank << - // std::endl; + std::cout << "[INFO] Communicator " << global_rank_ + << " RDMA connect_to succeeded to rank " << rank << std::endl; } else { - std::cerr << "[ERROR] Communicator " << local_rank_ + std::cerr << "[ERROR] Communicator " << global_rank_ << " RDMA connect_to failed to rank " << rank << std::endl; + return false; } + ep->type = EndpointType::RDMA; } { @@ -325,7 +338,52 @@ bool Communicator::connect_to(int rank) { return ret; } -bool Communicator::accept_from(int rank) { return connect_to(rank); } +bool Communicator::accept_from(int rank) { + if (!check_ready()) return false; + if (rank == global_rank_) return true; + + auto [existing_ep, ok] = get_endpoint_by_rank(rank); + if (ok && existing_ep) return true; + + auto meta = get_communicator_meta_by_rank(rank); + auto local_meta = get_communicator_meta_by_rank(global_rank_); + if (!meta || !local_meta) return false; + + bool same_host = meta->host_id == local_meta->host_id; + // same_host = false; // force RDMA + + std::shared_ptr ep; + bool ret = false; + + if (same_host) { + ep = std::make_shared(config_, this); + ret = ep->accept_from(rank); + if (ret) { + std::cout << "[INFO] Communicator " << global_rank_ + << " IPC accept_from succeeded from rank " << rank << std::endl; + } else { + std::cerr << "[ERROR] Communicator " << global_rank_ + << " IPC accept_from failed from rank " << rank << std::endl; + } + } else { + // RDMA: accept == connect + ep = std::make_shared(config_, this); + ret = ep->connect_to(rank); + if (ret) { + std::cout << "[INFO] Communicator " << global_rank_ + << " RDMA accept succeeded from rank " << rank << std::endl; + } else { + std::cerr << "[ERROR] Communicator " << global_rank_ + << " RDMA accept failed from rank " << rank << std::endl; + } + } + + { + std::lock_guard lk(ep_mu_); + rank_to_endpoint_[rank] = ep; + } + return ret; +} std::tuple, bool> Communicator::get_endpoint_by_rank(int rank) { @@ -376,7 +434,7 @@ unsigned Communicator::irecv(int rank, void* ptr, size_t offset, size_t len, uint16_t seq_val = ep->next_recv_seq_.fetch_add(1, std::memory_order_relaxed) % 4095; uint16_t safe_seq = seq_val + 1; // [1, 4095] - unsigned rid = make_request_id(local_rank_, local_mr.id, safe_seq); + unsigned rid = make_request_id(global_rank_, local_mr.id, safe_seq); auto req = std::make_shared(rid, ptr, offset, len, -1, -1, on_gpu, RequestType::RECV); @@ -411,7 +469,7 @@ unsigned Communicator::irecv_red(int rank, void* ptr, size_t offset, size_t len, uint16_t seq_val = ep->next_recv_seq_.fetch_add(1, std::memory_order_relaxed) % 4095; uint16_t safe_seq = seq_val + 1; // [1, 4095] - unsigned rid = make_request_id(local_rank_, local_mr.id, safe_seq); + unsigned rid = make_request_id(global_rank_, local_mr.id, safe_seq); auto req = std::make_shared(rid, ptr, offset, len, -1, -1, on_gpu, RequestType::RECV, true, red_op); @@ -546,7 +604,7 @@ bool Communicator::check_ready() { // Check meta map size if (static_cast(rank_to_comm_meta_.size()) < world_size_) { - std::cerr << "[WARN] Communicator " << local_rank_ + std::cerr << "[WARN] Communicator " << global_rank_ << " check_ready: rank_to_comm_meta_ size " << rank_to_comm_meta_.size() << " < world_size " << world_size_ << std::endl; @@ -557,13 +615,13 @@ bool Communicator::check_ready() { for (int i = 0; i < world_size_; i++) { auto it = rank_to_comm_meta_.find(i); if (it == rank_to_comm_meta_.end()) { - std::cerr << "[WARN] Communicator " << local_rank_ + std::cerr << "[WARN] Communicator " << global_rank_ << " check_ready: missing CommunicatorMeta for rank " << i << std::endl; return false; } if (!it->second->is_ready) { - std::cerr << "[WARN] Communicator " << local_rank_ + std::cerr << "[WARN] Communicator " << global_rank_ << " check_ready: CommunicatorMeta for rank " << i << " is not ready" << std::endl; return false; @@ -573,12 +631,12 @@ bool Communicator::check_ready() { // Check RDMA NIC context if supported if (support_rdma) { if (!nic_ibv_ctx_) { - std::cerr << "[WARN] Communicator " << local_rank_ + std::cerr << "[WARN] Communicator " << global_rank_ << " check_ready: nic_ibv_ctx_ is nullptr" << std::endl; return false; } if (!pd_) { - std::cerr << "[WARN] Communicator " << local_rank_ + std::cerr << "[WARN] Communicator " << global_rank_ << " check_ready: pd_ is nullptr" << std::endl; return false; } @@ -587,7 +645,7 @@ bool Communicator::check_ready() { } } - std::cerr << "[INFO] Communicator " << local_rank_ << " is ready" + std::cerr << "[INFO] Communicator " << global_rank_ << " is ready" << std::endl; return true; } @@ -633,7 +691,7 @@ bool Communicator::dereg_mr(void* local_buf) { if (mr) { if (ibv_dereg_mr(mr) != 0) { - std::cerr << "[WARN] Communicator " << local_rank_ + std::cerr << "[WARN] Communicator " << global_rank_ << " Failed to deregister local MR" << std::endl; return false; } else { @@ -647,44 +705,65 @@ bool Communicator::dereg_mr(void* local_buf) { bool Communicator::notify_mr(int remote_rank, MR& mr) { if (!exchanger_client_ || !exchanger_client_->valid()) return false; + // we assume that user will connect to remote before notify MR. + auto [ep, ok] = get_endpoint_by_rank(remote_rank); + if (!ok || !ep) { + throw std::runtime_error("Endpoint is not valid"); + return false; + } + if (ep->type != EndpointType::RDMA) { + std::cout << "MR only support for EndpointRDMA, skip notify mr" + << std::endl; + return true; + } + MRInfos wrapper; wrapper.mrs.push_back(mr); std::cout << "[notify MR to rank " << remote_rank << "] addr=" << mr.address << " length=" << mr.length << " key=" << mr.key << std::endl; std::string key = - "mr:" + std::to_string(local_rank_) + "->" + std::to_string(remote_rank); + "mr:" + std::to_string(global_rank_) + "->" + std::to_string(remote_rank); return exchanger_client_->publish(key, wrapper); } -MR Communicator::wait_mr_notify(int remote_rank) { +bool Communicator::wait_mr_notify(int remote_rank, MR& mr) { if (!exchanger_client_ || !exchanger_client_->valid()) { - throw std::runtime_error("Redis client not valid"); + throw std::runtime_error("Exchanger client is not valid"); + } + + auto [ep, ok] = get_endpoint_by_rank(remote_rank); + if (!ok || !ep) { + throw std::runtime_error("Endpoint is not valid"); + } + if (ep->type != EndpointType::RDMA) { + std::cout << "MR only support for EndpointRDMA, skip wait_mr_notify" + << std::endl; + return true; } std::string key = - "mr:" + std::to_string(remote_rank) + "->" + std::to_string(local_rank_); + "mr:" + std::to_string(remote_rank) + "->" + std::to_string(global_rank_); MRInfos wrapper; - bool ok = exchanger_client_->wait_and_fetch(key, wrapper); + ok = exchanger_client_->wait_and_fetch(key, wrapper); if (!ok || wrapper.mrs.empty()) { throw std::runtime_error("Failed to fetch MR from remote rank=" + std::to_string(remote_rank)); } - MR remote_mr = wrapper.mrs[0]; // only support one mr now + mr = wrapper.mrs[0]; // only support one mr now { std::lock_guard lk(remote_mr_mu_); - rank_mr_id_to_remote_mr_[remote_rank][remote_mr.id] = remote_mr; + rank_mr_id_to_remote_mr_[remote_rank][mr.id] = mr; } - std::cout << "[recv MR from rank " << remote_rank - << "] addr=" << remote_mr.address << " length=" << remote_mr.length - << " key=" << remote_mr.key << std::endl; + std::cout << "[recv MR from rank " << remote_rank << "] addr=" << mr.address + << " length=" << mr.length << " key=" << mr.key << std::endl; - return remote_mr; + return true; } MR Communicator::get_local_mr(void* local_buf) { @@ -730,39 +809,25 @@ MR Communicator::get_remote_mr(int remote_rank, uint16_t mr_id) { return it_mr->second; } -// Register a local IPC cache for a buffer -bool Communicator::register_local_ipc_cache(void* local_buf) { - std::lock_guard lock(local_ipc_cache_mu_); - // TODO: open gpu ipc handle - // ptr_to_local_ipc_cache_[local_buf] = cache; - return true; -} - -// Get the IPC cache of a local buffer -IpcCache Communicator::get_local_ipc_cache(void* local_buf) { - std::lock_guard lock(local_ipc_cache_mu_); - auto it = ptr_to_local_ipc_cache_.find(local_buf); - if (it != ptr_to_local_ipc_cache_.end()) return it->second; - return IpcCache{}; -} - // Register a remote IPC cache for a given rank and buffer -bool Communicator::register_remote_ipc_cache(int remote_rank, void* local_buf, +bool Communicator::register_remote_ipc_cache(int remote_rank, + gpuIpcMemHandle_t handle, IpcCache const& cache) { std::lock_guard lock(remote_ipc_cache_mu_); - rank_ptr_to_ipc_cache_[remote_rank][local_buf] = cache; + rank_handle_to_ipc_cache_[remote_rank][MakeHandleKey(handle)] = cache; return true; } // Get the remote IPC cache of a buffer from a given rank -IpcCache Communicator::get_remote_ipc_cache(int remote_rank, void* local_buf) { +IpcCache Communicator::get_remote_ipc_cache(int remote_rank, + gpuIpcMemHandle_t handle) { std::lock_guard lock(remote_ipc_cache_mu_); - auto it_rank = rank_ptr_to_ipc_cache_.find(remote_rank); - if (it_rank != rank_ptr_to_ipc_cache_.end()) { - auto it_buf = it_rank->second.find(local_buf); - if (it_buf != it_rank->second.end()) return it_buf->second; - } - return IpcCache{}; + auto it_rank = rank_handle_to_ipc_cache_.find(remote_rank); + if (it_rank == rank_handle_to_ipc_cache_.end()) return IpcCache{}; + + auto it = it_rank->second.find(MakeHandleKey(handle)); + if (it == it_rank->second.end()) return IpcCache{}; + return it->second; } ibv_cq* Communicator::get_cq_by_index(int index) { diff --git a/experimental/eccl/src/cq_poller.cc b/experimental/eccl/src/cq_poller.cc index 8d2d7f9e0..56905293c 100644 --- a/experimental/eccl/src/cq_poller.cc +++ b/experimental/eccl/src/cq_poller.cc @@ -18,7 +18,7 @@ void CQPoller::start() { if (!running_.compare_exchange_strong(expected, true)) return; // already started thr_ = std::thread(&CQPoller::run_loop, this); - std::cout << "Communicator " << comm_->local_rank_ << " CQPoller with cq " + std::cout << "Communicator " << comm_->global_rank_ << " CQPoller with cq " << cq_ << " started!" << std::endl; } @@ -53,7 +53,7 @@ void CQPoller::stop() { if (!running_.compare_exchange_strong(expected, false)) return; // already stopped or not started if (thr_.joinable()) thr_.join(); - std::cout << "Communicator " << comm_->local_rank_ << " CQPoller with cq " + std::cout << "Communicator " << comm_->global_rank_ << " CQPoller with cq " << cq_ << " Closed!" << std::endl; } diff --git a/experimental/eccl/src/device/Makefile b/experimental/eccl/src/device/Makefile index 248b52353..b24823bee 100644 --- a/experimental/eccl/src/device/Makefile +++ b/experimental/eccl/src/device/Makefile @@ -29,4 +29,4 @@ $(TARGET): $(OBJ_CC) $(OBJ_CU) $(NVCC) $(OBJ_CC) $(OBJ_CU) -o $(TARGET) $(LIBS) $(RPATH) $(NVCCFLAGS) clean: - rm -f $(OBJ_CC) $(OBJ_CU) $(TARGET) + rm -f $(OBJ_CC) $(OBJ_CU) $(TARGET) *.d diff --git a/experimental/eccl/src/device/test_persistent b/experimental/eccl/src/device/test_persistent deleted file mode 100755 index 7cb3bc7da..000000000 Binary files a/experimental/eccl/src/device/test_persistent and /dev/null differ diff --git a/experimental/eccl/src/ipc_endpoint.cc b/experimental/eccl/src/ipc_endpoint.cc index 84ed66988..158fd6157 100644 --- a/experimental/eccl/src/ipc_endpoint.cc +++ b/experimental/eccl/src/ipc_endpoint.cc @@ -1,18 +1,220 @@ #include "transport.h" +#include "util/util.h" IPCEndpoint::IPCEndpoint(std::shared_ptr config, Communicator* comm) - : config_(config), comm_(comm) {} + : config_(config), comm_(comm) { + task_ring_ = uccl::create_ring(sizeof(IpcTask), kTaskRingSize); + stop_.store(false); + proxy_thread_ = std::thread([this] { proxy_thread_func(); }); -IPCEndpoint::~IPCEndpoint() {} + // int n_streams = std::max(1, (int)ucclParamNumGpuRtStreams()); // + // ?ucclParamNumGpuRtStreams + int n_streams = 2; + GPU_RT_CHECK(gpuSetDevice(comm->local_rank_)); + ipc_streams_.resize(n_streams); + for (int i = 0; i < n_streams; ++i) { + GPU_RT_CHECK( + gpuStreamCreateWithFlags(&ipc_streams_[i], gpuStreamNonBlocking)); + } +} + +IPCEndpoint::~IPCEndpoint() { + stop_.store(true); + cv_.notify_all(); + if (proxy_thread_.joinable()) proxy_thread_.join(); + if (task_ring_) { + free(task_ring_); + } +} -bool IPCEndpoint::connect_to(int rank) { return true; } +bool IPCEndpoint::connect_to(int rank) { + return comm_->uds_->connect_to(rank, 30000); +} -bool IPCEndpoint::accept_from(int rank) { return true; } +bool IPCEndpoint::accept_from(int rank) { + return comm_->uds_->accept_from(rank, 30000); +} bool IPCEndpoint::send_async(int to_rank, std::shared_ptr creq) { + if (!creq || creq->len == 0) return false; + creq->pending_signaled.store(1, std::memory_order_relaxed); + creq->running.store(true, std::memory_order_release); + + IpcTask t{IpcTaskType::SEND, to_rank, creq, 0, 0}; + + while (jring_mp_enqueue_bulk(task_ring_, &t, 1, nullptr) != 1) { + std::this_thread::yield(); + } + pending_.fetch_add(1, std::memory_order_relaxed); + cv_.notify_one(); + + // std::cout << "produce IPC send creq to task_ring_" << std::endl; return true; } bool IPCEndpoint::recv_async(int from_rank, std::shared_ptr creq) { + if (!creq || creq->len == 0) return false; + creq->pending_signaled.store(1, std::memory_order_relaxed); + creq->running.store(true, std::memory_order_release); + + IpcTask t{IpcTaskType::RECV, from_rank, creq, 0, 0}; + + while (jring_mp_enqueue_bulk(task_ring_, &t, 1, nullptr) != 1) { + std::this_thread::yield(); + } + pending_.fetch_add(1, std::memory_order_relaxed); + cv_.notify_one(); + + // std::cout << "produce IPC recv creq to task_ring_" << std::endl; return true; } + +bool IPCEndpoint::send_(int to_rank, std::shared_ptr creq) { + CHECK(creq && creq->buf != nullptr) << "send_ipc: data pointer is null!"; + + int orig_device; + GPU_RT_CHECK(gpuGetDevice(&orig_device)); + auto dev_reset = + uccl::finally([&]() { GPU_RT_CHECK(gpuSetDevice(orig_device)); }); + + IpcCacheWire got{}; + uint64_t seq = 0; + if (!comm_->uds_->recv_ipc_cache(to_rank, got, &seq, 50000)) { + std::cerr << "[ERROR] recv_ipc_cache(" << to_rank << ") failed\n"; + return false; + } + + GPU_RT_CHECK(gpuSetDevice(got.remote_gpu_idx_)); + + IpcCache cache = comm_->get_remote_ipc_cache(to_rank, got.handle); + void* base = cache.direct_ptr; + if (base == nullptr) { + // miss: open + register + GPU_RT_CHECK( + gpuIpcOpenMemHandle(&base, got.handle, gpuIpcMemLazyEnablePeerAccess)); + + IpcCache new_cache{}; + new_cache.handle = got.handle; + new_cache.is_send = got.is_send; + new_cache.direct_ptr = base; + new_cache.offset = got.offset; + new_cache.size = got.size; + + comm_->register_remote_ipc_cache(to_rank, got.handle, new_cache); + } + + void* dst_ptr = + reinterpret_cast(reinterpret_cast(base) + got.offset); + + int num_streams = std::min(ipc_streams_.size(), + creq->len < kIpcSizePerEngine + ? 1 + : (size_t)creq->len / kIpcSizePerEngine); + size_t chunk_size = creq->len / num_streams; + + GPU_RT_CHECK(gpuSetDevice(comm_->local_rank_)); + for (int i = 0; i < num_streams; ++i) { + // Split data and dst_ptr into n_streams chunks + void* chunk_data = reinterpret_cast( + reinterpret_cast(creq->buf) + i * chunk_size); + void* chunk_dst_ptr = reinterpret_cast( + reinterpret_cast(dst_ptr) + i * chunk_size); + auto copy_size = + i == num_streams - 1 ? creq->len - i * chunk_size : chunk_size; + // Works for both intra-GPU and inter-GPU copy + GPU_RT_CHECK(gpuMemcpyAsync(chunk_dst_ptr, chunk_data, copy_size, + gpuMemcpyDeviceToDevice, ipc_streams_[i])); + } + + for (auto& stream : ipc_streams_) { + GPU_RT_CHECK(gpuStreamSynchronize(stream)); + } + + // Notify receiver of completion + comm_->uds_->send_ack(to_rank, 0, 1); + + // We close all IPC memory handles later when releasing this endpoint. + return true; +} + +bool IPCEndpoint::recv_(int from_rank, std::shared_ptr creq) { + CHECK(creq && creq->buf != nullptr) << "recv_ipc: data pointer is null!"; + + int orig_device; + GPU_RT_CHECK(gpuGetDevice(&orig_device)); + auto dev_reset = + uccl::finally([&]() { GPU_RT_CHECK(gpuSetDevice(orig_device)); }); + + GPU_RT_CHECK(gpuSetDevice(comm_->local_rank_)); + // Generate IPC memory handle for our receive buffer + IpcCacheWire transfer_info = {}; // Initialize to zero + transfer_info.size = creq->len; + transfer_info.is_send = 0; + transfer_info.remote_gpu_idx_ = comm_->local_rank_; + GPU_RT_CHECK(gpuIpcGetMemHandle(&transfer_info.handle, + reinterpret_cast(creq->buf))); + + // Getting the base address. + void* base = nullptr; + size_t base_size = 0; + GPU_RT_CHECK(gpuMemGetAddressRange(&base, &base_size, creq->buf)); + transfer_info.offset = reinterpret_cast(creq->buf) - + reinterpret_cast(base); + + comm_->uds_->send_ipc_cache(from_rank, 0, transfer_info); + + // Wait Notify of sender's completion + uint32_t status = 0; + uint64_t out_seq = 0; + uint64_t expect_seq = 0; + comm_->uds_->recv_ack(from_rank, &status, &out_seq, 5000, expect_seq); + CHECK_EQ(out_seq, expect_seq) << "Sender reported failure"; + + return true; +} + +void IPCEndpoint::proxy_thread_func() { + while (!stop_.load(std::memory_order_relaxed)) { + // Wait until there is work + { + std::unique_lock lk(cv_mu_); + cv_.wait(lk, [&] { + return stop_.load(std::memory_order_relaxed) || + pending_.load(std::memory_order_relaxed) > 0; + }); + } + if (stop_.load(std::memory_order_relaxed)) break; + + // pop one and run it to completion. + IpcTask t; + if (jring_sc_dequeue_bulk(task_ring_, &t, 1, nullptr) != 1) { + // just continue to next wait. + continue; + } + pending_.fetch_sub(1, std::memory_order_relaxed); + + bool ok = false; + if (t.type == IpcTaskType::SEND) { + std::cout << "consume a IPC send creq from task_ring_" << std::endl; + ok = send_(t.peer_rank, t.req); + } else { + std::cout << "consume a IPC recv creq from task_ring_" << std::endl; + ok = recv_(t.peer_rank, t.req); + } + + // Tasks on an IPCEndpoint are strictly serialized and order-preserving, + // so it is safe to process the current task directly. + if (t.req) { + if (!ok) t.req->failed.store(true, std::memory_order_release); + t.req->running.store(false, std::memory_order_release); + t.req->on_comm_done(ok); + } + } + + // Drain leftover tasks to avoid leaks + while (true) { + IpcTask t; + if (jring_mc_dequeue_bulk(task_ring_, &t, 1, nullptr) != 1) break; + if (t.req) t.req->on_comm_done(false); + } +} diff --git a/experimental/eccl/src/oob_uds.cc b/experimental/eccl/src/oob_uds.cc new file mode 100644 index 000000000..5b8519aee --- /dev/null +++ b/experimental/eccl/src/oob_uds.cc @@ -0,0 +1,559 @@ +#include "oob.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { +constexpr uint32_t kHelloMagic = 0xC0DEF00D; +constexpr uint32_t kMsgMagic = 0x55445331; // "UDS1" +constexpr uint16_t kVersion = 1; + +static inline std::string dirname_of(std::string const& path) { + auto pos = path.find_last_of('/'); + if (pos == std::string::npos) return {}; + return path.substr(0, pos); +} + +static inline void mkdir_best_effort(std::string const& dir) { + if (dir.empty()) return; + ::mkdir(dir.c_str(), 0700); +} + +} // namespace + +UdsExchanger::UdsExchanger(int self_rank) : self_rank_(self_rank) {} + +UdsExchanger::~UdsExchanger() { + running_.store(false, std::memory_order_relaxed); + + // close peers + { + std::lock_guard lk(mu_); + for (auto& kv : rank_to_fd_) { + ::shutdown(kv.second, SHUT_RDWR); + ::close(kv.second); + } + rank_to_fd_.clear(); + rank_send_mu_.clear(); + } + + // close listen + if (listen_fd_ != -1) { + ::shutdown(listen_fd_, SHUT_RDWR); + ::close(listen_fd_); + listen_fd_ = -1; + } + + // unlink sock path + if (!local_path_.empty()) { + ::unlink(local_path_.c_str()); + } +} + +bool UdsExchanger::ensure_server_started() { + if (running_.load(std::memory_order_relaxed)) return true; + + std::lock_guard lk(mu_); + if (running_.load(std::memory_order_relaxed)) return true; + + local_path_ = path_for_rank(self_rank_); + mkdir_best_effort(dirname_of(local_path_)); + ::unlink(local_path_.c_str()); + + listen_fd_ = ::socket(AF_UNIX, SOCK_STREAM, 0); + if (listen_fd_ < 0) { + std::cerr << "[UDS] socket() failed: " << std::strerror(errno) << "\n"; + listen_fd_ = -1; + return false; + } + + sockaddr_un addr{}; + addr.sun_family = AF_UNIX; + std::snprintf(addr.sun_path, sizeof(addr.sun_path), "%s", + local_path_.c_str()); + + if (::bind(listen_fd_, reinterpret_cast(&addr), sizeof(addr)) < + 0) { + std::cerr << "[UDS] bind(" << local_path_ + << ") failed: " << std::strerror(errno) << "\n"; + ::close(listen_fd_); + listen_fd_ = -1; + return false; + } + + if (::listen(listen_fd_, 128) < 0) { + std::cerr << "[UDS] listen() failed: " << std::strerror(errno) << "\n"; + ::close(listen_fd_); + listen_fd_ = -1; + return false; + } + + running_.store(true, std::memory_order_relaxed); + + std::cout << "[UDS] listen() at: " << local_path_ << std::endl; + return true; +} + +bool UdsExchanger::connect_to(int peer_rank, int timeout_ms) { + if (peer_rank == self_rank_) return true; + if (!ensure_server_started()) return false; + + { + std::lock_guard lk(mu_); + if (rank_to_fd_.count(peer_rank)) return true; + } + + const std::string peer_path = path_for_rank(peer_rank); + auto deadline = + std::chrono::steady_clock::now() + std::chrono::milliseconds(timeout_ms); + + while (std::chrono::steady_clock::now() < deadline) { + int fd = -1; + if (connect_once(peer_path, fd)) { + Hello h{kHelloMagic, self_rank_, peer_rank, kVersion}; + if (!send_all(fd, reinterpret_cast(&h), sizeof(h))) { + ::close(fd); + return false; + } + + { + std::lock_guard lk(mu_); + // In case another thread already connected + if (rank_to_fd_.count(peer_rank)) { + ::shutdown(fd, SHUT_RDWR); + ::close(fd); + return true; + } + rank_to_fd_[peer_rank] = fd; + rank_send_mu_[peer_rank] = std::make_unique(); + rank_recv_mu_[peer_rank] = std::make_unique(); + } + return true; + } + + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + + return false; +} + +bool UdsExchanger::accept_from(int peer_rank, int timeout_ms) { + if (peer_rank == self_rank_) return true; + if (!ensure_server_started()) return false; + + { + std::lock_guard lk(mu_); + if (rank_to_fd_.count(peer_rank)) return true; + } + + // Ensure only one accept loop at a time + std::unique_lock accept_lk(accept_mu_); + + // re-check after acquiring lock + { + std::lock_guard lk(mu_); + if (rank_to_fd_.count(peer_rank)) return true; + } + + auto deadline = + std::chrono::steady_clock::now() + std::chrono::milliseconds(timeout_ms); + + while (std::chrono::steady_clock::now() < deadline) { + int remaining_ms = + (int)std::chrono::duration_cast( + deadline - std::chrono::steady_clock::now()) + .count(); + if (remaining_ms <= 0) break; + + int fd = accept_with_timeout(remaining_ms); + if (fd == -1) { // timeout + break; + } + if (fd == -2) { // fatal + return false; + } + + // read hello + Hello h{}; + if (!recv_all(fd, reinterpret_cast(&h), sizeof(h))) { + ::close(fd); + continue; + } + if (h.magic != kHelloMagic || h.version != kVersion || + h.to_rank != self_rank_) { + ::close(fd); + continue; + } + + // cache connection (even if not the expected peer) + { + std::lock_guard lk(mu_); + auto it = rank_to_fd_.find(h.from_rank); + if (it != rank_to_fd_.end()) { + // already have; keep existing + ::shutdown(fd, SHUT_RDWR); + ::close(fd); + } else { + rank_to_fd_[h.from_rank] = fd; + rank_send_mu_[h.from_rank] = std::make_unique(); + rank_recv_mu_[h.from_rank] = std::make_unique(); + } + } + + // done if expected is present + { + std::lock_guard lk(mu_); + if (rank_to_fd_.count(peer_rank)) return true; + } + } + + return false; +} + +bool UdsExchanger::send(int peer_rank, uint16_t type, uint64_t seq, + void const* payload, uint32_t bytes) { + int fd = -1; + std::mutex* smu = nullptr; + + { + std::lock_guard lk(mu_); + auto it = rank_to_fd_.find(peer_rank); + if (it == rank_to_fd_.end()) return false; + fd = it->second; + + auto itmu = rank_send_mu_.find(peer_rank); + if (itmu == rank_send_mu_.end() || !itmu->second) return false; + smu = itmu->second.get(); + } + + MsgHdr hdr{}; + hdr.magic = kMsgMagic; + hdr.version = kVersion; + hdr.type = type; + hdr.bytes = bytes; + hdr.from_rank = self_rank_; + hdr.to_rank = peer_rank; + hdr.seq = seq; + + std::lock_guard lk(*smu); + if (!send_all(fd, reinterpret_cast(&hdr), sizeof(hdr))) + return false; + if (bytes > 0 && payload != nullptr) { + if (!send_all(fd, reinterpret_cast(payload), bytes)) + return false; + } + return true; +} + +bool UdsExchanger::send_ipc_cache(int peer_rank, uint64_t seq, + IpcCacheWire const& cache) { + return send(peer_rank, kTypeIpcCache, seq, &cache, + (uint32_t)sizeof(IpcCacheWire)); +} + +bool UdsExchanger::recv_ipc_cache(int peer_rank, IpcCacheWire& out_cache, + uint64_t* out_seq, int timeout_ms) { + int fd = -1; + std::mutex* rmu = nullptr; + + { + std::lock_guard lk(mu_); + auto it = rank_to_fd_.find(peer_rank); + if (it == rank_to_fd_.end()) return false; + fd = it->second; + + auto itmu = rank_recv_mu_.find(peer_rank); + if (itmu == rank_recv_mu_.end() || !itmu->second) return false; + rmu = itmu->second.get(); + } + + // Serialize all recv operations for this peer. + std::unique_lock rlk(*rmu); + + auto deadline = std::chrono::steady_clock::now() + + std::chrono::milliseconds(timeout_ms < 0 ? 0 : timeout_ms); + + auto drain_bytes = [&](uint32_t nbytes) -> bool { + // Drain without allocating huge buffers. + char buf[4096]; + uint32_t left = nbytes; + while (left > 0) { + uint32_t chunk = left > sizeof(buf) ? (uint32_t)sizeof(buf) : left; + if (!recv_all(fd, buf, chunk)) return false; + left -= chunk; + } + return true; + }; + + while (true) { + // wait for readable with timeout + int wait_ms = 0; + if (timeout_ms < 0) { + wait_ms = 1000; // "forever" mode, 1s polling chunks + } else { + auto now = std::chrono::steady_clock::now(); + if (now >= deadline) return false; + wait_ms = (int)std::chrono::duration_cast( + deadline - now) + .count(); + if (wait_ms <= 0) return false; + } + + fd_set rfds; + FD_ZERO(&rfds); + FD_SET(fd, &rfds); + + timeval tv{}; + tv.tv_sec = wait_ms / 1000; + tv.tv_usec = (wait_ms % 1000) * 1000; + + int r = ::select(fd + 1, &rfds, nullptr, nullptr, &tv); + if (r == 0) { + if (timeout_ms < 0) continue; // keep waiting + return false; // timeout + } + if (r < 0) { + if (errno == EINTR) continue; + return false; + } + + // read header + MsgHdr hdr{}; + if (!recv_all(fd, reinterpret_cast(&hdr), sizeof(hdr))) { + return false; + } + + // Only accept IPC cache messages; otherwise drain and continue. + if (hdr.type != kTypeIpcCache) { + if (hdr.bytes > 0 && !drain_bytes(hdr.bytes)) return false; + continue; + } + + if (hdr.bytes != sizeof(IpcCacheWire)) { + if (hdr.bytes > 0 && !drain_bytes(hdr.bytes)) return false; + continue; + } + + // read payload + if (!recv_all(fd, reinterpret_cast(&out_cache), sizeof(out_cache))) { + return false; + } + + if (out_seq) *out_seq = hdr.seq; + return true; + } +} + +bool UdsExchanger::send_ack(int peer_rank, uint64_t seq, uint32_t status) { + AckWire ack{}; + ack.status = status; + ack.reserved = 0; + return send(peer_rank, kTypeAck, seq, &ack, (uint32_t)sizeof(AckWire)); +} + +bool UdsExchanger::recv_ack(int peer_rank, uint32_t* out_status, + uint64_t* out_seq, int timeout_ms, + uint64_t expected_seq) { + int fd = -1; + std::mutex* rmu = nullptr; + + { + std::lock_guard lk(mu_); + auto it = rank_to_fd_.find(peer_rank); + if (it == rank_to_fd_.end()) return false; + fd = it->second; + + auto itmu = rank_recv_mu_.find(peer_rank); + if (itmu == rank_recv_mu_.end() || !itmu->second) return false; + rmu = itmu->second.get(); + } + + // Serialize all recv operations for this peer. + std::unique_lock rlk(*rmu); + + auto deadline = std::chrono::steady_clock::now() + + std::chrono::milliseconds(timeout_ms < 0 ? 0 : timeout_ms); + + auto drain_bytes = [&](uint32_t nbytes) -> bool { + char buf[4096]; + uint32_t left = nbytes; + while (left > 0) { + uint32_t chunk = left > sizeof(buf) ? (uint32_t)sizeof(buf) : left; + if (!recv_all(fd, buf, chunk)) return false; + left -= chunk; + } + return true; + }; + + while (true) { + // wait readable with timeout + int wait_ms = 0; + if (timeout_ms < 0) { + wait_ms = 1000; // forever mode + } else { + auto now = std::chrono::steady_clock::now(); + if (now >= deadline) return false; + wait_ms = (int)std::chrono::duration_cast( + deadline - now) + .count(); + if (wait_ms <= 0) return false; + } + + fd_set rfds; + FD_ZERO(&rfds); + FD_SET(fd, &rfds); + + timeval tv{}; + tv.tv_sec = wait_ms / 1000; + tv.tv_usec = (wait_ms % 1000) * 1000; + + int r = ::select(fd + 1, &rfds, nullptr, nullptr, &tv); + if (r == 0) { + if (timeout_ms < 0) continue; + return false; // timeout + } + if (r < 0) { + if (errno == EINTR) continue; + return false; + } + + // read header + MsgHdr hdr{}; + if (!recv_all(fd, reinterpret_cast(&hdr), sizeof(hdr))) return false; + + // If not ack, drain and continue + if (hdr.type != kTypeAck) { + if (hdr.bytes > 0 && !drain_bytes(hdr.bytes)) return false; + continue; + } + + // optional seq check + if (expected_seq != UINT64_MAX && hdr.seq != expected_seq) { + if (hdr.bytes > 0 && !drain_bytes(hdr.bytes)) return false; + continue; + } + + // validate payload size + if (hdr.bytes != sizeof(AckWire)) { + if (hdr.bytes > 0 && !drain_bytes(hdr.bytes)) return false; + continue; + } + + AckWire ack{}; + if (!recv_all(fd, reinterpret_cast(&ack), sizeof(ack))) return false; + + if (out_status) *out_status = ack.status; + if (out_seq) *out_seq = hdr.seq; + return true; + } +} + +int UdsExchanger::get_fd(int peer_rank) const { + std::lock_guard lk(mu_); + auto it = rank_to_fd_.find(peer_rank); + return it == rank_to_fd_.end() ? -1 : it->second; +} + +void UdsExchanger::close_peer(int peer_rank) { + std::lock_guard lk(mu_); + auto it = rank_to_fd_.find(peer_rank); + if (it == rank_to_fd_.end()) return; + int fd = it->second; + ::shutdown(fd, SHUT_RDWR); + ::close(fd); + rank_to_fd_.erase(it); + rank_send_mu_.erase(peer_rank); + rank_recv_mu_.erase(peer_rank); +} + +bool UdsExchanger::connect_once(std::string const& peer_path, int& out_fd) { + out_fd = -1; + int fd = ::socket(AF_UNIX, SOCK_STREAM, 0); + if (fd < 0) return false; + + sockaddr_un addr{}; + addr.sun_family = AF_UNIX; + std::snprintf(addr.sun_path, sizeof(addr.sun_path), "%s", peer_path.c_str()); + + if (::connect(fd, reinterpret_cast(&addr), sizeof(addr)) == 0) { + out_fd = fd; + return true; + } + + ::close(fd); + return false; +} + +bool UdsExchanger::send_all(int fd, char const* buf, size_t len) { + size_t off = 0; + while (off < len) { + ssize_t n = ::send(fd, buf + off, len - off, 0); + if (n > 0) { + off += (size_t)n; + continue; + } + if (n == 0) return false; + if (errno == EINTR) continue; + return false; + } + return true; +} + +bool UdsExchanger::recv_all(int fd, char* buf, size_t len) { + size_t off = 0; + while (off < len) { + ssize_t n = ::recv(fd, buf + off, len - off, 0); + if (n > 0) { + off += (size_t)n; + continue; + } + if (n == 0) return false; + if (errno == EINTR) continue; + return false; + } + return true; +} + +int UdsExchanger::accept_with_timeout(int timeout_ms) { + if (listen_fd_ < 0) return -2; + + fd_set rfds; + FD_ZERO(&rfds); + FD_SET(listen_fd_, &rfds); + + timeval tv{}; + tv.tv_sec = timeout_ms / 1000; + tv.tv_usec = (timeout_ms % 1000) * 1000; + + int r = ::select(listen_fd_ + 1, &rfds, nullptr, nullptr, &tv); + if (r == 0) return -1; // timeout + if (r < 0) { + if (errno == EINTR) return -1; + return -2; + } + + int fd = ::accept(listen_fd_, nullptr, nullptr); + if (fd < 0) { + if (errno == EINTR) return -1; + return -2; + } + return fd; +} + +std::string UdsExchanger::path_for_rank(int rank) { + // Keep path short: /tmp/eccl_oob_uds/r.sock + char buf[128]; + std::snprintf(buf, sizeof(buf), "/tmp/eccl_oob_uds/r%d.sock", rank); + + std::string path(buf); + mkdir_best_effort(dirname_of(path)); + return path; +} \ No newline at end of file diff --git a/experimental/eccl/src/rdma_endpoint.cc b/experimental/eccl/src/rdma_endpoint.cc index e3a7f9db4..5d80cc76f 100644 --- a/experimental/eccl/src/rdma_endpoint.cc +++ b/experimental/eccl/src/rdma_endpoint.cc @@ -5,17 +5,17 @@ RDMAEndpoint::RDMAEndpoint(std::shared_ptr config, Communicator* comm) : config_(config), comm_(comm) {} RDMAEndpoint::~RDMAEndpoint() { - int local_rank = comm_ ? comm_->local_rank_ : -1; + int rank = comm_ ? comm_->global_rank_ : -1; { std::lock_guard lk(qp_list_mu_); for (size_t i = 0; i < qp_list_.size(); i++) { ibv_qp* qp = qp_list_[i]; if (qp) { if (ibv_destroy_qp(qp)) { - std::cerr << "[WARN] Communicator " << local_rank + std::cerr << "[WARN] Communicator " << rank << " Failed to destroy QP[" << i << "]" << std::endl; } else { - std::cout << "[INFO] Communicator " << local_rank << " QP[" << i + std::cout << "[INFO] Communicator " << rank << " QP[" << i << "] destroyed" << std::endl; } } @@ -31,12 +31,12 @@ RDMAEndpoint::~RDMAEndpoint() { remote_qp_info_list_.clear(); } - std::cout << "[INFO] Communicator " << local_rank + std::cout << "[INFO] Communicator " << rank << " RDMAEndpoint resources released" << std::endl; } bool RDMAEndpoint::connect_to(int peer_rank) { - int local_rank = comm_->local_rank_; + int rank = comm_->global_rank_; // Create QPs for (int i = 0; i < config_->qp_count_per_ep; i++) { @@ -55,8 +55,8 @@ bool RDMAEndpoint::connect_to(int peer_rank) { std::lock_guard lock(qp_list_mu_); ibv_qp* new_qp = ibv_create_qp(comm_->pd_, &qp_init_attr); if (!new_qp) { - std::cerr << "[ERROR] Communicator " << local_rank - << " Failed to create QP" << std::endl; + std::cerr << "[ERROR] Communicator " << rank << " Failed to create QP" + << std::endl; return false; } qp_list_.push_back(new_qp); @@ -82,7 +82,7 @@ bool RDMAEndpoint::connect_to(int peer_rank) { for (int i = 0; i < config_->qp_count_per_ep; i++) { if (ibv_modify_qp(qp_list_[i], &attr, flags)) { - std::cerr << "[ERROR] Communicator " << local_rank + std::cerr << "[ERROR] Communicator " << rank << " Failed to modify QP to INIT: " << strerror(errno) << std::endl; return false; @@ -96,9 +96,9 @@ bool RDMAEndpoint::connect_to(int peer_rank) { local_info.qps = qp_info_list_; } std::string key = - "qpinfo:" + std::to_string(local_rank) + "->" + std::to_string(peer_rank); + "qpinfo:" + std::to_string(rank) + "->" + std::to_string(peer_rank); if (!comm_->exchanger_client_->publish(key, local_info)) { - std::cerr << "[ERROR] Communicator " << local_rank + std::cerr << "[ERROR] Communicator " << rank << " Failed to publish QP info for peer " << peer_rank << std::endl; return false; @@ -106,9 +106,9 @@ bool RDMAEndpoint::connect_to(int peer_rank) { RDMAInfo remote_info; std::string peer_key = - "qpinfo:" + std::to_string(peer_rank) + "->" + std::to_string(local_rank); + "qpinfo:" + std::to_string(peer_rank) + "->" + std::to_string(rank); if (!comm_->exchanger_client_->wait_and_fetch(peer_key, remote_info, -1)) { - std::cerr << "[ERROR] Communicator " << local_rank + std::cerr << "[ERROR] Communicator " << rank << " Timeout waiting QP info from peer " << peer_rank << std::endl; return false; @@ -124,9 +124,8 @@ bool RDMAEndpoint::connect_to(int peer_rank) { } } - std::cout << "[INFO] Communicator " << local_rank - << " QP info exchanged with rank " << peer_rank - << ", local_qps=" << local_info.qps.size() + std::cout << "[INFO] Communicator " << rank << " QP info exchanged with rank " + << peer_rank << ", local_qps=" << local_info.qps.size() << ", remote_qps=" << remote_info.qps.size() << std::endl; // Modify QPs to RTR @@ -159,12 +158,12 @@ bool RDMAEndpoint::connect_to(int peer_rank) { IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER; if (ibv_modify_qp(qp_list_[i], &attr, rtr_flags)) { - std::cerr << "[ERROR] Communicator " << local_rank + std::cerr << "[ERROR] Communicator " << rank << " Failed to modify QP to RTR: " << strerror(errno) << std::endl; return false; } - std::cout << "[INFO] Communicator " << local_rank << " QP[" << i + std::cout << "[INFO] Communicator " << rank << " QP[" << i << "] modified to RTR state" << std::endl; } @@ -181,12 +180,12 @@ bool RDMAEndpoint::connect_to(int peer_rank) { int rts_flags = IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC; if (ibv_modify_qp(qp_list_[i], &attr, rts_flags)) { - std::cerr << "[ERROR] Communicator " << local_rank + std::cerr << "[ERROR] Communicator " << rank << " Failed to modify QP to RTS: " << strerror(errno) << std::endl; return false; } - std::cout << "[INFO] Communicator " << local_rank << " QP[" << i + std::cout << "[INFO] Communicator " << rank << " QP[" << i << "] modified to RTS state" << std::endl; } diff --git a/experimental/eccl/tests/test.h b/experimental/eccl/tests/test.h index 80cee8139..cfed84292 100644 --- a/experimental/eccl/tests/test.h +++ b/experimental/eccl/tests/test.h @@ -2,10 +2,11 @@ #include -void test_communicator(); +void test_communicator(int argc, char** argv); void test_cq_poller(); void test_find_best_rdma_for_gpu(int gpu_id); void test_redis_oob(); void test_generate_host_id(); +void test_uds_oob(); void test_redis_meta_exchange_multi_threads(int world_size); void test_socket_meta_exchange_multi_threads(int world_size); \ No newline at end of file diff --git a/experimental/eccl/tests/test_communicator.cc b/experimental/eccl/tests/test_communicator.cc index 9077dfc44..25a334952 100644 --- a/experimental/eccl/tests/test_communicator.cc +++ b/experimental/eccl/tests/test_communicator.cc @@ -1,60 +1,158 @@ -#include "test.h" #include "transport.h" +#include "util/util.h" +#include +#include #include +#include #include +#include -void communicator_client_thread() { - { - auto comm = std::make_shared( - 0, 0, 2, - std::make_shared()); // gpu_id=0, local_rank=0, world_size=2 - std::cout << "[CLIENT] Communicator for rank 0 created." << std::endl; - - int peer_rank = 1; - if (comm->connect_to(peer_rank)) { - std::cout << "[CLIENT] Successfully connected to rank " << peer_rank - << std::endl; - } else { - std::cerr << "[CLIENT] Failed to connect to rank " << peer_rank - << std::endl; - } +static constexpr int kWorldSize = 2; +static constexpr int client_gpu = 0; +static constexpr int server_gpu = 0; +static constexpr int client_rank = 1; +static constexpr int server_rank = 0; + +namespace { +constexpr size_t kBytes = 4 * 1024; - comm.reset(); - std::cout << "[CLIENT] Communicator destroyed and resources released." - << std::endl; +static void fill_pattern(std::vector& buf) { + for (size_t i = 0; i < buf.size(); ++i) + buf[i] = static_cast(i & 0xFF); +} + +static bool check_pattern(std::vector const& buf) { + for (size_t i = 0; i < buf.size(); ++i) { + if (buf[i] != static_cast(i & 0xFF)) { + std::cerr << "[SERVER] mismatch at " << i << "\n"; + return false; + } } + return true; } -void communicator_server_thread() { - { - auto comm = std::make_shared( - 0, 1, 2, - std::make_shared()); // gpu_id=0, local_rank=1, world_size=2 - std::cout << "[SERVER] Communicator for rank 1 created." << std::endl; - - int peer_rank = 0; - if (comm->accept_from(peer_rank)) { - std::cout << "[SERVER] Successfully accepted connection from rank " - << peer_rank << std::endl; - } else { - std::cerr << "[SERVER] Failed to accept connection from rank " - << peer_rank << std::endl; +static std::string get_arg(int argc, char** argv, char const* key, + char const* def) { + // --role=server --role server + for (int i = 1; i < argc; ++i) { + if (std::strncmp(argv[i], key, std::strlen(key)) == 0) { + char const* p = argv[i] + std::strlen(key); + if (*p == '=') return std::string(p + 1); + if (*p == '\0' && i + 1 < argc) return std::string(argv[i + 1]); } + } + return std::string(def); +} +} // namespace + +static int run_client() { + auto cfg = std::make_shared(); + auto comm = + std::make_shared(client_gpu, client_rank, kWorldSize, cfg); + + int peer_rank = server_rank; + if (!comm->connect_to(peer_rank)) { + std::cerr << "[CLIENT] connect_to failed\n"; + return 2; + } + + // host pattern + std::vector sendbuf_h(kBytes); + fill_pattern(sendbuf_h); + + // device buffer + GPU_RT_CHECK(gpuSetDevice(client_gpu)); + void* sendbuf_d = nullptr; + GPU_RT_CHECK(gpuMalloc(&sendbuf_d, kBytes)); + auto send_free = uccl::finally([&] { + if (sendbuf_d) GPU_RT_CHECK(gpuFree(sendbuf_d)); + }); + + GPU_RT_CHECK( + gpuMemcpy(sendbuf_d, sendbuf_h.data(), kBytes, gpuMemcpyHostToDevice)); + + // only RDMA + MR local_mr = comm->reg_mr(sendbuf_d, kBytes); + if (!comm->notify_mr(peer_rank, local_mr)) { + return 3; + } + MR remote_mr; + if (!comm->wait_mr_notify(peer_rank, remote_mr)) { + return 3; + } + + unsigned sreq = + comm->isend(peer_rank, sendbuf_d, 0, kBytes, local_mr.id, remote_mr.id, + /*on_gpu*/ true); - comm.reset(); - std::cout << "[SERVER] Communicator destroyed and resources released." - << std::endl; + if (!comm->wait_finish(sreq)) { + std::cerr << "[CLIENT] wait_finish(send) failed\n"; + return 4; } + + std::cout << "[CLIENT] Send done\n"; + comm.reset(); + return 0; +} + +static int run_server() { + auto cfg = std::make_shared(); + auto comm = + std::make_shared(server_gpu, server_rank, kWorldSize, cfg); + + int peer_rank = client_rank; + if (!comm->accept_from(peer_rank)) { + std::cerr << "[SERVER] accept_from failed\n"; + return 2; + } + + GPU_RT_CHECK(gpuSetDevice(server_gpu)); + void* recvbuf_d = nullptr; + GPU_RT_CHECK(gpuMalloc(&recvbuf_d, kBytes)); + auto recv_free = uccl::finally([&] { + if (recvbuf_d) GPU_RT_CHECK(gpuFree(recvbuf_d)); + }); + + MR local_mr = comm->reg_mr(recvbuf_d, kBytes); + if (!comm->notify_mr(peer_rank, local_mr)) { + return 3; + } + MR remote_mr; + if (!comm->wait_mr_notify(peer_rank, remote_mr)) { + return 3; + } + + unsigned rreq = comm->irecv(peer_rank, recvbuf_d, 0, kBytes, + /*on_gpu*/ true); + + if (!comm->wait_finish(rreq)) { + std::cerr << "[SERVER] wait_finish(recv) failed\n"; + return 4; + } + + // copy back and check + std::vector recvbuf_h(kBytes, 0); + GPU_RT_CHECK( + gpuMemcpy(recvbuf_h.data(), recvbuf_d, kBytes, gpuMemcpyDeviceToHost)); + + bool ok = check_pattern(recvbuf_h); + std::cout << (ok ? "[SERVER] OK\n" : "[SERVER] FAILED\n"); + + comm.reset(); + return ok ? 0 : 5; } -void test_communicator() { - std::thread t_client(communicator_client_thread); - std::thread t_server(communicator_server_thread); +int test_communicator(int argc, char** argv) { + std::string role = get_arg(argc, argv, "--role", ""); + if (role.empty()) role = get_arg(argc, argv, "-r", ""); - t_client.join(); - t_server.join(); + if (role != "server" && role != "client") { + std::cerr << "Usage:\n" + << " " << argv[0] << " --role=server\n" + << " " << argv[0] << " --role=client\n"; + return 1; + } - std::cout - << "[TEST] RDMA QP connection test finished, all threads exited cleanly." - << std::endl; + if (role == "server") return run_server(); + return run_client(); } diff --git a/experimental/eccl/tests/test_cq_poller.cc b/experimental/eccl/tests/test_cq_poller.cc index 60fee47f3..dd15f35f5 100644 --- a/experimental/eccl/tests/test_cq_poller.cc +++ b/experimental/eccl/tests/test_cq_poller.cc @@ -29,7 +29,10 @@ void cqpoller_client_thread(std::shared_ptr comm, int peer_rank) { GPU_RT_CHECK(gpuMemcpy(d_data, h_data.data(), size, gpuMemcpyHostToDevice)); auto mr = comm->reg_mr(d_data, size); - auto remote_mr = comm->wait_mr_notify(peer_rank); + MR remote_mr; + if (!comm->wait_mr_notify(peer_rank, remote_mr)) { + return; + } std::cout << "[CLIENT] Got remote MR id=" << remote_mr.id << " addr=0x" << std::hex << remote_mr.address << " len=" << std::dec @@ -103,10 +106,8 @@ void cqpoller_server_thread(std::shared_ptr comm, int peer_rank) { } void test_cq_poller() { - auto comm0 = - std::make_shared(0, 0, 2); // gpu_id=0, local_rank=0 - auto comm1 = - std::make_shared(0, 1, 2); // gpu_id=0, local_rank=1 + auto comm0 = std::make_shared(0, 0, 2); // gpu_id=0, rank=0 + auto comm1 = std::make_shared(0, 1, 2); // gpu_id=0, rank=1 std::thread t_client(cqpoller_client_thread, comm0, 1); std::thread t_server(cqpoller_server_thread, comm1, 0); diff --git a/experimental/eccl/tests/test_main.cc b/experimental/eccl/tests/test_main.cc index fb4a909cd..db50136cd 100644 --- a/experimental/eccl/tests/test_main.cc +++ b/experimental/eccl/tests/test_main.cc @@ -1,14 +1,15 @@ #include "test.h" #include -int main() { +int main(int argc, char** argv) { // test_find_best_rdma_for_gpu(0); // test_find_best_rdma_for_gpu(2); // test_find_best_rdma_for_gpu(3); // test_find_best_rdma_for_gpu(5); - test_communicator(); + test_communicator(argc, argv); // test_redis_oob(); + // test_uds_oob(); // test_generate_host_id(); diff --git a/experimental/eccl/tests/test_oob_redis.cc b/experimental/eccl/tests/test_oob_redis.cc index 6edd8a08f..ced015f69 100644 --- a/experimental/eccl/tests/test_oob_redis.cc +++ b/experimental/eccl/tests/test_oob_redis.cc @@ -83,44 +83,44 @@ void test_redis_oob() { << std::endl; } -void rank_thread(int local_rank, int world_size, - std::string const& exchanger_ip, int exchanger_port) { +void rank_thread(int rank, int world_size, std::string const& exchanger_ip, + int exchanger_port) { auto ex = std::make_shared(exchanger_ip, exchanger_port); if (!ex->valid()) { - std::cerr << "[ERROR] Rank " << local_rank << " failed to connect to Redis" + std::cerr << "[ERROR] Rank " << rank << " failed to connect to Redis" << std::endl; return; } CommunicatorMeta local; - local.host_id = generate_host_id() + "_" + std::to_string(local_rank); + local.host_id = generate_host_id() + "_" + std::to_string(rank); local.is_ready = true; - std::string key = "meta:" + std::to_string(local_rank); + std::string key = "meta:" + std::to_string(rank); if (!ex->publish(key, local)) { - std::cerr << "[ERROR] Rank " << local_rank - << " failed to publish meta to key " << key << std::endl; + std::cerr << "[ERROR] Rank " << rank << " failed to publish meta to key " + << key << std::endl; return; } - std::cout << "[INFO] Rank " << local_rank << " published meta to key " << key + std::cout << "[INFO] Rank " << rank << " published meta to key " << key << std::endl; for (int r = 0; r < world_size; ++r) { - if (r == local_rank) continue; + if (r == rank) continue; std::string remote_key = "meta:" + std::to_string(r); CommunicatorMeta remote; if (ex->wait_and_fetch(remote_key, remote, 50, 100)) { - std::cout << "[INFO] Rank " << local_rank << " fetched meta for rank " - << r << ", host_id=" << remote.host_id + std::cout << "[INFO] Rank " << rank << " fetched meta for rank " << r + << ", host_id=" << remote.host_id << ", is_ready=" << remote.is_ready << std::endl; } else { - std::cerr << "[WARN] Rank " << local_rank + std::cerr << "[WARN] Rank " << rank << " timeout waiting for meta of rank " << r << std::endl; } } - std::cout << "[INFO] Rank " << local_rank << " completed meta exchange" + std::cout << "[INFO] Rank " << rank << " completed meta exchange" << std::endl; } diff --git a/experimental/eccl/tests/test_oob_socket.cc b/experimental/eccl/tests/test_oob_socket.cc index 6b7db8bb5..0ae3e152b 100644 --- a/experimental/eccl/tests/test_oob_socket.cc +++ b/experimental/eccl/tests/test_oob_socket.cc @@ -79,44 +79,43 @@ void test_socket_oob() { std::cout << "[INFO] RDMA socket OOB test complete\n"; } -void rank_thread_socket(int local_rank, int world_size, std::string const& ip, +void rank_thread_socket(int rank, int world_size, std::string const& ip, int port) { - bool is_server = (local_rank == 0); + bool is_server = (rank == 0); auto ex = std::make_shared(is_server, ip, port); if (!ex->valid()) { - std::cerr << "[ERROR] Rank " << local_rank - << " failed to init SockExchanger\n"; + std::cerr << "[ERROR] Rank " << rank << " failed to init SockExchanger\n"; return; } CommunicatorMeta local; - local.host_id = generate_host_id() + "_" + std::to_string(local_rank); + local.host_id = generate_host_id() + "_" + std::to_string(rank); local.ip = "127.0.0.1"; local.is_ready = true; - std::string key = "meta:" + std::to_string(local_rank); + std::string key = "meta:" + std::to_string(rank); if (!ex->publish(key, local)) { - std::cerr << "[ERROR] Rank " << local_rank << " failed to publish meta\n"; + std::cerr << "[ERROR] Rank " << rank << " failed to publish meta\n"; return; } - std::cout << "[INFO] Rank " << local_rank << " published meta (" - << local.host_id << ")\n"; + std::cout << "[INFO] Rank " << rank << " published meta (" << local.host_id + << ")\n"; for (int r = 0; r < world_size; ++r) { - if (r == local_rank) continue; + if (r == rank) continue; std::string remote_key = "meta:" + std::to_string(r); CommunicatorMeta remote; if (ex->wait_and_fetch(remote_key, remote, 50, 100)) { - std::cout << "[INFO] Rank " << local_rank << " fetched meta for rank " - << r << " host_id=" << remote.host_id << " ip=" << remote.ip + std::cout << "[INFO] Rank " << rank << " fetched meta for rank " << r + << " host_id=" << remote.host_id << " ip=" << remote.ip << " ready=" << remote.is_ready << "\n"; } else { - std::cerr << "[WARN] Rank " << local_rank + std::cerr << "[WARN] Rank " << rank << " timeout waiting for meta of rank " << r << "\n"; } } - std::cout << "[INFO] Rank " << local_rank << " completed meta exchange\n"; + std::cout << "[INFO] Rank " << rank << " completed meta exchange\n"; } void test_socket_meta_exchange_multi_threads(int world_size) { diff --git a/experimental/eccl/tests/test_oob_uds.cc b/experimental/eccl/tests/test_oob_uds.cc new file mode 100644 index 000000000..5f1f3e4db --- /dev/null +++ b/experimental/eccl/tests/test_oob_uds.cc @@ -0,0 +1,89 @@ +#include "oob.h" +#include "test.h" +#include +#include +#include +#include + +void test_uds_oob() { + std::cout << "[TEST] UDS OOB test start\n"; + + std::thread rank0([&]() { + UdsExchanger uds0(/*self_rank=*/0); + + if (!uds0.accept_from(/*peer_rank=*/1, /*timeout_ms=*/5000)) { + std::cerr << "[ERROR] rank0 accept_from(1) failed\n"; + return; + } + std::cout << "[INFO] rank0 accepted connection from rank1\n"; + + IpcCacheWire got{}; + uint64_t seq = 0; + if (!uds0.recv_ipc_cache(/*peer_rank=*/1, got, &seq, /*timeout_ms=*/5000)) { + std::cerr << "[ERROR] rank0 recv_ipc_cache(1) failed\n"; + return; + } + + std::cout << "[INFO] rank0 received IpcCacheWire" + << " seq=" << seq << " is_send=" << (int)got.is_send + << " offset=" << got.offset << " size=" << got.size << "\n"; + + bool ok = true; + ok &= (seq == 42); + ok &= (got.is_send == 1); + ok &= (got.offset == 0x12345678ULL); + ok &= (got.size == 4096ULL); + + // handle is opaque; here we just sanity-check that it matches the pattern + // we set + const uint8_t* hb = reinterpret_cast(&got.handle); + ok &= (hb[0] == 0xA0); + ok &= (hb[1] == 0xA1); + + if (!ok) { + std::cerr << "[ERROR] rank0 validation failed\n"; + return; + } + + std::cout << "[INFO] rank0 validation OK\n"; + }); + + std::thread rank1([&]() { + // small delay so rank0 can start listening; connect_to has retry anyway + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + + UdsExchanger uds1(/*self_rank=*/1); + + if (!uds1.connect_to(/*peer_rank=*/0, /*timeout_ms=*/5000)) { + std::cerr << "[ERROR] rank1 connect_to(0) failed\n"; + return; + } + std::cout << "[INFO] rank1 connected to rank0\n"; + + IpcCacheWire w{}; + std::memset(&w.handle, 0, sizeof(w.handle)); + // Fill recognizable pattern into handle bytes (no CUDA call needed for unit + // test) + uint8_t* hb = reinterpret_cast(&w.handle); + for (size_t i = 0; i < sizeof(gpuIpcMemHandle_t); ++i) { + hb[i] = static_cast(0xA0 + (i & 0x0F)); + } + + w.is_send = 1; + w.offset = 0x12345678ULL; + w.size = 4096ULL; + + const uint64_t seq = 42; + if (!uds1.send_ipc_cache(/*peer_rank=*/0, seq, w)) { + std::cerr << "[ERROR] rank1 send_ipc_cache failed\n"; + return; + } + + std::cout << "[INFO] rank1 sent IpcCacheWire seq=" << seq << "\n"; + }); + + rank0.join(); + rank1.join(); + + std::cout << "[TEST] UDS OOB test completed\n"; +} diff --git a/p2p/Makefile b/p2p/Makefile index 4ef3137cd..7c2004690 100644 --- a/p2p/Makefile +++ b/p2p/Makefile @@ -4,6 +4,8 @@ USE_TCPX ?= $(shell echo $${USE_TCPX:-0}) # EFA optional integration USE_EFA ?= $(shell echo $${USE_EFA:-0}) +# IB optional integration +USE_IB ?= $(shell echo $${USE_IB:-0}) # Compiler and flags CUDA_HOME ?= /usr/local/cuda @@ -13,22 +15,25 @@ RDMA_HOME := ../collective/rdma EFA_HOME ?= /opt/amazon/efa # Building with the following settings: -INC = -I./ -I$(CUDA_HOME)/include +INC = -I./ -I$(CUDA_HOME)/include LIBS = -L ${CUDA_HOME}/lib64 -libverbs -lcudart -lcuda -lpthread -lglog -lgflags -lgtest -lz -lelf LIBS2 = CXX := g++ -CXXFLAGS := -O3 -shared -std=c++17 -fPIC -I../include -I$(RDMA_HOME) -I$(CUDA_INC) \ +CXXFLAGS := -O3 -shared -std=c++17 -fPIC -I. -I../include -I$(RDMA_HOME) -I$(CUDA_INC) \ -Wno-pointer-arith -Wno-sign-compare -Wno-unused-variable \ - -Wl,-rpath=/usr/lib/x86_64-linux-gnu $(CXXFLAGS_EFA) + -Wl,-rpath=/usr/lib/x86_64-linux-gnu ifeq ($(USE_TCPX),1) CXXFLAGS += -DUCCL_P2P_USE_TCPX LIBS2 += -lglog -lgflags -lpthread -ldl else ifeq ($(USE_EFA),1) - CXXFLAGS += -DUCCL_P2P_USE_EFA + CXXFLAGS += -DUCCL_P2P_USE_NATIVE_RDMA INC += -I${EFA_HOME}/include LIBS += -L ${EFA_HOME}/lib -lefa +else ifeq ($(USE_IB),1) + CXXFLAGS += -DUCCL_P2P_USE_NATIVE_RDMA -DUCCL_P2P_USE_IB + LIBS += -libverbs else LIBS2 += -lglog -lgflags -lgtest -lz -lelf -lpthread -libverbs endif diff --git a/p2p/Makefile.rocm b/p2p/Makefile.rocm index a22459125..8633611bc 100644 --- a/p2p/Makefile.rocm +++ b/p2p/Makefile.rocm @@ -1,5 +1,10 @@ # Makefile for UCCL P2P Engine pybind11 project +# EFA optional integration +USE_EFA ?= $(shell echo $${USE_EFA:-0}) +# IB optional integration +USE_IB ?= $(shell echo $${USE_IB:-0}) + # Compiler and flags HIP_HOME?=/opt/rocm HIP_INC := $(HIP_HOME)/include @@ -7,9 +12,17 @@ HIP_LIB := $(HIP_HOME)/lib CONDA_LIB_HOME?=/usr/lib RDMA_HOME := ../collective/rdma CXX := g++ -CXXFLAGS := -O3 -shared -std=c++17 -fPIC -I../include -I$(RDMA_HOME) -I$(HIP_INC) \ + +ifeq ($(USE_EFA),1) + CXXFLAGS_TRANS := -DUCCL_P2P_USE_NATIVE_RDMA +else ifeq ($(USE_IB),1) + CXXFLAGS_TRANS := -DUCCL_P2P_USE_NATIVE_RDMA -DUCCL_P2P_USE_IB +endif + +CXXFLAGS := -O3 -shared -std=c++17 -fPIC -I. -I../include -I$(RDMA_HOME) -I$(HIP_INC) \ -Wno-pointer-arith -Wno-sign-compare -Wno-unused-variable \ - -Wl,-rpath=/usr/lib/x86_64-linux-gnu -Wl,-rpath=${CONDA_LIB_HOME} -I${CONDA_LIB_HOME}/../include -L${CONDA_LIB_HOME} -lglog -lgflags -lgtest -lz -lelf -libverbs -lpthread + -Wl,-rpath=/usr/lib/x86_64-linux-gnu -Wl,-rpath=${CONDA_LIB_HOME} -I${CONDA_LIB_HOME}/../include \ + $(CXXFLAGS_TRANS) # Python and pybind11 configuration PYTHON ?= python3 @@ -27,7 +40,7 @@ INCDIR ?= $(PREFIX)/include CXXFLAGS += -D__HIP_PLATFORM_AMD__ LDFLAGS = -L$(HIP_LIB) -lamdhip64 \ - -Wl,-rpath,$(HIP_LIB) -I${CONDA_LIB_HOME}/../include -L${CONDA_LIB_HOME} -lglog -lgflags -lgtest \ + -Wl,-rpath,$(HIP_LIB) -L${CONDA_LIB_HOME} -lglog -lgflags -lgtest \ -lz -lelf -libverbs -lpthread # Target and source files @@ -54,7 +67,7 @@ $(P2P_SHARED_LIB): $(CAPI_OBJECT) $(CORE_OBJECT) $(RDMA_PLUGIN_LIB) $(CXX) $(CAPI_OBJECT) $(CORE_OBJECT) $(RDMA_PLUGIN_LIB) \ -L$(HIP_LIB) -lamdhip64 \ -o $@ $(LDFLAGS) $(PYTHON_LDFLAGS) $(CXXFLAGS) \ - -Wl,-rpath,$(HIP_LIB) $(PYTHON_LIB) -libverbs -lz -lelf + -Wl,-rpath,$(HIP_LIB) $(PYTHON_LIB) $(RDMA_HOME)/librdma_hip.a: $(filter-out $(RDMA_HOME)/nccl_plugin.cc %_test.cc, $(wildcard $(RDMA_HOME)/*.cc)) $(RDMA_HOME)/*.h make -C $(RDMA_HOME) -f Makefile.rocm librdma_hip.a -j$(nproc) diff --git a/p2p/README.md b/p2p/README.md index 40b1610d5..53e284dfd 100644 --- a/p2p/README.md +++ b/p2p/README.md @@ -71,18 +71,16 @@ To build GCP TCPX support, you can refer to [NIXL_plugin_readme.md](./NIXL_plugi ## Performance Benchmarks -Navigate to `benchmarks` directory: - ### Running UCCL P2P On client: ```bash -torchrun --nnodes=2 --nproc_per_node=1 --node-rank=0 --master_addr= benchmark_uccl.py +torchrun --nnodes=2 --nproc_per_node=1 --node-rank=0 --master_addr= benchmarks/benchmark_uccl.py ``` On server: ```bash -torchrun --nnodes=2 --nproc_per_node=1 --node-rank=1 --master_addr= benchmark_uccl.py +torchrun --nnodes=2 --nproc_per_node=1 --node-rank=1 --master_addr= benchmarks/benchmark_uccl.py ``` Notes: @@ -90,28 +88,26 @@ Notes: * You may consider exporting `UCCL_IB_GID_INDEX` if your cluster requires it for NCCL to run (usually 1, or 3 in some testbed). * **You must first import `torch` before importing `uccl.p2p` for AMD GPUs**, otherwise, `RuntimeError: No HIP GPUs are available` will occur. We guess this is because torch does some extra init for AMD GPUs, in order for Pybind-C++ code to work. * To benchmark dual direction transfer, `benchmark_uccl.py --dual`. -* To benchmark intra-node transfer via CUDA/HIP IPC, `torchrun --nproc_per_node=2 benchmark_uccl.py --ipc`. +* To benchmark intra-node transfer via CUDA/HIP IPC, `torchrun --nproc_per_node=2 benchmarks/benchmark_uccl.py --ipc`. * To benchmark one-sided READ/WRITE transfer, `benchmark_uccl_readwrite.py`. * To benchmark UCCL copy-only collectives (eg, sendrecv, allgather), `benchmark_uccl_collective.py`. You can also run ring-like communication pattern with `--ring`. * From CollectiveContext, the default parameter `use_copy_engine_for_intra` is `False`, which means it will use NCCL/RCCL via `torch.distributed` for intra-node communication; if setting to `True`, it will use GPU copy engine (eg, `cudaMemcpy`) via UCCL for intranode communication. | Environment Variable | Description | Default Value | |---------------------|-------------|---------------| -| UCCL_IB_HCA | The names of IB devices used | (null) | -| UCCL_IB_GID_INDEX | Global ID index used in RoCE mode | -1 | -| UCCL_PORT_ENTROPY | Paths/QPs per engine | 32 | -| UCCL_CHUNK_SIZE_KB | Maximum chunk size for each WQE | 64 | +| UCCL_P2P_LOG_LEVEL | Logging level | WARNING (others: INFO, ERROR, FATAL) | +| UCCL_IB_GID_INDEX | Global ID index in RDMA network | 0 | ### Running NCCL On Client: ```bash -NCCL_NCHANNELS_PER_NET_PEER=4 torchrun --nnodes=2 --nproc_per_node=1 --node-rank=0 --master_addr= benchmark_nccl.py +NCCL_NCHANNELS_PER_NET_PEER=4 torchrun --nnodes=2 --nproc_per_node=1 --node-rank=0 --master_addr= benchmarks/benchmark_nccl.py ``` On Server: ```bash -NCCL_NCHANNELS_PER_NET_PEER=4 torchrun --nnodes=2 --nproc_per_node=1 --node-rank=1 --master_addr= benchmark_nccl.py +NCCL_NCHANNELS_PER_NET_PEER=4 torchrun --nnodes=2 --nproc_per_node=1 --node-rank=1 --master_addr= benchmarks/benchmark_nccl.py ``` Notes: diff --git a/p2p/benchmarks/benchmark_uccl_alltoall.py b/p2p/benchmarks/benchmark_uccl_alltoall.py index 7cedcb460..2c858e8c6 100644 --- a/p2p/benchmarks/benchmark_uccl_alltoall.py +++ b/p2p/benchmarks/benchmark_uccl_alltoall.py @@ -12,7 +12,7 @@ """ Benchmark UCCL Collective for Alltoall -NCCL_IB_GID_INDEX=3 UCCL_ENTROPY=2 UCCL_CHUNK_SIZE_KB=64 torchrun --nnodes=2 --nproc_per_node=1 --node-rank=1 --master_addr=10.21.9.41 --master_port=19999 benchmark_uccl_alltoall.py --block-size 1024 4096 16384 65536 264114 --num-qo-heads 32 --gqa-group-size 4 --head-dim 128 --num-iters 100 +UCCL_IB_GID_INDEX=3 torchrun --nnodes=2 --nproc_per_node=1 --node-rank=1 --master_addr=10.21.9.41 --master_port=19999 benchmark_uccl_alltoall.py --block-size 1024 4096 16384 65536 264114 --num-qo-heads 32 --gqa-group-size 4 --head-dim 128 --num-iters 100 """ diff --git a/p2p/benchmarks/pynccl.py b/p2p/benchmarks/pynccl.py index f0a0eb1a3..e1e9bba9e 100644 --- a/p2p/benchmarks/pynccl.py +++ b/p2p/benchmarks/pynccl.py @@ -515,7 +515,6 @@ def ncclGroupEnd(self) -> None: class PyNcclCommunicator: - def __init__( self, group: ProcessGroup, diff --git a/p2p/benchmarks/worker.sh b/p2p/benchmarks/worker.sh index 18c286e56..6097d859b 100644 --- a/p2p/benchmarks/worker.sh +++ b/p2p/benchmarks/worker.sh @@ -60,9 +60,6 @@ export NCCL_BUFFSIZE=1048576 export NCCL_IB_SPLIT_DATA_ON_QPS=0 export NCCL_IB_PCI_RELAXED_ORDERING=1 -# UCCL envs -export UCCL_ENTROPY=16 -export UCCL_CHUNK_SIZE_KB=16 # 32 # Parameters MASTER_ADDR=${MASTER_ADDR:?} MASTER_PORT=${MASTER_PORT:-19999} diff --git a/p2p/efa/efa_channel.h b/p2p/efa/efa_channel.h deleted file mode 100644 index 50cf485bf..000000000 --- a/p2p/efa/efa_channel.h +++ /dev/null @@ -1,315 +0,0 @@ -#pragma once -#include "define.h" -#include "rdma_context.h" -#include "seq_num.h" -#include "util/util.h" -#include - -class EFAChannel { - public: - enum class QPXType { - WriteImm, - Write, - READ, - }; - - explicit EFAChannel(std::shared_ptr ctx, uint32_t channel_id = 0) - : ctx_(ctx), - qp_(nullptr), - cq_ex_(nullptr), - ah_(nullptr), - channel_id_(channel_id), - local_meta_(std::make_shared()), - remote_meta_(std::make_shared()) { - initQP(); - } - - explicit EFAChannel(std::shared_ptr ctx, - ChannelMetaData const& remote_meta, - uint32_t channel_id = 0) - : ctx_(ctx), - qp_(nullptr), - cq_ex_(nullptr), - ah_(nullptr), - channel_id_(channel_id), - local_meta_(std::make_shared()), - remote_meta_(std::make_shared(remote_meta)) { - initQP(); - ah_ = ctx_->createAH(remote_meta_->gid); - UCCL_LOG_EP << "EFAChannel connected to remote qpn=" << remote_meta.qpn; - } - - EFAChannel(EFAChannel const&) = delete; - EFAChannel& operator=(EFAChannel const&) = delete; - - void connect(ChannelMetaData const& remote_meta) { - remote_meta_ = std::make_shared(remote_meta); - ah_ = ctx_->createAH(remote_meta_->gid); - UCCL_LOG_EP << "EFAChannel connected to remote qpn=" << remote_meta.qpn; - } - - int64_t submitRequest(std::shared_ptr req) { - return postRequest(req); - } - - int64_t read(std::shared_ptr req) { - int ret = postRequest(req); - if (ret != 0) { - LOG(ERROR) << "Failed to post read request, wr_id=" << req->wr_id; - return -1; - } - return req->wr_id; - } - - int64_t send(std::shared_ptr req) { - int ret = postRequest(req); - if (ret != 0) { - LOG(ERROR) << "Failed to post send request, wr_id=" << req->wr_id; - return -1; - } - return req->wr_id; - } - - int64_t recv(std::shared_ptr req) { - struct ibv_sge sge = { - .addr = (uintptr_t)req->getLocalAddress(), - .length = (uint32_t)req->getLocalLen(), - .lkey = req->getLocalKey(), - }; - struct ibv_recv_wr wr = {0}, *bad_wr = nullptr; - int64_t wr_id = req->wr_id; - wr.wr_id = wr_id; - wr.sg_list = &sge; - wr.num_sge = 1; - - if (ibv_post_recv(qp_, &wr, &bad_wr)) { - LOG(ERROR) << "ibv_post_recv failed: " << strerror(errno); - } - return wr_id; - } - - bool poll_once(std::vector& cq_datas) { - if (!cq_ex_) { - LOG(INFO) << "poll_once - channel_id: " << channel_id_ - << ", cq_ex_ is null"; - return false; - } - - struct ibv_poll_cq_attr attr = {}; - int ret = ibv_start_poll(cq_ex_, &attr); - - if (ret == ENOENT) { - return false; - } - if (ret) { - LOG(ERROR) << "poll_once - channel_id: " << channel_id_ - << ", ibv_start_poll error: " << ret << " (" << strerror(ret) - << ")"; - return false; - } - - do { - uint64_t wr_id = cq_ex_->wr_id; - auto status = cq_ex_->status; - if (unlikely(status != IBV_WC_SUCCESS)) { - LOG(WARNING) << "poll_once - channel_id: " << channel_id_ - << ", CQE error, wr_id=" << wr_id << ", status=" << status - << " (" << ibv_wc_status_str(status) << ")"; - } else { - CQMeta cq_data{}; - cq_data.wr_id = wr_id; - cq_data.op_code = ibv_wc_read_opcode(cq_ex_); - cq_data.len = ibv_wc_read_byte_len(cq_ex_); - - if (cq_data.op_code == IBV_WC_RECV_RDMA_WITH_IMM) { - cq_data.imm = ibv_wc_read_imm_data(cq_ex_); - } else { - cq_data.imm = 0; - } - - cq_datas.emplace_back(cq_data); - } - - ret = ibv_next_poll(cq_ex_); - } while (ret == 0); - - ibv_end_poll(cq_ex_); - - if (ret != ENOENT) { - LOG(ERROR) << "poll_once - channel_id: " << channel_id_ - << ", ibv_next_poll error: " << ret << " (" << strerror(ret) - << ")"; - } - - return !cq_datas.empty(); - } - - // Get local metadata - std::shared_ptr get_local_meta() const { - return local_meta_; - } - - // Get remote metadata - std::shared_ptr get_remote_meta() const { - return remote_meta_; - } - - // Get RdmaContext - inline std::shared_ptr const getContext() const { return ctx_; } - - inline uint64_t const getContextID() const { return ctx_->getContextID(); } - - inline uint32_t getChannelID() const { return channel_id_; } - - private: - std::shared_ptr ctx_; - uint32_t channel_id_; - - struct ibv_cq_ex* cq_ex_; - struct ibv_qp* qp_; - struct ibv_ah* ah_; - - std::shared_ptr local_meta_; - std::shared_ptr remote_meta_; - - std::shared_ptr tracker_; - - struct ibv_cq_ex* getCQ() const { - return cq_ex_; - } - - struct ibv_qp* getQP() const { - return qp_; - } - - // Post send request based on send_type - // Returns 0 on success, error code on failure - inline int postRequest(std::shared_ptr req) { - auto* qpx = ibv_qp_to_qp_ex(qp_); - ibv_wr_start(qpx); - LOG(INFO) << *req; - qpx->wr_id = req->wr_id; - qpx->comp_mask = 0; - qpx->wr_flags = IBV_SEND_SIGNALED; - - if (req->send_type == SendType::Send) { - ibv_wr_rdma_write_imm(qpx, req->getRemoteKey(), req->getRemoteAddress(), - req->imm_data); - } else if (req->send_type == SendType::Write) { - ibv_wr_rdma_write(qpx, req->getRemoteKey(), req->getRemoteAddress()); - } else if (req->send_type == SendType::Read) { - ibv_wr_rdma_read(qpx, req->getRemoteKey(), req->getRemoteAddress()); - } else { - LOG(ERROR) << "Unknown SendType in EFAChannel::postRequest"; - return -1; - } - - struct ibv_sge sge[1]; - int num_sge = prepareSGEList(sge, req); - ibv_wr_set_sge_list(qpx, num_sge, sge); - ibv_wr_set_ud_addr(qpx, ah_, remote_meta_->qpn, kQKey); - - int ret = ibv_wr_complete(qpx); - if (ret) { - std::ostringstream sge_info; - sge_info << "["; - for (int i = 0; i < num_sge; ++i) { - if (i > 0) sge_info << ", "; - sge_info << "{addr:0x" << std::hex << sge[i].addr - << ", len:" << std::dec << sge[i].length << ", lkey:0x" - << std::hex << sge[i].lkey << std::dec << "}"; - } - sge_info << "]"; - - LOG(ERROR) << "ibv_wr_complete failed in postRequest: " << ret << " " - << strerror(ret) << ", ah_=" << (void*)ah_ - << ", remote_qpn=" << remote_meta_->qpn - << ", local_qpn=" << qp_->qp_num << ", wr_id=" << req->wr_id - << ", remote_key=" << req->getRemoteKey() << ", remote_addr=0x" - << std::hex << req->getRemoteAddress() - << ", local_key=" << req->getLocalKey() - << ", num_sge=" << num_sge << ", sge_list=" << sge_info.str() - << std::dec; - } - return ret; - } - - void initQP() { - struct ibv_cq_init_attr_ex cq_attr = {0}; - cq_attr.cqe = 1024; - cq_attr.wc_flags = IBV_WC_STANDARD_FLAGS; - cq_attr.comp_mask = 0; - - cq_ex_ = ibv_create_cq_ex(ctx_->getCtx(), &cq_attr); - assert(cq_ex_); - - struct ibv_qp_init_attr_ex qp_attr = {0}; - qp_attr.comp_mask = IBV_QP_INIT_ATTR_PD | IBV_QP_INIT_ATTR_SEND_OPS_FLAGS; - qp_attr.send_ops_flags = IBV_QP_EX_WITH_RDMA_WRITE | - IBV_QP_EX_WITH_RDMA_WRITE_WITH_IMM | - IBV_QP_EX_WITH_RDMA_READ; - - qp_attr.cap.max_send_wr = kMaxSendWr; - qp_attr.cap.max_recv_wr = kMaxRecvWr; - qp_attr.cap.max_send_sge = kMaxSendSeg; - qp_attr.cap.max_recv_sge = kMaxRecvSeg; - qp_attr.cap.max_inline_data = 0; - - qp_attr.send_cq = ibv_cq_ex_to_cq(cq_ex_); - qp_attr.recv_cq = ibv_cq_ex_to_cq(cq_ex_); - - qp_attr.pd = ctx_->getPD(); - qp_attr.qp_context = ctx_->getCtx(); - qp_attr.sq_sig_all = 0; - - qp_attr.qp_type = IBV_QPT_DRIVER; - - struct efadv_qp_init_attr efa_attr = {}; - efa_attr.driver_qp_type = EFADV_QP_DRIVER_TYPE_SRD; - efa_attr.sl = kEfaQpLowLatencyServiceLevel; - efa_attr.flags = 0; - // If set, Receive WRs will not be consumed for RDMA write with imm. - efa_attr.flags |= EFADV_QP_FLAGS_UNSOLICITED_WRITE_RECV; - - qp_ = efadv_create_qp_ex(ctx_->getCtx(), &qp_attr, &efa_attr, - sizeof(efa_attr)); - assert(qp_); - - struct ibv_qp_attr attr = {}; - memset(&attr, 0, sizeof(attr)); - attr.qp_state = IBV_QPS_INIT; - attr.port_num = kPortNum; - attr.qkey = kQKey; - attr.pkey_index = 0; - assert(ibv_modify_qp(qp_, &attr, - IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | - IBV_QP_QKEY) == 0); - - memset(&attr, 0, sizeof(attr)); - attr.qp_state = IBV_QPS_RTR; - assert(ibv_modify_qp(qp_, &attr, IBV_QP_STATE) == 0); - - memset(&attr, 0, sizeof(attr)); - attr.qp_state = IBV_QPS_RTS; - // attr.rnr_retry = 10; - // attr.min_rnr_timer = 10; - attr.rnr_retry = kEfaRdmDefaultRnrRetry; - assert(ibv_modify_qp(qp_, &attr, - IBV_QP_STATE | IBV_QP_SQ_PSN | IBV_QP_RNR_RETRY) == 0); - - local_meta_->gid = ctx_->queryGid(kGidIndex); - local_meta_->qpn = qp_->qp_num; - } - - // Prepare SGE list for send request - // Returns the number of SGE entries filled - int prepareSGEList(struct ibv_sge* sge, std::shared_ptr req) { - uint32_t total_len = req->getLocalLen(); - uint64_t local_addr = req->getLocalAddress(); - uint32_t local_key = req->getLocalKey(); - sge[0].addr = local_addr; - sge[0].length = total_len; - sge[0].lkey = local_key; - return 1; - } -}; diff --git a/p2p/endpoint_wrapper.h b/p2p/endpoint_wrapper.h index 8887f2f60..0a6e18018 100644 --- a/p2p/endpoint_wrapper.h +++ b/p2p/endpoint_wrapper.h @@ -1,8 +1,6 @@ #pragma once #include "engine.h" -// #define UCCL_P2P_USE_EFA - namespace unified { template @@ -17,8 +15,8 @@ inline void delete_ep(RDMAEndPoint const& s) { // raw pointer: we own it → delete delete ep; } -#ifdef UCCL_P2P_USE_EFA - else if constexpr (std::is_same_v>) { +#ifdef UCCL_P2P_USE_NATIVE_RDMA + else if constexpr (std::is_same_v>) { // shared_ptr: do nothing (shared_ptr handles lifetime) } #endif @@ -30,8 +28,8 @@ inline void delete_ep(RDMAEndPoint const& s) { s); } -#ifdef UCCL_P2P_USE_EFA -inline int set_request(std::shared_ptr const& obj, Conn* conn, +#ifdef UCCL_P2P_USE_NATIVE_RDMA +inline int set_request(std::shared_ptr const& obj, Conn* conn, unified::P2PMhandle* local_mh, void* src, size_t size, FifoItem const& slot_item, uccl::ucclRequest* ureq) { // Create RemoteMemInfo from FifoItem @@ -44,7 +42,7 @@ inline int set_request(std::shared_ptr const& obj, Conn* conn, auto local_mem = std::make_shared(src, size, MemoryType::GPU); local_mem->mr_array = local_mh->mr_array; - auto req = std::make_shared(local_mem, remote_mem); + auto req = std::make_shared(local_mem, remote_mem); req->to_rank_id = conn->uccl_conn_id_.flow_id; req->send_type = SendType::Write; @@ -111,8 +109,8 @@ inline bool uccl_regmr(RDMAEndPoint const& s, int dev, void* data, size_t len, return false; } } -#ifdef UCCL_P2P_USE_EFA - else if constexpr (std::is_same_v>) { +#ifdef UCCL_P2P_USE_NATIVE_RDMA + else if constexpr (std::is_same_v>) { if (obj->uccl_regmr(data, len, mhandle->mr_array) < 0) { return false; } @@ -141,13 +139,13 @@ inline int uccl_send_async(RDMAEndPoint const& s, Conn* conn, static_cast(conn->uccl_conn_id_.context), mh, data, size, ureq); } -#ifdef UCCL_P2P_USE_EFA - else if constexpr (std::is_same_v>) { +#ifdef UCCL_P2P_USE_NATIVE_RDMA + else if constexpr (std::is_same_v>) { auto send_mem = std::make_shared(const_cast(data), size, MemoryType::GPU); send_mem->mr_array = mhandle->mr_array; auto remote_mem_placeholder = std::make_shared(); - auto send_req = std::make_shared( + auto send_req = std::make_shared( send_mem, remote_mem_placeholder); ureq->type = uccl::ReqType::ReqTx; send_req->to_rank_id = conn->uccl_conn_id_.flow_id; @@ -181,12 +179,12 @@ inline int uccl_recv_async(RDMAEndPoint const& s, Conn* conn, static_cast(conn->uccl_conn_id_.context), &(mhandles->mhandle_), data, size, n, ureq); } -#ifdef UCCL_P2P_USE_EFA - else if constexpr (std::is_same_v>) { +#ifdef UCCL_P2P_USE_NATIVE_RDMA + else if constexpr (std::is_same_v>) { auto recv_mem = std::make_shared(data[0], size[0], MemoryType::GPU); recv_mem->mr_array = mhandles->mr_array; - auto recv_req = std::make_shared(recv_mem); + auto recv_req = std::make_shared(recv_mem); ureq->type = uccl::ReqType::ReqRx; ureq->engine_idx = obj->recv(conn->uccl_conn_id_.flow_id, recv_req); ureq->n = conn->uccl_conn_id_.flow_id; @@ -210,8 +208,8 @@ inline bool uccl_poll_ureq_once(RDMAEndPoint const& s, if constexpr (std::is_pointer_v) { return obj->uccl_poll_ureq_once(ureq); } -#ifdef UCCL_P2P_USE_EFA - else if constexpr (std::is_same_v>) { +#ifdef UCCL_P2P_USE_NATIVE_RDMA + else if constexpr (std::is_same_v>) { if (ureq->type == uccl::ReqType::ReqTx || ureq->type == uccl::ReqType::ReqWrite || ureq->type == uccl::ReqType::ReqRead) { @@ -245,8 +243,8 @@ inline int uccl_read_async(RDMAEndPoint const& s, Conn* conn, local_mh->mhandle_, dst, size, static_cast(slot_item), ureq); } -#ifdef UCCL_P2P_USE_EFA - else if constexpr (std::is_same_v>) { +#ifdef UCCL_P2P_USE_NATIVE_RDMA + else if constexpr (std::is_same_v>) { ureq->type = uccl::ReqType::ReqRead; return set_request(obj, conn, local_mh, dst, size, slot_item, ureq); } @@ -273,8 +271,8 @@ inline int uccl_write_async(RDMAEndPoint const& s, Conn* conn, local_mh->mhandle_, src, size, static_cast(slot_item), ureq); } -#ifdef UCCL_P2P_USE_EFA - else if constexpr (std::is_same_v>) { +#ifdef UCCL_P2P_USE_NATIVE_RDMA + else if constexpr (std::is_same_v>) { ureq->type = uccl::ReqType::ReqWrite; return set_request(obj, conn, local_mh, src, size, slot_item, ureq); } @@ -300,8 +298,8 @@ inline int prepare_fifo_metadata(RDMAEndPoint const& s, Conn* conn, static_cast(conn->uccl_conn_id_.context), &(mhandle->mhandle_), data, size, out_buf); } -#ifdef UCCL_P2P_USE_EFA - else if constexpr (std::is_same_v>) { +#ifdef UCCL_P2P_USE_NATIVE_RDMA + else if constexpr (std::is_same_v>) { FifoItem remote_mem_info; remote_mem_info.addr = reinterpret_cast(data); remote_mem_info.size = size; @@ -332,8 +330,8 @@ inline void uccl_deregmr(RDMAEndPoint const& s, P2PMhandle* mhandle) { // raw pointer: call with mhandle_ obj->uccl_deregmr(mhandle->mhandle_); } -#ifdef UCCL_P2P_USE_EFA - else if constexpr (std::is_same_v>) { +#ifdef UCCL_P2P_USE_NATIVE_RDMA + else if constexpr (std::is_same_v>) { obj->uccl_deregmr(mhandle->mr_array); } #endif diff --git a/p2p/engine.cc b/p2p/engine.cc index 39353d426..2f70dca48 100644 --- a/p2p/engine.cc +++ b/p2p/engine.cc @@ -65,9 +65,9 @@ Endpoint::Endpoint(uint32_t const local_gpu_idx, uint32_t const num_cpus) google::InstallFailureSignalHandler(); // Initialize the RDMA endpoint with lazy creation. -#ifdef UCCL_P2P_USE_EFA - ep_ = std::shared_ptr( - new EFAEndpoint(local_gpu_idx_, INVALID_RANK_ID, 0, false)); +#ifdef UCCL_P2P_USE_NATIVE_RDMA + ep_ = std::shared_ptr( + new NICEndpoint(local_gpu_idx_, INVALID_RANK_ID, 0, false)); numa_node_ = 0; #else ep_ = new uccl::RDMAEndpoint(num_cpus_); @@ -127,9 +127,9 @@ Endpoint::Endpoint(uint32_t const num_cpus) : num_cpus_(num_cpus) { []() { google::InitGoogleLogging("uccl_p2p"); }); google::InstallFailureSignalHandler(); -#ifdef UCCL_P2P_USE_EFA - ep_ = std::shared_ptr( - new EFAEndpoint(local_gpu_idx_, INVALID_RANK_ID, 0, false)); +#ifdef UCCL_P2P_USE_NATIVE_RDMA + ep_ = std::shared_ptr( + new NICEndpoint(local_gpu_idx_, INVALID_RANK_ID, 0, false)); #else // Initialize the RDMA endpoint with lazy creation. ep_ = new uccl::RDMAEndpoint(num_cpus_); @@ -205,7 +205,7 @@ void Endpoint::initialize_engine() { std::cout << "Lazy creation of engine, GPU index: " << local_gpu_idx_ << std::endl; // Initialize engine by fixed engine offset since we did lazy initialization -#ifndef UCCL_P2P_USE_EFA +#ifndef UCCL_P2P_USE_NATIVE_RDMA unified::initialize_engine_by_dev(ep_, gpu_to_dev[local_gpu_idx_], false); std::cout << "Engine initialized for GPU " << local_gpu_idx_ << std::endl; #endif diff --git a/p2p/engine.h b/p2p/engine.h index 63a96a40d..374196a60 100644 --- a/p2p/engine.h +++ b/p2p/engine.h @@ -1,22 +1,22 @@ #pragma once -#include "efa/define.h" -#ifdef UCCL_P2P_USE_EFA -#include "efa/efa_endpoint.h" -#endif #include "transport.h" #include "util/gpu_rt.h" #include "util/jring.h" #include "util/net.h" #include "util/shared_pool.h" #include "util/util.h" +#include #include #include #include +#include +#include #include #include #include #include +#include #include #include #include @@ -24,6 +24,97 @@ namespace py = pybind11; extern thread_local bool inside_python; + +// Forward declaration +struct ibv_mr; + +inline int parseLogLevelFromEnv() { + char const* env = std::getenv("UCCL_P2P_LOG_LEVEL"); + if (!env) { + return google::WARNING; + } + + if (!strcasecmp(env, "INFO")) return google::INFO; + if (!strcasecmp(env, "WARNING")) return google::WARNING; + if (!strcasecmp(env, "ERROR")) return google::ERROR; + if (!strcasecmp(env, "FATAL")) return google::FATAL; + + char* end = nullptr; + long val = std::strtol(env, &end, 10); + if (end != env && val >= 0 && val <= 3) { + return static_cast(val); + } + + return google::WARNING; +} + +namespace unified { +static constexpr int kNICContextNumber = 4; + +inline size_t channelIdToContextId(uint32_t channel_id) { + return (channel_id == 0) ? 0 : (channel_id - 1) % kNICContextNumber; +} + +template +struct RKeyArrayT { + T keys[kNICContextNumber]; + + RKeyArrayT() { std::memset(keys, 0, sizeof(keys)); } + + inline void copyFrom(RKeyArrayT const& other) { + static_assert(std::is_trivially_copyable_v, + "RKeyArrayT::copyFrom requires trivially copyable T"); + std::memcpy(keys, other.keys, sizeof(keys)); + } + + inline void copyFrom(char const* other) { + static_assert(std::is_trivially_copyable_v, + "RKeyArrayT::copyFrom requires trivially copyable T"); + std::memcpy(keys, other, sizeof(keys)); + } + + inline T getKeyByChannelID(uint32_t channel_id) const { + return getKeyByContextID(channelIdToContextId(channel_id)); + } + + inline T getKeyByContextID(size_t context_id) const { + return keys[context_id]; + } + + inline void setKeyByContextID(uint32_t context_id, T key) { + keys[context_id] = key; + } + + inline void setKeyByChannelID(uint32_t channel_id, T key) { + setKeyByContextID(channelIdToContextId(channel_id), key); + } + + T& operator[](int index) { return keys[index]; } + T const& operator[](int index) const { return keys[index]; } + + friend std::ostream& operator<<(std::ostream& os, RKeyArrayT const& arr) { + os << "RKeyArrayT{"; + for (int i = 0; i < kNICContextNumber; ++i) { + if (i > 0) os << ", "; + if constexpr (std::is_integral_v) { + os << "0x" << std::hex << arr.keys[i] << std::dec; + } else { + os << arr.keys[i]; + } + } + os << "}"; + return os; + } +}; + +using MRArray = RKeyArrayT; +} // namespace unified + +#ifdef UCCL_P2P_USE_NATIVE_RDMA +#include "rdma/define.h" +#include "rdma/rdma_endpoint.h" +#endif + namespace unified { struct P2PMhandle { @@ -31,9 +122,9 @@ struct P2PMhandle { MRArray mr_array; }; -#ifdef UCCL_P2P_USE_EFA +#ifdef UCCL_P2P_USE_NATIVE_RDMA using RDMAEndPoint = - std::variant>; + std::variant>; #else using RDMAEndPoint = std::variant; #endif @@ -64,8 +155,7 @@ using FifoItem = uccl::FifoItem; #endif class Endpoint { - uint64_t const kRTTBytes = 1024 * 1024; -#ifdef UCCL_P2P_USE_EFA +#ifdef UCCL_P2P_USE_NATIVE_RDMA uint64_t const kChunkSize = 1024 * 1024 * 1024; #else uint64_t const kChunkSize = 1024 * 1024; diff --git a/p2p/efa/define.h b/p2p/rdma/define.h similarity index 85% rename from p2p/efa/define.h rename to p2p/rdma/define.h index a7f67ddda..18f436e62 100644 --- a/p2p/efa/define.h +++ b/p2p/rdma/define.h @@ -36,12 +36,8 @@ #include #include -static constexpr int kGidIndex = 0; static constexpr int kRankIDPlaceHolder = 9999; -static constexpr int kEfaQpLowLatencyServiceLevel = 8; -static constexpr uint32_t kQKey = 0x15695; static constexpr uint8_t kPortNum = 1; -static constexpr uint8_t kEfaRdmDefaultRnrRetry = 3; static constexpr int kQpNumPerChannel = 2; static constexpr int kNICContextNumber = 2; @@ -53,6 +49,8 @@ static constexpr int kMaxSendSeg = 2; static constexpr int kMaxRecvSeg = 2; static constexpr uint64_t kMessageChunkSizeKB = 256; // 1 MB static constexpr uint64_t kMaxSplitNum = 16; +static constexpr uint32_t kBatchPostRecvWr = 32; +static constexpr uint32_t kBatchPollCqe = 32; static constexpr size_t kTaskRingSize = 1024; static constexpr size_t kRingCapacity = 16384; // Must be power of 2 @@ -67,26 +65,6 @@ inline size_t channelIdToContextId(uint32_t channel_id) { return (channel_id == 0) ? 0 : (channel_id - 1) % kNICContextNumber; } -inline int parseLogLevelFromEnv() { - char const* env = std::getenv("EFA_LOG_LEVEL"); - if (!env) { - return google::WARNING; - } - - if (!strcasecmp(env, "INFO")) return google::INFO; - if (!strcasecmp(env, "WARNING")) return google::WARNING; - if (!strcasecmp(env, "ERROR")) return google::ERROR; - if (!strcasecmp(env, "FATAL")) return google::FATAL; - - char* end = nullptr; - long val = std::strtol(env, &end, 10); - if (end != env && val >= 0 && val <= 3) { - return static_cast(val); - } - - return google::WARNING; -} - struct MessageChunk { uint64_t offset; // Offset from the start of the message size_t size; // Size of this chunk in bytes @@ -155,62 +133,9 @@ enum class MemoryType { HOST, GPU }; enum class ChannelType : int16_t { Control, Normal }; template -struct RKeyArrayT { - T keys[kNICContextNumber]; - - RKeyArrayT() { std::memset(keys, 0, sizeof(keys)); } - - inline void copyFrom(RKeyArrayT const& other) { - static_assert(std::is_trivially_copyable_v, - "RKeyArrayT::copyFrom requires trivially copyable T"); - - std::memcpy(keys, other.keys, sizeof(keys)); - } - - inline void copyFrom(char const* other) { - static_assert(std::is_trivially_copyable_v, - "RKeyArrayT::copyFrom requires trivially copyable T"); - - std::memcpy(keys, other, sizeof(keys)); - } - - inline T getKeyByChannelID(uint32_t channel_id) const { - return getKeyByContextID(channelIdToContextId(channel_id)); - } - - inline T getKeyByContextID(size_t context_id) const { - return keys[context_id]; - } - - inline void setKeyByContextID(uint32_t context_id, T key) { - keys[context_id] = key; - } - - inline void setKeyByChannelID(uint32_t channel_id, T key) { - setKeyByContextID(channelIdToContextId(channel_id), key); - } - - T& operator[](int index) { return keys[index]; } - T const& operator[](int index) const { return keys[index]; } - - friend std::ostream& operator<<(std::ostream& os, RKeyArrayT const& arr) { - os << "RKeyArrayT{"; - for (int i = 0; i < kNICContextNumber; ++i) { - if (i > 0) os << ", "; - if constexpr (std::is_integral_v) { - os << "0x" << std::hex << arr.keys[i] << std::dec; - } else { - os << arr.keys[i]; - } - } - os << "}"; - return os; - } -}; - -// Type alias for backward compatibility and convenience -using RKeyArray = RKeyArrayT; -using MRArray = RKeyArrayT; +using RKeyArrayT = unified::RKeyArrayT; +using RKeyArray = unified::RKeyArrayT; +using MRArray = unified::MRArray; inline void copyRKeyArrayFromMRArray(MRArray const& mr_array, RKeyArray& rkey_array) { @@ -377,7 +302,7 @@ typedef struct RemoteMemInfo { } } RemoteMemInfo; -typedef struct EFARecvRequest { +typedef struct RDMARecvRequest { uint32_t from_rank_id; uint32_t to_rank_id; uint32_t channel_id; @@ -385,7 +310,7 @@ typedef struct EFARecvRequest { std::shared_ptr local_mem; // Constructor - EFARecvRequest(std::shared_ptr local) : local_mem(local) {} + RDMARecvRequest(std::shared_ptr local) : local_mem(local) {} // Getter methods inline uint32_t getLocalKey() const { @@ -398,8 +323,9 @@ typedef struct EFARecvRequest { inline uint32_t getLocalLen() const { return local_mem->size; } - friend std::ostream& operator<<(std::ostream& os, EFARecvRequest const& req) { - os << "EFARecvRequest{"; + friend std::ostream& operator<<(std::ostream& os, + RDMARecvRequest const& req) { + os << "RDMARecvRequest{"; os << "from_rank_id: " << req.from_rank_id << ", to_rank_id: " << req.to_rank_id << ", channel_id: " << req.channel_id; @@ -411,7 +337,7 @@ typedef struct EFARecvRequest { os << "}"; return os; } -} EFARecvRequest; +} RDMARecvRequest; enum class ReqFlag : int16_t { PENDING = 2, IN_PROGRESS = 3, IS_DONE = 4 }; @@ -437,7 +363,7 @@ struct alignas(64) SendReqMeta { expected_chunk_count(expected), received_chunk_count(received) {} - SendReqMeta(std::shared_ptr rev_req) { + SendReqMeta(std::shared_ptr rev_req) { rank_id = rev_req->from_rank_id; channel_id = rev_req->channel_id; remote_mem = rev_req->local_mem; @@ -531,7 +457,7 @@ inline auto from_ring_meta = [](SendReqMetaOnRing const& src, SendReqMeta& dst) { dst = src.meta; }; enum class SendType { Send, Write, Read }; -struct EFASendRequest { +struct RDMASendRequest { std::shared_ptr local_mem; std::shared_ptr remote_mem; uint32_t from_rank_id; @@ -543,27 +469,27 @@ struct EFASendRequest { SendType send_type = SendType::Send; // Constructor - EFASendRequest(std::shared_ptr local, - std::shared_ptr remote, uint32_t imm = 0, - bool signaled = true) + RDMASendRequest(std::shared_ptr local, + std::shared_ptr remote, uint32_t imm = 0, + bool signaled = true) : local_mem(local), remote_mem(remote), imm_data(imm), need_signaled(signaled) {} - // Constructor from shared_ptr - EFASendRequest(std::shared_ptr other, - std::shared_ptr local, uint32_t imm = 0, - bool signaled = true) + // Constructor from shared_ptr + RDMASendRequest(std::shared_ptr other, + std::shared_ptr local, uint32_t imm = 0, + bool signaled = true) : local_mem(local), remote_mem(other->remote_mem), imm_data(imm), need_signaled(signaled) {} - // Constructor from const EFASendRequest& - EFASendRequest(EFASendRequest const& other, - std::shared_ptr local, uint32_t imm = 0, - bool signaled = true) + // Constructor from const RDMASendRequest& + RDMASendRequest(RDMASendRequest const& other, + std::shared_ptr local, uint32_t imm = 0, + bool signaled = true) : local_mem(local), remote_mem(other.remote_mem), imm_data(imm), @@ -588,8 +514,9 @@ struct EFASendRequest { inline uint32_t getLocalLen() const { return local_mem->size; } - friend std::ostream& operator<<(std::ostream& os, EFASendRequest const& req) { - os << "EFASendRequest{"; + friend std::ostream& operator<<(std::ostream& os, + RDMASendRequest const& req) { + os << "RDMASendRequest{"; os << "from_rank_id: " << req.from_rank_id << ", to_rank_id: " << req.to_rank_id << ", channel_id: " << req.channel_id << ", imm_data: " << req.imm_data diff --git a/p2p/efa/epoll_client.h b/p2p/rdma/epoll_client.h similarity index 100% rename from p2p/efa/epoll_client.h rename to p2p/rdma/epoll_client.h diff --git a/p2p/efa/epoll_server.h b/p2p/rdma/epoll_server.h similarity index 100% rename from p2p/efa/epoll_server.h rename to p2p/rdma/epoll_server.h diff --git a/p2p/efa/memory_allocator.h b/p2p/rdma/memory_allocator.h similarity index 100% rename from p2p/efa/memory_allocator.h rename to p2p/rdma/memory_allocator.h diff --git a/p2p/rdma/providers/efa/rdma_channel_impl_efa.cc b/p2p/rdma/providers/efa/rdma_channel_impl_efa.cc new file mode 100644 index 000000000..f1c7237b5 --- /dev/null +++ b/p2p/rdma/providers/efa/rdma_channel_impl_efa.cc @@ -0,0 +1,172 @@ +#ifndef RDMA_CHANNEL_IMPL_EFA_CC_INCLUDED +#define RDMA_CHANNEL_IMPL_EFA_CC_INCLUDED + +#include "rdma_channel_impl_efa.h" +#include +#include +#include + +#define GID_INDEX 0 +#define MAX_INLINE_DATA 0 +#define SERVICE_LEVEL 8 +#define QKEY 0x15695 + +#define RNR_RETRY 3 + +#define MAX_CQE 1024 + +inline void EFAChannelImpl::initQP(std::shared_ptr ctx, + struct ibv_cq_ex** cq_ex, struct ibv_qp** qp, + ChannelMetaData* local_meta) { + struct ibv_cq_init_attr_ex cq_attr = {0}; + cq_attr.cqe = MAX_CQE; + cq_attr.wc_flags = IBV_WC_STANDARD_FLAGS; + cq_attr.comp_mask = 0; + + *cq_ex = ibv_create_cq_ex(ctx->getCtx(), &cq_attr); + assert(*cq_ex); + + struct ibv_qp_init_attr_ex qp_attr = {0}; + qp_attr.comp_mask = IBV_QP_INIT_ATTR_PD | IBV_QP_INIT_ATTR_SEND_OPS_FLAGS; + qp_attr.send_ops_flags = IBV_QP_EX_WITH_RDMA_WRITE | + IBV_QP_EX_WITH_RDMA_WRITE_WITH_IMM | + IBV_QP_EX_WITH_RDMA_READ; + + qp_attr.cap.max_send_wr = kMaxSendWr; + qp_attr.cap.max_recv_wr = kMaxRecvWr; + qp_attr.cap.max_send_sge = kMaxSendSeg; + qp_attr.cap.max_recv_sge = kMaxRecvSeg; + qp_attr.cap.max_inline_data = getMaxInlineData(); + + qp_attr.send_cq = ibv_cq_ex_to_cq(*cq_ex); + qp_attr.recv_cq = ibv_cq_ex_to_cq(*cq_ex); + + qp_attr.pd = ctx->getPD(); + qp_attr.qp_context = ctx->getCtx(); + qp_attr.sq_sig_all = 0; + + qp_attr.qp_type = IBV_QPT_DRIVER; + + struct efadv_qp_init_attr efa_attr = {}; + efa_attr.driver_qp_type = EFADV_QP_DRIVER_TYPE_SRD; + efa_attr.sl = SERVICE_LEVEL; + efa_attr.flags = 0; + // If set, Receive WRs will not be consumed for RDMA write with imm. + efa_attr.flags |= EFADV_QP_FLAGS_UNSOLICITED_WRITE_RECV; + + *qp = + efadv_create_qp_ex(ctx->getCtx(), &qp_attr, &efa_attr, sizeof(efa_attr)); + + assert(*qp); + + struct ibv_qp_attr attr = {}; + memset(&attr, 0, sizeof(attr)); + attr.qp_state = IBV_QPS_INIT; + attr.port_num = kPortNum; + attr.qkey = QKEY; + attr.pkey_index = 0; + assert(ibv_modify_qp(*qp, &attr, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | + IBV_QP_QKEY) == 0); + + memset(&attr, 0, sizeof(attr)); + attr.qp_state = IBV_QPS_RTR; + assert(ibv_modify_qp(*qp, &attr, IBV_QP_STATE) == 0); + + memset(&attr, 0, sizeof(attr)); + attr.qp_state = IBV_QPS_RTS; + attr.rnr_retry = RNR_RETRY; + assert(ibv_modify_qp(*qp, &attr, + IBV_QP_STATE | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN) == 0); + + local_meta->gid = ctx->queryGid(GID_INDEX); + local_meta->qpn = (*qp)->qp_num; +} + +inline void EFAChannelImpl::connectQP(struct ibv_qp* qp, + std::shared_ptr ctx, + ChannelMetaData const& remote_meta) { + (void)qp; + (void)ctx; + (void)remote_meta; +} + +inline bool EFAChannelImpl::poll_once(struct ibv_cq_ex* cq_ex, + std::vector& cq_datas, + uint32_t channel_id, + uint32_t& nb_post_recv) { + nb_post_recv = 0; + if (!cq_ex) { + LOG(INFO) << "poll_once - channel_id: " << channel_id << ", cq_ex_ is null"; + return false; + } + + struct ibv_poll_cq_attr attr = {}; + int ret = ibv_start_poll(cq_ex, &attr); + + if (ret == ENOENT) { + return false; + } + if (ret) { + LOG(ERROR) << "poll_once - channel_id: " << channel_id + << ", ibv_start_poll error: " << ret << " (" << strerror(ret) + << ")"; + return false; + } + + do { + uint64_t wr_id = cq_ex->wr_id; + auto status = cq_ex->status; + if (unlikely(status != IBV_WC_SUCCESS)) { + LOG(WARNING) << "poll_once - channel_id: " << channel_id + << ", CQE error, wr_id=" << wr_id << ", status=" << status + << " (" << ibv_wc_status_str(status) << ")"; + } else { + CQMeta cq_data{}; + cq_data.wr_id = wr_id; + cq_data.op_code = ibv_wc_read_opcode(cq_ex); + cq_data.len = ibv_wc_read_byte_len(cq_ex); + + if (cq_data.op_code == IBV_WC_RECV_RDMA_WITH_IMM) { + cq_data.imm = ibv_wc_read_imm_data(cq_ex); + } else { + cq_data.imm = 0; + } + + cq_datas.emplace_back(cq_data); + } + + ret = ibv_next_poll(cq_ex); + } while (ret == 0); + + ibv_end_poll(cq_ex); + + if (ret != ENOENT) { + LOG(ERROR) << "poll_once - channel_id: " << channel_id + << ", ibv_next_poll error: " << ret << " (" << strerror(ret) + << ")"; + } + + return !cq_datas.empty(); +} + +inline void EFAChannelImpl::lazy_post_recv_wrs_n(struct ibv_qp* qp, uint32_t n, + bool force) { + (void)qp; + (void)n; + (void)force; +} + +inline void EFAChannelImpl::setDstAddress(struct ibv_qp_ex* qpx, + struct ibv_ah* ah, + uint32_t remote_qpn) { + ibv_wr_set_ud_addr(qpx, ah, remote_qpn, QKEY); +} + +inline void EFAChannelImpl::initPreAllocResources() {} + +inline uint32_t EFAChannelImpl::getMaxInlineData() const { + return MAX_INLINE_DATA; +} + +#endif // RDMA_CHANNEL_IMPL_EFA_CC_INCLUDED diff --git a/p2p/rdma/providers/efa/rdma_channel_impl_efa.h b/p2p/rdma/providers/efa/rdma_channel_impl_efa.h new file mode 100644 index 000000000..84c11536e --- /dev/null +++ b/p2p/rdma/providers/efa/rdma_channel_impl_efa.h @@ -0,0 +1,31 @@ +#pragma once +#include "rdma/define.h" +#include "rdma/rdma_channel_impl.h" +#include + +class EFAChannelImpl : public RDMAChannelImpl { + public: + EFAChannelImpl() = default; + ~EFAChannelImpl() override = default; + + void initQP(std::shared_ptr ctx, struct ibv_cq_ex** cq_ex, + struct ibv_qp** qp, ChannelMetaData* local_meta) override; + + void connectQP(struct ibv_qp* qp, std::shared_ptr ctx, + ChannelMetaData const& remote_meta) override; + + bool poll_once(struct ibv_cq_ex* cq_ex, std::vector& cq_datas, + uint32_t channel_id, uint32_t& nb_post_recv) override; + + void lazy_post_recv_wrs_n(struct ibv_qp* qp, uint32_t n, bool force) override; + + void setDstAddress(struct ibv_qp_ex* qpx, struct ibv_ah* ah, + uint32_t remote_qpn) override; + + uint32_t getMaxInlineData() const override; + + void initPreAllocResources() override; +}; + +// Implementation (inline to avoid separate .cc file) +#include "rdma_channel_impl_efa.cc" diff --git a/p2p/rdma/providers/efa/rdma_device_selection_efa.h b/p2p/rdma/providers/efa/rdma_device_selection_efa.h new file mode 100644 index 000000000..31f646157 --- /dev/null +++ b/p2p/rdma/providers/efa/rdma_device_selection_efa.h @@ -0,0 +1,29 @@ +#pragma once +#include +#include +#include +#include + +// Forward declaration +class RDMADeviceSelectionStrategy; + +// EFA device selection strategy +class EFADeviceSelectionStrategy : public RDMADeviceSelectionStrategy { + public: + std::vector selectNICs( + std::vector const& candidates, int gpu_idx) override { + // NOTE(xzhiying): This is a temporary hack. + // On p5/p5en/p6-b200, there are 8/4/2 NICs with the same distance. + // E.g., on p5en, GPU0 uses candidates[0/1], GPU1 uses candidates[2/3], etc. + assert(candidates.size() == 8 || candidates.size() == 4 || + candidates.size() == 2); + int half_size = candidates.size() / 2; + int start_idx = (gpu_idx % 2 == 0) ? 0 : half_size; + int end_idx = start_idx + half_size; + std::vector selected; + for (int i = start_idx; i < end_idx; i++) { + selected.push_back(candidates[i]); + } + return selected; + } +}; diff --git a/p2p/rdma/providers/ib/rdma_channel_impl_ib.cc b/p2p/rdma/providers/ib/rdma_channel_impl_ib.cc new file mode 100644 index 000000000..fc40c512b --- /dev/null +++ b/p2p/rdma/providers/ib/rdma_channel_impl_ib.cc @@ -0,0 +1,227 @@ +#ifndef RDMA_CHANNEL_IMPL_IB_CC_INCLUDED +#define RDMA_CHANNEL_IMPL_IB_CC_INCLUDED + +#include "rdma_channel_impl_ib.h" +#include +#include +#include +#include + +#define GID_INDEX_DEFAULT 0 +#define MAX_INLINE_DATA 128 +#define SERVICE_LEVEL 135 +#define MIN_RNR_TIMER 12 +#define TRAFFIC_CLASS 3 + +#define RNR_RETRY 7 +#define RETRY_CNT 7 +#define TIMEOUT 14 +#define MAX_RD_ATOMIC 1 +#define MAX_DEST_RD_ATOMIC 1 +#define MAX_CQE 1024 + +static inline int get_gid_index_from_env() { + static int gid_index = -1; + if (gid_index == -1) { + char const* env = getenv("UCCL_IB_GID_INDEX"); + if (env) + gid_index = std::atoi(env); + else + gid_index = GID_INDEX_DEFAULT; + } + return gid_index; +} + +inline void IBChannelImpl::initQP(std::shared_ptr ctx, + struct ibv_cq_ex** cq_ex, struct ibv_qp** qp, + ChannelMetaData* local_meta) { + *cq_ex = (struct ibv_cq_ex*)ibv_create_cq(ctx->getCtx(), MAX_CQE, nullptr, + nullptr, 0); + assert(*cq_ex); + + struct ibv_qp_init_attr_ex qp_attr = {}; + memset(&qp_attr, 0, sizeof(qp_attr)); + qp_attr.comp_mask = IBV_QP_INIT_ATTR_PD | IBV_QP_INIT_ATTR_SEND_OPS_FLAGS; + qp_attr.send_ops_flags = IBV_QP_EX_WITH_RDMA_WRITE | + IBV_QP_EX_WITH_RDMA_WRITE_WITH_IMM | + IBV_QP_EX_WITH_RDMA_READ; + + qp_attr.cap.max_send_wr = kMaxSendWr; + qp_attr.cap.max_recv_wr = kMaxRecvWr; + qp_attr.cap.max_send_sge = kMaxSendSeg; + qp_attr.cap.max_recv_sge = kMaxRecvSeg; + qp_attr.cap.max_inline_data = getMaxInlineData(); + + qp_attr.send_cq = ibv_cq_ex_to_cq(*cq_ex); + qp_attr.recv_cq = ibv_cq_ex_to_cq(*cq_ex); + + qp_attr.pd = ctx->getPD(); + qp_attr.qp_context = ctx->getCtx(); + qp_attr.sq_sig_all = 0; + + qp_attr.qp_type = IBV_QPT_RC; + *qp = ibv_create_qp_ex(ctx->getCtx(), &qp_attr); + assert(*qp); + + struct ibv_qp_attr attr = {}; + memset(&attr, 0, sizeof(attr)); + attr.qp_state = IBV_QPS_INIT; + attr.port_num = kPortNum; + attr.pkey_index = 0; + attr.qp_access_flags = + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE; + assert(ibv_modify_qp(*qp, &attr, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | + IBV_QP_ACCESS_FLAGS) == 0); + + local_meta->gid = ctx->queryGid(get_gid_index_from_env()); + local_meta->qpn = (*qp)->qp_num; +} + +inline void IBChannelImpl::connectQP(struct ibv_qp* qp, + std::shared_ptr ctx, + ChannelMetaData const& remote_meta) { + ibrcQP_rtr_rts(qp, ctx, remote_meta); +} + +inline void IBChannelImpl::ibrcQP_rtr_rts(struct ibv_qp* qp, + std::shared_ptr ctx, + ChannelMetaData const& remote_meta) { + int flags = 0; + struct ibv_qp_attr attr = {}; + struct ibv_port_attr port_attr; + assert(ibv_query_port(ctx->getCtx(), kPortNum, &port_attr) == 0); + + // RTR + memset(&attr, 0, sizeof(attr)); + attr.qp_state = IBV_QPS_RTR; + attr.path_mtu = port_attr.active_mtu; + attr.dest_qp_num = remote_meta.qpn; + attr.rq_psn = 0; + attr.max_dest_rd_atomic = MAX_DEST_RD_ATOMIC; + attr.min_rnr_timer = MIN_RNR_TIMER; + // RoCE + // TODO: Infiniband + attr.ah_attr.is_global = 1; + attr.ah_attr.port_num = 1; + attr.ah_attr.sl = SERVICE_LEVEL; + attr.ah_attr.src_path_bits = 0; + attr.ah_attr.grh.traffic_class = TRAFFIC_CLASS; + attr.ah_attr.grh.hop_limit = 64; + memcpy(&attr.ah_attr.grh.dgid, remote_meta.gid.raw, 16); + attr.ah_attr.grh.sgid_index = get_gid_index_from_env(); + flags = IBV_QP_STATE | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | + IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER | IBV_QP_AV; + assert(ibv_modify_qp(qp, &attr, flags) == 0); + + lazy_post_recv_wrs_n(qp, kMaxRecvWr, true); + + // RTS + memset(&attr, 0, sizeof(attr)); + attr.qp_state = IBV_QPS_RTS; + attr.timeout = TIMEOUT; + attr.retry_cnt = RETRY_CNT; + attr.rnr_retry = RNR_RETRY; + attr.sq_psn = 0; + attr.max_rd_atomic = MAX_RD_ATOMIC; + flags = IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | + IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC; + assert(ibv_modify_qp(qp, &attr, flags) == 0); +} + +inline bool IBChannelImpl::poll_once(struct ibv_cq_ex* cq_ex, + std::vector& cq_datas, + uint32_t channel_id, + uint32_t& nb_post_recv) { + nb_post_recv = 0; + if (!cq_ex) { + LOG(INFO) << "poll_once - channel_id: " << channel_id << ", cq_ex_ is null"; + return false; + } + + struct ibv_wc pre_alloc_wcs[kBatchPollCqe]; + auto cq = ibv_cq_ex_to_cq(cq_ex); + int ret = ibv_poll_cq(cq, kBatchPollCqe, pre_alloc_wcs); + + if (ret <= 0) { + return false; + } + + for (int i = 0; i < ret; i++) { + auto wc = &pre_alloc_wcs[i]; + uint64_t wr_id = wc->wr_id; + auto status = wc->status; + if (unlikely(status != IBV_WC_SUCCESS)) { + LOG(WARNING) << "poll_once - channel_id: " << channel_id + << ", CQE error, wr_id=" << wr_id << ", status=" << status + << " (" << ibv_wc_status_str(status) << ")"; + } else { + CQMeta cq_data{}; + cq_data.wr_id = wr_id; + cq_data.op_code = wc->opcode; + cq_data.len = wc->byte_len; + + if (cq_data.op_code == IBV_WC_RECV_RDMA_WITH_IMM) { + cq_data.imm = wc->imm_data; + nb_post_recv++; + } else { + cq_data.imm = 0; + } + + cq_datas.emplace_back(cq_data); + } + } + + return !cq_datas.empty(); +} + +inline void IBChannelImpl::lazy_post_recv_wrs_n(struct ibv_qp* qp, uint32_t n, + bool force) { + pending_post_recv_ += n; + while (pending_post_recv_ >= kBatchPostRecvWr) { + struct ibv_recv_wr* bad_wr = nullptr; + pre_alloc_recv_wrs_[kBatchPostRecvWr - 1].next = nullptr; + assert(ibv_post_recv(qp, pre_alloc_recv_wrs_, &bad_wr) == 0); + pre_alloc_recv_wrs_[kBatchPostRecvWr - 1].next = + (kBatchPostRecvWr == kMaxRecvWr) + ? nullptr + : &pre_alloc_recv_wrs_[kBatchPostRecvWr]; + pending_post_recv_ -= kBatchPostRecvWr; + } + + if (force && pending_post_recv_) { + struct ibv_recv_wr* bad_wr = nullptr; + pre_alloc_recv_wrs_[pending_post_recv_ - 1].next = nullptr; + assert(ibv_post_recv(qp, pre_alloc_recv_wrs_, &bad_wr) == 0); + pre_alloc_recv_wrs_[pending_post_recv_ - 1].next = + (pending_post_recv_ == kMaxRecvWr) + ? nullptr + : &pre_alloc_recv_wrs_[pending_post_recv_]; + pending_post_recv_ = 0; + } +} + +inline void IBChannelImpl::setDstAddress(struct ibv_qp_ex* qpx, + struct ibv_ah* ah, + uint32_t remote_qpn) { + // IB RC doesn't need UD address setup + (void)qpx; + (void)ah; + (void)remote_qpn; +} + +inline void IBChannelImpl::initPreAllocResources() { + pre_alloc_recv_wrs_ = new struct ibv_recv_wr[kMaxRecvWr]; + pending_post_recv_ = 0; + for (int i = 0; i < kMaxRecvWr; i++) { + pre_alloc_recv_wrs_[i] = {}; + pre_alloc_recv_wrs_[i].next = + (i == kMaxRecvWr - 1) ? nullptr : &pre_alloc_recv_wrs_[i + 1]; + } +} + +inline uint32_t IBChannelImpl::getMaxInlineData() const { + return MAX_INLINE_DATA; +} + +#endif // RDMA_CHANNEL_IMPL_IB_CC_INCLUDED diff --git a/p2p/rdma/providers/ib/rdma_channel_impl_ib.h b/p2p/rdma/providers/ib/rdma_channel_impl_ib.h new file mode 100644 index 000000000..eb9b8b0de --- /dev/null +++ b/p2p/rdma/providers/ib/rdma_channel_impl_ib.h @@ -0,0 +1,35 @@ +#pragma once +#include "rdma/define.h" +#include "rdma/rdma_channel_impl.h" +#include + +class IBChannelImpl : public RDMAChannelImpl { + public: + IBChannelImpl() = default; + ~IBChannelImpl() override = default; + + void initQP(std::shared_ptr ctx, struct ibv_cq_ex** cq_ex, + struct ibv_qp** qp, ChannelMetaData* local_meta) override; + + void connectQP(struct ibv_qp* qp, std::shared_ptr ctx, + ChannelMetaData const& remote_meta) override; + + bool poll_once(struct ibv_cq_ex* cq_ex, std::vector& cq_datas, + uint32_t channel_id, uint32_t& nb_post_recv) override; + + void lazy_post_recv_wrs_n(struct ibv_qp* qp, uint32_t n, bool force) override; + + void setDstAddress(struct ibv_qp_ex* qpx, struct ibv_ah* ah, + uint32_t remote_qpn) override; + + uint32_t getMaxInlineData() const override; + + void initPreAllocResources() override; + + private: + void ibrcQP_rtr_rts(struct ibv_qp* qp, std::shared_ptr ctx, + ChannelMetaData const& remote_meta); +}; + +// Implementation (inline to avoid separate .cc file) +#include "rdma_channel_impl_ib.cc" diff --git a/p2p/rdma/providers/ib/rdma_device_selection_ib.h b/p2p/rdma/providers/ib/rdma_device_selection_ib.h new file mode 100644 index 000000000..887e8d2b5 --- /dev/null +++ b/p2p/rdma/providers/ib/rdma_device_selection_ib.h @@ -0,0 +1,22 @@ +#pragma once +#include +#include +#include +#include + +// Forward declaration +class RDMADeviceSelectionStrategy; + +// IB device selection strategy +class IBDeviceSelectionStrategy : public RDMADeviceSelectionStrategy { + public: + std::vector selectNICs( + std::vector const& candidates, int gpu_idx) override { + (void)gpu_idx; + std::vector selected; + if (!candidates.empty()) { + selected.push_back(candidates.front()); + } + return selected; + } +}; diff --git a/p2p/rdma/rdma_channel.h b/p2p/rdma/rdma_channel.h new file mode 100644 index 000000000..147560ee5 --- /dev/null +++ b/p2p/rdma/rdma_channel.h @@ -0,0 +1,229 @@ +#pragma once +#include "define.h" +#include "rdma_channel_impl.h" +#include "rdma_context.h" +#include "seq_num.h" +#include "util/util.h" +#include + +#ifdef UCCL_P2P_USE_IB +#include "providers/ib/rdma_channel_impl_ib.h" +#else +#include "providers/efa/rdma_channel_impl_efa.h" +#endif + +// Factory function implementation (inline, defined after including impl +// headers) +inline std::unique_ptr createRDMAChannelImpl() { +#ifdef UCCL_P2P_USE_IB + return std::make_unique(); +#else + return std::make_unique(); +#endif +} + +class RDMAChannel { + public: + explicit RDMAChannel(std::shared_ptr ctx, + uint32_t channel_id = 0) + : ctx_(ctx), + qp_(nullptr), + cq_ex_(nullptr), + ah_(nullptr), + channel_id_(channel_id), + local_meta_(std::make_shared()), + remote_meta_(std::make_shared()), + impl_(createRDMAChannelImpl()) { + initQP(); + } + + explicit RDMAChannel(std::shared_ptr ctx, + ChannelMetaData const& remote_meta, + uint32_t channel_id = 0) + : ctx_(ctx), + qp_(nullptr), + cq_ex_(nullptr), + ah_(nullptr), + channel_id_(channel_id), + local_meta_(std::make_shared()), + remote_meta_(std::make_shared(remote_meta)), + impl_(createRDMAChannelImpl()) { + initQP(); + ah_ = ctx_->createAH(remote_meta_->gid); + impl_->connectQP(qp_, ctx_, *remote_meta_); + UCCL_LOG_EP << "RDMAChannel connected to remote qpn=" << remote_meta.qpn; + } + + RDMAChannel(RDMAChannel const&) = delete; + RDMAChannel& operator=(RDMAChannel const&) = delete; + + void connect(ChannelMetaData const& remote_meta) { + remote_meta_ = std::make_shared(remote_meta); + ah_ = ctx_->createAH(remote_meta_->gid); + impl_->connectQP(qp_, ctx_, *remote_meta_); + UCCL_LOG_EP << "RDMAChannel connected to remote qpn=" << remote_meta.qpn; + } + + int64_t submitRequest(std::shared_ptr req) { + return postRequest(req); + } + + int64_t read(std::shared_ptr req) { + int ret = postRequest(req); + if (ret != 0) { + LOG(ERROR) << "Failed to post read request, wr_id=" << req->wr_id; + return -1; + } + return req->wr_id; + } + + int64_t send(std::shared_ptr req) { + int ret = postRequest(req); + if (ret != 0) { + LOG(ERROR) << "Failed to post send request, wr_id=" << req->wr_id; + return -1; + } + return req->wr_id; + } + + int64_t recv(std::shared_ptr req) { + struct ibv_sge sge = { + .addr = (uintptr_t)req->getLocalAddress(), + .length = (uint32_t)req->getLocalLen(), + .lkey = req->getLocalKey(), + }; + struct ibv_recv_wr wr = {0}, *bad_wr = nullptr; + int64_t wr_id = req->wr_id; + wr.wr_id = wr_id; + wr.sg_list = &sge; + wr.num_sge = 1; + if (ibv_post_recv(qp_, &wr, &bad_wr)) { + LOG(ERROR) << "ibv_post_recv failed: " << strerror(errno); + } + return wr_id; + } + + bool poll_once(std::vector& cq_datas) { + uint32_t nb_post_recv = 0; + bool result = impl_->poll_once(cq_ex_, cq_datas, channel_id_, nb_post_recv); + impl_->lazy_post_recv_wrs_n(qp_, nb_post_recv, false); + return result; + } + + // Get local metadata + std::shared_ptr get_local_meta() const { + return local_meta_; + } + + // Get remote metadata + std::shared_ptr get_remote_meta() const { + return remote_meta_; + } + + // Get RdmaContext + inline std::shared_ptr const getContext() const { return ctx_; } + + inline uint64_t const getContextID() const { return ctx_->getContextID(); } + + inline uint32_t getChannelID() const { return channel_id_; } + + private: + std::shared_ptr ctx_; + uint32_t channel_id_; + + struct ibv_cq_ex* cq_ex_; + struct ibv_qp* qp_; + struct ibv_ah* ah_; + + std::shared_ptr local_meta_; + std::shared_ptr remote_meta_; + + std::shared_ptr tracker_; + std::unique_ptr impl_; + + struct ibv_cq_ex* getCQ() const { + return cq_ex_; + } + + struct ibv_qp* getQP() const { + return qp_; + } + + // Post send request based on send_type + // Returns 0 on success, error code on failure + inline int postRequest(std::shared_ptr req) { + auto* qpx = ibv_qp_to_qp_ex(qp_); + ibv_wr_start(qpx); + LOG(INFO) << *req; + qpx->wr_id = req->wr_id; + qpx->comp_mask = 0; + qpx->wr_flags = IBV_SEND_SIGNALED; + + if (req->send_type == SendType::Send) { + ibv_wr_rdma_write_imm(qpx, req->getRemoteKey(), req->getRemoteAddress(), + req->imm_data); + } else if (req->send_type == SendType::Write) { + ibv_wr_rdma_write(qpx, req->getRemoteKey(), req->getRemoteAddress()); + } else if (req->send_type == SendType::Read) { + ibv_wr_rdma_read(qpx, req->getRemoteKey(), req->getRemoteAddress()); + } else { + LOG(ERROR) << "Unknown SendType in RDMAChannel::postRequest"; + return -1; + } + + struct ibv_sge sge[1]; + int num_sge = prepareSGEList(sge, req); + uint32_t max_inline = impl_->getMaxInlineData(); + if (req->getLocalLen() <= max_inline) { + qpx->wr_flags |= IBV_SEND_INLINE; + ibv_wr_set_inline_data(qpx, (void*)req->getLocalAddress(), + req->getLocalLen()); + } else { + ibv_wr_set_sge_list(qpx, num_sge, sge); + } + + impl_->setDstAddress(qpx, ah_, remote_meta_->qpn); + + int ret = ibv_wr_complete(qpx); + if (ret) { + std::ostringstream sge_info; + sge_info << "["; + for (int i = 0; i < num_sge; ++i) { + if (i > 0) sge_info << ", "; + sge_info << "{addr:0x" << std::hex << sge[i].addr + << ", len:" << std::dec << sge[i].length << ", lkey:0x" + << std::hex << sge[i].lkey << std::dec << "}"; + } + sge_info << "]"; + + LOG(ERROR) << "ibv_wr_complete failed in postRequest: " << ret << " " + << strerror(ret) << ", ah_=" << (void*)ah_ + << ", remote_qpn=" << remote_meta_->qpn + << ", local_qpn=" << qp_->qp_num << ", wr_id=" << req->wr_id + << ", remote_key=" << req->getRemoteKey() << ", remote_addr=0x" + << std::hex << req->getRemoteAddress() + << ", local_key=" << req->getLocalKey() + << ", num_sge=" << num_sge << ", sge_list=" << sge_info.str() + << std::dec; + } + return ret; + } + + void initQP() { + impl_->initQP(ctx_, &cq_ex_, &qp_, local_meta_.get()); + impl_->initPreAllocResources(); + } + + // Prepare SGE list for send request + // Returns the number of SGE entries filled + inline int prepareSGEList(struct ibv_sge* sge, + std::shared_ptr req) { + uint32_t total_len = req->getLocalLen(); + uint64_t local_addr = req->getLocalAddress(); + uint32_t local_key = req->getLocalKey(); + sge[0].addr = local_addr; + sge[0].length = total_len; + sge[0].lkey = local_key; + return 1; + } +}; diff --git a/p2p/efa/efa_channel_group.h b/p2p/rdma/rdma_channel_group.h similarity index 91% rename from p2p/efa/efa_channel_group.h rename to p2p/rdma/rdma_channel_group.h index 45fccd41d..1db2567e6 100644 --- a/p2p/efa/efa_channel_group.h +++ b/p2p/rdma/rdma_channel_group.h @@ -1,7 +1,7 @@ #pragma once #include "define.h" -#include "efa_channel.h" -#include "efa_ctrl_channel.h" +#include "rdma_channel.h" +#include "rdma_ctrl_channel.h" #include class ChannelGroup { @@ -10,7 +10,7 @@ class ChannelGroup { virtual ~ChannelGroup() = default; virtual void addChannel(uint32_t channel_id, - std::shared_ptr channel) { + std::shared_ptr channel) { if (!channel) { throw std::invalid_argument("addChannel called with null channel"); } @@ -19,7 +19,7 @@ class ChannelGroup { channels_[channel_id] = std::move(channel); } - virtual std::shared_ptr getChannel(uint32_t channel_id) const { + virtual std::shared_ptr getChannel(uint32_t channel_id) const { std::shared_lock lock(mutex_); auto it = channels_.find(channel_id); if (it == channels_.end()) return nullptr; @@ -31,7 +31,7 @@ class ChannelGroup { return channels_.size(); } - virtual std::unordered_map> const& + virtual std::unordered_map> const& channels() const { mutex_.lock_shared(); mutex_.unlock_shared(); // just to annotate read lock expected @@ -96,7 +96,7 @@ class ChannelGroup { protected: mutable std::shared_mutex mutex_; - std::unordered_map> channels_; + std::unordered_map> channels_; std::atomic last_channel_id_; }; @@ -108,17 +108,17 @@ class SendChannelGroup : public ChannelGroup { auto_start_polling_(auto_start_polling) { tracker_ = std::make_shared(); request_queue_ = std::make_unique< - RingBuffer, kRingCapacity>>(); + RingBuffer, kRingCapacity>>(); } ~SendChannelGroup() { stopPolling(); } void addChannel(uint32_t channel_id, - std::shared_ptr channel) override { + std::shared_ptr channel) override { ChannelGroup::addChannel(channel_id, channel); } - std::shared_ptr getChannel(uint32_t channel_id) const override { + std::shared_ptr getChannel(uint32_t channel_id) const override { auto result = ChannelGroup::getChannel(channel_id); return result; } @@ -134,7 +134,7 @@ class SendChannelGroup : public ChannelGroup { size_t normalChannelCount() const { return ChannelGroup::channelCount(); } - std::unordered_map> const& channels() + std::unordered_map> const& channels() const override { return ChannelGroup::channels(); } @@ -153,7 +153,7 @@ class SendChannelGroup : public ChannelGroup { } } - int64_t send(std::shared_ptr req) { + int64_t send(std::shared_ptr req) { int64_t wr_id = tracker_->sendPacket(req->getLocalLen()); req->wr_id = wr_id; if (unlikely(request_queue_->push(req) < 0)) { @@ -164,7 +164,7 @@ class SendChannelGroup : public ChannelGroup { return wr_id; } - int64_t postWriteOrRead(std::shared_ptr req) { + int64_t postWriteOrRead(std::shared_ptr req) { if (unlikely(req->send_type != SendType::Write && req->send_type != SendType::Read)) { LOG(ERROR) << "SendChannelGroup::write - Invalid send_type, expected " @@ -187,7 +187,7 @@ class SendChannelGroup : public ChannelGroup { return wr_id; } - int64_t read(std::shared_ptr req) { + int64_t read(std::shared_ptr req) { if (unlikely(req->send_type != SendType::Read)) { LOG(ERROR) << "SendChannelGroup::read - Invalid send_type, expected " "SendType::Read"; @@ -239,7 +239,7 @@ class SendChannelGroup : public ChannelGroup { << "SendChannelGroup::pollingLoop - Still running"; } - int processSendRequests(std::shared_ptr req) { + int processSendRequests(std::shared_ptr req) { pollControlChannel(); if (unlikely(ctrl_channel_ == nullptr)) { return -1; @@ -262,14 +262,14 @@ class SendChannelGroup : public ChannelGroup { mutable std::shared_mutex ctrl_channel_mutex_; std::atomic running_; std::unique_ptr poll_thread_; - std::unique_ptr, kRingCapacity>> + std::unique_ptr, kRingCapacity>> request_queue_; std::shared_ptr tracker_; bool auto_start_polling_; // Send a request through the appropriate channel // Returns true on success, false on failure - bool postRequestOnChannel(std::shared_ptr req) { + bool postRequestOnChannel(std::shared_ptr req) { auto channel = getChannel(req->channel_id); if (unlikely(!channel)) { LOG(WARNING) << "SendChannelGroup: Channel not found for channel_id " @@ -297,7 +297,7 @@ class SendChannelGroup : public ChannelGroup { } } - void postChunkedRequest(std::shared_ptr req) { + void postChunkedRequest(std::shared_ptr req) { // Split message into chunks size_t message_size = req->local_mem->size; auto chunks = splitMessageToChunks(message_size); @@ -326,7 +326,7 @@ class SendChannelGroup : public ChannelGroup { // Create send request for this chunk // Only the last chunk needs signaled for completion notification bool is_last_chunk = (i == chunks.size() - 1); - auto chunk_req = std::make_shared( + auto chunk_req = std::make_shared( chunk_local_mem, chunk_remote_mem, req->imm_data, is_last_chunk); chunk_req->channel_id = chunk_channel_id; chunk_req->from_rank_id = req->from_rank_id; @@ -358,7 +358,7 @@ class SendChannelGroup : public ChannelGroup { has_meta = ctrl_channel_->hasSendRequest(); // } while (has_meta) { - std::shared_ptr req; + std::shared_ptr req; if (tracker_->getTotalInflightBytes() > kInFlightMaxSizeKB * 1024 || !request_queue_->pop(req)) { if (tracker_->getTotalInflightBytes() > kInFlightMaxSizeKB * 1024) { @@ -376,7 +376,7 @@ class SendChannelGroup : public ChannelGroup { } } - inline void processOnceSendRequests(std::shared_ptr req, + inline void processOnceSendRequests(std::shared_ptr req, SendReqMeta& meta, int index) { req->imm_data = index; req->channel_id = meta.channel_id; @@ -427,11 +427,11 @@ class RecvChannelGroup : public ChannelGroup { ~RecvChannelGroup() { stopPolling(); } void addChannel(uint32_t channel_id, - std::shared_ptr channel) override { + std::shared_ptr channel) override { ChannelGroup::addChannel(channel_id, channel); } - std::shared_ptr getChannel(uint32_t channel_id) const override { + std::shared_ptr getChannel(uint32_t channel_id) const override { auto result = ChannelGroup::getChannel(channel_id); return result; } @@ -446,7 +446,7 @@ class RecvChannelGroup : public ChannelGroup { size_t normalChannelCount() const { return ChannelGroup::channelCount(); } - std::unordered_map> const& channels() + std::unordered_map> const& channels() const override { return ChannelGroup::channels(); } @@ -485,7 +485,7 @@ class RecvChannelGroup : public ChannelGroup { } } - int64_t recv(std::shared_ptr req) { + int64_t recv(std::shared_ptr req) { if (unlikely(!setupRecvRequestChannelAndMemoryRegion(req))) { LOG(WARNING) << "RecvChannelGroup: Failed to setup recv request with round robin"; @@ -573,7 +573,7 @@ class RecvChannelGroup : public ChannelGroup { // Round-robin channel selection and MR setup bool setupRecvRequestChannelAndMemoryRegion( - std::shared_ptr req) { + std::shared_ptr req) { if (unlikely(!req || !req->local_mem)) { return false; } diff --git a/p2p/rdma/rdma_channel_impl.h b/p2p/rdma/rdma_channel_impl.h new file mode 100644 index 000000000..2df1af4a7 --- /dev/null +++ b/p2p/rdma/rdma_channel_impl.h @@ -0,0 +1,61 @@ +#pragma once +#include "define.h" +#include "rdma_context.h" +#include + +// Forward declarations +struct ibv_qp; +struct ibv_qp_ex; +struct ibv_cq_ex; +struct ibv_ah; +struct ibv_sge; +struct ibv_recv_wr; +struct ibv_wc; + +// Base class for RDMA channel implementations +class RDMAChannelImpl { + public: + virtual ~RDMAChannelImpl() = default; + + // Initialize QP and CQ + virtual void initQP(std::shared_ptr ctx, + struct ibv_cq_ex** cq_ex, struct ibv_qp** qp, + ChannelMetaData* local_meta) = 0; + + // Connect QP to remote + virtual void connectQP(struct ibv_qp* qp, std::shared_ptr ctx, + ChannelMetaData const& remote_meta) = 0; + + // Poll completion queue + virtual bool poll_once(struct ibv_cq_ex* cq_ex, std::vector& cq_datas, + uint32_t channel_id, uint32_t& nb_post_recv) = 0; + + // Post receive work request + virtual void lazy_post_recv_wrs_n(struct ibv_qp* qp, uint32_t n, + bool force) = 0; + + // Setup Destination address + virtual void setDstAddress(struct ibv_qp_ex* qpx, struct ibv_ah* ah, + uint32_t remote_qpn) = 0; + + // Get max inline data size + virtual uint32_t getMaxInlineData() const = 0; + + // Initialize pre-allocated resources + virtual void initPreAllocResources() = 0; + + protected: + struct ibv_recv_wr* pre_alloc_recv_wrs_; + uint32_t pending_post_recv_; +}; + +// Forward declarations for implementations +#ifdef UCCL_P2P_USE_IB +class IBChannelImpl; +#else +class EFAChannelImpl; +#endif + +// Factory function declaration (implementation in rdma_channel.h after +// including impl headers) +std::unique_ptr createRDMAChannelImpl(); diff --git a/p2p/efa/rdma_context.h b/p2p/rdma/rdma_context.h similarity index 100% rename from p2p/efa/rdma_context.h rename to p2p/rdma/rdma_context.h diff --git a/p2p/efa/efa_ctrl_channel.h b/p2p/rdma/rdma_ctrl_channel.h similarity index 87% rename from p2p/efa/efa_ctrl_channel.h rename to p2p/rdma/rdma_ctrl_channel.h index 52b6d4fe7..e82191765 100644 --- a/p2p/efa/efa_ctrl_channel.h +++ b/p2p/rdma/rdma_ctrl_channel.h @@ -1,23 +1,23 @@ #pragma once #include "define.h" -#include "efa_channel.h" +#include "rdma_channel.h" #include "ring_spsc.h" -class SendControlChannel : public EFAChannel { +class SendControlChannel : public RDMAChannel { public: explicit SendControlChannel(std::shared_ptr ctx, uint32_t channel_id = 0) - : EFAChannel(ctx, channel_id) {} + : RDMAChannel(ctx, channel_id) {} explicit SendControlChannel(std::shared_ptr ctx, ChannelMetaData const& remote_meta, uint32_t channel_id = 0) - : EFAChannel(ctx, remote_meta, channel_id) {} + : RDMAChannel(ctx, remote_meta, channel_id) {} explicit SendControlChannel(std::shared_ptr ctx, std::shared_ptr mem_block, uint32_t channel_id = 0) - : EFAChannel(ctx, channel_id) { + : RDMAChannel(ctx, channel_id) { rb_ = std::make_unique>( mem_block); } @@ -26,7 +26,7 @@ class SendControlChannel : public EFAChannel { ChannelMetaData const& remote_meta, std::shared_ptr mem_block, uint32_t channel_id = 0) - : EFAChannel(ctx, remote_meta, channel_id) { + : RDMAChannel(ctx, remote_meta, channel_id) { rb_ = std::make_unique>( mem_block); } @@ -36,7 +36,7 @@ class SendControlChannel : public EFAChannel { // Initialize rb_ with the shared_ptr rb_ = std::make_unique>( mem_block); - EFAChannel::connect(remote_meta); + RDMAChannel::connect(remote_meta); } int getOneSendRequestMeta(SendReqMeta& meta) { @@ -47,7 +47,7 @@ class SendControlChannel : public EFAChannel { inline bool hasSendRequest() { return !rb_->empty(); } // not thread safe - bool getOneSendRequest(std::shared_ptr& req) { + bool getOneSendRequest(std::shared_ptr& req) { // Pop from rb_ and generate req, return false if empty SendReqMeta meta; int index = getOneSendRequestMeta(meta); @@ -72,7 +72,7 @@ class SendControlChannel : public EFAChannel { bool noblockingPoll() { std::vector cq_datas; - if (EFAChannel::poll_once(cq_datas)) { + if (RDMAChannel::poll_once(cq_datas)) { for (auto const& cq_data : cq_datas) { LOG(INFO) << "SendControlChannel::noblockingPoll - Polled completion: " << cq_data; @@ -90,12 +90,12 @@ class SendControlChannel : public EFAChannel { std::unique_ptr> rb_; }; -class RecvControlChannel : public EFAChannel { +class RecvControlChannel : public RDMAChannel { public: explicit RecvControlChannel(std::shared_ptr ctx, std::shared_ptr mem_block, uint32_t channel_id = 0) - : EFAChannel(ctx, channel_id) { + : RDMAChannel(ctx, channel_id) { local_info_ = mem_block; rb_ = std::make_unique>( local_info_); @@ -105,7 +105,7 @@ class RecvControlChannel : public EFAChannel { MetaInfoToExchange const& remote_meta, std::shared_ptr mem_block, uint32_t channel_id = 0) - : EFAChannel(ctx, remote_meta.channel_meta, channel_id) { + : RDMAChannel(ctx, remote_meta.channel_meta, channel_id) { local_info_ = mem_block; rb_ = std::make_unique>( local_info_); @@ -122,10 +122,10 @@ class RecvControlChannel : public EFAChannel { empty_rb_ = std::make_unique>( reinterpret_cast(remote_meta.mem_meta.addr)); - EFAChannel::connect(remote_meta.channel_meta); + RDMAChannel::connect(remote_meta.channel_meta); } - int postSendReq(std::shared_ptr rev_req) { + int postSendReq(std::shared_ptr rev_req) { SendReqMeta req_meta(rev_req); LOG(INFO) << "postSendReq - Created SendReqMeta: " << req_meta; @@ -156,10 +156,11 @@ class RecvControlChannel : public EFAChannel { local_mem_ptr_->size = rb_->elementSize(); } - std::shared_ptr send_ptr = std::make_shared( - local_mem_ptr_, remote_mem_ptr_, index); + std::shared_ptr send_ptr = + std::make_shared(local_mem_ptr_, remote_mem_ptr_, + index); send_ptr->channel_id = kControlChannelID; - EFAChannel::send(send_ptr); + RDMAChannel::send(send_ptr); return index; } @@ -177,7 +178,7 @@ class RecvControlChannel : public EFAChannel { bool noblockingPoll() { std::vector cq_datas; - if (EFAChannel::poll_once(cq_datas)) { + if (RDMAChannel::poll_once(cq_datas)) { for (auto const& cq_data : cq_datas) { LOG(INFO) << "RecvControlChannel::noblockingPoll - Polled completion: " << cq_data; diff --git a/p2p/efa/rdma_device.h b/p2p/rdma/rdma_device.h similarity index 81% rename from p2p/efa/rdma_device.h rename to p2p/rdma/rdma_device.h index a9c1ce830..159010d6b 100644 --- a/p2p/efa/rdma_device.h +++ b/p2p/rdma/rdma_device.h @@ -1,5 +1,34 @@ #pragma once #include "define.h" +#include +#include +#include + +// Base class for device selection strategy +class RDMADeviceSelectionStrategy { + public: + virtual ~RDMADeviceSelectionStrategy() = default; + + // Select NIC names from candidates based on GPU index + virtual std::vector selectNICs( + std::vector const& candidates, int gpu_idx) = 0; +}; + +// Include device selection strategy based on build configuration +#ifdef UCCL_P2P_USE_IB +#include "providers/ib/rdma_device_selection_ib.h" +#else +#include "providers/efa/rdma_device_selection_efa.h" +#endif + +inline std::unique_ptr +createDeviceSelectionStrategy() { +#ifdef UCCL_P2P_USE_IB + return std::make_unique(); +#else + return std::make_unique(); +#endif +} class RdmaDevice { public: @@ -87,16 +116,10 @@ class RdmaDeviceManager { LOG(WARNING) << "no candidate NIC found, defaulting to first"; selected_nic_names.push_back(dist.front().first); } else { - // NOTE(xzhiying): This is a temporary hack. - // On p5en, there are 4 NICs with the same distance. - // GPU0 uses candidates[0/1], GPU1 uses candidates[2/3], etc. - assert(candidates.size() == 4); - int half_size = candidates.size() / 2; - int start_idx = (gpu_idx % 2 == 0) ? 0 : half_size; - int end_idx = start_idx + half_size; - for (int i = start_idx; i < end_idx; i++) { - selected_nic_names.push_back(candidates[i]); - } + auto strategy = createDeviceSelectionStrategy(); + auto selected = strategy->selectNICs(candidates, gpu_idx); + selected_nic_names.insert(selected_nic_names.end(), selected.begin(), + selected.end()); } } diff --git a/p2p/efa/efa_endpoint.h b/p2p/rdma/rdma_endpoint.h similarity index 95% rename from p2p/efa/efa_endpoint.h rename to p2p/rdma/rdma_endpoint.h index 8a96e973c..00a964ed0 100644 --- a/p2p/efa/efa_endpoint.h +++ b/p2p/rdma/rdma_endpoint.h @@ -1,18 +1,18 @@ #pragma once #include "define.h" -#include "efa_channel_group.h" -#include "efa_ctrl_channel.h" #include "epoll_client.h" #include "epoll_server.h" #include "memory_allocator.h" +#include "rdma_channel_group.h" #include "rdma_context.h" +#include "rdma_ctrl_channel.h" #include "rdma_device.h" #include "util/net.h" #include -class EFAEndpoint { +class NICEndpoint { public: - explicit EFAEndpoint( + explicit NICEndpoint( int gpu_index, uint64_t rank_id = INVALID_RANK_ID, uint64_t port = 0, bool auto_start_polling = true, std::vector const& device_ids = std::vector()) @@ -29,7 +29,7 @@ class EFAEndpoint { actual_device_ids = device_ids; } initializeContexts(actual_device_ids); - LOG(INFO) << "EFAEndpoint initialized with " << contexts_.size() + LOG(INFO) << "NICEndpoint initialized with " << contexts_.size() << " context(s) for GPU " << gpu_index_; oob_server_ = std::make_shared( @@ -45,7 +45,7 @@ class EFAEndpoint { } // Destructor - ~EFAEndpoint() { + ~NICEndpoint() { if (oob_client_) { oob_client_->stop(); } @@ -178,7 +178,7 @@ class EFAEndpoint { << ", index: " << index; } - int64_t writeOrRead(std::shared_ptr req) { + int64_t writeOrRead(std::shared_ptr req) { uint64_t rank_id = req->to_rank_id; auto it = send_channel_groups_.find(rank_id); if (it == send_channel_groups_.end()) { @@ -191,7 +191,7 @@ class EFAEndpoint { // Blocking call until send succeeds while (wr_id < 0) { - LOG(INFO) << "EFAEndpoint::write - Attempting to send to rank_id: " + LOG(INFO) << "NICEndpoint::write - Attempting to send to rank_id: " << rank_id << ", peer rank_id " << rank_id; wr_id = send_group->postWriteOrRead(req); @@ -205,7 +205,7 @@ class EFAEndpoint { // Blocking send: wraps SendChannelGroup::send with rank_id parameter // Returns wr_id for checking completion later - int64_t send(uint64_t rank_id, std::shared_ptr req) { + int64_t send(uint64_t rank_id, std::shared_ptr req) { auto it = send_channel_groups_.find(rank_id); if (it == send_channel_groups_.end()) { throw std::runtime_error("Send channel group not found for rank_id: " + @@ -217,7 +217,7 @@ class EFAEndpoint { // Blocking call until send succeeds while (wr_id < 0) { - LOG(INFO) << "EFAEndpoint::send - Attempting to send to rank_id: " + LOG(INFO) << "NICEndpoint::send - Attempting to send to rank_id: " << rank_id << ", peer rank_id " << rank_id; wr_id = send_group->send(req); @@ -231,7 +231,7 @@ class EFAEndpoint { // Blocking recv: wraps RecvChannelGroup::recv with rank_id parameter // Returns index for checking completion later - int64_t recv(uint64_t rank_id, std::shared_ptr req) { + int64_t recv(uint64_t rank_id, std::shared_ptr req) { auto it = recv_channel_groups_.find(rank_id); if (it == recv_channel_groups_.end()) { throw std::runtime_error("Recv channel group not found for rank_id: " + @@ -243,7 +243,7 @@ class EFAEndpoint { // Blocking call until recv succeeds while (index < 0) { index = recv_group->recv(req); - LOG(INFO) << "EFAEndpoint::recv - Attempting to recv from rank_id: " + LOG(INFO) << "NICEndpoint::recv - Attempting to recv from rank_id: " << rank_id << ", peer rank_id " << rank_id; if (index < 0) { std::this_thread::sleep_for(std::chrono::microseconds(10)); @@ -415,12 +415,12 @@ class EFAEndpoint { } // Manual polling routine for send channels when auto_start_polling_ is false - int sendWithoutInnerQueue(std::shared_ptr req) { + int sendWithoutInnerQueue(std::shared_ptr req) { if (auto_start_polling_) { return -1; // Do nothing if auto polling is enabled } if (!req) { - LOG(WARNING) << "EFAEndpoint::sendRoutine - null request"; + LOG(WARNING) << "NICEndpoint::sendRoutine - null request"; return -1; } @@ -428,7 +428,7 @@ class EFAEndpoint { std::shared_lock lock(send_channel_mutex_); auto it = send_channel_groups_.find(rank_id); if (it == send_channel_groups_.end()) { - LOG(WARNING) << "EFAEndpoint::sendRoutine - Send channel group not found " + LOG(WARNING) << "NICEndpoint::sendRoutine - Send channel group not found " "for rank_id: " << rank_id; return -1; @@ -436,7 +436,7 @@ class EFAEndpoint { auto send_group = it->second; if (!send_group) { - LOG(WARNING) << "EFAEndpoint::sendRoutine - Send channel group is null " + LOG(WARNING) << "NICEndpoint::sendRoutine - Send channel group is null " "for rank_id: " << rank_id; return -1; @@ -466,7 +466,7 @@ class EFAEndpoint { } auto context = std::make_shared(device, contexts_.size()); contexts_.push_back(context); - LOG(INFO) << "EFAEndpoint: Created context " << i << " for device " + LOG(INFO) << "NICEndpoint: Created context " << i << " for device " << device_id << " (" << device->name() << ")"; } @@ -539,7 +539,7 @@ class EFAEndpoint { } } - std::shared_ptr new_channel = std::make_shared( + std::shared_ptr new_channel = std::make_shared( ctx_ptr, meta.channel_meta, meta.channel_id); // Create response (echo back the same data) MetaInfoToExchange response(rank_id_, meta.channel_id, @@ -552,7 +552,7 @@ class EFAEndpoint { } // Handle response from send_meta operation - uint64_t handle_send_meta_response(std::shared_ptr channel, + uint64_t handle_send_meta_response(std::shared_ptr channel, std::string const& response) { // Deserialize response as MetaInfoToExchange MetaInfoToExchange response_meta = @@ -583,7 +583,7 @@ class EFAEndpoint { } void addOneRecvChannel(uint64_t rank_id, uint32_t channel_id, - std::shared_ptr new_channel) { + std::shared_ptr new_channel) { std::shared_ptr group_ptr = getOrCreateRecvGroup(rank_id); group_ptr->addChannel(channel_id, new_channel); } @@ -610,7 +610,7 @@ class EFAEndpoint { } void addOneSendChannel(uint64_t rank_id, uint32_t channel_id, - std::shared_ptr new_channel) { + std::shared_ptr new_channel) { auto group_ptr = getOrCreateSendGroup(rank_id); group_ptr->addChannel(channel_id, new_channel); } @@ -698,7 +698,7 @@ class EFAEndpoint { for (int i = 0; i < kQpNumPerChannel; i++) { uint32_t channel_id = i + 1; - auto channel = std::make_shared( + auto channel = std::make_shared( getContextByChannelId(channel_id), channel_id); MetaInfoToExchange meta(rank_id_, channel_id, channel->get_local_meta(), diff --git a/p2p/efa/ring_spsc.h b/p2p/rdma/ring_spsc.h similarity index 100% rename from p2p/efa/ring_spsc.h rename to p2p/rdma/ring_spsc.h diff --git a/p2p/efa/seq_num.h b/p2p/rdma/seq_num.h similarity index 100% rename from p2p/efa/seq_num.h rename to p2p/rdma/seq_num.h diff --git a/p2p/efa/tests/Makefile b/p2p/rdma/tests/Makefile similarity index 100% rename from p2p/efa/tests/Makefile rename to p2p/rdma/tests/Makefile diff --git a/p2p/efa/tests/test_efa_endpoint.cpp b/p2p/rdma/tests/test_efa_endpoint.cpp similarity index 94% rename from p2p/efa/tests/test_efa_endpoint.cpp rename to p2p/rdma/tests/test_efa_endpoint.cpp index d020bd21f..02fa47f78 100644 --- a/p2p/efa/tests/test_efa_endpoint.cpp +++ b/p2p/rdma/tests/test_efa_endpoint.cpp @@ -49,7 +49,7 @@ DEFINE_uint64(buffer_size, 1024 * 1024, "Buffer size in bytes"); // --iterations=100 --buffer_size=104857600 // Correctness test: perform 100 send/recv operations and verify results -void correctness_test(EFAEndpoint& endpoint, MemoryAllocator& allocator) { +void correctness_test(NICEndpoint& endpoint, MemoryAllocator& allocator) { std::cout << "\n=== Starting Correctness Test (100 iterations) ===\n" << std::flush; @@ -104,8 +104,8 @@ void correctness_test(EFAEndpoint& endpoint, MemoryAllocator& allocator) { // Create requests auto remote_mem_placeholder = std::make_shared(); auto send_req = - std::make_shared(send_mem, remote_mem_placeholder); - auto recv_req = std::make_shared(recv_mem); + std::make_shared(send_mem, remote_mem_placeholder); + auto recv_req = std::make_shared(recv_mem); // Post recv first int64_t recv_index = endpoint.recv(FLAGS_remote_rank, recv_req); @@ -194,7 +194,7 @@ void correctness_test(EFAEndpoint& endpoint, MemoryAllocator& allocator) { } // Unidirectional bandwidth test: rank 0 only sends, rank 1 only receives -void unidirectional_test(EFAEndpoint& endpoint, MemoryAllocator& allocator, +void unidirectional_test(NICEndpoint& endpoint, MemoryAllocator& allocator, int iterations) { std::cout << "\n=== Starting Unidirectional Bandwidth Test (" << iterations << " iterations) ===\n"; @@ -231,12 +231,12 @@ void unidirectional_test(EFAEndpoint& endpoint, MemoryAllocator& allocator, // Rank 0: only send auto remote_mem_placeholder = std::make_shared(); auto send_req = - std::make_shared(send_mem, remote_mem_placeholder); + std::make_shared(send_mem, remote_mem_placeholder); int64_t send_wr_id = endpoint.send(FLAGS_remote_rank, send_req); endpoint.checkSendComplete(FLAGS_remote_rank, send_wr_id); } else { // Rank 1: only receive - auto recv_req = std::make_shared(recv_mem); + auto recv_req = std::make_shared(recv_mem); int64_t recv_index = endpoint.recv(FLAGS_remote_rank, recv_req); endpoint.checkRecvComplete(FLAGS_remote_rank, recv_index); } @@ -257,7 +257,7 @@ void unidirectional_test(EFAEndpoint& endpoint, MemoryAllocator& allocator, for (int i = 0; i < iterations; i++) { auto remote_mem_placeholder = std::make_shared(); auto send_req = - std::make_shared(send_mem, remote_mem_placeholder); + std::make_shared(send_mem, remote_mem_placeholder); int64_t send_wr_id = endpoint.send(FLAGS_remote_rank, send_req); send_infos.push_back({send_req->channel_id, send_wr_id}); @@ -279,7 +279,7 @@ void unidirectional_test(EFAEndpoint& endpoint, MemoryAllocator& allocator, // First, recv all messages for (int i = 0; i < iterations; i++) { - auto recv_req = std::make_shared(recv_mem); + auto recv_req = std::make_shared(recv_mem); int64_t recv_index = endpoint.recv(FLAGS_remote_rank, recv_req); recv_indices.push_back(recv_index); @@ -325,7 +325,7 @@ void unidirectional_test(EFAEndpoint& endpoint, MemoryAllocator& allocator, } // Bandwidth test: perform N send/recv operations and measure bandwidth -void bandwidth_test(EFAEndpoint& endpoint, MemoryAllocator& allocator, +void bandwidth_test(NICEndpoint& endpoint, MemoryAllocator& allocator, int iterations) { std::cout << "\n=== Starting Bandwidth Test (" << iterations << " iterations) ===\n" @@ -359,8 +359,8 @@ void bandwidth_test(EFAEndpoint& endpoint, MemoryAllocator& allocator, for (int i = 0; i < 50; i++) { auto remote_mem_placeholder = std::make_shared(); auto send_req = - std::make_shared(send_mem, remote_mem_placeholder); - auto recv_req = std::make_shared(recv_mem); + std::make_shared(send_mem, remote_mem_placeholder); + auto recv_req = std::make_shared(recv_mem); int64_t recv_index = endpoint.recv(FLAGS_remote_rank, recv_req); int64_t send_wr_id = endpoint.send(FLAGS_remote_rank, send_req); @@ -383,8 +383,8 @@ void bandwidth_test(EFAEndpoint& endpoint, MemoryAllocator& allocator, for (int i = 0; i < iterations; i++) { auto remote_mem_placeholder = std::make_shared(); auto send_req = - std::make_shared(send_mem, remote_mem_placeholder); - auto recv_req = std::make_shared(recv_mem); + std::make_shared(send_mem, remote_mem_placeholder); + auto recv_req = std::make_shared(recv_mem); int64_t recv_index = endpoint.recv(FLAGS_remote_rank, recv_req); recv_indices.push_back(recv_index); @@ -465,7 +465,7 @@ int main(int argc, char* argv[]) { FLAGS_logtostderr = true; // Parse command line flags - gflags::SetUsageMessage("EFAEndpoint usage example"); + gflags::SetUsageMessage("NICEndpoint usage example"); gflags::ParseCommandLineFlags(&argc, &argv, true); // Validate required flags @@ -475,7 +475,7 @@ int main(int argc, char* argv[]) { return 1; } - std::cout << "=== EFAEndpoint Usage Example ===\n"; + std::cout << "=== NICEndpoint Usage Example ===\n"; std::cout << "GPU Index: " << FLAGS_gpu_index << "\n"; std::cout << "Rank ID: " << FLAGS_rank_id << "\n"; std::cout << "Port: " << FLAGS_port << "\n"; @@ -490,7 +490,7 @@ int main(int argc, char* argv[]) { // std::cout << "Allocated " << gpu_mem->size << " bytes of GPU memory at " // << gpu_mem->addr << std::endl; // RemoteMemInfo info(gpu_mem); - // recv_test_ = std::make_shared(gpu_mem); + // recv_test_ = std::make_shared(gpu_mem); try { // Set GPU device for the entire process cudaError_t cuda_err = cudaSetDevice(FLAGS_gpu_index); @@ -517,11 +517,11 @@ int main(int argc, char* argv[]) { std::cout << "Found " << device_manager.deviceCount() << " RDMA device(s)\n\n"; - // Create EFAEndpoint with device_ids = {0} - std::cout << "Creating EFAEndpoint...\n"; + // Create NICEndpoint with device_ids = {0} + std::cout << "Creating NICEndpoint...\n"; std::vector device_ids = {0, 1}; - EFAEndpoint endpoint(FLAGS_gpu_index, FLAGS_rank_id, FLAGS_port); - std::cout << "EFAEndpoint created successfully\n\n"; + NICEndpoint endpoint(FLAGS_gpu_index, FLAGS_rank_id, FLAGS_port); + std::cout << "NICEndpoint created successfully\n\n"; // Create OOBMetaData for remote rank std::cout << "Setting up remote rank metadata...\n"; diff --git a/p2p/transfer.py b/p2p/transfer.py index 620be4d3b..28a9d3e4d 100644 --- a/p2p/transfer.py +++ b/p2p/transfer.py @@ -9,7 +9,6 @@ class TransferManager: - class ConnState: def __init__( self, diff --git a/thirdparty/dietgpu/.gitignore b/thirdparty/dietgpu/.gitignore new file mode 100644 index 000000000..01b7c130c --- /dev/null +++ b/thirdparty/dietgpu/.gitignore @@ -0,0 +1,17 @@ +build/ +CMakeFiles/ +*.cmake +__pycache__/ +*.pyc +*.pyo +*.pyd +*.pyw +*.pyz +*.pywz +*.pyzw +Makefile +detect_cuda_version.cc +detect_cuda_compute_capabilities.cu +*.hip +*_hip.* +*.egg-info \ No newline at end of file diff --git a/thirdparty/dietgpu/CMakeLists.txt b/thirdparty/dietgpu/CMakeLists.txt new file mode 100644 index 000000000..bd102a112 --- /dev/null +++ b/thirdparty/dietgpu/CMakeLists.txt @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the MIT-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.10 FATAL_ERROR) +project(dietgpu LANGUAGES CUDA CXX VERSION 1.0) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_C_STANDARD 11) + +include(CheckLanguage) +check_language(CUDA) + +if(NOT DEFINED CMAKE_CUDA_STANDARD) + set(CMAKE_CUDA_STANDARD 17) + set(CMAKE_CUDA_STANDARD_REQUIRED ON) +endif() + +find_package(CUDA REQUIRED) +find_package(Torch REQUIRED) +include_directories(${TORCH_INCLUDE_DIRS}) + +add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") + +if(${CMAKE_VERSION} VERSION_LESS_EQUAL "3.13.4") + cuda_select_nvcc_arch_flags(ARCH_FLAGS "Auto") + message("ARCH_FLAGS = ${ARCH_FLAGS}") + string(REPLACE "-gencode;" "--generate-code=" ARCH_FLAGS "${ARCH_FLAGS}") + string(APPEND CMAKE_CUDA_FLAGS "${ARCH_FLAGS}") +else() + include(FindCUDA/select_compute_arch) + CUDA_DETECT_INSTALLED_GPUS(INSTALLED_GPU_CCS_1) + string(STRIP "${INSTALLED_GPU_CCS_1}" INSTALLED_GPU_CCS_2) + string(REPLACE " " ";" INSTALLED_GPU_CCS_3 "${INSTALLED_GPU_CCS_2}") + string(REPLACE "." "" CUDA_ARCH_LIST "${INSTALLED_GPU_CCS_3}") + set(CMAKE_CUDA_ARCHITECTURES ${CUDA_ARCH_LIST}) + set_property(GLOBAL PROPERTY CUDA_ARCHITECTURES "${CUDA_ARCH_LIST}") +endif() + +# Set default build type. +if(NOT CMAKE_BUILD_TYPE) + message(STATUS "Setting build type to 'RelWithDebInfo' as none was specified.") + set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING + "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel." + FORCE + ) +endif() + +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + +add_subdirectory(third_party/glog) +add_subdirectory(third_party/googletest) + +add_subdirectory(dietgpu) +add_subdirectory(dietgpu/utils) +add_subdirectory(dietgpu/ans) +add_subdirectory(dietgpu/float) diff --git a/thirdparty/dietgpu/CODE_OF_CONDUCT.md b/thirdparty/dietgpu/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..15ef8e829 --- /dev/null +++ b/thirdparty/dietgpu/CODE_OF_CONDUCT.md @@ -0,0 +1,2 @@ +# Code of Conduct +Facebook has adopted a Code of Conduct that we expect project participants to adhere to. Please [read the full text](https://code.fb.com/codeofconduct) so that you can understand what actions will and will not be tolerated. diff --git a/thirdparty/dietgpu/CONTRIBUTING.md b/thirdparty/dietgpu/CONTRIBUTING.md new file mode 100644 index 000000000..b9c8fbe5d --- /dev/null +++ b/thirdparty/dietgpu/CONTRIBUTING.md @@ -0,0 +1,11 @@ +# Contributing to DietGPU + +DietGPU is still in a fairly early stage, and it is being rapidly iterated on to support integration into networked collective communication frameworks for distributed ML computation. + +The underlying rANS codec itself is fairly stable, but extensions to the library will include fused all-reduce kernel implementations, and specializations for more structured data (sparse data, dimensional correlations, etc). A CUB-like device library for fused kernel rANS usage is on the table as well. + +Contributions are very welcome, but preferably once the library achieves some stability, as this is a very early release of the code. + +## License + +By contributing to DietGPU, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree. diff --git a/thirdparty/dietgpu/Dockerfile b/thirdparty/dietgpu/Dockerfile new file mode 100644 index 000000000..e338baa0b --- /dev/null +++ b/thirdparty/dietgpu/Dockerfile @@ -0,0 +1,25 @@ +FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 as dev-base + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + ca-certificates \ + curl \ + git && \ + rm -rf /var/lib/apt/lists/* +ENV PATH /opt/conda/bin:$PATH + +FROM dev-base as conda +ENV PYTHON_VER=3.10 +RUN curl -fsSL -v -o ~/miniconda.sh -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + chmod +x ~/miniconda.sh && \ + ~/miniconda.sh -b -p /opt/conda && \ + rm ~/miniconda.sh && \ + /opt/conda/bin/conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main && \ + /opt/conda/bin/conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r && \ + /opt/conda/bin/conda install -y python=${PYTHON_VER} pytorch cmake ninja -c pytorch && \ + /opt/conda/bin/pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 && \ + /opt/conda/bin/conda clean -ya + +ENV CMAKE_PREFIX_PATH="/opt/conda/bin/../:/opt/conda/lib/python3.9/site-packages/torch/share/cmake" + +CMD bash diff --git a/thirdparty/dietgpu/LICENSE.md b/thirdparty/dietgpu/LICENSE.md new file mode 100644 index 000000000..b93be9051 --- /dev/null +++ b/thirdparty/dietgpu/LICENSE.md @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Meta Platforms, Inc. and affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/thirdparty/dietgpu/README.md b/thirdparty/dietgpu/README.md new file mode 100644 index 000000000..ba265dff6 --- /dev/null +++ b/thirdparty/dietgpu/README.md @@ -0,0 +1,131 @@ +# DietGPU: GPU-based lossless compression for numerical data + +Author: Jeff Johnson (@wickedfoo), `jhj _at_ fb.com` + +*(NOTE: very early alpha preview of the library; it is still under rapid development)* + +DietGPU is a library for fast specialized lossless compression of data on Nvidia GPUs, meant for ML/HPC applications. It also contains the first publicly available GPU-based generalized [asymmetric numeral system (ANS)](https://en.wikipedia.org/wiki/Asymmetric_numeral_systems) compressor and decompressor. It is a GPU analogue to Yann Collet's [FSE (Finite State Entropy)](https://github.com/Cyan4973/FiniteStateEntropy) ANS library. + +It currently consists of two parts: + +- **ANS entropy codec**: a generalized byte-oriented range-based ANS (rANS) entropy encoder and decoder, that operates at throughputs around 250-410 GB/s for reasonable data sizes on an A100 GPU. +- **Floating point codec**: an extension to the above to handle fast lossless compression and decompression of unstructured floating point data, for use in ML and HPC applications, especially in communicating over local interconnects (PCIe / NVLink) and remote interconnects (Ethernet / InfiniBand). This operates at around 250-600 GB/s for reasonable data sizes on an A100 GPU. + +Both APIs are available in both C++ (raw device pointers) and Python/PyTorch (PyTorch tensor) API forms. + +## Documentation + +Documentation is [available in the wiki](https://github.com/facebookresearch/dietgpu/wiki). + +## Building + +Clone this repo using + +```shell +git clone --recursive https://github.com/facebookresearch/dietgpu +cd dietgpu +``` + +Then the simplest way is to use the included Dockerfile, which installs the PyTorch dependencies *and* uses NVIDIA's dev image as a base (for the CUDA dependencies): + +```shell +docker build -t dietgpu . +docker run --privileged --runtime=nvidia --rm -v $(pwd):/dietgpu -it dietgpu:latest +``` + +Note you need NVIDIA's container runtime installed (if on Fedora consult this [Github issue](https://github.com/NVIDIA/nvidia-docker/issues/706#issuecomment-851816502)). + +Then do the standard CMake thing: + +```shell +cd dietgpu; mkdir build; cd build; +cmake .. -G Ninja +cmake --build . --target all +``` + +If you get complaints about `TorchConfig.cmake` then your `CMAKE_PREFIX_PATH` doesn't have the right paths; run + +```shell +python -c 'import torch;print(torch.utils.cmake_prefix_path)' +``` + +to discover where `TorchConfig.cmake` lives (and add that path to your `CMAKE_PREFIX_PATH`). +In general, you can run +```shell +export CMAKE_PREFIX_PATH="$(dirname $(which conda))/../:$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" +``` + +If you get complaints about `/dietgpu/third_party/glog... does not contain a CMakeLists.txt file.` then you didn't pull the submodules; run + +```shell +git submodule sync +git submodule update --init --recursive --jobs 0 +``` +and try again. + +## Library rationale + +As on-device global memory / HBM bandwidth continues to improve at a faster rate than CPU/GPU interconnect or server-to-server networking bandwidth, spending GPU compute and gmem bandwidth to save on data sent over interconnects is becoming more advantageous. DietGPU aims to target this gap. + +One can imagine a Pareto-optimal tradeoff curve between realizable compression ratios versus speed. On one end of the curve exists algorithms for supporting arbitrary data using dictionary/LZ type compression like some of the techniques in [Nvidia's nvCOMP](https://github.com/NVIDIA/nvcomp) at potentially high compression rates. At another end of the curve, one can imagine use completely on-device as something like a 1990s-style virtual RAM extender, where achievable compression is only 0.6x-0.9x or so, but compression can operate at around 1/4x to 1/2x the peak global memory bandwidth of the GPU. We emphasize the latter, where speed rather than compression ratio is important, where we can compress data that is even sent between GPUs in a single server over NVLink or PCIe. The savings may be low, but the effective network speed can be increased by 10-30%. For large-scale neural network training on hundreds of GPUs, this could translate into an additional 5-10% end-to-end performance increase. + +The initial focus of this library will be in HPC/ML distributed collective communications libraries, for primitives such as all-to-all, all-gather, reduce-scatter and all-reduce. Right now no off the shelf integration is provided (in progress), but the basics of the C++ API are available for use, as are Python-level PyTorch tensor-based APIs. + +## ANS codec + +The rANS codec operates on 8 bit bytes. It can compress arbitrary data, but using statistics gathered on a bytewise basis, so data highly structured or redundant at a level above byte level will typically not compress well. This codec however is meant to be applicable for any number of lossless compression applications, including usage as an entropy coder for LZ or RLE type matches for a fully-formed compression system. Symbol probability precisions supported are 9, 10 and 11 bits (i.e., symbol occurances are quantized to the nearest 1/512, 1/1024 or 1/2048). + +## Float codec + +The floating point compressor at the moment uses the rANS codec to handle compression of floating point exponents, as typically in ML/HPC data a very limited exponent dynamic range is used and is highly compressible. Floating point sign and significand values tend to be less compressible / fairly high entropy in practice, though sparse data or presence of functions like ReLU in neural networks can result in a lot of outright zero values which are very compressible. A future extension to the library will allow for specialized compression of sparse or semi-sparse data, specializing compression of zeros. At the moment only float16 (IEEE 754 binary16) and bfloat16 (fields of the most significant 16 bits of a IEEE 754 binary32 word) are supported, with float32 (IEEE 754 binary32) support coming shortly. + +## API design + +The basics of the design should work on CC 3.5+ (Kepler class) GPUs or later, though it has been primarily developed for and has only been tested on V100/A100 GPUs. + +Both APIs are available in both C++ (raw pointers) and Python/PyTorch (PyTorch tensor) API forms. It is a batch oriented API; both compression and decompression operate in batches of independent arrays of data which are independently compressed or decompressed, though with the floating point compressor, all arrays in the batch must be of the same data type. ANS compression symbol probabilities are calculated independently for each array in the batch, and each produced output compressed tensor in a batch is independently decompressible (and the ANS statistics are tailored to each individual array in the batch). See the wiki for details. + +The APIs are oriented around batching, though providing a large batch size of 1 also results in good performance (in fact, bs > 1 has somewhat worse performance than bs = 1 for sufficiently large data sizes at the moment, due to work imbalance issues). Arrays in the batch can be of arbitrary, varying sizes. The library treats all data as unstructured 1 dimensional arrays, so the PyTorch API does not really care about dimensionality. The primitive unit of compression are 4 KiB segments of the input data, which are assigned to individual warps. Typically, it is not worth using DietGPU unless one has at least 512 KiB of data or so due to compression overheads, and poor performance will be seen unless the total data size (whether bs = 1 or a large batch) is enough such that (total size in bytes / 4 KiB) is on par with the number of concurrently running warps that will saturate a GPUs SMs. + +All computation takes place completely on device. The design of the library pays special attention to avoiding memory allocations/deallocations and spurious device-to-host/host-to-device interactions and synchronizations where possible. Assuming inputs and outputs are properly sized and if enough temporary memory scratch space is provided up front, compression and decompression can run completely asynchronously on the GPU without CPU intervention. However, only the GPU during compression knows the actual final compressed size, and a typical application will need to copy the output size buffer containing the final compressed sizes per compression job in the batch in bytes back to the host for use in relocating compressed data elsewhere (in local memory or over the network), so we know how much data to send or copy. As the final output size cannot be predicted in advance, a function is provided to bound the maximum possible compressed output size (which is in fact larger than the input data size) which can be used to allocate an appropriate region of memory for the output. Realizing actual compression savings for applications other than networking would involve an additional memory allocation and memcpy to a new exactly sized buffer. + +## Performance + +Performance depends upon many factors, including entropy of the input data (higher entropy = more ANS stack memory operations = lower performance), number of SMs on the device and batch/data sizes. Here are some sample runs using an A100 GPU and the sync/alloc-free API on a batch size of 1 from the python PyTorch API, using `torch.normal(0, 1.0, [size], dtype=dt, ...)` to approximate a typical quasi-Gaussian data distribution as seen in real ML data. The float codec for bfloat16 extracts and compresses just the 8 bit exponent, while for float16 it currently operates on the most significant byte of the float word (containing the sign bit, 5 bits of exponent and 2 bits of significand). Typical ML float data might only have 2.7 bits of entropy in the exponent, so the savings ((8 + 2.7) / 16 ~= 0.67x for bfloat16, (11 + 2.7) / 16 ~= 0.85x for float16) is what is seen in the exponent-only strategy. + +![non-batch bfloat16 performance](images/dietgpu_bfloat16_nb.png) +![non-batch float16 performance](images/dietgpu_float16_nb.png) + +## Planned extensions + +- float32 support, possibly float64 support +- compression options to expect semi-sparse floating point data for higher compression (>10% zero values) +- a fused kernel implementation (likely using CUDA cooperative groups) to support single-kernel compression and decompression minimizing temporary memory usage +- a fused kernel implementation using the above to support persistent NCCL-like all-reduce for collective communications libraries +- CUB-like APIs for fusing warp-oriented ANS compression and decompression into arbitrary user kernels +- int32/int64 compression using a fixed-word size LZ-type window history +- support for embedding table compression with sparse reads/row gathers + +## References + +Prior GPU-based ANS implementations [to my knowledge](https://encode.su/threads/2078-List-of-Asymmetric-Numeral-Systems-implementations) include: + +- [GST: GPU-decodable Supercompressed Textures](https://gamma.cs.unc.edu/GST/) (not a generalized ANS codec; meant as part of a texture compression scheme) +- Weissenberger and Schmidt, [Massively Parallel ANS Decoding on GPUs](https://dl.acm.org/doi/10.1145/3337821.3337888) (a decoder only) + +Related GPU entropy coder works include: + +- Yamamoto et al., [Huffman Coding with Gap Arrays for GPU Acceleration](https://dl.acm.org/doi/10.1145/3404397.3404429) + +Related lossless floating point compression works include: + +- Lindstrom and Isenburg, [Fast and Efficient Compression of Floating-Point Data](https://computing.llnl.gov/projects/fpzip) (CPU-based) +- Various GPU works from Martin Burtscher's group at Texas State such as Yang et al., [MPC: A Massively Parallel Compression Algorithm for Scientific Data](https://www.semanticscholar.org/paper/MPC%3A-A-Massively-Parallel-Compression-Algorithm-for-Yang-Mukka/1ab6910c90ad714e29954ccd69d569eb2003eb20) + +These works are sometimes oriented at compressing HPC-type data (e.g., 2d/3d/Nd grid data) where there may be local/dimensional correlations that can be exploited. + +[nvCOMP](https://github.com/NVIDIA/nvcomp), Nvidia's GPU lossless compression library. + +## License + +DietGPU is licensed with the MIT license, available in the LICENSE file at the top level. diff --git a/thirdparty/dietgpu/dietgpu/CMakeLists.txt b/thirdparty/dietgpu/dietgpu/CMakeLists.txt new file mode 100644 index 000000000..57a98755e --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/CMakeLists.txt @@ -0,0 +1,17 @@ +find_package(Torch REQUIRED) + +add_library(dietgpu SHARED + DietGpu.cpp +) +add_dependencies(dietgpu + gpu_float_compress +) +target_link_libraries(dietgpu PRIVATE + gpu_float_compress + glog::glog + ${TORCH_LIBRARIES} +) +target_include_directories(dietgpu PRIVATE + $ + "${TORCH_INCLUDE_DIRS}" +) diff --git a/thirdparty/dietgpu/dietgpu/DietGpu.cpp b/thirdparty/dietgpu/dietgpu/DietGpu.cpp new file mode 100644 index 000000000..c283da2c3 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/DietGpu.cpp @@ -0,0 +1,971 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include "dietgpu/ans/GpuANSCodec.h" +#include "dietgpu/float/GpuFloatCodec.h" +#include "dietgpu/utils/StackDeviceMemory.h" + +namespace dietgpu { + +namespace { + +FloatType getFloatTypeFromDtype(at::ScalarType t) { + switch (t) { + case at::ScalarType::Half: + return FloatType::kFloat16; + case at::ScalarType::BFloat16: + return FloatType::kBFloat16; + case at::ScalarType::Float: + return FloatType::kFloat32; + default: + TORCH_CHECK( + t == at::ScalarType::Half || t == at::ScalarType::BFloat16 || + t == at::ScalarType::Float); + return FloatType::kUndefined; + } +} + +at::ScalarType getDtypeFromFloatType(FloatType ft) { + switch (ft) { + case FloatType::kFloat16: + return at::ScalarType::Half; + case FloatType::kBFloat16: + return at::ScalarType::BFloat16; + case FloatType::kFloat32: + return at::ScalarType::Float; + default: + TORCH_CHECK( + ft == FloatType::kFloat16 || ft == FloatType::kBFloat16 || + ft == FloatType::kFloat32); + return at::ScalarType::Half; + } +} + +FloatType getFloatTypeFromTensor(const torch::Tensor& t) { + return getFloatTypeFromDtype(t.dtype().toScalarType()); +} + +// returns (totalSize, maxSize) +std::tuple getTotalAndMaxSize( + const std::vector& tIns) { + int64_t totalSize = 0; + int64_t maxSize = 0; + + for (auto& t : tIns) { + auto curSize = t.numel(); + // FIXME: due to int indexing, it's really total size + TORCH_CHECK( + curSize * t.element_size() <= std::numeric_limits::max()); + + totalSize += curSize; + maxSize = std::max(maxSize, curSize); + } + + TORCH_CHECK(maxSize <= std::numeric_limits::max()); + + return std::make_tuple(totalSize, maxSize); +} + +// Convert a compressed matrix into a list of tensors that are views into the +// compressed row pieces +std::vector compressedMatrixToTensors( + int numInBatch, + torch::Tensor& matrix_dev, + torch::Tensor& sizes_dev) { + auto stream = at::cuda::getCurrentCUDAStream(); + + // We wish to return narrowed tensors with a view into the matrix + auto sizes_host = std::vector(numInBatch); + + CUDA_VERIFY(cudaMemcpyAsync( + sizes_host.data(), + sizes_dev.data_ptr(), + sizeof(uint32_t) * numInBatch, + cudaMemcpyDeviceToHost, + stream)); + + auto out = std::vector(numInBatch); + + auto matrix1d = matrix_dev.view({matrix_dev.numel()}); + auto cols = matrix_dev.size(1); + + for (int i = 0; i < numInBatch; ++i) { + out[i] = matrix1d.narrow(0, i * cols, sizes_host[i]); + } + + return out; +} + +} // namespace + +// +// External API +// + +constexpr int kDefaultPrecision = 10; + +std::tuple max_float_compressed_output_size( + const std::vector& ts) { + auto sizes = getTotalAndMaxSize(ts); + + auto maxCompSize = getMaxFloatCompressedSize( + getFloatTypeFromTensor(ts[0]), std::get<1>(sizes)); + + return std::make_tuple(ts.size(), maxCompSize); +} + +// FIXME: can we pass a dtype somehow instead? +int64_t max_float_compressed_size(const torch::Tensor& dtype, int64_t size) { + return getMaxFloatCompressedSize(getFloatTypeFromTensor(dtype), size); +} + +std::tuple max_any_compressed_output_size( + const std::vector& ts) { + auto sizes = getTotalAndMaxSize(ts); + int64_t maxBytes = std::get<1>(sizes) * ts[0].element_size(); + + return std::make_tuple(ts.size(), getMaxCompressedSize(maxBytes)); +} + +int64_t max_any_compressed_size(int64_t bytes) { + return getMaxCompressedSize(bytes); +} + +////////////////////// +// +// Compress +// +////////////////////// + +std::tuple compress_data_res( + bool compressAsFloat, + StackDeviceMemory& res, + const std::vector& tIns, + bool checksum, + const std::optional& outCompressed, + const std::optional& outCompressedSizes) { + TORCH_CHECK(!tIns.empty()); + + // All computation will take place on this device + int dev = tIns.front().get_device(); + DeviceScope device(dev); + + auto maxOutputSize = compressAsFloat ? max_float_compressed_output_size(tIns) + : max_any_compressed_output_size(tIns); + + // + // Validate input and validate / construct output + // + for (auto& t : tIns) { + TORCH_CHECK(t.device().type() == at::kCUDA); + TORCH_CHECK(t.is_contiguous()); + + // device must be consistent + TORCH_CHECK(t.get_device() == dev); + + // must be all the same type unless we are compressing bytewise + if (compressAsFloat) { + TORCH_CHECK(t.dtype() == tIns[0].dtype()); + + // must be a supported float type + TORCH_CHECK( + getFloatTypeFromDtype(t.dtype().toScalarType()) != + FloatType::kUndefined); + } + } + + torch::Tensor comp; + if (outCompressed) { + TORCH_CHECK(outCompressed->dtype() == torch::kByte); + TORCH_CHECK(outCompressed->device().type() == at::kCUDA); + TORCH_CHECK(outCompressed->is_contiguous()); + TORCH_CHECK(outCompressed->dim() == 2); + TORCH_CHECK(outCompressed->size(0) >= tIns.size()); + TORCH_CHECK(outCompressed->size(1) >= std::get<1>(maxOutputSize)); + TORCH_CHECK(outCompressed->get_device() == dev); + + comp = *outCompressed; + } else { + comp = torch::empty( + {(int64_t)tIns.size(), std::get<1>(maxOutputSize)}, + at::TensorOptions() + .device(tIns[0].device()) + .dtype(at::ScalarType::Byte)); + } + + auto inPtrs = std::vector(tIns.size()); + auto inSize = std::vector(tIns.size()); + auto compPtrs = std::vector(tIns.size()); + + for (size_t i = 0; i < tIns.size(); ++i) { + auto& t = tIns[i]; + + inPtrs[i] = t.data_ptr(); + inSize[i] = compressAsFloat ? t.numel() : (t.numel() * t.element_size()); + compPtrs[i] = (uint8_t*)comp.data_ptr() + i * comp.size(1); + } + + // + // Validate / construct output sizes + // + torch::Tensor sizes; + if (outCompressedSizes) { + TORCH_CHECK(outCompressedSizes->dtype() == torch::kInt); + TORCH_CHECK(outCompressedSizes->device().type() == at::kCUDA); + TORCH_CHECK(outCompressedSizes->dim() == 1); + TORCH_CHECK(outCompressedSizes->is_contiguous()); + TORCH_CHECK(outCompressedSizes->size(0) >= tIns.size()); + TORCH_CHECK(outCompressedSizes->get_device() == dev); + + sizes = *outCompressedSizes; + } else { + // FIXME: no uint32 in torch + sizes = torch::empty( + {(int64_t)tIns.size()}, + at::TensorOptions() + .device(tIns[0].device()) + .dtype(at::ScalarType::Int)); + } + + if (compressAsFloat) { + auto config = FloatCompressConfig( + getFloatTypeFromTensor(tIns[0]), + ANSCodecConfig(kDefaultPrecision, false), + false /* we'll figure this out later */, + checksum); + + floatCompress( + res, + config, + tIns.size(), + inPtrs.data(), + inSize.data(), + compPtrs.data(), + // FIXME: int32_t versus uint32_t + (uint32_t*)sizes.data_ptr(), + at::cuda::getCurrentCUDAStream()); + } else { + auto config = ANSCodecConfig(kDefaultPrecision, checksum); + + ansEncodeBatchPointer( + res, + config, + tIns.size(), + inPtrs.data(), + inSize.data(), + nullptr, + compPtrs.data(), + // FIXME: int32_t versus uint32_t + (uint32_t*)sizes.data_ptr(), + at::cuda::getCurrentCUDAStream()); + } + + // how much temporary memory we actually used + int64_t tempMemUsage = res.getMaxMemoryUsage(); + return std::make_tuple(std::move(comp), std::move(sizes), tempMemUsage); +} + +std::tuple compress_data( + bool compressAsFloat, + const std::vector& tIns, + bool checksum, + const std::optional& tempMem, + const std::optional& outCompressed, + const std::optional& outCompressedSizes) { + TORCH_CHECK(!tIns.empty()); + + // All computation will take place on this device; set before creating the + // GpuResources object + int dev = tIns.front().get_device(); + DeviceScope device(dev); + + // Validate temp memory if passed + if (tempMem) { + TORCH_CHECK(tempMem->device().type() == at::kCUDA); + TORCH_CHECK(tempMem->is_contiguous()); + + // Should be on the same device as the first tensor passed + TORCH_CHECK(tempMem->get_device() == tIns.front().get_device()); + } + + auto res = StackDeviceMemory( + getCurrentDevice(), + tempMem ? tempMem->data_ptr() : nullptr, + tempMem ? tempMem->numel() * tempMem->element_size() : 0); + + // The rest of the validation takes place here + return compress_data_res( + compressAsFloat, res, tIns, checksum, outCompressed, outCompressedSizes); +} + +std::tuple, torch::Tensor, int64_t> +compress_data_split_size( + bool compressAsFloat, + const torch::Tensor& tIn, + const torch::Tensor& tSplitSizes, + bool checksum, + const std::optional& tempMem, + const std::optional& outCompressed, + const std::optional& outCompressedSizes) { + // All computation will take place on this device; set before creating the + // GpuResources object + int dev = tIn.get_device(); + DeviceScope device(dev); + auto stream = at::cuda::getCurrentCUDAStream(); + + // Validate temp memory if passed + if (tempMem) { + TORCH_CHECK(tempMem->device().type() == at::kCUDA); + TORCH_CHECK(tempMem->is_contiguous()); + + // Should be on the same device as the first tensor passed + TORCH_CHECK(tempMem->get_device() == dev); + } + + // Validate input + auto floatType = compressAsFloat + ? getFloatTypeFromDtype(tIn.dtype().toScalarType()) + : FloatType::kUndefined; + + TORCH_CHECK(tIn.device().type() == at::kCUDA); + TORCH_CHECK(tIn.is_contiguous()); + TORCH_CHECK(tIn.get_device() == dev); + if (compressAsFloat) { + TORCH_CHECK(floatType != FloatType::kUndefined); + } else { + // All input data must meet ANS alignment + TORCH_CHECK( + uintptr_t(tIn.data_ptr()) % kANSRequiredAlignment == 0, + "All splits should start on a 16 byte boundary; " + "start pointer is not aligned"); + } + + // Validate split sizes + auto numInBatch = tSplitSizes.numel(); + TORCH_CHECK(tSplitSizes.is_contiguous()); + TORCH_CHECK(tSplitSizes.device().type() == at::kCPU); + TORCH_CHECK(tSplitSizes.dtype() == torch::kInt); + + uint32_t maxSize = 0; + for (size_t i = 0; i < numInBatch; ++i) { + auto size = ((const int32_t*)tSplitSizes.data_ptr())[i]; + TORCH_CHECK(size > 0); + maxSize = std::max((uint32_t)size, maxSize); + + if (!compressAsFloat && i != (numInBatch - 1)) { + // All input data starts for direct ANS compression must meet ANS + // alignment + TORCH_CHECK( + size % kANSRequiredAlignment == 0, + "All splits should start on a 16 byte boundary; the size of an interior " + "split is not a multiple of 16 bytes"); + } + } + + auto maxCompressedBytes = compressAsFloat + ? getMaxFloatCompressedSize(floatType, maxSize) + : getMaxCompressedSize(maxSize); + + // Validate output + torch::Tensor comp; + if (outCompressed) { + TORCH_CHECK(outCompressed->dtype() == torch::kByte); + TORCH_CHECK(outCompressed->device().type() == at::kCUDA); + TORCH_CHECK(outCompressed->is_contiguous()); + TORCH_CHECK(outCompressed->dim() == 2); + TORCH_CHECK(outCompressed->size(0) >= numInBatch); + TORCH_CHECK(outCompressed->size(1) >= maxCompressedBytes); + TORCH_CHECK(outCompressed->get_device() == dev); + + comp = *outCompressed; + } else { + comp = torch::empty( + {(int64_t)numInBatch, maxCompressedBytes}, + at::TensorOptions().device(tIn.device()).dtype(at::ScalarType::Byte)); + } + + torch::Tensor sizes; + if (outCompressedSizes) { + TORCH_CHECK(outCompressedSizes->dtype() == torch::kInt); + TORCH_CHECK(outCompressedSizes->device().type() == at::kCUDA); + TORCH_CHECK(outCompressedSizes->dim() == 1); + TORCH_CHECK(outCompressedSizes->is_contiguous()); + TORCH_CHECK(outCompressedSizes->size(0) >= numInBatch); + TORCH_CHECK(outCompressedSizes->get_device() == dev); + + sizes = *outCompressedSizes; + } else { + // FIXME: no uint32 in torch + sizes = torch::empty( + {(int64_t)numInBatch}, + at::TensorOptions().device(tIn.device()).dtype(at::ScalarType::Int)); + } + + auto res = StackDeviceMemory( + getCurrentDevice(), + tempMem ? tempMem->data_ptr() : nullptr, + tempMem ? tempMem->numel() * tempMem->element_size() : 0); + + if (compressAsFloat) { + auto config = FloatCompressConfig( + floatType, + ANSCodecConfig(kDefaultPrecision, false), + false /* we'll figure this out later */, + checksum); + + floatCompressSplitSize( + res, + config, + numInBatch, + tIn.data_ptr(), + // FIXME: int32_t versus uint32_t + (const uint32_t*)tSplitSizes.data_ptr(), + comp.data_ptr(), + maxCompressedBytes, + // FIXME: int32_t versus uint32_t + (uint32_t*)sizes.data_ptr(), + stream); + } else { + auto config = ANSCodecConfig(kDefaultPrecision, checksum); + + ansEncodeBatchSplitSize( + res, + config, + numInBatch, + tIn.data_ptr(), + // FIXME: int32_t versus uint32_t + (const uint32_t*)tSplitSizes.data_ptr(), + nullptr, + comp.data_ptr(), + maxCompressedBytes, + // FIXME: int32_t versus uint32_t + (uint32_t*)sizes.data_ptr(), + stream); + } + + auto compList = compressedMatrixToTensors(numInBatch, comp, sizes); + + // how much temporary memory we actually used + int64_t tempMemUsage = res.getMaxMemoryUsage(); + return std::make_tuple(std::move(compList), std::move(sizes), tempMemUsage); +} + +std::vector compress_data_simple( + bool compressAsFloat, + const std::vector& tIns, + bool checksum, + const std::optional& tempMem) { + TORCH_CHECK(!tIns.empty()); + + std::tuple comp; + + if (tempMem && *tempMem > 0) { + torch::Tensor scratch = torch::empty( + {*tempMem}, + at::TensorOptions() + .device(tIns[0].device()) + .dtype(at::ScalarType::Byte)); + + // rest of validation takes place here + comp = compress_data( + compressAsFloat, tIns, checksum, scratch, std::nullopt, std::nullopt); + } else { + // rest of validation takes place here + comp = compress_data( + compressAsFloat, + tIns, + checksum, + std::nullopt, + std::nullopt, + std::nullopt); + } + + auto& compMatrix_dev = std::get<0>(comp); + auto& size_dev = std::get<1>(comp); + + torch::Tensor size_host = size_dev.to(torch::kCPU); + TORCH_CHECK(size_host.size(0) == tIns.size()); + + auto compMatrixRowStride = compMatrix_dev.size(1); + auto stream = at::cuda::getCurrentCUDAStream(); + + auto out = std::vector(); + for (int i = 0; i < tIns.size(); ++i) { + auto compSize = ((int32_t*)size_host.data_ptr())[i]; + + out.emplace_back(torch::empty( + {compSize}, + at::TensorOptions() + .device(tIns[0].device()) + .dtype(at::ScalarType::Byte))); + + // FIXME: custom batch kernel to avoid N cudaMemcpy calls? + CUDA_VERIFY(cudaMemcpyAsync( + out[i].data_ptr(), + (uint8_t*)compMatrix_dev.data_ptr() + i * compMatrixRowStride, + compSize, + cudaMemcpyDeviceToDevice, + stream)); + } + + return out; +} + +////////////////////// +// +// Decompress +// +////////////////////// + +int64_t decompress_data_res( + bool compressAsFloat, + StackDeviceMemory& res, + const std::vector& tIns, + const std::vector& tOuts, + bool checksum, + const std::optional& outStatus, + const std::optional& outSizes) { + TORCH_CHECK(!tIns.empty()); + TORCH_CHECK(tIns.size() == tOuts.size()); + + // All computation will take place on this device + int dev = tIns.front().get_device(); + DeviceScope device(dev); + + // Validate input and output + auto inPtrs = std::vector(tIns.size()); + auto outPtrs = std::vector(tIns.size()); + auto outCapacity = std::vector(tOuts.size()); + + for (size_t i = 0; i < tIns.size(); ++i) { + auto& tIn = tIns[i]; + auto& tOut = tOuts[i]; + + TORCH_CHECK(tIn.device().type() == at::kCUDA); + TORCH_CHECK(tIn.get_device() == dev); + TORCH_CHECK(tIn.is_contiguous()); + + TORCH_CHECK(tOut.device().type() == at::kCUDA); + TORCH_CHECK(tOut.get_device() == dev); + TORCH_CHECK(tOut.is_contiguous()); + + TORCH_CHECK(tIn.dtype() == torch::kByte); + if (compressAsFloat) { + TORCH_CHECK( + tOut.dtype() == torch::kFloat16 || tOut.dtype() == torch::kBFloat16 || + tOut.dtype() == torch::kFloat32); + } + + inPtrs[i] = tIn.data_ptr(); + outPtrs[i] = tOut.data_ptr(); + + auto outSize = + compressAsFloat ? tOut.numel() : (tOut.numel() * tOut.element_size()); + + // FIXME: total range checking + TORCH_CHECK(outSize <= std::numeric_limits::max()); + outCapacity[i] = outSize; + } + + // Validate outStatus, if passed + if (outStatus) { + TORCH_CHECK(outStatus->is_contiguous()); + TORCH_CHECK(outStatus->device().type() == at::kCUDA); + TORCH_CHECK(outStatus->dtype() == torch::kByte); + TORCH_CHECK(outStatus->numel() == tIns.size()); + TORCH_CHECK(outStatus->get_device() == dev); + } + + // Validate outSizes, if passed + if (outSizes) { + TORCH_CHECK(outSizes->is_contiguous()); + TORCH_CHECK(outSizes->device().type() == at::kCUDA); + TORCH_CHECK(outSizes->dtype() == torch::kInt32); + TORCH_CHECK(outSizes->numel() == tIns.size()); + TORCH_CHECK(outSizes->get_device() == dev); + } + + if (compressAsFloat) { + auto config = FloatDecompressConfig( + getFloatTypeFromTensor(tOuts[0]), + ANSCodecConfig(kDefaultPrecision, false), + false /* we'll figure this out later */, + checksum); + + auto decStatus = floatDecompress( + res, + config, + tIns.size(), + inPtrs.data(), + outPtrs.data(), + outCapacity.data(), + outStatus ? (uint8_t*)outStatus->data_ptr() : nullptr, + // FIXME: int32_t versus uint32_t + outSizes ? (uint32_t*)outSizes->data_ptr() : nullptr, + at::cuda::getCurrentCUDAStream()); + + TORCH_CHECK( + decStatus.error != FloatDecompressError::ChecksumMismatch, + "floatDecompress: checksum mismatch seen on decoded data; " + "archive cannot be unpacked"); + } else { + auto config = ANSCodecConfig(kDefaultPrecision, checksum); + + auto decStatus = ansDecodeBatchPointer( + res, + config, + tIns.size(), + inPtrs.data(), + outPtrs.data(), + outCapacity.data(), + outStatus ? (uint8_t*)outStatus->data_ptr() : nullptr, + // FIXME: int32_t versus uint32_t + outSizes ? (uint32_t*)outSizes->data_ptr() : nullptr, + at::cuda::getCurrentCUDAStream()); + + TORCH_CHECK( + decStatus.error != ANSDecodeError::ChecksumMismatch, + "ANSDecode: checksum mismatch seen on decoded data; " + "archive cannot be unpacked"); + } + + // how much temporary memory we actually used + return res.getMaxMemoryUsage(); +} + +int64_t decompress_data( + bool compressAsFloat, + const std::vector& tIns, + const std::vector& tOuts, + bool checksum, + const std::optional& tempMem, + const std::optional& outStatus, + const std::optional& outSizes) { + TORCH_CHECK(!tIns.empty()); + + // All computation will take place on this device; set before creating the + // GpuResources object + int dev = tIns.front().get_device(); + DeviceScope device(dev); + + // Validate temp memory if passed + if (tempMem) { + TORCH_CHECK(tempMem->device().type() == at::kCUDA); + TORCH_CHECK(tempMem->is_contiguous()); + TORCH_CHECK(tempMem->get_device() == tIns.front().get_device()); + // we don't care about data type, we just care about memory + } + + auto res = StackDeviceMemory( + getCurrentDevice(), + tempMem ? tempMem->data_ptr() : nullptr, + tempMem ? tempMem->numel() * tempMem->element_size() : 0); + + // Rest of validation happens here + return decompress_data_res( + compressAsFloat, res, tIns, tOuts, checksum, outStatus, outSizes); +} + +int64_t decompress_data_split_size( + bool compressAsFloat, + const std::vector& tIns, + torch::Tensor& tOut, + const torch::Tensor& tSplitSizes, + bool checksum, + const std::optional& tempMem, + const std::optional& outStatus, + const std::optional& outSizes) { + TORCH_CHECK(!tIns.empty()); + + // All computation will take place on this device; set before creating the + // GpuResources object + int dev = tIns.front().get_device(); + DeviceScope device(dev); + + auto numInBatch = tSplitSizes.numel(); + + // Validate temp memory if passed + if (tempMem) { + TORCH_CHECK(tempMem->device().type() == at::kCUDA); + TORCH_CHECK(tempMem->is_contiguous()); + TORCH_CHECK(tempMem->get_device() == tIns.front().get_device()); + // we don't care about data type, we just care about memory + } + + // Validate input, split sizes and output + auto inPtrs = std::vector(tIns.size()); + auto splitSizes = std::vector(tIns.size()); + + // Validate split sizes + TORCH_CHECK(tSplitSizes.is_contiguous()); + TORCH_CHECK(tSplitSizes.device().type() == at::kCPU); + TORCH_CHECK(tSplitSizes.dtype() == torch::kInt); + // Should be a size for each of the input tensors + TORCH_CHECK(numInBatch == tIns.size()); + + for (size_t i = 0; i < numInBatch; ++i) { + auto& tIn = tIns[i]; + + TORCH_CHECK(tIn.device().type() == at::kCUDA); + TORCH_CHECK(tIn.get_device() == dev); + TORCH_CHECK(tIn.is_contiguous()); + + TORCH_CHECK(tIn.dtype() == torch::kByte); + + inPtrs[i] = tIn.data_ptr(); + + auto size = ((const int32_t*)tSplitSizes.data_ptr())[i]; + TORCH_CHECK(size > 0); + splitSizes[i] = size; + } + + // Validate output + TORCH_CHECK(tOut.device().type() == at::kCUDA); + TORCH_CHECK(tOut.get_device() == dev); + TORCH_CHECK(tOut.is_contiguous()); + if (compressAsFloat) { + TORCH_CHECK( + tOut.dtype() == torch::kFloat16 || tOut.dtype() == torch::kBFloat16 || + tOut.dtype() == torch::kFloat32); + } + + auto outSize = + compressAsFloat ? tOut.numel() : (tOut.numel() * tOut.element_size()); + + // FIXME: total range checking + TORCH_CHECK(outSize <= std::numeric_limits::max()); + + // Validate outStatus, if passed + if (outStatus) { + TORCH_CHECK(outStatus->is_contiguous()); + TORCH_CHECK(outStatus->device().type() == at::kCUDA); + TORCH_CHECK(outStatus->dtype() == torch::kByte); + TORCH_CHECK(outStatus->numel() == numInBatch); + TORCH_CHECK(outStatus->get_device() == dev); + } + + // Validate outSizes, if passed + if (outSizes) { + TORCH_CHECK(outSizes->is_contiguous()); + TORCH_CHECK(outSizes->device().type() == at::kCUDA); + TORCH_CHECK(outSizes->dtype() == torch::kInt32); + TORCH_CHECK(outSizes->numel() == numInBatch); + TORCH_CHECK(outSizes->get_device() == dev); + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + auto res = StackDeviceMemory( + getCurrentDevice(), + tempMem ? tempMem->data_ptr() : nullptr, + tempMem ? tempMem->numel() * tempMem->element_size() : 0); + + if (compressAsFloat) { + auto config = FloatDecompressConfig( + getFloatTypeFromTensor(tOut), + ANSCodecConfig(kDefaultPrecision, false), + false /* we figure this out later */, + checksum); + + auto decStatus = floatDecompressSplitSize( + res, + config, + numInBatch, + (const void**)inPtrs.data(), + tOut.data_ptr(), + splitSizes.data(), + (uint8_t*)(outStatus ? outStatus->data_ptr() : nullptr), + // FIXME: int32_t vs uint32_t + (uint32_t*)(outSizes ? outSizes->data_ptr() : nullptr), + stream); + + TORCH_CHECK( + decStatus.error != FloatDecompressError::ChecksumMismatch, + "floatDecompress: checksum mismatch seen on decoded data; " + "archive cannot be unpacked"); + } else { + auto config = ANSCodecConfig(kDefaultPrecision, checksum); + + auto decStatus = ansDecodeBatchSplitSize( + res, + config, + numInBatch, + (const void**)inPtrs.data(), + tOut.data_ptr(), + splitSizes.data(), + (uint8_t*)(outStatus ? outStatus->data_ptr() : nullptr), + // FIXME: int32_t vs uint32_t + (uint32_t*)(outSizes ? outSizes->data_ptr() : nullptr), + stream); + + TORCH_CHECK( + decStatus.error != ANSDecodeError::ChecksumMismatch, + "ANSDecode: checksum mismatch seen on decoded data; " + "archive cannot be unpacked"); + } + + // how much temporary memory we actually used + return res.getMaxMemoryUsage(); +} + +std::vector decompress_data_simple( + bool compressAsFloat, + const std::vector& tIns, + bool checksum, + const std::optional& tempMem) { + TORCH_CHECK(!tIns.empty()); + auto stream = at::cuda::getCurrentCUDAStream(); + + // All computation will take place on this device + int dev = tIns.front().get_device(); + DeviceScope device(dev); + + size_t tempMemToUse = 0; + if (tempMem && *tempMem >= kSDMAlignment) { + tempMemToUse = *tempMem; + } + + torch::Tensor scratch; + if (tempMemToUse) { + scratch = torch::empty( + {(int64_t)tempMemToUse}, + at::TensorOptions() + .device(tIns[0].device()) + .dtype(at::ScalarType::Byte)); + } + + auto res = StackDeviceMemory( + getCurrentDevice(), + tempMemToUse ? scratch.data_ptr() : nullptr, + tempMemToUse); + + auto sizes_dev = res.alloc(stream, tIns.size()); + auto types_dev = res.alloc(stream, tIns.size()); + + auto inPtrs = std::vector(tIns.size()); + for (int i = 0; i < tIns.size(); ++i) { + auto& tIn = tIns[i]; + + inPtrs[i] = tIn.data_ptr(); + + TORCH_CHECK(tIn.device().type() == at::kCUDA); + TORCH_CHECK(tIn.get_device() == dev); + TORCH_CHECK(tIn.is_contiguous()); + } + + if (compressAsFloat) { + floatGetCompressedInfo( + res, + inPtrs.data(), + tIns.size(), + sizes_dev.data(), + types_dev.data(), + nullptr, + stream); + } else { + ansGetCompressedInfo( + res, inPtrs.data(), tIns.size(), sizes_dev.data(), nullptr, stream); + } + + auto sizes = sizes_dev.copyToHost(stream); + auto types = types_dev.copyToHost(stream); + + auto tOuts = std::vector(); + for (int i = 0; i < tIns.size(); ++i) { + auto size = sizes[i]; + auto type = types[i]; + + torch::Tensor tOut; + + if (compressAsFloat) { + TORCH_CHECK(type == types[0]); // must be consistent dtype + + tOut = torch::empty( + {static_cast(size)}, + at::TensorOptions() + .device(tIns[0].device()) + .dtype(getDtypeFromFloatType((FloatType)type))); + } else { + tOut = torch::empty( + {static_cast(size)}, + at::TensorOptions().device(tIns[0].device()).dtype(torch::kByte)); + } + + tOuts.emplace_back(std::move(tOut)); + } + + decompress_data_res( + compressAsFloat, res, tIns, tOuts, checksum, std::nullopt, std::nullopt); + + return tOuts; +} + +} // namespace dietgpu + +TORCH_LIBRARY_FRAGMENT(dietgpu, m) { + // compression sizes + m.def("max_float_compressed_output_size(Tensor[] ts) -> (int, int)"); + m.def("max_float_compressed_size(Tensor dtype, int size) -> int"); + m.def("max_any_compressed_output_size(Tensor[] ts) -> (int, int)"); + m.def("max_any_compressed_size(int bytes) -> int"); + + // data compress + m.def( + "compress_data(bool compress_as_float, Tensor[] ts_in, bool checksum=False, Tensor? temp_mem=None, Tensor? out_compressed=None, Tensor? out_compressed_bytes=None) -> (Tensor, Tensor, int)"); + m.def( + "compress_data_split_size(bool compress_as_float, Tensor t_in, Tensor t_in_split_sizes, bool checksum=False, Tensor? temp_mem=None, Tensor? out_compressed=None, Tensor? out_compressed_bytes=None) -> (Tensor[], Tensor, int)"); + m.def( + "compress_data_simple(bool compress_as_float, Tensor[] ts_in, bool checksum=False, int? temp_mem=67108864) -> Tensor[]"); + + // data decompress + m.def( + "decompress_data(bool compress_as_float, Tensor[] ts_in, Tensor[] ts_out, bool checksum=False, Tensor? temp_mem=None, Tensor? out_status=None, Tensor? out_decompressed_words=None) -> (int)"); + m.def( + "decompress_data_split_size(bool compress_as_float, Tensor[] ts_in, Tensor t_out, Tensor t_out_split_sizes, bool checksum=False, Tensor? temp_mem=None, Tensor? out_status=None, Tensor? out_decompressed_words=None) -> (int)"); + m.def( + "decompress_data_simple(bool compress_as_float, Tensor[] ts_in, bool checksum=False, int? temp_mem=67108864) -> Tensor[]"); +} + +TORCH_LIBRARY(dietgpu, m) { + m.impl( + TORCH_SELECTIVE_NAME("dietgpu::max_float_compressed_output_size"), + TORCH_FN(dietgpu::max_float_compressed_output_size)); + m.impl( + TORCH_SELECTIVE_NAME("dietgpu::max_float_compressed_size"), + TORCH_FN(dietgpu::max_float_compressed_size)); + m.impl( + TORCH_SELECTIVE_NAME("dietgpu::max_any_compressed_output_size"), + TORCH_FN(dietgpu::max_any_compressed_output_size)); + m.impl( + TORCH_SELECTIVE_NAME("dietgpu::max_any_compressed_size"), + TORCH_FN(dietgpu::max_any_compressed_size)); + + m.impl( + TORCH_SELECTIVE_NAME("dietgpu::compress_data"), + TORCH_FN(dietgpu::compress_data)); + m.impl( + TORCH_SELECTIVE_NAME("dietgpu::compress_data_split_size"), + TORCH_FN(dietgpu::compress_data_split_size)); + m.impl( + TORCH_SELECTIVE_NAME("dietgpu::compress_data_simple"), + TORCH_FN(dietgpu::compress_data_simple)); + + m.impl( + TORCH_SELECTIVE_NAME("dietgpu::decompress_data"), + TORCH_FN(dietgpu::decompress_data)); + m.impl( + TORCH_SELECTIVE_NAME("dietgpu::decompress_data_split_size"), + TORCH_FN(dietgpu::decompress_data_split_size)); + m.impl( + TORCH_SELECTIVE_NAME("dietgpu::decompress_data_simple"), + TORCH_FN(dietgpu::decompress_data_simple)); +} diff --git a/thirdparty/dietgpu/dietgpu/ans/BatchPrefixSum.cuh b/thirdparty/dietgpu/dietgpu/ans/BatchPrefixSum.cuh new file mode 100644 index 000000000..943640f07 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/ans/BatchPrefixSum.cuh @@ -0,0 +1,196 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include "dietgpu/utils/StaticUtils.h" + +namespace dietgpu { + +// FIXME: at some point, batchExclusivePrefixSum1 can no longer be run with +// 1024 threads. Restrict our max threads to 512 +constexpr int kMaxBEPSThreads = 512; + +// +// Quick and dirty means of performing a batched exclusive prefix sum in two +// passes :( +// + +template +struct NoTransform { + __host__ __device__ __forceinline__ T operator()(const T& v) const { + return v; + } +}; + +template +__global__ void batchExclusivePrefixSum1( + const T* __restrict__ in, + T* __restrict__ out, + void* __restrict__ blockTotal, + uint32_t batchSize, + TransformFn fn) { + uint32_t batch = blockIdx.y; + uint32_t block = blockIdx.x; + uint32_t blocksInBatch = gridDim.x; + uint32_t tid = threadIdx.x; + + int batchIdx = block * Threads + tid; + bool valid = batchIdx < batchSize; + + int totalIdx = batch * batchSize + batchIdx; + auto v = valid ? fn(in[totalIdx]) : T(0); + + using Scan = cub::BlockScan; + __shared__ typename Scan::TempStorage smem; + + T prefix = 0; + T total = 0; + Scan(smem).ExclusiveSum(v, prefix, total); + + if (valid) { + out[totalIdx] = prefix; + } + + // Only if this is not provided is 1 level of the tree enough + if (threadIdx.x == 0 && blockTotal) { + ((T*)blockTotal)[batch * blocksInBatch + block] = total; + } +} + +// Single block that performs the cross-block prefix sum +template +__global__ void batchExclusivePrefixSum2( + void* __restrict__ blockTotal, + uint32_t batchSize, + uint32_t blocksInBatch) { + uint32_t batch = blockIdx.x; + uint32_t tid = threadIdx.x; + + bool valid = tid < blocksInBatch; + auto v = valid ? ((T*)blockTotal)[batch * blocksInBatch + tid] : 0; + + using Scan = cub::BlockScan; + __shared__ typename Scan::TempStorage smem; + + Scan(smem).ExclusiveSum(v, v); + + if (valid) { + ((T*)blockTotal)[batch * blocksInBatch + tid] = v; + } +} + +template +__global__ void batchExclusivePrefixSum3( + T* __restrict__ out, + const void* __restrict__ blockTotal, + uint32_t batchSize) { + uint32_t batch = blockIdx.y; + uint32_t block = blockIdx.x; + uint32_t blocksInBatch = gridDim.x; + uint32_t tid = threadIdx.x; + + auto vBlock = ((const T*)blockTotal)[batch * blocksInBatch + block]; + + int batchIdx = block * Threads + tid; + bool valid = batchIdx < batchSize; + + int totalIdx = batch * batchSize + batchIdx; + + if (valid) { + out[totalIdx] += vBlock; + } +} + +inline size_t getBatchExclusivePrefixSumTempSize( + uint32_t numInBatch, + uint32_t batchSize) { + if (batchSize <= kMaxBEPSThreads) { + return 0; + } else { + // number of blocks required + return divUp(batchSize, kMaxBEPSThreads); + } +} + +// Perform a batched exclusive prefix sum, comprising +// numInBatch x batchSize data +template +void batchExclusivePrefixSum( + const T* in_dev, + T* out_dev, + void* temp_dev, + uint32_t numInBatch, + uint32_t batchSize, + const TransformFn& fn, + cudaStream_t stream) { + // maximum size we can handle with a two-level reduction + assert(batchSize <= kMaxBEPSThreads * kMaxBEPSThreads); + +#define BPS_LEVEL_1(THREADS, TEMP) \ + batchExclusivePrefixSum1 \ + <<>>( \ + in_dev, out_dev, TEMP, batchSize, fn) + +#define BPS_LEVEL_2(THREADS) \ + batchExclusivePrefixSum2 \ + <<>>(temp_dev, batchSize, blocks) + +#define BPS_LEVEL_3(THREADS) \ + batchExclusivePrefixSum3 \ + <<>>( \ + out_dev, temp_dev, batchSize) + + if (batchSize > kMaxBEPSThreads) { + // multi-level reduction required + uint32_t blocks = divUp(batchSize, kMaxBEPSThreads); + assert(blocks > 1); + assert(temp_dev); // must have this allocated + + BPS_LEVEL_1(kMaxBEPSThreads, temp_dev); + + if (blocks <= 32) { + BPS_LEVEL_2(32); + } else if (blocks <= 64) { + BPS_LEVEL_2(64); + } else if (blocks <= 128) { + BPS_LEVEL_2(128); + } else if (blocks <= 256) { + BPS_LEVEL_2(256); + } else { + assert(blocks <= kMaxBEPSThreads); + BPS_LEVEL_2(kMaxBEPSThreads); + } + + BPS_LEVEL_3(kMaxBEPSThreads); + } else { + // single-level reduction + uint32_t blocks = 1; + + if (batchSize <= 32) { + BPS_LEVEL_1(32, (T*)nullptr); + } else if (batchSize <= 64) { + BPS_LEVEL_1(64, (T*)nullptr); + } else if (batchSize <= 128) { + BPS_LEVEL_1(128, (T*)nullptr); + } else if (batchSize <= 256) { + BPS_LEVEL_1(256, (T*)nullptr); + } else { + assert(batchSize <= kMaxBEPSThreads); + BPS_LEVEL_1(kMaxBEPSThreads, (T*)nullptr); + } + } + +#undef BPS_LEVEL_3 +#undef BPS_LEVEL_2 +#undef BPS_LEVEL_1 +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/ans/BatchProvider.cuh b/thirdparty/dietgpu/dietgpu/ans/BatchProvider.cuh new file mode 100644 index 000000000..18e21d4a7 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/ans/BatchProvider.cuh @@ -0,0 +1,196 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include "dietgpu/ans/GpuANSUtils.cuh" + +namespace dietgpu { + +struct BatchWriter { + inline __device__ BatchWriter(void* out) + : out_((uint8_t*)out), outBlock_(nullptr) {} + + inline __device__ void setBlock(uint32_t block) { + outBlock_ = out_ + block * kDefaultBlockSize; + } + + inline __device__ void write(uint32_t offset, uint8_t sym) { + outBlock_[offset] = sym; + } + + // template + // inline __device__ void writeVec(uint32_t offset, Vec symV) { + // ((Vec*)outBlock_)[offset] = symV; + // } + + // __device__ void preload(uint32_t offset) {} + + uint8_t* out_; + uint8_t* outBlock_; +}; + +struct BatchProviderStride { + using Writer = BatchWriter; + + __host__ BatchProviderStride( + void* ptr_dev, + uint32_t batchStride, + uint32_t batchCapacity = 0) + : ptr_dev_(ptr_dev), + batchStride_(batchStride), + batchCapacity_(batchCapacity) {} + + __device__ void* getBatchStart(uint32_t batch) { + return ((uint8_t*)ptr_dev_) + batchStride_ * batch; + } + + __device__ const void* getBatchStart(uint32_t batch) const { + return ((uint8_t*)ptr_dev_) + batchStride_ * batch; + } + + __device__ BatchWriter getWriter(uint32_t batch) { + return BatchWriter(getBatchStart(batch)); + } + + __device__ uint32_t getBatchSize(uint32_t batch) { + return batchCapacity_; + } + + void* ptr_dev_; + uint32_t batchStride_; + uint32_t batchCapacity_; +}; + +struct BatchProviderSplitSize { + using Writer = BatchWriter; + + __host__ BatchProviderSplitSize( + void* ptr_dev, + const uint32_t* splitSize_dev, + // Exclusive prefix sum of splitSize_dev + const uint32_t* splitSizePrefix_dev, + uint32_t wordSize) + : ptr_dev_(ptr_dev), + splitSize_dev_(splitSize_dev), + splitSizePrefix_dev_(splitSizePrefix_dev), + wordSize_(wordSize) {} + + __device__ void* getBatchStart(uint32_t batch) { + return ((uint8_t*)ptr_dev_) + splitSizePrefix_dev_[batch] * wordSize_; + } + + __device__ const void* getBatchStart(uint32_t batch) const { + return ((uint8_t*)ptr_dev_) + splitSizePrefix_dev_[batch] * wordSize_; + } + + __device__ BatchWriter getWriter(uint32_t batch) { + return BatchWriter(getBatchStart(batch)); + } + + __device__ uint32_t getBatchSize(uint32_t batch) { + return splitSize_dev_[batch]; + } + + void* ptr_dev_; + const uint32_t* splitSize_dev_; + const uint32_t* splitSizePrefix_dev_; + uint32_t wordSize_; +}; + +struct BatchProviderPointer { + using Writer = BatchWriter; + + __host__ BatchProviderPointer( + void** ptr_dev, + const uint32_t* capacity_dev = nullptr) + : ptr_dev_(ptr_dev), capacity_dev_(capacity_dev) {} + + __device__ void* getBatchStart(uint32_t batch) { + return ptr_dev_[batch]; + } + + __device__ const void* getBatchStart(uint32_t batch) const { + return ptr_dev_[batch]; + } + + __device__ BatchWriter getWriter(uint32_t batch) { + return BatchWriter(getBatchStart(batch)); + } + + __device__ uint32_t getBatchSize(uint32_t batch) { + return capacity_dev_[batch]; + } + + void** ptr_dev_; + const uint32_t* capacity_dev_; +}; + +template +struct BatchProviderInlinePointer { + using Writer = BatchWriter; + + __host__ BatchProviderInlinePointer(int num, void** ptr_host) { + CHECK_LE(num, N); + for (int i = 0; i < num; ++i) { + ptr_dev_[i] = ptr_host[i]; + } + } + + __device__ void* getBatchStart(uint32_t batch) { + return ptr_dev_[batch]; + } + + __device__ const void* getBatchStart(uint32_t batch) const { + return ptr_dev_[batch]; + } + + __device__ BatchWriter getWriter(uint32_t batch) { + return BatchWriter(getBatchStart(batch)); + } + + void* ptr_dev_[N]; +}; + +template +struct BatchProviderInlinePointerCapacity { + using Writer = BatchWriter; + + __host__ BatchProviderInlinePointerCapacity( + int num, + void** ptr_host, + const uint32_t* capacity_host) { + CHECK_LE(num, N); + for (int i = 0; i < num; ++i) { + ptr_dev_[i] = ptr_host[i]; + capacity_dev_[i] = capacity_host[i]; + } + } + + __device__ void* getBatchStart(uint32_t batch) { + return ptr_dev_[batch]; + } + + __device__ const void* getBatchStart(uint32_t batch) const { + return ptr_dev_[batch]; + } + + __device__ BatchWriter getWriter(uint32_t batch) { + return BatchWriter(getBatchStart(batch)); + } + + __device__ uint32_t getBatchSize(uint32_t batch) { + return capacity_dev_[batch]; + } + + void* ptr_dev_[N]; + uint32_t capacity_dev_[N]; +}; + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/ans/CMakeLists.txt b/thirdparty/dietgpu/dietgpu/ans/CMakeLists.txt new file mode 100644 index 000000000..1f2e167a2 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/ans/CMakeLists.txt @@ -0,0 +1,52 @@ +add_library(gpu_ans SHARED + GpuANSDecode.cu + GpuANSEncode.cu + GpuANSInfo.cu +) +add_dependencies(gpu_ans + dietgpu_utils +) +target_include_directories(gpu_ans PUBLIC + $ +) +target_link_libraries(gpu_ans PUBLIC + dietgpu_utils +) +target_link_libraries(gpu_ans PRIVATE + glog::glog +) +target_compile_options(gpu_ans PRIVATE $<$: + --generate-line-info + #--device-debug +>) + + +enable_testing() +include(GoogleTest) + +add_executable(ans_test ANSTest.cu) +target_link_libraries(ans_test + gpu_ans + gtest_main +) +gtest_discover_tests(ans_test) + +add_executable(ans_statistics_test ANSStatisticsTest.cu) +target_link_libraries(ans_statistics_test + gpu_ans + gtest_main + dietgpu_utils +) +gtest_discover_tests(ans_statistics_test) + +add_executable(batch_prefix_sum_test BatchPrefixSumTest.cu) +target_link_libraries(batch_prefix_sum_test + gpu_ans + gtest_main +) +gtest_discover_tests(batch_prefix_sum_test) + +get_property(GLOBAL_CUDA_ARCHITECTURES GLOBAL PROPERTY CUDA_ARCHITECTURES) +set_target_properties(gpu_ans ans_test ans_statistics_test batch_prefix_sum_test + PROPERTIES CUDA_ARCHITECTURES "${GLOBAL_CUDA_ARCHITECTURES}" +) diff --git a/thirdparty/dietgpu/dietgpu/ans/GpuANSCodec.h b/thirdparty/dietgpu/dietgpu/ans/GpuANSCodec.h new file mode 100644 index 000000000..24ad34dc4 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/ans/GpuANSCodec.h @@ -0,0 +1,343 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include "dietgpu/utils/StackDeviceMemory.h" + +namespace dietgpu { + +// Required minimum alignment in bytes of all data to be compressed in the batch +constexpr int kANSRequiredAlignment = 4; + +// Default number of probability quantization bits to use, if an alternative is +// not specified +constexpr int kANSDefaultProbBits = 10; + +uint32_t getMaxCompressedSize(uint32_t uncompressedBytes); + +struct ANSCodecConfig { + inline ANSCodecConfig() : probBits(kANSDefaultProbBits), useChecksum(false) {} + + explicit inline ANSCodecConfig(int pb, bool checksum = false) + : probBits(pb), useChecksum(checksum) {} + + // What the ANS probability accuracy is; all symbols have quantized + // probabilities of 1/2^probBits. + // 9, 10, 11 are only valid values. When in doubt, use 10 (e.g., all symbol + // probabilities are one of {1/1024, 2/1024, ..., 1023/1024, 1024/1024}) + int probBits; + + // If true, we calculate a checksum on the uncompressed input data to + // compression and store it in the archive, and on the decompression side + // post-decompression, we calculate a checksum on the decompressed data which + // is compared with the original stored in the archive. + // This is an optional feature useful if DietGPU data will be stored + // persistently on disk. + bool useChecksum; +}; + +enum class ANSDecodeError : uint32_t { + None = 0, + ChecksumMismatch = 1, +}; + +// Error status for decompression +struct ANSDecodeStatus { + inline ANSDecodeStatus() : error(ANSDecodeError::None) {} + + // Overall error status + ANSDecodeError error; + + // Error-specific information for the batch + std::vector> errorInfo; +}; + +// +// Encode +// + +void ansEncodeBatchStride( + StackDeviceMemory& res, + // Compression configuration + const ANSCodecConfig& config, + + // Number of separate, independent compression problems + uint32_t numInBatch, + + // Region in device memory of size at least + // numInBatch * inPerBatchSize + max(numInBatch - 1, 0) * inPerBatchStride + const void* in_dev, + // Bytes per batch member for compression + uint32_t inPerBatchSize, + // Stride per separate input compression problem (must be >= inPerBatchSize) + uint32_t inPerBatchStride, + + // Optional (can be null): region in device memory of size 256 words + // containing pre-calculated symbol counts (histogram) of the data to be + // compressed + const uint32_t* histogram_dev, + + // Region in device memory of size at least + // numInBatch * getMaxCompressedSize(inPerBatchSize) + + // max(numInBatch - 1, 0) * outPerBatchStride + void* out_dev, + // Stride per separate output compression problem, which must be + // >= getMaxCompressedSize(inPerBatchSize) + uint32_t outPerBatchStride, + // Device memory array of size numInBatch (optional) + // Provides the size of actual used memory in each output compressed batch + uint32_t* outBatchSize_dev, + + // stream on the current device on which this runs + cudaStream_t stream); + +void ansEncodeBatchPointer( + StackDeviceMemory& res, + // Compression configuration + const ANSCodecConfig& config, + + // Number of separate, independent compression problems + uint32_t numInBatch, + + // Host array with addresses of device pointers comprising the input batch + // to compress + const void** in, + // Host array with sizes of batch members + const uint32_t* inSize, + + // Optional (can be null): region in device memory of size 256 words + // containing pre-calculated symbol counts (histogram) of the data to be + // compressed + const uint32_t* histogram_dev, + + // Host array with addresses of device pointers for the compressed output + // arrays. Each out[i] must be a region of memory of size at least + // getMaxCompressedSize(inSize[i]) + void** out, + // Device memory array of size numInBatch (optional) + // Provides the size of actual used memory in each output compressed batch + uint32_t* outSize_dev, + + // stream on the current device on which this runs + cudaStream_t stream); + +void ansEncodeBatchSplitSize( + StackDeviceMemory& res, + + // Compression configuration + const ANSCodecConfig& config, + + // Number of separate, independent compression problems + uint32_t numInBatch, + + // Device pointer into a valid region of memory of size at least + // sum_i(inSplitSizes[i]) bytes + const void* in_dev, + + // Host array with the size (in bytes) of the input arrays in the batch. + // Each array in the batch is read starting at offset inSplitSizes[i]. + const uint32_t* inSplitSizes, + + // Optional (can be null): region in device memory of size 256 words + // containing pre-calculated symbol counts (histogram) of the data to be + // compressed + const uint32_t* histogram_dev, + + // Device pointer to a matrix of at least size + // numInBatch x getMaxCompressedSize(max(inSplitSizes[i])) + void* out_dev, + + // Stride between rows in bytes + uint32_t outStride, + + // Device memory array of size numInBatch (optional) + // Provides the size of actual used memory in each output compressed batch + uint32_t* outSize_dev, + + // stream on the current device on which this runs + cudaStream_t stream); + +// +// Decode +// + +ANSDecodeStatus ansDecodeBatchStride( + StackDeviceMemory& res, + + // Expected compression configuration (we verify this upon decompression) + const ANSCodecConfig& config, + + // Number of separate, independent decompression problems + uint32_t numInBatch, + + // start of compressed input data (device pointer) + const void* in_dev, + // stride in in_dev between separate compressed inputs; e.g., the regions of + // memory located at the following byte offset ranges from in_dev contain + // the input per-batch element compressed data: + // + // [b * inPerBatchStride, b * inPerBatchStride + compressed_size[b] - 1] + // where compressed_size[b] is the compressed size indicated in the header + // metadata for each per-batch compressed data. + // + // The kernel will not access memory beyond the per-batch member compressed + // size in each of these regions, thus the stride should be at least the + // maximum of all of the individual per-batch compressed sizes. + // If the stride is not sufficient, then the kernel may segfault. + uint32_t inPerBatchStride, + + // start of decompressed output data (device pointer) + void* out_dev, + // Stride between each decompressed output, which must be greater than the + // uncompressed size for each decompressed output and outPerBatchCapacity. + uint32_t outPerBatchStride, + // Overall space available for each decompression batch member; e.g., the + // regions of memory located at the following byte offset ranges from + // out_dev: + // + // [b * outPerBatchStride, b * outPerBatchStride + outPerBatchCapacity - 1] + // + // for all b \in [0, numInBatch - 1] are valid. + // If the seen decompressed size for any individual batch member is less + // than outBatchCapacity, that particular batch member will fail to + // decompress, and the reported size for that batch member will be in + // status_dev. + uint32_t outPerBatchCapacity, + + // Decode success/fail status (optional, can be nullptr) + // If present, this is a device pointer to an array of length numInBatch, + // with true/false for whether or not decompression status was successful + // FIXME: not bool due to issues with __nv_bool + uint8_t* outSuccess_dev, + + // Decode size status (optional, can be nullptr) + // If present, this is a device pointer to an array of length numInBatch, + // with either the size decompressed reported if successful, or the required + // size reported if our outPerBatchCapacity was insufficient + uint32_t* outSize_dev, + + // stream on the current device on which this runs + cudaStream_t stream); + +ANSDecodeStatus ansDecodeBatchPointer( + StackDeviceMemory& res, + + // Expected compression configuration (we verify this upon decompression) + const ANSCodecConfig& config, + + // Number of separate, independent decompression problems + uint32_t numInBatch, + + // Host array with addresses of device pointers corresponding to compressed + // inputs + const void** in, + + // Host array with addresses of device pointers corresponding to + // uncompressed outputs + void** out, + + // Host array with size of memory regions provided in out; if the seen + // decompressed size is greater than this, then there will be an error in + // decompression + const uint32_t* outCapacity, + + // Decode success/fail status (optional, can be nullptr) + // If present, this is a device pointer to an array of length numInBatch, + // with true/false for whether or not decompression status was successful + // FIXME: not bool due to issues with __nv_bool + uint8_t* outSuccess_dev, + + // Decode size status (optional, can be nullptr) + // If present, this is a device pointer to an array of length numInBatch, + // with either the size decompressed reported if successful, or the required + // size reported if our outPerBatchCapacity was insufficient + uint32_t* outSize_dev, + + // stream on the current device on which this runs + cudaStream_t stream); + +ANSDecodeStatus ansDecodeBatchSplitSize( + StackDeviceMemory& res, + + // Expected compression configuration (we verify this upon decompression) + const ANSCodecConfig& config, + + // Number of separate, independent compression problems + uint32_t numInBatch, + + // Host array with addresses of device pointers comprising the batch + const void** in, + + // Device pointer into a valid region of memory of size at least + // sum_i(outSplitSizes[i]) bytes + void* out_dev, + + // Host array with the size (in bytes) of the output + // decompressed arrays in the batch. + // Each decompressed array in the batch is written at offset + // outSplitSizes[i]. + // The decompressed size must match exactly these sizes, otherwise there's a + // decompression error + const uint32_t* outSplitSizes, + + // Decode success/fail status (optional, can be nullptr) + // If present, this is a device pointer to an array of length numInBatch, + // with true/false for whether or not decompression status was successful + // FIXME: not bool due to issues with __nv_bool + uint8_t* outSuccess_dev, + + // Decode size status (optional, can be nullptr) + // If present, this is a device pointer to an array of length numInBatch, + // with either the size decompressed reported if successful, or the required + // size reported if our outPerBatchCapacity was insufficient. Size reported + // is in float words + uint32_t* outSize_dev, + + // stream on the current device on which this runs + cudaStream_t stream); + +// +// Information +// + +void ansGetCompressedInfo( + StackDeviceMemory& res, + // Host array with addresses of device pointers comprising the batch of + // compressed ANS data + const void** in, + // Number of compressed arrays in the batch + uint32_t numInBatch, + // Optional device array to receive the resulting sizes. 0 is reported if + // the compresed data is not as expected, otherwise the size is reported in + // bytes + uint32_t* outSizes_dev, + // Optional device array to receive pre-compression checksums stored in the + // archive, if the checksum feature was enabled. + uint32_t* outChecksum_dev, + // stream on the current device on which this runs + cudaStream_t stream); + +void ansGetCompressedInfoDevice( + StackDeviceMemory& res, + // Device array with addresses of device pointers comprising the batch of + // compressed ANS data + const void** in_dev, + // Number of compressed arrays in the batch + uint32_t numInBatch, + // Optional device array to receive the resulting sizes. 0 is reported if + // the compresed data is not as expected, otherwise the size is reported in + // bytes + uint32_t* outSizes_dev, + // Optional device array to receive pre-compression checksums stored in the + // archive, if the checksum feature was enabled. + uint32_t* outChecksum_dev, + // stream on the current device on which this runs + cudaStream_t stream); + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/ans/GpuANSDecode.cu b/thirdparty/dietgpu/dietgpu/ans/GpuANSDecode.cu new file mode 100644 index 000000000..45cbe6ba9 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/ans/GpuANSDecode.cu @@ -0,0 +1,195 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "dietgpu/ans/GpuANSCodec.h" +#include "dietgpu/ans/GpuANSDecode.cuh" +#include "dietgpu/ans/GpuANSUtils.cuh" +#include "dietgpu/utils/StackDeviceMemory.h" + +#include +#include +#include +#include + +namespace dietgpu { + +ANSDecodeStatus ansDecodeBatchStride( + StackDeviceMemory& res, + const ANSCodecConfig& config, + uint32_t numInBatch, + const void* in_dev, + uint32_t inPerBatchStride, + void* out_dev, + uint32_t outPerBatchStride, + uint32_t outPerBatchCapacity, + uint8_t* outSuccess_dev, + uint32_t* outSize_dev, + cudaStream_t stream) { + auto inProvider = BatchProviderStride((void*)in_dev, inPerBatchStride); + auto outProvider = + BatchProviderStride(out_dev, outPerBatchStride, outPerBatchCapacity); + + return ansDecodeBatch( + res, + config, + numInBatch, + inProvider, + outProvider, + outSuccess_dev, + outSize_dev, + stream); +} + +ANSDecodeStatus ansDecodeBatchPointer( + StackDeviceMemory& res, + const ANSCodecConfig& config, + uint32_t numInBatch, + const void** in, + void** out, + const uint32_t* outCapacity, + uint8_t* outSuccess_dev, + uint32_t* outSize_dev, + cudaStream_t stream) { + // If the batch size is <= kBSLimit, we avoid cudaMemcpy and send all data at + // kernel launch + constexpr int kBSLimit = 128; + + if (numInBatch <= kBSLimit) { + auto inProvider = + BatchProviderInlinePointer(numInBatch, (void**)in); + auto outProvider = BatchProviderInlinePointerCapacity( + numInBatch, out, outCapacity); + + return ansDecodeBatch( + res, + config, + numInBatch, + inProvider, + outProvider, + outSuccess_dev, + outSize_dev, + stream); + } + + // Otherwise, we have to perform h2d copies + auto in_dev = res.alloc(stream, numInBatch); + + CUDA_VERIFY(cudaMemcpyAsync( + in_dev.data(), + in, + numInBatch * sizeof(void*), + cudaMemcpyHostToDevice, + stream)); + + auto out_dev = res.alloc(stream, numInBatch); + + CUDA_VERIFY(cudaMemcpyAsync( + out_dev.data(), + out, + numInBatch * sizeof(void*), + cudaMemcpyHostToDevice, + stream)); + + auto outCapacity_dev = res.alloc(stream, numInBatch); + + CUDA_VERIFY(cudaMemcpyAsync( + outCapacity_dev.data(), + outCapacity, + numInBatch * sizeof(uint32_t), + cudaMemcpyHostToDevice, + stream)); + + // Data is now on the device + auto inProvider = BatchProviderPointer(in_dev.data()); + auto outProvider = + BatchProviderPointer(out_dev.data(), outCapacity_dev.data()); + + return ansDecodeBatch( + res, + config, + numInBatch, + inProvider, + outProvider, + outSuccess_dev, + outSize_dev, + stream); +} + +ANSDecodeStatus ansDecodeBatchSplitSize( + StackDeviceMemory& res, + const ANSCodecConfig& config, + uint32_t numInBatch, + const void** in, + void* out_dev, + const uint32_t* outSplitSizes, + uint8_t* outSuccess_dev, + uint32_t* outSize_dev, + cudaStream_t stream) { + auto splitSizeHost = std::vector(numInBatch * 2); + auto splitSize = splitSizeHost.data(); + auto splitSizePrefix = splitSizeHost.data() + numInBatch; + uint32_t maxSplitSize = 0; + + // check alignment + CHECK_EQ(uintptr_t(out_dev) % kANSRequiredAlignment, 0); + + for (uint32_t i = 0; i < numInBatch; ++i) { + auto size = outSplitSizes[i]; + + if (i != (numInBatch - 1)) { + // check alignment (internal splits affect alignment of things after it) + CHECK_EQ(size % kANSRequiredAlignment, 0); + } + + splitSize[i] = size; + if (i > 0) { + splitSizePrefix[i] = splitSizePrefix[i - 1] + splitSize[i - 1]; + } + + maxSplitSize = std::max(size, maxSplitSize); + } + + // Concatenate splitSize and splitSizePrefix together for a single h2d copy + auto sizes_dev = res.alloc(stream, splitSizeHost.size()); + + CUDA_VERIFY(cudaMemcpyAsync( + sizes_dev.data(), + splitSizeHost.data(), + splitSizeHost.size() * sizeof(uint32_t), + cudaMemcpyHostToDevice, + stream)); + + // FIXME: combine with above for a single h2d copy + auto in_dev = res.alloc(stream, numInBatch); + + CUDA_VERIFY(cudaMemcpyAsync( + in_dev.data(), + in, + numInBatch * sizeof(void*), + cudaMemcpyHostToDevice, + stream)); + + auto inProvider = BatchProviderPointer(in_dev.data()); + + auto outProvider = BatchProviderSplitSize( + out_dev, + sizes_dev.data(), + sizes_dev.data() + numInBatch, + sizeof(uint8_t)); + + return ansDecodeBatch( + res, + config, + numInBatch, + inProvider, + outProvider, + outSuccess_dev, + outSize_dev, + stream); +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/ans/GpuANSDecode.cuh b/thirdparty/dietgpu/dietgpu/ans/GpuANSDecode.cuh new file mode 100644 index 000000000..41a76da9a --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/ans/GpuANSDecode.cuh @@ -0,0 +1,615 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include "dietgpu/ans/BatchProvider.cuh" +#include "dietgpu/ans/GpuANSCodec.h" +#include "dietgpu/ans/GpuANSInfo.cuh" +#include "dietgpu/ans/GpuANSUtils.cuh" +#include "dietgpu/ans/GpuChecksum.cuh" +#include "dietgpu/utils/DeviceDefs.cuh" +#include "dietgpu/utils/DeviceUtils.h" +#include "dietgpu/utils/PtxUtils.cuh" +#include "dietgpu/utils/StackDeviceMemory.h" +#include "dietgpu/utils/StaticUtils.h" + +#include +#include +#include +#include +#include +#include + +namespace dietgpu { + +using TableT = uint32_t; + +// We are limited to 11 bits of probability resolution +// (worst case, prec = 12, pdf == 2^12, single symbol. 2^12 cannot be +// represented in 12 bits) +inline __device__ TableT +packDecodeLookup(uint32_t sym, uint32_t pdf, uint32_t cdf) { + static_assert(sizeof(ANSDecodedT) == 1, ""); + // [31:20] cdf + // [19:8] pdf + // [7:0] symbol + return (cdf << 20) | (pdf << 8) | sym; +} + +inline __device__ void +unpackDecodeLookup(TableT v, uint32_t& sym, uint32_t& pdf, uint32_t& cdf) { + // [31:20] cdf + // [19:8] pdf + // [7:0] symbol + sym = v & 0xffU; + v >>= 8; + pdf = v & 0xfffU; + v >>= 12; + cdf = v; +} + +template +__device__ void decodeOneWarp( + ANSStateT& state, + + // Start offset where this warp is reading from the + // compressed input. As a variable number of lanes + // wish to read from the compressed offset each + // iteration, this offset upon calling is one after + // the last offset, if any, this warp will be reading rom. + uint32_t compressedOffset, + + const ANSEncodedT* __restrict__ in, + + // Shared memory LUTs + const TableT* lookup, + + // Output: number of words read from compressed input + uint32_t& outNumRead, + + // Output: decoded symbol for this iteration + ANSDecodedT& outSym) { + constexpr ANSStateT StateMask = (ANSStateT(1) << ProbBits) - ANSStateT(1); + + auto s_bar = state & StateMask; + + uint32_t sym; + uint32_t pdf; + uint32_t sMinusCdf; + unpackDecodeLookup(lookup[s_bar], sym, pdf, sMinusCdf); + + // We always write a decoded value + outSym = sym; + state = pdf * (state >> ProbBits) + ANSStateT(sMinusCdf); + + // We only sometimes read a new encoded value + bool read = state < kANSMinState; + auto vote = __ballot_sync(kFullMask, read); + // We are reading in the same order as we wrote, except by decrementing from + // compressedOffset, so we need to count down from the highest lane in the + // warp + #if defined(__HIP_PLATFORM_AMD__) + auto prefix = __popcll(vote & getLaneMaskGe()); + #else + auto prefix = __popc(vote & getLaneMaskGe()); + #endif + + if (read) { + auto v = in[compressedOffset - prefix]; + // auto v = in[-prefix]; + state = (state << kANSEncodedBits) + ANSStateT(v); + } + + // how many values we actually read from the compressed input + #if defined(__HIP_PLATFORM_AMD__) + outNumRead = __popcll(vote); + #else + outNumRead = __popc(vote); + #endif +} + +template +__device__ void decodeOnePartialWarp( + bool valid, + ANSStateT& state, + + // Start offset where this warp is reading from the + // compressed input. As a variable number of lanes + // wish to read from the compressed offset each + // iteration, this offset upon calling is one after + // the last offset, if any, this warp will be reading rom. + uint32_t compressedOffset, + + const ANSEncodedT* __restrict__ in, + + // Shared memory LUTs + const TableT* lookup, + + // Output: number of words read from compressed input + uint32_t& outNumRead, + + // Output: decoded symbol for this iteration (only if valid) + ANSDecodedT& outSym) { + constexpr ANSStateT StateMask = (ANSStateT(1) << ProbBits) - ANSStateT(1); + + auto s_bar = state & StateMask; + + uint32_t sym; + uint32_t pdf; + uint32_t sMinusCdf; + unpackDecodeLookup(lookup[s_bar], sym, pdf, sMinusCdf); + + if (valid) { + outSym = sym; + state = pdf * (state >> ProbBits) + ANSStateT(sMinusCdf); + } + + // We only sometimes read a new encoded value + bool read = valid && (state < kANSMinState); + auto vote = __ballot_sync(kFullMask, read); + // We are reading in the same order as we wrote, except by decrementing from + // compressedOffset, so we need to count down from the highest lane in the + // warp + #if defined(__HIP_PLATFORM_AMD__) + auto prefix = __popcll(vote & getLaneMaskGe()); + #else + auto prefix = __popc(vote & getLaneMaskGe()); + #endif + if (read) { + // auto in_new = in - compressedOffset; + auto v = in[compressedOffset - prefix]; + // auto v = in[-prefix]; + state = (state << kANSEncodedBits) + ANSStateT(v); + } + + // how many values we actually read from the compressed input + #if defined(__HIP_PLATFORM_AMD__) + outNumRead = __popcll(vote); + #else + outNumRead = __popc(vote); + #endif +} + +template +__device__ void ansDecodeWarpBlock( + int laneId, + ANSStateT state, + uint32_t uncompressedWords, + uint32_t compressedWords, + const ANSEncodedT* __restrict__ in, + Writer& writer, + const TableT* __restrict__ table) { + // The compressed input may not be a whole multiple of a warp. + // In this case, only the lanes that cover the remainder could have read a + // value in the input, and thus, only they can write a value in the output. + // We handle this partial data first. + uint32_t remainder = uncompressedWords % kWarpSize; + + // A fixed number of uncompressed elements are written each iteration + int uncompressedOffset = uncompressedWords - remainder; + + // A variable number of compressed elements are read each iteration + uint32_t compressedOffset = compressedWords; + + + + // Partial warp handling the end of the data + if (remainder) { + bool valid = laneId < remainder; + + uint32_t numCompressedRead; + ANSDecodedT sym; + + decodeOnePartialWarp( + valid, state, compressedOffset, in, table, numCompressedRead, sym); + + if (valid) { + writer.write(uncompressedOffset + laneId, sym); + } + + // compressedOffset -= numCompressedRead; + in -= numCompressedRead; + } + + // Full warp handling + while (uncompressedOffset > 0) { + uncompressedOffset -= kWarpSize; + + uint32_t numCompressedRead; + ANSDecodedT sym; + + decodeOneWarp( + state, compressedOffset, in, table, numCompressedRead, sym); + + writer.write(uncompressedOffset + laneId, sym); + + // compressedOffset -= numCompressedRead; + in -= numCompressedRead; + } + in += compressedOffset; +} + +template +struct ANSDecodeWarpFullBlock; + +// template +// struct ANSDecodeWarpFullBlock { +// static __device__ void decode( +// int laneId, +// ANSStateT state, +// uint32_t compressedWords, +// const ANSEncodedT* __restrict__ in, +// Writer& writer, +// const TableT* __restrict__ table) { +// // A variable number of compressed elements are read each iteration +// using VecT = ANSDecodedTx4; + +// in += compressedWords; + +// // 2: 252.16 us +// // 3: 246.62 us +// // 4: 254.91 us +// constexpr int kCacheLinesAhead = 3; + +// for (int i = (BlockSize / sizeof(VecT)) - kWarpSize + laneId; i >= 0; +// i -= kWarpSize) { +// VecT symV; +// // Assuming no compression, we load 2 * sizeof(ANSEncodedT) * +// // kWarpSize = 128 bytes per iteration +// asm volatile("prefetch.global.L1 [%0];" +// : +// : "l"(in - (kCacheLinesAhead * 128) / +// sizeof(ANSEncodedT))); + +// // writer.preload(i + laneId); +// writer.preload(i); + +// #pragma unroll +// for (int j = sizeof(VecT) - 1; j >= 0; --j) { +// ANSDecodedT sym; +// uint32_t numCompressedRead; + +// decodeOneWarp( +// state, compressedWords, in, table, numCompressedRead, sym); + +// symV.x[j] = sym; +// // compressedWords -= numCompressedRead; +// in -= numCompressedRead; +// } + +// // writer.writeVec(i + laneId, symV); +// writer.writeVec(i, symV); +// } +// } +// }; + +// Non-vectorized full block implementation +template +struct ANSDecodeWarpFullBlock { + static __device__ void decode( + int laneId, + ANSStateT state, + uint32_t compressedWords, + const ANSEncodedT* __restrict__ in, + Writer& writer, + const TableT* __restrict__ table) { + + for (int i = BlockSize - kWarpSize + laneId; i >= 0; i -= kWarpSize) { + ANSDecodedT sym; + uint32_t numCompressedRead; + + decodeOneWarp( + state, compressedWords, in, table, numCompressedRead, sym); + + in -= numCompressedRead; + + writer.write(i, sym); + } + in += compressedWords; + } +}; + +template < + typename InProvider, + typename OutProvider, + int Threads, + int ProbBits, + int BlockSize> +__global__ __launch_bounds__(128) void ansDecodeKernel( + InProvider inProvider, + const TableT* __restrict__ table, + OutProvider outProvider, + uint8_t* __restrict__ outSuccess, + uint32_t* __restrict__ outSize) { + auto tid = threadIdx.x; + auto batch = blockIdx.y; + + // Interpret header as uint4 + auto headerIn = (const ANSCoalescedHeader*)inProvider.getBatchStart(batch); + headerIn->checkMagicAndVersion(); + + auto header = *headerIn; + auto numBlocks = header.getNumBlocks(); + auto totalUncompressedWords = header.getTotalUncompressedWords(); + + // Is the data what we expect? + assert(ProbBits == header.getProbBits()); + + // Do we have enough space for the decompressed data? + auto uncompressedBytes = totalUncompressedWords * sizeof(ANSDecodedT); + bool success = outProvider.getBatchSize(batch) >= uncompressedBytes; + + if (blockIdx.x == 0 && tid == 0) { + if (outSuccess) { + outSuccess[batch] = success; + } + + if (outSize) { + outSize[batch] = uncompressedBytes; + } + } + + if (!success) { + return; + } + + // Initialize symbol, pdf, cdf tables + constexpr int kBuckets = 1 << ProbBits; + __shared__ TableT lookup[kBuckets]; + + { + uint4* lookup4 = (uint4*)lookup; + const uint4* table4 = (const uint4*)(table + batch * (1 << ProbBits)); + + static_assert(isEvenDivisor(kBuckets, Threads * 4), ""); + for (int j = 0; + // loading by uint4 words + j < kBuckets / (Threads * (sizeof(uint4) / sizeof(TableT))); + ++j) { + lookup4[j * Threads + tid] = table4[j * Threads + tid]; + } + } + + __syncthreads(); + + auto writer = outProvider.getWriter(batch); + + // warp id taking into account warps in the current block + // do this so the compiler knows it is warp uniform + int globalWarpId = + __shfl_sync(kFullMask, (blockIdx.x * blockDim.x + tid) / kWarpSize, 0); + + auto warpsPerGrid = gridDim.x * Threads / kWarpSize; + int laneId = getLaneId(); + + for (int block = globalWarpId; block < numBlocks; block += warpsPerGrid) { + // Load state + ANSStateT state = headerIn->getWarpStates()[block].warpState[laneId]; + + // Load per-block size data + auto blockWords = headerIn->getBlockWords(numBlocks)[block]; + uint32_t uncompressedWords = (blockWords.x >> 16); + uint32_t compressedWords = (blockWords.x & 0xffff); + uint32_t blockCompressedWordStart = blockWords.y; + + // Get block addresses for encoded/decoded data + auto blockDataIn = + headerIn->getBlockDataStart(numBlocks) + blockCompressedWordStart; + + writer.setBlock(block); + + using Writer = typename OutProvider::Writer; + if (uncompressedWords == BlockSize) { + ANSDecodeWarpFullBlock::decode( + laneId, state, compressedWords, blockDataIn, writer, lookup); + } else { + ansDecodeWarpBlock( + laneId, + state, + uncompressedWords, + compressedWords, + blockDataIn, + writer, + lookup); + } + } +} + +template +__global__ void ansDecodeTable( + BatchProvider inProvider, + uint32_t probBits, + TableT* __restrict__ table) { + auto batch = blockIdx.x; + auto tid = threadIdx.x; + int warpId = tid / kWarpSize; + int laneId = getLaneId(); + + table += batch * (1 << probBits); + auto headerIn = (const ANSCoalescedHeader*)inProvider.getBatchStart(batch); + + auto header = *headerIn; + + // Is this an expected header? + header.checkMagicAndVersion(); + + // Is our probability resolution what we expected? + assert(header.getProbBits() == probBits); + + if (header.getTotalUncompressedWords() == 0) { + // nothing to do; compressed empty array + return; + } + + // Skip to pdf table + auto probs = headerIn->getSymbolProbs(); + + static_assert(Threads >= kNumSymbols, ""); + uint32_t pdf = tid < kNumSymbols ? probs[tid] : 0; + uint32_t cdf = 0; + + // Get the CDF from the PDF + using BlockScan = cub::BlockScan; + __shared__ typename BlockScan::TempStorage tempStorage; + + uint32_t total = 0; + // FIXME: don't use cub, we can write both the pdf and cdf to smem with a + // single syncthreads + BlockScan(tempStorage).ExclusiveSum(pdf, cdf, total); + + uint32_t totalProb = 1 << probBits; + assert(totalProb == total); // should be a power of 2 + + // Broadcast the pdf/cdf values + __shared__ uint2 smemPdfCdf[kNumSymbols]; + + if (tid < kNumSymbols) { + smemPdfCdf[tid] = uint2{pdf, cdf}; + } + + __syncthreads(); + + // Build the table for each pdf/cdf bucket + constexpr int kWarpsPerBlock = Threads / kWarpSize; + + for (int i = warpId; i < kNumSymbols; i += kWarpsPerBlock) { + auto v = smemPdfCdf[i]; + + auto pdf = v.x; + auto begin = v.y; + auto end = begin + pdf; + + for (int j = begin + laneId; j < end; j += kWarpSize) { + table[j] = packDecodeLookup( + i, // symbol + pdf, // bucket pdf + j - begin); // within-bucket cdf + } + } +} + +template +ANSDecodeStatus ansDecodeBatch( + StackDeviceMemory& res, + const ANSCodecConfig& config, + uint32_t numInBatch, + const InProvider& inProvider, + OutProvider& outProvider, + uint8_t* outSuccess_dev, + uint32_t* outSize_dev, + cudaStream_t stream) { + auto table_dev = + res.alloc(stream, numInBatch * (1 << config.probBits)); + + // Build the rANS decoding table from the compression header + { + constexpr int kThreads = 512; + ansDecodeTable<<>>( + inProvider, config.probBits, table_dev.data()); + } + + // Perform decoding + { + // FIXME: We have no idea how large the decompression job is, as the + // relevant information is on the device. + // Just launch a grid that is sufficiently large enough to saturate the GPU; + // blocks will exit if there isn't enough work, or will loop if there is + // more work. We aim for a grid >4x larger than what the device can sustain, + // to help cover up tail effects and unequal provisioning across the batch +#define RUN_DECODE(BITS) \ + do { \ + constexpr int kThreads = 128; \ + auto& props = getCurrentDeviceProperties(); \ + int maxBlocksPerSM = 0; \ + CUDA_VERIFY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( \ + &maxBlocksPerSM, \ + ansDecodeKernel< \ + InProvider, \ + OutProvider, \ + kThreads, \ + BITS, \ + kDefaultBlockSize>, \ + kThreads, \ + 0)); \ + uint32_t maxGrid = maxBlocksPerSM * props.multiProcessorCount; \ + uint32_t perBatchGrid = divUp(maxGrid, numInBatch) * 4; \ + auto grid = dim3(perBatchGrid, numInBatch); \ + \ + ansDecodeKernel< \ + InProvider, \ + OutProvider, \ + kThreads, \ + BITS, \ + kDefaultBlockSize><<>>( \ + inProvider, \ + table_dev.data(), \ + outProvider, \ + outSuccess_dev, \ + outSize_dev); \ + } while (false) + + switch (config.probBits) { + case 9: + RUN_DECODE(9); + break; + case 10: + RUN_DECODE(10); + break; + case 11: + RUN_DECODE(11); + break; + default: + CHECK(false) << "unhandled pdf precision " << config.probBits; + } + +#undef RUN_DECODE + } + + ANSDecodeStatus status; + + // Perform optional checksum, if desired + if (config.useChecksum) { + auto checksum_dev = res.alloc(stream, numInBatch); + auto sizes_dev = res.alloc(stream, numInBatch); + auto archiveChecksum_dev = res.alloc(stream, numInBatch); + + // Checksum the output data + checksumBatch(numInBatch, outProvider, checksum_dev.data(), stream); + + // Get prior checksum from the ANS headers + ansGetCompressedInfo( + inProvider, + numInBatch, + sizes_dev.data(), + archiveChecksum_dev.data(), + stream); + + // Compare against previously seen checksums on the host + auto sizes = sizes_dev.copyToHost(stream); + auto newChecksums = checksum_dev.copyToHost(stream); + auto oldChecksums = archiveChecksum_dev.copyToHost(stream); + + std::stringstream errStr; + + for (int i = 0; i < numInBatch; ++i) { + if (oldChecksums[i] != newChecksums[i]) { + status.error = ANSDecodeError::ChecksumMismatch; + + errStr << "Checksum mismatch in batch member " << i + << ": expected checksum " << std::hex << oldChecksums[i] + << " got " << newChecksums[i] << "\n"; + status.errorInfo.push_back(std::make_pair(i, errStr.str())); + } + } + } + + CUDA_TEST_ERROR(); + + return status; +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/ans/GpuANSEncode.cu b/thirdparty/dietgpu/dietgpu/ans/GpuANSEncode.cu new file mode 100644 index 000000000..4d58be154 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/ans/GpuANSEncode.cu @@ -0,0 +1,181 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "dietgpu/ans/BatchProvider.cuh" +#include "dietgpu/ans/GpuANSEncode.cuh" + +namespace dietgpu { + +uint32_t getMaxCompressedSize(uint32_t uncompressedBytes) { + uint32_t blocks = divUp(uncompressedBytes, kDefaultBlockSize); + + size_t rawSize = ANSCoalescedHeader::getCompressedOverhead(kDefaultBlockSize); + rawSize += (size_t)getMaxBlockSizeCoalesced(kDefaultBlockSize) * blocks; + + // When used in batches, we must align everything to 16 byte boundaries (due + // to uint4 read/writes) + rawSize = roundUp(rawSize, sizeof(uint4)); + CHECK_LE(rawSize, std::numeric_limits::max()); + + return rawSize; +} + +void ansEncodeBatchStride( + StackDeviceMemory& res, + const ANSCodecConfig& config, + uint32_t numInBatch, + const void* in_dev, + uint32_t inPerBatchSize, + uint32_t inPerBatchStride, + const uint32_t* histogram_dev, + void* out_dev, + uint32_t outPerBatchStride, + uint32_t* outSize_dev, + cudaStream_t stream) { + auto inProvider = + BatchProviderStride((void*)in_dev, inPerBatchStride, inPerBatchSize); + auto outProvider = BatchProviderStride(out_dev, outPerBatchStride); + + ansEncodeBatchDevice( + res, + config, + numInBatch, + inProvider, + histogram_dev, + inPerBatchSize, // max size + outProvider, + outSize_dev, + stream); +} + +void ansEncodeBatchPointer( + StackDeviceMemory& res, + const ANSCodecConfig& config, + uint32_t numInBatch, + const void** in, + const uint32_t* inSize, + const uint32_t* histogram_dev, + void** out, + uint32_t* outSize_dev, + cudaStream_t stream) { + // Get the total and maximum input size + uint32_t maxSize = 0; + + for (uint32_t i = 0; i < numInBatch; ++i) { + uint32_t curSize = inSize[i] / sizeof(ANSDecodedT); + maxSize = std::max(maxSize, curSize); + } + + // Copy data to device + auto in_dev = res.alloc(stream, numInBatch); + auto inSize_dev = res.alloc(stream, numInBatch); + auto out_dev = res.alloc(stream, numInBatch); + + CUDA_VERIFY(cudaMemcpyAsync( + in_dev.data(), + in, + numInBatch * sizeof(void*), + cudaMemcpyHostToDevice, + stream)); + + CUDA_VERIFY(cudaMemcpyAsync( + inSize_dev.data(), + inSize, + numInBatch * sizeof(uint32_t), + cudaMemcpyHostToDevice, + stream)); + + CUDA_VERIFY(cudaMemcpyAsync( + out_dev.data(), + out, + numInBatch * sizeof(void*), + cudaMemcpyHostToDevice, + stream)); + + auto inProvider = + BatchProviderPointer((void**)in_dev.data(), inSize_dev.data()); + auto outProvider = BatchProviderPointer(out_dev.data()); + + ansEncodeBatchDevice( + res, + config, + numInBatch, + inProvider, + histogram_dev, + maxSize, + outProvider, + outSize_dev, + stream); +} + +void ansEncodeBatchSplitSize( + StackDeviceMemory& res, + const ANSCodecConfig& config, + uint32_t numInBatch, + const void* in_dev, + const uint32_t* inSplitSizes, + const uint32_t* histogram_dev, + void* out_dev, + uint32_t outStride, + uint32_t* outSize_dev, + cudaStream_t stream) { + auto splitSizeHost = std::vector(numInBatch * 2); + auto splitSize = splitSizeHost.data(); + auto splitSizePrefix = splitSizeHost.data() + numInBatch; + uint32_t maxSplitSize = 0; + + // check alignment + CHECK_EQ(uintptr_t(in_dev) % kANSRequiredAlignment, 0); + + for (uint32_t i = 0; i < numInBatch; ++i) { + auto size = inSplitSizes[i]; + + if (i != (numInBatch - 1)) { + // check alignment (internal splits affect alignment of things after it) + CHECK_EQ(size % kANSRequiredAlignment, 0); + } + + splitSize[i] = size; + if (i > 0) { + splitSizePrefix[i] = splitSizePrefix[i - 1] + splitSize[i - 1]; + } + + maxSplitSize = std::max(size, maxSplitSize); + } + + // Copy data to device + // splitSize, splitSizePrefix + auto sizes_dev = res.alloc(stream, splitSizeHost.size()); + + CUDA_VERIFY(cudaMemcpyAsync( + sizes_dev.data(), + splitSizeHost.data(), + splitSizeHost.size() * sizeof(uint32_t), + cudaMemcpyHostToDevice, + stream)); + + auto inProvider = BatchProviderSplitSize( + (void*)in_dev, + sizes_dev.data(), + sizes_dev.data() + numInBatch, + sizeof(uint8_t)); + + auto outProvider = BatchProviderStride(out_dev, outStride); + + ansEncodeBatchDevice( + res, + config, + numInBatch, + inProvider, + histogram_dev, + maxSplitSize, + outProvider, + outSize_dev, + stream); +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/ans/GpuANSEncode.cuh b/thirdparty/dietgpu/dietgpu/ans/GpuANSEncode.cuh new file mode 100644 index 000000000..8b0ebda03 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/ans/GpuANSEncode.cuh @@ -0,0 +1,869 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include "dietgpu/ans/BatchPrefixSum.cuh" +#include "dietgpu/ans/GpuANSCodec.h" +#include "dietgpu/ans/GpuANSStatistics.cuh" +#include "dietgpu/ans/GpuANSUtils.cuh" +#include "dietgpu/ans/GpuChecksum.cuh" +#include "dietgpu/utils/DeviceDefs.cuh" +#include "dietgpu/utils/DeviceUtils.h" +#include "dietgpu/utils/PtxUtils.cuh" +#include "dietgpu/utils/StackDeviceMemory.h" +#include "dietgpu/utils/StaticUtils.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace dietgpu { +#if defined(__HIP_PLATFORM_AMD__) + // HIP does not support thrust exec check disabling + #define THRUST_DISABLE_EXEC_CHECK +#else + #define THRUST_DISABLE_EXEC_CHECK __thrust_exec_check_disable__ +#endif + +// maximum raw compressed data block size in bytes +constexpr __host__ __device__ uint32_t +getRawCompBlockMaxSize(uint32_t uncompressedBlockBytes) { + // (an estimate from zstd) + return roundUp( + uncompressedBlockBytes + (uncompressedBlockBytes / 4), kBlockAlignment); +} + +inline uint32_t getMaxBlockSizeUnCoalesced(uint32_t uncompressedBlockBytes) { + // uncoalesced data has a warp state header + return sizeof(ANSWarpState) + getRawCompBlockMaxSize(uncompressedBlockBytes); +} + +inline uint32_t getMaxBlockSizeCoalesced(uint32_t uncompressedBlockBytes) { + return getRawCompBlockMaxSize(uncompressedBlockBytes); +} + +// Returns number of values written to the compressed output +// Assumes all lanes in the warp are presented valid input symbols +template +__device__ __forceinline__ uint32_t encodeOneWarp( + ANSStateT& state, + ANSDecodedT sym, + uint32_t outOffset, + ANSEncodedT* __restrict__ out, + const uint4* __restrict__ smemLookup) { + auto lookup = smemLookup[sym]; + + uint32_t pdf = lookup.x; + uint32_t cdf = lookup.y; + uint32_t div_m1 = lookup.z; + uint32_t div_shift = lookup.w; + + constexpr ANSStateT kStateCheckMul = 1 << (kANSStateBits - ProbBits); + + ANSStateT maxStateCheck = pdf * kStateCheckMul; + bool write = (state >= maxStateCheck); + + auto vote = __ballot_sync(kFullMask, write); + #if defined(__HIP_PLATFORM_AMD__) + auto prefix = __popcll(vote & getLaneMaskLt()); + #else + auto prefix = __popc(vote & getLaneMaskLt()); + #endif + // Some lanes wish to write out their data + if (write) { + out[outOffset + prefix] = state & kANSEncodedMask; + state >>= kANSEncodedBits; + } + + constexpr uint32_t kProbBitsMul = 1 << ProbBits; + + uint32_t t = __umulhi(state, div_m1); + // We prevent addition overflow here by restricting `state` to < 2^31 + // (kANSStateBits) + uint32_t div = (t + state) >> div_shift; + auto mod = state - (div * pdf); + + // calculating ((state / pdf) << ProbBits) + (state % pdf) + cdf + state = div * kProbBitsMul + mod + cdf; + + // how many values we actually write to the compressed output + #if defined(__HIP_PLATFORM_AMD__) + return __popcll(vote); + #else + return __popc(vote); + #endif +} + +// Returns number of values written to the compressed output +// Assumes only some lanes in the warp are presented valid input symbols +template +__device__ __forceinline__ uint32_t encodeOnePartialWarp( + // true for the lanes in the warp for which data read is valid + bool valid, + ANSStateT& state, + ANSDecodedT sym, + uint32_t outOffset, + ANSEncodedT* __restrict__ out, + const uint4* __restrict__ smemLookup) { + auto lookup = smemLookup[sym]; + + uint32_t pdf = lookup.x; + uint32_t cdf = lookup.y; + uint32_t div_m1 = lookup.z; + uint32_t div_shift = lookup.w; + + constexpr ANSStateT kStateCheckMul = 1 << (kANSStateBits - ProbBits); + + ANSStateT maxStateCheck = pdf * kStateCheckMul; + bool write = valid && (state >= maxStateCheck); + + auto vote = __ballot_sync(kFullMask, write); + #if defined(__HIP_PLATFORM_AMD__) + auto prefix = __popcll(vote & getLaneMaskLt()); + #else + auto prefix = __popc(vote & getLaneMaskLt()); + #endif + // Some lanes wish to write out their data + if (write) { + out[outOffset + prefix] = state & kANSEncodedMask; + state >>= kANSEncodedBits; + } + + uint32_t t = __umulhi(state, div_m1); + // We prevent addition overflow here by restricting `state` to < 2^31 + // (kANSStateBits) + uint32_t div = (t + state) >> div_shift; + auto mod = state - (div * pdf); + + // calculating ((state / pdf) << ProbBits) + (state % pdf) + cdf + constexpr uint32_t kProbBitsMul = 1 << ProbBits; + state = valid ? div * kProbBitsMul + mod + cdf : state; + + // how many values we actually write to the compressed output + #if defined(__HIP_PLATFORM_AMD__) + return __popcll(vote); + #else + return __popc(vote); + #endif +} + +// Fully encode a single block of data, along with the state for that block as +// the initial header. +// Returns the number of compressed words (ANSEncodedT) written +template +__device__ uint32_t ansEncodeWarpBlock( + // Current lane ID in the warp + uint32_t laneId, + // Input for this block + const ANSDecodedT* __restrict__ in, + // Number of ANSDecodedT words in this block + uint32_t inWords, + // encoded table in smem + const uint4* __restrict__ table, + // Output for this block + ANSWarpState* __restrict__ out) { + // where we write the compressed words + ANSEncodedT* outWords = (ANSEncodedT*)(out + 1); + + // Start state value for this warp + ANSStateT state = kANSStartState; + + uint32_t inOffset = laneId; + uint32_t outOffset = 0; + + constexpr int kUnroll = 8; + + // Unrolled iterations + uint32_t limit = roundDown(inWords, kWarpSize * kUnroll); + { + ANSDecodedT sym[kUnroll]; + + for (; inOffset < limit; inOffset += kWarpSize * kUnroll) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + sym[j] = in[inOffset + j * kWarpSize]; + } + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + outOffset += + encodeOneWarp(state, sym[j], outOffset, outWords, table); + } + } + } + + if (limit != inWords) { + // Remainder iterations + limit = roundDown(inWords, kWarpSize); + + // Whole warp iterations + for (; inOffset < limit; inOffset += kWarpSize) { + ANSDecodedT sym = in[inOffset]; + + outOffset += + encodeOneWarp(state, sym, outOffset, outWords, table); + } + + // Partial warp iteration + if (limit != inWords) { + // Last iteration may not be a full warp + bool valid = inOffset < inWords; + ANSDecodedT sym = valid ? in[inOffset] : ANSDecodedT(0); + + outOffset += encodeOnePartialWarp( + valid, state, sym, outOffset, outWords, table); + } + } + + // Write final state at the beginning (aligned addresses) + out->warpState[laneId] = state; + + // Number of compressed words written + return outOffset; +} + +// Fully encode a single, full sized block of data, along with the state for +// that block as the initial header. +// Returns the number of compressed words (ANSEncodedT) written +// UseVec4 means that each lane in the warp loads 4 bytes of the input at a +// time, with each byte compressed by the warp in an interleaved fashion. +// Such vectorization must match with decom +template +struct ANSEncodeWarpFullBlock; + +// template +// struct ANSEncodeWarpFullBlock { +// static __device__ uint32_t encode( +// // Current lane ID in the warp +// uint32_t laneId, +// // Input for this block +// const ANSDecodedT* __restrict__ in, +// // encoded table in smem +// const uint4* __restrict__ table, +// // Output for this block +// ANSWarpState* __restrict__ out) { +// // where we write the compressed words +// ANSEncodedT* outWords = (ANSEncodedT*)(out + 1); + +// // Start state value for this warp +// ANSStateT state = kANSStartState; + +// uint32_t outOffset = 0; + +// using VecT = uint32_t; + +// auto inV = (const VecT*)in; +// inV += laneId; + +// // kUnroll 4, unroll 2 164.93 us +// // kUnroll 8, unroll 0 161.86 us +// constexpr int kUnroll = 16; + +// static_assert( +// isEvenDivisor((int)BlockSize, (int)(kUnroll * kWarpSize * +// sizeof(VecT))), +// ""); + +// for (int i = 0; i < BlockSize / (kWarpSize * sizeof(VecT)); +// i += kUnroll, inV += kUnroll * kWarpSize) { +// VecT symV[kUnroll]; + +// #pragma unroll +// for (int j = 0; j < kUnroll; ++j) { +// symV[j] = inV[j * kWarpSize]; +// } + +// #pragma unroll +// for (int j = 0; j < kUnroll; ++j) { +// asm volatile("prefetch.global.L2 [%0];" : : "l"(inV + 128 + j * 32)); +// #pragma unroll +// for (int k = 0; k < 4; ++k) { +// outOffset += encodeOneWarp( +// state, symV[j] & 0xff, outOffset, outWords, table); + +// symV[j] >>= 8; +// } +// } +// } + +// // Write final state at the beginning (aligned addresses) +// out->warpState[laneId] = state; + +// // Number of compressed words written +// return outOffset; +// } +// }; + +template +struct ANSEncodeWarpFullBlock { + static __device__ uint32_t encode( + // Current lane ID in the warp + uint32_t laneId, + // Input for this block + const ANSDecodedT* __restrict__ in, + // encoded table in smem + const uint4* __restrict__ table, + // Output for this block + ANSWarpState* __restrict__ out) { + // Just use the normal implementation + return ansEncodeWarpBlock(laneId, in, BlockSize, table, out); + } +}; + +template +__device__ void ansEncodeBlocksFull( + // input data for all blocks + const ANSDecodedT* __restrict__ in, + // length in ANSDecodedT words + uint32_t uncompressedWords, + // number of blocks that different warps will process + uint32_t numBlocks, + // the stride of each encoded output block + uint32_t outBlockStride, + // address of the output for all blocks + uint8_t* __restrict__ out, + // output array of per-block sizes of number of ANSEncodedT words per block + uint32_t* __restrict__ compressedWords, + // the encoding table that we will load into smem + const uint4* __restrict__ table) { + // grid-wide warp id + auto tid = threadIdx.x; + // so we know the block is warp uniform + int block = + __shfl_sync(kFullMask, (blockIdx.x * blockDim.x + tid) / kWarpSize, 0); + int laneId = getLaneId(); + + __shared__ uint4 smemLookup[kNumSymbols]; + + // we always have at least 256 threads + if (tid < kNumSymbols) { + smemLookup[tid] = table[tid]; + } + + __syncthreads(); + + // How big is this block? + uint32_t start = block * BlockSize; + uint32_t end = min(start + BlockSize, uncompressedWords); + + auto blockSize = end - start; + + // Either the warp is an excess one, or the last block is not a full block and + // needs to be processed using the partial kernel + if (block >= numBlocks || blockSize != BlockSize) { + return; + } + + auto inBlock = in + start; + auto outBlock = (ANSWarpState*)(out + block * outBlockStride); + + // all input blocks must meet alignment requirements + assert(isPointerAligned(inBlock, kANSRequiredAlignment)); + + auto outWords = ANSEncodeWarpFullBlock::encode( + laneId, inBlock, smemLookup, outBlock); + + if (laneId == 0) { + // If the bound on max compressed size is not correct, this assert will go + // off. This block of data was then somewhat adversarial in terms of + // incompressibility. In this case, the getRawCompBlockMaxSize max estimate + // needs to increase. + assert(outWords <= getRawCompBlockMaxSize(BlockSize) / sizeof(ANSEncodedT)); + compressedWords[block] = outWords; + } +} + +template +__device__ void ansEncodeBlocksPartial( + // input data for all blocks + const ANSDecodedT* __restrict__ in, + // length in ANSDecodedT words + uint32_t uncompressedWords, + // number of blocks that different warps will process + uint32_t numBlocks, + // the stride of each encoded output block + uint32_t outBlockStride, + // address of the output for all blocks + uint8_t* __restrict__ out, + // output array of per-block sizes of number of ANSEncodedT words per block + uint32_t* __restrict__ compressedWords, + // the encoding table that we will load into smem + const uint4* __restrict__ table) { + int block = numBlocks - 1; + uint32_t tid = threadIdx.x; + int laneId = getLaneId(); + + __shared__ uint4 smemLookup[kNumSymbols]; + + // we always have at least 256 threads + if (tid < kNumSymbols) { + smemLookup[tid] = table[tid]; + } + + __syncthreads(); + + // We only have a single partial block to handle + if (tid >= kWarpSize) { + return; + } + + // How big is this block? + uint32_t start = block * BlockSize; + uint32_t end = min(start + BlockSize, uncompressedWords); + + auto blockSize = end - start; + + // If the end block is a full block, it would have been handled by the full + // block kernel + if (blockSize == BlockSize) { + return; + } + + auto inBlock = in + start; + auto outBlock = (ANSWarpState*)(out + block * outBlockStride); + + // all input blocks must meet required alignment + assert(isPointerAligned(inBlock, kANSRequiredAlignment)); + + auto outWords = ansEncodeWarpBlock( + laneId, inBlock, blockSize, smemLookup, outBlock); + + if (laneId == 0) { + // If the bound on max compressed size is not correct, this assert will go + // off. This block of data was then somewhat adversarial in terms of + // incompressibility. In this case, the getRawCompBlockMaxSize max estimate + // needs to increase. + assert(outWords <= getRawCompBlockMaxSize(BlockSize) / sizeof(ANSEncodedT)); + compressedWords[block] = outWords; + } +} + +template +__global__ void ansEncodeBatchFull( + // Input data for all blocks + InProvider inProvider, + // maximum number of blocks across all the batch + uint32_t maxNumCompressedBlocks, + // maximum size of a compressed block + uint32_t maxCompressedBlockSize, + // address of the output for all blocks + uint8_t* __restrict__ out, + // output array of per-block sizes of number of ANSEncodedT words per block + // per batch + // [batch][numBlocks] + uint32_t* __restrict__ compressedWords, + // the encoding table that we will load into smem + // [batch][kNumSymbols] + const uint4* __restrict__ table) { + // which batch element we are processing + auto batch = blockIdx.y; + + // Number of blocks for the current problem + uint32_t curSize = inProvider.getBatchSize(batch); + uint32_t numBlocks = divUp(curSize, BlockSize); + + ansEncodeBlocksFull( + (const ANSDecodedT*)inProvider.getBatchStart(batch), + curSize, + numBlocks, + maxCompressedBlockSize, + out + batch * maxNumCompressedBlocks * maxCompressedBlockSize, + compressedWords + batch * maxNumCompressedBlocks, + table + batch * kNumSymbols); +} + +template +__global__ void ansEncodeBatchPartial( + // input data for all blocks + InProvider inProvider, + // maximum number of blocks across all the batch + uint32_t maxNumCompressedBlocks, + // maximum size of a compressed block + uint32_t maxCompressedBlockSize, + // address of the output for all blocks + uint8_t* __restrict__ out, + // output array of per-block sizes of number of ANSEncodedT words per block + // per batch + // [batch][numBlocks] + uint32_t* __restrict__ compressedWords, + // the encoding table that we will load into smem + // [batch][kNumSymbols] + const uint4* __restrict__ table) { + // which batch element we are processing + auto batch = blockIdx.y; + + // Number of blocks for the current problem + uint32_t curSize = inProvider.getBatchSize(batch); + uint32_t numBlocks = divUp(curSize, BlockSize); + + ansEncodeBlocksPartial( + (const ANSDecodedT*)inProvider.getBatchStart(batch), + inProvider.getBatchSize(batch), + numBlocks, + maxCompressedBlockSize, + out + batch * maxNumCompressedBlocks * maxCompressedBlockSize, + compressedWords + batch * maxNumCompressedBlocks, + table + batch * kNumSymbols); +} + +template +struct Align { + typedef uint32_t argument_type; + typedef uint32_t result_type; + + THRUST_DISABLE_EXEC_CHECK template + __host__ __device__ uint32_t operator()(T x) const { + constexpr int kDiv = B / sizeof(A); + constexpr int kSize = kDiv < 1 ? 1 : kDiv; + + return roundUp(x, T(kSize)); + } +}; + +template +__device__ void ansEncodeCoalesce( + const uint8_t* __restrict__ inUncoalescedBlocks, + uint32_t uncoalescedBlockStride, + const uint32_t* __restrict__ compressedWords, + const uint32_t* __restrict__ compressedWordsPrefix, + const uint32_t* __restrict__ checksum, + const uint4* __restrict__ table, + uint32_t probBits, + bool useChecksum, + uint32_t numBlocks, + uint32_t uncompressedWords, + uint8_t* __restrict__ out, + uint32_t* __restrict__ compressedBytes) { + auto block = blockIdx.x; + auto tid = threadIdx.x; + + ANSCoalescedHeader* headerOut = (ANSCoalescedHeader*)out; + + // The first block will be responsible for the coalesced header + if (block == 0) { + if (tid == 0) { + uint32_t totalCompressedWords = 0; + + // Could be a header for a zero sized array + if (numBlocks > 0) { + totalCompressedWords = + // total number of compressed words in all blocks + // this is already a multiple of kBlockAlignment / + // sizeof(ANSEncodedT) + compressedWordsPrefix[numBlocks - 1] + + // this is not yet a multiple of kBlockAlignment / + // sizeof(ANSEncodedT), but needs to be + roundUp( + compressedWords[numBlocks - 1], + kBlockAlignment / sizeof(ANSEncodedT)); + } + + ANSCoalescedHeader header; + header.setMagicAndVersion(); + header.setNumBlocks(numBlocks); + header.setTotalUncompressedWords(uncompressedWords); + header.setTotalCompressedWords(totalCompressedWords); + header.setProbBits(probBits); + header.setUseChecksum(useChecksum); + + if (useChecksum) { + header.setChecksum(*checksum); + } + + if (compressedBytes) { + *compressedBytes = header.getTotalCompressedSize(); + } + + *headerOut = header; + } + + auto probsOut = headerOut->getSymbolProbs(); + + // Write out pdf + for (int i = tid; i < kNumSymbols; i += Threads) { + probsOut[i] = table[i].x; + } + } + + if (block >= numBlocks) { + return; + } + + // where our per-warp data lies + auto uncoalescedBlock = inUncoalescedBlocks + block * uncoalescedBlockStride; + + // Write per-block warp state + if (tid < kWarpSize) { + auto warpStateIn = (ANSWarpState*)uncoalescedBlock; + + headerOut->getWarpStates()[block].warpState[tid] = + warpStateIn->warpState[tid]; + } + + auto blockWordsOut = headerOut->getBlockWords(numBlocks); + + // Write out per-block word length + for (auto i = blockIdx.x * Threads + tid; i < numBlocks; + i += gridDim.x * Threads) { + uint32_t lastBlockWords = uncompressedWords % kDefaultBlockSize; + lastBlockWords = lastBlockWords == 0 ? kDefaultBlockSize : lastBlockWords; + + uint32_t blockWords = + (i == numBlocks - 1) ? lastBlockWords : kDefaultBlockSize; + + blockWordsOut[i] = uint2{ + (blockWords << 16) | compressedWords[i], compressedWordsPrefix[i]}; + } + + // Number of compressed words in this block + uint32_t numWords = compressedWords[block]; + + // We always have a valid multiple of kBlockAlignment bytes on both + // uncoalesced src and coalesced dest, even though numWords (actual encoded + // words) may be less than that + using LoadT = uint4; + static_assert(sizeof(LoadT) == kBlockAlignment, ""); + + uint32_t limitEnd = divUp(numWords, kBlockAlignment / sizeof(ANSEncodedT)); + + auto inT = (const LoadT*)(uncoalescedBlock + sizeof(ANSWarpState)); + auto outT = (LoadT*)(headerOut->getBlockDataStart(numBlocks) + + compressedWordsPrefix[block]); + + for (uint32_t i = tid; i < limitEnd; i += Threads) { + outT[i] = inT[i]; + } +} + +template +__global__ void ansEncodeCoalesceBatch( + const uint8_t* __restrict__ inUncoalescedBlocks, + SizeProvider sizeProvider, + uint32_t maxNumCompressedBlocks, + uint32_t uncoalescedBlockStride, + const uint32_t* __restrict__ compressedWords, + const uint32_t* __restrict__ compressedWordsPrefix, + const uint32_t* __restrict__ checksum, + const uint4* __restrict__ table, + uint32_t probBits, + bool useChecksum, + OutProvider outProvider, + uint32_t* __restrict__ compressedBytes) { + auto batch = blockIdx.y; + auto uncompressedWords = sizeProvider.getBatchSize(batch); + + // Number of compressed blocks in this batch element + auto numBlocks = divUp(uncompressedWords, kDefaultBlockSize); + + // Advance all pointers to handle our specific batch member + inUncoalescedBlocks += + batch * uncoalescedBlockStride * maxNumCompressedBlocks; + compressedWords += batch * maxNumCompressedBlocks; + compressedWordsPrefix += batch * maxNumCompressedBlocks; + compressedBytes += batch; + checksum += batch; + table += batch * kNumSymbols; + + ansEncodeCoalesce( + inUncoalescedBlocks, + uncoalescedBlockStride, + compressedWords, + compressedWordsPrefix, + checksum, + table, + probBits, + useChecksum, + numBlocks, + uncompressedWords, + (uint8_t*)outProvider.getBatchStart(batch), + compressedBytes); +} + +template +void ansEncodeBatchDevice( + StackDeviceMemory& res, + const ANSCodecConfig& config, + uint32_t numInBatch, + InProvider inProvider, + const uint32_t* histogram_dev, + uint32_t maxSize, + OutProvider outProvider, + uint32_t* outSize_dev, + cudaStream_t stream) { + auto maxUncompressedWords = maxSize / sizeof(ANSDecodedT); + uint32_t maxNumCompressedBlocks = + divUp(maxUncompressedWords, kDefaultBlockSize); + + // 1. Compute symbol statistics + auto table_dev = res.alloc(stream, numInBatch * kNumSymbols); + + if (histogram_dev) { + // use pre-calculated histogram + ansCalcWeights( + numInBatch, + config.probBits, + inProvider, + histogram_dev, + table_dev.data(), + stream); + } else { + auto tempHistogram_dev = + res.alloc(stream, numInBatch * kNumSymbols); + + // need to calculate a histogram + ansHistogramBatch(numInBatch, inProvider, tempHistogram_dev.data(), stream); + + ansCalcWeights( + numInBatch, + config.probBits, + inProvider, + tempHistogram_dev.data(), + table_dev.data(), + stream); + } + + // 2. Compute checksum on input data (optional) + auto checksum_dev = res.alloc(stream, numInBatch); + if (config.useChecksum) { + checksumBatch(numInBatch, inProvider, checksum_dev.data(), stream); + } + + // 3. Allocate memory for the per-warp results + // How much space in bytes we need to reserve for each warp's output + uint32_t uncoalescedBlockStride = + getMaxBlockSizeUnCoalesced(kDefaultBlockSize); + + auto compressedBlocks_dev = res.alloc( + stream, numInBatch * maxNumCompressedBlocks * uncoalescedBlockStride); + + // +1 in order to get the final sum as well + auto compressedWords_dev = + res.alloc(stream, numInBatch * maxNumCompressedBlocks); + + // Exclusive prefix sum of the compressed sizes (so we know where to write in + // the contiguous output). The offsets are aligned to a multple of 4 + auto compressedWordsPrefix_dev = + res.alloc(stream, numInBatch * maxNumCompressedBlocks); + + // Run per-warp encoding + // (only if we have blocks to compress) + if (maxNumCompressedBlocks > 0) { + constexpr int kThreads = 256; + + // The grid for the full block kernel + auto gridFull = dim3( + divUp((int)maxNumCompressedBlocks, kThreads / kWarpSize), numInBatch); + + // The grid for the partial block kernel; at most 1 partial block per input + // in the batch + auto gridPartial = dim3(1, numInBatch); + +#define RUN_ENCODE(BITS) \ + do { \ + ansEncodeBatchFull \ + <<>>( \ + inProvider, \ + maxNumCompressedBlocks, \ + uncoalescedBlockStride, \ + compressedBlocks_dev.data(), \ + compressedWords_dev.data(), \ + table_dev.data()); \ + \ + ansEncodeBatchPartial \ + <<>>( \ + inProvider, \ + maxNumCompressedBlocks, \ + uncoalescedBlockStride, \ + compressedBlocks_dev.data(), \ + compressedWords_dev.data(), \ + table_dev.data()); \ + } while (false) + + switch (config.probBits) { + case 9: + RUN_ENCODE(9); + break; + case 10: + RUN_ENCODE(10); + break; + case 11: + RUN_ENCODE(11); + break; + default: + CHECK(false) << "unhandled pdf precision " << config.probBits; + } + +#undef RUN_ENCODE + } + + // Perform exclusive prefix sum of the number of compressed words per block, + // so we know where to write the output. We align the blocks so that we can + // write state values at 4 byte alignment at the beginning. + // FIXME: probably some way to do this via thrust::exclusive_scan_by_key with + // transform iterators and what not + if (maxNumCompressedBlocks > 0) { + auto sizeRequired = + getBatchExclusivePrefixSumTempSize(numInBatch, maxNumCompressedBlocks); + + // FIXME: we can run a more minimal segmented prefix sum instead of using + // maxNumCompressedBlocks + if (sizeRequired == 0) { + batchExclusivePrefixSum>( + compressedWords_dev.data(), + compressedWordsPrefix_dev.data(), + nullptr, + numInBatch, + maxNumCompressedBlocks, + Align(), + stream); + } else { + auto tempPrefixSum_dev = res.alloc(stream, sizeRequired); + + batchExclusivePrefixSum>( + compressedWords_dev.data(), + compressedWordsPrefix_dev.data(), + tempPrefixSum_dev.data(), + numInBatch, + maxNumCompressedBlocks, + Align(), + stream); + } + } + + // Coalesce the data into one contiguous buffer + // Even if there is nothing to compress, we still need to create a compression + // header + { + constexpr int kThreads = 64; + auto grid = dim3(std::max(maxNumCompressedBlocks, 1U), numInBatch); + + ansEncodeCoalesceBatch + <<>>( + compressedBlocks_dev.data(), + inProvider, + maxNumCompressedBlocks, + uncoalescedBlockStride, + compressedWords_dev.data(), + compressedWordsPrefix_dev.data(), + checksum_dev.data(), + table_dev.data(), + config.probBits, + config.useChecksum, + outProvider, + outSize_dev); + } + + CUDA_TEST_ERROR(); +} + +} // namespace dietgpu + +#undef RUN_ENCODE_ALL diff --git a/thirdparty/dietgpu/dietgpu/ans/GpuANSInfo.cu b/thirdparty/dietgpu/dietgpu/ans/GpuANSInfo.cu new file mode 100644 index 000000000..fcdcbe937 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/ans/GpuANSInfo.cu @@ -0,0 +1,51 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "dietgpu/ans/BatchProvider.cuh" +#include "dietgpu/ans/GpuANSCodec.h" +#include "dietgpu/ans/GpuANSInfo.cuh" + +namespace dietgpu { + +void ansGetCompressedInfo( + StackDeviceMemory& res, + const void** in, + uint32_t numInBatch, + uint32_t* outSizes_dev, + uint32_t* outChecksum_dev, + cudaStream_t stream) { + if (!outSizes_dev && !outChecksum_dev) { + return; + } + + auto in_dev = res.copyAlloc(stream, (void**)in, numInBatch); + ansGetCompressedInfoDevice( + res, + (const void**)in_dev.data(), + numInBatch, + outSizes_dev, + outChecksum_dev, + stream); +} + +void ansGetCompressedInfoDevice( + StackDeviceMemory& res, + const void** in_dev, + uint32_t numInBatch, + uint32_t* outSizes_dev, + uint32_t* outChecksum_dev, + cudaStream_t stream) { + if (!outSizes_dev && !outChecksum_dev) { + return; + } + + auto inProvider = BatchProviderPointer((void**)in_dev); + ansGetCompressedInfo( + inProvider, numInBatch, outSizes_dev, outChecksum_dev, stream); +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/ans/GpuANSInfo.cuh b/thirdparty/dietgpu/dietgpu/ans/GpuANSInfo.cuh new file mode 100644 index 000000000..7d7988fc7 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/ans/GpuANSInfo.cuh @@ -0,0 +1,59 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include "dietgpu/ans/GpuANSUtils.cuh" +#include "dietgpu/utils/DeviceUtils.h" +#include "dietgpu/utils/StackDeviceMemory.h" +#include "dietgpu/utils/StaticUtils.h" + +namespace dietgpu { + +template +__global__ void ansGetCompressedInfoKernel( + InProvider inProvider, + uint32_t numInBatch, + uint32_t* outSizes, + uint32_t* outChecksum) { + auto batch = blockIdx.x * blockDim.x + threadIdx.x; + if (batch < numInBatch) { + auto header = (const ANSCoalescedHeader*)inProvider.getBatchStart(batch); + // Make sure it is valid + header->checkMagicAndVersion(); + + if (outSizes) { + outSizes[batch] = header->getTotalUncompressedWords(); + } + + if (outChecksum) { + assert(header->getUseChecksum()); + outChecksum[batch] = header->getChecksum(); + } + } +} + +template +void ansGetCompressedInfo( + InProvider& inProvider, + uint32_t numInBatch, + uint32_t* outSizes_dev, + uint32_t* outChecksum_dev, + cudaStream_t stream) { + if (!outSizes_dev && !outChecksum_dev) { + return; + } + + auto block = 128; + auto grid = divUp(numInBatch, block); + + ansGetCompressedInfoKernel<<>>( + inProvider, numInBatch, outSizes_dev, outChecksum_dev); + + CUDA_TEST_ERROR(); +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/ans/GpuANSStatistics.cuh b/thirdparty/dietgpu/dietgpu/ans/GpuANSStatistics.cuh new file mode 100644 index 000000000..845328f11 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/ans/GpuANSStatistics.cuh @@ -0,0 +1,432 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "dietgpu/ans/BatchProvider.cuh" +#include "dietgpu/ans/GpuANSUtils.cuh" +#include "dietgpu/utils/DeviceDefs.cuh" +#include "dietgpu/utils/DeviceUtils.h" +#include "dietgpu/utils/PtxUtils.cuh" +#include "dietgpu/utils/StaticUtils.h" + +#include +#include +#include + +namespace dietgpu { + +template +__device__ void histogramSingle( + const ANSDecodedT* __restrict__ in, + uint32_t size, + uint32_t* __restrict__ out) { + constexpr int kWarps = Threads / kWarpSize; + static_assert(Threads == kNumSymbols, ""); + + // +1 in order to force very common symbols that could overlap into different + // banks? + __shared__ uint32_t buckets[kWarps][kNumSymbols + 1]; + + auto warpId = threadIdx.x / kWarpSize; + +#pragma unroll + for (int i = 0; i < kWarps; ++i) { + buckets[i][threadIdx.x] = 0; + } + + __syncthreads(); + + uint32_t* warpBucket = buckets[warpId]; + + // If the size of batch is smaller than the increment for alignment, we only + // handle the batch + auto roundUp4 = min(size, getAlignmentRoundUp(in)); + + // The size of data that remains after alignment + auto remaining = size - roundUp4; + + // The size of data (in uint4 words) that we can process with alignment + uint32_t numU4 = divDown(remaining, sizeof(uint4)); + + auto inAligned = in + roundUp4; + auto inAligned4 = (const uint4*)inAligned; + + // Handle the non-aligned portion that we have to load as single words, if any + if (blockIdx.x == 0 && threadIdx.x < roundUp4) { + static_assert(sizeof(uint4) <= Threads, ""); + atomicAdd(&warpBucket[in[threadIdx.x]], 1); + } + + // Handle the portion that is aligned and uint4 vectorizable + // 37.60 us / 80.76% gmem / 51.29% smem for uint4 on A100 + for (uint32_t i = blockIdx.x * Threads + threadIdx.x; i < numU4; + i += gridDim.x * Threads) { + uint4 v = inAligned4[i]; + + { + uint32_t x = v.x; + atomicAdd(&warpBucket[x & 0xff], 1); + x >>= 8; + atomicAdd(&warpBucket[x & 0xff], 1); + x >>= 8; + atomicAdd(&warpBucket[x & 0xff], 1); + x >>= 8; + atomicAdd(&warpBucket[x], 1); + } + + { + uint32_t y = v.y; + atomicAdd(&warpBucket[y & 0xff], 1); + y >>= 8; + atomicAdd(&warpBucket[y & 0xff], 1); + y >>= 8; + atomicAdd(&warpBucket[y & 0xff], 1); + y >>= 8; + atomicAdd(&warpBucket[y], 1); + } + + { + uint32_t z = v.z; + atomicAdd(&warpBucket[z & 0xff], 1); + z >>= 8; + atomicAdd(&warpBucket[z & 0xff], 1); + z >>= 8; + atomicAdd(&warpBucket[z & 0xff], 1); + z >>= 8; + atomicAdd(&warpBucket[z], 1); + } + + { + uint32_t w = v.w; + atomicAdd(&warpBucket[w & 0xff], 1); + w >>= 8; + atomicAdd(&warpBucket[w & 0xff], 1); + w >>= 8; + atomicAdd(&warpBucket[w & 0xff], 1); + w >>= 8; + atomicAdd(&warpBucket[w], 1); + } + } + + if (blockIdx.x == 0) { + // Handle the remainder portion that doesn't comprise full words + int i = numU4 * sizeof(uint4) + threadIdx.x; + if (i < remaining) { + atomicAdd(&warpBucket[inAligned[i]], 1); + } + } + + __syncthreads(); + + uint32_t sum = buckets[0][threadIdx.x]; +#pragma unroll + for (int j = 1; j < kWarps; ++j) { + sum += buckets[j][threadIdx.x]; + } + + // The count for the thread's bucket could be 0 + if (sum) { + atomicAdd(&out[threadIdx.x], sum); + } +} + +template +__global__ void histogramBatch(InProvider in, uint32_t* out) { + auto batch = blockIdx.y; + out += batch * kNumSymbols; + + histogramSingle( + (const ANSDecodedT*)in.getBatchStart(batch), in.getBatchSize(batch), out); +} + +// sum that allows passing in smem for usage, so as to avoid a trailing +// syncthreads and associated latency +template +__device__ inline int +blockSum(int warpId, int laneId, int valForSum, int* smem) { + static_assert(isEvenDivisor(Threads, kWarpSize), ""); + constexpr int kWarps = Threads / kWarpSize; + + auto allSum = warpReduceAllSum(valForSum); + + if (laneId == 0) { + smem[warpId] = allSum; + } + __syncthreads(); + + if (warpId == 0) { + int v = laneId < kWarps ? smem[laneId] : 0; + v = warpReduceAllSum(v); + + if (laneId == 0) { + smem[0] = v; + } + } + + __syncthreads(); + + // trailing syncthreads is elsewhere + return smem[0]; +} + +// Function that allows normalization of symbol probabilities with a varying +// (statically known) number of threads, to allow for kernel fusion as needed +// Stand-alone normalization will use Threads == kNumSymbols (256) +template +__device__ void normalizeProbabilitiesFromHistogram( + // Size 256 histogram in gmem + const uint32_t* __restrict__ counts, + uint32_t totalNum, + int probBits, + uint4* __restrict__ table) { + static_assert( + kNumSymbols == Threads || isEvenDivisor(kNumSymbols, uint32_t(Threads)), + ""); + + constexpr int kNumSymPerThread = + kNumSymbols == Threads ? 1 : (kNumSymbols / Threads); + + // There's nothing to do if the input array in the batch was of zero size + if (totalNum == 0) { + return; + } + + constexpr int kWarps = Threads / kWarpSize; + uint32_t kProbWeight = 1 << probBits; + auto tid = threadIdx.x; + int warpId = tid / kWarpSize; + int laneId = getLaneId(); + + // Load the current count and compute the min/max non-zero values, then + // perform an approximate quantization + uint32_t qProb[kNumSymPerThread]; + + int qProbSum = 0; + +#pragma unroll + for (int i = 0; i < kNumSymPerThread; ++i) { + int curSym = i * Threads + tid; + uint32_t count = counts[curSym]; + + // Rough initial quantization + qProb[i] = kProbWeight * ((float)count / (float)totalNum); + + // All weights for symbols present must be > 0 + qProb[i] = (count > 0 && qProb[i] == 0) ? 1 : qProb[i]; + + qProbSum += qProb[i]; + } + + // Sum qProbSym across all threads + __shared__ int smemSum[kWarps]; + qProbSum = blockSum(warpId, laneId, qProbSum, smemSum); + + // In order to use radix sorting, and also in order to only sort a single + // word, pack both the weight and index into a single integer + uint32_t sortedPair[kNumSymPerThread]; + +#pragma unroll + for (int i = 0; i < kNumSymPerThread; ++i) { + int curSym = i * Threads + tid; + sortedPair[i] = (qProb[i] << 16) | curSym; + } + + // The sort assumes a blocked arrangement as input, which we don't have, but + // this doesn't matter as we only care about the arrangement post-sort + using Sort = cub::BlockRadixSort; + __shared__ typename Sort::TempStorage smemSort; + Sort(smemSort).SortDescending(sortedPair); + + // The (prob, symbol) pair that each thread is considered to + // hold is the following: + uint32_t tidSymbol[kNumSymPerThread]; + + // Recover the values +#pragma unroll + for (int i = 0; i < kNumSymPerThread; ++i) { + tidSymbol[i] = sortedPair[i] & 0xffffU; + qProb[i] = sortedPair[i] >> 16; + } + + // How far below (positive) or above (negative) our current first-pass + // quantization is from our target sum 2^probBits + int diff = (int)kProbWeight - (int)qProbSum; + + if (diff > 0) { + // We are below our total sum target; add 1 to largest values + // FIXME: use div/mod to avoid iterations + while (diff > 0) { + int iterToApply = diff < kNumSymbols ? diff : kNumSymbols; + +#pragma unroll + for (int i = 0; i < kNumSymPerThread; ++i) { + int curSym = tidSymbol[i]; + if (curSym < iterToApply) { + qProb[i] += 1; + } + } + + diff -= iterToApply; + } + } else if (diff < 0) { + // We are above our total sum target; subtract 1 from the smallest values + // that are > 1 (all symbols with a weight of 1 cannot go to zero as they + // are assumed present in the input) + diff = -diff; + + while (diff > 0) { + // Need to determine the number of + int qNumGt1s = 0; + +#pragma unroll + for (int i = 0; i < kNumSymPerThread; ++i) { + qNumGt1s += (int)(qProb[i] > 1); + } + + // We need to determine the remaining number of >1 values + // We reuse smemSum but there is a syncthreads in the sort above, and at + // the end of the loop below + qNumGt1s = blockSum(warpId, laneId, qNumGt1s, smemSum); + __syncthreads(); // FIXME: not needed? + + // subtract from smallest >1 values + // This should be the index of the first 1 value + // FIXME: use div/mod to avoid iterations + int iterToApply = diff < qNumGt1s ? diff : qNumGt1s; + assert(iterToApply > 0); + int startIndex = qNumGt1s - iterToApply; + +#pragma unroll + for (int i = 0; i < kNumSymPerThread; ++i) { + // Post-sort, the data is in a blocked arrangement + int curSym = tid * kNumSymPerThread + i; + if (curSym >= startIndex && curSym < qNumGt1s) { + qProb[i] -= 1; + } + } + + diff -= iterToApply; + + __syncthreads(); + } + } + + // Recover the pre-sort order + __shared__ uint32_t smemPdf[kNumSymbols]; + +#pragma unroll + for (int i = 0; i < kNumSymPerThread; ++i) { + smemPdf[tidSymbol[i]] = qProb[i]; + } + + __syncthreads(); + + // NOTE: we need to have a contiguous blocked arrangement for cub::BlockScan + // when kNumSymPerThread > 1, so the order is now tid * kNumSymPerThread + reg + uint32_t symPdf[kNumSymPerThread]; +#pragma unroll + for (int i = 0; i < kNumSymPerThread; ++i) { + int curSym = tid * kNumSymPerThread + i; + symPdf[i] = smemPdf[curSym]; + } + + using Scan = cub::BlockScan; + __shared__ typename Scan::TempStorage smemScan; + + // FIXME: initialize to 0? + uint32_t symCdf[kNumSymPerThread]; + Scan(smemScan).ExclusiveSum(symPdf, symCdf); + + // Compute divisor information (constant division via integer + // multiplication + shift) + uint32_t shift[kNumSymPerThread]; + uint32_t magic[kNumSymPerThread]; + +#pragma unroll + for (int i = 0; i < kNumSymPerThread; ++i) { + shift[i] = 32 - __clz(symPdf[i] - 1); + + constexpr uint64_t one = 1; + uint64_t magic64 = + ((one << 32) * ((one << shift[i]) - symPdf[i])) / symPdf[i] + 1; + + // should not overflow + magic[i] = (uint32_t)magic64; + } + +#pragma unroll + for (int i = 0; i < kNumSymPerThread; ++i) { + // Same blocked contiguous ordering as before + // Note that this is no longer a coalesced write + int curSym = tid * kNumSymPerThread + i; + table[curSym] = uint4{symPdf[i], symCdf[i], magic[i], shift[i]}; + } +} + +template +__global__ void quantizeWeights( + const uint32_t* __restrict__ counts, + SizeProvider sizeProvider, + int probBits, + uint4* __restrict__ table) { + auto batch = blockIdx.x; + + normalizeProbabilitiesFromHistogram( + counts + batch * kNumSymbols, + sizeProvider.getBatchSize(batch), + probBits, + table + batch * kNumSymbols); +} + +template +void ansHistogramBatch( + uint32_t numInBatch, + InProvider inProvider, + // size numInBatch * kNumSymbols + uint32_t* histogram_dev, + cudaStream_t stream) { + // 1. Compute symbol histogram + // zero out buckets before proceeding, as we aggregate with atomic adds + CUDA_VERIFY(cudaMemsetAsync( + histogram_dev, 0, sizeof(uint32_t) * kNumSymbols * numInBatch, stream)); + + { + constexpr uint32_t kThreads = kNumSymbols; + + // What is the maximum number of blocks to saturate the GPU? + int maxBlocks = 0; + CUDA_VERIFY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &maxBlocks, histogramBatch, kThreads, 0)); + maxBlocks *= getCurrentDeviceProperties().multiProcessorCount; + + // The y block dimension will be for each batch element + uint32_t xBlocks = divUp(maxBlocks, numInBatch); + auto grid = dim3(xBlocks, numInBatch); + + histogramBatch + <<>>(inProvider, histogram_dev); + } +} + +template +inline void ansCalcWeights( + uint32_t numInBatch, + int probBits, + // we only use this for sizes (of each input batch member) + SizeProvider sizeProvider, + // size numInBatch * kNumSymbols + const uint32_t* histogram_dev, + // size numInBatch * kNumSymbols + uint4* table_dev, + cudaStream_t stream) { + // Quantize weights and determine integer ANS division factors + constexpr int kThreads = kNumSymbols; + + quantizeWeights<<>>( + histogram_dev, sizeProvider, probBits, table_dev); +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/ans/GpuANSUtils.cuh b/thirdparty/dietgpu/dietgpu/ans/GpuANSUtils.cuh new file mode 100644 index 000000000..6fe6f92c9 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/ans/GpuANSUtils.cuh @@ -0,0 +1,233 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include "dietgpu/utils/DeviceDefs.cuh" +#include "dietgpu/utils/StaticUtils.h" + +namespace dietgpu { + +using ANSStateT = uint32_t; +using ANSEncodedT = uint16_t; +using ANSDecodedT = uint8_t; + +struct __align__(16) ANSDecodedTx16 { + ANSDecodedT x[16]; +}; + +struct __align__(8) ANSDecodedTx8 { + ANSDecodedT x[8]; +}; + +struct __align__(4) ANSDecodedTx4 { + ANSDecodedT x[4]; +}; + +constexpr uint32_t kNumSymbols = 1 << (sizeof(ANSDecodedT) * 8); +static_assert(kNumSymbols > 1, ""); + +// Default block size for compression (in bytes) +constexpr uint32_t kDefaultBlockSize = 4096; + +// limit state to 2^31 - 1, so as to prevent addition overflow in the integer +// division via mul and shift by constants +constexpr int kANSStateBits = sizeof(ANSStateT) * 8 - 1; +constexpr int kANSEncodedBits = sizeof(ANSEncodedT) * 8; // out bits +constexpr ANSStateT kANSEncodedMask = + (ANSStateT(1) << kANSEncodedBits) - ANSStateT(1); + +constexpr ANSStateT kANSStartState = ANSStateT(1) + << (kANSStateBits - kANSEncodedBits); +constexpr ANSStateT kANSMinState = ANSStateT(1) + << (kANSStateBits - kANSEncodedBits); + +// magic number to verify archive integrity +constexpr uint32_t kANSMagic = 0xd00d; + +// current DietGPU version number +constexpr uint32_t kANSVersion = 0x0001; + +// Each block of compressed data (either coalesced or uncoalesced) is aligned to +// this number of bytes and has a valid (if not all used) segment with this +// multiple of bytes +constexpr uint32_t kBlockAlignment = 16; + +struct ANSWarpState { + // The ANS state data for this warp + ANSStateT warpState[kWarpSize]; +}; + +struct __align__(32) ANSCoalescedHeader { + static __host__ __device__ uint32_t getCompressedOverhead( + uint32_t numBlocks) { + constexpr int kAlignment = kBlockAlignment / sizeof(uint2) == 0 + ? 1 + : kBlockAlignment / sizeof(uint2); + + return sizeof(ANSCoalescedHeader) + + // probs + sizeof(uint16_t) * kNumSymbols + + // states + sizeof(ANSWarpState) * numBlocks + + // block words + sizeof(uint2) * roundUp(numBlocks, kAlignment); + } + + __host__ __device__ uint32_t getTotalCompressedSize() const { + return getCompressedOverhead() + + getTotalCompressedWords() * sizeof(ANSEncodedT); + } + + __host__ __device__ uint32_t getCompressedOverhead() const { + return getCompressedOverhead(getNumBlocks()); + } + + __host__ __device__ float getCompressionRatio() const { + return (float)getTotalCompressedSize() / + (float)getTotalUncompressedWords() * sizeof(ANSDecodedT); + } + + __host__ __device__ uint32_t getNumBlocks() const { + return numBlocks; + } + + __host__ __device__ void setNumBlocks(uint32_t nb) { + numBlocks = nb; + } + + __host__ __device__ void setMagicAndVersion() { + magicAndVersion = (kANSMagic << 16) | kANSVersion; + } + + __host__ __device__ void checkMagicAndVersion() const { + assert((magicAndVersion >> 16) == kANSMagic); + assert((magicAndVersion & 0xffffU) == kANSVersion); + } + + __host__ __device__ uint32_t getTotalUncompressedWords() const { + return totalUncompressedWords; + } + + __host__ __device__ void setTotalUncompressedWords(uint32_t words) { + totalUncompressedWords = words; + } + + __host__ __device__ uint32_t getTotalCompressedWords() const { + return totalCompressedWords; + } + + __host__ __device__ void setTotalCompressedWords(uint32_t words) { + totalCompressedWords = words; + } + + __host__ __device__ uint32_t getProbBits() const { + return options & 0xf; + } + + __host__ __device__ void setProbBits(uint32_t bits) { + assert(bits <= 0xf); + options = (options & 0xfffffff0U) | bits; + } + + __host__ __device__ bool getUseChecksum() const { + return options & 0x10; + } + + __host__ __device__ void setUseChecksum(bool uc) { + options = (options & 0xffffffef) | (uint32_t(uc) << 4); + } + + __host__ __device__ uint32_t getChecksum() const { + return checksum; + } + + __host__ __device__ void setChecksum(uint32_t c) { + checksum = c; + } + + __device__ uint16_t* getSymbolProbs() { + return (uint16_t*)(this + 1); + } + + __device__ const uint16_t* getSymbolProbs() const { + return (const uint16_t*)(this + 1); + } + + __device__ ANSWarpState* getWarpStates() { + return (ANSWarpState*)(getSymbolProbs() + kNumSymbols); + } + + __device__ const ANSWarpState* getWarpStates() const { + return (const ANSWarpState*)(getSymbolProbs() + kNumSymbols); + } + + __device__ uint2* getBlockWords(uint32_t numBlocks) { + // All of the ANSWarpStates are already kBlockAlignment aligned + return (uint2*)(getWarpStates() + numBlocks); + } + + __device__ const uint2* getBlockWords(uint32_t numBlocks) const { + // All of the ANSWarpStates are already kBlockAlignment aligned + return (const uint2*)(getWarpStates() + numBlocks); + } + + __device__ ANSEncodedT* getBlockDataStart(uint32_t numBlocks) { + constexpr int kAlignment = kBlockAlignment / sizeof(uint2) == 0 + ? 1 + : kBlockAlignment / sizeof(uint2); + + return (ANSEncodedT*)(getBlockWords(numBlocks) + + roundUp(numBlocks, kAlignment)); + } + + __device__ const ANSEncodedT* getBlockDataStart(uint32_t numBlocks) const { + constexpr int kAlignment = kBlockAlignment / sizeof(uint2) == 0 + ? 1 + : kBlockAlignment / sizeof(uint2); + + return (const ANSEncodedT*)(getBlockWords(numBlocks) + + roundUp(numBlocks, kAlignment)); + } + + // (16: magic)(16: version) + uint32_t magicAndVersion; + uint32_t numBlocks; + uint32_t totalUncompressedWords; + uint32_t totalCompressedWords; + + // (27: unused)(1: use checksum)(4: probBits) + uint32_t options; + uint32_t checksum; + uint32_t unused0; + uint32_t unused1; + + // Data that follows after the header (some of which is variable length): + + // Fixed length array + // uint16_t probs[kNumSymbols]; + + // Variable length array: + // ANSWarpState states[numBlocks]; + + // Per-block information: + // (uint16: uncompressedWords, uint16: compressedWords) + // uint32: blockCompressedWordStart + // + // Variable length array: + // uint2 blockWords[roundUp(numBlocks, kBlockAlignment / sizeof(uint2))]; + + // Then follows the compressed per-warp/block data for each segment +}; + +static_assert(sizeof(ANSCoalescedHeader) == 32, ""); + +static_assert(isEvenDivisor(sizeof(ANSCoalescedHeader), sizeof(uint4)), ""); + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/ans/GpuChecksum.cuh b/thirdparty/dietgpu/dietgpu/ans/GpuChecksum.cuh new file mode 100644 index 000000000..754e367aa --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/ans/GpuChecksum.cuh @@ -0,0 +1,135 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include "dietgpu/utils/DeviceDefs.cuh" +#include "dietgpu/utils/DeviceUtils.h" +#include "dietgpu/utils/PtxUtils.cuh" +#include "dietgpu/utils/StaticUtils.h" + +#include + +namespace dietgpu { + +template +struct ReduceXor { + __host__ __device__ __forceinline__ T + operator()(const T& a, const T& b) const { + return a ^ b; + } +}; + +template +__device__ void checksumSingle( + const uint8_t* __restrict__ in, + uint32_t size, + uint32_t* __restrict__ out) { + // FIXME: general presumption in dietgpu that input data for ANS is only byte + // aligned, while float data is only float word aligned, whereas ideally we + // would like a 32 bit checksum. Since there is ultimately no guarantee of + // anything but byte alignment and we wish to compute the same checksum + // regardless of memory placement, the only checksum that makes sense to + // produce is uint8. + // We can fix this to compute a full 32-bit checksum by keeping track of + // initial alignment and shuffling data around I think. + uint32_t checksum32 = 0; + + // If the size of batch is smaller than the increment for alignment, we only + // handle the batch + auto roundUp4 = min(size, getAlignmentRoundUp(in)); + + // The size of data that remains after alignment + auto remaining = size - roundUp4; + + // The size of data (in uint4 words) that we can process with alignment + uint32_t numU4 = divDown(remaining, sizeof(uint4)); + + auto inAligned = in + roundUp4; + auto inAligned4 = (const uint4*)inAligned; + + // Handle the non-aligned portion that we have to load as single bytes, if any + if (blockIdx.x == 0 && threadIdx.x < roundUp4) { + static_assert(sizeof(uint4) <= Threads, ""); + checksum32 ^= in[threadIdx.x]; + } + + // Handle the portion that is aligned and uint4 vectorizable + // 37.60 us / 80.76% gmem / 51.29% smem for uint4 on A100 + for (uint32_t i = blockIdx.x * Threads + threadIdx.x; i < numU4; + i += gridDim.x * Threads) { + uint4 v = inAligned4[i]; + + checksum32 ^= v.x; + checksum32 ^= v.y; + checksum32 ^= v.z; + checksum32 ^= v.w; + } + + if (blockIdx.x == 0) { + // Handle the remainder portion that doesn't comprise full words + int i = numU4 * sizeof(uint4) + threadIdx.x; + if (i < remaining) { + checksum32 ^= inAligned[i]; + } + } + + // Fold the bytes of checksum32 + checksum32 = (checksum32 & 0xffU) ^ ((checksum32 >> 8) & 0xffU) ^ + ((checksum32 >> 16) & 0xffU) ^ ((checksum32 >> 24) & 0xffU); + + // Reduce within a warp + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage smem; + + checksum32 = BlockReduce(smem).Reduce(checksum32, ReduceXor()); + + if (threadIdx.x == 0) { + atomicXor(out, checksum32); + } +} + +template +__global__ void checksumBatch(InProvider in, uint32_t* out) { + auto batch = blockIdx.y; + out += batch; + + checksumSingle( + (const uint8_t*)in.getBatchStart(batch), in.getBatchSize(batch), out); +} + +template +void checksumBatch( + uint32_t numInBatch, + InProvider inProvider, + // size numInBatch + uint32_t* checksum_dev, + cudaStream_t stream) { + // zero out checksum before proceeding, as we aggregate with atomic xor + CUDA_VERIFY( + cudaMemsetAsync(checksum_dev, 0, sizeof(uint32_t) * numInBatch, stream)); + + constexpr uint32_t kThreads = 256; + + // We unfortunately don't know the per-batch element sizes in advance + // What is the maximum number of blocks to saturate the GPU just in case some + // per-batch members are big? + int maxBlocks = 0; + CUDA_VERIFY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &maxBlocks, checksumBatch, kThreads, 0)); + maxBlocks *= getCurrentDeviceProperties().multiProcessorCount; + + // The y block dimension will be for each batch element + uint32_t xBlocks = divUp(maxBlocks, numInBatch); + auto grid = dim3(xBlocks, numInBatch); + + checksumBatch + <<>>(inProvider, checksum_dev); + + CUDA_TEST_ERROR(); +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/ans_test.py b/thirdparty/dietgpu/dietgpu/ans_test.py new file mode 100644 index 000000000..7fa5daab9 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/ans_test.py @@ -0,0 +1,139 @@ +# Copyright (c) (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import random +import unittest + +import torch + +torch.ops.load_library("//dietgpu:dietgpu") + + +def run_test(dev, ts, checksum, temp_mem=None): + comp, sizes, _ = torch.ops.dietgpu.compress_data(False, ts, checksum, temp_mem) + for s, t in zip(sizes, ts): + t_bytes = t.numel() * t.element_size() + print(f"{t_bytes} bytes -> {s.item()} bytes ({s.item() / t_bytes}x)") + + # Truncate the output data to exactly the sizes that are used + # (i.e., we are testing that the byte sizes we report in compression are accurate) + truncated_comp = [] + for size, t in zip(sizes, [*comp]): + truncated_t = t.narrow(0, 0, size.item()).clone() + truncated_comp.append(truncated_t) + + out_ts = [] + for t in ts: + out_ts.append(torch.empty(t.size(), dtype=t.dtype, device=t.device)) + + if temp_mem is not None: + out_status = torch.empty([len(ts)], dtype=torch.uint8, device=dev) + out_sizes = torch.empty([len(ts)], dtype=torch.int32, device=dev) + + torch.ops.dietgpu.decompress_data( + False, truncated_comp, out_ts, checksum, temp_mem, out_status, out_sizes + ) + + for t, status, size in zip(ts, out_status, out_sizes): + assert status.item() + assert t.numel() * t.element_size() == size.item() + else: + torch.ops.dietgpu.decompress_data(False, truncated_comp, out_ts, checksum) + + for a, b in zip(ts, out_ts): + assert torch.equal(a, b) + + +class TestANSCodec(unittest.TestCase): + def test_codec(self): + dev = torch.device("cuda:0") + temp_mem = torch.empty([64 * 1024 * 1024], dtype=torch.uint8, device=dev) + + for dt in [torch.float32]: + for tm in [False, True]: + for checksum in [False, True]: + ts = [ + torch.normal(0, 1.0, [10000], dtype=dt, device=dev), + torch.normal(0, 1.0, [100000], dtype=dt, device=dev), + torch.normal(0, 1.0, [1000000], dtype=dt, device=dev), + ] + if tm: + run_test(dev, ts, checksum, temp_mem) + else: + run_test(dev, ts, checksum) + + def test_empty(self): + dev = torch.device("cuda:0") + ts = [torch.empty([0], dtype=torch.uint8, device=dev)] + comp_ts = torch.ops.dietgpu.compress_data_simple(False, ts, True) + + # should have a header + assert comp_ts[0].numel() > 0 + + decomp_ts = torch.ops.dietgpu.decompress_data_simple(False, comp_ts, True) + assert torch.equal(ts[0], decomp_ts[0]) + + def test_split_compress(self): + dev = torch.device("cuda:0") + temp_mem = torch.empty([64 * 1024 * 1024], dtype=torch.uint8, device=dev) + + for tries in range(5): + batch_size = random.randrange(1, 15) + sizes = [] + + sum_sizes = 0 + max_size = 0 + for i in range(batch_size): + size = random.randrange(1, 10000) + # meet required alignment + size += 4 - (size % 4) + + sizes.append(size) + sum_sizes += size + if size > max_size: + max_size = size + + t = torch.randint(0, 65, [sum_sizes], dtype=torch.uint8, device=dev) + sizes_t = torch.IntTensor(sizes) + splits = torch.split(t, sizes) + + comp_ts, _, _ = torch.ops.dietgpu.compress_data_split_size( + False, t, sizes_t, True, temp_mem + ) + decomp_ts = torch.ops.dietgpu.decompress_data_simple(False, comp_ts, True) + + for orig, decomp in zip(splits, decomp_ts): + assert torch.equal(orig, decomp) + + def test_split_decompress(self): + dev = torch.device("cuda:0") + temp_mem = torch.empty([64 * 1024 * 1024], dtype=torch.uint8, device=dev) + + for tries in range(5): + batch_size = random.randrange(1, 15) + sizes = [] + + sum_sizes = 0 + for i in range(batch_size): + size = random.randrange(1, 10000) + # meet required alignment + size += 4 - (size % 4) + + sizes.append(size) + sum_sizes += size + + t = torch.randint(0, 65, [sum_sizes], dtype=torch.uint8, device=dev) + sizes_t = torch.IntTensor(sizes) + + splits = torch.split(t, sizes) + comp_ts = torch.ops.dietgpu.compress_data_simple(False, splits, True) + + decomp_t = torch.empty([sum_sizes], dtype=torch.uint8, device=dev) + torch.ops.dietgpu.decompress_data_split_size( + False, comp_ts, decomp_t, sizes_t, True, temp_mem + ) + + assert torch.equal(t, decomp_t) diff --git a/thirdparty/dietgpu/dietgpu/benchmark.py b/thirdparty/dietgpu/dietgpu/benchmark.py new file mode 100644 index 000000000..807718a55 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/benchmark.py @@ -0,0 +1,223 @@ +# Copyright (c) (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Simple benchmarking script for both float and raw byte-wise ANS codecs in +# PyTorch using the asynchronous API, as applied to floating point data +# ~ N(0, 1) + +import torch + +torch.ops.load_library("/dietgpu/build/lib/libdietgpu.so") +dev = torch.device("cuda:0") + + +def calc_comp_ratio(input_ts, out_sizes): + total_input_size = 0 + total_comp_size = 0 + + for t, s in zip(input_ts, out_sizes): + total_input_size += t.numel() * t.element_size() + total_comp_size += s + + return total_input_size, total_comp_size, total_comp_size / total_input_size + + +def get_float_comp_timings(ts, num_runs=3): + tempMem = torch.empty([384 * 1024 * 1024], dtype=torch.uint8, device=dev) + + comp_time = 0 + decomp_time = 0 + total_size = 0 + comp_size = 0 + + # ignore first run timings + for i in range(1 + num_runs): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + rows, cols = torch.ops.dietgpu.max_float_compressed_output_size(ts) + + comp = torch.empty([rows, cols], dtype=torch.uint8, device=dev) + sizes = torch.zeros([len(ts)], dtype=torch.int, device=dev) + + start.record() + comp, sizes, memUsed = torch.ops.dietgpu.compress_data( + True, ts, False, tempMem, comp, sizes + ) + end.record() + + comp_size = 0 + + torch.cuda.synchronize() + if i > 0: + comp_time += start.elapsed_time(end) + + total_size, comp_size, _ = calc_comp_ratio(ts, sizes) + + out_ts = [] + for t in ts: + out_ts.append(torch.empty(t.size(), dtype=t.dtype, device=t.device)) + + # this takes a while + comp_ts = [*comp] + + out_status = torch.empty([len(ts)], dtype=torch.uint8, device=dev) + out_sizes = torch.empty([len(ts)], dtype=torch.int32, device=dev) + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + torch.ops.dietgpu.decompress_data( + True, comp_ts, out_ts, False, tempMem, out_status, out_sizes + ) + end.record() + + torch.cuda.synchronize() + if i > 0: + decomp_time += start.elapsed_time(end) + + # validate + for a, b in zip(ts, out_ts): + assert torch.equal(a, b) + + comp_time /= num_runs + decomp_time /= num_runs + + return comp_time, decomp_time, total_size, comp_size + + +def get_any_comp_timings(ts, num_runs=3): + tempMem = torch.empty([384 * 1024 * 1024], dtype=torch.uint8, device=dev) + + comp_time = 0 + decomp_time = 0 + total_size = 0 + comp_size = 0 + + # ignore first run timings + for i in range(1 + num_runs): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + rows, cols = torch.ops.dietgpu.max_any_compressed_output_size(ts) + + comp = torch.empty([rows, cols], dtype=torch.uint8, device=dev) + sizes = torch.zeros([len(ts)], dtype=torch.int, device=dev) + + start.record() + comp, sizes, memUsed = torch.ops.dietgpu.compress_data( + False, ts, False, tempMem, comp, sizes + ) + end.record() + + comp_size = 0 + + torch.cuda.synchronize() + comp_time = start.elapsed_time(end) + + total_size, comp_size, _ = calc_comp_ratio(ts, sizes) + + out_ts = [] + for t in ts: + out_ts.append(torch.empty(t.size(), dtype=t.dtype, device=t.device)) + + # this takes a while + comp_ts = [*comp] + + out_status = torch.empty([len(ts)], dtype=torch.uint8, device=dev) + out_sizes = torch.empty([len(ts)], dtype=torch.int32, device=dev) + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + torch.ops.dietgpu.decompress_data( + False, comp_ts, out_ts, False, tempMem, out_status, out_sizes + ) + end.record() + + torch.cuda.synchronize() + decomp_time = start.elapsed_time(end) + + for a, b in zip(ts, out_ts): + assert torch.equal(a, b) + + return comp_time, decomp_time, total_size, comp_size + + +for dt in [torch.bfloat16, torch.float16, torch.float32]: + # Non-batched + ts = [] + ts.append(torch.normal(0, 1.0, [128 * 512 * 1024], dtype=dt, device=dev)) + + c, dc, total_size, comp_size = get_float_comp_timings(ts) + ratio = comp_size / total_size + c_bw = (total_size / 1e9) / (c * 1e-3) + dc_bw = (total_size / 1e9) / (dc * 1e-3) + + print(f"Float codec non-batched perf [128 * 512 * 1024] {dt}") + print( + "comp time {:.3f} ms B/W {:.1f} GB/s, compression {} -> {} bytes ({:.4f}x) ".format( + c, c_bw, total_size, comp_size, ratio + ) + ) + print(f"decomp time {dc:.3f} ms B/W {dc_bw:.1f} GB/s") + + # Batched + ts = [] + for i in range(128): + ts.append(torch.normal(0, 1.0, [512 * 1024], dtype=dt, device=dev)) + + c, dc, total_size, comp_size = get_float_comp_timings(ts) + ratio = comp_size / total_size + bw = (total_size / 1e9) / (c * 1e-3) + dc_bw = (total_size / 1e9) / (dc * 1e-3) + + print(f"Float codec batched perf [128, [512 * 1024]] {dt}") + print( + "comp time {:.3f} ms B/W {:.1f} GB/s, compression {} -> {} bytes ({:.4f}x) ".format( + c, c_bw, total_size, comp_size, ratio + ) + ) + print(f"decomp time {dc:.3f} ms B/W {dc_bw:.1f} GB/s") + +print("\n") + +for dt in [torch.bfloat16, torch.float16, torch.float32]: + # Non-batched + ts = [] + ts.append(torch.normal(0, 1.0, [128 * 512 * 1024], dtype=dt, device=dev)) + + c, dc, total_size, comp_size = get_any_comp_timings(ts) + ratio = comp_size / total_size + c_bw = (total_size / 1e9) / (c * 1e-3) + dc_bw = (total_size / 1e9) / (dc * 1e-3) + + print(f"Raw ANS byte-wise non-batched perf [128 * 512 * 1024] {dt}") + print( + "comp time {:.3f} ms B/W {:.1f} GB/s, compression {} -> {} bytes ({:.4f}x) ".format( + c, c_bw, total_size, comp_size, ratio + ) + ) + print(f"decomp time {dc:.3f} ms B/W {dc_bw:.1f} GB/s") + + # Batched + ts = [] + for i in range(128): + ts.append(torch.normal(0, 1.0, [512 * 1024], dtype=dt, device=dev)) + + c, dc, total_size, comp_size = get_any_comp_timings(ts) + ratio = comp_size / total_size + c_bw = (total_size / 1e9) / (c * 1e-3) + dc_bw = (total_size / 1e9) / (dc * 1e-3) + + print(f"Raw ANS byte-wise batched perf [128, [512 * 1024]] {dt}") + print( + "comp time {:.3f} ms B/W {:.1f} GB/s, compression {} -> {} bytes ({:.4f}x) ".format( + c, c_bw, total_size, comp_size, ratio + ) + ) + print(f"decomp time {dc:.3f} ms B/W {dc_bw:.1f} GB/s") diff --git a/thirdparty/dietgpu/dietgpu/float/CMakeLists.txt b/thirdparty/dietgpu/dietgpu/float/CMakeLists.txt new file mode 100644 index 000000000..9b506f83e --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/float/CMakeLists.txt @@ -0,0 +1,40 @@ +add_library(gpu_float_compress SHARED + GpuFloatCompress.cu + GpuFloatDecompress.cu + GpuFloatInfo.cu +) +add_dependencies(gpu_float_compress + gpu_ans + dietgpu_utils +) + +target_include_directories(gpu_float_compress PUBLIC + $ +) +target_link_libraries(gpu_float_compress PUBLIC + gpu_ans + dietgpu_utils +) +target_link_libraries(gpu_float_compress PRIVATE + glog::glog +) +target_compile_options(gpu_float_compress PRIVATE $<$: + --generate-line-info + #--device-debug +>) + +enable_testing() +include(GoogleTest) + +add_executable(float_test FloatTest.cu) +target_link_libraries(float_test + gpu_float_compress + gtest_main +) +gtest_discover_tests(float_test) + + +get_property(GLOBAL_CUDA_ARCHITECTURES GLOBAL PROPERTY CUDA_ARCHITECTURES) +set_target_properties(gpu_float_compress float_test PROPERTIES + CUDA_ARCHITECTURES "${GLOBAL_CUDA_ARCHITECTURES}" +) diff --git a/thirdparty/dietgpu/dietgpu/float/GpuFloatCodec.h b/thirdparty/dietgpu/dietgpu/float/GpuFloatCodec.h new file mode 100644 index 000000000..0acb7abaa --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/float/GpuFloatCodec.h @@ -0,0 +1,294 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include "dietgpu/ans/GpuANSCodec.h" + +namespace dietgpu { + +class StackDeviceMemory; + +// The various floating point types we support for compression +enum class FloatType : uint32_t { + kUndefined = 0, + kFloat16 = 1, + kBFloat16 = 2, + kFloat32 = 3, +}; + +// Returns the maximum possible compressed size in bytes of an array of `size` +// float words of type `floatType`. Note that this will in fact be larger than +// size * sizeof(the float word type), as if something is uncompressible it will +// be expanded during compression. +// This can be used to bound memory consumption for the destination compressed +// buffer +uint32_t getMaxFloatCompressedSize(FloatType floatType, uint32_t size); + +struct FloatCodecConfig { + inline FloatCodecConfig() + : floatType(FloatType::kFloat16), + useChecksum(false), + is16ByteAligned(false) {} + + inline FloatCodecConfig( + FloatType ft, + const ANSCodecConfig& ansConf, + bool align, + bool checksum = false) + : floatType(ft), + useChecksum(checksum), + ansConfig(ansConf), + is16ByteAligned(align) { + // ANS-level checksumming is not allowed in float mode, only float level + // checksumming + assert(!ansConf.useChecksum); + } + + // What kind of floats are we compressing/decompressing? + FloatType floatType; + + // If true, we calculate a checksum on the uncompressed float input data to + // compression and store it in the archive, and on the decompression side + // post-decompression, we calculate a checksum on the decompressed data which + // is compared with the original stored in the archive. + // This is an optional feature useful if DietGPU data will be stored + // persistently on disk. + bool useChecksum; + + // ANS entropy coder parameters + // Checksumming will happen at the float level, not the ANS level, as + // decompression from ANS is immediately consumed by the float layer, + // so ansConfig.useChecksum being true is an error. + ANSCodecConfig ansConfig; + + // Are all all float input pointers/offsets (compress) or output + // pointers/offsets (decompress) are aligned to 16 bytes? + // + // If so, we can accelerate the decompression. If not, the float addresses + // should be aligned to the floating point word size (e.g., + // FloatType::kFloat16, all are assumed sizeof(float16) == 2 byte aligned) + bool is16ByteAligned; +}; + +// Same config options for compression and decompression for now +using FloatCompressConfig = FloatCodecConfig; +using FloatDecompressConfig = FloatCodecConfig; + +enum class FloatDecompressError : uint32_t { + None = 0, + ChecksumMismatch = 1, +}; + +// Error status for decompression +struct FloatDecompressStatus { + inline FloatDecompressStatus() : error(FloatDecompressError::None) {} + + // Overall error status + FloatDecompressError error; + + // Error-specific information for the batch + std::vector> errorInfo; +}; + +// +// Encode +// + +void floatCompress( + StackDeviceMemory& res, + // How should we compress our data? + const FloatCompressConfig& config, + + // Optional region of device temporary memory provided for our use + // Usage of this region of memory is ordered with respect to `stream`, + // so can be reused after execution of the kernels that we launch on + // that stream. + // If either nullptr is passed, or if the size is not sufficient for our + // needs, we will internally call cudaMalloc and cudaFree and will + // print warnings to stderr in this case. Providing a sufficient sized chunk + // of temp memory avoids the h2d synchronization overhead of + // cudaMalloc/cudaFree. + // The base address should be aligned to 16 bytes + // void* tempMem_dev, + // // The size in bytes of tempMem + // size_t tempMemBytes, + + // Number of separate, independent compression problems + uint32_t numInBatch, + + // Host array with addresses of device pointers comprising the batch + const void** in, + // Host array with sizes of batch members (in float words, NOT bytes) + const uint32_t* inSize, + + // Host array with addresses of device pointers of outputs, each pointing + // to a valid region of memory of at least size + // getMaxFloatCompressedSize(ft, inSize[i]) + void** out, + // Device memory array of size numInBatch (optional) + // Provides the size of actual used memory in bytes for each batch element + uint32_t* outSize_dev, + + // stream on the current device on which this runs + cudaStream_t stream); + +void floatCompressSplitSize( + StackDeviceMemory& res, + // How should we compress our data? + const FloatCompressConfig& config, + + // Number of separate, independent compression problems + uint32_t numInBatch, + + // Device pointer into a valid region of memory of size at least + // sum_i(inSplitSizes[i]) float words. + const void* in_dev, + + // Host array with the size (in floating point words) of the input + // floating point arrays in the batch. + // Each array in the batch is read starting at offset splitSize[i]. + const uint32_t* inSplitSizes, + + // Device pointer to a matrix of at least size + // numInBatch x getMaxFloatCompressedSize(ft, max(inSplitSizes[i])) + void* out_dev, + + // Stride between rows in bytes + uint32_t outStride, + + // Device memory array of size numInBatch (optional) + // Provides the size of actual used memory in bytes for each batch element + uint32_t* outSize_dev, + + // stream on the current device on which this runs + cudaStream_t stream); + +// +// Decode +// + +FloatDecompressStatus floatDecompress( + StackDeviceMemory& res, + // How should we decompress our data? + const FloatDecompressConfig& config, + // Number of separate, independent compression problems + uint32_t numInBatch, + + // Host array with addresses of device pointers comprising the batch + const void** in, + + // Host array with addresses of device pointers of outputs, each pointing + // to a valid region of memory of at least size outCapacity[i] + void** out, + // Host memory array of size numInBatch (optional) + // Provides the maximum amount of space present for decopressing each batch + // problem + const uint32_t* outCapacity, + + // Decode success/fail status (optional, can be nullptr) + // If present, this is a device pointer to an array of length numInBatch, + // with true/false for whether or not decompression status was successful + // FIXME: not bool due to issues with __nv_bool + uint8_t* outSuccess_dev, + + // Decode size status (optional, can be nullptr) + // If present, this is a device pointer to an array of length numInBatch, + // with either the size decompressed reported if successful, or the required + // size reported if our outPerBatchCapacity was insufficient. Size reported + // is in float words + uint32_t* outSize_dev, + + // stream on the current device on which this runs + cudaStream_t stream); + +FloatDecompressStatus floatDecompressSplitSize( + StackDeviceMemory& res, + // How should we decompress our data? + const FloatDecompressConfig& config, + // Number of separate, independent compression problems + uint32_t numInBatch, + + // Host array with addresses of device pointers comprising the batch + const void** in, + + // Device pointer into a valid region of memory of size at least + // sum_i(outSplitSizes[i]) float words + void* out_dev, + + // Host array with the size (in floating point words) of the output + // decompressed floating point arrays in the batch. + // Each decompressed array in the batch is written at offset + // outSplitSizes[i]. + // The decompressed size must match exactly these sizes, otherwise there's a + // decompression error + const uint32_t* outSplitSizes, + + // Decode success/fail status (optional, can be nullptr) + // If present, this is a device pointer to an array of length numInBatch, + // with true/false for whether or not decompression status was successful + // FIXME: not bool due to issues with __nv_bool + uint8_t* outSuccess_dev, + + // Decode size status (optional, can be nullptr) + // If present, this is a device pointer to an array of length numInBatch, + // with either the size decompressed reported if successful, or the required + // size reported if our outPerBatchCapacity was insufficient. Size reported + // is in float words + uint32_t* outSize_dev, + + // stream on the current device on which this runs + cudaStream_t stream); + +// +// Information +// + +void floatGetCompressedInfo( + StackDeviceMemory& res, + // Host array with addresses of device pointers comprising the batch of + // compressed float data + const void** in, + // Number of compressed arrays in the batch + uint32_t numInBatch, + // Optional device array to receive the resulting sizes. 0 is reported if + // the compresed data is not as expected, otherwise the size is reported in + // floating point words + uint32_t* outSizes_dev, + // Optional device array to receive the resulting FloatTypes. + // FloatType::kUndefined is reported if the compresed data is not as + // expected, otherwise the size is reported in floating point words + uint32_t* outTypes_dev, + // Optional device array to receive pre-compression checksums stored in the + // archive, if the checksum feature was enabled. + uint32_t* outChecksum_dev, + // stream on the current device on which this runs + cudaStream_t stream); + +void floatGetCompressedInfoDevice( + StackDeviceMemory& res, + // Device array with addresses of device pointers comprising the batch of + // compressed float data + const void** in_dev, + // Number of compressed arrays in the batch + uint32_t numInBatch, + // Optional device array to receive the resulting sizes. 0 is reported if + // the compresed data is not as expected, otherwise the size is reported in + // floating point words + uint32_t* outSizes_dev, + // Optional device array to receive the resulting FloatTypes. + // FloatType::kUndefined is reported if the compresed data is not as + // expected, otherwise the size is reported in floating point words + uint32_t* outTypes_dev, + // Optional device array to receive pre-compression checksums stored in the + // archive, if the checksum feature was enabled. + uint32_t* outChecksum_dev, + // stream on the current device on which this runs + cudaStream_t stream); + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/float/GpuFloatCompress.cu b/thirdparty/dietgpu/dietgpu/float/GpuFloatCompress.cu new file mode 100644 index 000000000..45cd22488 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/float/GpuFloatCompress.cu @@ -0,0 +1,161 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "dietgpu/ans/BatchProvider.cuh" +#include "dietgpu/float/GpuFloatCodec.h" +#include "dietgpu/float/GpuFloatCompress.cuh" +#include "dietgpu/float/GpuFloatUtils.cuh" +#include "dietgpu/utils/DeviceUtils.h" +#include "dietgpu/utils/StackDeviceMemory.h" +#include "dietgpu/utils/StaticUtils.h" + +#include +#include +#include +#include + +namespace dietgpu { + +uint32_t getMaxFloatCompressedSize(FloatType floatType, uint32_t size) { + // kNotCompressed bytes per float are simply stored uncompressed + // rounded up to 16 bytes to ensure alignment of the following ANS data + // portion + uint32_t baseSize = sizeof(GpuFloatHeader) + getMaxCompressedSize(size); + + switch (floatType) { + case FloatType::kFloat16: + baseSize += FloatTypeInfo::getUncompDataSize(size); + break; + case FloatType::kBFloat16: + baseSize += FloatTypeInfo::getUncompDataSize(size); + break; + case FloatType::kFloat32: + baseSize += FloatTypeInfo::getUncompDataSize(size); + break; + default: + CHECK(false); + break; + } + + return baseSize; +} + +void floatCompress( + StackDeviceMemory& res, + const FloatCompressConfig& config, + uint32_t numInBatch, + const void** in, + const uint32_t* inSize, + void** out, + uint32_t* outSize_dev, + cudaStream_t stream) { + // Get the total and maximum input size + uint32_t maxSize = 0; + + for (uint32_t i = 0; i < numInBatch; ++i) { + maxSize = std::max(maxSize, inSize[i]); + } + + // Copy data to device + // To reduce latency, we prefer to coalesce all data together and copy as one + // contiguous chunk + static_assert(sizeof(void*) == sizeof(uintptr_t), ""); + static_assert(sizeof(uint32_t) <= sizeof(uintptr_t), ""); + + // in, inSize, out + auto params_dev = res.alloc(stream, numInBatch * 3); + auto params_host = + std::unique_ptr(new uintptr_t[3 * numInBatch]); + + std::memcpy(¶ms_host[0], in, numInBatch * sizeof(void*)); + std::memcpy(¶ms_host[numInBatch], inSize, numInBatch * sizeof(uint32_t)); + std::memcpy(¶ms_host[2 * numInBatch], out, numInBatch * sizeof(void*)); + + CUDA_VERIFY(cudaMemcpyAsync( + params_dev.data(), + params_host.get(), + 3 * numInBatch * sizeof(uintptr_t), + cudaMemcpyHostToDevice, + stream)); + + auto in_dev = (const void**)params_dev.data(); + auto inSize_dev = (const uint32_t*)(params_dev.data() + numInBatch); + auto out_dev = (void**)(params_dev.data() + 2 * numInBatch); + + auto inProvider = BatchProviderPointer((void**)in_dev, inSize_dev); + auto outProvider = BatchProviderPointer(out_dev); + + floatCompressDevice( + res, + config, + numInBatch, + inProvider, + maxSize, + outProvider, + outSize_dev, + stream); +} + +void floatCompressSplitSize( + StackDeviceMemory& res, + const FloatCompressConfig& config, + uint32_t numInBatch, + const void* in_dev, + const uint32_t* inSplitSizes, + void* out_dev, + uint32_t outStride, + uint32_t* outSize_dev, + cudaStream_t stream) { + auto floatWordSize = getWordSizeFromFloatType(config.floatType); + + auto splitSizeHost = std::vector(numInBatch * 2); + auto splitSize = splitSizeHost.data(); + auto splitSizePrefix = splitSizeHost.data() + numInBatch; + uint32_t maxSplitSize = 0; + + for (uint32_t i = 0; i < numInBatch; ++i) { + auto size = inSplitSizes[i]; + + splitSize[i] = size; + if (i > 0) { + splitSizePrefix[i] = splitSizePrefix[i - 1] + splitSize[i - 1]; + } + + maxSplitSize = std::max(size, maxSplitSize); + } + + // Copy data to device + // splitSize, splitSizePrefix + auto sizes_dev = res.alloc(stream, splitSizeHost.size()); + + CUDA_VERIFY(cudaMemcpyAsync( + sizes_dev.data(), + splitSizeHost.data(), + splitSizeHost.size() * sizeof(uint32_t), + cudaMemcpyHostToDevice, + stream)); + + auto inProvider = BatchProviderSplitSize( + (void*)in_dev, + sizes_dev.data(), + sizes_dev.data() + numInBatch, + floatWordSize); + + auto outProvider = BatchProviderStride(out_dev, outStride); + + floatCompressDevice( + res, + config, + numInBatch, + inProvider, + maxSplitSize, + outProvider, + outSize_dev, + stream); +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/float/GpuFloatCompress.cuh b/thirdparty/dietgpu/dietgpu/float/GpuFloatCompress.cuh new file mode 100644 index 000000000..5f93e71ba --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/float/GpuFloatCompress.cuh @@ -0,0 +1,581 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "dietgpu/ans/BatchProvider.cuh" +#include "dietgpu/ans/GpuANSEncode.cuh" +#include "dietgpu/ans/GpuANSUtils.cuh" +#include "dietgpu/ans/GpuChecksum.cuh" +#include "dietgpu/float/GpuFloatCodec.h" +#include "dietgpu/float/GpuFloatUtils.cuh" +#include "dietgpu/utils/DeviceDefs.cuh" +#include "dietgpu/utils/DeviceUtils.h" +#include "dietgpu/utils/StackDeviceMemory.h" +#include "dietgpu/utils/StaticUtils.h" + +#include +#include +#include +#include + +namespace dietgpu { + +template +struct SplitFloatNonAligned { + static __device__ void split( + const typename FloatTypeInfo::WordT* in, + uint32_t size, + typename FloatTypeInfo::CompT* compOut, + typename FloatTypeInfo::NonCompT* nonCompOut, + uint32_t* warpHistogram) { + using FTI = FloatTypeInfo; + using CompT = typename FTI::CompT; + using NonCompT = typename FTI::NonCompT; + + for (uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; + i += gridDim.x * blockDim.x) { + CompT comp; + NonCompT nonComp; + FTI::split(in[i], comp, nonComp); + + atomicAdd(&warpHistogram[comp], 1); + + compOut[i] = comp; + nonCompOut[i] = nonComp; + } + } +}; + +template +struct SplitFloatNonAligned { + static __device__ void split( + const typename FloatTypeInfo::WordT* in, + uint32_t size, + typename FloatTypeInfo::CompT* compOut, + typename FloatTypeInfo::NonCompT* nonCompOut, + uint32_t* warpHistogram) { + using FTI = FloatTypeInfo; + using CompT = typename FTI::CompT; + using NonCompT = typename FTI::NonCompT; + + // Where the low order 2 bytes are written + uint16_t* nonComp2Out = (uint16_t*)nonCompOut; + + // Where the high order byte is written + uint8_t* nonComp1Out = (uint8_t*)(nonComp2Out + roundUp(size, 8)); + + for (uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; + i += gridDim.x * blockDim.x) { + CompT comp; + NonCompT nonComp; + FTI::split(in[i], comp, nonComp); + + nonComp2Out[i] = nonComp & 0xffffU; + nonComp1Out[i] = nonComp >> 16; + compOut[i] = comp; + + atomicAdd(&warpHistogram[comp], 1); + } + } +}; + +template +struct SplitFloatAligned16 { + static __device__ void split( + const typename FloatTypeInfo::WordT* __restrict__ in, + uint32_t size, + typename FloatTypeInfo::CompT* __restrict__ compOut, + typename FloatTypeInfo::NonCompT* __restrict__ nonCompOut, + uint32_t* warpHistogram) { + using FTI = FloatTypeInfo; + + using WordT = typename FTI::WordT; + using CompT = typename FTI::CompT; + using NonCompT = typename FTI::NonCompT; + using VecT = typename FTI::VecT; + using CompVecT = typename FTI::CompVecT; + using NonCompVecT = typename FTI::NonCompVecT; + + constexpr int kOuterUnroll = 2; + constexpr int kInnerUnroll = sizeof(VecT) / sizeof(WordT); + + const VecT* inV = (const VecT*)in; + CompVecT* compOutV = (CompVecT*)compOut; + NonCompVecT* nonCompOutV = (NonCompVecT*)nonCompOut; + + // Each block handles Threads * kOuterUnroll * kInnerUnroll inputs/outputs + // at a time, or Threads * kOuterUnroll 16-byte words at a time + + constexpr int kWordsPerBlock = Threads * kOuterUnroll; + constexpr int kFloatsPerBlock = kWordsPerBlock * kInnerUnroll; + uint32_t fullBlocks = divDown(size, kFloatsPerBlock); + + // Handle by block + uint32_t startBlock = blockIdx.x * kWordsPerBlock; + inV += startBlock + threadIdx.x; + compOutV += startBlock + threadIdx.x; + nonCompOutV += startBlock + threadIdx.x; + + for (uint32_t b = blockIdx.x; b < fullBlocks; b += gridDim.x, + inV += gridDim.x * kWordsPerBlock, + compOutV += gridDim.x * kWordsPerBlock, + nonCompOutV += gridDim.x * kWordsPerBlock) { + VecT v[kOuterUnroll]; + +#pragma unroll + for (uint32_t i = 0; i < kOuterUnroll; ++i) { + v[i] = inV[i * Threads]; + } + + CompVecT compV[kOuterUnroll]; + NonCompVecT nonCompV[kOuterUnroll]; + +#pragma unroll + for (uint32_t i = 0; i < kOuterUnroll; ++i) { +#pragma unroll + for (int j = 0; j < kInnerUnroll; ++j) { + CompT comp; + NonCompT nonComp; + FTI::split(v[i].x[j], comp, nonComp); + + atomicAdd(&warpHistogram[comp], 1); + + compV[i].x[j] = comp; + nonCompV[i].x[j] = nonComp; + } + } + +#pragma unroll + for (uint32_t i = 0; i < kOuterUnroll; ++i) { + compOutV[i * Threads] = compV[i]; + nonCompOutV[i * Threads] = nonCompV[i]; + } + } + + // Handle last (partial) block + for (uint32_t i = + fullBlocks * kFloatsPerBlock + blockIdx.x * Threads + threadIdx.x; + i < size; + i += gridDim.x * Threads) { + CompT comp; + NonCompT nonComp; + FTI::split(in[i], comp, nonComp); + + atomicAdd(&warpHistogram[comp], 1); + + compOut[i] = comp; + nonCompOut[i] = nonComp; + } + } +}; + +// float32 specialization +template +struct SplitFloatAligned16 { + static __device__ void split( + const typename FloatTypeInfo::WordT* __restrict__ in, + uint32_t size, + typename FloatTypeInfo::CompT* __restrict__ compOut, + typename FloatTypeInfo< + FloatType::kFloat32>::NonCompT* __restrict__ nonCompOut, + uint32_t* warpHistogram) { + using FTI = FloatTypeInfo; + + using WordT = typename FTI::WordT; + using CompT = typename FTI::CompT; + using NonCompT = typename FTI::NonCompT; + + constexpr int kOuterUnroll = 1; + constexpr int kInnerUnroll = sizeof(uint32x4) / sizeof(uint32_t); + + auto inV = (const uint32x4*)in; + auto compOutV = (uint8x4*)compOut; + + auto nonCompOut2 = (uint16_t*)nonCompOut; + auto nonCompOut1 = (uint8_t*)(nonCompOut2 + roundUp(size, 8)); + + auto nonCompOutV2 = (uint16x4*)nonCompOut2; + auto nonCompOutV1 = (uint8x4*)nonCompOut1; + + // Each block handles Threads * kOuterUnroll * kInnerUnroll inputs/outputs + // at a time, or Threads * kOuterUnroll 16-byte words at a time + constexpr int kWordsPerBlock = Threads * kOuterUnroll; + constexpr int kFloatsPerBlock = kWordsPerBlock * kInnerUnroll; + uint32_t fullBlocks = divDown(size, kFloatsPerBlock); + + // Handle by block + uint32_t startBlock = blockIdx.x * kWordsPerBlock; + inV += startBlock + threadIdx.x; + compOutV += startBlock + threadIdx.x; + nonCompOutV2 += startBlock + threadIdx.x; + nonCompOutV1 += startBlock + threadIdx.x; + + for (uint32_t b = blockIdx.x; b < fullBlocks; b += gridDim.x, + inV += gridDim.x * kWordsPerBlock, + compOutV += gridDim.x * kWordsPerBlock, + nonCompOutV2 += gridDim.x * kWordsPerBlock, + nonCompOutV1 += gridDim.x * kWordsPerBlock) { + uint32x4 v[kOuterUnroll]; + +#pragma unroll + for (uint32_t i = 0; i < kOuterUnroll; ++i) { + v[i] = inV[i * Threads]; + } + + uint8x4 compV[kOuterUnroll]; + uint32x4 nonCompV[kOuterUnroll]; + +#pragma unroll + for (uint32_t i = 0; i < kOuterUnroll; ++i) { +#pragma unroll + for (int j = 0; j < kInnerUnroll; ++j) { + CompT comp; + NonCompT nonComp; + FTI::split(v[i].x[j], comp, nonComp); + + atomicAdd(&warpHistogram[comp], 1); + + compV[i].x[j] = comp; + nonCompV[i].x[j] = nonComp; + } + } + +#pragma unroll + for (uint32_t i = 0; i < kOuterUnroll; ++i) { + compOutV[i * Threads] = compV[i]; + + uint16x4 nonCompV2; + uint8x4 nonCompV1; + for (int j = 0; j < kInnerUnroll; ++j) { + nonCompV2.x[j] = nonCompV[i].x[j] & 0xffffU; + nonCompV1.x[j] = nonCompV[i].x[j] >> 16; + } + + nonCompOutV2[i * Threads] = nonCompV2; + nonCompOutV1[i * Threads] = nonCompV1; + } + } + + // Handle last (partial) block + for (uint32_t i = + fullBlocks * kFloatsPerBlock + blockIdx.x * Threads + threadIdx.x; + i < size; + i += gridDim.x * Threads) { + CompT comp; + NonCompT nonComp; + FTI::split(in[i], comp, nonComp); + + atomicAdd(&warpHistogram[comp], 1); + + compOut[i] = comp; + nonCompOut2[i] = nonComp & 0xffffU; + nonCompOut1[i] = nonComp >> 16; + } + } +}; + +template < + typename InProvider, + typename NonCompProvider, + FloatType FT, + int Threads> +__global__ void splitFloat( + InProvider inProvider, + bool useChecksum, + const uint32_t* __restrict__ checksum, + void* __restrict__ compOut, + uint32_t compOutStride, + NonCompProvider nonCompProvider, + uint32_t* __restrict__ histogramOut) { + using WordT = typename FloatTypeInfo::WordT; + using CompT = typename FloatTypeInfo::CompT; + using NonCompT = typename FloatTypeInfo::NonCompT; + + constexpr int kWarps = Threads / kWarpSize; + static_assert(Threads == kNumSymbols, ""); + + auto batch = blockIdx.y; + auto warpId = threadIdx.x / kWarpSize; + + histogramOut += batch * kNumSymbols; + checksum += batch; + + // +1 in order to force very common symbols that could overlap into different + // banks between different warps + __shared__ uint32_t histogram[kWarps][kNumSymbols + 1]; + +#pragma unroll + for (int i = 0; i < kWarps; ++i) { + histogram[i][threadIdx.x] = 0; + } + + __syncthreads(); + + uint32_t* warpHistogram = histogram[warpId]; + + auto curIn = (const WordT*)inProvider.getBatchStart(batch); + auto headerOut = (GpuFloatHeader*)nonCompProvider.getBatchStart(batch); + auto curCompOut = (CompT*)compOut + compOutStride * batch; + auto curSize = inProvider.getBatchSize(batch); + + // Write size as a header + if (blockIdx.x == 0 && threadIdx.x == 0) { + GpuFloatHeader h; + h.setMagicAndVersion(); + h.size = curSize; + h.setFloatType(FT); + h.setUseChecksum(useChecksum); + + if (useChecksum) { + h.setChecksum(*checksum); + } + + *headerOut = h; + } + + auto curNonCompOut = (NonCompT*)(headerOut + 1); + + // How many bytes are before the point where we are 16 byte aligned? + auto nonAlignedBytes = getAlignmentRoundUp(curIn); + + if (nonAlignedBytes > 0) { + SplitFloatNonAligned::split( + curIn, curSize, curCompOut, curNonCompOut, warpHistogram); + } else { + SplitFloatAligned16::split( + curIn, curSize, curCompOut, curNonCompOut, warpHistogram); + } + + // Accumulate warp histogram data and write into the gmem histogram + __syncthreads(); + + uint32_t sum = histogram[0][threadIdx.x]; +#pragma unroll + for (int j = 1; j < kWarps; ++j) { + sum += histogram[j][threadIdx.x]; + } + + // The count for the thread's bucket could be 0 + if (sum) { + atomicAdd(&histogramOut[threadIdx.x], sum); + } +} + +// Update the final byte counts for the batch to take into account the +// uncompressed and compressed portions +template +__global__ void +incOutputSizes(InProvider inProvider, uint32_t* outSize, uint32_t numInBatch) { + uint32_t batch = blockIdx.x * blockDim.x + threadIdx.x; + if (batch < numInBatch) { + outSize[batch] += sizeof(GpuFloatHeader) + + FloatTypeInfo::getUncompDataSize(inProvider.getBatchSize(batch)); + } +} + +// Provides the input data to ANS compression +template +struct FloatANSInProvider { + using Writer = BatchWriter; + + __host__ + FloatANSInProvider(void* ptr_dev, uint32_t stride, SizeProvider& sizeProvider) + : ptr_dev_(ptr_dev), stride_(stride), sizeProvider_(sizeProvider) {} + + __device__ void* getBatchStart(uint32_t batch) { + return (uint8_t*)ptr_dev_ + batch * stride_; + } + + __device__ const void* getBatchStart(uint32_t batch) const { + return (uint8_t*)ptr_dev_ + batch * stride_; + } + + __device__ BatchWriter getWriter(uint32_t batch) { + return BatchWriter(getBatchStart(batch)); + } + + __device__ uint32_t getBatchSize(uint32_t batch) { + return sizeProvider_.getBatchSize(batch); + } + + void* ptr_dev_; + uint32_t stride_; + SizeProvider sizeProvider_; +}; + +// Provides the output data to ANS compression +template +struct FloatANSOutProvider { + using Writer = BatchWriter; + using FTI = FloatTypeInfo; + + __host__ FloatANSOutProvider( + OutProvider& outProvider, + SizeProvider& sizeProvider) + : outProvider_(outProvider), sizeProvider_(sizeProvider) {} + + __device__ void* getBatchStart(uint32_t batch) { + uint8_t* p = (uint8_t*)outProvider_.getBatchStart(batch); + + // Increment the pointer to past the floating point data + ((GpuFloatHeader*)p)->checkMagicAndVersion(); + return p + sizeof(GpuFloatHeader) + + FTI::getUncompDataSize(sizeProvider_.getBatchSize(batch)); + } + + __device__ const void* getBatchStart(uint32_t batch) const { + const uint8_t* p = (const uint8_t*)outProvider_.getBatchStart(batch); + + // Increment the pointer to past the floating point data + ((GpuFloatHeader*)p)->checkMagicAndVersion(); + return p + sizeof(GpuFloatHeader) + + FTI::getUncompDataSize(sizeProvider_.getBatchSize(batch)); + } + + __device__ BatchWriter getWriter(uint32_t batch) { + return BatchWriter(getBatchStart(batch)); + } + + OutProvider outProvider_; + SizeProvider sizeProvider_; +}; + +template +void floatCompressDevice( + StackDeviceMemory& res, + const FloatCompressConfig& config, + uint32_t numInBatch, + InProvider& inProvider, + uint32_t maxSize, + OutProvider& outProvider, + uint32_t* outSize_dev, + cudaStream_t stream) { + auto maxUncompressedWords = maxSize / sizeof(ANSDecodedT); + uint32_t maxNumCompressedBlocks = + divUp(maxUncompressedWords, kDefaultBlockSize); + + // Compute checksum on input data (optional) + auto checksum_dev = res.alloc(stream, numInBatch); + + // not allowed in float mode + assert(!config.ansConfig.useChecksum); + + if (config.useChecksum) { + checksumBatch(numInBatch, inProvider, checksum_dev.data(), stream); + } + + // Temporary space for the extracted exponents; all rows must be 16 byte + // aligned + uint32_t compRowStride = roundUp(maxSize, sizeof(uint4)); + auto toComp_dev = res.alloc(stream, numInBatch * compRowStride); + + // We calculate a histogram of the symbols to be compressed as part of + // extracting the compressible symbol from the float + auto histogram_dev = res.alloc(stream, numInBatch * kNumSymbols); + + // zero out buckets before proceeding, as we aggregate with atomic adds + CUDA_VERIFY(cudaMemsetAsync( + histogram_dev.data(), + 0, + sizeof(uint32_t) * numInBatch * kNumSymbols, + stream)); + +#define RUN_SPLIT(FLOAT_TYPE) \ + do { \ + constexpr int kBlock = 256; \ + auto& props = getCurrentDeviceProperties(); \ + int maxBlocksPerSM = 0; \ + CUDA_VERIFY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( \ + &maxBlocksPerSM, \ + splitFloat, \ + kBlock, \ + 0)); \ + uint32_t maxGrid = maxBlocksPerSM * props.multiProcessorCount; \ + uint32_t perBatchGrid = 4 * divUp(maxGrid, numInBatch); \ + auto grid = dim3(perBatchGrid, numInBatch); \ + \ + splitFloat \ + <<>>( \ + inProvider, \ + config.useChecksum, \ + checksum_dev.data(), \ + toComp_dev.data(), \ + compRowStride, \ + outProvider, \ + histogram_dev.data()); \ + } while (false) + + switch (config.floatType) { + case FloatType::kFloat16: + RUN_SPLIT(FloatType::kFloat16); + break; + case FloatType::kBFloat16: + RUN_SPLIT(FloatType::kBFloat16); + break; + case FloatType::kFloat32: + RUN_SPLIT(FloatType::kFloat32); + break; + default: + assert(false); + break; + } + +#undef RUN_SPLIT + + // outSize as reported by ansEncode is just the ANS-encoded portion of the + // data. + // We need to increment the sizes by the uncompressed portion (header plus + // uncompressed float data) with incOutputSizes +#define RUN_ANS(FT) \ + do { \ + auto inProviderANS = FloatANSInProvider( \ + toComp_dev.data(), compRowStride, inProvider); \ + \ + auto outProviderANS = FloatANSOutProvider( \ + outProvider, inProvider); \ + \ + ansEncodeBatchDevice( \ + res, \ + config.ansConfig, \ + numInBatch, \ + inProviderANS, \ + histogram_dev.data(), \ + maxSize, \ + outProviderANS, \ + outSize_dev, \ + stream); \ + \ + incOutputSizes<<>>( \ + inProvider, outSize_dev, numInBatch); \ + \ + } while (false) + + // We have written the non-compressed portions of the floats into the output, + // along with a header that indicates how many floats there are. + // For compression, we need to increment the address in which the compressed + // outputs are written. + + switch (config.floatType) { + case FloatType::kFloat16: + RUN_ANS(FloatType::kFloat16); + break; + case FloatType::kBFloat16: + RUN_ANS(FloatType::kBFloat16); + break; + case FloatType::kFloat32: + RUN_ANS(FloatType::kFloat32); + break; + default: + assert(false); + break; + } + +#undef RUN_ANS + + CUDA_TEST_ERROR(); +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/float/GpuFloatDecompress.cu b/thirdparty/dietgpu/dietgpu/float/GpuFloatDecompress.cu new file mode 100644 index 000000000..3c83cd553 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/float/GpuFloatDecompress.cu @@ -0,0 +1,181 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "dietgpu/ans/BatchProvider.cuh" +#include "dietgpu/float/GpuFloatCodec.h" +#include "dietgpu/float/GpuFloatDecompress.cuh" +#include "dietgpu/utils/DeviceUtils.h" +#include "dietgpu/utils/StackDeviceMemory.h" + +#include +#include +#include +#include +#include + +namespace dietgpu { + +FloatDecompressStatus floatDecompress( + StackDeviceMemory& res, + const FloatDecompressConfig& config, + uint32_t numInBatch, + const void** in, + void** out, + const uint32_t* outCapacity, + uint8_t* outSuccess_dev, + uint32_t* outSize_dev, + cudaStream_t stream) { + // If the batch size is <= kBSLimit, we avoid cudaMemcpy and send all data at + // kernel launch + constexpr int kLimit = 128; + + // Investigate all of the output pointers; are they 16 byte aligned? If so, we + // can do everything in a single pass + bool is16ByteAligned = true; + for (int i = 0; i < numInBatch; ++i) { + if (reinterpret_cast(out[i]) % 16 != 0) { + is16ByteAligned = false; + break; + } + } + + auto updatedConfig = config; + updatedConfig.is16ByteAligned = is16ByteAligned; + + // We need a max capacity estimate before proceeding, for temporary memory + // allocations + uint32_t maxCapacity = 0; + for (uint32_t i = 0; i < numInBatch; ++i) { + maxCapacity = std::max(maxCapacity, outCapacity[i]); + } + + if (numInBatch <= kLimit) { + // We can do everything in a single pass without a h2d memcpy + auto inProvider = + BatchProviderInlinePointer(numInBatch, (void**)in); + auto outProvider = BatchProviderInlinePointerCapacity( + numInBatch, out, outCapacity); + + return floatDecompressDevice( + res, + updatedConfig, + numInBatch, + inProvider, + outProvider, + maxCapacity, + outSuccess_dev, + outSize_dev, + stream); + } + + // Copy data to device + // To reduce latency, we prefer to coalesce all data together and copy as one + // contiguous chunk + static_assert(sizeof(void*) == sizeof(uintptr_t)); + static_assert(sizeof(uint32_t) <= sizeof(uintptr_t)); + + // in, out, outCapacity + auto params_dev = res.alloc(stream, numInBatch * 3); + auto params_host = + std::unique_ptr(new uintptr_t[3 * numInBatch]); + + std::memcpy(¶ms_host[0], in, numInBatch * sizeof(void*)); + std::memcpy(¶ms_host[numInBatch], out, numInBatch * sizeof(void*)); + std::memcpy( + ¶ms_host[2 * numInBatch], outCapacity, numInBatch * sizeof(uint32_t)); + + CUDA_VERIFY(cudaMemcpyAsync( + params_dev.data(), + params_host.get(), + 3 * numInBatch * sizeof(uintptr_t), + cudaMemcpyHostToDevice, + stream)); + + auto in_dev = params_dev.data(); + auto out_dev = params_dev.data() + numInBatch; + auto outCapacity_dev = (const uint32_t*)(params_dev.data() + 2 * numInBatch); + + auto inProvider = BatchProviderPointer((void**)in_dev); + auto outProvider = BatchProviderPointer((void**)out_dev, outCapacity_dev); + + return floatDecompressDevice( + res, + updatedConfig, + numInBatch, + inProvider, + outProvider, + maxCapacity, + outSuccess_dev, + outSize_dev, + stream); +} + +FloatDecompressStatus floatDecompressSplitSize( + StackDeviceMemory& res, + const FloatDecompressConfig& config, + uint32_t numInBatch, + const void** in, + void* out, + const uint32_t* outSplitSizes, + uint8_t* outSuccess_dev, + uint32_t* outSize_dev, + cudaStream_t stream) { + auto floatWordSize = getWordSizeFromFloatType(config.floatType); + + // Concatenate splitSize and splitSizePrefix together for a single h2d copy + auto splitSizeHost = std::vector(numInBatch * 2); + auto splitSize = splitSizeHost.data(); + auto splitSizePrefix = splitSizeHost.data() + numInBatch; + uint32_t maxSplitSize = 0; + + bool is16ByteAligned = isPointerAligned(out, 16); + + for (uint32_t i = 0; i < numInBatch; ++i) { + auto size = outSplitSizes[i]; + + // If we only have one tensor in the batch, we only care if the start + // pointer is 16 byte aligned. + // Otherwise, all sizes except for the final one must ensure that all + // splits give a 16 byte alignment. + if ((i != numInBatch - 1) && (size % (16 / floatWordSize) != 0)) { + is16ByteAligned = false; + } + + splitSize[i] = size; + if (i > 0) { + splitSizePrefix[i] = splitSizePrefix[i - 1] + splitSize[i - 1]; + } + + maxSplitSize = std::max(size, maxSplitSize); + } + + auto sizes_dev = res.copyAlloc(stream, splitSizeHost); + + // FIXME: combine with above for a single h2d copy + auto in_dev = res.copyAlloc(stream, (void**)in, numInBatch); + + auto updatedConfig = config; + updatedConfig.is16ByteAligned = is16ByteAligned; + + auto inProvider = BatchProviderPointer(in_dev.data()); + + auto outProvider = BatchProviderSplitSize( + out, sizes_dev.data(), sizes_dev.data() + numInBatch, floatWordSize); + + return floatDecompressDevice( + res, + updatedConfig, + numInBatch, + inProvider, + outProvider, + maxSplitSize, + outSuccess_dev, + outSize_dev, + stream); +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/float/GpuFloatDecompress.cuh b/thirdparty/dietgpu/dietgpu/float/GpuFloatDecompress.cuh new file mode 100644 index 000000000..12fd014b6 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/float/GpuFloatDecompress.cuh @@ -0,0 +1,740 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "dietgpu/ans/GpuANSCodec.h" +#include "dietgpu/ans/GpuANSDecode.cuh" +#include "dietgpu/ans/GpuANSUtils.cuh" +#include "dietgpu/float/GpuFloatCodec.h" +#include "dietgpu/float/GpuFloatInfo.cuh" +#include "dietgpu/float/GpuFloatUtils.cuh" +#include "dietgpu/utils/DeviceUtils.h" +#include "dietgpu/utils/StackDeviceMemory.h" +#include "dietgpu/utils/StaticUtils.h" + +#include +#include +#include +#include + +namespace dietgpu { + +template +struct JoinFloatNonAligned { + static __device__ void join( + const typename FloatTypeInfo::CompT* __restrict__ compIn, + const typename FloatTypeInfo::NonCompT* __restrict__ nonCompIn, + uint32_t size, + typename FloatTypeInfo::WordT* __restrict__ out) { + for (uint32_t i = blockIdx.x * Threads + threadIdx.x; i < size; + i += gridDim.x * Threads) { + out[i] = FloatTypeInfo::join(compIn[i], nonCompIn[i]); + } + } +}; + +template +struct JoinFloatNonAligned { + static __device__ void join( + const typename FloatTypeInfo< + FloatType::kFloat32>::CompT* __restrict__ compIn, + const typename FloatTypeInfo< + FloatType::kFloat32>::NonCompT* __restrict__ nonCompIn, + uint32_t size, + typename FloatTypeInfo::WordT* __restrict__ out) { + using FTI = FloatTypeInfo; + using CompT = typename FTI::CompT; + using NonCompT = typename FTI::NonCompT; + + // Where the low order 2 bytes are read + uint16_t* nonComp2In = (uint16_t*)nonCompIn; + + // Where the high order byte is read + uint8_t* nonComp1In = (uint8_t*)(nonComp2In + roundUp(size, 8)); + + for (uint32_t i = blockIdx.x * Threads + threadIdx.x; i < size; + i += gridDim.x * Threads) { + uint32_t nc = + (uint32_t(nonComp1In[i]) * 65536U) + uint32_t(nonComp2In[i]); + + out[i] = FTI::join(compIn[i], nc); + } + } +}; + +template +struct JoinFloatAligned16 { + static __device__ void join( + const typename FloatTypeInfo::CompT* __restrict__ compIn, + const typename FloatTypeInfo::NonCompT* __restrict__ nonCompIn, + uint32_t size, + typename FloatTypeInfo::WordT* __restrict__ out) { + using FTI = FloatTypeInfo; + + using WordT = typename FTI::WordT; + using CompT = typename FTI::CompT; + using NonCompT = typename FTI::NonCompT; + using VecT = typename FTI::VecT; + using CompVecT = typename FTI::CompVecT; + using NonCompVecT = typename FTI::NonCompVecT; + + constexpr int kOuterUnroll = 2; + constexpr int kInnerUnroll = sizeof(VecT) / sizeof(WordT); + + const CompVecT* compInV = (const CompVecT*)compIn; + const NonCompVecT* nonCompInV = (const NonCompVecT*)nonCompIn; + VecT* outV = (VecT*)out; + + // Each block handles Threads * kOuterUnroll * kInnerUnroll inputs/outputs + // at a time, or Threads * kOuterUnroll 16-byte words at a time + + constexpr int kWordsPerBlock = Threads * kOuterUnroll; + constexpr int kFloatsPerBlock = kWordsPerBlock * kInnerUnroll; + uint32_t fullBlocks = divDown(size, kFloatsPerBlock); + + // Handle by block + uint32_t startBlock = blockIdx.x * kWordsPerBlock; + compInV += startBlock + threadIdx.x; + nonCompInV += startBlock + threadIdx.x; + outV += startBlock + threadIdx.x; + + for (uint32_t b = blockIdx.x; b < fullBlocks; b += gridDim.x, + compInV += gridDim.x * kWordsPerBlock, + nonCompInV += gridDim.x * kWordsPerBlock, + outV += gridDim.x * kWordsPerBlock) { + CompVecT comp[kOuterUnroll]; + NonCompVecT nonComp[kOuterUnroll]; + +#pragma unroll + for (uint32_t i = 0; i < kOuterUnroll; ++i) { + comp[i] = compInV[i * Threads]; + nonComp[i] = nonCompInV[i * Threads]; + } + + VecT v[kOuterUnroll]; + +#pragma unroll + for (uint32_t i = 0; i < kOuterUnroll; ++i) { +#pragma unroll + for (int j = 0; j < kInnerUnroll; ++j) { + v[i].x[j] = FTI::join(comp[i].x[j], nonComp[i].x[j]); + } + } + +#pragma unroll + for (uint32_t i = 0; i < kOuterUnroll; ++i) { + outV[i * Threads] = v[i]; + } + } + + // Handle last (partial) block + for (uint32_t i = + fullBlocks * kFloatsPerBlock + blockIdx.x * Threads + threadIdx.x; + i < size; + i += blockDim.x) { + out[i] = FTI::join(compIn[i], nonCompIn[i]); + } + } +}; + +// float32 specialization +template +struct JoinFloatAligned16 { + static __device__ void join( + const typename FloatTypeInfo< + FloatType::kFloat32>::CompT* __restrict__ compIn, + const typename FloatTypeInfo< + FloatType::kFloat32>::NonCompT* __restrict__ nonCompIn, + uint32_t size, + typename FloatTypeInfo::WordT* __restrict__ out) { + using FTI = FloatTypeInfo; + + using WordT = typename FTI::WordT; + using CompT = typename FTI::CompT; + using NonCompT = typename FTI::NonCompT; + + constexpr int kOuterUnroll = 1; + constexpr int kInnerUnroll = sizeof(uint32x4) / sizeof(uint32_t); + + auto compInV = (const uint8x4*)compIn; + auto nonCompIn2 = (const uint16_t*)nonCompIn; + auto nonCompIn1 = (const uint8_t*)(nonCompIn2 + roundUp(size, 8)); + + auto nonCompInV2 = (uint16x4*)nonCompIn2; + auto nonCompInV1 = (uint8x4*)nonCompIn1; + + auto outV = (uint32x4*)out; + + // Each block handles Threads * kOuterUnroll * kInnerUnroll inputs/outputs + // at a time, or Threads * kOuterUnroll 16-byte words at a time + constexpr int kWordsPerBlock = Threads * kOuterUnroll; + constexpr int kFloatsPerBlock = kWordsPerBlock * kInnerUnroll; + uint32_t fullBlocks = divDown(size, kFloatsPerBlock); + + // Handle by block + uint32_t startBlock = blockIdx.x * kWordsPerBlock; + compInV += startBlock + threadIdx.x; + nonCompInV2 += startBlock + threadIdx.x; + nonCompInV1 += startBlock + threadIdx.x; + outV += startBlock + threadIdx.x; + + for (uint32_t b = blockIdx.x; b < fullBlocks; b += gridDim.x, + compInV += gridDim.x * kWordsPerBlock, + nonCompInV2 += gridDim.x * kWordsPerBlock, + nonCompInV1 += gridDim.x * kWordsPerBlock, + outV += gridDim.x * kWordsPerBlock) { + uint8x4 comp[kOuterUnroll]; + uint16x4 nonComp2[kOuterUnroll]; + uint8x4 nonComp1[kOuterUnroll]; + +#pragma unroll + for (uint32_t i = 0; i < kOuterUnroll; ++i) { + comp[i] = compInV[i * Threads]; + nonComp2[i] = nonCompInV2[i * Threads]; + nonComp1[i] = nonCompInV1[i * Threads]; + } + + uint32x4 nonComp[kOuterUnroll]; +#pragma unroll + for (uint32_t i = 0; i < kOuterUnroll; ++i) { +#pragma unroll + for (int j = 0; j < kInnerUnroll; ++j) { + nonComp[i].x[j] = nonComp1[i].x[j] * 65536U + nonComp2[i].x[j]; + } + } + + uint32x4 v[kOuterUnroll]; + +#pragma unroll + for (uint32_t i = 0; i < kOuterUnroll; ++i) { +#pragma unroll + for (int j = 0; j < kInnerUnroll; ++j) { + v[i].x[j] = FTI::join(comp[i].x[j], nonComp[i].x[j]); + } + } + +#pragma unroll + for (uint32_t i = 0; i < kOuterUnroll; ++i) { + outV[i * Threads] = v[i]; + } + } + + // Handle last (partial) block + for (uint32_t i = + fullBlocks * kFloatsPerBlock + blockIdx.x * Threads + threadIdx.x; + i < size; + i += blockDim.x) { + uint32_t nc2 = nonCompIn2[i]; + uint32_t nc1 = nonCompIn1[i]; + uint32_t nc = nc1 * 65536U + nc2; + + out[i] = FTI::join(compIn[i], nc); + } + } +}; + +template +struct JoinFloatImpl { + static __device__ void join( + const typename FloatTypeInfo::CompT* compIn, + const typename FloatTypeInfo::NonCompT* nonCompIn, + uint32_t size, + typename FloatTypeInfo::WordT* out) { + // compIn should always be aligned, as we decompress into temporary memory + auto compUnalignedBytes = getAlignmentRoundUp(compIn); + auto nonCompUnalignedBytes = getAlignmentRoundUp(nonCompIn); + auto outUnalignedBytes = getAlignmentRoundUp(out); + + if (compUnalignedBytes || nonCompUnalignedBytes || outUnalignedBytes) { + JoinFloatNonAligned::join(compIn, nonCompIn, size, out); + } else { + JoinFloatAligned16::join(compIn, nonCompIn, size, out); + } + } +}; + +template +struct JoinFloatImpl { + static __device__ void join( + const typename FloatTypeInfo::CompT* compIn, + const typename FloatTypeInfo::NonCompT* nonCompIn, + uint32_t size, + typename FloatTypeInfo::WordT* out) { + // FIXME: implement vectorization + JoinFloatNonAligned::join( + compIn, nonCompIn, size, out); + } +}; + +template < + typename InProviderComp, + typename InProviderNonComp, + typename OutProvider, + FloatType FT, + int Threads> +__global__ void joinFloat( + InProviderComp inProviderComp, + InProviderNonComp inProviderNonComp, + OutProvider outProvider, + uint8_t* __restrict__ outSuccess, + uint32_t* __restrict__ outSize) { + using FTI = FloatTypeInfo; + using WordT = typename FTI::WordT; + using CompT = typename FTI::CompT; + using NonCompT = typename FTI::NonCompT; + + auto batch = blockIdx.y; + + auto curCompIn = (const CompT*)inProviderComp.getBatchStart(batch); + auto curHeaderIn = + (const GpuFloatHeader*)inProviderNonComp.getBatchStart(batch); + auto curOut = (WordT*)outProvider.getBatchStart(batch); + + // FIXME: test out capacity + + if (outSuccess && !outSuccess[batch]) { + // ANS decompression failed, so nothing for us to do + return; + } + + // Get size as a header + GpuFloatHeader h = *curHeaderIn; + h.checkMagicAndVersion(); + + auto curSize = h.size; + + if (outSize && (curSize != outSize[batch])) { + // Reported size mismatch between ANS decompression and fp unpacking + assert(false); + return; + } + + auto curNonCompIn = (const NonCompT*)(curHeaderIn + 1); + + JoinFloatImpl::join(curCompIn, curNonCompIn, curSize, curOut); +} + +template +struct FloatANSProvider { + using FTI = FloatTypeInfo; + + __host__ FloatANSProvider(InProvider& provider) : inProvider_(provider) {} + + __device__ void* getBatchStart(uint32_t batch) { + uint8_t* p = (uint8_t*)inProvider_.getBatchStart(batch); + + // This is the first place that touches the header + GpuFloatHeader h = *((GpuFloatHeader*)p); + h.checkMagicAndVersion(); + assert(FT == h.getFloatType()); + + // Increment the pointer to past the floating point data + return p + sizeof(GpuFloatHeader) + FTI::getUncompDataSize(h.size); + } + + __device__ const void* getBatchStart(uint32_t batch) const { + const uint8_t* p = (const uint8_t*)inProvider_.getBatchStart(batch); + + // This is the first place that touches the header + GpuFloatHeader h = *((const GpuFloatHeader*)p); + h.checkMagicAndVersion(); + assert(FT == h.getFloatType()); + + // Increment the pointer to past the floating point data + return p + sizeof(GpuFloatHeader) + FTI::getUncompDataSize(h.size); + } + + InProvider inProvider_; +}; + +template +struct FloatANSProviderInline { + using FTI = FloatTypeInfo; + + __host__ FloatANSProviderInline(int num, const void** in) { + CHECK_LE(num, N); + for (int i = 0; i < num; ++i) { + in_[i] = in[i]; + } + } + + __device__ void* getBatchStart(uint32_t batch) { + uint8_t* p = (uint8_t*)in_[batch]; + + // This is the first place that touches the header + GpuFloatHeader h = *((GpuFloatHeader*)p); + h.checkMagicAndVersion(); + assert(FT == h.getFloatType()); + + // Increment the pointer to past the floating point data + return p + sizeof(GpuFloatHeader) + FTI::getUncompDataSize(h.size); + } + + __device__ const void* getBatchStart(uint32_t batch) const { + const uint8_t* p = (const uint8_t*)in_[batch]; + + // This is the first place that touches the header + GpuFloatHeader h = *((const GpuFloatHeader*)p); + h.checkMagicAndVersion(); + assert(FT == h.getFloatType()); + + // Increment the pointer to past the floating point data + return p + sizeof(GpuFloatHeader) + FTI::getUncompDataSize(h.size); + } + + const void* in_[N]; +}; + +template +struct JoinFloatWriter { + using FTI = FloatTypeInfo; + + __host__ __device__ JoinFloatWriter( + uint32_t size, + typename FTI::WordT* out, + const typename FTI::NonCompT* nonComp) + : out_(out), + nonComp_(nonComp), + outBlock_(nullptr), + nonCompBlock_(nullptr) {} + + __host__ __device__ void setBlock(uint32_t block) { + outBlock_ = out_ + block * BlockSize; + nonCompBlock_ = nonComp_ + block * BlockSize; + } + + __device__ void write(uint32_t offset, uint8_t sym) { + auto nonComp = nonCompBlock_[offset]; + outBlock_[offset] = FTI::join(sym, nonComp); + } + + // // The preload is an offset of a NonCompVec4 + // __device__ void preload(uint32_t offset) { + // // We can preload this before decompressing all of the ANS compressed + // data + // // to hide memory latency + // preload_ = ((typename FTI::NonCompVec4*)nonCompBlock_)[offset]; + // } + + // __device__ void writeVec(uint32_t offset, ANSDecodedTx4 symV) { + // typename FTI::Vec4 outV; + // #pragma unroll + // // We always receive 4 decoded values each iteration + // // FIXME: this is hacky + // for (int i = 0; i < 4; ++i) { + // outV.x[i] = JoinFloat::join(symV.x[i], preload_.x[i]); + // } + + // ((typename FTI::Vec4*)outBlock_)[offset] = outV; + // } + + // typename FTI::NonCompVec4 preload_; + typename FTI::WordT* out_; + const typename FTI::NonCompT* nonComp_; + typename FTI::WordT* outBlock_; + const typename FTI::NonCompT* nonCompBlock_; +}; + +template +struct JoinFloatWriter { + static constexpr bool kVectorize = false; + using FTI = FloatTypeInfo; + + __host__ __device__ JoinFloatWriter( + uint32_t size, + typename FTI::WordT* out, + const typename FTI::NonCompT* nonComp) + : size_(size), + out_(out), + nonComp_(nonComp), + outBlock_(nullptr), + nonCompBlock2_(nullptr), + nonCompBlock1_(nullptr) {} + + __host__ __device__ void setBlock(uint32_t block) { + nonCompBlock2_ = (const uint16_t*)nonComp_ + block * BlockSize; + nonCompBlock1_ = + (const uint8_t*)((const uint16_t*)nonComp_ + roundUp(size_, 8U)) + + block * BlockSize; + outBlock_ = out_ + block * BlockSize; + } + + __device__ void write(uint32_t offset, uint8_t sym) { + uint32_t nc = uint32_t(nonCompBlock1_[offset]) * 65536U + + uint32_t(nonCompBlock2_[offset]); + + outBlock_[offset] = FTI::join(sym, nc); + } + + // // This implementation does not preload + // __device__ void preload(uint32_t offset) { + // } + + // // This implementation does not vectorize + // __device__ void writeVec(uint32_t offset, ANSDecodedTx4 symV) { + // } + + uint32_t size_; + typename FTI::WordT* out_; + const typename FTI::NonCompT* nonComp_; + typename FTI::WordT* outBlock_; + const uint16_t* nonCompBlock2_; + const uint8_t* nonCompBlock1_; +}; + +template < + typename InProvider, + typename OutProvider, + FloatType FT, + uint32_t BlockSize> +struct FloatOutProvider { + using Writer = JoinFloatWriter; + using FTI = FloatTypeInfo; + + __host__ FloatOutProvider(InProvider& inProvider, OutProvider& outProvider) + : inProvider_(inProvider), outProvider_(outProvider) {} + + __device__ void* getBatchStart(uint32_t batch) { + return inProvider_.getBatchStart(batch); + } + + __device__ uint32_t getBatchSize(uint32_t batch) { + return outProvider_.getBatchSize(batch); + } + + __device__ Writer getWriter(uint32_t batch) { + // Get float header + auto h = (const GpuFloatHeader*)getBatchStart(batch); + + return Writer( + h->size, + (typename FTI::WordT*)outProvider_.getBatchStart(batch), + // advance past the header + (const typename FTI::NonCompT*)(h + 1)); + } + + InProvider inProvider_; + OutProvider outProvider_; +}; + +template +struct FloatOutProviderInline { + using FTI = FloatTypeInfo; + using Writer = JoinFloatWriter; + + __host__ FloatOutProviderInline( + int num, + const void** in, + void** out, + const uint32_t* outCapacity) { + CHECK_LE(num, N); + for (int i = 0; i < num; ++i) { + in_[i] = in[i]; + out_[i] = out[i]; + outCapacity_[i] = outCapacity[i]; + } + } + + __device__ void* getBatchStart(uint32_t batch) { + return in_[batch]; + } + + __device__ uint32_t getBatchSize(uint32_t batch) { + return outCapacity_[batch]; + } + + __device__ Writer getWriter(uint32_t batch) { + // Get float header + auto h = (const GpuFloatHeader*)getBatchStart(batch); + + return Writer( + h->size, + (typename FTI::WordT*)out_[batch], + // advance past the header + (const typename FTI::NonCompT*)(h + 1)); + } + + const void* in_[N]; + void* out_[N]; + uint32_t outCapacity_[N]; +}; + +template +FloatDecompressStatus floatDecompressDevice( + StackDeviceMemory& res, + const FloatDecompressConfig& config, + uint32_t numInBatch, + InProvider& inProvider, + OutProvider& outProvider, + uint32_t maxCapacity, + uint8_t* outSuccess_dev, + uint32_t* outSize_dev, + cudaStream_t stream) { + // not allowed in float mode + assert(!config.ansConfig.useChecksum); + + // We can perform decoding in a single pass if all input data is 16 byte + // aligned + if (config.is16ByteAligned) { + // + // Fused kernel: perform decompression in a single pass + // + +#define RUN_FUSED(FT) \ + do { \ + auto inProviderANS = FloatANSProvider(inProvider); \ + auto outProviderANS = \ + FloatOutProvider( \ + inProvider, outProvider); \ + \ + ansDecodeBatch( \ + res, \ + config.ansConfig, \ + numInBatch, \ + inProviderANS, \ + outProviderANS, \ + outSuccess_dev, \ + outSize_dev, \ + stream); \ + } while (false) + + switch (config.floatType) { + case FloatType::kFloat16: + RUN_FUSED(FloatType::kFloat16); + break; + case FloatType::kBFloat16: + RUN_FUSED(FloatType::kBFloat16); + break; + case FloatType::kFloat32: + RUN_FUSED(FloatType::kFloat32); + break; + default: + CHECK(false); + break; + } + +#undef RUN_FUSED + } + + else { + // + // Two pass kernel: decompress the ANS compressed data, then rejoin with + // uncompressed data + // + + // Temporary space for the decompressed exponents + // We need to ensure 16 byte alignment for the decompressed data due to + // vectorization + uint32_t maxCapacityAligned = roundUp(maxCapacity, sizeof(uint4)); + + auto exp_dev = res.alloc(stream, numInBatch * maxCapacityAligned); + +#define RUN_DECODE(FT) \ + do { \ + using InProviderANS = FloatANSProvider; \ + auto inProviderANS = InProviderANS(inProvider); \ + \ + using OutProviderANS = BatchProviderStride; \ + auto outProviderANS = OutProviderANS( \ + exp_dev.data(), maxCapacityAligned, maxCapacityAligned); \ + \ + ansDecodeBatch( \ + res, \ + config.ansConfig, \ + numInBatch, \ + inProviderANS, \ + outProviderANS, \ + outSuccess_dev, \ + outSize_dev, \ + stream); \ + \ + constexpr int kThreads = 256; \ + auto& props = getCurrentDeviceProperties(); \ + int maxBlocksPerSM = 0; \ + CUDA_VERIFY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( \ + &maxBlocksPerSM, \ + joinFloat, \ + kThreads, \ + 0)); \ + uint32_t maxGrid = maxBlocksPerSM * props.multiProcessorCount; \ + uint32_t perBatchGrid = divUp(maxGrid, numInBatch); \ + if ((perBatchGrid * numInBatch > maxGrid) && perBatchGrid > 1) { \ + perBatchGrid -= 1; \ + } \ + auto grid = dim3(perBatchGrid, numInBatch); \ + \ + joinFloat \ + <<>>( \ + outProviderANS, \ + inProvider, \ + outProvider, \ + outSuccess_dev, \ + outSize_dev); \ + } while (false) + + switch (config.floatType) { + case FloatType::kFloat16: + RUN_DECODE(FloatType::kFloat16); + break; + case FloatType::kBFloat16: + RUN_DECODE(FloatType::kBFloat16); + break; + case FloatType::kFloat32: + RUN_DECODE(FloatType::kFloat32); + break; + default: + CHECK(false); + break; + } + +#undef RUN_DECODE + } + + FloatDecompressStatus status; + + // Perform optional checksum, if desired + if (config.useChecksum) { + auto checksum_dev = res.alloc(stream, numInBatch); + auto sizes_dev = res.alloc(stream, numInBatch); + auto archiveChecksum_dev = res.alloc(stream, numInBatch); + + // Checksum the output data + checksumBatch(numInBatch, outProvider, checksum_dev.data(), stream); + + // Get prior checksum from the float headers + floatGetCompressedInfo( + inProvider, + numInBatch, + sizes_dev.data(), + nullptr, + archiveChecksum_dev.data(), + stream); + + // Compare against previously seen checksums on the host + auto sizes = sizes_dev.copyToHost(stream); + auto newChecksums = checksum_dev.copyToHost(stream); + auto oldChecksums = archiveChecksum_dev.copyToHost(stream); + + std::stringstream errStr; + + for (int i = 0; i < numInBatch; ++i) { + if (oldChecksums[i] != newChecksums[i]) { + status.error = FloatDecompressError::ChecksumMismatch; + + errStr << "Checksum mismatch in batch member " << i + << ": expected checksum " << std::hex << oldChecksums[i] + << " got " << newChecksums[i] << "\n"; + status.errorInfo.push_back(std::make_pair(i, errStr.str())); + } + } + } + + CUDA_TEST_ERROR(); + + return status; +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/float/GpuFloatInfo.cu b/thirdparty/dietgpu/dietgpu/float/GpuFloatInfo.cu new file mode 100644 index 000000000..4abe12708 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/float/GpuFloatInfo.cu @@ -0,0 +1,66 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "dietgpu/ans/BatchProvider.cuh" +#include "dietgpu/float/GpuFloatCodec.h" +#include "dietgpu/float/GpuFloatInfo.cuh" +#include "dietgpu/utils/DeviceUtils.h" +#include "dietgpu/utils/StackDeviceMemory.h" +#include "dietgpu/utils/StaticUtils.h" + +namespace dietgpu { + +void floatGetCompressedInfo( + StackDeviceMemory& res, + const void** in, + uint32_t numInBatch, + uint32_t* outSizes_dev, + uint32_t* outTypes_dev, + uint32_t* outChecksum_dev, + cudaStream_t stream) { + if (!outSizes_dev && !outTypes_dev && !outChecksum_dev) { + return; + } + + auto in_dev = res.copyAlloc(stream, in, numInBatch); + + floatGetCompressedInfoDevice( + res, + in_dev.data(), + numInBatch, + outSizes_dev, + outTypes_dev, + outChecksum_dev, + stream); + + CUDA_TEST_ERROR(); +} + +void floatGetCompressedInfoDevice( + StackDeviceMemory& res, + const void** in_dev, + uint32_t numInBatch, + uint32_t* outSizes_dev, + uint32_t* outTypes_dev, + uint32_t* outChecksum_dev, + cudaStream_t stream) { + if (!outSizes_dev && !outTypes_dev && !outChecksum_dev) { + return; + } + + auto inProvider = BatchProviderPointer((void**)in_dev); + + floatGetCompressedInfo( + inProvider, + numInBatch, + outSizes_dev, + outTypes_dev, + outChecksum_dev, + stream); +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/float/GpuFloatInfo.cuh b/thirdparty/dietgpu/dietgpu/float/GpuFloatInfo.cuh new file mode 100644 index 000000000..3a4503847 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/float/GpuFloatInfo.cuh @@ -0,0 +1,64 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include "dietgpu/float/GpuFloatCodec.h" +#include "dietgpu/float/GpuFloatUtils.cuh" +#include "dietgpu/utils/DeviceUtils.h" +#include "dietgpu/utils/StackDeviceMemory.h" +#include "dietgpu/utils/StaticUtils.h" + +namespace dietgpu { + +template +__global__ void floatGetCompressedInfoKernel( + InProvider inProvider, + uint32_t numInBatch, + uint32_t* outSizes, + uint32_t* outTypes, + uint32_t* outChecksum) { + auto batch = blockIdx.x * blockDim.x + threadIdx.x; + if (batch < numInBatch) { + auto header = (const GpuFloatHeader*)inProvider.getBatchStart(batch); + header->checkMagicAndVersion(); + + if (outSizes) { + outSizes[batch] = header->size; + } + if (outTypes) { + outTypes[batch] = uint32_t(header->getFloatType()); + } + if (outChecksum) { + assert(header->getUseChecksum()); + outChecksum[batch] = header->getChecksum(); + } + } +} + +template +void floatGetCompressedInfo( + InProvider& inProvider, + uint32_t numInBatch, + uint32_t* outSizes_dev, + uint32_t* outTypes_dev, + uint32_t* outChecksum_dev, + cudaStream_t stream) { + if (!outSizes_dev && !outTypes_dev && !outTypes_dev) { + return; + } + + auto block = 128; + auto grid = divUp(numInBatch, block); + + floatGetCompressedInfoKernel<<>>( + inProvider, numInBatch, outSizes_dev, outTypes_dev, outChecksum_dev); + + CUDA_TEST_ERROR(); +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/float/GpuFloatUtils.cuh b/thirdparty/dietgpu/dietgpu/float/GpuFloatUtils.cuh new file mode 100644 index 000000000..cf98ea409 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/float/GpuFloatUtils.cuh @@ -0,0 +1,227 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "dietgpu/utils/DeviceDefs.cuh" +#include "dietgpu/utils/PtxUtils.cuh" +#include "dietgpu/utils/StaticUtils.h" + +#include +#include + +namespace dietgpu { + +// magic number to verify archive integrity +constexpr uint32_t kFloatMagic = 0xf00f; + +// current DietGPU version number +constexpr uint32_t kFloatVersion = 0x0001; + +// Header on our compressed floating point data +struct __align__(16) GpuFloatHeader { + __host__ __device__ void setMagicAndVersion() { + magicAndVersion = (kFloatMagic << 16) | kFloatVersion; + } + + __host__ __device__ void checkMagicAndVersion() const { + assert((magicAndVersion >> 16) == kFloatMagic); + assert((magicAndVersion & 0xffffU) == kFloatVersion); + } + + __host__ __device__ FloatType getFloatType() const { + return FloatType(options & 0xf); + } + + __host__ __device__ void setFloatType(FloatType ft) { + assert(uint32_t(ft) <= 0xf); + options = (options & 0xfffffff0U) | uint32_t(ft); + } + + __host__ __device__ bool getUseChecksum() const { + return options & 0x10; + } + + __host__ __device__ void setUseChecksum(bool uc) { + options = (options & 0xffffffef) | (uint32_t(uc) << 4); + } + + __host__ __device__ uint32_t getChecksum() const { + return checksum; + } + + __host__ __device__ void setChecksum(uint32_t c) { + checksum = c; + } + + // (16: magic)(16: version) + uint32_t magicAndVersion; + + // Number of floating point words of the given float type in the archive + uint32_t size; + + // (27: unused)(1: use checksum)(4: float type) + uint32_t options; + + // Optional checksum computed on the input data + uint32_t checksum; +}; + +static_assert(sizeof(GpuFloatHeader) == 16, ""); + +struct __align__(16) uint32x4 { + uint32_t x[4]; +}; + +struct __align__(16) uint16x8 { + uint16_t x[8]; +}; + +struct __align__(8) uint16x4 { + uint16_t x[4]; +}; + +struct __align__(8) uint8x8 { + uint8_t x[8]; +}; + +struct __align__(4) uint8x4 { + uint8_t x[4]; +}; + +// Convert FloatType to word size/type +template +struct FloatTypeInfo; + +template <> +struct FloatTypeInfo { + using WordT = uint16_t; + using CompT = uint8_t; + using NonCompT = uint8_t; + + // 16 byte vector type + using VecT = uint16x8; + using CompVecT = uint8x8; + using NonCompVecT = uint8x8; + + static __device__ void split(WordT in, CompT& comp, NonCompT& nonComp) { + // don't bother extracting the specific exponent + comp = in >> 8; + nonComp = in & 0xff; + } + + static __device__ WordT join(CompT comp, NonCompT nonComp) { + return WordT(comp) * WordT(256) + WordT(nonComp); + } + + // How many bytes of data are in the non-compressed portion past the float + // header? + static __host__ __device__ uint32_t getUncompDataSize(uint32_t size) { + // The size of the uncompressed data is always a multiple of 16 bytes, to + // guarantee alignment for proceeding data segments + return roundUp(size, 16 / sizeof(NonCompT)); + } +}; + +template <> +struct FloatTypeInfo { + using WordT = uint16_t; + using CompT = uint8_t; + using NonCompT = uint8_t; + + // 16 byte vector type + using VecT = uint16x8; + using CompVecT = uint8x8; + using NonCompVecT = uint8x8; + + static __device__ void split(WordT in, CompT& comp, NonCompT& nonComp) { + uint32_t v = uint32_t(in) * 65536U + uint32_t(in); + + v = rotateLeft(v, 1); + comp = v >> 24; + nonComp = v & 0xff; + } + + static __device__ WordT join(CompT comp, NonCompT nonComp) { + uint32_t lo = uint32_t(comp) * 256U + uint32_t(nonComp); + lo <<= 16; + uint32_t hi = nonComp; + + uint32_t out; + +#if defined(__HIP_PLATFORM_AMD__) + out = (lo >> 1) | (hi << 31); + // Emulate funnel shift right: concatenate lo:hi and shift right by 1 + // out = static_cast(((uint64_t(lo) << 32) | uint64_t(hi)) >> 1); +#else + asm("shf.r.clamp.b32 %0, %1, %2, %3;" + : "=r"(out) + : "r"(lo), "r"(hi), "r"(1)); +#endif + + return out >>= 16; + } + + // How many bytes of data are in the non-compressed portion past the float + // header? + static __host__ __device__ uint32_t getUncompDataSize(uint32_t size) { + // The size of the uncompressed data is always a multiple of 16 bytes, to + // guarantee alignment for proceeding data segments + return roundUp(size, 16 / sizeof(NonCompT)); + } +}; + +template <> +struct FloatTypeInfo { + using WordT = uint32_t; + using CompT = uint8_t; + using NonCompT = uint32_t; + + // 16 byte vector type + using VecT = uint32x4; + using CompVecT = uint8x4; + using NonCompVecT = uint32x4; + + static __device__ void split(WordT in, CompT& comp, NonCompT& nonComp) { + auto v = rotateLeft(in, 1); + comp = v >> 24; + nonComp = v & 0xffffffU; + } + + static __device__ WordT join(CompT comp, NonCompT nonComp) { + uint32_t v = (uint32_t(comp) * 16777216U) + uint32_t(nonComp); + return rotateRight(v, 1); + } + + // How many bytes of data are in the non-compressed portion past the float + // header? + static __host__ __device__ uint32_t getUncompDataSize(uint32_t size) { + // The size of the uncompressed data is always a multiple of 16 bytes, to + // guarantee alignment for proceeding data segments + // We store the low order 2 bytes first, then the high order uncompressed + // byte afterwards. + // Both sections should be 16 byte aligned + return 2 * roundUp(size, 8) + // low order 2 bytes + roundUp(size, 16); // high order 1 byte, starting at an aligned address + // after the low 2 byte segment + } +}; + +inline size_t getWordSizeFromFloatType(FloatType ft) { + switch (ft) { + case FloatType::kFloat16: + case FloatType::kBFloat16: + return sizeof(uint16_t); + case FloatType::kFloat32: + return sizeof(uint32_t); + default: + CHECK(false); + return 0; + } +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/float_test.py b/thirdparty/dietgpu/dietgpu/float_test.py new file mode 100644 index 000000000..650c8afa9 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/float_test.py @@ -0,0 +1,178 @@ +# Copyright (c) (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import random +import unittest + +import torch + +torch.ops.load_library("//dietgpu:dietgpu") + + +def run_test(dev, ts, temp_mem=None): + comp, sizes, _ = torch.ops.dietgpu.compress_data(True, ts, True, temp_mem) + for s, t in zip(sizes, ts): + t_bytes = t.numel() * t.element_size() + print(f"{t_bytes} bytes -> {s.item()} bytes ({s.item() / t_bytes}x)") + + # Truncate the output data to exactly the sizes that are used + # (i.e., we are testing that the byte sizes we report in compression are accurate) + truncated_comp = [] + for size, t in zip(sizes, [*comp]): + truncated_t = t.narrow(0, 0, size.item()).clone() + truncated_comp.append(truncated_t) + + out_ts = [] + for t in ts: + out_ts.append(torch.empty(t.size(), dtype=t.dtype, device=t.device)) + + if temp_mem is not None: + out_status = torch.empty([len(ts)], dtype=torch.uint8, device=dev) + out_sizes = torch.empty([len(ts)], dtype=torch.int32, device=dev) + + torch.ops.dietgpu.decompress_data( + True, truncated_comp, out_ts, True, temp_mem, out_status, out_sizes + ) + + for t, status, size in zip(ts, out_status, out_sizes): + assert status.item() + assert t.numel() == size.item() + else: + torch.ops.dietgpu.decompress_data(True, truncated_comp, out_ts, True) + + for a, b in zip(ts, out_ts): + assert torch.equal(a, b) + + +class TestFloatCodec(unittest.TestCase): + def test_codec(self): + dev = torch.device("cuda:0") + temp_mem = torch.empty([64 * 1024 * 1024], dtype=torch.uint8, device=dev) + + for dt in [torch.bfloat16, torch.float16, torch.float32]: + for tm in [False, True]: + ts = [ + torch.normal(0, 1.0, [i], dtype=dt, device=dev) + for i in [10000, 100000, 1000000] + ] + if tm: + run_test(dev, ts, temp_mem) + else: + run_test(dev, ts) + + def test_large(self): + dev = torch.device("cuda:0") + temp_mem = torch.empty([64 * 1024 * 1024], dtype=torch.uint8, device=dev) + + for dt in [torch.bfloat16, torch.float16, torch.float32]: + for tm in [False, True]: + ts = [torch.normal(0, 1.0, [123456789], dtype=dt, device=dev)] + if tm: + run_test(dev, ts, temp_mem) + else: + run_test(dev, ts) + + def test_simple(self): + dev = torch.device("cuda:0") + for dt in [torch.bfloat16, torch.float16, torch.float32]: + ts = [ + torch.normal(0, 1.0, [i], dtype=dt, device=dev) + for i in [10000, 100000, 1000000] + ] + + cts = torch.ops.dietgpu.compress_data_simple(True, ts, True) + for before, after in zip(ts, cts): + # We should actually be compressing data + assert ( + before.numel() * before.element_size() + > after.numel() * after.element_size() + ) + + dts = torch.ops.dietgpu.decompress_data_simple(True, cts, True) + for orig, after in zip(ts, dts): + assert torch.equal(orig, after) + + def test_empty(self): + dev = torch.device("cuda:0") + for dt in [torch.bfloat16, torch.float16, torch.float32]: + ts = [torch.empty([0], dtype=dt, device=dev)] + comp_ts = torch.ops.dietgpu.compress_data_simple(True, ts, True) + + # should have a header + assert comp_ts[0].numel() > 0 + + decomp_ts = torch.ops.dietgpu.decompress_data_simple(True, comp_ts, True) + assert torch.equal(ts[0], decomp_ts[0]) + + def test_split_compress(self): + dev = torch.device("cuda:0") + temp_mem = torch.empty([64 * 1024 * 1024], dtype=torch.uint8, device=dev) + + for dt in [torch.bfloat16, torch.float16, torch.float32]: + for align16 in [True, False]: + for tries in range(5): + batch_size = random.randrange(1, 15) + sizes = [] + + sum_sizes = 0 + max_size = 0 + for i in range(batch_size): + size = random.randrange(1, 10000) + if align16: + # 2 bytes per float, ensure 16 byte alignment + size *= 8 + + sizes.append(size) + sum_sizes += size + if size > max_size: + max_size = size + + t = torch.normal(0, 1.0, [sum_sizes], dtype=dt, device=dev) + sizes_t = torch.IntTensor(sizes) + splits = torch.split(t, sizes) + + comp_ts, _, _ = torch.ops.dietgpu.compress_data_split_size( + True, t, sizes_t, True, temp_mem + ) + decomp_ts = torch.ops.dietgpu.decompress_data_simple( + True, comp_ts, True + ) + + for orig, decomp in zip(splits, decomp_ts): + assert torch.equal(orig, decomp) + + def test_split_decompress(self): + dev = torch.device("cuda:0") + temp_mem = torch.empty([64 * 1024 * 1024], dtype=torch.uint8, device=dev) + + for dt in [torch.bfloat16, torch.float16, torch.float32]: + for align16 in [True, False]: + for tries in range(5): + batch_size = random.randrange(1, 15) + sizes = [] + + sum_sizes = 0 + for i in range(batch_size): + size = random.randrange(1, 10000) + if align16: + # 2 bytes per float, ensure 16 byte alignment + size *= 8 + + sizes.append(size) + sum_sizes += size + + t = torch.normal(0, 1.0, [sum_sizes], dtype=dt, device=dev) + sizes_t = torch.IntTensor(sizes) + + splits = torch.split(t, sizes) + comp_ts = torch.ops.dietgpu.compress_data_simple(True, splits, True) + + decomp_t = torch.empty([sum_sizes], dtype=dt, device=dev) + torch.ops.dietgpu.decompress_data_split_size( + True, comp_ts, decomp_t, sizes_t, True, temp_mem + ) + + assert torch.equal(t, decomp_t) diff --git a/thirdparty/dietgpu/dietgpu/test/ANSStatisticsTest.cu b/thirdparty/dietgpu/dietgpu/test/ANSStatisticsTest.cu new file mode 100644 index 000000000..54b275d8f --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/test/ANSStatisticsTest.cu @@ -0,0 +1,207 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include "dietgpu/ans/GpuANSStatistics.cuh" +#include "dietgpu/ans/GpuANSUtils.cuh" +#include "dietgpu/utils/StackDeviceMemory.h" + +using namespace dietgpu; + +std::vector generateSymbols(int num, float lambda = 20.0f) { + std::random_device rd; + std::mt19937 gen(10); + std::exponential_distribution dist(lambda); + + auto out = std::vector(num); + for (auto& v : out) { + auto sample = std::min(dist(gen), 1.0f); + + v = sample * 256.0; + } + + return out; +} + +std::vector histogram(const std::vector& data) { + auto counts = std::vector(256); + + for (auto v : data) { + counts[v]++; + } + + return counts; +} + +TEST(ANSStatisticsTest, Histogram) { + auto res = makeStackMemory(); + // run on a different stream to test stream assignment + auto stream = CudaStream::makeNonBlocking(); + + for (auto size : + {1, + 2, + 11, + 32, + 55, + 1000, + 1001, + 1000000, + 1024 * 1024, + 1000001, + 12345677}) { + int numInBatch = 3; + + auto data = std::vector(); + + auto histograms = std::vector>(); + + int stride = 11; + auto strideData = std::vector(stride); + + for (int b = 0; b < numInBatch; ++b) { + auto gen = generateSymbols(size, 20.0 + b * 2); + histograms.push_back(histogram(gen)); + + data.insert(data.end(), gen.begin(), gen.end()); + + // Add some stride padding + data.insert(data.end(), strideData.begin(), strideData.end()); + } + + auto data_dev = res.copyAlloc(stream, data); + auto hist_dev = res.alloc(stream, numInBatch * kNumSymbols); + + auto inProvider = BatchProviderStride(data_dev.data(), size + stride, size); + + ansHistogramBatch(numInBatch, inProvider, hist_dev.data(), stream); + + auto hist_host = hist_dev.copyToHost(stream); + + for (int b = 0; b < numInBatch; ++b) { + for (int i = 0; i < kNumSymbols; ++i) { + EXPECT_EQ(histograms[b][i], hist_host[b * kNumSymbols + i]); + } + } + } +} + +std::vector dataToANSTable( + const std::vector& data, + int probBits = 10) { + auto res = makeStackMemory(); + // run on a different stream to test stream assignment + auto stream = CudaStream::makeNonBlocking(); + + auto data_dev = res.copyAlloc(stream, data); + + // Get histogram + auto hist_dev = res.alloc(stream, kNumSymbols); + auto inProvider = + BatchProviderStride(data_dev.data(), data.size(), data.size()); + + ansHistogramBatch(1, inProvider, hist_dev.data(), stream); + + // Get ANS table from histogram (post-normalization) + auto table_dev = res.alloc(stream, kNumSymbols); + + ansCalcWeights( + 1, + probBits, + BatchProviderStride(hist_dev.data(), data.size(), data.size()), + hist_dev.data(), + table_dev.data(), + stream); + + return table_dev.copyToHost(stream); +} + +TEST(ANSStatisticsTest, Normalization_NonZero) { + // Ensure that non-zero count symbols get non-zero weight + auto data = std::vector(10000); + + for (int i = 0; i < 256; ++i) { + data[i] = uint8_t(i); + } + + for (int i = 256; i < data.size(); ++i) { + data[i] = 1; + } + + int probBits = 10; + auto table = dataToANSTable(data, probBits); + + for (int i = 0; i < kNumSymbols; ++i) { + if (i != 1) { + EXPECT_EQ(table[i].x, 1); + } else { + EXPECT_EQ(table[i].x, (1 << probBits) - 255); + } + } +} + +TEST(ANSStatisticsTest, Normalization_EqualWeight) { + // Ensure that non-zero count symbols get non-zero weight + auto data = std::vector(kNumSymbols * 64); + + for (int i = 0; i < 64; ++i) { + for (int j = 0; j < kNumSymbols; ++j) { + data[i * kNumSymbols + j] = uint8_t(j); + } + } + + int probBits = 10; + auto table = dataToANSTable(data, probBits); + + for (int i = 0; i < kNumSymbols; ++i) { + EXPECT_EQ(table[i].x, (1 << probBits) / kNumSymbols); + } +} + +TEST(ANSStatisticsTest, Normalization) { + auto data = generateSymbols(12345, 40.0f); + + // Count true distribution + auto hist = histogram(data); + + int probBits = 11; + auto table = dataToANSTable(data, probBits); + + uint32_t totalSum = 0; + uint32_t totalWeight = 1 << probBits; + + for (int i = 0; i < kNumSymbols; ++i) { + auto count = hist[i]; + auto pdf = table[i].x; + + totalSum += pdf; + + if (count == 0) { + EXPECT_EQ(pdf, 0) << "failed on " << i; + } else if (count > 0) { + EXPECT_GT(pdf, 0); + // The normalized prob should be within some small factor of the real + // count + float prob = float(count) / float(data.size()); + float normalizedProb = float(pdf) / float(totalWeight); + + EXPECT_GE(normalizedProb, prob * 0.5f); + + // Only relevant if the prob is > 1/totalWeight (i.e., we + // weren't rounded up to be a non-zero weight) + if (prob > 1.0f / float(totalWeight)) { + EXPECT_LE(normalizedProb, prob * 2.0f); + } + } + } + + EXPECT_EQ(totalSum, totalWeight); +} diff --git a/thirdparty/dietgpu/dietgpu/test/ANSTest.cu b/thirdparty/dietgpu/dietgpu/test/ANSTest.cu new file mode 100644 index 000000000..5d4458dfd --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/test/ANSTest.cu @@ -0,0 +1,282 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include "dietgpu/ans/GpuANSCodec.h" +#include "dietgpu/utils/StackDeviceMemory.h" + +using namespace dietgpu; + +std::vector generateSymbols(int num, float lambda = 20.0f) { + std::random_device rd; + std::mt19937 gen(10); + std::exponential_distribution dist(lambda); + + auto out = std::vector(num); + for (auto& v : out) { + auto sample = std::min(dist(gen), 1.0f); + + v = sample * 256.0; + } + + return out; +} + +std::vector> toDevice( + StackDeviceMemory& res, + const std::vector>& vs, + cudaStream_t stream) { + auto out = std::vector>(); + + for (auto& v : vs) { + out.emplace_back(res.copyAlloc(stream, v, AllocType::Permanent)); + } + + return out; +} + +std::vector> toHost( + StackDeviceMemory& res, + const std::vector>& vs, + cudaStream_t stream) { + auto out = std::vector>(); + + for (auto& v : vs) { + out.emplace_back(v.copyToHost(stream)); + } + + return out; +} + +std::vector> buffersToDevice( + StackDeviceMemory& res, + const std::vector& sizes, + cudaStream_t stream) { + auto out = std::vector>(); + + for (auto& s : sizes) { + out.emplace_back(res.alloc(stream, s, AllocType::Permanent)); + } + + return out; +} + +std::vector> genBatch( + const std::vector& sizes, + double lambda) { + auto out = std::vector>(); + + for (auto s : sizes) { + out.push_back(generateSymbols(s, lambda)); + } + + return out; +} + +void runBatchPointer( + StackDeviceMemory& res, + int prec, + const std::vector& batchSizes, + double lambda = 100.0) { + // run on a different stream to test stream assignment + auto stream = CudaStream::makeNonBlocking(); + + int numInBatch = batchSizes.size(); + uint32_t maxSize = 0; + for (auto v : batchSizes) { + maxSize = std::max(maxSize, v); + } + + auto outBatchStride = getMaxCompressedSize(maxSize); + + auto batch_host = genBatch(batchSizes, lambda); + auto batch_dev = toDevice(res, batch_host, stream); + + auto inPtrs = std::vector(batchSizes.size()); + { + for (int i = 0; i < inPtrs.size(); ++i) { + inPtrs[i] = batch_dev[i].data(); + } + } + + auto enc_dev = res.alloc(stream, numInBatch * outBatchStride); + + auto encPtrs = std::vector(batchSizes.size()); + for (int i = 0; i < inPtrs.size(); ++i) { + encPtrs[i] = (uint8_t*)enc_dev.data() + i * outBatchStride; + } + + auto outCompressedSize_dev = res.alloc(stream, numInBatch); + + ansEncodeBatchPointer( + res, + ANSCodecConfig(prec, true), + numInBatch, + inPtrs.data(), + batchSizes.data(), + nullptr, + encPtrs.data(), + outCompressedSize_dev.data(), + stream); + + auto encSize = outCompressedSize_dev.copyToHost(stream); + for (auto v : encSize) { + // Reported compressed sizes in bytes should be a multiple of 16 for aligned + // packing + EXPECT_EQ(v % 16, 0); + } + + // Decode data + auto dec_dev = buffersToDevice(res, batchSizes, stream); + + auto decPtrs = std::vector(batchSizes.size()); + for (int i = 0; i < inPtrs.size(); ++i) { + decPtrs[i] = dec_dev[i].data(); + } + + auto outSuccess_dev = res.alloc(stream, numInBatch); + auto outSize_dev = res.alloc(stream, numInBatch); + + ansDecodeBatchPointer( + res, + ANSCodecConfig(prec, true), + numInBatch, + (const void**)encPtrs.data(), + decPtrs.data(), + batchSizes.data(), + outSuccess_dev.data(), + outSize_dev.data(), + stream); + + auto outSuccess = outSuccess_dev.copyToHost(stream); + auto outSize = outSize_dev.copyToHost(stream); + + for (int i = 0; i < outSuccess.size(); ++i) { + EXPECT_TRUE(outSuccess[i]); + EXPECT_EQ(outSize[i], batchSizes[i]); + } + + auto dec_host = toHost(res, dec_dev, stream); + EXPECT_EQ(batch_host, dec_host); +} + +void runBatchStride( + StackDeviceMemory& res, + int prec, + int numInBatch, + int inBatchSize, + double lambda = 100.0) { + // run on a different stream to test stream assignment + auto stream = CudaStream::makeNonBlocking(); + + auto orig = generateSymbols(numInBatch * inBatchSize, lambda); + auto orig_dev = res.copyAlloc(stream, orig); + + int outBatchStride = getMaxCompressedSize(inBatchSize); + + auto enc_dev = res.alloc(stream, numInBatch * outBatchStride); + + auto outCompressedSize_dev = res.alloc(stream, numInBatch); + + ansEncodeBatchStride( + res, + ANSCodecConfig(prec, true), + numInBatch, + orig_dev.data(), + inBatchSize, + inBatchSize, + nullptr, + enc_dev.data(), + outBatchStride, + outCompressedSize_dev.data(), + stream); + + auto encSize = outCompressedSize_dev.copyToHost(stream); + for (auto v : encSize) { + // Reported compressed sizes in bytes should be a multiple of 16 for aligned + // packing + EXPECT_EQ(v % 16, 0); + } + + auto dec_dev = res.alloc(stream, numInBatch * inBatchSize); + auto outSuccess_dev = res.alloc(stream, numInBatch); + auto outSize_dev = res.alloc(stream, numInBatch); + + // FIXME: Copy the compressed data to the host and truncate it to make + // sure the compressed size is accurate + ansDecodeBatchStride( + res, + ANSCodecConfig(prec, true), + numInBatch, + enc_dev.data(), + outBatchStride, + dec_dev.data(), + inBatchSize, + inBatchSize, + outSuccess_dev.data(), + outSize_dev.data(), + stream); + + auto outSuccess = outSuccess_dev.copyToHost(stream); + auto outSize = outSize_dev.copyToHost(stream); + + for (auto s : outSuccess) { + EXPECT_TRUE(s); + } + + for (auto s : outSize) { + EXPECT_EQ(s, inBatchSize); + } + + auto dec = dec_dev.copyToHost(stream); + EXPECT_EQ(orig, dec); +} + +TEST(ANSTest, ZeroSized) { + auto res = makeStackMemory(); + runBatchPointer(res, 10, {0}, 10.0); +} + +TEST(ANSTest, BatchPointer) { + auto res = makeStackMemory(); + + for (auto prec : {9, 10, 11}) { + for (auto lambda : {1.0, 10.0, 100.0, 1000.0}) { + runBatchPointer(res, prec, {1}, lambda); + runBatchPointer(res, prec, {1, 1}, lambda); + runBatchPointer(res, prec, {4096, 4095, 4096}, lambda); + runBatchPointer(res, prec, {1234, 2345, 3456}, lambda); + runBatchPointer(res, prec, {10000, 10013, 10000}, lambda); + } + } +} + +TEST(ANSTest, BatchPointerLarge) { + auto res = makeStackMemory(); + + std::random_device rd; + std::mt19937 gen(10); + std::uniform_int_distribution dist(100, 10000); + + std::vector sizes; + for (int i = 0; i < 100; ++i) { + sizes.push_back(dist(gen)); + } + + runBatchPointer(res, 10, sizes); +} + +TEST(ANSTest, BatchStride) { + auto res = makeStackMemory(); + + // FIXME: 16 byte alignment required + runBatchStride(res, 10, 13, 8192 + 16); +} diff --git a/thirdparty/dietgpu/dietgpu/test/BatchPrefixSumTest.cu b/thirdparty/dietgpu/dietgpu/test/BatchPrefixSumTest.cu new file mode 100644 index 000000000..3551b1519 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/test/BatchPrefixSumTest.cu @@ -0,0 +1,169 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include "dietgpu/ans/BatchPrefixSum.cuh" +#include "dietgpu/utils/StackDeviceMemory.h" + +using namespace dietgpu; + +std::vector +makeSequence(uint32_t numInBatch, uint32_t batchSize, int seed = 10) { + auto gen = std::mt19937(10); + auto dist = std::uniform_int_distribution(0, 20); + + auto out = std::vector(numInBatch * batchSize); + + for (auto& v : out) { + v = dist(gen); + } + + return out; +} + +std::vector exclusivePrefixSum( + const std::vector& in, + uint32_t numInBatch, + uint32_t batchSize) { + auto out = std::vector(numInBatch * batchSize); + + for (uint32_t b = 0; b < numInBatch; ++b) { + uint32_t sum = 0; + for (uint32_t i = 0; i < batchSize; ++i) { + auto v = in[b * batchSize + i]; + out[b * batchSize + i] = sum; + sum += v; + } + } + + return out; +} + +TEST(BatchPrefixSum, OneLevel) { + auto res = makeStackMemory(); + // run on a different stream to test stream assignment + auto stream = CudaStream::makeNonBlocking(); + + auto gen = std::mt19937(10); + auto nbDist = std::uniform_int_distribution(1, 20); + + for (auto batchSize : {1, 10, 32, 33, 64, 65, 128, 129, 256, 257, 512}) { + auto numInBatch = nbDist(gen); + + auto data = makeSequence(numInBatch, batchSize, nbDist(gen)); + auto dataPrefix = exclusivePrefixSum(data, numInBatch, batchSize); + + auto data_dev = res.copyAlloc(stream, data); + + auto tempSize = getBatchExclusivePrefixSumTempSize(numInBatch, batchSize); + EXPECT_EQ(tempSize, 0); + + auto prefix_dev = res.alloc(stream, numInBatch * batchSize); + + batchExclusivePrefixSum>( + data_dev.data(), + prefix_dev.data(), + nullptr, + numInBatch, + batchSize, + NoTransform(), + stream); + + auto gpuDataPrefix = prefix_dev.copyToHost(stream); + EXPECT_EQ(dataPrefix, gpuDataPrefix); + } +} + +TEST(BatchPrefixSum, TwoLevel) { + auto res = makeStackMemory(); + // run on a different stream to test stream assignment + auto stream = CudaStream::makeNonBlocking(); + + auto batchSizes = std::vector{ + 513, 1024, 2047, 2048, 4096, 4097, 10000, 100000, 512 * 512}; + + auto gen = std::mt19937(10); + auto nbDist = std::uniform_int_distribution(1, 20); + auto bsDist = std::uniform_int_distribution(513, 512 * 512); + + for (int i = 0; i < 10; ++i) { + batchSizes.push_back(bsDist(gen)); + } + + for (auto batchSize : batchSizes) { + auto numInBatch = nbDist(gen); + + auto data = makeSequence(numInBatch, batchSize, nbDist(gen)); + auto dataPrefix = exclusivePrefixSum(data, numInBatch, batchSize); + + auto data_dev = res.copyAlloc(stream, data); + + auto tempSize = getBatchExclusivePrefixSumTempSize(numInBatch, batchSize); + EXPECT_GT(tempSize, 0); + + auto prefix_dev = res.alloc(stream, numInBatch * batchSize); + auto temp_dev = res.alloc(stream, tempSize); + + batchExclusivePrefixSum>( + data_dev.data(), + prefix_dev.data(), + temp_dev.data(), + numInBatch, + batchSize, + NoTransform(), + stream); + + auto gpuDataPrefix = prefix_dev.copyToHost(stream); + EXPECT_EQ(dataPrefix, gpuDataPrefix); + } +} + +// TEST(BatchPrefixSum, Perf) { +// StandardGpuResources res; +// auto stream = res.getDefaultStreamCurrentDevice(); + +// int numInBatch = 128; +// int batchSize = 4000; + +// auto data = makeSequence(numInBatch, batchSize); +// auto dataPrefix = exclusivePrefixSum(data, numInBatch, batchSize); + +// auto data_dev = toDeviceNonTemporary(&res, data, stream); + +// auto tempSize = getBatchExclusivePrefixSumTempSize(numInBatch, batchSize); +// EXPECT_GT(tempSize, 0); + +// auto prefix_dev = DeviceTensor( +// &res, makeDevAlloc(AllocType::Other, stream), +// {(int) (numInBatch * batchSize)}); + +// auto temp_dev = DeviceTensor( +// &res, makeDevAlloc(AllocType::Other, stream), +// {(int) tempSize}); + +// batchExclusivePrefixSum(data_dev.data(), +// prefix_dev.data(), +// temp_dev.data(), +// numInBatch, +// batchSize, +// stream); + +// auto gpuDataPrefix = prefix_dev.copyToVector(stream); + +// for (int i = 0; i < dataPrefix.size(); ++i) { +// if (dataPrefix[i] != gpuDataPrefix[i]) { +// printf("mismatch on %d: %u %u\n", i, dataPrefix[i], gpuDataPrefix[i]); +// break; +// } +// } + +// // EXPECT_EQ(dataPrefix, gpuDataPrefix); +// } diff --git a/thirdparty/dietgpu/dietgpu/test/FloatTest.cu b/thirdparty/dietgpu/dietgpu/test/FloatTest.cu new file mode 100644 index 000000000..4f811b259 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/test/FloatTest.cu @@ -0,0 +1,311 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +#include "dietgpu/float/GpuFloatCodec.h" +#include "dietgpu/float/GpuFloatUtils.cuh" +#include "dietgpu/utils/StackDeviceMemory.h" + +using namespace dietgpu; + +uint16_t float32ToBFloat16(float f) { + // FIXME: does not round to nearest even + static_assert(sizeof(float) == sizeof(uint32_t), ""); + uint32_t x; + std::memcpy(&x, &f, sizeof(float)); + + x >>= 16; + return x; +} + +uint16_t float32ToFloat16(float f) { + static_assert(sizeof(float) == sizeof(uint32_t), ""); + uint32_t x; + std::memcpy(&x, &f, sizeof(float)); + + uint32_t u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; + uint32_t sign, exponent, mantissa; + + // Get rid of +NaN/-NaN case first. + if (u > 0x7f800000U) { + return 0x7fffU; + } + + sign = ((x >> 16) & 0x8000); + + // Get rid of +Inf/-Inf, +0/-0. + if (u > 0x477fefffU) { + return sign | 0x7c00U; + } + if (u < 0x33000001U) { + return (sign | 0x0000); + } + + exponent = ((u >> 23) & 0xff); + mantissa = (u & 0x7fffff); + + if (exponent > 0x70) { + shift = 13; + exponent -= 0x70; + } else { + shift = 0x7e - exponent; + exponent = 0; + mantissa |= 0x800000; + } + lsb = (1 << shift); + lsb_s1 = (lsb >> 1); + lsb_m1 = (lsb - 1); + + // Round to nearest even. + remainder = (mantissa & lsb_m1); + mantissa >>= shift; + if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { + ++mantissa; + if (!(mantissa & 0x3ff)) { + ++exponent; + mantissa = 0; + } + } + + return (sign | (exponent << 10) | mantissa); +} + +template +struct GenerateFloat; + +template <> +struct GenerateFloat { + static FloatTypeInfo::WordT gen(float v) { + return float32ToFloat16(v); + } +}; + +template <> +struct GenerateFloat { + static FloatTypeInfo::WordT gen(float v) { + return float32ToBFloat16(v); + } +}; + +template <> +struct GenerateFloat { + static FloatTypeInfo::WordT gen(float v) { + FloatTypeInfo::WordT out; + std::memcpy(&out, &v, sizeof(float)); + return out; + } +}; + +template +std::vector::WordT> generateFloats(int num) { + std::mt19937 gen(10 + num); + std::normal_distribution dist; + + auto out = std::vector::WordT>(num); + for (auto& v : out) { + v = GenerateFloat::gen(dist(gen)); + } + + return out; +} + +template +void runBatchPointerTest( + StackDeviceMemory& res, + int probBits, + const std::vector& batchSizes) { + using FTI = FloatTypeInfo; + + // run on a different stream to test stream assignment + auto stream = CudaStream::makeNonBlocking(); + + int numInBatch = batchSizes.size(); + uint32_t totalSize = 0; + uint32_t maxSize = 0; + for (auto v : batchSizes) { + totalSize += v; + maxSize = std::max(maxSize, v); + } + + auto maxCompressedSize = getMaxFloatCompressedSize(FT, maxSize); + + auto orig = generateFloats(totalSize); + auto orig_dev = res.copyAlloc(stream, orig); + + auto inPtrs = std::vector(batchSizes.size()); + { + uint32_t curOffset = 0; + for (int i = 0; i < inPtrs.size(); ++i) { + inPtrs[i] = (const typename FTI::WordT*)orig_dev.data() + curOffset; + curOffset += batchSizes[i]; + } + } + + auto enc_dev = res.alloc(stream, numInBatch * maxCompressedSize); + + auto encPtrs = std::vector(batchSizes.size()); + { + for (int i = 0; i < inPtrs.size(); ++i) { + encPtrs[i] = (uint8_t*)enc_dev.data() + i * maxCompressedSize; + } + } + + auto outBatchSize_dev = res.alloc(stream, numInBatch); + + auto compConfig = + FloatCompressConfig(FT, ANSCodecConfig(probBits), false, true); + + floatCompress( + res, + compConfig, + numInBatch, + inPtrs.data(), + batchSizes.data(), + encPtrs.data(), + outBatchSize_dev.data(), + stream); + + // Decode data + auto dec_dev = res.alloc(stream, totalSize); + + auto decPtrs = std::vector(batchSizes.size()); + { + uint32_t curOffset = 0; + for (int i = 0; i < inPtrs.size(); ++i) { + decPtrs[i] = (typename FTI::WordT*)dec_dev.data() + curOffset; + curOffset += batchSizes[i]; + } + } + + auto outSuccess_dev = res.alloc(stream, numInBatch); + auto outSize_dev = res.alloc(stream, numInBatch); + + auto decompConfig = + FloatDecompressConfig(FT, ANSCodecConfig(probBits), false, true); + + floatDecompress( + res, + decompConfig, + numInBatch, + (const void**)encPtrs.data(), + decPtrs.data(), + batchSizes.data(), + outSuccess_dev.data(), + outSize_dev.data(), + stream); + + auto outSuccess = outSuccess_dev.copyToHost(stream); + auto outSize = outSize_dev.copyToHost(stream); + + for (int i = 0; i < outSuccess.size(); ++i) { + EXPECT_TRUE(outSuccess[i]); + EXPECT_EQ(outSize[i], batchSizes[i]); + } + + auto dec = dec_dev.copyToHost(stream); + + for (int i = 0; i < orig.size(); ++i) { + if (orig[i] != dec[i]) { + printf( + "mismatch at %d / %d: 0x%08X 0x%08X\n", + i, + (int)orig.size(), + orig[i], + dec[i]); + break; + } + } + + EXPECT_EQ(orig, dec); +} + +void runBatchPointerTest( + StackDeviceMemory& res, + FloatType ft, + int probBits, + const std::vector& batchSizes) { + switch (ft) { + case FloatType::kFloat16: + runBatchPointerTest(res, probBits, batchSizes); + break; + case FloatType::kBFloat16: + runBatchPointerTest(res, probBits, batchSizes); + break; + case FloatType::kFloat32: + runBatchPointerTest(res, probBits, batchSizes); + break; + default: + CHECK(false); + break; + } +} + +void runBatchPointerTest( + StackDeviceMemory& res, + FloatType ft, + int probBits, + int numInBatch, + uint32_t multipleOf = 1) { + std::mt19937 gen(10 + numInBatch); + std::uniform_int_distribution dist(1, 10000); + + auto batchSizes = std::vector(numInBatch); + for (auto& v : batchSizes) { + v = roundUp(dist(gen), multipleOf); + } + + runBatchPointerTest(res, ft, probBits, batchSizes); +} + +TEST(FloatTest, Batch) { + auto res = makeStackMemory(); + + for (auto ft : + {FloatType::kFloat16, FloatType::kBFloat16, FloatType::kFloat32}) { + for (auto probBits : {9, 10}) { + for (auto numInBatch : {1, 3, 16, 23}) { + runBatchPointerTest(res, ft, probBits, numInBatch); + // Also test the case where there is uniform 16 byte alignment across + // all batches + runBatchPointerTest(res, ft, probBits, numInBatch, 16); + } + } + } +} + +TEST(FloatTest, LargeBatch) { + auto res = makeStackMemory(); + + auto batchSizes = std::vector(256); + for (auto& v : batchSizes) { + v = 512 * 1024; + } + + for (auto ft : + {FloatType::kFloat16, FloatType::kBFloat16, FloatType::kFloat32}) { + runBatchPointerTest(res, ft, 10, batchSizes); + } +} + +TEST(FloatTest, BatchSize1) { + auto res = makeStackMemory(); + + for (auto ft : + {FloatType::kFloat16, FloatType::kBFloat16, FloatType::kFloat32}) { + for (auto probBits : {9, 10}) { + runBatchPointerTest(res, ft, probBits, {1}); + runBatchPointerTest(res, ft, probBits, {13, 1}); + runBatchPointerTest(res, ft, probBits, {12345, 1, 8083, 1, 17}); + } + } +} diff --git a/thirdparty/dietgpu/dietgpu/utils/CMakeLists.txt b/thirdparty/dietgpu/dietgpu/utils/CMakeLists.txt new file mode 100644 index 000000000..86924f1d2 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/utils/CMakeLists.txt @@ -0,0 +1,22 @@ +add_library(dietgpu_utils SHARED + DeviceUtils.cpp + StackDeviceMemory.cpp +) + +target_include_directories(dietgpu_utils PUBLIC + $ + "${CUDA_INCLUDE_DIRS}" +) +target_link_libraries(dietgpu_utils PUBLIC + ${CUDA_LIBRARIES} + glog::glog +) +target_compile_options(dietgpu_utils PRIVATE $<$: + --generate-line-info + #--device-debug +>) + +get_property(GLOBAL_CUDA_ARCHITECTURES GLOBAL PROPERTY CUDA_ARCHITECTURES) +set_target_properties(dietgpu_utils PROPERTIES + CUDA_ARCHITECTURES "${GLOBAL_CUDA_ARCHITECTURES}" +) diff --git a/thirdparty/dietgpu/dietgpu/utils/DeviceDefs.cuh b/thirdparty/dietgpu/dietgpu/utils/DeviceDefs.cuh new file mode 100644 index 000000000..aff1c9d23 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/utils/DeviceDefs.cuh @@ -0,0 +1,19 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace dietgpu { +#if defined(__HIP_PLATFORM_AMD__) +constexpr int kWarpSize = 64; +#else +constexpr int kWarpSize = 32; +#endif + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/utils/DeviceUtils.cpp b/thirdparty/dietgpu/dietgpu/utils/DeviceUtils.cpp new file mode 100644 index 000000000..a133a2d53 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/utils/DeviceUtils.cpp @@ -0,0 +1,246 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "dietgpu/utils/DeviceUtils.h" +#include +#include +#include + +namespace dietgpu { + +std::string errorToString(cudaError_t err) { + return std::string(cudaGetErrorString(err)); +} + +std::string errorToName(cudaError_t err) { + return std::string(cudaGetErrorName(err)); +} + +int getCurrentDevice() { + int dev = -1; + CUDA_VERIFY(cudaGetDevice(&dev)); + CHECK_NE(dev, -1); + + return dev; +} + +void setCurrentDevice(int device) { + CUDA_VERIFY(cudaSetDevice(device)); +} + +int getNumDevices() { + int numDev = -1; + cudaError_t err = cudaGetDeviceCount(&numDev); + if (cudaErrorNoDevice == err) { + numDev = 0; + } else { + CUDA_VERIFY(err); + } + CHECK_NE(numDev, -1); + + return numDev; +} + +void profilerStart() { + CUDA_VERIFY(cudaProfilerStart()); +} + +void profilerStop() { + CUDA_VERIFY(cudaProfilerStop()); +} + +void synchronizeAllDevices() { + for (int i = 0; i < getNumDevices(); ++i) { + DeviceScope scope(i); + + CUDA_VERIFY(cudaDeviceSynchronize()); + } +} + +const cudaDeviceProp& getDeviceProperties(int device) { + static std::mutex mutex; + static std::unordered_map properties; + + std::lock_guard guard(mutex); + + auto it = properties.find(device); + if (it == properties.end()) { + cudaDeviceProp prop; + CUDA_VERIFY(cudaGetDeviceProperties(&prop, device)); + + properties[device] = prop; + it = properties.find(device); + } + + return it->second; +} + +const cudaDeviceProp& getCurrentDeviceProperties() { + return getDeviceProperties(getCurrentDevice()); +} + +int getMaxThreads(int device) { + return getDeviceProperties(device).maxThreadsPerBlock; +} + +int getMaxThreadsCurrentDevice() { + return getMaxThreads(getCurrentDevice()); +} + +size_t getMaxSharedMemPerBlock(int device) { + return getDeviceProperties(device).sharedMemPerBlock; +} + +size_t getMaxSharedMemPerBlockCurrentDevice() { + return getMaxSharedMemPerBlock(getCurrentDevice()); +} + +int getDeviceForAddress(const void* p) { + if (!p) { + return -1; + } + + cudaPointerAttributes att; + cudaError_t err = cudaPointerGetAttributes(&att, p); + CHECK(err == cudaSuccess || err == cudaErrorInvalidValue) + << "unknown error " << static_cast(err); + + if (err == cudaErrorInvalidValue) { + // Make sure the current thread error status has been reset + err = cudaGetLastError(); + CHECK_EQ(err, cudaErrorInvalidValue) + << "unknown error " << static_cast(err); + + return -1; + } + + // memoryType is deprecated for CUDA 10.0+ +#if defined(__HIP_PLATFORM_AMD__) + if (att.type == hipMemoryTypeHost) { + return -1; + } else { + return att.device; + } +#else +#if CUDA_VERSION < 10000 + if (att.memoryType == cudaMemoryTypeHost) { + return -1; + } else { + return att.device; + } +#else + // FIXME: what to use for managed memory? + if (att.type == cudaMemoryTypeDevice) { + return att.device; + } else { + return -1; + } +#endif +#endif +} + +bool getFullUnifiedMemSupport(int device) { + const auto& prop = getDeviceProperties(device); + return (prop.major >= 6); +} + +bool getFullUnifiedMemSupportCurrentDevice() { + return getFullUnifiedMemSupport(getCurrentDevice()); +} + +DeviceScope::DeviceScope(int device) { + if (device >= 0) { + int curDevice = getCurrentDevice(); + + if (curDevice != device) { + prevDevice_ = curDevice; + setCurrentDevice(device); + return; + } + } + + // Otherwise, we keep the current device + prevDevice_ = -1; +} + +DeviceScope::~DeviceScope() { + if (prevDevice_ != -1) { + setCurrentDevice(prevDevice_); + } +} + +CudaEvent::CudaEvent(cudaStream_t stream, bool timer) : event_(nullptr) { + CUDA_VERIFY(cudaEventCreateWithFlags( + &event_, timer ? cudaEventDefault : cudaEventDisableTiming)); + CUDA_VERIFY(cudaEventRecord(event_, stream)); +} + +CudaEvent::CudaEvent(CudaEvent&& event) noexcept + : event_(std::move(event.event_)) { + event.event_ = nullptr; +} + +CudaEvent::~CudaEvent() { + if (event_) { + CUDA_VERIFY(cudaEventDestroy(event_)); + } +} + +CudaEvent& CudaEvent::operator=(CudaEvent&& event) noexcept { + event_ = std::move(event.event_); + event.event_ = nullptr; + + return *this; +} + +void CudaEvent::streamWaitOnEvent(cudaStream_t stream) { + CUDA_VERIFY(cudaStreamWaitEvent(stream, event_, 0)); +} + +void CudaEvent::cpuWaitOnEvent() { + CUDA_VERIFY(cudaEventSynchronize(event_)); +} + +float CudaEvent::timeFrom(CudaEvent& from) { + cpuWaitOnEvent(); + float ms = 0; + CUDA_VERIFY(cudaEventElapsedTime(&ms, from.event_, event_)); + + return ms; +} + +CudaStream::CudaStream(int flags) : stream_(nullptr) { + CUDA_VERIFY(cudaStreamCreateWithFlags(&stream_, flags)); +} + +CudaStream::CudaStream(CudaStream&& stream) noexcept + : stream_(std::move(stream.stream_)) { + stream.stream_ = nullptr; +} + +CudaStream::~CudaStream() { + if (stream_) { + CUDA_VERIFY(cudaStreamDestroy(stream_)); + } +} + +CudaStream& CudaStream::operator=(CudaStream&& stream) noexcept { + stream_ = std::move(stream.stream_); + stream.stream_ = nullptr; + + return *this; +} + +CudaStream CudaStream::make() { + return CudaStream(); +} + +CudaStream CudaStream::makeNonBlocking() { + return CudaStream(cudaStreamNonBlocking); +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/utils/DeviceUtils.h b/thirdparty/dietgpu/dietgpu/utils/DeviceUtils.h new file mode 100644 index 000000000..4aa201be1 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/utils/DeviceUtils.h @@ -0,0 +1,218 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +/// Wrapper to test return status of CUDA functions +#define CUDA_VERIFY(X) \ + do { \ + auto err__ = (X); \ + CHECK_EQ(err__, cudaSuccess) \ + << "CUDA error " << dietgpu::errorToName(err__) << " " \ + << dietgpu::errorToString(err__); \ + } while (0) + +#define CURAND_VERIFY(X) \ + do { \ + auto err__ = (X); \ + CHECK_EQ(err__, CURAND_STATUS_SUCCESS) << "cuRAND error " << (int)err__; \ + } while (0) + +#ifdef __CUDA_ARCH__ +#define GPU_ASSERT(X) assert(X) +#else +#define GPU_ASSERT(X) CHECK(X) +#endif // __CUDA_ARCH__ + +/// Wrapper to synchronously probe for CUDA errors +// #define GPU_SYNC_ERROR 1 + +#ifdef GPU_SYNC_ERROR +#define CUDA_TEST_ERROR() \ + do { \ + CUDA_VERIFY(cudaDeviceSynchronize()); \ + } while (0) +#else +#define CUDA_TEST_ERROR() \ + do { \ + CUDA_VERIFY(cudaGetLastError()); \ + } while (0) +#endif + +namespace dietgpu { + +/// std::string wrapper around cudaGetErrorString +std::string errorToString(cudaError_t err); + +/// std::string wrapper around cudaGetErrorName +std::string errorToName(cudaError_t err); + +/// Returns the current thread-local GPU device +int getCurrentDevice(); + +/// Sets the current thread-local GPU device +void setCurrentDevice(int device); + +/// Returns the number of available GPU devices +int getNumDevices(); + +/// Starts the CUDA profiler (exposed via SWIG) +void profilerStart(); + +/// Stops the CUDA profiler (exposed via SWIG) +void profilerStop(); + +/// Synchronizes the CPU against all devices (equivalent to +/// cudaDeviceSynchronize for each device) +void synchronizeAllDevices(); + +/// Returns a cached cudaDeviceProp for the given device +const cudaDeviceProp& getDeviceProperties(int device); + +/// Returns the cached cudaDeviceProp for the current device +const cudaDeviceProp& getCurrentDeviceProperties(); + +/// Returns the maximum number of threads available for the given GPU +/// device +int getMaxThreads(int device); + +/// Equivalent to getMaxThreads(getCurrentDevice()) +int getMaxThreadsCurrentDevice(); + +/// Returns the maximum smem available for the given GPU device +size_t getMaxSharedMemPerBlock(int device); + +/// Equivalent to getMaxSharedMemPerBlock(getCurrentDevice()) +size_t getMaxSharedMemPerBlockCurrentDevice(); + +/// For a given pointer, returns whether or not it is located on +/// a device (deviceId >= 0) or the host (-1). +int getDeviceForAddress(const void* p); + +/// Does the given device support full unified memory sharing host +/// memory? +bool getFullUnifiedMemSupport(int device); + +/// Equivalent to getFullUnifiedMemSupport(getCurrentDevice()) +bool getFullUnifiedMemSupportCurrentDevice(); + +/// RAII object to set the current device, and restore the previous +/// device upon destruction +class DeviceScope { + public: + explicit DeviceScope(int device); + ~DeviceScope(); + + private: + int prevDevice_; +}; + +// RAII object to manage a cudaEvent_t +class CudaEvent { + public: + /// Creates an event and records it in this stream + explicit CudaEvent(cudaStream_t stream, bool timer = false); + CudaEvent(const CudaEvent& event) = delete; + CudaEvent(CudaEvent&& event) noexcept; + ~CudaEvent(); + + CudaEvent& operator=(CudaEvent&& event) noexcept; + CudaEvent& operator=(CudaEvent& event) = delete; + + inline cudaEvent_t get() { + return event_; + } + + /// Wait on this event in this stream + void streamWaitOnEvent(cudaStream_t stream); + + /// Have the CPU wait for the completion of this event + void cpuWaitOnEvent(); + + /// Returns the elapsed time from the other event + float timeFrom(CudaEvent& from); + + private: + cudaEvent_t event_; +}; + +// RAII object to manage a cudaStream_t +class CudaStream { + public: + /// Creates a stream on the current device + CudaStream(int flags = cudaStreamDefault); + CudaStream(const CudaStream& stream) = delete; + CudaStream(CudaStream&& stream) noexcept; + ~CudaStream(); + + CudaStream& operator=(CudaStream&& stream) noexcept; + CudaStream& operator=(CudaStream& stream) = delete; + + inline cudaStream_t get() { + return stream_; + } + + operator cudaStream_t() { + return stream_; + } + + static CudaStream make(); + static CudaStream makeNonBlocking(); + + private: + cudaStream_t stream_; +}; + +/// Call for a collection of streams to wait on +template +void streamWaitBase(const L1& listWaiting, const L2& listWaitOn) { + // For all the streams we are waiting on, create an event + std::vector events; + for (auto& stream : listWaitOn) { + cudaEvent_t event; + CUDA_VERIFY(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + CUDA_VERIFY(cudaEventRecord(event, stream)); + events.push_back(event); + } + + // For all the streams that are waiting, issue a wait + for (auto& stream : listWaiting) { + for (auto& event : events) { + CUDA_VERIFY(cudaStreamWaitEvent(stream, event, 0)); + } + } + + for (auto& event : events) { + CUDA_VERIFY(cudaEventDestroy(event)); + } +} + +/// These versions allow usage of initializer_list as arguments, since +/// otherwise {...} doesn't have a type +template +void streamWait(const L1& a, const std::initializer_list& b) { + streamWaitBase(a, b); +} + +template +void streamWait(const std::initializer_list& a, const L2& b) { + streamWaitBase(a, b); +} + +inline void streamWait( + const std::initializer_list& a, + const std::initializer_list& b) { + streamWaitBase(a, b); +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/utils/PtxUtils.cuh b/thirdparty/dietgpu/dietgpu/utils/PtxUtils.cuh new file mode 100644 index 000000000..0c164e21e --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/utils/PtxUtils.cuh @@ -0,0 +1,216 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace dietgpu { + +#if defined(__HIP_PLATFORM_AMD__) +// using WarpMaskT = unsigned long long; +using WarpMaskT = uint64_t; // wave64 +inline constexpr WarpMaskT kFullMask = 0xffffffffffffffffull; +#else +// using WarpMaskT = unsigned int; +using WarpMaskT = uint32_t; // warp32 +inline constexpr WarpMaskT kFullMask = 0xffffffffu; +#endif + +__device__ __forceinline__ WarpMaskT fullMask() { +#if defined(__HIP_PLATFORM_AMD__) + return 0xffffffffffffffffull; +#else + return 0xffffffffu; +#endif +} + + +__device__ __forceinline__ unsigned int +getBitfield(uint8_t val, int pos, int len) { +#if defined(__HIP_PLATFORM_AMD__) + return ((uint32_t)val >> pos) & ((1u << len) - 1u); +#else + unsigned int ret; + asm("bfe.u32 %0, %1, %2, %3;" + : "=r"(ret) + : "r"((uint32_t)val), "r"(pos), "r"(len)); + return ret; +#endif +} + +__device__ __forceinline__ unsigned int +getBitfield(uint16_t val, int pos, int len) { +#if defined(__HIP_PLATFORM_AMD__) + return ((uint32_t)val >> pos) & ((1u << len) - 1u); +#else + unsigned int ret; + asm("bfe.u32 %0, %1, %2, %3;" + : "=r"(ret) + : "r"((uint32_t)val), "r"(pos), "r"(len)); + return ret; +#endif +} + +__device__ __forceinline__ unsigned int +getBitfield(unsigned int val, int pos, int len) { +#if defined(__HIP_PLATFORM_AMD__) + return (val >> pos) & ((1u << len) - 1u); +#else + unsigned int ret; + asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len)); + return ret; +#endif +} + +__device__ __forceinline__ uint64_t +getBitfield(uint64_t val, int pos, int len) { +#if defined(__HIP_PLATFORM_AMD__) + return (val >> pos) & ((1ull << len) - 1ull); +#else + uint64_t ret; + asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len)); + return ret; +#endif +} + +__device__ __forceinline__ unsigned int +setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) { +#if defined(__HIP_PLATFORM_AMD__) + unsigned int mask = ((1u << len) - 1u) << pos; + return (val & ~mask) | ((toInsert << pos) & mask); +#else + unsigned int ret; + asm("bfi.b32 %0, %1, %2, %3, %4;" + : "=r"(ret) + : "r"(toInsert), "r"(val), "r"(pos), "r"(len)); + return ret; +#endif +} + +__device__ __forceinline__ uint32_t rotateLeft(uint32_t v, uint32_t shift) { +#if defined(__HIP_PLATFORM_AMD__) + return (v << shift) | (v >> (32u - shift)); +#else + uint32_t out; + asm("shf.l.clamp.b32 %0, %1, %2, %3;" + : "=r"(out) + : "r"(v), "r"(v), "r"(shift)); + return out; +#endif +} + +__device__ __forceinline__ uint32_t rotateRight(uint32_t v, uint32_t shift) { +#if defined(__HIP_PLATFORM_AMD__) + return (v >> shift) | (v << (32u - shift)); +#else + uint32_t out; + asm("shf.r.clamp.b32 %0, %1, %2, %3;" + : "=r"(out) + : "r"(v), "r"(v), "r"(shift)); + return out; +#endif +} + +__device__ __forceinline__ int getLaneId() { +#if defined(__HIP_PLATFORM_AMD__) + return __lane_id(); +#else + int laneId; + asm("mov.u32 %0, %%laneid;" : "=r"(laneId)); + return laneId; +#endif +} + +__device__ __forceinline__ WarpMaskT getLaneMaskLt() { +#if defined(__HIP_PLATFORM_AMD__) + int lane = __lane_id(); // 0..63 + return (lane == 0) ? 0ull : ((1ull << lane) - 1ull); +#else + WarpMaskT mask; + asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask)); + return mask; +#endif +} + +__device__ __forceinline__ WarpMaskT getLaneMaskLe() { +#if defined(__HIP_PLATFORM_AMD__) + int lane = __lane_id(); + return (lane == 63) ? 0xffffffffffffffffull : ((1ull << (lane + 1)) - 1ull); +#else + WarpMaskT mask; + asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask)); + return mask; +#endif +} + +__device__ __forceinline__ WarpMaskT getLaneMaskGt() { +#if defined(__HIP_PLATFORM_AMD__) + int lane = __lane_id(); + return ~getLaneMaskLe(); +#else + WarpMaskT mask; + asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask)); + return mask; +#endif +} + +__device__ __forceinline__ WarpMaskT getLaneMaskGe() { +#if defined(__HIP_PLATFORM_AMD__) + int lane = __lane_id(); + return ~getLaneMaskLt(); +#else + WarpMaskT mask; + asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask)); + return mask; +#endif +} + + +template +__device__ inline T warpReduceAllMin(T val) { +#if __CUDA_ARCH__ >= 800 + return __reduce_min_sync(kFullMask, val); +#else +#pragma unroll + for (int mask = kWarpSize / 2; mask > 0; mask >>= 1) { + val = min(val, __shfl_xor_sync(kFullMask, val, mask, kWarpSize)); + } + + return val; +#endif +} + +template +__device__ inline T warpReduceAllMax(T val) { +#if __CUDA_ARCH__ >= 800 + return __reduce_max_sync(kFullMask, val); +#else +#pragma unroll + for (int mask = Width / 2; mask > 0; mask >>= 1) { + val = max(val, __shfl_xor_sync(kFullMask, val, mask, kWarpSize)); + } + + return val; +#endif +} + +template +__device__ inline T warpReduceAllSum(T val) { +#if __CUDA_ARCH__ >= 800 + return __reduce_add_sync(kFullMask, val); +#else +#pragma unroll + for (int mask = Width / 2; mask > 0; mask >>= 1) { + val += __shfl_xor_sync(kFullMask, val, mask, kWarpSize); + } + + return val; +#endif +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/utils/StackDeviceMemory.cpp b/thirdparty/dietgpu/dietgpu/utils/StackDeviceMemory.cpp new file mode 100644 index 000000000..9e105d383 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/utils/StackDeviceMemory.cpp @@ -0,0 +1,252 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "dietgpu/utils/StackDeviceMemory.h" +#include +#include +#include +#include "dietgpu/utils/DeviceUtils.h" + +namespace dietgpu { + +namespace { + +size_t adjustStackSize(size_t sz) { + if (sz == 0) { + return 0; + } else { + // ensure that we have at least kSDMAlignment bytes, as all allocations are + // bumped up to it + return std::max(sz, kSDMAlignment); + } +} + +} // namespace + +// +// StackDeviceMemory +// + +StackDeviceMemory::Stack::Stack(int d, size_t sz) + : device_(d), + alloc_(nullptr), + allocSize_(adjustStackSize(sz)), + start_(nullptr), + end_(nullptr), + head_(nullptr), + overflowSize_(0), + maxSeenSize_(0) { + if (allocSize_ == 0) { + return; + } + + DeviceScope s(device_); + CUDA_VERIFY(cudaMalloc(&alloc_, allocSize_)); + CHECK(alloc_); + + // In order to disambiguate between our entire region of temporary memory + // versus the first allocation in the temporary memory region, ensure that the + // first address returned is +kSDMAlignment bytes from the beginning + start_ = alloc_; + head_ = start_; + end_ = alloc_ + allocSize_; +} + +StackDeviceMemory::Stack::Stack(int device, void* p, size_t size) + : device_(device), + alloc_(nullptr), + allocSize_(adjustStackSize(size)), + start_(nullptr), + end_(nullptr), + head_(nullptr), + overflowSize_(0), + maxSeenSize_(0) { + CHECK(p || size == 0); + + // the minimum size that can be provided (see adjustStackSize), if we are + // allocating memory internally + CHECK(size == 0 || size >= kSDMAlignment); + + // alloc_ is not used, as we don't own this allocation + start_ = (char*)p; + head_ = start_; + end_ = p ? (char*)p + allocSize_ : nullptr; +} + +StackDeviceMemory::Stack::~Stack() { + // Make sure there are no outstanding memory allocations + CHECK_EQ(head_, start_); + CHECK(overflowAllocs_.empty()); + CHECK_EQ(overflowSize_, 0); + + // Did we own the stack buffer? + if (alloc_) { + DeviceScope s(device_); + CUDA_VERIFY(cudaFree(alloc_)); + } +} + +size_t StackDeviceMemory::Stack::getSizeAvailable() const { + return (end_ - head_); +} + +size_t StackDeviceMemory::Stack::getSizeTotal() const { + return (end_ - start_); +} + +size_t StackDeviceMemory::Stack::getStackSizeUsed() const { + return (head_ - start_); +} + +void* StackDeviceMemory::Stack::getAlloc( + size_t size, + cudaStream_t stream, + AllocType type) { + // All allocations should have been adjusted to a multiple of kSDMAlignment + // bytes + CHECK_GE(size, kSDMAlignment); + CHECK_EQ(size % kSDMAlignment, 0); + + void* out = nullptr; + + size_t stackMemUsed = head_ - start_; + auto sizeRemaining = getSizeAvailable(); + + if (size > sizeRemaining || type == AllocType::Permanent) { + // No space in the stack, fallback to cudaMalloc + if (type == AllocType::Temporary) { + // Current memory used after this allocation + size_t curUsed = overflowSize_ + size + stackMemUsed; + + std::cerr << "WARNING: StackDeviceMemory: attempting to allocate " << size + << " bytes with " << sizeRemaining + << " bytes available; calling cudaMalloc. " + << "Resize temp memory to >= " + << std::max(maxSeenSize_, curUsed) + << " bytes to avoid performance problems. " + << "(Current usage: " << getStackSizeUsed() << " bytes stack " + << overflowSize_ << " bytes overflow)\n"; + } + + CUDA_VERIFY(cudaMalloc(&out, size)); + CHECK(out); + + overflowAllocs_[out] = size; + overflowSize_ += size; + } else { + // Space is available in the stack + CHECK(head_); + out = head_; + + head_ = head_ + size; + CHECK_LE(head_, end_); + } + + maxSeenSize_ = std::max(maxSeenSize_, stackMemUsed + overflowSize_); + + return out; +} + +void StackDeviceMemory::Stack::returnAlloc( + void* p, + size_t size, + cudaStream_t stream) { + auto it = overflowAllocs_.find(p); + if (it != overflowAllocs_.end()) { + // This allocation was not made on the stack + CHECK_EQ(it->second, size); + + CUDA_VERIFY(cudaFree(p)); + overflowAllocs_.erase(it); + CHECK_GE(overflowSize_, size); + overflowSize_ -= size; + + return; + } + + // Otherwise, this is on our stack + char* pc = static_cast(p); + + // Otherwise, this allocation should be within ourselves + CHECK(pc >= start_ && pc < end_); + + // All allocations should have been adjusted + CHECK_EQ(size % kSDMAlignment, 0); + + // Allocations should be freed in the reverse order they are made + CHECK_EQ(pc + size, head_); + + head_ = pc; +} + +std::string StackDeviceMemory::Stack::toString() const { + std::stringstream s; + + s << "SDM device " << device_ << ": Total memory " << allocSize_ << " [" + << (void*)start_ << ", " << (void*)end_ << ")\n"; + s << " Available memory " << (size_t)(end_ - head_) << " [" + << (void*)head_ << ", " << (void*)end_ << ")\n"; + s << " Maximum seen mem usage " << maxSeenSize_ << "\n"; + + return s.str(); +} + +StackDeviceMemory::StackDeviceMemory(int device, size_t allocPerDevice) + : device_(device), stack_(device, allocPerDevice) {} + +StackDeviceMemory::StackDeviceMemory(int device, void* p, size_t size) + : device_(device), stack_(device, p, size) {} + +StackDeviceMemory::~StackDeviceMemory() = default; + +int StackDeviceMemory::getDevice() const { + return device_; +} + +size_t StackDeviceMemory::getSizeAvailable() const { + return stack_.getSizeAvailable(); +} + +size_t StackDeviceMemory::getSizeTotal() const { + return stack_.getSizeTotal(); +} + +size_t StackDeviceMemory::getMaxMemoryUsage() const { + return stack_.maxSeenSize_; +} + +void StackDeviceMemory::resetMaxMemoryUsage() { + stack_.maxSeenSize_ = 0; +} + +std::string StackDeviceMemory::toString() const { + return stack_.toString(); +} + +void* StackDeviceMemory::allocPointer( + cudaStream_t stream, + size_t size, + AllocType type) { + return stack_.getAlloc(size, stream, type); +} + +void StackDeviceMemory::deallocPointer( + int device, + cudaStream_t stream, + size_t size, + void* p) { + CHECK(p); + CHECK_EQ(device, device_); + + stack_.returnAlloc(p, size, stream); +} + +StackDeviceMemory makeStackMemory(size_t bytes) { + return StackDeviceMemory(getCurrentDevice(), bytes); +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/utils/StackDeviceMemory.h b/thirdparty/dietgpu/dietgpu/utils/StackDeviceMemory.h new file mode 100644 index 000000000..5d2ebde3c --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/utils/StackDeviceMemory.h @@ -0,0 +1,300 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace dietgpu { + +// All memory allocations are aligned to this boundary and are a multiple of +// this size in bytes +constexpr size_t kSDMAlignment = 256; + +class StackDeviceMemory; + +enum class AllocType { + Temporary, + Permanent, +}; + +/// A RAII object that manages a temporary memory request +template +struct GpuMemoryReservation { + GpuMemoryReservation() + : res(nullptr), + device(0), + stream(nullptr), + ptr(nullptr), + num(0), + sizeAllocated(0) {} + + GpuMemoryReservation( + StackDeviceMemory* r, + int dev, + cudaStream_t str, + void* p, + size_t n, + size_t szAlloc) + : res(r), + device(dev), + stream(str), + ptr(p), + num(n), + sizeAllocated(szAlloc) {} + + GpuMemoryReservation(GpuMemoryReservation&& m) noexcept { + res = m.res; + m.res = nullptr; + device = m.device; + m.device = 0; + stream = m.stream; + m.stream = nullptr; + ptr = m.ptr; + m.ptr = nullptr; + num = m.num; + m.num = 0; + sizeAllocated = m.sizeAllocated; + m.sizeAllocated = 0; + } + + ~GpuMemoryReservation(); + + GpuMemoryReservation& operator=(GpuMemoryReservation&& m) { + // Can't be both a valid allocation and the same allocation + CHECK(!(res && res == m.res && device == m.device && ptr == m.ptr)); + + release(); + res = m.res; + m.res = nullptr; + device = m.device; + m.device = 0; + stream = m.stream; + m.stream = nullptr; + ptr = m.ptr; + m.ptr = nullptr; + num = m.num; + m.num = 0; + sizeAllocated = m.sizeAllocated; + m.sizeAllocated = 0; + + return *this; + } + + T* data() { + return reinterpret_cast(ptr); + } + + const T* data() const { + return reinterpret_cast(ptr); + } + + // Copy from the device to a host std::vector, ordered wrt stream + std::vector copyToHost(cudaStream_t stream_2) const { + auto out = std::vector(num); + + CUDA_VERIFY(cudaMemcpyAsync( + out.data(), data(), num * sizeof(T), cudaMemcpyDeviceToHost, stream_2)); + + return out; + } + + void release(); + + StackDeviceMemory* res; + int device; + cudaStream_t stream; + void* ptr; + // number of valid sizeof(T) words available + size_t num; + // size allocated in bytes + size_t sizeAllocated; +}; + +/// Device memory manager that provides temporary memory allocations +/// out of a region of memory, for a single device +class StackDeviceMemory { + public: + /// Allocate a new region of memory that we manage + StackDeviceMemory(int device, size_t allocPerDevice); + + /// Manage a region of memory for a particular device, without ownership + StackDeviceMemory(int device, void* p, size_t size); + ~StackDeviceMemory(); + + int getDevice() const; + + // Allocate a chunk of memory on our device ordered wrt the given stream + // of size sizeof(T) * num bytes + template + GpuMemoryReservation alloc( + cudaStream_t stream, + size_t num, + AllocType type = AllocType::Temporary) { + // All allocations are aligned to this size/boundary + size_t sizeToAlloc = roundUp(num * sizeof(T), kSDMAlignment); + sizeToAlloc = std::max(sizeToAlloc, kSDMAlignment); + + return GpuMemoryReservation( + this, + device_, + stream, + allocPointer(stream, sizeToAlloc, type), + num, + sizeToAlloc); + } + + // Copy a T* array from the host to our device, with the memory allocated + // from ourselves, ordered wrt the given stream + template + GpuMemoryReservation copyAlloc( + cudaStream_t stream, + const T* ptr, + size_t num, + AllocType type = AllocType::Temporary) { + auto size = num * sizeof(T); + auto mem = alloc(stream, size, type); + + CUDA_VERIFY( + cudaMemcpyAsync(mem.data(), ptr, size, cudaMemcpyHostToDevice, stream)); + + return mem; + } + + // Copy a std::vector from the host to our device, with the memory allocated + // from ourselves, ordered wrt the given stream + template + GpuMemoryReservation copyAlloc( + cudaStream_t stream, + const std::vector& v, + AllocType type = AllocType::Temporary) { + return copyAlloc(stream, v.data(), v.size(), type); + } + + /// All allocations requested should be a multiple of kSDMAlignment bytes + void* allocPointer(cudaStream_t stream, size_t size, AllocType type); + void deallocPointer(int device, cudaStream_t, size_t size, void* p); + + size_t getSizeAvailable() const; + size_t getSizeTotal() const; + std::string toString() const; + + size_t getMaxMemoryUsage() const; + void resetMaxMemoryUsage(); + + protected: + /// Previous allocation ranges and the streams for which + /// synchronization is required + struct Range { + inline Range(char* s, char* e, cudaStream_t str) + : start_(s), end_(e), stream_(str) {} + + // References a memory range [start, end) + char* start_; + char* end_; + cudaStream_t stream_; + }; + + struct Stack { + /// Constructor that allocates memory via cudaMalloc + Stack(int device, size_t size); + + /// Constructor that uses an externally-provided region of memory + Stack(int device, void* p, size_t size); + + ~Stack(); + + /// Returns how much size is available for an allocation without + /// calling cudaMalloc + size_t getSizeAvailable() const; + + /// Returns how large our temporary buffer is in total + size_t getSizeTotal() const; + + /// Returns how much stack memory is in use + size_t getStackSizeUsed() const; + + /// Obtains an allocation; all allocations are guaranteed to be 16 + /// byte aligned + void* getAlloc(size_t size, cudaStream_t stream, AllocType type); + + /// Returns an allocation + void returnAlloc(void* p, size_t size, cudaStream_t stream); + + /// Returns the stack state + std::string toString() const; + + /// Device this allocation is on + int device_; + + /// Where our temporary memory buffer is allocated; we allocate starting 16 + /// bytes into this + char* alloc_; + + /// Total size of our allocation + size_t allocSize_; + + /// Our temporary memory region; [start_, end_) is valid + char* start_; + char* end_; + + /// Stack head within [start, end) + char* head_; + + /// Free allocations via cudaMalloc that we made that couldn't fit inside + /// our stack + std::unordered_map overflowAllocs_; + + /// How much memory we currently have in overflowAllocs_ + size_t overflowSize_; + + /// The current maximum seen memory usage, including both stack usage and + /// overflow allocations + size_t maxSeenSize_; + }; + + /// Our device + int device_; + + /// Memory stack + Stack stack_; +}; + +template +GpuMemoryReservation::~GpuMemoryReservation() { + if (ptr) { + CHECK(res); + res->deallocPointer(device, stream, sizeAllocated, ptr); + } +} + +template +void GpuMemoryReservation::release() { + if (ptr) { + CHECK(res); + res->deallocPointer(device, stream, sizeAllocated, ptr); + res = nullptr; + device = 0; + stream = nullptr; + ptr = nullptr; + num = 0; + sizeAllocated = 0; + } +} + +// Construct a StackDeviceMemory for the current device pre-allocating the given +// amount of memory +StackDeviceMemory makeStackMemory(size_t bytes = 256 * 1024 * 1024); + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/dietgpu/utils/StaticUtils.h b/thirdparty/dietgpu/dietgpu/utils/StaticUtils.h new file mode 100644 index 000000000..f4e5d1aa1 --- /dev/null +++ b/thirdparty/dietgpu/dietgpu/utils/StaticUtils.h @@ -0,0 +1,122 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +// allow usage for non-CUDA files +#ifndef __host__ +#define __host__ +#define __device__ +#endif + +namespace dietgpu { + +template +constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) { + return (a / b); +} + +template +constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { + return (a + b - 1) / b; +} + +template +constexpr __host__ __device__ auto roundDown(U a, V b) -> decltype(a + b) { + return divDown(a, b) * b; +} + +template +constexpr __host__ __device__ auto roundUp(U a, V b) -> decltype(a + b) { + return divUp(a, b) * b; +} + +template +constexpr __host__ __device__ bool isEvenDivisor(T a, T b) { + return (a % b == 0) && ((a / b) >= 1); +} + +template +constexpr __host__ __device__ T pow(T n, T power) { + return (power > 0 ? n * pow(n, power - 1) : 1); +} + +template +constexpr __host__ __device__ T pow2(T n) { + return pow(2, (T)n); +} + +static_assert(pow2(8) == 256, "pow2"); + +template +constexpr __host__ __device__ int log2(T n, int p = 0) { + return (n <= 1) ? p : log2(n / 2, p + 1); +} + +static_assert(log2(2) == 1, "log2"); +static_assert(log2(3) == 1, "log2"); +static_assert(log2(4) == 2, "log2"); + +template +constexpr __host__ __device__ bool isPowerOf2(T v) { + return (v && !(v & (v - 1))); +} + +static_assert(isPowerOf2(2048), "isPowerOf2"); +static_assert(!isPowerOf2(3333), "isPowerOf2"); + +template +constexpr __host__ __device__ T nextHighestPowerOf2(T v) { + return (isPowerOf2(v) ? (T)2 * v : ((T)1 << (log2(v) + 1))); +} + +static_assert(nextHighestPowerOf2(1) == 2, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(2) == 4, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(3) == 4, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(4) == 8, "nextHighestPowerOf2"); + +static_assert(nextHighestPowerOf2(15) == 16, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(16) == 32, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(17) == 32, "nextHighestPowerOf2"); + +static_assert( + nextHighestPowerOf2(1536000000u) == 2147483648u, + "nextHighestPowerOf2"); +static_assert( + nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL, + "nextHighestPowerOf2"); + +template +constexpr __host__ __device__ T nextLowestPowerOf2(T v) { + return (isPowerOf2(v) ? v / (T)2 : ((T)1 << (log2(v)))); +} + +static_assert(nextLowestPowerOf2(1) == 0, "nextLowestPowerOf2"); +static_assert(nextLowestPowerOf2(2) == 1, "nextLowestPowerOf2"); +static_assert(nextLowestPowerOf2(3) == 2, "nextLowestPowerOf2"); +static_assert(nextLowestPowerOf2(4) == 2, "nextLowestPowerOf2"); + +static_assert(nextLowestPowerOf2(15) == 8, "nextLowestPowerOf2"); +static_assert(nextLowestPowerOf2(16) == 8, "nextLowestPowerOf2"); +static_assert(nextLowestPowerOf2(17) == 16, "nextLowestPowerOf2"); + +inline __host__ __device__ bool isPointerAligned(const void* p, int align) { + return reinterpret_cast(p) % align == 0; +} + +// Returns the increment needed to aligned the pointer to the next highest +// aligned address +template +inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) { + static_assert(isPowerOf2(Align)); + uint32_t diff = uint32_t(uintptr_t(p) & uintptr_t(Align - 1)); + return diff == 0 ? 0 : uint32_t(Align) - diff; +} + +} // namespace dietgpu diff --git a/thirdparty/dietgpu/setup.py b/thirdparty/dietgpu/setup.py new file mode 100644 index 000000000..dd93430ab --- /dev/null +++ b/thirdparty/dietgpu/setup.py @@ -0,0 +1,270 @@ +import os +import subprocess +import setuptools +from glob import glob +import shutil +import site +from pathlib import Path + +import torch +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +from setuptools.command.install import install + +PROJECT_ROOT = Path(os.path.dirname(__file__)).resolve() + + +class CustomInstall(install): + """Custom install command that installs .so file to INSTALL_DIR""" + + def run(self): + # Run the standard build first + self.run_command("build_ext") + + # Get the install directory + python_site_packages = site.getsitepackages()[0] + install_dir = os.getenv( + "INSTALL_DIR", os.path.join(python_site_packages, "uccl") + ) + os.makedirs(install_dir, exist_ok=True) + + # Find the built .so file + build_lib = self.get_finalized_command("build_ext").build_lib + so_files = list(Path(build_lib).glob("p2p_dietgpu*.so")) + + if not so_files: + raise RuntimeError(f"Could not find built .so file in {build_lib}") + + so_file = so_files[0] + dest_path = os.path.join(install_dir, so_file.name) + + # Copy the .so file to the install directory + print(f"Installing {so_file.name} to {install_dir}") + shutil.copy2(so_file, dest_path) + print(f"Installation complete. Module installed as: {dest_path}") + + +if __name__ == "__main__": + cxx_flags = [ + "-O3", + "-Wno-deprecated-declarations", + "-Wno-unused-variable", + "-Wno-sign-compare", + "-Wno-reorder", + "-Wno-attributes", + "-Wno-unused-result", + "-Wno-unused-function", + ] + nvcc_flags = ["-O3", "-Xcompiler", "-O3"] + + base = "./dietgpu" + dirs = ["utils", "float", "ans"] + exts = ("*.cu", "*.cpp", "*.cc") + sources = [os.path.join(base, "DietGpu.cpp")] + + for d in dirs: + for ext in exts: + sources += glob(os.path.join(base, d, ext)) + libraries = ["ibverbs", "glog", "nl-3", "nl-route-3", "numa"] + include_dirs = [PROJECT_ROOT, PROJECT_ROOT / ".." / "include"] + library_dirs = [] + nvcc_dlink = [] + extra_link_args = [] + + if torch.version.cuda: + # Add CUDA library directory to library_dirs + cuda_home = os.getenv("CUDA_HOME", "/usr/local/cuda") + library_dirs.append(str(Path(cuda_home) / "lib64")) + + # EFA (Elastic Fabric Adapter) Detection + efa_home = os.getenv("EFA_HOME", "/opt/amazon/efa") + has_efa = os.path.exists(efa_home) + if has_efa: + print("EFA detected, building with EFA support") + else: + print("EFA not detected, building without EFA") + + # Architecture Detection + arch = os.uname().machine + cpu_is_arm64 = arch == "aarch64" + + # GPU Detection + gpu_name = "" + gpu_is_hopper = False + detected_compute_cap = None + try: + gpu_query = ( + subprocess.check_output( + ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], + stderr=subprocess.DEVNULL, + ) + .decode("ascii") + .strip() + .split("\n")[0] + ) + gpu_name = gpu_query + gpu_is_hopper = "GH200" in gpu_name + + # Auto-detect compute capability + compute_cap_query = ( + subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=compute_cap", + "--format=csv,noheader", + ], + stderr=subprocess.DEVNULL, + ) + .decode("ascii") + .strip() + .split("\n")[0] + ) + detected_compute_cap = compute_cap_query.strip() + print(f"Detected GPU compute capability: {detected_compute_cap}") + except Exception as e: + print(f"Warning: Could not detect GPU info via nvidia-smi: {e}") + + # GH200 (Grace Hopper) Detection + has_gh200 = cpu_is_arm64 and gpu_is_hopper + if has_gh200: + print( + f"GH200 detected (GPU: {gpu_name}, CPU: {arch}), building with GH200 support" + ) + else: + print("GH200 not detected, building without GH200 support") + + # Add EFA flags if detected + if has_efa: + cxx_flags.append("-DEFA") + nvcc_flags.append("-DEFA") + include_dirs.append(Path(efa_home) / "include") + library_dirs.append(Path(efa_home) / "lib") + libraries.append("efa") + + # Add GH200 flags if detected + if has_gh200: + cxx_flags.append("-DUSE_GRACE_HOPPER") + nvcc_flags.append("-DUSE_GRACE_HOPPER") + + # Use auto-detected compute capability if available + if detected_compute_cap: + default_arch = detected_compute_cap + else: + # Fallback to 9.0 if detection failed + default_arch = "9.0" + + if int(os.getenv("DISABLE_SM90_FEATURES", 0)): + # Force A100 architecture + default_arch = "8.0" + # Disable some SM90 features: FP8, launch methods, and TMA + cxx_flags.append("-DDISABLE_SM90_FEATURES") + nvcc_flags.append("-DDISABLE_SM90_FEATURES") + else: + # For SM90 and above, add register usage optimization + if float(default_arch) >= 9.0: + nvcc_flags.extend(["--ptxas-options=--register-usage-level=10"]) + + # Set architecture environment variable before creating CUDAExtension + device_arch = os.getenv("TORCH_CUDA_ARCH_LIST", default_arch) + os.environ["TORCH_CUDA_ARCH_LIST"] = device_arch + else: + print("+++++++++++++++++++++++++++++++++++++++++++++") + device_arch = os.getenv("TORCH_CUDA_ARCH_LIST", "gfx942") + include_dirs.append("/opt/rocm/include") + for arch in device_arch.split(","): + nvcc_flags.append(f"--offload-arch={arch.lower()}") + + # Disable SM90 features on AMD + cxx_flags.append("-DDISABLE_SM90_FEATURES") + nvcc_flags.append("-DDISABLE_SM90_FEATURES") + + # Enable HIP DSA support + cxx_flags.append("-DTORCH_USE_HIP_DSA") + nvcc_flags.append("-DTORCH_USE_HIP_DSA") + + if int(os.getenv("DISABLE_AGGRESSIVE_ATOMIC", 0)): + cxx_flags.append("-DDISABLE_AGGRESSIVE_ATOMIC") + nvcc_flags.append("-DDISABLE_AGGRESSIVE_ATOMIC") + + # Disable LD/ST tricks, as some CUDA version does not support `.L1::no_allocate` + # Only enable aggressive PTX instructions for SM 9.0+ (H100/H800/B200) + try: + arch_version = float(device_arch.strip()) + if arch_version < 9.0: + os.environ["DISABLE_AGGRESSIVE_PTX_INSTRS"] = "1" + else: + # Enable aggressive PTX instructions for SM 9.0+ + os.environ.setdefault("DISABLE_AGGRESSIVE_PTX_INSTRS", "0") + except (ValueError, AttributeError): + os.environ.setdefault("DISABLE_AGGRESSIVE_PTX_INSTRS", "1") + + # Apply aggressive PTX instruction flag + if int(os.getenv("DISABLE_AGGRESSIVE_PTX_INSTRS", "0")): + cxx_flags.append("-DDISABLE_AGGRESSIVE_PTX_INSTRS") + nvcc_flags.append("-DDISABLE_AGGRESSIVE_PTX_INSTRS") + + # Put them together + extra_compile_args = { + "cxx": cxx_flags, + "nvcc": nvcc_flags, + } + if len(nvcc_dlink) > 0: + extra_compile_args["nvcc_dlink"] = nvcc_dlink + + # Convert Path objects to strings for include_dirs and library_dirs + include_dirs = [str(d) for d in include_dirs] + library_dirs = [str(d) for d in library_dirs] + + # Summary + print("\n" + "=" * 60) + print("Build Summary") + print("=" * 60) + print(f" > Platform: {'ROCm' if torch.version.hip else 'CUDA'}") + if torch.version.cuda: + print(f" > Architecture: {arch}") + if gpu_name: + print(f" > GPU: {gpu_name}") + print(f" > EFA Support: {'Yes' if has_efa else 'No'}") + print(f" > GH200 Support: {'Yes' if has_gh200 else 'No'}") + print(f" > Device Arch: {device_arch}") + print(f" > Sources: {len(sources)} files") + print(f" > Include Dirs: {include_dirs}") + print(f" > Library Dirs: {library_dirs}") + print(f" > Libraries: {libraries}") + print(f" > CXX Flags: {cxx_flags}") + print(f" > NVCC Flags: {nvcc_flags}") + print(f" > Link Flags: {extra_link_args}") + print("=" * 60 + "\n") + + # noinspection PyBroadException + try: + cmd = ["git", "rev-parse", "--short", "HEAD"] + revision = "+" + subprocess.check_output(cmd).decode("ascii").rstrip() + except Exception as _: + revision = "" + + has_hip = any(src.endswith(".hip") for src in sources) + + if has_hip: + print("[INFO] HIP source detected (.hip files found)") + else: + print("[INFO] No HIP source files detected") + + setuptools.setup( + name="p2p_dietgpu", + version="0.0.1" + revision, + ext_modules=[ + CUDAExtension( + name="p2p_dietgpu", + include_dirs=include_dirs, + library_dirs=library_dirs, + sources=sources, + libraries=libraries, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ) + ], + cmdclass={ + "build_ext": BuildExtension, + "install": CustomInstall, + }, + ) \ No newline at end of file