Skip to content

Commit b846d8c

Browse files
authored
[mlir][tosa] Fix tosa-reduce-transposes to handle large constants better (#148755)
This change addresses the performance issue in the **--tosa-reduce-transposes** implementation by working directly with the raw tensor data, eliminating the need for creating the costly intermediate attributes that leads to bottleneck.
1 parent 4dc6dfd commit b846d8c

File tree

1 file changed

+76
-60
lines changed

1 file changed

+76
-60
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp

Lines changed: 76 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,8 @@ std::optional<DenseElementsAttr>
178178
TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
179179
ArrayRef<int32_t> perms) {
180180
RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType());
181-
RankedTensorType newType =
182-
RankedTensorType::get(applyTOSAPermutation(oldType.getShape(), perms),
183-
oldType.getElementType());
184-
size_t rank = oldType.getRank();
181+
ArrayRef<int64_t> oldShape = oldType.getShape();
182+
int64_t rank = oldType.getRank();
185183

186184
// Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
187185
// 0. If not in place, something is very wrong.
@@ -190,65 +188,83 @@ TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
190188
return std::nullopt;
191189
}
192190

193-
if (input.isSplat())
191+
auto newShape = applyTOSAPermutation(oldShape, perms);
192+
RankedTensorType newType =
193+
RankedTensorType::get(newShape, oldType.getElementType());
194+
195+
if (input.isSplat()) {
194196
return input.reshape(newType);
197+
}
198+
199+
auto rawData = input.getRawData();
200+
if (!rawData.data()) {
201+
return std::nullopt;
202+
}
195203

196204
// The algorithm is approximately as follows:
197-
// input: perms, input flat array, input tensor type
198-
// (1/2) determine the strides of input/output if
199-
// they were strided in row-major order. (3) adjust the strides for the
200-
// input to be in the same order of indices as the output is written.
201-
// (4) process dimension by dimension. example: perms 2, 0, 1; input
202-
// 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
203-
// input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
204-
// input strides to be as input[i + 12j + 4k] so we may process
205-
// layer-by-layer.
206-
207-
// Step 1/2: Strides for input. We ignore output since row-major and can just
208-
// push_back.
209-
210-
SmallVector<int64_t> originalInputStrides(rank);
211-
originalInputStrides[rank - 1] = 1;
212-
// index with int64_t to avoid overflow
213-
for (int64_t i = rank - 2; i >= 0; i--)
214-
originalInputStrides[i] =
215-
originalInputStrides[i + 1] * oldType.getDimSize(i + 1);
216-
217-
// Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
218-
// output which is done in row-major order.
219-
220-
SmallVector<int64_t> newInputStrides;
221-
newInputStrides.reserve(rank);
222-
for (int32_t v : perms)
223-
newInputStrides.push_back(originalInputStrides[v]);
224-
225-
// Step 4: Write out the transposed "flat array" dimension by dimension.
226-
227-
auto inputArray = input.getValues<Attribute>();
228-
SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides;
229-
for (size_t i = 0; i < rank; i++)
230-
boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]});
231-
232-
SmallVector<Attribute> resultArray;
233-
resultArray.reserve(inputArray.size());
234-
235-
std::function<void(int64_t,
236-
SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
237-
processTransposeDim = [&](auto accumulatedIndex, auto it) {
238-
if (it == boundsAndStrides.end()) {
239-
resultArray.push_back(inputArray[accumulatedIndex]);
240-
return;
241-
}
242-
243-
for (int64_t i = 0; i < it->first; i++) {
244-
int64_t j = accumulatedIndex + i * it->second;
245-
processTransposeDim(j, it + 1);
246-
}
247-
};
248-
249-
processTransposeDim(0, boundsAndStrides.begin());
250-
251-
return DenseElementsAttr::get(newType, resultArray);
205+
// 1. Determine the strides of both input and output tensors in row-major
206+
// order
207+
// 2. Iterate through the output tensor linearly.
208+
// 3. For each output position, decompose the linear index into
209+
// multi-dimensional coordinates using output strides.
210+
// 4. Use the permutation to map output coordinates to input coordinates and
211+
// calculate the source linear index.
212+
213+
// Example: perms [2, 0, 1]; input 2x3x4; output 4x2x3
214+
// for output linear index 11: decompose to output[1][1][2]
215+
// using output strides [6,3,1]. Map to input coordinates using
216+
// perms: dim 0→2, dim 1→0, dim 2→1, giving source position
217+
// calculated as 1*inputStrides[2] + 1*inputStrides[0] + 2*inputStrides[1]
218+
// = 1*1 + 1*12 + 2*4 = 21
219+
220+
size_t elementSize = oldType.getElementTypeBitWidth() / 8;
221+
int64_t numElements = oldType.getNumElements();
222+
223+
SmallVector<char> outputBuffer(numElements * elementSize);
224+
const char *inputPtr = rawData.data();
225+
char *outputPtr = outputBuffer.data();
226+
227+
auto calculateStrides = [](ArrayRef<int64_t> shape) -> SmallVector<int64_t> {
228+
int64_t rank = shape.size();
229+
SmallVector<int64_t> strides(rank);
230+
strides[rank - 1] = 1;
231+
for (int64_t i = rank - 2; i >= 0; --i) {
232+
strides[i] = strides[i + 1] * shape[i + 1];
233+
}
234+
return strides;
235+
};
236+
237+
// Calculate strides for both input and output tensors
238+
SmallVector<int64_t> inputStrides = calculateStrides(oldShape);
239+
SmallVector<int64_t> outputStrides = calculateStrides(newShape);
240+
241+
auto mapCoordinates = [&](int64_t destLinearIndex) -> int64_t {
242+
int64_t tempDestIndex = destLinearIndex;
243+
int64_t sourceLinearIndex = 0;
244+
245+
// Decompose linear destination index into multi-dimensional
246+
// coordinates dividing by output strides.
247+
// Simultaneously map these coordinates through the permutation
248+
// to calculate the corresponding source linear index.
249+
for (auto j : llvm::seq<int64_t>(rank)) {
250+
int64_t destCoord = tempDestIndex / outputStrides[j];
251+
tempDestIndex %= outputStrides[j];
252+
sourceLinearIndex += destCoord * inputStrides[perms[j]];
253+
}
254+
255+
return sourceLinearIndex;
256+
};
257+
258+
for (auto destLinearIndex : llvm::seq<int64_t>(numElements)) {
259+
int64_t sourceLinearIndex = mapCoordinates(destLinearIndex);
260+
261+
// Copy the element from source to destination using type-agnostic byte
262+
// copying.
263+
std::memcpy(outputPtr + destLinearIndex * elementSize,
264+
inputPtr + sourceLinearIndex * elementSize, elementSize);
265+
}
266+
267+
return DenseElementsAttr::getFromRawBuffer(newType, outputBuffer);
252268
}
253269

254270
// The SetVector should only contain ConstOp, ReshapeOp, TransposeOp

0 commit comments

Comments
 (0)