Skip to content

[mlir][tosa] Fix tosa-reduce-transposes to handle large constants better #148755

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 76 additions & 60 deletions mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,8 @@ std::optional<DenseElementsAttr>
TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
ArrayRef<int32_t> perms) {
RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType());
RankedTensorType newType =
RankedTensorType::get(applyTOSAPermutation(oldType.getShape(), perms),
oldType.getElementType());
size_t rank = oldType.getRank();
ArrayRef<int64_t> oldShape = oldType.getShape();
int64_t rank = oldType.getRank();

// Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
// 0. If not in place, something is very wrong.
Expand All @@ -190,65 +188,83 @@ TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
return std::nullopt;
}

if (input.isSplat())
auto newShape = applyTOSAPermutation(oldShape, perms);
RankedTensorType newType =
RankedTensorType::get(newShape, oldType.getElementType());

if (input.isSplat()) {
return input.reshape(newType);
}

auto rawData = input.getRawData();
if (!rawData.data()) {
return std::nullopt;
}

// The algorithm is approximately as follows:
// input: perms, input flat array, input tensor type
// (1/2) determine the strides of input/output if
// they were strided in row-major order. (3) adjust the strides for the
// input to be in the same order of indices as the output is written.
// (4) process dimension by dimension. example: perms 2, 0, 1; input
// 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
// input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
// input strides to be as input[i + 12j + 4k] so we may process
// layer-by-layer.

// Step 1/2: Strides for input. We ignore output since row-major and can just
// push_back.

SmallVector<int64_t> originalInputStrides(rank);
originalInputStrides[rank - 1] = 1;
// index with int64_t to avoid overflow
for (int64_t i = rank - 2; i >= 0; i--)
originalInputStrides[i] =
originalInputStrides[i + 1] * oldType.getDimSize(i + 1);

// Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
// output which is done in row-major order.

SmallVector<int64_t> newInputStrides;
newInputStrides.reserve(rank);
for (int32_t v : perms)
newInputStrides.push_back(originalInputStrides[v]);

// Step 4: Write out the transposed "flat array" dimension by dimension.

auto inputArray = input.getValues<Attribute>();
SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides;
for (size_t i = 0; i < rank; i++)
boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]});

SmallVector<Attribute> resultArray;
resultArray.reserve(inputArray.size());

std::function<void(int64_t,
SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
processTransposeDim = [&](auto accumulatedIndex, auto it) {
if (it == boundsAndStrides.end()) {
resultArray.push_back(inputArray[accumulatedIndex]);
return;
}

for (int64_t i = 0; i < it->first; i++) {
int64_t j = accumulatedIndex + i * it->second;
processTransposeDim(j, it + 1);
}
};

processTransposeDim(0, boundsAndStrides.begin());

return DenseElementsAttr::get(newType, resultArray);
// 1. Determine the strides of both input and output tensors in row-major
// order
// 2. Iterate through the output tensor linearly.
// 3. For each output position, decompose the linear index into
// multi-dimensional coordinates using output strides.
// 4. Use the permutation to map output coordinates to input coordinates and
// calculate the source linear index.

// Example: perms [2, 0, 1]; input 2x3x4; output 4x2x3
// for output linear index 11: decompose to output[1][1][2]
// using output strides [6,3,1]. Map to input coordinates using
// perms: dim 0→2, dim 1→0, dim 2→1, giving source position
// calculated as 1*inputStrides[2] + 1*inputStrides[0] + 2*inputStrides[1]
// = 1*1 + 1*12 + 2*4 = 21

size_t elementSize = oldType.getElementTypeBitWidth() / 8;
int64_t numElements = oldType.getNumElements();

SmallVector<char> outputBuffer(numElements * elementSize);
const char *inputPtr = rawData.data();
char *outputPtr = outputBuffer.data();

auto calculateStrides = [](ArrayRef<int64_t> shape) -> SmallVector<int64_t> {
int64_t rank = shape.size();
SmallVector<int64_t> strides(rank);
strides[rank - 1] = 1;
for (int64_t i = rank - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * shape[i + 1];
}
return strides;
};

// Calculate strides for both input and output tensors
SmallVector<int64_t> inputStrides = calculateStrides(oldShape);
SmallVector<int64_t> outputStrides = calculateStrides(newShape);

auto mapCoordinates = [&](int64_t destLinearIndex) -> int64_t {
int64_t tempDestIndex = destLinearIndex;
int64_t sourceLinearIndex = 0;

// Decompose linear destination index into multi-dimensional
// coordinates dividing by output strides.
// Simultaneously map these coordinates through the permutation
// to calculate the corresponding source linear index.
for (auto j : llvm::seq<int64_t>(rank)) {
int64_t destCoord = tempDestIndex / outputStrides[j];
tempDestIndex %= outputStrides[j];
sourceLinearIndex += destCoord * inputStrides[perms[j]];
}

return sourceLinearIndex;
};

for (auto destLinearIndex : llvm::seq<int64_t>(numElements)) {
int64_t sourceLinearIndex = mapCoordinates(destLinearIndex);

// Copy the element from source to destination using type-agnostic byte
// copying.
std::memcpy(outputPtr + destLinearIndex * elementSize,
inputPtr + sourceLinearIndex * elementSize, elementSize);
}

return DenseElementsAttr::getFromRawBuffer(newType, outputBuffer);
}

// The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
Expand Down
Loading