From e8e955c4646113f243b99a876e3052ea3fc1e4d5 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 18 May 2022 19:05:59 +0200 Subject: [PATCH] more advanced load-store forwarding pass this should be able to eliminate some of the temporary storage introduced by translating shared memory --- include/polygeist/Passes/Passes.h | 1 + include/polygeist/Passes/Passes.td | 4 + lib/polygeist/Passes/CMakeLists.txt | 13 +- .../Passes/TemporaryStorageElimination.cpp | 259 ++++++++++++++++++ 4 files changed, 271 insertions(+), 6 deletions(-) create mode 100644 lib/polygeist/Passes/TemporaryStorageElimination.cpp diff --git a/include/polygeist/Passes/Passes.h b/include/polygeist/Passes/Passes.h index ba847ee251b8..dfd600db2978 100644 --- a/include/polygeist/Passes/Passes.h +++ b/include/polygeist/Passes/Passes.h @@ -24,6 +24,7 @@ std::unique_ptr createParallelLowerPass(); std::unique_ptr createConvertPolygeistToLLVMPass(const LowerToLLVMOptions &options); std::unique_ptr createConvertPolygeistToLLVMPass(); +std::unique_ptr createTemporaryStorageEliminationPass(); } // namespace polygeist } // namespace mlir diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index d2ab39bbd6e4..5087bc507251 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -78,6 +78,10 @@ def RemoveTrivialUse : Pass<"trivialuse"> { let constructor = "mlir::polygeist::createRemoveTrivialUsePass()"; } +def TemporaryStorageElimination : Pass<"temp-storage-elimination"> { + let constructor = "mlir::polygeist::createTemporaryStorageEliminationPass()"; +} + def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> { let summary = "Convert scalar and vector operations from the Standard to the " "LLVM dialect"; diff --git a/lib/polygeist/Passes/CMakeLists.txt b/lib/polygeist/Passes/CMakeLists.txt index 495942aaa34a..e95e3fbecf85 100644 --- a/lib/polygeist/Passes/CMakeLists.txt +++ b/lib/polygeist/Passes/CMakeLists.txt @@ -1,18 +1,19 @@ add_mlir_dialect_library(MLIRPolygeistTransforms AffineCFG.cpp AffineReduction.cpp + BarrierRemovalContinuation.cpp CanonicalizeFor.cpp + ConvertPolygeistToLLVM.cpp + InnerSerialization.cpp LoopRestructure.cpp Mem2Reg.cpp - ParallelLoopDistribute.cpp - ParallelLICM.cpp OpenMPOpt.cpp - BarrierRemovalContinuation.cpp - RaiseToAffine.cpp + ParallelLICM.cpp + ParallelLoopDistribute.cpp ParallelLower.cpp + RaiseToAffine.cpp + TemporaryStorageElimination.cpp TrivialUse.cpp - ConvertPolygeistToLLVM.cpp - InnerSerialization.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine diff --git a/lib/polygeist/Passes/TemporaryStorageElimination.cpp b/lib/polygeist/Passes/TemporaryStorageElimination.cpp new file mode 100644 index 000000000000..9a1ba774a849 --- /dev/null +++ b/lib/polygeist/Passes/TemporaryStorageElimination.cpp @@ -0,0 +1,259 @@ +//===- TemporaryStorageElimination.cpp - Shared memory-like elimination ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetails.h" +#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" +#include "mlir/Dialect/Affine/Analysis/AffineStructures.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/TypeID.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; + +#define DEBUG_TYPE "tmp-storage-elimination" +#define DBGS() llvm::dbgs() << "[" << DEBUG_TYPE << "] " + +namespace { +Block *getCommonAncestorBlock(Operation *first, Operation *second) { + Region *firstRegion = first->getParentRegion(); + Region *secondRegion = second->getParentRegion(); + if (firstRegion->isAncestor(secondRegion)) + return first->getBlock(); + if (secondRegion->isAncestor(firstRegion)) + return second->getBlock(); + + for (Region *region = firstRegion->getParentRegion(); region != nullptr; + region = region->getParentRegion()) { + if (region->isAncestor(secondRegion)) { + if (!llvm::hasSingleElement(*region)) + return nullptr; + return ®ion->getBlocks().front(); + } + } + return nullptr; +} + +AffineStoreOp findWriter(AffineLoadOp loadOp, Operation *root) { + // Find the stores to the same memref. + AffineStoreOp candidateStoreOp = nullptr; + WalkResult result = root->walk([&](AffineStoreOp storeOp) { + if (loadOp.getMemRef() != storeOp.getMemRef()) + return WalkResult::advance(); + if (candidateStoreOp) + return WalkResult::interrupt(); + candidateStoreOp = storeOp; + return WalkResult::advance(); + }); + + // If there's no or more than one writer, bail out. + if (result.wasInterrupted() || !candidateStoreOp) { + LLVM_DEBUG(DBGS() << "could not find the single writer\n"); + return AffineStoreOp(); + } + + // Check that the store happens before the load. + Block *commonParent = getCommonAncestorBlock(candidateStoreOp, loadOp); + if (!commonParent) { + LLVM_DEBUG( + DBGS() << "could not find a common parent between load and store\n"); + return AffineStoreOp(); + } + + if (!commonParent->findAncestorOpInBlock(*candidateStoreOp) + ->isBeforeInBlock(commonParent->findAncestorOpInBlock(*loadOp))) { + LLVM_DEBUG(DBGS() << "the store does not precede the load\n"); + return AffineStoreOp(); + } + + FlatAffineRelation loadRelation, storeRelation; + if (failed(MemRefAccess(loadOp).getAccessRelation(loadRelation)) || + failed(MemRefAccess(candidateStoreOp).getAccessRelation(storeRelation))) { + LLVM_DEBUG(DBGS() << "could not construct affine access relations\n"); + return AffineStoreOp(); + } + if (!loadRelation.getRangeSet().isSubsetOf(storeRelation.getRangeSet())) { + LLVM_DEBUG( + DBGS() + << "the set of loaded values is not a subset of written values\n"); + return AffineStoreOp(); + } + + return candidateStoreOp; +} + +AffineLoadOp findStoredValueLoad(AffineStoreOp storeOp) { + return storeOp.getValueToStore().getDefiningOp(); +} + +bool hasInterferringWrite(AffineLoadOp loadOp, AffineLoadOp originalLoadOp, + Operation *root) { + WalkResult result = root->walk([&](AffineStoreOp storeOp) { + // TODO: don't assume no-alias. + if (storeOp.getMemRef() != originalLoadOp.getMemRef()) + return WalkResult::advance(); + + // TODO: check if the store may happen before originalLoadOp and storeOp. + // For now, conservatively assume it may. + FlatAffineRelation loadRelation, storeRelation; + if (failed(MemRefAccess(originalLoadOp).getAccessRelation(loadRelation)) || + failed(MemRefAccess(storeOp).getAccessRelation(storeRelation))) { + LLVM_DEBUG(DBGS() << "could not construct affine access relations in " + "interference analysis\n"); + return WalkResult::interrupt(); + } + + if (!storeRelation.getRangeSet() + .intersect(loadRelation.getRangeSet()) + .isEmpty()) { + LLVM_DEBUG(DBGS() << "found interferring store: " << *storeOp << "\n"); + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + + return result.wasInterrupted(); +} + +AffineExpr tryExtractAffineExpr(const FlatAffineRelation &relation, + unsigned rangeDim, MLIRContext *ctx) { + std::unique_ptr clone = relation.clone(); + + clone->projectOut(relation.getNumDomainDims(), rangeDim); + clone->projectOut(relation.getNumDomainDims() + 1, + relation.getNumRangeDims() - rangeDim - 1); + if (clone->getNumEqualities() != 1) + return AffineExpr(); + + // TODO: support for local ids via mods. + ArrayRef eqCoeffs = clone->getEquality(0); + if (llvm::any_of(eqCoeffs.slice(relation.getNumDomainDims() + 1, + relation.getNumLocalIds()), + [](int64_t coeff) { return coeff != 0; })) { + return AffineExpr(); + } + + AffineExpr expr = getAffineConstantExpr(eqCoeffs.back(), ctx); + for (unsigned i = 0, e = relation.getNumDomainDims(); i != e; ++i) { + expr = expr + + getAffineConstantExpr(eqCoeffs[i], ctx) * getAffineDimExpr(i, ctx); + } + for (unsigned i = 0, e = relation.getNumSymbolIds(); i != e; ++i) { + expr = expr + getAffineConstantExpr( + eqCoeffs[relation.getNumDomainDims() + 1 + i], ctx) * + getAffineSymbolExpr(i, ctx); + } + return expr; +} + +AffineMap tryExtractAffineMap(const FlatAffineRelation &relation, + MLIRContext *ctx) { + SmallVector exprs; + for (unsigned i = 0, e = relation.getNumRangeDims(); i != e; ++i) { + exprs.push_back(tryExtractAffineExpr(relation, i, ctx)); + if (!exprs.back()) + return AffineMap(); + } + return AffineMap::get(relation.getNumDomainDims(), relation.getNumSymbolIds(), + exprs, ctx); +} + +void loadStoreForwarding(Operation *root) { + root->walk([root](AffineLoadOp loadOp) { + LLVM_DEBUG(DBGS() << "-----------------------------------------\n"); + LLVM_DEBUG(DBGS() << "considering " << *loadOp << "\n"); + AffineStoreOp storeOp = findWriter(loadOp, root); + if (!storeOp) + return; + + AffineLoadOp originalLoadOp = findStoredValueLoad(storeOp); + if (!originalLoadOp) + return; + + if (hasInterferringWrite(loadOp, originalLoadOp, root)) + return; + + // Replace the load, need the index remapping. + + // LLoops -> SMem. + FlatAffineRelation loadRelation; + // SLoops -> SMem. + FlatAffineRelation storeRelation; + // SLoops -> GMem. + FlatAffineRelation originalLoadRelation; + if (failed(MemRefAccess(loadOp).getAccessRelation(loadRelation)) || + failed(MemRefAccess(storeOp).getAccessRelation(storeRelation)) || + failed(MemRefAccess(originalLoadOp) + .getAccessRelation(originalLoadRelation))) { + LLVM_DEBUG(DBGS() << "could not construct affine access in remapping\n"); + return; + } + + // SMem -> SLoops. + storeRelation.inverse(); + // LLoops -> SLoops. + storeRelation.compose(loadRelation); + // LLoops -> GMem + originalLoadRelation.compose(storeRelation); + + AffineMap accessMap = + tryExtractAffineMap(originalLoadRelation, root->getContext()); + if (!accessMap) { + LLVM_DEBUG(DBGS() << "could not remap the access\n"); + return; + } + + IRRewriter rewriter(root->getContext()); + rewriter.setInsertionPoint(loadOp); + rewriter.replaceOpWithNewOp( + loadOp, originalLoadOp.getMemRef(), accessMap, loadOp.getIndices()); + LLVM_DEBUG(DBGS() << "replaced\n"); + }); +} + +void removeWriteOnlyAllocas(Operation *root) { + SmallVector toErase; + root->walk([&](memref::AllocaOp allocaOp) { + auto isWrite = [](Operation *op) { + return isa(op); + }; + if (llvm::all_of(allocaOp.getResult().getUsers(), isWrite)) { + llvm::append_range(toErase, allocaOp.getResult().getUsers()); + toErase.push_back(allocaOp); + } + }); + for (Operation *op : toErase) + op->erase(); +} + +struct TemporaryStorageEliminationPass + : TemporaryStorageEliminationBase { + void runOnOperation() override { + loadStoreForwarding(getOperation()); + removeWriteOnlyAllocas(getOperation()); + } +}; + +} // namespace + +namespace mlir { +namespace polygeist { +void registerTemporaryStorageEliminationPass() { + PassRegistration reg; +} + +std::unique_ptr createTemporaryStorageEliminationPass() { + return std::make_unique(); +} +} // namespace polygeist +} // namespace mlir