Skip to content

Commit 32ab6ee

Browse files
committed
[InstCombine] canonicalize sign bit checks
This is a generalization of #122520 to other instructions with constrained result ranges.
1 parent 89fdfb6 commit 32ab6ee

File tree

9 files changed

+115
-69
lines changed

9 files changed

+115
-69
lines changed

llvm/include/llvm/IR/ConstantRange.h

+6
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,12 @@ class [[nodiscard]] ConstantRange {
128128
/// NOTE: false does not mean that inverse predicate holds!
129129
bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const;
130130

131+
/// Does the predicate \p Pred or its inverse hold between ranges this and \p
132+
/// Other? Returns `true` if the predicate always holds, `false` if the
133+
/// inverse always holds, or `std::nullopt` otherwise.
134+
std::optional<bool> icmpOrInverse(CmpInst::Predicate Pred,
135+
const ConstantRange &Other) const;
136+
131137
/// Return true iff CR1 ult CR2 is equivalent to CR1 slt CR2.
132138
/// Does not depend on strictness/direction of the predicate.
133139
static bool

llvm/lib/Analysis/InstructionSimplify.cpp

+3-7
Original file line numberDiff line numberDiff line change
@@ -3783,13 +3783,9 @@ static Value *simplifyICmpInst(CmpPredicate Pred, Value *LHS, Value *RHS,
37833783
// If both operands have range metadata, use the metadata
37843784
// to simplify the comparison.
37853785
if (std::optional<ConstantRange> RhsCr = getRange(RHS, Q.IIQ))
3786-
if (std::optional<ConstantRange> LhsCr = getRange(LHS, Q.IIQ)) {
3787-
if (LhsCr->icmp(Pred, *RhsCr))
3788-
return ConstantInt::getTrue(ITy);
3789-
3790-
if (LhsCr->icmp(CmpInst::getInversePredicate(Pred), *RhsCr))
3791-
return ConstantInt::getFalse(ITy);
3792-
}
3786+
if (std::optional<ConstantRange> LhsCr = getRange(LHS, Q.IIQ))
3787+
if (auto Res = LhsCr->icmpOrInverse(Pred, *RhsCr))
3788+
return ConstantInt::getBool(ITy, *Res);
37933789

37943790
// Compare of cast, for example (zext X) != 0 -> X != 0
37953791
if (isa<CastInst>(LHS) && (isa<Constant>(RHS) || isa<CastInst>(RHS))) {

llvm/lib/IR/ConstantRange.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,16 @@ bool ConstantRange::icmp(CmpInst::Predicate Pred,
274274
}
275275
}
276276

277+
std::optional<bool>
278+
ConstantRange::icmpOrInverse(CmpInst::Predicate Pred,
279+
const ConstantRange &Other) const {
280+
if (icmp(Pred, Other))
281+
return true;
282+
if (icmp(CmpInst::getInversePredicate(Pred), Other))
283+
return false;
284+
return std::nullopt;
285+
}
286+
277287
/// Exact mul nuw region for single element RHS.
278288
static ConstantRange makeExactMulNUWRegion(const APInt &V) {
279289
unsigned BitWidth = V.getBitWidth();

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

+68-33
Original file line numberDiff line numberDiff line change
@@ -2674,41 +2674,10 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp,
26742674
Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp,
26752675
BinaryOperator *SRem,
26762676
const APInt &C) {
2677-
const ICmpInst::Predicate Pred = Cmp.getPredicate();
2678-
if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT) {
2679-
// Canonicalize unsigned predicates to signed:
2680-
// (X s% DivisorC) u> C -> (X s% DivisorC) s< 0
2681-
// iff (C s< 0 ? ~C : C) u>= abs(DivisorC)-1
2682-
// (X s% DivisorC) u< C+1 -> (X s% DivisorC) s> -1
2683-
// iff (C+1 s< 0 ? ~C : C) u>= abs(DivisorC)-1
2684-
2685-
const APInt *DivisorC;
2686-
if (!match(SRem->getOperand(1), m_APInt(DivisorC)))
2687-
return nullptr;
2688-
2689-
APInt NormalizedC = C;
2690-
if (Pred == ICmpInst::ICMP_ULT) {
2691-
assert(!NormalizedC.isZero() &&
2692-
"ult X, 0 should have been simplified already.");
2693-
--NormalizedC;
2694-
}
2695-
if (C.isNegative())
2696-
NormalizedC.flipAllBits();
2697-
assert(!DivisorC->isZero() &&
2698-
"srem X, 0 should have been simplified already.");
2699-
if (!NormalizedC.uge(DivisorC->abs() - 1))
2700-
return nullptr;
2701-
2702-
Type *Ty = SRem->getType();
2703-
if (Pred == ICmpInst::ICMP_UGT)
2704-
return new ICmpInst(ICmpInst::ICMP_SLT, SRem,
2705-
ConstantInt::getNullValue(Ty));
2706-
return new ICmpInst(ICmpInst::ICMP_SGT, SRem,
2707-
ConstantInt::getAllOnesValue(Ty));
2708-
}
27092677
// Match an 'is positive' or 'is negative' comparison of remainder by a
27102678
// constant power-of-2 value:
27112679
// (X % pow2C) sgt/slt 0
2680+
const ICmpInst::Predicate Pred = Cmp.getPredicate();
27122681
if (Pred != ICmpInst::ICMP_SGT && Pred != ICmpInst::ICMP_SLT &&
27132682
Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE)
27142683
return nullptr;
@@ -3164,7 +3133,10 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
31643133

31653134
if (ICmpInst::isUnsigned(Pred) && Add->hasNoSignedWrap() &&
31663135
C.isNonNegative() && (C - *C2).isNonNegative() &&
3167-
computeConstantRange(X, /*ForSigned=*/true).add(*C2).isAllNonNegative())
3136+
computeConstantRange(X, /*ForSigned=*/true, /*UseInstrInfo=*/true, &AC,
3137+
Add, &DT)
3138+
.add(*C2)
3139+
.isAllNonNegative())
31683140
return new ICmpInst(ICmpInst::getSignedPredicate(Pred), X,
31693141
ConstantInt::get(Ty, C - *C2));
31703142

@@ -7056,6 +7028,66 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I,
70567028
}
70577029
}
70587030

7031+
// (icmp X, Y) --> (icmp slt/sgt X, 0/-1) iff Y is outside the signed range of X
7032+
static ICmpInst *canonicalizeSignBitCheck(ICmpInst::Predicate Pred, Value *X,
7033+
const ConstantRange &XRange,
7034+
const ConstantRange &YRange) {
7035+
if (XRange.isSignWrappedSet())
7036+
return nullptr;
7037+
unsigned BitWidth = XRange.getBitWidth();
7038+
APInt SMin = APInt::getSignedMinValue(BitWidth);
7039+
APInt Zero = APInt::getZero(BitWidth);
7040+
auto NegResult =
7041+
XRange.intersectWith(ConstantRange(SMin, Zero), ConstantRange::Signed)
7042+
.icmpOrInverse(Pred, YRange);
7043+
if (!NegResult)
7044+
return nullptr;
7045+
auto PosResult =
7046+
XRange.intersectWith(ConstantRange(Zero, SMin), ConstantRange::Signed)
7047+
.icmpOrInverse(Pred, YRange);
7048+
if (!PosResult)
7049+
return nullptr;
7050+
assert(NegResult != PosResult &&
7051+
"Known result should been simplified already.");
7052+
Type *Ty = X->getType();
7053+
if (*NegResult)
7054+
return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::getNullValue(Ty));
7055+
return new ICmpInst(ICmpInst::ICMP_SGT, X, ConstantInt::getAllOnesValue(Ty));
7056+
}
7057+
7058+
// Try to fold an icmp using the constant ranges of its operands.
7059+
Instruction *InstCombinerImpl::foldICmpUsingConstantRanges(ICmpInst &Cmp) {
7060+
Value *X = Cmp.getOperand(0);
7061+
if (!X->getType()->isIntOrIntVectorTy())
7062+
return nullptr;
7063+
Value *Y = Cmp.getOperand(1);
7064+
ICmpInst::Predicate Pred = Cmp.getPredicate();
7065+
ConstantRange XRange = computeConstantRange(
7066+
X, ICmpInst::isSigned(Pred), /*UseInstrInfo=*/true, &AC, &Cmp, &DT);
7067+
if (XRange.isFullSet())
7068+
return nullptr; // early out if we don't have any information
7069+
ConstantRange YRange = computeConstantRange(
7070+
Y, ICmpInst::isSigned(Pred), /*UseInstrInfo=*/true, &AC, &Cmp, &DT);
7071+
if (YRange.isFullSet())
7072+
return nullptr; // early out if we don't have any information
7073+
if (auto Res = XRange.icmpOrInverse(Pred, YRange))
7074+
return replaceInstUsesWith(Cmp, ConstantInt::getBool(Cmp.getType(), *Res));
7075+
if (ICmpInst::isUnsigned(Pred)) {
7076+
// Check if this icmp is actually a sign bit check.
7077+
const APInt *C;
7078+
bool IgnoreTrueIfSigned;
7079+
if (!match(Y, m_APInt(C)) ||
7080+
!isSignBitCheck(Pred, *C, IgnoreTrueIfSigned)) {
7081+
if (ICmpInst *Res = canonicalizeSignBitCheck(Pred, X, XRange, YRange))
7082+
return Res;
7083+
if (ICmpInst *Res = canonicalizeSignBitCheck(
7084+
ICmpInst::getSwappedPredicate(Pred), Y, YRange, XRange))
7085+
return Res;
7086+
}
7087+
}
7088+
return nullptr;
7089+
}
7090+
70597091
// Transform pattern like:
70607092
// (1 << Y) u<= X or ~(-1 << Y) u< X or ((1 << Y)+(-1)) u< X
70617093
// (1 << Y) u> X or ~(-1 << Y) u>= X or ((1 << Y)+(-1)) u>= X
@@ -7428,6 +7460,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
74287460
if (Instruction *Res = canonicalizeICmpPredicate(I))
74297461
return Res;
74307462

7463+
if (Instruction *Res = foldICmpUsingConstantRanges(I))
7464+
return Res;
7465+
74317466
if (Instruction *Res = foldICmpWithConstant(I))
74327467
return Res;
74337468

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

+1
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
668668
Instruction *foldICmpWithCastOp(ICmpInst &ICmp);
669669
Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp);
670670

671+
Instruction *foldICmpUsingConstantRanges(ICmpInst &Cmp);
671672
Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp);
672673
Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp);
673674
Instruction *foldICmpWithConstant(ICmpInst &Cmp);

llvm/test/Transforms/InstCombine/icmp-dom.ll

+3-3
Original file line numberDiff line numberDiff line change
@@ -381,11 +381,11 @@ falselabel:
381381

382382
define i8 @PR48900_alt(i8 %i, ptr %p) {
383383
; CHECK-LABEL: @PR48900_alt(
384-
; CHECK-NEXT: [[SMAX:%.*]] = call i8 @llvm.smax.i8(i8 [[I:%.*]], i8 -127)
385-
; CHECK-NEXT: [[I4:%.*]] = icmp ugt i8 [[SMAX]], -128
384+
; CHECK-NEXT: [[I4:%.*]] = icmp slt i8 [[I:%.*]], 0
386385
; CHECK-NEXT: br i1 [[I4]], label [[TRUELABEL:%.*]], label [[FALSELABEL:%.*]]
387386
; CHECK: truelabel:
388-
; CHECK-NEXT: [[UMIN:%.*]] = call i8 @llvm.smin.i8(i8 [[SMAX]], i8 -126)
387+
; CHECK-NEXT: [[TMP1:%.*]] = icmp slt i8 [[I]], -126
388+
; CHECK-NEXT: [[UMIN:%.*]] = select i1 [[TMP1]], i8 -127, i8 -126
389389
; CHECK-NEXT: ret i8 [[UMIN]]
390390
; CHECK: falselabel:
391391
; CHECK-NEXT: ret i8 0

llvm/test/Transforms/InstCombine/icmp-srem.ll

+6-18
Original file line numberDiff line numberDiff line change
@@ -398,9 +398,7 @@ define i1 @icmp_slt_srem_pos_range(i32 %x, i32 range(i32 999, -2147483648) %y) {
398398
define i1 @icmp_sle_srem_pos_range(i32 %x, i32 range(i32 999, -2147483648) %y) {
399399
; CHECK-LABEL: define i1 @icmp_sle_srem_pos_range(
400400
; CHECK-SAME: i32 [[X:%.*]], i32 range(i32 999, -2147483648) [[Y:%.*]]) {
401-
; CHECK-NEXT: [[R:%.*]] = srem i32 [[X]], 1000
402-
; CHECK-NEXT: [[C:%.*]] = icmp sle i32 [[R]], [[Y]]
403-
; CHECK-NEXT: ret i1 [[C]]
401+
; CHECK-NEXT: ret i1 true
404402
;
405403
%r = srem i32 %x, 1000
406404
%c = icmp sle i32 %r, %y
@@ -410,9 +408,7 @@ define i1 @icmp_sle_srem_pos_range(i32 %x, i32 range(i32 999, -2147483648) %y) {
410408
define i1 @icmp_sgt_srem_pos_range(i32 %x, i32 range(i32 999, -2147483648) %y) {
411409
; CHECK-LABEL: define i1 @icmp_sgt_srem_pos_range(
412410
; CHECK-SAME: i32 [[X:%.*]], i32 range(i32 999, -2147483648) [[Y:%.*]]) {
413-
; CHECK-NEXT: [[R:%.*]] = srem i32 [[X]], 1000
414-
; CHECK-NEXT: [[C:%.*]] = icmp sgt i32 [[R]], [[Y]]
415-
; CHECK-NEXT: ret i1 [[C]]
411+
; CHECK-NEXT: ret i1 false
416412
;
417413
%r = srem i32 %x, 1000
418414
%c = icmp sgt i32 %r, %y
@@ -434,9 +430,7 @@ define i1 @icmp_sge_srem_pos_range(i32 %x, i32 range(i32 999, -2147483648) %y) {
434430
define i1 @icmp_slt_srem_neg_range(i32 %x, i32 range(i32 -2147483648, -999) %y) {
435431
; CHECK-LABEL: define i1 @icmp_slt_srem_neg_range(
436432
; CHECK-SAME: i32 [[X:%.*]], i32 range(i32 -2147483648, -999) [[Y:%.*]]) {
437-
; CHECK-NEXT: [[R:%.*]] = srem i32 [[X]], 1000
438-
; CHECK-NEXT: [[C:%.*]] = icmp slt i32 [[R]], [[Y]]
439-
; CHECK-NEXT: ret i1 [[C]]
433+
; CHECK-NEXT: ret i1 false
440434
;
441435
%r = srem i32 %x, 1000
442436
%c = icmp slt i32 %r, %y
@@ -446,9 +440,7 @@ define i1 @icmp_slt_srem_neg_range(i32 %x, i32 range(i32 -2147483648, -999) %y)
446440
define i1 @icmp_sle_srem_neg_range(i32 %x, i32 range(i32 -2147483648, -999) %y) {
447441
; CHECK-LABEL: define i1 @icmp_sle_srem_neg_range(
448442
; CHECK-SAME: i32 [[X:%.*]], i32 range(i32 -2147483648, -999) [[Y:%.*]]) {
449-
; CHECK-NEXT: [[R:%.*]] = srem i32 [[X]], 1000
450-
; CHECK-NEXT: [[C:%.*]] = icmp sle i32 [[R]], [[Y]]
451-
; CHECK-NEXT: ret i1 [[C]]
443+
; CHECK-NEXT: ret i1 false
452444
;
453445
%r = srem i32 %x, 1000
454446
%c = icmp sle i32 %r, %y
@@ -458,9 +450,7 @@ define i1 @icmp_sle_srem_neg_range(i32 %x, i32 range(i32 -2147483648, -999) %y)
458450
define i1 @icmp_sgt_srem_neg_range(i32 %x, i32 range(i32 -2147483648, -999) %y) {
459451
; CHECK-LABEL: define i1 @icmp_sgt_srem_neg_range(
460452
; CHECK-SAME: i32 [[X:%.*]], i32 range(i32 -2147483648, -999) [[Y:%.*]]) {
461-
; CHECK-NEXT: [[R:%.*]] = srem i32 [[X]], 1000
462-
; CHECK-NEXT: [[C:%.*]] = icmp sgt i32 [[R]], [[Y]]
463-
; CHECK-NEXT: ret i1 [[C]]
453+
; CHECK-NEXT: ret i1 true
464454
;
465455
%r = srem i32 %x, 1000
466456
%c = icmp sgt i32 %r, %y
@@ -470,9 +460,7 @@ define i1 @icmp_sgt_srem_neg_range(i32 %x, i32 range(i32 -2147483648, -999) %y)
470460
define i1 @icmp_sge_srem_neg_range(i32 %x, i32 range(i32 -2147483648, -999) %y) {
471461
; CHECK-LABEL: define i1 @icmp_sge_srem_neg_range(
472462
; CHECK-SAME: i32 [[X:%.*]], i32 range(i32 -2147483648, -999) [[Y:%.*]]) {
473-
; CHECK-NEXT: [[R:%.*]] = srem i32 [[X]], 1000
474-
; CHECK-NEXT: [[C:%.*]] = icmp sge i32 [[R]], [[Y]]
475-
; CHECK-NEXT: ret i1 [[C]]
463+
; CHECK-NEXT: ret i1 true
476464
;
477465
%r = srem i32 %x, 1000
478466
%c = icmp sge i32 %r, %y

llvm/test/Transforms/InstCombine/smin-icmp.ll

+6-7
Original file line numberDiff line numberDiff line change
@@ -965,9 +965,9 @@ define void @eq_smin_v2i32_constant(<2 x i32> %y) {
965965
; CHECK-NEXT: call void @use_v2i1(<2 x i1> [[CMP4]])
966966
; CHECK-NEXT: [[CMP5:%.*]] = icmp ult <2 x i32> [[COND]], splat (i32 10)
967967
; CHECK-NEXT: call void @use_v2i1(<2 x i1> [[CMP5]])
968-
; CHECK-NEXT: [[CMP6:%.*]] = icmp ult <2 x i32> [[COND]], splat (i32 11)
968+
; CHECK-NEXT: [[CMP6:%.*]] = icmp sgt <2 x i32> [[Y]], splat (i32 -1)
969969
; CHECK-NEXT: call void @use_v2i1(<2 x i1> [[CMP6]])
970-
; CHECK-NEXT: [[CMP7:%.*]] = icmp ugt <2 x i32> [[COND]], splat (i32 10)
970+
; CHECK-NEXT: [[CMP7:%.*]] = icmp slt <2 x i32> [[Y]], zeroinitializer
971971
; CHECK-NEXT: call void @use_v2i1(<2 x i1> [[CMP7]])
972972
; CHECK-NEXT: [[CMP8:%.*]] = icmp ugt <2 x i32> [[COND]], splat (i32 9)
973973
; CHECK-NEXT: call void @use_v2i1(<2 x i1> [[CMP8]])
@@ -1004,18 +1004,17 @@ define void @eq_smin_v2i32_constant(<2 x i32> %y) {
10041004
; icmp pred smin(C1, Y), C2 where C1 < C2
10051005
define void @slt_smin_v2i32_constant(<2 x i32> %y) {
10061006
; CHECK-LABEL: @slt_smin_v2i32_constant(
1007-
; CHECK-NEXT: [[COND:%.*]] = call <2 x i32> @llvm.smin.v2i32(<2 x i32> [[Y:%.*]], <2 x i32> splat (i32 5))
10081007
; CHECK-NEXT: call void @use_v2i1(<2 x i1> splat (i1 true))
10091008
; CHECK-NEXT: call void @use_v2i1(<2 x i1> splat (i1 true))
10101009
; CHECK-NEXT: call void @use_v2i1(<2 x i1> zeroinitializer)
10111010
; CHECK-NEXT: call void @use_v2i1(<2 x i1> zeroinitializer)
1012-
; CHECK-NEXT: [[CMP5:%.*]] = icmp ult <2 x i32> [[COND]], splat (i32 10)
1011+
; CHECK-NEXT: [[CMP5:%.*]] = icmp sgt <2 x i32> [[Y:%.*]], splat (i32 -1)
10131012
; CHECK-NEXT: call void @use_v2i1(<2 x i1> [[CMP5]])
1014-
; CHECK-NEXT: [[CMP6:%.*]] = icmp ult <2 x i32> [[COND]], splat (i32 11)
1013+
; CHECK-NEXT: [[CMP6:%.*]] = icmp sgt <2 x i32> [[Y]], splat (i32 -1)
10151014
; CHECK-NEXT: call void @use_v2i1(<2 x i1> [[CMP6]])
1016-
; CHECK-NEXT: [[CMP7:%.*]] = icmp ugt <2 x i32> [[COND]], splat (i32 10)
1015+
; CHECK-NEXT: [[CMP7:%.*]] = icmp slt <2 x i32> [[Y]], zeroinitializer
10171016
; CHECK-NEXT: call void @use_v2i1(<2 x i1> [[CMP7]])
1018-
; CHECK-NEXT: [[CMP8:%.*]] = icmp ugt <2 x i32> [[COND]], splat (i32 9)
1017+
; CHECK-NEXT: [[CMP8:%.*]] = icmp slt <2 x i32> [[Y]], zeroinitializer
10191018
; CHECK-NEXT: call void @use_v2i1(<2 x i1> [[CMP8]])
10201019
; CHECK-NEXT: call void @use_v2i1(<2 x i1> zeroinitializer)
10211020
; CHECK-NEXT: call void @use_v2i1(<2 x i1> splat (i1 true))

llvm/unittests/IR/ConstantRangeTest.cpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -1694,12 +1694,23 @@ void ICmpTestImpl(CmpInst::Predicate Pred) {
16941694
EnumerateTwoInterestingConstantRanges(
16951695
[&](const ConstantRange &CR1, const ConstantRange &CR2) {
16961696
bool Exhaustive = true;
1697+
bool ExhaustiveInverse = true;
16971698
ForeachNumInConstantRange(CR1, [&](const APInt &N1) {
16981699
ForeachNumInConstantRange(CR2, [&](const APInt &N2) {
1699-
Exhaustive &= ICmpInst::compare(N1, N2, Pred);
1700+
bool Res = ICmpInst::compare(N1, N2, Pred);
1701+
Exhaustive &= Res;
1702+
ExhaustiveInverse &= !Res;
17001703
});
17011704
});
1705+
1706+
std::optional<bool> ExhaustiveOrInverse;
1707+
if (Exhaustive) // Expect true if Exhaustive && ExhaustiveInverse.
1708+
ExhaustiveOrInverse = true;
1709+
else if (ExhaustiveInverse)
1710+
ExhaustiveOrInverse = false;
1711+
17021712
EXPECT_EQ(CR1.icmp(Pred, CR2), Exhaustive);
1713+
EXPECT_EQ(CR1.icmpOrInverse(Pred, CR2), ExhaustiveOrInverse);
17031714
});
17041715
}
17051716

0 commit comments

Comments
 (0)