Skip to content

Commit f9c2425

Browse files
committed
Support MemRef args
1 parent 75fcaed commit f9c2425

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

lib/gc/Transforms/ConstantTensorFolding.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "mlir/Dialect/Arith/IR/Arith.h"
1919
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2021
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
2122
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
2223
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -53,6 +54,8 @@ bool isInConstantSubgraph(Operation *op) {
5354
auto opNamespace = op->getDialect()->getNamespace();
5455
if (opNamespace == linalg::LinalgDialect::getDialectNamespace() ||
5556
opNamespace == tensor::TensorDialect::getDialectNamespace() ||
57+
opNamespace ==
58+
bufferization::BufferizationDialect::getDialectNamespace() ||
5659
opNamespace == arith::ArithDialect::getDialectNamespace()) {
5760
if (op->getAttr("onednn_graph.in_const_subgraph")) {
5861
return true;
@@ -61,7 +64,7 @@ bool isInConstantSubgraph(Operation *op) {
6164
return false;
6265
}
6366

64-
int64_t getTensorSize(TensorType t) {
67+
template <typename T> int64_t getDataSize(T t) {
6568
Type eleType = t.getElementType();
6669
unsigned bitWidth = eleType.getIntOrFloatBitWidth() / 8; // bytes
6770
ArrayRef<int64_t> shape = t.getShape();
@@ -72,6 +75,16 @@ int64_t getTensorSize(TensorType t) {
7275
return size;
7376
}
7477

78+
int64_t getValueSize(Value v) {
79+
if (isa<TensorType>(v.getType())) {
80+
auto t = dyn_cast<TensorType>(v.getType());
81+
return getDataSize<TensorType>(t);
82+
} else {
83+
auto t = dyn_cast<MemRefType>(v.getType());
84+
return getDataSize<MemRefType>(t);
85+
}
86+
}
87+
7588
/// @brief op has only one operand, or operands of op are one same value, or
7689
/// operands of op are one same value or from tensor.EmptyOp.
7790
/// @param op
@@ -465,7 +478,7 @@ void getInputsAndOutputs(Block &block,
465478
// The constant ops are all single-input single-output.
466479
bool simpleTopo = true;
467480
auto arg = block.getArgument(id);
468-
if (!isa<TensorType>(arg.getType())) {
481+
if (!isa<TensorType>(arg.getType()) && !isa<MemRefType>(arg.getType())) {
469482
continue;
470483
}
471484
inputTypes.push_back(arg.getType());
@@ -511,15 +524,12 @@ void getInputsAndOutputs(Block &block,
511524
// not fold it. Compare data size changes during traverse to find the last
512525
// op that satisfies this condition.
513526
if (simpleTopo) {
514-
int64_t initSize =
515-
getTensorSize(dyn_cast<TensorType>(valuesOnTheWay[0].getType()));
516-
if (!isa<TensorType>(outputTypes.back()) ||
517-
initSize * DATA_SIZE_EXPANDING_THRESHOLD <
518-
getTensorSize(dyn_cast<TensorType>(outputTypes.back()))) {
527+
int64_t initSize = getValueSize(valuesOnTheWay[0]);
528+
if (initSize * DATA_SIZE_EXPANDING_THRESHOLD <
529+
getValueSize(valuesOnTheWay.back())) {
519530
size_t lastIdx = 0;
520531
for (size_t i = 1; i < valuesOnTheWay.size(); ++i) {
521-
int64_t size = getTensorSize(
522-
dyn_cast<TensorType>(valuesOnTheWay[i].getType()));
532+
int64_t size = getValueSize(valuesOnTheWay[i]);
523533
if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) {
524534
lastIdx = i;
525535
}
@@ -574,8 +584,7 @@ func::FuncOp buildFoldFunc(MLIRContext *context, OpBuilder &builder,
574584
for (Value &tensor : outputValuesInFold) {
575585
LLVM_DEBUG(llvm::dbgs()
576586
<< "Allocate buffer for tensor: " << tensor << "\n");
577-
buffersSize.push_back(
578-
getTensorSize(dyn_cast<TensorType>(tensor.getType())));
587+
buffersSize.push_back(getValueSize(tensor));
579588
}
580589
auto manager = ConstGraphTensorCacheManager::get();
581590
SmallVector<int64_t> globalIndexes;

0 commit comments

Comments
 (0)