Skip to content

Commit

Permalink
Make dicts clone on write. (#964)
Browse files Browse the repository at this point in the history
* Update the runtime.

* Non-runtime changes and fixes.

* Revert slab, not necessary. More progress.

* Progress.

* Fix stuff.

* Make it clone-on-write.

* Finish the PR.

* Try to fix `src/arch/x86_64.rs`.

* Fix `x86_64.rs`.

* Fix suggestions.

* Add comment according to suggestions.

---------

Co-authored-by: Edgar <[email protected]>
Co-authored-by: Franco Giachetta <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2025
1 parent f1e895f commit dd1ea3e
Show file tree
Hide file tree
Showing 16 changed files with 1,240 additions and 502 deletions.
272 changes: 156 additions & 116 deletions runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ use std::{
ffi::{c_int, c_void},
fs::File,
io::Write,
mem::ManuallyDrop,
mem::{forget, ManuallyDrop},
os::fd::FromRawFd,
ptr::{self, null, null_mut},
rc::Rc,
slice,
};
use std::{ops::Mul, vec::IntoIter};
Expand Down Expand Up @@ -166,31 +167,124 @@ pub struct FeltDict {
pub layout: Layout,
pub elements: *mut (),

pub dup_fn: Option<extern "C" fn(*mut c_void, *mut c_void)>,
pub drop_fn: Option<extern "C" fn(*mut c_void)>,

pub count: u64,
}

impl Clone for FeltDict {
fn clone(&self) -> Self {
let mut new_dict = FeltDict {
mappings: HashMap::with_capacity(self.mappings.len()),

layout: self.layout,
elements: if self.mappings.is_empty() {
null_mut()
} else {
unsafe {
alloc(Layout::from_size_align_unchecked(
self.layout.pad_to_align().size() * self.mappings.len(),
self.layout.align(),
))
.cast()
}
},

dup_fn: self.dup_fn,
drop_fn: self.drop_fn,

// TODO: Check if `0` is fine or otherwise we should copy the value from `old_dict` too.
count: 0,
};

for (&key, &old_index) in self.mappings.iter() {
let old_value_ptr = unsafe {
self.elements
.byte_add(self.layout.pad_to_align().size() * old_index)
};

let new_index = new_dict.mappings.len();
let new_value_ptr = unsafe {
new_dict
.elements
.byte_add(new_dict.layout.pad_to_align().size() * new_index)
};

new_dict.mappings.insert(key, new_index);
match self.dup_fn {
Some(dup_fn) => dup_fn(old_value_ptr.cast(), new_value_ptr.cast()),
None => unsafe {
ptr::copy_nonoverlapping::<u8>(
old_value_ptr.cast(),
new_value_ptr.cast(),
self.layout.size(),
)
},
}
}

new_dict
}
}

impl Drop for FeltDict {
fn drop(&mut self) {
// Free the entries manually.
if let Some(drop_fn) = self.drop_fn {
for (_, &index) in self.mappings.iter() {
let value_ptr = unsafe {
self.elements
.byte_add(self.layout.pad_to_align().size() * index)
};

drop_fn(value_ptr.cast());
}
}

// Free the value data.
if !self.elements.is_null() {
unsafe {
dealloc(
self.elements.cast(),
Layout::from_size_align_unchecked(
self.layout.pad_to_align().size() * self.mappings.capacity(),
self.layout.align(),
),
)
};
}
}
}

/// Allocate a new dictionary.
///
/// # Safety
///
/// This function is intended to be called from MLIR, deals with pointers, and is therefore
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_new(size: u64, align: u64) -> *mut FeltDict {
Box::into_raw(Box::new(FeltDict {
pub unsafe extern "C" fn cairo_native__dict_new(
size: u64,
align: u64,
dup_fn: Option<extern "C" fn(*mut c_void, *mut c_void)>,
drop_fn: Option<extern "C" fn(*mut c_void)>,
) -> *const FeltDict {
Rc::into_raw(Rc::new(FeltDict {
mappings: HashMap::default(),

layout: Layout::from_size_align_unchecked(size as usize, align as usize),
elements: null_mut(),

dup_fn,
drop_fn,

count: 0,
}))
}

/// Free a dictionary using an optional callback to drop each element.
///
/// The `drop_fn` callback is present when the value implements `Drop`.
///
/// # Safety
///
/// This function is intended to be called from MLIR, deals with pointers, and is therefore
Expand All @@ -199,88 +293,23 @@ pub unsafe extern "C" fn cairo_native__dict_new(size: u64, align: u64) -> *mut F
// pointer optimization. Check out
// https://doc.rust-lang.org/nomicon/ffi.html#the-nullable-pointer-optimization for more info.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_drop(
ptr: *mut FeltDict,
drop_fn: Option<extern "C" fn(*mut c_void)>,
) {
let dict = Box::from_raw(ptr);

// Free the entries manually.
if let Some(drop_fn) = drop_fn {
for (_, &index) in dict.mappings.iter() {
let value_ptr = dict
.elements
.byte_add(dict.layout.pad_to_align().size() * index);

drop_fn(value_ptr.cast());
}
}

// Free the value data.
if !dict.elements.is_null() {
dealloc(
dict.elements.cast(),
Layout::from_size_align_unchecked(
dict.layout.pad_to_align().size() * dict.mappings.capacity(),
dict.layout.align(),
),
);
}
pub unsafe extern "C" fn cairo_native__dict_drop(ptr: *const FeltDict) {
drop(Rc::from_raw(ptr));
}

/// Duplicate a dictionary using a provided callback to clone each element.
///
/// The `dup_fn` callback is present when the value is not `Copy`, but `Clone`. The first argument
/// is the original value while the second is the target pointer.
///
/// # Safety
///
/// This function is intended to be called from MLIR, deals with pointers, and is therefore
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_dup(
old_dict: &FeltDict,
dup_fn: Option<extern "C" fn(*mut c_void, *mut c_void)>,
) -> *mut FeltDict {
let mut new_dict = Box::new(FeltDict {
mappings: HashMap::with_capacity(old_dict.mappings.len()),

layout: old_dict.layout,
elements: if old_dict.mappings.is_empty() {
null_mut()
} else {
alloc(Layout::from_size_align_unchecked(
old_dict.layout.pad_to_align().size() * old_dict.mappings.len(),
old_dict.layout.align(),
))
.cast()
},

// TODO: Check if `0` is fine or otherwise we should copy the value from `old_dict` too.
count: 0,
});

for (new_index, (&key, &old_index)) in old_dict.mappings.iter().enumerate() {
let old_value_ptr = old_dict
.elements
.byte_add(old_dict.layout.pad_to_align().size() * old_index);

let new_value_ptr = new_dict
.elements
.byte_add(new_dict.layout.pad_to_align().size() * new_index);

new_dict.mappings.insert(key, new_index);
match dup_fn {
Some(dup_fn) => dup_fn(old_value_ptr.cast(), new_value_ptr.cast()),
None => ptr::copy_nonoverlapping::<u8>(
old_value_ptr.cast(),
new_value_ptr.cast(),
old_dict.layout.size(),
),
}
}
pub unsafe extern "C" fn cairo_native__dict_dup(dict_ptr: *const FeltDict) -> *const FeltDict {
let old_dict = Rc::from_raw(dict_ptr);
let new_dict = Rc::clone(&old_dict);

Box::into_raw(new_dict)
forget(old_dict);
Rc::into_raw(new_dict)
}

/// Return a pointer to the entry's value pointer for a given key, inserting a null pointer if not
Expand All @@ -295,45 +324,46 @@ pub unsafe extern "C" fn cairo_native__dict_dup(
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_get(
dict: &mut FeltDict,
dict: *const FeltDict,
key: &[u8; 32],
value_ptr: *mut *mut c_void,
) -> c_int {
let mut key = *key;
key[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).

let old_capacity = dict.mappings.capacity();
let index = dict.mappings.len();
let (index, is_present) = match dict.mappings.entry(key) {
Entry::Occupied(entry) => (*entry.get(), 1),
Entry::Vacant(entry) => {
entry.insert(index);
let mut dict_rc = Rc::from_raw(dict);
let dict = Rc::make_mut(&mut dict_rc);

// Reallocate `mem_data` to match the slab's capacity.
if old_capacity != dict.mappings.capacity() {
dict.elements = realloc(
dict.elements.cast(),
Layout::from_size_align_unchecked(
dict.layout.pad_to_align().size() * old_capacity,
dict.layout.align(),
),
dict.layout.pad_to_align().size() * dict.mappings.capacity(),
)
.cast();
}
let num_mappings = dict.mappings.len();
let has_capacity = num_mappings != dict.mappings.capacity();

(index, 0)
let (is_present, index) = match dict.mappings.entry(*key) {
Entry::Occupied(entry) => (true, *entry.get()),
Entry::Vacant(entry) => {
entry.insert(num_mappings);
(false, num_mappings)
}
};

value_ptr.write(
dict.elements
.byte_add(dict.layout.pad_to_align().size() * index)
.cast(),
);
// Maybe realloc (conditions: !has_capacity && !is_present).
if !has_capacity && !is_present {
dict.elements = realloc(
dict.elements.cast(),
Layout::from_size_align_unchecked(
dict.layout.pad_to_align().size() * dict.mappings.len(),
dict.layout.align(),
),
dict.layout.pad_to_align().size() * dict.mappings.capacity(),
)
.cast();
}

*value_ptr = dict
.elements
.byte_add(dict.layout.pad_to_align().size() * index)
.cast();

dict.count += 1;
forget(dict_rc);

is_present
is_present as c_int
}

/// Compute the total gas refund for the dictionary at squash time.
Expand All @@ -344,8 +374,12 @@ pub unsafe extern "C" fn cairo_native__dict_get(
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_gas_refund(ptr: *const FeltDict) -> u64 {
let dict = &*ptr;
(dict.count.saturating_sub(dict.mappings.len() as u64)) * *DICT_GAS_REFUND_PER_ACCESS
let dict = Rc::from_raw(ptr);
let amount =
(dict.count.saturating_sub(dict.mappings.len() as u64)) * *DICT_GAS_REFUND_PER_ACCESS;

forget(dict);
amount
}

/// Compute `ec_point_from_x_nz(x)` and store it.
Expand Down Expand Up @@ -865,21 +899,27 @@ mod tests {

#[test]
fn test_dict() {
let dict =
unsafe { cairo_native__dict_new(size_of::<u64>() as u64, align_of::<u64>() as u64) };
let dict = unsafe {
cairo_native__dict_new(
size_of::<u64>() as u64,
align_of::<u64>() as u64,
None,
None,
)
};

let key = Felt::ONE.to_bytes_le();
let mut ptr = null_mut::<u64>();

assert_eq!(
unsafe { cairo_native__dict_get(&mut *dict, &key, (&raw mut ptr).cast()) },
unsafe { cairo_native__dict_get(dict, &key, (&raw mut ptr).cast()) },
0,
);
assert!(!ptr.is_null());
unsafe { *ptr = 24 };

assert_eq!(
unsafe { cairo_native__dict_get(&mut *dict, &key, (&raw mut ptr).cast()) },
unsafe { cairo_native__dict_get(dict, &key, (&raw mut ptr).cast()) },
1,
);
assert!(!ptr.is_null());
Expand All @@ -889,17 +929,17 @@ mod tests {
let refund = unsafe { cairo_native__dict_gas_refund(dict) };
assert_eq!(refund, 4050);

let cloned_dict = unsafe { cairo_native__dict_dup(&*dict, None) };
unsafe { cairo_native__dict_drop(dict, None) };
let cloned_dict = unsafe { cairo_native__dict_dup(&*dict) };
unsafe { cairo_native__dict_drop(dict) };

assert_eq!(
unsafe { cairo_native__dict_get(&mut *cloned_dict, &key, (&raw mut ptr).cast()) },
unsafe { cairo_native__dict_get(cloned_dict, &key, (&raw mut ptr).cast()) },
1,
);
assert!(!ptr.is_null());
assert_eq!(unsafe { *ptr }, 42);

unsafe { cairo_native__dict_drop(cloned_dict, None) };
unsafe { cairo_native__dict_drop(cloned_dict) };
}

#[test]
Expand Down
Loading

0 comments on commit dd1ea3e

Please sign in to comment.