@@ -312,15 +312,15 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8;
312
312
size_t divideAndCeil (size_t x, size_t y) { return (x + y - 1 ) / y; }
313
313
314
314
// Manager
315
- struct constGraphTensorCacheManager {
315
+ struct ConstGraphTensorCacheManager {
316
316
// dnnl_graph_compiler_context *ctx;
317
317
318
318
uint64_t cachedTensorGlobalId = 0 ;
319
319
320
320
// singleton
321
- static std::shared_ptr<constGraphTensorCacheManager > get () {
322
- static std::shared_ptr<constGraphTensorCacheManager > c =
323
- std::make_shared<constGraphTensorCacheManager >();
321
+ static std::shared_ptr<ConstGraphTensorCacheManager > get () {
322
+ static std::shared_ptr<ConstGraphTensorCacheManager > c =
323
+ std::make_shared<ConstGraphTensorCacheManager >();
324
324
return c;
325
325
}
326
326
@@ -385,18 +385,7 @@ static void addGlobalI32Array(ModuleOp &module, Location loc,
385
385
/* alignment=*/ 0 );
386
386
}
387
387
388
- // Operate on tensors. Create fold() and compute() on module. The
389
- // folded weights and first-run flag is maintained by upper-level runtime.
390
- void ConstantTensorFolding::runOnOperation () {
391
- Operation *topOp = getOperation ();
392
- MLIRContext *context = topOp->getContext ();
393
- // A ModuleOp contains a single region, which contains a single block.
394
- auto moduleOp = dyn_cast<ModuleOp>(topOp);
395
- SymbolTable symbolTable (moduleOp);
396
- auto &topFunc =
397
- topOp->getRegions ().front ().getBlocks ().front ().getOperations ().front ();
398
- OpBuilder builder (context);
399
-
388
+ std::unordered_set<int > getConstArgsIndexes (Operation &topFunc) {
400
389
auto topFuncAttr = topFunc.getAttrDictionary ();
401
390
std::optional<NamedAttribute> constArgs =
402
391
topFuncAttr.getNamed (" onednn_graph.const_args" );
@@ -406,32 +395,16 @@ void ConstantTensorFolding::runOnOperation() {
406
395
for (auto id : constArgsArray) {
407
396
constArgsIndexes.insert (llvm::cast<IntegerAttr>(id).getInt ());
408
397
}
409
- } else {
410
- return ;
411
- }
412
- if (constArgsIndexes.empty ()) {
413
- return ;
414
- }
415
-
416
- Region ®ion = topFunc.getRegions ().front ();
417
- Block &block = region.getBlocks ().front ();
418
-
419
- postponeBroadcast (block);
420
-
421
- SmallVector<Operation *> constOps;
422
- for (Operation &op : llvm::make_early_inc_range (block)) {
423
- if (isInConstantSubgraph (&op)) {
424
- constOps.push_back (&op);
425
- }
426
398
}
399
+ return constArgsIndexes;
400
+ }
427
401
428
- std::string funcName (" fold" );
429
- SmallVector<Type> inputTypes; // types of constant weights
430
- // values of constant weights in original block
431
- SmallVector<Value> inputValues;
432
- SmallVector<Type> outputTypes; // types of folded constant weights
433
- // values of folded constant weights in original block
434
- SmallVector<Value> outputValues;
402
+ void getInputsAndOutputs (Block &block,
403
+ std::unordered_set<int > &constArgsIndexes,
404
+ SmallVector<Type> &inputTypes,
405
+ SmallVector<Value> &inputValues,
406
+ SmallVector<Type> &outputTypes,
407
+ SmallVector<Value> &outputValues) {
435
408
Value v;
436
409
// Support complicated topology.
437
410
for (size_t id = 0 ; id < block.getNumArguments (); ++id) {
@@ -512,11 +485,19 @@ void ConstantTensorFolding::runOnOperation() {
512
485
}
513
486
}
514
487
}
488
+ }
515
489
490
+ func::FuncOp buildFoldFunc (MLIRContext *context, OpBuilder &builder,
491
+ Operation *topOp, SmallVector<Operation *> constOps,
492
+ SmallVector<Type> &inputTypes,
493
+ SmallVector<Value> &inputValues,
494
+ SmallVector<Type> &outputTypes,
495
+ SmallVector<Value> &outputValues) {
496
+ std::string funcName (" fold" );
516
497
FunctionType foldFuncType =
517
498
FunctionType::get (context, inputTypes, outputTypes);
518
499
func::FuncOp foldFunc =
519
- builder.create <func::FuncOp>(topFunc. getLoc (), funcName, foldFuncType);
500
+ builder.create <func::FuncOp>(topOp-> getLoc (), funcName, foldFuncType);
520
501
Block *foldBlock = foldFunc.addEntryBlock ();
521
502
// values of folded constant weights in foldBlock
522
503
SmallVector<Value> outputValuesInFold;
@@ -535,39 +516,50 @@ void ConstantTensorFolding::runOnOperation() {
535
516
});
536
517
}
537
518
538
- auto returnOp =
539
- builder.create <func::ReturnOp>(topOp->getLoc (), outputValuesInFold);
540
- foldBlock->getOperations ().push_back (returnOp);
541
- for (size_t i = 0 ; i < inputValues.size (); ++i) {
542
- inputValues[i].replaceUsesWithIf (foldBlock->getArgument (i),
543
- [&](OpOperand &val) {
544
- Operation *op = val.getOwner ();
545
- return op->getBlock () == foldBlock;
546
- });
547
- }
548
-
549
519
// Allocate buffer for outputValuesInFold
550
520
std::vector<size_t > buffersSize;
551
521
for (Value &tensor : outputValuesInFold) {
552
522
llvm::dbgs () << " Allocate buffer for tensor: " << tensor << " \n " ;
553
523
buffersSize.push_back (
554
524
getTensorSize (dyn_cast<TensorType>(tensor.getType ())));
555
525
}
556
- auto manager = constGraphTensorCacheManager ::get ();
526
+ auto manager = ConstGraphTensorCacheManager ::get ();
557
527
SmallVector<int64_t > globalIndexes;
558
528
for (auto id : manager->alloc (buffersSize)) {
559
529
globalIndexes.push_back (id);
560
530
}
561
531
globalIndexes.insert (globalIndexes.begin (), globalIndexes.size ());
532
+ auto moduleOp = dyn_cast<ModuleOp>(topOp);
562
533
addGlobalI64Array (moduleOp, moduleOp.getLoc (), builder, " __fold_buffer_ids" ,
563
534
globalIndexes);
564
535
536
+ auto returnOp =
537
+ builder.create <func::ReturnOp>(topOp->getLoc (), outputValuesInFold);
538
+ foldBlock->getOperations ().push_back (returnOp);
539
+ for (size_t i = 0 ; i < inputValues.size (); ++i) {
540
+ inputValues[i].replaceUsesWithIf (foldBlock->getArgument (i),
541
+ [&](OpOperand &val) {
542
+ Operation *op = val.getOwner ();
543
+ return op->getBlock () == foldBlock;
544
+ });
545
+ }
546
+
565
547
foldFunc.setVisibility (SymbolTable::Visibility::Public);
566
548
foldFunc->setAttr (LLVM::LLVMDialect::getEmitCWrapperAttrName (),
567
549
UnitAttr::get (context));
550
+
568
551
moduleOp.push_back (foldFunc);
552
+ SymbolTable symbolTable (moduleOp);
569
553
symbolTable.insert (foldFunc);
570
554
555
+ return foldFunc;
556
+ }
557
+
558
+ void modifyComputeFunc (MLIRContext *context, OpBuilder &builder,
559
+ Operation *topOp, Operation &func, Block &block,
560
+ std::unordered_set<int > &constArgsIndexes,
561
+ SmallVector<Type> &outputTypes,
562
+ SmallVector<Value> &outputValues) {
571
563
// the indexes of args to the folding func.
572
564
SmallVector<int32_t > foldArgs;
573
565
// the indexes of folded args.
@@ -631,6 +623,13 @@ void ConstantTensorFolding::runOnOperation() {
631
623
}
632
624
block.eraseArguments (argsToErase);
633
625
626
+ // modify the compute func signature
627
+ func::FuncOp computeFunc = cast<func::FuncOp>(func);
628
+ FunctionType computeFuncType = computeFunc.getFunctionType ();
629
+ computeFunc.setType (FunctionType::get (context, block.getArgumentTypes (),
630
+ computeFuncType.getResults ()));
631
+
632
+ auto moduleOp = dyn_cast<ModuleOp>(topOp);
634
633
for (auto id : foldIds) {
635
634
foldArgs.insert (foldArgs.end (), id);
636
635
}
@@ -647,13 +646,10 @@ void ConstantTensorFolding::runOnOperation() {
647
646
648
647
addGlobalI32 (moduleOp, moduleOp.getLoc (), builder, " __num_orig_num_args" ,
649
648
oriNumArgs);
649
+ }
650
650
651
- // modify the compute func signature
652
- func::FuncOp computeFunc = cast<func::FuncOp>(topFunc);
653
- FunctionType computeFuncType = computeFunc.getFunctionType ();
654
- computeFunc.setType (FunctionType::get (context, block.getArgumentTypes (),
655
- computeFuncType.getResults ()));
656
-
651
+ void canonicalizeAndClean (MLIRContext *context, Operation *topOp, Block &block,
652
+ func::FuncOp &foldFunc) {
657
653
// Delete dead operations by dialects' canonicalizer
658
654
RewritePatternSet owningPatterns (context);
659
655
for (auto *dialect : context->getLoadedDialects ())
@@ -674,13 +670,57 @@ void ConstantTensorFolding::runOnOperation() {
674
670
op.removeAttr (" onednn_graph.in_const_subgraph" );
675
671
}
676
672
}
677
- for (auto &op : foldBlock-> getOperations ()) {
673
+ for (auto &op : foldFunc. front (). getOperations ()) {
678
674
if (op.getAttr (" onednn_graph.in_const_subgraph" )) {
679
675
op.removeAttr (" onednn_graph.in_const_subgraph" );
680
676
}
681
677
}
682
678
}
683
679
680
+ // Operate on tensors. Create fold() and compute() on module. The
681
+ // folded weights and first-run flag is maintained by upper-level runtime.
682
+ void ConstantTensorFolding::runOnOperation () {
683
+ Operation *topOp = getOperation ();
684
+ MLIRContext *context = topOp->getContext ();
685
+ auto &topFunc =
686
+ topOp->getRegions ().front ().getBlocks ().front ().getOperations ().front ();
687
+ OpBuilder builder (context);
688
+ Region ®ion = topFunc.getRegions ().front ();
689
+ Block &block = region.getBlocks ().front ();
690
+
691
+ std::unordered_set<int > constArgsIndexes = getConstArgsIndexes (topFunc);
692
+ if (constArgsIndexes.empty ()) {
693
+ return ;
694
+ }
695
+
696
+ postponeBroadcast (block);
697
+
698
+ SmallVector<Operation *> constOps;
699
+ for (Operation &op : llvm::make_early_inc_range (block)) {
700
+ if (isInConstantSubgraph (&op)) {
701
+ constOps.push_back (&op);
702
+ }
703
+ }
704
+
705
+ SmallVector<Type> inputTypes; // types of constant weights
706
+ // values of constant weights in original block
707
+ SmallVector<Value> inputValues;
708
+ SmallVector<Type> outputTypes; // types of folded constant weights
709
+ // values of folded constant weights in original block
710
+ SmallVector<Value> outputValues;
711
+ getInputsAndOutputs (block, constArgsIndexes, inputTypes, inputValues,
712
+ outputTypes, outputValues);
713
+
714
+ func::FuncOp foldFunc =
715
+ buildFoldFunc (context, builder, topOp, constOps, inputTypes, inputValues,
716
+ outputTypes, outputValues);
717
+
718
+ modifyComputeFunc (context, builder, topOp, topFunc, block, constArgsIndexes,
719
+ outputTypes, outputValues);
720
+
721
+ canonicalizeAndClean (context, topOp, block, foldFunc);
722
+ }
723
+
684
724
std::unique_ptr<Pass> createConstantTensorFoldingPass () {
685
725
return std::make_unique<ConstantTensorFolding>();
686
726
}
0 commit comments