Skip to content

Commit 0943a29

Browse files
[EVM] Unfold signextend(c, x) -> ashr(shl(x, 256 - (c + 1) * 8), 256 - (c + 1) * 8)
This transform does unfold to canonical InstCombine form, instead to sext(trunc(x, (c + 1) * 8), 256). Unfold signextend to LLVM instructions, iff c is constant, so LLVM can optimize it better. Do the folding back just before ISel, since we can do cross BB boundaries. Signed-off-by: Vladimir Radosavljevic <[email protected]>
1 parent 99eddfc commit 0943a29

File tree

5 files changed

+65
-8
lines changed

5 files changed

+65
-8
lines changed

llvm/lib/Target/EVM/EVMCodegenPrepare.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@
2020
#include "llvm/IR/IntrinsicInst.h"
2121
#include "llvm/IR/Intrinsics.h"
2222
#include "llvm/IR/IntrinsicsEVM.h"
23+
#include "llvm/IR/PatternMatch.h"
2324
#include "llvm/Pass.h"
25+
#include "llvm/Transforms/InstCombine/InstCombiner.h"
2426

2527
#include "EVM.h"
2628

2729
using namespace llvm;
30+
using namespace llvm::PatternMatch;
2831

2932
#define DEBUG_TYPE "evm-codegen-prepare"
3033

@@ -102,14 +105,47 @@ void EVMCodegenPrepare::processMemTransfer(MemTransferInst *M) {
102105
M->setCalledFunction(Intrinsic::getDeclaration(M->getModule(), IntrID));
103106
}
104107

108+
static bool optimizeAShrInst(Instruction *I) {
109+
auto *Ty = I->getType();
110+
unsigned BitWidth = Ty->getIntegerBitWidth();
111+
if (BitWidth != 256)
112+
return false;
113+
114+
// Fold ashr(shl(x, c), c) -> signextend(((256 - c) / 8) - 1, x)
115+
// where c is a constant and divisible by 8.
116+
Value *X = nullptr;
117+
ConstantInt *ShiftAmt = nullptr;
118+
if (match(I->getOperand(0),
119+
m_OneUse(m_Shl(m_Value(X), m_ConstantInt(ShiftAmt)))) &&
120+
match(I->getOperand(1), m_Specific(ShiftAmt)) &&
121+
ShiftAmt->getZExtValue() % 8 == 0) {
122+
IRBuilder<> Builder(I);
123+
unsigned ByteIdx = ((BitWidth - ShiftAmt->getZExtValue()) / 8) - 1;
124+
auto *B = ConstantInt::get(Ty, ByteIdx);
125+
auto *SignExtend =
126+
Builder.CreateIntrinsic(Ty, Intrinsic::evm_signextend, {B, X});
127+
SignExtend->takeName(I);
128+
I->replaceAllUsesWith(SignExtend);
129+
130+
// Remove shl after ashr. If to do otherwise, assert will be triggered.
131+
auto *ToRemove = cast<Instruction>(I->getOperand(0));
132+
I->eraseFromParent();
133+
ToRemove->eraseFromParent();
134+
return true;
135+
}
136+
return false;
137+
}
138+
105139
bool EVMCodegenPrepare::runOnFunction(Function &F) {
106140
bool Changed = false;
107141
for (auto &BB : F) {
108-
for (auto &I : BB) {
142+
for (auto &I : make_early_inc_range(BB)) {
109143
if (auto *M = dyn_cast<MemTransferInst>(&I)) {
110144
processMemTransfer(M);
111145
Changed = true;
112146
}
147+
if (I.getOpcode() == Instruction::AShr)
148+
Changed |= optimizeAShrInst(&I);
113149
}
114150
}
115151

llvm/lib/Target/EVM/EVMTargetMachine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ bool EVMPassConfig::addPreISel() {
237237
}
238238

239239
void EVMPassConfig::addCodeGenPrepare() {
240-
addPass(createEVMCodegenPreparePass());
241240
TargetPassConfig::addCodeGenPrepare();
241+
addPass(createEVMCodegenPreparePass());
242242
}
243243

244244
bool EVMPassConfig::addInstSelector() {

llvm/lib/Target/EVM/EVMTargetTransformInfo.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,28 @@ using namespace llvm::PatternMatch;
2121

2222
static std::optional<Instruction *> instCombineSignExtend(InstCombiner &IC,
2323
IntrinsicInst &II) {
24+
unsigned BitWidth = II.getType()->getIntegerBitWidth();
25+
if (BitWidth != 256)
26+
return std::nullopt;
27+
28+
// Unfold signextend(c, x) ->
29+
// ashr(shl(x, 256 - (c + 1) * 8), 256 - (c + 1) * 8)
30+
// where c is a constant integer.
31+
ConstantInt *C = nullptr;
32+
if (match(II.getArgOperand(0), m_ConstantInt(C))) {
33+
const APInt &B = C->getValue();
34+
35+
// If the signextend is larger than 31 bits, leave constant
36+
// folding to handle it.
37+
if (B.uge(APInt(BitWidth, (BitWidth / 8) - 1)))
38+
return std::nullopt;
39+
40+
unsigned ShiftAmt = BitWidth - ((B.getZExtValue() + 1) * 8);
41+
auto *Shl = IC.Builder.CreateShl(II.getArgOperand(1), ShiftAmt);
42+
auto *Ashr = IC.Builder.CreateAShr(Shl, ShiftAmt);
43+
return IC.replaceInstUsesWith(II, Ashr);
44+
}
45+
2446
// Fold signextend(b, signextend(b, x)) -> signextend(b, x)
2547
Value *B = nullptr, *X = nullptr;
2648
if (match(&II, m_Intrinsic<Intrinsic::evm_signextend>(

llvm/test/CodeGen/EVM/O3-pipeline.ll

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,8 @@ target triple = "evm"
5555
; CHECK-NEXT: Expand reduction intrinsics
5656
; CHECK-NEXT: Natural Loop Information
5757
; CHECK-NEXT: TLS Variable Hoist
58-
; CHECK-NEXT: Final transformations before code generation
59-
; CHECK-NEXT: Dominator Tree Construction
60-
; CHECK-NEXT: Natural Loop Information
6158
; CHECK-NEXT: CodeGen Prepare
59+
; CHECK-NEXT: Final transformations before code generation
6260
; CHECK-NEXT: Lower invoke and unwind, for unwindless code generators
6361
; CHECK-NEXT: Remove unreachable blocks from the CFG
6462
; CHECK-NEXT: CallGraph Construction

llvm/test/CodeGen/EVM/fold-signextend.ll

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ target triple = "evm"
77
define i256 @test_const(i256 %x) {
88
; CHECK-LABEL: define i256 @test_const(
99
; CHECK-SAME: i256 [[X:%.*]]) {
10-
; CHECK-NEXT: [[SIGNEXT1:%.*]] = call i256 @llvm.evm.signextend(i256 15, i256 [[X]])
10+
; CHECK-NEXT: [[TMP1:%.*]] = shl i256 [[X]], 128
11+
; CHECK-NEXT: [[SIGNEXT1:%.*]] = ashr exact i256 [[TMP1]], 128
1112
; CHECK-NEXT: ret i256 [[SIGNEXT1]]
1213
;
1314
%signext1 = call i256 @llvm.evm.signextend(i256 15, i256 %x)
@@ -18,8 +19,8 @@ define i256 @test_const(i256 %x) {
1819
define i256 @test_const_ne(i256 %x) {
1920
; CHECK-LABEL: define i256 @test_const_ne(
2021
; CHECK-SAME: i256 [[X:%.*]]) {
21-
; CHECK-NEXT: [[SIGNEXT1:%.*]] = call i256 @llvm.evm.signextend(i256 15, i256 [[X]])
22-
; CHECK-NEXT: [[SIGNEXT2:%.*]] = call i256 @llvm.evm.signextend(i256 10, i256 [[SIGNEXT1]])
22+
; CHECK-NEXT: [[TMP1:%.*]] = shl i256 [[X]], 168
23+
; CHECK-NEXT: [[SIGNEXT2:%.*]] = ashr exact i256 [[TMP1]], 168
2324
; CHECK-NEXT: ret i256 [[SIGNEXT2]]
2425
;
2526
%signext1 = call i256 @llvm.evm.signextend(i256 15, i256 %x)

0 commit comments

Comments
 (0)