@@ -111,7 +111,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
111
111
if x .partition not in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ]:
112
112
raise ValueError (f"x should have partition={ Partition .BROADCAST } ,{ Partition .UNSAFE_BROADCAST } "
113
113
f"Got { x .partition } instead..." )
114
- y = DistributedArray (global_shape = self .shape [0 ], partition = Partition . BROADCAST ,
114
+ y = DistributedArray (global_shape = self .shape [0 ], partition = x . partition ,
115
115
engine = x .engine , dtype = self .dtype )
116
116
x = x .local_array .reshape (self .dims ).squeeze ()
117
117
x = x [self .islstart [self .rank ]:self .islend [self .rank ]]
@@ -133,7 +133,7 @@ def _rmatvec(self, x: NDArray) -> NDArray:
133
133
if x .partition not in [Partition .BROADCAST , Partition .UNSAFE_BROADCAST ]:
134
134
raise ValueError (f"x should have partition={ Partition .BROADCAST } ,{ Partition .UNSAFE_BROADCAST } "
135
135
f"Got { x .partition } instead..." )
136
- y = DistributedArray (global_shape = self .shape [1 ], partition = Partition . BROADCAST ,
136
+ y = DistributedArray (global_shape = self .shape [1 ], partition = x . partition ,
137
137
engine = x .engine , dtype = self .dtype )
138
138
x = x .local_array .reshape (self .dimsd ).squeeze ()
139
139
x = x [self .islstart [self .rank ]:self .islend [self .rank ]]
0 commit comments