Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions pylops_mpi/Distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from mpi4py import MPI
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
from pylops_mpi.utils._mpi import mpi_allreduce, mpi_allgather, mpi_send, mpi_recv, _prepare_allgather_inputs, _unroll_allgather_recv
from pylops_mpi.utils import deps

cupy_message = pylops_deps.cupy_import("the DistributedArray module")
nccl_message = deps.nccl_import("the DistributedArray module")

if nccl_message is None and cupy_message is None:
from pylops_mpi.utils._nccl import (
nccl_allgather, nccl_allreduce, nccl_send, nccl_recv
)


class DistributedMixIn:
r"""Distributed Mixin class

This class implements all methods associated with communication primitives
from MPI and NCCL. It is mostly charged to identifying which commuicator
to use and whether the buffered or object MPI primitives should be used
(the former in the case of NumPy arrays or CuPy arrays when a CUDA-Aware
MPI installation is available, the latter with CuPy arrays when a CUDA-Aware
MPI installation is not available).
"""
def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
"""Allreduce operation
"""
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op)
else:
return mpi_allreduce(self.base_comm, send_buf,
recv_buf, self.engine, op)

def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
"""Allreduce operation with subcommunicator
"""
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op)
else:
return mpi_allreduce(self.sub_comm, send_buf,
recv_buf, self.engine, op)

def _allgather(self, send_buf, recv_buf=None):
"""Allgather operation
"""
if deps.nccl_enabled and self.base_comm_nccl:
if isinstance(send_buf, (tuple, list, int)):
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
else:
send_shapes = self.base_comm.allgather(send_buf.shape)
(padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy")
raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv)
return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes)
else:
if isinstance(send_buf, (tuple, list, int)):
return self.base_comm.allgather(send_buf)
return mpi_allgather(self.base_comm, send_buf, recv_buf, self.engine)

def _allgather_subcomm(self, send_buf, recv_buf=None):
"""Allgather operation with subcommunicator
"""
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
if isinstance(send_buf, (tuple, list, int)):
return nccl_allgather(self.sub_comm, send_buf, recv_buf)
else:
send_shapes = self._allgather_subcomm(send_buf.shape)
(padded_send, padded_recv) = _prepare_allgather_inputs(send_buf, send_shapes, engine="cupy")
raw_recv = nccl_allgather(self.sub_comm, padded_send, recv_buf if recv_buf else padded_recv)
return _unroll_allgather_recv(raw_recv, padded_send.shape, send_shapes)
else:
return mpi_allgather(self.sub_comm, send_buf, recv_buf, self.engine)

def _send(self, send_buf, dest, count=None, tag=0):
"""Send operation
"""
if deps.nccl_enabled and self.base_comm_nccl:
if count is None:
count = send_buf.size
nccl_send(self.base_comm_nccl, send_buf, dest, count)
else:
mpi_send(self.base_comm,
send_buf, dest, count, tag=tag,
engine=self.engine)

def _recv(self, recv_buf=None, source=0, count=None, tag=0):
"""Receive operation
"""
if deps.nccl_enabled and self.base_comm_nccl:
if recv_buf is None:
raise ValueError("recv_buf must be supplied when using NCCL")
if count is None:
count = recv_buf.size
nccl_recv(self.base_comm_nccl, recv_buf, source, count)
return recv_buf
else:
return mpi_recv(self.base_comm,
recv_buf, source, count, tag=tag,
engine=self.engine)
111 changes: 16 additions & 95 deletions pylops_mpi/DistributedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from mpi4py import MPI
from pylops_mpi.Distributed import DistributedMixIn
from pylops.utils import DTypeLike, NDArray
from pylops.utils import deps as pylops_deps # avoid namespace crashes with pylops_mpi.utils
from pylops.utils._internal import _value_or_sized_to_tuple
Expand All @@ -14,7 +15,7 @@
nccl_message = deps.nccl_import("the DistributedArray module")

if nccl_message is None and cupy_message is None:
from pylops_mpi.utils._nccl import nccl_allgather, nccl_allreduce, nccl_asarray, nccl_bcast, nccl_split, nccl_send, nccl_recv, _prepare_nccl_allgather_inputs, _unroll_nccl_allgather_recv
from pylops_mpi.utils._nccl import nccl_asarray, nccl_bcast, nccl_split
from cupy.cuda.nccl import NcclCommunicator
else:
NcclCommunicator = Any
Expand Down Expand Up @@ -99,7 +100,7 @@ def subcomm_split(mask, comm: Optional[Union[MPI.Comm, NcclCommunicatorType]] =
return sub_comm


class DistributedArray:
class DistributedArray(DistributedMixIn):
r"""Distributed Numpy Arrays

Multidimensional NumPy-like distributed arrays.
Expand Down Expand Up @@ -472,92 +473,6 @@ def _check_mask(self, dist_array):
if not np.array_equal(self.mask, dist_array.mask):
raise ValueError("Mask of both the arrays must be same")

def _allreduce(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
"""Allreduce operation
"""
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
return nccl_allreduce(self.base_comm_nccl, send_buf, recv_buf, op)
else:
if recv_buf is None:
return self.base_comm.allreduce(send_buf, op)
# For MIN and MAX which require recv_buf
self.base_comm.Allreduce(send_buf, recv_buf, op)
return recv_buf

def _allreduce_subcomm(self, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM):
"""Allreduce operation with subcommunicator
"""
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
return nccl_allreduce(self.sub_comm, send_buf, recv_buf, op)
else:
if recv_buf is None:
return self.sub_comm.allreduce(send_buf, op)
# For MIN and MAX which require recv_buf
self.sub_comm.Allreduce(send_buf, recv_buf, op)
return recv_buf

def _allgather(self, send_buf, recv_buf=None):
"""Allgather operation
"""
if deps.nccl_enabled and self.base_comm_nccl:
if isinstance(send_buf, (tuple, list, int)):
return nccl_allgather(self.base_comm_nccl, send_buf, recv_buf)
else:
send_shapes = self.base_comm.allgather(send_buf.shape)
(padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes)
raw_recv = nccl_allgather(self.base_comm_nccl, padded_send, recv_buf if recv_buf else padded_recv)
return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes)
else:
if recv_buf is None:
return self.base_comm.allgather(send_buf)
self.base_comm.Allgather(send_buf, recv_buf)
return recv_buf

def _allgather_subcomm(self, send_buf, recv_buf=None):
"""Allgather operation with subcommunicator
"""
if deps.nccl_enabled and getattr(self, "base_comm_nccl"):
if isinstance(send_buf, (tuple, list, int)):
return nccl_allgather(self.sub_comm, send_buf, recv_buf)
else:
send_shapes = self._allgather_subcomm(send_buf.shape)
(padded_send, padded_recv) = _prepare_nccl_allgather_inputs(send_buf, send_shapes)
raw_recv = nccl_allgather(self.sub_comm, padded_send, recv_buf if recv_buf else padded_recv)
return _unroll_nccl_allgather_recv(raw_recv, padded_send.shape, send_shapes)
else:
if recv_buf is None:
return self.sub_comm.allgather(send_buf)
self.sub_comm.Allgather(send_buf, recv_buf)

def _send(self, send_buf, dest, count=None, tag=None):
""" Send operation
"""
if deps.nccl_enabled and self.base_comm_nccl:
if count is None:
# assuming sending the whole array
count = send_buf.size
nccl_send(self.base_comm_nccl, send_buf, dest, count)
else:
self.base_comm.send(send_buf, dest, tag)

def _recv(self, recv_buf=None, source=0, count=None, tag=None):
""" Receive operation
"""
# NCCL must be called with recv_buf. Size cannot be inferred from
# other arguments and thus cannot be dynamically allocated
if deps.nccl_enabled and self.base_comm_nccl and recv_buf is not None:
if recv_buf is not None:
if count is None:
# assuming data will take a space of the whole buffer
count = recv_buf.size
nccl_recv(self.base_comm_nccl, recv_buf, source, count)
return recv_buf
else:
raise ValueError("Using recv with NCCL must also supply receiver buffer ")
else:
# MPI allows a receiver buffer to be optional and receives as a Python Object
return self.base_comm.recv(source=source, tag=tag)

def _nccl_local_shapes(self, masked: bool):
"""Get the the list of shapes of every GPU in the communicator
"""
Expand Down Expand Up @@ -694,26 +609,32 @@ def _compute_vector_norm(self, local_array: NDArray,
recv_buf = self._allreduce_subcomm(ncp.count_nonzero(local_array, axis=axis).astype(ncp.float64))
elif ord == ncp.inf:
# Calculate max followed by max reduction
# TODO (tharitt): currently CuPy + MPI does not work well with buffered communication, particularly
# CuPy + non-CUDA-aware MPI does not work well with buffered communication, particularly
# with MAX, MIN operator. Here we copy the array back to CPU, transfer, and copy them back to GPUs
send_buf = ncp.max(ncp.abs(local_array), axis=axis).astype(ncp.float64)
if self.engine == "cupy" and self.base_comm_nccl is None:
if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled:
# CuPy + non-CUDA-aware MPI: This will call non-buffered communication
# which return a list of object - must be copied back to a GPU memory.
recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MAX)
recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
else:
recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MAX)
recv_buf = ncp.squeeze(recv_buf, axis=axis)
# TODO (tharitt): In current implementation, there seems to be a semantic difference between Buffered MPI and NCCL
# the (1, size) is collapsed to (size, ) with buffered MPI while NCCL retains it.
# There may be a way to unify it - may be something to do with how we allocate the recv_buf.
if self.base_comm_nccl:
recv_buf = ncp.squeeze(recv_buf, axis=axis)
elif ord == -ncp.inf:
# Calculate min followed by min reduction
# TODO (tharitt): see the comment above in infinity norm
# See the comment above in +infinity norm
send_buf = ncp.min(ncp.abs(local_array), axis=axis).astype(ncp.float64)
if self.engine == "cupy" and self.base_comm_nccl is None:
if self.engine == "cupy" and self.base_comm_nccl is None and not deps.cuda_aware_mpi_enabled:
recv_buf = self._allreduce_subcomm(send_buf.get(), recv_buf.get(), op=MPI.MIN)
recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
else:
recv_buf = self._allreduce_subcomm(send_buf, recv_buf, op=MPI.MIN)
recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))

if self.base_comm_nccl:
recv_buf = ncp.asarray(ncp.squeeze(recv_buf, axis=axis))
else:
recv_buf = self._allreduce_subcomm(ncp.sum(ncp.abs(ncp.float_power(local_array, ord)), axis=axis))
recv_buf = ncp.power(recv_buf, 1.0 / ord)
Expand Down
16 changes: 10 additions & 6 deletions pylops_mpi/basicoperators/VStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Partition,
StackedDistributedArray
)
from pylops_mpi.Distributed import DistributedMixIn
from pylops_mpi.utils.decorators import reshaped
from pylops_mpi.utils import deps

Expand All @@ -25,7 +26,7 @@
from pylops_mpi.utils._nccl import nccl_allreduce


class MPIVStack(MPILinearOperator):
class MPIVStack(DistributedMixIn, MPILinearOperator):
r"""MPI VStack Operator

Create a vertical stack of a set of linear operators using MPI. Each rank must
Expand Down Expand Up @@ -141,16 +142,19 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
@reshaped(forward=False, stacking=True)
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
ncp = get_module(x.engine)
y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, partition=Partition.BROADCAST,
# TODO: consider adding base_comm, base_comm_nccl, engine to the
# input parameters of _allreduce instead of relying on self
self.base_comm, self.base_comm_nccl, self.engine = \
x.base_comm, x.base_comm_nccl, x.engine
y = DistributedArray(global_shape=self.shape[1], base_comm=x.base_comm,
base_comm_nccl=x.base_comm_nccl,
partition=Partition.BROADCAST,
engine=x.engine, dtype=self.dtype)
y1 = []
for iop, oper in enumerate(self.ops):
y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]]))
y1 = ncp.sum(ncp.vstack(y1), axis=0)
if deps.nccl_enabled and x.base_comm_nccl:
y[:] = nccl_allreduce(x.base_comm_nccl, y1, op=MPI.SUM)
else:
y[:] = self.base_comm.allreduce(y1, op=MPI.SUM)
y[:] = self._allreduce(y1, op=MPI.SUM)
return y


Expand Down
Loading
Loading