Skip to content

Commit ba2a301

Browse files
committed
[SelectionDAG][X86] Split <2 x T> vector types for atomic load
Vector types of 2 elements that aren't widened are split so that they can be vectorized within SelectionDAG. This change utilizes the load vectorization infrastructure in order to regroup the split elements. This enables SelectionDAG to translate vectors with type bfloat,half. commit-id:3a045357
1 parent a96fdf1 commit ba2a301

File tree

8 files changed

+151
-35
lines changed

8 files changed

+151
-35
lines changed

Diff for: llvm/include/llvm/CodeGen/SelectionDAG.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -1835,7 +1835,7 @@ class SelectionDAG {
18351835
/// chain to the token factor. This ensures that the new memory node will have
18361836
/// the same relative memory dependency position as the old load. Returns the
18371837
/// new merged load chain.
1838-
SDValue makeEquivalentMemoryOrdering(LoadSDNode *OldLoad, SDValue NewMemOp);
1838+
SDValue makeEquivalentMemoryOrdering(MemSDNode *OldLoad, SDValue NewMemOp);
18391839

18401840
/// Topological-sort the AllNodes list and a
18411841
/// assign a unique node id for each node in the DAG based on their
@@ -2261,6 +2261,8 @@ class SelectionDAG {
22612261
/// location that the 'Base' load is loading from.
22622262
bool areNonVolatileConsecutiveLoads(LoadSDNode *LD, LoadSDNode *Base,
22632263
unsigned Bytes, int Dist) const;
2264+
bool areNonVolatileConsecutiveLoads(AtomicSDNode *LD, AtomicSDNode *Base,
2265+
unsigned Bytes, int Dist) const;
22642266

22652267
/// Infer alignment of a load / store address. Return std::nullopt if it
22662268
/// cannot be inferred.

Diff for: llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

+1
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
946946
void SplitVecRes_FPOp_MultiType(SDNode *N, SDValue &Lo, SDValue &Hi);
947947
void SplitVecRes_IS_FPCLASS(SDNode *N, SDValue &Lo, SDValue &Hi);
948948
void SplitVecRes_INSERT_VECTOR_ELT(SDNode *N, SDValue &Lo, SDValue &Hi);
949+
void SplitVecRes_ATOMIC_LOAD(AtomicSDNode *LD, SDValue &Lo, SDValue &Hi);
949950
void SplitVecRes_LOAD(LoadSDNode *LD, SDValue &Lo, SDValue &Hi);
950951
void SplitVecRes_VP_LOAD(VPLoadSDNode *LD, SDValue &Lo, SDValue &Hi);
951952
void SplitVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *SLD, SDValue &Lo,

Diff for: llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

+35
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
11481148
SplitVecRes_STEP_VECTOR(N, Lo, Hi);
11491149
break;
11501150
case ISD::SIGN_EXTEND_INREG: SplitVecRes_InregOp(N, Lo, Hi); break;
1151+
case ISD::ATOMIC_LOAD:
1152+
SplitVecRes_ATOMIC_LOAD(cast<AtomicSDNode>(N), Lo, Hi);
1153+
break;
11511154
case ISD::LOAD:
11521155
SplitVecRes_LOAD(cast<LoadSDNode>(N), Lo, Hi);
11531156
break;
@@ -1391,6 +1394,38 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
13911394
SetSplitVector(SDValue(N, ResNo), Lo, Hi);
13921395
}
13931396

1397+
void DAGTypeLegalizer::SplitVecRes_ATOMIC_LOAD(AtomicSDNode *LD, SDValue &Lo,
1398+
SDValue &Hi) {
1399+
EVT LoVT, HiVT;
1400+
SDLoc dl(LD);
1401+
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(LD->getValueType(0));
1402+
1403+
SDValue Ch = LD->getChain();
1404+
SDValue Ptr = LD->getBasePtr();
1405+
EVT MemoryVT = LD->getMemoryVT();
1406+
1407+
EVT LoMemVT, HiMemVT;
1408+
std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT);
1409+
1410+
Lo = DAG.getAtomic(ISD::ATOMIC_LOAD, dl, LoMemVT, LoMemVT, Ch, Ptr,
1411+
LD->getMemOperand());
1412+
1413+
MachinePointerInfo MPI;
1414+
IncrementPointer(LD, LoMemVT, MPI, Ptr);
1415+
1416+
Hi = DAG.getAtomic(ISD::ATOMIC_LOAD, dl, HiMemVT, HiMemVT, Ch, Ptr,
1417+
LD->getMemOperand());
1418+
1419+
// Build a factor node to remember that this load is independent of the
1420+
// other one.
1421+
Ch = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Lo.getValue(1),
1422+
Hi.getValue(1));
1423+
1424+
// Legalize the chain result - switch anything that used the old chain to
1425+
// use the new one.
1426+
ReplaceValueWith(SDValue(LD, 1), Ch);
1427+
}
1428+
13941429
void DAGTypeLegalizer::IncrementPointer(MemSDNode *N, EVT MemVT,
13951430
MachinePointerInfo &MPI, SDValue &Ptr,
13961431
uint64_t *ScaledOffset) {

Diff for: llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

+22-2
Original file line numberDiff line numberDiff line change
@@ -12161,7 +12161,7 @@ SDValue SelectionDAG::makeEquivalentMemoryOrdering(SDValue OldChain,
1216112161
return TokenFactor;
1216212162
}
1216312163

12164-
SDValue SelectionDAG::makeEquivalentMemoryOrdering(LoadSDNode *OldLoad,
12164+
SDValue SelectionDAG::makeEquivalentMemoryOrdering(MemSDNode *OldLoad,
1216512165
SDValue NewMemOp) {
1216612166
assert(isa<MemSDNode>(NewMemOp.getNode()) && "Expected a memop node");
1216712167
SDValue OldChain = SDValue(OldLoad, 1);
@@ -12873,13 +12873,33 @@ std::pair<SDValue, SDValue> SelectionDAG::UnrollVectorOverflowOp(
1287312873
getBuildVector(NewOvVT, dl, OvScalars));
1287412874
}
1287512875

12876+
bool SelectionDAG::areNonVolatileConsecutiveLoads(AtomicSDNode *LD,
12877+
AtomicSDNode *Base,
12878+
unsigned Bytes,
12879+
int Dist) const {
12880+
if (LD->isVolatile() || Base->isVolatile())
12881+
return false;
12882+
if (LD->getChain() != Base->getChain())
12883+
return false;
12884+
EVT VT = LD->getMemoryVT();
12885+
if (VT.getSizeInBits() / 8 != Bytes)
12886+
return false;
12887+
12888+
auto BaseLocDecomp = BaseIndexOffset::match(Base, *this);
12889+
auto LocDecomp = BaseIndexOffset::match(LD, *this);
12890+
12891+
int64_t Offset = 0;
12892+
if (BaseLocDecomp.equalBaseIndex(LocDecomp, *this, Offset))
12893+
return (Dist * (int64_t)Bytes == Offset);
12894+
return false;
12895+
}
12896+
1287612897
bool SelectionDAG::areNonVolatileConsecutiveLoads(LoadSDNode *LD,
1287712898
LoadSDNode *Base,
1287812899
unsigned Bytes,
1287912900
int Dist) const {
1288012901
if (LD->isVolatile() || Base->isVolatile())
1288112902
return false;
12882-
// TODO: probably too restrictive for atomics, revisit
1288312903
if (!LD->isSimple())
1288412904
return false;
1288512905
if (LD->isIndexed() || Base->isIndexed())

Diff for: llvm/lib/CodeGen/SelectionDAG/SelectionDAGAddressAnalysis.cpp

+17-13
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ bool BaseIndexOffset::contains(const SelectionDAG &DAG, int64_t BitSize,
194194
return false;
195195
}
196196

197-
/// Parses tree in Ptr for base, index, offset addresses.
198-
static BaseIndexOffset matchLSNode(const LSBaseSDNode *N,
197+
template <typename T>
198+
static BaseIndexOffset matchSDNode(const T *N,
199199
const SelectionDAG &DAG) {
200200
SDValue Ptr = N->getBasePtr();
201201

@@ -206,16 +206,18 @@ static BaseIndexOffset matchLSNode(const LSBaseSDNode *N,
206206
bool IsIndexSignExt = false;
207207

208208
// pre-inc/pre-dec ops are components of EA.
209-
if (N->getAddressingMode() == ISD::PRE_INC) {
210-
if (auto *C = dyn_cast<ConstantSDNode>(N->getOffset()))
211-
Offset += C->getSExtValue();
212-
else // If unknown, give up now.
213-
return BaseIndexOffset(SDValue(), SDValue(), 0, false);
214-
} else if (N->getAddressingMode() == ISD::PRE_DEC) {
215-
if (auto *C = dyn_cast<ConstantSDNode>(N->getOffset()))
216-
Offset -= C->getSExtValue();
217-
else // If unknown, give up now.
218-
return BaseIndexOffset(SDValue(), SDValue(), 0, false);
209+
if constexpr (std::is_same_v<T, LSBaseSDNode>) {
210+
if (N->getAddressingMode() == ISD::PRE_INC) {
211+
if (auto *C = dyn_cast<ConstantSDNode>(N->getOffset()))
212+
Offset += C->getSExtValue();
213+
else // If unknown, give up now.
214+
return BaseIndexOffset(SDValue(), SDValue(), 0, false);
215+
} else if (N->getAddressingMode() == ISD::PRE_DEC) {
216+
if (auto *C = dyn_cast<ConstantSDNode>(N->getOffset()))
217+
Offset -= C->getSExtValue();
218+
else // If unknown, give up now.
219+
return BaseIndexOffset(SDValue(), SDValue(), 0, false);
220+
}
219221
}
220222

221223
// Consume constant adds & ors with appropriate masking.
@@ -300,8 +302,10 @@ static BaseIndexOffset matchLSNode(const LSBaseSDNode *N,
300302

301303
BaseIndexOffset BaseIndexOffset::match(const SDNode *N,
302304
const SelectionDAG &DAG) {
305+
if (const auto *AN = dyn_cast<AtomicSDNode>(N))
306+
return matchSDNode(AN, DAG);
303307
if (const auto *LS0 = dyn_cast<LSBaseSDNode>(N))
304-
return matchLSNode(LS0, DAG);
308+
return matchSDNode(LS0, DAG);
305309
if (const auto *LN = dyn_cast<LifetimeSDNode>(N)) {
306310
if (LN->hasOffset())
307311
return BaseIndexOffset(LN->getOperand(1), SDValue(), LN->getOffset(),

Diff for: llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -5218,7 +5218,11 @@ void SelectionDAGBuilder::visitAtomicLoad(const LoadInst &I) {
52185218
L = DAG.getPtrExtOrTrunc(L, dl, VT);
52195219

52205220
setValue(&I, L);
5221-
DAG.setRoot(OutChain);
5221+
5222+
if (VT.isVector())
5223+
DAG.setRoot(InChain);
5224+
else
5225+
DAG.setRoot(OutChain);
52225226
}
52235227

52245228
void SelectionDAGBuilder::visitAtomicStore(const StoreInst &I) {

Diff for: llvm/lib/Target/X86/X86ISelLowering.cpp

+32-18
Original file line numberDiff line numberDiff line change
@@ -7050,14 +7050,23 @@ static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl,
70507050
}
70517051

70527052
// Recurse to find a LoadSDNode source and the accumulated ByteOffest.
7053-
static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
7054-
if (ISD::isNON_EXTLoad(Elt.getNode())) {
7055-
auto *BaseLd = cast<LoadSDNode>(Elt);
7056-
if (!BaseLd->isSimple())
7057-
return false;
7058-
Ld = BaseLd;
7059-
ByteOffset = 0;
7060-
return true;
7053+
template <typename T>
7054+
static bool findEltLoadSrc(SDValue Elt, T *&Ld, int64_t &ByteOffset) {
7055+
if constexpr (std::is_same_v<T, AtomicSDNode>) {
7056+
if (auto *BaseLd = dyn_cast<AtomicSDNode>(Elt)) {
7057+
Ld = BaseLd;
7058+
ByteOffset = 0;
7059+
return true;
7060+
}
7061+
} else if constexpr (std::is_same_v<T, LoadSDNode>) {
7062+
if (ISD::isNON_EXTLoad(Elt.getNode())) {
7063+
auto *BaseLd = cast<LoadSDNode>(Elt);
7064+
if (!BaseLd->isSimple())
7065+
return false;
7066+
Ld = BaseLd;
7067+
ByteOffset = 0;
7068+
return true;
7069+
}
70617070
}
70627071

70637072
switch (Elt.getOpcode()) {
@@ -7097,6 +7106,7 @@ static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
70977106
/// a build_vector or insert_subvector whose loaded operands are 'Elts'.
70987107
///
70997108
/// Example: <load i32 *a, load i32 *a+4, zero, undef> -> zextload a
7109+
template <typename T>
71007110
static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71017111
const SDLoc &DL, SelectionDAG &DAG,
71027112
const X86Subtarget &Subtarget,
@@ -7111,7 +7121,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71117121
APInt ZeroMask = APInt::getZero(NumElems);
71127122
APInt UndefMask = APInt::getZero(NumElems);
71137123

7114-
SmallVector<LoadSDNode*, 8> Loads(NumElems, nullptr);
7124+
SmallVector<T*, 8> Loads(NumElems, nullptr);
71157125
SmallVector<int64_t, 8> ByteOffsets(NumElems, 0);
71167126

71177127
// For each element in the initializer, see if we've found a load, zero or an
@@ -7161,7 +7171,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71617171
EVT EltBaseVT = EltBase.getValueType();
71627172
assert(EltBaseVT.getSizeInBits() == EltBaseVT.getStoreSizeInBits() &&
71637173
"Register/Memory size mismatch");
7164-
LoadSDNode *LDBase = Loads[FirstLoadedElt];
7174+
T *LDBase = Loads[FirstLoadedElt];
71657175
assert(LDBase && "Did not find base load for merging consecutive loads");
71667176
unsigned BaseSizeInBits = EltBaseVT.getStoreSizeInBits();
71677177
unsigned BaseSizeInBytes = BaseSizeInBits / 8;
@@ -7175,8 +7185,8 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71757185

71767186
// Check to see if the element's load is consecutive to the base load
71777187
// or offset from a previous (already checked) load.
7178-
auto CheckConsecutiveLoad = [&](LoadSDNode *Base, int EltIdx) {
7179-
LoadSDNode *Ld = Loads[EltIdx];
7188+
auto CheckConsecutiveLoad = [&](T *Base, int EltIdx) {
7189+
T *Ld = Loads[EltIdx];
71807190
int64_t ByteOffset = ByteOffsets[EltIdx];
71817191
if (ByteOffset && (ByteOffset % BaseSizeInBytes) == 0) {
71827192
int64_t BaseIdx = EltIdx - (ByteOffset / BaseSizeInBytes);
@@ -7204,7 +7214,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
72047214
}
72057215
}
72067216

7207-
auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) {
7217+
auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, T *LDBase) {
72087218
auto MMOFlags = LDBase->getMemOperand()->getFlags();
72097219
assert(LDBase->isSimple() &&
72107220
"Cannot merge volatile or atomic loads.");
@@ -7274,7 +7284,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
72747284
EVT HalfVT =
72757285
EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), HalfNumElems);
72767286
SDValue HalfLD =
7277-
EltsFromConsecutiveLoads(HalfVT, Elts.drop_back(HalfNumElems), DL,
7287+
EltsFromConsecutiveLoads<T>(HalfVT, Elts.drop_back(HalfNumElems), DL,
72787288
DAG, Subtarget, IsAfterLegalize);
72797289
if (HalfLD)
72807290
return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT),
@@ -7351,7 +7361,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
73517361
EVT::getVectorVT(*DAG.getContext(), RepeatVT.getScalarType(),
73527362
VT.getSizeInBits() / ScalarSize);
73537363
if (TLI.isTypeLegal(BroadcastVT)) {
7354-
if (SDValue RepeatLoad = EltsFromConsecutiveLoads(
7364+
if (SDValue RepeatLoad = EltsFromConsecutiveLoads<T>(
73557365
RepeatVT, RepeatedLoads, DL, DAG, Subtarget, IsAfterLegalize)) {
73567366
SDValue Broadcast = RepeatLoad;
73577367
if (RepeatSize > ScalarSize) {
@@ -7392,7 +7402,7 @@ static SDValue combineToConsecutiveLoads(EVT VT, SDValue Op, const SDLoc &DL,
73927402
return SDValue();
73937403
}
73947404
assert(Elts.size() == VT.getVectorNumElements());
7395-
return EltsFromConsecutiveLoads(VT, Elts, DL, DAG, Subtarget,
7405+
return EltsFromConsecutiveLoads<LoadSDNode>(VT, Elts, DL, DAG, Subtarget,
73967406
IsAfterLegalize);
73977407
}
73987408

@@ -9247,8 +9257,12 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
92479257
{
92489258
SmallVector<SDValue, 64> Ops(Op->op_begin(), Op->op_begin() + NumElems);
92499259
if (SDValue LD =
9250-
EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false))
9260+
EltsFromConsecutiveLoads<LoadSDNode>(VT, Ops, dl, DAG, Subtarget, false)) {
92519261
return LD;
9262+
} else if (SDValue LD =
9263+
EltsFromConsecutiveLoads<AtomicSDNode>(VT, Ops, dl, DAG, Subtarget, false)) {
9264+
return LD;
9265+
}
92529266
}
92539267

92549268
// If this is a splat of pairs of 32-bit elements, we can use a narrower
@@ -57934,7 +57948,7 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
5793457948
*FirstLd->getMemOperand(), &Fast) &&
5793557949
Fast) {
5793657950
if (SDValue Ld =
57937-
EltsFromConsecutiveLoads(VT, Ops, DL, DAG, Subtarget, false))
57951+
EltsFromConsecutiveLoads<LoadSDNode>(VT, Ops, DL, DAG, Subtarget, false))
5793857952
return Ld;
5793957953
}
5794057954
}

Diff for: llvm/test/CodeGen/X86/atomic-load-store.ll

+36
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,24 @@ define <2 x float> @atomic_vec2_float_align(ptr %x) {
195195
ret <2 x float> %ret
196196
}
197197

198+
define <2 x half> @atomic_vec2_half(ptr %x) {
199+
; CHECK-LABEL: atomic_vec2_half:
200+
; CHECK: ## %bb.0:
201+
; CHECK-NEXT: movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
202+
; CHECK-NEXT: retq
203+
%ret = load atomic <2 x half>, ptr %x acquire, align 4
204+
ret <2 x half> %ret
205+
}
206+
207+
define <2 x bfloat> @atomic_vec2_bfloat(ptr %x) {
208+
; CHECK-LABEL: atomic_vec2_bfloat:
209+
; CHECK: ## %bb.0:
210+
; CHECK-NEXT: movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
211+
; CHECK-NEXT: retq
212+
%ret = load atomic <2 x bfloat>, ptr %x acquire, align 4
213+
ret <2 x bfloat> %ret
214+
}
215+
198216
define <1 x ptr> @atomic_vec1_ptr(ptr %x) nounwind {
199217
; CHECK3-LABEL: atomic_vec1_ptr:
200218
; CHECK3: ## %bb.0:
@@ -367,6 +385,24 @@ define <4 x i16> @atomic_vec4_i16(ptr %x) nounwind {
367385
ret <4 x i16> %ret
368386
}
369387

388+
define <4 x half> @atomic_vec4_half(ptr %x) nounwind {
389+
; CHECK-LABEL: atomic_vec4_half:
390+
; CHECK: ## %bb.0:
391+
; CHECK-NEXT: movsd {{.*#+}} xmm0 = mem[0],zero
392+
; CHECK-NEXT: retq
393+
%ret = load atomic <4 x half>, ptr %x acquire, align 8
394+
ret <4 x half> %ret
395+
}
396+
397+
define <4 x bfloat> @atomic_vec4_bfloat(ptr %x) nounwind {
398+
; CHECK-LABEL: atomic_vec4_bfloat:
399+
; CHECK: ## %bb.0:
400+
; CHECK-NEXT: movsd {{.*#+}} xmm0 = mem[0],zero
401+
; CHECK-NEXT: retq
402+
%ret = load atomic <4 x bfloat>, ptr %x acquire, align 8
403+
ret <4 x bfloat> %ret
404+
}
405+
370406
define <4 x float> @atomic_vec4_float_align(ptr %x) nounwind {
371407
; CHECK-LABEL: atomic_vec4_float_align:
372408
; CHECK: ## %bb.0:

0 commit comments

Comments
 (0)