Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transform] mlir: named op layout propagation #101

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions include/gc/Analysis/GlobalAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
//===- GlobalAnalysis.h - Graph Compiler analysis pass ----------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_ANALYSIS_GLOBALANALYSIS_H
#define MLIR_ANALYSIS_GLOBALANALYSIS_H

#include <numeric>

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h"

namespace mlir {
namespace gc {

using namespace mlir;

class TensorLayout {
public:
TensorLayout(ArrayRef<int64_t> outerAxis, ArrayRef<int64_t> innerAxis,
ArrayRef<OpFoldResult> tileSizes)
: outerAxis(outerAxis), innerAxis(innerAxis), tileSizes(tileSizes) {
assert(innerAxis.size() == tileSizes.size());
}

static bool isPlainOuterAxis(ArrayRef<int64_t> outerAxis) {
for (int64_t i = 0; i < static_cast<int64_t>(outerAxis.size()); ++i) {
if (i != outerAxis[i])
return false;
}
return true;
}

bool isPlain() const {
if (isPlainOuterAxis(outerAxis))
return tileSizes.empty() && innerAxis.empty();
return false;
}

bool isBlocking() const { return !tileSizes.empty() && !innerAxis.empty(); }

static TensorLayout createPlainLayout(int64_t rank) {
SmallVector<int64_t> outerAxis(rank, 0);
std::iota(outerAxis.begin(), outerAxis.end(), 0);
return TensorLayout(outerAxis, SmallVector<int64_t>{},
SmallVector<OpFoldResult>{});
}

DenseMap<int64_t, SmallVector<int64_t>> getPlainToPackedAxisMapping() {
DenseMap<int64_t, SmallVector<int64_t>> axisMapping;
int64_t outerAxisSize = outerAxis.size();
for (int64_t i = 0; i < outerAxisSize; ++i) {
axisMapping[outerAxis[i]].push_back(i);
}
for (int64_t i = 0; i < static_cast<int64_t>(innerAxis.size()); ++i) {
axisMapping[innerAxis[i]].push_back(outerAxisSize + i);
}
return axisMapping;
}

int64_t getPlainAxis(int64_t idx) {
int64_t totalRank = outerAxis.size() + innerAxis.size();
assert(idx >= 0 && idx < totalRank && "Provided plain axis out of bound");
if (idx >= static_cast<int64_t>(outerAxis.size())) {
return innerAxis[idx - outerAxis.size()];
} else {
return outerAxis[idx];
}
}

size_t getRank() const { return outerAxis.size(); }

SmallVector<int64_t> getOuterAxis() const { return outerAxis; }

SmallVector<int64_t> getInnerAxis() const { return innerAxis; }

SmallVector<OpFoldResult> getTileSizes() const { return tileSizes; }

friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
const TensorLayout &layout);

bool operator==(const TensorLayout &other) const;

bool operator!=(const TensorLayout &other) const;

private:
SmallVector<int64_t> outerAxis;
SmallVector<int64_t> innerAxis;
SmallVector<OpFoldResult> tileSizes;
};

class OperatorLayout {
public:
OperatorLayout() {}

OperatorLayout(SmallVector<TensorLayout> inputLayouts,
SmallVector<TensorLayout> outputLayouts) {
supportedInputLayouts = inputLayouts;
supportedOutputLayouts = outputLayouts;
}

SmallVector<TensorLayout> getSupportedInputLayouts() const {
return supportedInputLayouts;
}

SmallVector<TensorLayout> getSupportedOutputLayouts() const {
return supportedOutputLayouts;
}

TensorLayout getOutputLayout(int64_t idx) const {
assert(idx < static_cast<int64_t>(supportedOutputLayouts.size()));
return supportedOutputLayouts[idx];
}

bool isPlain() const {
for (const auto &layout : llvm::concat<const TensorLayout>(
supportedInputLayouts, supportedOutputLayouts)) {
if (!layout.isPlain())
return false;
}
return true;
}

friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
const OperatorLayout &opLayout);

private:
SmallVector<TensorLayout> supportedInputLayouts;
SmallVector<TensorLayout> supportedOutputLayouts;
};

class GlobalAnalysis {
public:
explicit GlobalAnalysis(Operation *root);

FailureOr<OperatorLayout> getOpLayout(Operation *op) {
if (layoutCache.find(op) != layoutCache.end())
return layoutCache[op];
else
return failure();
}

private:
DenseMap<Operation *, OperatorLayout> layoutCache;
};

namespace utils {
bool isSupportedContractionNamedOp(const linalg::LinalgOp &linalgOp);

bool isPackableOp(Operation *op);

bool hasAllTensorSemantics(linalg::LinalgOp linalgOp);
} // namespace utils
} // namespace gc
} // namespace mlir

#endif
6 changes: 6 additions & 0 deletions include/gc/Analysis/MatmulConfigAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ getOprandDimType(linalg::LinalgOp &linalgOp) {
SmallVector<DimType>{DimType::M, DimType::K, DimType::M, DimType::K},
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N},
SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
} else if (linalgx::isGenericPackedMatmulOp(linalgOp.getOperation(),
linalgx::PackingType::MM2D4D)) {
return SmallVector<SmallVector<DimType>>{
SmallVector<DimType>{DimType::M, DimType::K},
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N},
SmallVector<DimType>{DimType::M, DimType::N}};
}
return failure();
}
Expand Down
1 change: 1 addition & 0 deletions include/gc/Dialect/Linalgx/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace linalgx {
/// @brief enum of type of matmul packing
enum class PackingType : int {
MM4D = 0, // MKmk x NKkn
MM2D4D, // MK x NKkn
VNNI_MM2D, // MK x NKknV
VNNI_MM4D, // MKmk x NKknV
VNNI_BRMM3D, // BMK x BKNV
Expand Down
35 changes: 35 additions & 0 deletions include/gc/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,40 @@ def MergeNestedForall : Pass<"merge-nested-forall"> {
let dependentDialects = ["scf::SCFDialect"];
}

def PropagateLayoutOnNamedOps : Pass<"propagate-layout-on-named-ops"> {
let summary = "Insert and propagte tensor.pack to pack the computation of linalg named ops and tensor ops.";
let description = [{
Insert and propagte tensor.pack on linalg named ops and tensor ops.
}];
let dependentDialects = [
"mlir::tensor::TensorDialect",
"mlir::linalg::LinalgDialect",
"mlir::linalgx::LinalgxDialect"
];
}

def PostProcessPackUnpack : Pass<"post-process-pack-unpack"> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it need to be a separate pass if it is a post-processing?

let summary = "Fold and simplify pack and unpack ops.";
let description = [{
Fold and simplify pack and unpack ops.
}];
let dependentDialects = [
"mlir::tensor::TensorDialect",
"mlir::linalg::LinalgDialect"
];
}

def LowerPackUnpack : Pass<"lower-pack-unpack"> {
let summary = "Lower pack and unpack ops.";
let description = [{
Lower pack and unpack into transpose and shape related ops.
}];
let dependentDialects = [
"mlir::tensor::TensorDialect",
"mlir::linalg::LinalgDialect"
];
}

def FoldTensorOperation : Pass<"fold-tensor-operation"> {
let summary = "Fold some tensor operation";
let description = [{
Expand All @@ -179,6 +213,7 @@ def FoldTensorOperation : Pass<"fold-tensor-operation"> {
];
}


def LowerToTileVector : Pass<"lower-to-tile-vector", "func::FuncOp"> {
let summary = "Lower tensor to tile (virtual) vector";
let description = [{
Expand Down
27 changes: 27 additions & 0 deletions include/gc/Transforms/Transforms.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===-- Transforms.h - transformation utilities -----------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef GC_TRANSFORMS_TRANSFORMS_H
#define GC_TRANSFORMS_TRANSFORMS_H

#include "gc/Analysis/GlobalAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"

namespace mlir {
namespace gc {
LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
const OperatorLayout &opLayout);

LogicalResult namedOpLayoutPropagation(RewriterBase &rewriter,
linalg::LinalgOp linalgOp,
OperatorLayout opLayout);
} // namespace gc
} // namespace mlir

#endif // GC_TRANSFORMS_TRANSFORMS_H
1 change: 1 addition & 0 deletions lib/gc/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
gc_add_mlir_library(GcAnalysis
TargetDescriptionAnalysis.cpp
MatmulConfigAnalysis.cpp
GlobalAnalysis.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please give it a more descriptive name?


DEPENDS
GraphCompilerPassIncGen
Expand Down
Loading
Loading