From c27ba678a712a401e4a6db75ec0ef9e1ce9e1777 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Thu, 23 Jan 2025 15:21:28 -0800 Subject: [PATCH] Expose populateStablehloToLinalgConversionPatterns function (#2695) This function declaration has existed since the code was forked from IREE in https://github.com/openxla/stablehlo/pull/1817, but the implementation was kept private (static function within an anonymous namespace). I'm now trying to switch IREE from having its own implementation to using the upstream implementation from this project in https://github.com/iree-org/iree/pull/19792, and I would like to access these patterns directly, instead of through the `StablehloLegalizeToLinalgPass`. With the patterns I can run conversion including my own sets of additional patterns, while a pass runs in isolation. I'm also deleting the `populateLegalizeChloPatterns`, `populateLegalizeControlFlowPatterns`, and `populateLegalizeShapeComputationPatterns` declarations, which were not migrated from IREE and are also dangling without implementations. --- .../conversions/linalg/transforms/Rewriters.h | 19 +---- .../transforms/StablehloLegalizeToLinalg.cpp | 77 ++++++++++--------- 2 files changed, 42 insertions(+), 54 deletions(-) diff --git a/stablehlo/conversions/linalg/transforms/Rewriters.h b/stablehlo/conversions/linalg/transforms/Rewriters.h index ffee12b79c..9db1a021ea 100644 --- a/stablehlo/conversions/linalg/transforms/Rewriters.h +++ b/stablehlo/conversions/linalg/transforms/Rewriters.h @@ -22,28 +22,15 @@ limitations under the License. namespace mlir::stablehlo { //===----------------------------------------------------------------------===// -// General StableHLO/CHLO lowering patterns. +// General StableHLO lowering patterns. //===----------------------------------------------------------------------===// /// Populates the patterns that convert from StableHLO to Linalg on tensors. void populateStablehloToLinalgConversionPatterns(MLIRContext *context, TypeConverter &typeConverter, RewritePatternSet *patterns, - bool enablePrimitiveOps); - -/// Collection of rewrite patterns for lowering of CHLO ops to StableHLO and -/// Shape ops. -void populateLegalizeChloPatterns(MLIRContext *context, - RewritePatternSet *patterns); - -/// Collection of rewrite patterns for lowering of StableHLO ops to SCF control -/// flow ops. -void populateLegalizeControlFlowPatterns(MLIRContext *context, - RewritePatternSet *patterns); - -/// Collection of rewrite patterns for lowering of StableHLO dim operations. -void populateLegalizeShapeComputationPatterns(MLIRContext *context, - RewritePatternSet *patterns); + bool enablePrimitiveOps, + bool enableSparseOps); //===----------------------------------------------------------------------===// // Fine-grained patterns used by the implementation. diff --git a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp index 76fec51e9e..45f3cbbb1c 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp @@ -2600,11 +2600,45 @@ struct SetDimensionSizeConverter final } }; -static void populateConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet *patterns, - bool enablePrimitiveOps, - bool enableSparseOps) { +struct StablehloLegalizeToLinalgPass + : impl::StablehloLegalizeToLinalgPassBase { + using StablehloLegalizeToLinalgPassBase::StablehloLegalizeToLinalgPassBase; + + LogicalResult initialize(MLIRContext *context) override { + target = std::make_shared(*context); + target->addLegalDialect< + bufferization::BufferizationDialect, arith::ArithDialect, + complex::ComplexDialect, linalg::LinalgDialect, math::MathDialect, + tensor::TensorDialect, sparse_tensor::SparseTensorDialect, + scf::SCFDialect, shape::ShapeDialect>(); + target->addLegalOp(); + + RewritePatternSet patterns_(context); + populateStablehloToLinalgConversionPatterns( + context, converter, &patterns_, enablePrimitiveOps, enableSparseOps); + patterns = std::move(patterns_); + + return success(); + } + + void runOnOperation() override { + if (failed(applyPartialConversion(getOperation(), *target, patterns))) { + return signalPassFailure(); + } + } + + private: + std::shared_ptr target; + FrozenRewritePatternSet patterns; + LinalgTypeConverter converter; +}; +} // namespace + +void populateStablehloToLinalgConversionPatterns(MLIRContext *context, + TypeConverter &typeConverter, + RewritePatternSet *patterns, + bool enablePrimitiveOps, + bool enableSparseOps) { // clang-format off patterns->add(typeConverter, context, enablePrimitiveOps); @@ -2670,37 +2704,4 @@ static void populateConversionPatterns(MLIRContext *context, linalg::populateEraseUnusedOperandsAndResultsPatterns(*patterns); } -struct StablehloLegalizeToLinalgPass - : impl::StablehloLegalizeToLinalgPassBase { - using StablehloLegalizeToLinalgPassBase::StablehloLegalizeToLinalgPassBase; - - LogicalResult initialize(MLIRContext *context) override { - target = std::make_shared(*context); - target->addLegalDialect< - bufferization::BufferizationDialect, arith::ArithDialect, - complex::ComplexDialect, linalg::LinalgDialect, math::MathDialect, - tensor::TensorDialect, sparse_tensor::SparseTensorDialect, - scf::SCFDialect, shape::ShapeDialect>(); - target->addLegalOp(); - - RewritePatternSet patterns_(context); - populateConversionPatterns(context, converter, &patterns_, - enablePrimitiveOps, enableSparseOps); - patterns = std::move(patterns_); - - return success(); - } - - void runOnOperation() override { - if (failed(applyPartialConversion(getOperation(), *target, patterns))) { - return signalPassFailure(); - } - } - - private: - std::shared_ptr target; - FrozenRewritePatternSet patterns; - LinalgTypeConverter converter; -}; -} // namespace } // namespace mlir::stablehlo