File tree 1 file changed +8
-7
lines changed 1 file changed +8
-7
lines changed Original file line number Diff line number Diff line change @@ -818,21 +818,22 @@ def norm(self, ord: Optional[int] = None):
818
818
ord : :obj:`int`, optional
819
819
Order of the norm.
820
820
"""
821
- norms = np .array ([distarray .norm (ord ) for distarray in self .distarrays ])
821
+ ncp = get_module (self .distarrays [0 ].engine )
822
+ norms = ncp .array ([distarray .norm (ord ) for distarray in self .distarrays ])
822
823
ord = 2 if ord is None else ord
823
824
if ord in ['fro' , 'nuc' ]:
824
825
raise ValueError (f"norm-{ ord } not possible for vectors" )
825
826
elif ord == 0 :
826
827
# Count non-zero then sum reduction
827
- norm = np .sum (norms )
828
- elif ord == np .inf :
828
+ norm = ncp .sum (norms )
829
+ elif ord == ncp .inf :
829
830
# Calculate max followed by max reduction
830
- norm = np .max (norms )
831
- elif ord == - np .inf :
831
+ norm = ncp .max (norms )
832
+ elif ord == - ncp .inf :
832
833
# Calculate min followed by max reduction
833
- norm = np .min (norms )
834
+ norm = ncp .min (norms )
834
835
else :
835
- norm = np .power (np .sum (np .power (norms , ord )), 1. / ord )
836
+ norm = ncp .power (ncp .sum (ncp .power (norms , ord )), 1. / ord )
836
837
return norm
837
838
838
839
def conj (self ):
You can’t perform that action at this time.
0 commit comments