Skip to content

Commit b690505

Browse files
committed
more advanced load-store forwarding pass
this should be able to eliminate some of the temporary storage introduced by translating shared memory
1 parent d061557 commit b690505

File tree

4 files changed

+271
-6
lines changed

4 files changed

+271
-6
lines changed

include/polygeist/Passes/Passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ std::unique_ptr<Pass> createParallelLowerPass();
2424
std::unique_ptr<Pass>
2525
createConvertPolygeistToLLVMPass(const LowerToLLVMOptions &options);
2626
std::unique_ptr<Pass> createConvertPolygeistToLLVMPass();
27+
std::unique_ptr<Pass> createTemporaryStorageEliminationPass();
2728

2829
} // namespace polygeist
2930
} // namespace mlir

include/polygeist/Passes/Passes.td

+4
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ def RemoveTrivialUse : Pass<"trivialuse"> {
7878
let constructor = "mlir::polygeist::createRemoveTrivialUsePass()";
7979
}
8080

81+
def TemporaryStorageElimination : Pass<"temp-storage-elimination"> {
82+
let constructor = "mlir::polygeist::createTemporaryStorageEliminationPass()";
83+
}
84+
8185
def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> {
8286
let summary = "Convert scalar and vector operations from the Standard to the "
8387
"LLVM dialect";

lib/polygeist/Passes/CMakeLists.txt

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
add_mlir_dialect_library(MLIRPolygeistTransforms
22
AffineCFG.cpp
33
AffineReduction.cpp
4+
BarrierRemovalContinuation.cpp
45
CanonicalizeFor.cpp
6+
ConvertPolygeistToLLVM.cpp
7+
InnerSerialization.cpp
58
LoopRestructure.cpp
69
Mem2Reg.cpp
7-
ParallelLoopDistribute.cpp
8-
ParallelLICM.cpp
910
OpenMPOpt.cpp
10-
BarrierRemovalContinuation.cpp
11-
RaiseToAffine.cpp
11+
ParallelLICM.cpp
12+
ParallelLoopDistribute.cpp
1213
ParallelLower.cpp
14+
RaiseToAffine.cpp
15+
TemporaryStorageElimination.cpp
1316
TrivialUse.cpp
14-
ConvertPolygeistToLLVM.cpp
15-
InnerSerialization.cpp
1617

1718
ADDITIONAL_HEADER_DIRS
1819
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
//===- TemporaryStorageElimination.cpp - Shared memory-like elimination ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "PassDetails.h"
10+
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
11+
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
12+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
13+
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
14+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
15+
#include "mlir/IR/Operation.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include "mlir/Support/TypeID.h"
19+
#include "llvm/Support/Debug.h"
20+
21+
using namespace mlir;
22+
23+
#define DEBUG_TYPE "tmp-storage-elimination"
24+
#define DBGS() llvm::dbgs() << "[" << DEBUG_TYPE << "] "
25+
26+
namespace {
27+
Block *getCommonAncestorBlock(Operation *first, Operation *second) {
28+
Region *firstRegion = first->getParentRegion();
29+
Region *secondRegion = second->getParentRegion();
30+
if (firstRegion->isAncestor(secondRegion))
31+
return first->getBlock();
32+
if (secondRegion->isAncestor(firstRegion))
33+
return second->getBlock();
34+
35+
for (Region *region = firstRegion->getParentRegion(); region != nullptr;
36+
region = region->getParentRegion()) {
37+
if (region->isAncestor(secondRegion)) {
38+
if (!llvm::hasSingleElement(*region))
39+
return nullptr;
40+
return &region->getBlocks().front();
41+
}
42+
}
43+
return nullptr;
44+
}
45+
46+
AffineStoreOp findWriter(AffineLoadOp loadOp, Operation *root) {
47+
// Find the stores to the same memref.
48+
AffineStoreOp candidateStoreOp = nullptr;
49+
WalkResult result = root->walk([&](AffineStoreOp storeOp) {
50+
if (loadOp.getMemRef() != storeOp.getMemRef())
51+
return WalkResult::advance();
52+
if (candidateStoreOp)
53+
return WalkResult::interrupt();
54+
candidateStoreOp = storeOp;
55+
return WalkResult::advance();
56+
});
57+
58+
// If there's no or more than one writer, bail out.
59+
if (result.wasInterrupted() || !candidateStoreOp) {
60+
LLVM_DEBUG(DBGS() << "could not find the single writer\n");
61+
return AffineStoreOp();
62+
}
63+
64+
// Check that the store happens before the load.
65+
Block *commonParent = getCommonAncestorBlock(candidateStoreOp, loadOp);
66+
if (!commonParent) {
67+
LLVM_DEBUG(
68+
DBGS() << "could not find a common parent between load and store\n");
69+
return AffineStoreOp();
70+
}
71+
72+
if (!commonParent->findAncestorOpInBlock(*candidateStoreOp)
73+
->isBeforeInBlock(commonParent->findAncestorOpInBlock(*loadOp))) {
74+
LLVM_DEBUG(DBGS() << "the store does not precede the load\n");
75+
return AffineStoreOp();
76+
}
77+
78+
FlatAffineRelation loadRelation, storeRelation;
79+
if (failed(MemRefAccess(loadOp).getAccessRelation(loadRelation)) ||
80+
failed(MemRefAccess(candidateStoreOp).getAccessRelation(storeRelation))) {
81+
LLVM_DEBUG(DBGS() << "could not construct affine access relations\n");
82+
return AffineStoreOp();
83+
}
84+
if (!loadRelation.getRangeSet().isSubsetOf(storeRelation.getRangeSet())) {
85+
LLVM_DEBUG(
86+
DBGS()
87+
<< "the set of loaded values is not a subset of written values\n");
88+
return AffineStoreOp();
89+
}
90+
91+
return candidateStoreOp;
92+
}
93+
94+
AffineLoadOp findStoredValueLoad(AffineStoreOp storeOp) {
95+
return storeOp.getValueToStore().getDefiningOp<AffineLoadOp>();
96+
}
97+
98+
bool hasInterferringWrite(AffineLoadOp loadOp, AffineLoadOp originalLoadOp,
99+
Operation *root) {
100+
WalkResult result = root->walk([&](AffineStoreOp storeOp) {
101+
// TODO: don't assume no-alias.
102+
if (storeOp.getMemRef() != originalLoadOp.getMemRef())
103+
return WalkResult::advance();
104+
105+
// TODO: check if the store may happen before originalLoadOp and storeOp.
106+
// For now, conservatively assume it may.
107+
FlatAffineRelation loadRelation, storeRelation;
108+
if (failed(MemRefAccess(originalLoadOp).getAccessRelation(loadRelation)) ||
109+
failed(MemRefAccess(storeOp).getAccessRelation(storeRelation))) {
110+
LLVM_DEBUG(DBGS() << "could not construct affine access relations in "
111+
"interference analysis\n");
112+
return WalkResult::interrupt();
113+
}
114+
115+
if (!storeRelation.getRangeSet()
116+
.intersect(loadRelation.getRangeSet())
117+
.isEmpty()) {
118+
LLVM_DEBUG(DBGS() << "found interferring store: " << *storeOp << "\n");
119+
return WalkResult::interrupt();
120+
}
121+
122+
return WalkResult::advance();
123+
});
124+
125+
return result.wasInterrupted();
126+
}
127+
128+
AffineExpr tryExtractAffineExpr(const FlatAffineRelation &relation,
129+
unsigned rangeDim, MLIRContext *ctx) {
130+
std::unique_ptr<FlatAffineValueConstraints> clone = relation.clone();
131+
132+
clone->projectOut(relation.getNumDomainDims(), rangeDim);
133+
clone->projectOut(relation.getNumDomainDims() + 1,
134+
relation.getNumRangeDims() - rangeDim - 1);
135+
if (clone->getNumEqualities() != 1)
136+
return AffineExpr();
137+
138+
// TODO: support for local ids via mods.
139+
ArrayRef<int64_t> eqCoeffs = clone->getEquality(0);
140+
if (llvm::any_of(eqCoeffs.slice(relation.getNumDomainDims() + 1,
141+
relation.getNumLocalIds()),
142+
[](int64_t coeff) { return coeff != 0; })) {
143+
return AffineExpr();
144+
}
145+
146+
AffineExpr expr = getAffineConstantExpr(eqCoeffs.back(), ctx);
147+
for (unsigned i = 0, e = relation.getNumDomainDims(); i != e; ++i) {
148+
expr = expr +
149+
getAffineConstantExpr(eqCoeffs[i], ctx) * getAffineDimExpr(i, ctx);
150+
}
151+
for (unsigned i = 0, e = relation.getNumSymbolIds(); i != e; ++i) {
152+
expr = expr + getAffineConstantExpr(
153+
eqCoeffs[relation.getNumDomainDims() + 1 + i], ctx) *
154+
getAffineSymbolExpr(i, ctx);
155+
}
156+
return expr;
157+
}
158+
159+
AffineMap tryExtractAffineMap(const FlatAffineRelation &relation,
160+
MLIRContext *ctx) {
161+
SmallVector<AffineExpr> exprs;
162+
for (unsigned i = 0, e = relation.getNumRangeDims(); i != e; ++i) {
163+
exprs.push_back(tryExtractAffineExpr(relation, i, ctx));
164+
if (!exprs.back())
165+
return AffineMap();
166+
}
167+
return AffineMap::get(relation.getNumDomainDims(), relation.getNumSymbolIds(),
168+
exprs, ctx);
169+
}
170+
171+
void loadStoreForwarding(Operation *root) {
172+
root->walk([root](AffineLoadOp loadOp) {
173+
LLVM_DEBUG(DBGS() << "-----------------------------------------\n");
174+
LLVM_DEBUG(DBGS() << "considering " << *loadOp << "\n");
175+
AffineStoreOp storeOp = findWriter(loadOp, root);
176+
if (!storeOp)
177+
return;
178+
179+
AffineLoadOp originalLoadOp = findStoredValueLoad(storeOp);
180+
if (!originalLoadOp)
181+
return;
182+
183+
if (hasInterferringWrite(loadOp, originalLoadOp, root))
184+
return;
185+
186+
// Replace the load, need the index remapping.
187+
188+
// LLoops -> SMem.
189+
FlatAffineRelation loadRelation;
190+
// SLoops -> SMem.
191+
FlatAffineRelation storeRelation;
192+
// SLoops -> GMem.
193+
FlatAffineRelation originalLoadRelation;
194+
if (failed(MemRefAccess(loadOp).getAccessRelation(loadRelation)) ||
195+
failed(MemRefAccess(storeOp).getAccessRelation(storeRelation)) ||
196+
failed(MemRefAccess(originalLoadOp)
197+
.getAccessRelation(originalLoadRelation))) {
198+
LLVM_DEBUG(DBGS() << "could not construct affine access in remapping\n");
199+
return;
200+
}
201+
202+
// SMem -> SLoops.
203+
storeRelation.inverse();
204+
// LLoops -> SLoops.
205+
storeRelation.compose(loadRelation);
206+
// LLoops -> GMem
207+
originalLoadRelation.compose(storeRelation);
208+
209+
AffineMap accessMap =
210+
tryExtractAffineMap(originalLoadRelation, root->getContext());
211+
if (!accessMap) {
212+
LLVM_DEBUG(DBGS() << "could not remap the access\n");
213+
return;
214+
}
215+
216+
IRRewriter rewriter(root->getContext());
217+
rewriter.setInsertionPoint(loadOp);
218+
rewriter.replaceOpWithNewOp<AffineLoadOp>(
219+
loadOp, originalLoadOp.getMemRef(), accessMap, loadOp.getIndices());
220+
LLVM_DEBUG(DBGS() << "replaced\n");
221+
});
222+
}
223+
224+
void removeWriteOnlyAllocas(Operation *root) {
225+
SmallVector<Operation *> toErase;
226+
root->walk([&](memref::AllocaOp allocaOp) {
227+
auto isWrite = [](Operation *op) {
228+
return isa<AffineWriteOpInterface, memref::StoreOp>(op);
229+
};
230+
if (llvm::all_of(allocaOp.getResult().getUsers(), isWrite)) {
231+
llvm::append_range(toErase, allocaOp.getResult().getUsers());
232+
toErase.push_back(allocaOp);
233+
}
234+
});
235+
for (Operation *op : toErase)
236+
op->erase();
237+
}
238+
239+
struct TemporaryStorageEliminationPass
240+
: TemporaryStorageEliminationBase<TemporaryStorageEliminationPass> {
241+
void runOnOperation() override {
242+
loadStoreForwarding(getOperation());
243+
removeWriteOnlyAllocas(getOperation());
244+
}
245+
};
246+
247+
} // namespace
248+
249+
namespace mlir {
250+
namespace polygeist {
251+
void registerTemporaryStorageEliminationPass() {
252+
PassRegistration<TemporaryStorageEliminationPass> reg;
253+
}
254+
255+
std::unique_ptr<Pass> createTemporaryStorageEliminationPass() {
256+
return std::make_unique<TemporaryStorageEliminationPass>();
257+
}
258+
} // namespace polygeist
259+
} // namespace mlir

0 commit comments

Comments
 (0)