Skip to content

Commit b489c9a

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
[xla:codegen] Remove MlirKernelDefinition alias
PiperOrigin-RevId: 825318694
1 parent 288a88f commit b489c9a

13 files changed

+40
-37
lines changed

xla/backends/cpu/codegen/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,7 @@ cc_library(
632632
"//xla/backends/cpu:alignment",
633633
"//xla/codegen:hlo_fusion_spec",
634634
"//xla/codegen:ir_emission_utils",
635+
"//xla/codegen:kernel_definition",
635636
"//xla/codegen:kernel_spec",
636637
"//xla/codegen:mlir_kernel_source",
637638
"//xla/codegen/emitters:concatenate_kernel_emitter",

xla/backends/cpu/codegen/emitters/cpu_scatter_emitter.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,8 @@ IndexingMap GetScatterIndexingMap(
249249
{}, constraints);
250250
}
251251

252-
absl::StatusOr<MlirKernelDefinition> CpuScatterFusion::EmitKernelDefinition() {
252+
absl::StatusOr<CpuScatterFusion::KernelDefinition>
253+
CpuScatterFusion::EmitKernelDefinition() {
253254
mlir::OpBuilder builder(symbolic_expr_context_->GetMLIRContext());
254255
TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> mlir_module,
255256
CreateNamedMlirModuleOp(*fusion_, builder));
@@ -325,8 +326,8 @@ absl::StatusOr<MlirKernelDefinition> CpuScatterFusion::EmitKernelDefinition() {
325326
std::move(argument_buffers), std::move(result_buffers),
326327
std::move(invariant_arguments));
327328

328-
return MlirKernelDefinition(std::move(kernel_spec),
329-
MlirKernelSource(std::move(mlir_module)));
329+
return KernelDefinition(std::move(kernel_spec),
330+
MlirKernelSource(std::move(mlir_module)));
330331
}
331332

332333
absl::Status CpuScatterFusion::EmitEntryFunction(

xla/backends/cpu/codegen/fusion_emitter.cc

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ limitations under the License.
4242
#include "xla/codegen/emitters/loop_kernel_emitter.h"
4343
#include "xla/codegen/hlo_fusion_spec.h"
4444
#include "xla/codegen/ir_emission_utils.h"
45-
#include "xla/codegen/kernel_spec.h"
45+
#include "xla/codegen/kernel_definition.h"
4646
#include "xla/codegen/mlir_kernel_source.h"
4747
#include "xla/hlo/analysis/symbolic_expr.h"
4848
#include "xla/hlo/ir/hlo_instruction.h"
@@ -57,7 +57,6 @@ limitations under the License.
5757
#include "xla/runtime/work_tile_size.h"
5858
#include "xla/service/buffer_assignment.h"
5959
#include "xla/service/cpu/backend_config.pb.h"
60-
#include "xla/service/gpu/ir_emission_utils.h"
6160
#include "xla/shape.h"
6261
#include "xla/shape_util.h"
6362
#include "xla/tsl/platform/statusor.h"
@@ -208,7 +207,7 @@ static HloFusionSpec GetLoopFusionSpec(const HloFusionInstruction& fusion) {
208207
std::move(heroes));
209208
}
210209

211-
static absl::StatusOr<MlirKernelDefinition> EmitLoopFusionKernel(
210+
static absl::StatusOr<KernelDefinition<MlirKernelSource>> EmitLoopFusionKernel(
212211
SymbolicExprContext& context, const HloFusionInstruction& fusion,
213212
const BufferAssignment* buffer_assignment, absl::string_view name) {
214213
VLOG(2) << "Emitting loop fusion kernel: " << name;
@@ -230,9 +229,11 @@ static absl::StatusOr<MlirKernelDefinition> EmitLoopFusionKernel(
230229
return mlir_kernel_definition;
231230
}
232231

233-
static absl::StatusOr<MlirKernelDefinition> EmitConcatenateFusionKernel(
234-
SymbolicExprContext& context, const HloFusionInstruction& fusion,
235-
const BufferAssignment* buffer_assignment, absl::string_view name) {
232+
static absl::StatusOr<KernelDefinition<MlirKernelSource>>
233+
EmitConcatenateFusionKernel(SymbolicExprContext& context,
234+
const HloFusionInstruction& fusion,
235+
const BufferAssignment* buffer_assignment,
236+
absl::string_view name) {
236237
VLOG(2) << "Emitting concatenate fusion kernel: " << name;
237238
HloFusionSpec fusion_spec = GetLoopFusionSpec(fusion);
238239
auto work_dimensions = GetConcatenateEmitterWorkDims(fusion, fusion_spec);
@@ -252,9 +253,11 @@ static absl::StatusOr<MlirKernelDefinition> EmitConcatenateFusionKernel(
252253
return mlir_kernel_definition;
253254
}
254255

255-
static absl::StatusOr<MlirKernelDefinition> EmitDynamicUpdateSliceFusionKernel(
256-
SymbolicExprContext& context, const HloFusionInstruction& fusion,
257-
const BufferAssignment* buffer_assignment, absl::string_view name) {
256+
static absl::StatusOr<KernelDefinition<MlirKernelSource>>
257+
EmitDynamicUpdateSliceFusionKernel(SymbolicExprContext& context,
258+
const HloFusionInstruction& fusion,
259+
const BufferAssignment* buffer_assignment,
260+
absl::string_view name) {
258261
VLOG(2) << "Emitting dynamic update slice fusion kernel: " << name;
259262
HloFusionSpec fusion_spec = GetLoopFusionSpec(fusion);
260263
auto work_dimensions =
@@ -275,7 +278,7 @@ static absl::StatusOr<MlirKernelDefinition> EmitDynamicUpdateSliceFusionKernel(
275278
return mlir_kernel_definition;
276279
}
277280

278-
absl::StatusOr<MlirKernelDefinition> EmitFusionKernel(
281+
absl::StatusOr<KernelDefinition<MlirKernelSource>> EmitFusionKernel(
279282
SymbolicExprContext& context, const HloFusionInstruction& fusion,
280283
const BufferAssignment* buffer_assignment, bool use_unique_c_name) {
281284
if (fusion.fusion_kind() == HloFusionInstruction::FusionKind::kLoop) {

xla/backends/cpu/codegen/fusion_emitter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818

1919
#include "absl/status/statusor.h"
2020
#include "xla/codegen/emitters/kernel_arguments.h"
21+
#include "xla/codegen/kernel_definition.h"
2122
#include "xla/codegen/mlir_kernel_source.h"
2223
#include "xla/hlo/analysis/symbolic_expr.h"
2324
#include "xla/hlo/ir/hlo_instructions.h"
@@ -27,7 +28,7 @@ namespace xla::cpu {
2728

2829
emitters::KernelArguments::BufferAlignment GetDefaultBufferAlignment();
2930

30-
absl::StatusOr<MlirKernelDefinition> EmitFusionKernel(
31+
absl::StatusOr<KernelDefinition<MlirKernelSource>> EmitFusionKernel(
3132
SymbolicExprContext& context, const HloFusionInstruction& fusion,
3233
const BufferAssignment* buffer_assignment, bool use_unique_c_name);
3334

xla/backends/cpu/codegen/tools/fusion_to_mlir.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ absl::Status Run(const std::string& filename) {
3939
module->entry_computation()->root_instruction());
4040
fusion->SetAndSanitizeName("main");
4141
TF_ASSIGN_OR_RETURN(
42-
MlirKernelDefinition kernel_definition,
42+
KernelDefinition kernel_definition,
4343
EmitFusionKernel(*symbolic_expr_context, *fusion, nullptr, false));
4444
llvm::outs() << kernel_definition.source().ToString();
4545
return absl::OkStatus();

xla/backends/cpu/testlib/kernel_runner_extension.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,8 @@ NB_MODULE(_extension, kernel_runner_module) {
206206
[](SymbolicExprContext& symbolic_expr_context,
207207
const HloFusionInstruction& fusion,
208208
const BufferAssignment* buffer_assignment) {
209-
absl::StatusOr<MlirKernelDefinition> kernel_definition =
210-
EmitFusionKernel(symbolic_expr_context, fusion, buffer_assignment,
211-
false);
209+
auto kernel_definition = EmitFusionKernel(symbolic_expr_context, fusion,
210+
buffer_assignment, false);
212211
if (!kernel_definition.ok()) {
213212
throw std::runtime_error(kernel_definition.status().ToString());
214213
}
@@ -242,8 +241,8 @@ NB_MODULE(_extension, kernel_runner_module) {
242241
"KernelRunner")
243242
.def_static(
244243
"create",
245-
[](std::unique_ptr<MlirKernelDefinition,
246-
nb::deleter<MlirKernelDefinition>>
244+
[](std::unique_ptr<KernelDefinition<MlirKernelSource>,
245+
nb::deleter<KernelDefinition<MlirKernelSource>>>
247246
kernel_definition,
248247
std::unique_ptr<JitCompiler, nb::deleter<JitCompiler>>
249248
jit_compiler) {

xla/backends/cpu/testlib/mlir_kernel_emitter.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ MlirTestKernelEmitter::MlirTestKernelEmitter(absl::string_view mlir,
4747
}
4848
}
4949

50-
absl::StatusOr<MlirKernelDefinition>
50+
absl::StatusOr<MlirTestKernelEmitter::KernelDefinition>
5151
MlirTestKernelEmitter::EmitKernelDefinition() {
5252
std::unique_ptr<mlir::MLIRContext> context = FusionCompiler::CreateContext();
5353

@@ -71,6 +71,6 @@ MlirTestKernelEmitter::EmitKernelDefinition() {
7171
KernelSpec kernel_spec(kernel_name_, num_workgroups_,
7272
std::move(argument_buffers), std::move(result_buffers),
7373
/*invariant_arguments=*/{});
74-
return MlirKernelDefinition(std::move(kernel_spec), std::move(source));
74+
return KernelDefinition(std::move(kernel_spec), std::move(source));
7575
}
7676
} // namespace xla::cpu

xla/codegen/emitters/concatenate_kernel_emitter.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ ConcatenateFusionKernelEmitter::ConcatenateFusionKernelEmitter(
8989
entry_function_name_(entry_function_name),
9090
backend_kind_(backend_kind) {}
9191

92-
absl::StatusOr<MlirKernelDefinition>
92+
absl::StatusOr<ConcatenateFusionKernelEmitter::KernelDefinition>
9393
ConcatenateFusionKernelEmitter::EmitKernelDefinition() {
9494
mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext());
9595
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name()));
@@ -121,8 +121,8 @@ ConcatenateFusionKernelEmitter::EmitKernelDefinition() {
121121
GetKernelSpec(entry_function_name_, fusion_,
122122
buffer_assignment_, work_dimensions_));
123123

124-
return MlirKernelDefinition(std::move(kernel_spec),
125-
MlirKernelSource(std::move(module)));
124+
return KernelDefinition(std::move(kernel_spec),
125+
MlirKernelSource(std::move(module)));
126126
}
127127

128128
const Shape& ConcatenateFusionKernelEmitter::GetIndexingShape(

xla/codegen/emitters/dynamic_update_slice_kernel_emitter.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ DynamicUpdateSliceKernelEmitter::DynamicUpdateSliceKernelEmitter(
9292
entry_function_name_(entry_function_name),
9393
backend_kind_(backend_kind) {}
9494

95-
absl::StatusOr<MlirKernelDefinition>
95+
absl::StatusOr<DynamicUpdateSliceKernelEmitter::KernelDefinition>
9696
DynamicUpdateSliceKernelEmitter::EmitKernelDefinition() {
9797
mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext());
9898
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name()));
@@ -120,8 +120,8 @@ DynamicUpdateSliceKernelEmitter::EmitKernelDefinition() {
120120

121121
TF_ASSIGN_OR_RETURN(auto kernel_spec, GetKernelSpec());
122122

123-
return MlirKernelDefinition(std::move(kernel_spec),
124-
MlirKernelSource(std::move(module)));
123+
return KernelDefinition(std::move(kernel_spec),
124+
MlirKernelSource(std::move(module)));
125125
}
126126

127127
IndexingMap DynamicUpdateSliceKernelEmitter::ComputeWorkItemIdToInputIndexing(

xla/codegen/emitters/loop_kernel_emitter.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ LoopFusionKernelEmitter::LoopFusionKernelEmitter(
8484
entry_function_name_(entry_function_name),
8585
backend_kind_(backend_kind) {}
8686

87-
absl::StatusOr<MlirKernelDefinition>
87+
absl::StatusOr<LoopFusionKernelEmitter::KernelDefinition>
8888
LoopFusionKernelEmitter::EmitKernelDefinition() {
8989
mlir::OpBuilder builder(symbolic_expr_context_.GetMLIRContext());
9090
auto loc = mlir::NameLoc::get(builder.getStringAttr(fusion_.name()));
@@ -114,8 +114,8 @@ LoopFusionKernelEmitter::EmitKernelDefinition() {
114114
GetKernelSpec(entry_function_name_, fusion_,
115115
buffer_assignment_, work_dimensions_));
116116

117-
return MlirKernelDefinition(std::move(kernel_spec),
118-
MlirKernelSource(std::move(module)));
117+
return KernelDefinition(std::move(kernel_spec),
118+
MlirKernelSource(std::move(module)));
119119
}
120120

121121
IndexingMap LoopFusionKernelEmitter::ComputeWorkItemIdToOutputIndexing(

0 commit comments

Comments
 (0)