diff --git a/include/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h b/include/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h new file mode 100644 index 000000000..2dee27f6f --- /dev/null +++ b/include/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h @@ -0,0 +1,125 @@ +//===-- ConstantSubgraphAnalyser.h - Constant subgraph ----------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// This file implements constant subgraph analysis. In this file are: +/// 1. the lattice value class that represents operations with constant inputs +/// and outputs in the program, and +/// 2. a sparse constant subgraph analysis. +/// +///===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H +#define MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" + +namespace mlir { +namespace dataflow { + +//===----------------------------------------------------------------------===// +// IsConstantTensor +//===----------------------------------------------------------------------===// + +/// This lattice represents a boolean indicating if a value is constant. +class IsConstantTensor { +public: + /// Construct as uninitialized. + explicit IsConstantTensor() = default; + + /// Construct with a known state. + explicit IsConstantTensor(bool initialized, bool isConstantTensor) + : initialized(initialized), isConstantTensor(isConstantTensor) {} + + /// Get the state. Must be initialized before. + bool getIsConstantTensor() const { + assert(!isUninitialized()); + return isConstantTensor; + } + + /// Compare. + bool operator==(const IsConstantTensor &rhs) const { + return initialized == rhs.initialized && + isConstantTensor == rhs.isConstantTensor; + } + + void print(raw_ostream &os) const; + + /// Get uninitialized state. This happens when the + /// state hasn't been set during the analysis. + static IsConstantTensor getUninitialized() { return IsConstantTensor{}; } + + /// Whether the state is uninitialized. + bool isUninitialized() const { return !initialized; } + + /// Get unknown state. + static IsConstantTensor getUnknown() { + return IsConstantTensor{/*initialized=*/false, + /*isConstantTensor*/ false}; + } + + // Join two states. + static IsConstantTensor join(const IsConstantTensor &lhs, + const IsConstantTensor &rhs) { + // if one is uninitialized, use another + if (lhs.isUninitialized()) + return rhs; + if (rhs.isUninitialized()) + return lhs; + + // both are initialized, intersect them + if (!lhs.isUninitialized() && !rhs.isUninitialized()) { + return IsConstantTensor(true, lhs.getIsConstantTensor() && + rhs.getIsConstantTensor()); + } + return getUninitialized(); + } + +private: + bool initialized = false; + bool isConstantTensor = false; +}; + +//===----------------------------------------------------------------------===// +// ConstantSubgraphAnalyser +//===----------------------------------------------------------------------===// + +class ConstantSubgraphAnalyser + : public SparseForwardDataFlowAnalysis> { +public: + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + + LogicalResult visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; + + void setToEntryState(Lattice *lattice) override; +}; + +//===----------------------------------------------------------------------===// +// RunConstantSubgraphAnalyser +//===----------------------------------------------------------------------===// + +/// Runs constant subgraph analysis on the IR defined by `op`. +struct RunConstantSubgraphAnalyser { +public: + RunConstantSubgraphAnalyser(); + + void run(Operation *op); + + bool getIsConstantTensor(Value val); + +private: + /// Stores the result of the analysis. + DataFlowSolver solver; + + void getConstantSubgraph(DataFlowSolver &solver, Operation *topFunc); +}; +} // end namespace dataflow +} // end namespace mlir + +#endif // MLIR_ANALYSIS_DATAFLOW_CONSTANTSUBGRAPHANALYSER_H diff --git a/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.td b/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.td index 6e1eaceca..9f0046c3a 100644 --- a/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.td +++ b/include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.td @@ -22,6 +22,8 @@ def OneDNNGraphDialect : Dialect { This dialect follows oneDNN Graph Specification. }]; let cppNamespace = "::mlir::onednn_graph"; + + let hasOperationAttrVerify = 1; } #endif // ONEDNNGRAPH_DIALECT diff --git a/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h new file mode 100644 index 000000000..8a7330eaa --- /dev/null +++ b/include/gc/ExecutionEngine/CPURuntime/ConstantCache.h @@ -0,0 +1,207 @@ +//===-- ConstantCache.h - Constant cache interfaces -------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H +#define GC_EXECUTIONENGINE_CPURUNTIME_CONSTANT_CACHE_H +#include "mlir/ExecutionEngine/CRunnerUtils.h" +#include +#include +#include +#include +#include +#include +namespace mlir { +namespace gc { +/** + * The helper class to manage ref count manually for an object allocated with + * shared ptr. It holds an additional shared ptr reference to the object and + * contains an additional self-managed refcount. The refcount will be set to 1 + * when the object is initialized (see init()). When the refcount counts down to + * 0, the additional shared ptr is reset. + */ +struct RefCountManaged { + RefCountManaged() = default; + RefCountManaged(const std::shared_ptr &vkeepAlive) { init(vkeepAlive); } + void init(const std::shared_ptr &vkeepAlive) { + keepAlive = vkeepAlive; + refCount.store(1); + } + + void ref() { ++refCount; } + void deref() { + auto newv = --refCount; + if (newv == 0) { + keepAlive = nullptr; + } + } + + // atomically check if refCount > 0. if so, ref() the object and return + // true. Otherwise (if refCount==0), return false + bool checkAliveAndRef() { + auto oldv = refCount.load(); + for (;;) { + if (oldv <= 0) { + return false; + } + if (refCount.compare_exchange_strong(oldv, oldv + 1)) { + return true; + } + // CAS failed, oldv has now the newest known value of refCount + } + } + + bool isAlive() const { return refCount > 0; } + void *getPtrUnsafe() const { return keepAlive.get(); } + +private: + std::shared_ptr keepAlive; + std::atomic refCount{0}; +}; + +/** + * The proxy for the constant cache of Graph API. It holds a shared ptr pointing + * to the cache item in the cache manager (keepAlive) to extend the lifetime by + * refcount, @see RefCountManaged. To access the memory buffer of the const + * cache, use acauire/release functions. They will ref/deref the ConstCacheProxy + * to make sure the cache is alive after calling acauire and before release. The + * cache manager of Graph API may evict the cache item by dereferenceing this + * RefCountManaged object. {acquire,release} functions will find out that the + * cache has been invalidated. Usually we expect JIT modules to hold shared ptr + * to ConstCacheProxy via CachedGraphTensor. If isLazy == true, the cache + * item's lifetime will be managed by the cache manager of Graph API and it is + * filled with data after the first execution of the computation. Otherwise, the + * cache item is always alive as long as the jit_module of the kernel is alive. + */ +struct ConstCacheProxy : RefCountManaged { + ConstCacheProxy(const std::shared_ptr &vkeepAlive, void *buffer, + size_t size, bool is_lazy) + : RefCountManaged(vkeepAlive), size(size), isLazy(is_lazy), + buffer(buffer) {} + ~ConstCacheProxy() = default; + + // get the buffer and increment the refcount. If the buffer is evicted, + // returns null + void *acquire(int32_t *inited) { + if (checkAliveAndRef()) { + *inited = *inited && initialized; + return buffer; + } + return nullptr; + } + // decrement the refcount + bool release() { + if (isAlive()) { + deref(); + initialized = 1; + return true; + } + return false; + } + + // return the buffer. Do not directly use the buffer because it may be already + // release! To access the buffer, always acquire() before using it. + void *getBufferUnsafe() const { return buffer; } + + size_t size; + // if the buffer is lazy-initialized. If false, it should be filled before + // computation + bool isLazy; + +private: + // raw pointer to the buffer + void *buffer; + // if the buffer has been initialized. calling release() will set this to 1 + int32_t initialized = 0; +}; + +struct CachedGraphTensor { + // Multiple tensors can reside in one common ConstCacheProxy `base`, with + // different offsets. + std::shared_ptr base; + size_t offset; + CachedGraphTensor(const std::shared_ptr &base, size_t offset) + : base{base}, offset{offset} { + // todo: fill in real values + ref.basePtr = (char *)base->getBufferUnsafe() + offset; + ref.data = ref.basePtr; + ref.offset = 0; + memset(ref.sizes, 0, sizeof(ref.sizes)); + memset(ref.strides, 0, sizeof(ref.strides)); + } + friend class JitModule; + +private: + StridedMemRefType ref; +}; + +inline std::shared_ptr createConstCacheProxy(size_t size) { + // simply allocate buffer and return + std::shared_ptr base = std::shared_ptr{ + std::aligned_alloc(64, size), [](void *p) { std::free(p); }}; + return std::make_shared(base, base.get(), size, true); +} + +inline static size_t divideAndCeil(size_t x, size_t y) { + return (x + y - 1) / y; +} + +// Manager +struct ConstGraphTensorCacheManager { + std::atomic_int64_t cachedTensorGlobalId = 0; + + std::unordered_map> cache; + + // singleton + static std::shared_ptr get() { + static std::shared_ptr c = + std::make_shared(); + return c; + } + + std::shared_ptr queryCacheTensor(int64_t key) { + auto itr = cache.find(key); + if (itr != cache.end()) { + return itr->second; + } + return nullptr; + } + + bool regCachedTensor(int64_t key, + const std::shared_ptr &base, + size_t offset) { + if (queryCacheTensor(key)) { + return false; + } + + cache[key] = std::make_shared(base, offset); + return true; + } + + // alloc and set the buf_base_ and offset_ attributes of cache + std::vector alloc(std::vector buffersSize) { + size_t totalSize = 0; + for (size_t size : buffersSize) { + totalSize += divideAndCeil(size, 64) * 64; + } + auto base = createConstCacheProxy(totalSize); + std::vector globalIds(buffersSize.size()); + size_t offset = 0; + for (size_t i = 0; i < buffersSize.size(); i++) { + bool regRes = regCachedTensor(cachedTensorGlobalId, base, offset); + assert(regRes && "Register constant tensor failed"); + globalIds[i] = cachedTensorGlobalId; + ++cachedTensorGlobalId; + offset += divideAndCeil(buffersSize[i], 64) * 64; + } + return globalIds; + } +}; + +} // namespace gc +} // namespace mlir + +#endif \ No newline at end of file diff --git a/include/gc/ExecutionEngine/Driver/Driver.h b/include/gc/ExecutionEngine/Driver/Driver.h index ee8630b53..d80fb4e51 100644 --- a/include/gc/ExecutionEngine/Driver/Driver.h +++ b/include/gc/ExecutionEngine/Driver/Driver.h @@ -9,6 +9,7 @@ #ifndef GC_EXECUTIONENGINE_DRIVER_DRIVER_H #define GC_EXECUTIONENGINE_DRIVER_DRIVER_H +#include "gc/ExecutionEngine/CPURuntime/ConstantCache.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include @@ -37,30 +38,57 @@ class JitModule { static llvm::Expected> create(Operation *op, const DriverOptions &options = {}); - /// args should be an array of XXXMemrefType* - void call(GeneralMemrefPtr *args, std::size_t numArgs) { - // Silly code, MLIR execution engine requires pointers of real args as - // inputs - llvm::SmallVector realargs; - realargs.reserve(numArgs); - for (size_t i = 0; i < numArgs; i++) { - realargs.push_back(&args[i]); - } - compute(realargs.data()); - } - - /// directly call compute(). args should be an array of void*. args[i] should + // args should be an array of XXXMemrefType* + // numArgs: including input and output args. + void call(GeneralMemrefPtr *args, int32_t numArgs); + + /// directly call entry(). args should be an array of void*. args[i] should /// be a pointer to the real data. For passing memref, users need to 1) create /// a pointer to XXXMemrefType 2) store the pointer to pointer to /// XXXMemrefType in args[i] - void callRaw(void **args) { compute(args); } + void callRaw(void **args) { entry(args); } + + JitModule(std::unique_ptr engine, JitModuleFuncT entry); - JitModule(std::unique_ptr engine, JitModuleFuncT compute); + JitModule( + std::unique_ptr engine, JitModuleFuncT entry, + JitModuleFuncT fold, int32_t numOrigArgs, + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef entryArgs, + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef foldArgs, + std::vector> &&cachekeepAlive = {}); ~JitModule(); private: std::unique_ptr engine; - JitModuleFuncT compute; + JitModuleFuncT entry; + JitModuleFuncT fold; + int32_t numOrigArgs; // only input args + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef foldArgs; + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef entryArgs; + + // The bases of CachedGraphTensors. For example, tensor1 (size 256) and + // tensor2 (size 256) are in ConstCacheProxy base1, and tensor3 (size 256) in + // base2. Then cacheBases is {base1, base2}, cacheInfo is {{baseIdx=0, + // offset=0}, {baseIdx=0, offset=256}, {baseIdx=1, offset=0}}. + + // `keepAlive` has the ownership of the objects pointed by this vector + llvm::SmallVector cacheBases; + struct CacheBufferInfo { + // index in cacheBases + size_t baseIdx; + size_t offset; + }; + // the info for each folded cached buffer + llvm::SmallVector cacheInfo; + + // holding the pointers to StridedMemRefType of folded cache + llvm::SmallVector fastFoldBuffers; + // `keepAlive` holds the the ownership of the pointers + std::vector> keepAlive; }; } // namespace gc diff --git a/include/gc/Transforms/Passes.h b/include/gc/Transforms/Passes.h index a9b73d687..06a3ee83d 100644 --- a/include/gc/Transforms/Passes.h +++ b/include/gc/Transforms/Passes.h @@ -120,8 +120,12 @@ void populateGPUPipeline(mlir::OpPassManager &); #endif #define GEN_PASS_DECL +#define GEN_PASS_DECL_CONSTANTSUBGRAPHANALYSIS +#define GEN_PASS_DECL_CONSTANTTENSORFOLDING #include "gc/Transforms/Passes.h.inc" +std::unique_ptr createConstantTensorFoldingPass(); + #define GEN_PASS_REGISTRATION #include "gc/Transforms/Passes.h.inc" } // namespace gc diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 5151a0335..9a968c3bd 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -169,6 +169,18 @@ def MergeNestedForall : Pass<"merge-nested-forall"> { let dependentDialects = ["scf::SCFDialect"]; } +def ConstantTensorFolding : Pass<"constant-tensor-folding"> { + let summary = "Constant Tensor Folding Transform"; + let description = [{ + This pass implements a constant tensor folding transform. + }]; + let constructor = "mlir::gc::createConstantTensorFoldingPass()"; + let dependentDialects = [ + "tensor::TensorDialect", + "linalg::LinalgDialect", + "LLVM::LLVMDialect"]; +} + def FoldTensorOperation : Pass<"fold-tensor-operation"> { let summary = "Fold some tensor operation"; let description = [{ diff --git a/lib/gc/Analysis/CMakeLists.txt b/lib/gc/Analysis/CMakeLists.txt index d7160f350..bddd18a52 100644 --- a/lib/gc/Analysis/CMakeLists.txt +++ b/lib/gc/Analysis/CMakeLists.txt @@ -4,6 +4,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS gc_add_mlir_library(GcAnalysis TargetDescriptionAnalysis.cpp + DataFlow/ConstantSubgraphAnalyser.cpp MatmulConfigAnalysis.cpp DEPENDS diff --git a/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp new file mode 100644 index 000000000..b3c6b51ba --- /dev/null +++ b/lib/gc/Analysis/DataFlow/ConstantSubgraphAnalyser.cpp @@ -0,0 +1,187 @@ +//===-- ConstantSubgraphAnalyser.cpp - Constant subgraph -------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include +#include + +#include "gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "in-constant-subgraph" + +using namespace mlir; +using namespace mlir::dataflow; + +//===----------------------------------------------------------------------===// +// IsConstantTensor +//===----------------------------------------------------------------------===// + +void IsConstantTensor::print(raw_ostream &os) const { + if (isUninitialized()) { + os << ""; + return; + } + os << getIsConstantTensor(); +} + +//===----------------------------------------------------------------------===// +// ConstantSubgraphAnalyser +//===----------------------------------------------------------------------===// + +LogicalResult ConstantSubgraphAnalyser::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { + LLVM_DEBUG(llvm::dbgs() << "ConstantSubgraphAnalyser: Visiting operation:\n" + << *op << "\n"); + + bool in = true; + if (op->hasTrait()) { + LLVM_DEBUG(llvm::dbgs() << "Curr op is a Constant op\n"); + in = true; + } else if (operands.empty()) { // For example, tensor.empty() + LLVM_DEBUG(llvm::dbgs() << "Curr op has 0 operand, constant\n"); + in = true; + } else { + LLVM_DEBUG(llvm::dbgs() << "Curr op has " << operands.size() + << " operands, check if constant\n"); + for (auto *operandLattice : operands) { + auto operandState = operandLattice->getValue().getIsConstantTensor(); + LLVM_DEBUG(llvm::dbgs() << "Operand: " << operandLattice->getPoint() + << ", lattice value: " << operandState << "\n"); + if (!operandState) { + in = false; + break; + } + } + } + + // lattice in results should be in unintialized state. + if (!in) { + LLVM_DEBUG(llvm::dbgs() << "Curr op not in constant subgraph\n"); + for (auto lattice : results) { + propagateIfChanged(lattice, lattice->join(IsConstantTensor(true, false))); + } + } else { + LLVM_DEBUG(llvm::dbgs() << "Curr op in constant subgraph\n"); + for (auto lattice : results) { + propagateIfChanged(lattice, lattice->join(IsConstantTensor(true, true))); + } + } + return LogicalResult::success(); +} + +void ConstantSubgraphAnalyser::setToEntryState( + Lattice *lattice) { + if (auto blockArg = cast(lattice->getPoint())) { + auto parentOp = blockArg.getParentBlock()->getParentOp(); + auto parentOpAttr = parentOp->getAttrDictionary(); + + std::unordered_set constArgsIndexes; + std::optional compiletimeConstArgs = + parentOpAttr.getNamed("compiletime_const_args_index"); + if (compiletimeConstArgs.has_value()) { + for (auto id : + llvm::dyn_cast(compiletimeConstArgs->getValue())) { + constArgsIndexes.insert(llvm::cast(id).getInt()); + } + } + std::optional runtimeConstArgs = + parentOpAttr.getNamed("runtime_const_args_index"); + if (runtimeConstArgs.has_value()) { + for (auto id : llvm::dyn_cast(runtimeConstArgs->getValue())) { + constArgsIndexes.insert(llvm::cast(id).getInt()); + } + } + + if (constArgsIndexes.count(blockArg.getArgNumber())) { + LLVM_DEBUG(llvm::dbgs() << "Block argument: " << blockArg + << " is marked as constant\n"); + propagateIfChanged(lattice, lattice->join(IsConstantTensor(true, true))); + return; + } + propagateIfChanged(lattice, lattice->join(IsConstantTensor(true, false))); + } else { + propagateIfChanged(lattice, + lattice->join(IsConstantTensor::getUninitialized())); + } +} + +//===----------------------------------------------------------------------===// +// RunConstantSubgraphAnalyser +//===----------------------------------------------------------------------===// + +/// Get the operations whose inputs and outputs are all constant values. +/// These operations will be put into a seperate subgraph. +void RunConstantSubgraphAnalyser::getConstantSubgraph(DataFlowSolver &solver, + Operation *topFunc) { + OpBuilder builder(topFunc->getContext()); + SmallVector constantOperations; + + Block &block = topFunc->getRegions().front().getBlocks().front(); + for (Operation &op : llvm::make_early_inc_range(block)) { + // If all the result values of a op are const, we mark this op as const. + bool resultsAllConstant = true; + if (op.getNumResults() == 0) + continue; + + for (Value res : op.getResults()) { + auto *lattice = solver.lookupState>(res); + if (!lattice || lattice->getValue().isUninitialized()) { + resultsAllConstant = false; + break; + } + const IsConstantTensor &latticeValue = lattice->getValue(); + if (!latticeValue.getIsConstantTensor()) { + resultsAllConstant = false; + break; + } + } + if (resultsAllConstant) { + op.setAttr("onednn_graph.in_const_subgraph", builder.getBoolAttr(true)); + constantOperations.push_back(&op); + } + } + + if (constantOperations.empty()) + return; +} + +RunConstantSubgraphAnalyser::RunConstantSubgraphAnalyser() { + solver.load(); + solver.load(); +} + +void RunConstantSubgraphAnalyser::run(Operation *op) { + if (failed(solver.initializeAndRun(op))) + return; + + getConstantSubgraph(solver, op); +} + +bool RunConstantSubgraphAnalyser::getIsConstantTensor(Value val) { + auto *lattice = solver.lookupState>(val); + const IsConstantTensor &latticeValue = lattice->getValue(); + return latticeValue.getIsConstantTensor(); +} \ No newline at end of file diff --git a/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp index b9f2e17a4..228845686 100644 --- a/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp +++ b/lib/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.cpp @@ -25,3 +25,9 @@ void OneDNNGraphDialect::initialize() { #include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp.inc" >(); } + +LogicalResult +OneDNNGraphDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + return success(); +} diff --git a/lib/gc/ExecutionEngine/CMakeLists.txt b/lib/gc/ExecutionEngine/CMakeLists.txt index f13a27b1a..042c41415 100644 --- a/lib/gc/ExecutionEngine/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/CMakeLists.txt @@ -2,4 +2,4 @@ add_subdirectory(CPURuntime) add_subdirectory(Driver) if(GC_ENABLE_IMEX) add_subdirectory(OpenCLRuntime) -endif() \ No newline at end of file +endif() diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index 16da521d0..42fa83b68 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -20,6 +20,11 @@ #include "llvm/Support/InitLLVM.h" #include "llvm/Support/TargetSelect.h" +#define DEBUG_TYPE "driver" + +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) + namespace mlir { namespace gc { @@ -46,8 +51,8 @@ const DialectRegistry &initCompilerAndGetDialects() { return reg; } -static const char defaultComputeName[] = "_mlir_ciface_compute"; - +static const char defaultEntryName[] = "_mlir_ciface_entry"; +static const char defaultFoldName[] = "_mlir_ciface_runtime_fold"; llvm::Expected> JitModule::create(Operation *op, const DriverOptions &options) { if (options.runTransforms) { @@ -66,21 +71,213 @@ JitModule::create(Operation *op, const DriverOptions &options) { return exec.takeError(); } auto &engine = *exec; - JitModuleFuncT compute; - { - auto expectCompute = engine->lookupPacked(defaultComputeName); - if (!expectCompute) { - return expectCompute.takeError(); + + auto expectEntry = engine->lookupPacked(defaultEntryName); + if (!expectEntry) { + // entry function must exist + return expectEntry.takeError(); + } + JitModuleFuncT entry = *expectEntry; + + int32_t numOrigArgs = 0; + llvm::ArrayRef foldBufferIds; + JitModuleFuncT fold = nullptr; + llvm::ArrayRef entryArgs; + llvm::ArrayRef foldArgs; + do { + { + auto expectArgs = engine->lookup("__num_orig_args"); + if (!expectArgs) { // nothing to fold, It is OK. + llvm::consumeError(expectArgs.takeError()); + // break out of the scope, don't need to lookup other things + break; + } + numOrigArgs = *reinterpret_cast(*expectArgs); + } + + // If lookup("__num_orig_num_args") succeeds, then all the following should + // also succeed. + { + auto expectBufferIds = engine->lookup("__runtime_fold_buffer_ids"); + if (!expectBufferIds) { + llvm_unreachable("Symbol: __runtime_fold_buffer_ids not found"); + break; + } + auto raw = reinterpret_cast(*expectBufferIds); + foldBufferIds = + llvm::ArrayRef{raw + 1, static_cast(raw[0])}; + } + + // find "fold" func + { + auto expectFold = engine->lookupPacked(defaultFoldName); + if (!expectFold) { + llvm_unreachable("Symbol: runtime_fold not found"); + break; + } + fold = *expectFold; } - compute = *expectCompute; + + // find "foldArgs" + { + auto expectFold = engine->lookup("__fold_args"); + if (!expectFold) { + llvm_unreachable("Symbol: __fold_args not found"); + break; + } + auto raw = reinterpret_cast(*expectFold); + foldArgs = llvm::ArrayRef{raw + 1, static_cast(raw[0])}; + } + + // find "entryArgs" + { + auto expect = engine->lookup("__compute_args"); + if (!expect) { + llvm_unreachable("Symbol: __compute_args not found"); + break; + } + auto raw = reinterpret_cast(*expect); + entryArgs = llvm::ArrayRef{raw + 1, static_cast(raw[0])}; + } + } while (false); + + std::vector> foldInfo; + foldInfo.reserve(foldBufferIds.size()); + auto cacheManager = ConstGraphTensorCacheManager::get(); + for (auto bufId : foldBufferIds) { + auto ret = cacheManager->queryCacheTensor(bufId); + if (!ret) { + return llvm::make_error( + "Failed to query the folded cached tensor of id: " + + std::to_string(bufId), + llvm::inconvertibleErrorCode()); + } + foldInfo.emplace_back(std::move(ret)); } - return std::make_shared(std::move(engine), compute); + + return std::make_shared(std::move(engine), entry, fold, + numOrigArgs, entryArgs, foldArgs, + std::move(foldInfo)); } -JitModule::JitModule(std::unique_ptr engine, - JitModuleFuncT compute) - : engine{std::move(engine)}, compute{compute} {} +JitModule::JitModule( + std::unique_ptr engine, JitModuleFuncT entry, + JitModuleFuncT fold, int32_t numOrigArgs, + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef entryArgs, + // The code inside `engine` has the ownership of the buffer + llvm::ArrayRef foldArgs, + std::vector> &&cachekeepAlive) + : engine{std::move(engine)}, entry{entry}, fold{fold}, + numOrigArgs{numOrigArgs}, foldArgs{foldArgs}, entryArgs{entryArgs}, + keepAlive{std::move(cachekeepAlive)} { + for (const auto &cache : keepAlive) { + auto currentItr = + std::find(cacheBases.begin(), cacheBases.end(), cache->base.get()); + if (currentItr == cacheBases.end()) { + cacheBases.push_back(cache->base.get()); + currentItr = cacheBases.end() - 1; + } + cacheInfo.emplace_back(CacheBufferInfo{ + static_cast(currentItr - cacheBases.begin()), cache->offset}); + fastFoldBuffers.push_back(&cache->ref); + } +} JitModule::~JitModule() = default; +static void prepareCallArgs(llvm::SmallVector &realargs, + GeneralMemrefPtr *origargs, int32_t numArgs, + int32_t numOrigArgs, GeneralMemrefPtr *foldedCache, + llvm::ArrayRef realArgIdx) { + // inputs, including unfolded and folded + realargs.reserve(realArgIdx.size()); + for (auto argIdx : realArgIdx) { + if (argIdx < numOrigArgs) { + realargs.push_back(&origargs[argIdx]); + } else { + realargs.push_back(&foldedCache[argIdx - numOrigArgs]); + } + } + // outputs + for (int i = numOrigArgs; i < numArgs; ++i) { + realargs.push_back(&origargs[i]); + } +} + +void JitModule::call(GeneralMemrefPtr *args, int32_t numArgs) { + if (unlikely(cacheInfo.empty())) { + // fast path, no folded cached buffers + // Silly code, MLIR execution engine requires pointers of real args as + // inputs + llvm::SmallVector realargs; + realargs.reserve(numArgs); + for (int i = 0; i < numArgs; i++) { + realargs.push_back(&args[i]); + } + entry(realargs.data()); + return; + } + + // stage 1, acquire the foldBasePtr + llvm::SmallVector foldBasePtr; + int32_t inited = 1; + for (auto b : cacheBases) { + auto ptr = b->acquire(&inited); + if (unlikely(!ptr)) { + ptr = std::aligned_alloc(/*alignment*/ 64, b->size); + inited = 0; + } + foldBasePtr.push_back((char *)ptr); + } + + // stage 2, run fold() if necessary + GeneralMemrefPtr *foldedCache; + // only used when !inited + std::vector slowFold; + std::vector> slowFoldObj; + if (likely(inited)) { + foldedCache = fastFoldBuffers.data(); + } else { + slowFold.reserve(cacheInfo.size()); + slowFoldObj.reserve(cacheInfo.size()); + for (auto &info : cacheInfo) { + slowFoldObj.emplace_back(); + auto &obj = slowFoldObj.back(); + obj.basePtr = foldBasePtr[info.baseIdx] + info.offset; + obj.data = obj.basePtr; + memset(obj.sizes, 0, sizeof(obj.sizes)); + memset(obj.strides, 0, sizeof(obj.strides)); + slowFold.push_back(&obj); + } + foldedCache = slowFold.data(); + llvm::SmallVector realargs; + prepareCallArgs(realargs, args, numArgs, numOrigArgs, foldedCache, + foldArgs); + LLVM_DEBUG(llvm::dbgs() + << "fold func args size: " << foldArgs.size() << '\n'); + fold(realargs.data()); + } + + // stage 3, call entry + { + llvm::SmallVector realargs; + prepareCallArgs(realargs, args, numArgs, numOrigArgs, foldedCache, + entryArgs); + LLVM_DEBUG(llvm::dbgs() + << "entry func args size: " << realargs.size() << '\n'); + entry(realargs.data()); + } + + // stage 4, cleanup + for (size_t i = 0; i < cacheBases.size(); i++) { + auto b = cacheBases[i]; + if (unlikely(!b->release())) { + // if the cached buffer is already free'd, foldBasePtr[i] is allocated via + // std::aligned_alloc by us, free it + std::free(foldBasePtr[i]); + } + } +} + } // namespace gc } // namespace mlir \ No newline at end of file diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index ca15f2f78..08d60e513 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -16,6 +16,7 @@ gc_add_mlir_library(GcPasses IterativeTilingAndFusion.cpp TilingUsingInterfaceX.cpp VerifyTargetDescription.cpp + ConstantTensorFolding.cpp DecomposeAggregatedOps.cpp DeepTileContractionOp.cpp TilingUtil.cpp @@ -36,6 +37,7 @@ gc_add_mlir_library(GcPasses ${MLIR_LINK_COMPONENTS} ${GC_ONEDNN_DIALECT_LIB_NAME} GcInterface + GcAnalysis MLIRMicrokernelTransforms ) diff --git a/lib/gc/Transforms/ConstantTensorFolding.cpp b/lib/gc/Transforms/ConstantTensorFolding.cpp new file mode 100644 index 000000000..44995101f --- /dev/null +++ b/lib/gc/Transforms/ConstantTensorFolding.cpp @@ -0,0 +1,855 @@ +//===-- ConstantTensorFolding.cpp - Constant Folding ------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass performs a constant subgraph transform in MLIR. +// +//===----------------------------------------------------------------------===// +#include +#include +#include + +#include "gc/Analysis/DataFlow/ConstantSubgraphAnalyser.h" +#include "mlir/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +#include "gc/ExecutionEngine/CPURuntime/ConstantCache.h" + +#define DEBUG_TYPE "constant-tensor-folding" + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_CONSTANTTENSORFOLDING +#include "gc/Transforms/Passes.h.inc" +} // namespace gc + +using namespace mlir; + +namespace gc { + +struct ConstantTensorFolding + : public impl::ConstantTensorFoldingBase { + void runOnOperation() override; +}; + +bool isInConstantSubgraph(Operation *op) { + auto opNamespace = op->getDialect()->getNamespace(); + if (opNamespace == linalg::LinalgDialect::getDialectNamespace() || + opNamespace == tensor::TensorDialect::getDialectNamespace() || + opNamespace == + bufferization::BufferizationDialect::getDialectNamespace() || + opNamespace == arith::ArithDialect::getDialectNamespace()) { + if (op->getAttr("onednn_graph.in_const_subgraph")) { + return true; + } + } + return false; +} + +template int64_t getDataSize(T t) { + Type eleType = t.getElementType(); + unsigned bitWidth = eleType.getIntOrFloatBitWidth() / 8; // bytes + ArrayRef shape = t.getShape(); + int64_t size = bitWidth; + for (auto s : shape) + size *= s; + + return size; +} + +int64_t getValueSize(Value v) { + if (isa(v.getType())) { + auto t = dyn_cast(v.getType()); + return getDataSize(t); + } else { + auto t = dyn_cast(v.getType()); + return getDataSize(t); + } +} + +/// @brief op has only one operand, or operands of op are one same value, or +/// operands of op are one same value or from tensor.EmptyOp. +/// @param op +/// @return +bool singleOperand(Operation *op) { + if (op->getNumOperands() > 1) { + Value firstOperand = op->getOperand(0); + for (int64_t i = 1; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + if (firstOperand == operand) + continue; + + auto parentOp = operand.getDefiningOp(); + if (parentOp && !isa(parentOp)) + return false; + } + } + return true; +} + +bool canMoveBefore(Operation *op) { + if (op->getDialect()->getNamespace() == + arith::ArithDialect::getDialectNamespace()) { + return true; + } + + if (op->getDialect()->getNamespace() != + linalg::LinalgDialect::getDialectNamespace()) { + return false; + } + + auto linalgOp = dyn_cast(op); + + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); + for (auto &affineMap : indexingMaps) { + if (!affineMap.isIdentity()) + return false; + } + + SmallVector iterTypes = linalgOp.getIteratorTypesArray(); + for (auto &iterType : iterTypes) { + if (iterType != utils::IteratorType::parallel) + return false; + } + + if (op->getNumOperands() > 1) { + // int64_t numInputs = linalgOp.getNumDpsInputs(); + int64_t numInits = linalgOp.getNumDpsInits(); + // definingOp of init should be tensor.empty() + for (int64_t i = 0; i < numInits; ++i) { + OpOperand *outOperand = linalgOp.getDpsInitOperand(i); + auto parentOp = outOperand->get().getDefiningOp(); + if (!isa(parentOp)) + return false; + } + } + + return true; +} + +void postponeBroadcast(Block &block) { + // auto bcOps = block.getOps(); + // for (linalg::BroadcastOp bcOp : bcOps) {} + SmallVector constBcOps; + for (Operation &op : block.getOperations()) { + if (isa(&op)) { + Operation *bcOp = &op; + if (isInConstantSubgraph(bcOp)) + constBcOps.push_back(bcOp); + } + } + + for (auto bcOp : constBcOps) { + // For topo v -> pack -> bc -> mul -> matmul, we transform + // it to v -> pack -> mul -> bc -> matmul, so that we can fold + // v -> pack -> mul. Note that we require the topo to be sequential + // and all the Values have exactly one user. + + // go upwards to BlockArg + SmallVector prevOps; + Operation *currOp = bcOp; + while (true) { + if (currOp->getNumOperands() != 1) + break; + + Value operand = currOp->getOperand(0); + if (isa(operand)) { + break; + } else { + currOp = operand.getDefiningOp(); + prevOps.push_back(currOp); + } + } + + // go downwards to the last constant op + SmallVector postOps; + currOp = bcOp; + while (true) { + if (currOp->getNumResults() != 1 || !currOp->hasOneUse()) + break; + + Value input = currOp->getResult(0); + currOp = *(input.getUsers().begin()); + Value output = currOp->getResult(0); + // NOTE: we require that input shape and output shape of curr op to be + // same. Operations from tensor dialect, like + // pack/unpack/concat/collapse_shape/expand_shape/reshape/pad, are not + // supported. So we simply restrict that currOp to be from arith or + // linalg. + if (!isa(input.getType()) || + !isa(output.getType()) || + dyn_cast(input.getType()).getShape() != + dyn_cast(output.getType()).getShape() || + !canMoveBefore(currOp)) { + break; + } + if (!isInConstantSubgraph(currOp)) { + break; + } else { + postOps.push_back(currOp); + } + } + if (postOps.empty()) + continue; + + // move bcOp after the last constant op + SmallVector newPostOps; + Value operand = static_cast(bcOp->getOperand(0)); + ArrayRef shapeBeforeBc = + dyn_cast(operand.getType()).getShape(); + size_t postOpId = 0; + for (Operation *postOp : postOps) { + SmallVector newOperandTypes; + for (auto oriType : postOp->getOperandTypes()) { + TensorType tt = dyn_cast(oriType); + newOperandTypes.push_back( + tt.cloneWith(shapeBeforeBc, tt.getElementType())); + } + SmallVector newResultTypes; + for (auto oriType : postOp->getResultTypes()) { + TensorType tt = dyn_cast(oriType); + newResultTypes.push_back( + tt.cloneWith(shapeBeforeBc, tt.getElementType())); + } + auto *newPostOp = + Operation::create(postOp->getLoc(), postOp->getName(), newResultTypes, + postOp->getOperands(), + /*postOp->getAttrDictionary()*/ std::nullopt, + /*postOp->getPropertiesStorage()*/ nullptr, + postOp->getSuccessors(), postOp->getNumRegions()); + for (auto [oldRegion, newRegion] : + llvm::zip(postOp->getRegions(), newPostOp->getRegions())) { + newRegion.takeBody(oldRegion); + } + + if (postOpId == 0) { + // Only the first post op needs to replace its operand. Others only + // needs to call postOp->replaceAllUsesWith(newPostOp->getResults()). + newPostOp->getOperand(0).replaceAllUsesWith(operand); + } + ++postOpId; + + newPostOp->setAttr("onednn_graph.in_const_subgraph", + postOp->getAttr("onednn_graph.in_const_subgraph")); + if (postOp->getDialect()->getNamespace() == + linalg::LinalgDialect::getDialectNamespace()) { + newPostOp->setAttr("operandSegmentSizes", + postOp->getAttr("operandSegmentSizes")); + + OpBuilder builder(postOp->getContext()); + size_t indexingMapsSize = + dyn_cast(postOp).getIndexingMapsArray().size(); + unsigned rank = shapeBeforeBc.size(); + SmallVector indexingMaps( + indexingMapsSize, builder.getMultiDimIdentityMap(rank)); + auto indexingMapsAttr = builder.getAffineMapArrayAttr(indexingMaps); + newPostOp->setAttr("indexing_maps", indexingMapsAttr); + + SmallVector iterTypes = + dyn_cast(postOp).getIteratorTypesArray(); + iterTypes.resize(rank); + auto iterTypesAttr = + builder.getArrayAttr(llvm::to_vector(llvm::map_range( + iterTypes, [&](utils::IteratorType iter) -> mlir::Attribute { + return linalg::IteratorTypeAttr::get(builder.getContext(), + iter); + }))); + newPostOp->setAttr("iterator_types", iterTypesAttr); + } else { + // Ops from other dialects. + } + + // Modify the outputOperands of postOp. Here we simply assume that the + // value is from tensor.empty(). + if (postOp->getNumOperands() > 0) { + for (size_t i = 1; i < postOp->getNumOperands(); ++i) { + auto outOperand = postOp->getOperand(i); + outOperand.setType(newOperandTypes.front()); + } + } + + block.getOperations().push_back(newPostOp); + newPostOp->moveAfter(postOp); + newPostOps.push_back(newPostOp); + postOp->replaceAllUsesWith(newPostOp->getResults()); + + operand = static_cast(newPostOp->getResult(0)); + } + + auto nextOp = *(newPostOps.back()->getUsers().begin()); + nextOp->getOperand(0).replaceAllUsesWith(bcOp->getResult(0)); + bcOp->moveAfter(newPostOps.back()); + bcOp->getOperand(0).replaceUsesWithIf(operand, [&](OpOperand &val) { + Operation *op = val.getOwner(); + return op == bcOp; + }); + + for (auto it = postOps.rbegin(); it != postOps.rend(); ++it) + (*it)->erase(); + } +} + +static void addGlobalI32(ModuleOp &module, Location loc, OpBuilder &builder, + StringRef name, int32_t value) { + OpBuilder::InsertionGuard insertGuard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto type = IntegerType::get(builder.getContext(), 32); + LLVM::GlobalOp global = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::External, name, + builder.getI32IntegerAttr(value), + /*alignment=*/0); + (void)global; +} + +static void addGlobalI64Array(ModuleOp &module, Location loc, + OpBuilder &builder, StringRef name, + ArrayRef array) { + OpBuilder::InsertionGuard insertGuard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto type = LLVM::LLVMArrayType::get( + IntegerType::get(builder.getContext(), 64), array.size()); + LLVM::GlobalOp global = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::External, name, + builder.getI64TensorAttr(array), + /*alignment=*/0); + (void)global; +} + +static void addGlobalI32Array(ModuleOp &module, Location loc, + OpBuilder &builder, StringRef name, + ArrayRef array) { + OpBuilder::InsertionGuard insertGuard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto type = LLVM::LLVMArrayType::get( + IntegerType::get(builder.getContext(), 32), array.size()); + LLVM::GlobalOp global = builder.create( + loc, type, /*isConstant=*/true, LLVM::Linkage::External, name, + builder.getI32TensorAttr(array), + /*alignment=*/0); + (void)global; +} + +std::unordered_set getConstArgsIndexes(Operation &topFunc, + bool compiletime) { + auto topFuncAttr = topFunc.getAttrDictionary(); + std::unordered_set constArgsIndexes; + std::string attrName = + compiletime ? "compiletime_const_args_index" : "runtime_const_args_index"; + std::optional constArgs = topFuncAttr.getNamed(attrName); + if (constArgs.has_value()) { + for (auto id : llvm::dyn_cast(constArgs->getValue())) { + constArgsIndexes.insert(llvm::cast(id).getInt()); + } + } + return constArgsIndexes; +} + +void getArithConstantOutputs(Block &block, SmallVector &outputTypes, + SmallVector &outputValues) { + for (Operation &op : block.getOperations()) { + if (isa(&op)) { + Operation *constOp = &op; + auto constTensor = constOp->getResults().front(); + if (!isa(constTensor.getType())) + continue; + + auto v = dyn_cast(constTensor); + SmallVector valuesOnTheWay = {v}; // the constant tensors + std::deque dq; + dq.push_back(v); + // For v -> pack1 -> pack2 -> matmul, we need the type of output of pack2 + while (!dq.empty()) { + v = dq.front(); + dq.pop_front(); + // if the children ops of v are not all constant, we end at v + if (std::any_of(v.getUsers().begin(), v.getUsers().end(), + [](Operation *child) { + return !isInConstantSubgraph(child); + })) { + if (valuesOnTheWay.size() == 1) { + continue; + } + if (std::find(outputValues.begin(), outputValues.end(), v) == + outputValues.end()) { + outputTypes.push_back(v.getType()); + outputValues.push_back(v); + } + continue; + } + + // the children ops of v are all constant, we push their results to + // queue + for (Operation *child : v.getUsers()) { + for (OpResult result : child->getResults()) { + auto r = dyn_cast(result); + dq.push_back(r); + valuesOnTheWay.push_back(r); + } + } + } + } + } +} + +static constexpr int DATA_SIZE_EXPANDING_THRESHOLD = 8; + +void getInputsAndOutputs(Block &block, + std::unordered_set &constArgsIndexes, + SmallVector &inputTypes, + SmallVector &inputValues, + SmallVector &outputTypes, + SmallVector &outputValues) { + Value v; + // Support complicated topology. + for (size_t id = 0; id < block.getNumArguments(); ++id) { + if (constArgsIndexes.count(id) == 1) { + // The constant ops are all single-input single-output. + bool simpleTopo = true; + auto arg = block.getArgument(id); + if (!isa(arg.getType()) && !isa(arg.getType())) { + continue; + } + inputTypes.push_back(arg.getType()); + v = dyn_cast(arg); + inputValues.push_back(v); + SmallVector valuesOnTheWay = {v}; // the constant tensors + std::deque dq; + dq.push_back(v); + // For v -> pack1 -> pack2 -> matmul, we need the type of output of pack2 + while (!dq.empty()) { + v = dq.front(); + dq.pop_front(); + // if the children ops of v are not all constant, we end at v + if (std::any_of(v.getUsers().begin(), v.getUsers().end(), + [](Operation *child) { + return !isInConstantSubgraph(child); + })) { + // skip case: memref v -> bufferization.to_tensor -> tensor t. + if (valuesOnTheWay.size() == 2 && v.hasOneUse() && + isa(v.getDefiningOp())) { + inputTypes.pop_back(); + inputValues.pop_back(); + constArgsIndexes.erase(id); + continue; + } + if (std::find(outputValues.begin(), outputValues.end(), v) == + outputValues.end()) { + outputTypes.push_back(v.getType()); + outputValues.push_back(v); + } + continue; + } + if (!v.hasOneUse()) + simpleTopo = false; + + // the children ops of v are all constant, we push their results to + // queue + for (Operation *child : v.getUsers()) { + if (!singleOperand(child) || child->getResults().size() > 1) + simpleTopo = false; + + for (OpResult result : child->getResults()) { + auto r = dyn_cast(result); + dq.push_back(r); + valuesOnTheWay.push_back(r); + } + } + } + + // If data size of outputValue is too greater than size of inputValue, do + // not fold it. Compare data size changes during traverse to find the last + // op that satisfies this condition. + if (simpleTopo) { + int64_t initSize = getValueSize(valuesOnTheWay[0]); + if (initSize * DATA_SIZE_EXPANDING_THRESHOLD < + getValueSize(valuesOnTheWay.back())) { + size_t lastIdx = 0; + for (size_t i = 1; i < valuesOnTheWay.size(); ++i) { + int64_t size = getValueSize(valuesOnTheWay[i]); + if (initSize * DATA_SIZE_EXPANDING_THRESHOLD > size) { + lastIdx = i; + } + } + if (lastIdx == 0) { // no suitable value found + inputTypes.pop_back(); + outputTypes.pop_back(); + inputValues.pop_back(); + outputValues.pop_back(); + constArgsIndexes.erase(id); + } else { + outputTypes.back() = valuesOnTheWay[lastIdx].getType(); + outputValues.back() = valuesOnTheWay[lastIdx]; + } + } + } + } + } +} + +func::FuncOp buildFoldFunc(MLIRContext *context, OpBuilder &builder, + Operation *topOp, const std::string &name, + const SmallVector &constOps, + SmallVector &inputTypes, + SmallVector &inputValues, + SmallVector &outputTypes, + SmallVector &outputValues) { + FunctionType foldFuncType = + FunctionType::get(context, inputTypes, outputTypes); + func::FuncOp foldFunc = + builder.create(topOp->getLoc(), name, foldFuncType); + Block *foldBlock = foldFunc.addEntryBlock(); + // values of folded constant tensors in foldBlock + SmallVector outputValuesInFold; + IRMapping mapper; + for (Operation *op : constOps) { + foldBlock->getOperations().push_back(op->clone(mapper)); + } + // the order of outputValuesInFold is according to the order of corresponding + // inputValues + for (auto &v : outputValues) { + auto foldedV = mapper.lookupOrNull(v); + outputValuesInFold.push_back(foldedV); + v.replaceUsesWithIf(foldedV, [&](OpOperand &val) { + Operation *op = val.getOwner(); + return op->getBlock() == foldBlock; + }); + } + + // Allocate buffer for outputValuesInFold + std::vector buffersSize; + for (Value &tensor : outputValuesInFold) { + LLVM_DEBUG(llvm::dbgs() + << "Allocate buffer for tensor: " << tensor << "\n"); + buffersSize.push_back(getValueSize(tensor)); + } + auto cacheManager = ConstGraphTensorCacheManager::get(); + SmallVector globalIndexes; + for (auto id : cacheManager->alloc(buffersSize)) + globalIndexes.push_back(id); + + globalIndexes.insert(globalIndexes.begin(), globalIndexes.size()); + auto moduleOp = dyn_cast(topOp); + addGlobalI64Array(moduleOp, moduleOp.getLoc(), builder, + "__" + name + "_buffer_ids", globalIndexes); + + auto returnOp = + builder.create(topOp->getLoc(), outputValuesInFold); + foldBlock->getOperations().push_back(returnOp); + for (size_t i = 0; i < inputValues.size(); ++i) { + inputValues[i].replaceUsesWithIf(foldBlock->getArgument(i), + [&](OpOperand &val) { + Operation *op = val.getOwner(); + return op->getBlock() == foldBlock; + }); + } + + // the ranks of folded results. + SmallVector foldRanks; + // the shapes of folded results. + SmallVector foldShapes; + for (Value &tensor : outputValuesInFold) { + auto t = dyn_cast(tensor.getType()); + Type eleType = t.getElementType(); + int64_t bitWidth = eleType.getIntOrFloatBitWidth() / 8; // bytes + ArrayRef shape = t.getShape(); + foldRanks.push_back(shape.size()); + foldShapes.insert(foldShapes.end(), shape.begin(), shape.end()); + foldShapes.push_back(bitWidth); + } + addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__folded_ranks", + foldRanks); + addGlobalI64Array(moduleOp, moduleOp.getLoc(), builder, "__folded_shapes", + foldShapes); + + foldFunc.setVisibility(SymbolTable::Visibility::Public); + foldFunc->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), + UnitAttr::get(context)); + + moduleOp.push_back(foldFunc); + SymbolTable symbolTable(moduleOp); + symbolTable.insert(foldFunc); + + return foldFunc; +} + +void modifyComputeFunc(MLIRContext *context, OpBuilder &builder, + Operation *topOp, Operation &func, Block &block, + std::unordered_set &constArgsIndexes, + SmallVector &outputTypes, + SmallVector &outputValues) { + // the indexes of args to the folding func, including to-fold tensors and + // folded results. + SmallVector foldArgs; + // the indexes of folded results. + SmallVector foldIds; + // the indexes of args to the computing func, including non-fold tensors and + // folded results. + SmallVector computeArgs; + + // modify the BlockArguments of block + size_t oriNumArgs = block.getNumArguments(); + // Add the folded args to the end of BlockArguments list + for (size_t id = 0; id < outputValues.size(); ++id) { + auto loc = block.getArgument(id).getLoc(); + BlockArgument foldArg = + block.insertArgument(oriNumArgs + id, outputTypes[id], loc); + outputValues[id].replaceUsesWithIf(foldArg, [&](OpOperand &val) { + Operation *op = val.getOwner(); + return op->getBlock() == █ + }); + foldIds.push_back(id + oriNumArgs); + } + // Erase the operations on constant args + for (size_t id = 0; id < oriNumArgs; ++id) { + if (constArgsIndexes.count(id) == 1) { + foldArgs.push_back(id); + std::deque dq; + SmallVector opsToErase; + std::unordered_set opsToEraseSet; + dq.push_back(block.getArgument(id)); + while (!dq.empty()) { + Value v = dq.front(); + dq.pop_front(); + for (Operation *op : v.getUsers()) { + for (auto res : op->getResults()) { + dq.push_back(res); + } + if (opsToEraseSet.count(op)) { + break; + } + opsToErase.push_back(op); + opsToEraseSet.insert(op); + } + } + for (auto it = opsToErase.rbegin(); it != opsToErase.rend(); ++it) { + (*it)->erase(); + } + } else { + computeArgs.push_back(id); + } + } + // Erase the constant args in BlockArguments list + llvm::BitVector argsToErase; + for (size_t id = 0; id < oriNumArgs; ++id) { + if (constArgsIndexes.count(id) == 1) { + argsToErase.push_back(true); + } else { + argsToErase.push_back(false); + } + } + for (size_t id = 0; id < outputValues.size(); ++id) { + argsToErase.push_back(false); + } + block.eraseArguments(argsToErase); + + // modify the compute func signature + func::FuncOp computeFunc = cast(func); + FunctionType computeFuncType = computeFunc.getFunctionType(); + computeFunc.setType(FunctionType::get(context, block.getArgumentTypes(), + computeFuncType.getResults())); + + auto moduleOp = dyn_cast(topOp); + for (auto id : foldIds) { + foldArgs.insert(foldArgs.end(), id); + } + foldArgs.insert(foldArgs.begin(), foldArgs.size()); + addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__fold_args", + foldArgs); + + for (auto id : foldIds) { + computeArgs.insert(computeArgs.end(), id); + } + computeArgs.insert(computeArgs.begin(), computeArgs.size()); + addGlobalI32Array(moduleOp, moduleOp.getLoc(), builder, "__compute_args", + computeArgs); + + addGlobalI32(moduleOp, moduleOp.getLoc(), builder, "__num_orig_args", + oriNumArgs); +} + +void canonicalizeAndClean(MLIRContext *context, Operation *topOp) { + // Delete dead operations by dialects' canonicalizer + RewritePatternSet owningPatterns(context); + for (auto *dialect : context->getLoadedDialects()) + dialect->getCanonicalizationPatterns(owningPatterns); + + ArrayRef disabledPatterns, enabledPatterns; + std::shared_ptr patterns = + std::make_shared( + std::move(owningPatterns), disabledPatterns, enabledPatterns); + GreedyRewriteConfig config; + LogicalResult converged = + applyPatternsAndFoldGreedily(topOp, *patterns, config); + (void)converged; + + // clean up the constant-related attrs on ops + topOp->walk([&](Operation *op) { + if (op->getAttr("onednn_graph.in_const_subgraph")) { + op->removeAttr("onednn_graph.in_const_subgraph"); + } + }); + topOp->walk([&](func::FuncOp op) { + if (op.getOperation()->getAttr("compiletime_const_args_index")) { + op.getOperation()->removeAttr("compiletime_const_args_index"); + } + if (op.getOperation()->getAttr("runtime_const_args_index")) { + op.getOperation()->removeAttr("runtime_const_args_index"); + } + }); +} + +// Operate on tensors. Create fold() and compute() on module. The +// folded weights and first-run flag is maintained by upper-level runtime. +void ConstantTensorFolding::runOnOperation() { + Operation *topOp = getOperation(); + MLIRContext *context = topOp->getContext(); + auto &topFunc = + topOp->getRegions().front().getBlocks().front().getOperations().front(); + + dataflow::RunConstantSubgraphAnalyser runAnalyser; + (void)runAnalyser.run(&topFunc); + + OpBuilder builder(context); + Region ®ion = topFunc.getRegions().front(); + Block &block = region.getBlocks().front(); + + std::unordered_set compiletimeConstArgsIndexes = + getConstArgsIndexes(topFunc, true); + std::unordered_set runtimeConstArgsIndexes = + getConstArgsIndexes(topFunc, false); + if (compiletimeConstArgsIndexes.empty() && runtimeConstArgsIndexes.empty()) { + return; + } + + postponeBroadcast(block); + + SmallVector constOps; + for (Operation &op : llvm::make_early_inc_range(block)) { + if (isInConstantSubgraph(&op)) { + constOps.push_back(&op); + } + } + + bool enableCompiletimeFolding = false; + if (enableCompiletimeFolding) { + // ===== build compile time folding function ===== + SmallVector compiletimeInputTypes; // types of constant tensors + // values of constant tensors in original block + SmallVector compiletimeInputValues; + SmallVector + compiletimeOutputTypes; // types of folded constant tensors + // values of folded constant tensors in original block + SmallVector compiletimeOutputValues; + getArithConstantOutputs(block, compiletimeOutputTypes, + compiletimeOutputValues); + getInputsAndOutputs(block, compiletimeConstArgsIndexes, + compiletimeInputTypes, compiletimeInputValues, + compiletimeOutputTypes, compiletimeOutputValues); + assert(compiletimeInputTypes.size() == compiletimeInputValues.size()); + assert(compiletimeOutputTypes.size() == compiletimeOutputValues.size()); + + if (!compiletimeOutputTypes.empty()) { + func::FuncOp compiletimeFoldFunc = + buildFoldFunc(context, builder, topOp, "compiletime_fold", constOps, + compiletimeInputTypes, compiletimeInputValues, + compiletimeOutputTypes, compiletimeOutputValues); + (void)compiletimeFoldFunc; + canonicalizeAndClean(context, compiletimeFoldFunc.getOperation()); + } + + // ===== build runtime folding function ===== + SmallVector runtimeInputTypes; // types of constant tensors + // values of constant tensors in original block + SmallVector runtimeInputValues; + SmallVector runtimeOutputTypes; // types of folded constant tensors + // values of folded constant tensors in original block + SmallVector runtimeOutputValues; + getInputsAndOutputs(block, runtimeConstArgsIndexes, runtimeInputTypes, + runtimeInputValues, runtimeOutputTypes, + runtimeOutputValues); + assert(runtimeInputTypes.size() == runtimeInputValues.size()); + assert(runtimeOutputTypes.size() == runtimeOutputValues.size()); + + if (!runtimeOutputTypes.empty()) { + func::FuncOp runtimeFoldFunc = buildFoldFunc( + context, builder, topOp, "runtime_fold", constOps, runtimeInputTypes, + runtimeInputValues, runtimeOutputTypes, runtimeOutputValues); + (void)runtimeFoldFunc; + canonicalizeAndClean(context, runtimeFoldFunc.getOperation()); + } + + // ===== build computing function ===== + std::unordered_set constArgsIndexes = compiletimeConstArgsIndexes; + constArgsIndexes.merge(runtimeConstArgsIndexes); + SmallVector outputTypes = compiletimeOutputTypes; + outputTypes.insert(outputTypes.end(), runtimeOutputTypes.begin(), + runtimeOutputTypes.end()); + SmallVector outputValues = compiletimeOutputValues; + outputValues.insert(outputValues.end(), runtimeOutputValues.begin(), + runtimeOutputValues.end()); + if (!outputTypes.empty()) { + modifyComputeFunc(context, builder, topOp, topFunc, block, + constArgsIndexes, outputTypes, outputValues); + } + } else { + std::unordered_set constArgsIndexes = compiletimeConstArgsIndexes; + constArgsIndexes.merge(runtimeConstArgsIndexes); + + // ===== build runtime folding function ===== + SmallVector inputTypes; // types of constant tensors + // values of constant tensors in original block + SmallVector inputValues; + SmallVector outputTypes; // types of folded constant tensors + // values of folded constant tensors in original block + SmallVector outputValues; + getArithConstantOutputs(block, outputTypes, outputValues); + getInputsAndOutputs(block, constArgsIndexes, inputTypes, inputValues, + outputTypes, outputValues); + assert(inputTypes.size() == inputValues.size()); + assert(outputTypes.size() == outputValues.size()); + + if (!outputTypes.empty()) { + func::FuncOp foldFunc = + buildFoldFunc(context, builder, topOp, "runtime_fold", constOps, + inputTypes, inputValues, outputTypes, outputValues); + (void)foldFunc; + canonicalizeAndClean(context, foldFunc.getOperation()); + + // ===== build computing function ===== + modifyComputeFunc(context, builder, topOp, topFunc, block, + constArgsIndexes, outputTypes, outputValues); + } + } + + canonicalizeAndClean(context, topOp); + topOp->dump(); +} + +std::unique_ptr createConstantTensorFoldingPass() { + return std::make_unique(); +} + +} // namespace gc +} // namespace mlir diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 0c118fbda..40527f644 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -52,7 +52,7 @@ void populateFrontendPasses(mlir::OpPassManager &pm) { void populateTensorPasses(mlir::OpPassManager &pm) { // todo: padding propagation pass // todo: layout propagation pass - // todo: tensor constant propagation pass + pm.addPass(createConstantTensorFoldingPass()); // linalg.matmul lowering to (scf.loop + linalg.brgemm) pass pm.addNestedPass(createDeepTileContractionOp()); diff --git a/test/gc/Transforms/test_constant_tensor_folding-0.mlir b/test/gc/Transforms/test_constant_tensor_folding-0.mlir new file mode 100644 index 000000000..155e0875e --- /dev/null +++ b/test/gc/Transforms/test_constant_tensor_folding-0.mlir @@ -0,0 +1,87 @@ +// RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(constant-tensor-folding)" %s | FileCheck %s + +// COM:A complete example of compile-time and runtime folding. + +// CHECK-LABEL: func.func @entry +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)> +module { + // COM: A three-layer mlp. %arg0: input feature. %arg1, %arg2, %arg3: weight of #1, #2 and #3 linear. + func.func @entry(%arg0: tensor<64x32xbf16>, %arg2: tensor<32x256xbf16>, %arg3: tensor<256x1024xbf16>) + -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, compiletime_const_args_index = [1 : i32], runtime_const_args_index = [2 : i32]} { + %1 = tensor.empty() : tensor<2x1x32x32xbf16> + %packed_arg0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<64x32xbf16> -> tensor<2x1x32x32xbf16> + + %arg1 = arith.constant dense<"0x99571CBE05BA1C3D926AFCBD782B34BE67A737BEBF181ABE3C4E253B5D32F73D1566963D9B6F313EB74C0FBD430B253AE8E2E23DB0CBC53C46C014BE2E0981BD4D313F3E833D37BEAB70E13D65B6CA3DB194983D1E60983D950B71BD1815FDBB32DF9A3DD106EBBDB4A8233E841EC3BDE8C7C13D3734073EFF067DBD070206BEF6AF633DB209843C2135C3BD4F85B83C43BD1CBE04841A3E3E78BD3DE9D0D0BCF660093ED074083E14D43E3ECDA735BE8C8C0E3E40C60FBE4F73C9BDB4358DBD263D323C64E61EBEE535D23D238F013C727EA73DBDBAA1BD79D53EBE392981BDC06B453D10E37D3D2D2B41BEE1FA6BBD410E513D05588BBD514AB0BB0624243E3D80993C8E6A113EE57CFD3D23FE37BE001573BD86AD143E7F052D3E97C07DBD19B4113D3E87F6BDB971E83DFEA12BBC5D51F9BD4F203A3ED454043E22775BBD2EE8313EB027D03D8FEFD7BD0E56B7BDBF963FBE5B64E93D9291FBBD027101BE573DFD3D0CD6EB3D809B863DA9E8263E9EF2A43D717AB73D3CF597BD9FB7243DC603003D61780E3E3992293D8B1B25BE6B0024BE806DCB3D5BAB91BD9A33AFBDD5BC3BBE6D920FBE0D90F53D4513383E2219A0BBE8B6FBBD341C42BD42F235BED91A1ABDC3AEB0BD5AC1383DE0EADC3D303D11BE850D263E8281163E5CB78A3D19EB34BE33150F3E84F8EE3D18FC823DB26CCBBD09AB06BED909FFBA605EFE3B9014B7BD1606DA3D75ACE13D0910753C33C6843DE9951CBECD220ABD0EF2BF3D14BB2E3C798718BD60A53A3E8B83E53D18663DBE4D07CABD37CE043EA6B18E3D3D0F303EE392073EC92A1ABED6900E3E72D3E73D8CEF803D1B4D3D3E997D283E210F923BC2D131BECEAF913DB981EFBDCBCCCCBA2B6711BE4E32FE3C5D5D33BD2F34313EB7EC48BC26CDFD3D07170B3E1CD816BE310DD2BD9E03023E1EA8F3BD8B99EEBBFC97433E047F8DBDDD6BA03DA3B2433E34D7C0BC7FDB89BA1980333EF3FC8D3DC05C203E9C7213BD8385403E2F971A3E4357CF3DB39BFBBC784FF8BC7DBD0C3E8301E23D77BF1ABB04F3243CFBA3B1BD5A46C6BD1745A8BDD6950ABD939CC5BDB4226EBCAC622EBD6748FBBDAFF9D53DF29D433E41991C3D4DD7353EE2EF8E3D21EF3B3DF679973D31DEFDBDF0AF303E8D34DFBB31B895BD6A633A3EACE125BEE94E95BDA58043BEC9F233BE915F03BD1B7C8F3DE1D367BDD7BBD63D6E990A3E23222F3D4B6CD73DB869C53D8697383E3A86853D973F2C3EFC3827BC4E87FA3DD5903BBE4BB8403E34A9A33D41C8843D4BC8FABD3CD5E8BD4946233D955052BDA5F841BC6C81AFBD5DD8883DB71A753CD0A1263D88690ABE35DAA73CA3557D3D8C09D23D5A27273DECEFDBBCD220023EE036ACBD6CD2443E8F630FBEBC43B73DF03AA4BDC709133E1B94E73D362CE4BCB15F33BE3139443E5FCF62BD0E3C1B3EE99DF93D9E1BB3BA70DB213E38EBDDBC47F10CBEF817293DAD3DEB3B730942BE535C87BD448D7B3B1C8094BD97962B3D5B0F3B3EA3F42A3E4ED46DBD6D72C33C687CC63DEA34C53D1CCC3EBEDCA640BE638ABCBD4B63AFBDA699063E92861E3E98219FBC8E0B233ED3ED573DC856B8BD13880F3EFA0763BD5A8C89BD194519BE89C6CF3D73A219BC5ECBD43D41EFA33D27D8493D756B1ABEC796C93D9A25133C6A5A363E13FB8DBD601755BD3935FABD14D6883D0EF2D33DB8E914BD527347397200433DE72A3F3B62C52F3ED164EF3CD8806FBD05528B3D89701EBE0A09C23DA19B103D05922EBE7A100E3E31C0503D8ED53BBE08463E3E5168013E55F3E53D782EC53DA8BBD93C1711223E05FDB2BDA740113EA27A20BD1685A23D7E35293E02BD8B3CC43F163E4AE6613DE4280F3EEEF20BBE965C1DBEFAAD233E75754E3D96C33BBCB6D7013E0D8E7ABD703C82BDEA0875BC6F57A6BCE83609BE8A8EB53DAB7D3C3E39A50ABEB878A33D9FCEA1BC124AD33C22C34A3DB5F338BE0307BF3C2F0881BD7E15E8BDBEE8C8BDBBFFA63C342F303E15B1CCBB2590153EEA05EF3DE778F2BCE9E1233ECEC244BDBF92D5BDECDEAE3C29750CBDD969FCBD7DC236BE571D1DBEC8FA7DBC243BAD3C38673D3ED15943BEFE4D913D5329273E18AB2EBE19AB5F3D30A62F3E94303CBE1421DABCBE6E133E355D073EEC76633DEB2AB83DA2BF16BC9A46C2BD4EB47EBC4C82343EC1D1E63D13D314BED232E3BD3E5CF1BDC78F9EBD6483233E7290293E514A163E255F0FBE1AEF7BBD5259173EF12524BEDF47793C886BE8BD57B408BE351980BD0FF71ABD24643ABEA79920BED2603A3EEB75393EC6D52B3E458B29BC22C45ABC02BB40BCED4BDEBCA6E9CABC11FB213EC4FB363E5AC2DCBDAD6B4F3CBB85B1BD8093343E487518BEDFA316BD7FFFAEBB9375963DF68A88BD6876013C9FA1C63D95CDB23C911721BE04B5F9BD1B7C8F3DE1D367BDD7BBD63D6E990A3E23222F3D4B6CD73DB869C53D8697383E3A86853D973F2C3EFC3827BC4E87FA3DD5903BBE4BB8403E34A9A33D41C8843D4BC8FABD3CD5E8BD4946233D955052BDA5F841BC6C81AFBD5DD8883DB71A753CD0A1263D88690ABE35DAA73CA3557D3D8C09D23D5A27273DECEFDBBCD220023EE036ACBD6CD2443E8F630FBEBC43B73DF03AA4BDC709133E1B94E73D362CE4BCB15F33BE3139443E5FCF62BD0E3C1B3EE99DF93D9E1BB3BA70DB213E38EBDDBC47F10CBEF817293DAD3997D283E210F923BC2D131BECEAF913DB981EFBDCBCCCCBA2B6711BE4E32FE3C5D5D33BD2F34313EB7EC48BC26CDFD3D07170B3E1CD816BE310DD2BD9E03023E1EA8F3BD8B99EEBBFC97433E047F8DBDDD6BA03DA3B2433E34D7C0BC7FDB89BA1980333EF3EB7EC48B383DE0E383DE0E383DE0E383DE0"> : tensor<32x32xbf16> + %2 = tensor.empty() : tensor<1x1x32x32xbf16> + %packed_arg1 = tensor.pack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %2 : tensor<32x32xbf16> -> tensor<1x1x32x32xbf16> + %3 = tensor.empty() : tensor<1x1x16x32x2xbf16> + %packed_packed_arg1 = tensor.pack %packed_arg1 inner_dims_pos = [2] inner_tiles = [2] into %3 : tensor<1x1x32x32xbf16> -> tensor<1x1x16x32x2xbf16> + + %4 = tensor.empty() : tensor<2x1x32x32xbf16> + %cst_0 = arith.constant 0.000000e+00 : bf16 + %5 = linalg.fill ins(%cst_0 : bf16) outs(%4 : tensor<2x1x32x32xbf16>) -> tensor<2x1x32x32xbf16> + %6 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%packed_arg0, %packed_packed_arg1 : tensor<2x1x32x32xbf16>, tensor<1x1x16x32x2xbf16>) outs(%5 : tensor<2x1x32x32xbf16>) { + ^bb0(%in: bf16, %in_0: bf16, %out: bf16): + %44 = arith.mulf %in, %in_0 : bf16 + %55 = arith.addf %out, %44 : bf16 + linalg.yield %55 : bf16 + } -> tensor<2x1x32x32xbf16> + + %7 = tensor.empty() : tensor<8x1x32x32xbf16> + %packed_arg2 = tensor.pack %arg2 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %7 : tensor<32x256xbf16> -> tensor<8x1x32x32xbf16> + %8 = tensor.empty() : tensor<8x1x16x32x2xbf16> + %packed_packed_arg2 = tensor.pack %packed_arg2 inner_dims_pos = [2] inner_tiles = [2] into %8 : tensor<8x1x32x32xbf16> -> tensor<8x1x16x32x2xbf16> + %9 = tensor.empty() : tensor<2x8x32x32xbf16> + %10 = linalg.fill ins(%cst_0 : bf16) outs(%9 : tensor<2x8x32x32xbf16>) -> tensor<2x8x32x32xbf16> + %11 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%6, %packed_packed_arg2 : tensor<2x1x32x32xbf16>, tensor<8x1x16x32x2xbf16>) outs(%10 : tensor<2x8x32x32xbf16>) { + ^bb0(%in: bf16, %in_0: bf16, %out: bf16): + %44 = arith.mulf %in, %in_0 : bf16 + %55 = arith.addf %out, %44 : bf16 + linalg.yield %55 : bf16 + } -> tensor<2x8x32x32xbf16> + + %12 = tensor.empty() : tensor<32x8x32x32xbf16> + %packed_arg3 = tensor.pack %arg3 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %12 : tensor<256x1024xbf16> -> tensor<32x8x32x32xbf16> + %13 = tensor.empty() : tensor<32x8x16x32x2xbf16> + %packed_packed_arg3 = tensor.pack %packed_arg3 inner_dims_pos = [2] inner_tiles = [2] into %13 : tensor<32x8x32x32xbf16> -> tensor<32x8x16x32x2xbf16> + + %14 = tensor.empty() : tensor<2x32x32x32xbf16> + %15 = linalg.fill ins(%cst_0 : bf16) outs(%14 : tensor<2x32x32x32xbf16>) -> tensor<2x32x32x32xbf16> + %16 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%11, %packed_packed_arg3 : tensor<2x8x32x32xbf16>, tensor<32x8x16x32x2xbf16>) outs(%15 : tensor<2x32x32x32xbf16>) { + ^bb0(%in: bf16, %in_0: bf16, %out: bf16): + %46 = arith.mulf %in, %in_0 : bf16 + %56 = arith.addf %out, %46 : bf16 + linalg.yield %56 : bf16 + } -> tensor<2x32x32x32xbf16> + + %17 = tensor.empty() : tensor<64x1024xbf16> + %unpack = tensor.unpack %16 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %17 : tensor<2x32x32x32xbf16> -> tensor<64x1024xbf16> + return %unpack : tensor<64x1024xbf16> + } +} + +// COM: If enable compile time folding, +// COM: 1 pack in entry for input feature, +// COM: 4 packs in compiletime_fold for 2 weights, +// COM: 2 packs in runtime_fold for 1 weights: +// COM: CHECK: tensor.pack +// COM: CHECK: func.func @compiletime_fold +// COM: CHECK: tensor.pack +// COM: CHECK: tensor.pack +// COM: CHECK: tensor.pack +// COM: CHECK: tensor.pack +// COM: CHECK: func.func @runtime_fold +// COM: CHECK: tensor.pack +// COM: CHECK: tensor.pack + +// COM: else, +// CHECK: tensor.pack +// CHECK: func.func @runtime_fold +// CHECK: tensor.pack +// CHECK: tensor.pack +// CHECK: tensor.pack +// CHECK: tensor.pack +// CHECK: tensor.pack +// CHECK: tensor.pack diff --git a/test/gc/Transforms/test_constant_tensor_folding-1.mlir b/test/gc/Transforms/test_constant_tensor_folding-1.mlir new file mode 100644 index 000000000..ca70f8d6a --- /dev/null +++ b/test/gc/Transforms/test_constant_tensor_folding-1.mlir @@ -0,0 +1,91 @@ +// RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(constant-tensor-folding)" %s | FileCheck %s + +// COM: Test the 'postponeBroadcast' feature of constant tensor folding. + +// CHECK-LABEL: func.func @entry +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d4, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d2, d6 floordiv 2, d5, d3)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d4, d5)> +#map3 = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +module { + // COM: A two-layer mlp. arg0: input feature. + // COM: arg1: weight of #1 linear. arg2: bias of #1 linear. + // COM: arg3: weight of #2 linear. arg4: bias of #2 linear. + func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>, %arg3: tensor<256x1024xbf16>, %arg4: tensor<1024xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, runtime_const_args_index = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} { + %1 = tensor.empty() : tensor<2x16x32x32xbf16> + %packed_arg0 = tensor.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<64x512xbf16> -> tensor<2x16x32x32xbf16> + %2 = tensor.empty() : tensor<8x16x32x32xbf16> + %packed_arg1 = tensor.pack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %2 : tensor<512x256xbf16> -> tensor<8x16x32x32xbf16> + %3 = tensor.empty() : tensor<8x16x16x32x2xbf16> + %packed_packed_arg1 = tensor.pack %packed_arg1 inner_dims_pos = [2] inner_tiles = [2] into %3 : tensor<8x16x32x32xbf16> -> tensor<8x16x16x32x2xbf16> + %4 = tensor.empty() : tensor<2x8x32x32xbf16> + %cst_0 = arith.constant 0.000000e+00 : bf16 + %5 = linalg.fill ins(%cst_0 : bf16) outs(%4 : tensor<2x8x32x32xbf16>) -> tensor<2x8x32x32xbf16> + %6 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%packed_arg0, %packed_packed_arg1 : tensor<2x16x32x32xbf16>, tensor<8x16x16x32x2xbf16>) outs(%5 : tensor<2x8x32x32xbf16>) { + ^bb0(%in: bf16, %in_0: bf16, %out: bf16): + %44 = arith.mulf %in, %in_0 : bf16 + %55 = arith.addf %out, %44 : bf16 + linalg.yield %55 : bf16 + } -> tensor<2x8x32x32xbf16> + + // COM: Operations on %arg2: {pack, broadcast, extf, mul, truncf, bias_add} in entry(). + %15 = tensor.empty() : tensor<8x32xbf16> + %packed_arg2 = tensor.pack %arg2 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %15 : tensor<256xbf16> -> tensor<8x32xbf16> + %bc_arg2_init = tensor.empty() : tensor<2x8x32x32xbf16> + %bc_arg2 = linalg.broadcast ins(%packed_arg2 : tensor<8x32xbf16>) outs(%bc_arg2_init : tensor<2x8x32x32xbf16>) dimensions = [0, 2] + %extf32 = arith.extf %bc_arg2 : tensor<2x8x32x32xbf16> to tensor<2x8x32x32xf32> + %cst_2 = arith.constant 2.000000e+00 : f32 + %extf32_mul2_init = tensor.empty() : tensor<2x8x32x32xf32> + %extf32_mul2 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extf32 : tensor<2x8x32x32xf32>) outs(%extf32_mul2_init : tensor<2x8x32x32xf32>) { + ^bb0(%in: f32, %out: f32): + %8 = arith.mulf %in, %cst_2 : f32 + linalg.yield %8 : f32 + } -> tensor<2x8x32x32xf32> + %truncbf16 = arith.truncf %extf32_mul2 : tensor<2x8x32x32xf32> to tensor<2x8x32x32xbf16> + + %7 = linalg.generic {indexing_maps = [#map4, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%truncbf16 : tensor<2x8x32x32xbf16>) outs(%6 : tensor<2x8x32x32xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %45 = arith.addf %in, %out : bf16 + linalg.yield %45 : bf16 + } -> tensor<2x8x32x32xbf16> + + %8 = tensor.empty() : tensor<32x8x32x32xbf16> + %packed_arg3 = tensor.pack %arg3 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %8 : tensor<256x1024xbf16> -> tensor<32x8x32x32xbf16> + %9 = tensor.empty() : tensor<32x8x16x32x2xbf16> + %packed_packed_arg3 = tensor.pack %packed_arg3 inner_dims_pos = [2] inner_tiles = [2] into %9 : tensor<32x8x32x32xbf16> -> tensor<32x8x16x32x2xbf16> + %10 = tensor.empty() : tensor<2x32x32x32xbf16> + %11 = linalg.fill ins(%cst_0 : bf16) outs(%10 : tensor<2x32x32x32xbf16>) -> tensor<2x32x32x32xbf16> + %12 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel", "reduction"]} ins(%7, %packed_packed_arg3 : tensor<2x8x32x32xbf16>, tensor<32x8x16x32x2xbf16>) outs(%11 : tensor<2x32x32x32xbf16>) { + ^bb0(%in: bf16, %in_0: bf16, %out: bf16): + %46 = arith.mulf %in, %in_0 : bf16 + %56 = arith.addf %out, %46 : bf16 + linalg.yield %56 : bf16 + } -> tensor<2x32x32x32xbf16> + %16 = tensor.empty() : tensor<32x32xbf16> + %packed_arg4 = tensor.pack %arg4 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %16 : tensor<1024xbf16> -> tensor<32x32xbf16> + %13 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%packed_arg4 : tensor<32x32xbf16>) outs(%12 : tensor<2x32x32x32xbf16>) { + ^bb0(%in: bf16, %out: bf16): + %47 = arith.addf %in, %out : bf16 + linalg.yield %47 : bf16 + } -> tensor<2x32x32x32xbf16> + %14 = tensor.empty() : tensor<64x1024xbf16> + %unpack = tensor.unpack %13 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %14 : tensor<2x32x32x32xbf16> -> tensor<64x1024xbf16> + return %unpack : tensor<64x1024xbf16> + } +} + +// COM: After transform, operations on %arg2: {pack, extf, mul, truncf} in fold(), {broadcast, bias_add} in entry(). +// CHECK: linalg.broadcast +// CHECK: func.func @runtime_fold +// CHECK: arith.extf +// CHECK: arith.truncf + +// COM: expected output: +// COM: module { +// COM: llvm.mlir.global external constant @__num_orig_args(5 : i32) {addr_space = 0 : i32} : i32 +// COM: llvm.mlir.global external constant @__compute_args(dense<[5, 0, 5, 6, 7, 8]> : tensor<6xi32>) {addr_space = 0 : i32} : !llvm.array<6 x i32> +// COM: llvm.mlir.global external constant @__fold_args(dense<[8, 1, 2, 3, 4, 5, 6, 7, 8]> : tensor<9xi32>) {addr_space = 0 : i32} : !llvm.array<9 x i32> +// COM: llvm.mlir.global external constant @__fold_buffer_ids(dense<[4, 0, 1, 2, 3]> : tensor<5xi64>) {addr_space = 0 : i32} : !llvm.array<5 x i64> +// COM: func.func @entry(%arg0: tensor<64x512xbf16>, %arg1: tensor<8x16x16x32x2xbf16>, %arg2: tensor<8x32xbf16>, %arg3: tensor<32x8x16x32x2xbf16>, %arg4: tensor<32x32xbf16>) -> tensor<64x1024xbf16> attributes {llvm.emit_c_interface, runtime_const_args_index = [1 : i32, 2 : i32, 3 : i32, 4 : i32]} +// COM: func.func @fold(%arg0: tensor<512x256xbf16>, %arg1: tensor<256xbf16>, %arg2: tensor<256x1024xbf16>, %arg3: tensor<1024xbf16>) -> (tensor<8x16x16x32x2xbf16>, tensor<8x32xbf16>, tensor<32x8x16x32x2xbf16>, tensor<32x32xbf16>) attributes {llvm.emit_c_interface} diff --git a/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp b/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp index f7b93eaa6..032c9e0d7 100644 --- a/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp +++ b/test/mlir/unittests/ExecutionEngine/JitWrapper.cpp @@ -25,8 +25,7 @@ using namespace mlir; static const char code1[] = R"mlir( module { -llvm.mlir.global constant @__num_orig_num_args(3 : i32) : i32 -func.func @compute(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { +func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface } { %out = tensor.empty() : tensor<128xf32> %2 = linalg.add ins(%a, %b : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> return %2 : tensor<128xf32> @@ -68,3 +67,94 @@ TEST(ExecutionEngine, JitWrapper) { ASSERT_EQ(bufC[{i}], 1.0f + i); } } + +// compute d = (a+a) + (b+b) + c, where a,b is marked constant +// bufIdx: a=0, b=1, c=2, d=3, foldedA=4, foldedB=5 +static const char code2[] = R"mlir( +module { +func.func @entry(%a: tensor<128xf32>, %b: tensor<128xf32>, %c: tensor<128xf32>) -> tensor<128xf32> attributes { llvm.emit_c_interface, runtime_const_args_index = [0 : i32, 1 : i32] } { + %out = tensor.empty() : tensor<128xf32> + %ax2 = linalg.add ins(%a, %a : tensor<128xf32>,tensor<128xf32>) outs(%out : tensor<128xf32>) -> tensor<128xf32> + %out2 = tensor.empty() : tensor<128xf32> + %bx2 = linalg.add ins(%b, %b : tensor<128xf32>,tensor<128xf32>) outs(%out2 : tensor<128xf32>) -> tensor<128xf32> + %out3 = tensor.empty() : tensor<128xf32> + %ax2pbx2 = linalg.add ins(%ax2, %bx2 : tensor<128xf32>,tensor<128xf32>) outs(%out3 : tensor<128xf32>) -> tensor<128xf32> + %out4 = tensor.empty() : tensor<128xf32> + %d = linalg.add ins(%ax2pbx2, %c : tensor<128xf32>,tensor<128xf32>) outs(%out4 : tensor<128xf32>) -> tensor<128xf32> + return %d : tensor<128xf32> +} +} +)mlir"; + +TEST(ExecutionEngine, JitWrapperCached) { + MLIRContext ctx{gc::initCompilerAndGetDialects()}; + std::unique_ptr ir_buffer = + llvm::MemoryBuffer::getMemBuffer(code2); + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(ir_buffer), llvm::SMLoc()); + mlir::OwningOpRef module = + mlir::parseSourceFile(sourceMgr, &ctx); + + ASSERT_TRUE(module); + auto jited = gc::JitModule::create(module.get()); + bool jit_success = static_cast(jited); + if (!jit_success) { + auto err = jited.takeError(); + llvm::errs() << err; + llvm::consumeError(std::move(err)); + } + ASSERT_TRUE(jit_success); + + auto cacheManager = gc::ConstGraphTensorCacheManager::get(); + auto ret = std::shared_ptr(new float[128]); + auto proxy = std::make_shared(ret, ret.get(), + 128 * sizeof(float), true); + // Can not register with already existing key. + ASSERT_FALSE(cacheManager->regCachedTensor(0, proxy, 0)); + + proxy = cacheManager->queryCacheTensor(0)->base; + auto data = (float *)proxy->getBufferUnsafe(); + + OwningMemRef bufA{ + {128}, {128}, [](float &ptr, ArrayRef) { ptr = 1.0f; }}; + OwningMemRef bufB{ + {128}, {128}, [](float &ptr, ArrayRef idx) { ptr = idx[0]; }}; + OwningMemRef bufC{ + {128}, {128}, [](float &ptr, ArrayRef idx) { + ptr = -idx[0] * 3; + }}; + OwningMemRef bufD{ + {128}, {128}, [](float &ptr, ArrayRef) { ptr = 100.0f; }}; + void *args[] = {&*bufA, &*bufB, &*bufC, &*bufD}; + + { + // first call, should run fold() + jited.get()->call(args, 4); + + for (int i = 0; i < 128; i++) { + ASSERT_EQ(*(data + i), 2 * 1.0f + 2 * i); + } + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); + } + } + + { + // second call, should not run fold() + jited.get()->call(args, 4); + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); + } + } + + // the cache is evicted + proxy->deref(); + { + // third call, should run fold() + jited.get()->call(args, 4); + for (int i = 0; i < 128; i++) { + ASSERT_EQ(bufD[{i}], 2 * 1.0f + 2 * i - 3 * i); + } + } +}