Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/backends/gpu/codegen/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:Support",
"@stablehlo//:stablehlo_ops",
"@triton//:TritonDialects",
"@tsl//tsl/platform:tensor_float_32_hdr_lib",
],
Expand Down
66 changes: 33 additions & 33 deletions xla/backends/gpu/codegen/triton/dot_algorithms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/backends/gpu/codegen/triton/emitter_helpers.h"
#include "xla/codegen/emitter_loc_op_builder.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand Down Expand Up @@ -71,8 +72,8 @@ struct PrecisionSpec {
// are currently a (XLA-wide) bridge to work with ALG_UNSET.
PrecisionConfig::Precision lhs_operand_precision;
PrecisionConfig::Precision rhs_operand_precision;
// Encodes `tt.dot`'s `inputPrecision` attribute.
ttir::InputPrecision ttir_input_precision;
// Encodes `stablehlo.dot`'s `precision` attribute.
mlir::stablehlo::Precision stablehlo_input_precision;
};

using AlgorithmEmitter = absl::StatusOr<Value> (*)(EmitterLocOpBuilder,
Expand Down Expand Up @@ -171,10 +172,27 @@ absl::StatusOr<Value> ScaledDot(EmitterLocOpBuilder b,
rhs_dot_elem_type, true);
}

namespace {

Value EmitStableHloDotAndAdd(EmitterLocOpBuilder b, Value lhs, Value rhs,
Value acc,
mlir::stablehlo::Precision input_precision) {
mlir::stablehlo::PrecisionAttr precisionAttr =
mlir::stablehlo::PrecisionAttr::get(b.getContext(), input_precision);
mlir::ArrayAttr precisionConfig =
mlir::ArrayAttr::get(b.getContext(), {precisionAttr});
auto dot = b.create<mlir::stablehlo::DotOp>(acc.getType(), lhs, rhs,
precisionConfig);

return b.create<mlir::stablehlo::AddOp>(acc, dot);
}

} // namespace

Value IEEEDot(EmitterLocOpBuilder b, Value lhs, Value rhs, Value acc) {
return b.create<ttir::DotOp>(lhs, rhs, acc,
/*inputPrecision=*/ttir::InputPrecision::IEEE,
/*maxNumImpreciseAcc=*/0);
return EmitStableHloDotAndAdd(
b, lhs, rhs, acc,
/*input_precision=*/mlir::stablehlo::Precision::HIGHEST);
}

// Leverages BF16 datatype for F32 matmul computation. It follows the guidance
Expand Down Expand Up @@ -285,14 +303,14 @@ bool IsTf32Allowed(const HloDotInstruction& dot) {
return algorithm_util::HasTf32InputType(precision_config.algorithm());
}

ttir::InputPrecision InferDotPrecision(const HloDotInstruction& dot) {
mlir::stablehlo::Precision InferDotPrecision(const HloDotInstruction& dot) {
if (dot.precision_config().algorithm() ==
PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3) {
return ttir::InputPrecision::TF32x3;
return mlir::stablehlo::Precision::HIGH;
}

return IsTf32Allowed(dot) ? ttir::InputPrecision::TF32
: ttir::InputPrecision::IEEE;
return IsTf32Allowed(dot) ? mlir::stablehlo::Precision::DEFAULT
: mlir::stablehlo::Precision::HIGHEST;
}

absl::StatusOr<Type> GetAlgUnsetAccumulatorType(EmitterLocOpBuilder b,
Expand Down Expand Up @@ -334,18 +352,9 @@ absl::StatusOr<Value> EmitDotAlgUnset(EmitterLocOpBuilder b,
Value rhs = dot_operands.rhs;
Value acc = dot_operands.accumulator;

int max_num_imprecise_acc = 0;
if (ElementType(lhs).isFloat(8) || ElementType(rhs).isFloat(8)) {
// For fp8 dots, disable accumulator promotion to mimick cuBLAS. It may make
// sense to enable frequent accumulator promotion at higher matmul
// precisions set in the config.
max_num_imprecise_acc = std::numeric_limits<int>::max();
}

return b.create<ttir::DotOp>(
lhs, rhs, acc,
/*inputPrecision=*/precision_spec.ttir_input_precision,
/*maxNumImpreciseAcc=*/max_num_imprecise_acc);
return EmitStableHloDotAndAdd(
b, lhs, rhs, acc,
/*input_precision=*/precision_spec.stablehlo_input_precision);
}

absl::StatusOr<Value> EmitRegularDot(EmitterLocOpBuilder b,
Expand All @@ -354,14 +363,6 @@ absl::StatusOr<Value> EmitRegularDot(EmitterLocOpBuilder b,
Value lhs = dot_operands.lhs;
Value rhs = dot_operands.rhs;

int max_num_imprecise_acc = 0;
if (ElementType(lhs).isFloat(8) || ElementType(rhs).isFloat(8)) {
// For fp8 dots, disable accumulator promotion to mimick cuBLAS. It may make
// sense to enable frequent accumulator promotion at higher matmul
// precisions set in the config.
max_num_imprecise_acc = std::numeric_limits<int>::max();
}

// Cast F32 inputs to BF16 if the algorithm is BF16_BF16_F32.
// TODO(bchetioui): abstract this.
if (precision_spec.algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32) {
Expand All @@ -374,10 +375,9 @@ absl::StatusOr<Value> EmitRegularDot(EmitterLocOpBuilder b,
}
}

return b.create<ttir::DotOp>(
dot_operands.lhs, dot_operands.rhs, dot_operands.accumulator,
/*inputPrecision=*/precision_spec.ttir_input_precision,
/*maxNumImpreciseAcc=*/max_num_imprecise_acc);
return EmitStableHloDotAndAdd(
b, dot_operands.lhs, dot_operands.rhs, dot_operands.accumulator,
/*input_precision=*/precision_spec.stablehlo_input_precision);
}

// Returns an emitter for the given dot algorithm. Raises an
Expand Down
130 changes: 113 additions & 17 deletions xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3204,10 +3204,20 @@ ENTRY entry {
"num_ctas":"1",
"num_stages":"1"}}}
})";
// We expect that for loop instruction will be optimized away.
TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "fdot", R"(

TF_ASSERT_OK_AND_ASSIGN(auto xtile_module_and_hlo_module,
CreateXTileIrAndFileCheck(this, kHloText, "fdot",
R"(
CHECK: stablehlo.dot
CHECK: stablehlo.add
)"));

TF_ASSERT_OK(LowerXTileIrToTritonAndFileCheck(
this, xtile_module_and_hlo_module.first.get(), R"(
CHECK: tt.dot {{.*}} -> tensor<16x16xf32>
)"));
)",
GetFusionInstruction(*xtile_module_and_hlo_module.second, "fdot")));

EXPECT_TRUE(RunAndCompareNoHloPasses(
kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
}
Expand Down Expand Up @@ -3267,8 +3277,32 @@ ENTRY entry {

const bool is_tma_allowed = GetParam();
std::string hlo_text = absl::Substitute(kHloTextTemplate, is_tma_allowed);
TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, hlo_text, "fdot", R"(
CHECK: xtile.entry_func @triton_fn(%[[ARG0:[A-Za-z0-9_]*]]: memref<32x123xf32>

TF_ASSERT_OK_AND_ASSIGN(auto xtile_module_and_hlo_module,
CreateXTileIrAndFileCheck(this, hlo_text, "fdot",
R"(
CHECK: xtile.entry_func @xtile_dialect_fn(%[[ARG0:[A-Za-z0-9_]*]]: memref<32x123xf32>
CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: memref<123x512xf32>
CHECK-SAME: %[[ARG2:[A-Za-z0-9_]*]]: memref<32x512xf32>
CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
CHECK: {{.*}} = scf.for %{{.*}} = %[[C0]] to %[[C4]] step %[[C1]]
CHECK-SAME: iter_args({{.*}}) -> (tensor<16x64xf32>) {
CHECK-DAG: xtile.extract %[[ARG0]]
CHECK-DAG: xtile.extract %[[ARG1]]
CHECK-DAG: arith.subf {{.*}} : tensor<16x32xf32>
CHECK-DAG: math.absf {{.*}} : tensor<32x64xf32>
CHECK: stablehlo.dot {{.*}} (tensor<16x32xf32>, tensor<32x64xf32>) -> tensor<16x64xf32>
CHECK: stablehlo.add {{.*}}
CHECK: scf.yield {{.*}} : tensor<16x64xf32>
CHECK-COUNT-1: xtile.insert

)"));

TF_ASSERT_OK(LowerXTileIrToTritonAndFileCheck(
this, xtile_module_and_hlo_module.first.get(), R"(
CHECK: xtile.entry_func @xtile_dialect_fn(%[[ARG0:[A-Za-z0-9_]*]]: memref<32x123xf32>
CHECK-SAME: %[[ARG1:[A-Za-z0-9_]*]]: memref<123x512xf32>
CHECK-SAME: %[[ARG2:[A-Za-z0-9_]*]]: memref<32x512xf32>
CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
Expand All @@ -3283,7 +3317,9 @@ CHECK-DAG: math.absf {{.*}} : tensor<32x64xf32>
CHECK: tt.dot {{.*}} tensor<16x32xf32> * tensor<32x64xf32> -> tensor<16x64xf32>
CHECK: scf.yield {{.*}} : tensor<16x64xf32>
CHECK-COUNT-1: xtile.insert
)"));
)",
GetFusionInstruction(*xtile_module_and_hlo_module.second, "fdot")));

EXPECT_TRUE(RunAndCompareNoHloPasses(
hlo_text, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
}
Expand Down Expand Up @@ -3405,7 +3441,21 @@ ENTRY entry {
"num_stages":"1"}}}
})";

TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "fdot", R"(
TF_ASSERT_OK_AND_ASSIGN(auto xtile_module_and_hlo_module,
CreateXTileIrAndFileCheck(this, kHloText, "fdot", R"(
// Ensure that masking is applied only conditionally to both operands.
CHECK: %[[MASKED_OPERAND0:.*]] = scf.if
CHECK: %[[SELECT0:.*]] = arith.select
CHECK-NEXT: scf.yield %[[SELECT0]]
CHECK: %[[MASKED_OPERAND1:.*]] = scf.if
CHECK: %[[SELECT1:.*]] = arith.select
CHECK-NEXT: scf.yield %[[SELECT1]]
CHECK: stablehlo.dot %[[MASKED_OPERAND0]], %[[MASKED_OPERAND1]]
CHECK: stablehlo.add %{{.*}}
)"));

TF_ASSERT_OK(LowerXTileIrToTritonAndFileCheck(
this, xtile_module_and_hlo_module.first.get(), R"(
// Ensure that masking is applied only conditionally to both operands.
CHECK: %[[MASKED_OPERAND0:.*]] = scf.if
CHECK: %[[SELECT0:.*]] = arith.select
Expand All @@ -3414,7 +3464,8 @@ ENTRY entry {
CHECK: %[[SELECT1:.*]] = arith.select
CHECK-NEXT: scf.yield %[[SELECT1]]
CHECK: tt.dot %[[MASKED_OPERAND0]], %[[MASKED_OPERAND1]]
)"));
)",
GetFusionInstruction(*xtile_module_and_hlo_module.second, "fdot")));

EXPECT_TRUE(RunAndCompareNoHloPasses(
kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6}));
Expand Down Expand Up @@ -3952,13 +4003,26 @@ TEST_P(BasicDotAlgorithmEmitterTest, BasicAlgorithmIsEmittedCorrectly) {
algorithm_util::GetDotAccumulatorType(algorithm));
const std::string kHloText = GetDotAlgorithmHlo(in_ty, out_ty, algorithm);

TF_EXPECT_OK(CreateTritonIrAndFileCheck(
this, kHloText, "dot",
TF_ASSERT_OK_AND_ASSIGN(
auto xtile_module_and_hlo_module,
CreateXTileIrAndFileCheck(
this, kHloText, "dot",
absl::Substitute(
R"(
CHECK: stablehlo.dot{{.*}} : (tensor<16x32x$0>, tensor<32x64x$0>) -> tensor<16x64x$1>
CHECK: stablehlo.add
)",
primitive_util::LowercasePrimitiveTypeName(in_ty),
primitive_util::LowercasePrimitiveTypeName(out_ty))));

TF_ASSERT_OK(LowerXTileIrToTritonAndFileCheck(
this, xtile_module_and_hlo_module.first.get(),
absl::Substitute(R"(
CHECK: tt.dot{{.*}} : tensor<16x32x$0> * tensor<32x64x$0> -> tensor<16x64x$1>
)",
primitive_util::LowercasePrimitiveTypeName(in_ty),
primitive_util::LowercasePrimitiveTypeName(out_ty))));
primitive_util::LowercasePrimitiveTypeName(out_ty)),
GetFusionInstruction(*xtile_module_and_hlo_module.second, "dot")));

EXPECT_TRUE(
RunAndCompareNoHloPasses(kHloText, ErrorSpecForDotAlgorithm(algorithm)));
Expand Down Expand Up @@ -3993,6 +4057,7 @@ TEST_P(MultiDotAlgorithmEmitterTest, MultiDotAlgorithmIsEmittedCorrectly) {
// Dummy value to ensure that the dot count is explicitly set.
int dot_count_for_algorithm = 0x1337;
std::string input_precision_string = "";
std::string stablehlo_input_precision_string = "";
switch (algorithm) {
case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3:
dot_count_for_algorithm = 3;
Expand All @@ -4006,6 +4071,7 @@ TEST_P(MultiDotAlgorithmEmitterTest, MultiDotAlgorithmIsEmittedCorrectly) {
case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3:
// Triton implements TF32x3 as a specific precision mode.
input_precision_string = "tf32x3";
stablehlo_input_precision_string = "HIGH";
dot_count_for_algorithm = 1;
break;
default:
Expand All @@ -4015,14 +4081,27 @@ TEST_P(MultiDotAlgorithmEmitterTest, MultiDotAlgorithmIsEmittedCorrectly) {

const std::string kHloText = GetDotAlgorithmHlo(in_ty, out_ty, algorithm);

TF_EXPECT_OK(CreateTritonIrAndFileCheck(
this, kHloText, "dot",
TF_ASSERT_OK_AND_ASSIGN(
auto xtile_module_and_hlo_module,
CreateXTileIrAndFileCheck(
this, kHloText, "dot",
absl::Substitute(
R"(
CHECK-COUNT-$2: stablehlo.dot{{.*}}$3{{.*}} : (tensor<16x32x$0>, tensor<32x64x$0>) -> tensor<16x64x$1>
)",
primitive_util::LowercasePrimitiveTypeName(in_ty),
primitive_util::LowercasePrimitiveTypeName(out_ty),
dot_count_for_algorithm, stablehlo_input_precision_string)));

TF_ASSERT_OK(LowerXTileIrToTritonAndFileCheck(
this, xtile_module_and_hlo_module.first.get(),
absl::Substitute(R"(
CHECK-COUNT-$2: tt.dot{{.*}}$3{{.*}} : tensor<16x32x$0> * tensor<32x64x$0> -> tensor<16x64x$1>
)",
primitive_util::LowercasePrimitiveTypeName(in_ty),
primitive_util::LowercasePrimitiveTypeName(out_ty),
dot_count_for_algorithm, input_precision_string)));
dot_count_for_algorithm, input_precision_string),
GetFusionInstruction(*xtile_module_and_hlo_module.second, "dot")));

EXPECT_TRUE(
RunAndCompareNoHloPasses(kHloText, ErrorSpecForDotAlgorithm(algorithm)));
Expand Down Expand Up @@ -4051,18 +4130,35 @@ TEST_P(TF32DotAlgorithmEmitterTest, TF32AlgorithmsUseTF32InputPrecision) {
algorithm_util::GetDotAccumulatorType(algorithm));
const std::string kHloText = GetDotAlgorithmHlo(in_ty, out_ty, algorithm);

std::string stablehlo_input_precision_string =
algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3 ? "HIGH"
: "DEFAULT";

std::string input_precision_string =
algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3 ? "tf32x3"
: "tf32";

TF_EXPECT_OK(CreateTritonIrAndFileCheck(
this, kHloText, "dot",
TF_ASSERT_OK_AND_ASSIGN(
auto xtile_module_and_hlo_module,
CreateXTileIrAndFileCheck(
this, kHloText, "dot",
absl::Substitute(R"(
CHECK: stablehlo.dot{{.*}}precision = [$2] : (tensor<16x32x$0>, tensor<32x64x$0>) -> tensor<16x64x$1>
)",
primitive_util::LowercasePrimitiveTypeName(in_ty),
primitive_util::LowercasePrimitiveTypeName(out_ty),
stablehlo_input_precision_string)));

TF_ASSERT_OK(LowerXTileIrToTritonAndFileCheck(
this, xtile_module_and_hlo_module.first.get(),
absl::Substitute(R"(
CHECK: tt.dot{{.*}} inputPrecision = $2 : tensor<16x32x$0> * tensor<32x64x$0> -> tensor<16x64x$1>
)",
primitive_util::LowercasePrimitiveTypeName(in_ty),
primitive_util::LowercasePrimitiveTypeName(out_ty),
input_precision_string)));
input_precision_string),
GetFusionInstruction(*xtile_module_and_hlo_module.second, "dot")));

// No need to `RunAndCompare` here, these algorithms are already covered by
// other tests.
}
Expand Down
Loading
Loading