Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3340d74
ep debug mem consistency issues
MaoZiming Dec 23, 2025
ce7ab42
nit
MaoZiming Dec 23, 2025
4d37b7c
used rdma atomic + cudamalloc for atomic on AMD non-EFA
MaoZiming Dec 27, 2025
565051c
revert formatting python
MaoZiming Dec 27, 2025
f0f15c8
revert formatting other python
MaoZiming Dec 27, 2025
71062d0
Minor change
MaoZiming Dec 27, 2025
5ba543c
revert some changes to rdma.cpp
MaoZiming Dec 27, 2025
760ebdb
wrap post_atomic_operations_native_rdma under ifdef AMD
MaoZiming Dec 27, 2025
b7132e5
format
MaoZiming Dec 27, 2025
f77006e
Merge branch 'main' into ep-debug-amd-mem-consistency
MaoZiming Dec 27, 2025
47f4396
combine
MaoZiming Dec 28, 2025
fb1f535
remove relaxed ordering flag
MaoZiming Dec 30, 2025
aa23177
change atomic buffer cudaMalloc to hipExtMallocWithFlags with uncache…
zhenhuang12 Jan 2, 2026
63b2203
efa stability
MaoZiming Jan 3, 2026
9f1bee7
mask seq
MaoZiming Jan 3, 2026
bfbadc1
fix
MaoZiming Jan 3, 2026
7f1da83
fix nits
MaoZiming Jan 3, 2026
a69de86
Merge branch 'ep-efa-stability' of https://github.com/uccl-project/uc…
MaoZiming Jan 3, 2026
25546e5
Merge pull request #616 from uccl-project/ep-efa-stability
MaoZiming Jan 4, 2026
5163c76
revert start barrier_seq
MaoZiming Jan 4, 2026
f857757
config rdma
YangZhou1997 Jan 5, 2026
6a38c2a
tunes
YangZhou1997 Jan 5, 2026
41bea60
nits
YangZhou1997 Jan 5, 2026
51bf4d8
add cudaMemPrefetchAsync back.
zhenhuang12 Jan 6, 2026
66e4e14
restored some paras
YangZhou1997 Jan 6, 2026
7f42d01
Merge branch 'ep-debug-amd-mem-consistency-yang' of https://github.co…
YangZhou1997 Jan 6, 2026
9c9e061
upgrade to rocm7.1
YangZhou1997 Jan 6, 2026
58e282f
fix small error of ep build
monopodium Jan 8, 2026
91e37ff
Merge remote-tracking branch 'origin/fix_small_build_error' into ep-d…
YangZhou1997 Jan 8, 2026
e76147d
Merge branch 'main' of https://github.com/uccl-project/uccl into ep-d…
YangZhou1997 Jan 9, 2026
c518f8a
tune rdma conifg
YangZhou1997 Jan 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ RUN apt-get update && \
RUN python${PY_VER} -m pip install --no-cache-dir build auditwheel pybind11

RUN python${PY_VER} -m pip install --no-cache-dir --pre torch torchvision \
--index-url https://download.pytorch.org/whl/nightly/rocm7.0
--index-url https://download.pytorch.org/whl/nightly/rocm7.1

RUN python${PY_VER} -m pip install --no-cache-dir --upgrade setuptools

Expand Down
7 changes: 6 additions & 1 deletion ep/bench/run_ep.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@ if [ "$MODE" = "ll" ]; then
--master_addr=$MAIN_IP --master_port=12355 \
test_low_latency.py --num-tokens=128 \
--hidden=7168 --num-topk=8 --num-experts=288
else
elif [ "$MODE" = "ht" ]; then
torchrun --nnodes=$NNODES --nproc_per_node=8 --node_rank=$RANK \
--master_addr=$MAIN_IP --master_port=12355 \
test_internode.py --num-tokens=4096 \
--hidden=7168 --num-topk=8 --num-experts=288 --test-ll-compatibility
else
torchrun --nnodes=$NNODES --nproc_per_node=8 --node_rank=$RANK \
--master_addr=$MAIN_IP --master_port=12355 \
test_internode.py --num-tokens=4096 \
--hidden=7168 --num-topk=8 --num-experts=256 --pressure-test-mode=1
fi
# --log-dir=logs --redirect=3
2 changes: 1 addition & 1 deletion ep/bench/test_internode.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def test_loop(

assert num_local_ranks == 8 and num_ranks > 8

for seed in range(int(1e9)):
for seed in range(0, int(1e9)):
if local_rank == 0:
print(f"Testing with seed {seed} ...", flush=True)
torch.manual_seed(rank + seed)
Expand Down
12 changes: 12 additions & 0 deletions ep/bench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,18 @@ def initialize_uccl(

ep.register_proxies(local_rank, proxies)

# Set atomic buffer pointer for all proxies BEFORE starting them
# This ensures the atomic buffer info is included in connection info exchange
# Note: Only thread 0's proxy allocates the atomic buffer in its constructor
if not is_intranode and len(proxies) > 0:
# Get atomic buffer pointer from thread 0 proxy (only thread 0 allocates it)
# This must be done before start_dual() so the atomic buffer info is included
# in the connection info exchange during init_common()
atomic_buffer_ptr = proxies[0].get_atomic_buffer_ptr()
if atomic_buffer_ptr:
for proxy in proxies:
proxy.set_atomic_buffer_ptr(atomic_buffer_ptr)

dist.barrier(group)
if not is_intranode:
for proxy in proxies:
Expand Down
1 change: 1 addition & 0 deletions ep/include/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <stdio.h>
#include <unistd.h>

// #define SOFTWARE_ORDERING
#define MAX_IB_DEVS 32
// #define MEASURE_PER_OP_LATENCY
// #define MEASURE_PER_VERB_LATENCY
Expand Down
13 changes: 9 additions & 4 deletions ep/include/ep_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,8 @@ __device__ __forceinline__ void trap() {
__device__ __forceinline__ int ld_volatile_global(int const* ptr) {
int ret;
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
ret = __atomic_load_n(const_cast<int*>(ptr), __ATOMIC_RELAXED);
ret = __hip_atomic_load(const_cast<int*>(ptr), __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_SYSTEM);
#else
asm volatile("ld.volatile.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
#endif
Expand All @@ -754,7 +755,8 @@ __device__ __forceinline__ int ld_volatile_global(int const* ptr) {
__device__ __forceinline__ float ld_volatile_global(float const* ptr) {
float ret;
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
__atomic_load(const_cast<float*>(ptr), &ret, __ATOMIC_RELAXED);
ret = __hip_atomic_load(const_cast<float*>(ptr), __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_SYSTEM);
#else
asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
#endif
Expand All @@ -764,7 +766,8 @@ __device__ __forceinline__ float ld_volatile_global(float const* ptr) {
__device__ __forceinline__ int64_t ld_volatile_global(int64_t const* ptr) {
int64_t ret;
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
ret = __atomic_load_n(const_cast<int64_t*>(ptr), __ATOMIC_RELAXED);
ret = __hip_atomic_load(const_cast<int64_t*>(ptr), __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_SYSTEM);
#else
asm volatile("ld.volatile.global.s64 %0, [%1];" : "=l"(ret) : "l"(ptr));
#endif
Expand All @@ -774,7 +777,8 @@ __device__ __forceinline__ int64_t ld_volatile_global(int64_t const* ptr) {
__device__ __forceinline__ int64_t ld_volatile_global(uint64_t const* ptr) {
int64_t ret;
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
ret = __atomic_load_n(const_cast<uint64_t*>(ptr), __ATOMIC_RELAXED);
ret = __hip_atomic_load(const_cast<uint64_t*>(ptr), __ATOMIC_RELAXED,
__HIP_MEMORY_SCOPE_SYSTEM);
#else
asm volatile("ld.volatile.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
#endif
Expand Down Expand Up @@ -836,6 +840,7 @@ __forceinline__ __device__ void barrier_block(int** barrier_signal_ptrs,
// Add self-ranks, sub other ranks
if (thread_id < kNumRanks) {
atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG);
memory_fence();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the memory_fence() needed here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG);
}
EP_DEVICE_ASSERT(kNumRanks <= blockDim.x);
Expand Down
6 changes: 6 additions & 0 deletions ep/include/proxy_ctx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ struct ProxyCtx {
uint32_t remote_rkey = 0;
uint32_t rkey = 0;

// Atomic buffer (separate from main RDMA buffer)
ibv_mr* atomic_buffer_mr = nullptr; // MR for local atomic_buffer_ptr
uintptr_t remote_atomic_buffer_addr = 0; // Remote atomic_buffer_ptr address
uint64_t remote_atomic_buffer_len = 0; // Remote atomic_buffer_ptr length
uint32_t remote_atomic_buffer_rkey = 0; // Remote atomic_buffer_ptr rkey

// Buffer offset within rdma_buffer for address translation
uintptr_t dispatch_recv_data_offset =
0; // offset of dispatch_rdma_recv_data_buffer from rdma_buffer base
Expand Down
13 changes: 10 additions & 3 deletions ep/include/rdma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ struct RDMAConnectionInfo {
uint16_t lid; // Local ID
uint8_t gid[16]; // Global ID for RoCE (optional)

// Atomic buffer info (separate from main GPU buffer)
uint32_t atomic_buffer_rkey = 0; // Atomic buffer memory region key
uintptr_t atomic_buffer_addr = 0; // Atomic buffer address
uint64_t atomic_buffer_len = 0; // Atomic buffer length

// #ifdef EFA
uint32_t num_rings;
uint32_t data_qp_num[kChannelPerProxy];
Expand Down Expand Up @@ -288,13 +293,14 @@ struct BarrierImm {
// [28:8]=SEQ (21 bits), [7:0]=SRC_RANK
static constexpr uint32_t kCtrlBit = 1u << 30;
static constexpr uint32_t kAckBit = 1u << 29;
static constexpr uint32_t kSeqMask = 0x1FFFFFu;
static inline bool IsAck(uint32_t imm) { return (imm & kAckBit) != 0u; }
static inline uint32_t Pack(bool ack, uint32_t seq, uint8_t src_rank) {
return kCtrlBit | (ack ? kAckBit : 0u) |
((seq & 0x1FFFFFu) << 8) // 21 bits for seq
((seq & kSeqMask) << 8) // 21 bits for seq
| uint32_t(src_rank);
}
static inline uint32_t Seq(uint32_t imm) { return (imm >> 8) & 0x1FFFFFu; }
static inline uint32_t Seq(uint32_t imm) { return (imm >> 8) & kSeqMask; }
static inline uint8_t Rank(uint32_t imm) { return imm & 0xFFu; }
explicit BarrierImm(uint32_t imm = 0) : value(imm) {}
bool GetIsAck() const { return IsAck(value); }
Expand Down Expand Up @@ -334,7 +340,8 @@ void remote_process_completions(
int my_rank, int num_nodes, bool use_normal_mode = false);
void create_per_thread_qp(ProxyCtx& S, void* gpu_buffer, size_t size,
RDMAConnectionInfo* local_info, int rank,
size_t num_rings, bool use_normal_mode);
size_t num_rings, bool use_normal_mode,
void* atomic_buffer_ptr = nullptr);
ibv_cq* create_per_thread_cq(ProxyCtx& S);
void remote_poll_completions(ProxyCtx& S, int idx, CopyRingBuffer& g_ring,
std::vector<ProxyCtx*>& ctx_by_tag,
Expand Down
2 changes: 1 addition & 1 deletion ep/install_deps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ if check_cuda; then
elif check_rocm; then
echo "Detected ROCM"
# Install Pytorch using nightly
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.0
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.1
else
echo "No CUDA or ROCM detected"
exit 1
Expand Down
43 changes: 32 additions & 11 deletions ep/src/internode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ __global__ void __launch_bounds__(
// Read RDMA rank existence
uint64_t is_token_in_rank_uint64 = 0;
if (lane_id < kNumRDMARanks) {
is_token_in_rank_uint64 = __ldg(reinterpret_cast<uint64_t const*>(
is_token_in_rank_uint64 = *(reinterpret_cast<uint64_t const*>(
is_token_in_rank + token_idx * num_ranks +
lane_id * NUM_MAX_NVL_PEERS));
}
Expand Down Expand Up @@ -1039,10 +1039,20 @@ __global__ void __launch_bounds__(
num_bytes_per_msg,
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
channel_id, // NOTE(MaoZiming): use channel_id for rb.
lane_id, 0, d2h_channel_addrs, num_d2h_channel_addrs, false, -1,
lane_id, 0, d2h_channel_addrs, num_d2h_channel_addrs, false,
// NOTE(MaoZiming): for AMD GPUs, we directly send a subsequent RDMA
// to update the tail. For other GPUs and EFA NICs, we use the
// CPU-emulated atomics, allow us to piggyback the atomic operation
// with the RDMA send.
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
-1,
#else
-1,
reinterpret_cast<uint64_t>(rdma_channel_tail.buffer(rdma_rank)) -
reinterpret_cast<uint64_t>(original_atomic_buffer_ptr),
num_tokens_to_issue);
num_tokens_to_issue
#endif
);
} else {
// Lighter fence for local RDMA rank
memory_fence();
Expand All @@ -1060,7 +1070,13 @@ __global__ void __launch_bounds__(
translate_dst_rdma_rank<kLowLatencyMode>(dst_rdma_rank, nvl_rank),
channel_id, // NOTE(MaoZiming): use channel_id for rb.
dst_rdma_rank == rdma_rank, d2h_channel_addrs,
num_d2h_channel_addrs, false, -1, true);
num_d2h_channel_addrs, false, -1,
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
false
#else
true
#endif
);
}
__syncwarp();
}
Expand Down Expand Up @@ -1263,13 +1279,8 @@ __global__ void __launch_bounds__(
// Move tail index
__syncwarp();
if (lane_id == 0)
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
__atomic_store_n(nvl_channel_tail.buffer(), cached_nvl_channel_tail,
__ATOMIC_RELEASE);
#else
st_release_sys_global(nvl_channel_tail.buffer(),
cached_nvl_channel_tail);
#endif
}
// Retired
__syncwarp();
Expand Down Expand Up @@ -2640,10 +2651,14 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1)
nvl_rank),
channel_id, // NOTE(MaoZiming): use channel_id for rb.
lane_id, 0, d2h_channel_addrs, num_d2h_channel_addrs, false, -1,
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
#else
reinterpret_cast<uint64_t>(
rdma_channel_tail.buffer(rdma_rank)) -
reinterpret_cast<uint64_t>(original_atomic_buffer_ptr),
num_chunked_tokens);
num_chunked_tokens
#endif
);
} else {
memory_fence();
}
Expand All @@ -2659,7 +2674,13 @@ __global__ void __launch_bounds__((kNumForwarders + 1) * WARP_SIZE, 1)
nvl_rank),
channel_id, // NOTE(MaoZiming): use warp_id for rb.
dst_rdma_rank == rdma_rank, d2h_channel_addrs,
num_d2h_channel_addrs, false, -1, true);
num_d2h_channel_addrs, false, -1,
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
false
#else
true
#endif
);
}
}
}
Expand Down
63 changes: 56 additions & 7 deletions ep/src/proxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,32 @@ void Proxy::init_common() {
cfg_.thread_idx, cfg_.local_rank);
pin_thread_to_numa_wrapper();
if (!ctx_.cq) ctx_.cq = create_per_thread_cq(ctx_);

// Register atomic_buffer_ptr as a separate RDMA memory region if it was set
// This must be done after PD is initialized by per_thread_rdma_init
if (atomic_buffer_ptr_ && !ctx_.atomic_buffer_mr) {
ctx_.atomic_buffer_mr =
ibv_reg_mr(ctx_.pd, atomic_buffer_ptr_, kAtomicBufferSize,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE |
#ifdef EFA
IBV_ACCESS_REMOTE_READ
#else
IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC
#endif
);

if (!ctx_.atomic_buffer_mr) {
perror("Failed to register atomic_buffer_ptr MR");
std::abort();
}

fprintf(stderr,
"[Proxy] Registered atomic_buffer_ptr MR: addr=0x%llx, len=%zu, "
"rkey=0x%x\n",
(unsigned long long)ctx_.atomic_buffer_mr->addr,
(size_t)ctx_.atomic_buffer_mr->length, ctx_.atomic_buffer_mr->rkey);
}

if (ctxs_for_all_ranks_.empty()) {
fprintf(stderr,
"Error: peers metadata not set before init_common (peers_.size() "
Expand All @@ -195,6 +221,14 @@ void Proxy::init_common() {
ctx_.atomic_old_values_buf =
reinterpret_cast<uint32_t*>(static_cast<uint8_t*>(cfg_.gpu_buffer) +
cfg_.total_size - atomic_buf_size);
// Check alignment - only abort if not aligned (8-byte alignment required for
// 64-bit atomics)
if ((reinterpret_cast<uintptr_t>(ctx_.atomic_old_values_buf) & 0x7) != 0) {
fprintf(stderr, "Atomic buffer not 8-byte aligned: 0x%llx\n",
(unsigned long long)reinterpret_cast<uintptr_t>(
ctx_.atomic_old_values_buf));
std::abort();
}

int num_ranks = ctxs_for_all_ranks_.size();
local_infos_.assign(num_ranks, RDMAConnectionInfo{});
Expand All @@ -218,6 +252,8 @@ void Proxy::init_common() {
// NOTE(MaoZiming): each context can share the same cq, pd, mr.
// but the qp must be different.
c.cq = ctx_.cq;
// Share the atomic buffer MR with peer contexts
c.atomic_buffer_mr = ctx_.atomic_buffer_mr;

if (peer == my_rank) continue;
// Skip rdma connection for intra-node.
Expand All @@ -226,7 +262,7 @@ void Proxy::init_common() {
continue;
create_per_thread_qp(c, cfg_.gpu_buffer, cfg_.total_size,
&local_infos_[peer], my_rank, cfg_.d2h_queues.size(),
cfg_.use_normal_mode);
cfg_.use_normal_mode, atomic_buffer_ptr_);
modify_qp_to_init(c);
}

Expand Down Expand Up @@ -291,12 +327,24 @@ void Proxy::init_common() {
c.remote_addr = remote_infos_[peer].addr;
c.remote_rkey = remote_infos_[peer].rkey;
c.remote_len = remote_infos_[peer].len;
if (FILE* f = fopen("/tmp/uccl_debug.txt", "a")) {

// Set remote atomic buffer info from exchanged connection info
c.remote_atomic_buffer_addr = remote_infos_[peer].atomic_buffer_addr;
c.remote_atomic_buffer_len = remote_infos_[peer].atomic_buffer_len;
c.remote_atomic_buffer_rkey = remote_infos_[peer].atomic_buffer_rkey;

if (c.remote_atomic_buffer_addr == 0) {
fprintf(
f,
"[PROXY_INIT] me=%d peer=%d: remote_addr=0x%lx local_buffer=0x%lx\n",
my_rank, peer, c.remote_addr, (uintptr_t)cfg_.gpu_buffer);
fclose(f);
stderr,
"[Proxy] WARNING: Remote atomic buffer not registered for peer %d "
"(local atomic_buffer_ptr=%p, local atomic_buffer_mr=%p)\n",
peer, atomic_buffer_ptr_, (void*)ctx_.atomic_buffer_mr);
} else {
fprintf(stderr,
"[Proxy] Remote atomic buffer info for peer %d: addr=0x%llx, "
"len=%zu, rkey=0x%x\n",
peer, (unsigned long long)c.remote_atomic_buffer_addr,
(size_t)c.remote_atomic_buffer_len, c.remote_atomic_buffer_rkey);
}
}
usleep(50 * 1000);
Expand Down Expand Up @@ -1089,7 +1137,7 @@ void Proxy::send_barrier(uint64_t wr) {
#endif
assert(ctx_.barrier_wr == -1 && "barrier_wr should be 0");
ctx_.barrier_wr = wr;
ctx_.barrier_seq = ctx_.barrier_seq + 1;
ctx_.barrier_seq = (ctx_.barrier_seq + 1) & BarrierImm::kSeqMask;

if (cfg_.rank == ctx_.node_leader_rank) {
if (ctx_.barrier_arrived.size() != static_cast<size_t>(cfg_.num_nodes)) {
Expand Down Expand Up @@ -1167,6 +1215,7 @@ void Proxy::barrier_check() {
}

// When global release comes back (CQ handler should set these):
// NOTE: BarrierImm is 21 bits, so we must mask the local seq.
if (ctx_.barrier_released && ctx_.barrier_release_seq == seq) {
// Reset local mask for next barrier and consume the global release
ctx_.barrier_released = false;
Expand Down
Loading