diff --git a/include/gc/Analysis/GlobalAnalysis.h b/include/gc/Analysis/GlobalAnalysis.h new file mode 100644 index 000000000..1e7d6beac --- /dev/null +++ b/include/gc/Analysis/GlobalAnalysis.h @@ -0,0 +1,165 @@ +//===- GlobalAnalysis.h - Graph Compiler analysis pass ----------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_GLOBALANALYSIS_H +#define MLIR_ANALYSIS_GLOBALANALYSIS_H + +#include + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace gc { + +using namespace mlir; + +class TensorLayout { +public: + TensorLayout(ArrayRef outerAxis, ArrayRef innerAxis, + ArrayRef tileSizes) + : outerAxis(outerAxis), innerAxis(innerAxis), tileSizes(tileSizes) { + assert(innerAxis.size() == tileSizes.size()); + } + + static bool isPlainOuterAxis(ArrayRef outerAxis) { + for (int64_t i = 0; i < static_cast(outerAxis.size()); ++i) { + if (i != outerAxis[i]) + return false; + } + return true; + } + + bool isPlain() const { + if (isPlainOuterAxis(outerAxis)) + return tileSizes.empty() && innerAxis.empty(); + return false; + } + + bool isBlocking() const { return !tileSizes.empty() && !innerAxis.empty(); } + + static TensorLayout createPlainLayout(int64_t rank) { + SmallVector outerAxis(rank, 0); + std::iota(outerAxis.begin(), outerAxis.end(), 0); + return TensorLayout(outerAxis, SmallVector{}, + SmallVector{}); + } + + DenseMap> getPlainToPackedAxisMapping() { + DenseMap> axisMapping; + int64_t outerAxisSize = outerAxis.size(); + for (int64_t i = 0; i < outerAxisSize; ++i) { + axisMapping[outerAxis[i]].push_back(i); + } + for (int64_t i = 0; i < static_cast(innerAxis.size()); ++i) { + axisMapping[innerAxis[i]].push_back(outerAxisSize + i); + } + return axisMapping; + } + + int64_t getPlainAxis(int64_t idx) { + int64_t totalRank = outerAxis.size() + innerAxis.size(); + assert(idx >= 0 && idx < totalRank && "Provided plain axis out of bound"); + if (idx >= static_cast(outerAxis.size())) { + return innerAxis[idx - outerAxis.size()]; + } else { + return outerAxis[idx]; + } + } + + size_t getRank() const { return outerAxis.size(); } + + SmallVector getOuterAxis() const { return outerAxis; } + + SmallVector getInnerAxis() const { return innerAxis; } + + SmallVector getTileSizes() const { return tileSizes; } + + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, + const TensorLayout &layout); + + bool operator==(const TensorLayout &other) const; + + bool operator!=(const TensorLayout &other) const; + +private: + SmallVector outerAxis; + SmallVector innerAxis; + SmallVector tileSizes; +}; + +class OperatorLayout { +public: + OperatorLayout() {} + + OperatorLayout(SmallVector inputLayouts, + SmallVector outputLayouts) { + supportedInputLayouts = inputLayouts; + supportedOutputLayouts = outputLayouts; + } + + SmallVector getSupportedInputLayouts() const { + return supportedInputLayouts; + } + + SmallVector getSupportedOutputLayouts() const { + return supportedOutputLayouts; + } + + TensorLayout getOutputLayout(int64_t idx) const { + assert(idx < static_cast(supportedOutputLayouts.size())); + return supportedOutputLayouts[idx]; + } + + bool isPlain() const { + for (const auto &layout : llvm::concat( + supportedInputLayouts, supportedOutputLayouts)) { + if (!layout.isPlain()) + return false; + } + return true; + } + + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, + const OperatorLayout &opLayout); + +private: + SmallVector supportedInputLayouts; + SmallVector supportedOutputLayouts; +}; + +class GlobalAnalysis { +public: + explicit GlobalAnalysis(Operation *root); + + FailureOr getOpLayout(Operation *op) { + if (layoutCache.find(op) != layoutCache.end()) + return layoutCache[op]; + else + return failure(); + } + +private: + DenseMap layoutCache; +}; + +namespace utils { +bool isSupportedContractionNamedOp(const linalg::LinalgOp &linalgOp); + +bool isPackableOp(Operation *op); + +bool hasAllTensorSemantics(linalg::LinalgOp linalgOp); +} // namespace utils +} // namespace gc +} // namespace mlir + +#endif diff --git a/include/gc/Analysis/MatmulConfigAnalysis.h b/include/gc/Analysis/MatmulConfigAnalysis.h index 2b275f246..3507f6edc 100644 --- a/include/gc/Analysis/MatmulConfigAnalysis.h +++ b/include/gc/Analysis/MatmulConfigAnalysis.h @@ -128,6 +128,12 @@ getOprandDimType(linalg::LinalgOp &linalgOp) { SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, SmallVector{DimType::N, DimType::K, DimType::K, DimType::N}, SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; + } else if (linalgx::isGenericPackedMatmulOp(linalgOp.getOperation(), + linalgx::PackingType::MM2D4D)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K}, + SmallVector{DimType::N, DimType::K, DimType::K, DimType::N}, + SmallVector{DimType::M, DimType::N}}; } return failure(); } diff --git a/include/gc/Dialect/Linalgx/Utils.h b/include/gc/Dialect/Linalgx/Utils.h index 5bc83b449..d5281b60a 100644 --- a/include/gc/Dialect/Linalgx/Utils.h +++ b/include/gc/Dialect/Linalgx/Utils.h @@ -20,6 +20,7 @@ namespace linalgx { /// @brief enum of type of matmul packing enum class PackingType : int { MM4D = 0, // MKmk x NKkn + MM2D4D, // MK x NKkn VNNI_MM2D, // MK x NKknV VNNI_MM4D, // MKmk x NKknV VNNI_BRMM3D, // BMK x BKNV diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 5151a0335..905c4f4eb 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -169,6 +169,40 @@ def MergeNestedForall : Pass<"merge-nested-forall"> { let dependentDialects = ["scf::SCFDialect"]; } +def PropagateLayoutOnNamedOps : Pass<"propagate-layout-on-named-ops"> { + let summary = "Insert and propagte tensor.pack to pack the computation of linalg named ops and tensor ops."; + let description = [{ + Insert and propagte tensor.pack on linalg named ops and tensor ops. + }]; + let dependentDialects = [ + "mlir::tensor::TensorDialect", + "mlir::linalg::LinalgDialect", + "mlir::linalgx::LinalgxDialect" + ]; +} + +def PostProcessPackUnpack : Pass<"post-process-pack-unpack"> { + let summary = "Fold and simplify pack and unpack ops."; + let description = [{ + Fold and simplify pack and unpack ops. + }]; + let dependentDialects = [ + "mlir::tensor::TensorDialect", + "mlir::linalg::LinalgDialect" + ]; +} + +def LowerPackUnpack : Pass<"lower-pack-unpack"> { + let summary = "Lower pack and unpack ops."; + let description = [{ + Lower pack and unpack into transpose and shape related ops. + }]; + let dependentDialects = [ + "mlir::tensor::TensorDialect", + "mlir::linalg::LinalgDialect" + ]; +} + def FoldTensorOperation : Pass<"fold-tensor-operation"> { let summary = "Fold some tensor operation"; let description = [{ @@ -179,6 +213,7 @@ def FoldTensorOperation : Pass<"fold-tensor-operation"> { ]; } + def LowerToTileVector : Pass<"lower-to-tile-vector", "func::FuncOp"> { let summary = "Lower tensor to tile (virtual) vector"; let description = [{ diff --git a/include/gc/Transforms/Transforms.h b/include/gc/Transforms/Transforms.h new file mode 100644 index 000000000..1b10cc64f --- /dev/null +++ b/include/gc/Transforms/Transforms.h @@ -0,0 +1,27 @@ +//===-- Transforms.h - transformation utilities -----------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_TRANSFORMS_TRANSFORMS_H +#define GC_TRANSFORMS_TRANSFORMS_H + +#include "gc/Analysis/GlobalAnalysis.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" + +namespace mlir { +namespace gc { +LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, + const OperatorLayout &opLayout); + +LogicalResult namedOpLayoutPropagation(RewriterBase &rewriter, + linalg::LinalgOp linalgOp, + OperatorLayout opLayout); +} // namespace gc +} // namespace mlir + +#endif // GC_TRANSFORMS_TRANSFORMS_H diff --git a/lib/gc/Analysis/CMakeLists.txt b/lib/gc/Analysis/CMakeLists.txt index d7160f350..55a689d86 100644 --- a/lib/gc/Analysis/CMakeLists.txt +++ b/lib/gc/Analysis/CMakeLists.txt @@ -5,6 +5,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS gc_add_mlir_library(GcAnalysis TargetDescriptionAnalysis.cpp MatmulConfigAnalysis.cpp + GlobalAnalysis.cpp DEPENDS GraphCompilerPassIncGen diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp new file mode 100644 index 000000000..44aba2e75 --- /dev/null +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -0,0 +1,651 @@ +//===-- GlobalAnalysis.cpp - Infer layout on packable ops -------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "gc/Analysis/GlobalAnalysis.h" +#include "gc/Analysis/MatmulConfigAnalysis.h" +#include "llvm/ADT/SetOperations.h" + +namespace mlir { +namespace gc { + +#define DEBUG_TYPE "global-analysis" + +namespace utils { +// TODO(yifei): extend to batch matmuls, sync with deep tile matmul +bool isSupportedContractionNamedOp(const linalg::LinalgOp &linalgOp) { + return isa(linalgOp); +} + +bool isPackableOp(Operation *op) { + if (auto linalgOp = dyn_cast(op)) { + if (!mlir::linalg::isaContractionOpInterface(linalgOp) && + !mlir::linalg::isaConvolutionOpInterface(linalgOp) && + !isSupportedContractionNamedOp(linalgOp)) { + return true; + } + } else if (isa( + op)) + return true; + return false; +} + +bool hasAllTensorSemantics(linalg::LinalgOp linalgOp) { + SmallVector initOperands = llvm::to_vector(llvm::map_range( + linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); + SmallVector inputOperands = linalgOp.getDpsInputOperands(); + return llvm::all_of(inputOperands, + [](OpOperand *opOperand) { + return mlir::isa( + opOperand->get().getType()); + }) && + llvm::all_of(initOperands, [](OpOperand *opOperand) { + return mlir::isa(opOperand->get().getType()); + }); +} +} // namespace utils + +llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, + const TensorLayout &layout) { + SmallVector outerAxis = layout.getOuterAxis(); + SmallVector innerAxis = layout.getInnerAxis(); + SmallVector tileSizes = layout.getTileSizes(); + ss << "["; + llvm::interleaveComma(outerAxis, ss); + if (!innerAxis.empty()) { + ss << "; "; + llvm::interleaveComma(innerAxis, ss); + } + ss << "]"; + if (!tileSizes.empty()) { + ss << "; {"; + llvm::interleaveComma(tileSizes, ss); + ss << "}"; + } + return ss; +} + +bool TensorLayout::operator==(const TensorLayout &other) const { + return (this->outerAxis == other.getOuterAxis()) && + (this->innerAxis == other.getInnerAxis()) && + (this->tileSizes == other.getTileSizes()); +} + +bool TensorLayout::operator!=(const TensorLayout &other) const { + return !(*this == other); +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, + const OperatorLayout &opLayout) { + if (!opLayout.getSupportedInputLayouts().empty()) { + ss << "Input layouts: "; + llvm::interleave(opLayout.getSupportedInputLayouts(), ss, "; "); + ss << ". "; + } + if (!opLayout.getSupportedOutputLayouts().empty()) { + ss << "Output layouts: "; + llvm::interleave(opLayout.getSupportedOutputLayouts(), ss, "; "); + ss << ". "; + } + return ss; +} + +// infer the relation between two indexing maps +// returns target dim -> base dim, means target is the same as base +// we don't allow duplication, e.g. 2 target corresponding to 1 base +static FailureOr> +inferIndexingMapRelation(AffineMap indexingMapBase, + AffineMap indexingMapTarget) { + // symbols are not allowed to occur + if (indexingMapBase.getNumSymbols() != 0 || + indexingMapTarget.getNumSymbols() != 0) + return failure(); + DenseMap res; + ArrayRef resultsBase = indexingMapBase.getResults(); + ArrayRef resultsTarget = indexingMapTarget.getResults(); + for (size_t j = 0; j < resultsTarget.size(); ++j) { + for (size_t i = 0; i < resultsBase.size(); ++i) { + auto base = dyn_cast(resultsBase[i]); + auto target = dyn_cast(resultsTarget[j]); + if (base && target && base.getPosition() == target.getPosition()) { + // dim j already mapped to certain i + if (res.find(j) != res.end()) + return failure(); + res[j] = i; + } + } + if (res.find(j) == res.end()) + res[j] = -1; + } + DenseSet indexSet; + for (auto pair : res) { + if (indexSet.find(pair.second) != indexSet.end()) { + return failure(); + } + if (pair.second >= 0) { + indexSet.insert(pair.second); + } + } + return res; +} + +// given target --> base and max rank of base, return base --> target +static DenseMap +getReversedIndexMap(const DenseMap &indexMap, + size_t maxRank) { + DenseMap res; + for (auto pair : indexMap) { + if (pair.second >= 0) { + res[pair.second] = pair.first; + } + } + for (size_t i = 0; i < maxRank; ++i) { + if (res.find(i) == res.end()) { + res[i] = -1; + } + } + return res; +} + +static FailureOr inferTargetLayout(const TensorLayout &layoutBase, + AffineMap indexingMapBase, + AffineMap indexingMapTarget) { + SmallVector baseOuterAxis = layoutBase.getOuterAxis(); + SmallVector baseInnerAxis = layoutBase.getInnerAxis(); + SmallVector baseTileSizes = layoutBase.getTileSizes(); + SmallVector targetOuterAxis; + SmallVector targetInnerAxis; + SmallVector targetTileSizes; + FailureOr> indexMap = + inferIndexingMapRelation(indexingMapBase, indexingMapTarget); + if (failed(indexMap)) + return failure(); + DenseMap reverseIndexMap = + getReversedIndexMap(*indexMap, layoutBase.getRank()); + for (int64_t oa : baseOuterAxis) { + if (reverseIndexMap[oa] >= 0) + targetOuterAxis.push_back(reverseIndexMap[oa]); + } + // filling up new j axes + SmallVector newDimAxis; + for (const auto &pair : *indexMap) { + if (pair.second < 0) + newDimAxis.push_back(pair.first); + } + // TODO(yifei): double consider the performance + // whether to push all new axis at the beginning of outer perm + targetOuterAxis.insert(targetOuterAxis.begin(), newDimAxis.begin(), + newDimAxis.end()); + for (auto &&[ia, ts] : llvm::zip(baseInnerAxis, baseTileSizes)) { + if (reverseIndexMap[ia] >= 0) { + targetInnerAxis.push_back(reverseIndexMap[ia]); + targetTileSizes.push_back(ts); + } + } + return TensorLayout(targetOuterAxis, targetInnerAxis, targetTileSizes); +} + +// TODO(yifei): enhance the logic for choose base input index +static size_t getBaseInputIdx(ArrayRef curInputLayouts) { + for (size_t i = 0; i < curInputLayouts.size(); ++i) { + if (!curInputLayouts[i].isPlain()) { + return i; + } + } + return 0; +} + +std::pair, SmallVector> +getPackingAxis(int64_t numRank, bool transposed) { + assert(numRank >= 2 && + "The rank of matmul semantic contraction op shall be at least 2."); + SmallVector outerAxisPerm(numRank); + SmallVector innerAxisPos(2); + std::iota(outerAxisPerm.begin(), outerAxisPerm.end(), 0); + innerAxisPos[0] = numRank - 2; + innerAxisPos[1] = numRank - 1; + if (transposed) { + std::swap(outerAxisPerm[numRank - 2], outerAxisPerm[numRank - 1]); + std::swap(innerAxisPos[0], innerAxisPos[1]); + } + return std::make_pair(outerAxisPerm, innerAxisPos); +} + +// copied from mlir +static SmallVector +projectToInnerMostNonUnitDimsPos(ArrayRef dimsPos, + ArrayRef reassocIndices, + ArrayRef targetShape) { + SmallVector projectedDimsPos; + for (auto pos : dimsPos) { + // In the case all dims are unit, this will return the inner-most one. + int64_t projectedPos = reassocIndices[pos].back(); + for (auto i : llvm::reverse(reassocIndices[pos])) { + int64_t dim = targetShape[i]; + if (dim > 1 || ShapedType::isDynamic(dim)) { + projectedPos = i; + break; + } + } + projectedDimsPos.push_back(projectedPos); + } + return projectedDimsPos; +} + +// Check if all dims in dimsPos are divisible by the corresponding tile sizes. +static bool isDimsDivisibleByTileSizes(ArrayRef dimsPos, + ArrayRef shape, + ArrayRef tileSizes) { + return llvm::all_of(llvm::zip_equal(dimsPos, tileSizes), + [shape](std::tuple sizePair) { + int64_t dim = shape[std::get<0>(sizePair)]; + return !ShapedType::isDynamic(dim) && + (dim % std::get<1>(sizePair)) == 0; + }); +} + +// if forceBlocking is set to true, we will unconditionally convert +// input/weight/output to blocking layout; otherwise we follow the default +// heuristic logic +static SmallVector +queryMatmulLayout(IRRewriter &rewriter, linalg::LinalgOp matmulOp, + ArrayRef curInputLayouts, + bool forceBlocking = false) { + SmallVector ret; + // infer layout for linalg contraction named ops + auto ARank = matmulOp.getRank(matmulOp.getDpsInputOperand(0)); + auto BRank = matmulOp.getRank(matmulOp.getDpsInputOperand(1)); + auto CRank = matmulOp.getRank(matmulOp.getDpsInitOperand(0)); + auto elementType = getElementTypeOrSelf(matmulOp.getDpsInputs()[0].getType()); + auto AShape = matmulOp.getShape(matmulOp.getDpsInputOperand(0)); + auto BShape = matmulOp.getShape(matmulOp.getDpsInputOperand(1)); + int64_t M = AShape[0], K = AShape[1], N = BShape[1]; + bool ASideTransposed = + isa( + matmulOp); + bool BSideTransposed = + isa( + matmulOp); + // set outer&inner axis values + auto APackInfo = getPackingAxis(ARank, ASideTransposed); + auto BPackInfo = getPackingAxis(BRank, BSideTransposed); + auto CPackInfo = getPackingAxis(CRank, /*transposed*/ false); + // query the cost model for tile sizes + MatmulConfig cfg = MatmulConfigAnalysis(matmulOp.getOperation()).getConfig(); + uint32_t iim = cfg.innerMostMBlock, iin = cfg.innerMostNBlock, + iik = cfg.innerMostKBlock; + if (forceBlocking) { + TensorLayout ALayout(APackInfo.first, APackInfo.second, + SmallVector{rewriter.getIndexAttr(iim), + rewriter.getIndexAttr(iik)}); + TensorLayout BLayout(BPackInfo.first, BPackInfo.second, + SmallVector{rewriter.getIndexAttr(iik), + rewriter.getIndexAttr(iin)}); + TensorLayout CLayout(CPackInfo.first, CPackInfo.second, + SmallVector{rewriter.getIndexAttr(iim), + rewriter.getIndexAttr(iin)}); + ret.emplace_back(SmallVector{ALayout, BLayout}, + SmallVector{CLayout}); + return ret; + } + // TODO(yifei): add detailed check for constant A or B + bool constantA = false, constantB = true; + SmallVector ALayouts, BLayouts, CLayouts; + if (constantA || curInputLayouts[0].isBlocking() || (M % iim) || (K % iik) || + (elementType.isBF16() && + curInputLayouts[0] == TensorLayout({1, 0}, {}, {}))) { + ALayouts.emplace_back( + APackInfo.first, APackInfo.second, + SmallVector{rewriter.getIndexAttr(iim), + rewriter.getIndexAttr(iik)}); + } else { + ALayouts.emplace_back(APackInfo.first, SmallVector{}, + SmallVector{}); + } + if (constantB || curInputLayouts[1].isBlocking() || K % iik || N % iin || + elementType.isBF16()) { + BLayouts.emplace_back( + BPackInfo.first, BPackInfo.second, + SmallVector{rewriter.getIndexAttr(iik), + rewriter.getIndexAttr(iin)}); + } else { + BLayouts.emplace_back(BPackInfo.first, SmallVector{}, + SmallVector{}); + } + if (M == iim && M >= 32 && N % iin == 0) { + CLayouts.emplace_back(CPackInfo.first, SmallVector{}, + SmallVector{}); + } else if (M % iim || N % iin) { + CLayouts.emplace_back( + CPackInfo.first, CPackInfo.second, + SmallVector{rewriter.getIndexAttr(iim), + rewriter.getIndexAttr(iin)}); + } else { + if (BSideTransposed) { + CLayouts.emplace_back(CPackInfo.first, SmallVector{}, + SmallVector{}); + } else { + // push 2 possibilities + CLayouts.emplace_back(CPackInfo.first, SmallVector{}, + SmallVector{}); + CLayouts.emplace_back( + CPackInfo.first, CPackInfo.second, + SmallVector{rewriter.getIndexAttr(iim), + rewriter.getIndexAttr(iin)}); + // duplicate ALayouts and BLayouts + ALayouts.emplace_back(ALayouts[0]); + BLayouts.emplace_back(BLayouts[0]); + } + } + for (auto [ALayout, BLayout, CLayout] : + llvm::zip(ALayouts, BLayouts, CLayouts)) { + ret.emplace_back(SmallVector{ALayout, BLayout}, + SmallVector{CLayout}); + } + return ret; +} + +GlobalAnalysis::GlobalAnalysis(Operation *root) { + IRRewriter rewriter(root); + // stage 1: calculate the total number of layout combination + int64_t totalLayoutPossibilities = 1; + std::vector possibilities; + root->walk([&](Operation *op) { + if (auto linalgOp = dyn_cast(op)) { + if (mlir::gc::utils::isSupportedContractionNamedOp(linalgOp)) { + auto curInputs = linalgOp.getDpsInputOperands(); + SmallVector curInputLayouts; + for (auto input : curInputs) + curInputLayouts.push_back(TensorLayout::createPlainLayout( + linalgOp.getMatchingIndexingMap(input).getNumResults())); + auto suggestedLayouts = + queryMatmulLayout(rewriter, linalgOp, curInputLayouts); + possibilities.push_back(suggestedLayouts.size()); + totalLayoutPossibilities *= possibilities.back(); + } + } + return WalkResult::advance(); + }); + // define cost function + auto computePackingCost = + [&](Operation *op, ArrayRef curInputLayouts, + ArrayRef suggestedLayouts = {}) -> int64_t { + int64_t cost = 0; + assert(op->getOperands().size() >= curInputLayouts.size() && + "curInputLayouts size out of range."); + for (auto [index, curLayout] : llvm::enumerate(curInputLayouts)) { + TensorLayout suggestedLayout = + suggestedLayouts.empty() + ? TensorLayout::createPlainLayout(curLayout.getRank()) + : suggestedLayouts[index]; + if (curLayout != suggestedLayout) { + ArrayRef shape = + cast(op->getOperands()[index].getType()) + .getShape(); + int64_t inputSize = std::accumulate( + shape.begin(), shape.end(), (int64_t)1, std::multiplies()); + if (suggestedLayout.isBlocking()) + cost += inputSize * 0.9; + else + cost += inputSize; + } + } + return cost; + }; + std::vector curChoice(possibilities.size(), 0); + int64_t bestCost = std::numeric_limits::max(); + // stage 2: infer layout for each possibility + for (int64_t trialIdx = 0; trialIdx < totalLayoutPossibilities; ++trialIdx) { + // stage 2.1: get the current layout choice + int64_t tmpIdx = trialIdx; + for (size_t i = 0; i < possibilities.size(); i++) { + curChoice[i] = tmpIdx % possibilities[i]; + tmpIdx /= possibilities[i]; + } + LLVM_DEBUG(llvm::dbgs() << "Inferring with layout choice: ["); + LLVM_DEBUG(llvm::interleaveComma(curChoice, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "].\n"); + int64_t curMatmulIdx = 0; + int64_t curCost = 0; + // stage 2.2: infer the current temp layout for the whole graph + DenseMap tmpLayoutCache; + root->walk([&](Operation *op) { + LLVM_DEBUG(llvm::dbgs() + << "Try inferring layout for op: " << op->getName() << "\n"); + if (auto linalgOp = dyn_cast(op)) { + auto curInputs = linalgOp.getDpsInputOperands(); + auto curResults = linalgOp.getOperation()->getResults(); + // if any input/output is not tensor, skip it + if (!gc::utils::hasAllTensorSemantics(linalgOp)) { + LLVM_DEBUG( + llvm::dbgs() + << "Op " << linalgOp.getOperation()->getName() + << " contains non-tensor operand. Skip layout inference.\n"); + return WalkResult::skip(); + } + // get current op's input layouts + SmallVector curInputLayouts; + for (auto input : curInputs) { + auto parent = input->get().getDefiningOp(); + if (tmpLayoutCache.find(parent) != tmpLayoutCache.end()) { + // TODO(yifei): extend to cases with multiple outputs + curInputLayouts.push_back( + tmpLayoutCache[parent].getOutputLayout(0)); + } else { + curInputLayouts.push_back(TensorLayout::createPlainLayout( + linalgOp.getMatchingIndexingMap(input).getNumResults())); + } + } + // start infer current op's layout + if (mlir::gc::utils::isSupportedContractionNamedOp(linalgOp)) { + auto suggestedLayouts = + queryMatmulLayout(rewriter, linalgOp, curInputLayouts, false); + tmpLayoutCache[linalgOp] = + suggestedLayouts[curChoice[curMatmulIdx++]]; + curCost += computePackingCost( + op, curInputLayouts, + tmpLayoutCache[linalgOp].getSupportedInputLayouts()); + } else if (mlir::gc::utils::isPackableOp(op)) { + // infer layout for non-contraction/non-convolution linalg named ops + // and linalg generic ops + SmallVector inputLayouts, outputLayouts; + size_t baseIdx = getBaseInputIdx(curInputLayouts); + // infer layout for inputs + for (size_t i = 0; i < curInputs.size(); ++i) { + if (i != baseIdx) { + FailureOr inferredLayout = inferTargetLayout( + curInputLayouts[baseIdx], + linalgOp.getMatchingIndexingMap(curInputs[baseIdx]), + linalgOp.getMatchingIndexingMap(curInputs[i])); + if (failed(inferredLayout)) { + LLVM_DEBUG( + llvm::dbgs() + << "Op " << linalgOp.getOperation()->getName() + << "'s input " << i + << "'s layout cannot be inferred. Choose plain layout.\n"); + curCost += computePackingCost(op, curInputLayouts); + return WalkResult::skip(); + } + inputLayouts.push_back(*inferredLayout); + } else { + inputLayouts.push_back(curInputLayouts[baseIdx]); + } + } + // infer layout for output + FailureOr inferredOutputLayout = inferTargetLayout( + curInputLayouts[baseIdx], + linalgOp.getMatchingIndexingMap(curInputs[baseIdx]), + linalgOp.getIndexingMapMatchingResult(curResults[0])); + if (failed(inferredOutputLayout)) { + LLVM_DEBUG(llvm::dbgs() + << "Op " << linalgOp.getOperation()->getName() + << "'s output layout cannot be inferred. Choose plain " + "layout.\n"); + curCost += computePackingCost(op, curInputLayouts); + return WalkResult::skip(); + } + outputLayouts.push_back(*inferredOutputLayout); + OperatorLayout suggestedLayout(inputLayouts, outputLayouts); + tmpLayoutCache[linalgOp] = suggestedLayout; + curCost += computePackingCost(op, curInputLayouts, inputLayouts); + } + } else if (auto padOp = dyn_cast(op)) { + auto inputOperand = padOp.getSource(); + auto inputRank = + cast(inputOperand.getType()).getShape().size(); + auto parent = inputOperand.getDefiningOp(); + TensorLayout curInputLayout = + tmpLayoutCache.find(parent) != tmpLayoutCache.end() + ? tmpLayoutCache[parent].getOutputLayout(0) + : TensorLayout::createPlainLayout(inputRank); + SmallVector inputLayouts{curInputLayout}, + outputLayouts{curInputLayout}; + OperatorLayout suggestedLayout(inputLayouts, outputLayouts); + tmpLayoutCache[padOp] = suggestedLayout; + } else if (auto expandShapeOp = dyn_cast(op)) { + SmallVector reassocIndices = + expandShapeOp.getReassociationIndices(); + auto staticOutputShape = expandShapeOp.getStaticOutputShape(); + auto parent = expandShapeOp.getSrc().getDefiningOp(); + auto inputShape = expandShapeOp.getSrcType().getShape(); + TensorLayout curInputLayout = + tmpLayoutCache.find(parent) != tmpLayoutCache.end() + ? tmpLayoutCache[parent].getOutputLayout(0) + : TensorLayout::createPlainLayout(inputShape.size()); + SmallVector innerTileSizes; + auto tileSizes = getConstantIntValues(curInputLayout.getTileSizes()); + if (tileSizes) { + innerTileSizes = *tileSizes; + } else { + LLVM_DEBUG(llvm::dbgs() + << "ExpandShapeOp's layout cannot be penetrated. Skip.\n"); + curCost += + computePackingCost(op, SmallVector{curInputLayout}); + return WalkResult::skip(); + } + SmallVector innerPosPos = curInputLayout.getInnerAxis(); + SmallVector outerDimsPerm = curInputLayout.getOuterAxis(); + SmallVector projectedInnerDimsPos = + projectToInnerMostNonUnitDimsPos(innerPosPos, reassocIndices, + staticOutputShape); + + if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, + staticOutputShape, innerTileSizes)) { + LLVM_DEBUG(llvm::dbgs() + << "ExpandShapeOp's layout cannot be penetrated. Skip.\n"); + return WalkResult::skip(); + } + SmallVector newOuterDimsPerm; + for (auto outerPos : outerDimsPerm) { + newOuterDimsPerm.insert(newOuterDimsPerm.end(), + reassocIndices[outerPos].begin(), + reassocIndices[outerPos].end()); + } + TensorLayout outputLayout(newOuterDimsPerm, projectedInnerDimsPos, + curInputLayout.getTileSizes()); + SmallVector inputLayouts{curInputLayout}, + outputLayouts{outputLayout}; + OperatorLayout suggestedLayout(inputLayouts, outputLayouts); + tmpLayoutCache[expandShapeOp] = suggestedLayout; + } else if (auto collapseShapeOp = dyn_cast(op)) { + SmallVector reassocIndices = + collapseShapeOp.getReassociationIndices(); + auto parent = collapseShapeOp.getSrc().getDefiningOp(); + auto inputShape = collapseShapeOp.getSrcType().getShape(); + TensorLayout curInputLayout = + tmpLayoutCache.find(parent) != tmpLayoutCache.end() + ? tmpLayoutCache[parent].getOutputLayout(0) + : TensorLayout::createPlainLayout(inputShape.size()); + auto innerPos = curInputLayout.getInnerAxis(); + llvm::SetVector innerPosSet(innerPos.begin(), innerPos.end()); + for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { + // For each reassociation, figure out which dimensions get packed if + // any. + llvm::SetVector collapseDimPos(indices.begin(), + indices.end()); + llvm::SetVector packedDims = + llvm::set_intersection(innerPosSet, collapseDimPos); + // only one of the collapsed indices can be packed + if (packedDims.size() > 1) { + LLVM_DEBUG( + llvm::dbgs() + << "CollapseShapeOp's layout cannot be penetrated. Skip.\n"); + return WalkResult::skip(); + } + // Only the inner-most expanded dimension should be packed. Otherwise, + // elements order will be affected after operation reordering. + if (!packedDims.empty() && packedDims[0] != indices.back()) { + LLVM_DEBUG( + llvm::dbgs() + << "CollapseShapeOp's layout cannot be penetrated. Skip.\n"); + return WalkResult::skip(); + } + } + // Project pack.inner_dims_pos to positions before shape expansion. + SmallVector projectedInnerDimsPos; + for (auto pos : innerPos) { + for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { + if (llvm::any_of(indices, [&](int64_t collapseDim) { + return collapseDim == pos; + })) { + projectedInnerDimsPos.push_back(idx); + break; + } + } + } + assert(projectedInnerDimsPos.size() == innerPos.size() && + "Invalid dim pos projection"); + + // outerPerm shall be a permutation of reassocIndices + auto outerPerm = curInputLayout.getOuterAxis(); + SmallVector newOuterDimsPerm; + int64_t axisIdx = 0; + while (axisIdx < static_cast(outerPerm.size())) { + for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { + if (llvm::any_of(indices, [&](int64_t collapseDim) { + return collapseDim == outerPerm[axisIdx]; + })) { + for (auto collapseDim : indices) { + if (collapseDim != outerPerm[axisIdx++]) { + LLVM_DEBUG(llvm::dbgs() << "CollapseShapeOp's layout cannot " + "be penetrated. Skip.\n"); + return WalkResult::skip(); + } + } + newOuterDimsPerm.push_back(idx); + break; + } + } + } + TensorLayout outputLayout(newOuterDimsPerm, projectedInnerDimsPos, + curInputLayout.getTileSizes()); + SmallVector inputLayouts{curInputLayout}, + outputLayouts{outputLayout}; + OperatorLayout suggestedLayout(inputLayouts, outputLayouts); + tmpLayoutCache[collapseShapeOp] = suggestedLayout; + } + if (tmpLayoutCache.find(op) != tmpLayoutCache.end()) { + LLVM_DEBUG(llvm::dbgs() << "Inferred layout of op: " << op->getName() + << " is: " << tmpLayoutCache[op] << "\n"); + } + return WalkResult::advance(); + }); + if (curCost < bestCost) { + bestCost = curCost; + layoutCache = tmpLayoutCache; + LLVM_DEBUG(llvm::dbgs() + << "Current cost " << curCost + << " is lower than the best cost; update best cost." + << "\n"); + } + } +} +} // namespace gc +} // namespace mlir diff --git a/lib/gc/Dialect/Linalgx/Utils.cpp b/lib/gc/Dialect/Linalgx/Utils.cpp index fe9096fe7..683038940 100644 --- a/lib/gc/Dialect/Linalgx/Utils.cpp +++ b/lib/gc/Dialect/Linalgx/Utils.cpp @@ -341,6 +341,12 @@ PackingAttr getPackingAttr(PackingType opType) { attr.nPacking = {PackingMap{{0}, {1}}, PackingMap{{3}, {3}}}; attr.kPacking = {PackingMap{{1}, {1}}, PackingMap{{3}, {2}}}; } break; + case PackingType::MM2D4D: { + attr.weightDims = 4; + attr.mPacking = {PackingMap{{0}, {0}}}; + attr.nPacking = {PackingMap{{0, 3}, {1}}}; + attr.kPacking = {PackingMap{{1}, {1, 2}}}; + } break; case PackingType::VNNI_MM2D: { attr.isVnni = true; attr.weightDims = 5; diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 2d10ed88f..37c79b9b9 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -16,6 +16,9 @@ gc_add_mlir_library(GcPasses TileUsingInterfaceX.cpp IterativeTilingAndFusion.cpp VerifyTargetDescription.cpp + PropagateLayout.cpp + PostProcessPackUnpack.cpp + LowerPackUnpack.cpp DecomposeAggregatedOps.cpp DeepTileContractionOp.cpp TilingUtil.cpp diff --git a/lib/gc/Transforms/DeepTileContractionOp.cpp b/lib/gc/Transforms/DeepTileContractionOp.cpp index 21de7b778..8805c9dc5 100644 --- a/lib/gc/Transforms/DeepTileContractionOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionOp.cpp @@ -952,7 +952,8 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { return llvm::isa(linalgOp) || linalgx::isGenericPackedMatmulOp( linalgOp.getOperation(), linalgx::PackingType::VNNI_MM2D, - linalgx::PackingType::VNNI_MM4D, linalgx::PackingType::MM4D); + linalgx::PackingType::VNNI_MM4D, linalgx::PackingType::MM4D, + linalgx::PackingType::MM2D4D); } LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, diff --git a/lib/gc/Transforms/LowerPackUnpack.cpp b/lib/gc/Transforms/LowerPackUnpack.cpp new file mode 100644 index 000000000..efd96a0dd --- /dev/null +++ b/lib/gc/Transforms/LowerPackUnpack.cpp @@ -0,0 +1,84 @@ +//===-- LowerPackUnpack.cpp - Lower pack unpack into linalg ops -*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "gc/Transforms/Transforms.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Dialect/Linalgx/LinalgxDialect.h" +#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "gc/Transforms/Passes.h" +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_LOWERPACKUNPACK +#include "gc/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "lower-pack-unpack" + +using namespace mlir; + +// copied from tpp +// A wrapper pattern that calls linalg::lowerPack on tensor::PackOp. It lowers +// a tensor.pack op to tensor.pad + tensor.expand_shape + linalg.transpose ops. +struct LowerPackPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp op, + PatternRewriter &rewriter) const override { + FailureOr res = linalg::lowerPack(rewriter, op); + if (failed(res)) { + return rewriter.notifyMatchFailure( + op, "cannot lower to pad + expand + transpose"); + } + return success(); + } +}; + +// A wrapper pattern that calls linalg::lowerUnPack on tensor::UnPackOp. It +// lowers a tensor.unpack op to tensor.empty + linalg.transpose + +// tensor.collapse_shape + tensor.extract_slice ops. +struct LowerUnPackPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::UnPackOp op, + PatternRewriter &rewriter) const override { + if (failed(linalg::lowerUnPack(rewriter, op))) { + return rewriter.notifyMatchFailure( + op, "cannot lower to empty + transpose + reshape + extract_slice"); + } + return success(); + } +}; + +class LowerPackUnpack : public impl::LowerPackUnpackBase { +public: + using impl::LowerPackUnpackBase::LowerPackUnpackBase; + void runOnOperation() final; +}; + +void LowerPackUnpack::runOnOperation() { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +} // namespace gc +} // namespace mlir diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 6fdc445cf..443918a20 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -52,6 +52,8 @@ void populateFrontendPasses(mlir::OpPassManager &pm) { void populateTensorPasses(mlir::OpPassManager &pm) { // todo: padding propagation pass // todo: layout propagation pass + pm.addPass(createPropagateLayoutOnNamedOps()); + pm.addPass(createPostProcessPackUnpack()); // todo: tensor constant propagation pass // linalg.matmul lowering to (scf.loop + linalg.brgemm) pass pm.addNestedPass(createDeepTileContractionOp()); @@ -72,6 +74,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) { pm.addPass(createFoldTensorOperation()); pm.addPass(createLoopInvariantCodeMotionPass()); pm.addPass(createControlFlowSinkPass()); + // TODO(yifei): remove lower pack here + pm.addPass(createLowerPackUnpack()); populateCleanUpPasses(pm); } diff --git a/lib/gc/Transforms/PostProcessPackUnpack.cpp b/lib/gc/Transforms/PostProcessPackUnpack.cpp new file mode 100644 index 000000000..6b5c2336c --- /dev/null +++ b/lib/gc/Transforms/PostProcessPackUnpack.cpp @@ -0,0 +1,167 @@ +//===-- PostProcessPackUnpack.cpp - Simplify pack unpack --------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Transforms/Passes.h" +#include "gc/Transforms/Transforms.h" + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_POSTPROCESSPACKUNPACK +#include "gc/Transforms/Passes.h.inc" + +using namespace mlir; + +// copied from tpp - lower tensor.pack operations that pack constants. +struct LowerConstantPacking : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + auto constOp = packOp.getSource().getDefiningOp(); + if (!constOp) + return failure(); + // Must be a dense constant. + auto denseAttr = dyn_cast(constOp.getValue()); + if (!denseAttr) + return failure(); + + // Bail out if the pack is used as a writing operation i.e., the destination + // is not a tensor.empty. + if (!packOp.getDest().getDefiningOp()) + return rewriter.notifyMatchFailure(packOp, + "expects empty tensor destination"); + // Pack destination must have static shape. + if (!packOp.getDestType().hasStaticShape()) + return rewriter.notifyMatchFailure( + packOp, "expects destination with static shape"); + + // If it is a splat constant, skip and let tensor.pack folder to handle this + // case. + if (denseAttr.isSplat()) + return rewriter.notifyMatchFailure( + packOp, "skip pack - existing folder covers constant splats"); + + return linalg::lowerPack(rewriter, packOp); + } +}; + +static void populateConstantFoldPacking(RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + patterns.add(ctx); + linalg::FillOp::getCanonicalizationPatterns(patterns, ctx); + tensor::PackOp::getCanonicalizationPatterns(patterns, ctx); + tensor::populateRewriteAsConstantPatterns( + patterns, [](OpOperand *) -> bool { return true; }); + linalg::populateConstantFoldLinalgOperations( + patterns, [](OpOperand *) -> bool { return true; }); +} + +struct EliminateDummyPack : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + if (packOp.getStaticInnerTiles().empty() && + packOp.getInnerTiles().empty()) { + auto outerPerm = packOp.getOuterDimsPerm(); + for (int64_t i = 0; i < static_cast(outerPerm.size()); ++i) { + if (outerPerm[i] != i) { + return rewriter.notifyMatchFailure(packOp, "Not dummy"); + } + } + auto source = packOp.getSource(); + rewriter.replaceAllOpUsesWith(packOp, source); + packOp->erase(); + return success(); + } else { + return rewriter.notifyMatchFailure(packOp, "Not dummy"); + } + } +}; + +struct EliminateDummyUnpack : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, + PatternRewriter &rewriter) const override { + if (unpackOp.getStaticInnerTiles().empty() && + unpackOp.getInnerTiles().empty()) { + auto outerPerm = unpackOp.getOuterDimsPerm(); + for (int64_t i = 0; i < static_cast(outerPerm.size()); ++i) { + if (outerPerm[i] != i) { + return rewriter.notifyMatchFailure(unpackOp, "Not dummy"); + } + } + auto source = unpackOp.getSource(); + rewriter.replaceAllOpUsesWith(unpackOp, source); + unpackOp->erase(); + return success(); + } else { + return rewriter.notifyMatchFailure(unpackOp, "Not dummy"); + } + } +}; + +static void populateEliminateDummyPackUnpack(RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + patterns.add(ctx); +} + +class PostProcessPackUnpack + : public impl::PostProcessPackUnpackBase { +public: + using impl::PostProcessPackUnpackBase< + PostProcessPackUnpack>::PostProcessPackUnpackBase; + void runOnOperation() final; +}; + +static void populateSimplifyPacking(RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + tensor::populateSimplifyPackAndUnpackPatterns(patterns); + tensor::populateFoldTensorEmptyPatterns(patterns); + tensor::PackOp::getCanonicalizationPatterns(patterns, ctx); + tensor::UnPackOp::getCanonicalizationPatterns(patterns, ctx); + tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); + tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, ctx); + tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); + tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); + tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx); + tensor::PadOp::getCanonicalizationPatterns(patterns, ctx); + tensor::ParallelInsertSliceOp::getCanonicalizationPatterns(patterns, ctx); + scf::ForallOp::getCanonicalizationPatterns(patterns, ctx); + ctx->getLoadedDialect()->getCanonicalizationPatterns( + patterns); + tensor::populateReassociativeReshapeFoldingPatterns(patterns); +} + +void PostProcessPackUnpack::runOnOperation() { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + // constant fold packing and transpose + populateConstantFoldPacking(patterns); + // simplify packing + populateSimplifyPacking(patterns); + populateEliminateDummyPackUnpack(patterns); + // simplify transpose inserted to perform packing + linalg::TransposeOp::getCanonicalizationPatterns(patterns, ctx); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +} // namespace gc +} // namespace mlir diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp new file mode 100644 index 000000000..df534c387 --- /dev/null +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -0,0 +1,824 @@ +//===-- PropagateLayout.cpp - Propagate packing on named ops ----*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "gc/Transforms/Transforms.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/DenseMap.h" + +#include "gc/Analysis/MatmulConfigAnalysis.h" +#include "gc/Dialect/Linalgx/LinalgxDialect.h" +#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "gc/Dialect/Linalgx/Utils.h" +#include "gc/Transforms/Passes.h" +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_PROPAGATELAYOUTONNAMEDOPS +#include "gc/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "named-op-layout-propagation" + +using namespace mlir; +using namespace mlir::arith; +using namespace mlir::tensor; + +// insert pack when innerPosDims is non-empty +// insert linalg.transpose otherwise +static Value insertLayoutPack(RewriterBase &rewriter, Location loc, Value input, + Value dest, ArrayRef innerDimsPos, + ArrayRef innerTiles, + ArrayRef outerDimsPerm) { + if (!innerDimsPos.empty()) { + auto zeroAttr = rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); + Value zero = rewriter.create(loc, zeroAttr); + // TODO(yifei): correct the padding value here + return rewriter.create(loc, input, dest, innerDimsPos, + innerTiles, zero, outerDimsPerm); + } + if (!TensorLayout::isPlainOuterAxis(outerDimsPerm)) + return rewriter.create(loc, input, dest, outerDimsPerm) + .getResults()[0]; + return input; +} + +// insert unpack when innerPosDims is non-empty +// insert linalg.transpose otherwise +static Value insertLayoutUnpack(RewriterBase &rewriter, Location loc, + Value input, Value dest, + ArrayRef innerDimsPos, + ArrayRef innerTiles, + ArrayRef outerDimsPerm) { + if (!innerDimsPos.empty()) { + return rewriter.create(loc, input, dest, innerDimsPos, + innerTiles, outerDimsPerm); + } + if (!TensorLayout::isPlainOuterAxis(outerDimsPerm)) { + // inverse the permutationVector + SmallVector permAxes(outerDimsPerm.size()); + for (auto [idx, axis] : llvm::enumerate(outerDimsPerm)) { + permAxes[axis] = idx; + } + return rewriter.create(loc, input, dest, permAxes) + .getResults()[0]; + } + return input; +} + +static SmallVector getPackedAxes(ArrayRef dimensions, + const TensorLayout &targetLayout) { + SmallVector result; + // permuting on outer axis + auto outerPerm = targetLayout.getOuterAxis(); + for (int64_t dim : dimensions) { + auto pos = std::find(outerPerm.begin(), outerPerm.end(), dim); + assert(pos != outerPerm.end() && "dimension must be within output perm."); + result.push_back(std::distance(outerPerm.begin(), pos)); + } + // inserting inner axis + auto innerPos = targetLayout.getInnerAxis(); + for (size_t i = 0; i < dimensions.size(); ++i) { + if (std::find(innerPos.begin(), innerPos.end(), dimensions[i]) != + innerPos.end()) { + result.push_back(i + targetLayout.getOuterAxis().size()); + } + } + return result; +} + +static SmallVector getPackedPermAxes(ArrayRef plainPermAxes, + TensorLayout inputLayout, + TensorLayout outputLayout) { + // dim(result, i) = dim(input, permutation[i]) + // input: permutation[i] --> output: i + // input: permutation[i] --> packed input: std::find(permutation[i]) - begin() + // output: i --> packed output: std::find(permutation[i]) - begin() + int64_t packedRank = + outputLayout.getInnerAxis().size() + outputLayout.getOuterAxis().size(); + SmallVector result(packedRank, 0); + SmallVector inputCount(inputLayout.getOuterAxis().size(), 0); + auto axisPlainToPacked = inputLayout.getPlainToPackedAxisMapping(); + for (int64_t i = 0; i < packedRank; ++i) { + // packedOutput[i] --> originalOutputAxis --> originalInputAxis + int64_t originalOutputAxis = outputLayout.getPlainAxis(i); + int64_t originalInputAxis = plainPermAxes[originalOutputAxis]; + SmallVector packedInputAxes = axisPlainToPacked[originalInputAxis]; + result[i] = packedInputAxes[inputCount[originalInputAxis]++]; + } + return result; +} + +static int64_t applyPermutationAndReindexReassoc( + SmallVector &reassocIndices, + ArrayRef permutation) { + if (!permutation.empty()) + applyPermutationToVector(reassocIndices, permutation); + int64_t nextPos = 0; + for (ReassociationIndices &indices : reassocIndices) { + for (auto &index : indices) { + index = nextPos; + nextPos += 1; + } + } + return nextPos; +} + +LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, + const OperatorLayout &opLayout) { + LLVM_DEBUG(llvm::dbgs() << "Try packing named op " + << linalgOp.getOperation()->getName() << ".\n"); + Location loc = linalgOp->getLoc(); + SmallVector packOps; + SmallVector unPackOps; + SmallVector inputsAndInits, results; + SmallVector initOperands = llvm::to_vector(llvm::map_range( + linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); + SmallVector inputOperands = linalgOp.getDpsInputOperands(); + SmallVector inputLayouts = opLayout.getSupportedInputLayouts(); + SmallVector initLayouts = opLayout.getSupportedOutputLayouts(); + // check all inputs and inits are tensor, otherwise no need for layout + // propagation + if (!gc::utils::hasAllTensorSemantics(linalgOp)) { + LLVM_DEBUG(llvm::dbgs() << "All inputs and outputs of linalg op: " + << linalgOp.getOperation()->getName() + << " shall be tensor. Skip layout packing.\n"); + return failure(); + } + for (const auto &operandsList : {inputOperands, initOperands}) { + for (OpOperand *opOperand : operandsList) { + size_t pos = opOperand->getOperandNumber(); + Value operand = opOperand->get(); + TensorLayout targetLayout = pos >= inputLayouts.size() + ? initLayouts[pos - inputLayouts.size()] + : inputLayouts[pos]; + SmallVector outerPerm = targetLayout.getOuterAxis(); + SmallVector innerPos = targetLayout.getInnerAxis(); + SmallVector innerPackSizes = targetLayout.getTileSizes(); + Value dest = tensor::PackOp::createDestinationTensor( + rewriter, loc, operand, innerPackSizes, innerPos, outerPerm); + ShapedType operandType = cast(operand.getType()); + bool areConstantTiles = + llvm::all_of(innerPackSizes, [](OpFoldResult tile) { + return getConstantIntValue(tile).has_value(); + }); + if (areConstantTiles && operandType.hasStaticShape()) { + // TODO(yifei): use masked operation or choose the correct padding value + // to ensure computation correctness + packOps.push_back(insertLayoutPack( + rewriter, loc, operand, dest, innerPos, innerPackSizes, outerPerm)); + } else { + LLVM_DEBUG( + llvm::dbgs() + << "Packing of linalg op " << linalgOp.getOperation()->getName() + << " failed due to non-constant tile sizes or dynamic shape.\n"); + return failure(); + } + inputsAndInits.push_back(packOps.back()); + } + } + + // Step 3. Build the packed op + ValueRange inputs = + ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs()); + ValueRange inits = + ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits()); + // TODO: deal with generic + linalg::LinalgOp packedLinalgOp; + if (auto reduceOp = dyn_cast(&linalgOp)) { + SmallVector packedAxes = + getPackedAxes(reduceOp->getDimensions(), inputLayouts[0]); + packedLinalgOp = rewriter.create( + loc, inits.getTypes(), inputs, inits, packedAxes); + packedLinalgOp->getRegion(0).takeBody(linalgOp->getRegion(0)); + } else if (auto broadcastOp = dyn_cast(&linalgOp)) { + SmallVector packedAxes = + getPackedAxes(broadcastOp->getDimensions(), initLayouts[0]); + packedLinalgOp = rewriter.create(loc, inputs[0], + inits[0], packedAxes); + } else if (auto transposeOp = dyn_cast(&linalgOp)) { + SmallVector packedPermAxes = getPackedPermAxes( + transposeOp->getPermutation(), inputLayouts[0], initLayouts[0]); + packedLinalgOp = rewriter.create( + loc, inputs[0], inits[0], packedPermAxes); + } else if (isa(linalgOp) || isa(linalgOp) || + isa(linalgOp) || isa(linalgOp)) { + return failure(); + } else { + packedLinalgOp = mlir::clone( + rewriter, linalgOp, SmallVector{inputsAndInits.back().getType()}, + inputsAndInits); + } + + // Step 4. Unpack all the op results. + for (OpResult result : packedLinalgOp->getResults()) { + int64_t resultNum = result.getResultNumber(); + assert(resultNum < static_cast(initLayouts.size()) && + "Linalg op results num exceeds inits num."); + // Build the symmetrical UnPackOp to the existing PackOp. + unPackOps.push_back(insertLayoutUnpack( + rewriter, packedLinalgOp->getLoc(), result, + initOperands[resultNum]->get(), initLayouts[resultNum].getInnerAxis(), + initLayouts[resultNum].getTileSizes(), + initLayouts[resultNum].getOuterAxis())); + results.push_back(unPackOps.back()); + } + + // Step 5. Replace `linalgOp`. + rewriter.replaceOp(linalgOp, results); + return success(); +} + +// check whether non-contraction packable ops are already packed or not +static bool checkPacked(Operation *op, const OperatorLayout &opLayout) { + // check whether rank match + if (auto linalgOp = dyn_cast(op)) { + assert(linalgOp.getDpsInits().size() == + opLayout.getSupportedOutputLayouts().size() && + linalgOp.getDpsInputs().size() == + opLayout.getSupportedInputLayouts().size()); + for (auto [index, layout] : + llvm::enumerate(opLayout.getSupportedInputLayouts())) { + // if dimension mismatch, then the op itself is already packed + if (layout.getOuterAxis().size() != + cast(linalgOp.getDpsInputs()[index].getType()) + .getShape() + .size()) + return true; + } + for (auto [index, layout] : + llvm::enumerate(opLayout.getSupportedOutputLayouts())) { + // if dimension mismatch, then the op itself is already packed + if (layout.getOuterAxis().size() != + cast(linalgOp.getDpsInits()[index].getType()) + .getShape() + .size()) + return true; + } + } else { + assert(op->getNumOperands() == 1 && op->getNumResults() == 1); + } + return false; +} + +using ControlPackNamedOpsFn = + std::function(Operation *)>; + +class PropagateLayoutOnNamedOps + : public impl::PropagateLayoutOnNamedOpsBase { +public: + using impl::PropagateLayoutOnNamedOpsBase< + PropagateLayoutOnNamedOps>::PropagateLayoutOnNamedOpsBase; + void runOnOperation() final; +}; + +template +static void packReshapeOp(T reshapeOp, IRRewriter &rewriter, + const OperatorLayout &opLayout) { + Location loc = reshapeOp->getLoc(); + TensorLayout inputLayout = opLayout.getSupportedInputLayouts()[0]; + TensorLayout outputLayout = opLayout.getSupportedOutputLayouts()[0]; + Value curSrc = reshapeOp.getSrc(); + Value curDst = reshapeOp.getResult(); + Value dest = tensor::PackOp::createDestinationTensor( + rewriter, loc, curSrc, inputLayout.getTileSizes(), + inputLayout.getInnerAxis(), inputLayout.getOuterAxis()); + Value packedSource = + insertLayoutPack(rewriter, loc, curSrc, dest, inputLayout.getInnerAxis(), + inputLayout.getTileSizes(), inputLayout.getOuterAxis()); + SmallVector newReassocIndices = + reshapeOp.getReassociationIndices(); + TensorLayout shorterSide = inputLayout.getRank() > outputLayout.getRank() + ? outputLayout + : inputLayout; + int64_t nextPos = applyPermutationAndReindexReassoc( + newReassocIndices, shorterSide.getOuterAxis()); + // Then add direct mapping for the inner tile dims. + for (size_t i = 0; i < inputLayout.getInnerAxis().size(); ++i) { + newReassocIndices.push_back({nextPos}); + nextPos += 1; + } + RankedTensorType newReshapeType = tensor::PackOp::inferPackedType( + dyn_cast(curDst.getType()), + *getConstantIntValues(outputLayout.getTileSizes()), + outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); + Value packedReshapeShape = + rewriter.create(loc, newReshapeType, packedSource, newReassocIndices); + Value unpackDest = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, packedReshapeShape, outputLayout.getTileSizes(), + outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); + Value newUnPackOp = insertLayoutUnpack( + rewriter, loc, packedReshapeShape, unpackDest, + outputLayout.getInnerAxis(), outputLayout.getTileSizes(), + outputLayout.getOuterAxis()); + rewriter.replaceOp(reshapeOp, newUnPackOp); +} + +LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, + ControlPackNamedOpsFn controlFn) { + IRRewriter rewriter(ctx); + graph->walk([&](Operation *op) { + if (mlir::gc::utils::isPackableOp(op)) { + LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName() << " visited.\n"); + if (failed(controlFn(op))) { + LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName() + << " does not have layout information.\n"); + return WalkResult::skip(); + } + OperatorLayout opLayout = *controlFn(op); + LLVM_DEBUG(llvm::dbgs() + << "Packing op " << op->getName() << " into inferred layout:\n" + << opLayout << "\n"); + if (opLayout.isPlain()) { + LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName() + << " has plain layout, skip packing.\n"); + return WalkResult::advance(); + } + if (checkPacked(op, opLayout)) { + LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName() + << " is already packed, skip packing.\n"); + return WalkResult::advance(); + } + // pack op into ideal layout + LLVM_DEBUG(llvm::dbgs() + << "Packing op " << op->getName() << " into inferred layout:\n" + << opLayout << "\n"); + // insert pack + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + if (auto linalgOp = dyn_cast(op)) { + if (failed(packLinalgOp(rewriter, linalgOp, opLayout))) { + return WalkResult::skip(); + } + } else if (auto expandShapeOp = dyn_cast(op)) { + packReshapeOp(expandShapeOp, rewriter, opLayout); + } else if (auto collapseShapeOp = dyn_cast(op)) { + packReshapeOp(collapseShapeOp, rewriter, + opLayout); + } else if (auto padOp = dyn_cast(op)) { + Location loc = padOp->getLoc(); + TensorLayout inputLayout = opLayout.getSupportedInputLayouts()[0]; + Value curSrc = padOp.getSource(); + Value curDest = padOp.getResult(); + SmallVector outerDimsPerm = inputLayout.getOuterAxis(); + SmallVector innerDimsPos = inputLayout.getInnerAxis(); + SmallVector tileSizes = inputLayout.getTileSizes(); + Value dest = tensor::PackOp::createDestinationTensor( + rewriter, loc, curSrc, tileSizes, innerDimsPos, outerDimsPerm); + Value packedSource = + insertLayoutPack(rewriter, loc, curSrc, dest, innerDimsPos, + tileSizes, outerDimsPerm); + // update lowPad and highPad + SmallVector lowPad = padOp.getMixedLowPad(); + SmallVector highPad = padOp.getMixedHighPad(); + applyPermutationToVector(lowPad, outerDimsPerm); + applyPermutationToVector(highPad, outerDimsPerm); + lowPad.append(innerDimsPos.size(), rewriter.getIndexAttr(0)); + highPad.append(innerDimsPos.size(), rewriter.getIndexAttr(0)); + auto packedPadOp = rewriter.create( + loc, /*result=*/Type(), packedSource, lowPad, highPad, + padOp.getConstantPaddingValue(), padOp.getNofold()); + Value newUnPackOp = + insertLayoutUnpack(rewriter, loc, packedPadOp, curDest, + innerDimsPos, tileSizes, outerDimsPerm); + rewriter.replaceOp(padOp, newUnPackOp); + } + } + return WalkResult::advance(); + }); + return success(); +} + +template +static LogicalResult packVNNIMMT4D(RewriterBase &rewriter, OpTy mmt4dOp) { + auto elementType = getElementTypeOrSelf(mmt4dOp.getInputs()[0].getType()); + if (!elementType.isBF16() && !elementType.isInteger(8)) + return rewriter.notifyMatchFailure(mmt4dOp, "require bf16/int8 data type"); + Location loc = mmt4dOp.getLoc(); + // BNKnk --> BNKkn2k + auto weightShape = + cast(mmt4dOp.getInputs()[1].getType()).getShape(); + int64_t weightRank = weightShape.size(); + // pack innermost k axis + SmallVector innerPos{weightRank - 1}; + int64_t blockingFactor = elementType.isBF16() ? 2 : 4; + SmallVector tileSize{rewriter.getIndexAttr(blockingFactor)}; + // BNKnk --> BNKkn2k + int64_t batchDimSize = weightRank - 4; + SmallVector batchPerm(batchDimSize, 0); + std::iota(batchPerm.begin(), batchPerm.end(), 0); + SmallVector outerPerm{batchDimSize, batchDimSize + 1, + batchDimSize + 3, batchDimSize + 2}; + outerPerm.insert(outerPerm.begin(), batchPerm.begin(), batchPerm.end()); + OpOperand *RHSOperand = mmt4dOp.getDpsInputOperand(1); + Value dest = tensor::PackOp::createDestinationTensor( + rewriter, loc, RHSOperand->get(), tileSize, innerPos, outerPerm); + auto zeroAttr = rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); + Value zero = rewriter.create(loc, zeroAttr); + Value VNNIPack = rewriter.create( + loc, RHSOperand->get(), dest, innerPos, tileSize, zero, outerPerm); + // check whether VNNIPack causes padding + int64_t innermostKDim = weightShape[weightRank - 1]; + int64_t paddingSize = (innermostKDim % blockingFactor) + ? (blockingFactor - innermostKDim % blockingFactor) + : 0; + assert(!paddingSize && "Padding shall not be introduced by VNNI pack."); + SmallVector inputsValues{mmt4dOp.getInputs()[0], VNNIPack}; + FailureOr op = linalgx::makeGenericPackedMatmulOp( + rewriter, loc, linalgx::PackingType::VNNI_MM4D, inputsValues, + mmt4dOp.getDpsInits()); + if (failed(op)) + return failure(); + rewriter.replaceOp(mmt4dOp, *op); + return success(); +} + +// strictly check whether the packed matmul is BMKmk & BNKkn +static bool isMM4DMatmul(linalg::GenericOp matmulOp) { + return linalgx::isGenericPackedMatmulOp(matmulOp.getOperation(), + linalgx::PackingType::MM4D); +} + +/* +If possible, pack to Mm2DVnniOp or Mm4DVnniOp. +If not possible, pack to GenericOp. +*/ +static LogicalResult packVNNIGeneric(RewriterBase &rewriter, + linalg::GenericOp matmulOp) { + if (matmulOp.getDpsInputs().size() != 2) + return rewriter.notifyMatchFailure(matmulOp, "require 2 inputs"); + + auto elementType = getElementTypeOrSelf(matmulOp.getInputs()[0].getType()); + if (!elementType.isBF16() && !elementType.isInteger(8)) + return rewriter.notifyMatchFailure(matmulOp, "require bf16/int8 data type"); + + if (matmulOp.hasDynamicShape()) + return rewriter.notifyMatchFailure(matmulOp, "require static shape"); + + if (matmulOp.hasPureBufferSemantics()) + return rewriter.notifyMatchFailure(matmulOp, "require tensor semantics"); + + if (!mlir::linalg::isaContractionOpInterface(matmulOp)) + return rewriter.notifyMatchFailure(matmulOp, "require matmul semantics"); + + // check whether generic op is packed as BMKmk & BNKkn + if (!isMM4DMatmul(matmulOp)) + return rewriter.notifyMatchFailure(matmulOp, + "require packed MM4D matmul semantics"); + + OpOperand &weight = matmulOp->getOpOperand(1); + // TODO(yifei): check ISA feasibility + Location loc = matmulOp.getLoc(); + int64_t blockingFactor = elementType.isBF16() ? 2 : 4; + SmallVector tileSize{rewriter.getIndexAttr(blockingFactor)}; + // BNKkn, get weight's rank + auto weightShape = + cast(matmulOp.getInputs()[1].getType()).getShape(); + int64_t weightRank = weightShape.size(); + auto innerPos = SmallVector{weightRank - 2}; + // pack weight + Value dest = tensor::PackOp::createDestinationTensor( + rewriter, loc, weight.get(), tileSize, innerPos, SmallVector{}); + auto zeroAttr = rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); + Value zero = rewriter.create(loc, zeroAttr); + Value VNNIPack = rewriter.create(loc, weight.get(), dest, + innerPos, tileSize, zero); + + SmallVector inputsValues{matmulOp.getInputs()[0], VNNIPack}; + // check whether VNNIPack causes padding, weightShape is BNKkn + int64_t innermostKDim = weightShape[weightRank - 2]; + int64_t paddingSize = (innermostKDim % blockingFactor) + ? (blockingFactor - innermostKDim % blockingFactor) + : 0; + assert(!paddingSize && "Padding shall not be introduced by VNNI pack."); + FailureOr op = linalgx::makeGenericPackedMatmulOp( + rewriter, loc, linalgx::PackingType::VNNI_MM4D, inputsValues, + matmulOp.getDpsInits()); + if (failed(op)) + return failure(); + rewriter.replaceOp(matmulOp, *op); + return success(); +} + +template struct PackVNNI : public OpRewritePattern { + PackVNNI(MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(OpTy linalgOp, + PatternRewriter &rewriter) const override { + if (failed(packVNNIMMT4D(rewriter, linalgOp))) + return failure(); + return success(); + } +}; + +template <> +struct PackVNNI + : public OpRewritePattern { + PackVNNI(MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + LogicalResult matchAndRewrite(linalg::GenericOp matmulOp, + PatternRewriter &rewriter) const override { + if (failed(packVNNIGeneric(rewriter, matmulOp))) + return failure(); + return success(); + } +}; + +static linalgx::PackingType revertToPackingType(linalg::GenericOp matmulOp) { + if (linalgx::isGenericPackedMatmulOp(matmulOp.getOperation(), + linalgx::PackingType::MM4D)) + return linalgx::PackingType::MM2D4D; + else if (linalgx::isGenericPackedMatmulOp(matmulOp.getOperation(), + linalgx::PackingType::VNNI_MM4D)) + return linalgx::PackingType::VNNI_MM2D; + else + assert(false && + "Unexpected generic op encountered in matmul reversion stage."); +} + +static bool isPlainActivationMatmul(const OperatorLayout &matmulLayout) { + auto inputLayout = matmulLayout.getSupportedInputLayouts()[0]; + auto outputLayout = matmulLayout.getSupportedInputLayouts()[0]; + return !inputLayout.isBlocking() && !outputLayout.isBlocking(); +} + +static LogicalResult +revertMatmulPacking(MLIRContext *ctx, mlir::Operation *graph, + const std::vector &matmulLayouts) { + IRRewriter rewriter(ctx); + uint64_t layoutIndex = 0; + auto result = graph->walk([&](Operation *op) { + if (auto matmulOp = dyn_cast(op)) { + if (linalgx::isGenericPackedMatmulOp(matmulOp.getOperation(), + linalgx::PackingType::MM4D, + linalgx::PackingType::VNNI_MM4D)) { + if (isPlainActivationMatmul(matmulLayouts[layoutIndex])) { + linalgx::PackingType revertType = revertToPackingType(matmulOp); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + // replace matmul 4D with unpack + matmul 2D + pack + auto packInputOp = matmulOp.getDpsInputOperand(0) + ->get() + .getDefiningOp(); + auto packInitOp = matmulOp.getDpsInitOperand(0) + ->get() + .getDefiningOp(); + if (!packInputOp || !packInitOp) + return WalkResult::skip(); + if (!matmulOp.getResults()[0].hasOneUse()) + return WalkResult::skip(); + auto consumer = matmulOp.getResults()[0].getUses().begin(); + auto unPackOp = dyn_cast(consumer->getOwner()); + if (!unPackOp) + return WalkResult::skip(); + Location loc = matmulOp.getLoc(); + // unpack input + auto packInputInnerTiles = packInputOp.getMixedTiles(); + auto packInputInnerDimsPos = packInputOp.getInnerDimsPos(); + auto packInputOuterDimsPerm = packInputOp.getInnerDimsPos(); + llvm::SmallVector unpackInputInnerDimsPos( + packInputInnerDimsPos); + // eliminate the transpose semantic in unpack + llvm::SmallDenseMap axisMapping; + if (!packInputOuterDimsPerm.empty()) { + for (auto [index, axis] : llvm::enumerate(packInputOuterDimsPerm)) { + axisMapping[axis] = index; + } + for (size_t i = 0; i < packInputOuterDimsPerm.size(); ++i) { + unpackInputInnerDimsPos[i] = + axisMapping[unpackInputInnerDimsPos[i]]; + } + } + Value unpackInputDest = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, packInputOp, packInputInnerTiles, + unpackInputInnerDimsPos, ArrayRef{}); + Value reUnpackInput = + insertLayoutUnpack(rewriter, loc, packInputOp, unpackInputDest, + unpackInputInnerDimsPos, packInputInnerTiles, + ArrayRef{}); + // unpack init + auto packInitInnerTiles = packInitOp.getMixedTiles(); + auto packInitInnerDimsPos = packInitOp.getInnerDimsPos(); + auto packInitOuterDimsPerm = packInitOp.getInnerDimsPos(); + // assert packInitOuterDimsPerm is not permuted + if (!packInitOuterDimsPerm.empty()) { + for (auto [index, dim] : llvm::enumerate(packInitOuterDimsPerm)) { + if (static_cast(index) != dim) + assert(false && "Packed matmul's init pack shall not contain " + "permutation semantics."); + } + } + Value unpackInitDest = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, packInitOp, packInitInnerTiles, + packInitInnerDimsPos, packInitOuterDimsPerm); + Value reUnpackInit = insertLayoutUnpack( + rewriter, loc, packInitOp, unpackInitDest, packInitInnerDimsPos, + packInitInnerTiles, ArrayRef{}); + // replace matmul 4D with matmul 2D + auto matmul2D = linalgx::makeGenericPackedMatmulOp( + rewriter, loc, revertType, + ValueRange{reUnpackInput, matmulOp.getDpsInputOperand(1)->get()}, + ValueRange{reUnpackInit}); + if (failed(matmul2D)) + return WalkResult::interrupt(); + // insert pack before unpack + auto unPackInnerTiles = unPackOp.getMixedTiles(); + auto unPackInnerDimsPos = unPackOp.getInnerDimsPos(); + auto unPackOuterDimsPerm = unPackOp.getInnerDimsPos(); + Value packDest = tensor::PackOp::createDestinationTensor( + rewriter, loc, (*matmul2D)->getResult(0), unPackInnerTiles, + unPackInnerDimsPos, unPackOuterDimsPerm); + auto zeroAttr = + rewriter.getZeroAttr(getElementTypeOrSelf(packDest.getType())); + Value zero = rewriter.create(loc, zeroAttr); + Value rePack = rewriter.create( + loc, (*matmul2D)->getResult(0), packDest, unPackInnerDimsPos, + unPackInnerTiles, zero, unPackOuterDimsPerm); + rewriter.replaceOp(op, rePack); + } + layoutIndex++; + } + } else if (auto matmulOp = dyn_cast(op)) { + if (mlir::gc::utils::isSupportedContractionNamedOp(matmulOp)) { + layoutIndex++; + } + } + return WalkResult::advance(); + }); + if (result.wasInterrupted() || result.wasSkipped()) + return failure(); // reversion not performed as expected + if (layoutIndex != matmulLayouts.size()) + return failure(); // layout index mismatch, reversion failed + return success(); +} + +/* +Match patterns like broadcast + pack, uplift pack +*/ +struct UpliftPackOverBroadcast : public OpRewritePattern { + UpliftPackOverBroadcast(MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + LogicalResult matchAndRewrite(tensor::PackOp pack, + PatternRewriter &rewriter) const override { + auto broadcastOp = pack.getSource().getDefiningOp(); + if (!broadcastOp || !broadcastOp.getResult()[0].hasOneUse()) { + return failure(); + } + SmallVector innerTileSizes = pack.getStaticTiles(); + SmallVector innerDimsPos(pack.getInnerDimsPos()); + SmallVector outerDimsPerm(pack.getOuterDimsPerm()); + int64_t rank = + cast(pack.getSource().getType()).getShape().size(); + if (outerDimsPerm.empty()) { + outerDimsPerm.resize(rank); + std::iota(outerDimsPerm.begin(), outerDimsPerm.end(), 0); + } + ArrayRef broadcastAxis = broadcastOp.getDimensions(); + SmallVector newInnerDimsPos, newOuterDimsPerm, packedBroadcastAxis; + SmallVector newInnerTileSizes; + llvm::SmallDenseMap axisMapping; + int64_t axisCounter = 0; + for (int64_t axis = 0; axis < rank; ++axis) { + if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) == + broadcastAxis.end()) { + // if the axis is not broadcasted, keep it + axisMapping[axis] = axisCounter++; + } + } + // update broadcast dims + for (auto [index, axis] : llvm::enumerate(outerDimsPerm)) { + if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) != + broadcastAxis.end()) { + packedBroadcastAxis.push_back(index); + } + } + for (auto [index, axis] : llvm::enumerate(innerDimsPos)) { + if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) != + broadcastAxis.end()) { + packedBroadcastAxis.push_back(index + rank); + } + } + // update packing axis + for (auto [index, axis] : llvm::enumerate(outerDimsPerm)) { + if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) == + broadcastAxis.end()) { + newOuterDimsPerm.push_back(axisMapping[axis]); + } + } + for (auto [index, axis] : llvm::enumerate(innerDimsPos)) { + if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) == + broadcastAxis.end()) { + newInnerDimsPos.push_back(axisMapping[axis]); + newInnerTileSizes.push_back( + rewriter.getIndexAttr(innerTileSizes[index])); + } + } + // replace ops + auto loc = broadcastOp.getLoc(); + auto dest = tensor::PackOp::createDestinationTensor( + rewriter, loc, broadcastOp.getDpsInputs()[0], newInnerTileSizes, + newInnerDimsPos, newOuterDimsPerm); + Value packedSource = + insertLayoutPack(rewriter, loc, broadcastOp.getDpsInputs()[0], dest, + newInnerDimsPos, newInnerTileSizes, newOuterDimsPerm); + auto newBroadcastOp = rewriter.create( + loc, packedSource, pack.getDest(), packedBroadcastAxis); + rewriter.replaceOp(pack, newBroadcastOp.getResults()); + return success(); + } +}; + +void PropagateLayoutOnNamedOps::runOnOperation() { + MLIRContext *ctx = &getContext(); + IRRewriter rewriter(ctx); + mlir::Operation *graph = getOperation(); + auto &layoutAnalysisResult = getAnalysis(); + + // pre-collect matmul layouts + std::vector matmulLayouts; + graph->walk([&](Operation *op) { + if (auto linalgOp = dyn_cast(op)) { + if (mlir::gc::utils::isSupportedContractionNamedOp(linalgOp)) { + matmulLayouts.push_back(*(layoutAnalysisResult.getOpLayout(op))); + } + } + return WalkResult::advance(); + }); + + // stage 1.1: pack matmul with `BlockPackMatmulPatterns` if any side of the + // matmul op requires packing + RewritePatternSet packMatmulPatterns(&getContext()); + mlir::linalg::ControlBlockPackMatmulFn packMatmulControlFn = + [&](linalg::LinalgOp op) -> mlir::linalg::BlockPackMatmulOptions { + mlir::linalg::BlockPackMatmulOptions options; + FailureOr matmulLayout = + layoutAnalysisResult.getOpLayout(op); + if (failed(matmulLayout)) + return options; // return default options to skip packing + TensorLayout inputLayout = matmulLayout->getSupportedInputLayouts()[0]; + TensorLayout weightLayout = matmulLayout->getSupportedInputLayouts()[1]; + TensorLayout outputLayout = matmulLayout->getSupportedOutputLayouts()[0]; + if (!inputLayout.isBlocking() && !weightLayout.isBlocking() && + !outputLayout.isBlocking()) + return options; // return default options to skip packing + // specify B side as be NKkn + options.rhsTransposeOuterBlocks = true; + options.rhsTransposeInnerBlocks = false; + // extract tile sizes + auto matmulCfg = MatmulConfigAnalysis(op.getOperation()).getConfig(); + OpFoldResult MBlock = rewriter.getIndexAttr(matmulCfg.innerMostMBlock), + KBlock = rewriter.getIndexAttr(matmulCfg.innerMostKBlock), + NBlock = rewriter.getIndexAttr(matmulCfg.innerMostNBlock); + options.blockFactors = SmallVector{ + *getConstantIntValue(MBlock), *getConstantIntValue(NBlock), + *getConstantIntValue(KBlock)}; + return options; + }; + linalg::populateBlockPackMatmulPatterns(packMatmulPatterns, + packMatmulControlFn); + if (failed( + applyPatternsAndFoldGreedily(graph, std::move(packMatmulPatterns)))) + return signalPassFailure(); + + // stage 1.2: pack VNNI + RewritePatternSet packVNNIPatterns(&getContext()); + packVNNIPatterns.add, PackVNNI, + PackVNNI>(ctx); + if (failed(applyPatternsAndFoldGreedily(graph, std::move(packVNNIPatterns)))) + return signalPassFailure(); + + // stage 1.3: revert packed matmul from blocking activation to plain + // activation + if (failed(revertMatmulPacking(ctx, graph, matmulLayouts))) + return signalPassFailure(); + + // stage 2: propagate layout on other named ops + ControlPackNamedOpsFn layoutControlFn = + [&](Operation *op) -> FailureOr { + return layoutAnalysisResult.getOpLayout(op); + }; + if (failed(namedOpLayoutPropagation(ctx, graph, layoutControlFn))) + return signalPassFailure(); + + // stage 3: uplift pack through broadcast + RewritePatternSet upliftPatterns(&getContext()); + upliftPatterns.add(ctx); + if (failed(applyPatternsAndFoldGreedily(graph, std::move(upliftPatterns)))) + return signalPassFailure(); +} + +} // namespace gc +} // namespace mlir diff --git a/test/mlir/test/gc/Transforms/pack-llama-mlp.mlir b/test/mlir/test/gc/Transforms/pack-llama-mlp.mlir new file mode 100644 index 000000000..c4e5ce89f --- /dev/null +++ b/test/mlir/test/gc/Transforms/pack-llama-mlp.mlir @@ -0,0 +1,122 @@ +// RUN: gc-opt %s --split-input-file --propagate-layout-on-named-ops --post-process-pack-unpack | FileCheck %s + +// ----- + +// CHECK-LABEL: @llama2_mlp +func.func @llama2_mlp(%arg0: tensor<1x32x4096xbf16>, %arg1: tensor<4096x4096xbf16>, %arg2: tensor<1x32x4096xbf16>, %arg3: tensor<1xf32>, %arg4: tensor<4096xbf16>, %arg5: tensor<11008x4096xbf16>, %arg6: tensor<11008x4096xbf16>, %arg7: tensor<4096x11008xbf16>, %arg8: tensor<1xf32>, %arg9: tensor<4096xbf16>) -> (tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) { + %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst = arith.constant 0.000000e+00 : bf16 + %cst_0 = arith.constant 1.000000e+00 : bf16 + %0 = tensor.empty() : tensor<32x4096xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %2 = linalg.matmul_transpose_b ins(%collapsed, %arg1 : tensor<32x4096xbf16>, tensor<4096x4096xbf16>) outs(%1 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %expanded = tensor.expand_shape %2 [[0, 1], [2]] output_shape [1, 32, 4096] : tensor<32x4096xbf16> into tensor<1x32x4096xbf16> + %3 = tensor.empty() : tensor<1x32x4096xbf16> + %4 = linalg.add ins(%arg2, %expanded : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%3 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %5 = tensor.empty() : tensor<1x32x4096xf32> + %6 = linalg.copy ins(%4 : tensor<1x32x4096xbf16>) outs(%5 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_1 = arith.constant dense<2.000000e+00> : tensor<1x32x4096xf32> + %7 = tensor.empty() : tensor<1x32x4096xf32> + %8 = linalg.powf ins(%6, %cst_1 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%7 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_2 = arith.constant 0.000000e+00 : f32 + %9 = tensor.empty() : tensor<1x32xf32> + %10 = linalg.fill ins(%cst_2 : f32) outs(%9 : tensor<1x32xf32>) -> tensor<1x32xf32> + %reduced = linalg.reduce ins(%8 : tensor<1x32x4096xf32>) outs(%10 : tensor<1x32xf32>) dimensions = [2] + (%in: f32, %init: f32) { + %72 = arith.addf %in, %init : f32 + linalg.yield %72 : f32 + } + %cst_3 = arith.constant dense<4.096000e+03> : tensor<1x32xf32> + %11 = tensor.empty() : tensor<1x32xf32> + %12 = linalg.div ins(%reduced, %cst_3 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%11 : tensor<1x32xf32>) -> tensor<1x32xf32> + %expanded_4 = tensor.expand_shape %12 [[0], [1, 2]] output_shape [1, 32, 1] : tensor<1x32xf32> into tensor<1x32x1xf32> + %13 = tensor.empty() : tensor<1x32x1xf32> + %broadcasted = linalg.broadcast ins(%arg8 : tensor<1xf32>) outs(%13 : tensor<1x32x1xf32>) dimensions = [0, 1] + %14 = tensor.empty() : tensor<1x32x1xf32> + %15 = linalg.add ins(%expanded_4, %broadcasted : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%14 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %cst_5 = arith.constant dense<-5.000000e-01> : tensor<1x32x1xf32> + %16 = tensor.empty() : tensor<1x32x1xf32> + %17 = linalg.powf ins(%15, %cst_5 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%16 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %collapsed_6 = tensor.collapse_shape %17 [[0], [1, 2]] : tensor<1x32x1xf32> into tensor<1x32xf32> + %18 = tensor.empty() : tensor<1x32x4096xf32> + %broadcasted_7 = linalg.broadcast ins(%collapsed_6 : tensor<1x32xf32>) outs(%18 : tensor<1x32x4096xf32>) dimensions = [2] + %19 = tensor.empty() : tensor<1x32x4096xf32> + %20 = linalg.mul ins(%6, %broadcasted_7 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%19 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %21 = tensor.empty() : tensor<1x32x4096xbf16> + %22 = linalg.copy ins(%20 : tensor<1x32x4096xf32>) outs(%21 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %23 = tensor.empty() : tensor<1x32x4096xbf16> + %broadcasted_8 = linalg.broadcast ins(%arg4 : tensor<4096xbf16>) outs(%23 : tensor<1x32x4096xbf16>) dimensions = [0, 1] + %24 = tensor.empty() : tensor<1x32x4096xbf16> + %25 = linalg.mul ins(%broadcasted_8, %22 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%24 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %collapsed_9 = tensor.collapse_shape %25 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst_10 = arith.constant 0.000000e+00 : bf16 + %26 = tensor.empty() : tensor<32x11008xbf16> + %27 = linalg.fill ins(%cst_10 : bf16) outs(%26 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %28 = linalg.matmul_transpose_b ins(%collapsed_9, %arg5 : tensor<32x4096xbf16>, tensor<11008x4096xbf16>) outs(%27 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %expanded_11 = tensor.expand_shape %28 [[0, 1], [2]] output_shape [1, 32, 11008] : tensor<32x11008xbf16> into tensor<1x32x11008xbf16> + %29 = tensor.empty() : tensor<1x32x11008xbf16> + %30 = linalg.negf ins(%expanded_11 : tensor<1x32x11008xbf16>) outs(%29 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %31 = tensor.empty() : tensor<1x32x11008xbf16> + %32 = linalg.exp ins(%30 : tensor<1x32x11008xbf16>) outs(%31 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %33 = tensor.empty() : tensor<1x32x11008xbf16> + %34 = tensor.empty() : tensor<1x32x11008xbf16> + %35 = linalg.fill ins(%cst_0 : bf16) outs(%34 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %36 = linalg.add ins(%32, %35 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%33 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %37 = tensor.empty() : tensor<1x32x11008xbf16> + %38 = linalg.reciprocal ins(%36 : tensor<1x32x11008xbf16>) outs(%37 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %39 = tensor.empty() : tensor<1x32x11008xbf16> + %40 = linalg.mul ins(%30, %expanded_11 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%39 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %collapsed_12 = tensor.collapse_shape %25 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst_13 = arith.constant 0.000000e+00 : bf16 + %41 = tensor.empty() : tensor<32x11008xbf16> + %42 = linalg.fill ins(%cst_13 : bf16) outs(%41 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %43 = linalg.matmul_transpose_b ins(%collapsed_12, %arg6 : tensor<32x4096xbf16>, tensor<11008x4096xbf16>) outs(%42 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %expanded_14 = tensor.expand_shape %43 [[0, 1], [2]] output_shape [1, 32, 11008] : tensor<32x11008xbf16> into tensor<1x32x11008xbf16> + %44 = tensor.empty() : tensor<1x32x11008xbf16> + %45 = linalg.mul ins(%40, %expanded_14 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%44 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %collapsed_15 = tensor.collapse_shape %45 [[0, 1], [2]] : tensor<1x32x11008xbf16> into tensor<32x11008xbf16> + %cst_16 = arith.constant 0.000000e+00 : bf16 + %46 = tensor.empty() : tensor<32x4096xbf16> + %47 = linalg.fill ins(%cst_16 : bf16) outs(%46 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %48 = linalg.matmul_transpose_b ins(%collapsed_15, %arg7 : tensor<32x11008xbf16>, tensor<4096x11008xbf16>) outs(%47 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %expanded_17 = tensor.expand_shape %48 [[0, 1], [2]] output_shape [1, 32, 4096] : tensor<32x4096xbf16> into tensor<1x32x4096xbf16> + %49 = tensor.empty() : tensor<1x32x4096xbf16> + %50 = linalg.add ins(%4, %expanded_17 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%49 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %51 = tensor.empty() : tensor<1x32x4096xf32> + %52 = linalg.copy ins(%50 : tensor<1x32x4096xbf16>) outs(%51 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_18 = arith.constant dense<2.000000e+00> : tensor<1x32x4096xf32> + %53 = tensor.empty() : tensor<1x32x4096xf32> + %54 = linalg.powf ins(%52, %cst_18 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%53 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_19 = arith.constant 0.000000e+00 : f32 + %55 = tensor.empty() : tensor<1x32xf32> + %56 = linalg.fill ins(%cst_19 : f32) outs(%55 : tensor<1x32xf32>) -> tensor<1x32xf32> + %reduced_20 = linalg.reduce ins(%54 : tensor<1x32x4096xf32>) outs(%56 : tensor<1x32xf32>) dimensions = [2] + (%in: f32, %init: f32) { + %72 = arith.addf %in, %init : f32 + linalg.yield %72 : f32 + } + %cst_21 = arith.constant dense<4.096000e+03> : tensor<1x32xf32> + %57 = tensor.empty() : tensor<1x32xf32> + %58 = linalg.div ins(%reduced_20, %cst_21 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%57 : tensor<1x32xf32>) -> tensor<1x32xf32> + %expanded_22 = tensor.expand_shape %58 [[0], [1, 2]] output_shape [1, 32, 1] : tensor<1x32xf32> into tensor<1x32x1xf32> + %59 = tensor.empty() : tensor<1x32x1xf32> + %broadcasted_23 = linalg.broadcast ins(%arg8 : tensor<1xf32>) outs(%59 : tensor<1x32x1xf32>) dimensions = [0, 1] + %60 = tensor.empty() : tensor<1x32x1xf32> + %61 = linalg.add ins(%expanded_22, %broadcasted_23 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%60 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %cst_24 = arith.constant dense<-5.000000e-01> : tensor<1x32x1xf32> + %62 = tensor.empty() : tensor<1x32x1xf32> + %63 = linalg.powf ins(%61, %cst_24 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%62 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %collapsed_25 = tensor.collapse_shape %63 [[0], [1, 2]] : tensor<1x32x1xf32> into tensor<1x32xf32> + %64 = tensor.empty() : tensor<1x32x4096xf32> + %broadcasted_26 = linalg.broadcast ins(%collapsed_25 : tensor<1x32xf32>) outs(%64 : tensor<1x32x4096xf32>) dimensions = [2] + %65 = tensor.empty() : tensor<1x32x4096xf32> + %66 = linalg.mul ins(%52, %broadcasted_26 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%65 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %67 = tensor.empty() : tensor<1x32x4096xbf16> + %68 = linalg.copy ins(%66 : tensor<1x32x4096xf32>) outs(%67 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %69 = tensor.empty() : tensor<1x32x4096xbf16> + %broadcasted_27 = linalg.broadcast ins(%arg9 : tensor<4096xbf16>) outs(%69 : tensor<1x32x4096xbf16>) dimensions = [0, 1] + %70 = tensor.empty() : tensor<1x32x4096xbf16> + %71 = linalg.mul ins(%broadcasted_27, %68 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%70 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + return %71, %50 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16> +} +// CHECK-COUNT-8: tensor.pack diff --git a/test/mlir/test/gc/Transforms/pack-matmul.mlir b/test/mlir/test/gc/Transforms/pack-matmul.mlir new file mode 100644 index 000000000..de13e0a6a --- /dev/null +++ b/test/mlir/test/gc/Transforms/pack-matmul.mlir @@ -0,0 +1,58 @@ +// RUN: gc-opt %s --split-input-file --propagate-layout-on-named-ops --post-process-pack-unpack | FileCheck %s + +// ----- + +// CHECK-LABEL: @matmul_add_plain_activation_f32 +func.func @matmul_add_plain_activation_f32(%arg0: tensor<128x64xf32>, %arg1: tensor<64x64xf32>, %arg2: tensor<64xf32>) -> tensor<128x64xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<128x64xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x64xf32>) -> tensor<128x64xf32> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xf32>, tensor<64x64xf32>) outs(%0 : tensor<128x64xf32>) -> tensor<128x64xf32> + %3 = tensor.empty() : tensor<128x64xf32> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<64xf32>) outs(%3 : tensor<128x64xf32>) dimensions = [0] + %4 = tensor.empty() : tensor<128x64xf32> + %5 = linalg.add ins(%2, %broadcasted : tensor<128x64xf32>, tensor<128x64xf32>) outs(%4 : tensor<128x64xf32>) -> tensor<128x64xf32> + return %5 : tensor<128x64xf32> +} +// CHECK-COUNT-1: tensor.pack +// CHECK-COUNT-1: linalg.generic +// CHECK: linalg.add ins(%{{.*}}, %{{.*}} : tensor<{{.*}}x{{.*}}xf32>, tensor<{{.*}}x{{.*}}xf32>) outs(%{{.*}} : tensor<{{.*}}x{{.*}}xf32>) -> tensor<{{.*}}x{{.*}}xf32> +// CHECK-NOT: tensor.unpack + +// ----- + +// CHECK-LABEL: @matmul_add_blocking_activation_f32 +func.func @matmul_add_blocking_activation_f32(%arg0: tensor<128x511xf32>, %arg1: tensor<511x255xf32>, %arg2: tensor<255xf32>) -> tensor<128x255xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<128x255xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x255xf32>) -> tensor<128x255xf32> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x511xf32>, tensor<511x255xf32>) outs(%0 : tensor<128x255xf32>) -> tensor<128x255xf32> + %3 = tensor.empty() : tensor<128x255xf32> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<255xf32>) outs(%3 : tensor<128x255xf32>) dimensions = [0] + %4 = tensor.empty() : tensor<128x255xf32> + %5 = linalg.add ins(%2, %broadcasted : tensor<128x255xf32>, tensor<128x255xf32>) outs(%4 : tensor<128x255xf32>) -> tensor<128x255xf32> + return %5 : tensor<128x255xf32> +} +// CHECK-COUNT-2: tensor.pack +// CHECK-COUNT-1: linalg.generic +// CHECK: linalg.add ins(%{{.*}}, %{{.*}} : tensor<{{.*}}x{{.*}}x{{.*}}x{{.*}}xf32>, tensor<{{.*}}x{{.*}}x{{.*}}x{{.*}}xf32>) outs(%{{.*}} : tensor<{{.*}}x{{.*}}x{{.*}}x{{.*}}xf32>) -> tensor<{{.*}}x{{.*}}x{{.*}}x{{.*}}xf32> +// CHECK-COUNT-1: tensor.unpack + +// ----- + +// CHECK-LABEL: @matmul_add_plain_activation_bf16 +func.func @matmul_add_plain_activation_bf16(%arg0: tensor<128x64xbf16>, %arg1: tensor<64x64xbf16>, %arg2: tensor<64xbf16>) -> tensor<128x64xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x64xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xbf16>, tensor<64x64xbf16>) outs(%0 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %3 = tensor.empty() : tensor<128x64xbf16> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<64xbf16>) outs(%3 : tensor<128x64xbf16>) dimensions = [0] + %4 = tensor.empty() : tensor<128x64xbf16> + %5 = linalg.add ins(%2, %broadcasted : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%4 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + return %5 : tensor<128x64xbf16> +} +// CHECK-COUNT-2: tensor.pack +// CHECK-COUNT-1: linalg.generic +// CHECK: linalg.add ins(%{{.*}}, %{{.*}} : tensor<{{.*}}x{{.*}}xbf16>, tensor<{{.*}}x{{.*}}xbf16>) outs(%{{.*}} : tensor<{{.*}}x{{.*}}xbf16>) -> tensor<{{.*}}x{{.*}}xbf16> +// CHECK-NOT: tensor.unpack