-
Notifications
You must be signed in to change notification settings - Fork 5
Add NCCL support to DistributedArray #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
f418a7d
implements NCCL collective calls in DistributedArray.py
tharittk b0da374
Fix dimensional slicing bug in NCCL asarray()
tharittk bf15ea0
add utils.backend and deps and move some nccl-related function. Fix s…
tharittk f1238cb
move nccl-related calls to backend.py to avoid direct import
tharittk 6d0895e
change protected import pattern (backend.py -> _nccl.py) and docs sty…
tharittk cdc5950
limit CI test to only trigger pure MPI tests and not nccl/cupy
tharittk 3dc41fe
Change DistributedArray() to take base_comm and base_comm_nccl as sug…
tharittk 9693d9b
base_comm instantiation - suggested in PR
tharittk File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,5 @@ | ||
# isort: skip_file | ||
from .dottest import * | ||
|
||
# currently dottest create circular dependency with DistributedArray.py | ||
# from .dottest import * | ||
from .deps import * |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,288 @@ | ||
__all__ = [ | ||
"initialize_nccl_comm", | ||
"nccl_split", | ||
"nccl_allgather", | ||
"nccl_allreduce", | ||
"nccl_bcast", | ||
"nccl_asarray" | ||
] | ||
|
||
from enum import IntEnum | ||
from mpi4py import MPI | ||
import os | ||
import numpy as np | ||
import cupy as cp | ||
import cupy.cuda.nccl as nccl | ||
|
||
cupy_to_nccl_dtype = { | ||
"float32": nccl.NCCL_FLOAT32, | ||
"float64": nccl.NCCL_FLOAT64, | ||
"int32": nccl.NCCL_INT32, | ||
"int64": nccl.NCCL_INT64, | ||
"uint8": nccl.NCCL_UINT8, | ||
"int8": nccl.NCCL_INT8, | ||
"uint32": nccl.NCCL_UINT32, | ||
"uint64": nccl.NCCL_UINT64, | ||
} | ||
|
||
|
||
class NcclOp(IntEnum): | ||
SUM = nccl.NCCL_SUM | ||
PROD = nccl.NCCL_PROD | ||
MAX = nccl.NCCL_MAX | ||
MIN = nccl.NCCL_MIN | ||
|
||
|
||
def mpi_op_to_nccl(mpi_op) -> NcclOp: | ||
""" Map MPI reduction operation to NCCL equivalent | ||
|
||
Parameters | ||
---------- | ||
mpi_op : :obj:`MPI.Op` | ||
A MPI reduction operation (e.g., MPI.SUM, MPI.PROD, MPI.MAX, MPI.MIN). | ||
|
||
Returns: | ||
------- | ||
NcclOp : :obj:`IntEnum` | ||
A corresponding NCCL reduction operation. | ||
""" | ||
if mpi_op is MPI.SUM: | ||
return NcclOp.SUM | ||
elif mpi_op is MPI.PROD: | ||
return NcclOp.PROD | ||
elif mpi_op is MPI.MAX: | ||
return NcclOp.MAX | ||
elif mpi_op is MPI.MIN: | ||
return NcclOp.MIN | ||
else: | ||
raise ValueError(f"Unsupported MPI.Op for NCCL: {mpi_op}") | ||
|
||
|
||
def initialize_nccl_comm() -> nccl.NcclCommunicator: | ||
""" Initialize NCCL world communicator for every GPU device | ||
|
||
Each GPU must be managed by exactly one MPI process. | ||
i.e. the number of MPI process launched must be equal to | ||
number of GPUs in communications | ||
|
||
Returns: | ||
------- | ||
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` | ||
A corresponding NCCL communicator | ||
""" | ||
comm = MPI.COMM_WORLD | ||
rank = comm.Get_rank() | ||
size = comm.Get_size() | ||
device_id = int( | ||
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK") | ||
or rank % cp.cuda.runtime.getDeviceCount() | ||
) | ||
cp.cuda.Device(device_id).use() | ||
|
||
if rank == 0: | ||
with cp.cuda.Device(device_id): | ||
nccl_id_bytes = nccl.get_unique_id() | ||
else: | ||
nccl_id_bytes = None | ||
nccl_id_bytes = comm.bcast(nccl_id_bytes, root=0) | ||
|
||
nccl_comm = nccl.NcclCommunicator(size, nccl_id_bytes, rank) | ||
return nccl_comm | ||
|
||
|
||
def nccl_split(mask) -> nccl.NcclCommunicator: | ||
""" NCCL-equivalent of MPI.Split() | ||
|
||
Splitting the communicator into multiple NCCL subcommunicators | ||
|
||
Parameters | ||
---------- | ||
mask : :obj:`list` | ||
Mask defining subsets of ranks to consider when performing 'global' | ||
operations on the distributed array such as dot product or norm. | ||
|
||
Returns: | ||
------- | ||
sub_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` | ||
Subcommunicator according to mask | ||
""" | ||
comm = MPI.COMM_WORLD | ||
rank = comm.Get_rank() | ||
sub_comm = comm.Split(color=mask[rank], key=rank) | ||
|
||
sub_rank = sub_comm.Get_rank() | ||
sub_size = sub_comm.Get_size() | ||
|
||
if sub_rank == 0: | ||
nccl_id_bytes = nccl.get_unique_id() | ||
else: | ||
nccl_id_bytes = None | ||
nccl_id_bytes = sub_comm.bcast(nccl_id_bytes, root=0) | ||
sub_comm = nccl.NcclCommunicator(sub_size, nccl_id_bytes, sub_rank) | ||
|
||
return sub_comm | ||
|
||
|
||
def nccl_allgather(nccl_comm, send_buf, recv_buf=None) -> cp.ndarray: | ||
""" NCCL equivalent of MPI_Allgather. Gathers data from all GPUs | ||
and distributes the concatenated result to all participants. | ||
|
||
Parameters | ||
---------- | ||
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` | ||
The NCCL communicator over which data will be gathered. | ||
send_buf : :obj:`cupy.ndarray` or array-like | ||
The data buffer from the local GPU to be sent. | ||
recv_buf : :obj:`cupy.ndarray`, optional | ||
The buffer to receive data from all GPUs. If None, a new | ||
buffer will be allocated with the appropriate shape. | ||
|
||
Returns | ||
------- | ||
recv_buf : :obj:`cupy.ndarray` | ||
A buffer containing the gathered data from all GPUs. | ||
""" | ||
send_buf = ( | ||
send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf) | ||
) | ||
if recv_buf is None: | ||
recv_buf = cp.zeros( | ||
MPI.COMM_WORLD.Get_size() * send_buf.size, | ||
dtype=send_buf.dtype, | ||
) | ||
nccl_comm.allGather( | ||
send_buf.data.ptr, | ||
recv_buf.data.ptr, | ||
send_buf.size, | ||
cupy_to_nccl_dtype[str(send_buf.dtype)], | ||
cp.cuda.Stream.null.ptr, | ||
) | ||
return recv_buf | ||
|
||
|
||
def nccl_allreduce(nccl_comm, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM) -> cp.ndarray: | ||
""" NCCL equivalent of MPI_Allreduce. Applies a reduction operation | ||
(e.g., sum, max) across all GPUs and distributes the result. | ||
|
||
Parameters | ||
---------- | ||
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` | ||
The NCCL communicator used for collective communication. | ||
send_buf : :obj:`cupy.ndarray` or array-like | ||
The data buffer from the local GPU to be reduced. | ||
recv_buf : :obj:`cupy.ndarray`, optional | ||
The buffer to store the result of the reduction. If None, | ||
a new buffer will be allocated with the appropriate shape. | ||
op : :obj:mpi4py.MPI.Op, optional | ||
The reduction operation to apply. Defaults to MPI.SUM. | ||
|
||
Returns | ||
------- | ||
recv_buf : :obj:`cupy.ndarray` | ||
A buffer containing the result of the reduction, broadcasted | ||
to all GPUs. | ||
""" | ||
send_buf = ( | ||
send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf) | ||
) | ||
if recv_buf is None: | ||
recv_buf = cp.zeros(send_buf.size, dtype=send_buf.dtype) | ||
|
||
nccl_comm.allReduce( | ||
send_buf.data.ptr, | ||
recv_buf.data.ptr, | ||
send_buf.size, | ||
cupy_to_nccl_dtype[str(send_buf.dtype)], | ||
mpi_op_to_nccl(op), | ||
cp.cuda.Stream.null.ptr, | ||
) | ||
return recv_buf | ||
|
||
|
||
def nccl_bcast(nccl_comm, local_array, index, value) -> None: | ||
""" NCCL equivalent of MPI_Bcast. Broadcasts a single value at the given index | ||
from the root GPU (rank 0) to all other GPUs. | ||
|
||
Parameters | ||
---------- | ||
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` | ||
The NCCL communicator used for collective communication. | ||
local_array : :obj:`cupy.ndarray` | ||
The local array on each GPU. The value at `index` will be broadcasted. | ||
index : :obj:`int` | ||
The index in the array to be broadcasted. | ||
value : :obj:`scalar` | ||
The value to broadcast (only used by the root GPU, rank 0). | ||
|
||
Returns | ||
------- | ||
None | ||
""" | ||
if nccl_comm.rank_id() == 0: | ||
local_array[index] = value | ||
nccl_comm.bcast( | ||
local_array[index].data.ptr, | ||
local_array[index].size, | ||
cupy_to_nccl_dtype[str(local_array[index].dtype)], | ||
0, | ||
cp.cuda.Stream.null.ptr, | ||
) | ||
|
||
|
||
def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray: | ||
"""Global view of the array | ||
|
||
Gather all local GPU arrays into a single global array via NCCL all-gather. | ||
|
||
Parameters | ||
---------- | ||
nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator` | ||
The NCCL communicator used for collective communication. | ||
local_array : :obj:`cupy.ndarray` | ||
The local array on the current GPU. | ||
local_shapes : :obj:`list` | ||
A list of shapes for each GPU local array (used to trim padding). | ||
axis : :obj:`int` | ||
The axis along which to concatenate the gathered arrays. | ||
|
||
Returns | ||
------- | ||
final_array : :obj:`cupy.ndarray` | ||
Global array gathered from all GPUs and concatenated along `axis`. | ||
|
||
Notes | ||
----- | ||
NCCL's allGather requires the sending buffer to have the same size for every device. | ||
Therefore, the padding is required when the array is not evenly partitioned across | ||
all the ranks. The padding is applied such that the sending buffer has the size of | ||
each dimension corresponding to the max possible size of that dimension. | ||
""" | ||
sizes_each_dim = list(zip(*local_shapes)) | ||
|
||
send_shape = tuple(map(max, sizes_each_dim)) | ||
pad_size = [ | ||
(0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, local_array.shape) | ||
] | ||
|
||
send_buf = cp.pad( | ||
local_array, pad_size, mode="constant", constant_values=0 | ||
) | ||
|
||
# NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred | ||
ndev = len(local_shapes) | ||
recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype) | ||
nccl_allgather(nccl_comm, send_buf, recv_buf) | ||
|
||
# extract an individual array from each device | ||
chunk_size = np.prod(send_shape) | ||
chunks = [ | ||
recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev) | ||
] | ||
|
||
# Remove padding from each array: the padded value may appear somewhere | ||
# in the middle of the flat array and thus the reshape and slicing for each dimension is required | ||
for i in range(ndev): | ||
slicing = tuple(slice(0, end) for end in local_shapes[i]) | ||
chunks[i] = chunks[i].reshape(send_shape)[slicing] | ||
# combine back to single global array | ||
return cp.concatenate(chunks, axis=axis) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
__all__ = [ | ||
"nccl_enabled" | ||
] | ||
|
||
import os | ||
from importlib import import_module, util | ||
from typing import Optional | ||
|
||
|
||
# error message at import of available package | ||
def nccl_import(message: Optional[str] = None) -> str: | ||
nccl_test = ( | ||
# detect if nccl is available and the user is expecting it to be used | ||
# CuPy must be checked first otherwise util.find_spec assumes it presents and check nccl immediately and lead to crash | ||
util.find_spec("cupy") is not None and util.find_spec("cupy.cuda.nccl") is not None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1 | ||
) | ||
if nccl_test: | ||
# try importing it | ||
try: | ||
import_module("cupy.cuda.nccl") # noqa: F401 | ||
|
||
# if succesful, set the message to None | ||
nccl_message = None | ||
tharittk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# if unable to import but the package is installed | ||
except (ImportError, ModuleNotFoundError) as e: | ||
nccl_message = ( | ||
f"Fail to import cupy.cuda.nccl, Falling back to pure MPI (error: {e})." | ||
"Please ensure your CUDA NCCL environment is set up correctly " | ||
"for more detials visit 'https://docs.cupy.dev/en/stable/install.html'" | ||
) | ||
print(UserWarning(nccl_message)) | ||
else: | ||
nccl_message = ( | ||
"cupy.cuda.nccl package not installed or os.getenv('NCCL_PYLOPS_MPI') == 0. " | ||
f"In order to be able to use {message} " | ||
"ensure 'os.getenv('NCCL_PYLOPS_MPI') == 1'" | ||
"for more details for installing NCCL visit 'https://docs.cupy.dev/en/stable/install.html'" | ||
) | ||
|
||
return nccl_message | ||
|
||
|
||
nccl_enabled: bool = ( | ||
True if (nccl_import() is None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1) else False | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
[tool:pytest] | ||
addopts = --verbose | ||
python_files = tests/*.py | ||
python_files = tests/*.py tests_nccl/*.py | ||
|
||
[flake8] | ||
ignore = E203, E501, W503, E402 | ||
per-file-ignores = | ||
__init__.py: F401, F403, F405 | ||
max-line-length = 88 | ||
__init__.py: F401, F403, F405 | ||
max-line-length = 88 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.