Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental][Transform] Split Compute Intensive Op #154

Open
wants to merge 12 commits into
base: xurui/add_benchmark
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,4 @@ target_compile_options(graph_compiler PRIVATE -fvisibility=hidden)
target_link_options(graph_compiler PRIVATE -Wl,--gc-sections)
target_link_libraries(graph_compiler PRIVATE ${GC_LIB_LINKED_LIBS})

add_subdirectory(unittests)
add_subdirectory(test)
2 changes: 1 addition & 1 deletion cmake/llvm-version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
763b96c86d81d51d0db430791a61fd1e8a406bce
891ec2af45c02718c65f539cb6dad1758f079e73
69 changes: 69 additions & 0 deletions include/gc/ExecutionEngine/Driver/Driver.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
//===-- Driver.h - The top-level MLIR compiler driver -----------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef GC_EXECUTIONENGINE_DRIVER_DRIVER_H
#define GC_EXECUTIONENGINE_DRIVER_DRIVER_H

#include "mlir/ExecutionEngine/CRunnerUtils.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include <memory>
#include <string_view>

namespace mlir {
class DialectRegistry;
namespace gc {

const DialectRegistry &initCompilerAndGetDialects();

// the pointers to XXXMemRefType
using GeneralMemrefPtr = void *;
using JitModuleFuncT = void (*)(void **);

struct DriverOptions {
/// the optimization level for the LLVM-JIT
llvm::CodeGenOptLevel jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive;
/// whether to run the MLIR transformation passes
bool runTransforms = true;
/// todo: target machine, etc.
};

class JitModule {
public:
static llvm::Expected<std::shared_ptr<JitModule>>
create(Operation *op, const DriverOptions &options = {});

/// args should be an array of XXXMemrefType*
void call(GeneralMemrefPtr *args, std::size_t numArgs) {
// Silly code, MLIR execution engine requires pointers of real args as
// inputs
llvm::SmallVector<void *, 32> realargs;
realargs.reserve(numArgs);
for (size_t i = 0; i < numArgs; i++) {
realargs.push_back(&args[i]);
}
compute(realargs.data());
}

/// directly call compute(). args should be an array of void*. args[i] should
/// be a pointer to the real data. For passing memref, users need to 1) create
/// a pointer to XXXMemrefType 2) store the pointer to pointer to
/// XXXMemrefType in args[i]
void callRaw(void **args) { compute(args); }

JitModule(std::unique_ptr<ExecutionEngine> engine, JitModuleFuncT compute);
~JitModule();

private:
std::unique_ptr<ExecutionEngine> engine;
JitModuleFuncT compute;
};

} // namespace gc
} // namespace mlir

#endif
12 changes: 12 additions & 0 deletions include/gc/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,16 @@ def GCCPUPipeline: Pass<"gc-cpu-pipeline"> {
"vector::VectorDialect"];
}

def SplitComputeIntensivePatterns : Pass<"split-compute-intensive-patterns"> {
let summary = "Split matmul patterns";
let description = [{
Split matmul patterns' weights into several parts, number of which aligns
with the number of target machine's numa node.
}];
let dependentDialects = [
"mlir::linalg::LinalgDialect",
"mlir::tensor::TensorDialect",
"mlir::arith::ArithDialect"];
}

#endif // GC_DIALECT_GC_PASSES
4 changes: 3 additions & 1 deletion lib/gc/Dialect/CPURuntime/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRFuncDialect)

add_mlir_dialect_library(MLIRCPURuntimeDialect
CPURuntimeDialect.cpp
CPURuntimeOps.cpp
Expand All @@ -10,5 +12,5 @@ add_mlir_dialect_library(MLIRCPURuntimeDialect
MLIRCPURuntimePassesIncGen

LINK_LIBS PUBLIC
MLIRFuncDialect
${MLIR_LINK_COMPONENTS}
)
4 changes: 3 additions & 1 deletion lib/gc/Dialect/CPURuntime/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRFuncDialect)

add_mlir_dialect_library(MLIRCPURuntimeTransforms
CPURuntimeToLLVM.cpp

Expand All @@ -8,7 +10,7 @@ add_mlir_dialect_library(MLIRCPURuntimeTransforms
MLIRCPURuntimePassesIncGen

LINK_LIBS PUBLIC
MLIRFuncDialect
${MLIR_LINK_COMPONENTS}
MLIRCPURuntimeDialect
)

Expand Down
1 change: 1 addition & 0 deletions lib/gc/ExecutionEngine/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(CPURuntime)
add_subdirectory(Driver)
41 changes: 41 additions & 0 deletions lib/gc/ExecutionEngine/Driver/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
if(GC_DEV_LINK_LLVM_DYLIB)
set(LLVM_LINK_COMPONENTS
LLVM
)
get_property(dialect_libs GLOBAL PROPERTY GC_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY GC_PASS_LIBS)
set(MLIR_LINK_COMPONENTS
MLIR
MLIRExecutionEngineShared
)
else()
set(LLVM_LINK_COMPONENTS
Core
Support
nativecodegen
native
)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
set(MLIR_LINK_COMPONENTS
MLIRBuiltinToLLVMIRTranslation
MLIRExecutionEngine
MLIRLLVMDialect
MLIRLLVMToLLVMIRTranslation
MLIRToLLVMIRTranslationRegistration
)
endif()

add_mlir_library(GCJitWrapper
Driver.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include

LINK_LIBS PUBLIC
${MLIR_LINK_COMPONENTS}
${dialect_libs}
${conversion_libs}
GCPasses
)

82 changes: 82 additions & 0 deletions lib/gc/ExecutionEngine/Driver/Driver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
//===-- Driver.cpp - Top-level MLIR compiler driver -------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "gc/ExecutionEngine/Driver/Driver.h"
#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h"
#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
#include "gc/Transforms/Passes.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR/Dialect/All.h"
#include "string.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/TargetSelect.h"

namespace mlir {
namespace gc {

static DialectRegistry initDialects() {
mlir::registerAllPasses();
mlir::gc::registerGraphCompilerPasses();
mlir::cpuruntime::registerCPURuntimePasses();
mlir::DialectRegistry registry;
registry.insert<mlir::cpuruntime::CPURuntimeDialect>();
mlir::registerAllDialects(registry);
mlir::cpuruntime::registerConvertCPURuntimeToLLVMInterface(registry);
registry.insert<mlir::onednn_graph::OneDNNGraphDialect>();
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();
mlir::registerAllToLLVMIRTranslations(registry);
return registry;
}

const DialectRegistry &initCompilerAndGetDialects() {
static DialectRegistry reg = initDialects();
return reg;
}

static const char defaultComputeName[] = "_mlir_ciface_compute";

llvm::Expected<std::shared_ptr<JitModule>>
JitModule::create(Operation *op, const DriverOptions &options) {
if (options.runTransforms) {
mlir::PassManager pm{op->getContext()};
populateCPUPipeline(pm);
if (auto result = pm.run(op); failed(result)) {
return llvm::make_error<llvm::StringError>(
"MLIR pass error", llvm::inconvertibleErrorCode());
}
}
ExecutionEngineOptions exeOptions;
exeOptions.jitCodeGenOptLevel = options.jitCodeGenOptLevel;
std::unique_ptr<llvm::TargetMachine> tm = nullptr;
auto exec = ExecutionEngine::create(op, exeOptions, std::move(tm));
if (!exec) {
return exec.takeError();
}
auto &engine = *exec;
JitModuleFuncT compute;
{
auto expectCompute = engine->lookupPacked(defaultComputeName);
if (!expectCompute) {
return expectCompute.takeError();
}
compute = *expectCompute;
}
return std::make_shared<JitModule>(std::move(engine), compute);
}

JitModule::JitModule(std::unique_ptr<ExecutionEngine> engine,
JitModuleFuncT compute)
: engine{std::move(engine)}, compute{compute} {}
JitModule::~JitModule() = default;

} // namespace gc
} // namespace mlir
1 change: 1 addition & 0 deletions lib/gc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_library(GCPasses
OneDNNGraphToLinalg.cpp
Pipeline.cpp
TileNamed.cpp
SplitComputeIntensivePatterns.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include
Expand Down
16 changes: 14 additions & 2 deletions lib/gc/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
Expand All @@ -35,6 +38,7 @@ void populateFrontendPasses(mlir::PassManager &pm) {

// scf + arith + math + vector + tensor + linalg.brgemm + tensor.pack/unpack
void populateTensorPasses(mlir::PassManager &pm) {
// pm.addNestedPass<func::FuncOp>(createSplitComputeIntensivePatterns());
// todo: padding propagation pass
// todo: layout propagation pass
// todo: tensor constant propagation pass
Expand All @@ -49,8 +53,16 @@ void populateTensorPasses(mlir::PassManager &pm) {

// scf + arith + math + vector + tensor + linalg.brgemm
void populateVectorPasses(mlir::PassManager &pm) {
// todo: bf16 promotion pass, device dependent pass
// todo: bf16 cast elimilation pass, fast-math kind pass, designed to support
// Do promotion for math / arith ops
pm.addNestedPass<func::FuncOp>(math::createMathLegalizeToF32());
// sourceTypeStrs can be extended
arith::ArithEmulateUnsupportedFloatsOptions options;
options.sourceTypeStrs = {"bf16"};
options.targetTypeStr = "f32";
pm.addNestedPass<func::FuncOp>(
arith::createArithEmulateUnsupportedFloats(options));
// Bf16 cast elimilation pass
pm.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());
// oneDNN graph spec
pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass());
// todo: lower to physical vector pass, device dependent pass
Expand Down
Loading