30
30
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
31
31
#include " llvm/Support/Debug.h"
32
32
33
- #include " gc/ExecutionEngine/CPURuntime/ConstantCache.hpp"
33
+ // #include "gc/ExecutionEngine/CPURuntime/ConstantCache.hpp"
34
34
35
35
namespace mlir {
36
36
namespace gc {
@@ -300,12 +300,12 @@ static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8;
300
300
// void *allocator(size_t size) { return std::aligned_alloc(64, size); }
301
301
// void deallocator(void *ptr) { std::free(ptr); }
302
302
303
- std::shared_ptr<ConstCacheProxy> createConstCacheProxy (size_t size) {
304
- // simply allocate buffer and return
305
- std::shared_ptr<void > base = std::shared_ptr<void >{
306
- std::aligned_alloc (64 , size), [](void *p) { std::free (p); }};
307
- return std::make_shared<ConstCacheProxy>(base, base.get (), size, true );
308
- }
303
+ // std::shared_ptr<ConstCacheProxy> createConstCacheProxy(size_t size) {
304
+ // // simply allocate buffer and return
305
+ // std::shared_ptr<void> base = std::shared_ptr<void>{
306
+ // std::aligned_alloc(64, size), [](void *p) { std::free(p); }};
307
+ // return std::make_shared<ConstCacheProxy>(base, base.get(), size, true);
308
+ // }
309
309
310
310
size_t divideAndCeil (size_t x, size_t y) { return (x + y - 1 ) / y; }
311
311
@@ -329,12 +329,12 @@ struct constGraphTensorCacheManager {
329
329
totalSize += divideAndCeil (buffersSize[i], 64 ) * 64 ;
330
330
}
331
331
llvm::dbgs () << " Alloc total size: " << totalSize << ' \n ' ;
332
- auto base = createConstCacheProxy (totalSize);
332
+ // auto base = createConstCacheProxy(totalSize);
333
333
std::vector<uint64_t > globalIds (buffersSize.size ());
334
334
size_t offset = 0 ;
335
335
for (size_t i = 0 ; i < buffersSize.size (); i++) {
336
336
llvm::dbgs () << " Alloc offset: " << offset << ' \n ' ;
337
- regCachedTensor (cachedTensorGlobalId, base, offset);
337
+ // regCachedTensor(cachedTensorGlobalId, base, offset);
338
338
globalIds[i] = cachedTensorGlobalId;
339
339
++cachedTensorGlobalId;
340
340
offset += divideAndCeil (buffersSize[i], 64 ) * 64 ;
@@ -431,11 +431,11 @@ void CST::runOnOperation() {
431
431
// values of folded constant weights in original block
432
432
SmallVector<Value> outputValues;
433
433
Value v;
434
- // TODO: solve complicated topology. Currently we only handle simple topology
435
- // where one constant weight input will and only will produce one constant
436
- // output and each constant weight only contributes to one constant output.
434
+ // Support complicated topology.
437
435
for (size_t id = 0 ; id < block.getNumArguments (); ++id) {
438
436
if (constArgsIndexes.count (id) == 1 ) {
437
+ // The constant ops are all single-input single-output.
438
+ bool simpleTopo = true ;
439
439
auto arg = block.getArgument (id);
440
440
if (!isa<TensorType>(arg.getType ())) {
441
441
continue ;
@@ -444,54 +444,72 @@ void CST::runOnOperation() {
444
444
v = dyn_cast<Value>(arg);
445
445
inputValues.push_back (v);
446
446
SmallVector<Value> valuesOnTheWay = {v}; // the constant tensors
447
+ std::deque<Value> dq;
448
+ dq.push_back (v);
447
449
// For v -> pack1 -> pack2 -> matmul, we need the type of output of pack2
448
- while (!v.getUsers ().empty ()) {
449
- // v.getUsers().size() should be 1
450
- Operation *user = *(v.getUsers ().begin ());
451
- // If user is not const or user has multiple operand, we reach the end
452
- if (!isInConstantSubgraph (user) || !singleOperand (user)) {
453
- outputTypes.push_back (v.getType ());
454
- outputValues.push_back (v);
455
- break ;
450
+ while (!dq.empty ()) {
451
+ v = dq.front ();
452
+ dq.pop_front ();
453
+ // if the children ops of v are not all constant, we end at v
454
+ if (std::any_of (v.getUsers ().begin (), v.getUsers ().end (),
455
+ [](Operation *child) {
456
+ return !isInConstantSubgraph (child);
457
+ })) {
458
+ if (std::find (outputValues.begin (), outputValues.end (), v) ==
459
+ outputValues.end ()) {
460
+ outputTypes.push_back (v.getType ());
461
+ outputValues.push_back (v);
462
+ }
463
+ continue ;
464
+ }
465
+ if (!v.hasOneUse ()) {
466
+ simpleTopo = false ;
467
+ }
468
+ // the children ops of v are all constant, we push their results to
469
+ // queue
470
+ for (Operation *child : v.getUsers ()) {
471
+ if (!singleOperand (child) || child->getResults ().size () > 1 ) {
472
+ simpleTopo = false ;
473
+ }
474
+ for (OpResult result : child->getResults ()) {
475
+ auto r = dyn_cast<Value>(result);
476
+ dq.push_back (r);
477
+ valuesOnTheWay.push_back (r);
478
+ }
456
479
}
457
- // user should has only 1 output value
458
- OpResult result = *(user->result_begin ());
459
- v = dyn_cast<Value>(result);
460
- valuesOnTheWay.push_back (v);
461
480
}
462
481
463
482
// If data size of outputValue is too greater than size of inputValue, do
464
483
// not fold it. Compare data size changes during traverse to find the last
465
484
// op that satisfies this condition.
466
- int64_t initSize =
467
- getTensorSize (dyn_cast<TensorType>(valuesOnTheWay[0 ].getType ()));
468
- if (!isa<TensorType>(outputTypes.back ()) ||
469
- initSize * DATA_SIZE_EXPANDING_THRESHOLD <
470
- getTensorSize (dyn_cast<TensorType>(outputTypes.back ()))) {
471
- size_t lastIdx = 0 ;
472
- for (size_t i = 1 ; i < valuesOnTheWay.size (); ++i) {
473
- int64_t size =
474
- getTensorSize (dyn_cast<TensorType>(valuesOnTheWay[i].getType ()));
475
- if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) {
476
- lastIdx = i;
485
+ if (simpleTopo) {
486
+ int64_t initSize =
487
+ getTensorSize (dyn_cast<TensorType>(valuesOnTheWay[0 ].getType ()));
488
+ if (!isa<TensorType>(outputTypes.back ()) ||
489
+ initSize * DATA_SIZE_EXPANDING_THRESHOLD <
490
+ getTensorSize (dyn_cast<TensorType>(outputTypes.back ()))) {
491
+ size_t lastIdx = 0 ;
492
+ for (size_t i = 1 ; i < valuesOnTheWay.size (); ++i) {
493
+ int64_t size = getTensorSize (
494
+ dyn_cast<TensorType>(valuesOnTheWay[i].getType ()));
495
+ if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) {
496
+ lastIdx = i;
497
+ }
498
+ }
499
+ if (lastIdx == 0 ) { // no suitable value found
500
+ inputTypes.pop_back ();
501
+ outputTypes.pop_back ();
502
+ inputValues.pop_back ();
503
+ outputValues.pop_back ();
504
+ constArgsIndexes.erase (id);
505
+ } else {
506
+ outputTypes.back () = valuesOnTheWay[lastIdx].getType ();
507
+ outputValues.back () = valuesOnTheWay[lastIdx];
477
508
}
478
- }
479
- if (lastIdx == 0 ) { // no suitable value found
480
- inputTypes.pop_back ();
481
- outputTypes.pop_back ();
482
- inputValues.pop_back ();
483
- outputValues.pop_back ();
484
- constArgsIndexes.erase (id);
485
- } else {
486
- outputTypes.back () = valuesOnTheWay[lastIdx].getType ();
487
- outputValues.back () = valuesOnTheWay[lastIdx];
488
509
}
489
510
}
490
511
}
491
512
}
492
- if (inputTypes.size () != outputTypes.size ()) {
493
- return ;
494
- }
495
513
496
514
FunctionType foldFuncType =
497
515
FunctionType::get (context, inputTypes, outputTypes);
@@ -548,30 +566,34 @@ void CST::runOnOperation() {
548
566
moduleOp.push_back (foldFunc);
549
567
symbolTable.insert (foldFunc);
550
568
569
+ // the indexes of args to the folding func.
551
570
SmallVector<int32_t > foldArgs;
571
+ // the indexes of folded args.
552
572
SmallVector<int32_t > foldIds;
573
+ // the indexes of args to the computing func.
553
574
SmallVector<int32_t > computeArgs;
554
575
555
576
// modify the BlockArguments of block
556
577
size_t oriNumArgs = block.getNumArguments ();
557
- size_t argIdx = 0 ;
578
+ // Add the folded args to the end of BlockArguments list
579
+ for (size_t id = 0 ; id < outputValues.size (); ++id) {
580
+ auto loc = block.getArgument (id).getLoc ();
581
+ BlockArgument foldArg =
582
+ block.insertArgument (oriNumArgs + id, outputTypes[id], loc);
583
+ outputValues[id].replaceUsesWithIf (foldArg, [&](OpOperand &val) {
584
+ Operation *op = val.getOwner ();
585
+ return op->getBlock () == █
586
+ });
587
+ foldIds.push_back (id + oriNumArgs);
588
+ }
589
+ // Erase the operations on constant args
558
590
for (size_t id = 0 ; id < oriNumArgs; ++id) {
559
591
if (constArgsIndexes.count (id) == 1 ) {
560
592
foldArgs.push_back (id);
561
- foldIds.push_back (argIdx + oriNumArgs);
562
- computeArgs.push_back (argIdx + oriNumArgs);
563
- auto loc = block.getArgument (id).getLoc ();
564
- BlockArgument foldArg =
565
- block.insertArgument (id, outputTypes[argIdx], loc);
566
- outputValues[argIdx].replaceUsesWithIf (foldArg, [&](OpOperand &val) {
567
- Operation *op = val.getOwner ();
568
- return op->getBlock () == █
569
- });
570
-
571
593
std::deque<Value> dq;
572
594
SmallVector<Operation *> opsToErase;
573
595
std::unordered_set<Operation *> opsToEraseSet;
574
- dq.push_back (block.getArgument (id + 1 ));
596
+ dq.push_back (block.getArgument (id));
575
597
while (!dq.empty ()) {
576
598
Value v = dq.front ();
577
599
dq.pop_front ();
@@ -586,16 +608,26 @@ void CST::runOnOperation() {
586
608
opsToEraseSet.insert (op);
587
609
}
588
610
}
589
-
590
611
for (auto it = opsToErase.rbegin (); it != opsToErase.rend (); ++it) {
591
612
(*it)->erase ();
592
613
}
593
- block.eraseArgument (id + 1 );
594
- ++argIdx;
595
614
} else {
596
615
computeArgs.push_back (id);
597
616
}
598
617
}
618
+ // Erase the constant args in BlockArguments list
619
+ llvm::BitVector argsToErase;
620
+ for (size_t id = 0 ; id < oriNumArgs; ++id) {
621
+ if (constArgsIndexes.count (id) == 1 ) {
622
+ argsToErase.push_back (true );
623
+ } else {
624
+ argsToErase.push_back (false );
625
+ }
626
+ }
627
+ for (size_t id = 0 ; id < outputValues.size (); ++id) {
628
+ argsToErase.push_back (false );
629
+ }
630
+ block.eraseArguments (argsToErase);
599
631
600
632
for (auto id : foldIds) {
601
633
foldArgs.insert (foldArgs.end (), id);
@@ -604,6 +636,9 @@ void CST::runOnOperation() {
604
636
addGlobalI32Array (moduleOp, moduleOp.getLoc (), builder, " __fold_args" ,
605
637
foldArgs);
606
638
639
+ for (auto id : foldIds) {
640
+ computeArgs.insert (computeArgs.end (), id);
641
+ }
607
642
computeArgs.insert (computeArgs.begin (), computeArgs.size ());
608
643
addGlobalI32Array (moduleOp, moduleOp.getLoc (), builder, " __compute_args" ,
609
644
computeArgs);
0 commit comments