diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp index dfab3c8202..e4332c4174 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp @@ -30,20 +30,13 @@ llvm::cl::opt nnpaEmissionTarget( clEnumVal(EmitZNONE, "Do not emit NNPA-related target (default)")), llvm::cl::init(EmitZNONE), llvm::cl::cat(OnnxMlirOptions)); -llvm::cl::opt nnpaClipToDLFloatRange("nnpa-clip-to-dlfloat-range", - llvm::cl::desc("Clip CPU tensors to dlfloat range before stickification to " - "avoid out-of-range. Only clip Softmax inputs at this " - "moment. Default is true. This option will be removed and " - "replaced by --nnpa-saturation in the future."), - llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions)); - -llvm::cl::opt nnpaEnableZHighToOnnx("enable-zhigh-to-onnx", +llvm::cl::opt nnpaDisableZHighToOnnx("disable-zhigh-to-onnx", llvm::cl::desc( - "Enabling this will convert a pattern `stick -> element-wise op -> " + "By default we convert a pattern `stick -> element-wise op -> " "unstick` back to an ONNX element-wise op. This conversion is called " "after applying all optimizations to remove stick/unstick at ZHigh " - "level. Default is true."), - llvm::cl::init(true), llvm::cl::cat(OnnxMlirOptions)); + "level. Use this option to disable this optimization."), + llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); llvm::cl::opt nnpaEnableZHighDecomposeStickUnstick( "enable-zhigh-decompose-stick-unstick", diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp index 8a0704b2b6..545f7e5a8a 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp @@ -67,8 +67,7 @@ typedef enum { extern llvm::cl::OptionCategory OnnxMlirOptions; extern llvm::cl::OptionCategory OnnxMlirCommonOptions; extern llvm::cl::opt nnpaEmissionTarget; -extern llvm::cl::opt nnpaClipToDLFloatRange; -extern llvm::cl::opt nnpaEnableZHighToOnnx; +extern llvm::cl::opt nnpaDisableZHighToOnnx; extern llvm::cl::opt nnpaEnableZHighDecomposeStickUnstick; extern llvm::cl::opt nnpaDisableCompilerStickUnstick; extern llvm::cl::opt nnpaEnableScalarBcastBinary; diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp index 75198c8c99..cb9e46a300 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp @@ -145,16 +145,6 @@ void addONNXToZHighPasses(mlir::PassManager &pm) { pm.addNestedPass(onnx_mlir::createShapeInferencePass()); pm.addPass(mlir::createCanonicalizerPass()); - // Clip zhigh.Stick inputs if required. This is to avoid out-of-range of - // dlfloat. Do constant propagation after clipping to remove ONNX ops used for - // clipping such as ONNXMax if applicable. - // This pass will be removed and replaced by nnpa-saturation in the future. - if (!nnpaEnableSaturation && nnpaClipToDLFloatRange) { - pm.addNestedPass( - onnx_mlir::zhigh::createZHighClipToDLFloatPass()); - pm.addNestedPass(onnx_mlir::createConstPropONNXToONNXPass()); - } - // One more call to ONNX shape inference/canonicalization/... to update shape // if possible. if (enableONNXHybridPass) { @@ -183,7 +173,7 @@ void addONNXToZHighPasses(mlir::PassManager &pm) { // sub, ...) that are of `stick -> light-weight op -> unstick`, it's better to // use CPU instead of NNPA to avoid stick/unstick. CPU is efficient to handle // these ops, e.g vectorize the computation. - if (nnpaEnableZHighToOnnx) + if (!nnpaDisableZHighToOnnx) pm.addNestedPass(onnx_mlir::createZHighToONNXPass()); // Constant propagation at ZHighIR: constant stickify. diff --git a/src/Accelerators/NNPA/NNPAAccelerator.cpp b/src/Accelerators/NNPA/NNPAAccelerator.cpp index 50ef2bf0ba..0fa76062b7 100644 --- a/src/Accelerators/NNPA/NNPAAccelerator.cpp +++ b/src/Accelerators/NNPA/NNPAAccelerator.cpp @@ -125,10 +125,6 @@ void NNPAAccelerator::registerPasses(int optLevel) const { return onnx_mlir::zhigh::createZHighLayoutPropagationPass(); }); - mlir::registerPass([]() -> std::unique_ptr { - return onnx_mlir::zhigh::createZHighClipToDLFloatPass(); - }); - mlir::registerPass([]() -> std::unique_ptr { return onnx_mlir::zhigh::createZHighDecomposeStickUnstickPass(); }); diff --git a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp index c23fb7f158..1c8d6b7012 100644 --- a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp +++ b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp @@ -55,9 +55,6 @@ std::unique_ptr createZHighConstPropagationPass(); std::unique_ptr createZHighScrubDisposablePass( bool closeAfter = true); -/// Pass for clipping values to dlfloat before stickification at ZHighIR. -std::unique_ptr createZHighClipToDLFloatPass(); - /// Pass for decomposing stick/unstick at ZHighIR. std::unique_ptr createZHighDecomposeStickUnstickPass(); diff --git a/src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt b/src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt index 30378c0f9e..3cae309723 100644 --- a/src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt +++ b/src/Accelerators/NNPA/Transform/ZHigh/CMakeLists.txt @@ -38,17 +38,6 @@ add_onnx_mlir_library(OMZHighLayoutPropagation ${NNPA_INCLUDE_PATH} ) -add_onnx_mlir_rewriter(ZHighClipToDLFloat) -add_onnx_mlir_library(OMZHighClipToDLFloat - ZHighClipToDLFloat.cpp - - LINK_LIBS PUBLIC - MLIRRewrite - MLIRTransformUtils - OMZHighOps - OMONNXOps - ) - add_onnx_mlir_rewriter(ZHighDecomposeStickUnstick) add_onnx_mlir_library(OMZHighDecomposeStickUnstick ZHighDecomposeStickUnstick.cpp diff --git a/src/Accelerators/NNPA/Transform/ZHigh/ZHighClipToDLFloat.cpp b/src/Accelerators/NNPA/Transform/ZHigh/ZHighClipToDLFloat.cpp deleted file mode 100644 index 9006c36669..0000000000 --- a/src/Accelerators/NNPA/Transform/ZHigh/ZHighClipToDLFloat.cpp +++ /dev/null @@ -1,170 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - */ - -//===---------- ZHighClipToDLFloat.cpp - ZHigh High Level Optimizer -------===// -// -// Copyright 2023- The IBM Research Authors. -// -// ============================================================================= -// -// This file implements a set of rewritten rules to clip CPU numerical values -// before passing to ZHighStick, which avoids data range violation error due to -// the dlfloat range. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp" -#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp" -#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp" -#include "src/Accelerators/NNPA/Support/NNPALimit.hpp" -#include "src/Dialect/ONNX/DialectBuilder.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" -#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" -#include "src/Support/TypeUtilities.hpp" - -using namespace mlir; -using namespace onnx_mlir; -using namespace onnx_mlir::zhigh; - -namespace onnx_mlir { -namespace zhigh { - -namespace { - -/// Check if a value is from or transitively from a zTensor without value -/// modification. -bool valueFromZTensor(Value tensor) { - // Function arguments are always CPU tensors. - if (mlir::dyn_cast(tensor)) - return false; - - Operation *op = tensor.getDefiningOp(); - - // Base case: From a zTensor. - if (isa(op)) - return true; - - // Base case: ReluOp clipped the lowerbound to zero. - if (isa(op)) - return true; - - // Base case: Operations having no input, e.g., Constant, ConstantOfShape. - if (op->getOperands().size() == 0) - return false; - - // Recursion case: There are operations (e.g. transpose, reshape, etc.) that - // do not change the input precision. So we can consider that the input is - // already in the dlfloat range if it comes from zTensor. - - // Operations whose only the first input form the output. These ops may - // have additional inputs, but they are like attributes. - if (isa(op)) - return valueFromZTensor(op->getOperand(0)); - - // PadOp - if (auto padOp = mlir::dyn_cast(op)) { - Value padVal = padOp.getConstantValue(); - // Only support default constant value that is 0 at this moment. - if (isNoneValue(padVal)) - return valueFromZTensor(op->getOperand(0)); - } - - // For all remaining operations, do a conservative check. - return llvm::all_of( - op->getOperands(), [&](Value v) { return valueFromZTensor(v); }); -} - -class ZHighClipToDLFloatPattern : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite( - ZHighStickOp stickOp, PatternRewriter &rewriter) const override { - Operation *genericOp = stickOp.getOperation(); - Location loc = genericOp->getLoc(); - - Value input = stickOp.getIn(); - Value output = stickOp.getOut(); - Type inputElementType = getElementType(input.getType()); - - // Only clip if the input is in float > 16 bit. - auto floatType = mlir::dyn_cast(inputElementType); - if (!floatType) - return failure(); - if (floatType.getWidth() <= 16) - return failure(); - - // Only clip if the consummer is Softmax with which we have seen NaNs. - if (llvm::none_of(output.getUsers(), - [&](Operation *op) { return isa(op); })) - return failure(); - - // Do not clip if the input tensor is already in the dlfloat range. - // For example, the input was unstickified from a zTensor. - if (valueFromZTensor(input)) - return failure(); - - // Clip the input values if required since the values are potentially - // out-of-bound of dlfloat. - MultiDialectBuilder create(rewriter, loc); - DenseElementsAttr minAttr = DenseElementsAttr::get( - RankedTensorType::get({1}, inputElementType), DLF16_MIN); - Value minVal = create.onnx.constant(minAttr); - Value clippedVal = create.onnx.max({input, minVal}); - Value replacedVal = - rewriter.create(loc, stickOp.getOut().getType(), - clippedVal, stickOp.getLayoutAttr(), IntegerAttr()); - - rewriter.replaceOp(genericOp, replacedVal); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// ZHighClipToDLFloatPass -//===----------------------------------------------------------------------===// - -struct ZHighClipToDLFloatPass - : public PassWrapper> { - - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ZHighClipToDLFloatPass) - - StringRef getArgument() const override { return "zhigh-clip-to-dlfloat"; } - - StringRef getDescription() const override { - return "Clip stickification inputs at ZHighIR."; - } - - void runOnOperation() override { - auto function = getOperation(); - ConversionTarget target(getContext()); - RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); - - GreedyRewriteConfig config; - config.useTopDownTraversal = true; - /// Only pre-existing ops (that were were on the worklist at the very - /// beginning) enqueued. All other ops are excluded. - config.strictMode = GreedyRewriteStrictness::ExistingOps; - - if (failed(applyPatternsAndFoldGreedily( - function, std::move(patterns), config))) - signalPassFailure(); - } -}; -} // anonymous namespace - -std::unique_ptr createZHighClipToDLFloatPass() { - return std::make_unique(); -} - -} // namespace zhigh -} // namespace onnx_mlir diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index 5561955304..efcb10d829 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -214,7 +214,7 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, void addKrnlToAffinePasses(mlir::PassManager &pm) { pm.addNestedPass( - onnx_mlir::krnl::createConvertKrnlToAffinePass()); + onnx_mlir::krnl::createConvertKrnlToAffinePass(enableParallel)); } void addKrnlToLLVMPasses( diff --git a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp index 73609c2f14..c3b466930c 100644 --- a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp +++ b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp @@ -844,11 +844,21 @@ struct ConvertKrnlToAffinePass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertKrnlToAffinePass); + ConvertKrnlToAffinePass() = default; + ConvertKrnlToAffinePass(const ConvertKrnlToAffinePass &pass) + : PassWrapper>() {} + ConvertKrnlToAffinePass(bool parallelEnabled) { + this->parallelEnabled = parallelEnabled; + } + StringRef getArgument() const override { return "convert-krnl-to-affine"; } StringRef getDescription() const override { return "Lower Krnl dialect."; } void runOnOperation() final; + + Option parallelEnabled{*this, "parallel-enabled", + llvm::cl::desc("Enable parallelization"), llvm::cl::init(false)}; }; void ConvertKrnlToAffinePass::runOnOperation() { @@ -1008,7 +1018,7 @@ void ConvertKrnlToAffinePass::runOnOperation() { RewritePatternSet patterns(ctx); AffineTypeConverter typeConverter; - populateKrnlToAffineConversion(typeConverter, patterns, ctx); + populateKrnlToAffineConversion(typeConverter, patterns, ctx, parallelEnabled); // Create list for recording the pairs associated with // this function. @@ -1046,8 +1056,12 @@ std::unique_ptr createConvertKrnlToAffinePass() { return std::make_unique(); } +std::unique_ptr createConvertKrnlToAffinePass(bool parallelEnabled) { + return std::make_unique(parallelEnabled); +} + void populateKrnlToAffineConversion(TypeConverter &typeConverter, - RewritePatternSet &patterns, MLIRContext *ctx) { + RewritePatternSet &patterns, MLIRContext *ctx, bool parallelEnabled) { krnl::populateLoweringKrnlCopyFromBufferOpPattern( typeConverter, patterns, ctx); krnl::populateLoweringKrnlCopyToBufferOpPattern(typeConverter, patterns, ctx); @@ -1055,7 +1069,8 @@ void populateKrnlToAffineConversion(TypeConverter &typeConverter, krnl::populateLoweringKrnlStoreOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlGetLinearOffsetIndexOpPattern( typeConverter, patterns, ctx); - krnl::populateLoweringKrnlMatmultOpPattern(typeConverter, patterns, ctx); + krnl::populateLoweringKrnlMatmultOpPattern( + typeConverter, patterns, ctx, parallelEnabled); krnl::populateLoweringKrnlMemsetOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlPrefetchOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlTerminatorOpPattern(typeConverter, patterns, ctx); diff --git a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp index 2bc0fd3aae..c1c222a293 100644 --- a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp +++ b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp @@ -56,7 +56,8 @@ using UnrollAndJamList = llvm::SmallVector; using UnrollAndJamMap = std::map; void populateKrnlToAffineConversion(mlir::TypeConverter &typeConverter, - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx, + bool enableParallel); void populateLoweringKrnlCopyFromBufferOpPattern( mlir::TypeConverter &typeConverter, mlir::RewritePatternSet &patterns, @@ -77,7 +78,8 @@ void populateLoweringKrnlGetLinearOffsetIndexOpPattern( mlir::MLIRContext *ctx); void populateLoweringKrnlMatmultOpPattern(mlir::TypeConverter &typeConverter, - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); + mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx, + bool parallelEnabled); void populateLoweringKrnlMemsetOpPattern(mlir::TypeConverter &typeConverter, mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); diff --git a/src/Conversion/KrnlToAffine/KrnlMatmul.cpp b/src/Conversion/KrnlToAffine/KrnlMatmul.cpp index 8ab9ef7b1a..6b42177457 100644 --- a/src/Conversion/KrnlToAffine/KrnlMatmul.cpp +++ b/src/Conversion/KrnlToAffine/KrnlMatmul.cpp @@ -42,9 +42,12 @@ extern std::mutex unrollAndJamMutex; class KrnlMatmulLowering : public ConversionPattern { public: explicit KrnlMatmulLowering( - TypeConverter &typeConverter, MLIRContext *context) + TypeConverter &typeConverter, MLIRContext *context, bool parallelEnabled) : ConversionPattern( - typeConverter, KrnlMatMulOp::getOperationName(), 1, context) {} + typeConverter, KrnlMatMulOp::getOperationName(), 1, context) { + this->parallelEnabled = parallelEnabled; + } + bool parallelEnabled = false; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { @@ -221,22 +224,33 @@ class KrnlMatmulLowering : public ConversionPattern { if (simdize) { // SIMD code generator. if (matVectorProduct) { + // Alloc of temp outside of inner if/then/else. + Value TmpSimdProd = allocForGenSimdMatVect(create.affineKMem, + elementType, iComputeTileSize, jComputeTileSize, kComputeTileSize, + vectorLen, fullUnrollAndJam); + Value TmpScalarProd = allocForGenScalar(create.affineKMem, elementType, + iTrip, jTrip, kTrip, /*unroll*/ false); // clang-format off create.affineKMem.ifThenElseIE(indexScope, allFullTiles, /* then full tiles */ [&](const AffineBuilderKrnlMem &createAffine) { - genSimdMatVect(createAffine, matmulOp, elementType, aStart, bStart, + genSimdMatVect(createAffine, matmulOp, TmpSimdProd, elementType, aStart, bStart, cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize, vectorLen, fullUnrollAndJam); }, /* else has partial tiles */ [&](const AffineBuilderKrnlMem &createAffine) { - genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart, + genScalar(createAffine, matmulOp, TmpScalarProd, elementType, aStart, bStart, cStart, iTrip, jTrip, kTrip, /*unroll*/ false); }); // clang-format on } else { + Value TmpSimdC = allocForGenSimdMatMat(create.affineKMem, elementType, + iComputeTileSize, jComputeTileSize, kComputeTileSize, vectorLen, + fullUnrollAndJam); + Value TmpScalarC = allocForGenScalar(create.affineKMem, elementType, + iTrip, jPartialTrip, kTrip, /*unroll*/ false); // clang-format off create.affineKMem.ifThenElseIE(indexScope, allFullTiles, /* then full tiles */ [&](const AffineBuilderKrnlMem &createAffine) { - genSimdMatMat(createAffine, matmulOp, elementType, aStart, bStart, + genSimdMatMat(createAffine, matmulOp, TmpSimdC, elementType, aStart, bStart, cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize, vectorLen, fullUnrollAndJam); }, @@ -246,33 +260,30 @@ class KrnlMatmulLowering : public ConversionPattern { // Test if SIMD dim (M) is full. createAffine.ifThenElseIE(indexScope, jFullTiles, /* full SIMD */ [&](const AffineBuilderKrnlMem &createAffine) { - genSimdMatMat(createAffine, matmulOp, elementType, aStart, bStart, + genSimdMatMat(createAffine, matmulOp, TmpSimdC, elementType, aStart, bStart, cStart, iTrip, jComputeTileSize, kTrip, vectorLen, /*unroll*/ false); }, /* else partial SIMD */ [&](const AffineBuilderKrnlMem &createAffine) { - // TODO: evaluate if get performance from partial SIMD - if (false && jPartialTrip.isLiteral() && jPartialTrip.getLiteral() >=2) { - // has a known trip count along the simd dimension of at least 2 - // elements, use simd again. - genSimdMatMat(createAffine, matmulOp, elementType, aStart, bStart, - cStart, iTrip, jPartialTrip, kTrip, vectorLen, /*unroll*/ false); - } else { - genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart, - iTrip, jPartialTrip, kTrip, /*unroll*/ false); - } + genScalar(createAffine, matmulOp, TmpScalarC, elementType, aStart, bStart, cStart, + iTrip, jPartialTrip, kTrip, /*unroll*/ false); }); }); // clang-format on } } else { // Scalar code generator. + Value TmpThenC = + allocForGenScalar(create.affineKMem, elementType, iComputeTileSize, + jComputeTileSize, kComputeTileSize, fullUnrollAndJam); + Value TmpElseC = allocForGenScalar( + create.affineKMem, elementType, iTrip, jTrip, kTrip, false); // clang-format off create.affineKMem.ifThenElseIE(indexScope, allFullTiles, /* then full */ [&](const AffineBuilderKrnlMem &createAffine) { - genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart, + genScalar(createAffine, matmulOp, TmpThenC, elementType, aStart, bStart, cStart, iComputeTileSize, jComputeTileSize, kComputeTileSize, fullUnrollAndJam); }, /* else partial */ [&](const AffineBuilderKrnlMem &createAffine) { - genScalar(createAffine, matmulOp, elementType, aStart, bStart, cStart, + genScalar(createAffine, matmulOp, TmpElseC, elementType, aStart, bStart, cStart, iTrip, jTrip, kTrip, false); }); // clang-format on @@ -282,21 +293,32 @@ class KrnlMatmulLowering : public ConversionPattern { } private: - void genScalar(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op, - Type elementType, ArrayRef aStart, ArrayRef bStart, - ArrayRef cStart, IndexExpr I, IndexExpr J, IndexExpr K, + Value allocForGenScalar(const AffineBuilderKrnlMem &createAffine, + Type elementType, IndexExpr I, IndexExpr J, IndexExpr K, bool unrollJam) const { // Get operands. + MemRefBuilder createMemRef(createAffine); + int64_t unrollFactor = (unrollJam && J.isLiteral()) ? J.getLiteral() : 1; + // Have to privatize CTmpType by unroll factor (1 if none). + MemRefType CTmpType = MemRefType::get({unrollFactor}, elementType); + assert(BUFFER_ALIGN >= gDefaultAllocAlign); + // + if (parallelEnabled) + return createMemRef.alignedAlloc(CTmpType, BUFFER_ALIGN); + return createMemRef.alignedAlloca(CTmpType, BUFFER_ALIGN); + } + + void genScalar(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op, + Value TmpC, Type elementType, ArrayRef aStart, + ArrayRef bStart, ArrayRef cStart, IndexExpr I, + IndexExpr J, IndexExpr K, bool unrollJam) const { + // Get operands. KrnlMatMulOpAdaptor operandAdaptor(op); MemRefBuilder createMemRef(createAffine); Value A(operandAdaptor.getA()), B(operandAdaptor.getB()), C(operandAdaptor.getC()); int64_t unrollFactor = (unrollJam && J.isLiteral()) ? J.getLiteral() : 1; - // Have to privatize CTmpType by unroll factor (1 if none). - MemRefType CTmpType = MemRefType::get({unrollFactor}, elementType); - assert(BUFFER_ALIGN >= gDefaultAllocAlign); - Value TmpC = createMemRef.alignedAlloc(CTmpType, BUFFER_ALIGN); // For i, j loops. LiteralIndexExpr zeroIE(0); @@ -342,11 +364,46 @@ class KrnlMatmulLowering : public ConversionPattern { } } + Value allocForGenSimdMatVect(const AffineBuilderKrnlMem &createAffine, + Type elementType, IndexExpr I, IndexExpr J, IndexExpr K, + IndexExpr vectorLen, bool unrollJam) const { + // can simdize only if I & K is compile time + assert(I.isLiteral() && K.isLiteral() && vectorLen.isLiteral() && + "can only simdize with compile time " + "blocking factor on simd axis"); + MultiDialectBuilder create(createAffine); + int64_t iLit(I.getLiteral()), VL(vectorLen.getLiteral()); + int64_t archVL = create.vec.getArchVectorLength(elementType); + + // Generate the vector type conversions. + assert(VL == archVL && "vector length and VL must be identical for now"); + VectorType vecType = VectorType::get({VL}, elementType); + int64_t iUnrollFactor = iLit; + assert(iUnrollFactor % VL == 0 && "i blocking should be a multiple of VL"); + + // Have to privatize CTmpType by unroll factor. + MemRefType CTmpType = MemRefType::get({iUnrollFactor}, vecType); + assert(BUFFER_ALIGN >= gDefaultAllocAlign && + "alignment of buffers cannot be smaller than the default alignment " + "(which is set for SIMD correctness"); + // Ok to use an alloca here because hoisting will take it out of the loop, + // as it is now generated before the scf.if which precluded the migration to + // outside the loops. + + // But at this time, if parallel is enabled, alloca would be stuck inside of + // the parallel loop, which is not great. TODO: migrate alloca from inside + // the parallel loop to the OMP parallel region before the loop. + // Grep for this pattern in all 3 instances of "parallelEnabled". + if (parallelEnabled) + return create.mem.alignedAlloc(CTmpType, BUFFER_ALIGN); + return create.mem.alignedAlloca(CTmpType, BUFFER_ALIGN); + } + // Initially, simdize with full K vector length. void genSimdMatVect(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op, - Type elementType, ArrayRef aStart, ArrayRef bStart, - ArrayRef cStart, IndexExpr I, IndexExpr J, IndexExpr K, - IndexExpr vectorLen, bool unrollJam) const { + Value TmpProd, Type elementType, ArrayRef aStart, + ArrayRef bStart, ArrayRef cStart, IndexExpr I, + IndexExpr J, IndexExpr K, IndexExpr vectorLen, bool unrollJam) const { // can simdize only if I & K is compile time assert(I.isLiteral() && K.isLiteral() && vectorLen.isLiteral() && "can only simdize with compile time " @@ -367,12 +424,6 @@ class KrnlMatmulLowering : public ConversionPattern { int64_t iUnrollFactor = iLit; assert(iUnrollFactor % VL == 0 && "i blocking should be a multiple of VL"); - // Have to privatize CTmpType by unroll factor. - MemRefType CTmpType = MemRefType::get({iUnrollFactor}, vecType); - assert(BUFFER_ALIGN >= gDefaultAllocAlign && - "alignment of buffers cannot be smaller than the default alignment " - "(which is set for SIMD correctness"); - Value TmpProd = create.mem.alignedAlloc(CTmpType, BUFFER_ALIGN); // Init with zero. Value fZero = create.math.constant(elementType, 0); Value vFZero = create.vec.broadcast(vecType, fZero); @@ -427,11 +478,36 @@ class KrnlMatmulLowering : public ConversionPattern { } } + Value allocForGenSimdMatMat(const AffineBuilderKrnlMem &createAffine, + Type elementType, IndexExpr I, IndexExpr J, IndexExpr K, + IndexExpr vectorLen, bool unrollJam) const { + // can simdize only if K is compile time + MultiDialectBuilder create(createAffine); + + // Generate the vector type conversions. + int64_t VL = vectorLen.getLiteral(); + VectorType vecType = VectorType::get({VL}, elementType); + int64_t unrollFactor = (unrollJam && I.isLiteral()) ? I.getLiteral() : 1; + // Have to privatize CTmpType by unroll factor (1 if none). + MemRefType CTmpType = MemRefType::get({unrollFactor}, vecType); + assert(BUFFER_ALIGN >= gDefaultAllocAlign); + // Ok to use an alloca here because hoisting will take it out of the loop, + // as it is now generated before the scf.if which precluded the migration to + // outside the loops. + + // But at this time, if parallel is enabled, alloca would be stuck inside of + // the parallel loop, which is not great. TODO: migrate alloca from inside + // the parallel loop to the OMP parallel region before the loop. + if (parallelEnabled) + return create.mem.alignedAlloc(CTmpType, BUFFER_ALIGN); + return create.mem.alignedAlloca(CTmpType, BUFFER_ALIGN); + } + // Simdize along J / memory rows in B and C. void genSimdMatMat(const AffineBuilderKrnlMem &createAffine, KrnlMatMulOp op, - Type elementType, ArrayRef aStart, ArrayRef bStart, - ArrayRef cStart, IndexExpr I, IndexExpr J, IndexExpr K, - IndexExpr vectorLen, bool unrollJam) const { + Value TmpC, Type elementType, ArrayRef aStart, + ArrayRef bStart, ArrayRef cStart, IndexExpr I, + IndexExpr J, IndexExpr K, IndexExpr vectorLen, bool unrollJam) const { // can simdize only if K is compile time assert(J.isLiteral() && "can only simdize with compile time blocking factor on simd axis"); @@ -446,10 +522,6 @@ class KrnlMatmulLowering : public ConversionPattern { int64_t VL = vectorLen.getLiteral(); VectorType vecType = VectorType::get({VL}, elementType); int64_t unrollFactor = (unrollJam && I.isLiteral()) ? I.getLiteral() : 1; - // Have to privatize CTmpType by unroll factor (1 if none). - MemRefType CTmpType = MemRefType::get({unrollFactor}, vecType); - assert(BUFFER_ALIGN >= gDefaultAllocAlign); - Value TmpC = create.mem.alignedAlloc(CTmpType, BUFFER_ALIGN); // Iterates over the I indices (j are simd dim). Value iSaved, kSaved; @@ -547,8 +619,8 @@ class KrnlMatmulLowering : public ConversionPattern { }; // namespace krnl void populateLoweringKrnlMatmultOpPattern(TypeConverter &typeConverter, - RewritePatternSet &patterns, MLIRContext *ctx) { - patterns.insert(typeConverter, ctx); + RewritePatternSet &patterns, MLIRContext *ctx, bool parallelEnabled) { + patterns.insert(typeConverter, ctx, parallelEnabled); } } // namespace krnl diff --git a/src/Conversion/ONNXToKrnl/Math/Gemm.cpp b/src/Conversion/ONNXToKrnl/Math/Gemm.cpp index af0724c446..4f110bd6bd 100644 --- a/src/Conversion/ONNXToKrnl/Math/Gemm.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Gemm.cpp @@ -248,6 +248,8 @@ struct ONNXGemmOpLowering : public OpConversionPattern { [&](const KrnlBuilder &createKrnl, ValueRange i1_j1_indices) { Value i1(i1_j1_indices[0]), j1(i1_j1_indices[1]); // If parallel, will stay inside, otherwise will migrate out. + // Since they are not in an if structure, migration out is not an + // issue. Value aBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN); Value bBuff = create.mem.alignedAlloc(bTileType, BUFFER_ALIGN); Value rBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN); @@ -313,6 +315,8 @@ struct ONNXGemmOpLowering : public OpConversionPattern { [&](const KrnlBuilder &createKrnl, ValueRange j1_k1_indices) { Value j1(j1_k1_indices[0]), k1(j1_k1_indices[1]); // If parallel, it will stay inside, otherwise it will migrate out. + // Since allocs are not in an if structure, migration is not an + // issue. Value aBuff = create.mem.alignedAlloc(aTileType, BUFFER_ALIGN); Value bBuff = create.mem.alignedAlloc(bTileType, BUFFER_ALIGN); if (bTrans) diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index f22cbf2595..070fe3d671 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -108,6 +108,7 @@ std::unique_ptr createElideConstGlobalValuePass(); namespace krnl { /// Pass for lowering frontend dialects to Krnl IR dialect. std::unique_ptr createConvertKrnlToAffinePass(); +std::unique_ptr createConvertKrnlToAffinePass(bool parallelEnabled); /// Pass for lowering Seq in Krnl dialect. std::unique_ptr createConvertSeqToMemrefPass(); diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-clip-to-dlfloat-range.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-clip-to-dlfloat-range.mlir deleted file mode 100644 index ee1e31f280..0000000000 --- a/test/mlir/accelerators/nnpa/transform/zhigh-clip-to-dlfloat-range.mlir +++ /dev/null @@ -1,96 +0,0 @@ -// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --zhigh-clip-to-dlfloat -split-input-file %s || FileCheck %s - -func.func @should_clip_stick(%arg0: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { - %0 = "zhigh.Stick"(%arg0) {layout = "3DS"} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %1 = "zhigh.Softmax"(%0) {act_func = "ACT_NONE"} : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %2 = "zhigh.Unstick"(%1) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf32> - return %2 : tensor<3x4x5xf32> - -// CHECK-LABEL: func.func @should_clip_stick -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> { -// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<-8.57315738E+9> : tensor<1xf32> -// CHECK: [[VAR_1_:%.+]] = "onnx.Max"([[PARAM_0_]], [[VAR_0_]]) : (tensor<3x4x5xf32>, tensor<1xf32>) -> tensor<3x4x5xf32> -// CHECK: [[VAR_2_:%.+]] = "zhigh.Stick"([[VAR_1_]]) {layout = "3DS"} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> -// CHECK: [[VAR_3_:%.+]] = "zhigh.Softmax"([[VAR_2_]]) {act_func = "ACT_NONE"} : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> -// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf32> -// CHECK: return [[VAR_4_]] : tensor<3x4x5xf32> -// CHECK: } -} - -// ----- - -func.func @should_clip_transpose(%arg0: tensor<3x5x4xf32>) -> tensor<3x4x5xf32> { - %1 = "onnx.Transpose"(%arg0) { perm = [0, 2, 1]} : (tensor<3x5x4xf32>) -> tensor<3x4x5xf32> - %2 = "zhigh.Stick"(%1) {layout = "3DS"} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %3 = "zhigh.Softmax"(%2) {act_func = "ACT_NONE"} : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %4 = "zhigh.Unstick"(%3) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf32> - return %4 : tensor<3x4x5xf32> - -// CHECK-LABEL: func.func @should_clip_transpose -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x5x4xf32>) -> tensor<3x4x5xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<-8.57315738E+9> : tensor<1xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Transpose"([[PARAM_0_]]) {perm = [0, 2, 1]} : (tensor<3x5x4xf32>) -> tensor<3x4x5xf32> -// CHECK: [[VAR_2_:%.+]] = "onnx.Max"([[VAR_1_]], [[VAR_0_]]) : (tensor<3x4x5xf32>, tensor<1xf32>) -> tensor<3x4x5xf32> -// CHECK: [[VAR_3_:%.+]] = "zhigh.Stick"([[VAR_2_]]) {layout = "3DS"} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> -// CHECK: [[VAR_4_:%.+]] = "zhigh.Softmax"([[VAR_3_]]) {act_func = "ACT_NONE"} : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> -// CHECK: [[VAR_5_:%.+]] = "zhigh.Unstick"([[VAR_4_]]) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf32> -// CHECK: return [[VAR_5_]] : tensor<3x4x5xf32> -// CHECK: } -} - -// ----- - -// Do not clip because the input comes from a zTensor via Unstick. -func.func @donot_clip_stick(%arg0: tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<3x4x5xf32> { - %0 = "zhigh.Unstick"(%arg0) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<3x4x5xf32> - %1 = "zhigh.Stick"(%0) {layout = "3DS"} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %2 = "zhigh.Softmax"(%1) {act_func = "ACT_NONE"} : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %3 = "zhigh.Unstick"(%2) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf32> - return %3 : tensor<3x4x5xf32> - -// CHECK-LABEL: donot_clip_stick -// CHECK: zhigh.Unstick -// CHECK: zhigh.Stick -// CHECK: zhigh.Softmax -// CHECK: zhigh.Unstick -} - -// ----- - -// Do not clip because transpose does not change the zTensor. -func.func @donot_clip_stick_transpose(%arg0: tensor<3x5x4xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<3x4x5xf32> { - %0 = "zhigh.Unstick"(%arg0) : (tensor<3x5x4xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<3x5x4xf32> - %1 = "onnx.Transpose"(%0) { perm = [0, 2, 1]} : (tensor<3x5x4xf32>) -> tensor<3x4x5xf32> - %2 = "zhigh.Stick"(%1) {layout = "3DS"} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %3 = "zhigh.Softmax"(%2) {act_func = "ACT_NONE"} : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %4 = "zhigh.Unstick"(%3) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf32> - return %4 : tensor<3x4x5xf32> - -// CHECK-LABEL: donot_clip_stick_transpose -// CHECK: zhigh.Unstick -// CHECK: onnx.Transpose. -// CHECK: zhigh.Stick -// CHECK: zhigh.Softmax -// CHECK: zhigh.Unstick -} - -// ----- - -// Do not clip because concat does not change the zTensor. -func.func @donot_clip_stick_concat(%arg0: tensor<3x2x5xf16, #zhigh.layout<{dataLayout = "3D"}>>, %arg1: tensor<3x2x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<3x4x5xf32> { - %0 = "zhigh.Unstick"(%arg0) : (tensor<3x2x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<3x2x5xf32> - %1 = "zhigh.Unstick"(%arg1) : (tensor<3x2x5xf16, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<3x2x5xf32> - %2 = "onnx.Concat"(%0, %1) { axis = 1 : si64} : (tensor<3x2x5xf32>, tensor<3x2x5xf32>) -> tensor<3x4x5xf32> - %3 = "zhigh.Stick"(%2) {layout = "3DS"} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %4 = "zhigh.Softmax"(%3) {act_func = "ACT_NONE"} : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>> - %5 = "zhigh.Unstick"(%4) : (tensor<3x4x5xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<3x4x5xf32> - return %5 : tensor<3x4x5xf32> - -// CHECK-LABEL: donot_clip_stick_concat -// CHECK: zhigh.Unstick -// CHECK: zhigh.Unstick -// CHECK: onnx.Concat. -// CHECK: zhigh.Stick -// CHECK: zhigh.Softmax -// CHECK: zhigh.Unstick -}