diff --git a/ipa-core/src/error.rs b/ipa-core/src/error.rs index 0c2705f1f..6b99bae13 100644 --- a/ipa-core/src/error.rs +++ b/ipa-core/src/error.rs @@ -36,6 +36,8 @@ pub enum Error { MaliciousSecurityCheckFailed, #[error("malicious reveal failed")] MaliciousRevealFailed, + #[error("share values were inconsistent between helpers")] + InconsistentShares, #[error("problem during IO: {0}")] Io(#[from] std::io::Error), // TODO remove if this https://github.com/awslabs/shuttle/pull/109 gets approved diff --git a/ipa-core/src/protocol/basics/mod.rs b/ipa-core/src/protocol/basics/mod.rs index 92ebdbfc2..e61f711d7 100644 --- a/ipa-core/src/protocol/basics/mod.rs +++ b/ipa-core/src/protocol/basics/mod.rs @@ -7,6 +7,7 @@ mod reshare; mod reveal; mod share_known_value; pub mod sum_of_product; +pub mod validate; pub use check_zero::check_zero; pub use if_else::if_else; diff --git a/ipa-core/src/protocol/basics/validate.rs b/ipa-core/src/protocol/basics/validate.rs new file mode 100644 index 000000000..80a42abc5 --- /dev/null +++ b/ipa-core/src/protocol/basics/validate.rs @@ -0,0 +1,344 @@ +#![allow(dead_code)] // Not wired in yet. + +use std::{ + convert::Infallible, + marker::PhantomData, + pin::Pin, + task::{Context as TaskContext, Poll}, +}; + +use futures::{ + future::try_join, + stream::{Fuse, Stream, StreamExt}, + Future, FutureExt, +}; +use generic_array::GenericArray; +use pin_project::pin_project; +use sha2::{ + digest::{typenum::Unsigned, FixedOutput, OutputSizeUser}, + Digest, Sha256, +}; + +use crate::{ + error::Error, + ff::Serializable, + helpers::{Direction, Message}, + protocol::{context::Context, RecordId}, + secret_sharing::{replicated::ReplicatedSecretSharing, SharedValue}, + seq_join::assert_send, +}; + +type HashFunction = Sha256; +type HashSize = ::OutputSize; +type HashOutputArray = [u8; ::USIZE]; + +#[derive(Debug, Clone, PartialEq, Eq)] +struct HashValue(GenericArray); + +impl Serializable for HashValue { + type Size = HashSize; + type DeserializationError = Infallible; + + fn serialize(&self, buf: &mut GenericArray) { + buf.copy_from_slice(self.0.as_slice()); + } + + fn deserialize(buf: &GenericArray) -> Result { + Ok(Self(buf.to_owned())) + } +} + +impl Message for HashValue {} + +impl From for HashValue { + fn from(value: HashFunction) -> Self { + // Ugh: The version of sha2 we currently use doesn't use the same GenericArray version as we do. + HashValue(GenericArray::from(::from( + value.finalize_fixed(), + ))) + } +} + +/// The finalizing state for the validator. +struct ReplicatedValidatorFinalization<'a> { + f: Pin> + Send + 'a)>>, +} + +impl<'a> ReplicatedValidatorFinalization<'a> { + fn new(active: ReplicatedValidatorActive<'a, C>) -> Self { + let ReplicatedValidatorActive { + ctx, + left_hash, + right_hash, + .. + } = active; + let left_hash = HashValue::from(left_hash); + let right_hash = HashValue::from(right_hash); + let left_peer = ctx.role().peer(Direction::Left); + let right_peer = ctx.role().peer(Direction::Right); + + let f = Box::pin(assert_send(async move { + try_join( + ctx.send_channel(left_peer) + .send(RecordId::FIRST, left_hash.clone()), + ctx.send_channel(right_peer) + .send(RecordId::FIRST, right_hash.clone()), + ) + .await?; + let (left_recvd, right_recvd) = try_join( + ctx.recv_channel(left_peer).receive(RecordId::FIRST), + ctx.recv_channel(right_peer).receive(RecordId::FIRST), + ) + .await?; + if left_hash == left_recvd && right_hash == right_recvd { + Ok(()) + } else { + Err(Error::InconsistentShares) + } + })); + Self { f } + } + + fn poll(&mut self, cx: &mut TaskContext<'_>) -> Poll> { + self.f.poll_unpin(cx) + } +} + +/// The active state for the validator. +struct ReplicatedValidatorActive<'a, C: 'a> { + ctx: C, + left_hash: Sha256, + right_hash: Sha256, + _marker: PhantomData<&'a ()>, +} + +impl<'a, C: Context + 'a> ReplicatedValidatorActive<'a, C> { + fn new(ctx: C) -> Self { + Self { + ctx, + left_hash: HashFunction::new(), + right_hash: HashFunction::new(), + _marker: PhantomData, + } + } + + fn update(&mut self, s: &S) + where + S: ReplicatedSecretSharing, + V: SharedValue, + { + let mut buf = GenericArray::default(); // ::::Size> + s.left().serialize(&mut buf); + self.left_hash.update(buf.as_slice()); + s.right().serialize(&mut buf); + self.right_hash.update(buf.as_slice()); + } + + fn finalize(self) -> ReplicatedValidatorFinalization<'a> { + ReplicatedValidatorFinalization::new(self) + } +} + +enum ReplicatedValidatorState<'a, C: 'a> { + /// While the validator is waiting, it holds a context reference. + Pending(Option>>), + /// After the validator has taken all of its inputs, it holds a future. + Finalizing(ReplicatedValidatorFinalization<'a>), +} + +impl<'a, C: Context + 'a> ReplicatedValidatorState<'a, C> { + /// # Panics + /// This panics if it is called after `finalize()`. + fn update(&mut self, s: &S) + where + S: ReplicatedSecretSharing, + V: SharedValue, + { + if let Self::Pending(Some(a)) = self { + a.update(s); + } else { + panic!(); + } + } + + fn poll(&mut self, cx: &mut TaskContext<'_>) -> Poll> { + match self { + Self::Pending(ref mut active) => { + let mut f = active.take().unwrap().finalize(); + let res = f.poll(cx); + *self = ReplicatedValidatorState::Finalizing(f); + res + } + Self::Finalizing(f) => f.poll(cx), + } + } +} + +/// A `ReplicatedValidator` takes a stream of replicated shares of anything +/// and produces a stream of the same values, without modifying them. +/// The only thing it does is check that the values are consistent across +/// all three helpers using the provided context. +/// To do this, it sends a single message. +/// +/// If validation passes, the stream is completely transparent. +/// If validation fails, the stream will error before it closes. +#[pin_project] +struct ReplicatedValidator<'a, C: 'a, T: Stream, S, V> { + #[pin] + input: Fuse, + state: ReplicatedValidatorState<'a, C>, + _marker: PhantomData<(S, V)>, +} + +impl<'a, C: Context + 'a, T: Stream, S, V> ReplicatedValidator<'a, C, T, S, V> { + pub fn new(ctx: &C, s: T) -> Self { + Self { + input: s.fuse(), + state: ReplicatedValidatorState::Pending(Some(Box::new( + ReplicatedValidatorActive::new(ctx.set_total_records(1)), + ))), + _marker: PhantomData, + } + } +} + +impl<'a, C, T, S, V> Stream for ReplicatedValidator<'a, C, T, S, V> +where + C: Context + 'a, + T: Stream>, + S: ReplicatedSecretSharing, + V: SharedValue, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + let this = self.project(); + match this.input.poll_next(cx) { + Poll::Ready(Some(v)) => match v { + Ok(v) => { + this.state.update(&v); + Poll::Ready(Some(Ok(v))) + } + Err(e) => Poll::Ready(Some(Err(e))), + }, + Poll::Ready(None) => match this.state.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(())) => Poll::Ready(None), + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + }, + Poll::Pending => Poll::Pending, + } + } + + fn size_hint(&self) -> (usize, Option) { + self.input.size_hint() + } +} + +#[cfg(all(test, unit_test))] +mod test { + use std::iter::repeat_with; + + use futures::stream::{iter as stream_iter, Stream, StreamExt, TryStreamExt}; + + use crate::{ + error::Error, + ff::{Field, Fp31}, + helpers::{Direction, Role}, + protocol::{basics::validate::ReplicatedValidator, context::Context, RecordId}, + rand::{thread_rng, Rng}, + secret_sharing::{ + replicated::{ + semi_honest::AdditiveShare as SemiHonestReplicated, ReplicatedSecretSharing, + }, + SharedValue, + }, + test_fixture::{Reconstruct, Runner, TestWorld}, + }; + + fn assert_stream>, T>(s: S) -> S { + s + } + + /// Successfully validate some shares. + #[tokio::test] + pub async fn simple() { + let mut rng = thread_rng(); + let world = TestWorld::default(); + + let input = repeat_with(|| rng.gen::()) + .take(10) + .collect::>(); + let result = world + .semi_honest(input.into_iter(), |ctx, shares| async move { + let s = stream_iter(shares).map(Ok); + let vs = ReplicatedValidator::new(&ctx.narrow("validate"), s); + let sum = assert_stream(vs) + .try_fold(Fp31::ZERO, |sum, value| async move { + Ok(sum + value.left() - value.right()) + }) + .await?; + let ctx = ctx.set_total_records(1); + // This value should sum to zero now, so replicate the value. + // (We don't care here that this reveals our share to other helpers, it's just a test.) + ctx.send_channel(ctx.role().peer(Direction::Right)) + .send(RecordId::FIRST, sum) + .await?; + let left = ctx + .recv_channel(ctx.role().peer(Direction::Left)) + .receive(RecordId::FIRST) + .await?; + Ok(SemiHonestReplicated::new(left, sum)) + }) + .await + .map(Result::<_, Error>::unwrap) + .reconstruct(); + + assert_eq!(Fp31::ZERO, result); + } + + #[tokio::test] + pub async fn inconsistent() { + let mut rng = thread_rng(); + let world = TestWorld::default(); + + let damage = |role| { + let mut tweak = role == Role::H3; + move |v: SemiHonestReplicated| -> SemiHonestReplicated { + if tweak { + tweak = false; + SemiHonestReplicated::new(v.left(), v.right() + Fp31::ONE) + } else { + v + } + } + }; + + let input = repeat_with(|| rng.gen::()) + .take(10) + .collect::>(); + let result = world + .semi_honest(input.into_iter(), |ctx, shares| async move { + let s = stream_iter(shares).map(damage(ctx.role())).map(Ok); + let vs = ReplicatedValidator::new(&ctx.narrow("validate"), s); + let sum = assert_stream(vs) + .try_fold(Fp31::ZERO, |sum, value| async move { + Ok(sum + value.left() - value.right()) + }) + .await?; + Ok(sum) // This will be not be reached by 2/3 helpers. + }) + .await; + + // With just one error having been introduced, two of three helpers will error out. + assert!(matches!( + result[0].as_ref().unwrap_err(), + Error::InconsistentShares + )); + assert!(result[1].is_ok()); + assert!(matches!( + result[2].as_ref().unwrap_err(), + Error::InconsistentShares + )); + } +} diff --git a/ipa-macros/src/derive_step/mod.rs b/ipa-macros/src/derive_step/mod.rs index 9093916e5..f03060422 100644 --- a/ipa-macros/src/derive_step/mod.rs +++ b/ipa-macros/src/derive_step/mod.rs @@ -44,7 +44,7 @@ use syn::{parse_macro_input, DeriveInput}; use crate::{ parser::{group_by_modules, ipa_state_transition_map, StepMetaData}, - tree::Node, + tree::{self, Node}, }; const MAX_DYNAMIC_STEPS: usize = 1024; @@ -115,7 +115,7 @@ fn impl_as_ref(ident: &syn::Ident, data: &syn::DataEnum) -> Result Result>(); let steps_array_ident = format_ident!("{}_DYNAMIC_STEP", ident_upper_case); const_arrays.extend(quote!( @@ -272,9 +272,8 @@ fn get_meta_data_for( 1 => { Ok(target_steps[0] .iter() - .map(|s| - // we want to retain the references to the parents, so we use `upgrade()` - s.upgrade()) + // we want to retain the references to the parents, so we use `upgrade()` + .map(tree::Node::upgrade) .collect::>()) } _ => Err(syn::Error::new_spanned( @@ -314,8 +313,7 @@ fn get_dynamic_step_count(variant: &syn::Variant) -> Result { dynamic_attr, format!( "ipa_macros::step \"dynamic\" attribute expects a number of steps \ - (<= {}) in parentheses: #[dynamic(...)].", - MAX_DYNAMIC_STEPS, + (<= {MAX_DYNAMIC_STEPS}) in parentheses: #[dynamic(...)].", ), )), } diff --git a/ipa-macros/src/parser.rs b/ipa-macros/src/parser.rs index 25dd32bbe..691341c64 100644 --- a/ipa-macros/src/parser.rs +++ b/ipa-macros/src/parser.rs @@ -71,7 +71,10 @@ pub(crate) fn read_steps_file(file_path: &str) -> Vec { let mut file = std::fs::File::open(path).expect("Could not open the steps file"); let mut contents = String::new(); file.read_to_string(&mut contents).unwrap(); - contents.lines().map(|s| s.to_owned()).collect::>() + contents + .lines() + .map(std::borrow::ToOwned::to_owned) + .collect::>() } /// Constructs a tree structure with nodes that contain the `Step` instances. @@ -109,10 +112,15 @@ pub(crate) fn construct_tree(steps: Vec) -> Node { /// Split a single substep full path into the module path and the step's name. /// /// # Example +/// ```ignore /// input = "ipa::protocol::modulus_conversion::convert_shares::Step::xor1" /// output = ("ipa::protocol::modulus_conversion::convert_shares::Step", "xor1") +/// ``` pub(crate) fn split_step_module_and_name(input: &str) -> (String, String) { - let mod_parts = input.split("::").map(|s| s.to_owned()).collect::>(); + let mod_parts = input + .split("::") + .map(std::borrow::ToOwned::to_owned) + .collect::>(); let (substep_name, path) = mod_parts.split_last().unwrap(); (path.join("::"), substep_name.to_owned()) } @@ -123,8 +131,8 @@ pub(crate) fn split_step_module_and_name(input: &str) -> (String, String) { /// # Example /// Let say we have the following steps: /// -/// - StepA::A1 -/// - StepC::C1/StepD::D1/StepA::A2 +/// - `StepA::A1` +/// - `StepC::C1/StepD::D1/StepA::A2` /// /// If we generate code for each node while traversing, we will end up with the following: ///