-
Notifications
You must be signed in to change notification settings - Fork 338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extend instrumentSignature to print data #3078
Changes from 4 commits
c2ed47b
86ac8d6
447ef2f
73c39b3
096d5b4
4589644
3fbd3ad
00cdb80
19f8476
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -169,10 +169,16 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU, | |
if (instrumentStage == onnx_mlir::InstrumentStages::Onnx) | ||
pm.addNestedPass<func::FuncOp>( | ||
onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions)); | ||
// Print Signatures of each op at runtime if enabled. Should not run | ||
// signature and instrument passes at the same time as time may include printf | ||
// overheads. | ||
if (instrumentSignatures != "NONE") | ||
pm.addNestedPass<func::FuncOp>( | ||
onnx_mlir::createInstrumentONNXSignaturePass(instrumentSignatures)); | ||
} | ||
|
||
void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, | ||
std::string instrumentSignatureString, std::string ONNXOpsStatFormat) { | ||
std::string ONNXOpsStatFormat) { | ||
if (enableCSE) | ||
// Eliminate common sub-expressions before lowering to Krnl. | ||
// TODO: enable this by default when we make sure it works flawlessly. | ||
|
@@ -196,12 +202,6 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, | |
} | ||
} | ||
|
||
// Print Signatures of each op at runtime if enabled. Should not run | ||
// signature and instrument passes at the same time as time may include printf | ||
// overheads. | ||
if (instrumentSignatureString != "NONE") | ||
pm.addNestedPass<func::FuncOp>(onnx_mlir::createInstrumentONNXSignaturePass( | ||
instrumentSignatureString)); | ||
pm.addPass(onnx_mlir::createLowerToKrnlPass(/*enableTiling*/ optLevel >= 3, | ||
/*enableSIMD*/ optLevel >= 3 && !disableSimdOption, enableParallel, | ||
/*enableFastMath*/ optLevel >= 3 && enableFastMathOption, | ||
|
@@ -214,7 +214,7 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, | |
|
||
void addKrnlToAffinePasses(mlir::PassManager &pm) { | ||
pm.addNestedPass<func::FuncOp>( | ||
onnx_mlir::krnl::createConvertKrnlToAffinePass(enableParallel)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You lost some changes from dev main. Please add them back. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
onnx_mlir::krnl::createConvertKrnlToAffinePass()); | ||
} | ||
|
||
void addKrnlToLLVMPasses( | ||
|
@@ -325,8 +325,8 @@ void addPasses(mlir::OwningOpRef<ModuleOp> &module, mlir::PassManager &pm, | |
|
||
if (emissionTarget >= EmitMLIR) { | ||
if (inputIRLevel <= ONNXLevel) | ||
addONNXToKrnlPasses(pm, OptimizationLevel, /*enableCSE*/ true, | ||
instrumentSignatures, ONNXOpStats); | ||
addONNXToKrnlPasses( | ||
pm, OptimizationLevel, /*enableCSE*/ true, ONNXOpStats); | ||
if (inputIRLevel <= MLIRLevel) | ||
addKrnlToAffinePasses(pm); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -72,15 +72,34 @@ class InstrumentONNXSignaturePass | |
void runOnOperation() override { | ||
onnx_mlir::EnableByRegexOption traceSpecificOpPattern( | ||
/*emptyIsNone*/ false); | ||
// If the signaturePattern is started with "onnx_node_name: ", | ||
// the node will be selected by the string attribute of onnx_node_name. | ||
std::string header = "onnx_node_name:"; | ||
std::cout << signaturePattern << "\n"; | ||
bool useNodeName = false; | ||
if (signaturePattern.rfind(header, 0) == 0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain the behavior here? If there is one match of "onnx_node_name", then all we use the node name for all ops? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When the option starts with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe the code is not doing what you think it is doing. if the pattern passed is "onnx.Add, onnx_node_name:one_specific_op", then the I think what you need to do is:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, that's what I want. For all the op, either from onnx_node_name or op->getName(). We could separate the pattern for onnx_node_name and getName, and match each op accordingly. I did not do that: just keep the pattern description simple. |
||
signaturePattern.erase(0, header.length()); | ||
useNodeName = true; | ||
std::cout << signaturePattern << "\n"; | ||
} | ||
|
||
traceSpecificOpPattern.setRegexString(signaturePattern); | ||
// Iterate on the operations nested in this function. | ||
getOperation().walk([&](mlir::Operation *op) { | ||
std::string opName = op->getName().getStringRef().str(); | ||
std::string matchString = opName; | ||
if (useNodeName) { | ||
StringAttr onnxNodeName = | ||
op->getAttrOfType<mlir::StringAttr>("onnx_node_name"); | ||
if (onnxNodeName && !onnxNodeName.getValue().empty()) { | ||
matchString = onnxNodeName.getValue().str(); | ||
} | ||
} | ||
auto dialect = op->getDialect(); | ||
if (isa<func::FuncDialect>(dialect) || isa<ONNXPrintSignatureOp>(op)) { | ||
// Always skip function dialects (such as function call/return), as well | ||
// as ONNX print signature ops. | ||
} else if (traceSpecificOpPattern.isEnabled(opName)) { | ||
} else if (traceSpecificOpPattern.isEnabled(matchString)) { | ||
// Add signature printing op. | ||
Location loc = op->getLoc(); | ||
OpBuilder builder(op); | ||
|
@@ -94,7 +113,9 @@ class InstrumentONNXSignaturePass | |
// Since we may use the result of an operation, we must insert the | ||
// print operation after the operation. | ||
builder.setInsertionPointAfter(op); | ||
builder.create<ONNXPrintSignatureOp>(loc, fullNameAttr, operAndRes); | ||
// When one node is selected, print the details of the tensor. | ||
builder.create<ONNXPrintSignatureOp>( | ||
loc, fullNameAttr, useNodeName ? 1 : 0, operAndRes); | ||
} | ||
}); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing the
:
in the pattern, as you are looking for thisAlso you state that "the data values" will be printed. Is that the input data values, the output data values? Please specify.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.