From 70baf493181359aea42f0fdd994fdc2c2f763ab4 Mon Sep 17 00:00:00 2001 From: Max-astro <651535280@qq.com> Date: Tue, 26 Nov 2024 21:00:28 +0800 Subject: [PATCH] [MooreToCore] Support pows and powu op --- lib/Conversion/MooreToCore/MooreToCore.cpp | 87 +++++++++++++++++++++- test/Conversion/MooreToCore/basic.mlir | 22 ++++++ 2 files changed, 106 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/MooreToCore/MooreToCore.cpp b/lib/Conversion/MooreToCore/MooreToCore.cpp index 486d362bbf86..14a7274c321e 100644 --- a/lib/Conversion/MooreToCore/MooreToCore.cpp +++ b/lib/Conversion/MooreToCore/MooreToCore.cpp @@ -1188,6 +1188,84 @@ struct ShrOpConversion : public OpConversionPattern { } }; +struct PowUOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(PowUOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = typeConverter->convertType(op.getResult().getType()); + + Location loc = op.getLoc(); + auto intType = cast(op.getRhs().getType()); + + // transform a ** b into scf.for 0 to b step 1 { init *= a }, init = 1 + Type integerType = rewriter.getIntegerType(intType.getWidth()); + Value lowerBound = rewriter.create(loc, integerType, 0); + Value upperBound = + rewriter.create(loc, integerType, op.getRhs()); + Value step = rewriter.create(loc, integerType, 1); + + Value initVal = rewriter.create(loc, resultType, 1); + Value lhsVal = rewriter.create(loc, resultType, op.getLhs()); + + auto forOp = rewriter.create( + loc, lowerBound, upperBound, step, ValueRange(initVal), + [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { + Value loopVar = iterArgs.front(); + Value mul = rewriter.create(loc, lhsVal, loopVar); + rewriter.create(loc, ValueRange(mul)); + }); + + rewriter.replaceOp(op, forOp.getResult(0)); + + return success(); + } +}; + +struct PowSOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(PowSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = typeConverter->convertType(op.getResult().getType()); + + Location loc = op.getLoc(); + auto intType = cast(op.getRhs().getType()); + // transform a ** b into scf.for 0 to b step 1 { init *= a }, init = 1 + Type integerType = rewriter.getIntegerType(intType.getWidth()); + Value lhsVal = rewriter.create(loc, resultType, op.getLhs()); + Value rhsVal = rewriter.create(loc, integerType, op.getRhs()); + Value constZero = rewriter.create(loc, integerType, 0); + Value constZeroResult = rewriter.create(loc, resultType, 0); + Value isNegative = rewriter.create(loc, ICmpPredicate::slt, + rhsVal, constZero); + + // if the exponent is negative, return 0 + lhsVal = + rewriter.create(loc, isNegative, constZeroResult, lhsVal); + Value upperBound = + rewriter.create(loc, isNegative, constZero, rhsVal); + + Value lowerBound = constZero; + Value step = rewriter.create(loc, integerType, 1); + Value initVal = rewriter.create(loc, resultType, 1); + + auto forOp = rewriter.create( + loc, lowerBound, upperBound, step, ValueRange(initVal), + [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { + auto loopVar = iterArgs.front(); + auto mul = rewriter.create(loc, lhsVal, loopVar); + rewriter.create(loc, ValueRange(mul)); + }); + + rewriter.replaceOp(op, forOp.getResult(0)); + + return success(); + } +}; + struct AShrOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1430,9 +1508,9 @@ static void populateLegality(ConversionTarget &target, target.addLegalOp(); target.addDynamicallyLegalOp< - cf::CondBranchOp, cf::BranchOp, scf::IfOp, scf::YieldOp, func::CallOp, - func::ReturnOp, UnrealizedConversionCastOp, hw::OutputOp, hw::InstanceOp, - debug::ArrayOp, debug::StructOp, debug::VariableOp>( + cf::CondBranchOp, cf::BranchOp, scf::IfOp, scf::ForOp, scf::YieldOp, + func::CallOp, func::ReturnOp, UnrealizedConversionCastOp, hw::OutputOp, + hw::InstanceOp, debug::ArrayOp, debug::StructOp, debug::VariableOp>( [&](Operation *op) { return converter.isLegal(op); }); target.addDynamicallyLegalOp([&](func::FuncOp op) { @@ -1590,6 +1668,9 @@ static void populateOpConversion(RewritePatternSet &patterns, BinaryOpConversion, BinaryOpConversion, + // Patterns of power operations. + PowUOpConversion, PowSOpConversion, + // Patterns of relational operations. ICmpOpConversion, ICmpOpConversion, diff --git a/test/Conversion/MooreToCore/basic.mlir b/test/Conversion/MooreToCore/basic.mlir index 7b7b70eabb18..109a18a8d7d4 100644 --- a/test/Conversion/MooreToCore/basic.mlir +++ b/test/Conversion/MooreToCore/basic.mlir @@ -985,3 +985,25 @@ func.func @Conversions(%arg0: !moore.i16, %arg1: !moore.l16) { return } + +// CHECK-LABEL: func.func @PowUOp +func.func @PowUOp(%arg0: !moore.l32, %arg1: !moore.l32) { + // CHECK: %{{.*}} = scf.for %{{.*}} = %{{.*}} to %arg1 step %{{.*}} iter_args([[VAR:%.+]] = %{{.*}}) -> (i32) : i32 { + // CHECK: [[MUL:%.+]] = comb.mul %arg0, [[VAR]] : i32 + // CHECK: scf.yield [[MUL]] : i32 + %0 = moore.powu %arg0, %arg1 : l32 + return +} + +// CHECK-LABEL: func.func @PowSOp +func.func @PowSOp(%arg0: !moore.i32, %arg1: !moore.i32) { + // CHECK: [[COND:%.+]] = comb.icmp slt %arg1, %{{.*}} : i32 + // CHECK: [[BASE:%.+]] = comb.mux [[COND]], %{{.*}}, %arg0 : i32 + // CHECK: [[EXP:%.+]] = comb.mux [[COND]], %{{.*}}, %arg1 : i32 + + // CHECK: %{{.*}} = scf.for %{{.*}} = %{{.*}} to [[EXP]] step %{{.*}} iter_args([[VAR:%.+]] = %{{.*}}) -> (i32) : i32 { + // CHECK: [[MUL:%.+]] = comb.mul [[BASE]], [[VAR]] : i32 + // CHECK: scf.yield [[MUL]] : i32 + %0 = moore.pows %arg0, %arg1 : i32 + return +}