From 8abf1a1608a24a6a4f3aa806671a40dc192cce3b Mon Sep 17 00:00:00 2001 From: Will Froom Date: Wed, 29 Oct 2025 18:13:43 -0700 Subject: [PATCH] [XLA:CPU][XTile] Create lowering for Iota. PiperOrigin-RevId: 825789498 --- .../cpu/codegen/tiled/transforms/passes.td | 7 +++-- .../tiled/transforms/shlo_to_vector.cc | 31 ++++++++++++++++++- .../transforms/tests/shlo_to_vector.mlir | 10 ++++++ 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/xla/backends/cpu/codegen/tiled/transforms/passes.td b/xla/backends/cpu/codegen/tiled/transforms/passes.td index de620ccec8c6d..88e4df5ca6ea2 100644 --- a/xla/backends/cpu/codegen/tiled/transforms/passes.td +++ b/xla/backends/cpu/codegen/tiled/transforms/passes.td @@ -46,11 +46,12 @@ def ShloToVectorPass : Pass<"xtile-cpu-shlo-to-vector", "mlir::ModuleOp"> { let constructor = "CreateShloToVectorPass()"; let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::scf::SCFDialect", + "mlir::stablehlo::StablehloDialect", "mlir::tensor::TensorDialect", "mlir::vector::VectorDialect", - "mlir::stablehlo::StablehloDialect", - "mlir::scf::SCFDialect", - "mlir::memref::MemRefDialect", ]; } diff --git a/xla/backends/cpu/codegen/tiled/transforms/shlo_to_vector.cc b/xla/backends/cpu/codegen/tiled/transforms/shlo_to_vector.cc index 01e33977a006b..37b68b0cefc9f 100644 --- a/xla/backends/cpu/codegen/tiled/transforms/shlo_to_vector.cc +++ b/xla/backends/cpu/codegen/tiled/transforms/shlo_to_vector.cc @@ -317,6 +317,35 @@ struct LowerReshape : mlir::OpRewritePattern { } }; +struct LowerIota : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::stablehlo::IotaOp op, + mlir::PatternRewriter& rewriter) const override { + if (op.getType().getRank() != 1) { + return rewriter.notifyMatchFailure( + op, "iota op with rank != 1 is not supported"); + } + + auto result_vector_type = GetVectorType(op.getType()); + auto element_type = result_vector_type.getElementType(); + int64_t iota_size = result_vector_type.getNumElements(); + + llvm::SmallVector iota_values(iota_size); + for (int idx = 0; idx != iota_size; ++idx) { + iota_values[idx] = rewriter.getIntegerAttr(element_type, idx); + } + + mlir::Value iota_const = mlir::arith::ConstantOp::create( + rewriter, op->getLoc(), + mlir::DenseElementsAttr::get(result_vector_type, iota_values)); + + rewriter.replaceOp(op, CastToTensor(rewriter, iota_const)); + return mlir::success(); + } +}; + class ShloToVectorPass : public impl::ShloToVectorPassBase { public: using ShloToVectorPassBase::ShloToVectorPassBase; @@ -325,7 +354,7 @@ class ShloToVectorPass : public impl::ShloToVectorPassBase { mlir::MLIRContext* context = &getContext(); mlir::RewritePatternSet patterns(context); patterns.add(context); + LowerBroadcastInDim, LowerReshape, LowerIota>(context); if (mlir::failed( mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); diff --git a/xla/backends/cpu/codegen/tiled/transforms/tests/shlo_to_vector.mlir b/xla/backends/cpu/codegen/tiled/transforms/tests/shlo_to_vector.mlir index f12b434fd0fe8..d36600e30b48f 100644 --- a/xla/backends/cpu/codegen/tiled/transforms/tests/shlo_to_vector.mlir +++ b/xla/backends/cpu/codegen/tiled/transforms/tests/shlo_to_vector.mlir @@ -185,3 +185,13 @@ func.func @reshape(%input : tensor<4xf32>) -> tensor<2x1x2xf32> { // CHECK-LABEL: @reshape // CHECK:vector.shape_cast {{.*}} : vector<4xf32> to vector<2x1x2xf32> +// ----- + +func.func @iota() -> tensor<4xi32> { + %result = stablehlo.iota dim = 0 : tensor<4xi32> + return %result : tensor<4xi32> +} + +// CHECK-LABEL: @iota +// CHECK: arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> +