diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index b4dde776822a1..dc4e6718907f2 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2347,6 +2347,9 @@ def VectorizeChildrenAndApplyPatternsOp : operation that is contained inside the vectorization target. This transformation supports the following attributes: + - `vectorize_mixed_precision`: a `UnitAttr` to activate the vectorization + of ops that have mixed precision types. This enables the folding of + arith.extFOp/arith.extIOp into vector.contract with mixed precision. - `vectorize_padding`: a `UnitAttr` to activate the vectorization of `tensor.pad` ops. Different pipelines may prefer to lower such ops to loops. @@ -2367,6 +2370,7 @@ def VectorizeChildrenAndApplyPatternsOp : }]; let arguments = (ins TransformHandleTypeInterface:$target, + UnitAttr:$vectorize_mixed_precision, UnitAttr:$vectorize_padding, UnitAttr:$vectorize_nd_extract, UnitAttr:$flatten_1d_depthwise_conv, @@ -2380,6 +2384,7 @@ def VectorizeChildrenAndApplyPatternsOp : let builders = [ OpBuilder<(ins "Value":$target, + CArg<"bool", "false">:$vectorizeMixedPrecision, CArg<"bool", "false">:$vectorizePadding, CArg<"bool", "false">:$vectorizeNDExtract, CArg<"bool", "false">:$flatten1DDepthwise)> diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 5d5f9de465561..c8f256cf38c9d 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3784,8 +3784,15 @@ LogicalResult TileUsingForallOp::verify() { void transform::VectorizeChildrenAndApplyPatternsOp::build( OpBuilder &builder, OperationState &result, Value target, - bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) { + bool vectorizeMixedPrecision, bool vectorizePadding, bool vectorizeExtract, + bool flatten1DDepthwiseConv) { result.addOperands(target); + if (vectorizeMixedPrecision) { + result.addAttribute( + VectorizeChildrenAndApplyPatternsOp::getVectorizeMixedPrecisionAttrName( + result.name), + builder.getUnitAttr()); + } if (vectorizePadding) { result.addAttribute( VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName( @@ -3876,6 +3883,10 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne( patterns.add(ctx); + if (getVectorizeMixedPrecision()) { + vector::populateFoldArithExtensionPatterns(patterns); + } + if (getVectorizePadding()) { linalg::populatePadOpVectorizationPatterns(patterns); // This creates an alternative path for lowering tensor.pad - by diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir index 0d59dbba8940d..96f89653d20ca 100644 --- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir @@ -190,3 +190,92 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// Mixed Precision vetorization tests. + +// CHECK-LABEL: func @mixed_precision_generic_as_contract +// CHECK-COUNT-3: vector.transfer_read +// CHECK-NOT: arith.extf +// CHECK: vector.contract +// CHECK: vector.transfer_write +func.func @mixed_precision_generic_as_contract(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>, + %C: memref<8x32xf32>) { + linalg.generic { + indexing_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> + ], + iterator_types = ["parallel", "parallel", "reduction"] + } + ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>) + outs(%C : memref<8x32xf32>) { + ^bb(%in: bf16, %in_0: bf16, %c: f32) : + %a = arith.extf %in : bf16 to f32 + %b = arith.extf %in_0 : bf16 to f32 + %d = arith.mulf %a, %b: f32 + %e = arith.addf %c, %d: f32 + linalg.yield %e : f32 + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @mixed_precision_matmul_as_contract +// CHECK-COUNT-3: vector.transfer_read +// CHECK-NOT: arith.extf +// CHECK: vector.contract +// CHECK: vector.transfer_write +func.func @mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>, + %B: tensor<12x25xbf16>, + %C: tensor<24x25xf32>) -> tensor<24x25xf32> { + %0 = linalg.contract + indexing_maps = [affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)>] + ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>) + outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32> + func.return %0 : tensor<24x25xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @contraction_matmul +// CHECK-COUNT-3: vector.transfer_read +// CHECK-NOT: arith.extf +// CHECK: vector.contract +func.func @contraction_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) { + linalg.matmul ins(%A, %B: memref<1584x1584xbf16>, memref<1584x1584xbf16>) + outs(%C: memref<1584x1584xf32>) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op + transform.yield + } +}