Skip to content

Commit

Permalink
Fix felt252 and enum deserialization bugs. (#844)
Browse files Browse the repository at this point in the history
* Fix felt252 and enum deserialization bugs.

* Fix formatting.

* Also fix the runtime.

* Fix errors.
  • Loading branch information
azteca1998 authored Oct 14, 2024
1 parent f862ec3 commit 5de558f
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 53 deletions.
110 changes: 71 additions & 39 deletions runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use starknet_types_core::{
felt::Felt,
hash::StarkHash,
};
use std::{collections::HashMap, ffi::c_void, fs::File, io::Write, os::fd::FromRawFd, slice};
use std::{collections::HashMap, ffi::c_void, fs::File, io::Write, os::fd::FromRawFd};
use std::{ops::Mul, vec::IntoIter};

lazy_static! {
Expand Down Expand Up @@ -44,7 +44,8 @@ pub unsafe extern "C" fn cairo_native__libfunc__debug__print(
let mut items = Vec::with_capacity(len as usize);

for i in 0..len as usize {
let data = *data.add(i);
let mut data = *data.add(i);
data[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).

let value = Felt::from_bytes_le(&data);
items.push(value);
Expand Down Expand Up @@ -76,22 +77,24 @@ pub unsafe extern "C" fn cairo_native__libfunc__debug__print(
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__libfunc__pedersen(
dst: *mut u8,
lhs: *const u8,
rhs: *const u8,
dst: &mut [u8; 32],
lhs: &[u8; 32],
rhs: &[u8; 32],
) {
// Extract arrays from the pointers.
let dst = slice::from_raw_parts_mut(dst, 32);
let lhs = slice::from_raw_parts(lhs, 32);
let rhs = slice::from_raw_parts(rhs, 32);
let mut lhs = *lhs;
let mut rhs = *rhs;

lhs[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
rhs[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).

// Convert to FieldElement.
let lhs = Felt::from_bytes_le_slice(lhs);
let rhs = Felt::from_bytes_le_slice(rhs);
let lhs = Felt::from_bytes_le(&lhs);
let rhs = Felt::from_bytes_le(&rhs);

// Compute pedersen hash and copy the result into `dst`.
let res = starknet_types_core::hash::Pedersen::hash(&lhs, &rhs);
dst.copy_from_slice(&res.to_bytes_le());
*dst = res.to_bytes_le();
}

/// Compute `hades_permutation(op0, op1, op2)` and replace the operands with the results.
Expand All @@ -108,29 +111,28 @@ pub unsafe extern "C" fn cairo_native__libfunc__pedersen(
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__libfunc__hades_permutation(
op0: *mut u8,
op1: *mut u8,
op2: *mut u8,
op0: &mut [u8; 32],
op1: &mut [u8; 32],
op2: &mut [u8; 32],
) {
// Extract arrays from the pointers.
let op0 = slice::from_raw_parts_mut(op0, 32);
let op1 = slice::from_raw_parts_mut(op1, 32);
let op2 = slice::from_raw_parts_mut(op2, 32);
op0[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
op1[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
op2[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).

// Convert to FieldElement.
let mut state = [
Felt::from_bytes_le_slice(op0),
Felt::from_bytes_le_slice(op1),
Felt::from_bytes_le_slice(op2),
Felt::from_bytes_le(op0),
Felt::from_bytes_le(op1),
Felt::from_bytes_le(op2),
];

// Compute Poseidon permutation.
starknet_types_core::hash::Poseidon::hades_permutation(&mut state);

// Write back the results.
op0.copy_from_slice(&state[0].to_bytes_le());
op1.copy_from_slice(&state[1].to_bytes_le());
op2.copy_from_slice(&state[2].to_bytes_le());
*op0 = state[0].to_bytes_le();
*op1 = state[1].to_bytes_le();
*op2 = state[2].to_bytes_le();
}

/// Felt252 type used in cairo native runtime
Expand Down Expand Up @@ -230,8 +232,11 @@ pub unsafe extern "C" fn cairo_native__dict_get(
dict: &mut FeltDict,
key: &[u8; 32],
) -> *mut c_void {
let mut key = *key;
key[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).

dict.count += 1;
dict.inner.entry(*key).or_insert(std::ptr::null_mut()) as *mut _ as *mut c_void
dict.inner.entry(key).or_insert(std::ptr::null_mut()) as *mut _ as *mut c_void
}

/// Compute the total gas refund for the dictionary at squash time.
Expand Down Expand Up @@ -260,6 +265,7 @@ pub unsafe extern "C" fn cairo_native__dict_gas_refund(ptr: *const FeltDict) ->
pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_point_from_x_nz(
point_ptr: &mut [[u8; 32]; 2],
) -> bool {
point_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
let x = Felt::from_bytes_le(&point_ptr[0]);

// https://github.com/starkware-libs/cairo/blob/aaad921bba52e729dc24ece07fab2edf09ccfa15/crates/cairo-lang-sierra-to-casm/src/invocations/ec.rs#L63
Expand All @@ -276,7 +282,7 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_point_from_x_nz(

match AffinePoint::new(x, y) {
Ok(point) => {
point_ptr.as_mut()[1].copy_from_slice(&point.y().to_bytes_le());
point_ptr[1] = point.y().to_bytes_le();
true
}
Err(_) => false,
Expand All @@ -297,13 +303,16 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_point_from_x_nz(
pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_point_try_new_nz(
point_ptr: &mut [[u8; 32]; 2],
) -> bool {
point_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
point_ptr[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).

let x = Felt::from_bytes_le(&point_ptr[0]);
let y = Felt::from_bytes_le(&point_ptr[1]);

match AffinePoint::new(x, y) {
Ok(point) => {
point_ptr[0].copy_from_slice(&point.x().to_bytes_le());
point_ptr[1].copy_from_slice(&point.y().to_bytes_le());
point_ptr[0] = point.x().to_bytes_le();
point_ptr[1] = point.y().to_bytes_le();
true
}
Err(_) => false,
Expand Down Expand Up @@ -333,10 +342,10 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_init(state_ptr: &mu
// We already made sure its a valid point.
let state = AffinePoint::new_unchecked(random_x, random_y);

state_ptr[0].copy_from_slice(&state.x().to_bytes_le());
state_ptr[1].copy_from_slice(&state.y().to_bytes_le());
state_ptr[2].copy_from_slice(&state.x().to_bytes_le());
state_ptr[3].copy_from_slice(&state.y().to_bytes_le());
state_ptr[0] = state.x().to_bytes_le();
state_ptr[1] = state.y().to_bytes_le();
state_ptr[2] = state_ptr[0];
state_ptr[3] = state_ptr[1];
}

/// Compute `ec_state_add(state, point)` and store the state back.
Expand All @@ -354,6 +363,13 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_add(
state_ptr: &mut [[u8; 32]; 4],
point_ptr: &[[u8; 32]; 2],
) {
state_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
state_ptr[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).

let mut point_ptr = *point_ptr;
point_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
point_ptr[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).

// We use unchecked methods because the inputs must already be valid points.
let mut state = ProjectivePoint::from_affine_unchecked(
Felt::from_bytes_le(&state_ptr[0]),
Expand All @@ -367,8 +383,8 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_add(
state += &point;
let state = state.to_affine().unwrap();

state_ptr[0].copy_from_slice(&state.x().to_bytes_le());
state_ptr[1].copy_from_slice(&state.y().to_bytes_le());
state_ptr[0] = state.x().to_bytes_le();
state_ptr[1] = state.y().to_bytes_le();
}

/// Compute `ec_state_add_mul(state, scalar, point)` and store the state back.
Expand All @@ -387,6 +403,16 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_add_mul(
scalar_ptr: &[u8; 32],
point_ptr: &[[u8; 32]; 2],
) {
state_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
state_ptr[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).

let mut point_ptr = *point_ptr;
point_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
point_ptr[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).

let mut scalar_ptr = *scalar_ptr;
scalar_ptr[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).

// Here the points should already be checked as valid, so we can use unchecked.
let mut state = ProjectivePoint::from_affine_unchecked(
Felt::from_bytes_le(&state_ptr[0]),
Expand All @@ -396,13 +422,13 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_add_mul(
Felt::from_bytes_le(&point_ptr[0]),
Felt::from_bytes_le(&point_ptr[1]),
);
let scalar = Felt::from_bytes_le(scalar_ptr);
let scalar = Felt::from_bytes_le(&scalar_ptr);

state += &point.mul(scalar);
let state = state.to_affine().unwrap();

state_ptr[0].copy_from_slice(&state.x().to_bytes_le());
state_ptr[1].copy_from_slice(&state.y().to_bytes_le());
state_ptr[0] = state.x().to_bytes_le();
state_ptr[1] = state.y().to_bytes_le();
}

/// Compute `ec_state_try_finalize_nz(state)` and store the result.
Expand All @@ -420,6 +446,12 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_try_finalize_nz(
point_ptr: &mut [[u8; 32]; 2],
state_ptr: &[[u8; 32]; 4],
) -> bool {
let mut state_ptr = *state_ptr;
state_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
state_ptr[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
state_ptr[2][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
state_ptr[3][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).

// We use unchecked methods because the inputs must already be valid points.
let state = ProjectivePoint::from_affine_unchecked(
Felt::from_bytes_le(&state_ptr[0]),
Expand All @@ -436,8 +468,8 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_try_finalize_nz(
let point = &state - &random;
let point = point.to_affine().unwrap();

point_ptr[0].copy_from_slice(&point.x().to_bytes_le());
point_ptr[1].copy_from_slice(&point.y().to_bytes_le());
point_ptr[0] = point.x().to_bytes_le();
point_ptr[1] = point.y().to_bytes_le();

true
}
Expand Down
18 changes: 13 additions & 5 deletions src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,11 +391,13 @@ fn parse_result(
return Err(Error::ParseAttributeError);

#[cfg(target_arch = "aarch64")]
Ok(Value::Felt252(
starknet_types_core::felt::Felt::from_bytes_le(unsafe {
std::mem::transmute::<&[u64; 4], &[u8; 32]>(&ret_registers)
}),
))
Ok(Value::Felt252({
let data = unsafe {
std::mem::transmute::<&mut [u64; 4], &mut [u8; 32]>(&mut ret_registers)
};
data[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
starknet_types_core::felt::Felt::from_bytes_le(data)
}))
}
},
CoreTypeConcrete::Bytes31(_) => match return_ptr {
Expand Down Expand Up @@ -509,6 +511,12 @@ fn parse_result(
}
};

// Filter out bits that are not part of the enum's tag.
let tag = tag
& 1usize
.wrapping_shl(info.variants.len().next_power_of_two().trailing_zeros())
.wrapping_sub(1);

(
tag,
Ok(unsafe {
Expand Down
5 changes: 4 additions & 1 deletion src/executor/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@ impl AotContractExecutor {
};

let tag = *unsafe { enum_ptr.cast::<u8>().as_ref() } as usize;
let tag = tag & 0x01; // Filter out bits that are not part of the enum's tag.

// layout of both enum variants, both are a array of felts
let value_layout = unsafe { Layout::from_size_align_unchecked(24, 8) };
let value_ptr = unsafe {
Expand Down Expand Up @@ -401,7 +403,8 @@ impl AotContractExecutor {
for i in 0..num_elems {
// safe to create a NonNull because if the array has elements, the data_ptr can't be null.
let cur_elem_ptr = NonNull::new(unsafe { data_ptr.byte_add(elem_stride * i) }).unwrap();
let data = unsafe { cur_elem_ptr.cast::<[u8; 32]>().as_ref() };
let data = unsafe { cur_elem_ptr.cast::<[u8; 32]>().as_mut() };
data[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
let data = Felt::from_bytes_le_slice(data);

array_value.push(data);
Expand Down
34 changes: 26 additions & 8 deletions src/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,12 +575,20 @@ impl Value {
value
}
CoreTypeConcrete::EcPoint(_) => {
let data = ptr.cast::<[[u8; 32]; 2]>().as_ref();
let data = ptr.cast::<[[u8; 32]; 2]>().as_mut();

data[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
data[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).

Self::EcPoint(Felt::from_bytes_le(&data[0]), Felt::from_bytes_le(&data[1]))
}
CoreTypeConcrete::EcState(_) => {
let data = ptr.cast::<[[u8; 32]; 4]>().as_ref();
let data = ptr.cast::<[[u8; 32]; 4]>().as_mut();

data[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
data[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
data[2][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
data[3][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).

Self::EcState(
Felt::from_bytes_le(&data[0]),
Expand All @@ -590,7 +598,8 @@ impl Value {
)
}
CoreTypeConcrete::Felt252(_) => {
let data = ptr.cast::<[u8; 32]>().as_ref();
let data = ptr.cast::<[u8; 32]>().as_mut();
data[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
let data = Felt::from_bytes_le_slice(data);
Self::Felt252(data)
}
Expand Down Expand Up @@ -645,6 +654,12 @@ impl Value {
},
};

// Filter out bits that are not part of the enum's tag.
let tag_value = tag_value
& 1usize
.wrapping_shl(info.variants.len().next_power_of_two().trailing_zeros())
.wrapping_sub(1);

let payload_ty = registry.get_type(&info.variants[tag_value])?;
let payload_layout = payload_ty.layout(registry)?;

Expand Down Expand Up @@ -695,21 +710,23 @@ impl Value {
);

let mut output_map = HashMap::with_capacity(inner.len());
for (key, val_ptr) in inner.iter() {
for (mut key, val_ptr) in inner.into_iter() {
if val_ptr.is_null() {
continue;
}

let key = Felt::from_bytes_le(key);
key[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).

let key = Felt::from_bytes_le(&key);
output_map.insert(
key,
Self::from_ptr(
NonNull::new(*val_ptr).unwrap().cast(),
NonNull::new(val_ptr).unwrap().cast(),
&info.ty,
registry,
)?,
);
libc_free(*val_ptr);
libc_free(val_ptr);
}

Self::Felt252Dict {
Expand Down Expand Up @@ -737,7 +754,8 @@ impl Value {
| StarkNetTypeConcrete::StorageBaseAddress(_)
| StarkNetTypeConcrete::StorageAddress(_) => {
// felt values
let data = ptr.cast::<[u8; 32]>().as_ref();
let data = ptr.cast::<[u8; 32]>().as_mut();
data[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).
let data = Felt::from_bytes_le(data);
Self::Felt252(data)
}
Expand Down

0 comments on commit 5de558f

Please sign in to comment.