Skip to content

Commit

Permalink
[arcilator] Introduce integrated JIT for simulation execution (#6783)
Browse files Browse the repository at this point in the history
This PR adds a JIT runtime for arcilator, backed by MLIR's ExecutionEngine. This JIT allows executing `arc.sim` operations directly from the arcilator binary.
  • Loading branch information
Moxinilian authored Mar 18, 2024
1 parent 4b075de commit 83a8292
Show file tree
Hide file tree
Showing 13 changed files with 393 additions and 122 deletions.
16 changes: 16 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ else()
include_directories(SYSTEM ${MLIR_INCLUDE_DIR})
include_directories(SYSTEM ${MLIR_TABLEGEN_OUTPUT_DIR})

# If building as part of a unified build, whether or not MLIR's execution engine
# is enabled must be fetched from its subdirectory scope.
get_directory_property(MLIR_ENABLE_EXECUTION_ENGINE
DIRECTORY ${MLIR_MAIN_SRC_DIR}
DEFINITION MLIR_ENABLE_EXECUTION_ENGINE)

set(BACKEND_PACKAGE_STRING "${PACKAGE_STRING}")

set(CIRCT_GTEST_AVAILABLE 1)
Expand Down Expand Up @@ -587,6 +593,16 @@ if(CIRCT_SLANG_FRONTEND_ENABLED)
endif()
endif()

#-------------------------------------------------------------------------------
# Arcilator JIT
#-------------------------------------------------------------------------------

if(MLIR_ENABLE_EXECUTION_ENGINE)
set(ARCILATOR_JIT_ENABLED 1)
else()
set(ARCILATOR_JIT_ENABLED 0)
endif()

#-------------------------------------------------------------------------------
# Directory setup
#-------------------------------------------------------------------------------
Expand Down
26 changes: 26 additions & 0 deletions integration_test/arcilator/JIT/basic.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: arcilator %s --run --jit-entry=main | FileCheck %s
// REQUIRES: arcilator-jit

// CHECK: output = 5

hw.module @adder(in %a: i8, in %b: i8, out c: i8) {
%res = comb.add %a, %b : i8
hw.output %res : i8
}

func.func @main() {
%two = arith.constant 2 : i8
%three = arith.constant 3 : i8

arc.sim.instantiate @adder as %model {
arc.sim.set_input %model, "a" = %two : i8, !arc.sim.instance<@adder>
arc.sim.set_input %model, "b" = %three : i8, !arc.sim.instance<@adder>

arc.sim.step %model : !arc.sim.instance<@adder>

%res = arc.sim.get_port %model, "c" : i8, !arc.sim.instance<@adder>
arc.sim.emit "output", %res : i8
}

return
}
50 changes: 50 additions & 0 deletions integration_test/arcilator/JIT/counter.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// RUN: arcilator %s --run --jit-entry=main | FileCheck %s
// REQUIRES: arcilator-jit

// CHECK: counter_value = 0
// CHECK-NEXT: counter_value = 1
// CHECK-NEXT: counter_value = 2
// CHECK-NEXT: counter_value = 3
// CHECK-NEXT: counter_value = 4
// CHECK-NEXT: counter_value = 5
// CHECK-NEXT: counter_value = 6
// CHECK-NEXT: counter_value = 7
// CHECK-NEXT: counter_value = 8
// CHECK-NEXT: counter_value = 9
// CHECK-NEXT: counter_value = a

hw.module @counter(in %clk: i1, out o: i8) {
%seq_clk = seq.to_clock %clk

%reg = seq.compreg %added, %seq_clk : i8

%one = hw.constant 1 : i8
%added = comb.add %reg, %one : i8

hw.output %reg : i8
}

func.func @main() {
%zero = arith.constant 0 : i1
%one = arith.constant 1 : i1
%lb = arith.constant 0 : index
%ub = arith.constant 10 : index
%step = arith.constant 1 : index

arc.sim.instantiate @counter as %model {
%init_val = arc.sim.get_port %model, "o" : i8, !arc.sim.instance<@counter>
arc.sim.emit "counter_value", %init_val : i8

scf.for %i = %lb to %ub step %step {
arc.sim.set_input %model, "clk" = %one : i1, !arc.sim.instance<@counter>
arc.sim.step %model : !arc.sim.instance<@counter>
arc.sim.set_input %model, "clk" = %zero : i1, !arc.sim.instance<@counter>
arc.sim.step %model : !arc.sim.instance<@counter>

%counter_val = arc.sim.get_port %model, "o" : i8, !arc.sim.instance<@counter>
arc.sim.emit "counter_value", %counter_val : i8
}
}

return
}
8 changes: 8 additions & 0 deletions integration_test/arcilator/JIT/err-not-found.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: ! (arcilator %s --run --jit-entry=unknown 2> %t) && FileCheck --input-file=%t %s
// REQUIRES: arcilator-jit

// CHECK: entry point not found: 'unknown'

func.func @main() {
return
}
6 changes: 6 additions & 0 deletions integration_test/arcilator/JIT/err-not-func.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// RUN: ! (arcilator %s --run --jit-entry=foo 2> %t) && FileCheck --input-file=%t %s
// REQUIRES: arcilator-jit

// CHECK: entry point 'foo' was found but on an operation of type 'llvm.mlir.global' while an LLVM function was expected

llvm.mlir.global @foo(0 : i32) : i32
8 changes: 8 additions & 0 deletions integration_test/arcilator/JIT/err-wrong-func.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: ! (arcilator %s --run --jit-entry=main 2> %t) && FileCheck --input-file=%t %s
// REQUIRES: arcilator-jit

// CHECK: entry point 'main' must have no arguments

func.func @main(%a: i32) {
return
}
10 changes: 10 additions & 0 deletions integration_test/arcilator/JIT/print.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: arcilator %s --run | FileCheck %s
// REQUIRES: arcilator-jit

// CHECK: result = 4

func.func @entry() {
%four = arith.constant 4 : i32
arc.sim.emit "result", %four : i32
return
}
6 changes: 5 additions & 1 deletion integration_test/lit.cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
config.llvm_tools_dir
]
tools = [
'circt-opt', 'circt-translate', 'firtool', 'circt-rtl-sim.py',
'arcilator', 'circt-opt', 'circt-translate', 'firtool', 'circt-rtl-sim.py',
'equiv-rtl.sh', 'handshake-runner', 'hlstool', 'ibistool'
]

Expand Down Expand Up @@ -206,6 +206,10 @@
config.available_features.add('slang')
tools.append('circt-verilog')

# Add arcilator JIT if MLIR's execution engine is enabled.
if config.arcilator_jit_enabled:
config.available_features.add('arcilator-jit')

config.substitutions.append(('%driver', f'{config.driver}'))
llvm_config.add_tool_substitutions(tools, tool_dirs)

Expand Down
1 change: 1 addition & 0 deletions integration_test/lit.site.cfg.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ config.bindings_python_enabled = @CIRCT_BINDINGS_PYTHON_ENABLED@
config.bindings_tcl_enabled = @CIRCT_BINDINGS_TCL_ENABLED@
config.lec_enabled = "@CIRCT_LEC_ENABLED@"
config.slang_frontend_enabled = "@CIRCT_SLANG_FRONTEND_ENABLED@"
config.arcilator_jit_enabled = @ARCILATOR_JIT_ENABLED@
config.driver = "@CIRCT_SOURCE_DIR@/tools/circt-rtl-sim/driver.cpp"

# Support substitution of the tools_dir with user parameters. This is
Expand Down
37 changes: 31 additions & 6 deletions lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,8 @@ struct SimInstantiateOpLowering

ConversionPatternRewriter::InsertionGuard guard(rewriter);

// FIXME: like the rest of MLIR, this assumes sizeof(intptr_t) ==
// sizeof(size_t) on the target architecture.
Type convertedIndex = typeConverter->convertType(rewriter.getIndexType());

LLVM::LLVMFuncOp mallocFunc =
Expand Down Expand Up @@ -460,8 +462,9 @@ struct SimStepOpLowering : public ModelAwarePattern<arc::SimStepOp> {
}
};

/// Lowers SimEmitValueOp to a printf call. This pattern will mutate the global
/// module.
/// Lowers SimEmitValueOp to a printf call. The integer will be printed in its
/// entirety if it is of size up to size_t, and explicitly truncated otherwise.
/// This pattern will mutate the global module.
struct SimEmitValueOpLowering
: public OpConversionPattern<arc::SimEmitValueOp> {
using OpConversionPattern::OpConversionPattern;
Expand All @@ -475,27 +478,49 @@ struct SimEmitValueOpLowering

Location loc = op.getLoc();

Value toPrint = rewriter.create<LLVM::IntToPtrOp>(
loc, LLVM::LLVMPointerType::get(getContext()), adaptor.getValue());

ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();

// Cast the value to a size_t.
// FIXME: like the rest of MLIR, this assumes sizeof(intptr_t) ==
// sizeof(size_t) on the target architecture.
Value toPrint = adaptor.getValue();
DataLayout layout = DataLayout::closest(op);
llvm::TypeSize sizeOfSizeT =
layout.getTypeSizeInBits(rewriter.getIndexType());
assert(!sizeOfSizeT.isScalable() &&
sizeOfSizeT.getFixedValue() <= std::numeric_limits<unsigned>::max());
bool truncated = false;
if (valueType.getWidth() > sizeOfSizeT) {
toPrint = rewriter.create<LLVM::TruncOp>(
loc, IntegerType::get(getContext(), sizeOfSizeT.getFixedValue()),
toPrint);
truncated = true;
} else if (valueType.getWidth() < sizeOfSizeT)
toPrint = rewriter.create<LLVM::ZExtOp>(
loc, IntegerType::get(getContext(), sizeOfSizeT.getFixedValue()),
toPrint);

// Lookup of create printf function symbol.
auto printfFunc = LLVM::lookupOrCreateFn(
moduleOp, "printf", LLVM::LLVMPointerType::get(getContext()),
LLVM::LLVMVoidType::get(getContext()), true);

// Insert the format string if not already available.
SmallString<16> formatStrName{"_arc_sim_emit_"};
formatStrName.append(truncated ? "trunc_" : "full_");
formatStrName.append(adaptor.getValueName());
LLVM::GlobalOp formatStrGlobal;
if (!(formatStrGlobal =
moduleOp.lookupSymbol<LLVM::GlobalOp>(formatStrName))) {
ConversionPatternRewriter::InsertionGuard insertGuard(rewriter);

SmallString<16> formatStr = adaptor.getValueName();
formatStr.append(" = %0.8p\n");
formatStr.append(" = ");
if (truncated)
formatStr.append("(truncated) ");
formatStr.append("%zx\n");
SmallVector<char> formatStrVec{formatStr.begin(), formatStr.end()};
formatStrVec.push_back(0);

Expand Down
67 changes: 42 additions & 25 deletions test/Dialect/Arc/lower-sim.mlir
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
// RUN: arcilator %s --emit-mlir | FileCheck %s

hw.module @id(in %i: i8, in %j: i8, out o: i8) {
module attributes { dlti.dl_spec = #dlti.dl_spec<
#dlti.dl_entry<index, 16>
> } {
hw.module @id(in %i: i8, in %j: i8, out o: i8) {
hw.output %i : i8
}
}

// CHECK-DAG: llvm.mlir.global internal constant @[[format_str:.*]]("result = %zx\0A\00")
// CHECK-DAG: llvm.mlir.global internal constant @[[format_str2:.*]]("result2 = %zx\0A\00")
// CHECK-DAG: llvm.mlir.global internal constant @[[format_str_trunc:.*]]("result = (truncated) %zx\0A\00")

// CHECK-DAG: llvm.mlir.global internal constant @[[format_str:.*]]("result = %0.8p\0A\00")
// CHECK-DAG: llvm.mlir.global internal constant @[[format_str2:.*]]("result2 = %0.8p\0A\00")
// CHECK-LABEL: llvm.func @full
func.func @full() {
// CHECK-LABEL: llvm.func @full
func.func @full() {
%c = arith.constant 24 : i8

// CHECK-DAG: %[[c:.*]] = llvm.mlir.constant(24 : i8)
Expand All @@ -16,33 +21,45 @@ func.func @full() {
// CHECK-DAG: %[[state:.*]] = llvm.call @malloc(%[[size:.*]]) :
// CHECK: "llvm.intr.memset"(%[[state]], %[[zero]], %[[size]]) <{isVolatile = false}>
arc.sim.instantiate @id as %model {
// CHECK-NEXT: llvm.store %[[c]], %[[state]] : i8
arc.sim.set_input %model, "i" = %c : i8, !arc.sim.instance<@id>
// CHECK-NEXT: llvm.store %[[c]], %[[state]] : i8
arc.sim.set_input %model, "i" = %c : i8, !arc.sim.instance<@id>

// CHECK-NEXT: %[[j_ptr:.*]] = llvm.getelementptr %[[state]][1] : (!llvm.ptr) -> !llvm.ptr, i8
// CHECK-NEXT: llvm.store %[[c]], %[[j_ptr]] : i8
arc.sim.set_input %model, "j" = %c : i8, !arc.sim.instance<@id>
// CHECK-NEXT: %[[j_ptr:.*]] = llvm.getelementptr %[[state]][1] : (!llvm.ptr) -> !llvm.ptr, i8
// CHECK-NEXT: llvm.store %[[c]], %[[j_ptr]] : i8
arc.sim.set_input %model, "j" = %c : i8, !arc.sim.instance<@id>

// CHECK-NEXT: llvm.call @id_eval(%[[state]])
arc.sim.step %model : !arc.sim.instance<@id>
// CHECK-NEXT: llvm.call @id_eval(%[[state]])
arc.sim.step %model : !arc.sim.instance<@id>

// CHECK-NEXT: %[[o_ptr:.*]] = llvm.getelementptr %[[state]][2] : (!llvm.ptr) -> !llvm.ptr, i8
// CHECK-NEXT: %[[result:.*]] = llvm.load %[[o_ptr]] : !llvm.ptr -> i8
%result = arc.sim.get_port %model, "o" : i8, !arc.sim.instance<@id>
// CHECK-NEXT: %[[o_ptr:.*]] = llvm.getelementptr %[[state]][2] : (!llvm.ptr) -> !llvm.ptr, i8
// CHECK-NEXT: %[[result:.*]] = llvm.load %[[o_ptr]] : !llvm.ptr -> i8
%result = arc.sim.get_port %model, "o" : i8, !arc.sim.instance<@id>

// CHECK-DAG: %[[to_print:.*]] = llvm.inttoptr %[[result]] : i8 to !llvm.ptr
// CHECK-DAG: %[[format_str_ptr:.*]] = llvm.mlir.addressof @[[format_str]] : !llvm.ptr
// CHECK: llvm.call @printf(%[[format_str_ptr]], %[[to_print]])
arc.sim.emit "result", %result : i8
// CHECK-DAG: %[[to_print:.*]] = llvm.zext %[[result]] : i8 to i16
// CHECK-DAG: %[[format_str_ptr:.*]] = llvm.mlir.addressof @[[format_str]] : !llvm.ptr
// CHECK: llvm.call @printf(%[[format_str_ptr]], %[[to_print]])
arc.sim.emit "result", %result : i8

// CHECK-DAG: %[[format_str2_ptr:.*]] = llvm.mlir.addressof @[[format_str2]] : !llvm.ptr
// CHECK: llvm.call @printf(%[[format_str2_ptr]], %[[to_print]])
arc.sim.emit "result2", %result : i8
// CHECK-DAG: %[[format_str2_ptr:.*]] = llvm.mlir.addressof @[[format_str2]] : !llvm.ptr
// CHECK: llvm.call @printf(%[[format_str2_ptr]], %[[to_print]])
arc.sim.emit "result2", %result : i8

// CHECK: llvm.call @printf(%[[format_str_ptr]], %[[to_print]])
arc.sim.emit "result", %result : i8
// CHECK: llvm.call @printf(%[[format_str_ptr]], %[[to_print]])
arc.sim.emit "result", %result : i8
}
// CHECK: llvm.call @free(%[[state]])

return
}

// CHECK-LABEL: llvm.func @trunc
func.func @trunc() {
%v = arith.constant 0 : i32
// CHECK-DAG: %[[val_i32:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-DAG: %[[val_truncated:.*]] = llvm.trunc %[[val_i32]] : i32 to i16
// CHECK-DAG: %[[format_str_trunc_ptr:.*]] = llvm.mlir.addressof @[[format_str_trunc]] : !llvm.ptr
// CHECK-DAG: llvm.call @printf(%[[format_str_trunc_ptr]], %[[val_truncated]])
arc.sim.emit "result", %v : i32
return
}
}
11 changes: 10 additions & 1 deletion tools/arcilator/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
set(LLVM_LINK_COMPONENTS Support)
if(ARCILATOR_JIT_ENABLED)
add_compile_definitions(ARCILATOR_ENABLE_JIT)
set(ARCILATOR_JIT_LLVM_COMPONENTS native)
set(ARCILATOR_JIT_DEPS MLIRExecutionEngine)
endif()

set(LLVM_LINK_COMPONENTS Support ${ARCILATOR_JIT_LLVM_COMPONENTS})

add_circt_tool(arcilator arcilator.cpp)
target_link_libraries(arcilator
Expand All @@ -16,11 +22,14 @@ target_link_libraries(arcilator
CIRCTSupport
CIRCTTransforms
MLIRBuiltinToLLVMIRTranslation
MLIRDLTIDialect
MLIRFuncInlinerExtension
MLIRLLVMIRTransforms
MLIRLLVMToLLVMIRTranslation
MLIRParser
MLIRTargetLLVMIRExport

${ARCILATOR_JIT_DEPS}
)

llvm_update_compile_flags(arcilator)
Expand Down
Loading

7 comments on commit 83a8292

@Moxinilian
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am once again completely confused by the result of the Windows build...

@fzi-hielscher
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no explanation for this either but after a little shuffling of includes it appears to be some absurd interaction between #include "mlir/ExecutionEngine/ExecutionEngine.h" and #include "circt/InitAllPasses.h". Untangling this is a bit of a nightmare, so I can just speculate in the general direction of namespace pollution. Since we are not actually calling registerAllPasses() we can probably sidestep this by only including the passes we really need.

@Moxinilian
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that's interesting, thanks for the help. Right now I am trying to not build the JIT infra on Windows, but I should probably try what you suggested.

@fzi-hielscher
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Moxinilian
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have dropped InitAllPasses from Arcilator for now, I have a CI build test running to see if it fixes the issue.

@Moxinilian
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately dropping it did not resolve the issue. What is your suggestion exactly @fzi-hielscher?

@fzi-hielscher
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would hope that this solves the problem: #6844

Please sign in to comment.