diff --git a/examples/BuddyNext/makefile b/examples/BuddyNext/makefile index ce30c81b2b..3e65c3450a 100644 --- a/examples/BuddyNext/makefile +++ b/examples/BuddyNext/makefile @@ -316,15 +316,15 @@ next-sgemm-run: next-transpose-lower: @${MLIR_OPT} ./next-transpose.mlir \ - -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \ - ${MLIR_OPT} \ - -arith-expand \ + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \ + ${MLIR_OPT} \ + -arith-expand \ -eliminate-empty-tensors \ -empty-tensor-to-alloc-tensor \ - -one-shot-bufferize \ - -func-bufferize \ - -arith-bufferize \ - -o log.mlir + -one-shot-bufferize \ + -func-bufferize \ + -arith-bufferize \ + -o log.mlir next-transpose-run: @${MLIR_OPT} ./next-transpose.mlir \ @@ -334,7 +334,7 @@ next-transpose-run: -eliminate-empty-tensors \ -empty-tensor-to-alloc-tensor \ -one-shot-bufferize \ - -func-bufferize \ + -func-bufferize \ -arith-bufferize \ -convert-linalg-to-affine-loops \ -affine-loop-fusion \ @@ -353,21 +353,71 @@ next-transpose-run: -convert-math-to-libm \ -convert-func-to-llvm \ -reconcile-unrealized-casts | \ - ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + ${MLIR_CPU_RUNNER} -O3 -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} \ + -shared-libs=${MLIR_C_RUNNER_UTILS} + +next-transpose-vectorization-lower: + @${BUDDY_OPT} ./next-transpose.mlir \ + -transpose-vectorization \ + -arith-expand \ + -eliminate-empty-tensors \ + -empty-tensor-to-alloc-tensor \ + -one-shot-bufferize \ + -func-bufferize \ + -arith-bufferize \ + -o log.mlir + +next-transpose-vectorization-run: + @${BUDDY_OPT} ./log.mlir \ + -transpose-vectorization \ + -arith-expand \ + -eliminate-empty-tensors \ + -empty-tensor-to-alloc-tensor \ + -one-shot-bufferize \ + -func-bufferize \ + -arith-bufferize \ + -convert-linalg-to-affine-loops \ + -affine-loop-fusion \ + -lower-affine \ + -convert-scf-to-openmp \ + -convert-vector-to-scf \ + -expand-strided-metadata \ + -convert-vector-to-llvm \ + -memref-expand \ + -arith-expand \ + -arith-bufferize \ + -buffer-deallocation \ + -finalizing-bufferize \ + -convert-linalg-to-affine-loops \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-scf-to-cf \ + -convert-openmp-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} -O3 -e main -entry-point-result=void \ -shared-libs=${MLIR_RUNNER_UTILS} \ -shared-libs=${MLIR_C_RUNNER_UTILS} next-transpose-vec-manual-run: - @${MLIR_OPT} ./next-transpose-vec-manual.mlir \ + @${BUDDY_OPT} ./next-transpose-vec-manual.mlir \ -convert-linalg-to-affine-loops \ -affine-loop-fusion \ -lower-affine \ - -convert-scf-to-openmp \ + -convert-scf-to-openmp \ -convert-vector-to-scf \ -expand-strided-metadata \ -convert-vector-to-llvm \ -memref-expand \ -arith-expand \ + -arith-bufferize \ + -buffer-deallocation \ + -finalizing-bufferize \ + -convert-linalg-to-affine-loops \ -convert-arith-to-llvm \ -finalize-memref-to-llvm \ -convert-scf-to-cf \ @@ -377,7 +427,7 @@ next-transpose-vec-manual-run: -convert-math-to-libm \ -convert-func-to-llvm \ -reconcile-unrealized-casts | \ - ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + ${MLIR_CPU_RUNNER} -O3 -e main -entry-point-result=void \ -shared-libs=${MLIR_RUNNER_UTILS} \ -shared-libs=${MLIR_C_RUNNER_UTILS} diff --git a/examples/BuddyNext/next-transpose-vec-manual.mlir b/examples/BuddyNext/next-transpose-vec-manual.mlir index ccf5c7b7e4..c29580ed47 100644 --- a/examples/BuddyNext/next-transpose-vec-manual.mlir +++ b/examples/BuddyNext/next-transpose-vec-manual.mlir @@ -26,25 +26,47 @@ // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ // RUN: | FileCheck %s +#map = affine_map<(d0) -> (d0)> module { memref.global "private" constant @__constant_1x32x40x128xf32 : memref<1x32x40x128xf32> = dense<3.000000e+00> {alignment = 64 : i64} func.func private @rtclock() -> f64 func.func private @printMemrefF32(memref<*xf32>) - func.func @kernel(%arg0: memref<1x32x40x128xf32>) { - %0 = call @rtclock() : () -> f64 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x40x32x128xf32> - affine.for %arg1 = 0 to 1 { - affine.for %arg2 = 0 to 40 { - affine.for %arg3 = 0 to 32 { - affine.for %arg4 = 0 to 128 step 64 { - %3 = vector.load %arg0[%arg1, %arg3, %arg2, %arg4] : memref<1x32x40x128xf32>, vector<64xf32> - vector.store %3, %alloc[%arg1, %arg2, %arg3, %arg4] : memref<1x40x32x128xf32>, vector<64xf32> + func.func @kernel(%arg0: memref<1x32x40x128xf32>) -> (){ + %alloc = memref.alloc() : memref<1x40x32x128xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c64 = arith.constant 64 : index + %cst = arith.constant 0.000000e+00 : f32 + %3 = vector.splat %cst : vector<64xf32> + %dim = memref.dim %arg0, %c0 : memref<1x32x40x128xf32> + %dim_0 = memref.dim %arg0, %c1 : memref<1x32x40x128xf32> + %dim_1 = memref.dim %arg0, %c2 : memref<1x32x40x128xf32> + %dim_2 = memref.dim %arg0, %c3 : memref<1x32x40x128xf32> + + %4 = arith.subi %dim_2, %c64 : index + %5 = arith.addi %4, %c1 : index + %t0 = call @rtclock() : () -> f64 + affine.for %arg1 = #map(%c0) to #map(%dim) { + affine.for %arg2 = #map(%c0) to #map(%dim_1) { + affine.for %arg3 = #map(%c0) to #map(%dim_0) { + %8 = scf.for %arg4 = %c0 to %5 step %c64 iter_args(%arg5 = %c0) -> (index) { + %12 = vector.load %arg0[%arg1, %arg3, %arg2, %arg4] : memref<1x32x40x128xf32>, vector<64xf32> + vector.store %12, %alloc[%arg1, %arg2, %arg3, %arg4] : memref<1x40x32x128xf32>, vector<64xf32> + %13 = arith.addi %arg4, %c64 : index + scf.yield %13 : index } + %9 = arith.subi %dim_2, %8 : index + %10 = vector.create_mask %9 : vector<64xi1> + %11 = vector.maskedload %arg0[%arg1, %arg3, %arg2, %8], %10, %3 : memref<1x32x40x128xf32>, vector<64xi1>, vector<64xf32> into vector<64xf32> + vector.maskedstore %alloc[%arg1, %arg2, %arg3, %8], %10, %11 : memref<1x40x32x128xf32>, vector<64xi1>, vector<64xf32> } } } - %1 = call @rtclock() : () -> f64 - %2 = arith.subf %1, %0 : f64 + %t1 = call @rtclock() : () -> f64 + %20 = arith.subf %t1, %t0 : f64 + %cast = memref.cast %alloc : memref<1x40x32x128xf32> to memref<*xf32> // All the elements of the MemRef are the same, @@ -56,13 +78,14 @@ module { // CHECK-SAME: [3{{(, 3)*}}], call @printMemrefF32(%cast) : (memref<*xf32>) -> () - vector.print %2 : f64 - return + + vector.print %20 : f64 + return } func.func @main() { %0 = memref.get_global @__constant_1x32x40x128xf32 : memref<1x32x40x128xf32> call @kernel(%0) : (memref<1x32x40x128xf32>) -> () + return } } - diff --git a/examples/BuddyNext/next-transpose.mlir b/examples/BuddyNext/next-transpose.mlir index 1b2bd93d62..ddb9af25d6 100644 --- a/examples/BuddyNext/next-transpose.mlir +++ b/examples/BuddyNext/next-transpose.mlir @@ -35,7 +35,7 @@ func.func private @rtclock() -> f64 func.func private @printMemrefF32(%ptr : tensor<*xf32>) -func.func @kernel(%t0 : tensor<1x32x40x128xf32>) { +func.func @kernel(%t0 : tensor<1x32x40x128xf32>) -> (){ %t_start = call @rtclock() : () -> f64 %idx = "tosa.const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> @@ -43,7 +43,7 @@ func.func @kernel(%t0 : tensor<1x32x40x128xf32>) { %t_end = call @rtclock() : () -> f64 %time = arith.subf %t_end, %t_start : f64 - + %tensor_unranked = tensor.cast %t1 : tensor<1x40x32x128xf32> to tensor<*xf32> // All the elements of the MemRef are the same, @@ -56,6 +56,7 @@ func.func @kernel(%t0 : tensor<1x32x40x128xf32>) { // Print results. call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () + // Print timings. vector.print %time : f64 @@ -65,6 +66,5 @@ func.func @kernel(%t0 : tensor<1x32x40x128xf32>) { func.func @main() { %c0 = arith.constant dense<3.0> : tensor<1x32x40x128xf32> call @kernel(%c0) : (tensor<1x32x40x128xf32>) -> () - return } diff --git a/midend/lib/Conversion/ConvVectorization/CMakeLists.txt b/midend/lib/Conversion/ConvVectorization/CMakeLists.txt index d4cc3ec987..9f81b1c3b3 100644 --- a/midend/lib/Conversion/ConvVectorization/CMakeLists.txt +++ b/midend/lib/Conversion/ConvVectorization/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_library(CBConvVectorization GEMMPointwiseConv2DNhwcHwcf.cpp PoolingVectorization.cpp PoolingNhwcMaxVectorization.cpp + TransposeVectorization.cpp LINK_LIBS PUBLIC BuddyUtils diff --git a/midend/lib/Conversion/ConvVectorization/TransposeVectorization.cpp b/midend/lib/Conversion/ConvVectorization/TransposeVectorization.cpp new file mode 100644 index 0000000000..1666c32d41 --- /dev/null +++ b/midend/lib/Conversion/ConvVectorization/TransposeVectorization.cpp @@ -0,0 +1,250 @@ +//===--------TransposeVectorization.cpp-------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Transpoese Vectorization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Utils/Utils.h" + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class TransposeVectorizationPattern : public ConversionPattern { +public: + explicit TransposeVectorizationPattern(MLIRContext *context, + int64_t stripParam) + : ConversionPattern(tosa::TransposeOp::getOperationName(), 1, context) { + strip = stripParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // Get perms + Value perms = op->getOperand(1); + TensorType permsTy = dyn_cast(perms.getType()); + if (!permsTy || !permsTy.getElementType().isInteger(32) || + permsTy.getShape() != llvm::ArrayRef({4})) { + return failure(); + } + + // Check if perms is {0, 2, 1, 3} + // DenseElementsAttr permsAttr = + // perms.getDefiningOp().getValue().cast(); + // SmallVector expectedPerms = {0, 2, 1, 3}; + // if (!std::equal(permsAttr.getValues().begin(), + // permsAttr.getValues().end(), expectedPerms.begin())) { + // return failure(); + // } + + auto loc = op->getLoc(); + auto ctx = op->getContext(); + + // Get i1 as the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + VectorType vectorMaskTy = mlir::VectorType::get({strip}, i1); + + // Get input. + Value input = op->getOperand(0); + + // Use perms to create a memref::AllocaOp + ShapedType inputTy = input.getType().cast(); + SmallVector inputNumVec; + for (auto in : inputTy.getShape()) + inputNumVec.push_back(in); + + auto tmpValue = inputNumVec[1]; + inputNumVec[1] = inputNumVec[2]; + inputNumVec[2] = tmpValue; + llvm::ArrayRef outputNum(inputNumVec); + // Get ElementType of input. + Type elementTy = input.getType().cast().getElementType(); + VectorType vectorTy = mlir::VectorType::get({strip}, elementTy); + Value inputMem = rewriter.create( + loc, MemRefType::get(inputTy.getShape(), elementTy), input); + Value alloc = rewriter.create( + loc, MemRefType::get(outputNum, elementTy)); + + // Get Constants. + const Value c0 = rewriter.create(loc, 0); + const Value c1 = rewriter.create(loc, 1); + const Value c2 = rewriter.create(loc, 2); + const Value c3 = rewriter.create(loc, 3); + const Value vlStep = rewriter.create(loc, strip); + const Value zero = + buddy::insertZeroConstantOp(ctx, rewriter, loc, elementTy); + + // Create pass through vector. + Value passThroughVec = rewriter.create(loc, vectorTy, zero); + + Value inputDim0 = rewriter.create(loc, inputMem, c0); + Value inputDim1 = rewriter.create(loc, inputMem, c1); + Value inputDim2 = rewriter.create(loc, inputMem, c2); + Value inputDim3 = rewriter.create(loc, inputMem, c3); + + // Calculate the upper bound for vectorized processing + // - Subtract `vlStep` is to avoid overflow at the vectorization tail. + // - Add 1 to ensure the final loop runs when the workload length + // is divisible by the vector size. + Value upperBoundTmp = + rewriter.create(loc, inputDim3, vlStep); + Value upperBound = rewriter.create(loc, upperBoundTmp, c1); + + SmallVector lowerBounds(3, c0); + SmallVector uperBounds{inputDim0, inputDim2, inputDim1}; + SmallVector steps(3, /*Value=*/1); + affine::buildAffineLoopNest( + rewriter, loc, lowerBounds, uperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Create strip mining loop. + auto iterIdx = builder.create( + loc, c0, upperBound, /*Step=*/vlStep, ValueRange{c0}, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange itrArgs) { + Value inputVector = nestedBuilder.create( + loc, vectorTy, inputMem, + ValueRange{ivs[0], ivs[2], ivs[1], iv}); + nestedBuilder.create( + loc, inputVector, alloc, + ValueRange{ivs[0], ivs[1], ivs[2], iv}); + Value idx = + nestedBuilder.create(loc, iv, vlStep); + nestedBuilder.create(loc, idx); + }); + // Compute the tail size and Process the remaining elements + // using masked vector operations. + Value idx = iterIdx.getResult(0); + Value tailSize = builder.create(loc, inputDim3, idx); + Value tailMask = + builder.create(loc, vectorMaskTy, tailSize); + // Masked load input. + Value maskedOutputVec = builder.create( + loc, vectorTy, inputMem, ValueRange{ivs[0], ivs[2], ivs[1], idx}, + tailMask, passThroughVec); + // Masked store the result to output. + builder.create(loc, alloc, + ValueRange{ivs[0], ivs[1], ivs[2], idx}, + tailMask, maskedOutputVec); + }); + Value output = rewriter.create( + loc, input.getType().cast().cloneWith(outputNum, elementTy), + alloc, /*restrict=*/true); + + // Remove the origin convolution operation. + rewriter.replaceOp(op, output); + return success(); + } + +private: + int64_t strip; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// TransposeVectorizationPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering tosa transpose operations to mixture of +/// Arith + Vector operations. +namespace { +class TransposeVectorizationPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransposeVectorizationPass) + StringRef getArgument() const final { return "transpose-vectorization"; } + StringRef getDescription() const final { return "Transpose Vectorization."; } + TransposeVectorizationPass() = default; + TransposeVectorizationPass(const TransposeVectorizationPass &) {} + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + Option strip{*this, "vector-size", + llvm::cl::desc("Specify vector type size."), + llvm::cl::init(64)}; +}; +} // end anonymous namespace. + +void TransposeVectorizationPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, strip); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerTransposeVectorizationPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index 61a0958c72..2dc099dfaf 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -86,6 +86,7 @@ void registerLegalizeShmemOutliningPass(); void registerMatMulTransposeBVecPass(); void registerConvertMemcpyToGPUPass(); void registerLegalizeShmemOutliningPass(); +void registerTransposeVectorizationPass(); } // namespace buddy } // namespace mlir @@ -126,6 +127,7 @@ int main(int argc, char **argv) { mlir::buddy::registerLowerSchePass(); mlir::buddy::registerFuncBufferizeDynamicOffsetPass(); mlir::buddy::registerMatMulTransposeBVecPass(); + mlir::buddy::registerTransposeVectorizationPass(); // Register gpu passes mlir::buddy::registerConvertMemcpyToGPUPass();