Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ build_p2p() {
else
echo "[container] USE_TCPX=1, skipping copying p2p runtime files"
fi
if [[ "$TARGET" == rocm* ]]; then
cd thirdparty/dietgpu
rm -rf build/
python3 setup.py build
cd ../..
cp thirdparty/dietgpu/build/**/*.so uccl/
fi
}

build_ep() {
Expand Down Expand Up @@ -271,9 +278,10 @@ echo "[2/3] Running build inside container..."

# Auto-detect CUDA architecture for ep build
DETECTED_GPU_ARCH=""
if [[ "$BUILD_TYPE" =~ (ep|all) ]];then
if [[ "$BUILD_TYPE" =~ (ep|all|p2p) ]];then
if [[ "$TARGET" == cuda* ]] && command -v nvidia-smi &> /dev/null; then
DETECTED_GPU_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null | head -n1 | tr -d ' ')
DETECTED_GPU_ARCH="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null | head -n1 | tr -d ' ' || true)"

if [[ -n "$DETECTED_GPU_ARCH" ]]; then
echo "Auto-detected CUDA compute capability: ${DETECTED_GPU_ARCH}"
fi
Expand All @@ -283,10 +291,23 @@ if [[ "$BUILD_TYPE" =~ (ep|all) ]];then
echo "jq not found, installing via pip..."
pip install jq
fi
DETECTED_GPU_ARCH=$(amd-smi static -g 0 --asic --json | jq -r '.[].asic.target_graphics_version')
if [[ -n "$DETECTED_GPU_ARCH" ]]; then
echo "Auto-detected ROCm architecture: ${DETECTED_GPU_ARCH}"
DETECTED_GPU_ARCH="$(
PYTHONWARNINGS=ignore \
amd-smi static -g 0 --asic --json 2>/dev/null \
| jq -r '
if .gpu_data and (.gpu_data | length > 0) then
.gpu_data[0].asic.target_graphics_version
else
empty
end
' \
|| true
)"
if [[ -n "$DETECTED_GPU_ARCH" ]]; then
echo "Auto-detected ROCm architecture: ${DETECTED_GPU_ARCH}"
fi
else
echo "[INFO] No compatible GPU detection tool found, skipping auto-detect"
fi
fi

Expand All @@ -306,6 +327,7 @@ docker run --rm --user "$(id -u):$(id -g)" \
-e BUILD_TYPE="${BUILD_TYPE}" \
-e USE_TCPX="${USE_TCPX:-0}" \
-e USE_EFA="${USE_EFA:-0}" \
-e USE_IB="${USE_IB:-0}" \
-e MAKE_NORMAL_MODE="${MAKE_NORMAL_MODE:-}" \
-e TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-}" \
-e FUNCTION_DEF="$(declare -f build_rccl_nccl_h build_ccl_rdma build_ccl_efa build_p2p build_ep build_eccl)" \
Expand Down Expand Up @@ -397,6 +419,7 @@ def initialize():
--exclude "libcudart.so.12" \
--exclude "libamdhip64.so.*" \
--exclude "libcuda.so.1" \
--exclude "libefa.so.1" \
-w /io/${WHEEL_DIR}

# Add backend tag to wheel filename using local version identifier
Expand Down
5 changes: 3 additions & 2 deletions docker/Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ RUN apt-get update && \
rdma-core libibverbs-dev libnuma-dev \
libgoogle-glog-dev libgflags-dev libgtest-dev libelf-dev \
pkg-config zlib1g-dev curl unzip \
software-properties-common && \
software-properties-common \
hipcub && \
\
# ───── Add Python ${PY_VER} PPA & install Python ${PY_VER} + setuptools ─────
add-apt-repository ppa:deadsnakes/ppa && \
Expand All @@ -40,7 +41,7 @@ RUN apt-get update && \
RUN python${PY_VER} -m pip install --no-cache-dir build auditwheel pybind11

RUN python${PY_VER} -m pip install --no-cache-dir --pre torch torchvision \
--index-url https://download.pytorch.org/whl/nightly/rocm7.0
--index-url https://download.pytorch.org/whl/nightly/rocm7.1

RUN python${PY_VER} -m pip install --no-cache-dir --upgrade setuptools

Expand Down
7 changes: 4 additions & 3 deletions ep/bench/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,9 +679,6 @@ def dispatch(

# Internode
if self.runtime.get_num_rdma_ranks() > 1:
assert (
num_worst_tokens == 0
), "Internode dispatch does not support `num_worst_tokens > 0`"
return self.internode_dispatch(
x,
handle,
Expand All @@ -692,6 +689,7 @@ def dispatch(
topk_idx,
topk_weights,
expert_alignment,
num_worst_tokens,
config,
previous_event,
async_finish,
Expand Down Expand Up @@ -881,6 +879,7 @@ def internode_dispatch(
topk_idx: Optional[torch.Tensor] = None,
topk_weights: Optional[torch.Tensor] = None,
expert_alignment: int = 1,
num_worst_tokens: int = 0,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None,
async_finish: bool = False,
Expand Down Expand Up @@ -934,6 +933,7 @@ def internode_dispatch(
gbl_channel_prefix_matrix,
recv_gbl_rank_prefix_sum,
expert_alignment,
num_worst_tokens,
config,
getattr(previous_event, "event", None),
async_finish,
Expand Down Expand Up @@ -986,6 +986,7 @@ def internode_dispatch(
None,
None,
expert_alignment,
num_worst_tokens,
config,
getattr(previous_event, "event", None),
async_finish,
Expand Down
7 changes: 6 additions & 1 deletion ep/bench/run_ep.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@ if [ "$MODE" = "ll" ]; then
--master_addr=$MAIN_IP --master_port=12355 \
test_low_latency.py --num-tokens=128 \
--hidden=7168 --num-topk=8 --num-experts=288
else
elif [ "$MODE" = "ht" ]; then
torchrun --nnodes=$NNODES --nproc_per_node=8 --node_rank=$RANK \
--master_addr=$MAIN_IP --master_port=12355 \
test_internode.py --num-tokens=4096 \
--hidden=7168 --num-topk=8 --num-experts=288 --test-ll-compatibility
else
torchrun --nnodes=$NNODES --nproc_per_node=8 --node_rank=$RANK \
--master_addr=$MAIN_IP --master_port=12355 \
test_internode.py --num-tokens=4096 \
--hidden=7168 --num-topk=8 --num-experts=256 --pressure-test-mode=1
fi
# --log-dir=logs --redirect=3
File renamed without changes.
File renamed without changes.
38 changes: 37 additions & 1 deletion ep/bench/test_internode.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
)
if current_x is not x_pure_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
recv_topk_weights_clone = None
if with_topk:
# Check `topk_idx`
assert (
Expand All @@ -265,6 +266,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
assert recv_topk_idx.eq(i).sum().item() == count

# Check `topk_weights`
recv_topk_weights_clone = recv_topk_weights.clone()
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = (
recv_topk_weights.amax(dim=1, keepdim=True).expand_as(
Expand All @@ -273,6 +275,40 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
)
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)

# Test `num_worst_tokens != 0`
if with_topk:
num_worst_tokens = num_tokens * num_ranks
dispatch_args.update({"num_worst_tokens": num_worst_tokens})
(
recv_worst_x,
recv_worst_topk_idx,
recv_worst_topk_weights,
empty_list,
_,
event,
) = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_worst_x = (
per_token_cast_back(*recv_worst_x)
if isinstance(recv_worst_x, tuple)
else recv_worst_x
)
assert len(empty_list) == 0
assert num_worst_tokens == recv_worst_x.size(0)
assert num_worst_tokens == recv_worst_topk_idx.size(0)
assert num_worst_tokens == recv_worst_topk_weights.size(0)
assert torch.equal(recv_x, recv_worst_x[: recv_x.size(0)])
assert torch.equal(
recv_topk_idx, recv_worst_topk_idx[: recv_x.size(0)]
)
assert torch.equal(
recv_topk_weights_clone,
recv_worst_topk_weights[: recv_x.size(0)],
)
assert torch.all(
recv_worst_topk_idx[recv_x.size(0) :] == -1
).item()

# Test cached dispatch (must without top-k staffs)
if not with_topk:
dispatch_args = {
Expand Down Expand Up @@ -502,7 +538,7 @@ def test_loop(

assert num_local_ranks == 8 and num_ranks > 8

for seed in range(int(1e9)):
for seed in range(0, int(1e9)):
if local_rank == 0:
print(f"Testing with seed {seed} ...", flush=True)
torch.manual_seed(rank + seed)
Expand Down
47 changes: 2 additions & 45 deletions ep/bench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,53 +89,10 @@ def init_dist_under_torchrun(local_rank: int, num_local_ranks: int):
)


def _discover_local_ip():
# Try to infer the IP that can reach MASTER_ADDR (works in most clusters)
import socket, os

# Method 1: Use MASTER_ADDR if available (torchrun style)
if "MASTER_ADDR" in os.environ:
master = os.environ["MASTER_ADDR"]
port = int(os.environ.get("MASTER_PORT", "29500"))
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect((master, port))
return s.getsockname()[0]
except:
pass
finally:
s.close()

# Method 2: Use hostname resolution (works in AWS and most cloud environments)
hostname = socket.gethostname()
try:
# This usually returns the private IP in cloud environments
local_ip = socket.gethostbyname(hostname)
# Skip loopback addresses
if not local_ip.startswith("127."):
return local_ip
except:
pass

# Method 3: Connect to a public DNS to determine outgoing interface
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# Google DNS - this doesn't actually send packets
s.connect(("8.8.8.8", 80))
local_ip = s.getsockname()[0]
s.close()
return local_ip
except:
pass

# Last resort: return localhost
return "127.0.0.1"


def _gather_peer_ips(group):
# Gather local IP strings across ranks
world = dist.get_world_size(group)
my_ip = _discover_local_ip()
my_ip = ep.get_oob_ip()
ips = [None] * world
dist.all_gather_object(ips, my_ip, group=group)
return ips
Expand All @@ -154,7 +111,7 @@ def get_peer_ip(rank: int, num_ranks: int, group: dist.ProcessGroup):


def get_cpu_proxies_meta(proxies, rank, scratch_ptr, scratch_bytes, num_ranks, group):
my_ip = _discover_local_ip()
my_ip = ep.get_oob_ip()
meta = {
"rank": rank,
"ptr": int(scratch_ptr),
Expand Down
Loading