@@ -7071,14 +7071,23 @@ static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl,
70717071}
70727072
70737073// Recurse to find a LoadSDNode source and the accumulated ByteOffest.
7074- static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
7075- if (ISD::isNON_EXTLoad(Elt.getNode())) {
7076- auto *BaseLd = cast<LoadSDNode>(Elt);
7077- if (!BaseLd->isSimple())
7078- return false;
7079- Ld = BaseLd;
7080- ByteOffset = 0;
7081- return true;
7074+ template <typename T>
7075+ static bool findEltLoadSrc(SDValue Elt, T *&Ld, int64_t &ByteOffset) {
7076+ if constexpr (std::is_same_v<T, AtomicSDNode>) {
7077+ if (auto *BaseLd = dyn_cast<AtomicSDNode>(Elt)) {
7078+ Ld = BaseLd;
7079+ ByteOffset = 0;
7080+ return true;
7081+ }
7082+ } else if constexpr (std::is_same_v<T, LoadSDNode>) {
7083+ if (ISD::isNON_EXTLoad(Elt.getNode())) {
7084+ auto *BaseLd = cast<LoadSDNode>(Elt);
7085+ if (!BaseLd->isSimple())
7086+ return false;
7087+ Ld = BaseLd;
7088+ ByteOffset = 0;
7089+ return true;
7090+ }
70827091 }
70837092
70847093 switch (Elt.getOpcode()) {
@@ -7118,6 +7127,7 @@ static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
71187127/// a build_vector or insert_subvector whose loaded operands are 'Elts'.
71197128///
71207129/// Example: <load i32 *a, load i32 *a+4, zero, undef> -> zextload a
7130+ template <typename T>
71217131static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71227132 const SDLoc &DL, SelectionDAG &DAG,
71237133 const X86Subtarget &Subtarget,
@@ -7132,7 +7142,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71327142 APInt ZeroMask = APInt::getZero(NumElems);
71337143 APInt UndefMask = APInt::getZero(NumElems);
71347144
7135- SmallVector<LoadSDNode *, 8> Loads(NumElems, nullptr);
7145+ SmallVector<T *, 8> Loads(NumElems, nullptr);
71367146 SmallVector<int64_t, 8> ByteOffsets(NumElems, 0);
71377147
71387148 // For each element in the initializer, see if we've found a load, zero or an
@@ -7182,7 +7192,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71827192 EVT EltBaseVT = EltBase.getValueType();
71837193 assert(EltBaseVT.getSizeInBits() == EltBaseVT.getStoreSizeInBits() &&
71847194 "Register/Memory size mismatch");
7185- LoadSDNode *LDBase = Loads[FirstLoadedElt];
7195+ T *LDBase = Loads[FirstLoadedElt];
71867196 assert(LDBase && "Did not find base load for merging consecutive loads");
71877197 unsigned BaseSizeInBits = EltBaseVT.getStoreSizeInBits();
71887198 unsigned BaseSizeInBytes = BaseSizeInBits / 8;
@@ -7196,8 +7206,8 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71967206
71977207 // Check to see if the element's load is consecutive to the base load
71987208 // or offset from a previous (already checked) load.
7199- auto CheckConsecutiveLoad = [&](LoadSDNode *Base, int EltIdx) {
7200- LoadSDNode *Ld = Loads[EltIdx];
7209+ auto CheckConsecutiveLoad = [&](T *Base, int EltIdx) {
7210+ T *Ld = Loads[EltIdx];
72017211 int64_t ByteOffset = ByteOffsets[EltIdx];
72027212 if (ByteOffset && (ByteOffset % BaseSizeInBytes) == 0) {
72037213 int64_t BaseIdx = EltIdx - (ByteOffset / BaseSizeInBytes);
@@ -7225,7 +7235,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
72257235 }
72267236 }
72277237
7228- auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) {
7238+ auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, T *LDBase) {
72297239 auto MMOFlags = LDBase->getMemOperand()->getFlags();
72307240 assert(LDBase->isSimple() &&
72317241 "Cannot merge volatile or atomic loads.");
@@ -7295,7 +7305,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
72957305 EVT HalfVT =
72967306 EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), HalfNumElems);
72977307 SDValue HalfLD =
7298- EltsFromConsecutiveLoads(HalfVT, Elts.drop_back(HalfNumElems), DL,
7308+ EltsFromConsecutiveLoads<T> (HalfVT, Elts.drop_back(HalfNumElems), DL,
72997309 DAG, Subtarget, IsAfterLegalize);
73007310 if (HalfLD)
73017311 return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT),
@@ -7372,7 +7382,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
73727382 EVT::getVectorVT(*DAG.getContext(), RepeatVT.getScalarType(),
73737383 VT.getSizeInBits() / ScalarSize);
73747384 if (TLI.isTypeLegal(BroadcastVT)) {
7375- if (SDValue RepeatLoad = EltsFromConsecutiveLoads(
7385+ if (SDValue RepeatLoad = EltsFromConsecutiveLoads<T> (
73767386 RepeatVT, RepeatedLoads, DL, DAG, Subtarget, IsAfterLegalize)) {
73777387 SDValue Broadcast = RepeatLoad;
73787388 if (RepeatSize > ScalarSize) {
@@ -7413,7 +7423,7 @@ static SDValue combineToConsecutiveLoads(EVT VT, SDValue Op, const SDLoc &DL,
74137423 return SDValue();
74147424 }
74157425 assert(Elts.size() == VT.getVectorNumElements());
7416- return EltsFromConsecutiveLoads(VT, Elts, DL, DAG, Subtarget,
7426+ return EltsFromConsecutiveLoads<LoadSDNode> (VT, Elts, DL, DAG, Subtarget,
74177427 IsAfterLegalize);
74187428}
74197429
@@ -9268,8 +9278,12 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
92689278 {
92699279 SmallVector<SDValue, 64> Ops(Op->op_begin(), Op->op_begin() + NumElems);
92709280 if (SDValue LD =
9271- EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false))
9281+ EltsFromConsecutiveLoads<LoadSDNode> (VT, Ops, dl, DAG, Subtarget, false)) {
92729282 return LD;
9283+ } else if (SDValue LD =
9284+ EltsFromConsecutiveLoads<AtomicSDNode>(VT, Ops, dl, DAG, Subtarget, false)) {
9285+ return LD;
9286+ }
92739287 }
92749288
92759289 // If this is a splat of pairs of 32-bit elements, we can use a narrower
@@ -58007,7 +58021,7 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
5800758021 *FirstLd->getMemOperand(), &Fast) &&
5800858022 Fast) {
5800958023 if (SDValue Ld =
58010- EltsFromConsecutiveLoads(VT, Ops, DL, DAG, Subtarget, false))
58024+ EltsFromConsecutiveLoads<LoadSDNode> (VT, Ops, DL, DAG, Subtarget, false))
5801158025 return Ld;
5801258026 }
5801358027 }
0 commit comments