17
17
18
18
#include " mlir/Dialect/Arith/IR/Arith.h"
19
19
#include " mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20
+ #include " mlir/Dialect/Bufferization/IR/Bufferization.h"
20
21
#include " mlir/Dialect/Bufferization/Transforms/Bufferize.h"
21
22
#include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
22
23
#include " mlir/Dialect/Func/IR/FuncOps.h"
@@ -53,6 +54,8 @@ bool isInConstantSubgraph(Operation *op) {
53
54
auto opNamespace = op->getDialect ()->getNamespace ();
54
55
if (opNamespace == linalg::LinalgDialect::getDialectNamespace () ||
55
56
opNamespace == tensor::TensorDialect::getDialectNamespace () ||
57
+ opNamespace ==
58
+ bufferization::BufferizationDialect::getDialectNamespace () ||
56
59
opNamespace == arith::ArithDialect::getDialectNamespace ()) {
57
60
if (op->getAttr (" onednn_graph.in_const_subgraph" )) {
58
61
return true ;
@@ -61,7 +64,7 @@ bool isInConstantSubgraph(Operation *op) {
61
64
return false ;
62
65
}
63
66
64
- int64_t getTensorSize (TensorType t) {
67
+ template < typename T> int64_t getDataSize (T t) {
65
68
Type eleType = t.getElementType ();
66
69
unsigned bitWidth = eleType.getIntOrFloatBitWidth () / 8 ; // bytes
67
70
ArrayRef<int64_t > shape = t.getShape ();
@@ -72,6 +75,16 @@ int64_t getTensorSize(TensorType t) {
72
75
return size;
73
76
}
74
77
78
+ int64_t getValueSize (Value v) {
79
+ if (isa<TensorType>(v.getType ())) {
80
+ auto t = dyn_cast<TensorType>(v.getType ());
81
+ return getDataSize<TensorType>(t);
82
+ } else {
83
+ auto t = dyn_cast<MemRefType>(v.getType ());
84
+ return getDataSize<MemRefType>(t);
85
+ }
86
+ }
87
+
75
88
// / @brief op has only one operand, or operands of op are one same value, or
76
89
// / operands of op are one same value or from tensor.EmptyOp.
77
90
// / @param op
@@ -465,7 +478,7 @@ void getInputsAndOutputs(Block &block,
465
478
// The constant ops are all single-input single-output.
466
479
bool simpleTopo = true ;
467
480
auto arg = block.getArgument (id);
468
- if (!isa<TensorType>(arg.getType ())) {
481
+ if (!isa<TensorType>(arg.getType ()) && !isa<MemRefType>(arg. getType ()) ) {
469
482
continue ;
470
483
}
471
484
inputTypes.push_back (arg.getType ());
@@ -511,15 +524,12 @@ void getInputsAndOutputs(Block &block,
511
524
// not fold it. Compare data size changes during traverse to find the last
512
525
// op that satisfies this condition.
513
526
if (simpleTopo) {
514
- int64_t initSize =
515
- getTensorSize (dyn_cast<TensorType>(valuesOnTheWay[0 ].getType ()));
516
- if (!isa<TensorType>(outputTypes.back ()) ||
517
- initSize * DATA_SIZE_EXPANDING_THRESHOLD <
518
- getTensorSize (dyn_cast<TensorType>(outputTypes.back ()))) {
527
+ int64_t initSize = getValueSize (valuesOnTheWay[0 ]);
528
+ if (initSize * DATA_SIZE_EXPANDING_THRESHOLD <
529
+ getValueSize (valuesOnTheWay.back ())) {
519
530
size_t lastIdx = 0 ;
520
531
for (size_t i = 1 ; i < valuesOnTheWay.size (); ++i) {
521
- int64_t size = getTensorSize (
522
- dyn_cast<TensorType>(valuesOnTheWay[i].getType ()));
532
+ int64_t size = getValueSize (valuesOnTheWay[i]);
523
533
if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) {
524
534
lastIdx = i;
525
535
}
@@ -574,8 +584,7 @@ func::FuncOp buildFoldFunc(MLIRContext *context, OpBuilder &builder,
574
584
for (Value &tensor : outputValuesInFold) {
575
585
LLVM_DEBUG (llvm::dbgs ()
576
586
<< " Allocate buffer for tensor: " << tensor << " \n " );
577
- buffersSize.push_back (
578
- getTensorSize (dyn_cast<TensorType>(tensor.getType ())));
587
+ buffersSize.push_back (getValueSize (tensor));
579
588
}
580
589
auto manager = ConstGraphTensorCacheManager::get ();
581
590
SmallVector<int64_t > globalIndexes;
0 commit comments