From 99016f665324244dd00e6e4aa04aa71ce7f24c6d Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 10 Oct 2025 07:06:03 +0000 Subject: [PATCH 01/11] enable rdma by default --- torchstore/transport/buffers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchstore/transport/buffers.py b/torchstore/transport/buffers.py index f29579a..4c61bd4 100644 --- a/torchstore/transport/buffers.py +++ b/torchstore/transport/buffers.py @@ -32,7 +32,7 @@ def RDMABuffer(*args: Any, **kwargs: Any) -> Any: def rdma_available() -> bool: rdma_enabled = ( - os.environ.get("TORCHSTORE_RDMA_ENABLED", "0") == "1" + os.environ.get("TORCHSTORE_RDMA_ENABLED", "1") == "1" ) # TODO: enable on this build return rdma_enabled and monarch_rdma_available() From 37c02ba9c4414f84f46971938e810cc960556beb Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sun, 12 Oct 2025 23:43:32 -0700 Subject: [PATCH 02/11] get --- torchstore/storage_volume.py | 24 +++++++++++------------- torchstore/transport/buffers.py | 26 ++++++++++++++++++++++---- torchstore/transport/pipe.py | 26 +++++--------------------- 3 files changed, 38 insertions(+), 38 deletions(-) diff --git a/torchstore/storage_volume.py b/torchstore/storage_volume.py index 355fea5..cabf1fe 100644 --- a/torchstore/storage_volume.py +++ b/torchstore/storage_volume.py @@ -11,8 +11,10 @@ import torch from monarch.actor import Actor, endpoint -from torchstore.transport.buffers import TransportBuffer - +from torchstore.transport.buffers import ( + create_default_transport_buffer, + TransportBuffer, +) from torchstore.transport.pipe import Request, TensorSlice from torchstore.utils import assemble_global_tensor, spawn_actors @@ -59,10 +61,8 @@ async def put( await self.store.put(key, transport_buffer, request) @endpoint - async def get( - self, key: str, transport_buffer: TransportBuffer, request: Request - ) -> TransportBuffer: - return await self.store.get(key, transport_buffer, request) + async def get(self, key: str, request: Request) -> TransportBuffer: + return await self.store.get(key, request) @endpoint async def get_meta( @@ -86,9 +86,7 @@ async def put( """Store data in the storage backend.""" raise NotImplementedError() - async def get( - self, key: str, transport_buffer: TransportBuffer, request: Request - ) -> TransportBuffer: + async def get(self, key: str, request: Request) -> TransportBuffer: """Retrieve data from the storage backend.""" raise NotImplementedError() @@ -201,13 +199,13 @@ async def put( self.kv[key] = tensor - async def get( - self, key: str, transport_buffer: TransportBuffer, request: Request - ) -> TransportBuffer: + async def get(self, key: str, request: Request) -> TransportBuffer: if key not in self.kv: raise KeyError(f"Key '{key}' not found. {list(self.kv.keys())=}") + transport_buffer = create_default_transport_buffer() + # TODO: clean up val = self.kv[key] if isinstance(val, dict) and "obj" in val: @@ -216,7 +214,7 @@ async def get( return transport_buffer if request.tensor_slice is None: - await transport_buffer.write_from(self.kv[key]) + transport_buffer.from_contiguous_tensor(self.kv[key]) return transport_buffer # TODO: diff --git a/torchstore/transport/buffers.py b/torchstore/transport/buffers.py index 4c61bd4..a682887 100644 --- a/torchstore/transport/buffers.py +++ b/torchstore/transport/buffers.py @@ -4,6 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import functools + import logging import os from typing import Any, Dict, List, Optional, Tuple, Union @@ -30,6 +34,7 @@ def RDMABuffer(*args: Any, **kwargs: Any) -> Any: # assert RDMA_CHUNK_SIZE_MB <= 1024, "Monarch does not support 1gb chunks via rdma" +@functools.cache def rdma_available() -> bool: rdma_enabled = ( os.environ.get("TORCHSTORE_RDMA_ENABLED", "1") == "1" @@ -37,6 +42,13 @@ def rdma_available() -> bool: return rdma_enabled and monarch_rdma_available() +def create_default_transport_buffer() -> TransportBuffer: + if rdma_available(): + return RDMATransportBuffer() + else: + return MonarchTransportBuffer() + + class TransportBuffer: finalize: bool = False is_object: bool = False @@ -49,10 +61,7 @@ def update(self, other_buffer: "TransportBuffer") -> None: self.objects = other_buffer.objects self.requires_meta = other_buffer.requires_meta - def allocate(self, tensor_like: Union[torch.Tensor, Tuple]) -> None: - """Allocates internal buffers based on either an existing tensor - or a Tuple of (shape, dtype) - """ + def from_contiguous_tensor(self, tensor: torch.Tensor) -> None: raise NotImplementedError() async def read_into(self, tensor: Optional[torch.Tensor]) -> torch.Tensor: @@ -190,6 +199,12 @@ async def write_from(self, tensor: Optional[torch.Tensor]) -> None: for idx, chunk in enumerate(chunked_byte_view): await self.rdma_buffers[idx].write_from(chunk) + def from_contiguous_tensor(self, tensor: torch.Tensor) -> None: + assert tensor.is_contiguous(), "Tensor must be contiguous" + byte_view_chunks = self._create_byte_views_from_tensor(tensor) + self.tensor_refs = [torch.empty_like(chunk) for chunk in byte_view_chunks] + self.rdma_buffers = [RDMABuffer(chunk) for chunk in self.tensor_refs] + class MonarchTransportBuffer(TransportBuffer): """This interface is mostly a noop, intended to be used with Monarch's regular RPC. @@ -224,3 +239,6 @@ async def write_from(self, tensor: Optional[torch.Tensor]) -> None: def update(self, other_buffer: "TransportBuffer") -> None: super().update(other_buffer) self.tensor = other_buffer.tensor + + def from_contiguous_tensor(self, tensor: torch.Tensor) -> None: + self.tensor = tensor diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index f0d94fd..6f7a3bd 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -145,9 +145,9 @@ def create_transport_buffer(self) -> TransportBuffer: async def put_to_storage_volume(self, key, request: Request): transport_buffer = self.create_transport_buffer() tensor = request.tensor_val - - transport_buffer.allocate(tensor) - await transport_buffer.write_from(tensor) + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + transport_buffer.from_contiguous_tensor(tensor) # transporting tensors is handled by the buffer, so we don't want to send it # via monarch RPC since that would generate considerable overhead @@ -156,24 +156,8 @@ async def put_to_storage_volume(self, key, request: Request): ) async def get_from_storage_volume(self, key, request: Request): - - transport_buffer = self.create_transport_buffer() - - # Certain buffers (RDMA) need to know the size of the tensor - # so we can allocate the right amount of memory locally. - # This can be avoided if the request contains a tensor slice. - # Could likely be optimized away in the future. - if transport_buffer.requires_meta and request.tensor_val is None: - meta = await self.storage_volume.get_meta.call_one(key, request.meta_only()) - transport_buffer.allocate(meta) - else: - transport_buffer.allocate(request.tensor_val) - - # TODO: consider placing the buffer inside the request or vice versa - transport_buffer.update( - await self.storage_volume.get.call_one( - key, transport_buffer, request.meta_only() - ) + transport_buffer = await self.storage_volume.get.call_one( + key, request.meta_only() ) if transport_buffer.is_object: From 86fbaeac83d3d76709c0c0238b0f3078243194d1 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sun, 12 Oct 2025 23:54:02 -0700 Subject: [PATCH 03/11] fix metadata --- torchstore/transport/buffers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchstore/transport/buffers.py b/torchstore/transport/buffers.py index a682887..d0afe19 100644 --- a/torchstore/transport/buffers.py +++ b/torchstore/transport/buffers.py @@ -201,6 +201,9 @@ async def write_from(self, tensor: Optional[torch.Tensor]) -> None: def from_contiguous_tensor(self, tensor: torch.Tensor) -> None: assert tensor.is_contiguous(), "Tensor must be contiguous" + self.shape = tensor.shape + self.dtype = tensor.dtype + self.dim = tensor.dim() byte_view_chunks = self._create_byte_views_from_tensor(tensor) self.tensor_refs = [torch.empty_like(chunk) for chunk in byte_view_chunks] self.rdma_buffers = [RDMABuffer(chunk) for chunk in self.tensor_refs] From e73de6db222ccc3362e80deff38204dc6be27d0c Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 13 Oct 2025 00:05:04 -0700 Subject: [PATCH 04/11] fix --- torchstore/storage_volume.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchstore/storage_volume.py b/torchstore/storage_volume.py index cabf1fe..b3c4f8a 100644 --- a/torchstore/storage_volume.py +++ b/torchstore/storage_volume.py @@ -225,7 +225,7 @@ async def get(self, key: str, request: Request) -> TransportBuffer: for shard in self.kv[key].values(): if shard["slice"] == request.tensor_slice: - await transport_buffer.write_from(shard["tensor"]) + transport_buffer.from_contiguous_tensor(shard["tensor"]) return transport_buffer raise RuntimeError(f"Tensor slice {request.tensor_slice} not found in {key}") From 6c3412a57502688b9af8a0527d1cbf48819467a9 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 13 Oct 2025 14:36:49 -0700 Subject: [PATCH 05/11] fix put for obj --- torchstore/storage_volume.py | 3 +++ torchstore/transport/pipe.py | 7 ++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/torchstore/storage_volume.py b/torchstore/storage_volume.py index b3c4f8a..7222329 100644 --- a/torchstore/storage_volume.py +++ b/torchstore/storage_volume.py @@ -196,6 +196,9 @@ async def put( if request.tensor_slice is not None: self._handle_dtensor(key, request.tensor_slice, tensor) return + print( + f"putting {key} {tensor.shape=}, {tensor.dtype=}, is tensor zero? {torch.allclose(tensor, torch.zeros_like(tensor))}" + ) self.kv[key] = tensor diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index 6f7a3bd..14a017c 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -145,9 +145,10 @@ def create_transport_buffer(self) -> TransportBuffer: async def put_to_storage_volume(self, key, request: Request): transport_buffer = self.create_transport_buffer() tensor = request.tensor_val - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - transport_buffer.from_contiguous_tensor(tensor) + if tensor is not None: + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + transport_buffer.from_contiguous_tensor(tensor) # transporting tensors is handled by the buffer, so we don't want to send it # via monarch RPC since that would generate considerable overhead From 255e89ca38dabc7b63c42c838f3085ca6fe0cb39 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 13 Oct 2025 14:38:47 -0700 Subject: [PATCH 06/11] fix bug in from_contiguous_tensor --- torchstore/transport/buffers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchstore/transport/buffers.py b/torchstore/transport/buffers.py index d0afe19..20d2b46 100644 --- a/torchstore/transport/buffers.py +++ b/torchstore/transport/buffers.py @@ -204,8 +204,7 @@ def from_contiguous_tensor(self, tensor: torch.Tensor) -> None: self.shape = tensor.shape self.dtype = tensor.dtype self.dim = tensor.dim() - byte_view_chunks = self._create_byte_views_from_tensor(tensor) - self.tensor_refs = [torch.empty_like(chunk) for chunk in byte_view_chunks] + self.tensor_refs = self._create_byte_views_from_tensor(tensor) self.rdma_buffers = [RDMABuffer(chunk) for chunk in self.tensor_refs] From 1733b0a2306c69b4a1b5d80f4c90563b175563d9 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 13 Oct 2025 14:46:43 -0700 Subject: [PATCH 07/11] fix --- torchstore/transport/buffers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchstore/transport/buffers.py b/torchstore/transport/buffers.py index 20d2b46..02b2458 100644 --- a/torchstore/transport/buffers.py +++ b/torchstore/transport/buffers.py @@ -200,12 +200,13 @@ async def write_from(self, tensor: Optional[torch.Tensor]) -> None: await self.rdma_buffers[idx].write_from(chunk) def from_contiguous_tensor(self, tensor: torch.Tensor) -> None: + """It is the caller's responsibility to ensure that the tensor lives long enough until the buffer is used.""" assert tensor.is_contiguous(), "Tensor must be contiguous" self.shape = tensor.shape self.dtype = tensor.dtype self.dim = tensor.dim() - self.tensor_refs = self._create_byte_views_from_tensor(tensor) - self.rdma_buffers = [RDMABuffer(chunk) for chunk in self.tensor_refs] + tensor_refs = self._create_byte_views_from_tensor(tensor) + self.rdma_buffers = [RDMABuffer(chunk) for chunk in tensor_refs] class MonarchTransportBuffer(TransportBuffer): From 7ff69b723ef98c5271fef0fc0df7b8326ca51153 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 13 Oct 2025 14:55:57 -0700 Subject: [PATCH 08/11] clean up --- torchstore/storage_volume.py | 3 --- torchstore/transport/buffers.py | 7 ++++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/torchstore/storage_volume.py b/torchstore/storage_volume.py index 7222329..b3c4f8a 100644 --- a/torchstore/storage_volume.py +++ b/torchstore/storage_volume.py @@ -196,9 +196,6 @@ async def put( if request.tensor_slice is not None: self._handle_dtensor(key, request.tensor_slice, tensor) return - print( - f"putting {key} {tensor.shape=}, {tensor.dtype=}, is tensor zero? {torch.allclose(tensor, torch.zeros_like(tensor))}" - ) self.kv[key] = tensor diff --git a/torchstore/transport/buffers.py b/torchstore/transport/buffers.py index 02b2458..25407af 100644 --- a/torchstore/transport/buffers.py +++ b/torchstore/transport/buffers.py @@ -200,13 +200,14 @@ async def write_from(self, tensor: Optional[torch.Tensor]) -> None: await self.rdma_buffers[idx].write_from(chunk) def from_contiguous_tensor(self, tensor: torch.Tensor) -> None: - """It is the caller's responsibility to ensure that the tensor lives long enough until the buffer is used.""" + """The caller must ensure that the tensor lives long enough until the buffer is used.""" assert tensor.is_contiguous(), "Tensor must be contiguous" self.shape = tensor.shape self.dtype = tensor.dtype self.dim = tensor.dim() - tensor_refs = self._create_byte_views_from_tensor(tensor) - self.rdma_buffers = [RDMABuffer(chunk) for chunk in tensor_refs] + self.rdma_buffers = [ + RDMABuffer(chunk) for chunk in self._create_byte_views_from_tensor(tensor) + ] class MonarchTransportBuffer(TransportBuffer): From df6335ef2c421aaa97008e1e9de1bd22dcdfa34c Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 13 Oct 2025 16:36:49 -0700 Subject: [PATCH 09/11] move to cpu first --- torchstore/transport/pipe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index 14a017c..4d1897a 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -146,6 +146,8 @@ async def put_to_storage_volume(self, key, request: Request): transport_buffer = self.create_transport_buffer() tensor = request.tensor_val if tensor is not None: + # TODO: investigate why RDMA fails on CUDA tensors + tensor = tensor.cpu() if not tensor.is_contiguous(): tensor = tensor.contiguous() transport_buffer.from_contiguous_tensor(tensor) From ba294def0cddbac02d0d77bd0e94dc0f24f26574 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 18 Oct 2025 14:12:31 -0700 Subject: [PATCH 10/11] drop --- torchstore/transport/buffers.py | 9 +++++++++ torchstore/transport/pipe.py | 6 +++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/torchstore/transport/buffers.py b/torchstore/transport/buffers.py index 3cd1a04..3a897b3 100644 --- a/torchstore/transport/buffers.py +++ b/torchstore/transport/buffers.py @@ -68,6 +68,9 @@ async def read_into(self, tensor: Optional[torch.Tensor]) -> torch.Tensor: async def write_from(self, tensor: Optional[torch.Tensor]) -> None: raise NotImplementedError() + async def drop(self) -> None: + pass + class RDMATransportBuffer(TransportBuffer): # TODO: when we try this with rdma, I should be able to write rdma directly to the tensor @@ -182,6 +185,12 @@ async def read_into(self, tensor: Optional[torch.Tensor] = None) -> torch.Tensor return tensor + async def drop(self) -> None: + if self.rdma_buffers is not None: + for buffer in self.rdma_buffers: + await buffer.drop() + self.tensor_refs = None + # recv async def write_from(self, tensor: Optional[torch.Tensor]) -> None: if tensor is None: diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index 4d1897a..ab0ac70 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -158,6 +158,8 @@ async def put_to_storage_volume(self, key, request: Request): key, transport_buffer, request.meta_only() ) + await transport_buffer.drop() + async def get_from_storage_volume(self, key, request: Request): transport_buffer = await self.storage_volume.get.call_one( key, request.meta_only() @@ -166,4 +168,6 @@ async def get_from_storage_volume(self, key, request: Request): if transport_buffer.is_object: return transport_buffer.objects - return await transport_buffer.read_into(request.tensor_val) + ret = await transport_buffer.read_into(request.tensor_val) + transport_buffer.drop() + return ret From f43bc8fa2b2fb2413d2c211ba27b456ad2542c3c Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 18 Oct 2025 14:35:34 -0700 Subject: [PATCH 11/11] await --- torchstore/transport/pipe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchstore/transport/pipe.py b/torchstore/transport/pipe.py index ab0ac70..9c77133 100644 --- a/torchstore/transport/pipe.py +++ b/torchstore/transport/pipe.py @@ -169,5 +169,5 @@ async def get_from_storage_volume(self, key, request: Request): return transport_buffer.objects ret = await transport_buffer.read_into(request.tensor_val) - transport_buffer.drop() + await transport_buffer.drop() return ret