@@ -178,10 +178,8 @@ std::optional<DenseElementsAttr>
178
178
TosaReduceTransposes::transposeDenseAttribute (DenseElementsAttr input,
179
179
ArrayRef<int32_t > perms) {
180
180
RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType ());
181
- RankedTensorType newType =
182
- RankedTensorType::get (applyTOSAPermutation (oldType.getShape (), perms),
183
- oldType.getElementType ());
184
- size_t rank = oldType.getRank ();
181
+ ArrayRef<int64_t > oldShape = oldType.getShape ();
182
+ int64_t rank = oldType.getRank ();
185
183
186
184
// Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
187
185
// 0. If not in place, something is very wrong.
@@ -190,65 +188,83 @@ TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
190
188
return std::nullopt;
191
189
}
192
190
193
- if (input.isSplat ())
191
+ auto newShape = applyTOSAPermutation (oldShape, perms);
192
+ RankedTensorType newType =
193
+ RankedTensorType::get (newShape, oldType.getElementType ());
194
+
195
+ if (input.isSplat ()) {
194
196
return input.reshape (newType);
197
+ }
198
+
199
+ auto rawData = input.getRawData ();
200
+ if (!rawData.data ()) {
201
+ return std::nullopt;
202
+ }
195
203
196
204
// The algorithm is approximately as follows:
197
- // input: perms, input flat array, input tensor type
198
- // (1/2) determine the strides of input/output if
199
- // they were strided in row-major order. (3) adjust the strides for the
200
- // input to be in the same order of indices as the output is written.
201
- // (4) process dimension by dimension. example: perms 2, 0, 1; input
202
- // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
203
- // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
204
- // input strides to be as input[i + 12j + 4k] so we may process
205
- // layer-by-layer.
206
-
207
- // Step 1/2: Strides for input. We ignore output since row-major and can just
208
- // push_back.
209
-
210
- SmallVector<int64_t > originalInputStrides (rank);
211
- originalInputStrides[rank - 1 ] = 1 ;
212
- // index with int64_t to avoid overflow
213
- for (int64_t i = rank - 2 ; i >= 0 ; i--)
214
- originalInputStrides[i] =
215
- originalInputStrides[i + 1 ] * oldType.getDimSize (i + 1 );
216
-
217
- // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
218
- // output which is done in row-major order.
219
-
220
- SmallVector<int64_t > newInputStrides;
221
- newInputStrides.reserve (rank);
222
- for (int32_t v : perms)
223
- newInputStrides.push_back (originalInputStrides[v]);
224
-
225
- // Step 4: Write out the transposed "flat array" dimension by dimension.
226
-
227
- auto inputArray = input.getValues <Attribute>();
228
- SmallVector<std::pair<int64_t , int64_t >> boundsAndStrides;
229
- for (size_t i = 0 ; i < rank; i++)
230
- boundsAndStrides.push_back ({newType.getDimSize (i), newInputStrides[i]});
231
-
232
- SmallVector<Attribute> resultArray;
233
- resultArray.reserve (inputArray.size ());
234
-
235
- std::function<void (int64_t ,
236
- SmallVector<std::pair<int64_t , int64_t >>::const_iterator)>
237
- processTransposeDim = [&](auto accumulatedIndex, auto it) {
238
- if (it == boundsAndStrides.end ()) {
239
- resultArray.push_back (inputArray[accumulatedIndex]);
240
- return ;
241
- }
242
-
243
- for (int64_t i = 0 ; i < it->first ; i++) {
244
- int64_t j = accumulatedIndex + i * it->second ;
245
- processTransposeDim (j, it + 1 );
246
- }
247
- };
248
-
249
- processTransposeDim (0 , boundsAndStrides.begin ());
250
-
251
- return DenseElementsAttr::get (newType, resultArray);
205
+ // 1. Determine the strides of both input and output tensors in row-major
206
+ // order
207
+ // 2. Iterate through the output tensor linearly.
208
+ // 3. For each output position, decompose the linear index into
209
+ // multi-dimensional coordinates using output strides.
210
+ // 4. Use the permutation to map output coordinates to input coordinates and
211
+ // calculate the source linear index.
212
+
213
+ // Example: perms [2, 0, 1]; input 2x3x4; output 4x2x3
214
+ // for output linear index 11: decompose to output[1][1][2]
215
+ // using output strides [6,3,1]. Map to input coordinates using
216
+ // perms: dim 0→2, dim 1→0, dim 2→1, giving source position
217
+ // calculated as 1*inputStrides[2] + 1*inputStrides[0] + 2*inputStrides[1]
218
+ // = 1*1 + 1*12 + 2*4 = 21
219
+
220
+ size_t elementSize = oldType.getElementTypeBitWidth () / 8 ;
221
+ int64_t numElements = oldType.getNumElements ();
222
+
223
+ SmallVector<char > outputBuffer (numElements * elementSize);
224
+ const char *inputPtr = rawData.data ();
225
+ char *outputPtr = outputBuffer.data ();
226
+
227
+ auto calculateStrides = [](ArrayRef<int64_t > shape) -> SmallVector<int64_t > {
228
+ int64_t rank = shape.size ();
229
+ SmallVector<int64_t > strides (rank);
230
+ strides[rank - 1 ] = 1 ;
231
+ for (int64_t i = rank - 2 ; i >= 0 ; --i) {
232
+ strides[i] = strides[i + 1 ] * shape[i + 1 ];
233
+ }
234
+ return strides;
235
+ };
236
+
237
+ // Calculate strides for both input and output tensors
238
+ SmallVector<int64_t > inputStrides = calculateStrides (oldShape);
239
+ SmallVector<int64_t > outputStrides = calculateStrides (newShape);
240
+
241
+ auto mapCoordinates = [&](int64_t destLinearIndex) -> int64_t {
242
+ int64_t tempDestIndex = destLinearIndex;
243
+ int64_t sourceLinearIndex = 0 ;
244
+
245
+ // Decompose linear destination index into multi-dimensional
246
+ // coordinates dividing by output strides.
247
+ // Simultaneously map these coordinates through the permutation
248
+ // to calculate the corresponding source linear index.
249
+ for (auto j : llvm::seq<int64_t >(rank)) {
250
+ int64_t destCoord = tempDestIndex / outputStrides[j];
251
+ tempDestIndex %= outputStrides[j];
252
+ sourceLinearIndex += destCoord * inputStrides[perms[j]];
253
+ }
254
+
255
+ return sourceLinearIndex;
256
+ };
257
+
258
+ for (auto destLinearIndex : llvm::seq<int64_t >(numElements)) {
259
+ int64_t sourceLinearIndex = mapCoordinates (destLinearIndex);
260
+
261
+ // Copy the element from source to destination using type-agnostic byte
262
+ // copying.
263
+ std::memcpy (outputPtr + destLinearIndex * elementSize,
264
+ inputPtr + sourceLinearIndex * elementSize, elementSize);
265
+ }
266
+
267
+ return DenseElementsAttr::getFromRawBuffer (newType, outputBuffer);
252
268
}
253
269
254
270
// The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
0 commit comments