Skip to content

Commit

Permalink
Expose populateStablehloToLinalgConversionPatterns function (#2695)
Browse files Browse the repository at this point in the history
This function declaration has existed since the code was forked from
IREE in #1817, but the
implementation was kept private (static function within an anonymous
namespace).

I'm now trying to switch IREE from having its own implementation to
using the upstream implementation from this project in
iree-org/iree#19792, and I would like to access
these patterns directly, instead of through the
`StablehloLegalizeToLinalgPass`. With the patterns I can run conversion
including my own sets of additional patterns, while a pass runs in
isolation.

I'm also deleting the `populateLegalizeChloPatterns`,
`populateLegalizeControlFlowPatterns`, and
`populateLegalizeShapeComputationPatterns` declarations, which were not
migrated from IREE and are also dangling without implementations.
  • Loading branch information
ScottTodd authored Jan 23, 2025
1 parent 01baa23 commit c27ba67
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 54 deletions.
19 changes: 3 additions & 16 deletions stablehlo/conversions/linalg/transforms/Rewriters.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,15 @@ limitations under the License.
namespace mlir::stablehlo {

//===----------------------------------------------------------------------===//
// General StableHLO/CHLO lowering patterns.
// General StableHLO lowering patterns.
//===----------------------------------------------------------------------===//

/// Populates the patterns that convert from StableHLO to Linalg on tensors.
void populateStablehloToLinalgConversionPatterns(MLIRContext *context,
TypeConverter &typeConverter,
RewritePatternSet *patterns,
bool enablePrimitiveOps);

/// Collection of rewrite patterns for lowering of CHLO ops to StableHLO and
/// Shape ops.
void populateLegalizeChloPatterns(MLIRContext *context,
RewritePatternSet *patterns);

/// Collection of rewrite patterns for lowering of StableHLO ops to SCF control
/// flow ops.
void populateLegalizeControlFlowPatterns(MLIRContext *context,
RewritePatternSet *patterns);

/// Collection of rewrite patterns for lowering of StableHLO dim operations.
void populateLegalizeShapeComputationPatterns(MLIRContext *context,
RewritePatternSet *patterns);
bool enablePrimitiveOps,
bool enableSparseOps);

//===----------------------------------------------------------------------===//
// Fine-grained patterns used by the implementation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2600,11 +2600,45 @@ struct SetDimensionSizeConverter final
}
};

static void populateConversionPatterns(MLIRContext *context,
TypeConverter &typeConverter,
RewritePatternSet *patterns,
bool enablePrimitiveOps,
bool enableSparseOps) {
struct StablehloLegalizeToLinalgPass
: impl::StablehloLegalizeToLinalgPassBase<StablehloLegalizeToLinalgPass> {
using StablehloLegalizeToLinalgPassBase::StablehloLegalizeToLinalgPassBase;

LogicalResult initialize(MLIRContext *context) override {
target = std::make_shared<ConversionTarget>(*context);
target->addLegalDialect<
bufferization::BufferizationDialect, arith::ArithDialect,
complex::ComplexDialect, linalg::LinalgDialect, math::MathDialect,
tensor::TensorDialect, sparse_tensor::SparseTensorDialect,
scf::SCFDialect, shape::ShapeDialect>();
target->addLegalOp<UnrealizedConversionCastOp>();

RewritePatternSet patterns_(context);
populateStablehloToLinalgConversionPatterns(
context, converter, &patterns_, enablePrimitiveOps, enableSparseOps);
patterns = std::move(patterns_);

return success();
}

void runOnOperation() override {
if (failed(applyPartialConversion(getOperation(), *target, patterns))) {
return signalPassFailure();
}
}

private:
std::shared_ptr<ConversionTarget> target;
FrozenRewritePatternSet patterns;
LinalgTypeConverter converter;
};
} // namespace

void populateStablehloToLinalgConversionPatterns(MLIRContext *context,
TypeConverter &typeConverter,
RewritePatternSet *patterns,
bool enablePrimitiveOps,
bool enableSparseOps) {
// clang-format off
patterns->add<ConcatenateConverter>(typeConverter, context,
enablePrimitiveOps);
Expand Down Expand Up @@ -2670,37 +2704,4 @@ static void populateConversionPatterns(MLIRContext *context,
linalg::populateEraseUnusedOperandsAndResultsPatterns(*patterns);
}

struct StablehloLegalizeToLinalgPass
: impl::StablehloLegalizeToLinalgPassBase<StablehloLegalizeToLinalgPass> {
using StablehloLegalizeToLinalgPassBase::StablehloLegalizeToLinalgPassBase;

LogicalResult initialize(MLIRContext *context) override {
target = std::make_shared<ConversionTarget>(*context);
target->addLegalDialect<
bufferization::BufferizationDialect, arith::ArithDialect,
complex::ComplexDialect, linalg::LinalgDialect, math::MathDialect,
tensor::TensorDialect, sparse_tensor::SparseTensorDialect,
scf::SCFDialect, shape::ShapeDialect>();
target->addLegalOp<UnrealizedConversionCastOp>();

RewritePatternSet patterns_(context);
populateConversionPatterns(context, converter, &patterns_,
enablePrimitiveOps, enableSparseOps);
patterns = std::move(patterns_);

return success();
}

void runOnOperation() override {
if (failed(applyPartialConversion(getOperation(), *target, patterns))) {
return signalPassFailure();
}
}

private:
std::shared_ptr<ConversionTarget> target;
FrozenRewritePatternSet patterns;
LinalgTypeConverter converter;
};
} // namespace
} // namespace mlir::stablehlo

0 comments on commit c27ba67

Please sign in to comment.