Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend instrumentSignature to print data #3078

Merged
merged 9 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,7 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,
else if (optStr == "-O3")
optLevel = OptLevel::O3;
// Lower ONNX to Krnl, ZHigh to ZLow.
addONNXToKrnlPasses(
pm, optLevel, /*enableCSE*/ true, instrumentSignatures, ONNXOpStats);
addONNXToKrnlPasses(pm, optLevel, /*enableCSE*/ true, ONNXOpStats);

if (nnpaEmissionTarget >= EmitZLowIR)
emissionTarget = EmitMLIR;
Expand Down
20 changes: 12 additions & 8 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,14 +501,18 @@ static llvm::cl::opt<std::string, true> parallelizeOpsOpt("parallelize-ops",

static llvm::cl::opt<std::string, true> instrumentSignatureOpt(
"instrument-signature",
llvm::cl::desc("Specify which high-level operations should print their"
" input type(s) and shape(s)\n"
"\"ALL\" or \"\" for all available operations.\n"
"\"NONE\" for no instrument (default).\n"
"\"ops1,ops2, ...\" for the multiple ops.\n"
"e.g. \"onnx.MatMul,onnx.Add\" for MatMul and Add ops.\n"
"Asterisk is also available.\n"
"e.g. \"onnx.*\" for all onnx operations.\n"),
llvm::cl::desc(
"Specify which high-level operations should print their"
" input type(s) and shape(s)\n"
"\"ALL\" or \"\" for all available operations.\n"
"\"NONE\" for no instrument (default).\n"
"\"ops1,ops2, ...\" for the multiple ops.\n"
"e.g. \"onnx.MatMul,onnx.Add\" for MatMul and Add ops.\n"
"Asterisk is also available.\n"
"e.g. \"onnx.*\" for all onnx operations.\n"
"If this option is started with \"onnx_node_name\"\n"
"the attribute of \"onnx_node_name\", instead of the op name\n"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing the : in the pattern, as you are looking for this

std::string header = "onnx_node_name:";

Also you state that "the data values" will be printed. Is that the input data values, the output data values? Please specify.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

"will be used to match the op, and the data value will be printed.\n"),
llvm::cl::location(instrumentSignatures), llvm::cl::init("NONE"),
llvm::cl::cat(OnnxMlirOptions));

Expand Down
20 changes: 10 additions & 10 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,16 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
if (instrumentStage == onnx_mlir::InstrumentStages::Onnx)
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions));
// Print Signatures of each op at runtime if enabled. Should not run
// signature and instrument passes at the same time as time may include printf
// overheads.
if (instrumentSignatures != "NONE")
pm.addNestedPass<func::FuncOp>(
onnx_mlir::createInstrumentONNXSignaturePass(instrumentSignatures));
}

void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
std::string instrumentSignatureString, std::string ONNXOpsStatFormat) {
std::string ONNXOpsStatFormat) {
if (enableCSE)
// Eliminate common sub-expressions before lowering to Krnl.
// TODO: enable this by default when we make sure it works flawlessly.
Expand All @@ -196,12 +202,6 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
}
}

// Print Signatures of each op at runtime if enabled. Should not run
// signature and instrument passes at the same time as time may include printf
// overheads.
if (instrumentSignatureString != "NONE")
pm.addNestedPass<func::FuncOp>(onnx_mlir::createInstrumentONNXSignaturePass(
instrumentSignatureString));
pm.addPass(onnx_mlir::createLowerToKrnlPass(/*enableTiling*/ optLevel >= 3,
/*enableSIMD*/ optLevel >= 3 && !disableSimdOption, enableParallel,
/*enableFastMath*/ optLevel >= 3 && enableFastMathOption,
Expand All @@ -214,7 +214,7 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,

void addKrnlToAffinePasses(mlir::PassManager &pm) {
pm.addNestedPass<func::FuncOp>(
onnx_mlir::krnl::createConvertKrnlToAffinePass(enableParallel));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You lost some changes from dev main. Please add them back.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

onnx_mlir::krnl::createConvertKrnlToAffinePass());
}

void addKrnlToLLVMPasses(
Expand Down Expand Up @@ -325,8 +325,8 @@ void addPasses(mlir::OwningOpRef<ModuleOp> &module, mlir::PassManager &pm,

if (emissionTarget >= EmitMLIR) {
if (inputIRLevel <= ONNXLevel)
addONNXToKrnlPasses(pm, OptimizationLevel, /*enableCSE*/ true,
instrumentSignatures, ONNXOpStats);
addONNXToKrnlPasses(
pm, OptimizationLevel, /*enableCSE*/ true, ONNXOpStats);
if (inputIRLevel <= MLIRLevel)
addKrnlToAffinePasses(pm);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Compiler/CompilerPasses.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void configurePasses();
void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
bool donotScrubDisposableElementsAttr = false);
void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
std::string instrumentSignatureString, std::string ONNXOpsStatFilename);
std::string ONNXOpsStatFilename);
void addKrnlToAffinePasses(mlir::PassManager &pm);
void addKrnlToLLVMPasses(
mlir::OpPassManager &pm, std::string outputNameNoExt, bool enableCSE);
Expand Down
11 changes: 9 additions & 2 deletions src/Conversion/ONNXToKrnl/Tensor/PrintSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,22 @@ struct ONNXPrintSignatureLowering
op, msg + "(no tensors)\n%e", noneVal);
return success();
}
// Control how the tensor will be printed
// Print the only the shape.
std::string printControl = "%e";
if (printSignatureOp.getPrintData() == 1) {
// The data of tensor will be printed
printControl = "%d";
}
Value lastVal = printVal.pop_back_val();
// Print all but the last one.
for (Value oper : printVal) {
create.krnl.printTensor(msg + ", %t%e", oper);
create.krnl.printTensor(msg + ", %t" + printControl, oper);
msg = "%i";
}
// Print the last one with replace with new op.
rewriter.replaceOpWithNewOp<KrnlPrintTensorOp>(
op, msg + ", %t\n%e", lastVal);
op, msg + ", %t\n" + printControl, lastVal);
return success();
}
};
Expand Down
12 changes: 8 additions & 4 deletions src/Dialect/ONNX/AdditionalONNXOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -322,15 +322,19 @@ def ONNXNoneOp : ONNX_Op<"NoValue", [ConstantLike, Pure]> {
//===----------------------------------------------------------------------===//
// ONNX PrintSignatureOp.
def ONNXPrintSignatureOp:ONNX_Op<"PrintSignature", []> {
let summary = "ONNX Op to print type signature of its input operands";
let summary = "ONNX Op to print type signature or data of its input operands";
let description = [{
Print type signature of the op's input operands. This operation is introduced early
so as to preserve the name of the original ONNX op.
Print type signature of the op's input operands.
The parameter op_name specifies a string to be printed with the output, and usually the op_name and onnx_node_name are used.
This operation is introduced early so as to preserve the name of the original ONNX op.
The argument print_data control whether signature of data of the tensors to be printed.
When print_data == 1, the data of the tensor will be printed. Otherwise, the shape of the tensor will be printed.
The argument input specifies all the tensor to be printed.

This operation is not part of the standard and was added to assist onnx-mlir.
}];

let arguments = (ins StrAttr:$op_name, Variadic<AnyTypeOf<[AnyTensor, NoneType]>>:$input);
let arguments = (ins StrAttr:$op_name, SI64Attr:$print_data, Variadic<AnyTypeOf<[AnyTensor, NoneType]>>:$input);
}

//===----------------------------------------------------------------------===//
Expand Down
25 changes: 23 additions & 2 deletions src/Dialect/ONNX/Transforms/InstrumentONNXSignaturePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,34 @@ class InstrumentONNXSignaturePass
void runOnOperation() override {
onnx_mlir::EnableByRegexOption traceSpecificOpPattern(
/*emptyIsNone*/ false);
// If the signaturePattern is started with "onnx_node_name: ",
// the node will be selected by the string attribute of onnx_node_name.
std::string header = "onnx_node_name:";
std::cout << signaturePattern << "\n";
bool useNodeName = false;
if (signaturePattern.rfind(header, 0) == 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the behavior here? If there is one match of "onnx_node_name", then all we use the node name for all ops?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the option starts with onnx_node_name:, we use attribute("onnx_node_name"), instead of op->getName(), to check all the op.
I am not sure what you mean by node name. I modified CompilerOption.cpp for this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the code is not doing what you think it is doing.

if the pattern passed is "onnx.Add, onnx_node_name:one_specific_op", then the rfind will return true and thus useNodeName is set to true. Then for every op in the walk, we will fetch the data from the onnx_node_name for alll.

I think what you need to do is:

  1. use the old code to match with an op. if there is a match, then do the old print.
  2. else search the onnx_node_name attribute, if non null, prefix the attribute with "onnx_node_name:", search. If hit, then do the new print including the data.

Copy link
Collaborator Author

@chentong319 chentong319 Feb 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then for every op in the walk, we will fetch the data from the onnx_node_name for all.

Yes, that's what I want. For all the op, either from onnx_node_name or op->getName(). We could separate the pattern for onnx_node_name and getName, and match each op accordingly. I did not do that: just keep the pattern description simple.

signaturePattern.erase(0, header.length());
useNodeName = true;
std::cout << signaturePattern << "\n";
}

traceSpecificOpPattern.setRegexString(signaturePattern);
// Iterate on the operations nested in this function.
getOperation().walk([&](mlir::Operation *op) {
std::string opName = op->getName().getStringRef().str();
std::string matchString = opName;
if (useNodeName) {
StringAttr onnxNodeName =
op->getAttrOfType<mlir::StringAttr>("onnx_node_name");
if (onnxNodeName && !onnxNodeName.getValue().empty()) {
matchString = onnxNodeName.getValue().str();
}
}
auto dialect = op->getDialect();
if (isa<func::FuncDialect>(dialect) || isa<ONNXPrintSignatureOp>(op)) {
// Always skip function dialects (such as function call/return), as well
// as ONNX print signature ops.
} else if (traceSpecificOpPattern.isEnabled(opName)) {
} else if (traceSpecificOpPattern.isEnabled(matchString)) {
// Add signature printing op.
Location loc = op->getLoc();
OpBuilder builder(op);
Expand All @@ -94,7 +113,9 @@ class InstrumentONNXSignaturePass
// Since we may use the result of an operation, we must insert the
// print operation after the operation.
builder.setInsertionPointAfter(op);
builder.create<ONNXPrintSignatureOp>(loc, fullNameAttr, operAndRes);
// When one node is selected, print the details of the tensor.
builder.create<ONNXPrintSignatureOp>(
loc, fullNameAttr, useNodeName ? 1 : 0, operAndRes);
}
});
}
Expand Down