Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions proto/v1/sidecar.proto
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,20 @@ service Sidecar {
// Mark a chunk of forwarded data as done.
rpc MarkDone(MarkDoneRequest) returns (MarkDoneResponse);

// Consumer sidecar client --> Its sidecar server (==> Producer sidecar server).
// Mark all chunks of forwarded data as done for a given data_id.
rpc MarkDoneAll(MarkDoneAllRequest) returns (MarkDoneAllResponse);

// Consumer sidecar server --> Producer sidecar server.
// Used exclusively during intra-node /dev/shm ping-pong.
// Decrement the reference count of a chunk.
rpc Unlink(UnlinkRequest) returns (UnlinkResponse);

// Consumer sidecar server --> Producer sidecar server.
// Used exclusively during intra-node /dev/shm ping-pong.
// Decrement the reference count of all chunks for a given data_id.
rpc UnlinkAll(UnlinkAllRequest) returns (UnlinkAllResponse);

// Producer sidecar server --> Consumer sidecar server.
// Ask the receiver sidecar to prepare for receiving a chunk of data,
// e.g., allocate shared memory.
Expand Down Expand Up @@ -108,6 +117,15 @@ message MarkDoneResponse {
common.Status status = 1;
}

message MarkDoneAllRequest {
string id = 1;
}

message MarkDoneAllResponse {
common.Status status = 1;
int32 num_chunks_marked = 2; // number of chunks that were marked done
}

message UnlinkRequest {
string id = 1;
int32 chunk_id = 2;
Expand All @@ -117,6 +135,15 @@ message UnlinkResponse {
common.Status status = 1;
}

message UnlinkAllRequest {
string id = 1;
}

message UnlinkAllResponse {
common.Status status = 1;
int32 num_chunks_unlinked = 2; // number of chunks that were unlinked
}


message PrepareReceiveRequest {
string id = 1;
Expand Down
147 changes: 146 additions & 1 deletion python/cornserve/services/sidecar/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
SharedMemoryManager,
)
from cornserve.sidecar.constants import (
GRPC_RETRY_BACKOFF_MULTIPLIER,
GRPC_RETRY_INITIAL_BACKOFF_SECONDS,
GRPC_RETRY_MAX_ATTEMPTS,
chunk_tag,
grpc_url_from_rank,
shm_filename,
Expand Down Expand Up @@ -396,7 +399,35 @@ async def mark_done(
)
stub = self._get_grpc_stub(chunk_state.intra_node_rank)
unlink_req = sidecar_pb2.UnlinkRequest(id=mark_done_req.id, chunk_id=mark_done_req.chunk_id)
res = await stub.Unlink(unlink_req)
res = None
for attempt in range(GRPC_RETRY_MAX_ATTEMPTS):
try:
res = await stub.Unlink(unlink_req)
break
except grpc.RpcError as e:
code = e.code()
if (
code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE)
and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1
):
backoff_delay = GRPC_RETRY_INITIAL_BACKOFF_SECONDS * (GRPC_RETRY_BACKOFF_MULTIPLIER**attempt)
logger.warning(
"Unlink retry %d for req %s chunk %d in rank %d due to %s, waiting %.3fs",
attempt + 1,
mark_done_req.id,
mark_done_req.chunk_id,
chunk_state.intra_node_rank,
code.name,
backoff_delay,
)
await asyncio.sleep(backoff_delay)
continue
raise
if res is None:
await context.abort(
grpc.StatusCode.INTERNAL,
f"Failed to unlink for id {mark_done_req.id}: no response received",
)
Comment on lines +402 to +430
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This retry logic with exponential backoff is duplicated in many places across the codebase (e.g., mark_done_all in this file, and in sender.py, server.py, api.py). This makes the code harder to maintain and prone to errors (like the missing time.sleep in some sync versions).

Consider extracting this logic into a reusable helper function for both asynchronous and synchronous gRPC calls. This would centralize the retry mechanism, improve readability, and ensure consistency.

For example, you could create an async helper like this:

async def grpc_retry_async(stub_call, request, log_message_prefix: str):
    response = None
    for attempt in range(GRPC_RETRY_MAX_ATTEMPTS):
        try:
            response = await stub_call(request)
            break
        except grpc.RpcError as e:
            code = e.code()
            if (
                code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE)
                and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1
            ):
                backoff_delay = GRPC_RETRY_INITIAL_BACKOFF_SECONDS * (GRPC_RETRY_BACKOFF_MULTIPLIER**attempt)
                logger.warning(
                    f"{log_message_prefix} retry %d due to %s, waiting %.3fs",
                    attempt + 1,
                    code.name,
                    backoff_delay,
                )
                await asyncio.sleep(backoff_delay)
                continue
            raise
    return response

And then use it like this:

log_prefix = f"Unlink for req {mark_done_req.id} chunk {mark_done_req.chunk_id} in rank {chunk_state.intra_node_rank}"
res = await grpc_retry_async(stub.Unlink, unlink_req, log_prefix)
if res is None:
    await context.abort(
        grpc.StatusCode.INTERNAL,
        f"Failed to unlink for id {mark_done_req.id}: no response received",
    )

A similar helper grpc_retry_sync could be created for synchronous calls.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yizhuoliang Please deduplicate this part.

if res.status != common_pb2.Status.STATUS_OK:
await context.abort(grpc.StatusCode.INTERNAL, "Failed to unlink intra node memory")
else:
Expand All @@ -413,3 +444,117 @@ async def mark_done(
# last chunk
del self.ledger[mark_done_req.id]
return sidecar_pb2.MarkDoneResponse(status=common_pb2.Status.STATUS_OK)

@tracer.start_as_current_span(name="SidecarReceiver.mark_done_all")
async def mark_done_all(
self,
mark_done_all_req: sidecar_pb2.MarkDoneAllRequest,
context: grpc.aio.ServicerContext,
) -> sidecar_pb2.MarkDoneAllResponse:
"""Mark all chunks of a request as done to free the shared memory buffers.

This is a convenience method that calls mark_done for all chunks of a given data_id.
It handles both tensor chunks (RecvTensorState) and object chunks (RecvObjState).
"""
span = trace.get_current_span()
span.set_attribute("SidecarReceiver.mark_done_all.id", mark_done_all_req.id)

if mark_done_all_req.id not in self.ledger:
logger.error("mark_done_all: %s not found", mark_done_all_req.id)
await context.abort(grpc.StatusCode.NOT_FOUND, "mark_done_all_req not found")

req_state = self.ledger[mark_done_all_req.id]
num_chunks = req_state.num_chunks
num_chunks_marked = 0

logger.info("mark_done_all: marking %d chunks for id %s", num_chunks, mark_done_all_req.id)

# Group chunks by type and intra_node_rank for efficient processing
intra_node_ranks: set[int] = set()
local_tensor_buffers: list[tuple[int, RecvTensorState]] = []

# First pass: categorize chunks
for chunk_id in range(num_chunks):
if chunk_id not in req_state.chunks:
logger.warning("mark_done_all: chunk %d not found for id %s, skipping", chunk_id, mark_done_all_req.id)
continue

chunk_state = req_state.chunks[chunk_id]

if isinstance(chunk_state, RecvObjState):
# Object chunk - just delete it
del req_state.chunks[chunk_id]
num_chunks_marked += 1
elif isinstance(chunk_state, RecvTensorState):
# Tensor chunk - categorize by whether it needs unlinking
if chunk_state.intra_node_rank >= 0:
intra_node_ranks.add(chunk_state.intra_node_rank)
else:
local_tensor_buffers.append((chunk_id, chunk_state))
num_chunks_marked += 1

# Unlink all intra-node tensors per rank
for rank in intra_node_ranks:
logger.debug(
"mark_done_all: calling UnlinkAll for id %s in rank %d",
mark_done_all_req.id,
rank,
)
stub = self._get_grpc_stub(rank)
unlink_all_req = sidecar_pb2.UnlinkAllRequest(id=mark_done_all_req.id)
res = None
for attempt in range(GRPC_RETRY_MAX_ATTEMPTS):
try:
res = await stub.UnlinkAll(unlink_all_req)
break
except grpc.RpcError as e:
code = e.code()
if (
code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE)
and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1
):
backoff_delay = GRPC_RETRY_INITIAL_BACKOFF_SECONDS * (GRPC_RETRY_BACKOFF_MULTIPLIER**attempt)
logger.warning(
"UnlinkAll retry %d for req %s in rank %d due to %s, waiting %.3fs",
attempt + 1,
mark_done_all_req.id,
rank,
code.name,
backoff_delay,
)
await asyncio.sleep(backoff_delay)
continue
raise

if res is None:
await context.abort(
grpc.StatusCode.INTERNAL,
f"Failed to unlink all for id {mark_done_all_req.id} in rank {rank}: no response received",
)
Comment on lines +505 to +533
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is another instance of duplicated retry logic. To improve maintainability and prevent inconsistencies, please see my other comment in this file on the mark_done method about refactoring this into a reusable helper function.

if res.status != common_pb2.Status.STATUS_OK:
await context.abort(grpc.StatusCode.INTERNAL, f"Failed to unlink all intra node memory in rank {rank}")
logger.debug(
"mark_done_all: UnlinkAll succeeded for id %s in rank %d, unlinked %d chunks",
mark_done_all_req.id,
rank,
res.num_chunks_unlinked,
)

# Third pass: free local buffers
for chunk_id, chunk_state in local_tensor_buffers:
await self._free(chunk_state.buffer)
logger.debug(
"mark_done_all: Freed up %d slots from %s chunk %d",
len(chunk_state.buffer.slots),
mark_done_all_req.id,
chunk_id,
)

# Clean up the ledger entry
del self.ledger[mark_done_all_req.id]
logger.info("mark_done_all: marked %d/%d chunks for id %s", num_chunks_marked, num_chunks, mark_done_all_req.id)

return sidecar_pb2.MarkDoneAllResponse(
status=common_pb2.Status.STATUS_OK,
num_chunks_marked=num_chunks_marked,
)
Loading