Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions ep/bench/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,9 +679,6 @@ def dispatch(

# Internode
if self.runtime.get_num_rdma_ranks() > 1:
assert (
num_worst_tokens == 0
), "Internode dispatch does not support `num_worst_tokens > 0`"
return self.internode_dispatch(
x,
handle,
Expand All @@ -692,6 +689,7 @@ def dispatch(
topk_idx,
topk_weights,
expert_alignment,
num_worst_tokens,
config,
previous_event,
async_finish,
Expand Down Expand Up @@ -881,6 +879,7 @@ def internode_dispatch(
topk_idx: Optional[torch.Tensor] = None,
topk_weights: Optional[torch.Tensor] = None,
expert_alignment: int = 1,
num_worst_tokens: int = 0,
config: Optional[Config] = None,
previous_event: Optional[EventOverlap] = None,
async_finish: bool = False,
Expand Down Expand Up @@ -934,6 +933,7 @@ def internode_dispatch(
gbl_channel_prefix_matrix,
recv_gbl_rank_prefix_sum,
expert_alignment,
num_worst_tokens,
config,
getattr(previous_event, "event", None),
async_finish,
Expand Down Expand Up @@ -986,6 +986,7 @@ def internode_dispatch(
None,
None,
expert_alignment,
num_worst_tokens,
config,
getattr(previous_event, "event", None),
async_finish,
Expand Down
36 changes: 36 additions & 0 deletions ep/bench/test_internode.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
)
if current_x is not x_pure_rand:
check_data(recv_x, recv_gbl_rank_prefix_sum)
recv_topk_weights_clone = None
if with_topk:
# Check `topk_idx`
assert (
Expand All @@ -265,6 +266,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
assert recv_topk_idx.eq(i).sum().item() == count

# Check `topk_weights`
recv_topk_weights_clone = recv_topk_weights.clone()
if current_x is not x_pure_rand:
recv_topk_weights[recv_topk_idx.eq(-1)] = (
recv_topk_weights.amax(dim=1, keepdim=True).expand_as(
Expand All @@ -273,6 +275,40 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
)
check_data(recv_topk_weights, recv_gbl_rank_prefix_sum)

# Test `num_worst_tokens != 0`
if with_topk:
num_worst_tokens = num_tokens * num_ranks
dispatch_args.update({"num_worst_tokens": num_worst_tokens})
(
recv_worst_x,
recv_worst_topk_idx,
recv_worst_topk_weights,
empty_list,
_,
event,
) = buffer.dispatch(**dispatch_args)
event.current_stream_wait() if async_mode else ()
recv_worst_x = (
per_token_cast_back(*recv_worst_x)
if isinstance(recv_worst_x, tuple)
else recv_worst_x
)
assert len(empty_list) == 0
assert num_worst_tokens == recv_worst_x.size(0)
assert num_worst_tokens == recv_worst_topk_idx.size(0)
assert num_worst_tokens == recv_worst_topk_weights.size(0)
assert torch.equal(recv_x, recv_worst_x[: recv_x.size(0)])
assert torch.equal(
recv_topk_idx, recv_worst_topk_idx[: recv_x.size(0)]
)
assert torch.equal(
recv_topk_weights_clone,
recv_worst_topk_weights[: recv_x.size(0)],
)
assert torch.all(
recv_worst_topk_idx[recv_x.size(0) :] == -1
).item()

# Test cached dispatch (must without top-k staffs)
if not with_topk:
dispatch_args = {
Expand Down
47 changes: 2 additions & 45 deletions ep/bench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,53 +89,10 @@ def init_dist_under_torchrun(local_rank: int, num_local_ranks: int):
)


def _discover_local_ip():
# Try to infer the IP that can reach MASTER_ADDR (works in most clusters)
import socket, os

# Method 1: Use MASTER_ADDR if available (torchrun style)
if "MASTER_ADDR" in os.environ:
master = os.environ["MASTER_ADDR"]
port = int(os.environ.get("MASTER_PORT", "29500"))
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect((master, port))
return s.getsockname()[0]
except:
pass
finally:
s.close()

# Method 2: Use hostname resolution (works in AWS and most cloud environments)
hostname = socket.gethostname()
try:
# This usually returns the private IP in cloud environments
local_ip = socket.gethostbyname(hostname)
# Skip loopback addresses
if not local_ip.startswith("127."):
return local_ip
except:
pass

# Method 3: Connect to a public DNS to determine outgoing interface
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# Google DNS - this doesn't actually send packets
s.connect(("8.8.8.8", 80))
local_ip = s.getsockname()[0]
s.close()
return local_ip
except:
pass

# Last resort: return localhost
return "127.0.0.1"


def _gather_peer_ips(group):
# Gather local IP strings across ranks
world = dist.get_world_size(group)
my_ip = _discover_local_ip()
my_ip = ep.get_oob_ip()
ips = [None] * world
dist.all_gather_object(ips, my_ip, group=group)
return ips
Expand All @@ -154,7 +111,7 @@ def get_peer_ip(rank: int, num_ranks: int, group: dist.ProcessGroup):


def get_cpu_proxies_meta(proxies, rank, scratch_ptr, scratch_bytes, num_ranks, group):
my_ip = _discover_local_ip()
my_ip = ep.get_oob_ip()
meta = {
"rank": rank,
"ptr": int(scratch_ptr),
Expand Down
40 changes: 19 additions & 21 deletions ep/include/internode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ void notify_dispatch(
int const* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped,
int const* num_tokens_per_expert, int* moe_recv_expert_counter_mapped,
int num_experts, bool const* is_token_in_rank, int num_tokens,
int num_channels, int hidden_int4, int num_scales, int num_topk,
int expert_alignment, int* rdma_channel_prefix_matrix,
int num_worst_tokens, int num_channels, int hidden_int4, int num_scales,
int num_topk, int expert_alignment, int* rdma_channel_prefix_matrix,
int* recv_rdma_rank_prefix_sum, int* gbl_channel_prefix_matrix,
int* recv_gbl_rank_prefix_sum, void* rdma_buffer_ptr,
int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs,
Expand All @@ -54,25 +54,23 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx,
uint64_t const* d2h_channel_addrs, int num_d2h_channel_addrs,
void* atomic_buffer_ptr);

void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx,
float* recv_topk_weights, void* recv_src_meta, void const* x,
float const* x_scales, int64_t const* topk_idx,
float const* topk_weights, int* send_rdma_head,
int* send_nvl_head, int* recv_rdma_channel_prefix_matrix,
int* recv_gbl_channel_prefix_matrix,
int const* rdma_channel_prefix_matrix,
int const* recv_rdma_rank_prefix_sum,
int const* gbl_channel_prefix_matrix,
int const* recv_gbl_rank_prefix_sum, bool const* is_token_in_rank,
int num_tokens, int hidden_int4, int num_scales, int num_topk,
int num_experts, int scale_token_stride, int scale_hidden_stride,
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens,
int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs,
int num_max_nvl_chunked_send_tokens,
int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks,
bool is_cached_dispatch, cudaStream_t stream, int num_channels,
bool low_latency_mode, uint64_t const* d2h_channel_addrs,
int num_d2h_channel_addrs, void* atomic_buffer_ptr);
void dispatch(
void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx,
float* recv_topk_weights, void* recv_src_meta, void const* x,
float const* x_scales, int64_t const* topk_idx, float const* topk_weights,
int* send_rdma_head, int* send_nvl_head,
int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix,
int const* rdma_channel_prefix_matrix, int const* recv_rdma_rank_prefix_sum,
int const* gbl_channel_prefix_matrix, int const* recv_gbl_rank_prefix_sum,
bool const* is_token_in_rank, int num_tokens, int num_worst_tokens,
int hidden_int4, int num_scales, int num_topk, int num_experts,
int scale_token_stride, int scale_hidden_stride, void* rdma_buffer_ptr,
int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens,
int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks,
bool is_cached_dispatch, cudaStream_t stream, int num_channels,
bool low_latency_mode, uint64_t const* d2h_channel_addrs,
int num_d2h_channel_addrs, void* atomic_buffer_ptr);

void combine(cudaDataType_t type, void* combined_x,
float* combined_topk_weights,
Expand Down
5 changes: 5 additions & 0 deletions ep/include/rdma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
#define RDMA_HPP
#include "common.hpp"
#include "proxy_ctx.hpp"
// clang-format off
// prevent clang-format reordering net.h before util.h
#include "util/util.h"
#include "util/net.h"
// clang-format on
#include "ring_buffer.cuh"
#include "unistd.h"
#include <infiniband/efadv.h>
Expand Down
Loading