Skip to content

Commit 94f2813

Browse files
committed
Support cpmplex topo
1 parent 4363915 commit 94f2813

File tree

3 files changed

+128
-88
lines changed

3 files changed

+128
-88
lines changed

lib/gc/Transforms/CST.cpp

+98-63
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3131
#include "llvm/Support/Debug.h"
3232

33-
#include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp"
33+
// #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp"
3434

3535
namespace mlir {
3636
namespace gc {
@@ -300,12 +300,12 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8;
300300
// void *allocator(size_t size) { return std::aligned_alloc(64, size); }
301301
// void deallocator(void *ptr) { std::free(ptr); }
302302

303-
std::shared_ptr<ConstCacheProxy> createConstCacheProxy(size_t size) {
304-
// simply allocate buffer and return
305-
std::shared_ptr<void> base = std::shared_ptr<void>{
306-
std::aligned_alloc(64, size), [](void *p) { std::free(p); }};
307-
return std::make_shared<ConstCacheProxy>(base, base.get(), size, true);
308-
}
303+
// std::shared_ptr<ConstCacheProxy> createConstCacheProxy(size_t size) {
304+
// // simply allocate buffer and return
305+
// std::shared_ptr<void> base = std::shared_ptr<void>{
306+
// std::aligned_alloc(64, size), [](void *p) { std::free(p); }};
307+
// return std::make_shared<ConstCacheProxy>(base, base.get(), size, true);
308+
// }
309309

310310
size_t divideAndCeil(size_t x, size_t y) { return (x + y - 1) / y; }
311311

@@ -329,12 +329,12 @@ struct constGraphTensorCacheManager {
329329
totalSize += divideAndCeil(buffersSize[i], 64) * 64;
330330
}
331331
llvm::dbgs() << "Alloc total size: " << totalSize << '\n';
332-
auto base = createConstCacheProxy(totalSize);
332+
// auto base = createConstCacheProxy(totalSize);
333333
std::vector<uint64_t> globalIds(buffersSize.size());
334334
size_t offset = 0;
335335
for (size_t i = 0; i < buffersSize.size(); i++) {
336336
llvm::dbgs() << "Alloc offset: " << offset << '\n';
337-
regCachedTensor(cachedTensorGlobalId, base, offset);
337+
// regCachedTensor(cachedTensorGlobalId, base, offset);
338338
globalIds[i] = cachedTensorGlobalId;
339339
++cachedTensorGlobalId;
340340
offset += divideAndCeil(buffersSize[i], 64) * 64;
@@ -431,11 +431,11 @@ void CST::runOnOperation() {
431431
// values of folded constant weights in original block
432432
SmallVector<Value> outputValues;
433433
Value v;
434-
// TODO: solve complicated topology. Currently we only handle simple topology
435-
// where one constant weight input will and only will produce one constant
436-
// output and each constant weight only contributes to one constant output.
434+
// Support complicated topology.
437435
for (size_t id = 0; id < block.getNumArguments(); ++id) {
438436
if (constArgsIndexes.count(id) == 1) {
437+
// The constant ops are all single-input single-output.
438+
bool simpleTopo = true;
439439
auto arg = block.getArgument(id);
440440
if (!isa<TensorType>(arg.getType())) {
441441
continue;
@@ -444,54 +444,72 @@ void CST::runOnOperation() {
444444
v = dyn_cast<Value>(arg);
445445
inputValues.push_back(v);
446446
SmallVector<Value> valuesOnTheWay = {v}; // the constant tensors
447+
std::deque<Value> dq;
448+
dq.push_back(v);
447449
// For v -> pack1 -> pack2 -> matmul, we need the type of output of pack2
448-
while (!v.getUsers().empty()) {
449-
// v.getUsers().size() should be 1
450-
Operation *user = *(v.getUsers().begin());
451-
// If user is not const or user has multiple operand, we reach the end
452-
if (!isInConstantSubgraph(user) || !singleOperand(user)) {
453-
outputTypes.push_back(v.getType());
454-
outputValues.push_back(v);
455-
break;
450+
while (!dq.empty()) {
451+
v = dq.front();
452+
dq.pop_front();
453+
// if the children ops of v are not all constant, we end at v
454+
if (std::any_of(v.getUsers().begin(), v.getUsers().end(),
455+
[](Operation *child) {
456+
return !isInConstantSubgraph(child);
457+
})) {
458+
if (std::find(outputValues.begin(), outputValues.end(), v) ==
459+
outputValues.end()) {
460+
outputTypes.push_back(v.getType());
461+
outputValues.push_back(v);
462+
}
463+
continue;
464+
}
465+
if (!v.hasOneUse()) {
466+
simpleTopo = false;
467+
}
468+
// the children ops of v are all constant, we push their results to
469+
// queue
470+
for (Operation *child : v.getUsers()) {
471+
if (!singleOperand(child) || child->getResults().size() > 1) {
472+
simpleTopo = false;
473+
}
474+
for (OpResult result : child->getResults()) {
475+
auto r = dyn_cast<Value>(result);
476+
dq.push_back(r);
477+
valuesOnTheWay.push_back(r);
478+
}
456479
}
457-
// user should has only 1 output value
458-
OpResult result = *(user->result_begin());
459-
v = dyn_cast<Value>(result);
460-
valuesOnTheWay.push_back(v);
461480
}
462481

463482
// If data size of outputValue is too greater than size of inputValue, do
464483
// not fold it. Compare data size changes during traverse to find the last
465484
// op that satisfies this condition.
466-
int64_t initSize =
467-
getTensorSize(dyn_cast<TensorType>(valuesOnTheWay[0].getType()));
468-
if (!isa<TensorType>(outputTypes.back()) ||
469-
initSize * DATA_SIZE_EXPANDING_THRESHOLD <
470-
getTensorSize(dyn_cast<TensorType>(outputTypes.back()))) {
471-
size_t lastIdx = 0;
472-
for (size_t i = 1; i < valuesOnTheWay.size(); ++i) {
473-
int64_t size =
474-
getTensorSize(dyn_cast<TensorType>(valuesOnTheWay[i].getType()));
475-
if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) {
476-
lastIdx = i;
485+
if (simpleTopo) {
486+
int64_t initSize =
487+
getTensorSize(dyn_cast<TensorType>(valuesOnTheWay[0].getType()));
488+
if (!isa<TensorType>(outputTypes.back()) ||
489+
initSize * DATA_SIZE_EXPANDING_THRESHOLD <
490+
getTensorSize(dyn_cast<TensorType>(outputTypes.back()))) {
491+
size_t lastIdx = 0;
492+
for (size_t i = 1; i < valuesOnTheWay.size(); ++i) {
493+
int64_t size = getTensorSize(
494+
dyn_cast<TensorType>(valuesOnTheWay[i].getType()));
495+
if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) {
496+
lastIdx = i;
497+
}
498+
}
499+
if (lastIdx == 0) { // no suitable value found
500+
inputTypes.pop_back();
501+
outputTypes.pop_back();
502+
inputValues.pop_back();
503+
outputValues.pop_back();
504+
constArgsIndexes.erase(id);
505+
} else {
506+
outputTypes.back() = valuesOnTheWay[lastIdx].getType();
507+
outputValues.back() = valuesOnTheWay[lastIdx];
477508
}
478-
}
479-
if (lastIdx == 0) { // no suitable value found
480-
inputTypes.pop_back();
481-
outputTypes.pop_back();
482-
inputValues.pop_back();
483-
outputValues.pop_back();
484-
constArgsIndexes.erase(id);
485-
} else {
486-
outputTypes.back() = valuesOnTheWay[lastIdx].getType();
487-
outputValues.back() = valuesOnTheWay[lastIdx];
488509
}
489510
}
490511
}
491512
}
492-
if (inputTypes.size() != outputTypes.size()) {
493-
return;
494-
}
495513

496514
FunctionType foldFuncType =
497515
FunctionType::get(context, inputTypes, outputTypes);
@@ -548,30 +566,34 @@ void CST::runOnOperation() {
548566
moduleOp.push_back(foldFunc);
549567
symbolTable.insert(foldFunc);
550568

569+
// the indexes of args to the folding func.
551570
SmallVector<int32_t> foldArgs;
571+
// the indexes of folded args.
552572
SmallVector<int32_t> foldIds;
573+
// the indexes of args to the computing func.
553574
SmallVector<int32_t> computeArgs;
554575

555576
// modify the BlockArguments of block
556577
size_t oriNumArgs = block.getNumArguments();
557-
size_t argIdx = 0;
578+
// Add the folded args to the end of BlockArguments list
579+
for (size_t id = 0; id < outputValues.size(); ++id) {
580+
auto loc = block.getArgument(id).getLoc();
581+
BlockArgument foldArg =
582+
block.insertArgument(oriNumArgs + id, outputTypes[id], loc);
583+
outputValues[id].replaceUsesWithIf(foldArg, [&](OpOperand &val) {
584+
Operation *op = val.getOwner();
585+
return op->getBlock() == &block;
586+
});
587+
foldIds.push_back(id + oriNumArgs);
588+
}
589+
// Erase the operations on constant args
558590
for (size_t id = 0; id < oriNumArgs; ++id) {
559591
if (constArgsIndexes.count(id) == 1) {
560592
foldArgs.push_back(id);
561-
foldIds.push_back(argIdx + oriNumArgs);
562-
computeArgs.push_back(argIdx + oriNumArgs);
563-
auto loc = block.getArgument(id).getLoc();
564-
BlockArgument foldArg =
565-
block.insertArgument(id, outputTypes[argIdx], loc);
566-
outputValues[argIdx].replaceUsesWithIf(foldArg, [&](OpOperand &val) {
567-
Operation *op = val.getOwner();
568-
return op->getBlock() == &block;
569-
});
570-
571593
std::deque<Value> dq;
572594
SmallVector<Operation *> opsToErase;
573595
std::unordered_set<Operation *> opsToEraseSet;
574-
dq.push_back(block.getArgument(id + 1));
596+
dq.push_back(block.getArgument(id));
575597
while (!dq.empty()) {
576598
Value v = dq.front();
577599
dq.pop_front();
@@ -586,16 +608,26 @@ void CST::runOnOperation() {
586608
opsToEraseSet.insert(op);
587609
}
588610
}
589-
590611
for (auto it = opsToErase.rbegin(); it != opsToErase.rend(); ++it) {
591612
(*it)->erase();
592613
}
593-
block.eraseArgument(id + 1);
594-
++argIdx;
595614
} else {
596615
computeArgs.push_back(id);
597616
}
598617
}
618+
// Erase the constant args in BlockArguments list
619+
llvm::BitVector argsToErase;
620+
for (size_t id = 0; id < oriNumArgs; ++id) {
621+
if (constArgsIndexes.count(id) == 1) {
622+
argsToErase.push_back(true);
623+
} else {
624+
argsToErase.push_back(false);
625+
}
626+
}
627+
for (size_t id = 0; id < outputValues.size(); ++id) {
628+
argsToErase.push_back(false);
629+
}
630+
block.eraseArguments(argsToErase);
599631

600632
for (auto id : foldIds) {
601633
foldArgs.insert(foldArgs.end(), id);
@@ -604,6 +636,9 @@ void CST::runOnOperation() {
604636
addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__fold_args",
605637
foldArgs);
606638

639+
for (auto id : foldIds) {
640+
computeArgs.insert(computeArgs.end(), id);
641+
}
607642
computeArgs.insert(computeArgs.begin(), computeArgs.size());
608643
addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__compute_args",
609644
computeArgs);

test/gc/Transforms/test_constant_weights_folding-1.mlir

+21-22
Original file line numberDiff line numberDiff line change
@@ -19,32 +19,31 @@ module {
1919

2020
// CHECK: cpuruntime.printf
2121
// CHECK: linalg.add
22-
// CHECK: linalg.add
2322
// CHECK: func.func @fold
2423
// CHECK: linalg.add
2524
// CHECK: linalg.add
25+
// CHECK: linalg.add
2626

2727
// COM: expected output:
2828
// COM: module {
29-
// COM: llvm.mlir.global constant @__num_orig_num_args(4 : i32) : i32
30-
// COM: llvm.mlir.global constant @__fold_buffer_ids(dense<[2, 114514, 1919810]> : tensor<3 x i64>) : !llvm.array<3 x i64>
31-
// COM: // a,b, foldedA,foldedB
32-
// COM: llvm.mlir.global constant @__fold_args(dense<[4, 0, 1, 4, 5]> : tensor<5xi32>) : !llvm.array<5 x i32>
33-
// COM: // foldedA, foldedB, c, d
34-
// COM: llvm.mlir.global constant @__compute_args(dense<[4, 4, 5, 2, 3]> : tensor<5xi32>) : !llvm.array<5 x i32>
35-
// COM: func.func @fold(%a: tensor<128xf32>, %b: tensor<128xf32>) -> (tensor<128xf32>, tensor<128xf32>) attributes { llvm.emit_c_interface } {
36-
// COM: %c0 = arith.constant 0 : index
37-
// COM: cpuruntime.printf "HI%zu\n" %c0 : index
38-
// COM: %out = tensor.empty() : tensor<128xf32>
39-
// COM: %2 = linalg.add ins(%a, %a : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32>
40-
// COM: %out2 = tensor.empty() : tensor<128xf32>
41-
// COM: %3 = linalg.add ins(%b, %b : tensor<128xf32>,tensor<128xf32>) outs(%out2 : tensor<128xf32>) -> tensor<128xf32>
42-
// COM: return %2, %3 : tensor<128xf32>, tensor<128xf32>
43-
// COM: }
44-
// COM: func.func @compute(%ax2: tensor<128xf32>, %bx2: tensor<128xf32>, %c: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } {
45-
// COM: %out = tensor.empty() : tensor<128xf32>
46-
// COM: %2 = linalg.add ins(%ax2, %bx2 : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32>
47-
// COM: %d = linalg.add ins(%2, %c : tensor<128xf32>,tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32>
48-
// COM: return %d : tensor<128xf32>
49-
// COM: }
29+
// COM: llvm.mlir.global external constant @__num_orig_num_args(3 : i32) {addr_space = 0 : i32} : i32
30+
// COM: llvm.mlir.global external constant @__compute_args(dense<[2, 2, 3]> : tensor<3xi32>) {addr_space = 0 : i32} : !llvm.array<3 x i32>
31+
// COM: llvm.mlir.global external constant @__fold_args(dense<[3, 0, 1, 3]> : tensor<4xi32>) {addr_space = 0 : i32} : !llvm.array<4 x i32>
32+
// COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[1, 0]> : tensor<2xi64>) {addr_space = 0 : i32} : !llvm.array<2 x i64>
33+
// COM: func.func @entry(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<128xf32> attributes {llvm.emit_c_interface, onednn_graph.const_args = [0 : i32, 1 : i32]} {
34+
// COM: %c0 = arith.constant 0 : index
35+
// COM: cpuruntime.printf "HI%zu\0A" %c0 : index
36+
// COM: %0 = tensor.empty() : tensor<128xf32>
37+
// COM: %1 = linalg.add ins(%arg1, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32>
38+
// COM: return %1 : tensor<128xf32>
39+
// COM: }
40+
// COM: func.func @fold(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<128xf32> attributes {llvm.emit_c_interface} {
41+
// COM: %0 = tensor.empty() : tensor<128xf32>
42+
// COM: %1 = linalg.add ins(%arg0, %arg0 : tensor<128xf32>, tensor<128xf32>) outs(%0 : tensor<128xf32>) -> tensor<128xf32>
43+
// COM: %2 = tensor.empty() : tensor<128xf32>
44+
// COM: %3 = linalg.add ins(%arg1, %arg1 : tensor<128xf32>, tensor<128xf32>) outs(%2 : tensor<128xf32>) -> tensor<128xf32>
45+
// COM: %4 = tensor.empty() : tensor<128xf32>
46+
// COM: %5 = linalg.add ins(%1, %3 : tensor<128xf32>, tensor<128xf32>) outs(%4 : tensor<128xf32>) -> tensor<128xf32>
47+
// COM: return %5 : tensor<128xf32>
48+
// COM: }
5049
// COM: }

test/gc/Transforms/test_constant_weights_folding.mlir

+9-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
module {
1010
// COM: A two-layer mlp. arg0: input feature. arg1: weight of #1 linear. arg2: bias of #1 linear.
1111
// COM: arg3: weight of #2 linear. arg4: bias of #2 linear.
12-
func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} {
12+
func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} {
1313
%1 = tensor.empty() : tensor<2x16x32x32xbf16>
1414
%packed_arg0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<64x512xbf16> -> tensor<2x16x32x32xbf16>
1515
%2 = tensor.empty() : tensor<8x16x32x32xbf16>
@@ -71,6 +71,12 @@ module {
7171
// CHECK: func.func @fold
7272
// CHECK: arith.extf
7373
// CHECK: arith.truncf
74+
7475
// COM: expected output:
75-
// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16>
76-
// COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>)
76+
// COM: module {
77+
// COM: llvm.mlir.global external constant @__num_orig_num_args(5 : i32) {addr_space = 0 : i32} : i32
78+
// COM: llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32>
79+
// COM: llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32>
80+
// COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64>
81+
// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, onednn_graph.const_args = [1 : i32, 2 : i32, 3 : i32, 4 : i32]}
82+
// COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) attributes {llvm.emit_c_interface}

0 commit comments

Comments
 (0)