diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index fc727abb3..5cfa3ec05 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -58,4 +58,36 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> { ]; } +def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion", + "func::FuncOp"> { + let summary = "Iterative tiling and fusion for any tilable operation"; + let description = [{ + The pass tries to fuse any MLIR operation which can be tiled. Moreover, this pass aims to support: + 1. Matmul fusion with element-wise/reduce/broadcast ops. + 2. Pre-op and post-op fusion. + 3. Multi-consumer and multi-producer support. + 4. Multiple level of nest loops and candidates. + 5. Flexible option to control the boundary of iterative process. + 6. Cost-model to determine whether to fuse or not. + + It intends to control the granularity of fusion by `fusion-level`, E.g. + * `0`: disable any fusion. + * `1`:[Default] enable both producer and consumer fusion, covering any tilable operation including tensor.pack/tensor.fill/linalg.reduce etc but excluding branches forked by multiple uses. + * `2`: `LEVEL 1` + extend to any topology including branches. + }]; + let dependentDialects = ["func::FuncDialect", "linalg::LinalgDialect", "scf::SCFDialect", + "tensor::TensorDialect"]; + + let options = [ + Option<"fusionLevel", "fusion-level", "int64_t", + /*default=*/"1", + "Control the granularity of fusion.">, + Option<"useCostModel", "use-cost-model", "bool", + /*default=*/"false", + "Decide if enable cost model to control iterative fusion.">, + ListOption<"defaultTileSize", "default-tile-size", "std::string", + "Set default TileSize for the certain type of op, saying matmul:{32,32}">, + ]; +} + #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 1b4f2cb73..47a004e04 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -13,6 +13,8 @@ add_mlir_library(GCPasses OneDNNGraphToLinalg.cpp Pipeline.cpp TileNamed.cpp + IterativeTilingAndFusion.cpp + TilingUsingInterfaceX.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include diff --git a/lib/gc/Transforms/IterativeTilingAndFusion.cpp b/lib/gc/Transforms/IterativeTilingAndFusion.cpp new file mode 100644 index 000000000..f2ccb7024 --- /dev/null +++ b/lib/gc/Transforms/IterativeTilingAndFusion.cpp @@ -0,0 +1,841 @@ +//===-- IterativeTilingAndFusion.cpp - Iterative Tiling+Fusion --*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Transforms/Passes.h" +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/DLTI/Traits.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" +#include +#include +#include + +#include "TilingUsingInterfaceX.h" + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_ITERATIVETILINGANDFUSION +#include "gc/Transforms/Passes.h.inc" + +static FailureOr +getClosestExtractSliceOfOperand(OpOperand &operand) { + if (auto iterArg = dyn_cast(operand.get())) { + if (auto loop = + dyn_cast(iterArg.getOwner()->getParentOp())) + return getClosestExtractSliceOfOperand(*loop.getTiedLoopInit(iterArg)); + } + + Operation *defineOp = operand.get().getDefiningOp(); + if (auto sliceOp = dyn_cast(defineOp)) + return sliceOp; + // For downstream cases + if (isa( + defineOp)) + return getClosestExtractSliceOfOperand(defineOp->getOpOperand(0)); + + return failure(); +} + +static FailureOr +getClosestInsertSliceOfResult(OpResult result) { + OffsetSizeAndStrideOpInterface sliceOp; + for (auto &useOfResult : result.getUses()) { + if (isa(useOfResult.getOwner()) || + isa(useOfResult.getOwner())) { + if (llvm::detail::isPresent(sliceOp)) + return failure(); + sliceOp = + dyn_cast(useOfResult.getOwner()); + } else if (auto yieldOp = dyn_cast(useOfResult.getOwner())) { + if (auto loop = dyn_cast(yieldOp->getParentOp())) + return getClosestInsertSliceOfResult( + loop->getResult(useOfResult.getOperandNumber())); + } + } + + if (!llvm::detail::isPresent(sliceOp)) + return failure(); + else { + return sliceOp; + } +} + +struct CandidateDefOrUse { + enum Type { def = 0, use }; + Operation *ownerOp; + Type type; + union { + OpOperand *operand; + OpResult result; + }; + + CandidateDefOrUse(OpResult resultOfDefOp) + : ownerOp(resultOfDefOp.getDefiningOp()), type(Type::def), + result(resultOfDefOp) {} + CandidateDefOrUse(OpOperand *operandOfUseOp) + : ownerOp(operandOfUseOp->getOwner()), type(Type::use), + operand(operandOfUseOp) {} + + bool isDef() const { return type == Type::def; } + bool isUse() const { return type == Type::use; } +}; + +using CandidateSliceFilter = std::function; + +using CandidateSliceComparer = std::function; + +static LogicalResult +noTilingOnReductionFilter(RewriterBase &rewriter, + OffsetSizeAndStrideOpInterface candidate, + CandidateDefOrUse defOrUse) { + linalg::LinalgOp linalgOp = dyn_cast(defOrUse.ownerOp); + if (!linalgOp) + return success(); + + AffineMap affMap = + defOrUse.isDef() ? linalgOp.getIndexingMapMatchingResult(defOrUse.result) + : linalgOp.getMatchingIndexingMap(defOrUse.operand); + + TilingInterface tilableOp = dyn_cast(defOrUse.ownerOp); + SmallVector iterDomain = tilableOp.getIterationDomain(rewriter); + SmallVector iterTypes = tilableOp.getLoopIteratorTypes(); + SmallVector tileSizes = candidate.getMixedSizes(); + // check reduction iteration is full on TileSizes + for (const auto &resultExpr : llvm::enumerate(affMap.getResults())) { + unsigned iterPosition = + cast(resultExpr.value()).getPosition(); + if (iterTypes[iterPosition] == utils::IteratorType::reduction) { + std::optional cstIterDomain = + getConstantIntValue(iterDomain[iterPosition].size); + FailureOr cstTileSizes = + ValueBoundsConstraintSet::computeConstantBound( + presburger::BoundType::UB, tileSizes[resultExpr.index()], nullptr, + true); + if (!cstIterDomain || failed(cstTileSizes) || + cstIterDomain != cstTileSizes) + return failure(); + } + } + return success(); +} + +static LogicalResult +exactTilingOnPackUnPackFilter(RewriterBase &rewriter, + OffsetSizeAndStrideOpInterface candidate, + CandidateDefOrUse defOrUse) { + if (!isa(defOrUse.ownerOp)) + return success(); + + SmallVector tileSizes = candidate.getMixedSizes(); + // collect target TileSizes and InnerTileSize to compare + SmallVector targetTileSizes, targetInnerTileSizes; + if (auto packOp = dyn_cast(defOrUse.ownerOp)) { + // tileSize comes from OpResult + if (defOrUse.isDef()) { + targetInnerTileSizes = packOp.getInnerTiles(); + targetTileSizes = llvm::to_vector( + ArrayRef(tileSizes).take_back(targetInnerTileSizes.size())); + } else { + // tileSize comes from OpOperand + targetTileSizes = llvm::to_vector(tileSizes); + DenseMap dimAndTileMapping = + packOp.getDimAndTileMapping(); + targetInnerTileSizes.resize(dimAndTileMapping.size()); + for (const auto &dimAndTile : dimAndTileMapping) { + targetInnerTileSizes[dimAndTile.first] = dimAndTile.second; + } + } + } else if (auto unPackOp = dyn_cast(defOrUse.ownerOp)) { + // tileSize comes from OpResult + if (defOrUse.isDef()) { + targetTileSizes = llvm::to_vector(tileSizes); + DenseMap dimAndTileMapping = + unPackOp.getDimAndTileMapping(); + targetInnerTileSizes.resize(dimAndTileMapping.size()); + for (const auto &dimAndTile : dimAndTileMapping) { + targetInnerTileSizes[dimAndTile.first] = dimAndTile.second; + } + } else { + // tileSize comes from OpOperand + targetInnerTileSizes = unPackOp.getInnerTiles(); + targetTileSizes = llvm::to_vector( + ArrayRef(tileSizes).take_back(targetInnerTileSizes.size())); + } + } + + // check tileSizes is full on or multiple of `inner_tile_size` + for (auto [tile, innerTile] : + llvm::zip_equal(targetTileSizes, targetInnerTileSizes)) { + if (isEqualConstantIntOrValue(tile, innerTile)) + continue; + FailureOr cstSize = ValueBoundsConstraintSet::computeConstantBound( + presburger::BoundType::UB, tile, + /*stopCondition=*/nullptr, /*closedUB=*/true); + std::optional cstInnerSize = getConstantIntValue(innerTile); + if (!failed(cstSize) && cstInnerSize) { + if (*cstSize % *cstInnerSize == 0) + continue; + } + return failure(); + } + return success(); +} + +static LogicalResult unTiledOpFilter(RewriterBase &rewriter, + OffsetSizeAndStrideOpInterface candidate, + CandidateDefOrUse defOrUse) { + // In general tiledOp would not have uses any more. + return failure(defOrUse.ownerOp->use_empty()); +} + +static LogicalResult +NonContractionOpFilter(RewriterBase &rewriter, + OffsetSizeAndStrideOpInterface candidate, + CandidateDefOrUse defOrUse) { + // Currently this pass focuses on fine-grained fusion, which does not expect + // two consecutive contraction ops. + return failure(isa(defOrUse.ownerOp)); +} + +static LogicalResult +SingleCandidateInBlockFilter(RewriterBase &rewriter, + OffsetSizeAndStrideOpInterface candidate, + CandidateDefOrUse defOrUse) { + Block *parent = candidate->getBlock(); + + // a. traverse all ops contained in parent Block. + for (auto &opInBlock : parent->getOperations()) { + // b. skip candidate slice + if (&opInBlock == candidate.getOperation()) + continue; + // c. check if all the other sliceOp not defined or used by the same owner + // with candidate slice. + if (auto otherCandidate = + dyn_cast(&opInBlock)) { + if (defOrUse.isDef()) { + SmallVector backwardSlice; + FailureOr realProducer = + scfX::getRealProducerOfExtractSliceOp(otherCandidate, + backwardSlice); + if (succeeded(realProducer) && + realProducer->getDefiningOp() == defOrUse.ownerOp) + return failure(); + } else { + SmallVector forwardSlice; + FailureOr> realConsumers = + scfX::getRealConsumersFromInsertSliceOp(otherCandidate, + forwardSlice); + if (succeeded(realConsumers) && + llvm::any_of(*realConsumers, [&defOrUse](OpOperand *use) { + return use->getOwner() == defOrUse.ownerOp; + })) + return failure(); + } + } + } + return success(); +} + +template struct CandidateSliceProcessPipeLine { + SmallVector candidateProcessFn; + CandidateSliceProcessPipeLine() { + append(static_cast(this)->getDefaultPipeLine()); + } + CandidateSliceProcessPipeLine(const T1 &newFn) + : CandidateSliceProcessPipeLine() { + append(newFn); + } + CandidateSliceProcessPipeLine(ArrayRef newFns) + : CandidateSliceProcessPipeLine() { + append(newFns); + } + + void append(const T1 &newFn) { candidateProcessFn.push_back(newFn); } + void append(ArrayRef newFns) { + llvm::append_range(candidateProcessFn, newFns); + } + + SmallVector getDefaultPipeLine() { return {}; } +}; + +struct CandidateSliceFilterPipeLine + : public CandidateSliceProcessPipeLine { + CandidateSliceFilterPipeLine() : CandidateSliceProcessPipeLine() {} + CandidateSliceFilterPipeLine(const CandidateSliceFilter &filter) + : CandidateSliceProcessPipeLine(filter) {} + CandidateSliceFilterPipeLine(const SmallVector &filters) + : CandidateSliceProcessPipeLine(filters) {} + + SmallVector getDefaultPipeLine() { + return SmallVector{ + unTiledOpFilter, NonContractionOpFilter, noTilingOnReductionFilter, + exactTilingOnPackUnPackFilter, SingleCandidateInBlockFilter}; + } + + LogicalResult filter(RewriterBase &rewriter, + OffsetSizeAndStrideOpInterface candidate, + CandidateDefOrUse defOrUse) const { + return success(llvm::all_of( + candidateProcessFn, + [&rewriter, &candidate, &defOrUse](const CandidateSliceFilter &filter) { + return succeeded(filter(rewriter, candidate, defOrUse)); + })); + } +}; + +static FailureOr +computeTileSizeProductOfCandidate(OffsetSizeAndStrideOpInterface candidate) { + SmallVector tileSizes = candidate.getMixedSizes(); + int64_t totalSize = 1; + for (auto &tile : tileSizes) { + FailureOr cstSize = ValueBoundsConstraintSet::computeConstantBound( + presburger::BoundType::UB, tile, + /*stopCondition=*/nullptr, /*closedUB=*/true); + if (failed(cstSize)) { + return failure(); + } + totalSize *= *cstSize; + }; + return totalSize; +} + +static int TilingSizeComparer(OffsetSizeAndStrideOpInterface candidateA, + OffsetSizeAndStrideOpInterface candidateB) { + FailureOr sizeProductA = + computeTileSizeProductOfCandidate(candidateA), + sizeProductB = + computeTileSizeProductOfCandidate(candidateB); + if (failed(sizeProductA) || failed(sizeProductB)) + return 0; + // deal with equality + if (*sizeProductA == *sizeProductB) + return 0; + + return *sizeProductA < *sizeProductB ? -1 : 1; +} + +struct CandidateSliceComparerPipeLine + : public CandidateSliceProcessPipeLine { + CandidateSliceComparerPipeLine() : CandidateSliceProcessPipeLine() {} + + SmallVector getDefaultPipeLine() { + return SmallVector{TilingSizeComparer}; + } + + bool compare(OffsetSizeAndStrideOpInterface candidateA, + OffsetSizeAndStrideOpInterface candidateB) const { + // deal with weak order + int cmpResult = -1; + llvm::any_of(candidateProcessFn, [&cmpResult, &candidateA, &candidateB]( + const CandidateSliceComparer &fn) { + cmpResult = fn(candidateA, candidateB); + return cmpResult != 0; + }); + return cmpResult == -1; + } +}; + +struct CandidateSliceOptions { + // Use for validity + CandidateSliceFilterPipeLine filterPipeLine; + // Use for performance + CandidateSliceComparerPipeLine comparerPipeLine; + + CandidateSliceOptions() = default; + + void addFilter(const CandidateSliceFilter &filter) { + filterPipeLine.append(filter); + } + void addFilter(ArrayRef filters) { + filterPipeLine.append(filters); + } + void addComparer(const CandidateSliceComparer &comparer) { + comparerPipeLine.append(comparer); + } + void addFilter(ArrayRef comparers) { + comparerPipeLine.append(comparers); + } +}; + +static FailureOr filterAndSelectCandidate( + RewriterBase &rewriter, + ArrayRef candidateSliceList, + const CandidateDefOrUse &defOrUse, const CandidateSliceOptions &options) { + SmallVector validCandidates = + llvm::to_vector(llvm::make_filter_range( + candidateSliceList, + [&rewriter, &options, + &defOrUse](const OffsetSizeAndStrideOpInterface &candidate) { + return succeeded( + options.filterPipeLine.filter(rewriter, candidate, defOrUse)); + })); + if (validCandidates.empty()) + return failure(); + + OffsetSizeAndStrideOpInterface bestCandidate = *llvm::min_element( + validCandidates, [&options](OffsetSizeAndStrideOpInterface &candidateA, + OffsetSizeAndStrideOpInterface &candidateB) { + return options.comparerPipeLine.compare(candidateA, candidateB); + }); + return bestCandidate; +} + +std::optional +tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand, + const CandidateSliceOptions &options) { + // a. Find the closest sliceOp + FailureOr closestSliceOp = + getClosestExtractSliceOfOperand(operand); + if (failed(closestSliceOp)) + return std::nullopt; + + // b. Find the real producer and collect the sliceOp chain during backward + // stage, sorted from inner to outer. + SmallVector backwardSlice; + FailureOr realProducer = + scfX::getRealProducerOfExtractSliceOp(*closestSliceOp, backwardSlice); + if (failed(realProducer)) + return std::nullopt; + + // c. Check the producer of root source if is tilable. + Operation *producer = realProducer->getDefiningOp(); + if (!producer) + return std::nullopt; + + CandidateDefOrUse defOrUse{*realProducer}; + // d. Filter out invalid candidates and select best candidates + SmallVector ossBackwardSlice = + llvm::map_to_vector(backwardSlice, + [](tensor::ExtractSliceOp &extractSlice) { + return cast( + extractSlice.getOperation()); + }); + FailureOr bestCandidate = + filterAndSelectCandidate(rewriter, ossBackwardSlice, defOrUse, options); + if (failed(bestCandidate)) + return std::nullopt; + + // e. call tiling interface + return scfX::tileAndFuseProducerOfSlice(rewriter, *bestCandidate); +} + +std::optional> +tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result, + const CandidateSliceOptions &options) { + // a. Find the closest sliceOp + FailureOr closestSliceOp = + getClosestInsertSliceOfResult(result); + if (failed(closestSliceOp)) + return std::nullopt; + + // b. Find the real consumers and collect the sliceOp chain during forward + // stage, sorted from inner to outer. + SmallVector forwardSlice; + FailureOr> realConsumers = + scfX::getRealConsumersFromInsertSliceOp(*closestSliceOp, forwardSlice); + if (failed(realConsumers)) + return std::nullopt; + + SmallVector fusedResultList; + for (auto useOperand : *realConsumers) { + // c. Check the consumer of top level result if is tilable. + Operation *consumer = dyn_cast(useOperand->getOwner()); + if (!consumer) + continue; + + CandidateDefOrUse defOrUse{useOperand}; + // d. Filter out invalid candidates and select best candidates + FailureOr bestCandidate = + filterAndSelectCandidate(rewriter, forwardSlice, defOrUse, options); + if (failed(bestCandidate)) + continue; + + // e. call tiling interface + FailureOr fusedResult = + scfX::tileAndFuseConsumerOfSlice(rewriter, *bestCandidate); + + if (succeeded(fusedResult)) { + fusedResultList.push_back(*fusedResult); + auto whileProducerOutOfLoopBlock = + [&fusedResult](LoopLikeOpInterface loop) -> LogicalResult { + Block &body = loop->getRegion(0).front(); + return failure(fusedResult.value().tiledOps[0]->getBlock() == &body); + }; + SmallVector outerLoops = + scfX::getOuterNestLoopsWhile( + (*bestCandidate)->getParentOfType(), + whileProducerOutOfLoopBlock); + // g. Manually run cse on region which contains top-level loop of + // candidate slice in avoid of conflict with subsequent + // `tileAndFuseConsumerOfSlice` get nest loops between next candidate + // sliceOp and tiled producer. + (void)mlir::simplifyRegions(rewriter, + {*outerLoops.front()->getParentRegion()}); + } + } + if (fusedResultList.empty()) + return std::nullopt; + + return fusedResultList; +} + +/// Target at following general topology: +/// +/// producer1 producer2 +/// \ / +/// Op +/// / \ +/// consumer1 consumer2 +/// +/// where: +/// +/// Support iterative producer and consumer fusion in BFS fashion. +LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp( + RewriterBase &rewriter, Operation *tiledOp, + const CandidateSliceOptions &options) { + unsigned numTiledOps = 0; + std::deque tiledOpList = {tiledOp}; + while (!tiledOpList.empty()) { + tiledOp = tiledOpList.front(); + tiledOpList.pop_front(); + numTiledOps++; + // fuse producer + for (OpOperand &operand : tiledOp->getOpOperands()) { + if (std::optional fuseProducerResult = + tileAndFuseProducerOfOpOperand(rewriter, operand, options)) + tiledOpList.push_back(fuseProducerResult.value().tiledOps[0]); + } + // fuse consumer(s) + for (OpResult result : tiledOp->getResults()) { + if (std::optional> + fuseConsumerResults = + tileAndFuseConsumerOfOpResult(rewriter, result, options)) { + for (auto &fuseConsumerResult : *fuseConsumerResults) + tiledOpList.push_back(fuseConsumerResult.tiledOps[0]); + } + } + } + return success(numTiledOps > 1); +} + +/// What is self tiled op compared with other fused op? +/// E.g. +/// %1 = scf.for(){ +/// %2 = scf.for(){ +/// %3 = extract_slice +/// %4 = tiled_op(%3) +/// %5 = insert %4 +/// yield %5 +/// } +/// } +static LogicalResult isSelfTiledOp(Operation *targetOp) { + // 0. check tilable + if (!isa(targetOp)) + return failure(); + // 1. check parentOp + auto forOp = targetOp->getParentOfType(); + if (!forOp) + return failure(); + // 2. check single one tiling interface in loop body + auto walkResult = forOp->walk([&targetOp](TilingInterface op) { + // some special op maybe already deal with in template + if (isa(op)) + return WalkResult::skip(); + return op != targetOp ? WalkResult::interrupt() : WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return failure(); + // 3. check whether has either extract or insert slice op + walkResult = forOp->walk( + [](tensor::ExtractSliceOp) { return WalkResult::interrupt(); }); + if (walkResult.wasInterrupted()) + return success(); + walkResult = forOp->walk( + [](tensor::InsertSliceOp) { return WalkResult::interrupt(); }); + return success(walkResult.wasInterrupted()); +} + +struct SystemDesc { + // get runtime OMP_NUM_THREADS + uint32_t getNumThreads() { + std::optional numThreads = layout.getDevicePropertyValue( + Builder(ctx).getStringAttr("CPU" /* device ID*/), + Builder(ctx).getStringAttr("num_threads")); + if (numThreads && isa(*numThreads)) { + return dyn_cast(*numThreads).getInt(); + } + return 1; + } + // get cache size by cacheLevel + size_t getCacheSize(uint8_t cacheLevel) { + if (cacheLevel == 1) { + std::optional cacheSize = layout.getDevicePropertyValue( + Builder(ctx).getStringAttr("CPU" /* device ID*/), + Builder(ctx).getStringAttr("L1_cache_size_in_bytes")); + if (cacheSize && isa(*cacheSize)) { + return dyn_cast(*cacheSize).getInt(); + } + } else if (cacheLevel == 2) { + std::optional cacheSize = layout.getDevicePropertyValue( + Builder(ctx).getStringAttr("CPU" /* device ID*/), + Builder(ctx).getStringAttr("L2_cache_size_in_bytes")); + if (cacheSize && isa(*cacheSize)) { + return dyn_cast(*cacheSize).getInt(); + } + } else if (cacheLevel == 3) { + std::optional cacheSize = layout.getDevicePropertyValue( + Builder(ctx).getStringAttr("CPU" /* device ID*/), + Builder(ctx).getStringAttr("L3_cache_size_in_bytes")); + if (cacheSize && isa(*cacheSize)) { + return dyn_cast(*cacheSize).getInt(); + } + } + return 0; + } + + // get the maximum vector length in bits + size_t getMaxVectorLength() { + std::optional maxVectorLength = layout.getDevicePropertyValue( + Builder(ctx).getStringAttr("CPU" /* device ID*/), + Builder(ctx).getStringAttr("max_vector_width")); + if (maxVectorLength && isa(*maxVectorLength)) { + return dyn_cast(*maxVectorLength).getInt(); + } + return 512; + } + + SystemDesc(ModuleOp m) : layout(m), ctx(m->getContext()) {} + +private: + DataLayout layout; + MLIRContext *ctx; +}; + +using OpTileSizeMap = std::unordered_map>; + +template +static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op, + const OpTileSizeMap &tsMap) { + // a. Check + if (!isa(op) || !isa(op)) + return false; + auto tilingInterfaceOp = cast(op); + + scf::SCFTilingOptions options; + // b. Get default tiling size + SmallVector iteratorTypes = + tilingInterfaceOp.getLoopIteratorTypes(); + + SmallVector defaultTileSize; + + std::string opName = op->getName().getStringRef().str(); + // Erase dialect name, such as Linalg or Tensor. + opName.erase(0, opName.find(".") + 1); + + if (tsMap.count(opName)) { + SmallVector userDefaultTileSize = tsMap.find(opName)->second; + defaultTileSize = + getAsOpFoldResult(rewriter.getI64ArrayAttr(userDefaultTileSize)); + } else { + defaultTileSize.resize(iteratorTypes.size(), rewriter.getIndexAttr(0)); + for (auto &&[en, iterType] : llvm::enumerate(iteratorTypes)) { + // All outer non reduction loop should contribute parallelism. In another + // word, all reduction dimensions should not be tiled. + if (iterType == utils::IteratorType::parallel && + (en != iteratorTypes.size() - 1 || + llvm::count(iteratorTypes, utils::IteratorType::reduction))) + defaultTileSize[en] = rewriter.getIndexAttr(1); + } + } + // If the tile sizes are all zero, no tiling would happen. + if (llvm::all_of(defaultTileSize, isZeroIndex)) + return false; + + options.setTileSizes(defaultTileSize); + // c. Set loop type + options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); + // d. Use builtin tiling interface + FailureOr tilingResult = + scf::tileUsingSCF(rewriter, tilingInterfaceOp, options); + if (succeeded(tilingResult)) { + rewriter.replaceOp(op, tilingResult->replacements); + return true; + } + return false; +} + +void iterativeTilingAndFusionUntilExhaustion( + RewriterBase &rewriter, func::FuncOp &f, + const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap) { + // Collect untiled and tiled ops respectively + llvm::SetVector selfTiledOp, unTiledOps; + + auto collectUnTiledOps = [&f, &unTiledOps]() -> bool { + // Reset + unTiledOps.clear(); + // Pre-order walk through funcOp + f->walk([&unTiledOps](Operation *op) { + if (isa(op)) + return WalkResult::skip(); + if (isa(op)) { + auto parentLoop = op->getParentOfType(); + if (!parentLoop.getOperation()) + unTiledOps.insert(op); + } + return WalkResult::advance(); + }); + return !unTiledOps.empty(); + }; + + auto collectSelfTiledOp = [&f, &selfTiledOp]() -> bool { + // Reset + selfTiledOp.clear(); + // Walk through funcOp + f->walk([&selfTiledOp](Operation *op) { + // Target at certain kind of tiled op, such as matmul/conv implemented + // by multiple level of nest loops and candidate slices for better + // utilization of parallelism and memory hierarchy. + if (succeeded(isSelfTiledOp(op))) { + selfTiledOp.insert(op); + } + }); + return !selfTiledOp.empty(); + }; + + // Iterative tiling and fusion until exhaustion. + while (collectUnTiledOps()) { + // If existing tiled op before tiling. + if (collectSelfTiledOp()) { + // Sort by topology + mlir::topologicalSort(selfTiledOp); + // Record if any fusion happens + bool changed = false; + // Iteratively fuse in forward and backward fashion. + llvm::for_each(selfTiledOp, [&rewriter, &sliceOptions, + &changed](Operation *tiledOp) { + changed |= succeeded(iterativelyFuseProducerAndConsumerOfTiledOp( + rewriter, tiledOp, sliceOptions)); + }); + if (changed) + (void)mlir::simplifyRegions(rewriter, {f.getRegion()}); + } else { + // Auto tiling with default tile size if no tiled op found. Follow tiling + // priority based on OpTy: `Contraction`->`Reduction`->`Elementwise`. + SmallVector> + priorityTilingPipeLine = { + defaultTilingOfType, + defaultTilingOfType, + defaultTilingOfType}; + if (llvm::all_of(priorityTilingPipeLine, + [&rewriter, &tsMap, &unTiledOps]( + function_ref + tilingFn) { + return !llvm::any_of( + unTiledOps, std::bind(tilingFn, std::ref(rewriter), + std::placeholders::_1, + std::cref(tsMap))); + })) { + // If no op can be tiled + return; + } + } + } +} + +static OpTileSizeMap defaultTileSizeParser(ArrayRef strArgs) { + OpTileSizeMap tsMap; + char warning[] = + "Please follow correct argument format: opType:{ts1,ts2,...}"; + for (auto str : strArgs) { + str.erase(llvm::remove_if(str, llvm::isSpace), str.end()); + size_t pos = str.find(":"); + if (pos == std::string::npos) + llvm_unreachable(warning); + + std::string opType = str.substr(0, pos); + std::string strTileSize = str.erase(0, pos + 1); + if (strTileSize.size() <= 2 || strTileSize.front() != '{' || + strTileSize.back() != '}') + llvm_unreachable(warning); + + strTileSize = strTileSize.substr(1, strTileSize.size() - 2); + SmallVector intTileSize; + while ((pos = strTileSize.find(",")) != std::string::npos) { + intTileSize.push_back(std::stoi(strTileSize.substr(0, pos))); + strTileSize.erase(0, pos + 1); + } + intTileSize.push_back(std::stoi(strTileSize)); + tsMap[opType] = intTileSize; + } + return tsMap; +} + +struct IterativeTilingAndFusion + : public impl::IterativeTilingAndFusionBase { + using IterativeTilingAndFusionBase::IterativeTilingAndFusionBase; + +public: + void runOnOperation() final { + auto &ctx = getContext(); + // Get funcOp + func::FuncOp func = getOperation(); + // Get system descriptor + SystemDesc sysDesc(func->getParentOfType()); + // Flexible options to control which candidate slice would be selected from + // the view of both validity and performance. + CandidateSliceOptions sliceOptions; + // Since most filters regarding to validity have already been built-in + // enabled. Users could focus on performance related filters, a.k.a. cost + // model. E.g. + if (useCostModel) { + // Customized filter by cost model. + CandidateSliceFilter costModelFilter = + [&sysDesc](RewriterBase &rewriter, + OffsetSizeAndStrideOpInterface candidate, + CandidateDefOrUse defOrUse) -> LogicalResult { + // Get cache size + size_t l2CacheSize = sysDesc.getCacheSize(2); + FailureOr tileSizeProduct = + computeTileSizeProductOfCandidate(candidate); + return success(succeeded(tileSizeProduct) && + (*tileSizeProduct <= (int64_t)l2CacheSize)); + }; + sliceOptions.addFilter(costModelFilter); + } + OpTileSizeMap tsMap = defaultTileSizeParser(defaultTileSize); + // Get rewriter + IRRewriter rewriter(&ctx); + // Run iterative fusion + iterativeTilingAndFusionUntilExhaustion(rewriter, func, sliceOptions, + tsMap); + } +}; + +} // namespace gc +} // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 29a143835..682671bc7 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -43,7 +43,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) { // todo: layout propagation pass // todo: tensor constant propagation pass // todo: linalg.matmul lowering to (scf.loop + linalg.brgemm) pass - // todo: fine-grain fusion pass + // Fine-grain fusion pass + pm.addNestedPass(createIterativeTilingAndFusion()); // todo: lower linalg to arith/math on virtual vector pass // REMOVE this pass after the above passes are added. Currently we add this diff --git a/lib/gc/Transforms/TilingUsingInterfaceX.cpp b/lib/gc/Transforms/TilingUsingInterfaceX.cpp new file mode 100644 index 000000000..b0104ca5d --- /dev/null +++ b/lib/gc/Transforms/TilingUsingInterfaceX.cpp @@ -0,0 +1,997 @@ +//===-- TilingUsingInterfaceX.cpp - upstream eXtension ---------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include + +#include "TilingUsingInterfaceX.h" + +#define DEBUG_TYPE "tile-using-interface-x" + +using namespace mlir; + +static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, + Operation *op, + ValueRange newDestArgs) { + Operation *clonedOp = rewriter.clone(*op); + if (newDestArgs.empty()) + return clonedOp; + if (auto destinationStyleOp = dyn_cast(clonedOp)) + destinationStyleOp.getDpsInitsMutable().assign(newDestArgs); + return clonedOp; +} + +static std::tuple> +getUntiledProducerFromSliceSource(OpOperand *source, + ArrayRef loops) { + std::optional destinationIterArg; + if (!loops.empty()) { + auto loopIt = loops.rbegin(); + while (auto iterArg = dyn_cast(source->get())) { + auto loop = *loopIt; + if (iterArg.getOwner()->getParentOp() != loop) + break; + source = loop.getTiedLoopInit(iterArg); + loopIt++; + } + if (loopIt == loops.rend()) + destinationIterArg = source; + } + return {dyn_cast(source->get()), destinationIterArg}; +} + +static std::optional +tileAndFuseProducerOfSliceImpl(RewriterBase &rewriter, + tensor::ExtractSliceOp candidateSliceOp, + MutableArrayRef loops) { + // 1. Get the producer of the source (potentially walking through + // `iter_args` of nested `scf.for`) + auto [fusableProducer, destinationInitArg] = + getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), + loops); + if (!fusableProducer) + return std::nullopt; + unsigned resultNumber = fusableProducer.getResultNumber(); + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(candidateSliceOp); + + // 2. Clone the fused producer + // 2a. Compute the destination operands to use for the cloned operation. + SmallVector origDestinationTensors, clonedOpDestinationTensors; + Operation *fusableProducerOp = fusableProducer.getOwner(); + if (isa(fusableProducerOp) && + failed(tensor::getOrCreateDestinations( + rewriter, fusableProducerOp->getLoc(), fusableProducerOp, + origDestinationTensors))) + return std::nullopt; + + clonedOpDestinationTensors = origDestinationTensors; + if (destinationInitArg && + isa(fusableProducerOp)) { + // 2b. If the producer is also destination style, then to maintain the + // destination passing style, update the destination of the producer to be + // the source of the slice. + clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource(); + } + // 2c. Clone the fused producer. + Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs( + rewriter, fusableProducerOp, clonedOpDestinationTensors); + // 2d. Update the source of the candidateSlice to be the cloned producer. + // Easier to just clone the slice with different source since replacements + // and DCE of cloned ops becomes easier + SmallVector candidateSliceOpOperands = + llvm::to_vector(candidateSliceOp->getOperands()); + candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber); + tensor::ExtractSliceOp clonedCandidateSliceOp = + mlir::clone(rewriter, candidateSliceOp, + candidateSliceOp->getResultTypes(), candidateSliceOpOperands); + + // 3. Generate the tiled implementation of the producer of the source + FailureOr tileAndFuseResult = + tensor::replaceExtractSliceWithTiledProducer( + rewriter, clonedCandidateSliceOp, + clonedProducerOp->getResult(resultNumber)); + if (failed(tileAndFuseResult)) + return std::nullopt; + // Note: Do not delete the candidateSliceOp, since its passed in from the + // caller. + rewriter.replaceAllUsesWith(candidateSliceOp, + tileAndFuseResult->tiledValues[0]); + rewriter.eraseOp(clonedCandidateSliceOp); + rewriter.eraseOp(clonedProducerOp); + + // 3. If the slice is for a destination operand, for example, + // + // ```mlir + // %0 = linalg.init + // %1 = linalg.fill .. outs(%0 : ) + // %2 = scf.for .. iter_args(%arg0 = %1) { + // %3 = scf.for .. iter_args(%arg1 = %arg0) { + // %4 = tensor.extract_slice %arg1 [..] + // .. = linalg.matmul .. outs(%4 : ) + // } + // } + // ``` + // + // the IR is currently + // + // ``` + // %0 = linalg.init + // %1 = linalg.fill + // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { + // %3 = scf.for .. iter_args(%arg1 = %arg0) { + // %4 = tensor.extract_slice %arg1[..] + // %5 = linalg.fill .. outs(%4 : ) + // .. = linalg.matmul .. outs(%5 : ) + // } + // } + // ``` + // + // The untiled `linalg.fill` is still used as the `init_value` since it + // was originally a destination operand of the untiled `linalg.matmul`. + // When fusing an operand that is a destination operand, the iter_arg of + // the outer most loop should be changed to use the destination of the + // fused operation. With this the IR will be. + // + // ``` + // %0 = linalg.init + // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { + // %2 = scf.for .. iter_args(%arg1 = %arg0) { + // %3 = tensor.extract_slice %arg1[..] + // %4 = linalg.fill .. outs(%3 : ) + // .. = linalg.matmul .. outs(%4 : ) + // } + // } + // ``` + if (destinationInitArg && + isa(fusableProducerOp) && !loops.empty()) { + loops.front() + ->getOpOperands()[destinationInitArg.value()->getOperandNumber()] + .set(origDestinationTensors[resultNumber]); + } + return scf::SCFFuseProducerOfSliceResult{fusableProducer, + tileAndFuseResult->tiledValues[0], + tileAndFuseResult->tiledOps}; +} + +/// Get the real producer of candidate ExtractSliceOp +/// +/// ``` +/// %0 = producer +/// %1 = scf.for(%arg1 = %0) +/// %2 = extract %arg1 +/// %3 = scf.for(%arg2 = %2) +/// %4 = extract %args2 +/// ... +/// ``` +/// +/// @param candidateSliceOp: %4 = extract %args2 +/// @param backwardSlice: in-out parameter populated by backward extractSliceOps +/// @return OpResult Producer : %0 = producer +FailureOr mlir::scfX::getRealProducerOfExtractSliceOp( + Operation *candidateSliceOp, + SmallVector &backwardSlice, unsigned curDepth, + unsigned maxDepth) { + if (!isa(candidateSliceOp)) + return failure(); + // control recursive time in avoid of stack overflow + if (curDepth > maxDepth) + return failure(); + + auto extractOp = cast(candidateSliceOp); + backwardSlice.push_back(extractOp); + Value rootSource = extractOp.getSourceMutable().get(); + + while (true) { + if (auto iterArg = dyn_cast(rootSource)) { + if (auto outerLoop = dyn_cast( + iterArg.getOwner()->getParentOp())) { + rootSource = outerLoop.getTiedLoopInit(iterArg)->get(); + continue; + } + return failure(); + } + if (auto sliceOp = rootSource.getDefiningOp()) { + // walk up loop to find larger candidate extractSliceOp + return getRealProducerOfExtractSliceOp(sliceOp, backwardSlice, + curDepth + 1); + } + break; + } + return dyn_cast(rootSource); +} + +/// Recursively find the outer nest loops of given loop(included) while the +/// predict function succeed, sorted from outer to inner. +/// +/// @param loop: target loop, note that this loop will be also included. I.e. +/// if no other nest loops were found, just return itself. +/// @param pred: predict function, the termination condition of recursive +/// process. +/// @return Outer Nest Loops: nest loops outside given target loop(included). +/// +/// E.g. +/// +/// ``` +/// %0 = scf.for() +/// %1 = scf.for() +/// %2 = scf.for() +/// ``` +/// +/// If `%2 = scf.for` is given without specific prediction function, this +/// function will return three nest loops: %0 + %1 + %2. +SmallVector mlir::scfX::getOuterNestLoopsWhile( + LoopLikeOpInterface loop, + const std::function &pred) { + SmallVector nestLoops = {loop}; + auto outerLoop = dyn_cast(loop->getParentOp()); + while (outerLoop && succeeded(pred(outerLoop))) { + nestLoops.push_back(outerLoop); + outerLoop = dyn_cast(outerLoop->getParentOp()); + } + // sorted from outer to inner + return {nestLoops.rbegin(), nestLoops.rend()}; +} + +/// Enhanced version of `tileAndFuseProducerOfSliceImpl`, which can deal with +/// multi-level `extractSliceOp`. E.g. +/// +/// ``` +/// %0 = untiled_producer +/// %1 = scf.for(%arg1 = %0) +/// %2 = extract %arg1 +/// %3 = scf.for(%arg2 = %2) +/// %4 = extract %args2 +/// %5 = tiled_consumer ins(%4) +/// ``` +std::optional +mlir::scfX::tileAndFuseProducerOfSlice(RewriterBase &rewriter, + Operation *candidateSliceOp) { + SmallVector backwardSlice; + if (failed(getRealProducerOfExtractSliceOp(candidateSliceOp, backwardSlice))) + return std::nullopt; + + std::optional fuseProducerResult; + // reverse from outer to inner + std::reverse(backwardSlice.begin(), backwardSlice.end()); + // multiple application of `tileAndFuseProducerOfSliceImpl` + for (auto &&[index, sliceOp] : llvm::enumerate(backwardSlice)) { + // get nest loops between next candidate sliceOp and tiled producer. + auto whileProducerOutOfLoopBlock = + [&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult { + if (fuseProducerResult) { + Block &body = loop->getRegion(0).front(); + if (fuseProducerResult->tiledAndFusedProducer.getDefiningOp() + ->getBlock() == &body) + return failure(); + } + return success(); + }; + SmallVector outerLoops = + getOuterNestLoopsWhile(sliceOp->getParentOfType(), + whileProducerOutOfLoopBlock); + fuseProducerResult = + tileAndFuseProducerOfSliceImpl(rewriter, sliceOp, outerLoops); + if (!fuseProducerResult) + return std::nullopt; + } + return fuseProducerResult; +} + +/// Get the real consumers from candidate InsertSliceOp. E.g +/// +/// ``` +/// %1 = scf.for +/// %2 = scf.for +/// %3 = scf.for +/// ... +/// %4 = insert +/// yield %4 +/// %5 = insert %3 +/// yield %5 +/// yield %2 +/// %6 = consumerOp ins(%1) +/// ``` +/// +/// @param candidateSliceOp: %4 = insert +/// @param forwardSlice: in-out parameter populated by forward insertSliceOps +/// @return OpOperand consumers: %6 = consumerOp ins(%1) +FailureOr> +mlir::scfX::getRealConsumersFromInsertSliceOp( + Operation *candidateSliceOp, + SmallVector &forwardSlice, + unsigned curDepth, unsigned maxDepth) { + if (!isa( + candidateSliceOp)) + return failure(); + // Control recursive time in avoid of stack overflow + if (curDepth > maxDepth) + return failure(); + + forwardSlice.push_back( + cast(candidateSliceOp)); + Value resultOfLoop; + if (auto sliceOp = + dyn_cast(candidateSliceOp)) { + Value destValue = sliceOp.getDest(); + auto iterArg = cast(destValue); + auto forallOp = dyn_cast(iterArg.getOwner()->getParentOp()); + if (!forallOp) + return failure(); + resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)); + } else if (auto sliceOp = dyn_cast(candidateSliceOp)) { + Value resultValue = sliceOp.getResult(); + for (auto &useOperand : resultValue.getUses()) { + if (auto yieldOp = dyn_cast(useOperand.getOwner())) { + if (llvm::detail::isPresent(resultOfLoop)) + return failure(); + auto forOp = dyn_cast(yieldOp->getParentOp()); + if (!forOp) + return failure(); + resultOfLoop = forOp->getResult(useOperand.getOperandNumber()); + } + } + } + + if (!llvm::detail::isPresent(resultOfLoop)) + return failure(); + + bool traverseUpperLoop; + do { + traverseUpperLoop = false; + for (OpOperand &useOperand : resultOfLoop.getUses()) { + if (auto sliceOp = + dyn_cast(useOperand.getOwner())) { + return getRealConsumersFromInsertSliceOp(sliceOp, forwardSlice, + curDepth + 1); + } + if (auto yieldOp = dyn_cast(useOperand.getOwner())) { + // walk through outer loop + auto forOp = dyn_cast(yieldOp->getParentOp()); + if (!forOp) + return failure(); + resultOfLoop = forOp->getResult(useOperand.getOperandNumber()); + traverseUpperLoop = true; + break; + } + } + } while (traverseUpperLoop); + // Return all operands using result of top level loop. + return llvm::map_to_vector(resultOfLoop.getUses(), + [](OpOperand &u) -> OpOperand * { return &u; }); +} + +/// A utility function that checks whether the only use of the result of a +/// tensor.insert_slice op is in a scf.yield op. +static LogicalResult +checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) { + Value result = candidateSliceOp.getResult(); + Value::use_range uses = result.getUses(); + if (!llvm::hasSingleElement(uses)) { + LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n"); + return failure(); + } + OpOperand &operandUse = (*uses.begin()); + Operation *userOp = operandUse.getOwner(); + if (!isa(userOp)) { + LLVM_DEBUG(llvm::dbgs() + << "Expected scf.yield to be the only user, but got -> " + << (*userOp)); + return failure(); + } + if (result.getDefiningOp()->getBlock() != userOp->getBlock()) { + LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to " + "be in the same block\n"); + return failure(); + } + return success(); +} + +/// Fetches the FIRST OpOperand of the tilable user (and use) of the value `val` +/// within the same block, which implements `TilingInterface` and +/// `DestinationStyleOpInterface` and has non-empty user list. +/// Returns failure otherwise. +static FailureOr getConsumerFromUses(Value val, + Block *containingOpBlock) { + OpOperand *operand = nullptr; + for (auto &use : val.getUses()) { + Operation *user = use.getOwner(); + // Step 1. Check if the user is tilable. + if (!isa(user)) { + // TODO: We have to init result of consumer before scf.for, use + // DestinationStyleOpInterface to get result shape from init for + // now. Add support for other op such as op has + // InferTypeOpInterface. + continue; + } else { + // Step 2. Check if user stay in the same block. + if (containingOpBlock != user->getBlock()) + continue; + // Step 3. Check if user has succeeding user. Otherwise, it usually + // represents already tiled. + if (user->use_empty()) + continue; + operand = &use; + break; + } + } + if (!operand) + return failure(); + + return operand; +} + +/// Check if it is the ForOp that yield the result of inner loop +static LogicalResult isForOpYieldResultOfInnerLoop(LoopLikeOpInterface loop) { + if (auto forOp = dyn_cast(loop.getOperation())) { + for (auto &&[index, op] : + llvm::enumerate(forOp.getBody()->getOperations())) { + // If the orderIndex of inner loop is the last second one before the + // yieldOp of ForOp, the given loop must yield the result of inner loop. + if (isa(op)) { + return success((index + 2) == forOp.getBody()->getOperations().size()); + } + } + } + return failure(); +} + +/// Fetch the untiled consumer of a scf.for's result which is yielded by a +/// tensor.insert_slice. This function makes the following assumptions that +/// tensor.insert_slice has scf.yield as its only user. +static FailureOr +getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) { + if (failed(checkAssumptionForFusingConsumer(candidateSliceOp))) + return failure(); + Value sliceResult = candidateSliceOp.getResult(); + // Step 1. Fetch the corresponding output. + OpOperand &yieldOpOperand = (*sliceResult.getUses().begin()); + unsigned resultNumber = yieldOpOperand.getOperandNumber(); + // Step 2. Check containing op is scf.for. + Operation *containingOp = candidateSliceOp->getParentOp(); + auto forOp = dyn_cast(containingOp); + if (!forOp) + return failure(); + LoopLikeOpInterface topLevelForOp = + scfX::getOuterNestLoopsWhile(forOp, isForOpYieldResultOfInnerLoop) + .front(); + Value resultingValue = topLevelForOp->getResult(resultNumber); + + return getConsumerFromUses(resultingValue, topLevelForOp->getBlock()); +} + +/// Fetch the first untiled consumer of a scf.forall's result which is yielded +/// by a tensor.parallel_insert_slice. +static FailureOr +getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) { + // Step 1. Fetch the corresponding output + Value sliceDest = candidateSliceOp.getDest(); + auto iterArg = dyn_cast(sliceDest); + if (!iterArg) + return failure(); + Operation *containingOp = iterArg.getOwner()->getParentOp(); + if (containingOp != candidateSliceOp->getParentOp()->getParentOp()) + return failure(); + // Step 2. Check that the containing op is scf.forall. + auto forallOp = dyn_cast(containingOp); + if (!forallOp) + return failure(); + Value resultingValue = + forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)); + + return getConsumerFromUses(resultingValue, containingOp->getBlock()); +} + +/// This utility currently checks whether the first userOp of loop is NOT before +/// the last defineOp of consumer. Currently we clone the loop op right before +/// a certain op in order to maintain a valid def-use chain. This utility thus +/// helps ensuring that no invalid IR is formed due to the same. E.g. +/// +/// ``` +/// %0 = scf.for() { +/// +/// } +/// ... +/// %1 = firstUserOfLoop(%0) +/// ... +/// %2 = lastDefOfConsumer +/// ... +/// %3 = consumerOp(%2) +/// ``` +/// +/// If the `firstUserOfLoop`is before `lastDefOfConsumer`, then it would be +/// invalid to clone the loop op right before the `firstUserOfLoop`: +/// +/// ``` +/// %0:2 = scf.for() { +/// %3 = tiledConsumerOp(%2) +/// } +/// %1 = firstUserOfLoop(%0) +/// ... +/// %2 = lastDefOfConsumer +/// ``` +/// +/// To address this issue, this utility would double-check there is no user of +/// `firstUserOfLoop` before `lastDefOfConsumer`. If so, move `firstUserOfLoop` +/// after `lastDefOfConsumer`. Then, it turns out valid as follow: +/// +/// ``` +/// %2 = lastDefOfConsumer +/// %0:2 = scf.for() { +/// %3 = tiledConsumerOp(%2) +/// } +/// %1 = firstUserOfLoop(%0) +/// ``` +/// +/// @param loopOp: loop operation +/// @param consumerOp: consumer operation +/// @param insertPointBefore: which operation we clone the looOp right before +static LogicalResult checkAssumptionForLoop(Operation *loopOp, + Operation *consumerOp, + Operation **insertPointBefore) { + Block *parentBlock = consumerOp->getBlock(); + // loopOp and consumerOp should stay in the same block. + if (loopOp->getBlock() != parentBlock) + return failure(); + + Operation *firstUserOfLoop = consumerOp, *lastDefOfConsumer = loopOp; + // Find the first user of loopOp + for (Operation *userOp : loopOp->getUsers()) { + if (userOp == consumerOp) + continue; + // `ParallelInsertSlice` located inside `InParallelOp` has no same parent + // block with any other types of operation. Thus, just redirecting to its + // parent `InParallelOp`. + if (isa(userOp)) + userOp = userOp->getParentOfType(); + + if (parentBlock != userOp->getBlock()) + return failure(); + + if (userOp->isBeforeInBlock(firstUserOfLoop)) + firstUserOfLoop = userOp; + } + // Find the last define of consumer + for (Value operand : consumerOp->getOperands()) { + // If the operand is `BlockArgument`, auto skip. + if (isa(operand)) + continue; + auto defineOp = operand.getDefiningOp(); + if (defineOp == loopOp) + continue; + if (!defineOp || parentBlock != defineOp->getBlock()) + return failure(); + if (lastDefOfConsumer->isBeforeInBlock(defineOp)) + lastDefOfConsumer = defineOp; + } + if (firstUserOfLoop->isBeforeInBlock(lastDefOfConsumer)) { + // Try to move if possible + if (llvm::all_of(firstUserOfLoop->getUsers(), + [&lastDefOfConsumer, &parentBlock](Operation *userOp) { + return userOp->getBlock() == parentBlock && + lastDefOfConsumer->isBeforeInBlock(userOp); + })) { + // Safely moving + firstUserOfLoop->moveAfter(lastDefOfConsumer); + } else { + return failure(); + } + } + // Set InsertPoint + *insertPointBefore = firstUserOfLoop; + return success(); +} + +/// A utility to fetch an untiled consumer of +/// tensor.insert_slice/tensor.parallel_insert_slice. +static FailureOr getUntiledConsumerFromSlice(Operation *sliceOp) { + if (auto insertSlice = dyn_cast(sliceOp)) { + return getUntiledConsumerFromSlice(insertSlice); + } else if (auto parallelInsertSlice = + dyn_cast(sliceOp)) { + return getUntiledConsumerFromSlice(parallelInsertSlice); + } else { + return failure(); + } +} + +/// After fusing consumer into scf.for we want to modify the scf.yield operation +/// to reflect the same by returning the values yielded by the tiled consumer. +static void +fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp, + TilingResult &tilingResult, + ArrayRef> &resultOffsets, + ArrayRef> &resultSizes, + ArrayRef bbArgs) { + scf::YieldOp oldTerminatorOp = + cast(newForOp.getBody()->getTerminator()); + unsigned totalOldResults = oldTerminatorOp->getNumResults(); + unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults(); + SmallVector newYieldOperands; + newYieldOperands.reserve(totalOldResults + totalTiledResults); + for (auto oldResult : oldTerminatorOp.getResults()) { + newYieldOperands.push_back(oldResult); + } + rewriter.setInsertionPointAfter(oldTerminatorOp); + Location loc = newForOp.getLoc(); + for (auto [tiledResult, bbArg, resultOffset, resultSize] : + llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs, + resultOffsets, resultSizes)) { + SmallVector strides(resultOffset.size(), + rewriter.getIndexAttr(1)); + Value newInsertSliceOp = rewriter.create( + loc, tiledResult, bbArg, resultOffset, resultSize, strides); + newYieldOperands.push_back(newInsertSliceOp); + } + rewriter.create(loc, newYieldOperands); + rewriter.eraseOp(oldTerminatorOp); +} + +/// After fusing consumer into scf.forall we want to yield each of the resulting +/// values by the tiled consumer within scf.forall.in_parallel region. +static void +fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp, + SmallVector tiledResults, + ArrayRef> &resultOffsets, + ArrayRef> &resultSizes, + ArrayRef bbArgs) { + scf::InParallelOp newTerminatorOp = newForallOp.getTerminator(); + rewriter.setInsertionPointToStart(newTerminatorOp.getBody()); + Location firstYieldOpLoc = + (*(newTerminatorOp.getYieldingOps().begin())).getLoc(); + for (auto [tiledResult, bbArg, resultOffset, resultSize] : + llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) { + SmallVector strides(resultOffset.size(), + rewriter.getIndexAttr(1)); + rewriter.create( + firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides); + } +} + +/// Implementation of fusing consumer of a single slice by computing the +/// slice of the consumer in-place for scf loop. +/// As for `insertSlice`, it also supports nest outer loop structure without +/// any other slice inside. E.g. +/// +/// ``` +/// scf.for() +/// scf.for() +/// scf.for() +/// ... +/// insert_slice +/// yield +/// yield +/// yield +/// ``` +static FailureOr +tileAndFuseConsumerOfSliceImpl(RewriterBase &rewriter, + Operation *candidateSliceOp) { + if (!isa( + candidateSliceOp)) + return failure(); + + bool isInsertSliceOp = isa(candidateSliceOp); + + // 1. Get the consumer of scf.for for the result yielded by + // tensor.insert_slice/parallel_insert_slice. + FailureOr maybeConsumerOpOperand = + getUntiledConsumerFromSlice(candidateSliceOp); + if (failed(maybeConsumerOpOperand)) { + return rewriter.notifyMatchFailure(candidateSliceOp, + "could not fetch consumer to fuse"); + } + OpOperand *consumerOpOperand = *maybeConsumerOpOperand; + Operation *consumerOp = consumerOpOperand->getOwner(); + unsigned operandNumber = consumerOpOperand->getOperandNumber(); + unsigned resultNumber = 0; + if (auto producerResult = dyn_cast(consumerOpOperand->get())) { + resultNumber = producerResult.getResultNumber(); + } else { + return rewriter.notifyMatchFailure( + consumerOp, "consumer op's operand doesn't seem to be an OpResult"); + } + + Operation *oldLoopOp = nullptr; + SmallVector newOuts; + Block *oldLoopBody = nullptr; + unsigned initSize = 0; + unsigned rank = 1; + if (isInsertSliceOp) { + auto forOp = candidateSliceOp->getParentOfType(); + oldLoopOp = forOp; + initSize = forOp.getInits().size(); + } else { + auto forallOp = candidateSliceOp->getParentOfType(); + oldLoopOp = forallOp; + initSize = forallOp.getOutputs().size(); + rank = forallOp.getRank(); + } + + Operation *oldTopLevelLoop = oldLoopOp; + SmallVector oldNestedForOps, newNestedForOps; + if (isInsertSliceOp) { + oldNestedForOps = + scfX::getOuterNestLoopsWhile(cast(oldTopLevelLoop), + isForOpYieldResultOfInnerLoop); + oldTopLevelLoop = oldNestedForOps.front(); + } + + // 2.a Check assumption for loop and find suitable insertPoint that loop + // structure would be cloned right before. + Operation *insertPointBefore = nullptr; + if (failed(checkAssumptionForLoop(oldTopLevelLoop, consumerOp, + &insertPointBefore))) { + return rewriter.notifyMatchFailure( + oldTopLevelLoop, "containing loop op does not satisfy the assumption " + "and no suitable insertPoint is found"); + } + + OpBuilder::InsertionGuard g(rewriter); + + // 2.b Check consumer is not using scf loop's output as init. + auto dstOp = cast(consumerOp); + SmallVector dpsInits = + llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; }); + if (llvm::is_contained(dpsInits, oldTopLevelLoop->getResult(resultNumber))) { + return rewriter.notifyMatchFailure( + consumerOp, + "consumer op taking the result of scf.for as init is not supported"); + } + SmallVector newInitAppend = dpsInits; + + Location loc = oldLoopOp->getLoc(); + + // 3. Create new scf loop op. + rewriter.setInsertionPoint(insertPointBefore); + + // 3.a Create new outer scf loops if necessary + bool isNestedForOps = isInsertSliceOp && oldNestedForOps.size() > 1; + if (isNestedForOps) { + for (auto &&[index, loopOp] : + llvm::enumerate(MutableArrayRef(oldNestedForOps).drop_back())) { + auto forOp = cast(loopOp); + SmallVector newInits; + newInits = llvm::to_vector(forOp.getInits()); + newInits.append(newInitAppend.begin(), newInitAppend.end()); + auto newLoop = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newInits); + newInitAppend = llvm::map_to_vector( + newLoop.getRegionIterArgs().take_back(newInitAppend.size()), + [](BlockArgument bArg) -> Value { return bArg; }); + rewriter.mergeBlocks( + forOp.getBody(), newLoop.getBody(), + newLoop.getBody()->getArguments().take_front(initSize + 1)); + rewriter.replaceOp( + forOp, newLoop->getResults().take_front(forOp->getNumResults())); + newNestedForOps.push_back(newLoop); + rewriter.setInsertionPointAfter(oldNestedForOps[index + 1]); + } + } + + // 3.b Create new inner most scf loop + Operation *newLoopOp = nullptr; + Block *newLoopBody = nullptr; + if (isInsertSliceOp) { + auto forOp = cast(oldLoopOp); + llvm::append_range(newOuts, forOp.getInits()); + newOuts.append(newInitAppend); + oldLoopBody = forOp.getBody(); + auto newForOp = rewriter.create(loc, forOp.getLowerBound(), + forOp.getUpperBound(), + forOp.getStep(), newOuts); + newLoopOp = newForOp; + newLoopBody = newForOp.getBody(); + } else { + auto forallOp = cast(oldLoopOp); + llvm::append_range(newOuts, forallOp.getOutputs()); + newOuts.append(newInitAppend); + oldLoopBody = forallOp.getBody(); + auto newForallOp = rewriter.create( + loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), + forallOp.getMixedStep(), newOuts, forallOp.getMapping()); + newLoopOp = newForallOp; + rewriter.eraseOp(newForallOp.getTerminator()); + newLoopBody = newForallOp.getBody(); + } + + // 4. Move the loop body to the new op. + unsigned oldNumArguments = oldLoopBody->getNumArguments(); + rewriter.mergeBlocks(oldLoopBody, newLoopBody, + newLoopBody->getArguments().take_front(oldNumArguments)); + + // 5. Set insertion point before terminator op of the loop and create a new + // tensor.insert_slice. In the scf.for case this is a clone of the + // candidateSliceOp whereas in the scf.forall case this is created from the + // operands of tensor.parallel_insert_slice. + tensor::InsertSliceOp clonedInsertSliceOp; + if (auto sliceOp = + dyn_cast(candidateSliceOp)) { + auto newForallOp = cast(newLoopOp); + rewriter.setInsertionPoint(newForallOp.getTerminator()); + clonedInsertSliceOp = rewriter.create( + loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(), + sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); + } else { + rewriter.setInsertionPoint(candidateSliceOp); + clonedInsertSliceOp = + cast(rewriter.clone(*candidateSliceOp)); + } + + // 6.a. Clone consumer op. + auto newForOpBlockArgsForConsumerDest = + newLoopBody->getArguments().drop_front(oldNumArguments); + auto clonedConsumerOp = cast(cloneOpAndUpdateDestinationArgs( + rewriter, consumerOp, newForOpBlockArgsForConsumerDest)); + + // 6.b. Replace all uses of the loop result with the result of the cloned + // tensor.insert_slice. + OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber); + rewriter.modifyOpInPlace(clonedConsumerOp, [&]() { + operandToReplace.set(clonedInsertSliceOp.getResult()); + }); + + // 7 - Perform tiling of the cloned consumer and replace the operand at + // `operandNumber` with the source of the cloned tensor.insert_slice op. + auto ossSliceOp = + cast(clonedInsertSliceOp.getOperation()); + FailureOr tileAndFuseResult = + tensor::replaceInsertSliceWithTiledConsumer( + rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber)); + if (failed(tileAndFuseResult)) { + return failure(); + } + rewriter.replaceAllUsesWith( + tileAndFuseResult->tiledOps[0]->getOperand(operandNumber), + clonedInsertSliceOp.getSource()); + + // 8 - Extract offset/sizes/strides required to create the + // tensor.insert_slice/parallel_insert_slice for each result of the consumer. + SmallVector offsets = ossSliceOp.getMixedOffsets(); + SmallVector sizes = ossSliceOp.getMixedSizes(); + SmallVector strides = ossSliceOp.getMixedStrides(); + + // 9. Check all insert stride is 1. + if (llvm::any_of(strides, [](OpFoldResult stride) { + return !isConstantIntValue(stride, 1); + })) { + return rewriter.notifyMatchFailure( + candidateSliceOp, "containingOp's result yield with stride"); + } + + // 10. Try to get iter domain position from input position. + SmallVector iterDomainOffsets, iterDomainSizes; + if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile( + rewriter, operandNumber, offsets, sizes, iterDomainOffsets, + iterDomainSizes))) { + return rewriter.notifyMatchFailure( + clonedConsumerOp, "can't get iter domain position from input position"); + } + + // 11. Try to fetch the offset and size for all results of the cloned + // consumer. This would then be used to form the corresponding + // tensor.insert_slice/parallel_insert_slice later. + unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults(); + SmallVector> resultOffsets( + totalNumResultsOfConsumer); + SmallVector> resultSizes(totalNumResultsOfConsumer); + for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) { + if (failed(clonedConsumerOp.getResultTilePosition( + rewriter, idx, iterDomainOffsets, iterDomainSizes, + resultOffsets[idx], resultSizes[idx]))) { + return rewriter.notifyMatchFailure( + clonedConsumerOp, + "can't get result domain position from iter domain position"); + } + } + + auto arrayRefOffsets = ArrayRef>(resultOffsets); + auto arrayRefSizes = ArrayRef>(resultSizes); + if (isInsertSliceOp) { + auto newForOp = cast(newLoopOp); + fixTerminatorSCFYield( + rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes, + newForOp.getBody()->getArguments().drop_front(1 + initSize)); + } else { + auto newForallOp = cast(newLoopOp); + fixTerminatorSCFInParallel( + rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(), + arrayRefOffsets, arrayRefSizes, + newForallOp.getBody()->getArguments().drop_front(rank + initSize)); + } + + // 12. Restore outer loops from inner to outer + if (isNestedForOps) { + newNestedForOps.push_back(cast(newLoopOp)); + for (auto [outerLoop, innerLoop] : + llvm::zip_equal(MutableArrayRef(newNestedForOps).drop_back(), + MutableArrayRef(newNestedForOps).drop_front())) { + auto forOp = cast(outerLoop); + auto outerLoopYield = + cast(forOp.getBody()->getTerminator()); + SmallVector newYields = + llvm::to_vector(outerLoopYield.getOperands()); + ValueRange additionalYields = + innerLoop->getResults().take_back(newInitAppend.size()); + newYields.append(additionalYields.begin(), additionalYields.end()); + rewriter.setInsertionPoint(outerLoopYield); + rewriter.replaceOpWithNewOp(outerLoopYield, newYields); + } + } + + // 13. Replace the result of scf loop and consumer op with new loop's results. + for (auto &&[oldResult, newResult] : + llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) { + rewriter.replaceAllUsesWith(oldResult, newResult); + } + + Operation *newTopLevelLoop = + isNestedForOps ? newNestedForOps.front() : newLoopOp; + for (auto &&[oldResult, newResult] : + llvm::zip(consumerOp->getResults(), + newTopLevelLoop->getResults().drop_front(initSize))) { + rewriter.replaceAllUsesWith(oldResult, newResult); + } + + // 14. Need to erase the old scf loop and the cloned consumer op. + rewriter.eraseOp(oldLoopOp); + rewriter.eraseOp(clonedConsumerOp); + + // 15. Need to erase the cloned insertSliceOp and unused extractSliceOp in + // avoid of complex domination analysis + assert(clonedInsertSliceOp->hasOneUse()); + auto unUsedExtractOp = + cast((*clonedInsertSliceOp->getUsers().begin())); + rewriter.eraseOp(unUsedExtractOp); + rewriter.eraseOp(clonedInsertSliceOp); + + return scf::SCFFuseConsumerOfSliceResult{ + consumerOpOperand, + &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)), + tileAndFuseResult->tiledOps}; +} + +/// Fusing real consumer of a single slice even within complex nested loops via +/// multiple application of `tileAndFuseConsumerOfSliceImpl`. +FailureOr +mlir::scfX::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, + Operation *candidateSliceOp) { + SmallVector forwardSlice; + if (failed(getRealConsumersFromInsertSliceOp(candidateSliceOp, forwardSlice))) + return failure(); + + FailureOr fuseConsumerResult; + // reverse from outer to inner + std::reverse(forwardSlice.begin(), forwardSlice.end()); + // multiple application of `tileAndFuseConsumerOfSliceImpl` + for (auto &sliceOp : forwardSlice) { + fuseConsumerResult = tileAndFuseConsumerOfSliceImpl(rewriter, sliceOp); + if (failed(fuseConsumerResult)) { + return rewriter.notifyMatchFailure(sliceOp, + "could not fuse consumer of sliceOp"); + } + } + return fuseConsumerResult; +} \ No newline at end of file diff --git a/lib/gc/Transforms/TilingUsingInterfaceX.h b/lib/gc/Transforms/TilingUsingInterfaceX.h new file mode 100644 index 000000000..778021e94 --- /dev/null +++ b/lib/gc/Transforms/TilingUsingInterfaceX.h @@ -0,0 +1,41 @@ +//===-- TilingUsingInterfaceX.h - upstream eXtension -----------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEMPORARY_TILEUSINGINTERFACE_X_H +#define TEMPORARY_TILEUSINGINTERFACE_X_H + +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" + +namespace mlir { +namespace scfX { + +SmallVector getOuterNestLoopsWhile( + LoopLikeOpInterface loop, + const std::function &pred); + +FailureOr getRealProducerOfExtractSliceOp( + Operation *candidateSliceOp, + SmallVector &backwardSlice, unsigned curDepth = 0, + unsigned maxDepth = 5); + +FailureOr> getRealConsumersFromInsertSliceOp( + Operation *candidateSliceOp, + SmallVector &forwardSlice, + unsigned curDepth = 0, unsigned maxDepth = 5); + +// Extension for upstream `tileAndFuseProducerOfSlice` +std::optional +tileAndFuseProducerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp); + +// Extension for upcoming upstream `tileAndFuseConsumerOfSlice` +FailureOr +tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp); +} // namespace scfX +} // namespace mlir + +#endif diff --git a/test/gc/Transform/iterative-tiling-and-fusion.mlir b/test/gc/Transform/iterative-tiling-and-fusion.mlir new file mode 100644 index 000000000..50989867b --- /dev/null +++ b/test/gc/Transform/iterative-tiling-and-fusion.mlir @@ -0,0 +1,253 @@ +// RUN: gc-opt --split-input-file -iterative-tiling-and-fusion %s --cse + +module attributes { + dlti.target_system_spec = #dlti.target_system_spec< + "CPU": #dlti.target_device_spec< + #dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>, + #dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>, + #dlti.dl_entry<"L3_cache_size_in_bytes", 110100480 : i32>, + #dlti.dl_entry<"num_threads", 56 : i32>, + #dlti.dl_entry<"max_vector_width", 512 : i32>> + >} { + /// CHECK-LABEL: @fuse_mlp + func.func @fuse_mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { + %c32 = arith.constant 32 : index + %c512 = arith.constant 512 : index + %c128 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : bf16 + /// CHECK: tensor.empty + %0 = tensor.empty() : tensor<128x256xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + /// CHECK: tensor.empty + %dest = tensor.empty() : tensor<512x256xbf16> + %unpack = tensor.unpack %arg1 inner_dims_pos = [0, 1] inner_tiles = [16, 32] into %dest : tensor<32x8x16x32xbf16> -> tensor<512x256xbf16> + /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}}) in (2, 2) + %2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %1) -> (tensor<128x256xbf16>) { + %5 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3) + %6 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg4) + %7 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3) + %8 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg4) + %extracted_slice = tensor.extract_slice %arg0[%5, 0] [64, 512] [1, 1] : tensor<128x512xbf16> to tensor<64x512xbf16> + %extracted_slice_0 = tensor.extract_slice %unpack[0, %6] [512, 128] [1, 1] : tensor<512x256xbf16> to tensor<512x128xbf16> + %extracted_slice_1 = tensor.extract_slice %arg5[%7, %8] [64, 128] [1, 1] : tensor<128x256xbf16> to tensor<64x128xbf16> + /// CHECK: scf.for + /// CHECK: scf.for + /// CHECK: scf.for + %9 = scf.for %arg6 = %c0 to %c64 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<64x128xbf16>) { + %12 = scf.for %arg8 = %c0 to %c128 step %c128 iter_args(%arg9 = %arg7) -> (tensor<64x128xbf16>) { + %13 = scf.for %arg10 = %c0 to %c512 step %c512 iter_args(%arg11 = %arg9) -> (tensor<64x128xbf16>) { + %extracted_slice_2 = tensor.extract_slice %extracted_slice[%arg6, %arg10] [64, 512] [1, 1] : tensor<64x512xbf16> to tensor<64x512xbf16> + %extracted_slice_3 = tensor.extract_slice %extracted_slice_0[%arg10, %arg8] [512, 128] [1, 1] : tensor<512x128xbf16> to tensor<512x128xbf16> + %extracted_slice_4 = tensor.extract_slice %arg11[%arg6, %arg8] [64, 128] [1, 1] : tensor<64x128xbf16> to tensor<64x128xbf16> + /// CHECK: scf.for + /// CHECK: scf.for + /// CHECK: scf.for + %14 = scf.for %arg12 = %c0 to %c64 step %c32 iter_args(%arg13 = %extracted_slice_4) -> (tensor<64x128xbf16>) { + %15 = scf.for %arg14 = %c0 to %c128 step %c32 iter_args(%arg15 = %arg13) -> (tensor<64x128xbf16>) { + %16 = scf.for %arg16 = %c0 to %c512 step %c512 iter_args(%arg17 = %arg15) -> (tensor<64x128xbf16>) { + %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg12, %arg16] [32, 512] [1, 1] : tensor<64x512xbf16> to tensor<32x512xbf16> + %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[%arg16, %arg14] [512, 32] [1, 1] : tensor<512x128xbf16> to tensor<512x32xbf16> + %extracted_slice_7 = tensor.extract_slice %arg17[%arg12, %arg14] [32, 32] [1, 1] : tensor<64x128xbf16> to tensor<32x32xbf16> + /// CHECK: %[[UNPACK_OUT:.*]] = tensor.unpack + /// CHECK: %[[FILL_OUT:.*]] = linalg.fill + /// CHECK: %[[EXPAND_OUT_1:.*]] = tensor.expand_shape + /// CHECK: %[[EXPAND_OUT_2:.*]] = tensor.expand_shape + %expanded = tensor.expand_shape %extracted_slice_5 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<32x512xbf16> into tensor<1x32x512xbf16> + %expanded_8 = tensor.expand_shape %extracted_slice_6 [[0, 1], [2]] output_shape [1, 32, 512] : tensor<512x32xbf16> into tensor<1x512x32xbf16> + /// CHECK: %[[MATMUL_OUT:.*]] = linalg.batch_reduce_matmul ins(%[[EXPAND_OUT_1]], %[[EXPAND_OUT_2]] : + %17 = linalg.batch_reduce_matmul ins(%expanded, %expanded_8 : tensor<1x32x512xbf16>, tensor<1x512x32xbf16>) outs(%extracted_slice_7 : tensor<32x32xbf16>) -> tensor<32x32xbf16> + /// CHECK: %[[BROADCAST_OUT:.*]] = linalg.broadcast + /// CHECK: %[[ADD_OUT:.*]] = linalg.add ins(%[[MATMUL_OUT]], %[[BROADCAST_OUT]] : + /// CHECK: %[[EXP_OUT:.*]] = linalg.exp ins(%[[ADD_OUT]] : + %inserted_slice_9 = tensor.insert_slice %17 into %arg17[%arg12, %arg14] [32, 32] [1, 1] : tensor<32x32xbf16> into tensor<64x128xbf16> + /// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}} : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + scf.yield %inserted_slice_9 : tensor<64x128xbf16> + } + /// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}} : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + scf.yield %16 : tensor<64x128xbf16> + } + /// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}} : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + scf.yield %15 : tensor<64x128xbf16> + } + /// CHECK: tensor.insert_slice + /// CHECK: tensor.insert_slice + /// CHECK: tensor.insert_slice + %inserted_slice = tensor.insert_slice %14 into %arg11[%arg6, %arg8] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<64x128xbf16> + /// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}} : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + scf.yield %inserted_slice : tensor<64x128xbf16> + } + /// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}} : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + scf.yield %13 : tensor<64x128xbf16> + } + /// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}} : tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16> + scf.yield %12 : tensor<64x128xbf16> + } + %10 = affine.apply affine_map<(d0) -> (d0 * 64)>(%arg3) + %11 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg4) + scf.forall.in_parallel { + /// CHECK: tensor.parallel_insert_slice + /// CHECK: tensor.parallel_insert_slice + /// CHECK: tensor.parallel_insert_slice + tensor.parallel_insert_slice %9 into %arg5[%10, %11] [64, 128] [1, 1] : tensor<64x128xbf16> into tensor<128x256xbf16> + } + } + %broadcasted = linalg.broadcast ins(%arg2 : tensor<256xbf16>) outs(%0 : tensor<128x256xbf16>) dimensions = [0] + %3 = linalg.add ins(%2, %broadcasted : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%0 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %4 = linalg.exp ins(%3 : tensor<128x256xbf16>) outs(%0 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + /// CHECK: return %[[FINAL_RESULT]]#2 + return %4 : tensor<128x256xbf16> + } +} + +// ----- + +#map = affine_map<(d0) -> (d0 * 128)> +module attributes { + dlti.target_system_spec = #dlti.target_system_spec< + "CPU": #dlti.target_device_spec< + #dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>, + #dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>, + #dlti.dl_entry<"L3_cache_size_in_bytes", 110100480 : i32>, + #dlti.dl_entry<"num_threads", 56 : i32>, + #dlti.dl_entry<"max_vector_width", 512 : i32>> + >} { + /// CHECK-LABEL: @fuse_multiple_consumer + func.func @fuse_multiple_consumer(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>, %arg3: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %cst = arith.constant 0.000000e+00 : f32 + %dest0 = tensor.empty() : tensor<256x256xf32> + %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}}) in (2, 2) + %1 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %dest1) -> tensor<256x256xf32> { + %iv0 = affine.apply #map(%arg4) + %iv1 = affine.apply #map(%arg5) + %extracted_slice_1 = tensor.extract_slice %arg6[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32> + %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32> + %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32> + /// CHECK: scf.for + /// CHECK: scf.for + %2 = scf.for %arg7 = %c0 to %c128 step %c64 iter_args(%arg8 = %extracted_slice_1) -> (tensor<128x128xf32>) { + %3 = scf.for %arg9 = %c0 to %c128 step %c64 iter_args(%arg10 = %arg8) -> (tensor<128x128xf32>) { + %extracted_slice_4 = tensor.extract_slice %arg10[%arg7, %arg9] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32> + %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg7, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32> + %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg9] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32> + /// CHECK: %[[MATMUL_OUT:.*]] = linalg.matmul + %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32> + /// CHECK: %[[MUL_OUT:.*]] = linalg.mul ins(%[[MATMUL_OUT]], + /// CHECK: %[[ADD_OUT:.*]] = linalg.add ins(%[[MATMUL_OUT]], + %insert_slice = tensor.insert_slice %4 into %arg10[%arg7, %arg9] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32> + /// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}} : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> + scf.yield %insert_slice : tensor<128x128xf32> + } + /// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}} : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> + scf.yield %3 : tensor<128x128xf32> + } + scf.forall.in_parallel { + /// CHECK: tensor.parallel_insert_slice + /// CHECK: tensor.parallel_insert_slice + /// CHECK: tensor.parallel_insert_slice + tensor.parallel_insert_slice %2 into %arg6[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32> + } + } + %5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + %6 = linalg.mul ins(%1, %arg3 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + /// CHECK: return %[[FINAL_RESULT]]#2, %[[FINAL_RESULT]]#1 + return %5, %6 : tensor<256x256xf32>, tensor<256x256xf32> + } +} + +// ----- + +#map = affine_map<(d0) -> (d0 * 128)> +module attributes { + dlti.target_system_spec = #dlti.target_system_spec< + "CPU": #dlti.target_device_spec< + #dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>, + #dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>, + #dlti.dl_entry<"L3_cache_size_in_bytes", 110100480 : i32>, + #dlti.dl_entry<"num_threads", 56 : i32>, + #dlti.dl_entry<"max_vector_width", 512 : i32>> + >} { + /// CHECK-LABEL: @fuse_reduce + func.func @fuse_reduce(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256xf32> { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %cst = arith.constant 0.000000e+00 : f32 + %dest0 = tensor.empty() : tensor<256x256xf32> + %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}, %{{.*}}) in (2, 1) + %1 = scf.forall (%arg3, %arg4) in (2, 1) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> { + %iv0 = affine.apply #map(%arg3) + %iv1 = affine.apply #map(%arg4) + %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 256] [1, 1] : tensor<256x256xf32> to tensor<128x256xf32> + %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32> + %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 256] [1, 1] : tensor<512x256xf32> to tensor<512x256xf32> + /// CHECK: %[[FOR_RESULT:.*]]:2 = scf.for + /// CHECK: scf.for + %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x256xf32>) { + %3 = scf.for %arg8 = %c0 to %c256 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x256xf32>) { + %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x256xf32> to tensor<64x64xf32> + %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32> + %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32> + /// CHECK: %[[MATMUL_OUT:.*]] = linalg.matmul + %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32> + /// CHECK: %[[ADD_OUT:.*]] = linalg.add ins(%[[MATMUL_OUT]], + %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x256xf32> + /// CHECK: scf.yield {{.*}}, {{.*}} : tensor<128x256xf32>, tensor<128x256xf32> + scf.yield %insert_slice : tensor<128x256xf32> + } + /// CHECK: scf.yield {{.*}}, {{.*}} : tensor<128x256xf32>, tensor<128x256xf32> + scf.yield %3 : tensor<128x256xf32> + } + /// CHECK: %[[REDUCE_OUT:.*]] = linalg.reduce { arith.addf } ins(%[[FOR_RESULT]]#1 : + scf.forall.in_parallel { + /// CHECK: tensor.parallel_insert_slice + /// CHECK: tensor.parallel_insert_slice + /// CHECK: tensor.parallel_insert_slice + tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 256] [1, 1] : tensor<128x256xf32> into tensor<256x256xf32> + } + } + %5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32> + %dest2 = tensor.empty() : tensor<256xf32> + %6 = linalg.reduce { arith.addf } ins(%5 : tensor<256x256xf32>) outs(%dest2 : tensor<256xf32>) dimensions = [1] + /// CHECK: return %[[FINAL_RESULT]]#2 + return %6 : tensor<256xf32> + } +} + +// ----- + +module attributes { + dlti.target_system_spec = #dlti.target_system_spec< + "CPU": #dlti.target_device_spec< + #dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>, + #dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>, + #dlti.dl_entry<"L3_cache_size_in_bytes", 110100480 : i32>, + #dlti.dl_entry<"num_threads", 56 : i32>, + #dlti.dl_entry<"max_vector_width", 512 : i32>> + >} { + /// CHECK-LABEL: @fuse_with_default_tiling + func.func @fuse_with_default_tiling(%arg0: tensor<128x256x256xf32>, %arg1: tensor<128x256x256xf32>) -> tensor<128x256xf32> { + %dest0 = tensor.empty() : tensor<128x256x256xf32> + %0 = linalg.add ins(%arg0, %arg1 : tensor<128x256x256xf32>, tensor<128x256x256xf32>) outs(%dest0 : tensor<128x256x256xf32>) -> tensor<128x256x256xf32> + %dest1 = tensor.empty() : tensor<128x256xf32> + /// CHECK: %[[FINAL_RESULT:.*]] = scf.forall (%{{.*}}, %{{.*}}) in (128, 256) + /// CHECK: tensor.extract_slice {{.*}} [1, 256, 1] [1, 1, 1] + /// CHECK: tensor.extract_slice {{.*}} [1, 256, 1] [1, 1, 1] + /// CHECK: tensor.extract_slice {{.*}} [1, 256, 1] [1, 1, 1] + /// CHECK: %[[ADD_OUT:.*]] = linalg.add + /// CHECK: tensor.extract_slice {{.*}} [1, 1] [1, 1] + /// CHECK: %[[REDUCE_OUT:.*]] = linalg.reduce { arith.addf } ins(%[[ADD_OUT]] : + %1 = linalg.reduce { arith.addf } ins(%0 : tensor<128x256x256xf32>) outs(%dest1 : tensor<128x256xf32>) dimensions = [1] + /// CHECK: scf.forall.in_parallel + /// CHECK: tensor.parallel_insert_slice + /// CHECK: return %[[FINAL_RESULT]] + return %1 : tensor<128x256xf32> + } +}