diff --git a/docs/Dialects/onnx.md b/docs/Dialects/onnx.md index 43339eb658..1bc9a786fe 100644 --- a/docs/Dialects/onnx.md +++ b/docs/Dialects/onnx.md @@ -6588,10 +6588,16 @@ Effects: `MemoryEffects::Effect{}` ### `onnx.PrintSignature` (ONNXPrintSignatureOp) -_ONNX Op to print type signature of its input operands_ +_ONNX Op to print type signature or data of its input operands_ -Print type signature of the op's input operands. This operation is introduced early -so as to preserve the name of the original ONNX op. +Print type signature or data of the input operands of this op. +The parameter op_name specifies a string to be printed before the tensors. +and usually the op_name and onnx_node_name are used. +This operation is introduced early so as to preserve the name of the original ONNX op. +The argument print_data control whether the data of the tensors to be printed. +When print_data == 1, the data of the tensor will be printed. Otherwise, just shape. +The argument input specifies the tensor to be printed. They could be a list +of the inputs and outputs of an ONNX op. This operation is not part of the standard and was added to assist onnx-mlir. @@ -6600,6 +6606,7 @@ This operation is not part of the standard and was added to assist onnx-mlir. +
AttributeMLIR TypeDescription
op_name::mlir::StringAttrstring attribute
print_data::mlir::IntegerAttr64-bit signed integer attribute
#### Operands: diff --git a/docs/Dialects/zlow.md b/docs/Dialects/zlow.md index e97ece5313..90af91485c 100644 --- a/docs/Dialects/zlow.md +++ b/docs/Dialects/zlow.md @@ -850,7 +850,6 @@ Traits: `MemRefsNormalizable` | Operand | Description | | :-----: | ----------- | | `X` | memref of dlfloat16 type values -| `shape` | memref of 64-bit signless integer values | `Out` | memref of dlfloat16 type values ### `zlow.sigmoid` (::onnx_mlir::zlow::ZLowSigmoidOp) diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp index cb9e46a300..4d083366c6 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp @@ -260,8 +260,7 @@ void addPassesNNPA(mlir::OwningOpRef &module, else if (optStr == "-O3") optLevel = OptLevel::O3; // Lower ONNX to Krnl, ZHigh to ZLow. - addONNXToKrnlPasses( - pm, optLevel, /*enableCSE*/ true, instrumentSignatures, ONNXOpStats); + addONNXToKrnlPasses(pm, optLevel, /*enableCSE*/ true, ONNXOpStats); if (nnpaEmissionTarget >= EmitZLowIR) emissionTarget = EmitMLIR; diff --git a/src/Compiler/CompilerOptions.cpp b/src/Compiler/CompilerOptions.cpp index 4810a47a41..45d50047c3 100644 --- a/src/Compiler/CompilerOptions.cpp +++ b/src/Compiler/CompilerOptions.cpp @@ -72,6 +72,7 @@ std::string instrumentOps; // onnx-mlir only unsigned instrumentControlBits; // onnx-mlir only std::string parallelizeOps; // onnx-mlir only std::string instrumentSignatures; // onnx-mlir only +std::string instrumentOnnxNode; // onnx-mlir only std::string ONNXOpStats; // onnx-mlir only int onnxOpTransformThreshold; // onnx-mlir only bool onnxOpTransformReport; // onnx-mlir only @@ -501,17 +502,44 @@ static llvm::cl::opt parallelizeOpsOpt("parallelize-ops", static llvm::cl::opt instrumentSignatureOpt( "instrument-signature", - llvm::cl::desc("Specify which high-level operations should print their" - " input type(s) and shape(s)\n" - "\"ALL\" or \"\" for all available operations.\n" - "\"NONE\" for no instrument (default).\n" - "\"ops1,ops2, ...\" for the multiple ops.\n" - "e.g. \"onnx.MatMul,onnx.Add\" for MatMul and Add ops.\n" - "Asterisk is also available.\n" - "e.g. \"onnx.*\" for all onnx operations.\n"), + llvm::cl::desc( + "Specify which high-level operations should be selected for printing\n" + "the type and shape of their input/output tensors.\n" + "The ops are selected by their op name.\n" + "The instrument-signature defines the pattern to select the ops.\n" + "\"NONE\" for no instrument (default).\n" + "\"ALL\" or \"\" for all available operations.\n" + "Except for the special values, the regexp is used for matching.\n" + "\"ops1,ops2, ...\" for the multiple ops.\n" + "e.g. \"onnx.MatMul,onnx.Add\" for MatMul and Add op in onnx dialect.\n" + "Asterisk is also available.\n" + "e.g. \"onnx.*\" for all onnx operations.\n"), llvm::cl::location(instrumentSignatures), llvm::cl::init("NONE"), llvm::cl::cat(OnnxMlirOptions)); +static llvm::cl::opt instrumentONNXNodeOpt( + "instrument-onnx-node", + llvm::cl::desc( + "Specify which onnx operation node will be selected for \n" + "inserting a runtime call after the node to print the data of\n" + "their input/output tensors.\n" + "The ops are selected by their onnx node name, which is a string\n" + "attribute unique to each onnx node (most of time).\n" + "You can find them in the output of --EmitONNXIR\n" + "Other instrumentation in onnx-mlir is specified by op->getName(),\n" + "namely, the type of onnx operation, such Add, Matmul, and etc.\n" + "This option is able to pinpoint to a particular node.\n" + "The instrument-onnx-node defines the pattern to select.\n" + "\"NONE\" for no instrument (default).\n" + "Except for the special values, the regexp is used for matching.\n" + "\"/layer1/MatMul, onnx.Add_0, ...\" for the multiple nodes.\n" + "Asterisk is also available. For example:\n" + "\"onnx.Add_*\" for all AddOp. This feature allows you to specify\n" + "part of the target of onnx_node_name, as long as it is long enough\n" + "to be unique.\n"), + llvm::cl::location(instrumentOnnxNode), llvm::cl::init("NONE"), + llvm::cl::cat(OnnxMlirOptions)); + static llvm::cl::opt ONNXOpStatsOpt("onnx-op-stats", llvm::cl::desc( "Report the occurrence frequency of ONNX ops in JSON or TXT format:\n" diff --git a/src/Compiler/CompilerOptions.hpp b/src/Compiler/CompilerOptions.hpp index 045e3aeaaa..85c09b065d 100644 --- a/src/Compiler/CompilerOptions.hpp +++ b/src/Compiler/CompilerOptions.hpp @@ -118,6 +118,7 @@ extern std::string instrumentOps; // onnx-mlir only extern unsigned instrumentControlBits; // onnx-mlir only extern std::string parallelizeOps; // onnx-mlir only extern std::string instrumentSignatures; // onnx-mlir only +extern std::string instrumentOnnxNode; // onnx-mlir only extern std::string ONNXOpStats; // onnx-mlir only extern int onnxOpTransformThreshold; // onnx-mlir only extern bool onnxOpTransformReport; // onnx-mlir only diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index efcb10d829..02ecde0241 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -169,10 +169,16 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU, if (instrumentStage == onnx_mlir::InstrumentStages::Onnx) pm.addNestedPass( onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions)); + // Print Signatures of each op at runtime if enabled. Should not run + // signature and instrument passes at the same time as time may include printf + // overheads. + if (instrumentSignatures != "NONE" || instrumentOnnxNode != "NONE") + pm.addNestedPass(onnx_mlir::createInstrumentONNXSignaturePass( + instrumentSignatures, instrumentOnnxNode)); } void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, - std::string instrumentSignatureString, std::string ONNXOpsStatFormat) { + std::string ONNXOpsStatFormat) { if (enableCSE) // Eliminate common sub-expressions before lowering to Krnl. // TODO: enable this by default when we make sure it works flawlessly. @@ -196,12 +202,6 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, } } - // Print Signatures of each op at runtime if enabled. Should not run - // signature and instrument passes at the same time as time may include printf - // overheads. - if (instrumentSignatureString != "NONE") - pm.addNestedPass(onnx_mlir::createInstrumentONNXSignaturePass( - instrumentSignatureString)); pm.addPass(onnx_mlir::createLowerToKrnlPass(/*enableTiling*/ optLevel >= 3, /*enableSIMD*/ optLevel >= 3 && !disableSimdOption, enableParallel, /*enableFastMath*/ optLevel >= 3 && enableFastMathOption, @@ -325,8 +325,8 @@ void addPasses(mlir::OwningOpRef &module, mlir::PassManager &pm, if (emissionTarget >= EmitMLIR) { if (inputIRLevel <= ONNXLevel) - addONNXToKrnlPasses(pm, OptimizationLevel, /*enableCSE*/ true, - instrumentSignatures, ONNXOpStats); + addONNXToKrnlPasses( + pm, OptimizationLevel, /*enableCSE*/ true, ONNXOpStats); if (inputIRLevel <= MLIRLevel) addKrnlToAffinePasses(pm); } diff --git a/src/Compiler/CompilerPasses.hpp b/src/Compiler/CompilerPasses.hpp index 9a6987cf19..3a564e352d 100644 --- a/src/Compiler/CompilerPasses.hpp +++ b/src/Compiler/CompilerPasses.hpp @@ -23,7 +23,7 @@ void configurePasses(); void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU, bool donotScrubDisposableElementsAttr = false); void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, - std::string instrumentSignatureString, std::string ONNXOpsStatFilename); + std::string ONNXOpsStatFilename); void addKrnlToAffinePasses(mlir::PassManager &pm); void addKrnlToLLVMPasses( mlir::OpPassManager &pm, std::string outputNameNoExt, bool enableCSE); diff --git a/src/Conversion/ONNXToKrnl/Tensor/PrintSignature.cpp b/src/Conversion/ONNXToKrnl/Tensor/PrintSignature.cpp index df220ee303..4c73a86c85 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/PrintSignature.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/PrintSignature.cpp @@ -50,15 +50,26 @@ struct ONNXPrintSignatureLowering op, msg + "(no tensors)\n%e", noneVal); return success(); } + // Control how the tensor will be printed + // Print the only the shape. + std::string printControl = ", %t%e"; + if (printSignatureOp.getPrintData() == 1) { + // The data of tensor will be printed + printControl = "%t%d\n"; + msg += "\n"; + } Value lastVal = printVal.pop_back_val(); // Print all but the last one. for (Value oper : printVal) { - create.krnl.printTensor(msg + ", %t%e", oper); + create.krnl.printTensor(msg + printControl, oper); msg = "%i"; } // Print the last one with replace with new op. + if (printSignatureOp.getPrintData() == 0) { + printControl = ", %t\n%e"; + } rewriter.replaceOpWithNewOp( - op, msg + ", %t\n%e", lastVal); + op, msg + printControl, lastVal); return success(); } }; diff --git a/src/Dialect/ONNX/AdditionalONNXOps.td b/src/Dialect/ONNX/AdditionalONNXOps.td index a72af4d7c2..8f4bc47650 100644 --- a/src/Dialect/ONNX/AdditionalONNXOps.td +++ b/src/Dialect/ONNX/AdditionalONNXOps.td @@ -322,15 +322,21 @@ def ONNXNoneOp : ONNX_Op<"NoValue", [ConstantLike, Pure]> { //===----------------------------------------------------------------------===// // ONNX PrintSignatureOp. def ONNXPrintSignatureOp:ONNX_Op<"PrintSignature", []> { - let summary = "ONNX Op to print type signature of its input operands"; + let summary = "ONNX Op to print type signature or data of its input operands"; let description = [{ - Print type signature of the op's input operands. This operation is introduced early - so as to preserve the name of the original ONNX op. + Print type signature or data of the input operands of this op. + The parameter op_name specifies a string to be printed before the tensors. + and usually the op_name and onnx_node_name are used. + This operation is introduced early so as to preserve the name of the original ONNX op. + The argument print_data control whether the data of the tensors to be printed. + When print_data == 1, the data of the tensor will be printed. Otherwise, just shape. + The argument input specifies the tensor to be printed. They could be a list + of the inputs and outputs of an ONNX op. This operation is not part of the standard and was added to assist onnx-mlir. }]; - let arguments = (ins StrAttr:$op_name, Variadic>:$input); + let arguments = (ins StrAttr:$op_name, SI64Attr:$print_data, Variadic>:$input); } //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp b/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp index ad88bceb42..73d186abd4 100644 --- a/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp +++ b/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp @@ -51,13 +51,15 @@ class InstrumentONNXSignaturePass : mlir::PassWrapper>() { signaturePattern = pass.signaturePattern; + nodeNamePattern = pass.nodeNamePattern; } - InstrumentONNXSignaturePass(const std::string pattern) { - signaturePattern = pattern; - } + InstrumentONNXSignaturePass( + const std::string opPattern, const std::string nodePattern) + : signaturePattern(opPattern), nodeNamePattern(nodePattern) {} private: std::string signaturePattern; + std::string nodeNamePattern; public: StringRef getArgument() const override { @@ -71,30 +73,54 @@ class InstrumentONNXSignaturePass void runOnOperation() override { onnx_mlir::EnableByRegexOption traceSpecificOpPattern( + + /*emptyIsNone*/ false); + onnx_mlir::EnableByRegexOption traceSpecificNodePattern( /*emptyIsNone*/ false); + traceSpecificOpPattern.setRegexString(signaturePattern); + traceSpecificNodePattern.setRegexString(nodeNamePattern); // Iterate on the operations nested in this function. getOperation().walk([&](mlir::Operation *op) { - std::string opName = op->getName().getStringRef().str(); auto dialect = op->getDialect(); + Location loc = op->getLoc(); + // Define a lambda function to check whether the node is selected by + // its op name or node name, and if yes, insert ONNXSignatureOp + auto checkAndInsert = [&](onnx_mlir::EnableByRegexOption &pattern, + std::string matchString, int detail) { + if (pattern.isEnabled(matchString)) { + // Add signature printing op. + OpBuilder builder(op); + std::string opName = op->getName().getStringRef().str(); + std::string nodeName = onnx_mlir::getNodeNameInPresenceOfOpt(op); + std::string fullName = opName + ", " + nodeName; + StringAttr fullNameAttr = builder.getStringAttr(fullName); + // Enqueue all input operands, and then the results. + llvm::SmallVector operAndRes(op->getOperands()); + for (Value res : op->getResults()) + operAndRes.emplace_back(res); + // Since we may use the result of an operation, we must insert the + // print operation after the operation. + builder.setInsertionPointAfter(op); + // When one node is selected, print the details of the tensor. + builder.create( + loc, fullNameAttr, detail, operAndRes); + } + }; + if (isa(dialect) || isa(op)) { // Always skip function dialects (such as function call/return), as well // as ONNX print signature ops. - } else if (traceSpecificOpPattern.isEnabled(opName)) { - // Add signature printing op. - Location loc = op->getLoc(); - OpBuilder builder(op); - std::string nodeName = onnx_mlir::getNodeNameInPresenceOfOpt(op); - std::string fullName = opName + ", " + nodeName; - StringAttr fullNameAttr = builder.getStringAttr(fullName); - // Enqueue all input operands, and then the results. - llvm::SmallVector operAndRes(op->getOperands()); - for (Value res : op->getResults()) - operAndRes.emplace_back(res); - // Since we may use the result of an operation, we must insert the - // print operation after the operation. - builder.setInsertionPointAfter(op); - builder.create(loc, fullNameAttr, operAndRes); + } else if (signaturePattern != "NONE") { + std::string opName = op->getName().getStringRef().str(); + checkAndInsert(traceSpecificOpPattern, opName, 0); + } else if (nodeNamePattern != "NONE") { + StringAttr onnxNodeName = + op->getAttrOfType("onnx_node_name"); + if (onnxNodeName && !onnxNodeName.getValue().empty()) { + std::string nodeNameString = onnxNodeName.getValue().str(); + checkAndInsert(traceSpecificNodePattern, nodeNameString, 1); + } } }); } @@ -105,6 +131,6 @@ class InstrumentONNXSignaturePass * Create an instrumentation pass. */ std::unique_ptr onnx_mlir::createInstrumentONNXSignaturePass( - const std::string pattern) { - return std::make_unique(pattern); + const std::string pattern, const std::string nodePattern) { + return std::make_unique(pattern, nodePattern); } diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 070fe3d671..c696845b59 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -63,7 +63,7 @@ std::unique_ptr createInstrumentCleanupPass(); /// Passes for instrumenting the ONNX ops to print their operand type /// signatures at runtime. std::unique_ptr createInstrumentONNXSignaturePass( - const std::string pattern); + const std::string opPattern, const std::string nodePattern); /// Pass for simplifying shape-related ONNX operations. std::unique_ptr createSimplifyShapeRelatedOpsPass(); @@ -129,4 +129,4 @@ std::unique_ptr createConvertKrnlToLLVMPass(bool verifyInputTensors, std::unique_ptr createConvertONNXToTOSAPass(); } // namespace onnx_mlir -#endif \ No newline at end of file +#endif diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp index 671dea1857..37d6d8b095 100644 --- a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp +++ b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp @@ -75,7 +75,7 @@ void registerOMPasses(int optLevel) { }); mlir::registerPass([]() -> std::unique_ptr { - return createInstrumentONNXSignaturePass("NONE"); + return createInstrumentONNXSignaturePass("NONE", "NONE"); }); mlir::registerPass([]() -> std::unique_ptr {