Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/rapidsmpf/rapidsmpf/buffer/resource.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class BufferResource:
def memory_reserved(self, mem_type: MemoryType) -> int: ...
@property
def spill_manager(self) -> SpillManager: ...
@property
def device_mr(self) -> DeviceMemoryResource: ...

class LimitAvailableMemory:
def __init__(
Expand Down
5 changes: 5 additions & 0 deletions python/rapidsmpf/rapidsmpf/buffer/resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ cdef class BufferResource:
"""
return self._handle.get()

@property
def device_mr(self):
"""The RMM Memory Resource this BufferResource was initialized with."""
return self._mr

def memory_reserved(self, MemoryType mem_type):
"""
Get the current reserved memory of the specified memory type.
Expand Down
64 changes: 48 additions & 16 deletions python/rapidsmpf/rapidsmpf/integrations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from rapidsmpf.statistics import Statistics

if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from collections.abc import Callable, Generator, Sequence

from rapidsmpf.communicator.communicator import Communicator

Expand Down Expand Up @@ -671,6 +671,32 @@ def spill_func(
return ctx.spill_collection.spill(amount, stream=DEFAULT_STREAM, device_mr=mr)


def walk_rmm_resource_stack(
mr: rmm.mr.DeviceMemoryResource,
) -> Generator[rmm.mr.DeviceMemoryResource, None, None]:
"""
Walk the RMM Memory Resources in the given RMM resource stack.

Parameters
----------
mr
The RMM Memory Resource to walk.

Yields
------
RMM Memory Resource

Notes
-----
The initial memory resource is yielded first. The ``upstream_mr``
property of upstream memory resources are recursively yielded until
no more upstream memory reources are found.
"""
yield mr
if hasattr(mr, "upstream_mr"):
yield from walk_rmm_resource_stack(mr.upstream_mr)


def rmpf_worker_setup(
worker: Any,
option_prefix: str,
Expand Down Expand Up @@ -702,23 +728,29 @@ def rmpf_worker_setup(
This function creates a new RMM memory pool, and
sets it as the current device resource.
"""
# Insert RMM resource adaptor on top of the current RMM resource stack.
mr = RmmResourceAdaptor(
upstream_mr=rmm.mr.get_current_device_resource(),
fallback_mr=(
# Use a managed memory resource if OOM protection is enabled.
rmm.mr.ManagedMemoryResource()
if options.get_or_default(
f"{option_prefix}oom_protection", default_value=False
)
else None
),
)
rmm.mr.set_current_device_resource(mr)
# Ensure that an RMM resource adaptor is present in the current RMM resource stack.
mr = rmm.mr.get_current_device_resource()
for child_mr in walk_rmm_resource_stack(mr):
if isinstance(child_mr, RmmResourceAdaptor):
resource_adaptor = child_mr
break
else:
resource_adaptor = mr = RmmResourceAdaptor(
upstream_mr=mr,
fallback_mr=(
# Use a managed memory resource if OOM protection is enabled.
rmm.mr.ManagedMemoryResource()
if options.get_or_default(
f"{option_prefix}oom_protection", default_value=False
)
else None
),
)
rmm.mr.set_current_device_resource(mr)

# Print statistics at worker shutdown.
if options.get_or_default(f"{option_prefix}statistics", default_value=False):
statistics = Statistics(enable=True, mr=mr)
statistics = Statistics(enable=True, mr=resource_adaptor)
else:
statistics = Statistics(enable=False)

Expand All @@ -740,7 +772,7 @@ def rmpf_worker_setup(
)
memory_available = {
MemoryType.DEVICE: LimitAvailableMemory(
mr, limit=int(total_memory * spill_device)
resource_adaptor, limit=int(total_memory * spill_device)
)
}
br = BufferResource(
Expand Down
2 changes: 2 additions & 0 deletions python/rapidsmpf/rapidsmpf/statistics.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class Statistics:
mr: RmmResourceAdaptor | None = None,
) -> None: ...
@property
def mr(self) -> RmmResourceAdaptor | None: ...
@property
def enabled(self) -> bool: ...
def report(self) -> str: ...
def get_stat(self, name: str) -> dict[str, Number]: ...
Expand Down
9 changes: 9 additions & 0 deletions python/rapidsmpf/rapidsmpf/statistics.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ cdef class Statistics:
"""
return deref(self._handle).enabled()

@property
def mr(self):
"""
The RMM Memory Resource this Statistics was initialized with, if enabled.

This is None if statistics are not enabled.
"""
return self._mr

def report(self):
"""
Generates a report of statistics in a formatted string.
Expand Down
68 changes: 68 additions & 0 deletions python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

import rmm.mr

import rapidsmpf.communicator.single
import rapidsmpf.config
import rapidsmpf.integrations.core
import rapidsmpf.rmm_resource_adaptor


class Worker:
pass


@pytest.mark.parametrize("case", ["cuda", "stats-cuda", "stats-pool-cuda"])
def test_rmpf_worker_setup_memory_resource(
device_mr: rmm.mr.CudaMemoryResource, case: str
) -> None:
if case == "cuda":
mr = rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(device_mr)
elif case == "stats-cuda":
mr = rmm.mr.StatisticsResourceAdaptor(
rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(device_mr)
)
elif case == "stats-pool-cuda":
mr = rmm.mr.StatisticsResourceAdaptor(
rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(
rmm.mr.PoolMemoryResource(device_mr)
)
)
else:
raise AssertionError(f"Unknown case: {case}")
rmm.mr.set_current_device_resource(mr)

statistics = "stats" in case

if statistics:
options = rapidsmpf.config.Options({"single_statistics": "true"})
else:
options = rapidsmpf.config.Options()
comm = rapidsmpf.communicator.single.new_communicator(options=options)
# call
worker = Worker()
worker_context = rapidsmpf.integrations.core.rmpf_worker_setup(
worker, "single_", comm=comm, options=options
)

# The global is set
assert rmm.mr.get_current_device_resource() is mr

assert worker_context.statistics.enabled is statistics

if statistics:
assert isinstance(
worker_context.statistics.mr,
rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor,
)
assert worker_context.statistics.mr is mr.get_upstream()
else:
assert worker_context.statistics.mr is None

assert worker_context.br.device_mr is mr
# Can't say much about worker_context.br.memory_available, since it just returns an int
12 changes: 11 additions & 1 deletion python/rapidsmpf/rapidsmpf/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import dask.dataframe as dd
import pytest

import rmm.mr

import rapidsmpf.integrations.single
from rapidsmpf.communicator import COMMUNICATORS
from rapidsmpf.config import Options
Expand Down Expand Up @@ -121,14 +123,22 @@ def test_dask_cudf_integration(
@pytest.mark.parametrize("partition_count", [None, 3])
@pytest.mark.parametrize("sort", [True, False])
@pytest.mark.parametrize("cluster_kind", ["auto", "single"])
@pytest.mark.parametrize("preconfigure_mr", [False, True])
def test_dask_cudf_integration_single(
partition_count: int,
sort: bool, # noqa: FBT001
*,
sort: bool,
cluster_kind: Literal["distributed", "single", "auto"],
preconfigure_mr: bool,
device_mr: rmm.mr.CudaMemoryResource,
) -> None:
# Test single-worker cuDF integration with Dask-cuDF
pytest.importorskip("dask_cudf")

if preconfigure_mr:
mr = rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(device_mr)
rmm.mr.set_current_device_resource(mr)

df = (
dask.datasets.timeseries(
freq="3600s",
Expand Down
Loading