diff --git a/docs/Dialects/onnx.md b/docs/Dialects/onnx.md
index 43339eb658..1bc9a786fe 100644
--- a/docs/Dialects/onnx.md
+++ b/docs/Dialects/onnx.md
@@ -6588,10 +6588,16 @@ Effects: `MemoryEffects::Effect{}`
### `onnx.PrintSignature` (ONNXPrintSignatureOp)
-_ONNX Op to print type signature of its input operands_
+_ONNX Op to print type signature or data of its input operands_
-Print type signature of the op's input operands. This operation is introduced early
-so as to preserve the name of the original ONNX op.
+Print type signature or data of the input operands of this op.
+The parameter op_name specifies a string to be printed before the tensors.
+and usually the op_name and onnx_node_name are used.
+This operation is introduced early so as to preserve the name of the original ONNX op.
+The argument print_data control whether the data of the tensors to be printed.
+When print_data == 1, the data of the tensor will be printed. Otherwise, just shape.
+The argument input specifies the tensor to be printed. They could be a list
+of the inputs and outputs of an ONNX op.
This operation is not part of the standard and was added to assist onnx-mlir.
@@ -6600,6 +6606,7 @@ This operation is not part of the standard and was added to assist onnx-mlir.
Attribute | MLIR Type | Description |
op_name | ::mlir::StringAttr | string attribute |
+print_data | ::mlir::IntegerAttr | 64-bit signed integer attribute |
#### Operands:
diff --git a/docs/Dialects/zlow.md b/docs/Dialects/zlow.md
index e97ece5313..90af91485c 100644
--- a/docs/Dialects/zlow.md
+++ b/docs/Dialects/zlow.md
@@ -850,7 +850,6 @@ Traits: `MemRefsNormalizable`
| Operand | Description |
| :-----: | ----------- |
| `X` | memref of dlfloat16 type values
-| `shape` | memref of 64-bit signless integer values
| `Out` | memref of dlfloat16 type values
### `zlow.sigmoid` (::onnx_mlir::zlow::ZLowSigmoidOp)
diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
index cb9e46a300..4d083366c6 100644
--- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
+++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
@@ -260,8 +260,7 @@ void addPassesNNPA(mlir::OwningOpRef &module,
else if (optStr == "-O3")
optLevel = OptLevel::O3;
// Lower ONNX to Krnl, ZHigh to ZLow.
- addONNXToKrnlPasses(
- pm, optLevel, /*enableCSE*/ true, instrumentSignatures, ONNXOpStats);
+ addONNXToKrnlPasses(pm, optLevel, /*enableCSE*/ true, ONNXOpStats);
if (nnpaEmissionTarget >= EmitZLowIR)
emissionTarget = EmitMLIR;
diff --git a/src/Compiler/CompilerOptions.cpp b/src/Compiler/CompilerOptions.cpp
index 4810a47a41..45d50047c3 100644
--- a/src/Compiler/CompilerOptions.cpp
+++ b/src/Compiler/CompilerOptions.cpp
@@ -72,6 +72,7 @@ std::string instrumentOps; // onnx-mlir only
unsigned instrumentControlBits; // onnx-mlir only
std::string parallelizeOps; // onnx-mlir only
std::string instrumentSignatures; // onnx-mlir only
+std::string instrumentOnnxNode; // onnx-mlir only
std::string ONNXOpStats; // onnx-mlir only
int onnxOpTransformThreshold; // onnx-mlir only
bool onnxOpTransformReport; // onnx-mlir only
@@ -501,17 +502,44 @@ static llvm::cl::opt parallelizeOpsOpt("parallelize-ops",
static llvm::cl::opt instrumentSignatureOpt(
"instrument-signature",
- llvm::cl::desc("Specify which high-level operations should print their"
- " input type(s) and shape(s)\n"
- "\"ALL\" or \"\" for all available operations.\n"
- "\"NONE\" for no instrument (default).\n"
- "\"ops1,ops2, ...\" for the multiple ops.\n"
- "e.g. \"onnx.MatMul,onnx.Add\" for MatMul and Add ops.\n"
- "Asterisk is also available.\n"
- "e.g. \"onnx.*\" for all onnx operations.\n"),
+ llvm::cl::desc(
+ "Specify which high-level operations should be selected for printing\n"
+ "the type and shape of their input/output tensors.\n"
+ "The ops are selected by their op name.\n"
+ "The instrument-signature defines the pattern to select the ops.\n"
+ "\"NONE\" for no instrument (default).\n"
+ "\"ALL\" or \"\" for all available operations.\n"
+ "Except for the special values, the regexp is used for matching.\n"
+ "\"ops1,ops2, ...\" for the multiple ops.\n"
+ "e.g. \"onnx.MatMul,onnx.Add\" for MatMul and Add op in onnx dialect.\n"
+ "Asterisk is also available.\n"
+ "e.g. \"onnx.*\" for all onnx operations.\n"),
llvm::cl::location(instrumentSignatures), llvm::cl::init("NONE"),
llvm::cl::cat(OnnxMlirOptions));
+static llvm::cl::opt instrumentONNXNodeOpt(
+ "instrument-onnx-node",
+ llvm::cl::desc(
+ "Specify which onnx operation node will be selected for \n"
+ "inserting a runtime call after the node to print the data of\n"
+ "their input/output tensors.\n"
+ "The ops are selected by their onnx node name, which is a string\n"
+ "attribute unique to each onnx node (most of time).\n"
+ "You can find them in the output of --EmitONNXIR\n"
+ "Other instrumentation in onnx-mlir is specified by op->getName(),\n"
+ "namely, the type of onnx operation, such Add, Matmul, and etc.\n"
+ "This option is able to pinpoint to a particular node.\n"
+ "The instrument-onnx-node defines the pattern to select.\n"
+ "\"NONE\" for no instrument (default).\n"
+ "Except for the special values, the regexp is used for matching.\n"
+ "\"/layer1/MatMul, onnx.Add_0, ...\" for the multiple nodes.\n"
+ "Asterisk is also available. For example:\n"
+ "\"onnx.Add_*\" for all AddOp. This feature allows you to specify\n"
+ "part of the target of onnx_node_name, as long as it is long enough\n"
+ "to be unique.\n"),
+ llvm::cl::location(instrumentOnnxNode), llvm::cl::init("NONE"),
+ llvm::cl::cat(OnnxMlirOptions));
+
static llvm::cl::opt ONNXOpStatsOpt("onnx-op-stats",
llvm::cl::desc(
"Report the occurrence frequency of ONNX ops in JSON or TXT format:\n"
diff --git a/src/Compiler/CompilerOptions.hpp b/src/Compiler/CompilerOptions.hpp
index 045e3aeaaa..85c09b065d 100644
--- a/src/Compiler/CompilerOptions.hpp
+++ b/src/Compiler/CompilerOptions.hpp
@@ -118,6 +118,7 @@ extern std::string instrumentOps; // onnx-mlir only
extern unsigned instrumentControlBits; // onnx-mlir only
extern std::string parallelizeOps; // onnx-mlir only
extern std::string instrumentSignatures; // onnx-mlir only
+extern std::string instrumentOnnxNode; // onnx-mlir only
extern std::string ONNXOpStats; // onnx-mlir only
extern int onnxOpTransformThreshold; // onnx-mlir only
extern bool onnxOpTransformReport; // onnx-mlir only
diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp
index efcb10d829..02ecde0241 100644
--- a/src/Compiler/CompilerPasses.cpp
+++ b/src/Compiler/CompilerPasses.cpp
@@ -169,10 +169,16 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
if (instrumentStage == onnx_mlir::InstrumentStages::Onnx)
pm.addNestedPass(
onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions));
+ // Print Signatures of each op at runtime if enabled. Should not run
+ // signature and instrument passes at the same time as time may include printf
+ // overheads.
+ if (instrumentSignatures != "NONE" || instrumentOnnxNode != "NONE")
+ pm.addNestedPass(onnx_mlir::createInstrumentONNXSignaturePass(
+ instrumentSignatures, instrumentOnnxNode));
}
void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
- std::string instrumentSignatureString, std::string ONNXOpsStatFormat) {
+ std::string ONNXOpsStatFormat) {
if (enableCSE)
// Eliminate common sub-expressions before lowering to Krnl.
// TODO: enable this by default when we make sure it works flawlessly.
@@ -196,12 +202,6 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
}
}
- // Print Signatures of each op at runtime if enabled. Should not run
- // signature and instrument passes at the same time as time may include printf
- // overheads.
- if (instrumentSignatureString != "NONE")
- pm.addNestedPass(onnx_mlir::createInstrumentONNXSignaturePass(
- instrumentSignatureString));
pm.addPass(onnx_mlir::createLowerToKrnlPass(/*enableTiling*/ optLevel >= 3,
/*enableSIMD*/ optLevel >= 3 && !disableSimdOption, enableParallel,
/*enableFastMath*/ optLevel >= 3 && enableFastMathOption,
@@ -325,8 +325,8 @@ void addPasses(mlir::OwningOpRef &module, mlir::PassManager &pm,
if (emissionTarget >= EmitMLIR) {
if (inputIRLevel <= ONNXLevel)
- addONNXToKrnlPasses(pm, OptimizationLevel, /*enableCSE*/ true,
- instrumentSignatures, ONNXOpStats);
+ addONNXToKrnlPasses(
+ pm, OptimizationLevel, /*enableCSE*/ true, ONNXOpStats);
if (inputIRLevel <= MLIRLevel)
addKrnlToAffinePasses(pm);
}
diff --git a/src/Compiler/CompilerPasses.hpp b/src/Compiler/CompilerPasses.hpp
index 9a6987cf19..3a564e352d 100644
--- a/src/Compiler/CompilerPasses.hpp
+++ b/src/Compiler/CompilerPasses.hpp
@@ -23,7 +23,7 @@ void configurePasses();
void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
bool donotScrubDisposableElementsAttr = false);
void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
- std::string instrumentSignatureString, std::string ONNXOpsStatFilename);
+ std::string ONNXOpsStatFilename);
void addKrnlToAffinePasses(mlir::PassManager &pm);
void addKrnlToLLVMPasses(
mlir::OpPassManager &pm, std::string outputNameNoExt, bool enableCSE);
diff --git a/src/Conversion/ONNXToKrnl/Tensor/PrintSignature.cpp b/src/Conversion/ONNXToKrnl/Tensor/PrintSignature.cpp
index df220ee303..4c73a86c85 100644
--- a/src/Conversion/ONNXToKrnl/Tensor/PrintSignature.cpp
+++ b/src/Conversion/ONNXToKrnl/Tensor/PrintSignature.cpp
@@ -50,15 +50,26 @@ struct ONNXPrintSignatureLowering
op, msg + "(no tensors)\n%e", noneVal);
return success();
}
+ // Control how the tensor will be printed
+ // Print the only the shape.
+ std::string printControl = ", %t%e";
+ if (printSignatureOp.getPrintData() == 1) {
+ // The data of tensor will be printed
+ printControl = "%t%d\n";
+ msg += "\n";
+ }
Value lastVal = printVal.pop_back_val();
// Print all but the last one.
for (Value oper : printVal) {
- create.krnl.printTensor(msg + ", %t%e", oper);
+ create.krnl.printTensor(msg + printControl, oper);
msg = "%i";
}
// Print the last one with replace with new op.
+ if (printSignatureOp.getPrintData() == 0) {
+ printControl = ", %t\n%e";
+ }
rewriter.replaceOpWithNewOp(
- op, msg + ", %t\n%e", lastVal);
+ op, msg + printControl, lastVal);
return success();
}
};
diff --git a/src/Dialect/ONNX/AdditionalONNXOps.td b/src/Dialect/ONNX/AdditionalONNXOps.td
index a72af4d7c2..8f4bc47650 100644
--- a/src/Dialect/ONNX/AdditionalONNXOps.td
+++ b/src/Dialect/ONNX/AdditionalONNXOps.td
@@ -322,15 +322,21 @@ def ONNXNoneOp : ONNX_Op<"NoValue", [ConstantLike, Pure]> {
//===----------------------------------------------------------------------===//
// ONNX PrintSignatureOp.
def ONNXPrintSignatureOp:ONNX_Op<"PrintSignature", []> {
- let summary = "ONNX Op to print type signature of its input operands";
+ let summary = "ONNX Op to print type signature or data of its input operands";
let description = [{
- Print type signature of the op's input operands. This operation is introduced early
- so as to preserve the name of the original ONNX op.
+ Print type signature or data of the input operands of this op.
+ The parameter op_name specifies a string to be printed before the tensors.
+ and usually the op_name and onnx_node_name are used.
+ This operation is introduced early so as to preserve the name of the original ONNX op.
+ The argument print_data control whether the data of the tensors to be printed.
+ When print_data == 1, the data of the tensor will be printed. Otherwise, just shape.
+ The argument input specifies the tensor to be printed. They could be a list
+ of the inputs and outputs of an ONNX op.
This operation is not part of the standard and was added to assist onnx-mlir.
}];
- let arguments = (ins StrAttr:$op_name, Variadic>:$input);
+ let arguments = (ins StrAttr:$op_name, SI64Attr:$print_data, Variadic>:$input);
}
//===----------------------------------------------------------------------===//
diff --git a/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp b/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp
index ad88bceb42..73d186abd4 100644
--- a/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp
+++ b/src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp
@@ -51,13 +51,15 @@ class InstrumentONNXSignaturePass
: mlir::PassWrapper>() {
signaturePattern = pass.signaturePattern;
+ nodeNamePattern = pass.nodeNamePattern;
}
- InstrumentONNXSignaturePass(const std::string pattern) {
- signaturePattern = pattern;
- }
+ InstrumentONNXSignaturePass(
+ const std::string opPattern, const std::string nodePattern)
+ : signaturePattern(opPattern), nodeNamePattern(nodePattern) {}
private:
std::string signaturePattern;
+ std::string nodeNamePattern;
public:
StringRef getArgument() const override {
@@ -71,30 +73,54 @@ class InstrumentONNXSignaturePass
void runOnOperation() override {
onnx_mlir::EnableByRegexOption traceSpecificOpPattern(
+
+ /*emptyIsNone*/ false);
+ onnx_mlir::EnableByRegexOption traceSpecificNodePattern(
/*emptyIsNone*/ false);
+
traceSpecificOpPattern.setRegexString(signaturePattern);
+ traceSpecificNodePattern.setRegexString(nodeNamePattern);
// Iterate on the operations nested in this function.
getOperation().walk([&](mlir::Operation *op) {
- std::string opName = op->getName().getStringRef().str();
auto dialect = op->getDialect();
+ Location loc = op->getLoc();
+ // Define a lambda function to check whether the node is selected by
+ // its op name or node name, and if yes, insert ONNXSignatureOp
+ auto checkAndInsert = [&](onnx_mlir::EnableByRegexOption &pattern,
+ std::string matchString, int detail) {
+ if (pattern.isEnabled(matchString)) {
+ // Add signature printing op.
+ OpBuilder builder(op);
+ std::string opName = op->getName().getStringRef().str();
+ std::string nodeName = onnx_mlir::getNodeNameInPresenceOfOpt(op);
+ std::string fullName = opName + ", " + nodeName;
+ StringAttr fullNameAttr = builder.getStringAttr(fullName);
+ // Enqueue all input operands, and then the results.
+ llvm::SmallVector operAndRes(op->getOperands());
+ for (Value res : op->getResults())
+ operAndRes.emplace_back(res);
+ // Since we may use the result of an operation, we must insert the
+ // print operation after the operation.
+ builder.setInsertionPointAfter(op);
+ // When one node is selected, print the details of the tensor.
+ builder.create(
+ loc, fullNameAttr, detail, operAndRes);
+ }
+ };
+
if (isa(dialect) || isa(op)) {
// Always skip function dialects (such as function call/return), as well
// as ONNX print signature ops.
- } else if (traceSpecificOpPattern.isEnabled(opName)) {
- // Add signature printing op.
- Location loc = op->getLoc();
- OpBuilder builder(op);
- std::string nodeName = onnx_mlir::getNodeNameInPresenceOfOpt(op);
- std::string fullName = opName + ", " + nodeName;
- StringAttr fullNameAttr = builder.getStringAttr(fullName);
- // Enqueue all input operands, and then the results.
- llvm::SmallVector operAndRes(op->getOperands());
- for (Value res : op->getResults())
- operAndRes.emplace_back(res);
- // Since we may use the result of an operation, we must insert the
- // print operation after the operation.
- builder.setInsertionPointAfter(op);
- builder.create(loc, fullNameAttr, operAndRes);
+ } else if (signaturePattern != "NONE") {
+ std::string opName = op->getName().getStringRef().str();
+ checkAndInsert(traceSpecificOpPattern, opName, 0);
+ } else if (nodeNamePattern != "NONE") {
+ StringAttr onnxNodeName =
+ op->getAttrOfType("onnx_node_name");
+ if (onnxNodeName && !onnxNodeName.getValue().empty()) {
+ std::string nodeNameString = onnxNodeName.getValue().str();
+ checkAndInsert(traceSpecificNodePattern, nodeNameString, 1);
+ }
}
});
}
@@ -105,6 +131,6 @@ class InstrumentONNXSignaturePass
* Create an instrumentation pass.
*/
std::unique_ptr onnx_mlir::createInstrumentONNXSignaturePass(
- const std::string pattern) {
- return std::make_unique(pattern);
+ const std::string pattern, const std::string nodePattern) {
+ return std::make_unique(pattern, nodePattern);
}
diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp
index 070fe3d671..c696845b59 100644
--- a/src/Pass/Passes.hpp
+++ b/src/Pass/Passes.hpp
@@ -63,7 +63,7 @@ std::unique_ptr createInstrumentCleanupPass();
/// Passes for instrumenting the ONNX ops to print their operand type
/// signatures at runtime.
std::unique_ptr createInstrumentONNXSignaturePass(
- const std::string pattern);
+ const std::string opPattern, const std::string nodePattern);
/// Pass for simplifying shape-related ONNX operations.
std::unique_ptr createSimplifyShapeRelatedOpsPass();
@@ -129,4 +129,4 @@ std::unique_ptr createConvertKrnlToLLVMPass(bool verifyInputTensors,
std::unique_ptr createConvertONNXToTOSAPass();
} // namespace onnx_mlir
-#endif
\ No newline at end of file
+#endif
diff --git a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp
index 671dea1857..37d6d8b095 100644
--- a/src/Tools/onnx-mlir-opt/RegisterPasses.cpp
+++ b/src/Tools/onnx-mlir-opt/RegisterPasses.cpp
@@ -75,7 +75,7 @@ void registerOMPasses(int optLevel) {
});
mlir::registerPass([]() -> std::unique_ptr {
- return createInstrumentONNXSignaturePass("NONE");
+ return createInstrumentONNXSignaturePass("NONE", "NONE");
});
mlir::registerPass([]() -> std::unique_ptr {