Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] canonicalize sign bit checks #122962

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llvm/include/llvm/IR/ConstantRange.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ class [[nodiscard]] ConstantRange {
/// NOTE: false does not mean that inverse predicate holds!
bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const;

/// Does the predicate \p Pred or its inverse hold between ranges this and \p
/// Other? Returns `true` if the predicate always holds, `false` if the
/// inverse always holds, or `std::nullopt` otherwise.
std::optional<bool> icmpOrInverse(CmpInst::Predicate Pred,
const ConstantRange &Other) const;

/// Return true iff CR1 ult CR2 is equivalent to CR1 slt CR2.
/// Does not depend on strictness/direction of the predicate.
static bool
Expand Down
10 changes: 3 additions & 7 deletions llvm/lib/Analysis/InstructionSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3783,13 +3783,9 @@ static Value *simplifyICmpInst(CmpPredicate Pred, Value *LHS, Value *RHS,
// If both operands have range metadata, use the metadata
// to simplify the comparison.
if (std::optional<ConstantRange> RhsCr = getRange(RHS, Q.IIQ))
if (std::optional<ConstantRange> LhsCr = getRange(LHS, Q.IIQ)) {
if (LhsCr->icmp(Pred, *RhsCr))
return ConstantInt::getTrue(ITy);

if (LhsCr->icmp(CmpInst::getInversePredicate(Pred), *RhsCr))
return ConstantInt::getFalse(ITy);
}
if (std::optional<ConstantRange> LhsCr = getRange(LHS, Q.IIQ))
if (auto Res = LhsCr->icmpOrInverse(Pred, *RhsCr))
return ConstantInt::getBool(ITy, *Res);

// Compare of cast, for example (zext X) != 0 -> X != 0
if (isa<CastInst>(LHS) && (isa<Constant>(RHS) || isa<CastInst>(RHS))) {
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/IR/ConstantRange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,16 @@ bool ConstantRange::icmp(CmpInst::Predicate Pred,
}
}

std::optional<bool>
ConstantRange::icmpOrInverse(CmpInst::Predicate Pred,
const ConstantRange &Other) const {
if (icmp(Pred, Other))
return true;
if (icmp(CmpInst::getInversePredicate(Pred), Other))
return false;
return std::nullopt;
}

/// Exact mul nuw region for single element RHS.
static ConstantRange makeExactMulNUWRegion(const APInt &V) {
unsigned BitWidth = V.getBitWidth();
Expand Down
68 changes: 67 additions & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3133,7 +3133,10 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,

if (ICmpInst::isUnsigned(Pred) && Add->hasNoSignedWrap() &&
C.isNonNegative() && (C - *C2).isNonNegative() &&
computeConstantRange(X, /*ForSigned=*/true).add(*C2).isAllNonNegative())
computeConstantRange(X, /*ForSigned=*/true, /*UseInstrInfo=*/true, &AC,
Add, &DT)
.add(*C2)
.isAllNonNegative())
return new ICmpInst(ICmpInst::getSignedPredicate(Pred), X,
ConstantInt::get(Ty, C - *C2));

Expand Down Expand Up @@ -7025,6 +7028,66 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I,
}
}

// (icmp X, Y) --> (icmp slt/sgt X, 0/-1) iff Y is outside the signed range of X
static ICmpInst *canonicalizeSignBitCheck(ICmpInst::Predicate Pred, Value *X,
const ConstantRange &XRange,
const ConstantRange &YRange) {
if (XRange.isSignWrappedSet())
return nullptr;
unsigned BitWidth = XRange.getBitWidth();
APInt SMin = APInt::getSignedMinValue(BitWidth);
APInt Zero = APInt::getZero(BitWidth);
auto NegResult =
XRange.intersectWith(ConstantRange(SMin, Zero), ConstantRange::Signed)
.icmpOrInverse(Pred, YRange);
if (!NegResult)
return nullptr;
auto PosResult =
XRange.intersectWith(ConstantRange(Zero, SMin), ConstantRange::Signed)
.icmpOrInverse(Pred, YRange);
if (!PosResult)
return nullptr;
assert(NegResult != PosResult &&
"Known result should been simplified already.");
Type *Ty = X->getType();
if (*NegResult)
return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::getNullValue(Ty));
return new ICmpInst(ICmpInst::ICMP_SGT, X, ConstantInt::getAllOnesValue(Ty));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least IMO, more useful than just building to just maximally constrain the icmp. I.e if you have icmp ugt X, 10 and can actually narrow it to icmp ugt X, 129, thats better than icmp slt X, 0.

Copy link
Contributor Author

@jacobly0 jacobly0 Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to use a form that other combines will recognize as a sign check, such as icmp slt X, 0 or icmp ugt X, smax. I'm also not clear which side is considered maximally constrained here. For example if -5 <= X <= 5 then icmp ugt X, -6 <=> icmp ugt X, smax <=> icmp slt X, 0 <=> icmp ugt X, 5. I'm choosing the signed comparison, which is exactly equivalent to comparing to smax, because sign checks can be cheaper and have special combines written for them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think it would be possible to update the downstream combines to handle a different canonical form. My main worry is losing out in the backend on various shifting tricks and simpler sign checks that don't require loading a potentially large constant and which are localized to only needing to check one bit for large integers. I'm sure this could be mitigated by moving the range checks and sign bit check preference into the backends that care, but it certainly does not appear to currently be the case.

Under the assumption that the issue is that more information could be gained by moving the comparison constant to one of the extremes, I'm still unclear on how to choose one side of the range as more beneficial. It seems like some information is lost on each side, by reducing the known range on either a true or false branch, which is only gained back by recomputing ranges. At that point, the choice of constant does not seem to affect the information you have, because the source range is always split into the same two ranges by any choice of constant.

If you want to limit it to only changing the constant of an existing icmp and not changing the signedness of the predicate, then that seems perfectly reasonable and potentially simple to implement. If you have a specific definition of canonical other than preferring strict sign bit checks, I could also try to see if I could make it work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think is depends heavily on the use. I can see the case where if you are branching on the cmp then maximally constraining it in one direction leaves us with the least useful region on the other side. But in other use-cases whether only the true (or false) value matters we would almost certaintly want maximal constraint. I.e (select (icmp ult X, 4), f(X), C) is almost always better (select (icmp sge X, 0), f(X), C).

I guess I can see sign-compare is a general middle-ground. Although I think this fold is liable to cause regressions in cases that it relaxes the constraints on a cmp that dominates some important expressions.

Truthfully not sure what the correct approach here is. I guess in some ideal case I would think some cost-function to decide which direction of the cmp is liable to benefit most from being maximally constrained, altough that might be overkill, especially in InstCombine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that totally makes sense now, and I don't really have a better solution to propose at the moment.

Do you have a suggestion for how to proceed with my original add (sdiv X, C), sext i1 (icmp ugt (srem X, C), smin) => ashr X, log2(C) use case in the short term, and in a way that won't need to be reverted? Is my original PR that only applies to srem seem restricted enough enough to avoid major regressions (especially given the relative rareness of srem), or should I revert back to something like my initial code which only applied where the full original foldICmpSRemConstant is applicable, just slightly improved based on what I have learned from these generalizations?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you talking about #122520? If so I think that can in without further changes.


// Try to fold an icmp using the constant ranges of its operands.
Instruction *InstCombinerImpl::foldICmpUsingConstantRanges(ICmpInst &Cmp) {
Value *X = Cmp.getOperand(0);
if (!X->getType()->isIntOrIntVectorTy())
return nullptr;
Value *Y = Cmp.getOperand(1);
ICmpInst::Predicate Pred = Cmp.getPredicate();
ConstantRange XRange = computeConstantRange(
X, ICmpInst::isSigned(Pred), /*UseInstrInfo=*/true, &AC, &Cmp, &DT);
if (XRange.isFullSet())
return nullptr; // early out if we don't have any information
ConstantRange YRange = computeConstantRange(
Y, ICmpInst::isSigned(Pred), /*UseInstrInfo=*/true, &AC, &Cmp, &DT);
if (YRange.isFullSet())
return nullptr; // early out if we don't have any information
if (auto Res = XRange.icmpOrInverse(Pred, YRange))
return replaceInstUsesWith(Cmp, ConstantInt::getBool(Cmp.getType(), *Res));
if (ICmpInst::isUnsigned(Pred)) {
// Check if this icmp is actually a sign bit check.
const APInt *C;
bool IgnoreTrueIfSigned;
if (!match(Y, m_APInt(C)) ||
!isSignBitCheck(Pred, *C, IgnoreTrueIfSigned)) {
if (ICmpInst *Res = canonicalizeSignBitCheck(Pred, X, XRange, YRange))
return Res;
if (ICmpInst *Res = canonicalizeSignBitCheck(
ICmpInst::getSwappedPredicate(Pred), Y, YRange, XRange))
return Res;
}
}
return nullptr;
}

// Transform pattern like:
// (1 << Y) u<= X or ~(-1 << Y) u< X or ((1 << Y)+(-1)) u< X
// (1 << Y) u> X or ~(-1 << Y) u>= X or ((1 << Y)+(-1)) u>= X
Expand Down Expand Up @@ -7397,6 +7460,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
if (Instruction *Res = canonicalizeICmpPredicate(I))
return Res;

if (Instruction *Res = foldICmpUsingConstantRanges(I))
return Res;

if (Instruction *Res = foldICmpWithConstant(I))
return Res;

Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Instruction *foldICmpWithCastOp(ICmpInst &ICmp);
Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp);

Instruction *foldICmpUsingConstantRanges(ICmpInst &Cmp);
Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp);
Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp);
Instruction *foldICmpWithConstant(ICmpInst &Cmp);
Expand Down
26 changes: 26 additions & 0 deletions llvm/test/Transforms/InstCombine/add.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3018,6 +3018,32 @@ define i32 @floor_sdiv_wrong_op(i32 %x, i32 %y) {
ret i32 %r
}

define i32 @floor_sdiv_using_srem_by_8(i32 %x) {
; CHECK-LABEL: @floor_sdiv_using_srem_by_8(
; CHECK-NEXT: [[F:%.*]] = ashr i32 [[X:%.*]], 3
; CHECK-NEXT: ret i32 [[F]]
;
%d = sdiv i32 %x, 8
%r = srem i32 %x, 8
%i = icmp ugt i32 %r, -2147483648
%s = sext i1 %i to i32
%f = add i32 %d, %s
ret i32 %f
}

define i32 @floor_sdiv_using_srem_by_2(i32 %x) {
; CHECK-LABEL: @floor_sdiv_using_srem_by_2(
; CHECK-NEXT: [[F:%.*]] = ashr i32 [[X:%.*]], 1
; CHECK-NEXT: ret i32 [[F]]
;
%d = sdiv i32 %x, 2
%r = srem i32 %x, 2
%i = icmp ugt i32 %r, -2147483648
%s = sext i1 %i to i32
%f = add i32 %d, %s
ret i32 %f
}

; (X s>> (BW - 1)) + (zext (X s> 0)) --> (X s>> (BW - 1)) | (zext (X != 0))

define i8 @signum_i8_i8(i8 %x) {
Expand Down
6 changes: 3 additions & 3 deletions llvm/test/Transforms/InstCombine/icmp-dom.ll
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,11 @@ falselabel:

define i8 @PR48900_alt(i8 %i, ptr %p) {
; CHECK-LABEL: @PR48900_alt(
; CHECK-NEXT: [[SMAX:%.*]] = call i8 @llvm.smax.i8(i8 [[I:%.*]], i8 -127)
; CHECK-NEXT: [[I4:%.*]] = icmp ugt i8 [[SMAX]], -128
; CHECK-NEXT: [[I4:%.*]] = icmp slt i8 [[I:%.*]], 0
; CHECK-NEXT: br i1 [[I4]], label [[TRUELABEL:%.*]], label [[FALSELABEL:%.*]]
; CHECK: truelabel:
; CHECK-NEXT: [[UMIN:%.*]] = call i8 @llvm.smin.i8(i8 [[SMAX]], i8 -126)
; CHECK-NEXT: [[TMP1:%.*]] = icmp slt i8 [[I]], -126
; CHECK-NEXT: [[UMIN:%.*]] = select i1 [[TMP1]], i8 -127, i8 -126
; CHECK-NEXT: ret i8 [[UMIN]]
; CHECK: falselabel:
; CHECK-NEXT: ret i8 0
Expand Down
Loading
Loading