Skip to content

Commit a14a025

Browse files
committed
[SelectionDAG][X86] Split via Concat <n x T> vector types for atomic load
Vector types that aren't widened are 'split' via CONCAT_VECTORS so that a single ATOMIC_LOAD is issued for the entire vector at once. This change utilizes the load vectorization infrastructure in SelectionDAG in order to group the vectors. This enables SelectionDAG to translate vectors with type bfloat,half. commit-id:3a045357
1 parent a847ecf commit a14a025

File tree

8 files changed

+147
-35
lines changed

8 files changed

+147
-35
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -1837,7 +1837,7 @@ class SelectionDAG {
18371837
/// chain to the token factor. This ensures that the new memory node will have
18381838
/// the same relative memory dependency position as the old load. Returns the
18391839
/// new merged load chain.
1840-
SDValue makeEquivalentMemoryOrdering(LoadSDNode *OldLoad, SDValue NewMemOp);
1840+
SDValue makeEquivalentMemoryOrdering(MemSDNode *OldLoad, SDValue NewMemOp);
18411841

18421842
/// Topological-sort the AllNodes list and a
18431843
/// assign a unique node id for each node in the DAG based on their
@@ -2263,6 +2263,8 @@ class SelectionDAG {
22632263
/// location that the 'Base' load is loading from.
22642264
bool areNonVolatileConsecutiveLoads(LoadSDNode *LD, LoadSDNode *Base,
22652265
unsigned Bytes, int Dist) const;
2266+
bool areNonVolatileConsecutiveLoads(AtomicSDNode *LD, AtomicSDNode *Base,
2267+
unsigned Bytes, int Dist) const;
22662268

22672269
/// Infer alignment of a load / store address. Return std::nullopt if it
22682270
/// cannot be inferred.

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

+1
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
948948
void SplitVecRes_FPOp_MultiType(SDNode *N, SDValue &Lo, SDValue &Hi);
949949
void SplitVecRes_IS_FPCLASS(SDNode *N, SDValue &Lo, SDValue &Hi);
950950
void SplitVecRes_INSERT_VECTOR_ELT(SDNode *N, SDValue &Lo, SDValue &Hi);
951+
void SplitVecRes_ATOMIC_LOAD(AtomicSDNode *LD);
951952
void SplitVecRes_LOAD(LoadSDNode *LD, SDValue &Lo, SDValue &Hi);
952953
void SplitVecRes_VP_LOAD(VPLoadSDNode *LD, SDValue &Lo, SDValue &Hi);
953954
void SplitVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *SLD, SDValue &Lo,

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

+31
Original file line numberDiff line numberDiff line change
@@ -1152,6 +1152,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
11521152
SplitVecRes_STEP_VECTOR(N, Lo, Hi);
11531153
break;
11541154
case ISD::SIGN_EXTEND_INREG: SplitVecRes_InregOp(N, Lo, Hi); break;
1155+
case ISD::ATOMIC_LOAD:
1156+
SplitVecRes_ATOMIC_LOAD(cast<AtomicSDNode>(N));
1157+
break;
11551158
case ISD::LOAD:
11561159
SplitVecRes_LOAD(cast<LoadSDNode>(N), Lo, Hi);
11571160
break;
@@ -1395,6 +1398,34 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
13951398
SetSplitVector(SDValue(N, ResNo), Lo, Hi);
13961399
}
13971400

1401+
void DAGTypeLegalizer::SplitVecRes_ATOMIC_LOAD(AtomicSDNode *LD) {
1402+
SDLoc dl(LD);
1403+
1404+
EVT MemoryVT = LD->getMemoryVT();
1405+
unsigned NumElts = MemoryVT.getVectorMinNumElements();
1406+
1407+
EVT IntMemoryVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, NumElts);
1408+
EVT ElemVT = EVT::getVectorVT(*DAG.getContext(),
1409+
MemoryVT.getVectorElementType(), 1);
1410+
1411+
// Create a single atomic to load all the elements at once.
1412+
SDValue Atomic = DAG.getAtomic(ISD::ATOMIC_LOAD, dl, IntMemoryVT, IntMemoryVT,
1413+
LD->getChain(), LD->getBasePtr(),
1414+
LD->getMemOperand());
1415+
1416+
// Instead of splitting, put all the elements back into a vector.
1417+
SmallVector<SDValue, 4> Ops;
1418+
for (unsigned i = 0; i < NumElts; ++i) {
1419+
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i16, Atomic,
1420+
DAG.getVectorIdxConstant(i, dl));
1421+
Elt = DAG.getBitcast(ElemVT, Elt);
1422+
Ops.push_back(Elt);
1423+
}
1424+
SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, dl, MemoryVT, Ops);
1425+
1426+
ReplaceValueWith(SDValue(LD, 0), Concat);
1427+
}
1428+
13981429
void DAGTypeLegalizer::IncrementPointer(MemSDNode *N, EVT MemVT,
13991430
MachinePointerInfo &MPI, SDValue &Ptr,
14001431
uint64_t *ScaledOffset) {

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

+22-2
Original file line numberDiff line numberDiff line change
@@ -12167,7 +12167,7 @@ SDValue SelectionDAG::makeEquivalentMemoryOrdering(SDValue OldChain,
1216712167
return TokenFactor;
1216812168
}
1216912169

12170-
SDValue SelectionDAG::makeEquivalentMemoryOrdering(LoadSDNode *OldLoad,
12170+
SDValue SelectionDAG::makeEquivalentMemoryOrdering(MemSDNode *OldLoad,
1217112171
SDValue NewMemOp) {
1217212172
assert(isa<MemSDNode>(NewMemOp.getNode()) && "Expected a memop node");
1217312173
SDValue OldChain = SDValue(OldLoad, 1);
@@ -12879,13 +12879,33 @@ std::pair<SDValue, SDValue> SelectionDAG::UnrollVectorOverflowOp(
1287912879
getBuildVector(NewOvVT, dl, OvScalars));
1288012880
}
1288112881

12882+
bool SelectionDAG::areNonVolatileConsecutiveLoads(AtomicSDNode *LD,
12883+
AtomicSDNode *Base,
12884+
unsigned Bytes,
12885+
int Dist) const {
12886+
if (LD->isVolatile() || Base->isVolatile())
12887+
return false;
12888+
if (LD->getChain() != Base->getChain())
12889+
return false;
12890+
EVT VT = LD->getMemoryVT();
12891+
if (VT.getSizeInBits() / 8 != Bytes)
12892+
return false;
12893+
12894+
auto BaseLocDecomp = BaseIndexOffset::match(Base, *this);
12895+
auto LocDecomp = BaseIndexOffset::match(LD, *this);
12896+
12897+
int64_t Offset = 0;
12898+
if (BaseLocDecomp.equalBaseIndex(LocDecomp, *this, Offset))
12899+
return (Dist * (int64_t)Bytes == Offset);
12900+
return false;
12901+
}
12902+
1288212903
bool SelectionDAG::areNonVolatileConsecutiveLoads(LoadSDNode *LD,
1288312904
LoadSDNode *Base,
1288412905
unsigned Bytes,
1288512906
int Dist) const {
1288612907
if (LD->isVolatile() || Base->isVolatile())
1288712908
return false;
12888-
// TODO: probably too restrictive for atomics, revisit
1288912909
if (!LD->isSimple())
1289012910
return false;
1289112911
if (LD->isIndexed() || Base->isIndexed())

llvm/lib/CodeGen/SelectionDAG/SelectionDAGAddressAnalysis.cpp

+17-13
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ bool BaseIndexOffset::contains(const SelectionDAG &DAG, int64_t BitSize,
194194
return false;
195195
}
196196

197-
/// Parses tree in Ptr for base, index, offset addresses.
198-
static BaseIndexOffset matchLSNode(const LSBaseSDNode *N,
197+
template <typename T>
198+
static BaseIndexOffset matchSDNode(const T *N,
199199
const SelectionDAG &DAG) {
200200
SDValue Ptr = N->getBasePtr();
201201

@@ -206,16 +206,18 @@ static BaseIndexOffset matchLSNode(const LSBaseSDNode *N,
206206
bool IsIndexSignExt = false;
207207

208208
// pre-inc/pre-dec ops are components of EA.
209-
if (N->getAddressingMode() == ISD::PRE_INC) {
210-
if (auto *C = dyn_cast<ConstantSDNode>(N->getOffset()))
211-
Offset += C->getSExtValue();
212-
else // If unknown, give up now.
213-
return BaseIndexOffset(SDValue(), SDValue(), 0, false);
214-
} else if (N->getAddressingMode() == ISD::PRE_DEC) {
215-
if (auto *C = dyn_cast<ConstantSDNode>(N->getOffset()))
216-
Offset -= C->getSExtValue();
217-
else // If unknown, give up now.
218-
return BaseIndexOffset(SDValue(), SDValue(), 0, false);
209+
if constexpr (std::is_same_v<T, LSBaseSDNode>) {
210+
if (N->getAddressingMode() == ISD::PRE_INC) {
211+
if (auto *C = dyn_cast<ConstantSDNode>(N->getOffset()))
212+
Offset += C->getSExtValue();
213+
else // If unknown, give up now.
214+
return BaseIndexOffset(SDValue(), SDValue(), 0, false);
215+
} else if (N->getAddressingMode() == ISD::PRE_DEC) {
216+
if (auto *C = dyn_cast<ConstantSDNode>(N->getOffset()))
217+
Offset -= C->getSExtValue();
218+
else // If unknown, give up now.
219+
return BaseIndexOffset(SDValue(), SDValue(), 0, false);
220+
}
219221
}
220222

221223
// Consume constant adds & ors with appropriate masking.
@@ -300,8 +302,10 @@ static BaseIndexOffset matchLSNode(const LSBaseSDNode *N,
300302

301303
BaseIndexOffset BaseIndexOffset::match(const SDNode *N,
302304
const SelectionDAG &DAG) {
305+
if (const auto *AN = dyn_cast<AtomicSDNode>(N))
306+
return matchSDNode(AN, DAG);
303307
if (const auto *LS0 = dyn_cast<LSBaseSDNode>(N))
304-
return matchLSNode(LS0, DAG);
308+
return matchSDNode(LS0, DAG);
305309
if (const auto *LN = dyn_cast<LifetimeSDNode>(N)) {
306310
if (LN->hasOffset())
307311
return BaseIndexOffset(LN->getOperand(1), SDValue(), LN->getOffset(),

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -5218,7 +5218,11 @@ void SelectionDAGBuilder::visitAtomicLoad(const LoadInst &I) {
52185218
L = DAG.getPtrExtOrTrunc(L, dl, VT);
52195219

52205220
setValue(&I, L);
5221-
DAG.setRoot(OutChain);
5221+
5222+
if (VT.isVector())
5223+
DAG.setRoot(InChain);
5224+
else
5225+
DAG.setRoot(OutChain);
52225226
}
52235227

52245228
void SelectionDAGBuilder::visitAtomicStore(const StoreInst &I) {

llvm/lib/Target/X86/X86ISelLowering.cpp

+32-18
Original file line numberDiff line numberDiff line change
@@ -7071,14 +7071,23 @@ static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl,
70717071
}
70727072

70737073
// Recurse to find a LoadSDNode source and the accumulated ByteOffest.
7074-
static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
7075-
if (ISD::isNON_EXTLoad(Elt.getNode())) {
7076-
auto *BaseLd = cast<LoadSDNode>(Elt);
7077-
if (!BaseLd->isSimple())
7078-
return false;
7079-
Ld = BaseLd;
7080-
ByteOffset = 0;
7081-
return true;
7074+
template <typename T>
7075+
static bool findEltLoadSrc(SDValue Elt, T *&Ld, int64_t &ByteOffset) {
7076+
if constexpr (std::is_same_v<T, AtomicSDNode>) {
7077+
if (auto *BaseLd = dyn_cast<AtomicSDNode>(Elt)) {
7078+
Ld = BaseLd;
7079+
ByteOffset = 0;
7080+
return true;
7081+
}
7082+
} else if constexpr (std::is_same_v<T, LoadSDNode>) {
7083+
if (ISD::isNON_EXTLoad(Elt.getNode())) {
7084+
auto *BaseLd = cast<LoadSDNode>(Elt);
7085+
if (!BaseLd->isSimple())
7086+
return false;
7087+
Ld = BaseLd;
7088+
ByteOffset = 0;
7089+
return true;
7090+
}
70827091
}
70837092

70847093
switch (Elt.getOpcode()) {
@@ -7118,6 +7127,7 @@ static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) {
71187127
/// a build_vector or insert_subvector whose loaded operands are 'Elts'.
71197128
///
71207129
/// Example: <load i32 *a, load i32 *a+4, zero, undef> -> zextload a
7130+
template <typename T>
71217131
static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71227132
const SDLoc &DL, SelectionDAG &DAG,
71237133
const X86Subtarget &Subtarget,
@@ -7132,7 +7142,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71327142
APInt ZeroMask = APInt::getZero(NumElems);
71337143
APInt UndefMask = APInt::getZero(NumElems);
71347144

7135-
SmallVector<LoadSDNode*, 8> Loads(NumElems, nullptr);
7145+
SmallVector<T*, 8> Loads(NumElems, nullptr);
71367146
SmallVector<int64_t, 8> ByteOffsets(NumElems, 0);
71377147

71387148
// For each element in the initializer, see if we've found a load, zero or an
@@ -7182,7 +7192,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71827192
EVT EltBaseVT = EltBase.getValueType();
71837193
assert(EltBaseVT.getSizeInBits() == EltBaseVT.getStoreSizeInBits() &&
71847194
"Register/Memory size mismatch");
7185-
LoadSDNode *LDBase = Loads[FirstLoadedElt];
7195+
T *LDBase = Loads[FirstLoadedElt];
71867196
assert(LDBase && "Did not find base load for merging consecutive loads");
71877197
unsigned BaseSizeInBits = EltBaseVT.getStoreSizeInBits();
71887198
unsigned BaseSizeInBytes = BaseSizeInBits / 8;
@@ -7196,8 +7206,8 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
71967206

71977207
// Check to see if the element's load is consecutive to the base load
71987208
// or offset from a previous (already checked) load.
7199-
auto CheckConsecutiveLoad = [&](LoadSDNode *Base, int EltIdx) {
7200-
LoadSDNode *Ld = Loads[EltIdx];
7209+
auto CheckConsecutiveLoad = [&](T *Base, int EltIdx) {
7210+
T *Ld = Loads[EltIdx];
72017211
int64_t ByteOffset = ByteOffsets[EltIdx];
72027212
if (ByteOffset && (ByteOffset % BaseSizeInBytes) == 0) {
72037213
int64_t BaseIdx = EltIdx - (ByteOffset / BaseSizeInBytes);
@@ -7225,7 +7235,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
72257235
}
72267236
}
72277237

7228-
auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) {
7238+
auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, T *LDBase) {
72297239
auto MMOFlags = LDBase->getMemOperand()->getFlags();
72307240
assert(LDBase->isSimple() &&
72317241
"Cannot merge volatile or atomic loads.");
@@ -7295,7 +7305,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
72957305
EVT HalfVT =
72967306
EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), HalfNumElems);
72977307
SDValue HalfLD =
7298-
EltsFromConsecutiveLoads(HalfVT, Elts.drop_back(HalfNumElems), DL,
7308+
EltsFromConsecutiveLoads<T>(HalfVT, Elts.drop_back(HalfNumElems), DL,
72997309
DAG, Subtarget, IsAfterLegalize);
73007310
if (HalfLD)
73017311
return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT),
@@ -7372,7 +7382,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
73727382
EVT::getVectorVT(*DAG.getContext(), RepeatVT.getScalarType(),
73737383
VT.getSizeInBits() / ScalarSize);
73747384
if (TLI.isTypeLegal(BroadcastVT)) {
7375-
if (SDValue RepeatLoad = EltsFromConsecutiveLoads(
7385+
if (SDValue RepeatLoad = EltsFromConsecutiveLoads<T>(
73767386
RepeatVT, RepeatedLoads, DL, DAG, Subtarget, IsAfterLegalize)) {
73777387
SDValue Broadcast = RepeatLoad;
73787388
if (RepeatSize > ScalarSize) {
@@ -7413,7 +7423,7 @@ static SDValue combineToConsecutiveLoads(EVT VT, SDValue Op, const SDLoc &DL,
74137423
return SDValue();
74147424
}
74157425
assert(Elts.size() == VT.getVectorNumElements());
7416-
return EltsFromConsecutiveLoads(VT, Elts, DL, DAG, Subtarget,
7426+
return EltsFromConsecutiveLoads<LoadSDNode>(VT, Elts, DL, DAG, Subtarget,
74177427
IsAfterLegalize);
74187428
}
74197429

@@ -9268,8 +9278,12 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
92689278
{
92699279
SmallVector<SDValue, 64> Ops(Op->op_begin(), Op->op_begin() + NumElems);
92709280
if (SDValue LD =
9271-
EltsFromConsecutiveLoads(VT, Ops, dl, DAG, Subtarget, false))
9281+
EltsFromConsecutiveLoads<LoadSDNode>(VT, Ops, dl, DAG, Subtarget, false)) {
92729282
return LD;
9283+
} else if (SDValue LD =
9284+
EltsFromConsecutiveLoads<AtomicSDNode>(VT, Ops, dl, DAG, Subtarget, false)) {
9285+
return LD;
9286+
}
92739287
}
92749288

92759289
// If this is a splat of pairs of 32-bit elements, we can use a narrower
@@ -58007,7 +58021,7 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
5800758021
*FirstLd->getMemOperand(), &Fast) &&
5800858022
Fast) {
5800958023
if (SDValue Ld =
58010-
EltsFromConsecutiveLoads(VT, Ops, DL, DAG, Subtarget, false))
58024+
EltsFromConsecutiveLoads<LoadSDNode>(VT, Ops, DL, DAG, Subtarget, false))
5801158025
return Ld;
5801258026
}
5801358027
}

llvm/test/CodeGen/X86/atomic-load-store.ll

+36
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,24 @@ define <2 x float> @atomic_vec2_float_align(ptr %x) {
204204
ret <2 x float> %ret
205205
}
206206

207+
define <2 x half> @atomic_vec2_half(ptr %x) {
208+
; CHECK-LABEL: atomic_vec2_half:
209+
; CHECK: ## %bb.0:
210+
; CHECK-NEXT: movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
211+
; CHECK-NEXT: retq
212+
%ret = load atomic <2 x half>, ptr %x acquire, align 4
213+
ret <2 x half> %ret
214+
}
215+
216+
define <2 x bfloat> @atomic_vec2_bfloat(ptr %x) {
217+
; CHECK-LABEL: atomic_vec2_bfloat:
218+
; CHECK: ## %bb.0:
219+
; CHECK-NEXT: movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
220+
; CHECK-NEXT: retq
221+
%ret = load atomic <2 x bfloat>, ptr %x acquire, align 4
222+
ret <2 x bfloat> %ret
223+
}
224+
207225
define <1 x ptr> @atomic_vec1_ptr(ptr %x) nounwind {
208226
; CHECK3-LABEL: atomic_vec1_ptr:
209227
; CHECK3: ## %bb.0:
@@ -376,6 +394,24 @@ define <4 x i16> @atomic_vec4_i16(ptr %x) nounwind {
376394
ret <4 x i16> %ret
377395
}
378396

397+
define <4 x half> @atomic_vec4_half(ptr %x) nounwind {
398+
; CHECK-LABEL: atomic_vec4_half:
399+
; CHECK: ## %bb.0:
400+
; CHECK-NEXT: movsd {{.*#+}} xmm0 = mem[0],zero
401+
; CHECK-NEXT: retq
402+
%ret = load atomic <4 x half>, ptr %x acquire, align 8
403+
ret <4 x half> %ret
404+
}
405+
406+
define <4 x bfloat> @atomic_vec4_bfloat(ptr %x) nounwind {
407+
; CHECK-LABEL: atomic_vec4_bfloat:
408+
; CHECK: ## %bb.0:
409+
; CHECK-NEXT: movsd {{.*#+}} xmm0 = mem[0],zero
410+
; CHECK-NEXT: retq
411+
%ret = load atomic <4 x bfloat>, ptr %x acquire, align 8
412+
ret <4 x bfloat> %ret
413+
}
414+
379415
define <4 x float> @atomic_vec4_float_align(ptr %x) nounwind {
380416
; CHECK-LABEL: atomic_vec4_float_align:
381417
; CHECK: ## %bb.0:

0 commit comments

Comments
 (0)