Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] canonicalize sign bit checks #122962

Closed
wants to merge 4 commits into from

Conversation

jacobly0
Copy link
Contributor

This is a generalization of #122520 to other instructions with constrained result ranges.

Not sure whether this generalization is desirable, but I thought I would try it out.

This allows optimization of more signed floor implementations when the
divisor is a known power of two to an arithmetic shift.

Proof for the implemented optimizations:
https://alive2.llvm.org/ce/z/j6C-Nz

Proof for the test cases:
https://alive2.llvm.org/ce/z/M_PBjw
@llvmbot
Copy link
Member

llvmbot commented Jan 14, 2025

@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-analysis

Author: Jacob Young (jacobly0)

Changes

This is a generalization of #122520 to other instructions with constrained result ranges.

Not sure whether this generalization is desirable, but I thought I would try it out.


Patch is 27.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122962.diff

10 Files Affected:

  • (modified) llvm/include/llvm/IR/ConstantRange.h (+6)
  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+3-7)
  • (modified) llvm/lib/IR/ConstantRange.cpp (+10)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+66-1)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+8)
  • (modified) llvm/test/Transforms/InstCombine/add.ll (+26)
  • (modified) llvm/test/Transforms/InstCombine/icmp-dom.ll (+3-3)
  • (added) llvm/test/Transforms/InstCombine/icmp-srem.ll (+468)
  • (modified) llvm/test/Transforms/InstCombine/smin-icmp.ll (+6-7)
  • (modified) llvm/unittests/IR/ConstantRangeTest.cpp (+12-1)
diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h
index d086c25390fd22..92e9341352c177 100644
--- a/llvm/include/llvm/IR/ConstantRange.h
+++ b/llvm/include/llvm/IR/ConstantRange.h
@@ -128,6 +128,12 @@ class [[nodiscard]] ConstantRange {
   /// NOTE: false does not mean that inverse predicate holds!
   bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const;
 
+  /// Does the predicate \p Pred or its inverse hold between ranges this and \p
+  /// Other? Returns `true` if the predicate always holds, `false` if the
+  /// inverse always holds, or `std::nullopt` otherwise.
+  std::optional<bool> icmpOrInverse(CmpInst::Predicate Pred,
+                                    const ConstantRange &Other) const;
+
   /// Return true iff CR1 ult CR2 is equivalent to CR1 slt CR2.
   /// Does not depend on strictness/direction of the predicate.
   static bool
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index d69747e30f884d..dae4f37908cf2c 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3783,13 +3783,9 @@ static Value *simplifyICmpInst(CmpPredicate Pred, Value *LHS, Value *RHS,
   // If both operands have range metadata, use the metadata
   // to simplify the comparison.
   if (std::optional<ConstantRange> RhsCr = getRange(RHS, Q.IIQ))
-    if (std::optional<ConstantRange> LhsCr = getRange(LHS, Q.IIQ)) {
-      if (LhsCr->icmp(Pred, *RhsCr))
-        return ConstantInt::getTrue(ITy);
-
-      if (LhsCr->icmp(CmpInst::getInversePredicate(Pred), *RhsCr))
-        return ConstantInt::getFalse(ITy);
-    }
+    if (std::optional<ConstantRange> LhsCr = getRange(LHS, Q.IIQ))
+      if (auto Res = LhsCr->icmpOrInverse(Pred, *RhsCr))
+        return ConstantInt::getBool(ITy, *Res);
 
   // Compare of cast, for example (zext X) != 0 -> X != 0
   if (isa<CastInst>(LHS) && (isa<Constant>(RHS) || isa<CastInst>(RHS))) {
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 35664353989929..c5380ab2ffeda4 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -274,6 +274,16 @@ bool ConstantRange::icmp(CmpInst::Predicate Pred,
   }
 }
 
+std::optional<bool>
+ConstantRange::icmpOrInverse(CmpInst::Predicate Pred,
+                             const ConstantRange &Other) const {
+  if (icmp(Pred, Other))
+    return true;
+  if (icmp(CmpInst::getInversePredicate(Pred), Other))
+    return false;
+  return std::nullopt;
+}
+
 /// Exact mul nuw region for single element RHS.
 static ConstantRange makeExactMulNUWRegion(const APInt &V) {
   unsigned BitWidth = V.getBitWidth();
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 2e457257599493..f45868ba283a1a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -3133,7 +3133,9 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
 
   if (ICmpInst::isUnsigned(Pred) && Add->hasNoSignedWrap() &&
       C.isNonNegative() && (C - *C2).isNonNegative() &&
-      computeConstantRange(X, /*ForSigned=*/true).add(*C2).isAllNonNegative())
+      computeConstantRange(X, Add, /*ForSigned=*/true)
+          .add(*C2)
+          .isAllNonNegative())
     return new ICmpInst(ICmpInst::getSignedPredicate(Pred), X,
                         ConstantInt::get(Ty, C - *C2));
 
@@ -7025,6 +7027,66 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I,
   }
 }
 
+// (icmp X, Y) --> (icmp slt/sgt X, 0/-1) iff Y is outside the signed range of X
+static ICmpInst *canonicalizeSignBitCheck(ICmpInst::Predicate Pred, Value *X,
+                                          const ConstantRange &XRange,
+                                          const ConstantRange &YRange) {
+  if (XRange.isSignWrappedSet())
+    return nullptr;
+  unsigned BitWidth = XRange.getBitWidth();
+  APInt SMin = APInt::getSignedMinValue(BitWidth);
+  APInt Zero = APInt::getZero(BitWidth);
+  auto NegResult =
+      XRange.intersectWith(ConstantRange(SMin, Zero), ConstantRange::Signed)
+          .icmpOrInverse(Pred, YRange);
+  if (!NegResult)
+    return nullptr;
+  auto PosResult =
+      XRange.intersectWith(ConstantRange(Zero, SMin), ConstantRange::Signed)
+          .icmpOrInverse(Pred, YRange);
+  if (!PosResult)
+    return nullptr;
+  assert(NegResult != PosResult &&
+         "Known result should been simplified already.");
+  Type *Ty = X->getType();
+  if (*NegResult)
+    return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::getNullValue(Ty));
+  return new ICmpInst(ICmpInst::ICMP_SGT, X, ConstantInt::getAllOnesValue(Ty));
+}
+
+// Try to fold an icmp using the constant ranges of its operands.
+Instruction *InstCombinerImpl::foldICmpUsingConstantRanges(ICmpInst &Cmp) {
+  Value *X = Cmp.getOperand(0);
+  if (!X->getType()->isIntOrIntVectorTy())
+    return nullptr;
+  Value *Y = Cmp.getOperand(1);
+  ICmpInst::Predicate Pred = Cmp.getPredicate();
+  ConstantRange XRange =
+      computeConstantRange(X, &Cmp, ICmpInst::isSigned(Pred));
+  if (XRange.isFullSet())
+    return nullptr; // early out if we don't have any information
+  ConstantRange YRange =
+      computeConstantRange(Y, &Cmp, ICmpInst::isSigned(Pred));
+  if (YRange.isFullSet())
+    return nullptr; // early out if we don't have any information
+  if (auto Res = XRange.icmpOrInverse(Pred, YRange))
+    return replaceInstUsesWith(Cmp, ConstantInt::getBool(Cmp.getType(), *Res));
+  if (ICmpInst::isUnsigned(Pred)) {
+    // Check if this icmp is actually a sign bit check.
+    const APInt *C;
+    bool IgnoreTrueIfSigned;
+    if (!match(Y, m_APInt(C)) ||
+        !isSignBitCheck(Pred, *C, IgnoreTrueIfSigned)) {
+      if (ICmpInst *Res = canonicalizeSignBitCheck(Pred, X, XRange, YRange))
+        return Res;
+      if (ICmpInst *Res = canonicalizeSignBitCheck(
+              ICmpInst::getSwappedPredicate(Pred), Y, YRange, XRange))
+        return Res;
+    }
+  }
+  return nullptr;
+}
+
 // Transform pattern like:
 //   (1 << Y) u<= X  or  ~(-1 << Y) u<  X  or  ((1 << Y)+(-1)) u<  X
 //   (1 << Y) u>  X  or  ~(-1 << Y) u>= X  or  ((1 << Y)+(-1)) u>= X
@@ -7397,6 +7459,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
   if (Instruction *Res = canonicalizeICmpPredicate(I))
     return Res;
 
+  if (Instruction *Res = foldICmpUsingConstantRanges(I))
+    return Res;
+
   if (Instruction *Res = foldICmpWithConstant(I))
     return Res;
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 83e1da98deeda0..9f60dcf59ae467 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -240,6 +240,13 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   convertOrOfShiftsToFunnelShift(Instruction &Or);
 
 private:
+  ConstantRange computeConstantRange(const Value *V, const Instruction *CtxI,
+                                     bool ForSigned, bool UseInstrInfo = true,
+                                     unsigned Depth = 0) {
+    return llvm::computeConstantRange(V, ForSigned, UseInstrInfo, &AC, CtxI,
+                                      &DT, Depth);
+  }
+
   bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI);
   bool isDesirableIntType(unsigned BitWidth) const;
   bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const;
@@ -668,6 +675,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *foldICmpWithCastOp(ICmpInst &ICmp);
   Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp);
 
+  Instruction *foldICmpUsingConstantRanges(ICmpInst &Cmp);
   Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp);
   Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp);
   Instruction *foldICmpWithConstant(ICmpInst &Cmp);
diff --git a/llvm/test/Transforms/InstCombine/add.ll b/llvm/test/Transforms/InstCombine/add.ll
index 222f87fa3a5f18..495f99824652d6 100644
--- a/llvm/test/Transforms/InstCombine/add.ll
+++ b/llvm/test/Transforms/InstCombine/add.ll
@@ -3018,6 +3018,32 @@ define i32 @floor_sdiv_wrong_op(i32 %x, i32 %y) {
   ret i32 %r
 }
 
+define i32 @floor_sdiv_using_srem_by_8(i32 %x) {
+; CHECK-LABEL: @floor_sdiv_using_srem_by_8(
+; CHECK-NEXT:    [[F:%.*]] = ashr i32 [[X:%.*]], 3
+; CHECK-NEXT:    ret i32 [[F]]
+;
+  %d = sdiv i32 %x, 8
+  %r = srem i32 %x, 8
+  %i = icmp ugt i32 %r, -2147483648
+  %s = sext i1 %i to i32
+  %f = add i32 %d, %s
+  ret i32 %f
+}
+
+define i32 @floor_sdiv_using_srem_by_2(i32 %x) {
+; CHECK-LABEL: @floor_sdiv_using_srem_by_2(
+; CHECK-NEXT:    [[F:%.*]] = ashr i32 [[X:%.*]], 1
+; CHECK-NEXT:    ret i32 [[F]]
+;
+  %d = sdiv i32 %x, 2
+  %r = srem i32 %x, 2
+  %i = icmp ugt i32 %r, -2147483648
+  %s = sext i1 %i to i32
+  %f = add i32 %d, %s
+  ret i32 %f
+}
+
 ; (X s>> (BW - 1)) + (zext (X s> 0)) --> (X s>> (BW - 1)) | (zext (X != 0))
 
 define i8 @signum_i8_i8(i8 %x) {
diff --git a/llvm/test/Transforms/InstCombine/icmp-dom.ll b/llvm/test/Transforms/InstCombine/icmp-dom.ll
index 3cf3a7af77041c..66e9e514a9022a 100644
--- a/llvm/test/Transforms/InstCombine/icmp-dom.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-dom.ll
@@ -381,11 +381,11 @@ falselabel:
 
 define i8 @PR48900_alt(i8 %i, ptr %p) {
 ; CHECK-LABEL: @PR48900_alt(
-; CHECK-NEXT:    [[SMAX:%.*]] = call i8 @llvm.smax.i8(i8 [[I:%.*]], i8 -127)
-; CHECK-NEXT:    [[I4:%.*]] = icmp ugt i8 [[SMAX]], -128
+; CHECK-NEXT:    [[I4:%.*]] = icmp slt i8 [[I:%.*]], 0
 ; CHECK-NEXT:    br i1 [[I4]], label [[TRUELABEL:%.*]], label [[FALSELABEL:%.*]]
 ; CHECK:       truelabel:
-; CHECK-NEXT:    [[UMIN:%.*]] = call i8 @llvm.smin.i8(i8 [[SMAX]], i8 -126)
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt i8 [[I]], -126
+; CHECK-NEXT:    [[UMIN:%.*]] = select i1 [[TMP1]], i8 -127, i8 -126
 ; CHECK-NEXT:    ret i8 [[UMIN]]
 ; CHECK:       falselabel:
 ; CHECK-NEXT:    ret i8 0
diff --git a/llvm/test/Transforms/InstCombine/icmp-srem.ll b/llvm/test/Transforms/InstCombine/icmp-srem.ll
new file mode 100644
index 00000000000000..9ab92f15ae7d2d
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/icmp-srem.ll
@@ -0,0 +1,468 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i1 @icmp_ugt_sremsmin_smin(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmin_smin(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[X]], -2147483648
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, -2147483648
+  %c = icmp ugt i32 %r, -2147483648
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmin_sminp1(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmin_sminp1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], -2147483648
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[R]], -2147483647
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, -2147483648
+  %c = icmp ugt i32 %r, -2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmin_smaxm1(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmin_smaxm1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], -2147483648
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[R]], 2147483646
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, -2147483648
+  %c = icmp ugt i32 %r, 2147483646
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmin_smax(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmin_smax(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[X]], -2147483648
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, -2147483648
+  %c = icmp ugt i32 %r, 2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ult_sremsmin_smin(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_sremsmin_smin(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], -2147483648
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, -2147483648
+  %c = icmp ult i32 %r, -2147483648
+  ret i1 %c
+}
+
+define i1 @icmp_ult_sremsmin_sminp1(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_sremsmin_sminp1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], -2147483648
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, -2147483648
+  %c = icmp ult i32 %r, -2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ult_sremsmin_sminp2(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_sremsmin_sminp2(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], -2147483648
+; CHECK-NEXT:    [[C:%.*]] = icmp ult i32 [[R]], -2147483646
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, -2147483648
+  %c = icmp ult i32 %r, -2147483646
+  ret i1 %c
+}
+
+define i1 @icmp_ult_sremsmin_smax(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_sremsmin_smax(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], -2147483648
+; CHECK-NEXT:    [[C:%.*]] = icmp ult i32 [[R]], 2147483647
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, -2147483648
+  %c = icmp ult i32 %r, 2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_srem5_smin(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_srem5_smin(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[R]], 0
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ugt i32 %r, -2147483648
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_srem5_m5(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_srem5_m5(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[R]], 0
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ugt i32 %r, -5
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_srem5_m4(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_srem5_m4(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[R]], -4
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ugt i32 %r, -4
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_srem5_3(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_srem5_3(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[R]], 3
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ugt i32 %r, 3
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_srem5_4(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_srem5_4(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[R]], 0
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ugt i32 %r, 4
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_srem5_smaxm1(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_srem5_smaxm1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[R]], 0
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ugt i32 %r, 2147483646
+  ret i1 %c
+}
+
+define i1 @icmp_ult_srem5_sminp1(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_srem5_sminp1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ult i32 %r, -2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ult_srem5_m4(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_srem5_m4(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ult i32 %r, -4
+  ret i1 %c
+}
+
+define i1 @icmp_ult_srem5_m3(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_srem5_m3(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp ult i32 [[R]], -3
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ult i32 %r, -3
+  ret i1 %c
+}
+
+define i1 @icmp_ult_srem5_4(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_srem5_4(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp ult i32 [[R]], 4
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ult i32 %r, 4
+  ret i1 %c
+}
+
+define i1 @icmp_ult_srem5_5(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_srem5_5(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ult i32 %r, 5
+  ret i1 %c
+}
+
+define i1 @icmp_ult_srem5_smax(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_srem5_smax(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ult i32 %r, 2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmax_smin(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmax_smin(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[R]], 0
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ugt i32 %r, -2147483648
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmax_sminp1(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmax_sminp1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[R]], 0
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ugt i32 %r, -2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmax_sminp2(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmax_sminp2(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[R]], -2147483646
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ugt i32 %r, -2147483646
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmax_smaxm2(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmax_smaxm2(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[R]], 2147483645
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ugt i32 %r, 2147483645
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmax_smaxm1(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmax_smaxm1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[R]], 0
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ugt i32 %r, 2147483646
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmax_smax(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmax_smax(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[R]], 0
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ugt i32 %r, 2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ult_sremsmax_smin(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_sremsmax_smin(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ult i32 %r, -2147483648
+  ret i1 %c
+}
+
+define i1 @icmp_ult_sremsmax_sminp1(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_sremsmax_sminp1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ult i32 %r, -2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ult_sremsmax_sminp2(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_sremsmax_sminp2(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ult i32 %r, -2147483646
+  ret i1 %c
+}
+
+define i1 @icmp_ult_sremsmax...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 14, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Jacob Young (jacobly0)

Changes

This is a generalization of #122520 to other instructions with constrained result ranges.

Not sure whether this generalization is desirable, but I thought I would try it out.


Patch is 27.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122962.diff

10 Files Affected:

  • (modified) llvm/include/llvm/IR/ConstantRange.h (+6)
  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+3-7)
  • (modified) llvm/lib/IR/ConstantRange.cpp (+10)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+66-1)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+8)
  • (modified) llvm/test/Transforms/InstCombine/add.ll (+26)
  • (modified) llvm/test/Transforms/InstCombine/icmp-dom.ll (+3-3)
  • (added) llvm/test/Transforms/InstCombine/icmp-srem.ll (+468)
  • (modified) llvm/test/Transforms/InstCombine/smin-icmp.ll (+6-7)
  • (modified) llvm/unittests/IR/ConstantRangeTest.cpp (+12-1)
diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h
index d086c25390fd22..92e9341352c177 100644
--- a/llvm/include/llvm/IR/ConstantRange.h
+++ b/llvm/include/llvm/IR/ConstantRange.h
@@ -128,6 +128,12 @@ class [[nodiscard]] ConstantRange {
   /// NOTE: false does not mean that inverse predicate holds!
   bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const;
 
+  /// Does the predicate \p Pred or its inverse hold between ranges this and \p
+  /// Other? Returns `true` if the predicate always holds, `false` if the
+  /// inverse always holds, or `std::nullopt` otherwise.
+  std::optional<bool> icmpOrInverse(CmpInst::Predicate Pred,
+                                    const ConstantRange &Other) const;
+
   /// Return true iff CR1 ult CR2 is equivalent to CR1 slt CR2.
   /// Does not depend on strictness/direction of the predicate.
   static bool
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index d69747e30f884d..dae4f37908cf2c 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3783,13 +3783,9 @@ static Value *simplifyICmpInst(CmpPredicate Pred, Value *LHS, Value *RHS,
   // If both operands have range metadata, use the metadata
   // to simplify the comparison.
   if (std::optional<ConstantRange> RhsCr = getRange(RHS, Q.IIQ))
-    if (std::optional<ConstantRange> LhsCr = getRange(LHS, Q.IIQ)) {
-      if (LhsCr->icmp(Pred, *RhsCr))
-        return ConstantInt::getTrue(ITy);
-
-      if (LhsCr->icmp(CmpInst::getInversePredicate(Pred), *RhsCr))
-        return ConstantInt::getFalse(ITy);
-    }
+    if (std::optional<ConstantRange> LhsCr = getRange(LHS, Q.IIQ))
+      if (auto Res = LhsCr->icmpOrInverse(Pred, *RhsCr))
+        return ConstantInt::getBool(ITy, *Res);
 
   // Compare of cast, for example (zext X) != 0 -> X != 0
   if (isa<CastInst>(LHS) && (isa<Constant>(RHS) || isa<CastInst>(RHS))) {
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 35664353989929..c5380ab2ffeda4 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -274,6 +274,16 @@ bool ConstantRange::icmp(CmpInst::Predicate Pred,
   }
 }
 
+std::optional<bool>
+ConstantRange::icmpOrInverse(CmpInst::Predicate Pred,
+                             const ConstantRange &Other) const {
+  if (icmp(Pred, Other))
+    return true;
+  if (icmp(CmpInst::getInversePredicate(Pred), Other))
+    return false;
+  return std::nullopt;
+}
+
 /// Exact mul nuw region for single element RHS.
 static ConstantRange makeExactMulNUWRegion(const APInt &V) {
   unsigned BitWidth = V.getBitWidth();
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 2e457257599493..f45868ba283a1a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -3133,7 +3133,9 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
 
   if (ICmpInst::isUnsigned(Pred) && Add->hasNoSignedWrap() &&
       C.isNonNegative() && (C - *C2).isNonNegative() &&
-      computeConstantRange(X, /*ForSigned=*/true).add(*C2).isAllNonNegative())
+      computeConstantRange(X, Add, /*ForSigned=*/true)
+          .add(*C2)
+          .isAllNonNegative())
     return new ICmpInst(ICmpInst::getSignedPredicate(Pred), X,
                         ConstantInt::get(Ty, C - *C2));
 
@@ -7025,6 +7027,66 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I,
   }
 }
 
+// (icmp X, Y) --> (icmp slt/sgt X, 0/-1) iff Y is outside the signed range of X
+static ICmpInst *canonicalizeSignBitCheck(ICmpInst::Predicate Pred, Value *X,
+                                          const ConstantRange &XRange,
+                                          const ConstantRange &YRange) {
+  if (XRange.isSignWrappedSet())
+    return nullptr;
+  unsigned BitWidth = XRange.getBitWidth();
+  APInt SMin = APInt::getSignedMinValue(BitWidth);
+  APInt Zero = APInt::getZero(BitWidth);
+  auto NegResult =
+      XRange.intersectWith(ConstantRange(SMin, Zero), ConstantRange::Signed)
+          .icmpOrInverse(Pred, YRange);
+  if (!NegResult)
+    return nullptr;
+  auto PosResult =
+      XRange.intersectWith(ConstantRange(Zero, SMin), ConstantRange::Signed)
+          .icmpOrInverse(Pred, YRange);
+  if (!PosResult)
+    return nullptr;
+  assert(NegResult != PosResult &&
+         "Known result should been simplified already.");
+  Type *Ty = X->getType();
+  if (*NegResult)
+    return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::getNullValue(Ty));
+  return new ICmpInst(ICmpInst::ICMP_SGT, X, ConstantInt::getAllOnesValue(Ty));
+}
+
+// Try to fold an icmp using the constant ranges of its operands.
+Instruction *InstCombinerImpl::foldICmpUsingConstantRanges(ICmpInst &Cmp) {
+  Value *X = Cmp.getOperand(0);
+  if (!X->getType()->isIntOrIntVectorTy())
+    return nullptr;
+  Value *Y = Cmp.getOperand(1);
+  ICmpInst::Predicate Pred = Cmp.getPredicate();
+  ConstantRange XRange =
+      computeConstantRange(X, &Cmp, ICmpInst::isSigned(Pred));
+  if (XRange.isFullSet())
+    return nullptr; // early out if we don't have any information
+  ConstantRange YRange =
+      computeConstantRange(Y, &Cmp, ICmpInst::isSigned(Pred));
+  if (YRange.isFullSet())
+    return nullptr; // early out if we don't have any information
+  if (auto Res = XRange.icmpOrInverse(Pred, YRange))
+    return replaceInstUsesWith(Cmp, ConstantInt::getBool(Cmp.getType(), *Res));
+  if (ICmpInst::isUnsigned(Pred)) {
+    // Check if this icmp is actually a sign bit check.
+    const APInt *C;
+    bool IgnoreTrueIfSigned;
+    if (!match(Y, m_APInt(C)) ||
+        !isSignBitCheck(Pred, *C, IgnoreTrueIfSigned)) {
+      if (ICmpInst *Res = canonicalizeSignBitCheck(Pred, X, XRange, YRange))
+        return Res;
+      if (ICmpInst *Res = canonicalizeSignBitCheck(
+              ICmpInst::getSwappedPredicate(Pred), Y, YRange, XRange))
+        return Res;
+    }
+  }
+  return nullptr;
+}
+
 // Transform pattern like:
 //   (1 << Y) u<= X  or  ~(-1 << Y) u<  X  or  ((1 << Y)+(-1)) u<  X
 //   (1 << Y) u>  X  or  ~(-1 << Y) u>= X  or  ((1 << Y)+(-1)) u>= X
@@ -7397,6 +7459,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
   if (Instruction *Res = canonicalizeICmpPredicate(I))
     return Res;
 
+  if (Instruction *Res = foldICmpUsingConstantRanges(I))
+    return Res;
+
   if (Instruction *Res = foldICmpWithConstant(I))
     return Res;
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 83e1da98deeda0..9f60dcf59ae467 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -240,6 +240,13 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   convertOrOfShiftsToFunnelShift(Instruction &Or);
 
 private:
+  ConstantRange computeConstantRange(const Value *V, const Instruction *CtxI,
+                                     bool ForSigned, bool UseInstrInfo = true,
+                                     unsigned Depth = 0) {
+    return llvm::computeConstantRange(V, ForSigned, UseInstrInfo, &AC, CtxI,
+                                      &DT, Depth);
+  }
+
   bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI);
   bool isDesirableIntType(unsigned BitWidth) const;
   bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const;
@@ -668,6 +675,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *foldICmpWithCastOp(ICmpInst &ICmp);
   Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp);
 
+  Instruction *foldICmpUsingConstantRanges(ICmpInst &Cmp);
   Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp);
   Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp);
   Instruction *foldICmpWithConstant(ICmpInst &Cmp);
diff --git a/llvm/test/Transforms/InstCombine/add.ll b/llvm/test/Transforms/InstCombine/add.ll
index 222f87fa3a5f18..495f99824652d6 100644
--- a/llvm/test/Transforms/InstCombine/add.ll
+++ b/llvm/test/Transforms/InstCombine/add.ll
@@ -3018,6 +3018,32 @@ define i32 @floor_sdiv_wrong_op(i32 %x, i32 %y) {
   ret i32 %r
 }
 
+define i32 @floor_sdiv_using_srem_by_8(i32 %x) {
+; CHECK-LABEL: @floor_sdiv_using_srem_by_8(
+; CHECK-NEXT:    [[F:%.*]] = ashr i32 [[X:%.*]], 3
+; CHECK-NEXT:    ret i32 [[F]]
+;
+  %d = sdiv i32 %x, 8
+  %r = srem i32 %x, 8
+  %i = icmp ugt i32 %r, -2147483648
+  %s = sext i1 %i to i32
+  %f = add i32 %d, %s
+  ret i32 %f
+}
+
+define i32 @floor_sdiv_using_srem_by_2(i32 %x) {
+; CHECK-LABEL: @floor_sdiv_using_srem_by_2(
+; CHECK-NEXT:    [[F:%.*]] = ashr i32 [[X:%.*]], 1
+; CHECK-NEXT:    ret i32 [[F]]
+;
+  %d = sdiv i32 %x, 2
+  %r = srem i32 %x, 2
+  %i = icmp ugt i32 %r, -2147483648
+  %s = sext i1 %i to i32
+  %f = add i32 %d, %s
+  ret i32 %f
+}
+
 ; (X s>> (BW - 1)) + (zext (X s> 0)) --> (X s>> (BW - 1)) | (zext (X != 0))
 
 define i8 @signum_i8_i8(i8 %x) {
diff --git a/llvm/test/Transforms/InstCombine/icmp-dom.ll b/llvm/test/Transforms/InstCombine/icmp-dom.ll
index 3cf3a7af77041c..66e9e514a9022a 100644
--- a/llvm/test/Transforms/InstCombine/icmp-dom.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-dom.ll
@@ -381,11 +381,11 @@ falselabel:
 
 define i8 @PR48900_alt(i8 %i, ptr %p) {
 ; CHECK-LABEL: @PR48900_alt(
-; CHECK-NEXT:    [[SMAX:%.*]] = call i8 @llvm.smax.i8(i8 [[I:%.*]], i8 -127)
-; CHECK-NEXT:    [[I4:%.*]] = icmp ugt i8 [[SMAX]], -128
+; CHECK-NEXT:    [[I4:%.*]] = icmp slt i8 [[I:%.*]], 0
 ; CHECK-NEXT:    br i1 [[I4]], label [[TRUELABEL:%.*]], label [[FALSELABEL:%.*]]
 ; CHECK:       truelabel:
-; CHECK-NEXT:    [[UMIN:%.*]] = call i8 @llvm.smin.i8(i8 [[SMAX]], i8 -126)
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp slt i8 [[I]], -126
+; CHECK-NEXT:    [[UMIN:%.*]] = select i1 [[TMP1]], i8 -127, i8 -126
 ; CHECK-NEXT:    ret i8 [[UMIN]]
 ; CHECK:       falselabel:
 ; CHECK-NEXT:    ret i8 0
diff --git a/llvm/test/Transforms/InstCombine/icmp-srem.ll b/llvm/test/Transforms/InstCombine/icmp-srem.ll
new file mode 100644
index 00000000000000..9ab92f15ae7d2d
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/icmp-srem.ll
@@ -0,0 +1,468 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i1 @icmp_ugt_sremsmin_smin(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmin_smin(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[X]], -2147483648
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, -2147483648
+  %c = icmp ugt i32 %r, -2147483648
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmin_sminp1(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmin_sminp1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], -2147483648
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[R]], -2147483647
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, -2147483648
+  %c = icmp ugt i32 %r, -2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmin_smaxm1(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmin_smaxm1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], -2147483648
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[R]], 2147483646
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, -2147483648
+  %c = icmp ugt i32 %r, 2147483646
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmin_smax(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmin_smax(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[X]], -2147483648
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, -2147483648
+  %c = icmp ugt i32 %r, 2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ult_sremsmin_smin(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_sremsmin_smin(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], -2147483648
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, -2147483648
+  %c = icmp ult i32 %r, -2147483648
+  ret i1 %c
+}
+
+define i1 @icmp_ult_sremsmin_sminp1(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_sremsmin_sminp1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], -2147483648
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, -2147483648
+  %c = icmp ult i32 %r, -2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ult_sremsmin_sminp2(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_sremsmin_sminp2(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], -2147483648
+; CHECK-NEXT:    [[C:%.*]] = icmp ult i32 [[R]], -2147483646
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, -2147483648
+  %c = icmp ult i32 %r, -2147483646
+  ret i1 %c
+}
+
+define i1 @icmp_ult_sremsmin_smax(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_sremsmin_smax(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], -2147483648
+; CHECK-NEXT:    [[C:%.*]] = icmp ult i32 [[R]], 2147483647
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, -2147483648
+  %c = icmp ult i32 %r, 2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_srem5_smin(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_srem5_smin(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[R]], 0
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ugt i32 %r, -2147483648
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_srem5_m5(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_srem5_m5(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[R]], 0
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ugt i32 %r, -5
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_srem5_m4(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_srem5_m4(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[R]], -4
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ugt i32 %r, -4
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_srem5_3(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_srem5_3(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[R]], 3
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ugt i32 %r, 3
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_srem5_4(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_srem5_4(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[R]], 0
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ugt i32 %r, 4
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_srem5_smaxm1(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_srem5_smaxm1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[R]], 0
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ugt i32 %r, 2147483646
+  ret i1 %c
+}
+
+define i1 @icmp_ult_srem5_sminp1(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_srem5_sminp1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ult i32 %r, -2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ult_srem5_m4(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_srem5_m4(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ult i32 %r, -4
+  ret i1 %c
+}
+
+define i1 @icmp_ult_srem5_m3(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_srem5_m3(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp ult i32 [[R]], -3
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ult i32 %r, -3
+  ret i1 %c
+}
+
+define i1 @icmp_ult_srem5_4(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_srem5_4(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp ult i32 [[R]], 4
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ult i32 %r, 4
+  ret i1 %c
+}
+
+define i1 @icmp_ult_srem5_5(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_srem5_5(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ult i32 %r, 5
+  ret i1 %c
+}
+
+define i1 @icmp_ult_srem5_smax(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_srem5_smax(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 5
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 5
+  %c = icmp ult i32 %r, 2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmax_smin(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmax_smin(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[R]], 0
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ugt i32 %r, -2147483648
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmax_sminp1(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmax_sminp1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[R]], 0
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ugt i32 %r, -2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmax_sminp2(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmax_sminp2(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[R]], -2147483646
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ugt i32 %r, -2147483646
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmax_smaxm2(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmax_smaxm2(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp ugt i32 [[R]], 2147483645
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ugt i32 %r, 2147483645
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmax_smaxm1(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmax_smaxm1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[R]], 0
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ugt i32 %r, 2147483646
+  ret i1 %c
+}
+
+define i1 @icmp_ugt_sremsmax_smax(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ugt_sremsmax_smax(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp slt i32 [[R]], 0
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ugt i32 %r, 2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ult_sremsmax_smin(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_sremsmax_smin(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ult i32 %r, -2147483648
+  ret i1 %c
+}
+
+define i1 @icmp_ult_sremsmax_sminp1(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_sremsmax_sminp1(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ult i32 %r, -2147483647
+  ret i1 %c
+}
+
+define i1 @icmp_ult_sremsmax_sminp2(i32 %x) {
+; CHECK-LABEL: define i1 @icmp_ult_sremsmax_sminp2(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:    [[R:%.*]] = srem i32 [[X]], 2147483647
+; CHECK-NEXT:    [[C:%.*]] = icmp sgt i32 [[R]], -1
+; CHECK-NEXT:    ret i1 [[C]]
+;
+  %r = srem i32 %x, 2147483647
+  %c = icmp ult i32 %r, -2147483646
+  ret i1 %c
+}
+
+define i1 @icmp_ult_sremsmax...
[truncated]

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

@dtcxzyw dtcxzyw requested review from dtcxzyw and goldsteinn January 15, 2025 07:36
; CHECK-NEXT: call void @use_v2i1(<2 x i1> [[CMP5]])
; CHECK-NEXT: [[CMP6:%.*]] = icmp ult <2 x i32> [[COND]], splat (i32 11)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two cases don't look profitable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, this is icmp ult (smin %Y, 5), 10 -> icmp sgt %Y, -1, is there a reason that removing an operation wouldn't be considered profitable here?

if (*NegResult)
return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::getNullValue(Ty));
return new ICmpInst(ICmpInst::ICMP_SGT, X, ConstantInt::getAllOnesValue(Ty));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least IMO, more useful than just building to just maximally constrain the icmp. I.e if you have icmp ugt X, 10 and can actually narrow it to icmp ugt X, 129, thats better than icmp slt X, 0.

Copy link
Contributor Author

@jacobly0 jacobly0 Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to use a form that other combines will recognize as a sign check, such as icmp slt X, 0 or icmp ugt X, smax. I'm also not clear which side is considered maximally constrained here. For example if -5 <= X <= 5 then icmp ugt X, -6 <=> icmp ugt X, smax <=> icmp slt X, 0 <=> icmp ugt X, 5. I'm choosing the signed comparison, which is exactly equivalent to comparing to smax, because sign checks can be cheaper and have special combines written for them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think it would be possible to update the downstream combines to handle a different canonical form. My main worry is losing out in the backend on various shifting tricks and simpler sign checks that don't require loading a potentially large constant and which are localized to only needing to check one bit for large integers. I'm sure this could be mitigated by moving the range checks and sign bit check preference into the backends that care, but it certainly does not appear to currently be the case.

Under the assumption that the issue is that more information could be gained by moving the comparison constant to one of the extremes, I'm still unclear on how to choose one side of the range as more beneficial. It seems like some information is lost on each side, by reducing the known range on either a true or false branch, which is only gained back by recomputing ranges. At that point, the choice of constant does not seem to affect the information you have, because the source range is always split into the same two ranges by any choice of constant.

If you want to limit it to only changing the constant of an existing icmp and not changing the signedness of the predicate, then that seems perfectly reasonable and potentially simple to implement. If you have a specific definition of canonical other than preferring strict sign bit checks, I could also try to see if I could make it work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think is depends heavily on the use. I can see the case where if you are branching on the cmp then maximally constraining it in one direction leaves us with the least useful region on the other side. But in other use-cases whether only the true (or false) value matters we would almost certaintly want maximal constraint. I.e (select (icmp ult X, 4), f(X), C) is almost always better (select (icmp sge X, 0), f(X), C).

I guess I can see sign-compare is a general middle-ground. Although I think this fold is liable to cause regressions in cases that it relaxes the constraints on a cmp that dominates some important expressions.

Truthfully not sure what the correct approach here is. I guess in some ideal case I would think some cost-function to decide which direction of the cmp is liable to benefit most from being maximally constrained, altough that might be overkill, especially in InstCombine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that totally makes sense now, and I don't really have a better solution to propose at the moment.

Do you have a suggestion for how to proceed with my original add (sdiv X, C), sext i1 (icmp ugt (srem X, C), smin) => ashr X, log2(C) use case in the short term, and in a way that won't need to be reverted? Is my original PR that only applies to srem seem restricted enough enough to avoid major regressions (especially given the relative rareness of srem), or should I revert back to something like my initial code which only applied where the full original foldICmpSRemConstant is applicable, just slightly improved based on what I have learned from these generalizations?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you talking about #122520? If so I think that can in without further changes.

This is a generalization of llvm#122520 to other instructions with
constrained result ranges.
@jacobly0
Copy link
Contributor Author

Closing in favor of #122520 due to a concern about regressions caused by the canonicalization.

@jacobly0 jacobly0 closed this Jan 18, 2025
@jacobly0 jacobly0 deleted the canon-sign branch January 18, 2025 22:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants