@@ -312,15 +312,15 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8;
312312size_t divideAndCeil (size_t x, size_t y) { return (x + y - 1 ) / y; }
313313
314314// Manager
315- struct constGraphTensorCacheManager {
315+ struct ConstGraphTensorCacheManager {
316316 // dnnl_graph_compiler_context *ctx;
317317
318318 uint64_t cachedTensorGlobalId = 0 ;
319319
320320 // singleton
321- static std::shared_ptr<constGraphTensorCacheManager > get () {
322- static std::shared_ptr<constGraphTensorCacheManager > c =
323- std::make_shared<constGraphTensorCacheManager >();
321+ static std::shared_ptr<ConstGraphTensorCacheManager > get () {
322+ static std::shared_ptr<ConstGraphTensorCacheManager > c =
323+ std::make_shared<ConstGraphTensorCacheManager >();
324324 return c;
325325 }
326326
@@ -385,18 +385,7 @@ static void addGlobalI32Array(ModuleOp &module, Location loc,
385385 /* alignment=*/ 0 );
386386}
387387
388- // Operate on tensors. Create fold() and compute() on module. The
389- // folded weights and first-run flag is maintained by upper-level runtime.
390- void ConstantTensorFolding::runOnOperation () {
391- Operation *topOp = getOperation ();
392- MLIRContext *context = topOp->getContext ();
393- // A ModuleOp contains a single region, which contains a single block.
394- auto moduleOp = dyn_cast<ModuleOp>(topOp);
395- SymbolTable symbolTable (moduleOp);
396- auto &topFunc =
397- topOp->getRegions ().front ().getBlocks ().front ().getOperations ().front ();
398- OpBuilder builder (context);
399-
388+ std::unordered_set<int > getConstArgsIndexes (Operation &topFunc) {
400389 auto topFuncAttr = topFunc.getAttrDictionary ();
401390 std::optional<NamedAttribute> constArgs =
402391 topFuncAttr.getNamed (" onednn_graph.const_args" );
@@ -406,32 +395,16 @@ void ConstantTensorFolding::runOnOperation() {
406395 for (auto id : constArgsArray) {
407396 constArgsIndexes.insert (llvm::cast<IntegerAttr>(id).getInt ());
408397 }
409- } else {
410- return ;
411- }
412- if (constArgsIndexes.empty ()) {
413- return ;
414- }
415-
416- Region ®ion = topFunc.getRegions ().front ();
417- Block &block = region.getBlocks ().front ();
418-
419- postponeBroadcast (block);
420-
421- SmallVector<Operation *> constOps;
422- for (Operation &op : llvm::make_early_inc_range (block)) {
423- if (isInConstantSubgraph (&op)) {
424- constOps.push_back (&op);
425- }
426398 }
399+ return constArgsIndexes;
400+ }
427401
428- std::string funcName (" fold" );
429- SmallVector<Type> inputTypes; // types of constant weights
430- // values of constant weights in original block
431- SmallVector<Value> inputValues;
432- SmallVector<Type> outputTypes; // types of folded constant weights
433- // values of folded constant weights in original block
434- SmallVector<Value> outputValues;
402+ void getInputsAndOutputs (Block &block,
403+ std::unordered_set<int > &constArgsIndexes,
404+ SmallVector<Type> &inputTypes,
405+ SmallVector<Value> &inputValues,
406+ SmallVector<Type> &outputTypes,
407+ SmallVector<Value> &outputValues) {
435408 Value v;
436409 // Support complicated topology.
437410 for (size_t id = 0 ; id < block.getNumArguments (); ++id) {
@@ -512,11 +485,19 @@ void ConstantTensorFolding::runOnOperation() {
512485 }
513486 }
514487 }
488+ }
515489
490+ func::FuncOp buildFoldFunc (MLIRContext *context, OpBuilder &builder,
491+ Operation *topOp, SmallVector<Operation *> constOps,
492+ SmallVector<Type> &inputTypes,
493+ SmallVector<Value> &inputValues,
494+ SmallVector<Type> &outputTypes,
495+ SmallVector<Value> &outputValues) {
496+ std::string funcName (" fold" );
516497 FunctionType foldFuncType =
517498 FunctionType::get (context, inputTypes, outputTypes);
518499 func::FuncOp foldFunc =
519- builder.create <func::FuncOp>(topFunc. getLoc (), funcName, foldFuncType);
500+ builder.create <func::FuncOp>(topOp-> getLoc (), funcName, foldFuncType);
520501 Block *foldBlock = foldFunc.addEntryBlock ();
521502 // values of folded constant weights in foldBlock
522503 SmallVector<Value> outputValuesInFold;
@@ -535,39 +516,50 @@ void ConstantTensorFolding::runOnOperation() {
535516 });
536517 }
537518
538- auto returnOp =
539- builder.create <func::ReturnOp>(topOp->getLoc (), outputValuesInFold);
540- foldBlock->getOperations ().push_back (returnOp);
541- for (size_t i = 0 ; i < inputValues.size (); ++i) {
542- inputValues[i].replaceUsesWithIf (foldBlock->getArgument (i),
543- [&](OpOperand &val) {
544- Operation *op = val.getOwner ();
545- return op->getBlock () == foldBlock;
546- });
547- }
548-
549519 // Allocate buffer for outputValuesInFold
550520 std::vector<size_t > buffersSize;
551521 for (Value &tensor : outputValuesInFold) {
552522 llvm::dbgs () << " Allocate buffer for tensor: " << tensor << " \n " ;
553523 buffersSize.push_back (
554524 getTensorSize (dyn_cast<TensorType>(tensor.getType ())));
555525 }
556- auto manager = constGraphTensorCacheManager ::get ();
526+ auto manager = ConstGraphTensorCacheManager ::get ();
557527 SmallVector<int64_t > globalIndexes;
558528 for (auto id : manager->alloc (buffersSize)) {
559529 globalIndexes.push_back (id);
560530 }
561531 globalIndexes.insert (globalIndexes.begin (), globalIndexes.size ());
532+ auto moduleOp = dyn_cast<ModuleOp>(topOp);
562533 addGlobalI64Array (moduleOp, moduleOp.getLoc (), builder, " __fold_buffer_ids" ,
563534 globalIndexes);
564535
536+ auto returnOp =
537+ builder.create <func::ReturnOp>(topOp->getLoc (), outputValuesInFold);
538+ foldBlock->getOperations ().push_back (returnOp);
539+ for (size_t i = 0 ; i < inputValues.size (); ++i) {
540+ inputValues[i].replaceUsesWithIf (foldBlock->getArgument (i),
541+ [&](OpOperand &val) {
542+ Operation *op = val.getOwner ();
543+ return op->getBlock () == foldBlock;
544+ });
545+ }
546+
565547 foldFunc.setVisibility (SymbolTable::Visibility::Public);
566548 foldFunc->setAttr (LLVM::LLVMDialect::getEmitCWrapperAttrName (),
567549 UnitAttr::get (context));
550+
568551 moduleOp.push_back (foldFunc);
552+ SymbolTable symbolTable (moduleOp);
569553 symbolTable.insert (foldFunc);
570554
555+ return foldFunc;
556+ }
557+
558+ void modifyComputeFunc (MLIRContext *context, OpBuilder &builder,
559+ Operation *topOp, Operation &func, Block &block,
560+ std::unordered_set<int > &constArgsIndexes,
561+ SmallVector<Type> &outputTypes,
562+ SmallVector<Value> &outputValues) {
571563 // the indexes of args to the folding func.
572564 SmallVector<int32_t > foldArgs;
573565 // the indexes of folded args.
@@ -631,6 +623,13 @@ void ConstantTensorFolding::runOnOperation() {
631623 }
632624 block.eraseArguments (argsToErase);
633625
626+ // modify the compute func signature
627+ func::FuncOp computeFunc = cast<func::FuncOp>(func);
628+ FunctionType computeFuncType = computeFunc.getFunctionType ();
629+ computeFunc.setType (FunctionType::get (context, block.getArgumentTypes (),
630+ computeFuncType.getResults ()));
631+
632+ auto moduleOp = dyn_cast<ModuleOp>(topOp);
634633 for (auto id : foldIds) {
635634 foldArgs.insert (foldArgs.end (), id);
636635 }
@@ -647,13 +646,9 @@ void ConstantTensorFolding::runOnOperation() {
647646
648647 addGlobalI32 (moduleOp, moduleOp.getLoc (), builder, " __num_orig_num_args" ,
649648 oriNumArgs);
649+ }
650650
651- // modify the compute func signature
652- func::FuncOp computeFunc = cast<func::FuncOp>(topFunc);
653- FunctionType computeFuncType = computeFunc.getFunctionType ();
654- computeFunc.setType (FunctionType::get (context, block.getArgumentTypes (),
655- computeFuncType.getResults ()));
656-
651+ void canonicalizeAndClean (MLIRContext *context, Operation *topOp) {
657652 // Delete dead operations by dialects' canonicalizer
658653 RewritePatternSet owningPatterns (context);
659654 for (auto *dialect : context->getLoadedDialects ())
@@ -669,16 +664,55 @@ void ConstantTensorFolding::runOnOperation() {
669664 (void )converged;
670665
671666 // clean up the constant-related attrs on ops
672- for ( auto &op : block. getOperations () ) {
673- if (op. getAttr (" onednn_graph.in_const_subgraph" )) {
674- op. removeAttr (" onednn_graph.in_const_subgraph" );
667+ topOp-> walk ([&](Operation *op ) {
668+ if (op-> getAttr (" onednn_graph.in_const_subgraph" )) {
669+ op-> removeAttr (" onednn_graph.in_const_subgraph" );
675670 }
671+ });
672+ }
673+
674+ // Operate on tensors. Create fold() and compute() on module. The
675+ // folded weights and first-run flag is maintained by upper-level runtime.
676+ void ConstantTensorFolding::runOnOperation () {
677+ Operation *topOp = getOperation ();
678+ MLIRContext *context = topOp->getContext ();
679+ auto &topFunc =
680+ topOp->getRegions ().front ().getBlocks ().front ().getOperations ().front ();
681+ OpBuilder builder (context);
682+ Region ®ion = topFunc.getRegions ().front ();
683+ Block &block = region.getBlocks ().front ();
684+
685+ std::unordered_set<int > constArgsIndexes = getConstArgsIndexes (topFunc);
686+ if (constArgsIndexes.empty ()) {
687+ return ;
676688 }
677- for (auto &op : foldBlock->getOperations ()) {
678- if (op.getAttr (" onednn_graph.in_const_subgraph" )) {
679- op.removeAttr (" onednn_graph.in_const_subgraph" );
689+
690+ postponeBroadcast (block);
691+
692+ SmallVector<Operation *> constOps;
693+ for (Operation &op : llvm::make_early_inc_range (block)) {
694+ if (isInConstantSubgraph (&op)) {
695+ constOps.push_back (&op);
680696 }
681697 }
698+
699+ SmallVector<Type> inputTypes; // types of constant weights
700+ // values of constant weights in original block
701+ SmallVector<Value> inputValues;
702+ SmallVector<Type> outputTypes; // types of folded constant weights
703+ // values of folded constant weights in original block
704+ SmallVector<Value> outputValues;
705+ getInputsAndOutputs (block, constArgsIndexes, inputTypes, inputValues,
706+ outputTypes, outputValues);
707+
708+ func::FuncOp foldFunc =
709+ buildFoldFunc (context, builder, topOp, constOps, inputTypes, inputValues,
710+ outputTypes, outputValues);
711+
712+ modifyComputeFunc (context, builder, topOp, topFunc, block, constArgsIndexes,
713+ outputTypes, outputValues);
714+
715+ canonicalizeAndClean (context, topOp);
682716}
683717
684718std::unique_ptr<Pass> createConstantTensorFoldingPass () {
0 commit comments