diff --git a/python/rapidsmpf/rapidsmpf/buffer/resource.pyi b/python/rapidsmpf/rapidsmpf/buffer/resource.pyi index 6d7843800..818545f1e 100644 --- a/python/rapidsmpf/rapidsmpf/buffer/resource.pyi +++ b/python/rapidsmpf/rapidsmpf/buffer/resource.pyi @@ -20,6 +20,8 @@ class BufferResource: def memory_reserved(self, mem_type: MemoryType) -> int: ... @property def spill_manager(self) -> SpillManager: ... + @property + def device_mr(self) -> DeviceMemoryResource: ... class LimitAvailableMemory: def __init__( diff --git a/python/rapidsmpf/rapidsmpf/buffer/resource.pyx b/python/rapidsmpf/rapidsmpf/buffer/resource.pyx index 42a126e38..0a4dc5f21 100644 --- a/python/rapidsmpf/rapidsmpf/buffer/resource.pyx +++ b/python/rapidsmpf/rapidsmpf/buffer/resource.pyx @@ -114,6 +114,11 @@ cdef class BufferResource: """ return self._handle.get() + @property + def device_mr(self): + """The RMM Memory Resource this BufferResource was initialized with.""" + return self._mr + def memory_reserved(self, MemoryType mem_type): """ Get the current reserved memory of the specified memory type. diff --git a/python/rapidsmpf/rapidsmpf/integrations/core.py b/python/rapidsmpf/rapidsmpf/integrations/core.py index 9b9593164..c2fbd0bd6 100644 --- a/python/rapidsmpf/rapidsmpf/integrations/core.py +++ b/python/rapidsmpf/rapidsmpf/integrations/core.py @@ -28,7 +28,7 @@ from rapidsmpf.statistics import Statistics if TYPE_CHECKING: - from collections.abc import Callable, Sequence + from collections.abc import Callable, Generator, Sequence from rapidsmpf.communicator.communicator import Communicator @@ -671,6 +671,32 @@ def spill_func( return ctx.spill_collection.spill(amount, stream=DEFAULT_STREAM, device_mr=mr) +def walk_rmm_resource_stack( + mr: rmm.mr.DeviceMemoryResource, +) -> Generator[rmm.mr.DeviceMemoryResource, None, None]: + """ + Walk the RMM Memory Resources in the given RMM resource stack. + + Parameters + ---------- + mr + The RMM Memory Resource to walk. + + Yields + ------ + RMM Memory Resource + + Notes + ----- + The initial memory resource is yielded first. The ``upstream_mr`` + property of upstream memory resources are recursively yielded until + no more upstream memory reources are found. + """ + yield mr + if hasattr(mr, "upstream_mr"): + yield from walk_rmm_resource_stack(mr.upstream_mr) + + def rmpf_worker_setup( worker: Any, option_prefix: str, @@ -702,23 +728,29 @@ def rmpf_worker_setup( This function creates a new RMM memory pool, and sets it as the current device resource. """ - # Insert RMM resource adaptor on top of the current RMM resource stack. - mr = RmmResourceAdaptor( - upstream_mr=rmm.mr.get_current_device_resource(), - fallback_mr=( - # Use a managed memory resource if OOM protection is enabled. - rmm.mr.ManagedMemoryResource() - if options.get_or_default( - f"{option_prefix}oom_protection", default_value=False - ) - else None - ), - ) - rmm.mr.set_current_device_resource(mr) + # Ensure that an RMM resource adaptor is present in the current RMM resource stack. + mr = rmm.mr.get_current_device_resource() + for child_mr in walk_rmm_resource_stack(mr): + if isinstance(child_mr, RmmResourceAdaptor): + resource_adaptor = child_mr + break + else: + resource_adaptor = mr = RmmResourceAdaptor( + upstream_mr=mr, + fallback_mr=( + # Use a managed memory resource if OOM protection is enabled. + rmm.mr.ManagedMemoryResource() + if options.get_or_default( + f"{option_prefix}oom_protection", default_value=False + ) + else None + ), + ) + rmm.mr.set_current_device_resource(mr) # Print statistics at worker shutdown. if options.get_or_default(f"{option_prefix}statistics", default_value=False): - statistics = Statistics(enable=True, mr=mr) + statistics = Statistics(enable=True, mr=resource_adaptor) else: statistics = Statistics(enable=False) @@ -740,7 +772,7 @@ def rmpf_worker_setup( ) memory_available = { MemoryType.DEVICE: LimitAvailableMemory( - mr, limit=int(total_memory * spill_device) + resource_adaptor, limit=int(total_memory * spill_device) ) } br = BufferResource( diff --git a/python/rapidsmpf/rapidsmpf/statistics.pyi b/python/rapidsmpf/rapidsmpf/statistics.pyi index 6af83fa7e..bf369f45a 100644 --- a/python/rapidsmpf/rapidsmpf/statistics.pyi +++ b/python/rapidsmpf/rapidsmpf/statistics.pyi @@ -16,6 +16,8 @@ class Statistics: mr: RmmResourceAdaptor | None = None, ) -> None: ... @property + def mr(self) -> RmmResourceAdaptor | None: ... + @property def enabled(self) -> bool: ... def report(self) -> str: ... def get_stat(self, name: str) -> dict[str, Number]: ... diff --git a/python/rapidsmpf/rapidsmpf/statistics.pyx b/python/rapidsmpf/rapidsmpf/statistics.pyx index 66eaac7dd..dc1762879 100644 --- a/python/rapidsmpf/rapidsmpf/statistics.pyx +++ b/python/rapidsmpf/rapidsmpf/statistics.pyx @@ -80,6 +80,15 @@ cdef class Statistics: """ return deref(self._handle).enabled() + @property + def mr(self): + """ + The RMM Memory Resource this Statistics was initialized with, if enabled. + + This is None if statistics are not enabled. + """ + return self._mr + def report(self): """ Generates a report of statistics in a formatted string. diff --git a/python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py b/python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py new file mode 100644 index 000000000..8c4122689 --- /dev/null +++ b/python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +import rmm.mr + +import rapidsmpf.communicator.single +import rapidsmpf.config +import rapidsmpf.integrations.core +import rapidsmpf.rmm_resource_adaptor + + +class Worker: + pass + + +@pytest.mark.parametrize("case", ["cuda", "stats-cuda", "stats-pool-cuda"]) +def test_rmpf_worker_setup_memory_resource( + device_mr: rmm.mr.CudaMemoryResource, case: str +) -> None: + if case == "cuda": + mr = rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(device_mr) + elif case == "stats-cuda": + mr = rmm.mr.StatisticsResourceAdaptor( + rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(device_mr) + ) + elif case == "stats-pool-cuda": + mr = rmm.mr.StatisticsResourceAdaptor( + rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor( + rmm.mr.PoolMemoryResource(device_mr) + ) + ) + else: + raise AssertionError(f"Unknown case: {case}") + rmm.mr.set_current_device_resource(mr) + + statistics = "stats" in case + + if statistics: + options = rapidsmpf.config.Options({"single_statistics": "true"}) + else: + options = rapidsmpf.config.Options() + comm = rapidsmpf.communicator.single.new_communicator(options=options) + # call + worker = Worker() + worker_context = rapidsmpf.integrations.core.rmpf_worker_setup( + worker, "single_", comm=comm, options=options + ) + + # The global is set + assert rmm.mr.get_current_device_resource() is mr + + assert worker_context.statistics.enabled is statistics + + if statistics: + assert isinstance( + worker_context.statistics.mr, + rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor, + ) + assert worker_context.statistics.mr is mr.get_upstream() + else: + assert worker_context.statistics.mr is None + + assert worker_context.br.device_mr is mr + # Can't say much about worker_context.br.memory_available, since it just returns an int diff --git a/python/rapidsmpf/rapidsmpf/tests/test_dask.py b/python/rapidsmpf/rapidsmpf/tests/test_dask.py index f8045743f..bc3c8a637 100644 --- a/python/rapidsmpf/rapidsmpf/tests/test_dask.py +++ b/python/rapidsmpf/rapidsmpf/tests/test_dask.py @@ -8,6 +8,8 @@ import dask.dataframe as dd import pytest +import rmm.mr + import rapidsmpf.integrations.single from rapidsmpf.communicator import COMMUNICATORS from rapidsmpf.config import Options @@ -121,14 +123,22 @@ def test_dask_cudf_integration( @pytest.mark.parametrize("partition_count", [None, 3]) @pytest.mark.parametrize("sort", [True, False]) @pytest.mark.parametrize("cluster_kind", ["auto", "single"]) +@pytest.mark.parametrize("preconfigure_mr", [False, True]) def test_dask_cudf_integration_single( partition_count: int, - sort: bool, # noqa: FBT001 + *, + sort: bool, cluster_kind: Literal["distributed", "single", "auto"], + preconfigure_mr: bool, + device_mr: rmm.mr.CudaMemoryResource, ) -> None: # Test single-worker cuDF integration with Dask-cuDF pytest.importorskip("dask_cudf") + if preconfigure_mr: + mr = rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(device_mr) + rmm.mr.set_current_device_resource(mr) + df = ( dask.datasets.timeseries( freq="3600s",