@@ -41658,6 +41658,19 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
41658
41658
N1.getOperand(0).getScalarValueSizeInBits() <= 8))
41659
41659
return SDValue();
41660
41660
41661
+ // Check if only high 16 bits of signed 16-bit multiplication are used
41662
+ bool high_only = true;
41663
+
41664
+ for (auto *User : N->uses()) {
41665
+ if (User->getOpcode() == ISD::SRL || User->getOpcode() == ISD::SRA) {
41666
+ if (DAG.MaskedValueIsAllOnes(User->getOperand(1), {32, 16})) {
41667
+ continue;
41668
+ }
41669
+ }
41670
+ high_only = false;
41671
+ break;
41672
+ }
41673
+
41661
41674
APInt Mask17 = APInt::getHighBitsSet(32, 17);
41662
41675
if (N0.getOpcode() == ISD::SRA && N1.getOpcode() == ISD::SRA) {
41663
41676
// If both arguments are sign-extended, try to replace sign extends
@@ -41671,10 +41684,17 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
41671
41684
DAG.isSplatValue(N1.getOperand(1))) {
41672
41685
// Nullify mask to pass the following check
41673
41686
Mask17 = 0;
41674
- N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0),
41675
- N0.getOperand(1));
41676
- N1 = DAG.getNode(ISD::SRL, N1.getNode(), VT, N1.getOperand(0),
41677
- N1.getOperand(1));
41687
+
41688
+ if (high_only) {
41689
+ // Bypass shifts to keep values in high bits
41690
+ N0 = N0.getOperand(0);
41691
+ N1 = N1.getOperand(0);
41692
+ } else {
41693
+ N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0),
41694
+ N0.getOperand(1));
41695
+ N1 = DAG.getNode(ISD::SRL, N1.getNode(), VT, N1.getOperand(0),
41696
+ N1.getOperand(1));
41697
+ }
41678
41698
}
41679
41699
}
41680
41700
@@ -41684,15 +41704,32 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
41684
41704
DAG.isSplatValue(N0.getOperand(1)) &&
41685
41705
N0.getOperand(1).getConstantOperandVal(0) == 16) {
41686
41706
Mask17 = 0;
41687
- N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0),
41688
- N0.getOperand(1));
41707
+
41708
+ if (high_only)
41709
+ N0 = N0.getOperand(0);
41710
+ else
41711
+ N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0),
41712
+ N0.getOperand(1));
41689
41713
}
41690
41714
}
41691
41715
41692
41716
if (!!Mask17 && (!DAG.MaskedValueIsZero(N1, Mask17) ||
41693
41717
!DAG.MaskedValueIsZero(N0, Mask17)))
41694
41718
return SDValue();
41695
41719
41720
+ // Use PMULHW if applicable
41721
+ if (high_only && !Mask17) {
41722
+ auto MULHSBuilder = [=](SelectionDAG &DAG, const SDLoc &DL,
41723
+ ArrayRef<SDValue> Ops) {
41724
+ MVT RT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
41725
+ MVT OpVT = Ops[0].getSimpleValueType();
41726
+ return DAG.getBitcast(RT, DAG.getNode(ISD::MULHS, DL, OpVT, Ops));
41727
+ };
41728
+ return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT,
41729
+ {DAG.getBitcast(WVT, N0), DAG.getBitcast(WVT, N1)},
41730
+ MULHSBuilder);
41731
+ }
41732
+
41696
41733
// Use SplitOpsAndApply to handle AVX splitting.
41697
41734
auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
41698
41735
ArrayRef<SDValue> Ops) {
0 commit comments