diff --git a/xla/backends/gpu/codegen/triton/BUILD b/xla/backends/gpu/codegen/triton/BUILD index 20c51158f5ffb..6ed31c874f613 100644 --- a/xla/backends/gpu/codegen/triton/BUILD +++ b/xla/backends/gpu/codegen/triton/BUILD @@ -414,6 +414,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/codegen:emitter_loc_op_builder", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/hlo_to_mhlo:attribute_importer", "//xla/service:algorithm_util", "//xla/service/llvm_ir:llvm_util", "//xla/tsl/platform:errors", diff --git a/xla/backends/gpu/codegen/triton/dot_algorithms.cc b/xla/backends/gpu/codegen/triton/dot_algorithms.cc index 97c767c817066..bd12198ffbef8 100644 --- a/xla/backends/gpu/codegen/triton/dot_algorithms.cc +++ b/xla/backends/gpu/codegen/triton/dot_algorithms.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/codegen/emitter_loc_op_builder.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" #include "xla/service/algorithm_util.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/tsl/platform/errors.h" @@ -70,12 +71,26 @@ struct PrecisionSpec { PrecisionConfig::Algorithm algorithm; // TODO(bchetioui): we hope to get rid of operand precisions eventually, they // are currently a (XLA-wide) bridge to work with ALG_UNSET. - PrecisionConfig::Precision lhs_operand_precision; - PrecisionConfig::Precision rhs_operand_precision; - // Encodes `tt.dot`'s `inputPrecision` attribute. - ttir::InputPrecision ttir_input_precision; + mlir::stablehlo::Precision lhs_operand_precision; + mlir::stablehlo::Precision rhs_operand_precision; + // Encodes `stablehlo.dot`'s `precision` attribute. + mlir::stablehlo::DotAlgorithmAttr stablehlo_dot_algorithm; }; +mlir::stablehlo::Precision XlaPrecisionToStableHloPrecision( + PrecisionConfig::Precision precision) { + switch (precision) { + case PrecisionConfig::DEFAULT: + return mlir::stablehlo::Precision::DEFAULT; + case PrecisionConfig::HIGH: + return mlir::stablehlo::Precision::HIGH; + case PrecisionConfig::HIGHEST: + return mlir::stablehlo::Precision::HIGHEST; + default: + LOG(FATAL) << "Unsupported precision: " << precision; + } +} + using AlgorithmEmitter = absl::StatusOr (*)(EmitterLocOpBuilder, const DotOperands&, const PrecisionSpec&); @@ -170,10 +185,48 @@ absl::StatusOr ScaledDot(EmitterLocOpBuilder b, rhs_dot_elem_type, true); } -Value IEEEDot(EmitterLocOpBuilder b, Value lhs, Value rhs, Value acc) { - return b.create(lhs, rhs, acc, - /*inputPrecision=*/ttir::InputPrecision::IEEE, - /*maxNumImpreciseAcc=*/0); +namespace { + +Value EmitStableHloDotAndAdd(EmitterLocOpBuilder b, Value lhs, Value rhs, + Value acc, PrecisionSpec precision_spec) { + auto lhs_type = mlir::cast(lhs.getType()); + auto rhs_type = mlir::cast(rhs.getType()); + + CHECK(lhs_type.getRank() <= 2 && rhs_type.getRank() <= 2) + << "Unsupported ranks. LHS rank: " << lhs_type.getRank() + << " RHS rank: " << rhs_type.getRank(); + + llvm::SmallVector array_attr{0}; + auto dot_dimension_numbers = mlir::stablehlo::DotDimensionNumbersAttr::get( + b.getContext(), /*lhsBatchingDimensions=*/{}, + /*rhsBatchingDimensions=*/{}, + /*lhsContractingDimensions=*/ + {lhs_type.getRank() - 1}, + /*rhsContractingDimensions=*/ + {0}); + + auto precision_config = mlir::stablehlo::PrecisionConfigAttr::get( + b.getContext(), {precision_spec.lhs_operand_precision, + precision_spec.rhs_operand_precision}); + + auto dot = b.create( + acc.getType(), lhs, rhs, dot_dimension_numbers, + /*precision_config=*/precision_config, + /*algorithm=*/precision_spec.stablehlo_dot_algorithm); + + auto add_result = + mlir::isa(dot.getResult().getType().getElementType()) + ? b.create(acc, dot) + : b.create(acc, dot); + return add_result->getResult(0); +} + +} // namespace + +Value IEEEDot(EmitterLocOpBuilder b, Value lhs, Value rhs, Value acc, + PrecisionSpec precision_spec) { + return EmitStableHloDotAndAdd(b, lhs, rhs, acc, + /*precision_spec=*/precision_spec); } // Leverages BF16 datatype for F32 matmul computation. It follows the guidance @@ -196,20 +249,25 @@ absl::StatusOr EmitBF16x9Matmul(EmitterLocOpBuilder b, Value result = triton::ZerosLike(b, dot_operands.accumulator); - result = IEEEDot(b, lhs_parts[kLow], rhs_parts[kLow], result); - result = IEEEDot(b, lhs_parts[kMid], rhs_parts[kLow], result); - result = IEEEDot(b, lhs_parts[kLow], rhs_parts[kMid], result); + result = IEEEDot(b, lhs_parts[kLow], rhs_parts[kLow], result, precision_spec); + result = IEEEDot(b, lhs_parts[kMid], rhs_parts[kLow], result, precision_spec); + result = IEEEDot(b, lhs_parts[kLow], rhs_parts[kMid], result, precision_spec); - result = IEEEDot(b, lhs_parts[kMid], rhs_parts[kMid], result); + result = IEEEDot(b, lhs_parts[kMid], rhs_parts[kMid], result, precision_spec); - result = IEEEDot(b, lhs_parts[kLow], rhs_parts[kHigh], result); - result = IEEEDot(b, lhs_parts[kHigh], rhs_parts[kLow], result); + result = + IEEEDot(b, lhs_parts[kLow], rhs_parts[kHigh], result, precision_spec); + result = + IEEEDot(b, lhs_parts[kHigh], rhs_parts[kLow], result, precision_spec); - result = IEEEDot(b, lhs_parts[kMid], rhs_parts[kHigh], result); - result = IEEEDot(b, lhs_parts[kHigh], rhs_parts[kMid], result); + result = + IEEEDot(b, lhs_parts[kMid], rhs_parts[kHigh], result, precision_spec); + result = + IEEEDot(b, lhs_parts[kHigh], rhs_parts[kMid], result, precision_spec); result = ZeroNaNs(b, result); - result = IEEEDot(b, lhs_parts[kHigh], rhs_parts[kHigh], result); + result = + IEEEDot(b, lhs_parts[kHigh], rhs_parts[kHigh], result, precision_spec); result = b.create(dot_operands.accumulator, result); return result; } @@ -234,16 +292,21 @@ absl::StatusOr EmitBF16x6Matmul(EmitterLocOpBuilder b, Value result = triton::ZerosLike(b, dot_operands.accumulator); - result = IEEEDot(b, lhs_parts[kMid], rhs_parts[kMid], result); + result = IEEEDot(b, lhs_parts[kMid], rhs_parts[kMid], result, precision_spec); - result = IEEEDot(b, lhs_parts[kLow], rhs_parts[kHigh], result); - result = IEEEDot(b, lhs_parts[kHigh], rhs_parts[kLow], result); + result = + IEEEDot(b, lhs_parts[kLow], rhs_parts[kHigh], result, precision_spec); + result = + IEEEDot(b, lhs_parts[kHigh], rhs_parts[kLow], result, precision_spec); - result = IEEEDot(b, lhs_parts[kMid], rhs_parts[kHigh], result); - result = IEEEDot(b, lhs_parts[kHigh], rhs_parts[kMid], result); + result = + IEEEDot(b, lhs_parts[kMid], rhs_parts[kHigh], result, precision_spec); + result = + IEEEDot(b, lhs_parts[kHigh], rhs_parts[kMid], result, precision_spec); result = ZeroNaNs(b, result); - result = IEEEDot(b, lhs_parts[kHigh], rhs_parts[kHigh], result); + result = + IEEEDot(b, lhs_parts[kHigh], rhs_parts[kHigh], result, precision_spec); result = b.create(dot_operands.accumulator, result); return result; } @@ -266,32 +329,18 @@ absl::StatusOr EmitBF16x3Matmul(EmitterLocOpBuilder b, std::vector rhs_bf16 = SplitF32(b, dot_operands.rhs, kNumParts); Value result = triton::ZerosLike(b, dot_operands.accumulator); - result = IEEEDot(b, lhs_bf16[kLow], rhs_bf16[kHigh], result); - result = IEEEDot(b, lhs_bf16[kHigh], rhs_bf16[kLow], result); + result = IEEEDot(b, lhs_bf16[kLow], rhs_bf16[kHigh], result, precision_spec); + result = IEEEDot(b, lhs_bf16[kHigh], rhs_bf16[kLow], result, precision_spec); result = ZeroNaNs(b, result); - result = IEEEDot(b, lhs_bf16[kHigh], rhs_bf16[kHigh], result); + result = IEEEDot(b, lhs_bf16[kHigh], rhs_bf16[kHigh], result, precision_spec); result = b.create(dot_operands.accumulator, result); return result; } -bool IsTf32Allowed(const HloDotInstruction& dot) { - auto precision_config = dot.precision_config(); - if (precision_config.algorithm() == PrecisionConfig::ALG_UNSET) { - return tsl::tensor_float_32_execution_enabled() && - precision_config.operand_precision(0) == PrecisionConfig::DEFAULT && - precision_config.operand_precision(1) == PrecisionConfig::DEFAULT; - } - return algorithm_util::HasTf32InputType(precision_config.algorithm()); -} - -ttir::InputPrecision InferDotPrecision(const HloDotInstruction& dot) { - if (dot.precision_config().algorithm() == - PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3) { - return ttir::InputPrecision::TF32x3; - } - - return IsTf32Allowed(dot) ? ttir::InputPrecision::TF32 - : ttir::InputPrecision::IEEE; +mlir::stablehlo::DotAlgorithmAttr InferDotPrecision( + const HloDotInstruction& dot, EmitterLocOpBuilder& builder) { + return stablehlo::ConvertDotAlgorithm(dot.precision_config().algorithm(), + &builder); } absl::StatusOr GetAlgUnsetAccumulatorType(EmitterLocOpBuilder b, @@ -333,18 +382,8 @@ absl::StatusOr EmitDotAlgUnset(EmitterLocOpBuilder b, Value rhs = dot_operands.rhs; Value acc = dot_operands.accumulator; - int max_num_imprecise_acc = 0; - if (ElementType(lhs).isFloat(8) || ElementType(rhs).isFloat(8)) { - // For fp8 dots, disable accumulator promotion to mimick cuBLAS. It may make - // sense to enable frequent accumulator promotion at higher matmul - // precisions set in the config. - max_num_imprecise_acc = std::numeric_limits::max(); - } - - return b.create( - lhs, rhs, acc, - /*inputPrecision=*/precision_spec.ttir_input_precision, - /*maxNumImpreciseAcc=*/max_num_imprecise_acc); + return EmitStableHloDotAndAdd(b, lhs, rhs, acc, + /*precision_spec=*/precision_spec); } absl::StatusOr EmitRegularDot(EmitterLocOpBuilder b, @@ -353,14 +392,6 @@ absl::StatusOr EmitRegularDot(EmitterLocOpBuilder b, Value lhs = dot_operands.lhs; Value rhs = dot_operands.rhs; - int max_num_imprecise_acc = 0; - if (ElementType(lhs).isFloat(8) || ElementType(rhs).isFloat(8)) { - // For fp8 dots, disable accumulator promotion to mimick cuBLAS. It may make - // sense to enable frequent accumulator promotion at higher matmul - // precisions set in the config. - max_num_imprecise_acc = std::numeric_limits::max(); - } - // Cast F32 inputs to BF16 if the algorithm is BF16_BF16_F32. // TODO(bchetioui): abstract this. if (precision_spec.algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32) { @@ -373,10 +404,9 @@ absl::StatusOr EmitRegularDot(EmitterLocOpBuilder b, } } - return b.create( - dot_operands.lhs, dot_operands.rhs, dot_operands.accumulator, - /*inputPrecision=*/precision_spec.ttir_input_precision, - /*maxNumImpreciseAcc=*/max_num_imprecise_acc); + return EmitStableHloDotAndAdd(b, dot_operands.lhs, dot_operands.rhs, + dot_operands.accumulator, + /*precision_spec=*/precision_spec); } // Returns an emitter for the given dot algorithm. Raises an @@ -489,9 +519,12 @@ absl::StatusOr EmitSingleTileDot(EmitterLocOpBuilder b, const HloDotInstruction& dot, DotOperands dot_operands) { PrecisionConfig::Algorithm algorithm = dot.precision_config().algorithm(); - PrecisionSpec precision_spec{ - algorithm, dot.precision_config().operand_precision(0), - dot.precision_config().operand_precision(1), InferDotPrecision(dot)}; + PrecisionSpec precision_spec{algorithm, + XlaPrecisionToStableHloPrecision( + dot.precision_config().operand_precision(0)), + XlaPrecisionToStableHloPrecision( + dot.precision_config().operand_precision(1)), + InferDotPrecision(dot, b)}; TF_ASSIGN_OR_RETURN(AlgorithmEmitter algorithm_emitter, GetAlgorithmEmitter(algorithm)); diff --git a/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc b/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc index d1f748027bba5..f820dbc51da27 100644 --- a/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc +++ b/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc @@ -3409,10 +3409,20 @@ ENTRY entry { "num_ctas":"1", "num_stages":"1"}}} })"; - // We expect that for loop instruction will be optimized away. - TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "fdot", R"( + + TF_ASSERT_OK_AND_ASSIGN(auto xtile_module_and_hlo_module, + CreateXTileIrAndFileCheck(this, kHloText, "fdot", + R"( +CHECK: stablehlo.dot_general +CHECK: arith.addf + )")); + + TF_ASSERT_OK(LowerXTileIrToTritonAndFileCheck( + this, xtile_module_and_hlo_module.first.get(), R"( CHECK: tt.dot {{.*}} -> tensor<16x16xf32> -)")); + )", + GetFusionInstruction(*xtile_module_and_hlo_module.second, "fdot"))); + EXPECT_TRUE(RunAndCompareNoHloPasses( kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); } @@ -3472,8 +3482,32 @@ ENTRY entry { const bool is_tma_allowed = GetParam(); std::string hlo_text = absl::Substitute(kHloTextTemplate, is_tma_allowed); - TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, hlo_text, "fdot", R"( -CHECK: xtile.entry_func @triton_fn(%[[ARG0:[A-Za-z0-9_]*]]: memref<32x123xf32> + + TF_ASSERT_OK_AND_ASSIGN(auto xtile_module_and_hlo_module, + CreateXTileIrAndFileCheck(this, hlo_text, "fdot", + R"( +CHECK: xtile.entry_func @xtile_dialect_fn(%[[ARG0:[A-Za-z0-9_]*]]: memref<32x123xf32> +CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: memref<123x512xf32> +CHECK-SAME: %[[ARG2:[A-Za-z0-9_]*]]: memref<32x512xf32> +CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +CHECK: {{.*}} = scf.for %{{.*}} = %[[C0]] to %[[C4]] step %[[C1]] +CHECK-SAME: iter_args({{.*}}) -> (tensor<16x64xf32>) { +CHECK-DAG: xtile.extract %[[ARG0]] +CHECK-DAG: xtile.extract %[[ARG1]] +CHECK-DAG: arith.subf {{.*}} : tensor<16x32xf32> +CHECK-DAG: math.absf {{.*}} : tensor<32x64xf32> +CHECK: stablehlo.dot_general {{.*}} (tensor<16x32xf32>, tensor<32x64xf32>) -> tensor<16x64xf32> +CHECK: arith.addf {{.*}} +CHECK: scf.yield {{.*}} : tensor<16x64xf32> +CHECK-COUNT-1: xtile.insert + + )")); + + TF_ASSERT_OK(LowerXTileIrToTritonAndFileCheck( + this, xtile_module_and_hlo_module.first.get(), R"( +CHECK: xtile.entry_func @xtile_dialect_fn(%[[ARG0:[A-Za-z0-9_]*]]: memref<32x123xf32> CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: memref<123x512xf32> CHECK-SAME: %[[ARG2:[A-Za-z0-9_]*]]: memref<32x512xf32> CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -3488,7 +3522,9 @@ CHECK-DAG: math.absf {{.*}} : tensor<32x64xf32> CHECK: tt.dot {{.*}} tensor<16x32xf32> * tensor<32x64xf32> -> tensor<16x64xf32> CHECK: scf.yield {{.*}} : tensor<16x64xf32> CHECK-COUNT-1: xtile.insert -)")); + )", + GetFusionInstruction(*xtile_module_and_hlo_module.second, "fdot"))); + EXPECT_TRUE(RunAndCompareNoHloPasses( hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); } @@ -3610,7 +3646,21 @@ ENTRY entry { "num_stages":"1"}}} })"; - TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "fdot", R"( + TF_ASSERT_OK_AND_ASSIGN(auto xtile_module_and_hlo_module, + CreateXTileIrAndFileCheck(this, kHloText, "fdot", R"( + // Ensure that masking is applied only conditionally to both operands. + CHECK: %[[MASKED_OPERAND0:.*]] = scf.if + CHECK: %[[SELECT0:.*]] = arith.select + CHECK-NEXT: scf.yield %[[SELECT0]] + CHECK: %[[MASKED_OPERAND1:.*]] = scf.if + CHECK: %[[SELECT1:.*]] = arith.select + CHECK-NEXT: scf.yield %[[SELECT1]] + CHECK: stablehlo.dot_general %[[MASKED_OPERAND0]], %[[MASKED_OPERAND1]] + CHECK: arith.addf %{{.*}} + )")); + + TF_ASSERT_OK(LowerXTileIrToTritonAndFileCheck( + this, xtile_module_and_hlo_module.first.get(), R"( // Ensure that masking is applied only conditionally to both operands. CHECK: %[[MASKED_OPERAND0:.*]] = scf.if CHECK: %[[SELECT0:.*]] = arith.select @@ -3619,7 +3669,8 @@ ENTRY entry { CHECK: %[[SELECT1:.*]] = arith.select CHECK-NEXT: scf.yield %[[SELECT1]] CHECK: tt.dot %[[MASKED_OPERAND0]], %[[MASKED_OPERAND1]] -)")); + )", + GetFusionInstruction(*xtile_module_and_hlo_module.second, "fdot"))); EXPECT_TRUE(RunAndCompareNoHloPasses( kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); @@ -4157,13 +4208,26 @@ TEST_P(BasicDotAlgorithmEmitterTest, BasicAlgorithmIsEmittedCorrectly) { algorithm_util::GetDotAccumulatorType(algorithm)); const std::string kHloText = GetDotAlgorithmHlo(in_ty, out_ty, algorithm); - TF_EXPECT_OK(CreateTritonIrAndFileCheck( - this, kHloText, "dot", + TF_ASSERT_OK_AND_ASSIGN( + auto xtile_module_and_hlo_module, + CreateXTileIrAndFileCheck( + this, kHloText, "dot", + absl::Substitute( + R"( + CHECK: stablehlo.dot_general{{.*}} : (tensor<16x32x$0>, tensor<32x64x$0>) -> tensor<16x64x$1> + CHECK: arith.addf + )", + primitive_util::LowercasePrimitiveTypeName(in_ty), + primitive_util::LowercasePrimitiveTypeName(out_ty)))); + + TF_ASSERT_OK(LowerXTileIrToTritonAndFileCheck( + this, xtile_module_and_hlo_module.first.get(), absl::Substitute(R"( CHECK: tt.dot{{.*}} : tensor<16x32x$0> * tensor<32x64x$0> -> tensor<16x64x$1> )", primitive_util::LowercasePrimitiveTypeName(in_ty), - primitive_util::LowercasePrimitiveTypeName(out_ty)))); + primitive_util::LowercasePrimitiveTypeName(out_ty)), + GetFusionInstruction(*xtile_module_and_hlo_module.second, "dot"))); EXPECT_TRUE( RunAndCompareNoHloPasses(kHloText, ErrorSpecForDotAlgorithm(algorithm))); @@ -4220,14 +4284,27 @@ TEST_P(MultiDotAlgorithmEmitterTest, MultiDotAlgorithmIsEmittedCorrectly) { const std::string kHloText = GetDotAlgorithmHlo(in_ty, out_ty, algorithm); - TF_EXPECT_OK(CreateTritonIrAndFileCheck( - this, kHloText, "dot", + TF_ASSERT_OK_AND_ASSIGN( + auto xtile_module_and_hlo_module, + CreateXTileIrAndFileCheck( + this, kHloText, "dot", + absl::Substitute( + R"( + CHECK-COUNT-$2: stablehlo.dot_general{{.*}} : (tensor<16x32x$0>, tensor<32x64x$0>) -> tensor<16x64x$1> + )", + primitive_util::LowercasePrimitiveTypeName(in_ty), + primitive_util::LowercasePrimitiveTypeName(out_ty), + dot_count_for_algorithm))); + + TF_ASSERT_OK(LowerXTileIrToTritonAndFileCheck( + this, xtile_module_and_hlo_module.first.get(), absl::Substitute(R"( CHECK-COUNT-$2: tt.dot{{.*}}$3{{.*}} : tensor<16x32x$0> * tensor<32x64x$0> -> tensor<16x64x$1> )", primitive_util::LowercasePrimitiveTypeName(in_ty), primitive_util::LowercasePrimitiveTypeName(out_ty), - dot_count_for_algorithm, input_precision_string))); + dot_count_for_algorithm, input_precision_string), + GetFusionInstruction(*xtile_module_and_hlo_module.second, "dot"))); EXPECT_TRUE( RunAndCompareNoHloPasses(kHloText, ErrorSpecForDotAlgorithm(algorithm))); @@ -4260,14 +4337,32 @@ TEST_P(TF32DotAlgorithmEmitterTest, TF32AlgorithmsUseTF32InputPrecision) { algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3 ? "tf32x3" : "tf32"; - TF_EXPECT_OK(CreateTritonIrAndFileCheck( - this, kHloText, "dot", + std::string num_primitive_operations_string = + algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3 ? "3" : "1"; + + // TODO(basioli): maybe algorithm string? + TF_ASSERT_OK_AND_ASSIGN( + auto xtile_module_and_hlo_module, + CreateXTileIrAndFileCheck( + this, kHloText, "dot", + absl::Substitute( + R"( + CHECK: stablehlo.dot_general{{.*}}, contracting_dims = [1] x [0], {{.*}} algorithm = : (tensor<16x32x$0>, tensor<32x64x$0>) -> tensor<16x64x$1> + )", + primitive_util::LowercasePrimitiveTypeName(in_ty), + primitive_util::LowercasePrimitiveTypeName(out_ty), + num_primitive_operations_string))); + + TF_ASSERT_OK(LowerXTileIrToTritonAndFileCheck( + this, xtile_module_and_hlo_module.first.get(), absl::Substitute(R"( CHECK: tt.dot{{.*}} inputPrecision = $2 : tensor<16x32x$0> * tensor<32x64x$0> -> tensor<16x64x$1> )", primitive_util::LowercasePrimitiveTypeName(in_ty), primitive_util::LowercasePrimitiveTypeName(out_ty), - input_precision_string))); + input_precision_string), + GetFusionInstruction(*xtile_module_and_hlo_module.second, "dot"))); + // No need to `RunAndCompare` here, these algorithms are already covered by // other tests. } diff --git a/xla/backends/gpu/codegen/triton/fusion_emitter_shared_dialect_test.cc b/xla/backends/gpu/codegen/triton/fusion_emitter_shared_dialect_test.cc index 60454cc8d41ad..41050c7d0a472 100644 --- a/xla/backends/gpu/codegen/triton/fusion_emitter_shared_dialect_test.cc +++ b/xla/backends/gpu/codegen/triton/fusion_emitter_shared_dialect_test.cc @@ -254,6 +254,61 @@ CHECK: %[[RES:.*]] = stablehlo.reshape %[[ARG:.*]] : (tensor<16xi32>) -> tensor< )")); } +TEST_F(XTileDialectTest, HloDotIsLoweredToStableHloDot) { + constexpr absl::string_view kHloText = R"( +HloModule t + +flhs { + ROOT flhs.p0 = f32[150,160] parameter(0) +} + +frhs { + ROOT frhs.p0 = f32[160,31] parameter(0) +} + +dot_fusion { + fdot.p0 = f32[150,160] parameter(0) + fdot.p1 = f32[160,31] parameter(1) + fdot.lhs = f32[150,160] fusion(fdot.p0), kind=kCustom, calls=flhs, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["32", "8"]}] + } + } + } + fdot.rhs = f32[160,31]{1,0} fusion(fdot.p1), kind=kCustom, calls=frhs, backend_config={ + "fusion_backend_config":{ + "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ + "output_tiles":[{"sizes":["32", "8"]}] + } + } + } + + ROOT dot = f32[150,31] dot(fdot.lhs, fdot.rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[150, 160] parameter(0) + p1 = f32[160, 31] parameter(1) + ROOT custom-call = f32[150,31] fusion(p0, p1), kind=kCustom, + calls=dot_fusion, + backend_config={"fusion_backend_config": {kind: "__triton"}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + + BlockLevelParameters block_level_parameters; + block_level_parameters.output_tile_sizes = {{32, 8}}; + + TF_EXPECT_OK(CreateXTileIrAndFileCheck( + this, *module->GetComputationWithName("dot_fusion"), + block_level_parameters, + R"( +CHECK: %[[RES:.*]] = stablehlo.dot_general %[[ARG0:.*]], %[[ARG1:.*]], contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<32x8xf32>, tensor<8x8xf32>) -> tensor<32x8xf32> +CHECK: %[[ADD_RES:.*]] = arith.addf %[[ARG2:.*]], %[[RES]] : tensor<32x8xf32> +)")); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/backends/gpu/codegen/triton/transforms/BUILD b/xla/backends/gpu/codegen/triton/transforms/BUILD index 94d389bd79bc5..1b5f43d53d713 100644 --- a/xla/backends/gpu/codegen/triton/transforms/BUILD +++ b/xla/backends/gpu/codegen/triton/transforms/BUILD @@ -60,9 +60,10 @@ cc_library( "//xla/codegen:emitter_loc_op_builder", "//xla/codegen/emitters/ir:xla", "//xla/codegen/xtile/ir:xtile", + "//xla/hlo/translate/mhlo_to_hlo:attribute_exporter", + "//xla/service:algorithm_util", "//xla/service/gpu:target_util", "//xla/service/llvm_ir:llvm_util", - "//xla/stream_executor:device_description", "//xla/stream_executor/gpu:collective_kernel_metadata", "//xla/stream_executor/gpu:tma_metadata", "@com_google_absl//absl/algorithm:container", @@ -96,5 +97,6 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@stablehlo//:stablehlo_ops", "@triton//:TritonDialects", + "@tsl//tsl/platform:tensor_float_32_hdr_lib", ], ) diff --git a/xla/backends/gpu/codegen/triton/transforms/stablehlo_lower_to_triton.cc b/xla/backends/gpu/codegen/triton/transforms/stablehlo_lower_to_triton.cc index e78f16e3abbc0..ef8f5e701454f 100644 --- a/xla/backends/gpu/codegen/triton/transforms/stablehlo_lower_to_triton.cc +++ b/xla/backends/gpu/codegen/triton/transforms/stablehlo_lower_to_triton.cc @@ -14,14 +14,19 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include #include #include "absl/log/check.h" +#include "absl/log/log.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" @@ -37,6 +42,10 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" #include "xla/codegen/emitter_loc_op_builder.h" #include "xla/codegen/xtile/ir/xtile_ops.h" +#include "xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h" +#include "xla/service/algorithm_util.h" +#include "tsl/platform/tensor_float_32_utils.h" +#include "third_party/triton/include/triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" namespace mlir::triton::xla { @@ -325,6 +334,171 @@ class LowerReshape : public mlir::OpRewritePattern { } }; +namespace { + +LogicalResult PopulateOperandPrecision(PatternRewriter& rewriter, + stablehlo::DotGeneralOp op, + stablehlo::Precision& lhs_precision, + stablehlo::Precision& rhs_precision) { + auto precision_config = op.getPrecisionConfig(); + + if (!precision_config.has_value()) { + return rewriter.notifyMatchFailure(op->getLoc(), + "Dot op must have precision config."); + } + + if (precision_config.value().size() != 2) { + return rewriter.notifyMatchFailure( + op->getLoc(), + "Dot op must have exactly two precisions. One for lhs and one for " + "rhs."); + } + + auto lhs_precision_attr = + mlir::cast(precision_config.value()[0]); + auto rhs_precision_attr = + mlir::cast(precision_config.value()[1]); + + lhs_precision = lhs_precision_attr.getValue(); + rhs_precision = rhs_precision_attr.getValue(); + + return mlir::success(); +} + +::xla::PrecisionConfig::Precision StableHloPrecisionToXlaPrecision( + stablehlo::Precision precision) { + switch (precision) { + case stablehlo::Precision::DEFAULT: + return ::xla::PrecisionConfig::DEFAULT; + case stablehlo::Precision::HIGH: + return ::xla::PrecisionConfig::HIGH; + case stablehlo::Precision::HIGHEST: + return ::xla::PrecisionConfig::HIGHEST; + default: + LOG(FATAL) << "Unsupported precision"; + } +} + +bool IsTf32Allowed(const stablehlo::Precision lhs, + const stablehlo::Precision rhs, + const ::xla::PrecisionConfig::Algorithm algorithm) { + if (algorithm == ::xla::PrecisionConfig::ALG_UNSET) { + return tsl::tensor_float_32_execution_enabled() && + StableHloPrecisionToXlaPrecision(lhs) == + ::xla::PrecisionConfig::DEFAULT && + StableHloPrecisionToXlaPrecision(rhs) == + ::xla::PrecisionConfig::DEFAULT; + } + return ::xla::algorithm_util::HasTf32InputType(algorithm); +} + +LogicalResult GetTritonInputPrecision( + PatternRewriter& rewriter, stablehlo::DotGeneralOp op, + ttir::InputPrecision& triton_input_precision) { + stablehlo::Precision lhs_precision; + stablehlo::Precision rhs_precision; + + if (mlir::failed(PopulateOperandPrecision(rewriter, op, lhs_precision, + rhs_precision))) { + return mlir::failure(); + } + + auto dot_algorithm = op.getAlgorithm(); + + auto hlo_algorithm_or_status = + dot_algorithm.has_value() + ? ::xla::ConvertDotAlgorithm(dot_algorithm.value()) + : ::xla::PrecisionConfig::ALG_UNSET; + + if (!hlo_algorithm_or_status.ok()) { + return rewriter.notifyMatchFailure( + op->getLoc(), + "Dot op must have algorithm set to be converted to " + "triton dot."); + } + + auto hlo_algorithm = hlo_algorithm_or_status.value(); + + if (hlo_algorithm == ::xla::PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3) { + triton_input_precision = ttir::InputPrecision::TF32x3; + } else if (hlo_algorithm == + ::xla::PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 || + hlo_algorithm == + ::xla::PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 || + hlo_algorithm == + ::xla::PrecisionConfig::ALG_DOT_BF16_BF16_F32_X9) { + triton_input_precision = ttir::InputPrecision::IEEE; + } else { + triton_input_precision = + IsTf32Allowed(lhs_precision, rhs_precision, hlo_algorithm) + ? ttir::InputPrecision::TF32 + : ttir::InputPrecision::IEEE; + ; + } + + return mlir::success(); +} + +} // namespace + +class LowerDotGeneral : public mlir::OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + private: + mlir::LogicalResult matchAndRewrite( + stablehlo::DotGeneralOp op, + mlir::PatternRewriter& rewriter) const override { + if (std::distance(op->getUsers().begin(), op->getUsers().end()) != 1) { + return rewriter.notifyMatchFailure( + op->getLoc(), + "Dot op must have exactly one user in order to be lowered to " + "triton."); + } + + mlir::Operation* add_op = dyn_cast(*op->getUsers().begin()); + if (!add_op) { + add_op = dyn_cast(*op->getUsers().begin()); + } + + if (!add_op) { + return rewriter.notifyMatchFailure( + op->getLoc(), + "Dot op must be consumed by an AddOp in order to be convertible to " + "triton dot."); + } + + int max_num_imprecise_acc = 0; + + if (op.getLhs().getType().getElementType().isFloat(8) || + op.getRhs().getType().getElementType().isFloat(8)) { + // For fp8 dots, disable accumulator promotion to mimick cuBLAS. It may + // make sense to enable frequent accumulator promotion at higher matmul + // precisions set in the config. + max_num_imprecise_acc = std::numeric_limits::max(); + } + + // Accumulator is the operand of add that is not the dot operation. + auto accumulator = add_op->getOperand(1) == op ? add_op->getOperand(0) + : add_op->getOperand(1); + + ttir::InputPrecision triton_input_precision; + if (mlir::failed( + GetTritonInputPrecision(rewriter, op, triton_input_precision))) { + return mlir::failure(); + } + + auto triton_dot_op = + ttir::DotOp::create(rewriter, op.getLoc(), op.getResult().getType(), + op.getLhs(), op.getRhs(), accumulator, + triton_input_precision, max_num_imprecise_acc); + + rewriter.replaceAllOpUsesWith(add_op, op.getResult()); + rewriter.replaceOp(op, triton_dot_op); + return mlir::success(); + } +}; + class StableHLOLowerToTritonPass : public impl::StableHLOLowerToTritonPassBase { public: @@ -332,7 +506,7 @@ class StableHLOLowerToTritonPass mlir::MLIRContext* mlir_context = &getContext(); mlir::RewritePatternSet patterns(mlir_context); patterns.add(mlir_context); + LowerReduce, LowerReshape, LowerDotGeneral>(mlir_context); if (mlir::failed( mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { diff --git a/xla/backends/gpu/codegen/triton/transforms/tests/stable_hlo_to_triton_lowering.mlir b/xla/backends/gpu/codegen/triton/transforms/tests/stable_hlo_to_triton_lowering.mlir index eb939634ffac8..5debaa5454168 100644 --- a/xla/backends/gpu/codegen/triton/transforms/tests/stable_hlo_to_triton_lowering.mlir +++ b/xla/backends/gpu/codegen/triton/transforms/tests/stable_hlo_to_triton_lowering.mlir @@ -164,3 +164,31 @@ func.func @reshape_2d_to_0d_reduces(%arg0: tensor<1x1xf32>) -> tensor { // CHECK: return %[[TO_TENSOR]] return %0 : tensor } + +// CHECK: func @lower_dot_add_to_triton(%[[ARG0:.*]]: tensor<2x4xf32>, %[[ARG1:.*]]: tensor<4x8xf32>, %[[ARG2:.*]]: tensor<2x8xf32>) -> tensor<2x8xf32> +func.func @lower_dot_add_to_triton(%arg0: tensor<2x4xf32>, %arg1: tensor<4x8xf32>, %arg2: tensor<2x8xf32>) -> tensor<2x8xf32> { + // CHECK: %[[RES:.*]] = tt.dot %[[ARG0]], %[[ARG1]], %[[ARG2]], inputPrecision = tf32 : tensor<2x4xf32> * tensor<4x8xf32> -> tensor<2x8xf32> + // CHECK-NOT: arith.addf + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x4xf32>, tensor<4x8xf32>) -> tensor<2x8xf32> + %1 = arith.addf %0, %arg2 : tensor<2x8xf32> + // CHECK: return %[[RES]] : tensor<2x8xf32> + return %1 : tensor<2x8xf32> +} + +// CHECK: func @lower_dot_without_add_falls_back_to_stablehlo(%[[ARG0:.*]]: tensor<2x4xf32>, %[[ARG1:.*]]: tensor<4x8xf32>, %[[ARG2:.*]]: tensor<2x8xf32>) -> tensor<2x8xf32> +func.func @lower_dot_without_add_falls_back_to_stablehlo(%arg0: tensor<2x4xf32>, %arg1: tensor<4x8xf32>, %arg2: tensor<2x8xf32>) -> tensor<2x8xf32> { + // CHECK: %[[RES:.*]] = stablehlo.dot_general %[[ARG0]], %[[ARG1]], contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x4xf32>, tensor<4x8xf32>) -> tensor<2x8xf32> + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x4xf32>, tensor<4x8xf32>) -> tensor<2x8xf32> + // CHECK: return %[[RES]] : tensor<2x8xf32> + return %0 : tensor<2x8xf32> +} + +// CHECK: func @lower_dot_f8_no_ieee_has_max_num_imprecise_acc_set_to_max(%[[ARG0:.*]]: tensor<2x4xf8E4M3FN>, %[[ARG1:.*]]: tensor<4x8xf8E4M3FN>, %[[ARG2:.*]]: tensor<2x8xf8E4M3FN>) -> tensor<2x8xf8E4M3FN> +func.func @lower_dot_f8_no_ieee_has_max_num_imprecise_acc_set_to_max(%arg0: tensor<2x4xf8E4M3FN>, %arg1: tensor<4x8xf8E4M3FN>, %arg2: tensor<2x8xf8E4M3FN>) -> tensor<2x8xf8E4M3FN> { + // CHECK: %[[RES:.*]] = tt.dot %[[ARG0]], %[[ARG1]], %[[ARG2]], inputPrecision = tf32 {maxNumImpreciseAcc = 2147483647 : i32} : tensor<2x4xf8E4M3FN> * tensor<4x8xf8E4M3FN> -> tensor<2x8xf8E4M3FN> + // CHECK-NOT: arith.addf + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x4xf8E4M3FN>, tensor<4x8xf8E4M3FN>) -> tensor<2x8xf8E4M3FN> + %1 = arith.addf %0, %arg2 : tensor<2x8xf8E4M3FN> + // CHECK: return %[[RES]] : tensor<2x8xf8E4M3FN> + return %1 : tensor<2x8xf8E4M3FN> +}