5
5
from enum import Enum
6
6
7
7
from pylops .utils import DTypeLike , NDArray
8
+ from pylops .utils .backend import get_module , get_array_module , get_module_name
8
9
9
10
10
11
class Partition (Enum ):
@@ -78,6 +79,8 @@ class DistributedArray:
78
79
Axis along which distribution occurs. Defaults to ``0``.
79
80
local_shapes : :obj:`list`, optional
80
81
List of tuples representing local shapes at each rank.
82
+ engine : :obj:`str`, optional
83
+ Engine used to store array (``numpy`` or ``cupy``)
81
84
dtype : :obj:`str`, optional
82
85
Type of elements in input array. Defaults to ``numpy.float64``.
83
86
"""
@@ -86,6 +89,7 @@ def __init__(self, global_shape: Union[Tuple, Integral],
86
89
base_comm : Optional [MPI .Comm ] = MPI .COMM_WORLD ,
87
90
partition : Partition = Partition .SCATTER , axis : int = 0 ,
88
91
local_shapes : Optional [List [Tuple ]] = None ,
92
+ engine : Optional [str ] = "numpy" ,
89
93
dtype : Optional [DTypeLike ] = np .float64 ):
90
94
if isinstance (global_shape , Integral ):
91
95
global_shape = (global_shape ,)
@@ -103,7 +107,8 @@ def __init__(self, global_shape: Union[Tuple, Integral],
103
107
self ._check_local_shapes (local_shapes )
104
108
self ._local_shape = local_shapes [base_comm .rank ] if local_shapes else local_split (global_shape , base_comm ,
105
109
partition , axis )
106
- self ._local_array = np .empty (shape = self .local_shape , dtype = self .dtype )
110
+ self ._engine = engine
111
+ self ._local_array = get_module (engine ).empty (shape = self .local_shape , dtype = self .dtype )
107
112
108
113
def __getitem__ (self , index ):
109
114
return self .local_array [index ]
@@ -160,6 +165,16 @@ def local_shape(self):
160
165
"""
161
166
return self ._local_shape
162
167
168
+ @property
169
+ def engine (self ):
170
+ """Engine of the Distributed array
171
+
172
+ Returns
173
+ -------
174
+ engine : :obj:`str`
175
+ """
176
+ return self ._engine
177
+
163
178
@property
164
179
def local_array (self ):
165
180
"""View of the Local Array
@@ -269,6 +284,7 @@ def to_dist(cls, x: NDArray,
269
284
Axis of Distribution
270
285
local_shapes : :obj:`list`, optional
271
286
Local Shapes at each rank.
287
+
272
288
Returns
273
289
----------
274
290
dist_array : :obj:`DistributedArray`
@@ -279,6 +295,7 @@ def to_dist(cls, x: NDArray,
279
295
partition = partition ,
280
296
axis = axis ,
281
297
local_shapes = local_shapes ,
298
+ engine = get_module_name (get_array_module (x )),
282
299
dtype = x .dtype )
283
300
if partition == Partition .BROADCAST :
284
301
dist_array [:] = x
@@ -334,6 +351,7 @@ def __neg__(self):
334
351
partition = self .partition ,
335
352
axis = self .axis ,
336
353
local_shapes = self .local_shapes ,
354
+ engine = self .engine ,
337
355
dtype = self .dtype )
338
356
arr [:] = - self .local_array
339
357
return arr
@@ -365,6 +383,7 @@ def add(self, dist_array):
365
383
dtype = self .dtype ,
366
384
partition = self .partition ,
367
385
local_shapes = self .local_shapes ,
386
+ engine = self .engine ,
368
387
axis = self .axis )
369
388
SumArray [:] = self .local_array + dist_array .local_array
370
389
return SumArray
@@ -387,6 +406,7 @@ def multiply(self, dist_array):
387
406
dtype = self .dtype ,
388
407
partition = self .partition ,
389
408
local_shapes = self .local_shapes ,
409
+ engine = self .engine ,
390
410
axis = self .axis )
391
411
if isinstance (dist_array , DistributedArray ):
392
412
# multiply two DistributedArray
@@ -480,6 +500,7 @@ def conj(self):
480
500
partition = self .partition ,
481
501
axis = self .axis ,
482
502
local_shapes = self .local_shapes ,
503
+ engine = self .engine ,
483
504
dtype = self .dtype )
484
505
conj [:] = self .local_array .conj ()
485
506
return conj
@@ -492,6 +513,7 @@ def copy(self):
492
513
partition = self .partition ,
493
514
axis = self .axis ,
494
515
local_shapes = self .local_shapes ,
516
+ engine = self .engine ,
495
517
dtype = self .dtype )
496
518
arr [:] = self .local_array
497
519
return arr
@@ -514,6 +536,7 @@ def ravel(self, order: Optional[str] = "C"):
514
536
arr = DistributedArray (global_shape = np .prod (self .global_shape ),
515
537
local_shapes = local_shapes ,
516
538
partition = self .partition ,
539
+ engine = self .engine ,
517
540
dtype = self .dtype )
518
541
local_array = np .ravel (self .local_array , order = order )
519
542
x = local_array .copy ()
0 commit comments