@@ -7049,15 +7049,23 @@ static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl,
7049
7049
return SDValue();
7050
7050
}
7051
7051
7052
- // Recurse to find a LoadSDNode source and the accumulated ByteOffest.
7053
- static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
7054
- if (ISD::isNON_EXTLoad(Elt.getNode())) {
7055
- auto *BaseLd = cast<LoadSDNode>(Elt);
7056
- if (!BaseLd->isSimple())
7057
- return false;
7058
- Ld = BaseLd;
7059
- ByteOffset = 0;
7060
- return true;
7052
+ template <typename T>
7053
+ static bool findEltLoadSrc(SDValue Elt, T *&Ld, int64_t &ByteOffset) {
7054
+ if constexpr (std::is_same_v<T, AtomicSDNode>) {
7055
+ if (auto *BaseLd = dyn_cast<AtomicSDNode>(Elt)) {
7056
+ Ld = BaseLd;
7057
+ ByteOffset = 0;
7058
+ return true;
7059
+ }
7060
+ } else if constexpr (std::is_same_v<T, LoadSDNode>) {
7061
+ if (ISD::isNON_EXTLoad(Elt.getNode())) {
7062
+ auto *BaseLd = cast<LoadSDNode>(Elt);
7063
+ if (!BaseLd->isSimple())
7064
+ return false;
7065
+ Ld = BaseLd;
7066
+ ByteOffset = 0;
7067
+ return true;
7068
+ }
7061
7069
}
7062
7070
7063
7071
switch (Elt.getOpcode()) {
@@ -7097,6 +7105,7 @@ static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
7097
7105
/// a build_vector or insert_subvector whose loaded operands are 'Elts'.
7098
7106
///
7099
7107
/// Example: <load i32 *a, load i32 *a+4, zero, undef> -> zextload a
7108
+ template <typename T>
7100
7109
static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7101
7110
const SDLoc &DL, SelectionDAG &DAG,
7102
7111
const X86Subtarget &Subtarget,
@@ -7111,7 +7120,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7111
7120
APInt ZeroMask = APInt::getZero(NumElems);
7112
7121
APInt UndefMask = APInt::getZero(NumElems);
7113
7122
7114
- SmallVector<LoadSDNode *, 8> Loads(NumElems, nullptr);
7123
+ SmallVector<T *, 8> Loads(NumElems, nullptr);
7115
7124
SmallVector<int64_t, 8> ByteOffsets(NumElems, 0);
7116
7125
7117
7126
// For each element in the initializer, see if we've found a load, zero or an
@@ -7161,7 +7170,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7161
7170
EVT EltBaseVT = EltBase.getValueType();
7162
7171
assert(EltBaseVT.getSizeInBits() == EltBaseVT.getStoreSizeInBits() &&
7163
7172
"Register/Memory size mismatch");
7164
- LoadSDNode *LDBase = Loads[FirstLoadedElt];
7173
+ T *LDBase = Loads[FirstLoadedElt];
7165
7174
assert(LDBase && "Did not find base load for merging consecutive loads");
7166
7175
unsigned BaseSizeInBits = EltBaseVT.getStoreSizeInBits();
7167
7176
unsigned BaseSizeInBytes = BaseSizeInBits / 8;
@@ -7175,8 +7184,8 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7175
7184
7176
7185
// Check to see if the element's load is consecutive to the base load
7177
7186
// or offset from a previous (already checked) load.
7178
- auto CheckConsecutiveLoad = [&](LoadSDNode *Base, int EltIdx) {
7179
- LoadSDNode *Ld = Loads[EltIdx];
7187
+ auto CheckConsecutiveLoad = [&](T *Base, int EltIdx) {
7188
+ T *Ld = Loads[EltIdx];
7180
7189
int64_t ByteOffset = ByteOffsets[EltIdx];
7181
7190
if (ByteOffset && (ByteOffset % BaseSizeInBytes) == 0) {
7182
7191
int64_t BaseIdx = EltIdx - (ByteOffset / BaseSizeInBytes);
@@ -7204,7 +7213,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7204
7213
}
7205
7214
}
7206
7215
7207
- auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) {
7216
+ auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, T *LDBase) {
7208
7217
auto MMOFlags = LDBase->getMemOperand()->getFlags();
7209
7218
assert(LDBase->isSimple() &&
7210
7219
"Cannot merge volatile or atomic loads.");
@@ -7274,7 +7283,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7274
7283
EVT HalfVT =
7275
7284
EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), HalfNumElems);
7276
7285
SDValue HalfLD =
7277
- EltsFromConsecutiveLoads(HalfVT, Elts.drop_back(HalfNumElems), DL,
7286
+ EltsFromConsecutiveLoads<T> (HalfVT, Elts.drop_back(HalfNumElems), DL,
7278
7287
DAG, Subtarget, IsAfterLegalize);
7279
7288
if (HalfLD)
7280
7289
return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT),
@@ -7351,7 +7360,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
7351
7360
EVT::getVectorVT(*DAG.getContext(), RepeatVT.getScalarType(),
7352
7361
VT.getSizeInBits() / ScalarSize);
7353
7362
if (TLI.isTypeLegal(BroadcastVT)) {
7354
- if (SDValue RepeatLoad = EltsFromConsecutiveLoads(
7363
+ if (SDValue RepeatLoad = EltsFromConsecutiveLoads<T> (
7355
7364
RepeatVT, RepeatedLoads, DL, DAG, Subtarget, IsAfterLegalize)) {
7356
7365
SDValue Broadcast = RepeatLoad;
7357
7366
if (RepeatSize > ScalarSize) {
@@ -7392,7 +7401,7 @@ static SDValue combineToConsecutiveLoads(EVT VT, SDValue Op, const SDLoc &DL,
7392
7401
return SDValue();
7393
7402
}
7394
7403
assert(Elts.size() == VT.getVectorNumElements());
7395
- return EltsFromConsecutiveLoads(VT, Elts, DL, DAG, Subtarget,
7404
+ return EltsFromConsecutiveLoads<LoadSDNode> (VT, Elts, DL, DAG, Subtarget,
7396
7405
IsAfterLegalize);
7397
7406
}
7398
7407
@@ -9247,8 +9256,12 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
9247
9256
{
9248
9257
SmallVector<SDValue, 64> Ops(Op->op_begin(), Op->op_begin() + NumElems);
9249
9258
if (SDValue LD =
9250
- EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false))
9259
+ EltsFromConsecutiveLoads<LoadSDNode> (VT, Ops, dl, DAG, Subtarget, false)) {
9251
9260
return LD;
9261
+ } else if (SDValue LD =
9262
+ EltsFromConsecutiveLoads<AtomicSDNode>(VT, Ops, dl, DAG, Subtarget, false)) {
9263
+ return LD;
9264
+ }
9252
9265
}
9253
9266
9254
9267
// If this is a splat of pairs of 32-bit elements, we can use a narrower
@@ -57934,7 +57947,7 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
57934
57947
*FirstLd->getMemOperand(), &Fast) &&
57935
57948
Fast) {
57936
57949
if (SDValue Ld =
57937
- EltsFromConsecutiveLoads(VT, Ops, DL, DAG, Subtarget, false))
57950
+ EltsFromConsecutiveLoads<LoadSDNode> (VT, Ops, DL, DAG, Subtarget, false))
57938
57951
return Ld;
57939
57952
}
57940
57953
}
0 commit comments