Skip to content

Commit a0ddebe

Browse files
committed
Split into short functions
1 parent 0f67f75 commit a0ddebe

File tree

1 file changed

+100
-60
lines changed

1 file changed

+100
-60
lines changed

lib/gc/Transforms/ConstantTensorFolding.cpp

+100-60
Original file line numberDiff line numberDiff line change
@@ -312,15 +312,15 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8;
312312
size_t divideAndCeil(size_t x, size_t y) { return (x + y - 1) / y; }
313313

314314
// Manager
315-
struct constGraphTensorCacheManager {
315+
struct ConstGraphTensorCacheManager {
316316
// dnnl_graph_compiler_context *ctx;
317317

318318
uint64_t cachedTensorGlobalId = 0;
319319

320320
// singleton
321-
static std::shared_ptr<constGraphTensorCacheManager> get() {
322-
static std::shared_ptr<constGraphTensorCacheManager> c =
323-
std::make_shared<constGraphTensorCacheManager>();
321+
static std::shared_ptr<ConstGraphTensorCacheManager> get() {
322+
static std::shared_ptr<ConstGraphTensorCacheManager> c =
323+
std::make_shared<ConstGraphTensorCacheManager>();
324324
return c;
325325
}
326326

@@ -385,18 +385,7 @@ static void addGlobalI32Array(ModuleOp &module, Location loc,
385385
/*alignment=*/0);
386386
}
387387

388-
// Operate on tensors. Create fold() and compute() on module. The
389-
// folded weights and first-run flag is maintained by upper-level runtime.
390-
void ConstantTensorFolding::runOnOperation() {
391-
Operation *topOp = getOperation();
392-
MLIRContext *context = topOp->getContext();
393-
// A ModuleOp contains a single region, which contains a single block.
394-
auto moduleOp = dyn_cast<ModuleOp>(topOp);
395-
SymbolTable symbolTable(moduleOp);
396-
auto &topFunc =
397-
topOp->getRegions().front().getBlocks().front().getOperations().front();
398-
OpBuilder builder(context);
399-
388+
std::unordered_set<int> getConstArgsIndexes(Operation &topFunc) {
400389
auto topFuncAttr = topFunc.getAttrDictionary();
401390
std::optional<NamedAttribute> constArgs =
402391
topFuncAttr.getNamed("onednn_graph.const_args");
@@ -406,32 +395,16 @@ void ConstantTensorFolding::runOnOperation() {
406395
for (auto id : constArgsArray) {
407396
constArgsIndexes.insert(llvm::cast<IntegerAttr>(id).getInt());
408397
}
409-
} else {
410-
return;
411-
}
412-
if (constArgsIndexes.empty()) {
413-
return;
414-
}
415-
416-
Region &region = topFunc.getRegions().front();
417-
Block &block = region.getBlocks().front();
418-
419-
postponeBroadcast(block);
420-
421-
SmallVector<Operation *> constOps;
422-
for (Operation &op : llvm::make_early_inc_range(block)) {
423-
if (isInConstantSubgraph(&op)) {
424-
constOps.push_back(&op);
425-
}
426398
}
399+
return constArgsIndexes;
400+
}
427401

428-
std::string funcName("fold");
429-
SmallVector<Type> inputTypes; // types of constant weights
430-
// values of constant weights in original block
431-
SmallVector<Value> inputValues;
432-
SmallVector<Type> outputTypes; // types of folded constant weights
433-
// values of folded constant weights in original block
434-
SmallVector<Value> outputValues;
402+
void getInputsAndOutputs(Block &block,
403+
std::unordered_set<int> &constArgsIndexes,
404+
SmallVector<Type> &inputTypes,
405+
SmallVector<Value> &inputValues,
406+
SmallVector<Type> &outputTypes,
407+
SmallVector<Value> &outputValues) {
435408
Value v;
436409
// Support complicated topology.
437410
for (size_t id = 0; id < block.getNumArguments(); ++id) {
@@ -512,11 +485,19 @@ void ConstantTensorFolding::runOnOperation() {
512485
}
513486
}
514487
}
488+
}
515489

490+
func::FuncOp buildFoldFunc(MLIRContext *context, OpBuilder &builder,
491+
Operation *topOp, SmallVector<Operation *> constOps,
492+
SmallVector<Type> &inputTypes,
493+
SmallVector<Value> &inputValues,
494+
SmallVector<Type> &outputTypes,
495+
SmallVector<Value> &outputValues) {
496+
std::string funcName("fold");
516497
FunctionType foldFuncType =
517498
FunctionType::get(context, inputTypes, outputTypes);
518499
func::FuncOp foldFunc =
519-
builder.create<func::FuncOp>(topFunc.getLoc(), funcName, foldFuncType);
500+
builder.create<func::FuncOp>(topOp->getLoc(), funcName, foldFuncType);
520501
Block *foldBlock = foldFunc.addEntryBlock();
521502
// values of folded constant weights in foldBlock
522503
SmallVector<Value> outputValuesInFold;
@@ -535,39 +516,50 @@ void ConstantTensorFolding::runOnOperation() {
535516
});
536517
}
537518

538-
auto returnOp =
539-
builder.create<func::ReturnOp>(topOp->getLoc(), outputValuesInFold);
540-
foldBlock->getOperations().push_back(returnOp);
541-
for (size_t i = 0; i < inputValues.size(); ++i) {
542-
inputValues[i].replaceUsesWithIf(foldBlock->getArgument(i),
543-
[&](OpOperand &val) {
544-
Operation *op = val.getOwner();
545-
return op->getBlock() == foldBlock;
546-
});
547-
}
548-
549519
// Allocate buffer for outputValuesInFold
550520
std::vector<size_t> buffersSize;
551521
for (Value &tensor : outputValuesInFold) {
552522
llvm::dbgs() << "Allocate buffer for tensor: " << tensor << "\n";
553523
buffersSize.push_back(
554524
getTensorSize(dyn_cast<TensorType>(tensor.getType())));
555525
}
556-
auto manager = constGraphTensorCacheManager::get();
526+
auto manager = ConstGraphTensorCacheManager::get();
557527
SmallVector<int64_t> globalIndexes;
558528
for (auto id : manager->alloc(buffersSize)) {
559529
globalIndexes.push_back(id);
560530
}
561531
globalIndexes.insert(globalIndexes.begin(), globalIndexes.size());
532+
auto moduleOp = dyn_cast<ModuleOp>(topOp);
562533
addGlobalI64Array(moduleOp, moduleOp.getLoc(), builder, "__fold_buffer_ids",
563534
globalIndexes);
564535

536+
auto returnOp =
537+
builder.create<func::ReturnOp>(topOp->getLoc(), outputValuesInFold);
538+
foldBlock->getOperations().push_back(returnOp);
539+
for (size_t i = 0; i < inputValues.size(); ++i) {
540+
inputValues[i].replaceUsesWithIf(foldBlock->getArgument(i),
541+
[&](OpOperand &val) {
542+
Operation *op = val.getOwner();
543+
return op->getBlock() == foldBlock;
544+
});
545+
}
546+
565547
foldFunc.setVisibility(SymbolTable::Visibility::Public);
566548
foldFunc->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
567549
UnitAttr::get(context));
550+
568551
moduleOp.push_back(foldFunc);
552+
SymbolTable symbolTable(moduleOp);
569553
symbolTable.insert(foldFunc);
570554

555+
return foldFunc;
556+
}
557+
558+
void modifyComputeFunc(MLIRContext *context, OpBuilder &builder,
559+
Operation *topOp, Operation &func, Block &block,
560+
std::unordered_set<int> &constArgsIndexes,
561+
SmallVector<Type> &outputTypes,
562+
SmallVector<Value> &outputValues) {
571563
// the indexes of args to the folding func.
572564
SmallVector<int32_t> foldArgs;
573565
// the indexes of folded args.
@@ -631,6 +623,13 @@ void ConstantTensorFolding::runOnOperation() {
631623
}
632624
block.eraseArguments(argsToErase);
633625

626+
// modify the compute func signature
627+
func::FuncOp computeFunc = cast<func::FuncOp>(func);
628+
FunctionType computeFuncType = computeFunc.getFunctionType();
629+
computeFunc.setType(FunctionType::get(context, block.getArgumentTypes(),
630+
computeFuncType.getResults()));
631+
632+
auto moduleOp = dyn_cast<ModuleOp>(topOp);
634633
for (auto id : foldIds) {
635634
foldArgs.insert(foldArgs.end(), id);
636635
}
@@ -647,13 +646,10 @@ void ConstantTensorFolding::runOnOperation() {
647646

648647
addGlobalI32(moduleOp, moduleOp.getLoc(), builder, "__num_orig_num_args",
649648
oriNumArgs);
649+
}
650650

651-
// modify the compute func signature
652-
func::FuncOp computeFunc = cast<func::FuncOp>(topFunc);
653-
FunctionType computeFuncType = computeFunc.getFunctionType();
654-
computeFunc.setType(FunctionType::get(context, block.getArgumentTypes(),
655-
computeFuncType.getResults()));
656-
651+
void canonicalizeAndClean(MLIRContext *context, Operation *topOp, Block &block,
652+
func::FuncOp &foldFunc) {
657653
// Delete dead operations by dialects' canonicalizer
658654
RewritePatternSet owningPatterns(context);
659655
for (auto *dialect : context->getLoadedDialects())
@@ -674,13 +670,57 @@ void ConstantTensorFolding::runOnOperation() {
674670
op.removeAttr("onednn_graph.in_const_subgraph");
675671
}
676672
}
677-
for (auto &op : foldBlock->getOperations()) {
673+
for (auto &op : foldFunc.front().getOperations()) {
678674
if (op.getAttr("onednn_graph.in_const_subgraph")) {
679675
op.removeAttr("onednn_graph.in_const_subgraph");
680676
}
681677
}
682678
}
683679

680+
// Operate on tensors. Create fold() and compute() on module. The
681+
// folded weights and first-run flag is maintained by upper-level runtime.
682+
void ConstantTensorFolding::runOnOperation() {
683+
Operation *topOp = getOperation();
684+
MLIRContext *context = topOp->getContext();
685+
auto &topFunc =
686+
topOp->getRegions().front().getBlocks().front().getOperations().front();
687+
OpBuilder builder(context);
688+
Region &region = topFunc.getRegions().front();
689+
Block &block = region.getBlocks().front();
690+
691+
std::unordered_set<int> constArgsIndexes = getConstArgsIndexes(topFunc);
692+
if (constArgsIndexes.empty()) {
693+
return;
694+
}
695+
696+
postponeBroadcast(block);
697+
698+
SmallVector<Operation *> constOps;
699+
for (Operation &op : llvm::make_early_inc_range(block)) {
700+
if (isInConstantSubgraph(&op)) {
701+
constOps.push_back(&op);
702+
}
703+
}
704+
705+
SmallVector<Type> inputTypes; // types of constant weights
706+
// values of constant weights in original block
707+
SmallVector<Value> inputValues;
708+
SmallVector<Type> outputTypes; // types of folded constant weights
709+
// values of folded constant weights in original block
710+
SmallVector<Value> outputValues;
711+
getInputsAndOutputs(block, constArgsIndexes, inputTypes, inputValues,
712+
outputTypes, outputValues);
713+
714+
func::FuncOp foldFunc =
715+
buildFoldFunc(context, builder, topOp, constOps, inputTypes, inputValues,
716+
outputTypes, outputValues);
717+
718+
modifyComputeFunc(context, builder, topOp, topFunc, block, constArgsIndexes,
719+
outputTypes, outputValues);
720+
721+
canonicalizeAndClean(context, topOp, block, foldFunc);
722+
}
723+
684724
std::unique_ptr<Pass> createConstantTensorFoldingPass() {
685725
return std::make_unique<ConstantTensorFolding>();
686726
}

0 commit comments

Comments
 (0)