Skip to content

Commit 713ea8f

Browse files
WillFroomGoogle-ML-Automation
authored andcommitted
[XLA:CPU][XTile] Create lowering for Iota.
PiperOrigin-RevId: 825568861
1 parent b3ea33e commit 713ea8f

File tree

3 files changed

+44
-4
lines changed

3 files changed

+44
-4
lines changed

xla/backends/cpu/codegen/tiled/transforms/passes.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,12 @@ def ShloToVectorPass : Pass<"xtile-cpu-shlo-to-vector", "mlir::ModuleOp"> {
4646
let constructor = "CreateShloToVectorPass()";
4747

4848
let dependentDialects = [
49+
"mlir::arith::ArithDialect",
50+
"mlir::memref::MemRefDialect",
51+
"mlir::scf::SCFDialect",
52+
"mlir::stablehlo::StablehloDialect",
4953
"mlir::tensor::TensorDialect",
5054
"mlir::vector::VectorDialect",
51-
"mlir::stablehlo::StablehloDialect",
52-
"mlir::scf::SCFDialect",
53-
"mlir::memref::MemRefDialect",
5455
];
5556
}
5657

xla/backends/cpu/codegen/tiled/transforms/shlo_to_vector.cc

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,35 @@ struct LowerReshape : mlir::OpRewritePattern<mlir::stablehlo::ReshapeOp> {
317317
}
318318
};
319319

320+
struct LowerIota : mlir::OpRewritePattern<mlir::stablehlo::IotaOp> {
321+
using OpRewritePattern::OpRewritePattern;
322+
323+
mlir::LogicalResult matchAndRewrite(
324+
mlir::stablehlo::IotaOp op,
325+
mlir::PatternRewriter& rewriter) const override {
326+
if (op.getType().getRank() != 1) {
327+
return rewriter.notifyMatchFailure(
328+
op, "iota op with rank != 1 is not supported");
329+
}
330+
331+
auto result_vector_type = GetVectorType(op.getType());
332+
auto element_type = result_vector_type.getElementType();
333+
int64_t iota_size = result_vector_type.getNumElements();
334+
335+
llvm::SmallVector<mlir::Attribute> iota_values(iota_size);
336+
for (int idx = 0; idx != iota_size; ++idx) {
337+
iota_values[idx] = rewriter.getIntegerAttr(element_type, idx);
338+
}
339+
340+
mlir::Value iota_const = mlir::arith::ConstantOp::create(
341+
rewriter, op->getLoc(),
342+
mlir::DenseElementsAttr::get(result_vector_type, iota_values));
343+
344+
rewriter.replaceOp(op, CastToTensor(rewriter, iota_const));
345+
return mlir::success();
346+
}
347+
};
348+
320349
class ShloToVectorPass : public impl::ShloToVectorPassBase<ShloToVectorPass> {
321350
public:
322351
using ShloToVectorPassBase::ShloToVectorPassBase;
@@ -325,7 +354,7 @@ class ShloToVectorPass : public impl::ShloToVectorPassBase<ShloToVectorPass> {
325354
mlir::MLIRContext* context = &getContext();
326355
mlir::RewritePatternSet patterns(context);
327356
patterns.add<LowerTranspose, LowerDotGeneral, LowerReduce,
328-
LowerBroadcastInDim, LowerReshape>(context);
357+
LowerBroadcastInDim, LowerReshape, LowerIota>(context);
329358
if (mlir::failed(
330359
mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) {
331360
signalPassFailure();

xla/backends/cpu/codegen/tiled/transforms/tests/shlo_to_vector.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,13 @@ func.func @reshape(%input : tensor<4xf32>) -> tensor<2x1x2xf32> {
185185
// CHECK-LABEL: @reshape
186186
// CHECK:vector.shape_cast {{.*}} : vector<4xf32> to vector<2x1x2xf32>
187187

188+
// -----
189+
190+
func.func @iota() -> tensor<4xi32> {
191+
%result = stablehlo.iota dim = 0 : tensor<4xi32>
192+
return %result : tensor<4xi32>
193+
}
194+
195+
// CHECK-LABEL: @iota
196+
// CHECK: arith.constant dense<[0, 1, 2, 3]> : vector<4xi32>
197+

0 commit comments

Comments
 (0)