Skip to content

Commit f9f5d13

Browse files
authored
Merge pull request #130 from tharittk/distarray-nccl
Add NCCL support to DistributedArray
2 parents cdd9809 + 9693d9b commit f9f5d13

File tree

9 files changed

+910
-59
lines changed

9 files changed

+910
-59
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,4 @@ jobs:
4343
- name: Install pylops-mpi
4444
run: pip install .
4545
- name: Testing using pytest-mpi
46-
run: mpiexec -n ${{ matrix.rank }} pytest --with-mpi
46+
run: mpiexec -n ${{ matrix.rank }} pytest tests/ --with-mpi

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ lint:
3636
tests:
3737
mpiexec -n $(NUM_PROCESSES) pytest tests/ --with-mpi
3838

39+
# assuming NUM_PRCESS <= number of gpus available
40+
tests_nccl:
41+
mpiexec -n $(NUM_PROCESSES) pytest tests_nccl/ --with-mpi
42+
3943
doc:
4044
cd docs && rm -rf source/api/generated && rm -rf source/gallery &&\
4145
rm -rf source/tutorials && rm -rf build &&\

pylops_mpi/DistributedArray.py

Lines changed: 155 additions & 52 deletions
Large diffs are not rendered by default.

pylops_mpi/utils/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
# isort: skip_file
2-
from .dottest import *
2+
3+
# currently dottest create circular dependency with DistributedArray.py
4+
# from .dottest import *
5+
from .deps import *

pylops_mpi/utils/_nccl.py

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
__all__ = [
2+
"initialize_nccl_comm",
3+
"nccl_split",
4+
"nccl_allgather",
5+
"nccl_allreduce",
6+
"nccl_bcast",
7+
"nccl_asarray"
8+
]
9+
10+
from enum import IntEnum
11+
from mpi4py import MPI
12+
import os
13+
import numpy as np
14+
import cupy as cp
15+
import cupy.cuda.nccl as nccl
16+
17+
cupy_to_nccl_dtype = {
18+
"float32": nccl.NCCL_FLOAT32,
19+
"float64": nccl.NCCL_FLOAT64,
20+
"int32": nccl.NCCL_INT32,
21+
"int64": nccl.NCCL_INT64,
22+
"uint8": nccl.NCCL_UINT8,
23+
"int8": nccl.NCCL_INT8,
24+
"uint32": nccl.NCCL_UINT32,
25+
"uint64": nccl.NCCL_UINT64,
26+
}
27+
28+
29+
class NcclOp(IntEnum):
30+
SUM = nccl.NCCL_SUM
31+
PROD = nccl.NCCL_PROD
32+
MAX = nccl.NCCL_MAX
33+
MIN = nccl.NCCL_MIN
34+
35+
36+
def mpi_op_to_nccl(mpi_op) -> NcclOp:
37+
""" Map MPI reduction operation to NCCL equivalent
38+
39+
Parameters
40+
----------
41+
mpi_op : :obj:`MPI.Op`
42+
A MPI reduction operation (e.g., MPI.SUM, MPI.PROD, MPI.MAX, MPI.MIN).
43+
44+
Returns:
45+
-------
46+
NcclOp : :obj:`IntEnum`
47+
A corresponding NCCL reduction operation.
48+
"""
49+
if mpi_op is MPI.SUM:
50+
return NcclOp.SUM
51+
elif mpi_op is MPI.PROD:
52+
return NcclOp.PROD
53+
elif mpi_op is MPI.MAX:
54+
return NcclOp.MAX
55+
elif mpi_op is MPI.MIN:
56+
return NcclOp.MIN
57+
else:
58+
raise ValueError(f"Unsupported MPI.Op for NCCL: {mpi_op}")
59+
60+
61+
def initialize_nccl_comm() -> nccl.NcclCommunicator:
62+
""" Initialize NCCL world communicator for every GPU device
63+
64+
Each GPU must be managed by exactly one MPI process.
65+
i.e. the number of MPI process launched must be equal to
66+
number of GPUs in communications
67+
68+
Returns:
69+
-------
70+
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
71+
A corresponding NCCL communicator
72+
"""
73+
comm = MPI.COMM_WORLD
74+
rank = comm.Get_rank()
75+
size = comm.Get_size()
76+
device_id = int(
77+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
78+
or rank % cp.cuda.runtime.getDeviceCount()
79+
)
80+
cp.cuda.Device(device_id).use()
81+
82+
if rank == 0:
83+
with cp.cuda.Device(device_id):
84+
nccl_id_bytes = nccl.get_unique_id()
85+
else:
86+
nccl_id_bytes = None
87+
nccl_id_bytes = comm.bcast(nccl_id_bytes, root=0)
88+
89+
nccl_comm = nccl.NcclCommunicator(size, nccl_id_bytes, rank)
90+
return nccl_comm
91+
92+
93+
def nccl_split(mask) -> nccl.NcclCommunicator:
94+
""" NCCL-equivalent of MPI.Split()
95+
96+
Splitting the communicator into multiple NCCL subcommunicators
97+
98+
Parameters
99+
----------
100+
mask : :obj:`list`
101+
Mask defining subsets of ranks to consider when performing 'global'
102+
operations on the distributed array such as dot product or norm.
103+
104+
Returns:
105+
-------
106+
sub_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
107+
Subcommunicator according to mask
108+
"""
109+
comm = MPI.COMM_WORLD
110+
rank = comm.Get_rank()
111+
sub_comm = comm.Split(color=mask[rank], key=rank)
112+
113+
sub_rank = sub_comm.Get_rank()
114+
sub_size = sub_comm.Get_size()
115+
116+
if sub_rank == 0:
117+
nccl_id_bytes = nccl.get_unique_id()
118+
else:
119+
nccl_id_bytes = None
120+
nccl_id_bytes = sub_comm.bcast(nccl_id_bytes, root=0)
121+
sub_comm = nccl.NcclCommunicator(sub_size, nccl_id_bytes, sub_rank)
122+
123+
return sub_comm
124+
125+
126+
def nccl_allgather(nccl_comm, send_buf, recv_buf=None) -> cp.ndarray:
127+
""" NCCL equivalent of MPI_Allgather. Gathers data from all GPUs
128+
and distributes the concatenated result to all participants.
129+
130+
Parameters
131+
----------
132+
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
133+
The NCCL communicator over which data will be gathered.
134+
send_buf : :obj:`cupy.ndarray` or array-like
135+
The data buffer from the local GPU to be sent.
136+
recv_buf : :obj:`cupy.ndarray`, optional
137+
The buffer to receive data from all GPUs. If None, a new
138+
buffer will be allocated with the appropriate shape.
139+
140+
Returns
141+
-------
142+
recv_buf : :obj:`cupy.ndarray`
143+
A buffer containing the gathered data from all GPUs.
144+
"""
145+
send_buf = (
146+
send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf)
147+
)
148+
if recv_buf is None:
149+
recv_buf = cp.zeros(
150+
MPI.COMM_WORLD.Get_size() * send_buf.size,
151+
dtype=send_buf.dtype,
152+
)
153+
nccl_comm.allGather(
154+
send_buf.data.ptr,
155+
recv_buf.data.ptr,
156+
send_buf.size,
157+
cupy_to_nccl_dtype[str(send_buf.dtype)],
158+
cp.cuda.Stream.null.ptr,
159+
)
160+
return recv_buf
161+
162+
163+
def nccl_allreduce(nccl_comm, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM) -> cp.ndarray:
164+
""" NCCL equivalent of MPI_Allreduce. Applies a reduction operation
165+
(e.g., sum, max) across all GPUs and distributes the result.
166+
167+
Parameters
168+
----------
169+
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
170+
The NCCL communicator used for collective communication.
171+
send_buf : :obj:`cupy.ndarray` or array-like
172+
The data buffer from the local GPU to be reduced.
173+
recv_buf : :obj:`cupy.ndarray`, optional
174+
The buffer to store the result of the reduction. If None,
175+
a new buffer will be allocated with the appropriate shape.
176+
op : :obj:mpi4py.MPI.Op, optional
177+
The reduction operation to apply. Defaults to MPI.SUM.
178+
179+
Returns
180+
-------
181+
recv_buf : :obj:`cupy.ndarray`
182+
A buffer containing the result of the reduction, broadcasted
183+
to all GPUs.
184+
"""
185+
send_buf = (
186+
send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf)
187+
)
188+
if recv_buf is None:
189+
recv_buf = cp.zeros(send_buf.size, dtype=send_buf.dtype)
190+
191+
nccl_comm.allReduce(
192+
send_buf.data.ptr,
193+
recv_buf.data.ptr,
194+
send_buf.size,
195+
cupy_to_nccl_dtype[str(send_buf.dtype)],
196+
mpi_op_to_nccl(op),
197+
cp.cuda.Stream.null.ptr,
198+
)
199+
return recv_buf
200+
201+
202+
def nccl_bcast(nccl_comm, local_array, index, value) -> None:
203+
""" NCCL equivalent of MPI_Bcast. Broadcasts a single value at the given index
204+
from the root GPU (rank 0) to all other GPUs.
205+
206+
Parameters
207+
----------
208+
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
209+
The NCCL communicator used for collective communication.
210+
local_array : :obj:`cupy.ndarray`
211+
The local array on each GPU. The value at `index` will be broadcasted.
212+
index : :obj:`int`
213+
The index in the array to be broadcasted.
214+
value : :obj:`scalar`
215+
The value to broadcast (only used by the root GPU, rank 0).
216+
217+
Returns
218+
-------
219+
None
220+
"""
221+
if nccl_comm.rank_id() == 0:
222+
local_array[index] = value
223+
nccl_comm.bcast(
224+
local_array[index].data.ptr,
225+
local_array[index].size,
226+
cupy_to_nccl_dtype[str(local_array[index].dtype)],
227+
0,
228+
cp.cuda.Stream.null.ptr,
229+
)
230+
231+
232+
def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray:
233+
"""Global view of the array
234+
235+
Gather all local GPU arrays into a single global array via NCCL all-gather.
236+
237+
Parameters
238+
----------
239+
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`
240+
The NCCL communicator used for collective communication.
241+
local_array : :obj:`cupy.ndarray`
242+
The local array on the current GPU.
243+
local_shapes : :obj:`list`
244+
A list of shapes for each GPU local array (used to trim padding).
245+
axis : :obj:`int`
246+
The axis along which to concatenate the gathered arrays.
247+
248+
Returns
249+
-------
250+
final_array : :obj:`cupy.ndarray`
251+
Global array gathered from all GPUs and concatenated along `axis`.
252+
253+
Notes
254+
-----
255+
NCCL's allGather requires the sending buffer to have the same size for every device.
256+
Therefore, the padding is required when the array is not evenly partitioned across
257+
all the ranks. The padding is applied such that the sending buffer has the size of
258+
each dimension corresponding to the max possible size of that dimension.
259+
"""
260+
sizes_each_dim = list(zip(*local_shapes))
261+
262+
send_shape = tuple(map(max, sizes_each_dim))
263+
pad_size = [
264+
(0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, local_array.shape)
265+
]
266+
267+
send_buf = cp.pad(
268+
local_array, pad_size, mode="constant", constant_values=0
269+
)
270+
271+
# NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred
272+
ndev = len(local_shapes)
273+
recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype)
274+
nccl_allgather(nccl_comm, send_buf, recv_buf)
275+
276+
# extract an individual array from each device
277+
chunk_size = np.prod(send_shape)
278+
chunks = [
279+
recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev)
280+
]
281+
282+
# Remove padding from each array: the padded value may appear somewhere
283+
# in the middle of the flat array and thus the reshape and slicing for each dimension is required
284+
for i in range(ndev):
285+
slicing = tuple(slice(0, end) for end in local_shapes[i])
286+
chunks[i] = chunks[i].reshape(send_shape)[slicing]
287+
# combine back to single global array
288+
return cp.concatenate(chunks, axis=axis)

pylops_mpi/utils/deps.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
__all__ = [
2+
"nccl_enabled"
3+
]
4+
5+
import os
6+
from importlib import import_module, util
7+
from typing import Optional
8+
9+
10+
# error message at import of available package
11+
def nccl_import(message: Optional[str] = None) -> str:
12+
nccl_test = (
13+
# detect if nccl is available and the user is expecting it to be used
14+
# CuPy must be checked first otherwise util.find_spec assumes it presents and check nccl immediately and lead to crash
15+
util.find_spec("cupy") is not None and util.find_spec("cupy.cuda.nccl") is not None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1
16+
)
17+
if nccl_test:
18+
# try importing it
19+
try:
20+
import_module("cupy.cuda.nccl") # noqa: F401
21+
22+
# if succesful, set the message to None
23+
nccl_message = None
24+
# if unable to import but the package is installed
25+
except (ImportError, ModuleNotFoundError) as e:
26+
nccl_message = (
27+
f"Fail to import cupy.cuda.nccl, Falling back to pure MPI (error: {e})."
28+
"Please ensure your CUDA NCCL environment is set up correctly "
29+
"for more detials visit 'https://docs.cupy.dev/en/stable/install.html'"
30+
)
31+
print(UserWarning(nccl_message))
32+
else:
33+
nccl_message = (
34+
"cupy.cuda.nccl package not installed or os.getenv('NCCL_PYLOPS_MPI') == 0. "
35+
f"In order to be able to use {message} "
36+
"ensure 'os.getenv('NCCL_PYLOPS_MPI') == 1'"
37+
"for more details for installing NCCL visit 'https://docs.cupy.dev/en/stable/install.html'"
38+
)
39+
40+
return nccl_message
41+
42+
43+
nccl_enabled: bool = (
44+
True if (nccl_import() is None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1) else False
45+
)

setup.cfg

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
[tool:pytest]
22
addopts = --verbose
3-
python_files = tests/*.py
3+
python_files = tests/*.py tests_nccl/*.py
44

55
[flake8]
66
ignore = E203, E501, W503, E402
77
per-file-ignores =
8-
__init__.py: F401, F403, F405
9-
max-line-length = 88
8+
__init__.py: F401, F403, F405
9+
max-line-length = 88

tests/test_distributedarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def test_distributed_maskeddot(par1, par2):
201201
"""Test Distributed Dot product with masked array"""
202202
# number of subcommunicators
203203
if MPI.COMM_WORLD.Get_size() % 2 == 0:
204-
nsub = 2
204+
nsub = 2
205205
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
206206
nsub = 3
207207
else:
@@ -236,7 +236,7 @@ def test_distributed_maskednorm(par):
236236
"""Test Distributed numpy.linalg.norm method with masked array"""
237237
# number of subcommunicators
238238
if MPI.COMM_WORLD.Get_size() % 2 == 0:
239-
nsub = 2
239+
nsub = 2
240240
elif MPI.COMM_WORLD.Get_size() % 3 == 0:
241241
nsub = 3
242242
else:

0 commit comments

Comments
 (0)