diff --git a/lib/gc/Transforms/LowerToTileVector.cpp b/lib/gc/Transforms/LowerToTileVector.cpp index d105eaeb8..b5f92dca9 100644 --- a/lib/gc/Transforms/LowerToTileVector.cpp +++ b/lib/gc/Transforms/LowerToTileVector.cpp @@ -34,8 +34,7 @@ namespace { #define SAFE_EXPAND(X) X #define LDBG(X) LLVM_DEBUG(DBGS() << SAFE_EXPAND(X) << "\n") -#define SUPPORT_TENSOR_OP \ - tensor::ExpandShapeOp, tensor::CollapseShapeOp, tensor::ConcatOp +#define SUPPORT_TENSOR_OP tensor::ConcatOp template struct decay_equiv : std::is_same::type, U>::type {}; @@ -559,6 +558,38 @@ struct TensorUnpackConvertVectorPass : public RewritePattern { } }; +struct TensorPackConvertVectorPass : public RewritePattern { + + explicit TensorPackConvertVectorPass(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + + auto tensorPackOp = dyn_cast(op); + if (!tensorPackOp) + return rewriter.notifyMatchFailure(op, "Not expected operations."); + + auto resultTy = cast(op->getResultTypes()[0]); + ArrayRef retShape = resultTy.getShape(); + + if (ShapedType::isDynamicShape(retShape)) + return rewriter.notifyMatchFailure( + op, "Pack operation result is dynamic shape."); + + SmallVector inputVectorSize( + retShape.take_front(tensorPackOp.getSourceRank())); + SmallVector targetVecDims(inputVectorSize.size(), false); + + if (failed(linalg::vectorize(rewriter, op, + /*inputVectorSizes=*/inputVectorSize, + /*inputScalableVecDims=*/targetVecDims, false, + false))) + return rewriter.notifyMatchFailure(op, "Fail to vectorize."); + + return success(); + } +}; + /// Some tensor operation lowering to vector. /// /// Currently support expand_shape, collapse_shape and concat_shape. @@ -590,10 +621,8 @@ struct TensorOpConvertVectorPass : public RewritePattern { /// Patterns that lower to tile (virtual) vector. void populateLowerToTileVectorPatterns(RewritePatternSet &patterns) { patterns.add, - OperationConvertTileVectorPass>( - patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); + TensorPackConvertVectorPass, TensorUnpackConvertVectorPass, + TensorOpConvertVectorPass>(patterns.getContext()); } /// LowerToTileVectorPass is a pass that lowers operations to tile (virtual) diff --git a/test/mlir/test/gc/Transforms/vectorization.mlir b/test/mlir/test/gc/Transforms/vectorization.mlir index 2f87dffc3..acb8fec09 100644 --- a/test/mlir/test/gc/Transforms/vectorization.mlir +++ b/test/mlir/test/gc/Transforms/vectorization.mlir @@ -108,16 +108,12 @@ func.func @add_tensor_test4(%arg0: tensor<12x2x56x56x32xf32>, %arg1: tensor<12x2 } // CHECK-LABEL: @add_tensor_test5 -// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[CST_0:.*]] = arith.constant dense<1.000000e+00> : vector<1x8xf32> +// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1x8xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[EMPTY0:.*]] tensor.empty() : tensor<1x8xf32> // CHECK: %[[WRITE0:.*]] = vector.transfer_write %{{.*}} {in_bounds = [true, true]} : vector<1x8xf32>, tensor<1x8xf32> -// CHECK: %[[extracted_slice:.*]] = tensor.extract_slice %1[0, 0] [1, 8] [1, 1] : tensor<1x8xf32> to tensor<8xf32> -// CHECK: %[[READ1:.*]] = vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<8xf32>, vector<8xf32> -// CHECK: %[[SHAPECAST1:.*]] = vector.shape_cast %[[READ1]] : vector<8xf32> to vector<1x1x1x8xf32> -// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<1x1x1x8xf32> -// CHECK: %[[WRITE0:.*]] = vector.transfer_write %{{.*}} {in_bounds = [true, true, true, true]} : vector<1x1x1x8xf32>, tensor<1x1x1x8xf32> +// CHECK: %[[EXTRACTSLICE:.*]] = tensor.extract_slice %1[0, 0] [1, 8] [1, 1] : tensor<1x8xf32> to tensor<8xf32> +// CHECK: %[[expand:.*]] = tensor.expand_shape func.func @add_tensor_test5() -> tensor<1x1x1x8xf32> { %cst = arith.constant 1.000000e+00 : f32 %init = tensor.empty() : tensor<1x8xf32> @@ -128,12 +124,7 @@ func.func @add_tensor_test5() -> tensor<1x1x1x8xf32> { } // CHECK-LABEL: @tensor_collapse_shape_test6 -// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[READ1:.*]] = vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor<2x3xf32>, vector<2x3xf32> -// CHECK: %[[SHAPECAST1:.*]] = vector.shape_cast %[[READ1]] : vector<2x3xf32> to vector<6xf32> -// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<6xf32> -// CHECK: %[[WRITE0:.*]] = vector.transfer_write %{{.*}} {in_bounds = [true]} : vector<6xf32>, tensor<6xf32> +// CHECK: tensor.collapse_shape func.func @tensor_collapse_shape_test6(%arg0: tensor<2x3xf32>) -> tensor<6xf32> { %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x3xf32> into tensor<6xf32> return %0 : tensor<6xf32> @@ -183,13 +174,32 @@ func.func @fc_relu_test8(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, } -// CHECK-LABEL: @test_pad_dynamic_shape_test9 +// CHECK-LABEL: @test_pack_dynamic_shape_test9 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM0:.*]] = tensor.dim {{.*}}, %[[C0]] : tensor +// CHECK: %[[DIM1:.*]] = tensor.dim {{.*}}, %[[C1]] : tensor +// CHECK: %[[CREATEMASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<4x16xi1> +// CHECK: %[[MASKREAD:.*]] = vector.mask %[[CREATEMASK]] { vector.transfer_read {{.*}}, %[[CST]] {in_bounds = [true, true]} : tensor, vector<4x16xf32> } : vector<4x16xi1> -> vector<4x16xf32> +// CHECK: %[[SHAPECAST:.*]] = vector.shape_cast %[[MASKREAD]] : vector<4x16xf32> to vector<1x4x1x16xf32> +// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPECAST]], [0, 2, 1, 3] : vector<1x4x1x16xf32> to vector<1x1x4x16xf32> +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x4x16xf32> +// CHECK: %[[WRITE0:.*]] = vector.transfer_write %[[TRANSPOSE]], %[[EMPTY]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<1x1x4x16xf32>, tensor<1x1x4x16xf32> +func.func @test_pack_dynamic_shape_test9(%arg0: tensor, %arg1: tensor<1x1x4x16xf32>) -> tensor<1x1x4x16xf32> { + %cst = arith.constant 0.0 : f32 + %pack = tensor.pack %arg0 padding_value(%cst : f32) outer_dims_perm = [0, 1] + inner_dims_pos = [0, 1] inner_tiles = [4, 16] + into %arg1 : tensor -> tensor<1x1x4x16xf32> + return %pack : tensor<1x1x4x16xf32> +} +// CHECK-LABEL: @test_pad_dynamic_shape_test10 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<4x16xf32> // CHECK: %[[READ0:.*]] = vector.transfer_read %{{.*}} : tensor, vector<4x16xf32> // CHECK: %[[WRITE0:.*]] = vector.transfer_write %{{.*}} {in_bounds = [true, true]} : vector<4x16xf32>, tensor<4x16xf32> -func.func @test_pad_dynamic_shape_test9(%arg0: tensor) -> tensor<4x16xf32> { +func.func @test_pad_dynamic_shape_test10(%arg0: tensor) -> tensor<4x16xf32> { %f0 = arith.constant 0.0 : f32 %pad = tensor.pad %arg0 low[0, 0] high[1,1] { ^bb0(%arg3: index, %arg4: index): @@ -205,13 +215,7 @@ func.func @test_add_dynamic_shape_test11(%arg0: tensor, %arg1: tensor<4 return %1 : tensor<4x16xf32> } -func.func @test_collapse_dynamic_shape_test12(%arg0: tensor) -> tensor { - // expected-error @+1 {{Fail to vectorize.}} - %0 = tensor.collapse_shape %arg0 [[0], [1, 2]]: tensor into tensor - return %0 : tensor -} - -func.func @test_expand_dynamic_shape_test13(%arg0 : tensor<2x?x5xf32>, %sz0: index) ->tensor<2x6x?x5xf32> { +func.func @test_expand_dynamic_shape_test12(%arg0 : tensor<2x?x5xf32>, %sz0: index) ->tensor<2x6x?x5xf32> { // expected-error @+1 {{Fail to vectorize.}} %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [2, 6, %sz0, 5] : tensor<2x?x5xf32> into tensor<2x6x?x5xf32> return %0 : tensor<2x6x?x5xf32>