@@ -7071,14 +7071,23 @@ static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl,
7071
7071
}
7072
7072
7073
7073
// Recurse to find a LoadSDNode source and the accumulated ByteOffest.
7074
- static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
7075
- if (ISD::isNON_EXTLoad(Elt.getNode())) {
7076
- auto *BaseLd = cast<LoadSDNode>(Elt);
7077
- if (!BaseLd->isSimple())
7078
- return false;
7079
- Ld = BaseLd;
7080
- ByteOffset = 0;
7081
- return true;
7074
+ template <typename T>
7075
+ static bool findEltLoadSrc(SDValue Elt, T *&Ld, int64_t &ByteOffset) {
7076
+ if constexpr (std::is_same_v<T, AtomicSDNode>) {
7077
+ if (auto *BaseLd = dyn_cast<AtomicSDNode>(Elt)) {
7078
+ Ld = BaseLd;
7079
+ ByteOffset = 0;
7080
+ return true;
7081
+ }
7082
+ } else if constexpr (std::is_same_v<T, LoadSDNode>) {
7083
+ if (ISD::isNON_EXTLoad(Elt.getNode())) {
7084
+ auto *BaseLd = cast<LoadSDNode>(Elt);
7085
+ if (!BaseLd->isSimple())
7086
+ return false;
7087
+ Ld = BaseLd;
7088
+ ByteOffset = 0;
7089
+ return true;
7090
+ }
7082
7091
}
7083
7092
7084
7093
switch (Elt.getOpcode()) {
@@ -7118,6 +7127,7 @@ static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
7118
7127
/// a build_vector or insert_subvector whose loaded operands are 'Elts'.
7119
7128
///
7120
7129
/// Example: <load i32 *a, load i32 *a+4, zero, undef> -> zextload a
7130
+ template <typename T>
7121
7131
static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7122
7132
const SDLoc &DL, SelectionDAG &DAG,
7123
7133
const X86Subtarget &Subtarget,
@@ -7132,7 +7142,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7132
7142
APInt ZeroMask = APInt::getZero(NumElems);
7133
7143
APInt UndefMask = APInt::getZero(NumElems);
7134
7144
7135
- SmallVector<LoadSDNode *, 8> Loads(NumElems, nullptr);
7145
+ SmallVector<T *, 8> Loads(NumElems, nullptr);
7136
7146
SmallVector<int64_t, 8> ByteOffsets(NumElems, 0);
7137
7147
7138
7148
// For each element in the initializer, see if we've found a load, zero or an
@@ -7182,7 +7192,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7182
7192
EVT EltBaseVT = EltBase.getValueType();
7183
7193
assert(EltBaseVT.getSizeInBits() == EltBaseVT.getStoreSizeInBits() &&
7184
7194
"Register/Memory size mismatch");
7185
- LoadSDNode *LDBase = Loads[FirstLoadedElt];
7195
+ T *LDBase = Loads[FirstLoadedElt];
7186
7196
assert(LDBase && "Did not find base load for merging consecutive loads");
7187
7197
unsigned BaseSizeInBits = EltBaseVT.getStoreSizeInBits();
7188
7198
unsigned BaseSizeInBytes = BaseSizeInBits / 8;
@@ -7196,8 +7206,8 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7196
7206
7197
7207
// Check to see if the element's load is consecutive to the base load
7198
7208
// or offset from a previous (already checked) load.
7199
- auto CheckConsecutiveLoad = [&](LoadSDNode *Base, int EltIdx) {
7200
- LoadSDNode *Ld = Loads[EltIdx];
7209
+ auto CheckConsecutiveLoad = [&](T *Base, int EltIdx) {
7210
+ T *Ld = Loads[EltIdx];
7201
7211
int64_t ByteOffset = ByteOffsets[EltIdx];
7202
7212
if (ByteOffset && (ByteOffset % BaseSizeInBytes) == 0) {
7203
7213
int64_t BaseIdx = EltIdx - (ByteOffset / BaseSizeInBytes);
@@ -7225,7 +7235,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7225
7235
}
7226
7236
}
7227
7237
7228
- auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) {
7238
+ auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, T *LDBase) {
7229
7239
auto MMOFlags = LDBase->getMemOperand()->getFlags();
7230
7240
assert(LDBase->isSimple() &&
7231
7241
"Cannot merge volatile or atomic loads.");
@@ -7295,7 +7305,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7295
7305
EVT HalfVT =
7296
7306
EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), HalfNumElems);
7297
7307
SDValue HalfLD =
7298
- EltsFromConsecutiveLoads(HalfVT, Elts.drop_back(HalfNumElems), DL,
7308
+ EltsFromConsecutiveLoads<T> (HalfVT, Elts.drop_back(HalfNumElems), DL,
7299
7309
DAG, Subtarget, IsAfterLegalize);
7300
7310
if (HalfLD)
7301
7311
return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT),
@@ -7372,7 +7382,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7372
7382
EVT::getVectorVT(*DAG.getContext(), RepeatVT.getScalarType(),
7373
7383
VT.getSizeInBits() / ScalarSize);
7374
7384
if (TLI.isTypeLegal(BroadcastVT)) {
7375
- if (SDValue RepeatLoad = EltsFromConsecutiveLoads(
7385
+ if (SDValue RepeatLoad = EltsFromConsecutiveLoads<T> (
7376
7386
RepeatVT, RepeatedLoads, DL, DAG, Subtarget, IsAfterLegalize)) {
7377
7387
SDValue Broadcast = RepeatLoad;
7378
7388
if (RepeatSize > ScalarSize) {
@@ -7413,7 +7423,7 @@ static SDValue combineToConsecutiveLoads(EVT VT, SDValue Op, const SDLoc &DL,
7413
7423
return SDValue();
7414
7424
}
7415
7425
assert(Elts.size() == VT.getVectorNumElements());
7416
- return EltsFromConsecutiveLoads(VT, Elts, DL, DAG, Subtarget,
7426
+ return EltsFromConsecutiveLoads<LoadSDNode> (VT, Elts, DL, DAG, Subtarget,
7417
7427
IsAfterLegalize);
7418
7428
}
7419
7429
@@ -9268,8 +9278,12 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
9268
9278
{
9269
9279
SmallVector<SDValue, 64> Ops(Op->op_begin(), Op->op_begin() + NumElems);
9270
9280
if (SDValue LD =
9271
- EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false))
9281
+ EltsFromConsecutiveLoads<LoadSDNode> (VT, Ops, dl, DAG, Subtarget, false)) {
9272
9282
return LD;
9283
+ } else if (SDValue LD =
9284
+ EltsFromConsecutiveLoads<AtomicSDNode>(VT, Ops, dl, DAG, Subtarget, false)) {
9285
+ return LD;
9286
+ }
9273
9287
}
9274
9288
9275
9289
// If this is a splat of pairs of 32-bit elements, we can use a narrower
@@ -58007,7 +58021,7 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
58007
58021
*FirstLd->getMemOperand(), &Fast) &&
58008
58022
Fast) {
58009
58023
if (SDValue Ld =
58010
- EltsFromConsecutiveLoads(VT, Ops, DL, DAG, Subtarget, false))
58024
+ EltsFromConsecutiveLoads<LoadSDNode> (VT, Ops, DL, DAG, Subtarget, false))
58011
58025
return Ld;
58012
58026
}
58013
58027
}
0 commit comments