@@ -108,8 +108,9 @@ def __init__(
108
108
109
109
def _matvec (self , x : DistributedArray ) -> DistributedArray :
110
110
ncp = get_module (x .engine )
111
- if x .partition is not Partition .BROADCAST :
112
- raise ValueError (f"x should have partition={ Partition .BROADCAST } , { x .partition } != { Partition .BROADCAST } " )
111
+ if x .partition not in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ]:
112
+ raise ValueError (f"x should have partition={ Partition .BROADCAST } ,{ Partition .UNSAFE_BROADCAST } "
113
+ f"Got { x .partition } instead..." )
113
114
y = DistributedArray (global_shape = self .shape [0 ], partition = Partition .BROADCAST ,
114
115
engine = x .engine , dtype = self .dtype )
115
116
x = x .local_array .reshape (self .dims ).squeeze ()
@@ -129,8 +130,9 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
129
130
130
131
def _rmatvec (self , x : NDArray ) -> NDArray :
131
132
ncp = get_module (x .engine )
132
- if x .partition is not Partition .BROADCAST :
133
- raise ValueError (f"x should have partition={ Partition .BROADCAST } , { x .partition } != { Partition .BROADCAST } " )
133
+ if x .partition not in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ]:
134
+ raise ValueError (f"x should have partition={ Partition .BROADCAST } ,{ Partition .UNSAFE_BROADCAST } "
135
+ f"Got { x .partition } instead..." )
134
136
y = DistributedArray (global_shape = self .shape [1 ], partition = Partition .BROADCAST ,
135
137
engine = x .engine , dtype = self .dtype )
136
138
x = x .local_array .reshape (self .dimsd ).squeeze ()
0 commit comments