Skip to content

Commit

Permalink
Merge branch 'main' into enable-to-disable-v1
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandreEichenberger authored Feb 13, 2025
2 parents 645c209 + 2bede1a commit eca7372
Show file tree
Hide file tree
Showing 14 changed files with 149 additions and 357 deletions.
15 changes: 4 additions & 11 deletions src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,13 @@ llvm::cl::opt<NNPAEmissionTargetType> nnpaEmissionTarget(
clEnumVal(EmitZNONE, "Do not emit NNPA-related target (default)")),
llvm::cl::init(EmitZNONE), llvm::cl::cat(OnnxMlirOptions));

llvm::cl::opt<bool> nnpaClipToDLFloatRange("nnpa-clip-to-dlfloat-range",
llvm::cl::desc("Clip CPU tensors to dlfloat range before stickification to "
"avoid out-of-range. Only clip Softmax inputs at this "
"moment. Default is true. This option will be removed and "
"replaced by --nnpa-saturation in the future."),
llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions));

llvm::cl::opt<bool> nnpaEnableZHighToOnnx("enable-zhigh-to-onnx",
llvm::cl::opt<bool> nnpaDisableZHighToOnnx("disable-zhigh-to-onnx",
llvm::cl::desc(
"Enabling this will convert a pattern `stick -> element-wise op -> "
"By default we convert a pattern `stick -> element-wise op -> "
"unstick` back to an ONNX element-wise op. This conversion is called "
"after applying all optimizations to remove stick/unstick at ZHigh "
"level. Default is true."),
llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions));
"level. Use this option to disable this optimization."),
llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions));

llvm::cl::opt<bool> nnpaEnableZHighDecomposeStickUnstick(
"enable-zhigh-decompose-stick-unstick",
Expand Down
3 changes: 1 addition & 2 deletions src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ typedef enum {
extern llvm::cl::OptionCategory OnnxMlirOptions;
extern llvm::cl::OptionCategory OnnxMlirCommonOptions;
extern llvm::cl::opt<onnx_mlir::NNPAEmissionTargetType> nnpaEmissionTarget;
extern llvm::cl::opt<bool> nnpaClipToDLFloatRange;
extern llvm::cl::opt<bool> nnpaEnableZHighToOnnx;
extern llvm::cl::opt<bool> nnpaDisableZHighToOnnx;
extern llvm::cl::opt<bool> nnpaEnableZHighDecomposeStickUnstick;
extern llvm::cl::opt<bool> nnpaDisableCompilerStickUnstick;
extern llvm::cl::opt<bool> nnpaEnableScalarBcastBinary;
Expand Down
12 changes: 1 addition & 11 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,6 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());

// Clip zhigh.Stick inputs if required. This is to avoid out-of-range of
// dlfloat. Do constant propagation after clipping to remove ONNX ops used for
// clipping such as ONNXMax if applicable.
// This pass will be removed and replaced by nnpa-saturation in the future.
if (!nnpaEnableSaturation && nnpaClipToDLFloatRange) {
pm.addNestedPass<func::FuncOp>(
onnx_mlir::zhigh::createZHighClipToDLFloatPass());
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConstPropONNXToONNXPass());
}

// One more call to ONNX shape inference/canonicalization/... to update shape
// if possible.
if (enableONNXHybridPass) {
Expand Down Expand Up @@ -183,7 +173,7 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
// sub, ...) that are of `stick -> light-weight op -> unstick`, it's better to
// use CPU instead of NNPA to avoid stick/unstick. CPU is efficient to handle
// these ops, e.g vectorize the computation.
if (nnpaEnableZHighToOnnx)
if (!nnpaDisableZHighToOnnx)
pm.addNestedPass<func::FuncOp>(onnx_mlir::createZHighToONNXPass());

// Constant propagation at ZHighIR: constant stickify.
Expand Down
4 changes: 0 additions & 4 deletions src/Accelerators/NNPA/NNPAAccelerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,6 @@ void NNPAAccelerator::registerPasses(int optLevel) const {
return onnx_mlir::zhigh::createZHighLayoutPropagationPass();
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return onnx_mlir::zhigh::createZHighClipToDLFloatPass();
});

mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return onnx_mlir::zhigh::createZHighDecomposeStickUnstickPass();
});
Expand Down
3 changes: 0 additions & 3 deletions src/Accelerators/NNPA/Pass/NNPAPasses.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ std::unique_ptr<mlir::Pass> createZHighConstPropagationPass();
std::unique_ptr<mlir::Pass> createZHighScrubDisposablePass(
bool closeAfter = true);

/// Pass for clipping values to dlfloat before stickification at ZHighIR.
std::unique_ptr<mlir::Pass> createZHighClipToDLFloatPass();

/// Pass for decomposing stick/unstick at ZHighIR.
std::unique_ptr<mlir::Pass> createZHighDecomposeStickUnstickPass();

Expand Down
11 changes: 0 additions & 11 deletions src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,6 @@ add_onnx_mlir_library(OMZHighLayoutPropagation
${NNPA_INCLUDE_PATH}
)

add_onnx_mlir_rewriter(ZHighClipToDLFloat)
add_onnx_mlir_library(OMZHighClipToDLFloat
ZHighClipToDLFloat.cpp

LINK_LIBS PUBLIC
MLIRRewrite
MLIRTransformUtils
OMZHighOps
OMONNXOps
)

add_onnx_mlir_rewriter(ZHighDecomposeStickUnstick)
add_onnx_mlir_library(OMZHighDecomposeStickUnstick
ZHighDecomposeStickUnstick.cpp
Expand Down
170 changes: 0 additions & 170 deletions src/Accelerators/NNPA/Transform/ZHigh/ZHighClipToDLFloat.cpp

This file was deleted.

2 changes: 1 addition & 1 deletion src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,

void addKrnlToAffinePasses(mlir::PassManager &pm) {
pm.addNestedPass<func::FuncOp>(
onnx_mlir::krnl::createConvertKrnlToAffinePass());
onnx_mlir::krnl::createConvertKrnlToAffinePass(enableParallel));
}

void addKrnlToLLVMPasses(
Expand Down
21 changes: 18 additions & 3 deletions src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -844,11 +844,21 @@ struct ConvertKrnlToAffinePass
: public PassWrapper<ConvertKrnlToAffinePass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertKrnlToAffinePass);

ConvertKrnlToAffinePass() = default;
ConvertKrnlToAffinePass(const ConvertKrnlToAffinePass &pass)
: PassWrapper<ConvertKrnlToAffinePass, OperationPass<func::FuncOp>>() {}
ConvertKrnlToAffinePass(bool parallelEnabled) {
this->parallelEnabled = parallelEnabled;
}

StringRef getArgument() const override { return "convert-krnl-to-affine"; }

StringRef getDescription() const override { return "Lower Krnl dialect."; }

void runOnOperation() final;

Option<bool> parallelEnabled{*this, "parallel-enabled",
llvm::cl::desc("Enable parallelization"), llvm::cl::init(false)};
};

void ConvertKrnlToAffinePass::runOnOperation() {
Expand Down Expand Up @@ -1008,7 +1018,7 @@ void ConvertKrnlToAffinePass::runOnOperation() {
RewritePatternSet patterns(ctx);
AffineTypeConverter typeConverter;

populateKrnlToAffineConversion(typeConverter, patterns, ctx);
populateKrnlToAffineConversion(typeConverter, patterns, ctx, parallelEnabled);

// Create list for recording the <loop, unroll factor> pairs associated with
// this function.
Expand Down Expand Up @@ -1046,16 +1056,21 @@ std::unique_ptr<Pass> createConvertKrnlToAffinePass() {
return std::make_unique<ConvertKrnlToAffinePass>();
}

std::unique_ptr<Pass> createConvertKrnlToAffinePass(bool parallelEnabled) {
return std::make_unique<ConvertKrnlToAffinePass>(parallelEnabled);
}

void populateKrnlToAffineConversion(TypeConverter &typeConverter,
RewritePatternSet &patterns, MLIRContext *ctx) {
RewritePatternSet &patterns, MLIRContext *ctx, bool parallelEnabled) {
krnl::populateLoweringKrnlCopyFromBufferOpPattern(
typeConverter, patterns, ctx);
krnl::populateLoweringKrnlCopyToBufferOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlLoadOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlStoreOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlGetLinearOffsetIndexOpPattern(
typeConverter, patterns, ctx);
krnl::populateLoweringKrnlMatmultOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlMatmultOpPattern(
typeConverter, patterns, ctx, parallelEnabled);
krnl::populateLoweringKrnlMemsetOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlPrefetchOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlTerminatorOpPattern(typeConverter, patterns, ctx);
Expand Down
6 changes: 4 additions & 2 deletions src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ using UnrollAndJamList = llvm::SmallVector<UnrollAndJamRecord, 4>;
using UnrollAndJamMap = std::map<mlir::Operation *, UnrollAndJamList *>;

void populateKrnlToAffineConversion(mlir::TypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx,
bool enableParallel);

void populateLoweringKrnlCopyFromBufferOpPattern(
mlir::TypeConverter &typeConverter, mlir::RewritePatternSet &patterns,
Expand All @@ -77,7 +78,8 @@ void populateLoweringKrnlGetLinearOffsetIndexOpPattern(
mlir::MLIRContext *ctx);

void populateLoweringKrnlMatmultOpPattern(mlir::TypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx,
bool parallelEnabled);

void populateLoweringKrnlMemsetOpPattern(mlir::TypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
Expand Down
Loading

0 comments on commit eca7372

Please sign in to comment.