diff --git a/proto/v1/sidecar.proto b/proto/v1/sidecar.proto index 1900ce17..ae351638 100644 --- a/proto/v1/sidecar.proto +++ b/proto/v1/sidecar.proto @@ -24,11 +24,20 @@ service Sidecar { // Mark a chunk of forwarded data as done. rpc MarkDone(MarkDoneRequest) returns (MarkDoneResponse); + // Consumer sidecar client --> Its sidecar server (==> Producer sidecar server). + // Mark all chunks of forwarded data as done for a given data_id. + rpc MarkDoneAll(MarkDoneAllRequest) returns (MarkDoneAllResponse); + // Consumer sidecar server --> Producer sidecar server. // Used exclusively during intra-node /dev/shm ping-pong. // Decrement the reference count of a chunk. rpc Unlink(UnlinkRequest) returns (UnlinkResponse); + // Consumer sidecar server --> Producer sidecar server. + // Used exclusively during intra-node /dev/shm ping-pong. + // Decrement the reference count of all chunks for a given data_id. + rpc UnlinkAll(UnlinkAllRequest) returns (UnlinkAllResponse); + // Producer sidecar server --> Consumer sidecar server. // Ask the receiver sidecar to prepare for receiving a chunk of data, // e.g., allocate shared memory. @@ -108,6 +117,15 @@ message MarkDoneResponse { common.Status status = 1; } +message MarkDoneAllRequest { + string id = 1; +} + +message MarkDoneAllResponse { + common.Status status = 1; + int32 num_chunks_marked = 2; // number of chunks that were marked done +} + message UnlinkRequest { string id = 1; int32 chunk_id = 2; @@ -117,6 +135,15 @@ message UnlinkResponse { common.Status status = 1; } +message UnlinkAllRequest { + string id = 1; +} + +message UnlinkAllResponse { + common.Status status = 1; + int32 num_chunks_unlinked = 2; // number of chunks that were unlinked +} + message PrepareReceiveRequest { string id = 1; diff --git a/python/cornserve/services/sidecar/receiver.py b/python/cornserve/services/sidecar/receiver.py index d2b6f0c1..0e129593 100644 --- a/python/cornserve/services/sidecar/receiver.py +++ b/python/cornserve/services/sidecar/receiver.py @@ -22,6 +22,9 @@ SharedMemoryManager, ) from cornserve.sidecar.constants import ( + GRPC_RETRY_BACKOFF_MULTIPLIER, + GRPC_RETRY_INITIAL_BACKOFF_SECONDS, + GRPC_RETRY_MAX_ATTEMPTS, chunk_tag, grpc_url_from_rank, shm_filename, @@ -396,7 +399,35 @@ async def mark_done( ) stub = self._get_grpc_stub(chunk_state.intra_node_rank) unlink_req = sidecar_pb2.UnlinkRequest(id=mark_done_req.id, chunk_id=mark_done_req.chunk_id) - res = await stub.Unlink(unlink_req) + res = None + for attempt in range(GRPC_RETRY_MAX_ATTEMPTS): + try: + res = await stub.Unlink(unlink_req) + break + except grpc.RpcError as e: + code = e.code() + if ( + code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE) + and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1 + ): + backoff_delay = GRPC_RETRY_INITIAL_BACKOFF_SECONDS * (GRPC_RETRY_BACKOFF_MULTIPLIER**attempt) + logger.warning( + "Unlink retry %d for req %s chunk %d in rank %d due to %s, waiting %.3fs", + attempt + 1, + mark_done_req.id, + mark_done_req.chunk_id, + chunk_state.intra_node_rank, + code.name, + backoff_delay, + ) + await asyncio.sleep(backoff_delay) + continue + raise + if res is None: + await context.abort( + grpc.StatusCode.INTERNAL, + f"Failed to unlink for id {mark_done_req.id}: no response received", + ) if res.status != common_pb2.Status.STATUS_OK: await context.abort(grpc.StatusCode.INTERNAL, "Failed to unlink intra node memory") else: @@ -413,3 +444,117 @@ async def mark_done( # last chunk del self.ledger[mark_done_req.id] return sidecar_pb2.MarkDoneResponse(status=common_pb2.Status.STATUS_OK) + + @tracer.start_as_current_span(name="SidecarReceiver.mark_done_all") + async def mark_done_all( + self, + mark_done_all_req: sidecar_pb2.MarkDoneAllRequest, + context: grpc.aio.ServicerContext, + ) -> sidecar_pb2.MarkDoneAllResponse: + """Mark all chunks of a request as done to free the shared memory buffers. + + This is a convenience method that calls mark_done for all chunks of a given data_id. + It handles both tensor chunks (RecvTensorState) and object chunks (RecvObjState). + """ + span = trace.get_current_span() + span.set_attribute("SidecarReceiver.mark_done_all.id", mark_done_all_req.id) + + if mark_done_all_req.id not in self.ledger: + logger.error("mark_done_all: %s not found", mark_done_all_req.id) + await context.abort(grpc.StatusCode.NOT_FOUND, "mark_done_all_req not found") + + req_state = self.ledger[mark_done_all_req.id] + num_chunks = req_state.num_chunks + num_chunks_marked = 0 + + logger.info("mark_done_all: marking %d chunks for id %s", num_chunks, mark_done_all_req.id) + + # Group chunks by type and intra_node_rank for efficient processing + intra_node_ranks: set[int] = set() + local_tensor_buffers: list[tuple[int, RecvTensorState]] = [] + + # First pass: categorize chunks + for chunk_id in range(num_chunks): + if chunk_id not in req_state.chunks: + logger.warning("mark_done_all: chunk %d not found for id %s, skipping", chunk_id, mark_done_all_req.id) + continue + + chunk_state = req_state.chunks[chunk_id] + + if isinstance(chunk_state, RecvObjState): + # Object chunk - just delete it + del req_state.chunks[chunk_id] + num_chunks_marked += 1 + elif isinstance(chunk_state, RecvTensorState): + # Tensor chunk - categorize by whether it needs unlinking + if chunk_state.intra_node_rank >= 0: + intra_node_ranks.add(chunk_state.intra_node_rank) + else: + local_tensor_buffers.append((chunk_id, chunk_state)) + num_chunks_marked += 1 + + # Unlink all intra-node tensors per rank + for rank in intra_node_ranks: + logger.debug( + "mark_done_all: calling UnlinkAll for id %s in rank %d", + mark_done_all_req.id, + rank, + ) + stub = self._get_grpc_stub(rank) + unlink_all_req = sidecar_pb2.UnlinkAllRequest(id=mark_done_all_req.id) + res = None + for attempt in range(GRPC_RETRY_MAX_ATTEMPTS): + try: + res = await stub.UnlinkAll(unlink_all_req) + break + except grpc.RpcError as e: + code = e.code() + if ( + code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE) + and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1 + ): + backoff_delay = GRPC_RETRY_INITIAL_BACKOFF_SECONDS * (GRPC_RETRY_BACKOFF_MULTIPLIER**attempt) + logger.warning( + "UnlinkAll retry %d for req %s in rank %d due to %s, waiting %.3fs", + attempt + 1, + mark_done_all_req.id, + rank, + code.name, + backoff_delay, + ) + await asyncio.sleep(backoff_delay) + continue + raise + + if res is None: + await context.abort( + grpc.StatusCode.INTERNAL, + f"Failed to unlink all for id {mark_done_all_req.id} in rank {rank}: no response received", + ) + if res.status != common_pb2.Status.STATUS_OK: + await context.abort(grpc.StatusCode.INTERNAL, f"Failed to unlink all intra node memory in rank {rank}") + logger.debug( + "mark_done_all: UnlinkAll succeeded for id %s in rank %d, unlinked %d chunks", + mark_done_all_req.id, + rank, + res.num_chunks_unlinked, + ) + + # Third pass: free local buffers + for chunk_id, chunk_state in local_tensor_buffers: + await self._free(chunk_state.buffer) + logger.debug( + "mark_done_all: Freed up %d slots from %s chunk %d", + len(chunk_state.buffer.slots), + mark_done_all_req.id, + chunk_id, + ) + + # Clean up the ledger entry + del self.ledger[mark_done_all_req.id] + logger.info("mark_done_all: marked %d/%d chunks for id %s", num_chunks_marked, num_chunks, mark_done_all_req.id) + + return sidecar_pb2.MarkDoneAllResponse( + status=common_pb2.Status.STATUS_OK, + num_chunks_marked=num_chunks_marked, + ) diff --git a/python/cornserve/services/sidecar/sender.py b/python/cornserve/services/sidecar/sender.py index 10c1888d..601c96bb 100644 --- a/python/cornserve/services/sidecar/sender.py +++ b/python/cornserve/services/sidecar/sender.py @@ -17,6 +17,9 @@ SharedMemoryManager, ) from cornserve.sidecar.constants import ( + GRPC_RETRY_BACKOFF_MULTIPLIER, + GRPC_RETRY_INITIAL_BACKOFF_SECONDS, + GRPC_RETRY_MAX_ATTEMPTS, chunk_tag, grpc_url_from_rank, shm_filename, @@ -205,6 +208,75 @@ async def unlink( ) return sidecar_pb2.UnlinkResponse(status=common_pb2.Status.STATUS_OK) + @tracer.start_as_current_span(name="SidecarSender.unlink_all") + async def unlink_all( + self, + request: sidecar_pb2.UnlinkAllRequest, + context: grpc.aio.ServicerContext, + ) -> sidecar_pb2.UnlinkAllResponse: + """Mark all chunks of a tensor as consumed. + + Exclusively from intra-node send. This call will decrement the ref count for all chunks + of a given data_id. If any ref count reaches 0, the underlying buffer will be freed. + + Args: + request: The unlink all request containing the data_id. + context: The gRPC context. + + Returns: + UnlinkAllResponse with status and number of chunks unlinked. + """ + span = trace.get_current_span() + span.set_attribute("SidecarSender.unlink_all.id", request.id) + + num_chunks_unlinked = 0 + base_id = request.id + + # Find all chunks for this data_id + # The id format is "{data_id}-{chunk_id}" + chunks_to_unlink = [] + for buffer_id in list(self.saved_buffers.keys()): + if buffer_id.startswith(base_id + "-"): + chunks_to_unlink.append(buffer_id) + if not chunks_to_unlink: + logger.warning("unlink_all: No buffers found for id %s", base_id) + return sidecar_pb2.UnlinkAllResponse( + status=common_pb2.Status.STATUS_OK, + num_chunks_unlinked=0, + ) + + logger.info("unlink_all: Unlinking %d chunks for id %s", len(chunks_to_unlink), base_id) + + for buffer_id in chunks_to_unlink: + if buffer_id not in self.ref_counts: + logger.warning("unlink_all: Buffer %s not in ref_counts, skipping", buffer_id) + continue + self.ref_counts[buffer_id] -= 1 + if self.ref_counts[buffer_id] == 0: + del self.ref_counts[buffer_id] + buffer = self.saved_buffers[buffer_id] + logger.debug("unlink_all: Freeing buffer %s", buffer_id) + await self._free(buffer) + del self.saved_buffers[buffer_id] + else: + logger.debug( + "unlink_all: Decremented ref count for %s, remaining: %d", + buffer_id, + self.ref_counts[buffer_id], + ) + num_chunks_unlinked += 1 + + logger.info( + "unlink_all: Unlinked %d chunks for id %s, used slots: %d", + num_chunks_unlinked, + base_id, + self.shm_manager.used_slots, + ) + return sidecar_pb2.UnlinkAllResponse( + status=common_pb2.Status.STATUS_OK, + num_chunks_unlinked=num_chunks_unlinked, + ) + async def send( self, request: sidecar_pb2.SendRequest, @@ -286,8 +358,31 @@ async def _send_intra_node_buffer( num_chunks=request.num_chunks, ) stub = self._get_grpc_stub(dst_rank) - res = await stub.PrepareReceive(req) - if res.status != common_pb2.Status.STATUS_OK: + res = None + for attempt in range(GRPC_RETRY_MAX_ATTEMPTS): + try: + res = await stub.PrepareReceive(req) + break + except grpc.RpcError as e: + code = e.code() + if ( + code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE) + and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1 + ): + backoff_delay = GRPC_RETRY_INITIAL_BACKOFF_SECONDS * (GRPC_RETRY_BACKOFF_MULTIPLIER**attempt) + logger.warning( + "PrepareReceive retry %d for req %s chunk %d to rank %d due to %s, waiting %.3fs", + attempt + 1, + request.id, + request.chunk_id, + dst_rank, + code.name, + backoff_delay, + ) + await asyncio.sleep(backoff_delay) + continue + raise + if res is None or res.status != common_pb2.Status.STATUS_OK: logger.error("Failed to prepare receive") return sidecar_pb2.SendResponse(status=common_pb2.Status.STATUS_ERROR) return sidecar_pb2.SendResponse(status=common_pb2.Status.STATUS_OK) @@ -331,8 +426,34 @@ async def _send_inter_node_buffer( num_chunks=request.num_chunks, ) stub = self._get_grpc_stub(dst_rank) - res = await stub.PrepareReceive(req) - if res.status != common_pb2.Status.STATUS_OK: + res = None + for attempt in range(GRPC_RETRY_MAX_ATTEMPTS): + try: + res = await stub.PrepareReceive(req) + break + except grpc.RpcError as e: + code = e.code() + if ( + code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE) + and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1 + ): + backoff_delay = GRPC_RETRY_INITIAL_BACKOFF_SECONDS * (GRPC_RETRY_BACKOFF_MULTIPLIER**attempt) + logger.warning( + ( + "PrepareReceive (concurrent) retry %d for req %s " + "chunk %d to rank %d due to %s, waiting %.3fs" + ), + attempt + 1, + request.id, + request.chunk_id, + dst_rank, + code.name, + backoff_delay, + ) + await asyncio.sleep(backoff_delay) + continue + raise + if res is None or res.status != common_pb2.Status.STATUS_OK: logger.error("Failed to prepare receive") return sidecar_pb2.SendResponse(status=common_pb2.Status.STATUS_ERROR) tag = chunk_tag(request.id, self.sidecar_rank, request.chunk_id, obj.shard_rank) @@ -357,8 +478,34 @@ async def _send_inter_node_buffer( num_chunks=request.num_chunks, ) stub = self._get_grpc_stub(dst_rank) - res = await stub.PrepareReceive(req) - if res.status != common_pb2.Status.STATUS_OK: + res = None + for attempt in range(GRPC_RETRY_MAX_ATTEMPTS): + try: + res = await stub.PrepareReceive(req) + break + except grpc.RpcError as e: + code = e.code() + if ( + code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE) + and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1 + ): + backoff_delay = GRPC_RETRY_INITIAL_BACKOFF_SECONDS * (GRPC_RETRY_BACKOFF_MULTIPLIER**attempt) + logger.warning( + ( + "PrepareReceive (non-concurrent) retry %d for req %s " + "chunk %d to rank %d due to %s, waiting %.3fs" + ), + attempt + 1, + request.id, + request.chunk_id, + dst_rank, + code.name, + backoff_delay, + ) + await asyncio.sleep(backoff_delay) + continue + raise + if res is None or res.status != common_pb2.Status.STATUS_OK: logger.error("Failed to prepare receive") return sidecar_pb2.SendResponse(status=common_pb2.Status.STATUS_ERROR) # do send diff --git a/python/cornserve/services/sidecar/server.py b/python/cornserve/services/sidecar/server.py index 3584f779..61fcc91d 100644 --- a/python/cornserve/services/sidecar/server.py +++ b/python/cornserve/services/sidecar/server.py @@ -33,6 +33,9 @@ from cornserve.services.sidecar.sender import SidecarSender from cornserve.sidecar.constants import ( GRPC_BASE_PORT, + GRPC_RETRY_BACKOFF_MULTIPLIER, + GRPC_RETRY_INITIAL_BACKOFF_SECONDS, + GRPC_RETRY_MAX_ATTEMPTS, UCX_BASE_PORT, grpc_url_from_rank, ucx_port_from_rank, @@ -110,8 +113,30 @@ async def _reachable(self, sidecar_rank: int) -> bool: async with grpc.aio.insecure_channel(grpc_url_from_rank(sidecar_rank)) as channel: req = sidecar_pb2.CheckHealthRequest() stub = sidecar_pb2_grpc.SidecarStub(channel) - _ = await stub.CheckHealth(req) - return True + for attempt in range(GRPC_RETRY_MAX_ATTEMPTS): + try: + _ = await stub.CheckHealth(req) + return True + except grpc.RpcError as e: + code = e.code() + if ( + code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE) + and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1 + ): + backoff_delay = GRPC_RETRY_INITIAL_BACKOFF_SECONDS * ( + GRPC_RETRY_BACKOFF_MULTIPLIER**attempt + ) + logger.warning( + "CheckHealth retry %d for sidecar rank %d due to %s, waiting %.3fs", + attempt + 1, + sidecar_rank, + code.name, + backoff_delay, + ) + await asyncio.sleep(backoff_delay) + continue + raise + return False except Exception: return False @@ -374,6 +399,25 @@ async def MarkDone( # noqa: N802 logger.exception("Error in MarkDone") await context.abort(grpc.StatusCode.INTERNAL, f"Error in MarkDone: {e} \n {tb_str}") + async def MarkDoneAll( # noqa: N802 + self, + request: sidecar_pb2.MarkDoneAllRequest, + context: grpc.aio.ServicerContext, + ) -> sidecar_pb2.MarkDoneAllResponse: + """Called by the receiver server to mark all chunks of a request as done.""" + try: + if not self.live: + logger.error("Sidecar not online") + await context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Sidecar not online") + if self.receiver is None: + logger.error("Sidecar not registered") + await context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Sidecar not registered") + return await self.receiver.mark_done_all(request, context) + except Exception as e: + tb_str = traceback.format_exc() + logger.exception("Error in MarkDoneAll") + await context.abort(grpc.StatusCode.INTERNAL, f"Error in MarkDoneAll: {e} \n {tb_str}") + async def Unlink( # noqa: N802 self, request: sidecar_pb2.UnlinkRequest, @@ -394,6 +438,25 @@ async def Unlink( # noqa: N802 logger.exception("Error in Unlink") await context.abort(grpc.StatusCode.INTERNAL, f"Error in Unlink: {e} \n {tb_str}") + async def UnlinkAll( # noqa: N802 + self, + request: sidecar_pb2.UnlinkAllRequest, + context: grpc.aio.ServicerContext, + ) -> sidecar_pb2.UnlinkAllResponse: + """Called by the receiver server to mark all chunks of a request as done.""" + try: + if not self.live: + logger.error("Sidecar not online") + await context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Sidecar not online") + if self.sender is None: + logger.error("Sidecar not registered") + await context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Sidecar not registered") + return await self.sender.unlink_all(request, context) + except Exception as e: + tb_str = traceback.format_exc() + logger.exception("Error in UnlinkAll") + await context.abort(grpc.StatusCode.INTERNAL, f"Error in UnlinkAll: {e} \n {tb_str}") + async def shutdown(self) -> None: """Shutdown the sidecar.""" if self.sender is not None: diff --git a/python/cornserve/sidecar/api.py b/python/cornserve/sidecar/api.py index f7a3ed4d..5f46c345 100644 --- a/python/cornserve/sidecar/api.py +++ b/python/cornserve/sidecar/api.py @@ -7,6 +7,7 @@ import ctypes import json import os +import time import weakref from concurrent.futures import ThreadPoolExecutor from functools import lru_cache @@ -24,7 +25,13 @@ from cornserve.logging import get_logger from cornserve.services.pb import common_pb2, sidecar_pb2, sidecar_pb2_grpc -from cornserve.sidecar.constants import grpc_url_from_rank, shm_filename +from cornserve.sidecar.constants import ( + GRPC_RETRY_BACKOFF_MULTIPLIER, + GRPC_RETRY_INITIAL_BACKOFF_SECONDS, + GRPC_RETRY_MAX_ATTEMPTS, + grpc_url_from_rank, + shm_filename, +) from cornserve.sidecar.schema import SidecarConfig from cornserve.sidecar.serde import MsgpackDecoder, MsgpackEncoder, SharedTensorHandle from cornserve.sidecar.utils import device_from_rank, init_shmem @@ -111,7 +118,27 @@ def __init__( concurrent_copy=config.concurrent_copy, ) - response = self.stub.Register(request) + response = None + for attempt in range(GRPC_RETRY_MAX_ATTEMPTS): + try: + response = self.stub.Register(request) + break + except grpc.RpcError as e: + code = e.code() + if ( + code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE) + and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1 + ): + logger.warning( + "Register retry %d for sidecar rank %d due to %s", + attempt + 1, + self.sidecar_rank, + code.name, + ) + continue + raise + if response is None: + raise RuntimeError(f"Failed to register sidecar rank {self.sidecar_rank}: no response received") assert response.shm_size > 0, "Failed to register sidecar" self.shard_rank = response.local_rank @@ -254,7 +281,29 @@ def _send_worker( chunk_id=chunk_id, num_chunks=num_chunks, ) - response = self.stub.Send(request) + response = None + for attempt in range(GRPC_RETRY_MAX_ATTEMPTS): + try: + response = self.stub.Send(request) + break + except grpc.RpcError as e: + code = e.code() + if ( + code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE) + and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1 + ): + logger.warning( + "Send retry %d for shard %d chunk %d in req %s due to %s", + attempt + 1, + self.shard_rank, + chunk_id, + id, + code.name, + ) + continue + raise + if response is None: + raise RuntimeError(f"Failed to send req {id}: no response received.") if response.status == common_pb2.Status.STATUS_OK: logger.info("Sent shard %d of chunk %d in req %s successfully", self.shard_rank, chunk_id, id) if isinstance(obj, torch.Tensor): @@ -315,7 +364,27 @@ def _close_stream_worker( num_chunks=num_chunks, ) grpc_stub = self._get_grpc_stub(sidecar_rank) - response = grpc_stub.CloseStream(request) + response = None + for attempt in range(GRPC_RETRY_MAX_ATTEMPTS): + try: + response = grpc_stub.CloseStream(request) + break + except grpc.RpcError as e: + code = e.code() + if ( + code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE) + and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1 + ): + logger.warning( + "CloseStream retry %d for stream %s due to %s", + attempt + 1, + id, + code.name, + ) + continue + raise + if response is None: + raise RuntimeError(f"Failed to close stream {id}: no response received") if response.status == common_pb2.Status.STATUS_OK: logger.info("Closed stream %s successfully with %d chunks", id, num_chunks) else: @@ -358,7 +427,31 @@ async def recv(self, id: str, chunk_id: int = 0) -> Any: span.set_attribute("sidecar.recv.id", id) span.set_attribute("sidecar.recv.chunk_id", chunk_id) request = sidecar_pb2.ReceiveRequest(id=id, chunk_id=chunk_id) - response = await self.aio_stub.Receive(request) + response = None + for attempt in range(GRPC_RETRY_MAX_ATTEMPTS): + try: + response = await self.aio_stub.Receive(request) + break + except grpc.RpcError as e: + code = e.code() + if ( + code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE) + and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1 + ): + backoff_delay = GRPC_RETRY_INITIAL_BACKOFF_SECONDS * (GRPC_RETRY_BACKOFF_MULTIPLIER**attempt) + logger.warning( + "Receive retry %d for chunk %d in req %s due to %s, waiting %.3fs", + attempt + 1, + chunk_id, + id, + code.name, + backoff_delay, + ) + await asyncio.sleep(backoff_delay) + continue + raise + if response is None: + raise RuntimeError(f"Failed to receive data with id {id}: no response received") if response.status != common_pb2.Status.STATUS_OK: raise ValueError(f"Failed to receive data with id {id}") @@ -396,7 +489,31 @@ def recv_sync(self, id: str, chunk_id: int = 0) -> Any: span.set_attribute("sidecar.read.id", id) span.set_attribute("sidecar.read.chunk_id", chunk_id) request = sidecar_pb2.ReceiveRequest(id=id, chunk_id=chunk_id) - response = self.stub.Receive(request) + response = None + for attempt in range(GRPC_RETRY_MAX_ATTEMPTS): + try: + response = self.stub.Receive(request) + break + except grpc.RpcError as e: + code = e.code() + if ( + code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE) + and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1 + ): + backoff_delay = GRPC_RETRY_INITIAL_BACKOFF_SECONDS * (GRPC_RETRY_BACKOFF_MULTIPLIER**attempt) + logger.warning( + "Receive (sync) retry %d for chunk %d in req %s due to %s, waiting %.3fs", + attempt + 1, + chunk_id, + id, + code.name, + backoff_delay, + ) + time.sleep(backoff_delay) + continue + raise + if response is None: + raise RuntimeError(f"Failed to receive data with id {id}: no response received") if response.status != common_pb2.Status.STATUS_OK: raise ValueError(f"Failed to receive data with id {id}") @@ -430,9 +547,78 @@ async def mark_done(self, id: str, chunk_id: int = 0) -> None: span = trace.get_current_span() span.set_attribute("sidecar.mark_done.id", id) request = sidecar_pb2.MarkDoneRequest(id=id, chunk_id=chunk_id) - response = await self.aio_stub.MarkDone(request) + response = None + for attempt in range(GRPC_RETRY_MAX_ATTEMPTS): + try: + response = await self.aio_stub.MarkDone(request) + break + except grpc.RpcError as e: + code = e.code() + if ( + code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE) + and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1 + ): + backoff_delay = GRPC_RETRY_INITIAL_BACKOFF_SECONDS * (GRPC_RETRY_BACKOFF_MULTIPLIER**attempt) + logger.warning( + "MarkDone retry %d for req %s chunk %d due to %s, waiting %.3fs", + attempt + 1, + id, + chunk_id, + code.name, + backoff_delay, + ) + await asyncio.sleep(backoff_delay) + continue + raise + if response is None: + raise RuntimeError(f"Failed to mark done for id {id}: no response received") if response.status == common_pb2.Status.STATUS_OK: logger.debug("Request %s marked done", id) + else: + logger.error("Failed to mark request %s done", id) + + @tracer.start_as_current_span(name="Sidecar.mark_done_all") + async def mark_done_all(self, id: str) -> int: + """Mark all chunks of a tensor as done in the sidecar server to free all shared memory buffers. + + Returns the number of chunks marked. + """ + if _is_mocking(): + return 0 + + span = trace.get_current_span() + span.set_attribute("sidecar.mark_done_all.id", id) + request = sidecar_pb2.MarkDoneAllRequest(id=id) + response = None + for attempt in range(GRPC_RETRY_MAX_ATTEMPTS): + try: + response = await self.aio_stub.MarkDoneAll(request) + break + except grpc.RpcError as e: + code = e.code() + if ( + code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE) + and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1 + ): + backoff_delay = GRPC_RETRY_INITIAL_BACKOFF_SECONDS * (GRPC_RETRY_BACKOFF_MULTIPLIER**attempt) + logger.warning( + "MarkDoneAll retry %d for req %s due to %s, waiting %.3fs", + attempt + 1, + id, + code.name, + backoff_delay, + ) + await asyncio.sleep(backoff_delay) + continue + raise + if response is None: + raise RuntimeError(f"Failed to mark done all for id {id}: no response received") + if response.status == common_pb2.Status.STATUS_OK: + logger.info("Request %s marked all %d chunks done", id, response.num_chunks_marked) + return response.num_chunks_marked + else: + logger.error("Failed to mark request %s all chunks done", id) + raise RuntimeError(f"Failed to mark done all for id {id}") def _mark_done_worker(self, id: str, chunk_id: int = 0) -> None: """Mark a tensor as done in the sidecar server to free the shared memory buffer. @@ -444,7 +630,31 @@ def _mark_done_worker(self, id: str, chunk_id: int = 0) -> None: span = trace.get_current_span() span.set_attribute("sidecar.mark_done.id", id) request = sidecar_pb2.MarkDoneRequest(id=id, chunk_id=chunk_id) - response = self.stub.MarkDone(request) + response = None + for attempt in range(GRPC_RETRY_MAX_ATTEMPTS): + try: + response = self.stub.MarkDone(request) + break + except grpc.RpcError as e: + code = e.code() + if ( + code in (grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE) + and attempt < GRPC_RETRY_MAX_ATTEMPTS - 1 + ): + backoff_delay = GRPC_RETRY_INITIAL_BACKOFF_SECONDS * (GRPC_RETRY_BACKOFF_MULTIPLIER**attempt) + logger.warning( + "MarkDone (worker) retry %d for req %s chunk %d due to %s, waiting %.3fs", + attempt + 1, + id, + chunk_id, + code.name, + backoff_delay, + ) + time.sleep(backoff_delay) + continue + raise + if response is None: + raise RuntimeError(f"Failed to mark done for id {id}: no response received") if response.status == common_pb2.Status.STATUS_OK: logger.debug("Request %s marked done", id) else: diff --git a/python/cornserve/sidecar/constants.py b/python/cornserve/sidecar/constants.py index 1945e6b3..93168030 100644 --- a/python/cornserve/sidecar/constants.py +++ b/python/cornserve/sidecar/constants.py @@ -11,6 +11,11 @@ GRPC_BASE_PORT = 10000 UCX_BASE_PORT = 12000 +# Retry configuration for gRPC calls +GRPC_RETRY_MAX_ATTEMPTS = 3 # 3 total attempts (1 initial + 2 retries) +GRPC_RETRY_INITIAL_BACKOFF_SECONDS = 0.05 # Initial backoff delay in seconds +GRPC_RETRY_BACKOFF_MULTIPLIER = 2.0 # Exponential backoff multiplier + def chunk_tag(id: str, rank: int, chunk_id: int, shard_rank: int) -> int: """Generate a tag for the chunk. diff --git a/python/cornserve/task_executors/geri/engine/client.py b/python/cornserve/task_executors/geri/engine/client.py index 4e2e6374..4cb11caf 100644 --- a/python/cornserve/task_executors/geri/engine/client.py +++ b/python/cornserve/task_executors/geri/engine/client.py @@ -257,7 +257,8 @@ async def generate_streaming( else: logger.info("Error detected for request %s", request_id) break - + # Attempt to free all sidecar buffers associated with this request + await self.sidecar.mark_done_all(request.embedding_data_id) # Cleanup self.pending_streams.pop(request_id) diff --git a/tasklib/cornserve_tasklib/task_executors/descriptor/omni.py b/tasklib/cornserve_tasklib/task_executors/descriptor/omni.py index b00a4ad7..68f9711c 100644 --- a/tasklib/cornserve_tasklib/task_executors/descriptor/omni.py +++ b/tasklib/cornserve_tasklib/task_executors/descriptor/omni.py @@ -27,7 +27,6 @@ async def parse_stream_to_audio_chunks( response: aiohttp.ClientResponse, ) -> AsyncGenerator[str]: """Parse the streaming response into audio chunks.""" - assert not response.closed, "Response must not be closed when parsing." try: buffer = b"" # Read in larger chunks to avoid "Chunk too big" error with large base64-encoded audio @@ -57,7 +56,8 @@ async def parse_stream_to_audio_chunks( chunk = OpenAIChatCompletionChunk.model_validate_json(line) yield chunk.model_dump_json() finally: - response.close() + if not response.closed: + response.close() class OmniTalkerVocoderDescriptor(