diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp index 8ebbbc94eb6a2..db7a3c671dedc 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp @@ -178,10 +178,8 @@ std::optional TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input, ArrayRef perms) { RankedTensorType oldType = llvm::cast(input.getType()); - RankedTensorType newType = - RankedTensorType::get(applyTOSAPermutation(oldType.getShape(), perms), - oldType.getElementType()); - size_t rank = oldType.getRank(); + ArrayRef oldShape = oldType.getShape(); + int64_t rank = oldType.getRank(); // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension // 0. If not in place, something is very wrong. @@ -190,65 +188,83 @@ TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input, return std::nullopt; } - if (input.isSplat()) + auto newShape = applyTOSAPermutation(oldShape, perms); + RankedTensorType newType = + RankedTensorType::get(newShape, oldType.getElementType()); + + if (input.isSplat()) { return input.reshape(newType); + } + + auto rawData = input.getRawData(); + if (!rawData.data()) { + return std::nullopt; + } // The algorithm is approximately as follows: - // input: perms, input flat array, input tensor type - // (1/2) determine the strides of input/output if - // they were strided in row-major order. (3) adjust the strides for the - // input to be in the same order of indices as the output is written. - // (4) process dimension by dimension. example: perms 2, 0, 1; input - // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] = - // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust - // input strides to be as input[i + 12j + 4k] so we may process - // layer-by-layer. - - // Step 1/2: Strides for input. We ignore output since row-major and can just - // push_back. - - SmallVector originalInputStrides(rank); - originalInputStrides[rank - 1] = 1; - // index with int64_t to avoid overflow - for (int64_t i = rank - 2; i >= 0; i--) - originalInputStrides[i] = - originalInputStrides[i + 1] * oldType.getDimSize(i + 1); - - // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as - // output which is done in row-major order. - - SmallVector newInputStrides; - newInputStrides.reserve(rank); - for (int32_t v : perms) - newInputStrides.push_back(originalInputStrides[v]); - - // Step 4: Write out the transposed "flat array" dimension by dimension. - - auto inputArray = input.getValues(); - SmallVector> boundsAndStrides; - for (size_t i = 0; i < rank; i++) - boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]}); - - SmallVector resultArray; - resultArray.reserve(inputArray.size()); - - std::function>::const_iterator)> - processTransposeDim = [&](auto accumulatedIndex, auto it) { - if (it == boundsAndStrides.end()) { - resultArray.push_back(inputArray[accumulatedIndex]); - return; - } - - for (int64_t i = 0; i < it->first; i++) { - int64_t j = accumulatedIndex + i * it->second; - processTransposeDim(j, it + 1); - } - }; - - processTransposeDim(0, boundsAndStrides.begin()); - - return DenseElementsAttr::get(newType, resultArray); + // 1. Determine the strides of both input and output tensors in row-major + // order + // 2. Iterate through the output tensor linearly. + // 3. For each output position, decompose the linear index into + // multi-dimensional coordinates using output strides. + // 4. Use the permutation to map output coordinates to input coordinates and + // calculate the source linear index. + + // Example: perms [2, 0, 1]; input 2x3x4; output 4x2x3 + // for output linear index 11: decompose to output[1][1][2] + // using output strides [6,3,1]. Map to input coordinates using + // perms: dim 0→2, dim 1→0, dim 2→1, giving source position + // calculated as 1*inputStrides[2] + 1*inputStrides[0] + 2*inputStrides[1] + // = 1*1 + 1*12 + 2*4 = 21 + + size_t elementSize = oldType.getElementTypeBitWidth() / 8; + int64_t numElements = oldType.getNumElements(); + + SmallVector outputBuffer(numElements * elementSize); + const char *inputPtr = rawData.data(); + char *outputPtr = outputBuffer.data(); + + auto calculateStrides = [](ArrayRef shape) -> SmallVector { + int64_t rank = shape.size(); + SmallVector strides(rank); + strides[rank - 1] = 1; + for (int64_t i = rank - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * shape[i + 1]; + } + return strides; + }; + + // Calculate strides for both input and output tensors + SmallVector inputStrides = calculateStrides(oldShape); + SmallVector outputStrides = calculateStrides(newShape); + + auto mapCoordinates = [&](int64_t destLinearIndex) -> int64_t { + int64_t tempDestIndex = destLinearIndex; + int64_t sourceLinearIndex = 0; + + // Decompose linear destination index into multi-dimensional + // coordinates dividing by output strides. + // Simultaneously map these coordinates through the permutation + // to calculate the corresponding source linear index. + for (auto j : llvm::seq(rank)) { + int64_t destCoord = tempDestIndex / outputStrides[j]; + tempDestIndex %= outputStrides[j]; + sourceLinearIndex += destCoord * inputStrides[perms[j]]; + } + + return sourceLinearIndex; + }; + + for (auto destLinearIndex : llvm::seq(numElements)) { + int64_t sourceLinearIndex = mapCoordinates(destLinearIndex); + + // Copy the element from source to destination using type-agnostic byte + // copying. + std::memcpy(outputPtr + destLinearIndex * elementSize, + inputPtr + sourceLinearIndex * elementSize, elementSize); + } + + return DenseElementsAttr::getFromRawBuffer(newType, outputBuffer); } // The SetVector should only contain ConstOp, ReshapeOp, TransposeOp