Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion ep/bench/run_ep.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@ if [ "$MODE" = "ll" ]; then
--master_addr=$MAIN_IP --master_port=12355 \
test_low_latency.py --num-tokens=128 \
--hidden=7168 --num-topk=8 --num-experts=288
else
elif [ "$MODE" = "ht" ]; then
torchrun --nnodes=$NNODES --nproc_per_node=8 --node_rank=$RANK \
--master_addr=$MAIN_IP --master_port=12355 \
test_internode.py --num-tokens=4096 \
--hidden=7168 --num-topk=8 --num-experts=288 --test-ll-compatibility
else
torchrun --nnodes=$NNODES --nproc_per_node=8 --node_rank=$RANK \
--master_addr=$MAIN_IP --master_port=12355 \
test_internode.py --num-tokens=4096 \
--hidden=7168 --num-topk=8 --num-experts=256 --pressure-test-mode=1
fi
# --log-dir=logs --redirect=3
2 changes: 1 addition & 1 deletion ep/bench/test_internode.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def test_loop(

assert num_local_ranks == 8 and num_ranks > 8

for seed in range(int(1e9)):
for seed in range(676, int(1e9)):
if local_rank == 0:
print(f"Testing with seed {seed} ...", flush=True)
torch.manual_seed(rank + seed)
Expand Down
13 changes: 9 additions & 4 deletions ep/include/ep_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,8 @@ __device__ __forceinline__ void trap() {
__device__ __forceinline__ int ld_volatile_global(int const* ptr) {
int ret;
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
ret = __atomic_load_n(const_cast<int*>(ptr), __ATOMIC_RELAXED);
ret = __hip_atomic_load(const_cast<int*>(ptr), __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_SYSTEM);
#else
asm volatile("ld.volatile.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
#endif
Expand All @@ -754,7 +755,8 @@ __device__ __forceinline__ int ld_volatile_global(int const* ptr) {
__device__ __forceinline__ float ld_volatile_global(float const* ptr) {
float ret;
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
__atomic_load(const_cast<float*>(ptr), &ret, __ATOMIC_RELAXED);
ret = __hip_atomic_load(const_cast<float*>(ptr), __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_SYSTEM);
#else
asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
#endif
Expand All @@ -764,7 +766,8 @@ __device__ __forceinline__ float ld_volatile_global(float const* ptr) {
__device__ __forceinline__ int64_t ld_volatile_global(int64_t const* ptr) {
int64_t ret;
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
ret = __atomic_load_n(const_cast<int64_t*>(ptr), __ATOMIC_RELAXED);
ret = __hip_atomic_load(const_cast<int64_t*>(ptr), __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_SYSTEM);
#else
asm volatile("ld.volatile.global.s64 %0, [%1];" : "=l"(ret) : "l"(ptr));
#endif
Expand All @@ -774,7 +777,8 @@ __device__ __forceinline__ int64_t ld_volatile_global(int64_t const* ptr) {
__device__ __forceinline__ int64_t ld_volatile_global(uint64_t const* ptr) {
int64_t ret;
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
ret = __atomic_load_n(const_cast<uint64_t*>(ptr), __ATOMIC_RELAXED);
ret = __hip_atomic_load(const_cast<uint64_t*>(ptr), __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_SYSTEM);
#else
asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
#endif
Expand Down Expand Up @@ -836,6 +840,7 @@ __forceinline__ __device__ void barrier_block(int** barrier_signal_ptrs,
// Add self-ranks, sub other ranks
if (thread_id < kNumRanks) {
atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG);
memory_fence();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the memory_fence() needed here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG);
}
EP_DEVICE_ASSERT(kNumRanks <= blockDim.x);
Expand Down
3 changes: 3 additions & 0 deletions ep/src/internode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,9 @@ __global__ void __launch_bounds__(
trap();
}
}
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
memory_fence();
#endif
auto src_rdma_head =
__shfl_sync(WARP_MASK, cached_rdma_channel_head, src_rdma_rank);
auto src_rdma_tail =
Expand Down
11 changes: 5 additions & 6 deletions ep/src/proxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ LocalBarrier* map_local_barrier_shm(std::string const& name, bool* out_owner) {
perror("shm_open(existing)");
return nullptr;
}
struct stat st{};
struct stat st {};
int tries = 1000;
while (tries-- > 0) {
if (fstat(fd, &st) == 0 && static_cast<size_t>(st.st_size) >= kSize)
Expand Down Expand Up @@ -186,9 +186,9 @@ void Proxy::init_common() {
#ifdef EFA
IBV_ACCESS_REMOTE_READ
#else
IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC
IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC
#endif
);
);

if (!ctx_.atomic_buffer_mr) {
perror("Failed to register atomic_buffer_ptr MR");
Expand Down Expand Up @@ -504,7 +504,7 @@ void Proxy::run_dual() {
void Proxy::notify_gpu_completion(uint64_t& my_tail) {
if (acked_wrs_.empty()) return;

// Mark all acked command slots in each ring's bitmask
// Mark all acked command slots in each ring's bitmask
#ifdef USE_MSCCLPP_FIFO_BACKEND
// FIFO path: pop in order using the pending deque and the completion set.
for (size_t rb_idx = 0; rb_idx < cfg_.d2h_queues.size(); ++rb_idx) {
Expand Down Expand Up @@ -1215,8 +1215,7 @@ void Proxy::barrier_check() {

// When global release comes back (CQ handler should set these):
// NOTE: BarrierImm is 21 bits, so we must mask the local seq.
if (ctx_.barrier_released &&
ctx_.barrier_release_seq == seq) {
if (ctx_.barrier_released && ctx_.barrier_release_seq == seq) {
// Reset local mask for next barrier and consume the global release
ctx_.barrier_released = false;

Expand Down
30 changes: 22 additions & 8 deletions ep/src/rdma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ void per_thread_rdma_init(ProxyCtx& S, void* gpu_buf, size_t bytes, int rank,
IBV_ACCESS_REMOTE_ATOMIC);
#else
S.mr = ibv_reg_mr_iova2(S.pd, gpu_buf, bytes, iova,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE |
IBV_ACCESS_RELAXED_ORDERING);
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE |
IBV_ACCESS_RELAXED_ORDERING);
#endif

if (!S.mr) {
Expand Down Expand Up @@ -554,6 +554,13 @@ void modify_qp_to_rtr(ProxyCtx& S, RDMAConnectionInfo* remote,
exit(1);
}

// Query device attributes to get max_dest_rd_atomic
struct ibv_device_attr dev_attr;
if (ibv_query_device(S.context, &dev_attr)) {
perror("Failed to query device attributes");
exit(1);
}

if (port_attr.link_layer == IBV_LINK_LAYER_ETHERNET) {
printf("RoCE detected (Ethernet)\n");
is_roce = 1;
Expand All @@ -571,16 +578,16 @@ void modify_qp_to_rtr(ProxyCtx& S, RDMAConnectionInfo* remote,
attr.path_mtu = port_attr.active_mtu;
attr.dest_qp_num = remote->qp_num;
attr.rq_psn = remote->psn;
attr.max_dest_rd_atomic = 1;
attr.max_dest_rd_atomic = dev_attr.max_qp_init_rd_atom;
attr.min_rnr_timer = 12;

if (is_roce) {
attr.ah_attr.is_global = 1;
attr.ah_attr.port_num = 1;
attr.ah_attr.sl = 135;
attr.ah_attr.sl = 0;
attr.ah_attr.src_path_bits = 0;
attr.ah_attr.grh.traffic_class = 3;
attr.ah_attr.grh.hop_limit = 64;
attr.ah_attr.grh.traffic_class = 0;
attr.ah_attr.grh.hop_limit = 255;
// Fill GID from remote_info
memcpy(&attr.ah_attr.grh.dgid, remote->gid, 16);
attr.ah_attr.grh.sgid_index = S.gid_index;
Expand Down Expand Up @@ -655,14 +662,21 @@ void modify_qp_to_rts(ProxyCtx& S, RDMAConnectionInfo* local_info) {
#ifdef EFA
return;
#endif
// Query device attributes to get max_rd_atomic
struct ibv_device_attr dev_attr;
if (ibv_query_device(S.context, &dev_attr)) {
perror("Failed to query device attributes");
exit(1);
}

struct ibv_qp_attr attr;
memset(&attr, 0, sizeof(attr));
attr.qp_state = IBV_QPS_RTS;
attr.timeout = 14;
attr.timeout = 20;
attr.retry_cnt = 7;
attr.rnr_retry = 7;
attr.sq_psn = local_info->psn;
attr.max_rd_atomic = 1;
attr.max_rd_atomic = dev_attr.max_qp_rd_atom;
attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_ATOMIC;

int flags = IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
Expand Down