@@ -41664,9 +41664,11 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
41664
41664
// with zero extends, which should qualify for the optimization.
41665
41665
// Otherwise just fallback to zero-extension check.
41666
41666
if (isa<ConstantSDNode>(N0.getOperand(1).getOperand(0)) &&
41667
- N0.getOperand(1).getConstantOperandVal(0) == 16 &&
41668
41667
isa<ConstantSDNode>(N1.getOperand(1).getOperand(0)) &&
41669
- N1.getOperand(1).getConstantOperandVal(0) == 16) {
41668
+ N0.getOperand(1).getConstantOperandVal(0) == 16 &&
41669
+ N1.getOperand(1).getConstantOperandVal(0) == 16 &&
41670
+ DAG.isSplatValue(N0.getOperand(1)) &&
41671
+ DAG.isSplatValue(N1.getOperand(1))) {
41670
41672
// Nullify mask to pass the following check
41671
41673
Mask17 = 0;
41672
41674
N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0),
@@ -41675,8 +41677,20 @@ static SDValue combineMulToPMADDWD(SDNode *N, SelectionDAG &DAG,
41675
41677
N1.getOperand(1));
41676
41678
}
41677
41679
}
41678
- if (!DAG.MaskedValueIsZero(N1, Mask17) ||
41679
- !DAG.MaskedValueIsZero(N0, Mask17))
41680
+
41681
+ if (!!Mask17 && N0.getOpcode() == ISD::SRA) {
41682
+ if (isa<ConstantSDNode>(N0.getOperand(1).getOperand(0)) &&
41683
+ DAG.ComputeNumSignBits(N1) >= 17 &&
41684
+ DAG.isSplatValue(N0.getOperand(1)) &&
41685
+ N0.getOperand(1).getConstantOperandVal(0) == 16) {
41686
+ Mask17 = 0;
41687
+ N0 = DAG.getNode(ISD::SRL, N0.getNode(), VT, N0.getOperand(0),
41688
+ N0.getOperand(1));
41689
+ }
41690
+ }
41691
+
41692
+ if (!!Mask17 && (!DAG.MaskedValueIsZero(N1, Mask17) ||
41693
+ !DAG.MaskedValueIsZero(N0, Mask17)))
41680
41694
return SDValue();
41681
41695
41682
41696
// Use SplitOpsAndApply to handle AVX splitting.
0 commit comments