Skip to content

Add NCCL support to DistributedArray #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ jobs:
- name: Install pylops-mpi
run: pip install .
- name: Testing using pytest-mpi
run: mpiexec -n ${{ matrix.rank }} pytest --with-mpi
run: mpiexec -n ${{ matrix.rank }} pytest tests/ --with-mpi
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ lint:
tests:
mpiexec -n $(NUM_PROCESSES) pytest tests/ --with-mpi

# assuming NUM_PRCESS <= number of gpus available
tests_nccl:
mpiexec -n $(NUM_PROCESSES) pytest tests_nccl/ --with-mpi

doc:
cd docs && rm -rf source/api/generated && rm -rf source/gallery &&\
rm -rf source/tutorials && rm -rf build &&\
Expand Down
207 changes: 155 additions & 52 deletions pylops_mpi/DistributedArray.py

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion pylops_mpi/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
# isort: skip_file
from .dottest import *

# currently dottest create circular dependency with DistributedArray.py
# from .dottest import *
from .deps import *
288 changes: 288 additions & 0 deletions pylops_mpi/utils/_nccl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
__all__ = [
"initialize_nccl_comm",
"nccl_split",
"nccl_allgather",
"nccl_allreduce",
"nccl_bcast",
"nccl_asarray"
]

from enum import IntEnum
from mpi4py import MPI
import os
import numpy as np
import cupy as cp
import cupy.cuda.nccl as nccl

cupy_to_nccl_dtype = {
"float32": nccl.NCCL_FLOAT32,
"float64": nccl.NCCL_FLOAT64,
"int32": nccl.NCCL_INT32,
"int64": nccl.NCCL_INT64,
"uint8": nccl.NCCL_UINT8,
"int8": nccl.NCCL_INT8,
"uint32": nccl.NCCL_UINT32,
"uint64": nccl.NCCL_UINT64,
}


class NcclOp(IntEnum):
SUM = nccl.NCCL_SUM
PROD = nccl.NCCL_PROD
MAX = nccl.NCCL_MAX
MIN = nccl.NCCL_MIN


def mpi_op_to_nccl(mpi_op) -> NcclOp:
""" Map MPI reduction operation to NCCL equivalent

Parameters
----------
mpi_op : :obj:`MPI.Op`
A MPI reduction operation (e.g., MPI.SUM, MPI.PROD, MPI.MAX, MPI.MIN).

Returns:
-------
NcclOp : :obj:`IntEnum`
A corresponding NCCL reduction operation.
"""
if mpi_op is MPI.SUM:
return NcclOp.SUM
elif mpi_op is MPI.PROD:
return NcclOp.PROD
elif mpi_op is MPI.MAX:
return NcclOp.MAX
elif mpi_op is MPI.MIN:
return NcclOp.MIN
else:
raise ValueError(f"Unsupported MPI.Op for NCCL: {mpi_op}")


def initialize_nccl_comm() -> nccl.NcclCommunicator:
""" Initialize NCCL world communicator for every GPU device

Each GPU must be managed by exactly one MPI process.
i.e. the number of MPI process launched must be equal to
number of GPUs in communications

Returns:
-------
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
A corresponding NCCL communicator
"""
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
device_id = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
or rank % cp.cuda.runtime.getDeviceCount()
)
cp.cuda.Device(device_id).use()

if rank == 0:
with cp.cuda.Device(device_id):
nccl_id_bytes = nccl.get_unique_id()
else:
nccl_id_bytes = None
nccl_id_bytes = comm.bcast(nccl_id_bytes, root=0)

nccl_comm = nccl.NcclCommunicator(size, nccl_id_bytes, rank)
return nccl_comm


def nccl_split(mask) -> nccl.NcclCommunicator:
""" NCCL-equivalent of MPI.Split()

Splitting the communicator into multiple NCCL subcommunicators

Parameters
----------
mask : :obj:`list`
Mask defining subsets of ranks to consider when performing 'global'
operations on the distributed array such as dot product or norm.

Returns:
-------
sub_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
Subcommunicator according to mask
"""
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
sub_comm = comm.Split(color=mask[rank], key=rank)

sub_rank = sub_comm.Get_rank()
sub_size = sub_comm.Get_size()

if sub_rank == 0:
nccl_id_bytes = nccl.get_unique_id()
else:
nccl_id_bytes = None
nccl_id_bytes = sub_comm.bcast(nccl_id_bytes, root=0)
sub_comm = nccl.NcclCommunicator(sub_size, nccl_id_bytes, sub_rank)

return sub_comm


def nccl_allgather(nccl_comm, send_buf, recv_buf=None) -> cp.ndarray:
""" NCCL equivalent of MPI_Allgather. Gathers data from all GPUs
and distributes the concatenated result to all participants.

Parameters
----------
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
The NCCL communicator over which data will be gathered.
send_buf : :obj:`cupy.ndarray` or array-like
The data buffer from the local GPU to be sent.
recv_buf : :obj:`cupy.ndarray`, optional
The buffer to receive data from all GPUs. If None, a new
buffer will be allocated with the appropriate shape.

Returns
-------
recv_buf : :obj:`cupy.ndarray`
A buffer containing the gathered data from all GPUs.
"""
send_buf = (
send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf)
)
if recv_buf is None:
recv_buf = cp.zeros(
MPI.COMM_WORLD.Get_size() * send_buf.size,
dtype=send_buf.dtype,
)
nccl_comm.allGather(
send_buf.data.ptr,
recv_buf.data.ptr,
send_buf.size,
cupy_to_nccl_dtype[str(send_buf.dtype)],
cp.cuda.Stream.null.ptr,
)
return recv_buf


def nccl_allreduce(nccl_comm, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM) -> cp.ndarray:
""" NCCL equivalent of MPI_Allreduce. Applies a reduction operation
(e.g., sum, max) across all GPUs and distributes the result.

Parameters
----------
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
The NCCL communicator used for collective communication.
send_buf : :obj:`cupy.ndarray` or array-like
The data buffer from the local GPU to be reduced.
recv_buf : :obj:`cupy.ndarray`, optional
The buffer to store the result of the reduction. If None,
a new buffer will be allocated with the appropriate shape.
op : :obj:mpi4py.MPI.Op, optional
The reduction operation to apply. Defaults to MPI.SUM.

Returns
-------
recv_buf : :obj:`cupy.ndarray`
A buffer containing the result of the reduction, broadcasted
to all GPUs.
"""
send_buf = (
send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf)
)
if recv_buf is None:
recv_buf = cp.zeros(send_buf.size, dtype=send_buf.dtype)

nccl_comm.allReduce(
send_buf.data.ptr,
recv_buf.data.ptr,
send_buf.size,
cupy_to_nccl_dtype[str(send_buf.dtype)],
mpi_op_to_nccl(op),
cp.cuda.Stream.null.ptr,
)
return recv_buf


def nccl_bcast(nccl_comm, local_array, index, value) -> None:
""" NCCL equivalent of MPI_Bcast. Broadcasts a single value at the given index
from the root GPU (rank 0) to all other GPUs.

Parameters
----------
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
The NCCL communicator used for collective communication.
local_array : :obj:`cupy.ndarray`
The local array on each GPU. The value at `index` will be broadcasted.
index : :obj:`int`
The index in the array to be broadcasted.
value : :obj:`scalar`
The value to broadcast (only used by the root GPU, rank 0).

Returns
-------
None
"""
if nccl_comm.rank_id() == 0:
local_array[index] = value
nccl_comm.bcast(
local_array[index].data.ptr,
local_array[index].size,
cupy_to_nccl_dtype[str(local_array[index].dtype)],
0,
cp.cuda.Stream.null.ptr,
)


def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray:
"""Global view of the array

Gather all local GPU arrays into a single global array via NCCL all-gather.

Parameters
----------
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
The NCCL communicator used for collective communication.
local_array : :obj:`cupy.ndarray`
The local array on the current GPU.
local_shapes : :obj:`list`
A list of shapes for each GPU local array (used to trim padding).
axis : :obj:`int`
The axis along which to concatenate the gathered arrays.

Returns
-------
final_array : :obj:`cupy.ndarray`
Global array gathered from all GPUs and concatenated along `axis`.

Notes
-----
NCCL's allGather requires the sending buffer to have the same size for every device.
Therefore, the padding is required when the array is not evenly partitioned across
all the ranks. The padding is applied such that the sending buffer has the size of
each dimension corresponding to the max possible size of that dimension.
"""
sizes_each_dim = list(zip(*local_shapes))

send_shape = tuple(map(max, sizes_each_dim))
pad_size = [
(0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, local_array.shape)
]

send_buf = cp.pad(
local_array, pad_size, mode="constant", constant_values=0
)

# NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred
ndev = len(local_shapes)
recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype)
nccl_allgather(nccl_comm, send_buf, recv_buf)

# extract an individual array from each device
chunk_size = np.prod(send_shape)
chunks = [
recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev)
]

# Remove padding from each array: the padded value may appear somewhere
# in the middle of the flat array and thus the reshape and slicing for each dimension is required
for i in range(ndev):
slicing = tuple(slice(0, end) for end in local_shapes[i])
chunks[i] = chunks[i].reshape(send_shape)[slicing]
# combine back to single global array
return cp.concatenate(chunks, axis=axis)
45 changes: 45 additions & 0 deletions pylops_mpi/utils/deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
__all__ = [
"nccl_enabled"
]

import os
from importlib import import_module, util
from typing import Optional


# error message at import of available package
def nccl_import(message: Optional[str] = None) -> str:
nccl_test = (
# detect if nccl is available and the user is expecting it to be used
# CuPy must be checked first otherwise util.find_spec assumes it presents and check nccl immediately and lead to crash
util.find_spec("cupy") is not None and util.find_spec("cupy.cuda.nccl") is not None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1
)
if nccl_test:
# try importing it
try:
import_module("cupy.cuda.nccl") # noqa: F401

# if succesful, set the message to None
nccl_message = None
# if unable to import but the package is installed
except (ImportError, ModuleNotFoundError) as e:
nccl_message = (
f"Fail to import cupy.cuda.nccl, Falling back to pure MPI (error: {e})."
"Please ensure your CUDA NCCL environment is set up correctly "
"for more detials visit 'https://docs.cupy.dev/en/stable/install.html'"
)
print(UserWarning(nccl_message))
else:
nccl_message = (
"cupy.cuda.nccl package not installed or os.getenv('NCCL_PYLOPS_MPI') == 0. "
f"In order to be able to use {message} "
"ensure 'os.getenv('NCCL_PYLOPS_MPI') == 1'"
"for more details for installing NCCL visit 'https://docs.cupy.dev/en/stable/install.html'"
)

return nccl_message


nccl_enabled: bool = (
True if (nccl_import() is None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1) else False
)
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[tool:pytest]
addopts = --verbose
python_files = tests/*.py
python_files = tests/*.py tests_nccl/*.py

[flake8]
ignore = E203, E501, W503, E402
per-file-ignores =
__init__.py: F401, F403, F405
max-line-length = 88
__init__.py: F401, F403, F405
max-line-length = 88
4 changes: 2 additions & 2 deletions tests/test_distributedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_distributed_maskeddot(par1, par2):
"""Test Distributed Dot product with masked array"""
# number of subcommunicators
if MPI.COMM_WORLD.Get_size() % 2 == 0:
nsub = 2
nsub = 2
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
nsub = 3
else:
Expand Down Expand Up @@ -236,7 +236,7 @@ def test_distributed_maskednorm(par):
"""Test Distributed numpy.linalg.norm method with masked array"""
# number of subcommunicators
if MPI.COMM_WORLD.Get_size() % 2 == 0:
nsub = 2
nsub = 2
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
nsub = 3
else:
Expand Down
Loading