Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transport/net_ib/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,8 @@ ncclResult_t ncclIbPeerMemSupport();
ncclResult_t ncclIbDmaBufSupport(int dev);

void ncclIbAddEvent(struct ncclIbRequest* req, int devIndex);
ncclResult_t ncclIbGetGidIndex(struct ibv_context* context, uint8_t portNum, struct ibv_port_attr* portAttr,
int* gidIndex);
ncclResult_t ncclIbGetGidIndex(struct ibv_context *context, uint8_t portNum, struct ibv_port_attr* portAttr, int *gidIndex);
ncclResult_t ncclIbRefreshGidInfo(struct ncclIbDev* ibDev, struct ncclIbGidInfo* gidInfo);
ncclResult_t ncclIbGetRequest(struct ncclIbNetCommBase* base, struct ncclIbRequest** req);
ncclResult_t ncclIbFreeRequest(struct ncclIbRequest* r);

Expand Down
47 changes: 45 additions & 2 deletions src/transport/net_ib/connect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,51 @@ static ncclResult_t ncclUpdateGidIndex(struct ibv_context* context, uint8_t port
return ncclSuccess;
}

ncclResult_t ncclIbGetGidIndex(struct ibv_context* context, uint8_t portNum, struct ibv_port_attr* portAttr,
int* gidIndex) {
// Re-query the device port + GID table and update the cached gidInfo to point
// at a currently valid GID. Required after a port flap, since the kernel may
// re-register GIDs at a different index than at initial connect time, leaving
// the previously-cached localGidIndex pointing at an empty slot. RoCE/Ethernet
// only; for IB the GID layout is stable so no refresh is needed. The freshly
// queried port_attr is also written back into ibDev->portAttr so that callers
// reading active_mtu etc. see post-flap values.
ncclResult_t ncclIbRefreshGidInfo(struct ncclIbDev* ibDev, struct ncclIbGidInfo* gidInfo) {
if (ibDev == NULL || gidInfo == NULL) return ncclInternalError;
if (gidInfo->link_layer != IBV_LINK_LAYER_ETHERNET) return ncclSuccess;

struct ibv_port_attr portAttr;
ncclResult_t res = wrap_ibv_query_port(ibDev->context, ibDev->portNum, &portAttr);
if (res != ncclSuccess) {
INFO(NCCL_NET, "NET/IB: %s: query_port failed on %s:%d, keeping cached gidIndex=%d",
__func__, ibDev->devName, ibDev->portNum, gidInfo->localGidIndex);
return ncclSuccess;
}
int newIdx = -1;
res = ncclIbGetGidIndex(ibDev->context, ibDev->portNum, &portAttr, &newIdx);
if (res != ncclSuccess || newIdx < 0) {
INFO(NCCL_NET, "NET/IB: %s: get_gid_index failed on %s:%d, keeping cached gidIndex=%d",
__func__, ibDev->devName, ibDev->portNum, gidInfo->localGidIndex);
return ncclSuccess;
}
union ibv_gid newGid;
res = wrap_ibv_query_gid(ibDev->context, ibDev->portNum, newIdx, &newGid);
if (res != ncclSuccess) {
INFO(NCCL_NET, "NET/IB: %s: query_gid(idx=%d) failed on %s:%d, keeping cached gidIndex=%d",
__func__, newIdx, ibDev->devName, ibDev->portNum, gidInfo->localGidIndex);
return ncclSuccess;
}

ibDev->portAttr = portAttr;
if (newIdx != gidInfo->localGidIndex ||
memcmp(&newGid, &gidInfo->localGid, sizeof(newGid)) != 0) {
INFO(NCCL_NET, "NET/IB: %s: refreshed GID for %s:%d: idx %d -> %d",
__func__, ibDev->devName, ibDev->portNum, gidInfo->localGidIndex, newIdx);
gidInfo->localGid = newGid;
gidInfo->localGidIndex = newIdx;
}
return ncclSuccess;
}

ncclResult_t ncclIbGetGidIndex(struct ibv_context *context, uint8_t portNum, struct ibv_port_attr* portAttr, int *gidIndex) {
int gidTblLen = portAttr->gid_tbl_len;

// for IB, choose GID Index that will have routable FLID if present
Expand Down
141 changes: 141 additions & 0 deletions src/transport/net_ib/p2p_resiliency_recovery.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ struct ncclIbPortRecoveryContext {
// Indicates whether an "ack" message posted to the peer was completed locally.
bool ackCompleted;

// After a port flap the kernel may re-register GIDs at a different index,
// leaving the recovery QP's cached sgid_index pointing at an empty slot.
// qpsRebuilt is set once the recovery QP has been reset and re-modified to
// RTS using a freshly queried GID. lastRebuildAttemptNs throttles retries
// when the rebuild fails (e.g. port still down or GID not yet registered).
bool qpsRebuilt;
uint64_t lastRebuildAttemptNs;

union {
struct {
bool aliveMsgPosted;
Expand Down Expand Up @@ -173,6 +181,87 @@ static inline ncclResult_t ncclIbPortRecoveryQpsToError(ncclIbPortRecoveryContex
// Predefined work ID for receive work request for port recovery messages
#define NCCL_IB_PORT_RECOVERY_WR_ID (0xAAAA)

inline static ncclResult_t ncclIbPortRecoveryPostRecvWorkRequest(struct ibv_qp* qp);

// Rebuild the recovery QP for the failed device after a port flap. Refreshes
// the cached gidInfo (so we pick up the kernel's new GID layout), then resets
// and re-modifies the recovery QP to RTS using the freshly queried sgid_index.
// On the receiver side, also re-posts the initial batch of receive WRs that
// were flushed by the reset. Sets *success=true only if all steps succeed; on
// transient failure (e.g. port still down, GID not yet registered) returns
// ncclSuccess with *success=false so the caller can retry on the next tick.
static ncclResult_t ncclIbPortRecoveryQpsRebuild(struct ncclIbPortRecoveryContext* recoveryContext, bool* success) {
*success = false;
struct ncclIbResiliency* resCtx = recoveryContext->resCtx;
int devIndex = recoveryContext->devIndex;
bool isSend = resCtx->baseComm->isSend;

struct ncclIbNetCommDevBase* devBase = ncclIbGetNetCommDevBase(resCtx->baseComm, devIndex);
if (devBase == NULL) return ncclInternalError;
ncclIbDev* ibDev = &ncclIbDevs[devBase->ibDevN];

NCCLCHECK(ncclIbRefreshGidInfo(ibDev, &devBase->gidInfo));

ncclIbQp* localQp = &resCtx->portRecoveryQps[devIndex];
if (localQp->qp == NULL) return ncclInternalError;

ncclResult_t res = ncclIbQpReset(localQp);
if (res != ncclSuccess) {
INFO(NCCL_NET, "NET/IB: %s: Reset failed for recovery QP devIndex=%d (%s comm=%p)",
__func__, devIndex, isSend ? "send" : "recv", resCtx->baseComm);
return ncclSuccess;
}
res = ncclIbQpInit(localQp);
if (res != ncclSuccess) {
INFO(NCCL_NET, "NET/IB: %s: Init failed for recovery QP devIndex=%d (%s comm=%p)",
__func__, devIndex, isSend ? "send" : "recv", resCtx->baseComm);
return ncclSuccess;
}

// Update rtrAttr to use the freshly queried GID before transitioning to RTR.
localQp->rtrAttr.localGid = devBase->gidInfo.localGid;
localQp->rtrAttr.localGidIndex = devBase->gidInfo.localGidIndex;

res = ncclIbQpRtr(localQp);
if (res != ncclSuccess) {
INFO(NCCL_NET, "NET/IB: %s: RTR failed for recovery QP devIndex=%d localGidIndex=%d (port still down or GID not yet registered) (%s comm=%p)",
__func__, devIndex, devBase->gidInfo.localGidIndex, isSend ? "send" : "recv", resCtx->baseComm);
return ncclSuccess;
}
res = ncclIbQpRts(localQp);
if (res != ncclSuccess) {
INFO(NCCL_NET, "NET/IB: %s: RTS failed for recovery QP devIndex=%d (%s comm=%p)",
__func__, devIndex, isSend ? "send" : "recv", resCtx->baseComm);
return ncclSuccess;
}

// Re-post receive WRs lost by the reset. Sender expects a single ACK; the
// receiver expects up to BATCH_SIZE_MAX alive messages.
if (!isSend) {
for (int i = 0; i < NCCL_IB_RESILIENCY_PORT_RECOVERY_ALIVE_MSG_BATCH_SIZE_MAX; i++) {
res = ncclIbPortRecoveryPostRecvWorkRequest(localQp->qp);
if (res != ncclSuccess) {
INFO(NCCL_NET, "NET/IB: %s: Post recv WR failed for recovery QP devIndex=%d (%s comm=%p)",
__func__, devIndex, isSend ? "send" : "recv", resCtx->baseComm);
return ncclSuccess;
}
}
} else {
res = ncclIbPortRecoveryPostRecvWorkRequest(localQp->qp);
if (res != ncclSuccess) {
INFO(NCCL_NET, "NET/IB: %s: Post recv WR failed for recovery QP devIndex=%d (%s comm=%p)",
__func__, devIndex, isSend ? "send" : "recv", resCtx->baseComm);
return ncclSuccess;
}
}

INFO(NCCL_NET, "NET/IB: %s: Rebuilt recovery QP devIndex=%d qp_num=%u localGidIndex=%d (%s comm=%p)",
__func__, devIndex, localQp->qp->qp_num, devBase->gidInfo.localGidIndex,
isSend ? "send" : "recv", resCtx->baseComm);
*success = true;
return ncclSuccess;
}

static struct ibv_recv_wr ncclIbResiliencyPortRecoveryRecvWr = {
.wr_id = NCCL_IB_PORT_RECOVERY_WR_ID, .next = NULL, .sg_list = NULL, .num_sge = 0
};
Expand Down Expand Up @@ -242,6 +331,8 @@ static inline ncclResult_t ncclIbPortRecoveryContextInit(struct ncclIbResiliency
recoveryCtx->ackReceived = false;
recoveryCtx->ackPosted = false;
recoveryCtx->ackCompleted = false;
recoveryCtx->qpsRebuilt = false;
recoveryCtx->lastRebuildAttemptNs = 0;

for (int i = 0; i < resCtx->ndevs; i++) {
if (i != failedDevIndex) continue;
Expand Down Expand Up @@ -475,6 +566,15 @@ ncclResult_t ncclIbPortRecoveryQpsDestroy(struct ncclIbResiliency* resCtx, int n
static inline ncclResult_t ncclIbPortRecoveryQpsRestore(ncclIbPortRecoveryContext* recoveryContext, bool* success) {
ncclResult_t res = ncclSuccess;
uint nqps = recoveryContext->resCtx->baseComm->nqps;
// The cached per-QP rtrAttr.localGid/localGidIndex were captured at initial
// connect time and may be stale after a port flap. The rebuild step has
// already refreshed devBase->gidInfo for the failed device; use that fresh
// value when modifying each data QP back to RTR.
struct ncclIbNetCommDevBase* devBase = ncclIbGetNetCommDevBase(recoveryContext->resCtx->baseComm, recoveryContext->devIndex);
if (devBase == NULL) {
*success = false;
return ncclInternalError;
}
for (int qpIndex = 0; qpIndex < nqps; qpIndex++) {
ncclIbQp* localQp = &recoveryContext->resCtx->baseComm->qps[qpIndex];
if (localQp->devIndex != recoveryContext->devIndex) {
Expand Down Expand Up @@ -510,6 +610,8 @@ static inline ncclResult_t ncclIbPortRecoveryQpsRestore(ncclIbPortRecoveryContex
return ncclSuccess;
}
}
localQp->rtrAttr.localGid = devBase->gidInfo.localGid;
localQp->rtrAttr.localGidIndex = devBase->gidInfo.localGidIndex;
res = ncclIbQpRtr(localQp);
if (res != ncclSuccess) {
INFO(NCCL_NET, "NET/IB: %s: Failed to modify to RTR QP index %d on device %d (comm=%p, devIndex=%d, qp_num=%u)",
Expand Down Expand Up @@ -572,6 +674,12 @@ static inline ncclResult_t ncclIbPortRecoveryQpsRestore(ncclIbPortRecoveryContex
*success = false;
return ncclSuccess;
}
// Flush QP uses the local device as both source and destination, so
// both rtrAttr->localGid and rtrAttr->remoteGid must point at the
// refreshed GID for the failed device.
flushQp->rtrAttr.localGid = devBase->gidInfo.localGid;
flushQp->rtrAttr.localGidIndex = devBase->gidInfo.localGidIndex;
flushQp->rtrAttr.remoteGid = devBase->gidInfo.localGid;
res = ncclIbQpRtr(flushQp);
if (res != ncclSuccess) {
INFO(NCCL_NET, "NET/IB: %s: Failed to modify to RTR Flush QP on device %d (comm=%p, devIndex=%d, qp_num=%u)",
Expand Down Expand Up @@ -1180,6 +1288,39 @@ static inline ncclResult_t ncclIbPortRecoveryContextProgress(ncclIbPortRecoveryC
recoveryContext->state = ncclIbPortRecoveryStateAliveMessages;
}

// Before the alive-message handshake can make progress, the recovery QP on
// the failed device must be rebuilt with a freshly queried sgid_index
// (the kernel may have re-registered GIDs at a different table index after
// the flap). Throttle retries by the alive-message batch interval so we
// don't busy-loop while waiting for the port to come back / GIDs to settle.
if (recoveryContext->state == ncclIbPortRecoveryStateAliveMessages && !recoveryContext->qpsRebuilt) {
uint64_t now = clockNano();
uint64_t throttleNs = ncclParamIbResiliencyPortRecoveryAliveMsgBatchInterval() * MSEC_TO_NSEC;
if (recoveryContext->lastRebuildAttemptNs == 0 ||
(now - recoveryContext->lastRebuildAttemptNs) >= throttleNs) {
bool rebuilt = false;
NCCLCHECK(ncclIbPortRecoveryQpsRebuild(recoveryContext, &rebuilt));
recoveryContext->lastRebuildAttemptNs = now;
if (rebuilt) {
recoveryContext->qpsRebuilt = true;
// Reset internal state so a fresh alive-message batch is posted on
// the rebuilt QP.
recoveryContext->timeLastMsg = 0;
recoveryContext->aliveMsgNextId = 0;
if (recoveryContext->resCtx->baseComm->isSend) {
recoveryContext->send.aliveMsgPosted = false;
recoveryContext->send.aliveMsgCompleted = false;
} else {
recoveryContext->recv.nInOrderMsgsReceived = 0;
}
}
}
if (!recoveryContext->qpsRebuilt) {
*outDone = false;
return ncclSuccess;
}
}

if (recoveryContext->state == ncclIbPortRecoveryStateAliveMessages) {
NCCLCHECK(ncclIbPortRecoveryProgressAliveMessages(recoveryContext));
}
Expand Down