diff --git a/stablehlo/conversions/linalg/transforms/Rewriters.h b/stablehlo/conversions/linalg/transforms/Rewriters.h index ffee12b79c..9db1a021ea 100644 --- a/stablehlo/conversions/linalg/transforms/Rewriters.h +++ b/stablehlo/conversions/linalg/transforms/Rewriters.h @@ -22,28 +22,15 @@ limitations under the License. namespace mlir::stablehlo { //===----------------------------------------------------------------------===// -// General StableHLO/CHLO lowering patterns. +// General StableHLO lowering patterns. //===----------------------------------------------------------------------===// /// Populates the patterns that convert from StableHLO to Linalg on tensors. void populateStablehloToLinalgConversionPatterns(MLIRContext *context, TypeConverter &typeConverter, RewritePatternSet *patterns, - bool enablePrimitiveOps); - -/// Collection of rewrite patterns for lowering of CHLO ops to StableHLO and -/// Shape ops. -void populateLegalizeChloPatterns(MLIRContext *context, - RewritePatternSet *patterns); - -/// Collection of rewrite patterns for lowering of StableHLO ops to SCF control -/// flow ops. -void populateLegalizeControlFlowPatterns(MLIRContext *context, - RewritePatternSet *patterns); - -/// Collection of rewrite patterns for lowering of StableHLO dim operations. -void populateLegalizeShapeComputationPatterns(MLIRContext *context, - RewritePatternSet *patterns); + bool enablePrimitiveOps, + bool enableSparseOps); //===----------------------------------------------------------------------===// // Fine-grained patterns used by the implementation. diff --git a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp index 76fec51e9e..45f3cbbb1c 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp @@ -2600,11 +2600,45 @@ struct SetDimensionSizeConverter final } }; -static void populateConversionPatterns(MLIRContext *context, - TypeConverter &typeConverter, - RewritePatternSet *patterns, - bool enablePrimitiveOps, - bool enableSparseOps) { +struct StablehloLegalizeToLinalgPass + : impl::StablehloLegalizeToLinalgPassBase { + using StablehloLegalizeToLinalgPassBase::StablehloLegalizeToLinalgPassBase; + + LogicalResult initialize(MLIRContext *context) override { + target = std::make_shared(*context); + target->addLegalDialect< + bufferization::BufferizationDialect, arith::ArithDialect, + complex::ComplexDialect, linalg::LinalgDialect, math::MathDialect, + tensor::TensorDialect, sparse_tensor::SparseTensorDialect, + scf::SCFDialect, shape::ShapeDialect>(); + target->addLegalOp(); + + RewritePatternSet patterns_(context); + populateStablehloToLinalgConversionPatterns( + context, converter, &patterns_, enablePrimitiveOps, enableSparseOps); + patterns = std::move(patterns_); + + return success(); + } + + void runOnOperation() override { + if (failed(applyPartialConversion(getOperation(), *target, patterns))) { + return signalPassFailure(); + } + } + + private: + std::shared_ptr target; + FrozenRewritePatternSet patterns; + LinalgTypeConverter converter; +}; +} // namespace + +void populateStablehloToLinalgConversionPatterns(MLIRContext *context, + TypeConverter &typeConverter, + RewritePatternSet *patterns, + bool enablePrimitiveOps, + bool enableSparseOps) { // clang-format off patterns->add(typeConverter, context, enablePrimitiveOps); @@ -2670,37 +2704,4 @@ static void populateConversionPatterns(MLIRContext *context, linalg::populateEraseUnusedOperandsAndResultsPatterns(*patterns); } -struct StablehloLegalizeToLinalgPass - : impl::StablehloLegalizeToLinalgPassBase { - using StablehloLegalizeToLinalgPassBase::StablehloLegalizeToLinalgPassBase; - - LogicalResult initialize(MLIRContext *context) override { - target = std::make_shared(*context); - target->addLegalDialect< - bufferization::BufferizationDialect, arith::ArithDialect, - complex::ComplexDialect, linalg::LinalgDialect, math::MathDialect, - tensor::TensorDialect, sparse_tensor::SparseTensorDialect, - scf::SCFDialect, shape::ShapeDialect>(); - target->addLegalOp(); - - RewritePatternSet patterns_(context); - populateConversionPatterns(context, converter, &patterns_, - enablePrimitiveOps, enableSparseOps); - patterns = std::move(patterns_); - - return success(); - } - - void runOnOperation() override { - if (failed(applyPartialConversion(getOperation(), *target, patterns))) { - return signalPassFailure(); - } - } - - private: - std::shared_ptr target; - FrozenRewritePatternSet patterns; - LinalgTypeConverter converter; -}; -} // namespace } // namespace mlir::stablehlo