diff --git a/.github/workflows/bench-hyperfine.yml b/.github/workflows/bench-hyperfine.yml index 02a123f3d..42f5efaff 100644 --- a/.github/workflows/bench-hyperfine.yml +++ b/.github/workflows/bench-hyperfine.yml @@ -58,15 +58,11 @@ jobs: - name: Install hyperfine uses: taiki-e/install-action@v2 with: - tool: hyperfine@1.16 + tool: hyperfine@1.18 - name: Install deps run: make deps - - name: Build project - run: make build - - name: Build runtime subproject - run: make runtime-ci - name: Run benchmarks - run: ./scripts/bench-hyperfine.sh programs/benches/*.cairo + run: make bench - name: Create markdown file run: bash .github/scripts/merge-benches.sh diff --git a/.github/workflows/starknet-blocks.yml b/.github/workflows/starknet-blocks.yml index 99dcf3160..99665cace 100644 --- a/.github/workflows/starknet-blocks.yml +++ b/.github/workflows/starknet-blocks.yml @@ -26,6 +26,8 @@ jobs: with: components: clippy - uses: Swatinem/rust-cache@v2 + with: + key: "ref-dc35685315f4df544d5d1cf006d3a2a25d8c2c9a" - name: Check and free hdd space left if: ${{ matrix.runner == 'native' }} @@ -62,7 +64,7 @@ jobs: run: make deps - name: Build Cairo Native project if: ${{ matrix.runner == 'native' }} - run: cargo b --release --all-features + run: cargo b --release - name: Build runtime if: ${{ matrix.runner == 'native' }} run: | @@ -74,7 +76,7 @@ jobs: uses: actions/checkout@v4 with: repository: lambdaclass/starknet-replay - ref: 08aa133b11e1e319036354c9acd8270483c844b5 + ref: dc35685315f4df544d5d1cf006d3a2a25d8c2c9a path: replay - name: Install Starknet Replay deps @@ -116,6 +118,22 @@ jobs: steps: - uses: actions/checkout@v4 + - name: Check and free hdd space left + if: ${{ matrix.runner == 'native' }} + run: | + sudo apt-get update + sudo apt-get remove -y '^llvm-.*' + sudo apt-get remove -y 'php.*' + sudo apt-get remove -y '^dotnet-.*' + sudo apt-get remove -y '^temurin-.*' + sudo apt-get remove -y azure-cli microsoft-edge-stable google-chrome-stable firefox mono-devel + sudo apt-get autoremove -y + sudo apt-get clean + echo "Removing large directories" + # deleting 15GB + sudo rm -rf /usr/share/dotnet/ + sudo rm -rf /usr/local/lib/android + - name: Fetch Native dumps uses: actions/download-artifact@v4 with: diff --git a/Cargo.lock b/Cargo.lock index 656944bdc..54aafac03 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -987,6 +987,7 @@ dependencies = [ "num-traits 0.2.19", "pretty_assertions_sorted", "proptest", + "rayon", "rstest", "scarb-metadata", "scarb-ui", diff --git a/Cargo.toml b/Cargo.toml index 86d28cac7..cd5988623 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -131,6 +131,7 @@ rstest = "0.23.0" test-case = "3.3" walkdir = "2.5.0" serde_json = { version = "1.0.128" } +rayon = "1.10.0" [build-dependencies] cc = "1.1.28" diff --git a/Makefile b/Makefile index 217091874..6c6c75318 100644 --- a/Makefile +++ b/Makefile @@ -97,6 +97,7 @@ doc-open: check-llvm .PHONY: bench bench: needs-cairo2 runtime cargo b --release --bin cairo-native-run + cargo b --release --bin cairo-native-compile ./scripts/bench-hyperfine.sh .PHONY: bench-ci diff --git a/programs/benches/factorial_2M.c b/programs/benches/factorial_2M.c index 7db0cfcc6..992f8648d 100644 --- a/programs/benches/factorial_2M.c +++ b/programs/benches/factorial_2M.c @@ -10,23 +10,28 @@ typedef struct factorial_return_values uint8_t discriminant; struct { void* ptr; - uint32_t len; + uint32_t start; + uint32_t end; uint32_t cap; } err; } result; } factorial_return_values_t; - static void run_bench(factorial_return_values_t*, uint64_t) __attribute__((weakref("_mlir_ciface_factorial_2M::factorial_2M::main(f1)"))); +extern uint64_t* cairo_native__set_costs_builtin(uint64_t*); int main() { factorial_return_values_t return_values; + uint64_t BuiltinCosts[7] = {1, 4050, 583, 4085, 491, 230, 604}; + + cairo_native__set_costs_builtin(&BuiltinCosts[0]); + run_bench(&return_values, 0); - assert(return_values.result.discriminant == 0); + assert((return_values.result.discriminant & 0x1) == 0); return 0; } diff --git a/programs/benches/fib_2M.c b/programs/benches/fib_2M.c index fecb87cb3..ffbec9785 100644 --- a/programs/benches/fib_2M.c +++ b/programs/benches/fib_2M.c @@ -10,23 +10,28 @@ typedef struct fib_return_values uint8_t discriminant; struct { void *ptr; - uint32_t len; + uint32_t start; + uint32_t end; uint32_t cap; } err; } result; } fib_return_values_t; - static void run_bench(fib_return_values_t *, uint64_t) __attribute__((weakref("_mlir_ciface_fib_2M::fib_2M::main(f1)"))); +extern uint64_t* cairo_native__set_costs_builtin(uint64_t*); int main() { + uint64_t BuiltinCosts[7] = {1, 4050, 583, 4085, 491, 230, 604}; + + cairo_native__set_costs_builtin(&BuiltinCosts[0]); + fib_return_values_t return_values; run_bench(&return_values, 0); - assert(return_values.result.discriminant == 0); + assert((return_values.result.discriminant & 0x1) == 0); return 0; } diff --git a/programs/benches/logistic_map.c b/programs/benches/logistic_map.c index 1294dcdbf..5d88dfb89 100644 --- a/programs/benches/logistic_map.c +++ b/programs/benches/logistic_map.c @@ -10,23 +10,28 @@ typedef struct map_return_values uint8_t discriminant; struct { void *ptr; - uint32_t len; + uint32_t start; + uint32_t end; uint32_t cap; } err; } result; } map_return_values_t; - static void run_bench(map_return_values_t *, uint64_t) __attribute__((weakref("_mlir_ciface_logistic_map::logistic_map::main(f2)"))); +extern uint64_t* cairo_native__set_costs_builtin(uint64_t*); int main() { + uint64_t BuiltinCosts[7] = {1, 4050, 583, 4085, 491, 230, 604}; + + cairo_native__set_costs_builtin(&BuiltinCosts[0]); + map_return_values_t return_values; run_bench(&return_values, 0); - assert(return_values.result.discriminant == 0); + assert((return_values.result.discriminant & 0x1) == 0); return 0; } diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index 5e2c75ae7..4ab674b7b 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -14,7 +14,8 @@ use starknet_types_core::{ hash::StarkHash, }; use std::{ - collections::HashMap, ffi::c_void, fs::File, io::Write, mem::ManuallyDrop, os::fd::FromRawFd, + cell::Cell, collections::HashMap, ffi::c_void, fs::File, io::Write, mem::ManuallyDrop, + os::fd::FromRawFd, ptr::null, }; use std::{ops::Mul, vec::IntoIter}; @@ -475,6 +476,32 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_try_finalize_nz( } } +thread_local! { + // We can use cell because a ptr is copy. + static BUILTIN_COSTS: Cell<*const u64> = const { + Cell::new(null()) + }; +} + +/// Store the gas builtin in the internal thread local. Returns the old pointer, to restore it after execution. +/// Not a runtime metadata method, it should be called before the program is executed. +#[no_mangle] +pub extern "C" fn cairo_native__set_costs_builtin(ptr: *const u64) -> *const u64 { + let old = BUILTIN_COSTS.get(); + BUILTIN_COSTS.set(ptr); + old +} + +/// Get the gas builtin from the internal thread local. +#[no_mangle] +pub extern "C" fn cairo_native__get_costs_builtin() -> *const u64 { + if BUILTIN_COSTS.get().is_null() { + // We shouldn't panic here, but we can print a big message. + eprintln!("BUILTIN_COSTS POINTER IS NULL!"); + } + BUILTIN_COSTS.get() +} + /// Utility methods for the print runtime function /// Formats the given felts as a debug string. diff --git a/scripts/bench-hyperfine.sh b/scripts/bench-hyperfine.sh index 180120c79..473c67265 100755 --- a/scripts/bench-hyperfine.sh +++ b/scripts/bench-hyperfine.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash # Configuration. -ROOT_DIR="$(dirname "$(dirname "${0%/*}")")" +ROOT_DIR="$(dirname "$(readlink -f "${0%/*}")")" MLIR_DIR="$MLIR_SYS_190_PREFIX" CAIRO_SRCS=$(find \ @@ -10,7 +10,7 @@ CAIRO_SRCS=$(find \ IFS=$'\n' read -rd '' -a CAIRO_SRCS <<<"$CAIRO_SRCS" CAIRO_RUN="$ROOT_DIR/cairo2/bin/cairo-run" -COMPILER_CLI="$ROOT_DIR/target/release/cairo-native-dump" +COMPILER_CLI="$ROOT_DIR/target/release/cairo-native-compile" JIT_CLI="$ROOT_DIR/target/release/cairo-native-run" OUTPUT_DIR="$ROOT_DIR/target/bench-outputs" @@ -42,51 +42,24 @@ run_bench() { base_name=$(basename $base_path) "$COMPILER_CLI" \ - "$base_path.cairo" \ - --output "$OUTPUT_DIR/$base_name.mlir" \ - >> /dev/stderr - - "$MLIR_DIR/bin/mlir-opt" \ - --canonicalize \ - --convert-scf-to-cf \ - --canonicalize \ - --cse \ - --expand-strided-metadata \ - --finalize-memref-to-llvm \ - --convert-func-to-llvm \ - --convert-index-to-llvm \ - --reconcile-unrealized-casts \ + -s "$base_path.cairo" \ "$OUTPUT_DIR/$base_name.mlir" \ - -o "$OUTPUT_DIR/$base_name.opt.mlir" \ - >> /dev/stderr - - "$MLIR_DIR/bin/mlir-translate" \ - --mlir-to-llvmir \ - "$OUTPUT_DIR/$base_name.opt.mlir" \ - -o "$OUTPUT_DIR/$base_name.ll" \ - >> /dev/stderr - - "$MLIR_DIR/bin/clang" \ - -O3 \ - -Wno-override-module \ - "$base_path.c" \ - "$OUTPUT_DIR/$base_name.ll" \ - -L "target/release" \ - -Wl,-rpath "$MLIR_DIR/lib" \ - -Wl,-rpath "target/release" \ - -o "$OUTPUT_DIR/$base_name" \ + "$OUTPUT_DIR/lib$base_name.so" \ >> /dev/stderr "$MLIR_DIR/bin/clang" \ -O3 \ -march=native \ -mtune=native \ + -fPIC \ -Wno-override-module \ "$base_path.c" \ - "$OUTPUT_DIR/$base_name.ll" \ - -L "target/release" \ + -L"$OUTPUT_DIR/" \ -Wl,-rpath "$MLIR_DIR/lib" \ - -Wl,-rpath "target/release" \ + -Wl,-rpath "$OUTPUT_DIR" \ + -Wl,--rpath-link "$OUTPUT_DIR" \ + -l"$base_name" \ + -lm \ -o "$OUTPUT_DIR/$base_name-march-native" \ >> /dev/stderr @@ -97,7 +70,6 @@ run_bench() { -n "Cairo-vm (Rust, Cairo 1)" "$CAIRO_RUN --available-gas 18446744073709551615 -s $base_path.cairo" \ -n "cairo-native (embedded AOT)" "$JIT_CLI --run-mode=aot -s $base_path.cairo --opt-level 3 --available-gas 18446744073709551615 " \ -n "cairo-native (embedded JIT using LLVM's ORC Engine)" "$JIT_CLI --run-mode=jit -s $base_path.cairo --opt-level 3 --available-gas 18446744073709551615 " \ - -n "cairo-native (standalone AOT)" "$OUTPUT_DIR/$base_name" \ -n "cairo-native (standalone AOT with -march=native)" "$OUTPUT_DIR/$base_name-march-native" \ >> /dev/stderr } diff --git a/scripts/diff-check.sh b/scripts/diff-check.sh index 5350934be..752394936 100755 --- a/scripts/diff-check.sh +++ b/scripts/diff-check.sh @@ -15,11 +15,9 @@ for vm_dump in state_dumps/vm/*/*.json; do continue fi - base=$(basename "$vm_dump") - if ! cmp -s \ - <(sed '/"reverted": /d' "$native_dump") \ - <(sed '/"reverted": /d' "$vm_dump") + <(sed '/"reverted": /d' "$native_dump" 2>/dev/null) \ + <(sed '/"reverted": /d' "$vm_dump" 2>/dev/null) then echo "NATIVE DIFFING IN TX: $native_dump" diffing=1 diff --git a/src/compiler.rs b/src/compiler.rs index f0efa38d4..e6d217516 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -446,7 +446,7 @@ fn compile_func( initial_state, |statement_idx, mut state| { if let Some(gas_metadata) = metadata.get::() { - let gas_cost = gas_metadata.get_gas_cost_for_statement(statement_idx); + let gas_cost = gas_metadata.get_gas_costs_for_statement(statement_idx); metadata.remove::(); metadata.insert(GasCost(gas_cost)); } diff --git a/src/error.rs b/src/error.rs index c58b2fa57..c093488c6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -69,6 +69,12 @@ pub enum Error { #[error("integer conversion failed")] IntegerConversion, + #[error("missing BuiltinCosts global symbol, should never happen, this is a bug")] + MissingBuiltinCostsSymbol, + + #[error("selector not found in the AotContractExecutor mappings")] + SelectorNotFound, + #[error(transparent)] IoError(#[from] std::io::Error), diff --git a/src/executor.rs b/src/executor.rs index 10f8549c7..0553f620f 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -10,7 +10,7 @@ use crate::{ execution_result::{BuiltinStats, ExecutionResult}, starknet::{handler::StarknetSyscallHandlerCallbacks, StarknetSyscallHandler}, types::TypeBuilder, - utils::{libc_free, RangeExt}, + utils::{libc_free, BuiltinCosts, RangeExt}, values::Value, }; use bumpalo::Bump; @@ -69,6 +69,7 @@ extern "C" { fn invoke_dynamic( registry: &ProgramRegistry, function_ptr: *const c_void, + set_builtin_costs_fnptr: extern "C" fn(*const u64) -> *const u64, function_signature: &FunctionSignature, args: &[Value], gas: u128, @@ -141,6 +142,15 @@ fn invoke_dynamic( previous_syscall_handler }); + // Order matters, for the libfunc impl + let builtin_costs_stack: [u64; 7] = BuiltinCosts::default().into(); + // Note: the ptr into a slice is valid, it can be used with cast() + // Care should be taken if you dereference it and take the .as_ptr() of the slice, since when you + // deref it, it will be a copy on the stack, so you will get the ptr of the value in the stack. + let builtin_costs: *mut [u64; 7] = Box::into_raw(Box::new(builtin_costs_stack)); + // We may be inside a recursive contract, save the possible saved builtin costs to restore it after our call. + let old_builtincosts_ptr = set_builtin_costs_fnptr(builtin_costs.cast()); + // Generate argument list. let mut iter = args.iter(); for item in function_signature.param_types.iter().filter_map(|type_id| { @@ -166,6 +176,9 @@ fn invoke_dynamic( (syscall_handler as *mut StarknetSyscallHandlerCallbacks<_>) .to_bytes(&mut invoke_data)?; } + CoreTypeConcrete::BuiltinCosts(_) => { + builtin_costs.to_bytes(&mut invoke_data)?; + } type_info if type_info.is_builtin() => 0u64.to_bytes(&mut invoke_data)?, type_info => ValueWithInfoWrapper { value: iter.next().unwrap(), @@ -250,26 +263,38 @@ fn invoke_dynamic( } _ if type_info.is_builtin() => { if !type_info.is_zst(registry)? { - let value = match &mut return_ptr { - Some(return_ptr) => unsafe { *read_value::(return_ptr) }, - None => ret_registers[0], - } as usize; - - match type_info { - CoreTypeConcrete::Bitwise(_) => builtin_stats.bitwise = value, - CoreTypeConcrete::EcOp(_) => builtin_stats.ec_op = value, - CoreTypeConcrete::RangeCheck(_) => builtin_stats.range_check = value, - CoreTypeConcrete::Pedersen(_) => builtin_stats.pedersen = value, - CoreTypeConcrete::Poseidon(_) => builtin_stats.poseidon = value, - CoreTypeConcrete::SegmentArena(_) => builtin_stats.segment_arena = value, - CoreTypeConcrete::RangeCheck96(_) => builtin_stats.range_check_96 = value, - CoreTypeConcrete::Circuit(CircuitTypeConcrete::AddMod(_)) => { - builtin_stats.circuit_add = value + if let CoreTypeConcrete::BuiltinCosts(_) = type_info { + // todo: should we use this value? + let _value = match &mut return_ptr { + Some(return_ptr) => unsafe { *read_value::<*mut u64>(return_ptr) }, + None => ret_registers[0] as *mut u64, + }; + } else { + let value = match &mut return_ptr { + Some(return_ptr) => unsafe { *read_value::(return_ptr) }, + None => ret_registers[0], + } as usize; + + match type_info { + CoreTypeConcrete::Bitwise(_) => builtin_stats.bitwise = value, + CoreTypeConcrete::EcOp(_) => builtin_stats.ec_op = value, + CoreTypeConcrete::RangeCheck(_) => builtin_stats.range_check = value, + CoreTypeConcrete::Pedersen(_) => builtin_stats.pedersen = value, + CoreTypeConcrete::Poseidon(_) => builtin_stats.poseidon = value, + CoreTypeConcrete::SegmentArena(_) => { + builtin_stats.segment_arena = value + } + CoreTypeConcrete::RangeCheck96(_) => { + builtin_stats.range_check_96 = value + } + CoreTypeConcrete::Circuit(CircuitTypeConcrete::AddMod(_)) => { + builtin_stats.circuit_add = value + } + CoreTypeConcrete::Circuit(CircuitTypeConcrete::MulMod(_)) => { + builtin_stats.circuit_mul = value + } + _ => unreachable!("{type_id:?}"), } - CoreTypeConcrete::Circuit(CircuitTypeConcrete::MulMod(_)) => { - builtin_stats.circuit_mul = value - } - _ => unreachable!("{type_id:?}"), } } } @@ -299,6 +324,15 @@ fn invoke_dynamic( debug_name: None, }); + // Restore the old ptr and get back our builtincost box and free it. + let our_builtincosts_ptr = set_builtin_costs_fnptr(old_builtincosts_ptr); + + if !our_builtincosts_ptr.is_null() && old_builtincosts_ptr.is_aligned() { + unsafe { + let _ = Box::<[u64; 7]>::from_raw(our_builtincosts_ptr.cast_mut().cast()); + }; + } + #[cfg(feature = "with-mem-tracing")] crate::utils::mem_tracing::report_stats(); diff --git a/src/executor/aot.rs b/src/executor/aot.rs index cac2122e2..41b1565e2 100644 --- a/src/executor/aot.rs +++ b/src/executor/aot.rs @@ -55,13 +55,17 @@ impl AotNativeExecutor { mut metadata, } = module; - let library_path = NamedTempFile::new().unwrap().into_temp_path(); + let library_path = NamedTempFile::new() + .unwrap() + .into_temp_path() + .keep() + .unwrap(); let object_data = crate::module_to_object(&module, opt_level).unwrap(); crate::object_to_shared_lib(&object_data, &library_path).unwrap(); Self { - library: unsafe { Library::new(library_path).unwrap() }, + library: unsafe { Library::new(&library_path).unwrap() }, registry, gas_metadata: metadata.remove().unwrap(), } @@ -78,9 +82,21 @@ impl AotNativeExecutor { .get_initial_available_gas(function_id, gas) .map_err(crate::error::Error::GasMetadataError)?; + let set_costs_builtin: extern "C" fn(*const u64) -> *const u64 = unsafe { + std::mem::transmute( + self.library + .get:: *const u64>( + b"cairo_native__set_costs_builtin", + )? + .into_raw() + .into_raw(), + ) + }; + super::invoke_dynamic( &self.registry, self.find_function_ptr(function_id), + set_costs_builtin, self.extract_signature(function_id), args, available_gas, @@ -100,9 +116,21 @@ impl AotNativeExecutor { .get_initial_available_gas(function_id, gas) .map_err(crate::error::Error::GasMetadataError)?; + let set_costs_builtin: extern "C" fn(*const u64) -> *const u64 = unsafe { + std::mem::transmute( + self.library + .get:: *const u64>( + b"cairo_native__set_costs_builtin", + )? + .into_raw() + .into_raw(), + ) + }; + super::invoke_dynamic( &self.registry, self.find_function_ptr(function_id), + set_costs_builtin, self.extract_signature(function_id), args, available_gas, @@ -122,15 +150,26 @@ impl AotNativeExecutor { .get_initial_available_gas(function_id, gas) .map_err(crate::error::Error::GasMetadataError)?; + let set_costs_builtin: extern "C" fn(*const u64) -> *const u64 = unsafe { + std::mem::transmute( + self.library + .get:: *const u64>( + b"cairo_native__set_costs_builtin", + )? + .into_raw() + .into_raw(), + ) + }; + ContractExecutionResult::from_execution_result(super::invoke_dynamic( &self.registry, self.find_function_ptr(function_id), + set_costs_builtin, self.extract_signature(function_id), &[Value::Struct { fields: vec![Value::Array( args.iter().cloned().map(Value::Felt252).collect(), )], - // TODO: Populate `debug_name`. debug_name: None, }], available_gas, @@ -152,6 +191,15 @@ impl AotNativeExecutor { } } + pub fn find_symbol_ptr(&self, name: &str) -> Option<*mut c_void> { + unsafe { + self.library + .get::<*mut ()>(name.as_bytes()) + .ok() + .map(|x| x.into_raw().into_raw()) + } + } + fn extract_signature(&self, function_id: &FunctionId) -> &FunctionSignature { &self.registry.get_function(function_id).unwrap().signature } diff --git a/src/executor/contract.rs b/src/executor/contract.rs index 14c4130f2..a7a09ddbf 100644 --- a/src/executor/contract.rs +++ b/src/executor/contract.rs @@ -34,14 +34,16 @@ use crate::{ arch::AbiArgument, context::NativeContext, - error::Result, + error::{Error, Result}, execution_result::{BuiltinStats, ContractExecutionResult}, executor::invoke_trampoline, + metadata::gas::GasMetadata, module::NativeModule, starknet::{handler::StarknetSyscallHandlerCallbacks, StarknetSyscallHandler}, types::TypeBuilder, utils::{ decode_error_message, generate_function_name, get_integer_layout, libc_free, libc_malloc, + BuiltinCosts, }, OptLevel, }; @@ -54,13 +56,14 @@ use cairo_lang_sierra::{ ids::FunctionId, program::Program, }; +use cairo_lang_starknet_classes::contract_class::ContractEntryPoints; use educe::Educe; use libloading::Library; use serde::{Deserialize, Serialize}; use starknet_types_core::felt::Felt; use std::{ alloc::Layout, - collections::BTreeMap, + collections::{BTreeMap, HashSet}, ffi::c_void, path::{Path, PathBuf}, ptr::NonNull, @@ -76,16 +79,29 @@ pub struct AotContractExecutor { library: Arc, path: PathBuf, is_temp_path: bool, - entry_points_info: BTreeMap, + contract_info: NativeContractInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct NativeContractInfo { + pub version: ContractInfoVersion, + pub entry_points_info: BTreeMap, + pub entry_point_selector_to_id: BTreeMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] +pub enum ContractInfoVersion { + Version0, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] -struct EntryPointInfo { - builtins: Vec, +pub struct EntryPointInfo { + pub builtins: Vec, + pub initial_cost: BTreeMap, // cost token type offset, cost } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] -enum BuiltinType { +pub enum BuiltinType { Bitwise, EcOp, RangeCheck, @@ -97,6 +113,26 @@ enum BuiltinType { CircuitMul, Gas, System, + BuiltinCosts, +} + +impl BuiltinType { + pub fn size_in_bytes(&self) -> usize { + match self { + BuiltinType::Bitwise => 8, + BuiltinType::EcOp => 8, + BuiltinType::RangeCheck => 8, + BuiltinType::SegmentArena => 8, + BuiltinType::Poseidon => 8, + BuiltinType::Pedersen => 8, + BuiltinType::RangeCheck96 => 8, + BuiltinType::CircuitAdd => 8, + BuiltinType::CircuitMul => 8, + BuiltinType::Gas => 16, + BuiltinType::System => 8, + BuiltinType::BuiltinCosts => 8, + } + } } impl AotContractExecutor { @@ -105,19 +141,47 @@ impl AotContractExecutor { /// If not saved, the path is treated as /// a temporary file an deleted when dropped. /// If you loaded a ContractExecutor using [`load`] then it will not be treated as a temp file. - pub fn new(sierra_program: &Program, opt_level: OptLevel) -> Result { + pub fn new( + sierra_program: &Program, + entry_points: &ContractEntryPoints, + opt_level: OptLevel, + ) -> Result { let native_context = NativeContext::new(); let module = native_context.compile(sierra_program, true)?; let NativeModule { module, registry, - metadata: _, + metadata, } = module; + let initial_gas_costs = { + let gas_meta: &GasMetadata = metadata.get().unwrap(); + gas_meta.initial_required_gas_for_entry_points() + }; + let mut infos = BTreeMap::new(); + let mut entry_point_selector_to_id = BTreeMap::new(); + + let mut used_function_ids = HashSet::new(); + for entry in entry_points + .constructor + .iter() + .chain(entry_points.external.iter()) + .chain(entry_points.l1_handler.iter()) + { + entry_point_selector_to_id + .insert(Felt::from(&entry.selector), entry.function_idx as u64); + used_function_ids.insert(entry.function_idx as u64); + } + for x in &sierra_program.funcs { + // Avoid storing function info for methods that are not contract entry points. + if !used_function_ids.contains(&x.id.id) { + continue; + } + let mut builtins = Vec::new(); for p in &x.params { @@ -134,6 +198,9 @@ impl AotContractExecutor { CoreTypeConcrete::RangeCheck(_) => builtins.push(BuiltinType::RangeCheck), CoreTypeConcrete::Pedersen(_) => builtins.push(BuiltinType::Pedersen), CoreTypeConcrete::Poseidon(_) => builtins.push(BuiltinType::Poseidon), + CoreTypeConcrete::BuiltinCosts(_) => { + builtins.push(BuiltinType::BuiltinCosts) + } CoreTypeConcrete::SegmentArena(_) => { builtins.push(BuiltinType::SegmentArena) } @@ -157,7 +224,13 @@ impl AotContractExecutor { } } - infos.insert(x.id.id, EntryPointInfo { builtins }); + infos.insert( + x.id.id, + EntryPointInfo { + builtins, + initial_cost: initial_gas_costs.get(&x.id.id).cloned().unwrap_or_default(), + }, + ); } let library_path = NamedTempFile::new()? @@ -172,7 +245,11 @@ impl AotContractExecutor { library: Arc::new(unsafe { Library::new(&library_path)? }), path: library_path, is_temp_path: true, - entry_points_info: infos, + contract_info: NativeContractInfo { + version: ContractInfoVersion::Version0, + entry_points_info: infos, + entry_point_selector_to_id, + }, }) } @@ -181,9 +258,9 @@ impl AotContractExecutor { let to = to.as_ref(); std::fs::copy(&self.path, to)?; - let info = serde_json::to_string(&self.entry_points_info)?; + let contract_info = serde_json::to_string(&self.contract_info)?; let path = to.with_extension("json"); - std::fs::write(path, info)?; + std::fs::write(path, contract_info)?; self.path = to.to_path_buf(); self.is_temp_path = false; @@ -194,49 +271,99 @@ impl AotContractExecutor { /// Load the executor from an already compiled library with the additional info json file. pub fn load(library_path: &Path) -> Result { let info_str = std::fs::read_to_string(library_path.with_extension("json"))?; - let info: BTreeMap = serde_json::from_str(&info_str)?; + let contract_info: NativeContractInfo = serde_json::from_str(&info_str)?; Ok(Self { library: Arc::new(unsafe { Library::new(library_path)? }), path: library_path.to_path_buf(), is_temp_path: false, - entry_points_info: info, + contract_info, }) } /// Runs the given entry point. pub fn run( &self, - function_id: &FunctionId, + selector: Felt, args: &[Felt], gas: Option, + builtin_costs: Option, mut syscall_handler: impl StarknetSyscallHandler, ) -> Result { let arena = Bump::new(); let mut invoke_data = Vec::::new(); - let function_ptr = self.find_function_ptr(function_id, true)?; + let function_id = FunctionId { + id: *self + .contract_info + .entry_point_selector_to_id + .get(&selector) + .ok_or(Error::SelectorNotFound)?, + debug_name: None, + }; + let function_ptr = self.find_function_ptr(&function_id, true)?; + + let builtin_costs = builtin_costs.unwrap_or_default(); + let builtin_costs_stack: [u64; 7] = builtin_costs.into(); + // Note: the ptr into a slice is valid, it can be used with cast() + // Care should be taken if you dereference it and take the .as_ptr() of the slice, since when you + // deref it, it will be a copy on the stack, so you will get the ptr of the value in the stack. + let builtin_costs: *mut [u64; 7] = Box::into_raw(Box::new(builtin_costs_stack)); + let set_costs_builtin = unsafe { + self.library + .get:: *const u64>( + b"cairo_native__set_costs_builtin", + )? + }; + // We may be inside a recursive contract, save the possible saved builtin costs to restore it after our call. + let old_builtincosts_ptr = set_costs_builtin(builtin_costs.cast()); + + let initial_gas_cost = { + let mut cost = 0; + + for (offset, val) in self + .contract_info + .entry_points_info + .get(&function_id.id) + .unwrap() + .initial_cost + .iter() + { + let token_cost = builtin_costs_stack[*offset as usize] * val; + cost += token_cost; + } + cost as u128 + }; + let gas = gas + .unwrap_or(initial_gas_cost) + .saturating_sub(initial_gas_cost); // it can vary from contract to contract thats why we need to store/ load it. - // substract 2, which are the gas and syscall builtin - let num_builtins = self.entry_points_info[&function_id.id].builtins.len() - 2; + let builtins_size: usize = self.contract_info.entry_points_info[&function_id.id] + .builtins + .iter() + .map(|x| x.size_in_bytes()) + .sum(); // There is always a return ptr because contracts always return more than 1 thing (builtin counters, syscall, enum) let return_ptr = arena.alloc_layout(unsafe { - // 64 = size of enum + syscall + u128 from gas builtin + 8 bytes for each additional builtin counter - // align is 16 because of the u128 - Layout::from_size_align_unchecked(64 + 8 * num_builtins, 16) + // 64 = size of enum + builtin sizes + // align is 16 because of the u128 from gas + Layout::from_size_align_unchecked(128 + builtins_size, 16) }); return_ptr.as_ptr().to_bytes(&mut invoke_data)?; let mut syscall_handler = StarknetSyscallHandlerCallbacks::new(&mut syscall_handler); - for b in &self.entry_points_info[&function_id.id].builtins { + for b in &self.contract_info.entry_points_info[&function_id.id].builtins { match b { BuiltinType::Gas => { - let gas = gas.unwrap_or(0); gas.to_bytes(&mut invoke_data)?; } + BuiltinType::BuiltinCosts => { + // todo: check if valid + builtin_costs_stack.as_ptr().to_bytes(&mut invoke_data)?; + } BuiltinType::System => { (&mut syscall_handler as *mut StarknetSyscallHandlerCallbacks<_>) .to_bytes(&mut invoke_data)?; @@ -311,14 +438,17 @@ impl AotContractExecutor { let return_ptr = &mut return_ptr.cast(); - for b in &self.entry_points_info[&function_id.id].builtins { + for b in &self.contract_info.entry_points_info[&function_id.id].builtins { match b { BuiltinType::Gas => { remaining_gas = unsafe { *read_value::(return_ptr) }; } BuiltinType::System => { - let ptr = return_ptr.cast::<*mut ()>(); - *return_ptr = unsafe { NonNull::new_unchecked(ptr.as_ptr().add(1)).cast() }; + unsafe { read_value::<*mut ()>(return_ptr) }; + } + BuiltinType::BuiltinCosts => { + unsafe { read_value::<*mut ()>(return_ptr) }; + // ptr holds the builtin costs, but they dont change, so its of no use, but we read to advance the ptr. } x => { let value = unsafe { *read_value::(return_ptr) } as usize; @@ -335,6 +465,7 @@ impl AotContractExecutor { BuiltinType::CircuitMul => builtin_stats.circuit_mul = value, BuiltinType::Gas => {} BuiltinType::System => {} + BuiltinType::BuiltinCosts => {} } } } @@ -409,6 +540,15 @@ impl AotContractExecutor { error_msg = Some(str_error); } + // Restore the old ptr and get back our builtincost box and free it. + let our_builtincosts_ptr = set_costs_builtin(old_builtincosts_ptr); + + if !our_builtincosts_ptr.is_null() && old_builtincosts_ptr.is_aligned() { + unsafe { + let _ = Box::<[u64; 7]>::from_raw(our_builtincosts_ptr.cast_mut().cast()); + }; + } + #[cfg(feature = "with-mem-tracing")] crate::utils::mem_tracing::report_stats(); @@ -436,6 +576,15 @@ impl AotContractExecutor { .into_raw() }) } + + pub fn find_symbol_ptr(&self, name: &str) -> Option<*mut c_void> { + unsafe { + self.library + .get::<*mut ()>(name.as_bytes()) + .ok() + .map(|x| x.into_raw().into_raw()) + } + } } impl Drop for AotContractExecutor { @@ -450,13 +599,16 @@ impl Drop for AotContractExecutor { #[cfg(test)] mod tests { use super::*; - use crate::{starknet_stub::StubSyscallHandler, utils::test::load_starknet}; - use cairo_lang_sierra::program::Program; + use crate::{starknet_stub::StubSyscallHandler, utils::test::load_starknet_contract}; + use cairo_lang_starknet_classes::contract_class::ContractClass; + use rayon::iter::ParallelBridge; use rstest::*; + // todo add recursive contract test + #[fixture] - fn starknet_program() -> Program { - let (_, program) = load_starknet! { + fn starknet_program() -> ContractClass { + let (_, program) = load_starknet_contract! { #[starknet::interface] trait ISimpleStorage { fn get(self: @TContractState, x: felt252) -> (felt252, felt252); @@ -479,8 +631,40 @@ mod tests { } #[fixture] - fn starknet_program_empty() -> Program { - let (_, program) = load_starknet! { + fn starknet_program_factorial() -> ContractClass { + let (_, program) = load_starknet_contract! { + #[starknet::interface] + trait ISimpleStorage { + fn get(self: @TContractState, x: felt252) -> felt252; + } + + #[starknet::contract] + mod contract { + #[storage] + struct Storage {} + + #[abi(embed_v0)] + impl ISimpleStorageImpl of super::ISimpleStorage { + fn get(self: @ContractState, x: felt252) -> felt252 { + factorial(1, x) + } + } + + fn factorial(value: felt252, n: felt252) -> felt252 { + if (n == 1) { + value + } else { + factorial(value * n, n - 1) + } + } + } + }; + program + } + + #[fixture] + fn starknet_program_empty() -> ContractClass { + let (_, program) = load_starknet_contract! { #[starknet::interface] trait ISimpleStorage { fn call(self: @TContractState); @@ -501,24 +685,74 @@ mod tests { program } + #[rstest] + #[case(OptLevel::Default)] + fn test_contract_executor_parallel( + starknet_program: ContractClass, + #[case] optlevel: OptLevel, + ) { + use rayon::iter::ParallelIterator; + + let executor = Arc::new( + AotContractExecutor::new( + &starknet_program.extract_sierra_program().unwrap(), + &starknet_program.entry_points_by_type, + optlevel, + ) + .unwrap(), + ); + + // The last function in the program is the `get` wrapper function. + let selector = starknet_program + .entry_points_by_type + .external + .last() + .unwrap() + .selector + .clone(); + + (0..200).par_bridge().for_each(|n| { + let result = executor + .run( + Felt::from(&selector), + &[n.into()], + Some(u64::MAX as u128), + None, + &mut StubSyscallHandler::default(), + ) + .unwrap(); + + assert_eq!(result.return_values, vec![Felt::from(n), Felt::from(n * 2)]); + assert_eq!(result.remaining_gas, 18446744073709548175); + }); + } + #[rstest] #[case(OptLevel::None)] #[case(OptLevel::Default)] - fn test_contract_executor(starknet_program: Program, #[case] optlevel: OptLevel) { - let executor = AotContractExecutor::new(&starknet_program, optlevel).unwrap(); + fn test_contract_executor(starknet_program: ContractClass, #[case] optlevel: OptLevel) { + let executor = AotContractExecutor::new( + &starknet_program.extract_sierra_program().unwrap(), + &starknet_program.entry_points_by_type, + optlevel, + ) + .unwrap(); // The last function in the program is the `get` wrapper function. - let entrypoint_function_id = &starknet_program - .funcs + let selector = starknet_program + .entry_points_by_type + .external .last() - .expect("should have a function") - .id; + .unwrap() + .selector + .clone(); let result = executor .run( - entrypoint_function_id, + Felt::from(&selector), &[2.into()], Some(u64::MAX as u128), + None, &mut StubSyscallHandler::default(), ) .unwrap(); @@ -526,24 +760,72 @@ mod tests { assert_eq!(result.return_values, vec![Felt::from(2), Felt::from(4)]); } + #[rstest] + #[case(OptLevel::Aggressive)] + fn test_contract_executor_factorial( + starknet_program_factorial: ContractClass, + #[case] optlevel: OptLevel, + ) { + let executor = AotContractExecutor::new( + &starknet_program_factorial.extract_sierra_program().unwrap(), + &starknet_program_factorial.entry_points_by_type, + optlevel, + ) + .unwrap(); + + // The last function in the program is the `get` wrapper function. + let selector = starknet_program_factorial + .entry_points_by_type + .external + .last() + .unwrap() + .selector + .clone(); + + let result = executor + .run( + Felt::from(&selector), + &[10.into()], + Some(u64::MAX as u128), + None, + &mut StubSyscallHandler::default(), + ) + .unwrap(); + + assert_eq!(result.return_values, vec![Felt::from(3628800)]); + assert_eq!(result.remaining_gas, 18446744073709533805); + } + #[rstest] #[case(OptLevel::None)] #[case(OptLevel::Default)] - fn test_contract_executor_empty(starknet_program_empty: Program, #[case] optlevel: OptLevel) { - let executor = AotContractExecutor::new(&starknet_program_empty, optlevel).unwrap(); + fn test_contract_executor_empty( + starknet_program_empty: ContractClass, + #[case] optlevel: OptLevel, + ) { + let executor = AotContractExecutor::new( + &starknet_program_empty.extract_sierra_program().unwrap(), + &starknet_program_empty.entry_points_by_type, + optlevel, + ) + .unwrap(); // The last function in the program is the `get` wrapper function. - let entrypoint_function_id = &starknet_program_empty - .funcs + // The last function in the program is the `get` wrapper function. + let selector = starknet_program_empty + .entry_points_by_type + .external .last() - .expect("should have a function") - .id; + .unwrap() + .selector + .clone(); let result = executor .run( - entrypoint_function_id, + Felt::from(&selector), &[], Some(u64::MAX as u128), + None, &mut StubSyscallHandler::default(), ) .unwrap(); diff --git a/src/executor/jit.rs b/src/executor/jit.rs index 55bfafcbe..440c95caf 100644 --- a/src/executor/jit.rs +++ b/src/executor/jit.rs @@ -76,9 +76,13 @@ impl<'m> JitNativeExecutor<'m> { .get_initial_available_gas(function_id, gas) .map_err(crate::error::Error::GasMetadataError)?; + let set_builtin_costs_fnptr: extern "C" fn(*const u64) -> *const u64 = + unsafe { std::mem::transmute(self.engine.lookup("cairo_native__set_costs_builtin")) }; + super::invoke_dynamic( &self.registry, self.find_function_ptr(function_id), + set_builtin_costs_fnptr, self.extract_signature(function_id).unwrap(), args, available_gas, @@ -99,9 +103,13 @@ impl<'m> JitNativeExecutor<'m> { .get_initial_available_gas(function_id, gas) .map_err(crate::error::Error::GasMetadataError)?; + let set_builtin_costs_fnptr: extern "C" fn(*const u64) -> *const u64 = + unsafe { std::mem::transmute(self.engine.lookup("cairo_native__set_costs_builtin")) }; + super::invoke_dynamic( &self.registry, self.find_function_ptr(function_id), + set_builtin_costs_fnptr, self.extract_signature(function_id).unwrap(), args, available_gas, @@ -120,16 +128,19 @@ impl<'m> JitNativeExecutor<'m> { .gas_metadata .get_initial_available_gas(function_id, gas) .map_err(crate::error::Error::GasMetadataError)?; - // TODO: Check signature for contract interface. + + let set_builtin_costs_fnptr: extern "C" fn(*const u64) -> *const u64 = + unsafe { std::mem::transmute(self.engine.lookup("cairo_native__set_costs_builtin")) }; + ContractExecutionResult::from_execution_result(super::invoke_dynamic( &self.registry, self.find_function_ptr(function_id), + set_builtin_costs_fnptr, self.extract_signature(function_id).unwrap(), &[Value::Struct { fields: vec![Value::Array( args.iter().cloned().map(Value::Felt252).collect(), )], - // TODO: Populate `debug_name`. debug_name: None, }], available_gas, @@ -145,6 +156,16 @@ impl<'m> JitNativeExecutor<'m> { self.engine.lookup(&function_name) as *mut c_void } + pub fn find_symbol_ptr(&self, name: &str) -> Option<*mut c_void> { + let ptr = self.engine.lookup(name) as *mut c_void; + + if ptr.is_null() { + None + } else { + Some(ptr) + } + } + fn extract_signature(&self, function_id: &FunctionId) -> Option<&FunctionSignature> { self.program_registry() .get_function(function_id) diff --git a/src/ffi.rs b/src/ffi.rs index 0dbb8aa1d..4113bb8d7 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -281,6 +281,7 @@ pub fn object_to_shared_lib(object: &[u8], output_filename: &Path) -> Result<()> "-o".into(), Cow::from(output_path), "-lSystem".into(), + "-force_load".into(), // needed so `cairo_native__set_costs_builtin` is always available Cow::from(runtime_library_path), ]); @@ -290,7 +291,6 @@ pub fn object_to_shared_lib(object: &[u8], output_filename: &Path) -> Result<()> { let mut args: Vec> = vec![ "--hash-style=gnu".into(), - "--eh-frame-hdr".into(), "-shared".into(), "-L/lib/../lib64".into(), "-L/usr/lib/../lib64".into(), @@ -301,6 +301,7 @@ pub fn object_to_shared_lib(object: &[u8], output_filename: &Path) -> Result<()> Cow::from(output_path), "-lc".into(), Cow::from(file_path), + "--whole-archive".into(), // needed so `cairo_native__set_costs_builtin` is always available Cow::from(runtime_library_path), ]); diff --git a/src/libfuncs/gas.rs b/src/libfuncs/gas.rs index d1e86b5dc..ce50f673c 100644 --- a/src/libfuncs/gas.rs +++ b/src/libfuncs/gas.rs @@ -2,23 +2,22 @@ use super::LibfuncHelper; use crate::{ - error::Result, - metadata::{gas::GasCost, MetadataStorage}, - utils::{BlockExt, ProgramRegistryExt}, + error::{Error, Result}, + metadata::{gas::GasCost, runtime_bindings::RuntimeBindingsMeta, MetadataStorage}, + utils::{BlockExt, GepIndex}, }; use cairo_lang_sierra::{ extensions::{ core::{CoreLibfunc, CoreType}, - gas::GasConcreteLibfunc, + gas::{CostTokenType, GasConcreteLibfunc}, lib_func::SignatureOnlyConcreteLibfunc, - ConcreteLibfunc, }, program_registry::ProgramRegistry, }; use melior::{ dialect::{ arith::{self, CmpiPredicate}, - llvm, ods, + ods, }, ir::{r#type::IntegerType, Block, Location}, Context, @@ -51,7 +50,7 @@ pub fn build<'ctx, 'this>( } } -/// Generate MLIR operations for the `get_builtin_costs` libfunc. +/// Generate MLIR operations for the `get_available_gas` libfunc. pub fn build_get_available_gas<'ctx, 'this>( _context: &'ctx Context, _registry: &ProgramRegistry, @@ -76,29 +75,80 @@ pub fn build_withdraw_gas<'ctx, 'this>( entry: &'this Block<'ctx>, location: Location<'ctx>, helper: &LibfuncHelper<'ctx, 'this>, - metadata: &MetadataStorage, + metadata: &mut MetadataStorage, _info: &SignatureOnlyConcreteLibfunc, ) -> Result<()> { let range_check = super::increment_builtin_counter(context, entry, location, entry.argument(0)?.into())?; let current_gas = entry.argument(1)?.into(); - let cost = metadata.get::().and_then(|x| x.0); + let gas_cost = metadata + .get::() + .expect("builtin_withdraw_gas should always have a gas cost") + .clone(); let u128_type: melior::ir::Type = IntegerType::new(context, 128).into(); - let gas_cost_val = - entry.const_int_from_type(context, location, cost.unwrap_or(0), u128_type)?; + let u64_type: melior::ir::Type = IntegerType::new(context, 64).into(); + + let builtin_ptr = { + let runtime = metadata + .get_mut::() + .ok_or(Error::MissingMetadata)?; + runtime + .get_gas_builtin(context, helper, entry, location)? + .result(0)? + .into() + }; + + let mut total_gas_cost_value = entry.const_int_from_type(context, location, 0, u128_type)?; + + for (cost_count, token_type) in &gas_cost.0 { + if *cost_count == 0 { + continue; + } + + let builtin_costs_index = match token_type { + CostTokenType::Const => 0, + CostTokenType::Pedersen => 1, + CostTokenType::Bitwise => 2, + CostTokenType::EcOp => 3, + CostTokenType::Poseidon => 4, + CostTokenType::AddMod => 5, + CostTokenType::MulMod => 6, + _ => unreachable!(), + }; + + let cost_count_value = + entry.const_int_from_type(context, location, *cost_count, u128_type)?; + let builtin_costs_index_value = + entry.const_int_from_type(context, location, builtin_costs_index, u64_type)?; + + let builtin_cost_value_ptr = entry.gep( + context, + location, + builtin_ptr, + &[GepIndex::Value(builtin_costs_index_value)], + u64_type, + )?; + let cost_value = entry.load(context, location, builtin_cost_value_ptr, u64_type)?; + let cost_value_ext = + entry.append_op_result(arith::extui(cost_value, u128_type, location))?; + let gas_cost_value = + entry.append_op_result(arith::muli(cost_count_value, cost_value_ext, location))?; + total_gas_cost_value = + entry.append_op_result(arith::addi(total_gas_cost_value, gas_cost_value, location))?; + } let is_enough = entry.append_op_result(arith::cmpi( context, CmpiPredicate::Uge, current_gas, - gas_cost_val, + total_gas_cost_value, location, ))?; let resulting_gas = entry.append_op_result( - ods::llvm::intr_usub_sat(context, current_gas, gas_cost_val, location).into(), + ods::llvm::intr_usub_sat(context, current_gas, total_gas_cost_value, location).into(), )?; entry.append_operation(helper.cond_br( @@ -125,23 +175,64 @@ pub fn build_builtin_withdraw_gas<'ctx, 'this>( let range_check = super::increment_builtin_counter(context, entry, location, entry.argument(0)?.into())?; let current_gas = entry.argument(1)?.into(); + let builtin_ptr = entry.argument(2)?.into(); - let cost = metadata.get::().and_then(|x| x.0); + let gas_cost = metadata + .get::() + .expect("builtin_withdraw_gas should always have a gas cost"); let u128_type: melior::ir::Type = IntegerType::new(context, 128).into(); - let gas_cost_val = - entry.const_int_from_type(context, location, cost.unwrap_or(0), u128_type)?; + let u64_type: melior::ir::Type = IntegerType::new(context, 64).into(); + + let mut total_gas_cost_value = entry.const_int_from_type(context, location, 0, u128_type)?; + + for (cost_count, token_type) in &gas_cost.0 { + if *cost_count == 0 { + continue; + } + + let builtin_costs_index = match token_type { + CostTokenType::Const => 0, + CostTokenType::Pedersen => 1, + CostTokenType::Bitwise => 2, + CostTokenType::EcOp => 3, + CostTokenType::Poseidon => 4, + CostTokenType::AddMod => 5, + CostTokenType::MulMod => 6, + _ => unreachable!(), + }; + + let cost_count_value = + entry.const_int_from_type(context, location, *cost_count, u128_type)?; + let builtin_costs_index_value = + entry.const_int_from_type(context, location, builtin_costs_index, u64_type)?; + + let builtin_cost_value_ptr = entry.gep( + context, + location, + builtin_ptr, + &[GepIndex::Value(builtin_costs_index_value)], + u64_type, + )?; + let cost_value = entry.load(context, location, builtin_cost_value_ptr, u64_type)?; + let cost_value_ext = + entry.append_op_result(arith::extui(cost_value, u128_type, location))?; + let gas_cost_value = + entry.append_op_result(arith::muli(cost_count_value, cost_value_ext, location))?; + total_gas_cost_value = + entry.append_op_result(arith::addi(total_gas_cost_value, gas_cost_value, location))?; + } let is_enough = entry.append_op_result(arith::cmpi( context, CmpiPredicate::Uge, current_gas, - gas_cost_val, + total_gas_cost_value, location, ))?; let resulting_gas = entry.append_op_result( - ods::llvm::intr_usub_sat(context, current_gas, gas_cost_val, location).into(), + ods::llvm::intr_usub_sat(context, current_gas, total_gas_cost_value, location).into(), )?; entry.append_operation(helper.cond_br( @@ -158,25 +249,25 @@ pub fn build_builtin_withdraw_gas<'ctx, 'this>( /// Generate MLIR operations for the `get_builtin_costs` libfunc. pub fn build_get_builtin_costs<'ctx, 'this>( context: &'ctx Context, - registry: &ProgramRegistry, + _registry: &ProgramRegistry, entry: &'this Block<'ctx>, location: Location<'ctx>, helper: &LibfuncHelper<'ctx, 'this>, metadata: &mut MetadataStorage, - info: &SignatureOnlyConcreteLibfunc, + _info: &SignatureOnlyConcreteLibfunc, ) -> Result<()> { - let builtin_costs_ty = registry.build_type( - context, - helper, - registry, - metadata, - &info.branch_signatures()[0].vars[0].ty, - )?; - - // TODO: Implement libfunc. - let op0 = entry.append_op_result(llvm::undef(builtin_costs_ty, location))?; + // Get the ptr to the global, holding a ptr to the list. + let builtin_ptr = { + let runtime = metadata + .get_mut::() + .ok_or(Error::MissingMetadata)?; + runtime + .get_gas_builtin(context, helper, entry, location)? + .result(0)? + .into() + }; - entry.append_operation(helper.br(0, &[op0], location)); + entry.append_operation(helper.br(0, &[builtin_ptr], location)); Ok(()) } diff --git a/src/metadata/gas.rs b/src/metadata/gas.rs index 9378c5895..d368c0c1c 100644 --- a/src/metadata/gas.rs +++ b/src/metadata/gas.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + use cairo_lang_runner::token_gas_cost; use cairo_lang_sierra::{ extensions::gas::CostTokenType, @@ -20,8 +22,9 @@ pub struct GasMetadata { pub gas_info: GasInfo, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct GasCost(pub Option); +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +// Cost, token type (index into builtin costs). +pub struct GasCost(pub Vec<(u128, CostTokenType)>); /// Configuration for metadata computation. #[derive(Debug, Clone)] @@ -102,16 +105,46 @@ impl GasMetadata { ) } - pub fn get_gas_cost_for_statement(&self, idx: StatementIdx) -> Option { - let mut cost = None; + pub fn initial_required_gas_for_entry_points(&self) -> BTreeMap> { + self.gas_info + .function_costs + .iter() + .map(|func| { + (func.0.id, { + let mut costs = BTreeMap::new(); + + for (token, val) in func.1.iter() { + let offset: u64 = match token { + CostTokenType::Const => 0, + CostTokenType::Pedersen => 1, + CostTokenType::Bitwise => 2, + CostTokenType::EcOp => 3, + CostTokenType::Poseidon => 4, + CostTokenType::AddMod => 5, + CostTokenType::MulMod => 6, + _ => unreachable!(), + }; + costs.insert(offset, *val as u64); + } + + costs + }) + }) + .collect() + } + + pub fn get_gas_costs_for_statement(&self, idx: StatementIdx) -> Vec<(u128, CostTokenType)> { + let mut costs = Vec::new(); for cost_type in CostTokenType::iter_casm_tokens() { - if let Some(amount) = + if let Some(cost_count) = self.get_gas_cost_for_statement_and_cost_token_type(idx, *cost_type) { - *cost.get_or_insert(0) += amount * token_gas_cost(*cost_type) as u128; + if cost_count > 0 { + costs.push((cost_count, *cost_type)); + } } } - cost + costs } pub fn get_gas_cost_for_statement_and_cost_token_type( diff --git a/src/metadata/runtime_bindings.rs b/src/metadata/runtime_bindings.rs index 3ddbf0441..225b671c3 100644 --- a/src/metadata/runtime_bindings.rs +++ b/src/metadata/runtime_bindings.rs @@ -31,6 +31,7 @@ enum RuntimeBinding { DictGasRefund, DictDrop, DictDup, + GetGasBuiltin, DebugPrint, #[cfg(feature = "with-cheatcode")] VtableCheatcode, @@ -862,6 +863,49 @@ impl RuntimeBindingsMeta { ))) } + // Register if necessary, then invoke the `set_gas_builtin()` function. + #[allow(clippy::too_many_arguments)] + pub fn get_gas_builtin<'c, 'a>( + &mut self, + context: &'c Context, + module: &Module, + block: &'a Block<'c>, + location: Location<'c>, + ) -> Result> + where + 'c: 'a, + { + if self.active_map.insert(RuntimeBinding::GetGasBuiltin) { + module.body().append_operation(func::func( + context, + StringAttribute::new(context, "cairo_native__get_costs_builtin"), + TypeAttribute::new( + FunctionType::new(context, &[], &[llvm::r#type::pointer(context, 0)]).into(), + ), + Region::new(), + &[ + ( + Identifier::new(context, "sym_visibility"), + StringAttribute::new(context, "private").into(), + ), + ( + Identifier::new(context, "llvm.linkage"), + Attribute::parse(context, "#llvm.linkage").unwrap(), + ), + ], + Location::unknown(context), + )); + } + + Ok(block.append_operation(func::call( + context, + FlatSymbolRefAttribute::new(context, "cairo_native__get_costs_builtin"), + &[], + &[llvm::r#type::pointer(context, 0)], + location, + ))) + } + /// Register if necessary, then invoke the `vtable_cheatcode()` runtime function. /// /// Calls the cheatcode syscall with the given arguments. diff --git a/src/types.rs b/src/types.rs index bf1dbb8bf..6d0d6314e 100644 --- a/src/types.rs +++ b/src/types.rs @@ -549,10 +549,12 @@ impl TypeBuilder for CoreTypeConcrete { | CoreTypeConcrete::Poseidon(_) | CoreTypeConcrete::RangeCheck96(_) | CoreTypeConcrete::SegmentArena(_) => false, + + // A ptr to a list of costs. + CoreTypeConcrete::BuiltinCosts(_) => false, + // Other builtins: - CoreTypeConcrete::BuiltinCosts(_) - | CoreTypeConcrete::Uint128MulGuarantee(_) - | CoreTypeConcrete::Coupon(_) => true, + CoreTypeConcrete::Uint128MulGuarantee(_) | CoreTypeConcrete::Coupon(_) => true, // Normal types: CoreTypeConcrete::Array(_) @@ -634,7 +636,7 @@ impl TypeBuilder for CoreTypeConcrete { CoreTypeConcrete::EcState(_) => layout_repeat(&get_integer_layout(252), 4)?.0, CoreTypeConcrete::Felt252(_) => get_integer_layout(252), CoreTypeConcrete::GasBuiltin(_) => get_integer_layout(128), - CoreTypeConcrete::BuiltinCosts(_) => Layout::new::<()>(), + CoreTypeConcrete::BuiltinCosts(_) => Layout::new::<*const ()>(), CoreTypeConcrete::Uint8(_) => get_integer_layout(8), CoreTypeConcrete::Uint16(_) => get_integer_layout(16), CoreTypeConcrete::Uint32(_) => get_integer_layout(32), diff --git a/src/types/builtin_costs.rs b/src/types/builtin_costs.rs index b7a202dce..7fc2e978c 100644 --- a/src/types/builtin_costs.rs +++ b/src/types/builtin_costs.rs @@ -1,4 +1,7 @@ //! # Builtin costs type +//! +//! A ptr to a list of u64, this list will not change at runtime in size and thus we only really need to store the pointer, +//! it can be allocated on the stack on rust side and passed. use super::WithSelf; use crate::{error::Result, metadata::MetadataStorage}; @@ -11,7 +14,7 @@ use cairo_lang_sierra::{ }; use melior::{ dialect::llvm, - ir::{r#type::IntegerType, Module, Type}, + ir::{Module, Type}, Context, }; @@ -25,5 +28,6 @@ pub fn build<'ctx>( _metadata: &mut MetadataStorage, _info: WithSelf, ) -> Result> { - Ok(llvm::r#type::array(IntegerType::new(context, 8).into(), 0)) + // A ptr to a list of u64 + Ok(llvm::r#type::pointer(context, 0)) } diff --git a/src/utils.rs b/src/utils.rs index 110ff2c93..a2cfc77a3 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -7,7 +7,9 @@ pub(crate) use self::{ }; use crate::{metadata::MetadataStorage, OptLevel}; use cairo_lang_compiler::CompilerConfig; +use cairo_lang_runner::token_gas_cost; use cairo_lang_sierra::{ + extensions::gas::CostTokenType, ids::FunctionId, program::{GenFunction, Program, StatementIdx}, }; @@ -17,6 +19,7 @@ use melior::{ Context, Error, ExecutionEngine, }; use num_bigint::{BigInt, BigUint, Sign}; +use serde::{Deserialize, Serialize}; use starknet_types_core::felt::Felt; use std::sync::LazyLock; use std::{ @@ -51,6 +54,47 @@ pub static HALF_PRIME: LazyLock = LazyLock::new(|| { .unwrap() }); +#[derive(Debug, Clone, Copy, Deserialize, Serialize)] +pub struct BuiltinCosts { + pub r#const: u64, + pub pedersen: u64, + pub bitwise: u64, + pub ecop: u64, + pub poseidon: u64, + pub add_mod: u64, + pub mul_mod: u64, +} + +impl From for [u64; 7] { + // Order matters, for the libfunc impl + // https://github.com/starkware-libs/sequencer/blob/1b7252f8a30244d39614d7666aa113b81291808e/crates/blockifier/src/execution/entry_point_execution.rs#L208 + fn from(value: BuiltinCosts) -> Self { + [ + value.r#const, + value.pedersen, + value.bitwise, + value.ecop, + value.poseidon, + value.add_mod, + value.mul_mod, + ] + } +} + +impl Default for BuiltinCosts { + fn default() -> Self { + Self { + r#const: token_gas_cost(CostTokenType::Const) as u64, + pedersen: token_gas_cost(CostTokenType::Pedersen) as u64, + bitwise: token_gas_cost(CostTokenType::Bitwise) as u64, + ecop: token_gas_cost(CostTokenType::EcOp) as u64, + poseidon: token_gas_cost(CostTokenType::Poseidon) as u64, + add_mod: token_gas_cost(CostTokenType::AddMod) as u64, + mul_mod: token_gas_cost(CostTokenType::MulMod) as u64, + } + } +} + #[cfg(feature = "with-mem-tracing")] #[allow(unused_imports)] pub(crate) use self::mem_tracing::{ @@ -349,6 +393,18 @@ pub fn register_runtime_symbols(engine: &ExecutionEngine) { as *mut (), ); + engine.register_symbol( + "cairo_native__set_costs_builtin", + cairo_native_runtime::cairo_native__set_costs_builtin as *const fn(*const u64) -> () + as *mut (), + ); + + engine.register_symbol( + "cairo_native__get_costs_builtin", + cairo_native_runtime::cairo_native__get_costs_builtin as *const fn() -> *const u64 + as *mut (), + ); + #[cfg(feature = "with-cheatcode")] { engine.register_symbol( @@ -475,7 +531,8 @@ pub mod test { program::Program, program::{FunctionSignature, GenFunction, StatementIdx}, }; - use cairo_lang_starknet::starknet_plugin_suite; + use cairo_lang_starknet::{compile::compile_contract_in_prepared_db, starknet_plugin_suite}; + use cairo_lang_starknet_classes::contract_class::ContractClass; use pretty_assertions_sorted::assert_eq; use std::io::Write; use std::{env::var, fmt::Formatter, fs, path::Path}; @@ -490,8 +547,14 @@ pub mod test { $crate::utils::test::load_starknet_str(stringify!($($program)+)) }; } + macro_rules! load_starknet_contract { + ( $( $program:tt )+ ) => { + $crate::utils::test::load_starknet_contract_str(stringify!($($program)+)) + }; + } pub(crate) use load_cairo; pub(crate) use load_starknet; + pub(crate) use load_starknet_contract; // Helper macros for faster testing. macro_rules! jit_struct { @@ -550,6 +613,49 @@ pub mod test { ) } + pub(crate) fn load_starknet_contract_str(program_str: &str) -> (String, ContractClass) { + compile_contract( + program_str, + RootDatabase::builder() + .with_plugin_suite(starknet_plugin_suite()) + .build() + .unwrap(), + ) + } + + pub(crate) fn compile_contract( + program_str: &str, + mut db: RootDatabase, + ) -> (String, ContractClass) { + let mut program_file = tempfile::Builder::new() + .prefix("test_") + .suffix(".cairo") + .tempfile() + .unwrap(); + fs::write(&mut program_file, program_str).unwrap(); + + init_dev_corelib( + &mut db, + Path::new(&var("CARGO_MANIFEST_DIR").unwrap()).join("corelib/src"), + ); + let main_crate_ids = setup_project(&mut db, program_file.path()).unwrap(); + let contract = compile_contract_in_prepared_db( + &db, + None, + main_crate_ids, + CompilerConfig { + diagnostics_reporter: DiagnosticsReporter::stderr(), + replace_ids: true, + ..Default::default() + }, + ) + .unwrap(); + + let module_name = program_file.path().with_extension(""); + let module_name = module_name.file_name().unwrap().to_str().unwrap(); + (module_name.to_string(), contract) + } + pub(crate) fn compile_program(program_str: &str, mut db: RootDatabase) -> (String, Program) { let mut program_file = tempfile::Builder::new() .prefix("test_") diff --git a/src/values.rs b/src/values.rs index 213360111..796933dec 100644 --- a/src/values.rs +++ b/src/values.rs @@ -1,6 +1,6 @@ -//! # JIT params and return values de/serialization +//! # Params and return values de/serialization //! -//! A Rusty interface to provide parameters to JIT calls. +//! A Rusty interface to provide parameters to cairo-native entry point calls. use crate::{ error::{CompilerError, Error}, diff --git a/tests/common.rs b/tests/common.rs index 649499276..d2231ed4b 100644 --- a/tests/common.rs +++ b/tests/common.rs @@ -31,7 +31,7 @@ use cairo_lang_starknet_classes::{ use cairo_native::{ context::NativeContext, execution_result::{ContractExecutionResult, ExecutionResult}, - executor::JitNativeExecutor, + executor::{AotContractExecutor, AotNativeExecutor, JitNativeExecutor}, starknet::{DummySyscallHandler, StarknetSyscallHandler}, utils::{find_entry_point_by_idx, HALF_PRIME, PRIME}, OptLevel, Value, @@ -48,7 +48,7 @@ use lambdaworks_math::{ }, unsigned_integer::element::UnsignedInteger, }; -use num_bigint::{BigInt, Sign}; +use num_bigint::{BigInt, BigUint, Sign}; use proptest::{strategy::Strategy, test_runner::TestCaseError}; use starknet_types_core::felt::Felt; use std::{collections::HashMap, env::var, fs, ops::Neg, path::Path}; @@ -391,7 +391,6 @@ pub fn run_vm_contract( .collect_vec() } -#[track_caller] pub fn compare_inputless_program(program_path: &str) { let program: (String, Program, SierraCasmRunner) = load_cairo_path(program_path); let program = &program; @@ -428,17 +427,33 @@ pub fn run_native_starknet_contract( let entry_point_fn = find_entry_point_by_idx(sierra_program, entry_point_function_idx).unwrap(); let entry_point_id = &entry_point_fn.id; - let native_executor = JitNativeExecutor::from_native_module(native_program, Default::default()); + let native_executor = AotNativeExecutor::from_native_module(native_program, Default::default()); native_executor .invoke_contract_dynamic(entry_point_id, args, u128::MAX.into(), handler) .expect("failed to execute the given contract") } +pub fn run_native_starknet_aot_contract( + contract: &ContractClass, + selector: &BigUint, + args: &[Felt], + handler: impl StarknetSyscallHandler, +) -> ContractExecutionResult { + let native_executor = AotContractExecutor::new( + &contract.extract_sierra_program().unwrap(), + &contract.entry_points_by_type, + Default::default(), + ) + .unwrap(); + native_executor + .run(Felt::from(selector), args, u128::MAX.into(), None, handler) + .expect("failed to execute the given contract") +} + /// Given the result of the cairo-vm and cairo-native of the same program, it compares /// the results automatically, triggering a proptest assert if there is a mismatch. /// /// Left of report of the assert is the cairo vm result, right side is cairo native -#[track_caller] pub fn compare_outputs( program: &Program, entry_point: &FunctionId, @@ -744,8 +759,12 @@ pub fn compare_outputs( .unwrap_or(false) }); assert_eq!( - vm_result.gas_counter.unwrap_or_else(|| Felt::from(0)), - Felt::from(native_result.remaining_gas.unwrap_or(0)), + vm_result + .gas_counter + .unwrap_or_else(|| Felt::from(0)) + .to_bigint(), + Felt::from(native_result.remaining_gas.unwrap_or(0)).to_bigint(), + "gas mismatch" ); let vm_result = match &vm_result.value { @@ -807,7 +826,11 @@ pub fn compare_outputs( }, }; - pretty_assertions_sorted::assert_eq!(native_result.return_value, vm_result); + pretty_assertions_sorted::assert_eq!( + native_result.return_value, + vm_result, + "return value mismatch" + ); Ok(()) } diff --git a/tests/tests/starknet/keccak.rs b/tests/tests/starknet/keccak.rs index 9422ef6fe..a2365ab8c 100644 --- a/tests/tests/starknet/keccak.rs +++ b/tests/tests/starknet/keccak.rs @@ -1,4 +1,4 @@ -use crate::common::run_native_starknet_contract; +use crate::common::{run_native_starknet_aot_contract, run_native_starknet_contract}; use cairo_lang_compiler::CompilerConfig; use cairo_lang_starknet::compile::compile_path; use cairo_native::starknet_stub::StubSyscallHandler; @@ -41,4 +41,15 @@ fn keccak_test() { 340282366920938463463374607431768143515 ); assert_eq!(result.return_values, vec![1.into()]); + + let result_aot_ct = run_native_starknet_aot_contract( + contract, + &entry_point.selector, + &[], + &mut StubSyscallHandler::default(), + ); + + assert!(!result_aot_ct.failure_flag); + assert_eq!(result_aot_ct.remaining_gas, result.remaining_gas); + assert_eq!(result_aot_ct.return_values, vec![1.into()]); }