Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend instrumentSignature to print data #3078

Merged
merged 9 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions docs/Dialects/onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -6588,10 +6588,16 @@ Effects: `MemoryEffects::Effect{}`

### `onnx.PrintSignature` (ONNXPrintSignatureOp)

_ONNX Op to print type signature of its input operands_
_ONNX Op to print type signature or data of its input operands_

Print type signature of the op's input operands. This operation is introduced early
so as to preserve the name of the original ONNX op.
Print type signature or data of the input operands of this op.
The parameter op_name specifies a string to be printed before the tensors.
and usually the op_name and onnx_node_name are used.
This operation is introduced early so as to preserve the name of the original ONNX op.
The argument print_data control whether the data of the tensors to be printed.
When print_data == 1, the data of the tensor will be printed. Otherwise, just shape.
The argument input specifies the tensor to be printed. They could be a list
of the inputs and outputs of an ONNX op.

This operation is not part of the standard and was added to assist onnx-mlir.

Expand All @@ -6600,6 +6606,7 @@ This operation is not part of the standard and was added to assist onnx-mlir.
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>op_name</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
<tr><td><code>print_data</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
</table>

#### Operands:
Expand Down
1 change: 0 additions & 1 deletion docs/Dialects/zlow.md
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,6 @@ Traits: `MemRefsNormalizable`
| Operand | Description |
| :-----: | ----------- |
| `X` | memref of dlfloat16 type values
| `shape` | memref of 64-bit signless integer values
| `Out` | memref of dlfloat16 type values

### `zlow.sigmoid` (::onnx_mlir::zlow::ZLowSigmoidOp)
Expand Down
3 changes: 1 addition & 2 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,7 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,
else if (optStr == "-O3")
optLevel = OptLevel::O3;
// Lower ONNX to Krnl, ZHigh to ZLow.
addONNXToKrnlPasses(
pm, optLevel, /*enableCSE*/ true, instrumentSignatures, ONNXOpStats);
addONNXToKrnlPasses(pm, optLevel, /*enableCSE*/ true, ONNXOpStats);

if (nnpaEmissionTarget >= EmitZLowIR)
emissionTarget = EmitMLIR;
Expand Down
44 changes: 36 additions & 8 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ std::string instrumentOps; // onnx-mlir only
unsigned instrumentControlBits; // onnx-mlir only
std::string parallelizeOps; // onnx-mlir only
std::string instrumentSignatures; // onnx-mlir only
std::string instrumentOnnxNode; // onnx-mlir only
std::string ONNXOpStats; // onnx-mlir only
int onnxOpTransformThreshold; // onnx-mlir only
bool onnxOpTransformReport; // onnx-mlir only
Expand Down Expand Up @@ -501,17 +502,44 @@ static llvm::cl::opt<std::string, true> parallelizeOpsOpt("parallelize-ops",

static llvm::cl::opt<std::string, true> instrumentSignatureOpt(
"instrument-signature",
llvm::cl::desc("Specify which high-level operations should print their"
" input type(s) and shape(s)\n"
"\"ALL\" or \"\" for all available operations.\n"
"\"NONE\" for no instrument (default).\n"
"\"ops1,ops2, ...\" for the multiple ops.\n"
"e.g. \"onnx.MatMul,onnx.Add\" for MatMul and Add ops.\n"
"Asterisk is also available.\n"
"e.g. \"onnx.*\" for all onnx operations.\n"),
llvm::cl::desc(
"Specify which high-level operations should be selected for printing\n"
"the type and shape of their input/output tensors.\n"
"The ops are selected by their op name.\n"
"The instrument-signature defines the pattern to select the ops.\n"
"\"NONE\" for no instrument (default).\n"
"\"ALL\" or \"\" for all available operations.\n"
"Except for the special values, the regexp is used for matching.\n"
"\"ops1,ops2, ...\" for the multiple ops.\n"
"e.g. \"onnx.MatMul,onnx.Add\" for MatMul and Add op in onnx dialect.\n"
"Asterisk is also available.\n"
"e.g. \"onnx.*\" for all onnx operations.\n"),
llvm::cl::location(instrumentSignatures), llvm::cl::init("NONE"),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<std::string, true> instrumentONNXNodeOpt(
"instrument-onnx-node",
llvm::cl::desc(
"Specify which onnx operation node will be selected for \n"
"inserting a runtime call after the node to print the data of\n"
"their input/output tensors.\n"
"The ops are selected by their onnx node name, which is a string\n"
"attribute unique to each onnx node (most of time).\n"
"You can find them in the output of --EmitONNXIR\n"
"Other instrumentation in onnx-mlir is specified by op->getName(),\n"
"namely, the type of onnx operation, such Add, Matmul, and etc.\n"
"This option is able to pinpoint to a particular node.\n"
"The instrument-onnx-node defines the pattern to select.\n"
"\"NONE\" for no instrument (default).\n"
"Except for the special values, the regexp is used for matching.\n"
"\"/layer1/MatMul, onnx.Add_0, ...\" for the multiple nodes.\n"
"Asterisk is also available. For example:\n"
"\"onnx.Add_*\" for all AddOp. This feature allows you to specify\n"
"part of the target of onnx_node_name, as long as it is long enough\n"
"to be unique.\n"),
llvm::cl::location(instrumentOnnxNode), llvm::cl::init("NONE"),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<std::string, true> ONNXOpStatsOpt("onnx-op-stats",
llvm::cl::desc(
"Report the occurrence frequency of ONNX ops in JSON or TXT format:\n"
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ extern std::string instrumentOps; // onnx-mlir only
extern unsigned instrumentControlBits; // onnx-mlir only
extern std::string parallelizeOps; // onnx-mlir only
extern std::string instrumentSignatures; // onnx-mlir only
extern std::string instrumentOnnxNode; // onnx-mlir only
extern std::string ONNXOpStats; // onnx-mlir only
extern int onnxOpTransformThreshold; // onnx-mlir only
extern bool onnxOpTransformReport; // onnx-mlir only
Expand Down
18 changes: 9 additions & 9 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,16 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
if (instrumentStage == onnx_mlir::InstrumentStages::Onnx)
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions));
// Print Signatures of each op at runtime if enabled. Should not run
// signature and instrument passes at the same time as time may include printf
// overheads.
if (instrumentSignatures != "NONE" || instrumentOnnxNode != "NONE")
pm.addNestedPass<func::FuncOp>(onnx_mlir::createInstrumentONNXSignaturePass(
instrumentSignatures, instrumentOnnxNode));
}

void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
std::string instrumentSignatureString, std::string ONNXOpsStatFormat) {
std::string ONNXOpsStatFormat) {
if (enableCSE)
// Eliminate common sub-expressions before lowering to Krnl.
// TODO: enable this by default when we make sure it works flawlessly.
Expand All @@ -196,12 +202,6 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
}
}

// Print Signatures of each op at runtime if enabled. Should not run
// signature and instrument passes at the same time as time may include printf
// overheads.
if (instrumentSignatureString != "NONE")
pm.addNestedPass<func::FuncOp>(onnx_mlir::createInstrumentONNXSignaturePass(
instrumentSignatureString));
pm.addPass(onnx_mlir::createLowerToKrnlPass(/*enableTiling*/ optLevel >= 3,
/*enableSIMD*/ optLevel >= 3 && !disableSimdOption, enableParallel,
/*enableFastMath*/ optLevel >= 3 && enableFastMathOption,
Expand Down Expand Up @@ -325,8 +325,8 @@ void addPasses(mlir::OwningOpRef<ModuleOp> &module, mlir::PassManager &pm,

if (emissionTarget >= EmitMLIR) {
if (inputIRLevel <= ONNXLevel)
addONNXToKrnlPasses(pm, OptimizationLevel, /*enableCSE*/ true,
instrumentSignatures, ONNXOpStats);
addONNXToKrnlPasses(
pm, OptimizationLevel, /*enableCSE*/ true, ONNXOpStats);
if (inputIRLevel <= MLIRLevel)
addKrnlToAffinePasses(pm);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Compiler/CompilerPasses.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void configurePasses();
void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
bool donotScrubDisposableElementsAttr = false);
void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
std::string instrumentSignatureString, std::string ONNXOpsStatFilename);
std::string ONNXOpsStatFilename);
void addKrnlToAffinePasses(mlir::PassManager &pm);
void addKrnlToLLVMPasses(
mlir::OpPassManager &pm, std::string outputNameNoExt, bool enableCSE);
Expand Down
11 changes: 9 additions & 2 deletions src/Conversion/ONNXToKrnl/Tensor/PrintSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,22 @@ struct ONNXPrintSignatureLowering
op, msg + "(no tensors)\n%e", noneVal);
return success();
}
// Control how the tensor will be printed
// Print the only the shape.
std::string printControl = "%e";
if (printSignatureOp.getPrintData() == 1) {
// The data of tensor will be printed
printControl = "%d%e";
}
Value lastVal = printVal.pop_back_val();
// Print all but the last one.
for (Value oper : printVal) {
create.krnl.printTensor(msg + ", %t%e", oper);
create.krnl.printTensor(msg + ", %t" + printControl, oper);
msg = "%i";
}
// Print the last one with replace with new op.
rewriter.replaceOpWithNewOp<KrnlPrintTensorOp>(
op, msg + ", %t\n%e", lastVal);
op, msg + ", %t\n" + printControl, lastVal);
return success();
}
};
Expand Down
14 changes: 10 additions & 4 deletions src/Dialect/ONNX/AdditionalONNXOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -322,15 +322,21 @@ def ONNXNoneOp : ONNX_Op<"NoValue", [ConstantLike, Pure]> {
//===----------------------------------------------------------------------===//
// ONNX PrintSignatureOp.
def ONNXPrintSignatureOp:ONNX_Op<"PrintSignature", []> {
let summary = "ONNX Op to print type signature of its input operands";
let summary = "ONNX Op to print type signature or data of its input operands";
let description = [{
Print type signature of the op's input operands. This operation is introduced early
so as to preserve the name of the original ONNX op.
Print type signature or data of the input operands of this op.
The parameter op_name specifies a string to be printed before the tensors.
and usually the op_name and onnx_node_name are used.
This operation is introduced early so as to preserve the name of the original ONNX op.
The argument print_data control whether the data of the tensors to be printed.
When print_data == 1, the data of the tensor will be printed. Otherwise, just shape.
The argument input specifies the tensor to be printed. They could be a list
of the inputs and outputs of an ONNX op.

This operation is not part of the standard and was added to assist onnx-mlir.
}];

let arguments = (ins StrAttr:$op_name, Variadic<AnyTypeOf<[AnyTensor, NoneType]>>:$input);
let arguments = (ins StrAttr:$op_name, SI64Attr:$print_data, Variadic<AnyTypeOf<[AnyTensor, NoneType]>>:$input);
}

//===----------------------------------------------------------------------===//
Expand Down
68 changes: 47 additions & 21 deletions src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ class InstrumentONNXSignaturePass
: mlir::PassWrapper<InstrumentONNXSignaturePass,
OperationPass<func::FuncOp>>() {
signaturePattern = pass.signaturePattern;
nodeNamePattern = pass.nodeNamePattern;
}
InstrumentONNXSignaturePass(const std::string pattern) {
signaturePattern = pattern;
}
InstrumentONNXSignaturePass(
const std::string opPattern, const std::string nodePattern)
: signaturePattern(opPattern), nodeNamePattern(nodePattern) {}

private:
std::string signaturePattern;
std::string nodeNamePattern;

public:
StringRef getArgument() const override {
Expand All @@ -71,30 +73,54 @@ class InstrumentONNXSignaturePass

void runOnOperation() override {
onnx_mlir::EnableByRegexOption traceSpecificOpPattern(

/*emptyIsNone*/ false);
onnx_mlir::EnableByRegexOption traceSpecificNodePattern(
/*emptyIsNone*/ false);

traceSpecificOpPattern.setRegexString(signaturePattern);
traceSpecificNodePattern.setRegexString(nodeNamePattern);
// Iterate on the operations nested in this function.
getOperation().walk([&](mlir::Operation *op) {
std::string opName = op->getName().getStringRef().str();
auto dialect = op->getDialect();
Location loc = op->getLoc();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did not check the logic in great detail, I assume that you used a small example mlir file and tried both the signature and the new functionality and that you were happy with the outputs.

// Define a lambda function to check whether the node is selected by
// its op name or node name, and if yes, insert ONNXSignatureOp
auto checkAndInsert = [&](onnx_mlir::EnableByRegexOption &pattern,
std::string matchString, int detail) {
if (pattern.isEnabled(matchString)) {
// Add signature printing op.
OpBuilder builder(op);
std::string opName = op->getName().getStringRef().str();
std::string nodeName = onnx_mlir::getNodeNameInPresenceOfOpt(op);
std::string fullName = opName + ", " + nodeName;
StringAttr fullNameAttr = builder.getStringAttr(fullName);
// Enqueue all input operands, and then the results.
llvm::SmallVector<Value, 6> operAndRes(op->getOperands());
for (Value res : op->getResults())
operAndRes.emplace_back(res);
// Since we may use the result of an operation, we must insert the
// print operation after the operation.
builder.setInsertionPointAfter(op);
// When one node is selected, print the details of the tensor.
builder.create<ONNXPrintSignatureOp>(
loc, fullNameAttr, detail, operAndRes);
}
};

if (isa<func::FuncDialect>(dialect) || isa<ONNXPrintSignatureOp>(op)) {
// Always skip function dialects (such as function call/return), as well
// as ONNX print signature ops.
} else if (traceSpecificOpPattern.isEnabled(opName)) {
// Add signature printing op.
Location loc = op->getLoc();
OpBuilder builder(op);
std::string nodeName = onnx_mlir::getNodeNameInPresenceOfOpt(op);
std::string fullName = opName + ", " + nodeName;
StringAttr fullNameAttr = builder.getStringAttr(fullName);
// Enqueue all input operands, and then the results.
llvm::SmallVector<Value, 6> operAndRes(op->getOperands());
for (Value res : op->getResults())
operAndRes.emplace_back(res);
// Since we may use the result of an operation, we must insert the
// print operation after the operation.
builder.setInsertionPointAfter(op);
builder.create<ONNXPrintSignatureOp>(loc, fullNameAttr, operAndRes);
} else if (signaturePattern != "NONE") {
std::string opName = op->getName().getStringRef().str();
checkAndInsert(traceSpecificOpPattern, opName, 0);
} else if (nodeNamePattern != "NONE") {
StringAttr onnxNodeName =
op->getAttrOfType<mlir::StringAttr>("onnx_node_name");
if (onnxNodeName && !onnxNodeName.getValue().empty()) {
std::string nodeNameString = onnxNodeName.getValue().str();
checkAndInsert(traceSpecificNodePattern, nodeNameString, 1);
}
}
});
}
Expand All @@ -105,6 +131,6 @@ class InstrumentONNXSignaturePass
* Create an instrumentation pass.
*/
std::unique_ptr<mlir::Pass> onnx_mlir::createInstrumentONNXSignaturePass(
const std::string pattern) {
return std::make_unique<InstrumentONNXSignaturePass>(pattern);
const std::string pattern, const std::string nodePattern) {
return std::make_unique<InstrumentONNXSignaturePass>(pattern, nodePattern);
}
4 changes: 2 additions & 2 deletions src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ std::unique_ptr<mlir::Pass> createInstrumentCleanupPass();
/// Passes for instrumenting the ONNX ops to print their operand type
/// signatures at runtime.
std::unique_ptr<mlir::Pass> createInstrumentONNXSignaturePass(
const std::string pattern);
const std::string opPattern, const std::string nodePattern);

/// Pass for simplifying shape-related ONNX operations.
std::unique_ptr<mlir::Pass> createSimplifyShapeRelatedOpsPass();
Expand Down Expand Up @@ -129,4 +129,4 @@ std::unique_ptr<mlir::Pass> createConvertKrnlToLLVMPass(bool verifyInputTensors,
std::unique_ptr<mlir::Pass> createConvertONNXToTOSAPass();

} // namespace onnx_mlir
#endif
#endif
Loading