diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8aa5f4c72..217b02a41 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -253,11 +253,11 @@ jobs: run: make runtime-ci && make check-llvm && make needs-cairo2 && make build-alexandria - name: Run tests and generate coverage partition ${{ matrix.partition }} - run: cargo llvm-cov nextest --verbose --all-features --workspace --lcov --output-path ${{ matrix.output }} --partition count:${{ matrix.partition }}/4 + run: cargo llvm-cov nextest --verbose --features=scarb --workspace --lcov --output-path ${{ matrix.output }} --partition count:${{ matrix.partition }}/4 - name: test and generate coverage corelib if: ${{ matrix.partition == '1' }} - run: cargo llvm-cov nextest --verbose --all-features --lcov --output-path lcov-test.info run --bin cairo-native-test -- corelib + run: cargo llvm-cov nextest --verbose --features=scarb --lcov --output-path lcov-test.info run --bin cairo-native-test -- corelib - name: save coverage data with corelib if: ${{ matrix.partition == '1' }} @@ -275,7 +275,6 @@ jobs: name: coverage-data-${{ matrix.partition }} path: ./${{ matrix.output }} - upload-coverage: name: Upload Coverage runs-on: ubuntu-latest diff --git a/Cargo.toml b/Cargo.toml index c5276b1df..c4b2e6a37 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,9 +47,10 @@ build-cli = [ "dep:colored", ] scarb = ["build-cli", "dep:scarb-ui", "dep:scarb-metadata"] +with-cheatcode = [] with-debug-utils = [] +with-mem-tracing = [] with-runtime = ["dep:cairo-native-runtime"] -with-cheatcode = [] # the aquamarine dep is only used in docs and cannot be detected as used by cargo udeps [package.metadata.cargo-udeps.ignore] diff --git a/Makefile b/Makefile index 885f06764..5d018b1f4 100644 --- a/Makefile +++ b/Makefile @@ -46,15 +46,15 @@ endif .PHONY: build build: check-llvm runtime - cargo build --release --all-features + cargo build --release --features=scarb .PHONY: build-natives build-native: check-llvm runtime - RUSTFLAGS="-C target-cpu=native" cargo build --release --all-features + RUSTFLAGS="-C target-cpu=native" cargo build --release --features=scarb .PHONY: build-dev build-dev: check-llvm - cargo build --profile optimized-dev --all-features + cargo build --profile optimized-dev --features=scarb .PHONY: check check: check-llvm @@ -63,7 +63,7 @@ check: check-llvm .PHONY: test test: check-llvm needs-cairo2 build-alexandria runtime-ci - cargo test --profile ci --all-features + cargo test --profile ci --features=scarb,with-cheatcode,with-debug-utils .PHONY: test-cairo test-cairo: check-llvm needs-cairo2 build-alexandria runtime-ci @@ -71,20 +71,20 @@ test-cairo: check-llvm needs-cairo2 build-alexandria runtime-ci .PHONY: proptest proptest: check-llvm needs-cairo2 runtime-ci - cargo test --profile ci --all-features proptest + cargo test --profile ci --features=scarb,with-cheatcode,with-debug-utils proptest .PHONY: test-cli test-ci: check-llvm needs-cairo2 build-alexandria runtime-ci - cargo test --profile ci --all-features + cargo test --profile ci --features=scarb,with-cheatcode,with-debug-utils .PHONY: proptest-cli proptest-ci: check-llvm needs-cairo2 runtime-ci - cargo test --profile ci --all-features proptest + cargo test --profile ci --features=scarb,with-cheatcode,with-debug-utils proptest .PHONY: coverage coverage: check-llvm needs-cairo2 build-alexandria runtime-ci - cargo llvm-cov --verbose --profile ci --all-features --workspace --lcov --output-path lcov.info - cargo llvm-cov --verbose --profile ci --all-features --lcov --output-path lcov-test.info run --bin cairo-native-test -- corelib + cargo llvm-cov --verbose --profile ci --features=scarb,with-cheatcode,with-debug-utils --workspace --lcov --output-path lcov.info + cargo llvm-cov --verbose --profile ci --features=scarb,with-cheatcode,with-debug-utils --lcov --output-path lcov-test.info run --bin cairo-native-test -- corelib .PHONY: doc doc: check-llvm @@ -100,7 +100,7 @@ bench: build needs-cairo2 runtime .PHONY: bench-ci bench-ci: check-llvm needs-cairo2 runtime - cargo criterion --all-features + cargo criterion --features=scarb,with-cheatcode,with-debug-utils .PHONY: stress-test stress-test: check-llvm @@ -116,7 +116,7 @@ stress-clean: .PHONY: install install: check-llvm - RUSTFLAGS="-C target-cpu=native" cargo install --all-features --locked --path . + RUSTFLAGS="-C target-cpu=native" cargo install --features=scarb,with-cheatcode --locked --path . .PHONY: clean clean: stress-clean diff --git a/examples/erc20.rs b/examples/erc20.rs index 594906a2f..76917c32e 100644 --- a/examples/erc20.rs +++ b/examples/erc20.rs @@ -268,10 +268,10 @@ impl StarknetSyscallHandler for SyscallHandler { fn sha256_process_block( &mut self, - _prev_state: &[u32; 8], - _current_block: &[u32; 16], + _state: &mut [u32; 8], + _block: &[u32; 16], _remaining_gas: &mut u128, - ) -> SyscallResult<[u32; 8]> { + ) -> SyscallResult<()> { unimplemented!() } } diff --git a/examples/starknet.rs b/examples/starknet.rs index 6ccbb0771..28f987041 100644 --- a/examples/starknet.rs +++ b/examples/starknet.rs @@ -399,10 +399,10 @@ impl StarknetSyscallHandler for SyscallHandler { fn sha256_process_block( &mut self, - _prev_state: &[u32; 8], - _current_block: &[u32; 16], + _state: &mut [u32; 8], + _block: &[u32; 16], _remaining_gas: &mut u128, - ) -> SyscallResult<[u32; 8]> { + ) -> SyscallResult<()> { unimplemented!() } } diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index 3e88b651a..2110b0c58 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -13,7 +13,7 @@ use starknet_types_core::{ felt::Felt, hash::StarkHash, }; -use std::{collections::HashMap, fs::File, io::Write, os::fd::FromRawFd, ptr::NonNull, slice}; +use std::{collections::HashMap, ffi::c_void, fs::File, io::Write, os::fd::FromRawFd, slice}; use std::{ops::Mul, vec::IntoIter}; lazy_static! { @@ -134,10 +134,12 @@ pub unsafe extern "C" fn cairo_native__libfunc__hades_permutation( } /// Felt252 type used in cairo native runtime -#[derive(Debug, Default)] +#[derive(Debug)] pub struct FeltDict { - pub inner: HashMap<[u8; 32], NonNull>, + pub inner: HashMap<[u8; 32], *mut c_void>, pub count: u64, + + pub free_fn: unsafe extern "C" fn(*mut c_void), } /// Allocate a new dictionary. @@ -147,8 +149,14 @@ pub struct FeltDict { /// This function is intended to be called from MLIR, deals with pointers, and is therefore /// definitely unsafe to use manually. #[no_mangle] -pub unsafe extern "C" fn cairo_native__dict_new() -> *mut FeltDict { - Box::into_raw(Box::::default()) +pub unsafe extern "C" fn cairo_native__dict_new( + free_fn: extern "C" fn(*mut c_void), +) -> *mut FeltDict { + Box::into_raw(Box::new(FeltDict { + inner: HashMap::default(), + count: 0, + free_fn, + })) } /// Free a dictionary using an optional callback to drop each element. @@ -157,23 +165,25 @@ pub unsafe extern "C" fn cairo_native__dict_new() -> *mut FeltDict { /// /// This function is intended to be called from MLIR, deals with pointers, and is therefore /// definitely unsafe to use manually. -// Note: Using `Option` is ffi-safe thanks to Option's null +// Note: Using `Option` is ffi-safe thanks to Option's null // pointer optimization. Check out // https://doc.rust-lang.org/nomicon/ffi.html#the-nullable-pointer-optimization for more info. #[no_mangle] pub unsafe extern "C" fn cairo_native__dict_drop( ptr: *mut FeltDict, - drop_fn: Option, + drop_fn: Option, ) { - let map = Box::from_raw(ptr); + let dict = Box::from_raw(ptr); // Free the entries manually. - for entry in map.inner.into_values() { - if let Some(drop_fn) = drop_fn { - drop_fn(entry.as_ptr()); - } + for entry in dict.inner.into_values() { + if !entry.is_null() { + if let Some(drop_fn) = drop_fn { + drop_fn(entry); + } - libc::free(entry.as_ptr()); + (dict.free_fn)(entry); + } } } @@ -186,23 +196,30 @@ pub unsafe extern "C" fn cairo_native__dict_drop( #[no_mangle] pub unsafe extern "C" fn cairo_native__dict_dup( ptr: *mut FeltDict, - dup_fn: extern "C" fn(*mut std::ffi::c_void) -> *mut std::ffi::c_void, + dup_fn: extern "C" fn(*mut c_void) -> *mut c_void, ) -> *mut FeltDict { let old_dict = &*ptr; - let mut new_dict = Box::::default(); + let mut new_dict = Box::new(FeltDict { + inner: HashMap::default(), + count: 0, + free_fn: old_dict.free_fn, + }); new_dict.inner.extend( old_dict .inner .iter() - .map(|(&k, &v)| (k, NonNull::new(dup_fn(v.as_ptr())).unwrap())), + .filter_map(|(&k, &v)| (!v.is_null()).then_some((k, dup_fn(v)))), ); Box::into_raw(new_dict) } -/// Return the value (reference) for a given key, or null if not present. Increment the access -/// count. +/// Return a pointer to the entry's value pointer for a given key, inserting a null pointer if not +/// present. Increment the access count. +/// +/// The null pointer will be either updated by `felt252_dict_entry_finalize` or removed (along with +/// everything else in the dict) by the entry's drop implementation. /// /// # Safety /// @@ -210,38 +227,11 @@ pub unsafe extern "C" fn cairo_native__dict_dup( /// definitely unsafe to use manually. #[no_mangle] pub unsafe extern "C" fn cairo_native__dict_get( - ptr: *mut FeltDict, + dict: &mut FeltDict, key: &[u8; 32], -) -> *mut std::ffi::c_void { - let dict: &mut FeltDict = &mut *ptr; +) -> *mut c_void { dict.count += 1; - - match dict.inner.get(key) { - Some(v) => v.as_ptr(), - None => std::ptr::null_mut(), - } -} - -/// Inserts the provided key value. Returning the old one or nullptr if there was none. -/// -/// # Safety -/// -/// This function is intended to be called from MLIR, deals with pointers, and is therefore -/// definitely unsafe to use manually. -#[no_mangle] -pub unsafe extern "C" fn cairo_native__dict_insert( - ptr: *mut FeltDict, - key: &[u8; 32], - value: NonNull, -) -> *mut std::ffi::c_void { - let dict = &mut *ptr; - let old_ptr = dict.inner.insert(*key, value); - - if let Some(v) = old_ptr { - v.as_ptr() - } else { - std::ptr::null_mut() - } + dict.inner.entry(*key).or_insert(std::ptr::null_mut()) as *mut _ as *mut c_void } /// Compute the total gas refund for the dictionary at squash time. @@ -268,9 +258,9 @@ pub unsafe extern "C" fn cairo_native__dict_gas_refund(ptr: *const FeltDict) -> /// definitely unsafe to use manually. #[no_mangle] pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_point_from_x_nz( - mut point_ptr: NonNull<[[u8; 32]; 2]>, + point_ptr: &mut [[u8; 32]; 2], ) -> bool { - let x = Felt::from_bytes_le(&point_ptr.as_ref()[0]); + let x = Felt::from_bytes_le(&point_ptr[0]); // https://github.com/starkware-libs/cairo/blob/aaad921bba52e729dc24ece07fab2edf09ccfa15/crates/cairo-lang-sierra-to-casm/src/invocations/ec.rs#L63 @@ -305,15 +295,15 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_point_from_x_nz( /// definitely unsafe to use manually. #[no_mangle] pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_point_try_new_nz( - mut point_ptr: NonNull<[[u8; 32]; 2]>, + point_ptr: &mut [[u8; 32]; 2], ) -> bool { - let x = Felt::from_bytes_le(&point_ptr.as_ref()[0]); - let y = Felt::from_bytes_le(&point_ptr.as_ref()[1]); + let x = Felt::from_bytes_le(&point_ptr[0]); + let y = Felt::from_bytes_le(&point_ptr[1]); match AffinePoint::new(x, y) { Ok(point) => { - point_ptr.as_mut()[0].copy_from_slice(&point.x().to_bytes_le()); - point_ptr.as_mut()[1].copy_from_slice(&point.y().to_bytes_le()); + point_ptr[0].copy_from_slice(&point.x().to_bytes_le()); + point_ptr[1].copy_from_slice(&point.y().to_bytes_le()); true } Err(_) => false, @@ -327,9 +317,7 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_point_try_new_nz( /// This function is intended to be called from MLIR, deals with pointers, and is therefore /// definitely unsafe to use manually. #[no_mangle] -pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_init( - mut state_ptr: NonNull<[[u8; 32]; 4]>, -) { +pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_init(state_ptr: &mut [[u8; 32]; 4]) { // https://github.com/starkware-libs/cairo/blob/aaad921bba52e729dc24ece07fab2edf09ccfa15/crates/cairo-lang-runner/src/casm_run/mod.rs#L1802 let mut rng = rand::thread_rng(); let (random_x, random_y) = loop { @@ -345,10 +333,10 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_init( // We already made sure its a valid point. let state = AffinePoint::new_unchecked(random_x, random_y); - state_ptr.as_mut()[0].copy_from_slice(&state.x().to_bytes_le()); - state_ptr.as_mut()[1].copy_from_slice(&state.y().to_bytes_le()); - state_ptr.as_mut()[2].copy_from_slice(&state.x().to_bytes_le()); - state_ptr.as_mut()[3].copy_from_slice(&state.y().to_bytes_le()); + state_ptr[0].copy_from_slice(&state.x().to_bytes_le()); + state_ptr[1].copy_from_slice(&state.y().to_bytes_le()); + state_ptr[2].copy_from_slice(&state.x().to_bytes_le()); + state_ptr[3].copy_from_slice(&state.y().to_bytes_le()); } /// Compute `ec_state_add(state, point)` and store the state back. @@ -363,24 +351,24 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_init( /// definitely unsafe to use manually. #[no_mangle] pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_add( - mut state_ptr: NonNull<[[u8; 32]; 4]>, - point_ptr: NonNull<[[u8; 32]; 2]>, + state_ptr: &mut [[u8; 32]; 4], + point_ptr: &[[u8; 32]; 2], ) { // We use unchecked methods because the inputs must already be valid points. let mut state = ProjectivePoint::from_affine_unchecked( - Felt::from_bytes_le(&state_ptr.as_ref()[0]), - Felt::from_bytes_le(&state_ptr.as_ref()[1]), + Felt::from_bytes_le(&state_ptr[0]), + Felt::from_bytes_le(&state_ptr[1]), ); let point = AffinePoint::new_unchecked( - Felt::from_bytes_le(&point_ptr.as_ref()[0]), - Felt::from_bytes_le(&point_ptr.as_ref()[1]), + Felt::from_bytes_le(&point_ptr[0]), + Felt::from_bytes_le(&point_ptr[1]), ); state += &point; let state = state.to_affine().unwrap(); - state_ptr.as_mut()[0].copy_from_slice(&state.x().to_bytes_le()); - state_ptr.as_mut()[1].copy_from_slice(&state.y().to_bytes_le()); + state_ptr[0].copy_from_slice(&state.x().to_bytes_le()); + state_ptr[1].copy_from_slice(&state.y().to_bytes_le()); } /// Compute `ec_state_add_mul(state, scalar, point)` and store the state back. @@ -395,26 +383,26 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_add( /// definitely unsafe to use manually. #[no_mangle] pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_add_mul( - mut state_ptr: NonNull<[[u8; 32]; 4]>, - scalar_ptr: NonNull<[u8; 32]>, - point_ptr: NonNull<[[u8; 32]; 2]>, + state_ptr: &mut [[u8; 32]; 4], + scalar_ptr: &[u8; 32], + point_ptr: &[[u8; 32]; 2], ) { // Here the points should already be checked as valid, so we can use unchecked. let mut state = ProjectivePoint::from_affine_unchecked( - Felt::from_bytes_le(&state_ptr.as_ref()[0]), - Felt::from_bytes_le(&state_ptr.as_ref()[1]), + Felt::from_bytes_le(&state_ptr[0]), + Felt::from_bytes_le(&state_ptr[1]), ); let point = ProjectivePoint::from_affine_unchecked( - Felt::from_bytes_le(&point_ptr.as_ref()[0]), - Felt::from_bytes_le(&point_ptr.as_ref()[1]), + Felt::from_bytes_le(&point_ptr[0]), + Felt::from_bytes_le(&point_ptr[1]), ); - let scalar = Felt::from_bytes_le(scalar_ptr.as_ref()); + let scalar = Felt::from_bytes_le(scalar_ptr); state += &point.mul(scalar); let state = state.to_affine().unwrap(); - state_ptr.as_mut()[0].copy_from_slice(&state.x().to_bytes_le()); - state_ptr.as_mut()[1].copy_from_slice(&state.y().to_bytes_le()); + state_ptr[0].copy_from_slice(&state.x().to_bytes_le()); + state_ptr[1].copy_from_slice(&state.y().to_bytes_le()); } /// Compute `ec_state_try_finalize_nz(state)` and store the result. @@ -429,17 +417,17 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_add_mul( /// definitely unsafe to use manually. #[no_mangle] pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_try_finalize_nz( - mut point_ptr: NonNull<[[u8; 32]; 2]>, - state_ptr: NonNull<[[u8; 32]; 4]>, + point_ptr: &mut [[u8; 32]; 2], + state_ptr: &[[u8; 32]; 4], ) -> bool { // We use unchecked methods because the inputs must already be valid points. let state = ProjectivePoint::from_affine_unchecked( - Felt::from_bytes_le(&state_ptr.as_ref()[0]), - Felt::from_bytes_le(&state_ptr.as_ref()[1]), + Felt::from_bytes_le(&state_ptr[0]), + Felt::from_bytes_le(&state_ptr[1]), ); let random = ProjectivePoint::from_affine_unchecked( - Felt::from_bytes_le(&state_ptr.as_ref()[2]), - Felt::from_bytes_le(&state_ptr.as_ref()[3]), + Felt::from_bytes_le(&state_ptr[2]), + Felt::from_bytes_le(&state_ptr[3]), ); if state.x() == random.x() && state.y() == random.y() { @@ -448,8 +436,8 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_try_finalize_nz( let point = &state - &random; let point = point.to_affine().unwrap(); - point_ptr.as_mut()[0].copy_from_slice(&point.x().to_bytes_le()); - point_ptr.as_mut()[1].copy_from_slice(&point.y().to_bytes_le()); + point_ptr[0].copy_from_slice(&point.x().to_bytes_le()); + point_ptr[1].copy_from_slice(&point.y().to_bytes_le()); true } diff --git a/src/arch.rs b/src/arch.rs index bbe58b6b0..12d76b865 100644 --- a/src/arch.rs +++ b/src/arch.rs @@ -2,6 +2,7 @@ use crate::{ error, starknet::{ArrayAbi, Secp256k1Point, Secp256r1Point}, types::TypeBuilder, + utils::libc_malloc, values::Value, }; use bumpalo::Bump; @@ -63,7 +64,7 @@ impl<'a> AbiArgument for JitValueWithInfoWrapper<'a> { let layout = self.registry.get_type(&info.ty)?.layout(self.registry)?; let heap_ptr = unsafe { - let heap_ptr = libc::malloc(layout.size()); + let heap_ptr = libc_malloc(layout.size()); libc::memcpy(heap_ptr, ptr.as_ptr().cast(), layout.size()); heap_ptr }; @@ -78,7 +79,7 @@ impl<'a> AbiArgument for JitValueWithInfoWrapper<'a> { let layout = self.registry.get_type(&info.ty)?.layout(self.registry)?; let heap_ptr = unsafe { - let heap_ptr = libc::malloc(layout.size()); + let heap_ptr = libc_malloc(layout.size()); libc::memcpy(heap_ptr, ptr.as_ptr().cast(), layout.size()); heap_ptr }; diff --git a/src/executor.rs b/src/executor.rs index 9373d5c2a..a2feb256a 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -10,7 +10,7 @@ use crate::{ execution_result::{BuiltinStats, ExecutionResult}, starknet::{handler::StarknetSyscallHandlerCallbacks, StarknetSyscallHandler}, types::TypeBuilder, - utils::RangeExt, + utils::{libc_free, RangeExt}, values::Value, }; use bumpalo::Bump; @@ -340,6 +340,9 @@ fn invoke_dynamic( debug_name: None, }); + #[cfg(feature = "with-mem-tracing")] + crate::utils::mem_tracing::report_stats(); + Ok(ExecutionResult { remaining_gas, return_value, @@ -378,7 +381,7 @@ fn parse_result( CoreTypeConcrete::Box(info) => unsafe { let ptr = return_ptr.unwrap_or(NonNull::new_unchecked(ret_registers[0] as *mut ())); let value = Value::from_ptr(ptr, &info.ty, registry)?; - libc::free(ptr.cast().as_ptr()); + libc_free(ptr.cast().as_ptr()); Ok(value) }, CoreTypeConcrete::EcPoint(_) | CoreTypeConcrete::EcState(_) => { @@ -496,7 +499,7 @@ fn parse_result( } else { let ptr = NonNull::new_unchecked(ptr); let value = Value::from_ptr(ptr, &info.ty, registry)?; - libc::free(ptr.as_ptr().cast()); + libc_free(ptr.as_ptr().cast()); Ok(value) } }, diff --git a/src/executor/contract.rs b/src/executor/contract.rs index 2b471e424..6b19a4008 100644 --- a/src/executor/contract.rs +++ b/src/executor/contract.rs @@ -40,7 +40,9 @@ use crate::{ module::NativeModule, starknet::{handler::StarknetSyscallHandlerCallbacks, StarknetSyscallHandler}, types::TypeBuilder, - utils::{decode_error_message, generate_function_name, get_integer_layout}, + utils::{ + decode_error_message, generate_function_name, get_integer_layout, libc_free, libc_malloc, + }, OptLevel, }; use bumpalo::Bump; @@ -272,7 +274,7 @@ impl AotContractExecutor { } let felt_layout = get_integer_layout(252).pad_to_align(); - let ptr: *mut () = unsafe { libc::malloc(felt_layout.size() * args.len()).cast() }; + let ptr: *mut () = unsafe { libc_malloc(felt_layout.size() * args.len()).cast() }; let len: u32 = args.len().try_into().unwrap(); ptr.to_bytes(&mut invoke_data)?; @@ -419,7 +421,7 @@ impl AotContractExecutor { } if !array_ptr.is_null() { - unsafe { libc::free(array_ptr.cast()) }; + unsafe { libc_free(array_ptr.cast()) }; } let mut error_msg = None; @@ -436,6 +438,9 @@ impl AotContractExecutor { error_msg = Some(str_error); } + #[cfg(feature = "with-mem-tracing")] + crate::utils::mem_tracing::report_stats(); + Ok(ContractExecutionResult { remaining_gas, failure_flag: tag != 0, diff --git a/src/libfuncs/array.rs b/src/libfuncs/array.rs index fe87547af..5d8e38edd 100644 --- a/src/libfuncs/array.rs +++ b/src/libfuncs/array.rs @@ -15,7 +15,7 @@ use crate::{ use cairo_lang_sierra::{ extensions::{ array::{ArrayConcreteLibfunc, ConcreteMultiPopLibfunc}, - core::{CoreLibfunc, CoreType}, + core::{CoreLibfunc, CoreType, CoreTypeConcrete}, lib_func::{SignatureAndTypeConcreteLibfunc, SignatureOnlyConcreteLibfunc}, ConcreteLibfunc, }, @@ -35,7 +35,6 @@ use melior::{ }, Context, }; -use std::ops::Deref; /// Select and call the correct libfunc builder function from the selector. pub fn build<'ctx, 'this>( @@ -380,6 +379,21 @@ pub fn build_len<'ctx, 'this>( let array_len = entry.append_op_result(arith::subi(array_end, array_start, location))?; + match metadata.get::() { + Some(drop_overrides_meta) + if drop_overrides_meta.is_overriden(&info.signature.param_signatures[0].ty) => + { + drop_overrides_meta.invoke_override( + context, + entry, + location, + &info.signature.param_signatures[0].ty, + entry.argument(0)?.into(), + )?; + } + _ => {} + } + entry.append_operation(helper.br(0, &[array_len], location)); Ok(()) } @@ -578,7 +592,18 @@ pub fn build_get<'ctx, 'this>( valid_block.append_operation(helper.br(0, &[range_check, target_ptr], location)); } + metadata + .get::() + .unwrap() + .invoke_override( + context, + error_block, + location, + &info.param_signatures()[1].ty, + value, + )?; error_block.append_operation(helper.br(1, &[range_check], location)); + Ok(()) } @@ -692,8 +717,105 @@ pub fn build_pop_front_consume<'ctx, 'this>( metadata: &mut MetadataStorage, info: &SignatureAndTypeConcreteLibfunc, ) -> Result<()> { - // Equivalent to `array_pop_front_consume` for our purposes. - build_pop_front(context, registry, entry, location, helper, metadata, info) + if metadata.get::().is_none() { + metadata.insert(ReallocBindingsMeta::new(context, helper)); + } + + let array_ty = registry.build_type( + context, + helper, + registry, + metadata, + &info.param_signatures()[0].ty, + )?; + + let elem_ty = registry.get_type(&info.ty)?; + let elem_layout = elem_ty.layout(registry)?; + + let ptr_ty = crate::ffi::get_struct_field_type_at(&array_ty, 0); + let len_ty = crate::ffi::get_struct_field_type_at(&array_ty, 1); + + let value = entry.argument(0)?.into(); + + let array_start = entry.extract_value(context, location, value, len_ty, 1)?; + let array_end = entry.extract_value(context, location, value, len_ty, 2)?; + + let is_empty = entry.append_op_result(arith::cmpi( + context, + CmpiPredicate::Eq, + array_start, + array_end, + location, + ))?; + + let valid_block = helper.append_block(Block::new(&[])); + let empty_block = helper.append_block(Block::new(&[])); + entry.append_operation(cf::cond_br( + context, + is_empty, + empty_block, + valid_block, + &[], + &[], + location, + )); + + { + let ptr = valid_block.extract_value(context, location, value, ptr_ty, 0)?; + + let elem_size = valid_block.const_int(context, location, elem_layout.size(), 64)?; + let elem_offset = valid_block.append_op_result(arith::extui( + array_start, + IntegerType::new(context, 64).into(), + location, + ))?; + let elem_offset = + valid_block.append_op_result(arith::muli(elem_offset, elem_size, location))?; + let ptr = valid_block.append_op_result(llvm::get_element_ptr_dynamic( + context, + ptr, + &[elem_offset], + IntegerType::new(context, 8).into(), + llvm::r#type::pointer(context, 0), + location, + ))?; + + let target_ptr = valid_block.append_op_result( + ods::llvm::mlir_zero(context, pointer(context, 0), location).into(), + )?; + let target_ptr = valid_block.append_op_result(ReallocBindingsMeta::realloc( + context, target_ptr, elem_size, location, + ))?; + assert_nonnull( + context, + valid_block, + location, + target_ptr, + "realloc returned nullptr", + )?; + + valid_block.memcpy(context, location, ptr, target_ptr, elem_size); + + let k1 = valid_block.const_int(context, location, 1, 32)?; + let new_start = valid_block.append_op_result(arith::addi(array_start, k1, location))?; + let value = valid_block.insert_value(context, location, value, new_start, 1)?; + + valid_block.append_operation(helper.br(0, &[value, target_ptr], location)); + } + + metadata + .get::() + .unwrap() + .invoke_override( + context, + empty_block, + location, + &info.param_signatures()[0].ty, + value, + )?; + empty_block.append_operation(helper.br(1, &[], location)); + + Ok(()) } /// Generate MLIR operations for the `array_snapshot_pop_front` libfunc. @@ -1071,60 +1193,46 @@ pub fn build_slice<'ctx, 'this>( metadata: &mut MetadataStorage, info: &SignatureAndTypeConcreteLibfunc, ) -> Result<()> { - if metadata.get::().is_none() { - metadata.insert(ReallocBindingsMeta::new(context, helper)); - } - let range_check = super::increment_builtin_counter(context, entry, location, entry.argument(0)?.into())?; - let array_ty = registry.build_type( - context, - helper, - registry, - metadata, - &info.param_signatures()[1].ty, - )?; - - let len_ty = crate::ffi::get_struct_field_type_at(&array_ty, 1); + let len_ty = IntegerType::new(context, 32).into(); let elem_ty = registry.get_type(&info.ty)?; let elem_layout = elem_ty.layout(registry)?; - let slice_since = entry.argument(2)?.into(); - let slice_length = entry.argument(3)?.into(); - - let slice_until = entry.append_op_result(arith::addi(slice_since, slice_length, location))?; - let array_start = entry.extract_value(context, location, entry.argument(1)?.into(), len_ty, 1)?; let array_end = entry.extract_value(context, location, entry.argument(1)?.into(), len_ty, 2)?; - let slice_since = entry.append_op_result(arith::addi(slice_since, array_start, location))?; - let slice_until = entry.append_op_result(arith::addi(slice_until, array_start, location))?; + let slice_start = entry.argument(2)?.into(); + let slice_len = entry.argument(3)?.into(); + let slice_end = entry.append_op_result(arith::addi(slice_start, slice_len, location))?; + let slice_start = entry.append_op_result(arith::addi(array_start, slice_start, location))?; + let slice_end = entry.append_op_result(arith::addi(array_start, slice_end, location))?; let lhs_bound = entry.append_op_result(arith::cmpi( context, CmpiPredicate::Uge, - slice_since, + slice_start, array_start, location, ))?; let rhs_bound = entry.append_op_result(arith::cmpi( context, CmpiPredicate::Ule, - slice_until, + slice_end, array_end, location, ))?; - let is_fully_contained = entry.append_op_result(arith::andi(lhs_bound, rhs_bound, location))?; + let is_valid = entry.append_op_result(arith::andi(lhs_bound, rhs_bound, location))?; let slice_block = helper.append_block(Block::new(&[])); let error_block = helper.append_block(Block::new(&[])); entry.append_operation(cf::cond_br( context, - is_fully_contained, + is_valid, slice_block, error_block, &[], @@ -1133,65 +1241,106 @@ pub fn build_slice<'ctx, 'this>( )); { - let elem_size = - slice_block.const_int(context, location, elem_layout.pad_to_align().size(), 64)?; - let dst_size = slice_block.append_op_result(arith::extui( - slice_length, - IntegerType::new(context, 64).into(), - location, - ))?; - let dst_size = slice_block.append_op_result(arith::muli(dst_size, elem_size, location))?; + let elem_ty = elem_ty.build(context, helper, registry, metadata, &info.ty)?; - let dst_ptr = slice_block.append_op_result( - ods::llvm::mlir_zero(context, pointer(context, 0), location).into(), - )?; - let dst_ptr = slice_block.append_op_result(ReallocBindingsMeta::realloc( - context, dst_ptr, dst_size, location, - ))?; + let value = entry.argument(1)?.into(); + let value = slice_block.insert_value(context, location, value, slice_start, 1)?; + let value = slice_block.insert_value(context, location, value, slice_end, 2)?; - // TODO: Find out if we need to clone stuff using the snapshot clone meta. - let src_offset = { - let slice_since = slice_block.append_op_result(arith::extui( - slice_since, + let elem_stride = + slice_block.const_int(context, location, elem_layout.pad_to_align().size(), 64)?; + let prepare = |value| { + let value = slice_block.append_op_result(arith::extui( + value, IntegerType::new(context, 64).into(), location, ))?; - - slice_block.append_op_result(arith::muli(slice_since, elem_size, location))? + slice_block.append_op_result(arith::muli(value, elem_stride, location)) }; - let src_ptr = slice_block.extract_value( + let ptr = slice_block.extract_value( context, location, entry.argument(1)?.into(), - pointer(context, 0), + llvm::r#type::pointer(context, 0), 0, )?; - let src_ptr = slice_block.append_op_result(llvm::get_element_ptr_dynamic( - context, - src_ptr, - &[src_offset], - IntegerType::new(context, 8).into(), - llvm::r#type::pointer(context, 0), - location, - ))?; + let make_region = |drop_overrides_meta: &DropOverridesMeta| { + let region = Region::new(); + let block = region.append_block(Block::new(&[( + IntegerType::new(context, 64).into(), + location, + )])); + + let value_ptr = block.append_op_result(llvm::get_element_ptr_dynamic( + context, + ptr, + &[block.argument(0)?.into()], + IntegerType::new(context, 8).into(), + llvm::r#type::pointer(context, 0), + location, + ))?; - slice_block.memcpy(context, location, src_ptr, dst_ptr, dst_size); + let value = block.load(context, location, value_ptr, elem_ty)?; + drop_overrides_meta.invoke_override(context, &block, location, &info.ty, value)?; - let k0 = slice_block.const_int_from_type(context, location, 0, len_ty)?; + block.append_operation(scf::r#yield(&[], location)); + Result::Ok(region) + }; + + let array_start = prepare(array_start)?; + let array_end = prepare(array_end)?; + let slice_start = prepare(slice_start)?; + let slice_end = prepare(slice_end)?; - let value = slice_block.append_op_result(llvm::undef(array_ty, location))?; - let value = slice_block.insert_values( + match metadata.get::() { + Some(drop_overrides_meta) if drop_overrides_meta.is_overriden(&info.ty) => { + slice_block.append_operation(scf::r#for( + array_start, + slice_start, + elem_stride, + make_region(drop_overrides_meta)?, + location, + )); + slice_block.append_operation(scf::r#for( + slice_end, + array_end, + elem_stride, + make_region(drop_overrides_meta)?, + location, + )); + } + _ => {} + }; + + slice_block.append_operation(helper.br(0, &[range_check, value], location)); + } + + { + registry.build_type( context, - location, - value, - &[dst_ptr, k0, slice_length, slice_length], + helper, + registry, + metadata, + &info.signature.param_signatures[1].ty, )?; - slice_block.append_operation(helper.br(0, &[range_check, value], location)); + // The following unwrap is unreachable because an array always has a drop implementation, + // which at this point is always inserted thanks to the `build_type()` just above. + metadata + .get::() + .unwrap() + .invoke_override( + context, + error_block, + location, + &info.signature.param_signatures[1].ty, + entry.argument(1)?.into(), + )?; + + error_block.append_operation(helper.br(1, &[range_check], location)); } - error_block.append_operation(helper.br(1, &[range_check], location)); Ok(()) } @@ -1206,7 +1355,6 @@ pub fn build_span_from_tuple<'ctx, 'this>( info: &SignatureAndTypeConcreteLibfunc, ) -> Result<()> { // tuple to array span (t,t,t) -> &[t,t,t] - if metadata.get::().is_none() { metadata.insert(ReallocBindingsMeta::new(context, helper)); } @@ -1322,48 +1470,59 @@ pub fn build_tuple_from_span<'ctx, 'this>( metadata: &mut MetadataStorage, info: &SignatureAndTypeConcreteLibfunc, ) -> Result<()> { - // Libfunc Signature: - // - // (Snapshot>) -> Box> + // Tasks: + // - Check if sizes match. + // - If they do not match, jump to branch [1]. + // - If they do match: + // - If start == 0 && capacity == len -> reuse the pointer + // - Otherwise, realloc + memcpy + free. if metadata.get::().is_none() { metadata.insert(ReallocBindingsMeta::new(context, helper)); } - // Get type information - - let elem_core_ty = registry.get_type(&info.signature.param_signatures[0].ty)?; - let elem_layout = elem_core_ty.layout(registry)?; - let elem_stride = elem_layout.pad_to_align().size(); - let tuple_len = registry - .get_type(&info.ty)? - .fields() - .expect("should be a struct (ergo, has fields)") - .len(); + let array_ty = registry.get_type(&info.signature.param_signatures[0].ty)?; + let (elem_id, elem_ty) = match array_ty { + CoreTypeConcrete::Array(info) => (&info.ty, registry.get_type(&info.ty)?), + CoreTypeConcrete::Snapshot(info) => match registry.get_type(&info.ty)? { + CoreTypeConcrete::Array(info) => (&info.ty, registry.get_type(&info.ty)?), + _ => unreachable!(), + }, + _ => unreachable!(), + }; + let elem_layout = elem_ty.layout(registry)?; - let u32_ty = IntegerType::new(context, 32).into(); + let array_start = entry.extract_value( + context, + location, + entry.argument(0)?.into(), + IntegerType::new(context, 32).into(), + 1, + )?; + let array_end = entry.extract_value( + context, + location, + entry.argument(0)?.into(), + IntegerType::new(context, 32).into(), + 2, + )?; - // Get array information + let array_len = entry.append_op_result(arith::subi(array_end, array_start, location))?; + let (tuple_len, tuple_len_val) = { + let fields = registry.get_type(&info.ty)?.fields().unwrap(); + assert!(fields.iter().all(|f| f.id == elem_id.id)); - let array_value = entry.argument(0)?.into(); - let array_ptr = entry.extract_value(context, location, array_value, pointer(context, 0), 0)?; - let array_start = entry.extract_value(context, location, array_value, u32_ty, 1)?; - let array_end = entry.extract_value(context, location, array_value, u32_ty, 2)?; - let array_capacity = entry.extract_value(context, location, array_value, u32_ty, 3)?; - - // Check if conversion is valid - // - // if array.end - array.start != tuple_len { - // return err; - // } + ( + entry.const_int(context, location, fields.len(), 32)?, + fields.len(), + ) + }; - let array_len = entry.append_op_result(arith::subi(array_end, array_start, location))?; - let tuple_len_value = entry.const_int(context, location, tuple_len, 32)?; - let array_len_matches = entry.append_op_result(arith::cmpi( + let len_matches = entry.append_op_result(arith::cmpi( context, CmpiPredicate::Eq, array_len, - tuple_len_value, + tuple_len, location, ))?; @@ -1371,7 +1530,7 @@ pub fn build_tuple_from_span<'ctx, 'this>( let block_err = helper.append_block(Block::new(&[])); entry.append_operation(cf::cond_br( context, - array_len_matches, + len_matches, block_ok, block_err, &[], @@ -1379,108 +1538,132 @@ pub fn build_tuple_from_span<'ctx, 'this>( location, )); - // Check if pointer can be passed through, that is - // if array.start == 0 && array.capacity == tuple_len - - let is_pointer_passthrough = { + { let k0 = block_ok.const_int(context, location, 0, 32)?; - let array_since_is_zero = block_ok.append_op_result(arith::cmpi( + let starts_at_zero = block_ok.append_op_result(arith::cmpi( context, CmpiPredicate::Eq, array_start, k0, location, ))?; - let array_cap_matches = block_ok.append_op_result(arith::cmpi( + + let array_cap = block_ok.extract_value( + context, + location, + entry.argument(0)?.into(), + IntegerType::new(context, 32).into(), + 3, + )?; + let capacity_matches = block_ok.append_op_result(arith::cmpi( context, CmpiPredicate::Eq, - array_capacity, - tuple_len_value, + array_cap, + tuple_len, location, ))?; - block_ok.append_op_result(arith::andi( - array_since_is_zero, - array_cap_matches, + let array_ptr = block_ok.extract_value( + context, location, - ))? - }; - - let box_ptr = block_ok.append_op_result(scf::r#if( - is_pointer_passthrough, - &[llvm::r#type::pointer(context, 0)], - { - let region = Region::new(); - let block = region.append_block(Block::new(&[])); - - // If can be passed through, just return the array ptr + entry.argument(0)?.into(), + llvm::r#type::pointer(context, 0), + 0, + )?; + let should_forward_pointer = + block_ok.append_op_result(arith::andi(starts_at_zero, capacity_matches, location))?; - block.append_operation(scf::r#yield(&[array_ptr], location)); + let block_clone = helper.append_block(Block::new(&[])); + let block_forward = helper.append_block(Block::new(&[])); + block_ok.append_operation(cf::cond_br( + context, + should_forward_pointer, + block_forward, + block_clone, + &[], + &[], + location, + )); - region - }, { - let region = Region::new(); - let block = region.append_block(Block::new(&[])); - - // Otherwise, alloc memory for the returned tuple and clone it - - let tuple_len_value = block.const_int(context, location, tuple_len, 64)?; - let elem_stride_value = block.const_int(context, location, elem_stride, 64)?; - let tuple_len_bytes = block.append_op_result(arith::muli( - tuple_len_value, - elem_stride_value, + let elem_stride = + block_clone.const_int(context, location, elem_layout.pad_to_align().size(), 64)?; + let tuple_len = block_clone.append_op_result(arith::extui( + tuple_len, + IntegerType::new(context, 64).into(), location, ))?; + let tuple_len = + block_clone.append_op_result(arith::muli(tuple_len, elem_stride, location))?; - let tuple_ptr = { - let null_ptr = block - .append_op_result(llvm::zero(llvm::r#type::pointer(context, 0), location))?; - let tuple_ptr = block.append_op_result(ReallocBindingsMeta::realloc( - context, - null_ptr, - tuple_len_bytes, - location, - ))?; - - assert_nonnull( - context, - block.deref(), - location, - tuple_ptr, - "realloc returned null", - )?; - - tuple_ptr - }; + let box_ptr = block_clone + .append_op_result(llvm::zero(llvm::r#type::pointer(context, 0), location))?; + let box_ptr = block_clone.append_op_result(ReallocBindingsMeta::realloc( + context, box_ptr, tuple_len, location, + ))?; - let array_start = block.append_op_result(arith::extui( + let elem_offset = block_clone.append_op_result(arith::extui( array_start, IntegerType::new(context, 64).into(), location, ))?; - let array_start_offset = - block.append_op_result(arith::muli(elem_stride_value, array_start, location))?; - - let src_ptr = block.append_op_result(llvm::get_element_ptr_dynamic( + let elem_offset = + block_clone.append_op_result(arith::muli(elem_offset, elem_stride, location))?; + let elem_ptr = block_clone.append_op_result(llvm::get_element_ptr_dynamic( context, array_ptr, - &[array_start_offset], + &[elem_offset], IntegerType::new(context, 8).into(), llvm::r#type::pointer(context, 0), location, ))?; - block.memcpy(context, location, src_ptr, tuple_ptr, tuple_len_bytes); - block.append_operation(scf::r#yield(&[tuple_ptr], location)); + block_clone.append_operation( + ods::llvm::intr_memcpy_inline( + context, + box_ptr, + elem_ptr, + IntegerAttribute::new( + IntegerType::new(context, 64).into(), + (tuple_len_val * elem_layout.pad_to_align().size()) as i64, + ), + IntegerAttribute::new(IntegerType::new(context, 1).into(), 0), + location, + ) + .into(), + ); - region - }, - location, - ))?; + block_clone.append_operation(ReallocBindingsMeta::free(context, array_ptr, location)); + block_clone.append_operation(helper.br(0, &[box_ptr], location)); + } + + block_forward.append_operation(helper.br(0, &[array_ptr], location)); + } - block_ok.append_operation(helper.br(0, &[box_ptr], location)); - block_err.append_operation(helper.br(1, &[], location)); + { + registry.build_type( + context, + helper, + registry, + metadata, + &info.signature.param_signatures[0].ty, + )?; + + // The following unwrap is unreachable because an array always has a drop implementation, + // which at this point is always inserted thanks to the `build_type()` just above. + metadata + .get::() + .unwrap() + .invoke_override( + context, + block_err, + location, + &info.signature.param_signatures[0].ty, + entry.argument(0)?.into(), + )?; + + block_err.append_operation(helper.br(1, &[], location)); + } Ok(()) } @@ -1751,18 +1934,18 @@ mod test { use box::BoxTrait; fn run_test() -> u32 { - let mut data: Array = ArrayTrait::new(); + let mut data: Array = ArrayTrait::new(); // Alloca (freed). data.append(1_u32); data.append(2_u32); data.append(3_u32); data.append(4_u32); - let sp = data.span(); + let sp = data.span(); // Alloca (leaked). let slice = sp.slice(1, 2); data.append(5_u32); data.append(5_u32); data.append(5_u32); data.append(5_u32); - data.append(5_u32); + data.append(5_u32); // Realloc (freed). data.append(5_u32); *slice.get(1).unwrap().unbox() } diff --git a/src/libfuncs/box.rs b/src/libfuncs/box.rs index f7840f772..9150162b5 100644 --- a/src/libfuncs/box.rs +++ b/src/libfuncs/box.rs @@ -12,10 +12,7 @@ use cairo_lang_sierra::{ extensions::{ boxing::BoxConcreteLibfunc, core::{CoreLibfunc, CoreType}, - lib_func::{ - BranchSignature, LibfuncSignature, SignatureAndTypeConcreteLibfunc, - SignatureOnlyConcreteLibfunc, - }, + lib_func::SignatureAndTypeConcreteLibfunc, }, program_registry::ProgramRegistry, }; @@ -152,37 +149,16 @@ pub fn build_unbox<'ctx, 'this>( } fn build_forward_snapshot<'ctx, 'this>( - context: &'ctx Context, - registry: &ProgramRegistry, + _context: &'ctx Context, + _registry: &ProgramRegistry, entry: &'this Block<'ctx>, location: Location<'ctx>, helper: &LibfuncHelper<'ctx, 'this>, - metadata: &mut MetadataStorage, - info: &SignatureAndTypeConcreteLibfunc, + _metadata: &mut MetadataStorage, + _info: &SignatureAndTypeConcreteLibfunc, ) -> Result<()> { - super::snapshot_take::build( - context, - registry, - entry, - location, - helper, - metadata, - &SignatureOnlyConcreteLibfunc { - signature: LibfuncSignature { - param_signatures: info.signature.param_signatures.clone(), - branch_signatures: info - .signature - .branch_signatures - .iter() - .map(|x| BranchSignature { - vars: x.vars.clone(), - ap_change: x.ap_change.clone(), - }) - .collect(), - fallthrough: info.signature.fallthrough, - }, - }, - ) + entry.append_operation(helper.br(0, &[entry.argument(0)?.into()], location)); + Ok(()) } #[cfg(test)] diff --git a/src/libfuncs/debug.rs b/src/libfuncs/debug.rs index ed279aba2..a22ebb9ae 100644 --- a/src/libfuncs/debug.rs +++ b/src/libfuncs/debug.rs @@ -12,8 +12,10 @@ use super::LibfuncHelper; use crate::{ error::Result, - metadata::{runtime_bindings::RuntimeBindingsMeta, MetadataStorage}, - utils::BlockExt, + metadata::{ + drop_overrides::DropOverridesMeta, runtime_bindings::RuntimeBindingsMeta, MetadataStorage, + }, + utils::{BlockExt, ProgramRegistryExt}, }; use cairo_lang_sierra::{ extensions::{ @@ -47,12 +49,12 @@ pub fn build<'ctx>( pub fn build_print<'ctx>( context: &'ctx Context, - _registry: &ProgramRegistry, + registry: &ProgramRegistry, entry: &Block<'ctx>, location: Location<'ctx>, helper: &LibfuncHelper<'ctx, '_>, metadata: &mut MetadataStorage, - _info: &SignatureOnlyConcreteLibfunc, + info: &SignatureOnlyConcreteLibfunc, ) -> Result<()> { let stdout_fd = entry.const_int(context, location, 1, 32)?; @@ -105,6 +107,19 @@ pub fn build_print<'ctx>( context, helper, entry, stdout_fd, values_ptr, values_len, location, )?; + let input_ty = &info.signature.param_signatures[0].ty; + registry.build_type(context, helper, registry, metadata, input_ty)?; + metadata + .get::() + .unwrap() + .invoke_override( + context, + entry, + location, + input_ty, + entry.argument(0)?.into(), + )?; + let k0 = entry.const_int(context, location, 0, 32)?; let return_code_is_ok = entry.append_op_result(arith::cmpi( context, diff --git a/src/libfuncs/felt252.rs b/src/libfuncs/felt252.rs index 348ec476e..351650ce4 100644 --- a/src/libfuncs/felt252.rs +++ b/src/libfuncs/felt252.rs @@ -377,15 +377,8 @@ pub mod test { lazy_static! { static ref FELT252_ADD: (String, Program) = load_cairo! { - use core::debug::PrintTrait; fn run_test(lhs: felt252, rhs: felt252) -> felt252 { - lhs.print(); - rhs.print(); - let result = lhs + rhs; - - result.print(); - - result + lhs + rhs } }; diff --git a/src/libfuncs/felt252_dict.rs b/src/libfuncs/felt252_dict.rs index ce4977079..4fdc609b3 100644 --- a/src/libfuncs/felt252_dict.rs +++ b/src/libfuncs/felt252_dict.rs @@ -56,8 +56,7 @@ pub fn build_new<'ctx, 'this>( .get_mut::() .expect("Runtime library not available."); - let op = runtime_bindings.dict_new(context, helper, entry, location)?; - let dict_ptr = op.result(0)?.into(); + let dict_ptr = runtime_bindings.dict_new(context, helper, entry, location)?; entry.append_operation(helper.br(0, &[segment_arena, dict_ptr], location)); Ok(()) diff --git a/src/libfuncs/felt252_dict_entry.rs b/src/libfuncs/felt252_dict_entry.rs index 1fb38cf0f..bc09d6c05 100644 --- a/src/libfuncs/felt252_dict_entry.rs +++ b/src/libfuncs/felt252_dict_entry.rs @@ -4,11 +4,12 @@ use super::LibfuncHelper; use crate::{ error::Result, metadata::{ + drop_overrides::DropOverridesMeta, dup_overrides::DupOverridesMeta, realloc_bindings::ReallocBindingsMeta, runtime_bindings::RuntimeBindingsMeta, MetadataStorage, }, types::TypeBuilder, - utils::{get_integer_layout, BlockExt, ProgramRegistryExt}, + utils::{BlockExt, ProgramRegistryExt}, }; use cairo_lang_sierra::{ extensions::{ @@ -20,11 +21,8 @@ use cairo_lang_sierra::{ program_registry::ProgramRegistry, }; use melior::{ - dialect::{cf, llvm}, - ir::{ - attribute::IntegerAttribute, operation::OperationBuilder, r#type::IntegerType, Block, - Identifier, Location, Value, ValueLike, - }, + dialect::{cf, llvm, ods}, + ir::{attribute::IntegerAttribute, r#type::IntegerType, Block, Location}, Context, }; @@ -57,10 +55,6 @@ pub fn build_get<'ctx, 'this>( metadata: &mut MetadataStorage, info: &SignatureAndTypeConcreteLibfunc, ) -> Result<()> { - if metadata.get::().is_none() { - metadata.insert(ReallocBindingsMeta::new(context, helper)); - } - let (key_ty, key_layout) = registry.build_type_with_layout( context, helper, @@ -68,7 +62,6 @@ pub fn build_get<'ctx, 'this>( metadata, &info.param_signatures()[1].ty, )?; - let entry_ty = registry.build_type( context, helper, @@ -76,169 +69,209 @@ pub fn build_get<'ctx, 'this>( metadata, &info.branch_signatures()[0].vars[0].ty, )?; - - let (value_ty, value_layout) = registry.build_type_with_layout( - context, - helper, - registry, - metadata, - &info.branch_signatures()[0].vars[1].ty, - )?; + let value_ty = registry.build_type(context, helper, registry, metadata, &info.ty)?; let dict_ptr = entry.argument(0)?.into(); - let key_value = entry.argument(1)?.into(); + let entry_key = entry.argument(1)?.into(); - let key_ptr = helper - .init_block() - .alloca1(context, location, key_ty, key_layout.align())?; + let entry_key_ptr = + helper + .init_block() + .alloca1(context, location, key_ty, key_layout.align())?; + entry.store(context, location, entry_key_ptr, entry_key)?; - entry.store(context, location, key_ptr, key_value)?; - - let runtime_bindings = metadata + // Double pointer. Avoid allocating an element on a dict getter. + let entry_value_ptr_ptr = metadata .get_mut::() - .expect("Runtime library not available."); - - let op = runtime_bindings.dict_get(context, helper, entry, dict_ptr, key_ptr, location)?; - let result_ptr: Value = op.result(0)?.into(); - - let null_ptr = entry.append_op_result( - OperationBuilder::new("llvm.mlir.zero", location) - .add_results(&[result_ptr.r#type()]) - .build()?, + .unwrap() + .dict_get(context, helper, entry, dict_ptr, entry_key_ptr, location)?; + let entry_value_ptr = entry.load( + context, + location, + entry_value_ptr_ptr, + llvm::r#type::pointer(context, 0), )?; - // need llvm instead of arith to compare pointers - let is_null_ptr = entry.append_op_result( - OperationBuilder::new("llvm.icmp", location) - .add_operands(&[result_ptr, null_ptr]) - .add_attributes(&[( - Identifier::new(context, "predicate"), - IntegerAttribute::new(IntegerType::new(context, 64).into(), 0).into(), - )]) - .add_results(&[IntegerType::new(context, 1).into()]) - .build()?, + let null_ptr = + entry.append_op_result(llvm::zero(llvm::r#type::pointer(context, 0), location))?; + let is_vacant = entry.append_op_result( + ods::llvm::icmp( + context, + IntegerType::new(context, 1).into(), + entry_value_ptr, + null_ptr, + IntegerAttribute::new(IntegerType::new(context, 64).into(), 0).into(), + location, + ) + .into(), )?; - let block_is_null = helper.append_block(Block::new(&[])); - let block_is_found = helper.append_block(Block::new(&[])); - let block_final = helper.append_block(Block::new(&[ - (llvm::r#type::pointer(context, 0), location), - (value_ty, location), - ])); - + let block_occupied = helper.append_block(Block::new(&[])); + let block_vacant = helper.append_block(Block::new(&[])); + let block_final = helper.append_block(Block::new(&[(value_ty, location)])); entry.append_operation(cf::cond_br( context, - is_null_ptr, - block_is_null, - block_is_found, + is_vacant, + block_vacant, + block_occupied, &[], &[], location, )); - // null block { - let alloc_size = block_is_null.const_int(context, location, value_layout.size(), 64)?; + let value = block_occupied.load(context, location, entry_value_ptr, value_ty)?; + let values = match metadata.get::() { + Some(dup_overrides_meta) if dup_overrides_meta.is_overriden(&info.ty) => { + dup_overrides_meta.invoke_override( + context, + block_occupied, + location, + &info.ty, + value, + )? + } + _ => (value, value), + }; - let value_ptr = block_is_null.append_op_result(ReallocBindingsMeta::realloc( - context, result_ptr, alloc_size, location, - ))?; + block_occupied.store(context, location, entry_value_ptr, values.0)?; + block_occupied.append_operation(cf::br(block_final, &[values.1], location)); + } - let default_value = registry + { + let value = registry .get_type(&info.branch_signatures()[0].vars[1].ty)? .build_default( context, registry, - block_is_null, + block_vacant, location, helper, metadata, &info.branch_signatures()[0].vars[1].ty, )?; - - block_is_null.append_operation(cf::br(block_final, &[value_ptr, default_value], location)); + block_vacant.append_operation(cf::br(block_final, &[value], location)); } - // found block - { - let loaded_val_ptr = block_is_found.load(context, location, result_ptr, value_ty)?; - block_is_found.append_operation(cf::br( - block_final, - &[result_ptr, loaded_val_ptr], - location, - )); - } - - // construct the struct - - let entry_value = block_final.append_op_result(llvm::undef(entry_ty, location))?; - - let value_ptr = block_final.argument(0)?.into(); - let value = block_final.argument(1)?.into(); - - let entry_value = block_final.insert_value(context, location, entry_value, key_value, 0)?; - - let entry_value = block_final.insert_value(context, location, entry_value, value_ptr, 1)?; - - let entry_value = block_final.insert_value(context, location, entry_value, dict_ptr, 2)?; - - block_final.append_operation(helper.br(0, &[entry_value, value], location)); + let entry = block_final.append_op_result(llvm::undef(entry_ty, location))?; + let entry = + block_final.insert_values(context, location, entry, &[dict_ptr, entry_value_ptr_ptr])?; + block_final.append_operation(helper.br(0, &[entry, block_final.argument(0)?.into()], location)); Ok(()) } pub fn build_finalize<'ctx, 'this>( context: &'ctx Context, - _registry: &ProgramRegistry, + registry: &ProgramRegistry, entry: &'this Block<'ctx>, location: Location<'ctx>, helper: &LibfuncHelper<'ctx, 'this>, metadata: &mut MetadataStorage, - _info: &SignatureAndTypeConcreteLibfunc, + info: &SignatureAndTypeConcreteLibfunc, ) -> Result<()> { - let key_ty = IntegerType::new(context, 252).into(); - let key_layout = get_integer_layout(252); + if metadata.get::().is_none() { + metadata.insert(ReallocBindingsMeta::new(context, helper)); + } - let entry_value = entry.argument(0)?.into(); - let new_value = entry.argument(1)?.into(); + let (value_ty, value_layout) = registry.build_type_with_layout( + context, + helper, + registry, + metadata, + &info.signature.param_signatures[1].ty, + )?; - let key_value = entry.extract_value(context, location, entry_value, key_ty, 0)?; + let dict_entry = entry.argument(0)?.into(); + let entry_value = entry.argument(1)?.into(); - let value_ptr = entry.extract_value( + let dict_ptr = entry.extract_value( context, location, - entry_value, + dict_entry, + llvm::r#type::pointer(context, 0), + 0, + )?; + let value_ptr_ptr = entry.extract_value( + context, + location, + dict_entry, llvm::r#type::pointer(context, 0), 1, )?; - let dict_ptr = entry.extract_value( + let value_ptr = entry.load( context, location, - entry_value, + value_ptr_ptr, llvm::r#type::pointer(context, 0), - 2, )?; - entry.store(context, location, value_ptr, new_value)?; + let null_ptr = + entry.append_op_result(llvm::zero(llvm::r#type::pointer(context, 0), location))?; + let is_vacant = entry.append_op_result( + ods::llvm::icmp( + context, + IntegerType::new(context, 1).into(), + value_ptr, + null_ptr, + IntegerAttribute::new(IntegerType::new(context, 64).into(), 0).into(), + location, + ) + .into(), + )?; + + let block_occupied = helper.append_block(Block::new(&[])); + let block_vacant = helper.append_block(Block::new(&[])); + let block_final = + helper.append_block(Block::new(&[(llvm::r#type::pointer(context, 0), location)])); + entry.append_operation(cf::cond_br( + context, + is_vacant, + block_vacant, + block_occupied, + &[], + &[], + location, + )); - let key_ptr = helper - .init_block() - .alloca1(context, location, key_ty, key_layout.align())?; + { + match metadata.get::() { + Some(drop_overrides_meta) + if drop_overrides_meta.is_overriden(&info.signature.param_signatures[1].ty) => + { + let value = block_occupied.load(context, location, value_ptr, value_ty)?; + drop_overrides_meta.invoke_override( + context, + block_occupied, + location, + &info.signature.param_signatures[1].ty, + value, + )?; + } + _ => {} + } - entry.store(context, location, key_ptr, key_value)?; + block_occupied.append_operation(cf::br(block_final, &[value_ptr], location)); + } - // call insert + { + let value_len = block_vacant.const_int(context, location, value_layout.size(), 64)?; + let value_ptr = block_vacant.append_op_result(ReallocBindingsMeta::realloc( + context, null_ptr, value_len, location, + ))?; - let runtime_bindings = metadata - .get_mut::() - .expect("Runtime library not available."); + block_vacant.store(context, location, value_ptr_ptr, value_ptr)?; + block_vacant.append_operation(cf::br(block_final, &[value_ptr], location)); + } - runtime_bindings.dict_insert( - context, helper, entry, dict_ptr, key_ptr, value_ptr, location, + block_final.store( + context, + location, + block_final.argument(0)?.into(), + entry_value, )?; - - entry.append_operation(helper.br(0, &[dict_ptr], location)); + block_final.append_operation(helper.br(0, &[dict_ptr], location)); Ok(()) } diff --git a/src/libfuncs/nullable.rs b/src/libfuncs/nullable.rs index b1fed17fc..974dcef29 100644 --- a/src/libfuncs/nullable.rs +++ b/src/libfuncs/nullable.rs @@ -7,10 +7,7 @@ use crate::{error::Result, metadata::MetadataStorage, utils::BlockExt}; use cairo_lang_sierra::{ extensions::{ core::{CoreLibfunc, CoreType}, - lib_func::{ - BranchSignature, LibfuncSignature, SignatureAndTypeConcreteLibfunc, - SignatureOnlyConcreteLibfunc, - }, + lib_func::{SignatureAndTypeConcreteLibfunc, SignatureOnlyConcreteLibfunc}, nullable::NullableConcreteLibfunc, }, program_registry::ProgramRegistry, @@ -132,37 +129,16 @@ fn build_match_nullable<'ctx, 'this>( } fn build_forward_snapshot<'ctx, 'this>( - context: &'ctx Context, - registry: &ProgramRegistry, + _context: &'ctx Context, + _registry: &ProgramRegistry, entry: &'this Block<'ctx>, location: Location<'ctx>, helper: &LibfuncHelper<'ctx, 'this>, - metadata: &mut MetadataStorage, - info: &SignatureAndTypeConcreteLibfunc, + _metadata: &mut MetadataStorage, + _info: &SignatureAndTypeConcreteLibfunc, ) -> Result<()> { - super::snapshot_take::build( - context, - registry, - entry, - location, - helper, - metadata, - &SignatureOnlyConcreteLibfunc { - signature: LibfuncSignature { - param_signatures: info.signature.param_signatures.clone(), - branch_signatures: info - .signature - .branch_signatures - .iter() - .map(|x| BranchSignature { - vars: x.vars.clone(), - ap_change: x.ap_change.clone(), - }) - .collect(), - fallthrough: info.signature.fallthrough, - }, - }, - ) + entry.append_operation(helper.br(0, &[entry.argument(0)?.into()], location)); + Ok(()) } #[cfg(test)] diff --git a/src/libfuncs/starknet.rs b/src/libfuncs/starknet.rs index 51bc798c1..c4ca0552d 100644 --- a/src/libfuncs/starknet.rs +++ b/src/libfuncs/starknet.rs @@ -4,7 +4,7 @@ use super::LibfuncHelper; use crate::{ error::Result, ffi::get_struct_field_type_at, - metadata::MetadataStorage, + metadata::{drop_overrides::DropOverridesMeta, MetadataStorage}, starknet::handler::StarknetSyscallHandlerCallbacks, utils::{get_integer_layout, BlockExt, ProgramRegistryExt, PRIME}, }; @@ -24,9 +24,7 @@ use melior::{ llvm::{self, LoadStoreOptions}, }, ir::{ - attribute::{ - DenseI32ArrayAttribute, DenseI64ArrayAttribute, IntegerAttribute, TypeAttribute, - }, + attribute::{DenseI32ArrayAttribute, DenseI64ArrayAttribute, TypeAttribute}, operation::OperationBuilder, r#type::IntegerType, Attribute, Block, Identifier, Location, Type, ValueLike, @@ -2994,7 +2992,7 @@ pub fn build_sha256_process_block_syscall<'ctx, 'this>( .into(); // Allocate space for the return value. - let (result_layout, (result_tag_ty, result_tag_layout), variant_tys) = + let (result_layout, (result_tag_ty, _), variant_tys) = crate::types::r#enum::get_type_for_variants( context, helper, @@ -3006,50 +3004,22 @@ pub fn build_sha256_process_block_syscall<'ctx, 'this>( ], )?; - let k1 = helper - .init_block() - .append_operation(arith::constant( + let result_ptr = helper.init_block().alloca1( + context, + location, + llvm::r#type::r#struct( context, - IntegerAttribute::new(IntegerType::new(context, 64).into(), 1).into(), - location, - )) - .result(0)? - .into(); - let result_ptr = helper - .init_block() - .append_operation( - OperationBuilder::new("llvm.alloca", location) - .add_attributes(&[ - ( - Identifier::new(context, "alignment"), - IntegerAttribute::new( - IntegerType::new(context, 64).into(), - result_layout.align().try_into()?, - ) - .into(), - ), - ( - Identifier::new(context, "elem_type"), - TypeAttribute::new(llvm::r#type::r#struct( - context, - &[ - result_tag_ty, - llvm::r#type::array( - IntegerType::new(context, 8).into(), - (result_layout.size() - 1).try_into()?, - ), - ], - false, - )) - .into(), - ), - ]) - .add_operands(&[k1]) - .add_results(&[llvm::r#type::pointer(context, 0)]) - .build()?, - ) - .result(0)? - .into(); + &[ + result_tag_ty, + llvm::r#type::array( + IntegerType::new(context, 8).into(), + (result_layout.size() - 1).try_into()?, + ), + ], + false, + ), + result_layout.align(), + )?; // Allocate space and write the current gas. let gas_builtin_ptr = helper.init_block().alloca1( @@ -3107,118 +3077,51 @@ pub fn build_sha256_process_block_syscall<'ctx, 'this>( .build()?, ); - let result = entry - .append_operation(llvm::load( - context, - result_ptr, - llvm::r#type::r#struct( - context, - &[ - result_tag_ty, - llvm::r#type::array( - IntegerType::new(context, 8).into(), - (result_layout.size() - 1).try_into()?, - ), - ], - false, - ), - location, - LoadStoreOptions::default(), - )) - .result(0)? - .into(); - let result_tag = entry - .append_operation(llvm::extract_value( + registry.build_type( + context, + helper, + registry, + metadata, + &info.signature.param_signatures[3].ty, + )?; + metadata + .get::() + .unwrap() + .invoke_override( context, - result, - DenseI64ArrayAttribute::new(context, &[0]), - IntegerType::new(context, 1).into(), + entry, location, - )) - .result(0)? - .into(); + &info.signature.param_signatures[3].ty, + sha256_current_block_ptr, + )?; + + let result_tag = entry.load(context, location, result_ptr, result_tag_ty)?; let payload_ok = { - let ptr = entry - .append_operation( - OperationBuilder::new("llvm.getelementptr", location) - .add_attributes(&[ - ( - Identifier::new(context, "rawConstantIndices"), - DenseI32ArrayAttribute::new( - context, - &[result_tag_layout.extend(variant_tys[0].1)?.1.try_into()?], - ) - .into(), - ), - ( - Identifier::new(context, "elem_type"), - TypeAttribute::new(IntegerType::new(context, 8).into()).into(), - ), - ]) - .add_operands(&[result_ptr]) - .add_results(&[llvm::r#type::pointer(context, 0)]) - .build()?, - ) - .result(0)? - .into(); - entry - .append_operation(llvm::load( - context, - ptr, - variant_tys[0].0, - location, - LoadStoreOptions::default(), - )) - .result(0)? - .into() + let value = entry.load( + context, + location, + result_ptr, + llvm::r#type::r#struct(context, &[result_tag_ty, variant_tys[0].0], false), + )?; + entry.extract_value(context, location, value, variant_tys[0].0, 1)? }; let payload_err = { - let ptr = entry - .append_operation( - OperationBuilder::new("llvm.getelementptr", location) - .add_attributes(&[ - ( - Identifier::new(context, "rawConstantIndices"), - DenseI32ArrayAttribute::new( - context, - &[result_tag_layout.extend(variant_tys[1].1)?.1.try_into()?], - ) - .into(), - ), - ( - Identifier::new(context, "elem_type"), - TypeAttribute::new(IntegerType::new(context, 8).into()).into(), - ), - ]) - .add_operands(&[result_ptr]) - .add_results(&[llvm::r#type::pointer(context, 0)]) - .build()?, - ) - .result(0)? - .into(); - entry - .append_operation(llvm::load( - context, - ptr, - variant_tys[1].0, - location, - LoadStoreOptions::default(), - )) - .result(0)? - .into() - }; - - let remaining_gas = entry - .append_operation(llvm::load( + let value = entry.load( context, - gas_builtin_ptr, - IntegerType::new(context, 128).into(), location, - LoadStoreOptions::default(), - )) - .result(0)? - .into(); + result_ptr, + llvm::r#type::r#struct(context, &[result_tag_ty, variant_tys[1].0], false), + )?; + entry.extract_value(context, location, value, variant_tys[1].0, 1)? + }; + + let remaining_gas = entry.load( + context, + location, + gas_builtin_ptr, + IntegerType::new(context, 128).into(), + )?; entry.append_operation(helper.cond_br( context, diff --git a/src/metadata/realloc_bindings.rs b/src/metadata/realloc_bindings.rs index 1bd5793f3..369bc709f 100644 --- a/src/metadata/realloc_bindings.rs +++ b/src/metadata/realloc_bindings.rs @@ -4,39 +4,34 @@ //! compilation context. use melior::{ - dialect::{func, llvm}, + dialect::llvm, ir::{ attribute::{FlatSymbolRefAttribute, StringAttribute, TypeAttribute}, - r#type::{FunctionType, IntegerType}, + operation::OperationBuilder, + r#type::IntegerType, Identifier, Location, Module, Operation, Region, Value, }, Context, }; -use std::marker::PhantomData; /// Memory allocation `realloc` metadata. #[derive(Debug)] -pub struct ReallocBindingsMeta { - phantom: PhantomData<()>, -} +pub struct ReallocBindingsMeta; impl ReallocBindingsMeta { /// Register the bindings to the `realloc` C function and return the metadata. pub fn new(context: &Context, module: &Module) -> Self { - module.body().append_operation(func::func( + module.body().append_operation(llvm::func( context, StringAttribute::new(context, "realloc"), - TypeAttribute::new( - FunctionType::new( - context, - &[ - llvm::r#type::pointer(context, 0), - IntegerType::new(context, 64).into(), - ], - &[llvm::r#type::pointer(context, 0)], - ) - .into(), - ), + TypeAttribute::new(llvm::r#type::function( + llvm::r#type::pointer(context, 0), + &[ + llvm::r#type::pointer(context, 0), + IntegerType::new(context, 64).into(), + ], + false, + )), Region::new(), &[( Identifier::new(context, "sym_visibility"), @@ -44,12 +39,14 @@ impl ReallocBindingsMeta { )], Location::unknown(context), )); - module.body().append_operation(func::func( + module.body().append_operation(llvm::func( context, StringAttribute::new(context, "free"), - TypeAttribute::new( - FunctionType::new(context, &[llvm::r#type::pointer(context, 0)], &[]).into(), - ), + TypeAttribute::new(llvm::r#type::function( + llvm::r#type::void(context), + &[llvm::r#type::pointer(context, 0)], + false, + )), Region::new(), &[( Identifier::new(context, "sym_visibility"), @@ -58,9 +55,7 @@ impl ReallocBindingsMeta { Location::unknown(context), )); - Self { - phantom: PhantomData, - } + Self } /// Calls the `realloc` function, returns a op with 1 result: an opaque pointer. @@ -70,13 +65,15 @@ impl ReallocBindingsMeta { len: Value<'c, 'a>, location: Location<'c>, ) -> Operation<'c> { - func::call( - context, - FlatSymbolRefAttribute::new(context, "realloc"), - &[ptr, len], - &[llvm::r#type::pointer(context, 0)], - location, - ) + OperationBuilder::new("llvm.call", location) + .add_attributes(&[( + Identifier::new(context, "callee"), + FlatSymbolRefAttribute::new(context, "realloc").into(), + )]) + .add_operands(&[ptr, len]) + .add_results(&[llvm::r#type::pointer(context, 0)]) + .build() + .unwrap() } /// Calls the `free` function. @@ -85,12 +82,13 @@ impl ReallocBindingsMeta { ptr: Value<'c, '_>, location: Location<'c>, ) -> Operation<'c> { - func::call( - context, - FlatSymbolRefAttribute::new(context, "free"), - &[ptr], - &[], - location, - ) + OperationBuilder::new("llvm.call", location) + .add_attributes(&[( + Identifier::new(context, "callee"), + FlatSymbolRefAttribute::new(context, "free").into(), + )]) + .add_operands(&[ptr]) + .build() + .unwrap() } } diff --git a/src/metadata/runtime_bindings.rs b/src/metadata/runtime_bindings.rs index 5e6ba6ace..78bf2eca0 100644 --- a/src/metadata/runtime_bindings.rs +++ b/src/metadata/runtime_bindings.rs @@ -5,7 +5,7 @@ use crate::{error::Result, utils::BlockExt}; use melior::{ - dialect::{func, llvm}, + dialect::{func, llvm, ods}, ir::{ attribute::{FlatSymbolRefAttribute, StringAttribute, TypeAttribute}, r#type::{FunctionType, IntegerType}, @@ -472,7 +472,7 @@ impl RuntimeBindingsMeta { module: &Module, block: &'a Block<'c>, location: Location<'c>, - ) -> Result> + ) -> Result> where 'c: 'a, { @@ -481,7 +481,12 @@ impl RuntimeBindingsMeta { context, StringAttribute::new(context, "cairo_native__dict_new"), TypeAttribute::new( - FunctionType::new(context, &[], &[llvm::r#type::pointer(context, 0)]).into(), + FunctionType::new( + context, + &[llvm::r#type::pointer(context, 0)], + &[llvm::r#type::pointer(context, 0)], + ) + .into(), ), Region::new(), &[( @@ -492,13 +497,23 @@ impl RuntimeBindingsMeta { )); } - Ok(block.append_operation(func::call( + let free_fn = block.append_op_result( + ods::llvm::mlir_addressof( + context, + llvm::r#type::pointer(context, 0), + FlatSymbolRefAttribute::new(context, "free"), + location, + ) + .into(), + )?; + + block.append_op_result(func::call( context, FlatSymbolRefAttribute::new(context, "cairo_native__dict_new"), - &[], + &[free_fn], &[llvm::r#type::pointer(context, 0)], location, - ))) + )) } /// Register if necessary, then invoke the `dict_alloc_new()` function. @@ -620,7 +635,7 @@ impl RuntimeBindingsMeta { dict_ptr: Value<'c, 'a>, // ptr to the dict key_ptr: Value<'c, 'a>, // key must be a ptr to Felt location: Location<'c>, - ) -> Result> + ) -> Result> where 'c: 'a, { @@ -648,13 +663,13 @@ impl RuntimeBindingsMeta { )); } - Ok(block.append_operation(func::call( + block.append_op_result(func::call( context, FlatSymbolRefAttribute::new(context, "cairo_native__dict_get"), &[dict_ptr, key_ptr], &[llvm::r#type::pointer(context, 0)], location, - ))) + )) } /// Register if necessary, then invoke the `dict_insert()` function. diff --git a/src/starknet.rs b/src/starknet.rs index 0661f305c..d1a7c6ddf 100644 --- a/src/starknet.rs +++ b/src/starknet.rs @@ -286,10 +286,10 @@ pub trait StarknetSyscallHandler { fn sha256_process_block( &mut self, - prev_state: &[u32; 8], - current_block: &[u32; 16], + state: &mut [u32; 8], + block: &[u32; 16], remaining_gas: &mut u128, - ) -> SyscallResult<[u32; 8]>; + ) -> SyscallResult<()>; #[cfg(feature = "with-cheatcode")] fn cheatcode(&mut self, _selector: Felt, _input: &[Felt]) -> Vec { @@ -485,10 +485,10 @@ impl StarknetSyscallHandler for DummySyscallHandler { fn sha256_process_block( &mut self, - _prev_state: &[u32; 8], - _current_block: &[u32; 16], + _state: &mut [u32; 8], + _block: &[u32; 16], _remaining_gas: &mut u128, - ) -> SyscallResult<[u32; 8]> { + ) -> SyscallResult<()> { unimplemented!() } } @@ -496,8 +496,10 @@ impl StarknetSyscallHandler for DummySyscallHandler { // TODO: Move to the correct place or remove if unused. pub(crate) mod handler { use super::*; + use crate::utils::{libc_free, libc_malloc}; use std::{ alloc::Layout, + ffi::c_void, fmt::Debug, mem::{size_of, ManuallyDrop, MaybeUninit}, ptr::{null_mut, NonNull}, @@ -762,8 +764,8 @@ pub(crate) mod handler { result_ptr: &mut SyscallResultAbi<*mut [u32; 8]>, ptr: &mut T, gas: &mut u128, - prev_state: &[u32; 8], - current_block: &[u32; 16], + state: *mut [u32; 8], + block: &[u32; 16], ), // testing syscalls #[cfg(feature = "with-cheatcode")] @@ -853,8 +855,7 @@ pub(crate) mod handler { capacity: 0, }, _ => { - let ptr = - libc::malloc(Layout::array::(data.len()).unwrap().size()) as *mut E; + let ptr = libc_malloc(Layout::array::(data.len()).unwrap().size()) as *mut E; let len: u32 = data.len().try_into().unwrap(); for (i, val) in data.iter().enumerate() { @@ -909,7 +910,8 @@ pub(crate) mod handler { selector: &Felt252Abi, input: &ArrayAbi, ) { - let input: Vec<_> = unsafe { + let selector = Felt::from_bytes_le(&selector.0); + let input_vec: Vec<_> = unsafe { let since_offset = input.since as usize; let until_offset = input.until as usize; debug_assert!(since_offset <= until_offset); @@ -919,10 +921,13 @@ pub(crate) mod handler { .iter() .map(|x| Felt::from_bytes_le(&x.0)) .collect(); - let selector = Felt::from_bytes_le(&selector.0); + + unsafe { + libc_free(input.ptr as *mut c_void); + } let result = ptr - .cheatcode(selector, &input) + .cheatcode(selector, &input_vec) .into_iter() .map(|x| Felt252Abi(x.to_bytes_le())) .collect::>(); @@ -942,20 +947,19 @@ pub(crate) mod handler { ok: ManuallyDrop::new(SyscallResultAbiOk { tag: 0u8, payload: unsafe { - let mut block_info_ptr = - NonNull::new( - libc::malloc(size_of::()) as *mut BlockInfoAbi - ) - .unwrap(); + let mut block_info_ptr = NonNull::new(libc_malloc( + size_of::(), + ) + as *mut BlockInfoAbi) + .unwrap(); block_info_ptr.as_mut().block_number = x.block_info.block_number; block_info_ptr.as_mut().block_timestamp = x.block_info.block_timestamp; block_info_ptr.as_mut().sequencer_address = Felt252Abi(x.block_info.sequencer_address.to_bytes_le()); - let mut tx_info_ptr = NonNull::new( - libc::malloc(size_of::()) as *mut TxInfoAbi, - ) - .unwrap(); + let mut tx_info_ptr = + NonNull::new(libc_malloc(size_of::()) as *mut TxInfoAbi) + .unwrap(); tx_info_ptr.as_mut().version = Felt252Abi(x.tx_info.version.to_bytes_le()); tx_info_ptr.as_mut().account_contract_address = @@ -975,7 +979,7 @@ pub(crate) mod handler { tx_info_ptr.as_mut().nonce = Felt252Abi(x.tx_info.nonce.to_bytes_le()); let mut execution_info_ptr = - NonNull::new(libc::malloc(size_of::()) + NonNull::new(libc_malloc(size_of::()) as *mut ExecutionInfoAbi) .unwrap(); execution_info_ptr.as_mut().block_info = block_info_ptr; @@ -1008,22 +1012,22 @@ pub(crate) mod handler { tag: 0u8, payload: unsafe { let mut execution_info_ptr = - NonNull::new(libc::malloc(size_of::()) + NonNull::new(libc_malloc(size_of::()) as *mut ExecutionInfoV2Abi) .unwrap(); - let mut block_info_ptr = - NonNull::new( - libc::malloc(size_of::()) as *mut BlockInfoAbi - ) - .unwrap(); + let mut block_info_ptr = NonNull::new(libc_malloc( + size_of::(), + ) + as *mut BlockInfoAbi) + .unwrap(); block_info_ptr.as_mut().block_number = x.block_info.block_number; block_info_ptr.as_mut().block_timestamp = x.block_info.block_timestamp; block_info_ptr.as_mut().sequencer_address = Felt252Abi(x.block_info.sequencer_address.to_bytes_le()); let mut tx_info_ptr = NonNull::new( - libc::malloc(size_of::()) as *mut TxInfoV2Abi, + libc_malloc(size_of::()) as *mut TxInfoV2Abi, ) .unwrap(); tx_info_ptr.as_mut().version = @@ -1113,7 +1117,7 @@ pub(crate) mod handler { data }); - let calldata: Vec<_> = unsafe { + let calldata_vec: Vec<_> = unsafe { let since_offset = calldata.since as usize; let until_offset = calldata.until as usize; debug_assert!(since_offset <= until_offset); @@ -1133,10 +1137,14 @@ pub(crate) mod handler { }) .collect(); + unsafe { + libc_free(calldata.ptr as *mut c_void); + } + let result = ptr.deploy( class_hash, contract_address_salt, - &calldata, + &calldata_vec, deploy_from_zero, gas, ); @@ -1199,7 +1207,7 @@ pub(crate) mod handler { data }); - let calldata: Vec<_> = unsafe { + let calldata_vec: Vec<_> = unsafe { let since_offset = calldata.since as usize; let until_offset = calldata.until as usize; debug_assert!(since_offset <= until_offset); @@ -1219,7 +1227,11 @@ pub(crate) mod handler { }) .collect(); - let result = ptr.library_call(class_hash, function_selector, &calldata, gas); + unsafe { + libc_free(calldata.ptr as *mut c_void); + } + + let result = ptr.library_call(class_hash, function_selector, &calldata_vec, gas); *result_ptr = match result { Ok(x) => { @@ -1255,7 +1267,7 @@ pub(crate) mod handler { data }); - let calldata: Vec<_> = unsafe { + let calldata_vec: Vec<_> = unsafe { let since_offset = calldata.since as usize; let until_offset = calldata.until as usize; debug_assert!(since_offset <= until_offset); @@ -1275,7 +1287,11 @@ pub(crate) mod handler { }) .collect(); - let result = ptr.call_contract(address, entry_point_selector, &calldata, gas); + unsafe { + libc_free(calldata.ptr as *mut c_void); + } + + let result = ptr.call_contract(address, entry_point_selector, &calldata_vec, gas); *result_ptr = match result { Ok(x) => { @@ -1355,7 +1371,7 @@ pub(crate) mod handler { keys: &ArrayAbi, data: &ArrayAbi, ) { - let keys: Vec<_> = unsafe { + let keys_vec: Vec<_> = unsafe { let since_offset = keys.since as usize; let until_offset = keys.until as usize; debug_assert!(since_offset <= until_offset); @@ -1375,7 +1391,11 @@ pub(crate) mod handler { }) .collect(); - let data: Vec<_> = unsafe { + unsafe { + libc_free(keys.ptr as *mut c_void); + } + + let data_vec: Vec<_> = unsafe { let since_offset = data.since as usize; let until_offset = data.until as usize; debug_assert!(since_offset <= until_offset); @@ -1395,7 +1415,11 @@ pub(crate) mod handler { }) .collect(); - let result = ptr.emit_event(&keys, &data, gas); + unsafe { + libc_free(data.ptr as *mut c_void); + } + + let result = ptr.emit_event(&keys_vec, &data_vec, gas); *result_ptr = match result { Ok(_) => SyscallResultAbi { @@ -1420,7 +1444,7 @@ pub(crate) mod handler { data.reverse(); data }); - let payload: Vec<_> = unsafe { + let payload_vec: Vec<_> = unsafe { let since_offset = payload.since as usize; let until_offset = payload.until as usize; debug_assert!(since_offset <= until_offset); @@ -1440,7 +1464,11 @@ pub(crate) mod handler { }) .collect(); - let result = ptr.send_message_to_l1(to_address, &payload, gas); + unsafe { + libc_free(payload.ptr as *mut c_void); + } + + let result = ptr.send_message_to_l1(to_address, &payload_vec, gas); *result_ptr = match result { Ok(_) => SyscallResultAbi { @@ -1459,7 +1487,7 @@ pub(crate) mod handler { gas: &mut u128, input: &ArrayAbi, ) { - let input = unsafe { + let input_vec = unsafe { let since_offset = input.since as usize; let until_offset = input.until as usize; debug_assert!(since_offset <= until_offset); @@ -1470,7 +1498,10 @@ pub(crate) mod handler { } }; - let result = ptr.keccak(input, gas); + let result = ptr.keccak(input_vec, gas); + unsafe { + libc_free(input.ptr as *mut c_void); + } *result_ptr = match result { Ok(x) => SyscallResultAbi { @@ -1716,26 +1747,16 @@ pub(crate) mod handler { result_ptr: &mut SyscallResultAbi<*mut [u32; 8]>, ptr: &mut T, gas: &mut u128, - prev_state: &[u32; 8], - current_block: &[u32; 16], + state: *mut [u32; 8], + block: &[u32; 16], ) { - let result = ptr.sha256_process_block(prev_state, current_block, gas); + let result = ptr.sha256_process_block(unsafe { &mut *state }, block, gas); *result_ptr = match result { Ok(x) => SyscallResultAbi { ok: ManuallyDrop::new(SyscallResultAbiOk { tag: 0u8, - payload: ManuallyDrop::new({ - unsafe { - let data = libc::malloc(std::mem::size_of_val(&x)).cast(); - std::ptr::copy_nonoverlapping::( - x.as_ptr().cast(), - data, - x.len(), - ); - data.cast() - } - }), + payload: ManuallyDrop::new(state), }), }, Err(e) => Self::wrap_error(&e), diff --git a/src/starknet_stub.rs b/src/starknet_stub.rs index c5ccdb665..bf25d22c1 100644 --- a/src/starknet_stub.rs +++ b/src/starknet_stub.rs @@ -634,19 +634,18 @@ impl StarknetSyscallHandler for &mut StubSyscallHandler { fn sha256_process_block( &mut self, - prev_state: &[u32; 8], - current_block: &[u32; 16], + state: &mut [u32; 8], + block: &[u32; 16], _remaining_gas: &mut u128, - ) -> SyscallResult<[u32; 8]> { + ) -> SyscallResult<()> { // reference impl // https://github.com/starkware-libs/cairo/blob/ba3f82b4a09972b6a24bf791e344cabce579bf69/crates/cairo-lang-runner/src/casm_run/mod.rs#L1292 - let mut state = *prev_state; let data_as_bytes = sha2::digest::generic_array::GenericArray::from_exact_iter( - current_block.iter().flat_map(|x| x.to_be_bytes()), + block.iter().flat_map(|x| x.to_be_bytes()), ) .unwrap(); - sha2::compress256(&mut state, &[data_as_bytes]); - Ok(state) + sha2::compress256(state, &[data_as_bytes]); + Ok(()) } } diff --git a/src/types/felt252_dict_entry.rs b/src/types/felt252_dict_entry.rs index 23fba74f1..ba0740cf3 100644 --- a/src/types/felt252_dict_entry.rs +++ b/src/types/felt252_dict_entry.rs @@ -4,11 +4,10 @@ //! //! It is represented as the following struct: //! -//! | Index | Type | Description | -//! | ----- | -------------- | -------------------------------- | -//! | 0 | `i252` | The entry key. | -//! | 1 | `!llvm.ptr` | Pointer to the entry value. | -//! | 2 | `!llvm.ptr` | Pointer to the dictionary (rust) | +//! | Index | Type | Description | +//! | ----- | -------------- | -------------------------------------------- | +//! | 0 | `!llvm.ptr` | Pointer to the dictionary (Rust). | +//! | 1 | `!llvm.ptr` | Pointer to the entry's value pointer (Rust). | //! use super::WithSelf; @@ -22,7 +21,7 @@ use cairo_lang_sierra::{ }; use melior::{ dialect::llvm, - ir::{r#type::IntegerType, Module, Type}, + ir::{Module, Type}, Context, }; @@ -36,12 +35,12 @@ pub fn build<'ctx>( _metadata: &mut MetadataStorage, _info: WithSelf, ) -> Result> { + // Note: This is neither droppable nor cloneable. Ok(llvm::r#type::r#struct( context, &[ - IntegerType::new(context, 252).into(), // entry key - llvm::r#type::pointer(context, 0), // value ptr - llvm::r#type::pointer(context, 0), // dict ptr + llvm::r#type::pointer(context, 0), // dict ptr + llvm::r#type::pointer(context, 0), // value ptr ], false, )) diff --git a/src/types/squashed_felt252_dict.rs b/src/types/squashed_felt252_dict.rs index 0ee19a818..a54af17b2 100644 --- a/src/types/squashed_felt252_dict.rs +++ b/src/types/squashed_felt252_dict.rs @@ -10,7 +10,6 @@ use cairo_lang_sierra::{ program_registry::ProgramRegistry, }; use melior::{ - dialect::llvm, ir::{Module, Type}, Context, }; @@ -20,10 +19,10 @@ use melior::{ /// Check out [the module](self) for more info. pub fn build<'ctx>( context: &'ctx Context, - _module: &Module<'ctx>, - _registry: &ProgramRegistry, - _metadata: &mut MetadataStorage, - _info: WithSelf, + module: &Module<'ctx>, + registry: &ProgramRegistry, + metadata: &mut MetadataStorage, + info: WithSelf, ) -> Result> { - Ok(llvm::r#type::pointer(context, 0)) + super::felt252_dict::build(context, module, registry, metadata, info) } diff --git a/src/types/starknet.rs b/src/types/starknet.rs index a8ec9d9aa..9299125c2 100644 --- a/src/types/starknet.rs +++ b/src/types/starknet.rs @@ -22,7 +22,14 @@ // TODO: Maybe the types used here can be i251 instead of i252. use super::WithSelf; -use crate::{error::Result, metadata::MetadataStorage}; +use crate::{ + error::Result, + metadata::{ + drop_overrides::DropOverridesMeta, dup_overrides::DupOverridesMeta, + realloc_bindings::ReallocBindingsMeta, MetadataStorage, + }, + utils::BlockExt, +}; use cairo_lang_sierra::{ extensions::{ core::{CoreLibfunc, CoreType}, @@ -32,8 +39,8 @@ use cairo_lang_sierra::{ program_registry::ProgramRegistry, }; use melior::{ - dialect::llvm, - ir::{r#type::IntegerType, Module, Type}, + dialect::{func, llvm, ods}, + ir::{attribute::IntegerAttribute, r#type::IntegerType, Block, Location, Module, Region, Type}, Context, }; @@ -188,11 +195,61 @@ pub fn build_secp256_point<'ctx>( pub fn build_sha256_state_handle<'ctx>( context: &'ctx Context, - _module: &Module<'ctx>, - _registry: &ProgramRegistry, - _metadata: &mut MetadataStorage, - _info: WithSelf, + module: &Module<'ctx>, + registry: &ProgramRegistry, + metadata: &mut MetadataStorage, + info: WithSelf, ) -> Result> { + let location = Location::unknown(context); + if metadata.get::().is_none() { + metadata.insert(ReallocBindingsMeta::new(context, module)); + } + + DupOverridesMeta::register_with(context, module, registry, metadata, info.self_ty(), |_| { + let region = Region::new(); + let block = + region.append_block(Block::new(&[(llvm::r#type::pointer(context, 0), location)])); + + let null_ptr = + block.append_op_result(llvm::zero(llvm::r#type::pointer(context, 0), location))?; + let k32 = block.const_int(context, location, 32, 64)?; + let new_ptr = block.append_op_result(ReallocBindingsMeta::realloc( + context, null_ptr, k32, location, + ))?; + + block.append_operation( + ods::llvm::intr_memcpy_inline( + context, + new_ptr, + block.argument(0)?.into(), + IntegerAttribute::new(IntegerType::new(context, 64).into(), 32), + IntegerAttribute::new(IntegerType::new(context, 1).into(), 0), + location, + ) + .into(), + ); + + block.append_operation(func::r#return( + &[block.argument(0)?.into(), new_ptr], + location, + )); + Ok(Some(region)) + })?; + DropOverridesMeta::register_with(context, module, registry, metadata, info.self_ty(), |_| { + let region = Region::new(); + let block = + region.append_block(Block::new(&[(llvm::r#type::pointer(context, 0), location)])); + + block.append_operation(ReallocBindingsMeta::free( + context, + block.argument(0)?.into(), + location, + )); + + block.append_operation(func::r#return(&[], location)); + Ok(Some(region)) + })?; + // A ptr to a heap (realloc) allocated [u32; 8] Ok(llvm::r#type::pointer(context, 0)) } diff --git a/src/utils.rs b/src/utils.rs index d55900d9b..993c91cd0 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -23,12 +23,12 @@ use std::{ fmt::{self, Display}, ops::Neg, path::Path, - ptr::NonNull, sync::Arc, }; use thiserror::Error; mod block_ext; +pub mod mem_tracing; mod program_registry_ext; mod range_ext; @@ -49,6 +49,15 @@ pub static HALF_PRIME: LazyLock = LazyLock::new(|| { .unwrap() }); +#[cfg(feature = "with-mem-tracing")] +#[allow(unused_imports)] +pub(crate) use self::mem_tracing::{ + _wrapped_free as libc_free, _wrapped_malloc as libc_malloc, _wrapped_realloc as libc_realloc, +}; +#[cfg(not(feature = "with-mem-tracing"))] +#[allow(unused_imports)] +pub(crate) use libc::{free as libc_free, malloc as libc_malloc, realloc as libc_realloc}; + /// Generate a function name. /// /// If the program includes function identifiers, return those. Otherwise return `f` followed by the @@ -227,6 +236,9 @@ pub fn create_engine( .unwrap() .register_impls(&engine); + #[cfg(feature = "with-mem-tracing")] + self::mem_tracing::register_bindings(&engine); + engine } @@ -329,16 +341,6 @@ pub fn register_runtime_symbols(engine: &ExecutionEngine) { as *mut (), ); - engine.register_symbol( - "cairo_native__dict_insert", - cairo_native_runtime::cairo_native__dict_insert - as *const fn( - *mut FeltDict, - &[u8; 32], - NonNull, - ) -> *mut std::ffi::c_void as *mut (), - ); - engine.register_symbol( "cairo_native__dict_gas_refund", cairo_native_runtime::cairo_native__dict_gas_refund as *const fn(*const FeltDict) -> u64 diff --git a/src/utils/mem_tracing.rs b/src/utils/mem_tracing.rs new file mode 100644 index 000000000..1a6a26b6a --- /dev/null +++ b/src/utils/mem_tracing.rs @@ -0,0 +1,131 @@ +#![cfg(feature = "with-mem-tracing")] + +use libc::{c_void, size_t}; +use melior::ExecutionEngine; +use std::cell::UnsafeCell; + +thread_local! { + static MEM_TRACING: UnsafeCell = const { UnsafeCell::new(MemTracing::new()) }; +} + +struct MemTracing { + finished: Vec, + pending: Vec, +} + +struct AllocTrace { + ptr: *mut c_void, + len: size_t, +} + +impl MemTracing { + pub const fn new() -> Self { + Self { + finished: Vec::new(), + pending: Vec::new(), + } + } + + pub fn push(&mut self, trace: AllocTrace) { + match self.pending.binary_search_by_key(&trace.ptr, |x| x.ptr) { + Ok(_) => unreachable!(), + Err(pos) => self.pending.insert(pos, trace), + } + } + + pub fn update(&mut self, ptr: *mut c_void, trace: AllocTrace) { + if let Ok(pos) = self.pending.binary_search_by_key(&ptr, |x| x.ptr) { + let trace = self.pending.remove(pos); + if trace.len == 0 { + self.finished.push(trace); + return; + } + }; + + self.push(trace); + } + + pub fn finish(&mut self, ptr: *mut c_void) { + if ptr.is_null() { + return; + } + + match self.pending.binary_search_by_key(&ptr, |x| x.ptr) { + Ok(pos) => { + let trace = self.pending.remove(pos); + self.finished.push(trace); + } + Err(_) => unreachable!(), + } + } +} + +impl AllocTrace { + pub fn new(ptr: *mut c_void, len: size_t) -> Self { + Self { ptr, len } + } +} + +pub(crate) fn register_bindings(engine: &ExecutionEngine) { + unsafe { + engine.register_symbol( + "malloc", + _wrapped_malloc as *const fn(size_t) -> *mut c_void as *mut (), + ); + engine.register_symbol( + "realloc", + _wrapped_realloc as *const fn(*mut c_void, size_t) -> *mut c_void as *mut (), + ); + engine.register_symbol("free", _wrapped_free as *const fn(*mut c_void) as *mut ()); + } +} + +pub fn report_stats() { + unsafe { + MEM_TRACING.with(|x| { + println!(); + println!("[MemTracing] Stats:"); + println!( + "[MemTracing] Freed allocations: {}", + (*x.get()).finished.len() + ); + println!("[MemTracing] Pending allocations:"); + for AllocTrace { ptr, len } in &(*x.get()).pending { + println!("[MemTracing] - {ptr:?} ({len} bytes)"); + } + + assert!((*x.get()).pending.is_empty()); + *x.get() = MemTracing::new(); + }); + } +} + +pub(crate) unsafe extern "C" fn _wrapped_malloc(len: size_t) -> *mut c_void { + let ptr = libc::malloc(len); + + println!("[MemTracing] Allocating ptr {ptr:?} with {len} bytes."); + MEM_TRACING.with(|x| (*x.get()).push(AllocTrace::new(ptr, len))); + + ptr +} + +pub(crate) unsafe extern "C" fn _wrapped_realloc(ptr: *mut c_void, len: size_t) -> *mut c_void { + let new_ptr = libc::realloc(ptr, len); + + println!("[MemTracing] Reallocating {ptr:?} into {new_ptr:?} with {len} bytes."); + MEM_TRACING.with(|x| (*x.get()).update(ptr, AllocTrace::new(new_ptr, len))); + + new_ptr +} + +pub(crate) unsafe extern "C" fn _wrapped_free(ptr: *mut c_void) { + if !ptr.is_null() { + // This print is placed before the actual call to log pointers before double free + // situations. + println!("[MemTracing] Freeing {ptr:?}."); + + libc::free(ptr); + + MEM_TRACING.with(|x| (*x.get()).finish(ptr)); + } +} diff --git a/src/values.rs b/src/values.rs index 3a0bc2f98..873b00cbb 100644 --- a/src/values.rs +++ b/src/values.rs @@ -1,12 +1,14 @@ //! # JIT params and return values de/serialization - +//! //! A Rusty interface to provide parameters to JIT calls. use crate::{ error::{CompilerError, Error}, starknet::{Secp256k1Point, Secp256r1Point}, types::TypeBuilder, - utils::{felt252_bigint, get_integer_layout, layout_repeat, RangeExt, PRIME}, + utils::{ + felt252_bigint, get_integer_layout, layout_repeat, libc_free, libc_malloc, RangeExt, PRIME, + }, }; use bumpalo::Bump; use cairo_lang_sierra::{ @@ -234,7 +236,10 @@ impl Value { let elem_ty = registry.get_type(&info.ty)?; let elem_layout = elem_ty.layout(registry)?.pad_to_align(); - let ptr: *mut () = libc::malloc(elem_layout.size() * data.len()).cast(); + let ptr: *mut () = match elem_layout.size() * data.len() { + 0 => std::ptr::null_mut(), + len => libc_malloc(len).cast(), + }; let len: u32 = data .len() .try_into() @@ -383,7 +388,11 @@ impl Value { let elem_ty = registry.get_type(&info.ty)?; let elem_layout = elem_ty.layout(registry)?.pad_to_align(); - let mut value_map = Box::::default(); + let mut value_map = Box::new(FeltDict { + inner: HashMap::default(), + count: 0, + free_fn: crate::utils::libc_free, + }); // next key must be called before next_value @@ -391,20 +400,14 @@ impl Value { let key = key.to_bytes_le(); let value = value.to_ptr(arena, registry, &info.ty)?; - let value_malloc_ptr = libc::malloc(elem_layout.size()); - + let value_malloc_ptr = libc_malloc(elem_layout.size()); std::ptr::copy_nonoverlapping( value.cast::().as_ptr(), value_malloc_ptr.cast(), elem_layout.size(), ); - value_map.inner.insert( - key, - NonNull::new(value_malloc_ptr) - .expect("allocation failure") - .cast(), - ); + value_map.inner.insert(key, value_malloc_ptr); } NonNull::new_unchecked(Box::into_raw(value_map)).cast() @@ -560,7 +563,7 @@ impl Value { } if !init_data_ptr.is_null() { - libc::free(init_data_ptr.cast()); + libc_free(init_data_ptr.cast()); } Self::Array(array_value) @@ -568,7 +571,7 @@ impl Value { CoreTypeConcrete::Box(info) => { let inner = *ptr.cast::>().as_ptr(); let value = Self::from_ptr(inner, &info.ty, registry)?; - libc::free(inner.as_ptr().cast()); + libc_free(inner.as_ptr().cast()); value } CoreTypeConcrete::EcPoint(_) => { @@ -613,7 +616,7 @@ impl Value { &info.ty, registry, )?; - libc::free(inner_ptr.cast()); + libc_free(inner_ptr.cast()); value } } @@ -693,9 +696,20 @@ impl Value { let mut output_map = HashMap::with_capacity(inner.len()); for (key, val_ptr) in inner.iter() { + if val_ptr.is_null() { + continue; + } + let key = Felt::from_bytes_le(key); - output_map.insert(key, Self::from_ptr(val_ptr.cast(), &info.ty, registry)?); - libc::free(val_ptr.as_ptr()); + output_map.insert( + key, + Self::from_ptr( + NonNull::new(*val_ptr).unwrap().cast(), + &info.ty, + registry, + )?, + ); + libc_free(*val_ptr); } Self::Felt252Dict { diff --git a/tests/tests/starknet/contracts/test_u256_order.cairo b/tests/tests/starknet/contracts/test_u256_order.cairo index 181d1242e..47a43d38c 100644 --- a/tests/tests/starknet/contracts/test_u256_order.cairo +++ b/tests/tests/starknet/contracts/test_u256_order.cairo @@ -15,12 +15,12 @@ mod Keccak { #[abi(embed_v0)] impl Keccak of super::IKeccak { fn cairo_keccak_test(self: @ContractState) -> felt252 { - let input : Array:: = array![1,2,4,5,6,6,7,2,3,4,4,5,5,6,7,7,2]; + let input: Array:: = array![1, 2, 4, 5, 6, 6, 7, 2, 3, 4, 4, 5, 5, 6, 7, 7, 2]; let output = starknet::syscalls::keccak_syscall(input.span()).unwrap(); - if output.low == 0x9293867273ef341e81577655f28aeade && output.high == 0xf70cba9bb86caa97b086fdfa3df602ed { - panic_with_felt252('arguments swapped'); - } + assert(output.low == 0xf70cba9bb86caa97b086fdfa3df602ed, 'invalid low value'); + assert(output.high == 0x9293867273ef341e81577655f28aeade, 'invalid high value'); + output.low.into() } } diff --git a/tests/tests/starknet/secp256.rs b/tests/tests/starknet/secp256.rs index a165b56f5..28a8f9ce6 100644 --- a/tests/tests/starknet/secp256.rs +++ b/tests/tests/starknet/secp256.rs @@ -250,10 +250,10 @@ impl StarknetSyscallHandler for &mut SyscallHandler { fn sha256_process_block( &mut self, - _prev_state: &[u32; 8], - _current_block: &[u32; 16], + _state: &mut [u32; 8], + _block: &[u32; 16], _remaining_gas: &mut u128, - ) -> SyscallResult<[u32; 8]> { + ) -> SyscallResult<()> { unimplemented!() } } diff --git a/tests/tests/starknet/syscalls.rs b/tests/tests/starknet/syscalls.rs index 8f88cd0eb..fc52ee5f0 100644 --- a/tests/tests/starknet/syscalls.rs +++ b/tests/tests/starknet/syscalls.rs @@ -496,11 +496,11 @@ impl StarknetSyscallHandler for SyscallHandler { fn sha256_process_block( &mut self, - prev_state: &[u32; 8], - _current_block: &[u32; 16], + _state: &mut [u32; 8], + _block: &[u32; 16], _remaining_gas: &mut u128, - ) -> SyscallResult<[u32; 8]> { - Ok(*prev_state) + ) -> SyscallResult<()> { + Ok(()) } } diff --git a/tests/tests/starknet/u256.rs b/tests/tests/starknet/u256.rs index 0557e000d..4ac77b790 100644 --- a/tests/tests/starknet/u256.rs +++ b/tests/tests/starknet/u256.rs @@ -43,6 +43,6 @@ fn u256_test() { ); assert_eq!( result.remaining_gas, - 340282366920938463463374607431768192315 + 340282366920938463463374607431768192415 ); }