Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Runtime] Constant cache manager and runtime pipeline #342

Open
wants to merge 94 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
13faa33
add cpuruntime dialect
May 14, 2024
161848e
format
May 14, 2024
447ef12
add dependency
May 14, 2024
a73dcc1
fix new MLIR
May 14, 2024
1cfede8
add
May 15, 2024
57ba92e
Merge remote-tracking branch 'origin/main' into yijie/cpuruntime
May 15, 2024
3d3308c
move codes from dnn-compiler
niuxiaog May 15, 2024
4d25de6
Merge branch 'yijie/cpuruntime' into yijie/pipeline
May 15, 2024
475faf8
update
May 15, 2024
4f112c0
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog May 15, 2024
0ac087d
fix
May 15, 2024
74b0d34
remove at exit
May 16, 2024
2cebba9
fix lint
May 16, 2024
d1b35a1
Merge branch 'yijie/cpuruntime' into yijie/pipeline
May 16, 2024
34d10ea
Add kmp_* wrapper for gomp environment
May 16, 2024
55c1043
Merge remote-tracking branch 'origin' into yijie/pipeline
May 16, 2024
e1490bb
Merge branch 'yijie/fake_omp' into yijie/pipeline
May 16, 2024
80a597f
fix
May 16, 2024
0b4332b
fix
May 16, 2024
c43f481
Merge branch 'main' into yijie/fake_omp
May 23, 2024
b1c79a2
add wraper
May 23, 2024
382171b
fix lint
May 23, 2024
ef75da8
Merge branch 'yijie/fake_omp' of https://github.com/intel/graph-compi…
May 23, 2024
f1fd0ae
fix
May 23, 2024
a773ea6
f
May 23, 2024
84933c2
fix
May 23, 2024
4cca4df
add reference
May 23, 2024
678cef9
enable const cache
May 24, 2024
c12156c
reduce size
May 24, 2024
d50a3e8
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog May 27, 2024
6219935
Add single operand check
niuxiaog May 27, 2024
5eb0ac0
Add cache manager
niuxiaog May 27, 2024
c3e186d
Use llvm global [need to cowork with yijie/mainfunc_wrapper]
niuxiaog May 28, 2024
e24b1df
rename
May 28, 2024
34064f3
Merge branch 'main' of https://github.com/intel/graph-compiler into y…
May 28, 2024
70c5e97
Merge branch 'main' of https://github.com/intel/graph-compiler into y…
May 28, 2024
1e06c98
fix license.py
May 28, 2024
3f656b7
Merge branch 'yijie/fake_omp' into yijie/pipeline
May 28, 2024
24cee01
Merge branch 'yijie/pipeline' into yijie/mainfunc_wrapper
May 28, 2024
7c32bc5
fix
May 28, 2024
4540fb6
fix lint
May 28, 2024
381677a
fix comments
May 28, 2024
8c50b67
Rename; Add llvm dependence
niuxiaog May 28, 2024
25f611e
Change dtype
niuxiaog May 28, 2024
4363915
Fix visibility and type
niuxiaog May 29, 2024
fdfc53e
Merge branch 'main' of https://github.com/intel/graph-compiler into y…
May 29, 2024
60042e1
Merge branch 'yijie/pipeline' into yijie/mainfunc_wrapper
May 29, 2024
b54b310
fix
May 29, 2024
9d04cd2
format
May 29, 2024
206c3f3
cleanup
May 30, 2024
824946b
Revert "cleanup"
May 30, 2024
bc9a7ad
refine options
May 30, 2024
3bd954c
Merge branch 'yijie/mainfunc_wrapper' into yijie/const_cache_jit
May 30, 2024
94f2813
Support cpmplex topo
niuxiaog May 30, 2024
0f67f75
Rename
niuxiaog Jun 3, 2024
d7663a5
Split into short functions
niuxiaog Jun 4, 2024
3f34e97
Add a test
niuxiaog Jun 5, 2024
22c3d76
Adapt to constant PropertyType
niuxiaog Jun 11, 2024
5c92931
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Jul 24, 2024
9218762
Revert "Adapt to constant PropertyType"
niuxiaog Jul 24, 2024
4e447dd
Fix link
niuxiaog Jul 24, 2024
d4d81a6
Fold arith.constant
niuxiaog Jul 25, 2024
afec52a
Add compile_time_fold and runtime_fold.
niuxiaog Jul 25, 2024
9c4fd70
Fix license and tidy
niuxiaog Jul 26, 2024
fad5f92
Fix link
niuxiaog Jul 26, 2024
57f887d
Only enable runtime folding
niuxiaog Jul 29, 2024
1fc3b9f
Rename and polish
niuxiaog Jul 29, 2024
aaa4ed4
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Jul 31, 2024
bfc12c7
Add accuracy tests on mlp
niuxiaog Aug 7, 2024
346965f
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Aug 7, 2024
75fcaed
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Aug 19, 2024
f9c2425
Support MemRef args
niuxiaog Aug 20, 2024
d8d2d79
Add to pipeline
niuxiaog Aug 20, 2024
fc739e5
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Aug 26, 2024
22c4474
Forbid buffer_to_tensor case
niuxiaog Aug 26, 2024
968677d
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Sep 2, 2024
1473a88
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Sep 5, 2024
e20d059
Add shape info to global
niuxiaog Sep 6, 2024
99811f2
Merge branch 'xgniu/constant_weights_folding' into xgniu/folding_manager
niuxiaog Sep 11, 2024
3a47e28
Merge branch 'yijie/const_cache_jit' into xgniu/folding_manager
niuxiaog Sep 11, 2024
36fc758
Make things work
niuxiaog Sep 13, 2024
362ad2b
Merge branch 'xgniu/constant_weights_folding' into xgniu/folding_manager
niuxiaog Sep 13, 2024
8d08752
Unify attr name
niuxiaog Sep 13, 2024
ad24768
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Sep 14, 2024
edbb708
Clean tests.
niuxiaog Sep 14, 2024
fa30e4a
Updates
niuxiaog Sep 14, 2024
b8b0dd2
Move manager
niuxiaog Sep 14, 2024
361bed6
Merge branch 'xgniu/constant_weights_folding' into xgniu/folding_manager
niuxiaog Sep 14, 2024
6a041dd
Use atomic
niuxiaog Sep 14, 2024
c876358
Fix
niuxiaog Sep 14, 2024
a255c7b
Merge branch 'main' into xgniu/constant_weights_folding
niuxiaog Sep 18, 2024
77e0f02
Merge into one pass
niuxiaog Sep 18, 2024
2df16c2
Skip case
niuxiaog Sep 18, 2024
d8aedad
Merge branch 'xgniu/constant_weights_folding' into xgniu/folding_manager
niuxiaog Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions include/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
//===-- ConstantSubgraphAnalyser.h - Constant subgraph ----------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// This file implements constant subgraph analysis. In this file are:
/// 1. the lattice value class that represents operations with constant inputs
/// and outputs in the program, and
/// 2. a sparse constant subgraph analysis.
///
///===----------------------------------------------------------------------===//

#ifndef MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H
#define MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H

#include "mlir/Analysis/DataFlow/SparseAnalysis.h"

namespace mlir {
namespace dataflow {

//===----------------------------------------------------------------------===//
// IsConstantTensor
//===----------------------------------------------------------------------===//

/// This lattice represents a boolean indicating if a value is constant.
class IsConstantTensor {
public:
/// Construct as uninitialized.
explicit IsConstantTensor() = default;

/// Construct with a known state.
explicit IsConstantTensor(bool initialized, bool isConstantTensor)
: initialized(initialized), isConstantTensor(isConstantTensor) {}

/// Get the state. Must be initialized before.
bool getIsConstantTensor() const {
assert(!isUninitialized());
return isConstantTensor;
}

/// Compare.
bool operator==(const IsConstantTensor &rhs) const {
return initialized == rhs.initialized &&
isConstantTensor == rhs.isConstantTensor;
}

void print(raw_ostream &os) const;

/// Get uninitialized state. This happens when the
/// state hasn't been set during the analysis.
static IsConstantTensor getUninitialized() { return IsConstantTensor{}; }

/// Whether the state is uninitialized.
bool isUninitialized() const { return !initialized; }

/// Get unknown state.
static IsConstantTensor getUnknown() {
return IsConstantTensor{/*initialized=*/false,
/*isConstantTensor*/ false};
}

// Join two states.
static IsConstantTensor join(const IsConstantTensor &lhs,
const IsConstantTensor &rhs) {
// if one is uninitialized, use another
if (lhs.isUninitialized())
return rhs;
if (rhs.isUninitialized())
return lhs;

// both are initialized, intersect them
if (!lhs.isUninitialized() && !rhs.isUninitialized()) {
return IsConstantTensor(true, lhs.getIsConstantTensor() &&
rhs.getIsConstantTensor());
}
return getUninitialized();
}

private:
bool initialized = false;
bool isConstantTensor = false;
};

//===----------------------------------------------------------------------===//
// ConstantSubgraphAnalyser
//===----------------------------------------------------------------------===//

class ConstantSubgraphAnalyser
: public SparseForwardDataFlowAnalysis<Lattice<IsConstantTensor>> {
public:
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;

LogicalResult visitOperation(Operation *op,
ArrayRef<const Lattice<IsConstantTensor> *> operands,
ArrayRef<Lattice<IsConstantTensor> *> results) override;

void setToEntryState(Lattice<IsConstantTensor> *lattice) override;
};

//===----------------------------------------------------------------------===//
// RunConstantSubgraphAnalyser
//===----------------------------------------------------------------------===//

/// Runs constant subgraph analysis on the IR defined by `op`.
struct RunConstantSubgraphAnalyser {
public:
RunConstantSubgraphAnalyser();

void run(Operation *op);

bool getIsConstantTensor(Value val);

private:
/// Stores the result of the analysis.
DataFlowSolver solver;

void getConstantSubgraph(DataFlowSolver &solver, Operation *topFunc);
};
} // end namespace dataflow
} // end namespace mlir

#endif // MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H
2 changes: 2 additions & 0 deletions include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def OneDNNGraphDialect : Dialect {
This dialect follows oneDNN Graph Specification.
}];
let cppNamespace = "::mlir::onednn_graph";

let hasOperationAttrVerify = 1;
}

#endif // ONEDNNGRAPH_DIALECT
207 changes: 207 additions & 0 deletions include/gc/ExecutionEngine/CPURuntime/ConstantCache.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
//===-- ConstantCache.h - Constant cache interfaces -------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H
#define GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H
#include "mlir/ExecutionEngine/CRunnerUtils.h"
#include <atomic>
#include <cstdlib>
#include <cstring>
#include <memory>
#include <stdint.h>
#include <unordered_map>
namespace mlir {
namespace gc {
/**
* The helper class to manage ref count manually for an object allocated with
* shared ptr. It holds an additional shared ptr reference to the object and
* contains an additional self-managed refcount. The refcount will be set to 1
* when the object is initialized (see init()). When the refcount counts down to
* 0, the additional shared ptr is reset.
*/
struct RefCountManaged {
RefCountManaged() = default;
RefCountManaged(const std::shared_ptr<void> &vkeepAlive) { init(vkeepAlive); }
void init(const std::shared_ptr<void> &vkeepAlive) {
keepAlive = vkeepAlive;
refCount.store(1);
}

void ref() { ++refCount; }
void deref() {
auto newv = --refCount;
if (newv == 0) {
keepAlive = nullptr;
}
}

// atomically check if refCount > 0. if so, ref() the object and return
// true. Otherwise (if refCount==0), return false
bool checkAliveAndRef() {
auto oldv = refCount.load();
for (;;) {
if (oldv <= 0) {
return false;
}
if (refCount.compare_exchange_strong(oldv, oldv + 1)) {
return true;
}
// CAS failed, oldv has now the newest known value of refCount
}
}

bool isAlive() const { return refCount > 0; }
void *getPtrUnsafe() const { return keepAlive.get(); }

private:
std::shared_ptr<void> keepAlive;
std::atomic<int> refCount{0};
};

/**
* The proxy for the constant cache of Graph API. It holds a shared ptr pointing
* to the cache item in the cache manager (keepAlive) to extend the lifetime by
* refcount, @see RefCountManaged. To access the memory buffer of the const
* cache, use acauire/release functions. They will ref/deref the ConstCacheProxy
* to make sure the cache is alive after calling acauire and before release. The
* cache manager of Graph API may evict the cache item by dereferenceing this
* RefCountManaged object. {acquire,release} functions will find out that the
* cache has been invalidated. Usually we expect JIT modules to hold shared ptr
* to ConstCacheProxy via CachedGraphTensor. If isLazy == true, the cache
* item's lifetime will be managed by the cache manager of Graph API and it is
* filled with data after the first execution of the computation. Otherwise, the
* cache item is always alive as long as the jit_module of the kernel is alive.
*/
struct ConstCacheProxy : RefCountManaged {
ConstCacheProxy(const std::shared_ptr<void> &vkeepAlive, void *buffer,
size_t size, bool is_lazy)
: RefCountManaged(vkeepAlive), size(size), isLazy(is_lazy),
buffer(buffer) {}
~ConstCacheProxy() = default;

// get the buffer and increment the refcount. If the buffer is evicted,
// returns null
void *acquire(int32_t *inited) {
if (checkAliveAndRef()) {
*inited = *inited && initialized;
return buffer;
}
return nullptr;
}
// decrement the refcount
bool release() {
if (isAlive()) {
deref();
initialized = 1;
return true;
}
return false;
}

// return the buffer. Do not directly use the buffer because it may be already
// release! To access the buffer, always acquire() before using it.
void *getBufferUnsafe() const { return buffer; }

size_t size;
// if the buffer is lazy-initialized. If false, it should be filled before
// computation
bool isLazy;

private:
// raw pointer to the buffer
void *buffer;
// if the buffer has been initialized. calling release() will set this to 1
int32_t initialized = 0;
};

struct CachedGraphTensor {
// Multiple tensors can reside in one common ConstCacheProxy `base`, with
// different offsets.
std::shared_ptr<ConstCacheProxy> base;
size_t offset;
CachedGraphTensor(const std::shared_ptr<ConstCacheProxy> &base, size_t offset)
: base{base}, offset{offset} {
// todo: fill in real values
ref.basePtr = (char *)base->getBufferUnsafe() + offset;
ref.data = ref.basePtr;
ref.offset = 0;
memset(ref.sizes, 0, sizeof(ref.sizes));
memset(ref.strides, 0, sizeof(ref.strides));
}
friend class JitModule;

private:
StridedMemRefType<char, 8> ref;
};

inline std::shared_ptr<ConstCacheProxy> createConstCacheProxy(size_t size) {
// simply allocate buffer and return
std::shared_ptr<void> base = std::shared_ptr<void>{
std::aligned_alloc(64, size), [](void *p) { std::free(p); }};
return std::make_shared<ConstCacheProxy>(base, base.get(), size, true);
}

inline static size_t divideAndCeil(size_t x, size_t y) {
return (x + y - 1) / y;
}

// Manager
struct ConstGraphTensorCacheManager {
std::atomic_int64_t cachedTensorGlobalId = 0;

std::unordered_map<int64_t, std::shared_ptr<CachedGraphTensor>> cache;

// singleton
static std::shared_ptr<ConstGraphTensorCacheManager> get() {
static std::shared_ptr<ConstGraphTensorCacheManager> c =
std::make_shared<ConstGraphTensorCacheManager>();
return c;
}

std::shared_ptr<CachedGraphTensor> queryCacheTensor(int64_t key) {
auto itr = cache.find(key);
if (itr != cache.end()) {
return itr->second;
}
return nullptr;
}

bool regCachedTensor(int64_t key,
const std::shared_ptr<ConstCacheProxy> &base,
size_t offset) {
if (queryCacheTensor(key)) {
return false;
}

cache[key] = std::make_shared<CachedGraphTensor>(base, offset);
return true;
}

// alloc and set the buf_base_ and offset_ attributes of cache
std::vector<int64_t> alloc(std::vector<size_t> buffersSize) {
size_t totalSize = 0;
for (size_t size : buffersSize) {
totalSize += divideAndCeil(size, 64) * 64;
}
auto base = createConstCacheProxy(totalSize);
std::vector<int64_t> globalIds(buffersSize.size());
size_t offset = 0;
for (size_t i = 0; i < buffersSize.size(); i++) {
bool regRes = regCachedTensor(cachedTensorGlobalId, base, offset);
assert(regRes && "Register constant tensor failed");
globalIds[i] = cachedTensorGlobalId;
++cachedTensorGlobalId;
offset += divideAndCeil(buffersSize[i], 64) * 64;
}
return globalIds;
}
};

} // namespace gc
} // namespace mlir

#endif
Loading