Skip to content

Commit 0c71559

Browse files
[LVL][CSA] Legalize CSA vectorization
1 parent e741182 commit 0c71559

File tree

7 files changed

+72
-4
lines changed

7 files changed

+72
-4
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

+9
Original file line numberDiff line numberDiff line change
@@ -1828,6 +1828,10 @@ class TargetTransformInfo {
18281828
: EVLParamStrategy(EVLParamStrategy), OpStrategy(OpStrategy) {}
18291829
};
18301830

1831+
/// \returns true if the loop vectorizer should vectorize conditional
1832+
/// scalar assignments for the target.
1833+
bool enableCSAVectorization() const;
1834+
18311835
/// \returns How the target needs this vector-predicated operation to be
18321836
/// transformed.
18331837
VPLegalization getVPLegalizationStrategy(const VPIntrinsic &PI) const;
@@ -2266,6 +2270,7 @@ class TargetTransformInfo::Concept {
22662270
SmallVectorImpl<Use *> &OpsToSink) const = 0;
22672271

22682272
virtual bool isVectorShiftByScalarCheap(Type *Ty) const = 0;
2273+
virtual bool enableCSAVectorization() const = 0;
22692274
virtual VPLegalization
22702275
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
22712276
virtual bool hasArmWideBranch(bool Thumb) const = 0;
@@ -3077,6 +3082,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
30773082
return Impl.isVectorShiftByScalarCheap(Ty);
30783083
}
30793084

3085+
bool enableCSAVectorization() const override {
3086+
return Impl.enableCSAVectorization();
3087+
}
3088+
30803089
VPLegalization
30813090
getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
30823091
return Impl.getVPLegalizationStrategy(PI);

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

+2
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,8 @@ class TargetTransformInfoImplBase {
10161016

10171017
bool isVectorShiftByScalarCheap(Type *Ty) const { return false; }
10181018

1019+
bool enableCSAVectorization() const { return false; }
1020+
10191021
TargetTransformInfo::VPLegalization
10201022
getVPLegalizationStrategy(const VPIntrinsic &PI) const {
10211023
return TargetTransformInfo::VPLegalization(

llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h

+18
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#define LLVM_TRANSFORMS_VECTORIZE_LOOPVECTORIZATIONLEGALITY_H
2828

2929
#include "llvm/ADT/MapVector.h"
30+
#include "llvm/Analysis/CSADescriptors.h"
3031
#include "llvm/Analysis/LoopAccessAnalysis.h"
3132
#include "llvm/Support/TypeSize.h"
3233
#include "llvm/Transforms/Utils/LoopUtils.h"
@@ -269,6 +270,10 @@ class LoopVectorizationLegality {
269270
/// induction descriptor.
270271
using InductionList = MapVector<PHINode *, InductionDescriptor>;
271272

273+
/// CSAList contains the CSA descriptors for all the CSAs that were found
274+
/// in the loop, rooted by their phis.
275+
using CSAList = MapVector<PHINode *, CSADescriptor>;
276+
272277
/// RecurrenceSet contains the phi nodes that are recurrences other than
273278
/// inductions and reductions.
274279
using RecurrenceSet = SmallPtrSet<const PHINode *, 8>;
@@ -321,6 +326,12 @@ class LoopVectorizationLegality {
321326
/// Returns True if V is a Phi node of an induction variable in this loop.
322327
bool isInductionPhi(const Value *V) const;
323328

329+
/// Returns the CSAs found in the loop.
330+
const CSAList &getCSAs() const { return CSAs; }
331+
332+
/// Returns true if Phi is the root of a CSA in the loop.
333+
bool isCSAPhi(PHINode *Phi) const { return CSAs.count(Phi) != 0; }
334+
324335
/// Returns a pointer to the induction descriptor, if \p Phi is an integer or
325336
/// floating point induction.
326337
const InductionDescriptor *getIntOrFpInductionDescriptor(PHINode *Phi) const;
@@ -550,6 +561,10 @@ class LoopVectorizationLegality {
550561
void addInductionPhi(PHINode *Phi, const InductionDescriptor &ID,
551562
SmallPtrSetImpl<Value *> &AllowedExit);
552563

564+
// Updates the vetorization state by adding \p Phi to the CSA list.
565+
void addCSAPhi(PHINode *Phi, const CSADescriptor &CSADesc,
566+
SmallPtrSetImpl<Value *> &AllowedExit);
567+
553568
/// The loop that we evaluate.
554569
Loop *TheLoop;
555570

@@ -594,6 +609,9 @@ class LoopVectorizationLegality {
594609
/// variables can be pointers.
595610
InductionList Inductions;
596611

612+
/// Holds the conditional scalar assignments
613+
CSAList CSAs;
614+
597615
/// Holds all the casts that participate in the update chain of the induction
598616
/// variables, and that have been proven to be redundant (possibly under a
599617
/// runtime guard). These casts can be ignored when creating the vectorized

llvm/lib/Analysis/TargetTransformInfo.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,10 @@ bool TargetTransformInfo::preferEpilogueVectorization() const {
13511351
return TTIImpl->preferEpilogueVectorization();
13521352
}
13531353

1354+
bool TargetTransformInfo::enableCSAVectorization() const {
1355+
return TTIImpl->enableCSAVectorization();
1356+
}
1357+
13541358
TargetTransformInfo::VPLegalization
13551359
TargetTransformInfo::getVPLegalizationStrategy(const VPIntrinsic &VPI) const {
13561360
return TTIImpl->getVPLegalizationStrategy(VPI);

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -2361,6 +2361,11 @@ bool RISCVTTIImpl::isLegalMaskedExpandLoad(Type *DataTy, Align Alignment) {
23612361
return true;
23622362
}
23632363

2364+
bool RISCVTTIImpl::enableCSAVectorization() const {
2365+
return ST->hasVInstructions() &&
2366+
ST->getProcFamily() == RISCVSubtarget::SiFive7;
2367+
}
2368+
23642369
bool RISCVTTIImpl::isLegalMaskedCompressStore(Type *DataTy, Align Alignment) {
23652370
auto *VTy = dyn_cast<VectorType>(DataTy);
23662371
if (!VTy || VTy->isScalableTy())

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

+4
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,10 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
306306
return TLI->isVScaleKnownToBeAPowerOfTwo();
307307
}
308308

309+
/// \returns true if the loop vectorizer should vectorize conditional
310+
/// scalar assignments for the target.
311+
bool enableCSAVectorization() const;
312+
309313
/// \returns How the target needs this vector-predicated operation to be
310314
/// transformed.
311315
TargetTransformInfo::VPLegalization

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

+30-4
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ static cl::opt<bool> EnableHistogramVectorization(
8383
"enable-histogram-loop-vectorization", cl::init(false), cl::Hidden,
8484
cl::desc("Enables autovectorization of some loops containing histograms"));
8585

86+
static cl::opt<bool>
87+
EnableCSA("enable-csa-vectorization", cl::init(false), cl::Hidden,
88+
cl::desc("Control whether CSA loop vectorization is enabled"));
89+
8690
/// Maximum vectorization interleave count.
8791
static const unsigned MaxInterleaveFactor = 16;
8892

@@ -750,6 +754,15 @@ bool LoopVectorizationLegality::setupOuterLoopInductions() {
750754
return llvm::all_of(Header->phis(), IsSupportedPhi);
751755
}
752756

757+
void LoopVectorizationLegality::addCSAPhi(
758+
PHINode *Phi, const CSADescriptor &CSADesc,
759+
SmallPtrSetImpl<Value *> &AllowedExit) {
760+
assert(CSADesc.isValid() && "Expected Valid CSADescriptor");
761+
LLVM_DEBUG(dbgs() << "LV: found legal CSA opportunity" << *Phi << "\n");
762+
AllowedExit.insert(Phi);
763+
CSAs.insert({Phi, CSADesc});
764+
}
765+
753766
/// Checks if a function is scalarizable according to the TLI, in
754767
/// the sense that it should be vectorized and then expanded in
755768
/// multiple scalar calls. This is represented in the
@@ -867,14 +880,23 @@ bool LoopVectorizationLegality::canVectorizeInstrs() {
867880
continue;
868881
}
869882

870-
// As a last resort, coerce the PHI to a AddRec expression
871-
// and re-try classifying it a an induction PHI.
883+
// Try to coerce the PHI to a AddRec expression and re-try classifying
884+
// it a an induction PHI.
872885
if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID, true) &&
873886
!IsDisallowedStridedPointerInduction(ID)) {
874887
addInductionPhi(Phi, ID, AllowedExit);
875888
continue;
876889
}
877890

891+
// Check if the PHI can be classified as a CSA PHI.
892+
if (EnableCSA || (TTI->enableCSAVectorization() &&
893+
EnableCSA.getNumOccurrences() == 0)) {
894+
if (auto CSADesc = CSADescriptor::isCSAPhi(Phi, TheLoop)) {
895+
addCSAPhi(Phi, CSADesc, AllowedExit);
896+
continue;
897+
}
898+
}
899+
878900
reportVectorizationFailure("Found an unidentified PHI",
879901
"value that could not be identified as "
880902
"reduction is used outside the loop",
@@ -1858,11 +1880,15 @@ bool LoopVectorizationLegality::canFoldTailByMasking() const {
18581880
for (const auto &Reduction : getReductionVars())
18591881
ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());
18601882

1883+
SmallPtrSet<const Value *, 8> CSALiveOuts;
1884+
for (const auto &CSA : getCSAs())
1885+
CSALiveOuts.insert(CSA.second.getAssignment());
1886+
18611887
// TODO: handle non-reduction outside users when tail is folded by masking.
18621888
for (auto *AE : AllowedExit) {
18631889
// Check that all users of allowed exit values are inside the loop or
1864-
// are the live-out of a reduction.
1865-
if (ReductionLiveOuts.count(AE))
1890+
// are the live-out of a reduction or a CSA
1891+
if (ReductionLiveOuts.count(AE) || CSALiveOuts.count(AE))
18661892
continue;
18671893
for (User *U : AE->users()) {
18681894
Instruction *UI = cast<Instruction>(U);

0 commit comments

Comments
 (0)