@@ -42,7 +42,7 @@ limitations under the License.
4242#include " xla/codegen/emitters/loop_kernel_emitter.h"
4343#include " xla/codegen/hlo_fusion_spec.h"
4444#include " xla/codegen/ir_emission_utils.h"
45- #include " xla/codegen/kernel_spec .h"
45+ #include " xla/codegen/kernel_definition .h"
4646#include " xla/codegen/mlir_kernel_source.h"
4747#include " xla/hlo/analysis/symbolic_expr.h"
4848#include " xla/hlo/ir/hlo_instruction.h"
@@ -57,7 +57,6 @@ limitations under the License.
5757#include " xla/runtime/work_tile_size.h"
5858#include " xla/service/buffer_assignment.h"
5959#include " xla/service/cpu/backend_config.pb.h"
60- #include " xla/service/gpu/ir_emission_utils.h"
6160#include " xla/shape.h"
6261#include " xla/shape_util.h"
6362#include " xla/tsl/platform/statusor.h"
@@ -208,7 +207,7 @@ static HloFusionSpec GetLoopFusionSpec(const HloFusionInstruction& fusion) {
208207 std::move (heroes));
209208}
210209
211- static absl::StatusOr<MlirKernelDefinition > EmitLoopFusionKernel (
210+ static absl::StatusOr<KernelDefinition<MlirKernelSource> > EmitLoopFusionKernel (
212211 SymbolicExprContext& context, const HloFusionInstruction& fusion,
213212 const BufferAssignment* buffer_assignment, absl::string_view name) {
214213 VLOG (2 ) << " Emitting loop fusion kernel: " << name;
@@ -230,9 +229,11 @@ static absl::StatusOr<MlirKernelDefinition> EmitLoopFusionKernel(
230229 return mlir_kernel_definition;
231230}
232231
233- static absl::StatusOr<MlirKernelDefinition> EmitConcatenateFusionKernel (
234- SymbolicExprContext& context, const HloFusionInstruction& fusion,
235- const BufferAssignment* buffer_assignment, absl::string_view name) {
232+ static absl::StatusOr<KernelDefinition<MlirKernelSource>>
233+ EmitConcatenateFusionKernel (SymbolicExprContext& context,
234+ const HloFusionInstruction& fusion,
235+ const BufferAssignment* buffer_assignment,
236+ absl::string_view name) {
236237 VLOG (2 ) << " Emitting concatenate fusion kernel: " << name;
237238 HloFusionSpec fusion_spec = GetLoopFusionSpec (fusion);
238239 auto work_dimensions = GetConcatenateEmitterWorkDims (fusion, fusion_spec);
@@ -252,9 +253,11 @@ static absl::StatusOr<MlirKernelDefinition> EmitConcatenateFusionKernel(
252253 return mlir_kernel_definition;
253254}
254255
255- static absl::StatusOr<MlirKernelDefinition> EmitDynamicUpdateSliceFusionKernel (
256- SymbolicExprContext& context, const HloFusionInstruction& fusion,
257- const BufferAssignment* buffer_assignment, absl::string_view name) {
256+ static absl::StatusOr<KernelDefinition<MlirKernelSource>>
257+ EmitDynamicUpdateSliceFusionKernel (SymbolicExprContext& context,
258+ const HloFusionInstruction& fusion,
259+ const BufferAssignment* buffer_assignment,
260+ absl::string_view name) {
258261 VLOG (2 ) << " Emitting dynamic update slice fusion kernel: " << name;
259262 HloFusionSpec fusion_spec = GetLoopFusionSpec (fusion);
260263 auto work_dimensions =
@@ -275,7 +278,7 @@ static absl::StatusOr<MlirKernelDefinition> EmitDynamicUpdateSliceFusionKernel(
275278 return mlir_kernel_definition;
276279}
277280
278- absl::StatusOr<MlirKernelDefinition > EmitFusionKernel (
281+ absl::StatusOr<KernelDefinition<MlirKernelSource> > EmitFusionKernel (
279282 SymbolicExprContext& context, const HloFusionInstruction& fusion,
280283 const BufferAssignment* buffer_assignment, bool use_unique_c_name) {
281284 if (fusion.fusion_kind () == HloFusionInstruction::FusionKind::kLoop ) {
0 commit comments