diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 627edb680dfa1..58145c7e3c591 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -1038,23 +1038,20 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) { // The compare predicates should match, and each compare should have a // constant operand. - // TODO: Relax the one-use constraints. Value *B0 = I.getOperand(0), *B1 = I.getOperand(1); Instruction *I0, *I1; Constant *C0, *C1; CmpInst::Predicate P0, P1; - if (!match(B0, m_OneUse(m_Cmp(P0, m_Instruction(I0), m_Constant(C0)))) || - !match(B1, m_OneUse(m_Cmp(P1, m_Instruction(I1), m_Constant(C1)))) || - P0 != P1) + if (!match(B0, m_Cmp(P0, m_Instruction(I0), m_Constant(C0))) || + !match(B1, m_Cmp(P1, m_Instruction(I1), m_Constant(C1))) || P0 != P1) return false; // The compare operands must be extracts of the same vector with constant // extract indexes. - // TODO: Relax the one-use constraints. Value *X; uint64_t Index0, Index1; - if (!match(I0, m_OneUse(m_ExtractElt(m_Value(X), m_ConstantInt(Index0)))) || - !match(I1, m_OneUse(m_ExtractElt(m_Specific(X), m_ConstantInt(Index1))))) + if (!match(I0, m_ExtractElt(m_Value(X), m_ConstantInt(Index0))) || + !match(I1, m_ExtractElt(m_Specific(X), m_ConstantInt(Index1)))) return false; auto *Ext0 = cast(I0); @@ -1073,14 +1070,16 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) { return false; TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + InstructionCost Ext0Cost = + TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0), + Ext1Cost = + TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1); InstructionCost OldCost = - TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0); - OldCost += TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1); - OldCost += + Ext0Cost + Ext1Cost + TTI.getCmpSelInstrCost(CmpOpcode, I0->getType(), CmpInst::makeCmpResultType(I0->getType()), Pred) * - 2; - OldCost += TTI.getArithmeticInstrCost(I.getOpcode(), I.getType()); + 2 + + TTI.getArithmeticInstrCost(I.getOpcode(), I.getType()); // The proposed vector pattern is: // vcmp = cmp Pred X, VecC @@ -1096,6 +1095,8 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) { ShufMask); NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy); NewCost += TTI.getVectorInstrCost(*Ext0, CmpTy, CostKind, CheapIndex); + NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost; + NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost; // Aggressively form vector ops if the cost is equal because the transform // may enable further optimization. diff --git a/llvm/test/Transforms/VectorCombine/X86/extract-cmp-binop.ll b/llvm/test/Transforms/VectorCombine/X86/extract-cmp-binop.ll index 462bb13ae7d12..be5359f549ac9 100644 --- a/llvm/test/Transforms/VectorCombine/X86/extract-cmp-binop.ll +++ b/llvm/test/Transforms/VectorCombine/X86/extract-cmp-binop.ll @@ -92,6 +92,60 @@ define i1 @icmp_add_v8i32(<8 x i32> %a) { ret i1 %r } +declare void @use() + +define i1 @fcmp_and_v2f64_multiuse(<2 x double> %a) { +; SSE-LABEL: @fcmp_and_v2f64_multiuse( +; SSE-NEXT: [[E1:%.*]] = extractelement <2 x double> [[A:%.*]], i32 0 +; SSE-NEXT: call void @use(double [[E1]]) +; SSE-NEXT: [[E2:%.*]] = extractelement <2 x double> [[A]], i32 1 +; SSE-NEXT: [[CMP1:%.*]] = fcmp olt double [[E1]], 4.200000e+01 +; SSE-NEXT: [[CMP2:%.*]] = fcmp olt double [[E2]], -8.000000e+00 +; SSE-NEXT: [[R:%.*]] = and i1 [[CMP1]], [[CMP2]] +; SSE-NEXT: call void @use(i1 [[R]]) +; SSE-NEXT: ret i1 [[R]] +; +; AVX-LABEL: @fcmp_and_v2f64_multiuse( +; AVX-NEXT: [[E1:%.*]] = extractelement <2 x double> [[A:%.*]], i32 0 +; AVX-NEXT: call void @use(double [[E1]]) +; AVX-NEXT: [[TMP1:%.*]] = fcmp olt <2 x double> [[A]], +; AVX-NEXT: [[SHIFT:%.*]] = shufflevector <2 x i1> [[TMP1]], <2 x i1> poison, <2 x i32> +; AVX-NEXT: [[TMP2:%.*]] = and <2 x i1> [[TMP1]], [[SHIFT]] +; AVX-NEXT: [[R:%.*]] = extractelement <2 x i1> [[TMP2]], i64 0 +; AVX-NEXT: call void @use(i1 [[R]]) +; AVX-NEXT: ret i1 [[R]] +; + %e1 = extractelement <2 x double> %a, i32 0 + call void @use(double %e1) + %e2 = extractelement <2 x double> %a, i32 1 + %cmp1 = fcmp olt double %e1, 42.0 + %cmp2 = fcmp olt double %e2, -8.0 + %r = and i1 %cmp1, %cmp2 + call void @use(i1 %r) + ret i1 %r +} + +define i1 @icmp_xor_v4i32_multiuse(<4 x i32> %a) { +; CHECK-LABEL: @icmp_xor_v4i32_multiuse( +; CHECK-NEXT: [[E2:%.*]] = extractelement <4 x i32> [[A:%.*]], i32 1 +; CHECK-NEXT: call void @use(i32 [[E2]]) +; CHECK-NEXT: [[TMP1:%.*]] = icmp sgt <4 x i32> [[A]], +; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <4 x i1> [[TMP1]], <4 x i1> poison, <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = xor <4 x i1> [[TMP1]], [[SHIFT]] +; CHECK-NEXT: [[R:%.*]] = extractelement <4 x i1> [[TMP2]], i64 1 +; CHECK-NEXT: call void @use(i1 [[R]]) +; CHECK-NEXT: ret i1 [[R]] +; + %e1 = extractelement <4 x i32> %a, i32 3 + %e2 = extractelement <4 x i32> %a, i32 1 + call void @use(i32 %e2) + %cmp1 = icmp sgt i32 %e1, 42 + %cmp2 = icmp sgt i32 %e2, -8 + %r = xor i1 %cmp1, %cmp2 + call void @use(i1 %r) + ret i1 %r +} + ; Negative test - this could CSE/simplify. define i1 @same_extract_index(<4 x i32> %a) {