From 524f76f3f652fd60e7eff301078b7ef177e59004 Mon Sep 17 00:00:00 2001 From: Esteve Soler Arderiu Date: Thu, 10 Oct 2024 19:34:07 +0200 Subject: [PATCH] Also fix the runtime. --- runtime/src/lib.rs | 106 +++++++++++++++++++++++++++++---------------- 1 file changed, 69 insertions(+), 37 deletions(-) diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index 2110b0c58..cca5cfd06 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -44,7 +44,8 @@ pub unsafe extern "C" fn cairo_native__libfunc__debug__print( let mut items = Vec::with_capacity(len as usize); for i in 0..len as usize { - let data = *data.add(i); + let mut data = *data.add(i); + data[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). let value = Felt::from_bytes_le(&data); items.push(value); @@ -76,22 +77,24 @@ pub unsafe extern "C" fn cairo_native__libfunc__debug__print( /// definitely unsafe to use manually. #[no_mangle] pub unsafe extern "C" fn cairo_native__libfunc__pedersen( - dst: *mut u8, - lhs: *const u8, - rhs: *const u8, + dst: &mut [u8; 32], + lhs: &[u8; 32], + rhs: &[u8; 32], ) { // Extract arrays from the pointers. - let dst = slice::from_raw_parts_mut(dst, 32); - let lhs = slice::from_raw_parts(lhs, 32); - let rhs = slice::from_raw_parts(rhs, 32); + let mut lhs = *lhs; + let mut rhs = *rhs; + + lhs[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + rhs[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). // Convert to FieldElement. - let lhs = Felt::from_bytes_le_slice(lhs); - let rhs = Felt::from_bytes_le_slice(rhs); + let lhs = Felt::from_bytes_le(&lhs); + let rhs = Felt::from_bytes_le(&rhs); // Compute pedersen hash and copy the result into `dst`. let res = starknet_types_core::hash::Pedersen::hash(&lhs, &rhs); - dst.copy_from_slice(&res.to_bytes_le()); + *dst = res.to_bytes_le(); } /// Compute `hades_permutation(op0, op1, op2)` and replace the operands with the results. @@ -108,29 +111,28 @@ pub unsafe extern "C" fn cairo_native__libfunc__pedersen( /// definitely unsafe to use manually. #[no_mangle] pub unsafe extern "C" fn cairo_native__libfunc__hades_permutation( - op0: *mut u8, - op1: *mut u8, - op2: *mut u8, + op0: &mut [u8; 32], + op1: &mut [u8; 32], + op2: &mut [u8; 32], ) { - // Extract arrays from the pointers. - let op0 = slice::from_raw_parts_mut(op0, 32); - let op1 = slice::from_raw_parts_mut(op1, 32); - let op2 = slice::from_raw_parts_mut(op2, 32); + op0[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + op1[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + op2[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). // Convert to FieldElement. let mut state = [ - Felt::from_bytes_le_slice(op0), - Felt::from_bytes_le_slice(op1), - Felt::from_bytes_le_slice(op2), + Felt::from_bytes_le(op0), + Felt::from_bytes_le(op1), + Felt::from_bytes_le(op2), ]; // Compute Poseidon permutation. starknet_types_core::hash::Poseidon::hades_permutation(&mut state); // Write back the results. - op0.copy_from_slice(&state[0].to_bytes_le()); - op1.copy_from_slice(&state[1].to_bytes_le()); - op2.copy_from_slice(&state[2].to_bytes_le()); + *op0 = state[0].to_bytes_le(); + *op1 = state[1].to_bytes_le(); + *op2 = state[2].to_bytes_le(); } /// Felt252 type used in cairo native runtime @@ -230,8 +232,11 @@ pub unsafe extern "C" fn cairo_native__dict_get( dict: &mut FeltDict, key: &[u8; 32], ) -> *mut c_void { + let mut key = *key; + key[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + dict.count += 1; - dict.inner.entry(*key).or_insert(std::ptr::null_mut()) as *mut _ as *mut c_void + dict.inner.entry(key).or_insert(std::ptr::null_mut()) as *mut _ as *mut c_void } /// Compute the total gas refund for the dictionary at squash time. @@ -260,6 +265,7 @@ pub unsafe extern "C" fn cairo_native__dict_gas_refund(ptr: *const FeltDict) -> pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_point_from_x_nz( point_ptr: &mut [[u8; 32]; 2], ) -> bool { + point_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). let x = Felt::from_bytes_le(&point_ptr[0]); // https://github.com/starkware-libs/cairo/blob/aaad921bba52e729dc24ece07fab2edf09ccfa15/crates/cairo-lang-sierra-to-casm/src/invocations/ec.rs#L63 @@ -276,7 +282,7 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_point_from_x_nz( match AffinePoint::new(x, y) { Ok(point) => { - point_ptr.as_mut()[1].copy_from_slice(&point.y().to_bytes_le()); + point_ptr[1] = point.y().to_bytes_le(); true } Err(_) => false, @@ -297,13 +303,16 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_point_from_x_nz( pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_point_try_new_nz( point_ptr: &mut [[u8; 32]; 2], ) -> bool { + point_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + point_ptr[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + let x = Felt::from_bytes_le(&point_ptr[0]); let y = Felt::from_bytes_le(&point_ptr[1]); match AffinePoint::new(x, y) { Ok(point) => { - point_ptr[0].copy_from_slice(&point.x().to_bytes_le()); - point_ptr[1].copy_from_slice(&point.y().to_bytes_le()); + point_ptr[0] = point.x().to_bytes_le(); + point_ptr[1] = point.y().to_bytes_le(); true } Err(_) => false, @@ -333,10 +342,10 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_init(state_ptr: &mu // We already made sure its a valid point. let state = AffinePoint::new_unchecked(random_x, random_y); - state_ptr[0].copy_from_slice(&state.x().to_bytes_le()); - state_ptr[1].copy_from_slice(&state.y().to_bytes_le()); - state_ptr[2].copy_from_slice(&state.x().to_bytes_le()); - state_ptr[3].copy_from_slice(&state.y().to_bytes_le()); + state_ptr[0] = state.x().to_bytes_le(); + state_ptr[1] = state.y().to_bytes_le(); + state_ptr[2] = state_ptr[0]; + state_ptr[3] = state_ptr[1]; } /// Compute `ec_state_add(state, point)` and store the state back. @@ -354,6 +363,13 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_add( state_ptr: &mut [[u8; 32]; 4], point_ptr: &[[u8; 32]; 2], ) { + state_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + state_ptr[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + + let mut point_ptr = *point_ptr; + point_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + point_ptr[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + // We use unchecked methods because the inputs must already be valid points. let mut state = ProjectivePoint::from_affine_unchecked( Felt::from_bytes_le(&state_ptr[0]), @@ -367,8 +383,8 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_add( state += &point; let state = state.to_affine().unwrap(); - state_ptr[0].copy_from_slice(&state.x().to_bytes_le()); - state_ptr[1].copy_from_slice(&state.y().to_bytes_le()); + state_ptr[0] = state.x().to_bytes_le(); + state_ptr[1] = state.y().to_bytes_le(); } /// Compute `ec_state_add_mul(state, scalar, point)` and store the state back. @@ -387,6 +403,16 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_add_mul( scalar_ptr: &[u8; 32], point_ptr: &[[u8; 32]; 2], ) { + state_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + state_ptr[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + + let mut point_ptr = *point_ptr; + point_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + point_ptr[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + + let scalar_ptr = *scalar_ptr; + scalar_ptr[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + // Here the points should already be checked as valid, so we can use unchecked. let mut state = ProjectivePoint::from_affine_unchecked( Felt::from_bytes_le(&state_ptr[0]), @@ -401,8 +427,8 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_add_mul( state += &point.mul(scalar); let state = state.to_affine().unwrap(); - state_ptr[0].copy_from_slice(&state.x().to_bytes_le()); - state_ptr[1].copy_from_slice(&state.y().to_bytes_le()); + state_ptr[0] = state.x().to_bytes_le(); + state_ptr[1] = state.y().to_bytes_le(); } /// Compute `ec_state_try_finalize_nz(state)` and store the result. @@ -420,6 +446,12 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_try_finalize_nz( point_ptr: &mut [[u8; 32]; 2], state_ptr: &[[u8; 32]; 4], ) -> bool { + let mut state_ptr = *state_ptr; + state_ptr[0][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + state_ptr[1][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + state_ptr[2][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + state_ptr[3][31] &= 0x0F; // Filter out first 4 bits (they're outside an i252). + // We use unchecked methods because the inputs must already be valid points. let state = ProjectivePoint::from_affine_unchecked( Felt::from_bytes_le(&state_ptr[0]), @@ -436,8 +468,8 @@ pub unsafe extern "C" fn cairo_native__libfunc__ec__ec_state_try_finalize_nz( let point = &state - &random; let point = point.to_affine().unwrap(); - point_ptr[0].copy_from_slice(&point.x().to_bytes_le()); - point_ptr[1].copy_from_slice(&point.y().to_bytes_le()); + point_ptr[0] = point.x().to_bytes_le(); + point_ptr[1] = point.y().to_bytes_le(); true }