Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/backends/cpu/codegen/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ cc_library(
"//xla/backends/cpu:alignment",
"//xla/codegen:hlo_fusion_spec",
"//xla/codegen:ir_emission_utils",
"//xla/codegen:kernel_definition",
"//xla/codegen:kernel_spec",
"//xla/codegen:mlir_kernel_source",
"//xla/codegen/emitters:concatenate_kernel_emitter",
Expand Down
7 changes: 4 additions & 3 deletions xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ IndexingMap GetScatterIndexingMap(
{}, constraints);
}

absl::StatusOr<MlirKernelDefinition> CpuScatterFusion::EmitKernelDefinition() {
absl::StatusOr<CpuScatterFusion::KernelDefinition>
CpuScatterFusion::EmitKernelDefinition() {
mlir::OpBuilder builder(symbolic_expr_context_->GetMLIRContext());
TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> mlir_module,
CreateNamedMlirModuleOp(*fusion_, builder));
Expand Down Expand Up @@ -325,8 +326,8 @@ absl::StatusOr<MlirKernelDefinition> CpuScatterFusion::EmitKernelDefinition() {
std::move(argument_buffers), std::move(result_buffers),
std::move(invariant_arguments));

return MlirKernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(mlir_module)));
return KernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(mlir_module)));
}

absl::Status CpuScatterFusion::EmitEntryFunction(
Expand Down
23 changes: 13 additions & 10 deletions xla/backends/cpu/codegen/fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ limitations under the License.
#include "xla/codegen/emitters/loop_kernel_emitter.h"
#include "xla/codegen/hlo_fusion_spec.h"
#include "xla/codegen/ir_emission_utils.h"
#include "xla/codegen/kernel_spec.h"
#include "xla/codegen/kernel_definition.h"
#include "xla/codegen/mlir_kernel_source.h"
#include "xla/hlo/analysis/symbolic_expr.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand All @@ -57,7 +57,6 @@ limitations under the License.
#include "xla/runtime/work_tile_size.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/cpu/backend_config.pb.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/platform/statusor.h"
Expand Down Expand Up @@ -208,7 +207,7 @@ static HloFusionSpec GetLoopFusionSpec(const HloFusionInstruction& fusion) {
std::move(heroes));
}

static absl::StatusOr<MlirKernelDefinition> EmitLoopFusionKernel(
static absl::StatusOr<KernelDefinition<MlirKernelSource>> EmitLoopFusionKernel(
SymbolicExprContext& context, const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment, absl::string_view name) {
VLOG(2) << "Emitting loop fusion kernel: " << name;
Expand All @@ -230,9 +229,11 @@ static absl::StatusOr<MlirKernelDefinition> EmitLoopFusionKernel(
return mlir_kernel_definition;
}

static absl::StatusOr<MlirKernelDefinition> EmitConcatenateFusionKernel(
SymbolicExprContext& context, const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment, absl::string_view name) {
static absl::StatusOr<KernelDefinition<MlirKernelSource>>
EmitConcatenateFusionKernel(SymbolicExprContext& context,
const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment,
absl::string_view name) {
VLOG(2) << "Emitting concatenate fusion kernel: " << name;
HloFusionSpec fusion_spec = GetLoopFusionSpec(fusion);
auto work_dimensions = GetConcatenateEmitterWorkDims(fusion, fusion_spec);
Expand All @@ -252,9 +253,11 @@ static absl::StatusOr<MlirKernelDefinition> EmitConcatenateFusionKernel(
return mlir_kernel_definition;
}

static absl::StatusOr<MlirKernelDefinition> EmitDynamicUpdateSliceFusionKernel(
SymbolicExprContext& context, const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment, absl::string_view name) {
static absl::StatusOr<KernelDefinition<MlirKernelSource>>
EmitDynamicUpdateSliceFusionKernel(SymbolicExprContext& context,
const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment,
absl::string_view name) {
VLOG(2) << "Emitting dynamic update slice fusion kernel: " << name;
HloFusionSpec fusion_spec = GetLoopFusionSpec(fusion);
auto work_dimensions =
Expand All @@ -275,7 +278,7 @@ static absl::StatusOr<MlirKernelDefinition> EmitDynamicUpdateSliceFusionKernel(
return mlir_kernel_definition;
}

absl::StatusOr<MlirKernelDefinition> EmitFusionKernel(
absl::StatusOr<KernelDefinition<MlirKernelSource>> EmitFusionKernel(
SymbolicExprContext& context, const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment, bool use_unique_c_name) {
if (fusion.fusion_kind() == HloFusionInstruction::FusionKind::kLoop) {
Expand Down
3 changes: 2 additions & 1 deletion xla/backends/cpu/codegen/fusion_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#include "absl/status/statusor.h"
#include "xla/codegen/emitters/kernel_arguments.h"
#include "xla/codegen/kernel_definition.h"
#include "xla/codegen/mlir_kernel_source.h"
#include "xla/hlo/analysis/symbolic_expr.h"
#include "xla/hlo/ir/hlo_instructions.h"
Expand All @@ -27,7 +28,7 @@ namespace xla::cpu {

emitters::KernelArguments::BufferAlignment GetDefaultBufferAlignment();

absl::StatusOr<MlirKernelDefinition> EmitFusionKernel(
absl::StatusOr<KernelDefinition<MlirKernelSource>> EmitFusionKernel(
SymbolicExprContext& context, const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment, bool use_unique_c_name);

Expand Down
2 changes: 1 addition & 1 deletion xla/backends/cpu/codegen/tools/fusion_to_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ absl::Status Run(const std::string& filename) {
module->entry_computation()->root_instruction());
fusion->SetAndSanitizeName("main");
TF_ASSIGN_OR_RETURN(
MlirKernelDefinition kernel_definition,
KernelDefinition kernel_definition,
EmitFusionKernel(*symbolic_expr_context, *fusion, nullptr, false));
llvm::outs() << kernel_definition.source().ToString();
return absl::OkStatus();
Expand Down
9 changes: 4 additions & 5 deletions xla/backends/cpu/testlib/kernel_runner_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,8 @@ NB_MODULE(_extension, kernel_runner_module) {
[](SymbolicExprContext& symbolic_expr_context,
const HloFusionInstruction& fusion,
const BufferAssignment* buffer_assignment) {
absl::StatusOr<MlirKernelDefinition> kernel_definition =
EmitFusionKernel(symbolic_expr_context, fusion, buffer_assignment,
false);
auto kernel_definition = EmitFusionKernel(symbolic_expr_context, fusion,
buffer_assignment, false);
if (!kernel_definition.ok()) {
throw std::runtime_error(kernel_definition.status().ToString());
}
Expand Down Expand Up @@ -242,8 +241,8 @@ NB_MODULE(_extension, kernel_runner_module) {
"KernelRunner")
.def_static(
"create",
[](std::unique_ptr<MlirKernelDefinition,
nb::deleter<MlirKernelDefinition>>
[](std::unique_ptr<KernelDefinition<MlirKernelSource>,
nb::deleter<KernelDefinition<MlirKernelSource>>>
kernel_definition,
std::unique_ptr<JitCompiler, nb::deleter<JitCompiler>>
jit_compiler) {
Expand Down
4 changes: 2 additions & 2 deletions xla/backends/cpu/testlib/mlir_kernel_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ MlirTestKernelEmitter::MlirTestKernelEmitter(absl::string_view mlir,
}
}

absl::StatusOr<MlirKernelDefinition>
absl::StatusOr<MlirTestKernelEmitter::KernelDefinition>
MlirTestKernelEmitter::EmitKernelDefinition() {
std::unique_ptr<mlir::MLIRContext> context = FusionCompiler::CreateContext();

Expand All @@ -71,6 +71,6 @@ MlirTestKernelEmitter::EmitKernelDefinition() {
KernelSpec kernel_spec(kernel_name_, num_workgroups_,
std::move(argument_buffers), std::move(result_buffers),
/*invariant_arguments=*/{});
return MlirKernelDefinition(std::move(kernel_spec), std::move(source));
return KernelDefinition(std::move(kernel_spec), std::move(source));
}
} // namespace xla::cpu
6 changes: 3 additions & 3 deletions xla/codegen/emitters/concatenate_kernel_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ ConcatenateFusionKernelEmitter::ConcatenateFusionKernelEmitter(
entry_function_name_(entry_function_name),
backend_kind_(backend_kind) {}

absl::StatusOr<MlirKernelDefinition>
absl::StatusOr<ConcatenateFusionKernelEmitter::KernelDefinition>
ConcatenateFusionKernelEmitter::EmitKernelDefinition() {
mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext());
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name()));
Expand Down Expand Up @@ -121,8 +121,8 @@ ConcatenateFusionKernelEmitter::EmitKernelDefinition() {
GetKernelSpec(entry_function_name_, fusion_,
buffer_assignment_, work_dimensions_));

return MlirKernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(module)));
return KernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(module)));
}

const Shape& ConcatenateFusionKernelEmitter::GetIndexingShape(
Expand Down
6 changes: 3 additions & 3 deletions xla/codegen/emitters/dynamic_update_slice_kernel_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ DynamicUpdateSliceKernelEmitter::DynamicUpdateSliceKernelEmitter(
entry_function_name_(entry_function_name),
backend_kind_(backend_kind) {}

absl::StatusOr<MlirKernelDefinition>
absl::StatusOr<DynamicUpdateSliceKernelEmitter::KernelDefinition>
DynamicUpdateSliceKernelEmitter::EmitKernelDefinition() {
mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext());
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name()));
Expand Down Expand Up @@ -120,8 +120,8 @@ DynamicUpdateSliceKernelEmitter::EmitKernelDefinition() {

TF_ASSIGN_OR_RETURN(auto kernel_spec, GetKernelSpec());

return MlirKernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(module)));
return KernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(module)));
}

IndexingMap DynamicUpdateSliceKernelEmitter::ComputeWorkItemIdToInputIndexing(
Expand Down
6 changes: 3 additions & 3 deletions xla/codegen/emitters/loop_kernel_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ LoopFusionKernelEmitter::LoopFusionKernelEmitter(
entry_function_name_(entry_function_name),
backend_kind_(backend_kind) {}

absl::StatusOr<MlirKernelDefinition>
absl::StatusOr<LoopFusionKernelEmitter::KernelDefinition>
LoopFusionKernelEmitter::EmitKernelDefinition() {
mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext());
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name()));
Expand Down Expand Up @@ -114,8 +114,8 @@ LoopFusionKernelEmitter::EmitKernelDefinition() {
GetKernelSpec(entry_function_name_, fusion_,
buffer_assignment_, work_dimensions_));

return MlirKernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(module)));
return KernelDefinition(std::move(kernel_spec),
MlirKernelSource(std::move(module)));
}

IndexingMap LoopFusionKernelEmitter::ComputeWorkItemIdToOutputIndexing(
Expand Down
2 changes: 0 additions & 2 deletions xla/codegen/mlir_kernel_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ class MlirKernelSource final : public KernelSource {
Storage storage_;
};

using MlirKernelDefinition = KernelDefinition<MlirKernelSource>; // NOLINT

} // namespace xla

#endif // XLA_CODEGEN_MLIR_KERNEL_SOURCE_H_
6 changes: 3 additions & 3 deletions xla/service/cpu/parallel_fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ absl::StatusOr<KernelSpec> ParallelFusionEmitter::AddFusion(
// fixed but will require a rework of the ThunkEmitter.
auto compiler_instance = fusion_compiler_pool_->GetInstance();
TF_ASSIGN_OR_RETURN(
MlirKernelDefinition mlir_kernel_definition,
KernelDefinition mlir_kernel_definition,
EmitFusionKernel(*compiler_instance->symbolic_expr_context, *fusion,
buffer_assignment_, use_unique_c_name_));

Expand All @@ -181,8 +181,8 @@ absl::StatusOr<KernelSpec> ParallelFusionEmitter::AddFusion(
}

KernelSpec spec = mlir_kernel_definition.spec();
auto shared_source =
std::make_shared<MlirKernelDefinition>(std::move(mlir_kernel_definition));
auto shared_source = std::make_shared<KernelDefinition<MlirKernelSource>>(
std::move(mlir_kernel_definition));

thread_pool_.Schedule(absl::bind_front(&ParallelFusionEmitter::CompileFusion,
this, std::move(shared_source),
Expand Down
2 changes: 1 addition & 1 deletion xla/service/cpu/thunk_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitFusionKernelThunk(
auto kernel_emitter = std::make_unique<CpuScatterFusion>(
buffer_assignment_, fusion, &symbolic_expr_context_);

TF_ASSIGN_OR_RETURN(MlirKernelDefinition kernel_definition,
TF_ASSIGN_OR_RETURN(KernelDefinition kernel_definition,
kernel_emitter->EmitKernelDefinition());

auto kernel_spec = kernel_definition.spec();
Expand Down
Loading