Skip to content

Commit 03e3a77

Browse files
authored
Merge pull request #120 from PyLops/bug-unsafebroad
bug: allow unsafe_broadcast in VStack and Fredholm1
2 parents e5d7b52 + fedb902 commit 03e3a77

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

pylops_mpi/basicoperators/VStack.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,9 @@ def __init__(self, ops: Sequence[LinearOperator],
118118

119119
def _matvec(self, x: DistributedArray) -> DistributedArray:
120120
ncp = get_module(x.engine)
121-
if x.partition is not Partition.BROADCAST:
122-
raise ValueError(f"x should have partition={Partition.BROADCAST}, {x.partition} != {Partition.BROADCAST}")
121+
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
122+
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
123+
f"Got {x.partition} instead...")
123124
y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n,
124125
engine=x.engine, dtype=self.dtype)
125126
y1 = []

pylops_mpi/signalprocessing/Fredholm1.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,9 @@ def __init__(
108108

109109
def _matvec(self, x: DistributedArray) -> DistributedArray:
110110
ncp = get_module(x.engine)
111-
if x.partition is not Partition.BROADCAST:
112-
raise ValueError(f"x should have partition={Partition.BROADCAST}, {x.partition} != {Partition.BROADCAST}")
111+
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
112+
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
113+
f"Got {x.partition} instead...")
113114
y = DistributedArray(global_shape=self.shape[0], partition=Partition.BROADCAST,
114115
engine=x.engine, dtype=self.dtype)
115116
x = x.local_array.reshape(self.dims).squeeze()
@@ -129,8 +130,9 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
129130

130131
def _rmatvec(self, x: NDArray) -> NDArray:
131132
ncp = get_module(x.engine)
132-
if x.partition is not Partition.BROADCAST:
133-
raise ValueError(f"x should have partition={Partition.BROADCAST}, {x.partition} != {Partition.BROADCAST}")
133+
if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]:
134+
raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}"
135+
f"Got {x.partition} instead...")
134136
y = DistributedArray(global_shape=self.shape[1], partition=Partition.BROADCAST,
135137
engine=x.engine, dtype=self.dtype)
136138
x = x.local_array.reshape(self.dimsd).squeeze()

0 commit comments

Comments
 (0)