From f95364df3e826f8bf8747bfc39e3c99c52f6c4c9 Mon Sep 17 00:00:00 2001 From: FloatingcloudKnight <1348185166@qq.com> Date: Wed, 1 Oct 2025 15:24:46 +0000 Subject: [PATCH] Add a pass that fuses matmul and transpose operations For the transpose-matrix multiplication-transpose pattern in Llama2, perform fusion and vectorization during dialect reduction. The acceleration ratio at the operator level is 1.84 before and after fusion, while it is not visible at the model level. --- examples/BuddyNext/makefile | 45 ++- .../BuddyNext/next-matmul-transpose2-vec.mlir | 21 +- .../BuddyNext/next-matmul-transpose2.mlir | 35 +- .../TransposeOptimization/CMakeLists.txt | 1 + .../TransposeFusionVectorization.cpp | 328 ++++++++++++++++++ tools/buddy-opt/buddy-opt.cpp | 2 + 6 files changed, 391 insertions(+), 41 deletions(-) create mode 100644 midend/lib/Conversion/TransposeOptimization/TransposeFusionVectorization.cpp diff --git a/examples/BuddyNext/makefile b/examples/BuddyNext/makefile index c4da8ece5f..36a279b6f9 100644 --- a/examples/BuddyNext/makefile +++ b/examples/BuddyNext/makefile @@ -1235,54 +1235,65 @@ next-compass-run: -shared-libs=${MLIR_RUNNER_UTILS} \ -shared-libs=${MLIR_C_RUNNER_UTILS} -tosa-matmul-transpose2-lower: - @${BUDDY_OPT} ./tosa-matmultranspose2.mlir \ +next-matmul-transpose2-lower: + @${BUDDY_OPT} ./next-matmul-transpose2.mlir \ -transpose-fusion-vectorization \ -o log.mlir -tosa-matmul-transpose2-run: - @${BUDDY_OPT} ./tosa-matmultranspose2.mlir \ +next-matmul-transpose2-run: + @${BUDDY_OPT} ./next-matmul-transpose2.mlir \ -pass-pipeline "builtin.module(transpose-fusion-vectorization, func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" | \ ${BUDDY_OPT} \ + -convert-elementwise-to-linalg \ + -arith-expand \ -eliminate-empty-tensors \ - -convert-tensor-to-linalg \ - -linalg-bufferize \ + -empty-tensor-to-alloc-tensor \ + -one-shot-bufferize="bufferize-function-boundaries" \ -convert-linalg-to-affine-loops \ - -lower-affine \ - -func-bufferize \ - -arith-bufferize \ - -tensor-bufferize \ - -buffer-deallocation \ - -finalizing-bufferize \ + -affine-loop-fusion \ + -affine-parallelize \ + -convert-scf-to-openmp \ -convert-vector-to-scf \ -expand-strided-metadata \ + -lower-affine \ + -cse \ -convert-vector-to-llvm \ + -memref-expand \ + -arith-expand \ -convert-arith-to-llvm \ -finalize-memref-to-llvm \ -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-openmp-to-llvm \ -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ -convert-func-to-llvm \ -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} -tosa-matmul-transpose2-vec-run: - @${BUDDY_OPT} ./tosa-matmultranspose2-vec.mlir\ +next-matmul-transpose2-vec-run: + @${BUDDY_OPT} ./next-matmul-transpose2-vec.mlir \ -convert-linalg-to-affine-loops \ + -affine-loop-fusion \ -affine-parallelize \ - -lower-affine \ -convert-scf-to-openmp \ -convert-vector-to-scf \ -expand-strided-metadata \ + -lower-affine \ + -cse \ -convert-vector-to-llvm \ -memref-expand \ -arith-expand \ -convert-arith-to-llvm \ - -finalize-memref-to-llvm \ + -finalize-memref-to-llvm \ -convert-scf-to-cf \ + -convert-cf-to-llvm \ -convert-openmp-to-llvm \ + -convert-arith-to-llvm \ -convert-math-to-llvm \ - -convert-math-to-libm \ + -convert-math-to-libm \ -convert-func-to-llvm \ -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ diff --git a/examples/BuddyNext/next-matmul-transpose2-vec.mlir b/examples/BuddyNext/next-matmul-transpose2-vec.mlir index 426c6a5193..0ffb989e47 100644 --- a/examples/BuddyNext/next-matmul-transpose2-vec.mlir +++ b/examples/BuddyNext/next-matmul-transpose2-vec.mlir @@ -2,15 +2,13 @@ // RUN: -convert-linalg-to-affine-loops \ // RUN: -lower-affine \ // RUN: -convert-vector-to-scf \ +// RUN: -expand-strided-metadata \ // RUN: -convert-scf-to-cf \ // RUN: -convert-cf-to-llvm \ // RUN: -convert-vector-to-llvm \ -// RUN: -convert-math-to-llvm \ -// RUN: -convert-math-to-libm \ +// RUN: -finalize-memref-to-llvm \ // RUN: -convert-arith-to-llvm \ // RUN: -convert-func-to-llvm \ -// RUN: -expand-strided-metadata \ -// RUN: -finalize-memref-to-llvm \ // RUN: -reconcile-unrealized-casts \ // RUN: | mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ @@ -28,21 +26,21 @@ func.func @test(%a : memref, %b : memref, %c : memref - %dim = memref.dim %a, %c0 : memref // + %dim = memref.dim %a, %c0 : memref // %dim_0 = memref.dim %a, %c1 : memref %dim_1 = memref.dim %a, %c2 : memref %dim_2 = memref.dim %b, %c2 : memref // Calculate the upper bound for vectorized processing // - Subtract `vl_step` is to avoid overflow at the vectorization tail. - // - Add 1 to ensure the final loop runs when the workload length + // - Add 1 to ensure the final loop runs when the workload length // is divisible by the vector size. %dim_2_upbound_tmp = arith.subi %dim_2, %vl_step : index %dim_2_upbound = arith.addi %dim_2_upbound_tmp, %c1 : index affine.for %arg3 = %c0 to %dim { affine.for %arg4 = %c0 to %dim_0 { - %iter_idx = scf.for %arg5 = %c0 to %dim_2_upbound + %iter_idx = scf.for %arg5 = %c0 to %dim_2_upbound step %vl_step iter_args(%iter_init = %c0) -> (index){ %0 = vector.load %c[%arg4, %arg3, %arg5] : memref, vector<32xf32> %iter_value = scf.for %arg6 = %c0 to %dim_1 step %c1 iter_args(%value_init = %0) -> (vector<32xf32>){ @@ -57,9 +55,9 @@ func.func @test(%a : memref, %b : memref, %c : memref %0 = vector.maskedload %c[%arg4, %arg3, %iter_idx], %mask, %v0 : memref, vector<32xi1>, vector<32xf32> into vector<32xf32> %iter_value = scf.for %arg6 = %c0 to %dim_1 step %c1 iter_args(%value_init = %0) -> (vector<32xf32>){ @@ -72,7 +70,7 @@ func.func @test(%a : memref, %b : memref, %c : memref, vector<32xi1>, vector<32xf32> } } - + %t_end = call @rtclock() : () -> f64 %time = arith.subf %t_end, %t_start : f64 // Print timings. @@ -111,8 +109,9 @@ func.func @main(){ %printed_m2 = memref.cast %m2 : memref to memref<*xf32> - // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [40, 32, 128] strides = [4096, 128, 1] data = + // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [40, 32, 128] strides = [4096, 128, 1] data = // CHECK-NEXT: [ + // CHECK: [ // CHECK: [240{{(, 240)*}}] call @printMemrefF32(%printed_m2) : (memref<*xf32>) -> () diff --git a/examples/BuddyNext/next-matmul-transpose2.mlir b/examples/BuddyNext/next-matmul-transpose2.mlir index b15150c76f..889ff8c53a 100644 --- a/examples/BuddyNext/next-matmul-transpose2.mlir +++ b/examples/BuddyNext/next-matmul-transpose2.mlir @@ -1,21 +1,32 @@ // RUN: buddy-opt %s \ // RUN: -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" \ // RUN: | buddy-opt \ +// RUN: -convert-elementwise-to-linalg \ +// RUN: -arith-expand \ // RUN: -eliminate-empty-tensors \ -// RUN: -convert-tensor-to-linalg \ +// RUN: -empty-tensor-to-alloc-tensor \ // RUN: -one-shot-bufferize="bufferize-function-boundaries" \ // RUN: -convert-linalg-to-affine-loops \ -// RUN: -lower-affine \ +// RUN: -affine-loop-fusion \ +// RUN: -affine-parallelize \ +// RUN: -convert-scf-to-openmp \ // RUN: -convert-vector-to-scf \ // RUN: -expand-strided-metadata \ -// RUN: -convert-vector-to-llvm \ -// RUN: -convert-arith-to-llvm \ -// RUN: -finalize-memref-to-llvm \ -// RUN: -convert-scf-to-cf \ -// RUN: -convert-cf-to-llvm \ -// RUN: -convert-arith-to-llvm \ -// RUN: -convert-func-to-llvm \ -// RUN: -reconcile-unrealized-casts \ +// RUN: -lower-affine \ +// RUN: -cse \ +// RUN: -convert-vector-to-llvm \ +// RUN: -memref-expand \ +// RUN: -arith-expand \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-cf-to-llvm \ +// RUN: -convert-openmp-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-math-to-libm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ // RUN: | mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ @@ -32,7 +43,7 @@ func.func @test(%a : tensor<1x40x32x128xf32>, %b : tensor<32x40x40xf32>) -> (ten %3 = tosa.matmul %b, %2 : (tensor<32x40x40xf32>, tensor<32x40x128xf32>) -> tensor<32x40x128xf32> %4 = tosa.reshape %3 {new_shape = array} : (tensor<32x40x128xf32>) -> tensor<1x32x40x128xf32> %5 = "tosa.const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> - %6 = tosa.transpose %4, %5 : (tensor<1x32x40x128xf32>, tensor<4xi32>) -> tensor<1x40x32x128xf32> + %6 = tosa.transpose %4, %5 : (tensor<1x32x40x128xf32>, tensor<4xi32>) -> tensor<1x40x32x128xf32> %t_end = call @rtclock() : () -> f64 %time = arith.subf %t_end, %t_start : f64 // Print timings. @@ -44,8 +55,6 @@ func.func @main(){ %v2 = arith.constant dense<2.0> : tensor<32x40x40xf32> %v3 = arith.constant dense<3.0> : tensor<1x40x32x128xf32> - // %m0 = tensor.cast %v2 : tensor<32x40x40xf32> to tensor - // %m1 = tensor.cast %v3 : tensor<40x32x128xf32> to tensor %m2 = call @test(%v3, %v2) : (tensor<1x40x32x128xf32>, tensor<32x40x40xf32>) -> (tensor<1x40x32x128xf32>) diff --git a/midend/lib/Conversion/TransposeOptimization/CMakeLists.txt b/midend/lib/Conversion/TransposeOptimization/CMakeLists.txt index 70d5ca7fca..6a0e4eaed6 100644 --- a/midend/lib/Conversion/TransposeOptimization/CMakeLists.txt +++ b/midend/lib/Conversion/TransposeOptimization/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(TransposeOptimization BuiltinTransposeVectorization.cpp + TransposeFusionVectorization.cpp LINK_LIBS PUBLIC BuddyUtils ) diff --git a/midend/lib/Conversion/TransposeOptimization/TransposeFusionVectorization.cpp b/midend/lib/Conversion/TransposeOptimization/TransposeFusionVectorization.cpp new file mode 100644 index 0000000000..a5ba29d4b4 --- /dev/null +++ b/midend/lib/Conversion/TransposeOptimization/TransposeFusionVectorization.cpp @@ -0,0 +1,328 @@ +//====- TransposeFusionVectorization.cpp-----------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the transpose-matmul-transpose fusion vectorization. +// +//===----------------------------------------------------------------------===// +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewriter Pattern +//===----------------------------------------------------------------------===// + +namespace { + +/// TransposeFusion vectorization pattern +class TransposeFusionVectorizationPattern : public ConversionPattern { +public: + explicit TransposeFusionVectorizationPattern(MLIRContext *context, + int64_t vecSizeParam) + : ConversionPattern(tosa::MatMulOp::getOperationName(), 1, context) { + vecSize = vecSizeParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = op->getLoc(); + mlir::MLIRContext *ctx = op->getContext(); + + Value A = op->getOperand(0); + Value B = op->getOperand(1); + Value C = op->getOpResult(0); + + tosa::ReshapeOp reshapeBOp = B.getDefiningOp(); + if (!reshapeBOp) { + return failure(); + } + tosa::TransposeOp transposeBOp = + reshapeBOp.getOperand().getDefiningOp(); + if (!transposeBOp) { + return failure(); + } + Value::user_iterator reshapeCUserIt = C.getUsers().begin(); + if (reshapeCUserIt == C.getUsers().end()) { + return failure(); + } + Operation *reshapeCOp = *reshapeCUserIt; + if (!isa(reshapeCOp)) { + return failure(); + } + + Value::user_iterator transposeCUserIt = + reshapeCOp->getOpResult(0).getUsers().begin(); + if (transposeCUserIt == reshapeCOp->getOpResult(0).getUsers().end()) { + return failure(); + } + Operation *transposeCOp = *transposeCUserIt; + if (!isa(transposeCOp)) { + return failure(); + } + + Value::user_iterator nextUserIt = + transposeCOp->getOpResult(0).getUsers().begin(); + if (nextUserIt == transposeCOp->getOpResult(0).getUsers().end()) { + return failure(); + } + Operation *nextUserOp = *nextUserIt; + + // Get i1 as the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + VectorType vectorMaskTy = mlir::VectorType::get({vecSize}, i1); + // Acquire the element type of input tensors. + ShapedType AType = cast(A.getType()); + Type elementType = AType.getElementType(); + VectorType vectorTy = mlir::VectorType::get({vecSize}, elementType); + + ShapedType newBType = + cast(transposeBOp.getOperand(0).getType()); + ShapedType newCType = + cast(transposeCOp->getOpResult(0).getType()); + + // Define constants. + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value c2 = rewriter.create(loc, 2); + Value c3 = rewriter.create(loc, 3); + Value vlStep = rewriter.create(loc, vecSize); + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(elementType)); + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr zeroAffine = rewriter.getAffineConstantExpr(0); + + // Create pass through vector. + Value passThroughVec = rewriter.create(loc, vectorTy, zero); + Value newA = rewriter.create( + loc, MemRefType::get(AType.getShape(), elementType), A); + Value newB = rewriter.create( + loc, MemRefType::get(newBType.getShape(), elementType), + transposeBOp.getOperand(0)); + Value newC = rewriter.create( + loc, MemRefType::get(newCType.getShape(), elementType)); + + // Get dimensions of input tensors. + Value batch = rewriter.create(loc, newA, c0); + Value aRow = rewriter.create(loc, newA, c1); + Value aCol = rewriter.create(loc, newA, c2); + Value bCol = rewriter.create(loc, newB, c3); + + Value upperBoundTmp = rewriter.create(loc, bCol, vlStep); + Value upperBound = rewriter.create(loc, upperBoundTmp, c1); + + AffineMap map0 = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, + {rewriter.getAffineDimExpr(0)}, + rewriter.getContext()); + AffineMap map1 = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, + {rewriter.getAffineDimExpr(1)}, + rewriter.getContext()); + SmallVector lbMaps({map0, map1}); + SmallVector ubMaps({map0, map1}); + SmallVector lbArgs = {c0, c0}; + SmallVector ubArgs = {batch, aRow}; + SmallVector steps = {1, 1}; + + affine::AffineParallelOp parOp = rewriter.create( + loc, /*resultTypes=*/TypeRange{}, + /*reductions=*/ArrayRef{}, lbMaps, lbArgs, ubMaps, + ubArgs, steps); + + // Create the loop body for the parallel loop. + // Block *loopBody = new Block(); + // rewriter.setInsertionPointToStart(loopBody); + // TypeRange types = {rewriter.getIndexType(), rewriter.getIndexType()}; + // ArrayRef locs = {loc, loc}; + // loopBody->addArguments(types, locs); + Block &loopBody = parOp.getRegion().front(); + rewriter.setInsertionPointToStart(&loopBody); + Value ivs0 = loopBody.getArguments()[0]; + Value ivs1 = loopBody.getArguments()[1]; + + auto iterIdx = rewriter.create( + loc, ValueRange{c0}, rewriter.getDimIdentityMap(), + ValueRange{upperBound}, rewriter.getDimIdentityMap(), + /*Step=*/vecSize, ValueRange{c0}, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange itrArgs) { + auto iterVec = nestedBuilder.create( + nestedLoc, ValueRange{c0}, rewriter.getDimIdentityMap(), + ValueRange{aCol}, rewriter.getDimIdentityMap(), /*Step=*/1, + ValueRange{passThroughVec}, + [&](OpBuilder &nestedBuilder0, Location nestedLoc0, Value iv0, + ValueRange itrArgs0) { + Value aVal = nestedBuilder0.create( + nestedLoc0, elementType, newA, ValueRange{ivs0, ivs1, iv0}); + Value aVec = nestedBuilder0.create( + nestedLoc0, vectorTy, aVal); + Value bVec = nestedBuilder0.create( + nestedLoc0, vectorTy, newB, ValueRange{c0, iv0, ivs0, iv}); + // Compute the result vector either through integer + // multiplication and addition or fused multiply-add + // based on the element type. + Value tmpVec; + if (isa(elementType)) { + Value mulVec = nestedBuilder0.create( + nestedLoc0, aVec, bVec); + tmpVec = nestedBuilder0.create( + nestedLoc0, mulVec, itrArgs0[0]); + } else { + tmpVec = nestedBuilder0.create( + nestedLoc0, vectorTy, aVec, bVec, itrArgs0[0]); + } + nestedBuilder0.create(loc, tmpVec); + }); + nestedBuilder.create(nestedLoc, iterVec.getResult(0), + newC, + ValueRange{c0, ivs1, ivs0, iv}); + Value idx = + nestedBuilder.create(nestedLoc, iv, vlStep); + nestedBuilder.create(nestedLoc, idx); + }); + // Compute the tail size and Process the remaining elements + // using masked vector operations. + Value idx = iterIdx.getResult(0); + Value tailSize = rewriter.create(loc, bCol, idx); + Value tailMask = rewriter.create(loc, vectorMaskTy, tailSize); + auto iterVec = rewriter.create( + loc, ValueRange{c0}, rewriter.getDimIdentityMap(), ValueRange{aCol}, + rewriter.getDimIdentityMap(), /*Step=*/1, ValueRange{passThroughVec}, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange itrArgs) { + Value aVal = nestedBuilder.create( + nestedLoc, elementType, newA, ValueRange{ivs0, ivs1, iv}); + Value aVec = + nestedBuilder.create(nestedLoc, vectorTy, aVal); + Value bVec = nestedBuilder.create( + nestedLoc, vectorTy, newB, ValueRange{c0, iv, ivs0, idx}, + tailMask, passThroughVec); + + // Compute the result vector either through integer + // multiplication and addition or fused multiply-add + // based on the element type. + Value tmpVec; + if (isa(elementType)) { + Value mulVec = + nestedBuilder.create(nestedLoc, aVec, bVec); + tmpVec = nestedBuilder.create(nestedLoc, mulVec, + itrArgs[0]); + } else { + tmpVec = nestedBuilder.create( + nestedLoc, vectorTy, aVec, bVec, itrArgs[0]); + } + rewriter.create(nestedLoc, tmpVec); + }); + rewriter.create(loc, newC, ValueRange{c0, ivs1, ivs0, idx}, + tailMask, iterVec.getResult(0)); + + rewriter.setInsertionPointAfter(parOp); + Value output = rewriter.create( + loc, newCType, newC, /*restrict=*/true); + + rewriter.eraseOp(reshapeBOp); + rewriter.eraseOp(transposeBOp); + rewriter.eraseOp(op); + rewriter.eraseOp(reshapeCOp); + rewriter.replaceOp(transposeCOp, output); + return success(); + } + +private: + int64_t vecSize; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// TransposeFusionVectorizationPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering linalg pooling operations to mixture of +/// Affine + Vector operations. +namespace { +class TransposeFusionVectorizationPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransposeFusionVectorizationPass) + StringRef getArgument() const final { + return "transpose-fusion-vectorization"; + } + StringRef getDescription() const final { + return "Transpose Fusion Vectorization."; + } + TransposeFusionVectorizationPass() = default; + TransposeFusionVectorizationPass(const TransposeFusionVectorizationPass &) {} + explicit TransposeFusionVectorizationPass(int64_t vecSizeParam) { + vecSize = vecSizeParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vecSize{*this, "vector-size", + llvm::cl::desc("Affine Vector size."), + llvm::cl::init(16)}; +}; +} // end anonymous namespace. + +void TransposeFusionVectorizationPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerTransposeFusionVectorizationPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index b85356baa5..d809ee5776 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -71,6 +71,7 @@ void registerLowerRVVPass(); void registerMatMulOptimizePass(); void registerMatMulVectorizationPass(); void registerMatMulParallelVectorizationPass(); +void registerTransposeFusionVectorizationPass(); void registerTransposeOptimizationPass(); void registerConvOptimizePass(); void registerConvNhwcFhwcOptimizePass(); @@ -118,6 +119,7 @@ int main(int argc, char **argv) { mlir::buddy::registerBatchMatMulTransVecPass(); mlir::buddy::registerMatMulVectorizationPass(); mlir::buddy::registerMatMulParallelVectorizationPass(); + mlir::buddy::registerTransposeFusionVectorizationPass(); mlir::buddy::registerTransposeOptimizationPass(); mlir::buddy::registerConvOptimizePass(); mlir::buddy::registerConvNhwcFhwcOptimizePass();