diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp index 3599428c5ff41..5c0d1dd1c74b0 100644 --- a/llvm/lib/Analysis/CmpInstAnalysis.cpp +++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp @@ -168,6 +168,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, std::optional llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) { + using namespace PatternMatch; if (auto *ICmp = dyn_cast(Cond)) { // Don't allow pointers. Splat vectors are fine. if (!ICmp->getOperand(0)->getType()->isIntOrIntVectorTy()) @@ -176,6 +177,19 @@ llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) { ICmp->getPredicate(), LookThruTrunc, AllowNonZeroC); } + Value *X; + if (Cond->getType()->isIntOrIntVectorTy(1) && + (match(Cond, m_Trunc(m_Value(X))) || + match(Cond, m_Not(m_Trunc(m_Value(X)))))) { + DecomposedBitTest Result; + Result.X = X; + unsigned BitWidth = X->getType()->getScalarSizeInBits(); + Result.Mask = APInt(BitWidth, 1); + Result.C = APInt::getZero(BitWidth); + Result.Pred = isa(Cond) ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ; + + return Result; + } return std::nullopt; } diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index d69747e30f884..1facf56937f24 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -4612,12 +4612,11 @@ static Value *simplifyCmpSelOfMaxMin(Value *CmpLHS, Value *CmpRHS, return nullptr; } -/// An alternative way to test if a bit is set or not uses sgt/slt instead of -/// eq/ne. -static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS, - CmpPredicate Pred, Value *TrueVal, - Value *FalseVal) { - if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred)) +/// An alternative way to test if a bit is set or not. +/// uses e.g. sgt/slt or trunc instead of eq/ne. +static Value *simplifySelectWithBitTest(Value *CondVal, Value *TrueVal, + Value *FalseVal) { + if (auto Res = decomposeBitTest(CondVal)) return simplifySelectBitTest(TrueVal, FalseVal, Res->X, &Res->Mask, Res->Pred == ICmpInst::ICMP_EQ); @@ -4728,11 +4727,6 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, return FalseVal; } - // Check for other compares that behave like bit test. - if (Value *V = - simplifySelectWithFakeICmpEq(CmpLHS, CmpRHS, Pred, TrueVal, FalseVal)) - return V; - // If we have a scalar equality comparison, then we know the value in one of // the arms of the select. See if substituting this value into the arm and // simplifying the result yields the same value as the other arm. @@ -4984,6 +4978,9 @@ static Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, simplifySelectWithICmpCond(Cond, TrueVal, FalseVal, Q, MaxRecurse)) return V; + if (Value *V = simplifySelectWithBitTest(Cond, TrueVal, FalseVal)) + return V; + if (Value *V = simplifySelectWithFCmp(Cond, TrueVal, FalseVal, Q, MaxRecurse)) return V; diff --git a/llvm/test/Transforms/InstSimplify/select.ll b/llvm/test/Transforms/InstSimplify/select.ll index 40539b8ade388..1b5703a46cf68 100644 --- a/llvm/test/Transforms/InstSimplify/select.ll +++ b/llvm/test/Transforms/InstSimplify/select.ll @@ -1752,10 +1752,7 @@ define <4 x i32> @select_vector_cmp_with_bitcasts(<2 x i64> %x, <4 x i32> %y) { define i8 @bittest_trunc_or(i8 %x) { ; CHECK-LABEL: @bittest_trunc_or( -; CHECK-NEXT: [[TRUNC:%.*]] = trunc i8 [[X1:%.*]] to i1 -; CHECK-NEXT: [[OR:%.*]] = or i8 [[X1]], 1 -; CHECK-NEXT: [[X:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[X1]] -; CHECK-NEXT: ret i8 [[X]] +; CHECK-NEXT: ret i8 [[X:%.*]] ; %trunc = trunc i8 %x to i1 %or = or i8 %x, 1 @@ -1765,11 +1762,8 @@ define i8 @bittest_trunc_or(i8 %x) { define i8 @bittest_trunc_not_or(i8 %x) { ; CHECK-LABEL: @bittest_trunc_not_or( -; CHECK-NEXT: [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1 -; CHECK-NEXT: [[NOT:%.*]] = xor i1 [[TRUNC]], true -; CHECK-NEXT: [[OR:%.*]] = or i8 [[X]], 1 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[NOT]], i8 [[OR]], i8 [[X]] -; CHECK-NEXT: ret i8 [[COND]] +; CHECK-NEXT: [[OR:%.*]] = or i8 [[X:%.*]], 1 +; CHECK-NEXT: ret i8 [[OR]] ; %trunc = trunc i8 %x to i1 %not = xor i1 %trunc, true