Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transform][vector] lowering dynamic shape of tensor.pack to vector #351

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions lib/gc/Transforms/LowerToTileVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ namespace {
#define SAFE_EXPAND(X) X
#define LDBG(X) LLVM_DEBUG(DBGS() << SAFE_EXPAND(X) << "\n")

#define SUPPORT_TENSOR_OP \
tensor::ExpandShapeOp, tensor::CollapseShapeOp, tensor::ConcatOp
#define SUPPORT_TENSOR_OP tensor::ConcatOp

template <typename T, typename U>
struct decay_equiv : std::is_same<typename std::decay<T>::type, U>::type {};
Expand Down Expand Up @@ -559,6 +558,38 @@ struct TensorUnpackConvertVectorPass : public RewritePattern {
}
};

struct TensorPackConvertVectorPass : public RewritePattern {

explicit TensorPackConvertVectorPass(MLIRContext *context)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {

auto tensorPackOp = dyn_cast<tensor::PackOp>(op);
if (!tensorPackOp)
return rewriter.notifyMatchFailure(op, "Not expected operations.");

auto resultTy = cast<ShapedType>(op->getResultTypes()[0]);
ArrayRef<int64_t> retShape = resultTy.getShape();

if (ShapedType::isDynamicShape(retShape))
return rewriter.notifyMatchFailure(
op, "Pack operation result is dynamic shape.");

SmallVector<int64_t> inputVectorSize(
retShape.take_front(tensorPackOp.getSourceRank()));
SmallVector<bool, 5> targetVecDims(inputVectorSize.size(), false);

if (failed(linalg::vectorize(rewriter, op,
/*inputVectorSizes=*/inputVectorSize,
/*inputScalableVecDims=*/targetVecDims, false,
false)))
return rewriter.notifyMatchFailure(op, "Fail to vectorize.");

return success();
}
};

/// Some tensor operation lowering to vector.
///
/// Currently support expand_shape, collapse_shape and concat_shape.
Expand Down Expand Up @@ -590,10 +621,8 @@ struct TensorOpConvertVectorPass : public RewritePattern {
/// Patterns that lower to tile (virtual) vector.
void populateLowerToTileVectorPatterns(RewritePatternSet &patterns) {
patterns.add<OperationConvertTileVectorPass<linalg::LinalgOp>,
OperationConvertTileVectorPass<tensor::PackOp>>(
patterns.getContext());
patterns.add<TensorUnpackConvertVectorPass>(patterns.getContext());
patterns.add<TensorOpConvertVectorPass>(patterns.getContext());
TensorPackConvertVectorPass, TensorUnpackConvertVectorPass,
TensorOpConvertVectorPass>(patterns.getContext());
}

/// LowerToTileVectorPass is a pass that lowers operations to tile (virtual)
Expand Down
48 changes: 26 additions & 22 deletions test/mlir/test/gc/Transforms/vectorization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,12 @@ func.func @add_tensor_test4(%arg0: tensor<12x2x56x56x32xf32>, %arg1: tensor<12x2
}

// CHECK-LABEL: @add_tensor_test5
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[CST_0:.*]] = arith.constant dense<1.000000e+00> : vector<1x8xf32>
// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1x8xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[EMPTY0:.*]] tensor.empty() : tensor<1x8xf32>
// CHECK: %[[WRITE0:.*]] = vector.transfer_write %{{.*}} {in_bounds = [true, true]} : vector<1x8xf32>, tensor<1x8xf32>
// CHECK: %[[extracted_slice:.*]] = tensor.extract_slice %1[0, 0] [1, 8] [1, 1] : tensor<1x8xf32> to tensor<8xf32>
// CHECK: %[[READ1:.*]] = vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<8xf32>, vector<8xf32>
// CHECK: %[[SHAPECAST1:.*]] = vector.shape_cast %[[READ1]] : vector<8xf32> to vector<1x1x1x8xf32>
// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<1x1x1x8xf32>
// CHECK: %[[WRITE0:.*]] = vector.transfer_write %{{.*}} {in_bounds = [true, true, true, true]} : vector<1x1x1x8xf32>, tensor<1x1x1x8xf32>
// CHECK: %[[EXTRACTSLICE:.*]] = tensor.extract_slice %1[0, 0] [1, 8] [1, 1] : tensor<1x8xf32> to tensor<8xf32>
// CHECK: %[[expand:.*]] = tensor.expand_shape
func.func @add_tensor_test5() -> tensor<1x1x1x8xf32> {
%cst = arith.constant 1.000000e+00 : f32
%init = tensor.empty() : tensor<1x8xf32>
Expand All @@ -128,12 +124,7 @@ func.func @add_tensor_test5() -> tensor<1x1x1x8xf32> {
}

// CHECK-LABEL: @tensor_collapse_shape_test6
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[READ1:.*]] = vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor<2x3xf32>, vector<2x3xf32>
// CHECK: %[[SHAPECAST1:.*]] = vector.shape_cast %[[READ1]] : vector<2x3xf32> to vector<6xf32>
// CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<6xf32>
// CHECK: %[[WRITE0:.*]] = vector.transfer_write %{{.*}} {in_bounds = [true]} : vector<6xf32>, tensor<6xf32>
// CHECK: tensor.collapse_shape
func.func @tensor_collapse_shape_test6(%arg0: tensor<2x3xf32>) -> tensor<6xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x3xf32> into tensor<6xf32>
return %0 : tensor<6xf32>
Expand Down Expand Up @@ -183,13 +174,32 @@ func.func @fc_relu_test8(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>,
}


// CHECK-LABEL: @test_pad_dynamic_shape_test9
// CHECK-LABEL: @test_pack_dynamic_shape_test9
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM0:.*]] = tensor.dim {{.*}}, %[[C0]] : tensor<?x?xf32>
// CHECK: %[[DIM1:.*]] = tensor.dim {{.*}}, %[[C1]] : tensor<?x?xf32>
// CHECK: %[[CREATEMASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<4x16xi1>
// CHECK: %[[MASKREAD:.*]] = vector.mask %[[CREATEMASK]] { vector.transfer_read {{.*}}, %[[CST]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x16xf32> } : vector<4x16xi1> -> vector<4x16xf32>
// CHECK: %[[SHAPECAST:.*]] = vector.shape_cast %[[MASKREAD]] : vector<4x16xf32> to vector<1x4x1x16xf32>
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPECAST]], [0, 2, 1, 3] : vector<1x4x1x16xf32> to vector<1x1x4x16xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x4x16xf32>
// CHECK: %[[WRITE0:.*]] = vector.transfer_write %[[TRANSPOSE]], %[[EMPTY]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<1x1x4x16xf32>, tensor<1x1x4x16xf32>
func.func @test_pack_dynamic_shape_test9(%arg0: tensor<?x?xf32>, %arg1: tensor<1x1x4x16xf32>) -> tensor<1x1x4x16xf32> {
%cst = arith.constant 0.0 : f32
%pack = tensor.pack %arg0 padding_value(%cst : f32) outer_dims_perm = [0, 1]
inner_dims_pos = [0, 1] inner_tiles = [4, 16]
into %arg1 : tensor<?x?xf32> -> tensor<1x1x4x16xf32>
return %pack : tensor<1x1x4x16xf32>
}
// CHECK-LABEL: @test_pad_dynamic_shape_test10
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<4x16xf32>
// CHECK: %[[READ0:.*]] = vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<4x16xf32>
// CHECK: %[[WRITE0:.*]] = vector.transfer_write %{{.*}} {in_bounds = [true, true]} : vector<4x16xf32>, tensor<4x16xf32>
func.func @test_pad_dynamic_shape_test9(%arg0: tensor<?x?xf32>) -> tensor<4x16xf32> {
func.func @test_pad_dynamic_shape_test10(%arg0: tensor<?x?xf32>) -> tensor<4x16xf32> {
%f0 = arith.constant 0.0 : f32
%pad = tensor.pad %arg0 low[0, 0] high[1,1] {
^bb0(%arg3: index, %arg4: index):
Expand All @@ -205,13 +215,7 @@ func.func @test_add_dynamic_shape_test11(%arg0: tensor<?x?xf32>, %arg1: tensor<4
return %1 : tensor<4x16xf32>
}

func.func @test_collapse_dynamic_shape_test12(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
// expected-error @+1 {{Fail to vectorize.}}
%0 = tensor.collapse_shape %arg0 [[0], [1, 2]]: tensor<?x?x?xf32> into tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

func.func @test_expand_dynamic_shape_test13(%arg0 : tensor<2x?x5xf32>, %sz0: index) ->tensor<2x6x?x5xf32> {
func.func @test_expand_dynamic_shape_test12(%arg0 : tensor<2x?x5xf32>, %sz0: index) ->tensor<2x6x?x5xf32> {
// expected-error @+1 {{Fail to vectorize.}}
%0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [2, 6, %sz0, 5] : tensor<2x?x5xf32> into tensor<2x6x?x5xf32>
return %0 : tensor<2x6x?x5xf32>
Expand Down