Skip to content

Commit e5d7b52

Browse files
authored
Merge pull request #119 from PyLops/bug-stackedarraynorm
bug: fixed StackedDistributedArray.norm to work with cupy arrays
2 parents e9bbecc + a36e19d commit e5d7b52

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

pylops_mpi/DistributedArray.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -818,21 +818,22 @@ def norm(self, ord: Optional[int] = None):
818818
ord : :obj:`int`, optional
819819
Order of the norm.
820820
"""
821-
norms = np.array([distarray.norm(ord) for distarray in self.distarrays])
821+
ncp = get_module(self.distarrays[0].engine)
822+
norms = ncp.array([distarray.norm(ord) for distarray in self.distarrays])
822823
ord = 2 if ord is None else ord
823824
if ord in ['fro', 'nuc']:
824825
raise ValueError(f"norm-{ord} not possible for vectors")
825826
elif ord == 0:
826827
# Count non-zero then sum reduction
827-
norm = np.sum(norms)
828-
elif ord == np.inf:
828+
norm = ncp.sum(norms)
829+
elif ord == ncp.inf:
829830
# Calculate max followed by max reduction
830-
norm = np.max(norms)
831-
elif ord == -np.inf:
831+
norm = ncp.max(norms)
832+
elif ord == -ncp.inf:
832833
# Calculate min followed by max reduction
833-
norm = np.min(norms)
834+
norm = ncp.min(norms)
834835
else:
835-
norm = np.power(np.sum(np.power(norms, ord)), 1. / ord)
836+
norm = ncp.power(ncp.sum(ncp.power(norms, ord)), 1. / ord)
836837
return norm
837838

838839
def conj(self):

0 commit comments

Comments
 (0)