From 98f9e76dfc1900087320e5f58e50bb6515fee569 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Tue, 23 Sep 2025 18:38:31 -0700 Subject: [PATCH 1/6] Reuse existing MR when valid. --- .../rapidsmpf/rapidsmpf/buffer/resource.pyi | 2 + .../rapidsmpf/rapidsmpf/buffer/resource.pyx | 5 ++ .../rapidsmpf/rapidsmpf/integrations/core.py | 40 +++++++++----- python/rapidsmpf/rapidsmpf/statistics.pyi | 2 + python/rapidsmpf/rapidsmpf/statistics.pyx | 7 +++ .../rapidsmpf/tests/integrations/test_core.py | 55 +++++++++++++++++++ 6 files changed, 97 insertions(+), 14 deletions(-) create mode 100644 python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py diff --git a/python/rapidsmpf/rapidsmpf/buffer/resource.pyi b/python/rapidsmpf/rapidsmpf/buffer/resource.pyi index 6d7843800..818545f1e 100644 --- a/python/rapidsmpf/rapidsmpf/buffer/resource.pyi +++ b/python/rapidsmpf/rapidsmpf/buffer/resource.pyi @@ -20,6 +20,8 @@ class BufferResource: def memory_reserved(self, mem_type: MemoryType) -> int: ... @property def spill_manager(self) -> SpillManager: ... + @property + def device_mr(self) -> DeviceMemoryResource: ... class LimitAvailableMemory: def __init__( diff --git a/python/rapidsmpf/rapidsmpf/buffer/resource.pyx b/python/rapidsmpf/rapidsmpf/buffer/resource.pyx index 42a126e38..0a4dc5f21 100644 --- a/python/rapidsmpf/rapidsmpf/buffer/resource.pyx +++ b/python/rapidsmpf/rapidsmpf/buffer/resource.pyx @@ -114,6 +114,11 @@ cdef class BufferResource: """ return self._handle.get() + @property + def device_mr(self): + """The RMM Memory Resource this BufferResource was initialized with.""" + return self._mr + def memory_reserved(self, MemoryType mem_type): """ Get the current reserved memory of the specified memory type. diff --git a/python/rapidsmpf/rapidsmpf/integrations/core.py b/python/rapidsmpf/rapidsmpf/integrations/core.py index 5285edbc8..f94213f32 100644 --- a/python/rapidsmpf/rapidsmpf/integrations/core.py +++ b/python/rapidsmpf/rapidsmpf/integrations/core.py @@ -440,23 +440,35 @@ def rmpf_worker_setup( This function creates a new RMM memory pool, and sets it as the current device resource. """ - # Insert RMM resource adaptor on top of the current RMM resource stack. - mr = RmmResourceAdaptor( - upstream_mr=rmm.mr.get_current_device_resource(), - fallback_mr=( - # Use a managed memory resource if OOM protection is enabled. - rmm.mr.ManagedMemoryResource() - if options.get_or_default( - f"{option_prefix}oom_protection", default_value=False - ) - else None - ), - ) + # Ensure that an RMM resource adaptor (or a StatisticsResourceAdaptor wrapping one) + # is on top of the current RMM resource stack. + upstream_mr = rmm.mr.get_current_device_resource() + + if isinstance(upstream_mr, RmmResourceAdaptor): + resource_adaptor = mr = upstream_mr + elif isinstance(upstream_mr, rmm.mr.StatisticsResourceAdaptor) and isinstance( + upstream_mr.upstream_mr, RmmResourceAdaptor + ): + mr = upstream_mr + resource_adaptor = mr.upstream_mr + else: + resource_adaptor = mr = RmmResourceAdaptor( + upstream_mr=upstream_mr, + fallback_mr=( + # Use a managed memory resource if OOM protection is enabled. + rmm.mr.ManagedMemoryResource() + if options.get_or_default( + f"{option_prefix}oom_protection", default_value=False + ) + else None + ), + ) + rmm.mr.set_current_device_resource(mr) # Print statistics at worker shutdown. if options.get_or_default(f"{option_prefix}statistics", default_value=False): - statistics = Statistics(enable=True, mr=mr) + statistics = Statistics(enable=True, mr=resource_adaptor) else: statistics = Statistics(enable=False) @@ -478,7 +490,7 @@ def rmpf_worker_setup( ) memory_available = { MemoryType.DEVICE: LimitAvailableMemory( - mr, limit=int(total_memory * spill_device) + resource_adaptor, limit=int(total_memory * spill_device) ) } br = BufferResource( diff --git a/python/rapidsmpf/rapidsmpf/statistics.pyi b/python/rapidsmpf/rapidsmpf/statistics.pyi index 6af83fa7e..bf369f45a 100644 --- a/python/rapidsmpf/rapidsmpf/statistics.pyi +++ b/python/rapidsmpf/rapidsmpf/statistics.pyi @@ -16,6 +16,8 @@ class Statistics: mr: RmmResourceAdaptor | None = None, ) -> None: ... @property + def mr(self) -> RmmResourceAdaptor | None: ... + @property def enabled(self) -> bool: ... def report(self) -> str: ... def get_stat(self, name: str) -> dict[str, Number]: ... diff --git a/python/rapidsmpf/rapidsmpf/statistics.pyx b/python/rapidsmpf/rapidsmpf/statistics.pyx index 66eaac7dd..02d4fefc9 100644 --- a/python/rapidsmpf/rapidsmpf/statistics.pyx +++ b/python/rapidsmpf/rapidsmpf/statistics.pyx @@ -80,6 +80,13 @@ cdef class Statistics: """ return deref(self._handle).enabled() + @property + def mr(self): + """ + The RMM Memory Resource this Statistics was initialized with. + """ + return self._mr + def report(self): """ Generates a report of statistics in a formatted string. diff --git a/python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py b/python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py new file mode 100644 index 000000000..3d76eb7de --- /dev/null +++ b/python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +import rmm.mr + +import rapidsmpf.communicator.single +import rapidsmpf.config +import rapidsmpf.integrations.core +import rapidsmpf.rmm_resource_adaptor + + +class Worker: + pass + + +@pytest.mark.parametrize("statistics", [False, True]) +def test_rmpf_worker_setup_memory_resource(*, statistics: bool) -> None: + # setup + upstream_mr = rmm.mr.CudaMemoryResource() + mr = rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(upstream_mr) + if statistics: + mr = rmm.mr.StatisticsResourceAdaptor(mr) + rmm.mr.set_current_device_resource(mr) + + if statistics: + options = rapidsmpf.config.Options({"single_statistics": "true"}) + else: + options = rapidsmpf.config.Options() + comm = rapidsmpf.communicator.single.new_communicator(options=options) + # call + worker = Worker() + worker_context = rapidsmpf.integrations.core.rmpf_worker_setup( + worker, "single_", comm=comm, options=options + ) + + # The global is set + assert rmm.mr.get_current_device_resource() is mr + + assert worker_context.statistics.enabled is statistics + + if statistics: + assert isinstance( + worker_context.statistics.mr, + rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor, + ) + assert worker_context.statistics.mr is mr.get_upstream() + else: + assert worker_context.statistics.mr is None + + assert worker_context.br.device_mr is mr + # Can't say much about worker_context.br.memory_available, since it just returns an int From 77d5f6045373bc5a86ab5dcd7902a1bd2a9e373f Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 24 Sep 2025 07:15:10 -0700 Subject: [PATCH 2/6] Walk the RMM resource stack to find adaptors --- .../rapidsmpf/rapidsmpf/integrations/core.py | 51 +++++++++++++------ .../rapidsmpf/tests/integrations/test_core.py | 27 +++++++--- python/rapidsmpf/rapidsmpf/tests/test_dask.py | 13 ++++- 3 files changed, 69 insertions(+), 22 deletions(-) diff --git a/python/rapidsmpf/rapidsmpf/integrations/core.py b/python/rapidsmpf/rapidsmpf/integrations/core.py index f94213f32..e888a09f4 100644 --- a/python/rapidsmpf/rapidsmpf/integrations/core.py +++ b/python/rapidsmpf/rapidsmpf/integrations/core.py @@ -28,7 +28,7 @@ from rapidsmpf.statistics import Statistics if TYPE_CHECKING: - from collections.abc import Callable, Sequence + from collections.abc import Callable, Generator, Sequence from rapidsmpf.communicator.communicator import Communicator @@ -409,6 +409,32 @@ def spill_func( return ctx.spill_collection.spill(amount, stream=DEFAULT_STREAM, device_mr=mr) +def walk_rmm_resource_stack( + mr: rmm.mr.DeviceMemoryResource, +) -> Generator[rmm.mr.DeviceMemoryResource, None, None]: + """ + Walk the RMM Memory Resources in the given RMM resource stack. + + Parameters + ---------- + mr + The RMM Memory Resource to walk. + + Yields + ------ + RMM Memory Resource + + Notes + ----- + The initial memory resource is yielded first. The ``upstream_mr`` + property of upstream memory resources are recursively yielded until + no more upstream memory reources are found. + """ + yield mr + if hasattr(mr, "upstream_mr"): + yield from walk_rmm_resource_stack(mr.upstream_mr) + + def rmpf_worker_setup( worker: Any, option_prefix: str, @@ -440,20 +466,16 @@ def rmpf_worker_setup( This function creates a new RMM memory pool, and sets it as the current device resource. """ - # Ensure that an RMM resource adaptor (or a StatisticsResourceAdaptor wrapping one) - # is on top of the current RMM resource stack. - upstream_mr = rmm.mr.get_current_device_resource() - - if isinstance(upstream_mr, RmmResourceAdaptor): - resource_adaptor = mr = upstream_mr - elif isinstance(upstream_mr, rmm.mr.StatisticsResourceAdaptor) and isinstance( - upstream_mr.upstream_mr, RmmResourceAdaptor - ): - mr = upstream_mr - resource_adaptor = mr.upstream_mr + # Ensure that an RMM resource adaptor is present in the current RMM resource stack. + mr = rmm.mr.get_current_device_resource() + + for child_mr in walk_rmm_resource_stack(mr): + if isinstance(child_mr, RmmResourceAdaptor): + resource_adaptor = child_mr + break else: resource_adaptor = mr = RmmResourceAdaptor( - upstream_mr=upstream_mr, + upstream_mr=mr, fallback_mr=( # Use a managed memory resource if OOM protection is enabled. rmm.mr.ManagedMemoryResource() @@ -463,8 +485,7 @@ def rmpf_worker_setup( else None ), ) - - rmm.mr.set_current_device_resource(mr) + rmm.mr.set_current_device_resource(mr) # Print statistics at worker shutdown. if options.get_or_default(f"{option_prefix}statistics", default_value=False): diff --git a/python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py b/python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py index 3d76eb7de..386676158 100644 --- a/python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py +++ b/python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py @@ -17,15 +17,30 @@ class Worker: pass -@pytest.mark.parametrize("statistics", [False, True]) -def test_rmpf_worker_setup_memory_resource(*, statistics: bool) -> None: +# @pytest.mark.parametrize("statistics", [False, True]) +@pytest.mark.parametrize("case", ["cuda", "stats-cuda", "stats-pool-cuda"]) +def test_rmpf_worker_setup_memory_resource(case: str) -> None: # setup - upstream_mr = rmm.mr.CudaMemoryResource() - mr = rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(upstream_mr) - if statistics: - mr = rmm.mr.StatisticsResourceAdaptor(mr) + if case == "cuda": + mr = rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor( + rmm.mr.CudaMemoryResource() + ) + elif case == "stats-cuda": + mr = rmm.mr.StatisticsResourceAdaptor( + rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor( + rmm.mr.CudaMemoryResource() + ) + ) + elif case == "stats-pool-cuda": + mr = rmm.mr.StatisticsResourceAdaptor( + rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor( + rmm.mr.PoolMemoryResource(rmm.mr.CudaMemoryResource()) + ) + ) rmm.mr.set_current_device_resource(mr) + statistics = "stats" in case + if statistics: options = rapidsmpf.config.Options({"single_statistics": "true"}) else: diff --git a/python/rapidsmpf/rapidsmpf/tests/test_dask.py b/python/rapidsmpf/rapidsmpf/tests/test_dask.py index 877fbfc20..217a358b2 100644 --- a/python/rapidsmpf/rapidsmpf/tests/test_dask.py +++ b/python/rapidsmpf/rapidsmpf/tests/test_dask.py @@ -8,6 +8,8 @@ import dask.dataframe as dd import pytest +import rmm.mr + import rapidsmpf.integrations.single from rapidsmpf.communicator import COMMUNICATORS from rapidsmpf.config import Options @@ -117,14 +119,23 @@ def test_dask_cudf_integration( @pytest.mark.parametrize("partition_count", [None, 3]) @pytest.mark.parametrize("sort", [True, False]) @pytest.mark.parametrize("cluster_kind", ["auto", "single"]) +@pytest.mark.parametrize("preconfigure_mr", [False, True]) def test_dask_cudf_integration_single( partition_count: int, - sort: bool, # noqa: FBT001 + *, + sort: bool, cluster_kind: Literal["distributed", "single", "auto"], + preconfigure_mr: bool, ) -> None: # Test single-worker cuDF integration with Dask-cuDF pytest.importorskip("dask_cudf") + if preconfigure_mr: + mr = rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor( + rmm.mr.CudaMemoryResource() + ) + rmm.mr.set_current_device_resource(mr) + df = ( dask.datasets.timeseries( freq="3600s", From ac51bf8ddab996cea99fd895d0edf173672c8db4 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 24 Sep 2025 08:02:33 -0700 Subject: [PATCH 3/6] Test fixes --- .../rapidsmpf/tests/integrations/test_core.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py b/python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py index 386676158..8c4122689 100644 --- a/python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py +++ b/python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py @@ -17,26 +17,24 @@ class Worker: pass -# @pytest.mark.parametrize("statistics", [False, True]) @pytest.mark.parametrize("case", ["cuda", "stats-cuda", "stats-pool-cuda"]) -def test_rmpf_worker_setup_memory_resource(case: str) -> None: - # setup +def test_rmpf_worker_setup_memory_resource( + device_mr: rmm.mr.CudaMemoryResource, case: str +) -> None: if case == "cuda": - mr = rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor( - rmm.mr.CudaMemoryResource() - ) + mr = rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(device_mr) elif case == "stats-cuda": mr = rmm.mr.StatisticsResourceAdaptor( - rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor( - rmm.mr.CudaMemoryResource() - ) + rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(device_mr) ) elif case == "stats-pool-cuda": mr = rmm.mr.StatisticsResourceAdaptor( rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor( - rmm.mr.PoolMemoryResource(rmm.mr.CudaMemoryResource()) + rmm.mr.PoolMemoryResource(device_mr) ) ) + else: + raise AssertionError(f"Unknown case: {case}") rmm.mr.set_current_device_resource(mr) statistics = "stats" in case From 8a09c7f6939ad0613c6b412463dec258e0864359 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 24 Sep 2025 08:04:24 -0700 Subject: [PATCH 4/6] format --- python/rapidsmpf/rapidsmpf/integrations/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/rapidsmpf/rapidsmpf/integrations/core.py b/python/rapidsmpf/rapidsmpf/integrations/core.py index e888a09f4..bee985a81 100644 --- a/python/rapidsmpf/rapidsmpf/integrations/core.py +++ b/python/rapidsmpf/rapidsmpf/integrations/core.py @@ -468,7 +468,6 @@ def rmpf_worker_setup( """ # Ensure that an RMM resource adaptor is present in the current RMM resource stack. mr = rmm.mr.get_current_device_resource() - for child_mr in walk_rmm_resource_stack(mr): if isinstance(child_mr, RmmResourceAdaptor): resource_adaptor = child_mr From 442709d00d035c144d47e715d3e45e795db9649d Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 24 Sep 2025 08:05:30 -0700 Subject: [PATCH 5/6] doc fix --- python/rapidsmpf/rapidsmpf/statistics.pyx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/rapidsmpf/rapidsmpf/statistics.pyx b/python/rapidsmpf/rapidsmpf/statistics.pyx index 02d4fefc9..dc1762879 100644 --- a/python/rapidsmpf/rapidsmpf/statistics.pyx +++ b/python/rapidsmpf/rapidsmpf/statistics.pyx @@ -83,7 +83,9 @@ cdef class Statistics: @property def mr(self): """ - The RMM Memory Resource this Statistics was initialized with. + The RMM Memory Resource this Statistics was initialized with, if enabled. + + This is None if statistics are not enabled. """ return self._mr From 4386505a93a236c400b5878ca1ef0645e92c7453 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 24 Sep 2025 10:09:40 -0700 Subject: [PATCH 6/6] Use device_mr --- python/rapidsmpf/rapidsmpf/tests/test_dask.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/rapidsmpf/rapidsmpf/tests/test_dask.py b/python/rapidsmpf/rapidsmpf/tests/test_dask.py index 217a358b2..bbac48317 100644 --- a/python/rapidsmpf/rapidsmpf/tests/test_dask.py +++ b/python/rapidsmpf/rapidsmpf/tests/test_dask.py @@ -126,14 +126,13 @@ def test_dask_cudf_integration_single( sort: bool, cluster_kind: Literal["distributed", "single", "auto"], preconfigure_mr: bool, + device_mr: rmm.mr.CudaMemoryResource, ) -> None: # Test single-worker cuDF integration with Dask-cuDF pytest.importorskip("dask_cudf") if preconfigure_mr: - mr = rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor( - rmm.mr.CudaMemoryResource() - ) + mr = rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(device_mr) rmm.mr.set_current_device_resource(mr) df = (