Skip to content

Commit 1f23ff6

Browse files
committed
X86: allow combineMulToPMADDWD to emit PMULHW (MULHS)
If only high bits of single multiplication are used.
1 parent 5836324 commit 1f23ff6

File tree

1 file changed

+43
-6
lines changed

1 file changed

+43
-6
lines changed

lib/Target/X86/X86ISelLowering.cpp

+43-6
Original file line numberDiff line numberDiff line change
@@ -41658,6 +41658,19 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
4165841658
N1.getOperand(0).getScalarValueSizeInBits() <= 8))
4165941659
return SDValue();
4166041660

41661+
// Check if only high 16 bits of signed 16-bit multiplication are used
41662+
bool high_only = true;
41663+
41664+
for (auto *User : N->uses()) {
41665+
if (User->getOpcode() == ISD::SRL || User->getOpcode() == ISD::SRA) {
41666+
if (DAG.MaskedValueIsAllOnes(User->getOperand(1), {32, 16})) {
41667+
continue;
41668+
}
41669+
}
41670+
high_only = false;
41671+
break;
41672+
}
41673+
4166141674
APInt Mask17 = APInt::getHighBitsSet(32, 17);
4166241675
if (N0.getOpcode() == ISD::SRA && N1.getOpcode() == ISD::SRA) {
4166341676
// If both arguments are sign-extended, try to replace sign extends
@@ -41671,10 +41684,17 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
4167141684
DAG.isSplatValue(N1.getOperand(1))) {
4167241685
// Nullify mask to pass the following check
4167341686
Mask17 = 0;
41674-
N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0),
41675-
N0.getOperand(1));
41676-
N1 = DAG.getNode(ISD::SRL, N1.getNode(), VT, N1.getOperand(0),
41677-
N1.getOperand(1));
41687+
41688+
if (high_only) {
41689+
// Bypass shifts to keep values in high bits
41690+
N0 = N0.getOperand(0);
41691+
N1 = N1.getOperand(0);
41692+
} else {
41693+
N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0),
41694+
N0.getOperand(1));
41695+
N1 = DAG.getNode(ISD::SRL, N1.getNode(), VT, N1.getOperand(0),
41696+
N1.getOperand(1));
41697+
}
4167841698
}
4167941699
}
4168041700

@@ -41684,15 +41704,32 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
4168441704
DAG.isSplatValue(N0.getOperand(1)) &&
4168541705
N0.getOperand(1).getConstantOperandVal(0) == 16) {
4168641706
Mask17 = 0;
41687-
N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0),
41688-
N0.getOperand(1));
41707+
41708+
if (high_only)
41709+
N0 = N0.getOperand(0);
41710+
else
41711+
N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0),
41712+
N0.getOperand(1));
4168941713
}
4169041714
}
4169141715

4169241716
if (!!Mask17 && (!DAG.MaskedValueIsZero(N1, Mask17) ||
4169341717
!DAG.MaskedValueIsZero(N0, Mask17)))
4169441718
return SDValue();
4169541719

41720+
// Use PMULHW if applicable
41721+
if (high_only && !Mask17) {
41722+
auto MULHSBuilder = [=](SelectionDAG &DAG, const SDLoc &DL,
41723+
ArrayRef<SDValue> Ops) {
41724+
MVT RT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
41725+
MVT OpVT = Ops[0].getSimpleValueType();
41726+
return DAG.getBitcast(RT, DAG.getNode(ISD::MULHS, DL, OpVT, Ops));
41727+
};
41728+
return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT,
41729+
{DAG.getBitcast(WVT, N0), DAG.getBitcast(WVT, N1)},
41730+
MULHSBuilder);
41731+
}
41732+
4169641733
// Use SplitOpsAndApply to handle AVX splitting.
4169741734
auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
4169841735
ArrayRef<SDValue> Ops) {

0 commit comments

Comments
 (0)