diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index 354fd4554..8f0464d08 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -20,9 +20,10 @@ use std::{ ffi::{c_int, c_void}, fs::File, io::Write, - mem::ManuallyDrop, + mem::{forget, ManuallyDrop}, os::fd::FromRawFd, ptr::{self, null, null_mut}, + rc::Rc, slice, }; use std::{ops::Mul, vec::IntoIter}; @@ -166,9 +167,96 @@ pub struct FeltDict { pub layout: Layout, pub elements: *mut (), + pub dup_fn: Option, + pub drop_fn: Option, + pub count: u64, } +impl Clone for FeltDict { + fn clone(&self) -> Self { + let mut new_dict = FeltDict { + mappings: HashMap::with_capacity(self.mappings.len()), + + layout: self.layout, + elements: if self.mappings.is_empty() { + null_mut() + } else { + unsafe { + alloc(Layout::from_size_align_unchecked( + self.layout.pad_to_align().size() * self.mappings.len(), + self.layout.align(), + )) + .cast() + } + }, + + dup_fn: self.dup_fn, + drop_fn: self.drop_fn, + + // TODO: Check if `0` is fine or otherwise we should copy the value from `old_dict` too. + count: 0, + }; + + for (&key, &old_index) in self.mappings.iter() { + let old_value_ptr = unsafe { + self.elements + .byte_add(self.layout.pad_to_align().size() * old_index) + }; + + let new_index = new_dict.mappings.len(); + let new_value_ptr = unsafe { + new_dict + .elements + .byte_add(new_dict.layout.pad_to_align().size() * new_index) + }; + + new_dict.mappings.insert(key, new_index); + match self.dup_fn { + Some(dup_fn) => dup_fn(old_value_ptr.cast(), new_value_ptr.cast()), + None => unsafe { + ptr::copy_nonoverlapping::( + old_value_ptr.cast(), + new_value_ptr.cast(), + self.layout.size(), + ) + }, + } + } + + new_dict + } +} + +impl Drop for FeltDict { + fn drop(&mut self) { + // Free the entries manually. + if let Some(drop_fn) = self.drop_fn { + for (_, &index) in self.mappings.iter() { + let value_ptr = unsafe { + self.elements + .byte_add(self.layout.pad_to_align().size() * index) + }; + + drop_fn(value_ptr.cast()); + } + } + + // Free the value data. + if !self.elements.is_null() { + unsafe { + dealloc( + self.elements.cast(), + Layout::from_size_align_unchecked( + self.layout.pad_to_align().size() * self.mappings.capacity(), + self.layout.align(), + ), + ) + }; + } + } +} + /// Allocate a new dictionary. /// /// # Safety @@ -176,21 +264,27 @@ pub struct FeltDict { /// This function is intended to be called from MLIR, deals with pointers, and is therefore /// definitely unsafe to use manually. #[no_mangle] -pub unsafe extern "C" fn cairo_native__dict_new(size: u64, align: u64) -> *mut FeltDict { - Box::into_raw(Box::new(FeltDict { +pub unsafe extern "C" fn cairo_native__dict_new( + size: u64, + align: u64, + dup_fn: Option, + drop_fn: Option, +) -> *const FeltDict { + Rc::into_raw(Rc::new(FeltDict { mappings: HashMap::default(), layout: Layout::from_size_align_unchecked(size as usize, align as usize), elements: null_mut(), + dup_fn, + drop_fn, + count: 0, })) } /// Free a dictionary using an optional callback to drop each element. /// -/// The `drop_fn` callback is present when the value implements `Drop`. -/// /// # Safety /// /// This function is intended to be called from MLIR, deals with pointers, and is therefore @@ -199,88 +293,23 @@ pub unsafe extern "C" fn cairo_native__dict_new(size: u64, align: u64) -> *mut F // pointer optimization. Check out // https://doc.rust-lang.org/nomicon/ffi.html#the-nullable-pointer-optimization for more info. #[no_mangle] -pub unsafe extern "C" fn cairo_native__dict_drop( - ptr: *mut FeltDict, - drop_fn: Option, -) { - let dict = Box::from_raw(ptr); - - // Free the entries manually. - if let Some(drop_fn) = drop_fn { - for (_, &index) in dict.mappings.iter() { - let value_ptr = dict - .elements - .byte_add(dict.layout.pad_to_align().size() * index); - - drop_fn(value_ptr.cast()); - } - } - - // Free the value data. - if !dict.elements.is_null() { - dealloc( - dict.elements.cast(), - Layout::from_size_align_unchecked( - dict.layout.pad_to_align().size() * dict.mappings.capacity(), - dict.layout.align(), - ), - ); - } +pub unsafe extern "C" fn cairo_native__dict_drop(ptr: *const FeltDict) { + drop(Rc::from_raw(ptr)); } /// Duplicate a dictionary using a provided callback to clone each element. /// -/// The `dup_fn` callback is present when the value is not `Copy`, but `Clone`. The first argument -/// is the original value while the second is the target pointer. -/// /// # Safety /// /// This function is intended to be called from MLIR, deals with pointers, and is therefore /// definitely unsafe to use manually. #[no_mangle] -pub unsafe extern "C" fn cairo_native__dict_dup( - old_dict: &FeltDict, - dup_fn: Option, -) -> *mut FeltDict { - let mut new_dict = Box::new(FeltDict { - mappings: HashMap::with_capacity(old_dict.mappings.len()), - - layout: old_dict.layout, - elements: if old_dict.mappings.is_empty() { - null_mut() - } else { - alloc(Layout::from_size_align_unchecked( - old_dict.layout.pad_to_align().size() * old_dict.mappings.len(), - old_dict.layout.align(), - )) - .cast() - }, - - // TODO: Check if `0` is fine or otherwise we should copy the value from `old_dict` too. - count: 0, - }); - - for (new_index, (&key, &old_index)) in old_dict.mappings.iter().enumerate() { - let old_value_ptr = old_dict - .elements - .byte_add(old_dict.layout.pad_to_align().size() * old_index); - - let new_value_ptr = new_dict - .elements - .byte_add(new_dict.layout.pad_to_align().size() * new_index); - - new_dict.mappings.insert(key, new_index); - match dup_fn { - Some(dup_fn) => dup_fn(old_value_ptr.cast(), new_value_ptr.cast()), - None => ptr::copy_nonoverlapping::( - old_value_ptr.cast(), - new_value_ptr.cast(), - old_dict.layout.size(), - ), - } - } +pub unsafe extern "C" fn cairo_native__dict_dup(dict_ptr: *const FeltDict) -> *const FeltDict { + let old_dict = Rc::from_raw(dict_ptr); + let new_dict = Rc::clone(&old_dict); - Box::into_raw(new_dict) + forget(old_dict); + Rc::into_raw(new_dict) } /// Return a pointer to the entry's value pointer for a given key, inserting a null pointer if not @@ -295,45 +324,46 @@ pub unsafe extern "C" fn cairo_native__dict_dup( /// definitely unsafe to use manually. #[no_mangle] pub unsafe extern "C" fn cairo_native__dict_get( - dict: &mut FeltDict, + dict: *const FeltDict, key: &[u8; 32], value_ptr: *mut *mut c_void, ) -> c_int { - let mut key = *key; - key[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). - - let old_capacity = dict.mappings.capacity(); - let index = dict.mappings.len(); - let (index, is_present) = match dict.mappings.entry(key) { - Entry::Occupied(entry) => (*entry.get(), 1), - Entry::Vacant(entry) => { - entry.insert(index); + let mut dict_rc = Rc::from_raw(dict); + let dict = Rc::make_mut(&mut dict_rc); - // Reallocate `mem_data` to match the slab's capacity. - if old_capacity != dict.mappings.capacity() { - dict.elements = realloc( - dict.elements.cast(), - Layout::from_size_align_unchecked( - dict.layout.pad_to_align().size() * old_capacity, - dict.layout.align(), - ), - dict.layout.pad_to_align().size() * dict.mappings.capacity(), - ) - .cast(); - } + let num_mappings = dict.mappings.len(); + let has_capacity = num_mappings != dict.mappings.capacity(); - (index, 0) + let (is_present, index) = match dict.mappings.entry(*key) { + Entry::Occupied(entry) => (true, *entry.get()), + Entry::Vacant(entry) => { + entry.insert(num_mappings); + (false, num_mappings) } }; - value_ptr.write( - dict.elements - .byte_add(dict.layout.pad_to_align().size() * index) - .cast(), - ); + // Maybe realloc (conditions: !has_capacity && !is_present). + if !has_capacity && !is_present { + dict.elements = realloc( + dict.elements.cast(), + Layout::from_size_align_unchecked( + dict.layout.pad_to_align().size() * dict.mappings.len(), + dict.layout.align(), + ), + dict.layout.pad_to_align().size() * dict.mappings.capacity(), + ) + .cast(); + } + + *value_ptr = dict + .elements + .byte_add(dict.layout.pad_to_align().size() * index) + .cast(); + dict.count += 1; + forget(dict_rc); - is_present + is_present as c_int } /// Compute the total gas refund for the dictionary at squash time. @@ -344,8 +374,12 @@ pub unsafe extern "C" fn cairo_native__dict_get( /// definitely unsafe to use manually. #[no_mangle] pub unsafe extern "C" fn cairo_native__dict_gas_refund(ptr: *const FeltDict) -> u64 { - let dict = &*ptr; - (dict.count.saturating_sub(dict.mappings.len() as u64)) * *DICT_GAS_REFUND_PER_ACCESS + let dict = Rc::from_raw(ptr); + let amount = + (dict.count.saturating_sub(dict.mappings.len() as u64)) * *DICT_GAS_REFUND_PER_ACCESS; + + forget(dict); + amount } /// Compute `ec_point_from_x_nz(x)` and store it. @@ -865,21 +899,27 @@ mod tests { #[test] fn test_dict() { - let dict = - unsafe { cairo_native__dict_new(size_of::() as u64, align_of::() as u64) }; + let dict = unsafe { + cairo_native__dict_new( + size_of::() as u64, + align_of::() as u64, + None, + None, + ) + }; let key = Felt::ONE.to_bytes_le(); let mut ptr = null_mut::(); assert_eq!( - unsafe { cairo_native__dict_get(&mut *dict, &key, (&raw mut ptr).cast()) }, + unsafe { cairo_native__dict_get(dict, &key, (&raw mut ptr).cast()) }, 0, ); assert!(!ptr.is_null()); unsafe { *ptr = 24 }; assert_eq!( - unsafe { cairo_native__dict_get(&mut *dict, &key, (&raw mut ptr).cast()) }, + unsafe { cairo_native__dict_get(dict, &key, (&raw mut ptr).cast()) }, 1, ); assert!(!ptr.is_null()); @@ -889,17 +929,17 @@ mod tests { let refund = unsafe { cairo_native__dict_gas_refund(dict) }; assert_eq!(refund, 4050); - let cloned_dict = unsafe { cairo_native__dict_dup(&*dict, None) }; - unsafe { cairo_native__dict_drop(dict, None) }; + let cloned_dict = unsafe { cairo_native__dict_dup(&*dict) }; + unsafe { cairo_native__dict_drop(dict) }; assert_eq!( - unsafe { cairo_native__dict_get(&mut *cloned_dict, &key, (&raw mut ptr).cast()) }, + unsafe { cairo_native__dict_get(cloned_dict, &key, (&raw mut ptr).cast()) }, 1, ); assert!(!ptr.is_null()); assert_eq!(unsafe { *ptr }, 42); - unsafe { cairo_native__dict_drop(cloned_dict, None) }; + unsafe { cairo_native__dict_drop(cloned_dict) }; } #[test] diff --git a/src/arch.rs b/src/arch.rs index ceaca4cdb..0316034e9 100644 --- a/src/arch.rs +++ b/src/arch.rs @@ -15,7 +15,10 @@ use cairo_lang_sierra::{ ids::ConcreteTypeId, program_registry::ProgramRegistry, }; -use std::ptr::{null, NonNull}; +use std::{ + ffi::c_void, + ptr::{null, NonNull}, +}; mod aarch64; mod x86_64; @@ -24,7 +27,17 @@ mod x86_64; pub trait AbiArgument { /// Serialize the argument into the buffer. This method should keep track of arch-dependent /// stuff like register vs stack allocation. - fn to_bytes(&self, buffer: &mut Vec) -> Result<()>; + fn to_bytes( + &self, + buffer: &mut Vec, + find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<()>; } /// A wrapper that implements `AbiArgument` for `Value`s. It contains all the required stuff to @@ -58,10 +71,21 @@ impl<'a> ValueWithInfoWrapper<'a> { } impl AbiArgument for ValueWithInfoWrapper<'_> { - fn to_bytes(&self, buffer: &mut Vec) -> Result<()> { + fn to_bytes( + &self, + buffer: &mut Vec, + find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<()> { match (self.value, self.info) { (value, CoreTypeConcrete::Box(info)) => { - let ptr = value.to_ptr(self.arena, self.registry, self.type_id)?; + let ptr = + value.to_ptr(self.arena, self.registry, self.type_id, find_dict_overrides)?; let layout = self.registry.get_type(&info.ty)?.layout(self.registry)?; let heap_ptr = unsafe { @@ -70,13 +94,18 @@ impl AbiArgument for ValueWithInfoWrapper<'_> { heap_ptr }; - heap_ptr.to_bytes(buffer)?; + heap_ptr.to_bytes(buffer, find_dict_overrides)?; } (value, CoreTypeConcrete::Nullable(info)) => { if matches!(value, Value::Null) { - null::<()>().to_bytes(buffer)?; + null::<()>().to_bytes(buffer, find_dict_overrides)?; } else { - let ptr = value.to_ptr(self.arena, self.registry, self.type_id)?; + let ptr = value.to_ptr( + self.arena, + self.registry, + self.type_id, + find_dict_overrides, + )?; let layout = self.registry.get_type(&info.ty)?.layout(self.registry)?; let heap_ptr = unsafe { @@ -85,51 +114,64 @@ impl AbiArgument for ValueWithInfoWrapper<'_> { heap_ptr }; - heap_ptr.to_bytes(buffer)?; + heap_ptr.to_bytes(buffer, find_dict_overrides)?; } } - (value, CoreTypeConcrete::NonZero(info) | CoreTypeConcrete::Snapshot(info)) => { - self.map(value, &info.ty)?.to_bytes(buffer)? - } + (value, CoreTypeConcrete::NonZero(info) | CoreTypeConcrete::Snapshot(info)) => self + .map(value, &info.ty)? + .to_bytes(buffer, find_dict_overrides)?, (Value::Array(_), CoreTypeConcrete::Array(_)) => { // TODO: Assert that `info.ty` matches all the values' types. - let abi_ptr = self.value.to_ptr(self.arena, self.registry, self.type_id)?; + let abi_ptr = self.value.to_ptr( + self.arena, + self.registry, + self.type_id, + find_dict_overrides, + )?; let abi = unsafe { abi_ptr.cast::>().as_ref() }; - abi.ptr.to_bytes(buffer)?; - abi.since.to_bytes(buffer)?; - abi.until.to_bytes(buffer)?; - abi.capacity.to_bytes(buffer)?; + abi.ptr.to_bytes(buffer, find_dict_overrides)?; + abi.since.to_bytes(buffer, find_dict_overrides)?; + abi.until.to_bytes(buffer, find_dict_overrides)?; + abi.capacity.to_bytes(buffer, find_dict_overrides)?; } (Value::BoundedInt { .. }, CoreTypeConcrete::BoundedInt(_)) => { native_panic!("todo: implement AbiArgument for Value::BoundedInt case") } - (Value::Bytes31(value), CoreTypeConcrete::Bytes31(_)) => value.to_bytes(buffer)?, + (Value::Bytes31(value), CoreTypeConcrete::Bytes31(_)) => { + value.to_bytes(buffer, find_dict_overrides)? + } (Value::EcPoint(x, y), CoreTypeConcrete::EcPoint(_)) => { - x.to_bytes(buffer)?; - y.to_bytes(buffer)?; + x.to_bytes(buffer, find_dict_overrides)?; + y.to_bytes(buffer, find_dict_overrides)?; } (Value::EcState(x, y, x0, y0), CoreTypeConcrete::EcState(_)) => { - x.to_bytes(buffer)?; - y.to_bytes(buffer)?; - x0.to_bytes(buffer)?; - y0.to_bytes(buffer)?; + x.to_bytes(buffer, find_dict_overrides)?; + y.to_bytes(buffer, find_dict_overrides)?; + x0.to_bytes(buffer, find_dict_overrides)?; + y0.to_bytes(buffer, find_dict_overrides)?; } (Value::Enum { tag, value, .. }, CoreTypeConcrete::Enum(info)) => { if self.info.is_memory_allocated(self.registry)? { - let abi_ptr = self.value.to_ptr(self.arena, self.registry, self.type_id)?; + let abi_ptr = self.value.to_ptr( + self.arena, + self.registry, + self.type_id, + find_dict_overrides, + )?; let abi_ptr = unsafe { *abi_ptr.cast::>().as_ref() }; - abi_ptr.as_ptr().to_bytes(buffer)?; + abi_ptr.as_ptr().to_bytes(buffer, find_dict_overrides)?; } else { match (info.variants.len().next_power_of_two().trailing_zeros() + 7) / 8 { 0 => {} - _ => (*tag as u64).to_bytes(buffer)?, + _ => (*tag as u64).to_bytes(buffer, find_dict_overrides)?, } - self.map(value, &info.variants[*tag])?.to_bytes(buffer)?; + self.map(value, &info.variants[*tag])? + .to_bytes(buffer, find_dict_overrides)?; } } ( @@ -141,7 +183,7 @@ impl AbiArgument for ValueWithInfoWrapper<'_> { | StarkNetTypeConcrete::StorageAddress(_) | StarkNetTypeConcrete::StorageBaseAddress(_), ), - ) => value.to_bytes(buffer)?, + ) => value.to_bytes(buffer, find_dict_overrides)?, (Value::Felt252Dict { .. }, CoreTypeConcrete::Felt252Dict(_)) => { #[cfg(not(feature = "with-runtime"))] native_panic!("enable the `with-runtime` feature to use felt252 dicts"); @@ -151,9 +193,9 @@ impl AbiArgument for ValueWithInfoWrapper<'_> { // TODO: Assert that `info.ty` matches all the values' types. self.value - .to_ptr(self.arena, self.registry, self.type_id)? + .to_ptr(self.arena, self.registry, self.type_id, find_dict_overrides)? .as_ptr() - .to_bytes(buffer)? + .to_bytes(buffer, find_dict_overrides)? } } ( @@ -168,27 +210,47 @@ impl AbiArgument for ValueWithInfoWrapper<'_> { Secp256PointTypeConcrete::R1(_), )), ) => { - x.to_bytes(buffer)?; - y.to_bytes(buffer)?; - is_infinity.to_bytes(buffer)?; - } - (Value::Sint128(value), CoreTypeConcrete::Sint128(_)) => value.to_bytes(buffer)?, - (Value::Sint16(value), CoreTypeConcrete::Sint16(_)) => value.to_bytes(buffer)?, - (Value::Sint32(value), CoreTypeConcrete::Sint32(_)) => value.to_bytes(buffer)?, - (Value::Sint64(value), CoreTypeConcrete::Sint64(_)) => value.to_bytes(buffer)?, - (Value::Sint8(value), CoreTypeConcrete::Sint8(_)) => value.to_bytes(buffer)?, + x.to_bytes(buffer, find_dict_overrides)?; + y.to_bytes(buffer, find_dict_overrides)?; + is_infinity.to_bytes(buffer, find_dict_overrides)?; + } + (Value::Sint128(value), CoreTypeConcrete::Sint128(_)) => { + value.to_bytes(buffer, find_dict_overrides)? + } + (Value::Sint16(value), CoreTypeConcrete::Sint16(_)) => { + value.to_bytes(buffer, find_dict_overrides)? + } + (Value::Sint32(value), CoreTypeConcrete::Sint32(_)) => { + value.to_bytes(buffer, find_dict_overrides)? + } + (Value::Sint64(value), CoreTypeConcrete::Sint64(_)) => { + value.to_bytes(buffer, find_dict_overrides)? + } + (Value::Sint8(value), CoreTypeConcrete::Sint8(_)) => { + value.to_bytes(buffer, find_dict_overrides)? + } (Value::Struct { fields, .. }, CoreTypeConcrete::Struct(info)) => { fields .iter() .zip(&info.members) .map(|(value, type_id)| self.map(value, type_id)) - .try_for_each(|wrapper| wrapper?.to_bytes(buffer))?; + .try_for_each(|wrapper| wrapper?.to_bytes(buffer, find_dict_overrides))?; + } + (Value::Uint128(value), CoreTypeConcrete::Uint128(_)) => { + value.to_bytes(buffer, find_dict_overrides)? + } + (Value::Uint16(value), CoreTypeConcrete::Uint16(_)) => { + value.to_bytes(buffer, find_dict_overrides)? + } + (Value::Uint32(value), CoreTypeConcrete::Uint32(_)) => { + value.to_bytes(buffer, find_dict_overrides)? + } + (Value::Uint64(value), CoreTypeConcrete::Uint64(_)) => { + value.to_bytes(buffer, find_dict_overrides)? + } + (Value::Uint8(value), CoreTypeConcrete::Uint8(_)) => { + value.to_bytes(buffer, find_dict_overrides)? } - (Value::Uint128(value), CoreTypeConcrete::Uint128(_)) => value.to_bytes(buffer)?, - (Value::Uint16(value), CoreTypeConcrete::Uint16(_)) => value.to_bytes(buffer)?, - (Value::Uint32(value), CoreTypeConcrete::Uint32(_)) => value.to_bytes(buffer)?, - (Value::Uint64(value), CoreTypeConcrete::Uint64(_)) => value.to_bytes(buffer)?, - (Value::Uint8(value), CoreTypeConcrete::Uint8(_)) => value.to_bytes(buffer)?, _ => native_panic!( "todo: abi argument unimplemented for ({:?}, {:?})", self.value, diff --git a/src/arch/aarch64.rs b/src/arch/aarch64.rs index ad699033c..2ab529944 100644 --- a/src/arch/aarch64.rs +++ b/src/arch/aarch64.rs @@ -11,15 +11,27 @@ use super::AbiArgument; use crate::{error::Error, starknet::U256, utils::get_integer_layout}; +use cairo_lang_sierra::ids::ConcreteTypeId; use num_traits::ToBytes; use starknet_types_core::felt::Felt; +use std::ffi::c_void; fn align_to(buffer: &mut Vec, align: usize) { buffer.resize(buffer.len().next_multiple_of(align), 0); } impl AbiArgument for bool { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { if buffer.len() < 64 { buffer.extend_from_slice(&(*self as u64).to_ne_bytes()); } else { @@ -31,7 +43,17 @@ impl AbiArgument for bool { } impl AbiArgument for u8 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { if buffer.len() < 64 { buffer.extend_from_slice(&(*self as u64).to_ne_bytes()); } else { @@ -43,7 +65,17 @@ impl AbiArgument for u8 { } impl AbiArgument for i8 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { if buffer.len() < 64 { buffer.extend_from_slice(&(*self as u64).to_ne_bytes()); } else { @@ -55,7 +87,17 @@ impl AbiArgument for i8 { } impl AbiArgument for u16 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { if buffer.len() < 64 { buffer.extend_from_slice(&(*self as u64).to_ne_bytes()); } else { @@ -67,7 +109,17 @@ impl AbiArgument for u16 { } impl AbiArgument for i16 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { if buffer.len() < 64 { buffer.extend_from_slice(&(*self as u64).to_ne_bytes()); } else { @@ -79,7 +131,17 @@ impl AbiArgument for i16 { } impl AbiArgument for u32 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { if buffer.len() < 64 { buffer.extend_from_slice(&(*self as u64).to_ne_bytes()); } else { @@ -91,7 +153,17 @@ impl AbiArgument for u32 { } impl AbiArgument for i32 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { if buffer.len() < 64 { buffer.extend_from_slice(&(*self as u64).to_ne_bytes()); } else { @@ -103,7 +175,17 @@ impl AbiArgument for i32 { } impl AbiArgument for u64 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { if buffer.len() >= 64 { align_to(buffer, get_integer_layout(64).align()); } @@ -113,7 +195,17 @@ impl AbiArgument for u64 { } impl AbiArgument for i64 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { if buffer.len() >= 64 { align_to(buffer, get_integer_layout(64).align()); } @@ -123,7 +215,17 @@ impl AbiArgument for i64 { } impl AbiArgument for u128 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { if buffer.len() >= 56 { align_to(buffer, get_integer_layout(128).align()); } @@ -133,7 +235,17 @@ impl AbiArgument for u128 { } impl AbiArgument for i128 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { if buffer.len() >= 56 { align_to(buffer, get_integer_layout(128).align()); } @@ -143,7 +255,17 @@ impl AbiArgument for i128 { } impl AbiArgument for Felt { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { if buffer.len() >= 56 { align_to(buffer, get_integer_layout(252).align()); } @@ -153,14 +275,34 @@ impl AbiArgument for Felt { } impl AbiArgument for U256 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { - self.lo.to_bytes(buffer)?; - self.hi.to_bytes(buffer) + fn to_bytes( + &self, + buffer: &mut Vec, + find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { + self.lo.to_bytes(buffer, find_dict_overrides)?; + self.hi.to_bytes(buffer, find_dict_overrides) } } impl AbiArgument for [u8; 31] { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { // The `bytes31` type is treated as a 248-bit integer, therefore it follows the same // splitting rules as them. if buffer.len() >= 56 { @@ -173,14 +315,34 @@ impl AbiArgument for [u8; 31] { } impl AbiArgument for *const T { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { - ::to_bytes(&(*self as u64), buffer) + fn to_bytes( + &self, + buffer: &mut Vec, + find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { + ::to_bytes(&(*self as u64), buffer, find_dict_overrides) } } impl AbiArgument for *mut T { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { - ::to_bytes(&(*self as u64), buffer) + fn to_bytes( + &self, + buffer: &mut Vec, + find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { + ::to_bytes(&(*self as u64), buffer, find_dict_overrides) } } @@ -192,12 +354,12 @@ mod test { fn u8_to_bytes() { // Buffer initially empty let mut buffer = vec![]; - u8::MAX.to_bytes(&mut buffer).unwrap(); + u8::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!(buffer, [u8::MAX, 0, 0, 0, 0, 0, 0, 0]); // Buffer initially filled with 70 zeros (len > 64) let mut buffer = vec![0; 70]; - u8::MAX.to_bytes(&mut buffer).unwrap(); + u8::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!( buffer, [0; 70].into_iter().chain([u8::MAX]).collect::>() @@ -208,17 +370,17 @@ mod test { fn i8_to_bytes() { // Buffer initially empty let mut buffer = vec![]; - i8::MAX.to_bytes(&mut buffer).unwrap(); + i8::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!(buffer, [i8::MAX as u8, 0, 0, 0, 0, 0, 0, 0]); // Buffer initially empty with negative value let mut buffer = vec![]; - i8::MIN.to_bytes(&mut buffer).unwrap(); + i8::MIN.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!(buffer, [128, 255, 255, 255, 255, 255, 255, 255]); // Buffer initially filled with 70 zeros (len > 64) let mut buffer = vec![0; 70]; - i8::MAX.to_bytes(&mut buffer).unwrap(); + i8::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!( buffer, [0; 70] @@ -229,7 +391,7 @@ mod test { // Buffer initially filled with 70 zeros (len > 64) and negative value let mut buffer = vec![0; 70]; - i8::MIN.to_bytes(&mut buffer).unwrap(); + i8::MIN.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!(buffer, [0; 70].into_iter().chain([128]).collect::>()); } @@ -237,12 +399,12 @@ mod test { fn u16_to_bytes() { // Buffer initially empty let mut buffer = vec![]; - u16::MAX.to_bytes(&mut buffer).unwrap(); + u16::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!(buffer, vec![u8::MAX, u8::MAX, 0, 0, 0, 0, 0, 0]); // Buffer initially filled with 70 zeros (len > 64) let mut buffer = vec![0; 70]; - u16::MAX.to_bytes(&mut buffer).unwrap(); + u16::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!( buffer, [0; 70] @@ -256,17 +418,17 @@ mod test { fn i16_to_bytes() { // Buffer initially empty let mut buffer = vec![]; - i16::MAX.to_bytes(&mut buffer).unwrap(); + i16::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!(buffer, vec![u8::MAX, i8::MAX as u8, 0, 0, 0, 0, 0, 0]); // Buffer initially empty with negative value let mut buffer = vec![]; - i16::MIN.to_bytes(&mut buffer).unwrap(); + i16::MIN.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!(buffer, [0, 128, 255, 255, 255, 255, 255, 255]); // Buffer initially filled with 70 zeros (len > 64) let mut buffer = vec![0; 70]; - i16::MAX.to_bytes(&mut buffer).unwrap(); + i16::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!( buffer, [0; 70] @@ -277,7 +439,7 @@ mod test { // Buffer initially filled with 70 zeros (len > 64) and negative value let mut buffer = vec![0; 70]; - i16::MIN.to_bytes(&mut buffer).unwrap(); + i16::MIN.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!( buffer, [0; 70].into_iter().chain([0, 128]).collect::>() @@ -288,7 +450,7 @@ mod test { fn u32_to_bytes() { // Buffer initially empty let mut buffer = vec![]; - u32::MAX.to_bytes(&mut buffer).unwrap(); + u32::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!( buffer, vec![u8::MAX; 4] @@ -299,7 +461,7 @@ mod test { // Buffer initially filled with 70 zeros (len > 64) let mut buffer = vec![0; 70]; - u32::MAX.to_bytes(&mut buffer).unwrap(); + u32::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!( buffer, [0; 72] @@ -313,7 +475,7 @@ mod test { fn i32_to_bytes() { // Buffer initially empty let mut buffer = vec![]; - i32::MAX.to_bytes(&mut buffer).unwrap(); + i32::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!( buffer, vec![u8::MAX, u8::MAX, u8::MAX, i8::MAX as u8, 0, 0, 0, 0] @@ -321,12 +483,12 @@ mod test { // Buffer initially empty with negative value let mut buffer = vec![]; - i32::MIN.to_bytes(&mut buffer).unwrap(); + i32::MIN.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!(buffer, [0, 0, 0, 128, 255, 255, 255, 255]); // Buffer initially filled with 70 zeros (len > 64) let mut buffer = vec![0; 70]; - i32::MAX.to_bytes(&mut buffer).unwrap(); + i32::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!( buffer, [0; 72] @@ -337,7 +499,7 @@ mod test { // Buffer initially filled with 70 zeros (len > 64) and negative value let mut buffer = vec![0; 70]; - i32::MIN.to_bytes(&mut buffer).unwrap(); + i32::MIN.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!( buffer, [0; 72] @@ -351,12 +513,12 @@ mod test { fn u64_to_bytes() { // Buffer initially empty let mut buffer = vec![]; - u64::MAX.to_bytes(&mut buffer).unwrap(); + u64::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!(buffer, u64::MAX.to_ne_bytes().to_vec()); // Buffer initially filled with 70 zeros (len > 64) let mut buffer = vec![0; 70]; - u64::MAX.to_bytes(&mut buffer).unwrap(); + u64::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!( buffer, [0; 72] @@ -370,17 +532,17 @@ mod test { fn i64_to_bytes() { // Buffer initially empty let mut buffer = vec![]; - i64::MAX.to_bytes(&mut buffer).unwrap(); + i64::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!(buffer, i64::MAX.to_ne_bytes().to_vec()); // Buffer initially empty with negative value let mut buffer = vec![]; - i64::MIN.to_bytes(&mut buffer).unwrap(); + i64::MIN.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!(buffer, i64::MIN.to_ne_bytes().to_vec()); // Buffer initially filled with 70 zeros (len > 64) let mut buffer = vec![0; 70]; - i64::MAX.to_bytes(&mut buffer).unwrap(); + i64::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!( buffer, [0; 72] @@ -391,7 +553,7 @@ mod test { // Buffer initially filled with 70 zeros (len > 64) and negative value let mut buffer = vec![0; 70]; - i64::MIN.to_bytes(&mut buffer).unwrap(); + i64::MIN.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!( buffer, [0; 72] @@ -404,7 +566,7 @@ mod test { #[test] fn u128_stack_split() { let mut buffer = vec![0; 56]; - u128::MAX.to_bytes(&mut buffer).unwrap(); + u128::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!( buffer, [0; 64].into_iter().chain([0xFF; 16]).collect::>() @@ -417,7 +579,7 @@ mod test { let mut buffer = vec![0; 40]; Felt::from_hex("0x00ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") .unwrap() - .to_bytes(&mut buffer) + .to_bytes(&mut buffer, |_| unreachable!()) .unwrap(); assert_eq!( buffer, @@ -432,7 +594,7 @@ mod test { let mut buffer = vec![0; 48]; Felt::from_hex("0x00ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") .unwrap() - .to_bytes(&mut buffer) + .to_bytes(&mut buffer, |_| unreachable!()) .unwrap(); assert_eq!( buffer, @@ -447,7 +609,7 @@ mod test { let mut buffer = vec![0; 56]; Felt::from_hex("0x00ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") .unwrap() - .to_bytes(&mut buffer) + .to_bytes(&mut buffer, |_| unreachable!()) .unwrap(); assert_eq!( buffer, diff --git a/src/arch/x86_64.rs b/src/arch/x86_64.rs index 2eea9ec65..ec82c60bf 100644 --- a/src/arch/x86_64.rs +++ b/src/arch/x86_64.rs @@ -11,78 +11,180 @@ use super::AbiArgument; use crate::{error::Error, starknet::U256, utils::get_integer_layout}; +use cairo_lang_sierra::ids::ConcreteTypeId; use num_traits::ToBytes; use starknet_types_core::felt::Felt; +use std::ffi::c_void; fn align_to(buffer: &mut Vec, align: usize) { buffer.resize(buffer.len().next_multiple_of(align), 0); } impl AbiArgument for bool { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { buffer.extend_from_slice(&(*self as u64).to_ne_bytes()); Ok(()) } } impl AbiArgument for u8 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { buffer.extend_from_slice(&(*self as u64).to_ne_bytes()); Ok(()) } } impl AbiArgument for i8 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { buffer.extend_from_slice(&(*self as u64).to_ne_bytes()); Ok(()) } } impl AbiArgument for u16 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { buffer.extend_from_slice(&(*self as u64).to_ne_bytes()); Ok(()) } } impl AbiArgument for i16 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { buffer.extend_from_slice(&(*self as u64).to_ne_bytes()); Ok(()) } } impl AbiArgument for u32 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { buffer.extend_from_slice(&(*self as u64).to_ne_bytes()); Ok(()) } } impl AbiArgument for i32 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { buffer.extend_from_slice(&(*self as u64).to_ne_bytes()); Ok(()) } } impl AbiArgument for u64 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { buffer.extend_from_slice(&self.to_ne_bytes()); Ok(()) } } impl AbiArgument for i64 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { buffer.extend_from_slice(&self.to_ne_bytes()); Ok(()) } } impl AbiArgument for u128 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { if buffer.len() >= 40 { align_to(buffer, get_integer_layout(128).align()); } @@ -93,7 +195,17 @@ impl AbiArgument for u128 { } impl AbiArgument for i128 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { if buffer.len() >= 40 { align_to(buffer, get_integer_layout(128).align()); } @@ -104,7 +216,17 @@ impl AbiArgument for i128 { } impl AbiArgument for Felt { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { if buffer.len() >= 40 { align_to(buffer, get_integer_layout(252).align()); } @@ -115,14 +237,34 @@ impl AbiArgument for Felt { } impl AbiArgument for U256 { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { - self.lo.to_bytes(buffer)?; - self.hi.to_bytes(buffer) + fn to_bytes( + &self, + buffer: &mut Vec, + find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { + self.lo.to_bytes(buffer, find_dict_overrides)?; + self.hi.to_bytes(buffer, find_dict_overrides) } } impl AbiArgument for [u8; 31] { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { + fn to_bytes( + &self, + buffer: &mut Vec, + _find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { // The `bytes31` type is treated as a 248-bit integer, therefore it follows the same // splitting rules as them. @@ -137,14 +279,34 @@ impl AbiArgument for [u8; 31] { } impl AbiArgument for *const T { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { - ::to_bytes(&(*self as u64), buffer) + fn to_bytes( + &self, + buffer: &mut Vec, + find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { + ::to_bytes(&(*self as u64), buffer, find_dict_overrides) } } impl AbiArgument for *mut T { - fn to_bytes(&self, buffer: &mut Vec) -> Result<(), Error> { - ::to_bytes(&(*self as u64), buffer) + fn to_bytes( + &self, + buffer: &mut Vec, + find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), + ) -> Result<(), Error> { + ::to_bytes(&(*self as u64), buffer, find_dict_overrides) } } @@ -155,7 +317,7 @@ mod test { #[test] fn u128_stack_split() { let mut buffer = vec![0; 40]; - u128::MAX.to_bytes(&mut buffer).unwrap(); + u128::MAX.to_bytes(&mut buffer, |_| unreachable!()).unwrap(); assert_eq!( buffer, [0; 48].into_iter().chain([0xFF; 16]).collect::>() @@ -168,7 +330,7 @@ mod test { let mut buffer = vec![0; 24]; Felt::from_hex("0x00ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") .unwrap() - .to_bytes(&mut buffer) + .to_bytes(&mut buffer, |_| unreachable!()) .unwrap(); assert_eq!( buffer, @@ -183,7 +345,7 @@ mod test { let mut buffer = vec![0; 32]; Felt::from_hex("0x00ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") .unwrap() - .to_bytes(&mut buffer) + .to_bytes(&mut buffer, |_| unreachable!()) .unwrap(); assert_eq!( buffer, @@ -198,7 +360,7 @@ mod test { let mut buffer = vec![0; 40]; Felt::from_hex("0x00ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") .unwrap() - .to_bytes(&mut buffer) + .to_bytes(&mut buffer, |_| unreachable!()) .unwrap(); assert_eq!( buffer, diff --git a/src/bin/cairo-native-stress/main.rs b/src/bin/cairo-native-stress/main.rs index cf5471b1e..167ef1869 100644 --- a/src/bin/cairo-native-stress/main.rs +++ b/src/bin/cairo-native-stress/main.rs @@ -307,7 +307,12 @@ where } }; - let executor = AotNativeExecutor::new(shared_library, registry, metadata); + let executor = AotNativeExecutor::new( + shared_library, + registry, + metadata, + native_module.metadata().get().cloned().unwrap_or_default(), + ); let executor = Arc::new(executor); self.cache.insert(key, executor.clone()); diff --git a/src/cache/aot.rs b/src/cache/aot.rs index 28aaa6a91..f939b131f 100644 --- a/src/cache/aot.rs +++ b/src/cache/aot.rs @@ -1,7 +1,7 @@ use crate::error::{Error, Result}; use crate::{ - context::NativeContext, executor::AotNativeExecutor, metadata::gas::GasMetadata, - module::NativeModule, utils::SHARED_LIBRARY_EXT, OptLevel, + context::NativeContext, executor::AotNativeExecutor, module::NativeModule, + utils::SHARED_LIBRARY_EXT, OptLevel, }; use cairo_lang_sierra::program::Program; use libloading::Library; @@ -44,7 +44,7 @@ where let NativeModule { module, registry, - metadata, + mut metadata, } = self .context .compile(program, false, Some(Default::default()))?; @@ -64,10 +64,8 @@ where let executor = AotNativeExecutor::new( shared_library, registry, - metadata - .get::() - .cloned() - .ok_or(Error::MissingMetadata)?, + metadata.remove().ok_or(Error::MissingMetadata)?, + metadata.remove().unwrap_or_default(), ); let executor = Arc::new(executor); diff --git a/src/executor.rs b/src/executor.rs index 26f95d761..8aaf2d9fe 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -63,6 +63,7 @@ extern "C" { /// constructs the function call in place. /// /// To pass the arguments, they are stored in a arena. +#[allow(clippy::too_many_arguments)] fn invoke_dynamic( registry: &ProgramRegistry, function_ptr: *const c_void, @@ -71,6 +72,13 @@ fn invoke_dynamic( args: &[Value], gas: u64, mut syscall_handler: Option, + find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), ) -> Result { tracing::info!("Invoking function with signature: {function_signature:?}."); let arena = Bump::new(); @@ -116,7 +124,9 @@ fn invoke_dynamic( })?; let return_ptr = arena.alloc_layout(layout).cast::<()>(); - return_ptr.as_ptr().to_bytes(&mut invoke_data)?; + return_ptr + .as_ptr() + .to_bytes(&mut invoke_data, |_| unreachable!())?; Some(return_ptr) } else { @@ -164,19 +174,23 @@ fn invoke_dynamic( // Process gas requirements and syscall handler. match type_info { - CoreTypeConcrete::GasBuiltin(_) => gas.to_bytes(&mut invoke_data)?, + CoreTypeConcrete::GasBuiltin(_) => { + gas.to_bytes(&mut invoke_data, |_| unreachable!())? + } CoreTypeConcrete::StarkNet(StarkNetTypeConcrete::System(_)) => { let syscall_handler = syscall_handler .as_mut() .to_native_assert_error("syscall handler should be available")?; (syscall_handler as *mut StarknetSyscallHandlerCallbacks<_>) - .to_bytes(&mut invoke_data)?; + .to_bytes(&mut invoke_data, |_| unreachable!())?; } CoreTypeConcrete::BuiltinCosts(_) => { - builtin_costs.to_bytes(&mut invoke_data)?; + builtin_costs.to_bytes(&mut invoke_data, |_| unreachable!())?; + } + type_info if type_info.is_builtin() => { + 0u64.to_bytes(&mut invoke_data, |_| unreachable!())? } - type_info if type_info.is_builtin() => 0u64.to_bytes(&mut invoke_data)?, type_info => ValueWithInfoWrapper { value: iter .next() @@ -187,7 +201,7 @@ fn invoke_dynamic( arena: &arena, registry, } - .to_bytes(&mut invoke_data)?, + .to_bytes(&mut invoke_data, find_dict_overrides)?, } } diff --git a/src/executor/aot.rs b/src/executor/aot.rs index 14a53280f..257bb775a 100644 --- a/src/executor/aot.rs +++ b/src/executor/aot.rs @@ -1,7 +1,7 @@ use crate::{ error::Error, execution_result::{ContractExecutionResult, ExecutionResult}, - metadata::gas::GasMetadata, + metadata::{felt252_dict::Felt252DictOverrides, gas::GasMetadata}, module::NativeModule, starknet::{DummySyscallHandler, StarknetSyscallHandler}, utils::generate_function_name, @@ -10,7 +10,7 @@ use crate::{ }; use cairo_lang_sierra::{ extensions::core::{CoreLibfunc, CoreType}, - ids::FunctionId, + ids::{ConcreteTypeId, FunctionId}, program::FunctionSignature, program_registry::ProgramRegistry, }; @@ -18,7 +18,7 @@ use educe::Educe; use libc::c_void; use libloading::Library; use starknet_types_core::felt::Felt; -use std::io; +use std::{io, mem::transmute}; use tempfile::NamedTempFile; #[derive(Educe)] @@ -30,6 +30,7 @@ pub struct AotNativeExecutor { registry: ProgramRegistry, gas_metadata: GasMetadata, + dict_overrides: Felt252DictOverrides, } unsafe impl Send for AotNativeExecutor {} @@ -40,11 +41,13 @@ impl AotNativeExecutor { library: Library, registry: ProgramRegistry, gas_metadata: GasMetadata, + dict_overrides: Felt252DictOverrides, ) -> Self { Self { library, registry, gas_metadata, + dict_overrides, } } @@ -68,6 +71,7 @@ impl AotNativeExecutor { library: unsafe { Library::new(&library_path)? }, registry, gas_metadata: metadata.remove().ok_or(Error::MissingMetadata)?, + dict_overrides: metadata.remove().unwrap_or_default(), }) } @@ -101,6 +105,7 @@ impl AotNativeExecutor { args, available_gas, Option::::None, + self.build_find_dict_overrides(), ) } @@ -135,6 +140,7 @@ impl AotNativeExecutor { args, available_gas, Some(syscall_handler), + self.build_find_dict_overrides(), ) } @@ -174,6 +180,7 @@ impl AotNativeExecutor { }], available_gas, Some(syscall_handler), + self.build_find_dict_overrides(), )?) } @@ -203,6 +210,30 @@ impl AotNativeExecutor { fn extract_signature(&self, function_id: &FunctionId) -> Result<&FunctionSignature, Error> { Ok(&self.registry.get_function(function_id)?.signature) } + + fn build_find_dict_overrides( + &self, + ) -> impl '_ + + Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ) { + |type_id| { + ( + self.dict_overrides + .get_dup_fn(type_id) + .and_then(|symbol| self.find_symbol_ptr(symbol)) + .map(|ptr| unsafe { transmute(ptr as *const ()) }), + self.dict_overrides + .get_drop_fn(type_id) + .and_then(|symbol| self.find_symbol_ptr(symbol)) + .map(|ptr| unsafe { transmute(ptr as *const ()) }), + ) + } + } } #[cfg(test)] diff --git a/src/executor/contract.rs b/src/executor/contract.rs index 04ff164e7..f0a9c17a8 100644 --- a/src/executor/contract.rs +++ b/src/executor/contract.rs @@ -344,25 +344,29 @@ impl AotContractExecutor { Layout::from_size_align_unchecked(128 + builtins_size, 16) }); - return_ptr.as_ptr().to_bytes(&mut invoke_data)?; + return_ptr + .as_ptr() + .to_bytes(&mut invoke_data, |_| unreachable!())?; let mut syscall_handler = StarknetSyscallHandlerCallbacks::new(&mut syscall_handler); for b in &self.contract_info.entry_points_info[&function_id.id].builtins { match b { BuiltinType::Gas => { - gas.to_bytes(&mut invoke_data)?; + gas.to_bytes(&mut invoke_data, |_| unreachable!())?; } BuiltinType::BuiltinCosts => { // todo: check if valid - builtin_costs.as_ptr().to_bytes(&mut invoke_data)?; + builtin_costs + .as_ptr() + .to_bytes(&mut invoke_data, |_| unreachable!())?; } BuiltinType::System => { (&mut syscall_handler as *mut StarknetSyscallHandlerCallbacks<_>) - .to_bytes(&mut invoke_data)?; + .to_bytes(&mut invoke_data, |_| unreachable!())?; } _ => { - 0u64.to_bytes(&mut invoke_data)?; + 0u64.to_bytes(&mut invoke_data, |_| unreachable!())?; } } } @@ -390,15 +394,15 @@ impl AotContractExecutor { .try_into() .to_native_assert_error("number of arguments should fit into a u32")?; - ptr.to_bytes(&mut invoke_data)?; + ptr.to_bytes(&mut invoke_data, |_| unreachable!())?; if cfg!(target_arch = "aarch64") { - 0u32.to_bytes(&mut invoke_data)?; // start - len.to_bytes(&mut invoke_data)?; // end - len.to_bytes(&mut invoke_data)?; // cap + 0u32.to_bytes(&mut invoke_data, |_| unreachable!())?; // start + len.to_bytes(&mut invoke_data, |_| unreachable!())?; // end + len.to_bytes(&mut invoke_data, |_| unreachable!())?; // cap } else if cfg!(target_arch = "x86_64") { - (0u32 as u64).to_bytes(&mut invoke_data)?; // start - (len as u64).to_bytes(&mut invoke_data)?; // end - (len as u64).to_bytes(&mut invoke_data)?; // cap + (0u32 as u64).to_bytes(&mut invoke_data, |_| unreachable!())?; // start + (len as u64).to_bytes(&mut invoke_data, |_| unreachable!())?; // end + (len as u64).to_bytes(&mut invoke_data, |_| unreachable!())?; // cap } else { unreachable!("unsupported architecture"); } diff --git a/src/executor/jit.rs b/src/executor/jit.rs index ad707d994..6697e7f5e 100644 --- a/src/executor/jit.rs +++ b/src/executor/jit.rs @@ -1,7 +1,7 @@ use crate::{ error::Error, execution_result::{ContractExecutionResult, ExecutionResult}, - metadata::gas::GasMetadata, + metadata::{felt252_dict::Felt252DictOverrides, gas::GasMetadata}, module::NativeModule, starknet::{DummySyscallHandler, StarknetSyscallHandler}, utils::{create_engine, generate_function_name}, @@ -10,13 +10,14 @@ use crate::{ }; use cairo_lang_sierra::{ extensions::core::{CoreLibfunc, CoreType}, - ids::FunctionId, + ids::{ConcreteTypeId, FunctionId}, program::FunctionSignature, program_registry::ProgramRegistry, }; use libc::c_void; use melior::{ir::Module, ExecutionEngine}; use starknet_types_core::felt::Felt; +use std::mem::transmute; /// A MLIR JIT execution engine in the context of Cairo Native. pub struct JitNativeExecutor<'m> { @@ -26,6 +27,7 @@ pub struct JitNativeExecutor<'m> { registry: ProgramRegistry, gas_metadata: GasMetadata, + dict_overrides: Felt252DictOverrides, } unsafe impl Send for JitNativeExecutor<'_> {} @@ -48,17 +50,15 @@ impl<'m> JitNativeExecutor<'m> { let NativeModule { module, registry, - metadata, + mut metadata, } = native_module; Ok(Self { engine: create_engine(&module, &metadata, opt_level), module, registry, - gas_metadata: metadata - .get::() - .cloned() - .ok_or(Error::MissingMetadata)?, + gas_metadata: metadata.remove().ok_or(Error::MissingMetadata)?, + dict_overrides: metadata.remove().unwrap_or_default(), }) } @@ -93,6 +93,7 @@ impl<'m> JitNativeExecutor<'m> { args, available_gas, Option::::None, + self.build_find_dict_overrides(), ) } @@ -120,6 +121,7 @@ impl<'m> JitNativeExecutor<'m> { args, available_gas, Some(syscall_handler), + self.build_find_dict_overrides(), ) } @@ -151,6 +153,7 @@ impl<'m> JitNativeExecutor<'m> { }], available_gas, Some(syscall_handler), + self.build_find_dict_overrides(), )?) } @@ -178,4 +181,28 @@ impl<'m> JitNativeExecutor<'m> { .get_function(function_id) .map(|func| &func.signature)?) } + + fn build_find_dict_overrides( + &self, + ) -> impl '_ + + Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ) { + |type_id| { + ( + self.dict_overrides + .get_dup_fn(type_id) + .and_then(|symbol| self.find_symbol_ptr(symbol)) + .map(|ptr| unsafe { transmute(ptr as *const ()) }), + self.dict_overrides + .get_drop_fn(type_id) + .and_then(|symbol| self.find_symbol_ptr(symbol)) + .map(|ptr| unsafe { transmute(ptr as *const ()) }), + ) + } + } } diff --git a/src/libfuncs/felt252_dict.rs b/src/libfuncs/felt252_dict.rs index f0303c2a7..09e9967db 100644 --- a/src/libfuncs/felt252_dict.rs +++ b/src/libfuncs/felt252_dict.rs @@ -3,7 +3,9 @@ use super::LibfuncHelper; use crate::{ error::Result, - metadata::{runtime_bindings::RuntimeBindingsMeta, MetadataStorage}, + metadata::{ + felt252_dict::Felt252DictOverrides, runtime_bindings::RuntimeBindingsMeta, MetadataStorage, + }, native_panic, types::TypeBuilder, utils::BlockExt, @@ -17,6 +19,7 @@ use cairo_lang_sierra::{ program_registry::ProgramRegistry, }; use melior::{ + dialect::{llvm, ods}, ir::{Block, Location}, Context, }; @@ -52,20 +55,71 @@ pub fn build_new<'ctx, 'this>( ) -> Result<()> { let segment_arena = super::increment_builtin_counter(context, entry, location, entry.arg(0)?)?; - let runtime_bindings = metadata - .get_mut::() - .expect("Runtime library not available."); - let value_type_id = match registry.get_type(&info.signature.branch_signatures[0].vars[1].ty)? { CoreTypeConcrete::Felt252Dict(info) => &info.ty, _ => native_panic!("entered unreachable code"), }; + let (dup_fn, drop_fn) = { + let mut dict_overrides = metadata + .remove::() + .unwrap_or_default(); + + let dup_fn = match dict_overrides.build_dup_fn( + context, + helper, + registry, + metadata, + value_type_id, + )? { + Some(dup_fn) => Some( + entry.append_op_result( + ods::llvm::mlir_addressof( + context, + llvm::r#type::pointer(context, 0), + dup_fn, + location, + ) + .into(), + )?, + ), + None => None, + }; + let drop_fn = match dict_overrides.build_drop_fn( + context, + helper, + registry, + metadata, + value_type_id, + )? { + Some(drop_fn_symbol) => Some( + entry.append_op_result( + ods::llvm::mlir_addressof( + context, + llvm::r#type::pointer(context, 0), + drop_fn_symbol, + location, + ) + .into(), + )?, + ), + None => None, + }; + + metadata.insert(dict_overrides); + (dup_fn, drop_fn) + }; + + let runtime_bindings = metadata + .get_mut::() + .expect("Runtime library not available."); let dict_ptr = runtime_bindings.dict_new( context, helper, entry, location, + dup_fn, + drop_fn, registry.get_type(value_type_id)?.layout(registry)?, )?; diff --git a/src/metadata.rs b/src/metadata.rs index 6f5038516..9a05f369a 100644 --- a/src/metadata.rs +++ b/src/metadata.rs @@ -19,6 +19,7 @@ pub mod debug_utils; pub mod drop_overrides; pub mod dup_overrides; pub mod enum_snapshot_variants; +pub mod felt252_dict; pub mod gas; pub mod realloc_bindings; pub mod runtime_bindings; diff --git a/src/metadata/felt252_dict.rs b/src/metadata/felt252_dict.rs new file mode 100644 index 000000000..ac9109660 --- /dev/null +++ b/src/metadata/felt252_dict.rs @@ -0,0 +1,163 @@ +use super::{drop_overrides::DropOverridesMeta, dup_overrides::DupOverridesMeta, MetadataStorage}; +use crate::{ + error::{Error, Result}, + utils::{BlockExt, ProgramRegistryExt}, +}; +use cairo_lang_sierra::{ + extensions::core::{CoreLibfunc, CoreType}, + ids::ConcreteTypeId, + program_registry::ProgramRegistry, +}; +use melior::{ + dialect::llvm, + ir::{ + attribute::{FlatSymbolRefAttribute, StringAttribute, TypeAttribute}, + Attribute, Block, Identifier, Location, Module, Region, + }, + Context, +}; +use std::collections::{hash_map::Entry, HashMap}; + +#[derive(Clone, Debug, Default)] +pub struct Felt252DictOverrides { + dup_overrides: HashMap, + drop_overrides: HashMap, +} + +impl Felt252DictOverrides { + pub fn get_dup_fn(&self, type_id: &ConcreteTypeId) -> Option<&str> { + self.dup_overrides.get(type_id).map(String::as_str) + } + + pub fn get_drop_fn(&self, type_id: &ConcreteTypeId) -> Option<&str> { + self.drop_overrides.get(type_id).map(String::as_str) + } + + pub fn build_dup_fn<'ctx>( + &mut self, + context: &'ctx Context, + module: &Module<'ctx>, + registry: &ProgramRegistry, + metadata: &mut MetadataStorage, + type_id: &ConcreteTypeId, + ) -> Result>> { + let location = Location::unknown(context); + + let inner_ty = registry.build_type(context, module, metadata, type_id)?; + Ok(match metadata.get::() { + Some(dup_overrides_meta) if dup_overrides_meta.is_overriden(type_id) => { + let dup_fn_symbol = format!("dup${}$item", type_id.id); + let flat_symbol_ref = FlatSymbolRefAttribute::new(context, &dup_fn_symbol); + + if let Entry::Vacant(entry) = self.dup_overrides.entry(type_id.clone()) { + let dup_fn_symbol = entry.insert(dup_fn_symbol); + + let region = Region::new(); + let entry = region.append_block(Block::new(&[ + (llvm::r#type::pointer(context, 0), location), + (llvm::r#type::pointer(context, 0), location), + ])); + + let source_ptr = entry.arg(0)?; + let target_ptr = entry.arg(1)?; + + let value = entry.load(context, location, source_ptr, inner_ty)?; + let values = dup_overrides_meta + .invoke_override(context, &entry, location, type_id, value)?; + entry.store(context, location, source_ptr, values.0)?; + entry.store(context, location, target_ptr, values.1)?; + + entry.append_operation(llvm::r#return(None, location)); + + module.body().append_operation(llvm::func( + context, + StringAttribute::new(context, dup_fn_symbol), + TypeAttribute::new(llvm::r#type::function( + llvm::r#type::void(context), + &[ + llvm::r#type::pointer(context, 0), + llvm::r#type::pointer(context, 0), + ], + false, + )), + region, + &[ + ( + Identifier::new(context, "sym_visibility"), + StringAttribute::new(context, "public").into(), + ), + ( + Identifier::new(context, "linkage"), + Attribute::parse(context, "#llvm.linkage") + .ok_or(Error::ParseAttributeError)?, + ), + ], + location, + )); + } + + Some(flat_symbol_ref) + } + _ => None, + }) + } + + pub fn build_drop_fn<'ctx>( + &mut self, + context: &'ctx Context, + module: &Module<'ctx>, + registry: &ProgramRegistry, + metadata: &mut MetadataStorage, + type_id: &ConcreteTypeId, + ) -> Result>> { + let location = Location::unknown(context); + + let inner_ty = registry.build_type(context, module, metadata, type_id)?; + Ok(match metadata.get::() { + Some(drop_overrides_meta) if drop_overrides_meta.is_overriden(type_id) => { + let drop_fn_symbol = format!("drop${}$item", type_id.id); + let flat_symbol_ref = FlatSymbolRefAttribute::new(context, &drop_fn_symbol); + + if let Entry::Vacant(entry) = self.drop_overrides.entry(type_id.clone()) { + let drop_fn_symbol = entry.insert(drop_fn_symbol); + + let region = Region::new(); + let entry = region + .append_block(Block::new(&[(llvm::r#type::pointer(context, 0), location)])); + + let value = entry.load(context, location, entry.arg(0)?, inner_ty)?; + drop_overrides_meta + .invoke_override(context, &entry, location, type_id, value)?; + + entry.append_operation(llvm::r#return(None, location)); + + module.body().append_operation(llvm::func( + context, + StringAttribute::new(context, drop_fn_symbol), + TypeAttribute::new(llvm::r#type::function( + llvm::r#type::void(context), + &[llvm::r#type::pointer(context, 0)], + false, + )), + region, + &[ + ( + Identifier::new(context, "sym_visibility"), + StringAttribute::new(context, "public").into(), + ), + ( + Identifier::new(context, "llvm.linkage"), + Attribute::parse(context, "#llvm.linkage") + .ok_or(Error::ParseAttributeError)?, + ), + ], + location, + )); + } + + Some(flat_symbol_ref) + } + _ => None, + }) + } +} diff --git a/src/metadata/runtime_bindings.rs b/src/metadata/runtime_bindings.rs index 0df01a628..f7e2f2b29 100644 --- a/src/metadata/runtime_bindings.rs +++ b/src/metadata/runtime_bindings.rs @@ -539,6 +539,8 @@ impl RuntimeBindingsMeta { module: &Module, block: &'a Block<'c>, location: Location<'c>, + dup_fn: Option>, + drop_fn: Option>, layout: Layout, ) -> Result> where @@ -553,7 +555,12 @@ impl RuntimeBindingsMeta { TypeAttribute::new( FunctionType::new( context, - &[i64_ty, i64_ty], + &[ + i64_ty, + i64_ty, + llvm::r#type::pointer(context, 0), + llvm::r#type::pointer(context, 0), + ], &[llvm::r#type::pointer(context, 0)], ) .into(), @@ -577,10 +584,23 @@ impl RuntimeBindingsMeta { let size = block.const_int_from_type(context, location, layout.size(), i64_ty)?; let align = block.const_int_from_type(context, location, layout.align(), i64_ty)?; + let dup_fn = match dup_fn { + Some(x) => x, + None => { + block.append_op_result(llvm::zero(llvm::r#type::pointer(context, 0), location))? + } + }; + let drop_fn = match drop_fn { + Some(x) => x, + None => { + block.append_op_result(llvm::zero(llvm::r#type::pointer(context, 0), location))? + } + }; + block.append_op_result(func::call( context, FlatSymbolRefAttribute::new(context, "cairo_native__dict_new"), - &[size, align], + &[size, align, dup_fn, drop_fn], &[llvm::r#type::pointer(context, 0)], location, )) @@ -596,7 +616,6 @@ impl RuntimeBindingsMeta { module: &Module, block: &'a Block<'c>, ptr: Value<'c, 'a>, - drop_fn: Option>, location: Location<'c>, ) -> Result> where @@ -607,15 +626,7 @@ impl RuntimeBindingsMeta { context, StringAttribute::new(context, "cairo_native__dict_drop"), TypeAttribute::new( - FunctionType::new( - context, - &[ - llvm::r#type::pointer(context, 0), - llvm::r#type::pointer(context, 0), - ], - &[], - ) - .into(), + FunctionType::new(context, &[llvm::r#type::pointer(context, 0)], &[]).into(), ), Region::new(), &[ @@ -633,17 +644,10 @@ impl RuntimeBindingsMeta { )); } - let drop_fn = match drop_fn { - Some(x) => x, - None => { - block.append_op_result(llvm::zero(llvm::r#type::pointer(context, 0), location))? - } - }; - Ok(block.append_operation(func::call( context, FlatSymbolRefAttribute::new(context, "cairo_native__dict_drop"), - &[ptr, drop_fn], + &[ptr], &[], location, ))) @@ -659,7 +663,6 @@ impl RuntimeBindingsMeta { module: &Module, block: &'a Block<'c>, ptr: Value<'c, 'a>, - dup_fn: Option>, location: Location<'c>, ) -> Result> where @@ -672,10 +675,7 @@ impl RuntimeBindingsMeta { TypeAttribute::new( FunctionType::new( context, - &[ - llvm::r#type::pointer(context, 0), - llvm::r#type::pointer(context, 0), - ], + &[llvm::r#type::pointer(context, 0)], &[llvm::r#type::pointer(context, 0)], ) .into(), @@ -696,17 +696,10 @@ impl RuntimeBindingsMeta { )); } - let dup_fn = match dup_fn { - Some(x) => x, - None => { - block.append_op_result(llvm::zero(llvm::r#type::pointer(context, 0), location))? - } - }; - block.append_op_result(func::call( context, FlatSymbolRefAttribute::new(context, "cairo_native__dict_dup"), - &[ptr, dup_fn], + &[ptr], &[llvm::r#type::pointer(context, 0)], location, )) diff --git a/src/types/felt252_dict.rs b/src/types/felt252_dict.rs index 3164bcbf8..dfe672bb4 100644 --- a/src/types/felt252_dict.rs +++ b/src/types/felt252_dict.rs @@ -11,10 +11,9 @@ use crate::{ error::{Error, Result}, metadata::{ drop_overrides::DropOverridesMeta, dup_overrides::DupOverridesMeta, - realloc_bindings::ReallocBindingsMeta, runtime_bindings::RuntimeBindingsMeta, - MetadataStorage, + felt252_dict::Felt252DictOverrides, realloc_bindings::ReallocBindingsMeta, + runtime_bindings::RuntimeBindingsMeta, MetadataStorage, }, - types::TypeBuilder, utils::{BlockExt, ProgramRegistryExt}, }; use cairo_lang_sierra::{ @@ -25,11 +24,8 @@ use cairo_lang_sierra::{ program_registry::ProgramRegistry, }; use melior::{ - dialect::{func, llvm, ods}, - ir::{ - attribute::{FlatSymbolRefAttribute, StringAttribute, TypeAttribute}, - Attribute, Block, Identifier, Location, Module, Region, Type, - }, + dialect::{func, llvm}, + ir::{Block, Location, Module, Region, Type}, Context, }; @@ -88,84 +84,24 @@ fn build_dup<'ctx>( } let value_ty = registry.build_type(context, module, metadata, info.self_ty())?; - let inner_ty = registry.get_type(&info.ty)?; - let inner_ty = inner_ty.build(context, module, registry, metadata, &info.ty)?; - - let dup_fn = match metadata.get::() { - Some(dup_overrides_meta) if dup_overrides_meta.is_overriden(&info.ty) => { - let region = Region::new(); - let entry = region.append_block(Block::new(&[ - (llvm::r#type::pointer(context, 0), location), - (llvm::r#type::pointer(context, 0), location), - ])); - - let source_ptr = entry.arg(0)?; - let target_ptr = entry.arg(1)?; - - let value = entry.load(context, location, source_ptr, inner_ty)?; - let values = - dup_overrides_meta.invoke_override(context, &entry, location, &info.ty, value)?; - entry.store(context, location, source_ptr, values.0)?; - entry.store(context, location, target_ptr, values.1)?; - - entry.append_operation(llvm::r#return(None, location)); - - let dup_fn_symbol = format!("dup${}$item", info.self_ty().id); - module.body().append_operation(llvm::func( - context, - StringAttribute::new(context, &dup_fn_symbol), - TypeAttribute::new(llvm::r#type::function( - llvm::r#type::void(context), - &[ - llvm::r#type::pointer(context, 0), - llvm::r#type::pointer(context, 0), - ], - false, - )), - region, - &[ - ( - Identifier::new(context, "sym_visibility"), - StringAttribute::new(context, "public").into(), - ), - ( - Identifier::new(context, "linkage"), - Attribute::parse(context, "#llvm.linkage") - .ok_or(Error::ParseAttributeError)?, - ), - ], - location, - )); - Some(dup_fn_symbol) - } - _ => None, - }; + { + let mut dict_overrides = metadata + .remove::() + .unwrap_or_default(); + dict_overrides.build_dup_fn(context, module, registry, metadata, &info.ty)?; + metadata.insert(dict_overrides); + } let region = Region::new(); let entry = region.append_block(Block::new(&[(value_ty, location)])); - let dup_fn = match dup_fn { - Some(dup_fn) => Some( - entry.append_op_result( - ods::llvm::mlir_addressof( - context, - llvm::r#type::pointer(context, 0), - FlatSymbolRefAttribute::new(context, &dup_fn), - location, - ) - .into(), - )?, - ), - None => None, - }; - // The following unwrap is unreachable because the registration logic will always insert it. let value0 = entry.arg(0)?; let value1 = metadata .get_mut::() .ok_or(Error::MissingMetadata)? - .dict_dup(context, module, &entry, value0, dup_fn, location)?; + .dict_dup(context, module, &entry, value0, location)?; entry.append_operation(func::r#return(&[value0, value1], location)); Ok(region) @@ -184,71 +120,23 @@ fn build_drop<'ctx>( } let value_ty = registry.build_type(context, module, metadata, info.self_ty())?; - let inner_ty = registry.build_type(context, module, metadata, &info.ty)?; - - let drop_fn_symbol = match metadata.get::() { - Some(drop_overrides_meta) if drop_overrides_meta.is_overriden(&info.ty) => { - let region = Region::new(); - let entry = - region.append_block(Block::new(&[(llvm::r#type::pointer(context, 0), location)])); - - let value = entry.load(context, location, entry.arg(0)?, inner_ty)?; - drop_overrides_meta.invoke_override(context, &entry, location, &info.ty, value)?; - - entry.append_operation(llvm::r#return(None, location)); - - let drop_fn_symbol = format!("drop${}$item", info.self_ty().id); - module.body().append_operation(llvm::func( - context, - StringAttribute::new(context, &drop_fn_symbol), - TypeAttribute::new(llvm::r#type::function( - llvm::r#type::void(context), - &[llvm::r#type::pointer(context, 0)], - false, - )), - region, - &[ - ( - Identifier::new(context, "sym_visibility"), - StringAttribute::new(context, "public").into(), - ), - ( - Identifier::new(context, "llvm.linkage"), - Attribute::parse(context, "#llvm.linkage") - .ok_or(Error::ParseAttributeError)?, - ), - ], - location, - )); - Some(drop_fn_symbol) - } - _ => None, - }; + { + let mut dict_overrides = metadata + .remove::() + .unwrap_or_default(); + dict_overrides.build_drop_fn(context, module, registry, metadata, &info.ty)?; + metadata.insert(dict_overrides); + } let region = Region::new(); let entry = region.append_block(Block::new(&[(value_ty, location)])); - let drop_fn = match drop_fn_symbol { - Some(drop_fn_symbol) => Some( - entry.append_op_result( - ods::llvm::mlir_addressof( - context, - llvm::r#type::pointer(context, 0), - FlatSymbolRefAttribute::new(context, &drop_fn_symbol), - location, - ) - .into(), - )?, - ), - None => None, - }; - // The following unwrap is unreachable because the registration logic will always insert it. let runtime_bindings_meta = metadata .get_mut::() .ok_or(Error::MissingMetadata)?; - runtime_bindings_meta.dict_drop(context, module, &entry, entry.arg(0)?, drop_fn, location)?; + runtime_bindings_meta.dict_drop(context, module, &entry, entry.arg(0)?, location)?; entry.append_operation(func::r#return(&[], location)); Ok(region) diff --git a/src/values.rs b/src/values.rs index 042443173..b9b5b846d 100644 --- a/src/values.rs +++ b/src/values.rs @@ -25,15 +25,17 @@ use educe::Educe; use num_bigint::{BigInt, BigUint, Sign}; use num_traits::{Euclid, One}; use starknet_types_core::felt::Felt; -use std::{alloc::Layout, collections::HashMap, ptr::NonNull, slice}; -#[cfg(feature = "with-runtime")] -use { - cairo_native_runtime::FeltDict, - std::{ - alloc::{alloc, dealloc}, - ptr::null_mut, - }, +use std::{ + alloc::Layout, + collections::HashMap, + ffi::c_void, + mem::forget, + ptr::{null_mut, NonNull}, + rc::Rc, + slice, }; +#[cfg(feature = "with-runtime")] +use {cairo_native_runtime::FeltDict, std::alloc::alloc}; /// A Value is a value that can be passed to either the JIT engine or a compiled program as an argument or received as a result. /// @@ -164,6 +166,13 @@ impl Value { arena: &Bump, registry: &ProgramRegistry, type_id: &ConcreteTypeId, + find_dict_overrides: impl Copy + + Fn( + &ConcreteTypeId, + ) -> ( + Option, + Option, + ), ) -> Result, Error> { let ty = registry.get_type(type_id)?; @@ -240,7 +249,8 @@ impl Value { // Write the data. for (idx, elem) in data.iter().enumerate() { - let elem = elem.to_ptr(arena, registry, &info.ty)?; + let elem = + elem.to_ptr(arena, registry, &info.ty, find_dict_overrides)?; std::ptr::copy_nonoverlapping( elem.cast::().as_ptr(), @@ -299,7 +309,12 @@ impl Value { }; layout = Some(new_layout); - let member_ptr = member.to_ptr(arena, registry, member_type_id)?; + let member_ptr = member.to_ptr( + arena, + registry, + member_type_id, + find_dict_overrides, + )?; data.push(( member_layout, offset, @@ -345,7 +360,8 @@ impl Value { native_assert!(*tag < info.variants.len(), "Variant index out of range."); let payload_type_id = &info.variants[*tag]; - let payload = value.to_ptr(arena, registry, payload_type_id)?; + let payload = + value.to_ptr(arena, registry, payload_type_id, find_dict_overrides)?; let (layout, tag_layout, variant_layouts) = crate::types::r#enum::get_layout_for_variants( @@ -385,7 +401,11 @@ impl Value { let elem_ty = registry.get_type(&info.ty)?; let elem_layout = elem_ty.layout(registry)?.pad_to_align(); - let mut value_map = Box::new(FeltDict { + // We need `find_dict_overrides` to obtain the function pointers of the dup and drop + // implementations (if any) for the value type. This is required to be able to clone and drop + // the dictionary automatically when their reference count drops to zero. + let (dup_fn, drop_fn) = find_dict_overrides(&info.ty); + let mut value_map = FeltDict { mappings: HashMap::with_capacity(map.len()), layout: elem_layout, @@ -399,14 +419,18 @@ impl Value { .cast() }, + dup_fn, + drop_fn, + count: 0, - }); + }; // next key must be called before next_value for (key, value) in map.iter() { let key = key.to_bytes_le(); - let value = value.to_ptr(arena, registry, &info.ty)?; + let value = + value.to_ptr(arena, registry, &info.ty, find_dict_overrides)?; let index = value_map.mappings.len(); value_map.mappings.insert(key, index); @@ -421,7 +445,7 @@ impl Value { ); } - NonNull::new_unchecked(Box::into_raw(value_map)).cast() + NonNull::new_unchecked(Rc::into_raw(Rc::new(value_map)) as *mut ()).cast() } else { Err(Error::UnexpectedValue(format!( "expected value of type {:?} but got a felt dict", @@ -533,11 +557,11 @@ impl Value { let inner = registry.get_type(&info.ty)?; let inner_layout = inner.layout(registry)?; - let x_ptr = x.to_ptr(arena, registry, &info.ty)?; + let x_ptr = x.to_ptr(arena, registry, &info.ty, find_dict_overrides)?; let (struct_layout, y_offset) = inner_layout.extend(inner_layout)?; - let y_ptr = y.to_ptr(arena, registry, &info.ty)?; + let y_ptr = y.to_ptr(arena, registry, &info.ty, find_dict_overrides)?; let ptr = arena.alloc_layout(struct_layout.pad_to_align()).as_ptr(); @@ -806,11 +830,7 @@ impl Value { #[cfg(feature = "with-runtime")] CoreTypeConcrete::Felt252Dict(info) | CoreTypeConcrete::SquashedFelt252Dict(info) => { - let dict = &ptr - .cast::>() - .as_ref() - .cast::() - .as_ref(); + let dict = Rc::from_raw(ptr.cast::<*const FeltDict>().read()); let mut output_map = HashMap::with_capacity(dict.mappings.len()); for (&key, &index) in dict.mappings.iter() { @@ -818,6 +838,8 @@ impl Value { key[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). let key = Felt::from_bytes_le(&key); + // The dictionary items are not being dropped here. They'll be dropped along + // with the dictionary (if requested using `should_drop`). output_map.insert( key, Self::from_ptr( @@ -831,28 +853,15 @@ impl Value { .cast(), &info.ty, registry, - should_drop, + false, )?, ); } if should_drop { - let dict = Box::from_raw( - ptr.cast::>() - .as_ref() - .cast::() - .as_ptr(), - ); - - dealloc( - dict.elements.cast(), - Layout::from_size_align_unchecked( - dict.layout.pad_to_align().size() * dict.mappings.capacity(), - dict.layout.align(), - ), - ); - drop(dict); + } else { + forget(dict); } Self::Felt252Dict { @@ -1142,7 +1151,12 @@ mod test { assert_eq!( unsafe { *Value::Felt252(Felt::from(42)) - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap() .cast::<[u32; 8]>() .as_ptr() @@ -1153,7 +1167,12 @@ mod test { assert_eq!( unsafe { *Value::Felt252(Felt::MAX) - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap() .cast::<[u32; 8]>() .as_ptr() @@ -1165,7 +1184,12 @@ mod test { assert_eq!( unsafe { *Value::Felt252(Felt::MAX + Felt::ONE) - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap() .cast::<[u32; 8]>() .as_ptr() @@ -1183,7 +1207,12 @@ mod test { assert_eq!( unsafe { *Value::Uint8(9) - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap() .cast::() .as_ptr() @@ -1201,7 +1230,12 @@ mod test { assert_eq!( unsafe { *Value::Uint16(17) - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap() .cast::() .as_ptr() @@ -1219,7 +1253,12 @@ mod test { assert_eq!( unsafe { *Value::Uint32(33) - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap() .cast::() .as_ptr() @@ -1237,7 +1276,12 @@ mod test { assert_eq!( unsafe { *Value::Uint64(65) - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap() .cast::() .as_ptr() @@ -1255,7 +1299,12 @@ mod test { assert_eq!( unsafe { *Value::Uint128(129) - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap() .cast::() .as_ptr() @@ -1273,7 +1322,12 @@ mod test { assert_eq!( unsafe { *Value::Sint8(-9) - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap() .cast::() .as_ptr() @@ -1291,7 +1345,12 @@ mod test { assert_eq!( unsafe { *Value::Sint16(-17) - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap() .cast::() .as_ptr() @@ -1309,7 +1368,12 @@ mod test { assert_eq!( unsafe { *Value::Sint32(-33) - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap() .cast::() .as_ptr() @@ -1327,7 +1391,12 @@ mod test { assert_eq!( unsafe { *Value::Sint64(-65) - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap() .cast::() .as_ptr() @@ -1345,7 +1414,12 @@ mod test { assert_eq!( unsafe { *Value::Sint128(-129) - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap() .cast::() .as_ptr() @@ -1365,7 +1439,12 @@ mod test { assert_eq!( unsafe { *Value::EcPoint(Felt::from(1234), Felt::from(4321)) - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap() .cast::<[[u32; 8]; 2]>() .as_ptr() @@ -1390,7 +1469,12 @@ mod test { Felt::from(3333), Felt::from(4444), ) - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap() .cast::<[[u32; 8]; 4]>() .as_ptr() @@ -1423,7 +1507,12 @@ mod test { value: Box::new(Value::Uint8(10)), debug_name: None, } - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[1].id); + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[1].id, + |_| todo!(), + ); // Assertion to verify that the value returned by to_jit is not NULL assert!(result.is_ok()); @@ -1452,7 +1541,12 @@ mod test { upper: BigInt::from(510), }, } - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[1].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[1].id, + |_| todo!(), + ) .unwrap() .cast::<[u32; 8]>() .as_ptr() @@ -1482,7 +1576,12 @@ mod test { upper: BigInt::from(10), }, } - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[1].id); + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[1].id, + |_| todo!(), + ); assert!(matches!( result, @@ -1511,7 +1610,12 @@ mod test { upper: BigInt::from(510), }, } - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[1].id); + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[1].id, + |_| todo!(), + ); assert!(matches!( result, @@ -1540,7 +1644,12 @@ mod test { upper: BigInt::from(510), }, } - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[1].id); + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[1].id, + |_| todo!(), + ); assert!(matches!( result, @@ -1569,7 +1678,12 @@ mod test { upper: BigInt::from(10), }, } - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[1].id); + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[1].id, + |_| todo!(), + ); assert!(matches!( result, @@ -1596,7 +1710,12 @@ mod test { value: Box::new(Value::Uint8(10)), debug_name: None, } - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[1].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[1].id, + |_| todo!(), + ) .unwrap_err(); let error = result.to_string().clone(); @@ -1621,7 +1740,12 @@ mod test { value: Box::new(Value::Uint8(10)), debug_name: None, } - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[1].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[1].id, + |_| todo!(), + ) .unwrap_err(); let error = result.to_string().clone(); @@ -1656,7 +1780,12 @@ mod test { }), debug_name: None, } - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap_err(); // Unwrapping the error // Matching the error result to verify the error type and message. @@ -1694,7 +1823,12 @@ mod test { fields: vec![Value::from(2u32)], debug_name: None, } - .to_ptr(&Bump::new(), ®istry, &program.type_declarations[0].id) + .to_ptr( + &Bump::new(), + ®istry, + &program.type_declarations[0].id, + |_| todo!(), + ) .unwrap_err(); // Unwrapping the error // Matching the error result to verify the error type and message.