Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/rapidsmpf/rapidsmpf/buffer/resource.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class BufferResource:
def memory_reserved(self, mem_type: MemoryType) -> int: ...
@property
def spill_manager(self) -> SpillManager: ...
@property
def device_mr(self) -> DeviceMemoryResource: ...

class LimitAvailableMemory:
def __init__(
Expand Down
5 changes: 5 additions & 0 deletions python/rapidsmpf/rapidsmpf/buffer/resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ cdef class BufferResource:
"""
return self._handle.get()

@property
def device_mr(self):
"""The RMM Memory Resource this BufferResource was initialized with."""
return self._mr

def memory_reserved(self, MemoryType mem_type):
"""
Get the current reserved memory of the specified memory type.
Expand Down
65 changes: 49 additions & 16 deletions python/rapidsmpf/rapidsmpf/integrations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from rapidsmpf.statistics import Statistics

if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from collections.abc import Callable, Generator, Sequence

from rapidsmpf.communicator.communicator import Communicator

Expand Down Expand Up @@ -409,6 +409,32 @@ def spill_func(
return ctx.spill_collection.spill(amount, stream=DEFAULT_STREAM, device_mr=mr)


def walk_rmm_resource_stack(
mr: rmm.mr.DeviceMemoryResource,
) -> Generator[rmm.mr.DeviceMemoryResource, None, None]:
"""
Walk the RMM Memory Resources in the given RMM resource stack.

Parameters
----------
mr
The RMM Memory Resource to walk.

Yields
------
RMM Memory Resource

Notes
-----
The initial memory resource is yielded first. The ``upstream_mr``
property of upstream memory resources are recursively yielded until
no more upstream memory reources are found.
"""
yield mr
if hasattr(mr, "upstream_mr"):
yield from walk_rmm_resource_stack(mr.upstream_mr)


def rmpf_worker_setup(
worker: Any,
option_prefix: str,
Expand Down Expand Up @@ -440,23 +466,30 @@ def rmpf_worker_setup(
This function creates a new RMM memory pool, and
sets it as the current device resource.
"""
# Insert RMM resource adaptor on top of the current RMM resource stack.
mr = RmmResourceAdaptor(
upstream_mr=rmm.mr.get_current_device_resource(),
fallback_mr=(
# Use a managed memory resource if OOM protection is enabled.
rmm.mr.ManagedMemoryResource()
if options.get_or_default(
f"{option_prefix}oom_protection", default_value=False
)
else None
),
)
rmm.mr.set_current_device_resource(mr)
# Ensure that an RMM resource adaptor is present in the current RMM resource stack.
mr = rmm.mr.get_current_device_resource()

for child_mr in walk_rmm_resource_stack(mr):
if isinstance(child_mr, RmmResourceAdaptor):
resource_adaptor = child_mr
break
else:
resource_adaptor = mr = RmmResourceAdaptor(
upstream_mr=mr,
fallback_mr=(
# Use a managed memory resource if OOM protection is enabled.
rmm.mr.ManagedMemoryResource()
if options.get_or_default(
f"{option_prefix}oom_protection", default_value=False
)
else None
),
)
rmm.mr.set_current_device_resource(mr)

# Print statistics at worker shutdown.
if options.get_or_default(f"{option_prefix}statistics", default_value=False):
statistics = Statistics(enable=True, mr=mr)
statistics = Statistics(enable=True, mr=resource_adaptor)
else:
statistics = Statistics(enable=False)

Expand All @@ -478,7 +511,7 @@ def rmpf_worker_setup(
)
memory_available = {
MemoryType.DEVICE: LimitAvailableMemory(
mr, limit=int(total_memory * spill_device)
resource_adaptor, limit=int(total_memory * spill_device)
)
}
br = BufferResource(
Expand Down
2 changes: 2 additions & 0 deletions python/rapidsmpf/rapidsmpf/statistics.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class Statistics:
mr: RmmResourceAdaptor | None = None,
) -> None: ...
@property
def mr(self) -> RmmResourceAdaptor | None: ...
@property
def enabled(self) -> bool: ...
def report(self) -> str: ...
def get_stat(self, name: str) -> dict[str, Number]: ...
Expand Down
7 changes: 7 additions & 0 deletions python/rapidsmpf/rapidsmpf/statistics.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ cdef class Statistics:
"""
return deref(self._handle).enabled()

@property
def mr(self):
"""
The RMM Memory Resource this Statistics was initialized with.
"""
return self._mr

def report(self):
"""
Generates a report of statistics in a formatted string.
Expand Down
70 changes: 70 additions & 0 deletions python/rapidsmpf/rapidsmpf/tests/integrations/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

import rmm.mr

import rapidsmpf.communicator.single
import rapidsmpf.config
import rapidsmpf.integrations.core
import rapidsmpf.rmm_resource_adaptor


class Worker:
pass


# @pytest.mark.parametrize("statistics", [False, True])
@pytest.mark.parametrize("case", ["cuda", "stats-cuda", "stats-pool-cuda"])
def test_rmpf_worker_setup_memory_resource(case: str) -> None:
# setup
if case == "cuda":
mr = rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(
rmm.mr.CudaMemoryResource()
)
elif case == "stats-cuda":
mr = rmm.mr.StatisticsResourceAdaptor(
rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(
rmm.mr.CudaMemoryResource()
)
)
elif case == "stats-pool-cuda":
mr = rmm.mr.StatisticsResourceAdaptor(
rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(
rmm.mr.PoolMemoryResource(rmm.mr.CudaMemoryResource())
)
)
rmm.mr.set_current_device_resource(mr)

statistics = "stats" in case

if statistics:
options = rapidsmpf.config.Options({"single_statistics": "true"})
else:
options = rapidsmpf.config.Options()
comm = rapidsmpf.communicator.single.new_communicator(options=options)
# call
worker = Worker()
worker_context = rapidsmpf.integrations.core.rmpf_worker_setup(
worker, "single_", comm=comm, options=options
)

# The global is set
assert rmm.mr.get_current_device_resource() is mr

assert worker_context.statistics.enabled is statistics

if statistics:
assert isinstance(
worker_context.statistics.mr,
rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor,
)
assert worker_context.statistics.mr is mr.get_upstream()
else:
assert worker_context.statistics.mr is None

assert worker_context.br.device_mr is mr
# Can't say much about worker_context.br.memory_available, since it just returns an int
13 changes: 12 additions & 1 deletion python/rapidsmpf/rapidsmpf/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import dask.dataframe as dd
import pytest

import rmm.mr

import rapidsmpf.integrations.single
from rapidsmpf.communicator import COMMUNICATORS
from rapidsmpf.config import Options
Expand Down Expand Up @@ -117,14 +119,23 @@ def test_dask_cudf_integration(
@pytest.mark.parametrize("partition_count", [None, 3])
@pytest.mark.parametrize("sort", [True, False])
@pytest.mark.parametrize("cluster_kind", ["auto", "single"])
@pytest.mark.parametrize("preconfigure_mr", [False, True])
def test_dask_cudf_integration_single(
partition_count: int,
sort: bool, # noqa: FBT001
*,
sort: bool,
cluster_kind: Literal["distributed", "single", "auto"],
preconfigure_mr: bool,
) -> None:
# Test single-worker cuDF integration with Dask-cuDF
pytest.importorskip("dask_cudf")

if preconfigure_mr:
mr = rapidsmpf.rmm_resource_adaptor.RmmResourceAdaptor(
rmm.mr.CudaMemoryResource()
)
rmm.mr.set_current_device_resource(mr)

df = (
dask.datasets.timeseries(
freq="3600s",
Expand Down
Loading