From d0d2c988a83d158789d9e6c00dc4f93d8f408f80 Mon Sep 17 00:00:00 2001 From: Jonas Schneider-Bensch Date: Tue, 4 Jun 2024 11:51:08 +0200 Subject: [PATCH 01/14] Simplify `AuthBit` using a const generic parameter --- atlas-spec/mpc-engine/src/party.rs | 247 ++++++++---------- .../mpc-engine/src/primitives/auth_share.rs | 28 +- atlas-spec/mpc-engine/src/primitives/mac.rs | 10 + 3 files changed, 136 insertions(+), 149 deletions(-) diff --git a/atlas-spec/mpc-engine/src/party.rs b/atlas-spec/mpc-engine/src/party.rs index 9dd7fbb..7e6831a 100644 --- a/atlas-spec/mpc-engine/src/party.rs +++ b/atlas-spec/mpc-engine/src/party.rs @@ -11,7 +11,7 @@ use crate::{ auth_share::{AuthBit, Bit, BitID, BitKey}, commitment::{Commitment, Opening}, mac::{ - generate_mac_key, hash_to_mac_width, mac, verify_mac, xor_mac_width, Mac, MacKey, + self, generate_mac_key, hash_to_mac_width, mac, verify_mac, xor_mac_width, Mac, MacKey, MAC_LENGTH, }, }, @@ -27,6 +27,7 @@ const SEC_MARGIN_BIT_AUTH: usize = 2 * STATISTICAL_SECURITY * 8; pub(crate) const SEC_MARGIN_SHARE_AUTH: usize = STATISTICAL_SECURITY * 8; const EVALUATOR_ID: usize = 0; +const NUM_PARTIES: usize = 4; /// Collects all party communication channels. /// @@ -73,15 +74,15 @@ pub struct Party { /// A local source of random bits and bytes entropy: Randomness, /// Pool of pre-computed authenticated bits - abit_pool: Vec, + abit_pool: Vec>, /// Pool of pre-computed authenticated shares - ashare_pool: Vec, + ashare_pool: Vec>, /// Whether to log events enable_logging: bool, /// Incremental counter for ordering logs log_counter: u128, /// Wire labels for every wire in the circuit - wire_shares: Vec)>>, + wire_shares: Vec, Option)>>, } impl Party { @@ -294,7 +295,7 @@ impl Party { /// After this point the guarantee is that a pair-wise consistent /// `global_mac_key` was used in all bit-authentications between two /// parties. - fn precompute_abits(&mut self, len: usize) -> Result, Error> { + fn precompute_abits(&mut self, len: usize) -> Result>, Error> { let len_unchecked = len + SEC_MARGIN_BIT_AUTH; // 1. Generate `len_unchecked` random local bits for authenticating. @@ -317,7 +318,7 @@ impl Party { let mut authenticated_bits = Vec::new(); for (_bit_index, bit) in bits.into_iter().enumerate() { let mut computed_keys: Vec = Vec::new(); - let mut received_macs = Vec::new(); + let mut received_macs = [mac::zero_mac(); NUM_PARTIES]; // Obliviously authenticate local bits of earlier parties. for bit_holder in 0..self.id { @@ -332,7 +333,7 @@ impl Party { } let received_mac: Mac = self.obtain_bit_authentication(authenticator, &bit)?; - received_macs.push((authenticator, received_mac)); + received_macs[authenticator] = received_mac; } // Obliviously authenticate local bits of later parties. @@ -363,9 +364,13 @@ impl Party { } /// Transform authenticated bits into `len` authenticated bit shares. - fn random_authenticated_shares(&mut self, len: usize) -> Result, Error> { + fn random_authenticated_shares( + &mut self, + len: usize, + ) -> Result>, Error> { let len_unchecked = len + SEC_MARGIN_SHARE_AUTH; - let authenticated_bits: Vec = self.abit_pool.drain(..len_unchecked).collect(); + let authenticated_bits: Vec> = + self.abit_pool.drain(..len_unchecked).collect(); // Malicious security checks for r in len..len + SEC_MARGIN_SHARE_AUTH { @@ -412,7 +417,9 @@ impl Party { .expect("should have received commitments from all parties"); other_bits_macs.push(( party, - AuthBit::deserialize_bit_macs(&their_mac_commitment.open(&their_opening)?)?, + AuthBit::::deserialize_bit_macs( + &their_mac_commitment.open(&their_opening)?, + )?, )); } @@ -433,12 +440,7 @@ impl Party { for p in 0..self.num_parties { let their_mac = if p == self.id { - authenticated_bits[r] - .macs - .iter() - .find(|(party, _mac)| *party == maccing_party) - .expect("should have MACs from all other parties") - .1 + authenticated_bits[r].macs[maccing_party] } else { let (_sending_party, (_other_bit, other_macs)) = other_bits_macs .iter() @@ -507,7 +509,11 @@ impl Party { } /// Compute unauthenticated cross terms in an AND triple output share. - fn half_and(&mut self, x: &AuthBit, y: &AuthBit) -> Result { + fn half_and( + &mut self, + x: &AuthBit, + y: &AuthBit, + ) -> Result { /// Obtain the least significant bit of some hash output fn lsb(input: &[u8]) -> bool { (input[input.len() - 1] & 1) != 0 @@ -528,12 +534,7 @@ impl Party { } = hashes_message { debug_assert_eq!(to, self.id); - let their_mac = x - .macs - .iter() - .find(|(party, _mac)| *party == from) - .expect("should have MACs from all other parties") - .1; + let their_mac = x.macs[from]; let hash_lsb = lsb(&hash_to_mac_width(domain_separator, &their_mac)); let t_j = if x.bit.value { hash_j_1 ^ hash_lsb @@ -591,12 +592,8 @@ impl Party { } = hashes_message { debug_assert_eq!(to, self.id); - let their_mac = x - .macs - .iter() - .find(|(party, _mac)| *party == from) - .expect("should have MACs from all other parties") - .1; + let their_mac = x.macs[from]; + let hash_lsb = lsb(&hash_to_mac_width(domain_separator, &their_mac)); let t_j = if x.bit.value { hash_j_1 ^ hash_lsb @@ -623,9 +620,19 @@ impl Party { } /// Compute authenticated AND triples. - fn random_leaky_and(&mut self, len: usize) -> Result, Error> { + fn random_leaky_and( + &mut self, + len: usize, + ) -> Result< + Vec<( + AuthBit, + AuthBit, + AuthBit, + )>, + Error, + > { let mut results = Vec::new(); - let mut shares: Vec = self.ashare_pool.drain(..3 * len).collect(); + let mut shares: Vec> = self.ashare_pool.drain(..3 * len).collect(); for _i in 0..len { let x = shares.pop().expect("requested enough authenticated bits"); let y = shares.pop().expect("requested enough authenticated bits"); @@ -656,12 +663,9 @@ impl Party { // 4. compute Phi let mut phi = [0u8; MAC_LENGTH]; for key in y.mac_keys.iter() { - let (_, their_mac) = y - .macs - .iter() - .find(|(maccing_party, _)| *maccing_party == key.bit_holder) - .unwrap(); - let intermediate_xor = xor_mac_width(&key.mac_key, their_mac); + let their_mac = y.macs[key.bit_holder]; + + let intermediate_xor = xor_mac_width(&key.mac_key, &their_mac); phi = xor_mac_width(&phi, &intermediate_xor); } @@ -683,12 +687,9 @@ impl Party { { debug_assert_eq!(self.id, to); // compute M_phi - let (_, their_mac) = x - .macs - .iter() - .find(|(maccing_party, _)| *maccing_party == from) - .expect("should have MACs from all other parties"); - let mut mac_phi = hash_to_mac_width(domain_separator_triple, their_mac); + let their_mac = x.macs[from]; + + let mut mac_phi = hash_to_mac_width(domain_separator_triple, &their_mac); if x.bit.value { for byte in 0..MAC_LENGTH { mac_phi[byte] ^= u[byte]; @@ -743,12 +744,9 @@ impl Party { { debug_assert_eq!(self.id, to); // compute M_phi - let (_, their_mac) = x - .macs - .iter() - .find(|(maccing_party, _)| *maccing_party == from) - .expect("should have MACs from all other parties"); - let mut mac_phi = hash_to_mac_width(domain_separator_triple, their_mac); + let their_mac = x.macs[from]; + + let mut mac_phi = hash_to_mac_width(domain_separator_triple, &their_mac); if x.bit.value { for byte in 0..MAC_LENGTH { mac_phi[byte] ^= u[byte]; @@ -775,12 +773,9 @@ impl Party { } for key in z.mac_keys.iter() { - let (_, their_mac) = z - .macs - .iter() - .find(|(maccing_party, _)| key.bit_holder == *maccing_party) - .expect("should have MACs from all other parties"); - let intermediate_xor = xor_mac_width(&key.mac_key, their_mac); + let their_mac = z.macs[key.bit_holder]; + + let intermediate_xor = xor_mac_width(&key.mac_key, &their_mac); h = xor_mac_width(&h, &intermediate_xor); } @@ -816,7 +811,7 @@ impl Party { } /// Verifiably open an authenticated bit, revealing its value to all parties. - fn open_bit(&mut self, bit: &AuthBit) -> Result { + fn open_bit(&mut self, bit: &AuthBit) -> Result { let mut other_bits = Vec::new(); // receive earlier parties MACs and verify them @@ -848,16 +843,13 @@ impl Party { if j == self.id { continue; } - let (_, their_mac) = bit - .macs - .iter() - .find(|(maccing_party, _mac)| j == *maccing_party) - .expect("should have MACs from all other parties"); + let their_mac = bit.macs[j]; + self.channels.parties[j] .send(Message { from: self.id, to: j, - payload: MessagePayload::BitReveal(bit.bit.value, *their_mac), + payload: MessagePayload::BitReveal(bit.bit.value, their_mac), }) .unwrap(); } @@ -898,20 +890,20 @@ impl Party { /// Locally compute the XOR of two authenticated bits, which will itself be /// authenticated already. - fn xor_abits(&mut self, a: &AuthBit, b: &AuthBit) -> AuthBit { - let mut macs = Vec::new(); - for (maccing_party, mac) in a.macs.iter() { + fn xor_abits( + &mut self, + a: &AuthBit, + b: &AuthBit, + ) -> AuthBit { + let mut macs = [mac::zero_mac(); NUM_PARTIES]; + for (maccing_party, mac) in a.macs.iter().enumerate() { let mut xored_mac = [0u8; MAC_LENGTH]; - let other_mac = b - .macs - .iter() - .find(|(party, _)| *party == *maccing_party) - .expect("should have MACs from all other parties") - .1; + let other_mac = b.macs[maccing_party]; + for byte in 0..MAC_LENGTH { xored_mac[byte] = mac[byte] ^ other_mac[byte]; } - macs.push((*maccing_party, xored_mac)) + macs[maccing_party] = xored_mac; } let mut mac_keys = Vec::new(); @@ -945,10 +937,14 @@ impl Party { fn and_abits( &mut self, - random_triple: (AuthBit, AuthBit, AuthBit), - x: &AuthBit, - y: &AuthBit, - ) -> Result { + random_triple: ( + AuthBit, + AuthBit, + AuthBit, + ), + x: &AuthBit, + y: &AuthBit, + ) -> Result, Error> { let (a, b, c) = random_triple; let blinded_x_share = self.xor_abits(x, &a); let blinded_y_share = self.xor_abits(y, &b); @@ -969,7 +965,7 @@ impl Party { /// Invert an authenticated bit, resulting in an authentication of the /// inverted bit. - fn invert_abit(&mut self, a: &AuthBit) -> AuthBit { + fn invert_abit(&mut self, a: &AuthBit) -> AuthBit { let mut mac_keys = a.mac_keys.clone(); for key in mac_keys.iter_mut() { key.mac_key = xor_mac_width(&key.mac_key, &self.global_mac_key) @@ -990,7 +986,14 @@ impl Party { &mut self, len: usize, bucket_size: usize, - ) -> Result, Error> { + ) -> Result< + Vec<( + AuthBit, + AuthBit, + AuthBit, + )>, + Error, + > { // get `len * BUCKET_SIZE` leaky ANDs let leaky_ands = self.random_leaky_and(len * bucket_size)?; @@ -998,7 +1001,14 @@ impl Party { // Using random u128 bit indices for shuffling should prevent collisions // for at least 2^40 triples except with probability 2^-40. let random_indices = self.coin_flip(leaky_ands.len() * 8 * 16)?; - let mut indexed_ands: Vec<(u128, (AuthBit, AuthBit, AuthBit))> = random_indices + let mut indexed_ands: Vec<( + u128, + ( + AuthBit, + AuthBit, + AuthBit, + ), + )> = random_indices .chunks_exact(16) .map(|chunk| { u128::from_be_bytes(chunk.try_into().expect("chunks are exactly the right size")) @@ -1006,8 +1016,11 @@ impl Party { .zip(leaky_ands) .collect(); indexed_ands.sort_by_key(|(index, _)| *index); - let leaky_ands: Vec<&(AuthBit, AuthBit, AuthBit)> = - indexed_ands.iter().map(|(_, triple)| triple).collect(); + let leaky_ands: Vec<&( + AuthBit, + AuthBit, + AuthBit, + )> = indexed_ands.iter().map(|(_, triple)| triple).collect(); // combine all buckets to single ANDs let mut results = Vec::new(); @@ -1031,7 +1044,7 @@ impl Party { } /// Perform the active_security check for bit authentication - fn bit_auth_check(&mut self, auth_bits: &[AuthBit]) -> Result<(), Error> { + fn bit_auth_check(&mut self, auth_bits: &[AuthBit]) -> Result<(), Error> { for _j in 0..SEC_MARGIN_BIT_AUTH { // a) Sample `ell'` random bit.s let r = self.coin_flip(auth_bits.len())?; @@ -1063,9 +1076,9 @@ impl Party { xored_keys[mac_keys.bit_holder][byte] ^= mac_keys.mac_key[byte]; } } - for (key_holder, tag) in xm.macs.iter() { + for (key_holder, tag) in xm.macs.iter().enumerate() { for (index, tag_byte) in tag.iter().enumerate() { - xored_tags[*key_holder][index] ^= *tag_byte; + xored_tags[key_holder][index] ^= *tag_byte; } } } @@ -1355,7 +1368,7 @@ impl Party { fn function_dependent( &mut self, circuit: &Circuit, - ) -> Result<(Vec, Vec<(usize, u8, AuthBit)>), Error> { + ) -> Result<(Vec, Vec<(usize, u8, AuthBit)>), Error> { let num_and_triples = circuit.num_and_gates(); let mut and_shares = self .random_and_shares(num_and_triples, circuit.and_bucket_size()) @@ -1625,13 +1638,8 @@ impl Party { self.broadcast(&vec![masked_wire_value as u8])?; } else { // send input wire shares to the party - let their_mac = wire_share - .0 - .macs - .iter() - .find(|(maccing_party, _)| *maccing_party == party) - .expect("should have macs from all other parties") - .1; + let their_mac = wire_share.0.macs[party]; + self.channels.parties[party] .send(Message { from: self.id, @@ -1740,7 +1748,7 @@ impl Party { &mut self, circuit: &Circuit, garbled_ands: Vec, - local_ands: Vec<(usize, u8, AuthBit)>, + local_ands: Vec<(usize, u8, AuthBit)>, masked_input_wire_values: Vec<(usize, bool)>, input_wire_labels: Vec<(usize, usize, [u8; MAC_LENGTH])>, ) -> Result<(Vec<(usize, bool)>, Vec<(usize, usize, [u8; 16])>), Error> { @@ -1952,7 +1960,7 @@ impl Party { } } else { // send output wire mask shares - let evaluator_mac = output_wire_share.macs[EVALUATOR_ID].1; + let evaluator_mac = output_wire_share.macs[EVALUATOR_ID]; self.channels .evaluator .send(Message { @@ -1988,8 +1996,6 @@ impl Party { circuit: &Circuit, input: &[bool], ) -> Result>, Error> { - use std::io::Write; - // Validate the circuit circuit .validate_circuit_specification() @@ -2000,42 +2006,9 @@ impl Party { panic!("Invalid input provided to party {}", self.id) } - let num_auth_shares = circuit.share_authentication_cost() + SEC_MARGIN_SHARE_AUTH; - - if read_stored_triples { - let file = std::fs::File::open(format!("{}.triples", self.id)); - if let Ok(f) = file { - (self.global_mac_key, self.abit_pool) = - serde_json::from_reader(f).map_err(|_| Error::OtherError)?; - - let max_id = self - .abit_pool - .iter() - .max_by_key(|abit| abit.bit.id.0) - .map(|abit| abit.bit.id.0) - .unwrap_or(0); - self.bit_counter = max_id; - - if num_auth_shares > self.abit_pool.len() { - self.log(&format!( - "Insufficient precomputation (by {})", - num_auth_shares - self.abit_pool.len() - )); - return Ok(None); - } - } - } else { - let target_number = circuit.share_authentication_cost(); + let target_number = circuit.share_authentication_cost(); - self.abit_pool = self.precompute_abits(target_number + SEC_MARGIN_SHARE_AUTH)?; - - let file = std::fs::File::create(format!("{}.triples", self.id)) - .map_err(|_| Error::OtherError)?; - let mut writer = std::io::BufWriter::new(file); - serde_json::to_writer(&mut writer, &(self.global_mac_key, &self.abit_pool)) - .map_err(|_| Error::OtherError)?; - writer.flush().unwrap(); - } + self.abit_pool = self.precompute_abits(target_number + SEC_MARGIN_SHARE_AUTH)?; self.function_independent(circuit).unwrap(); @@ -2147,7 +2120,7 @@ impl Party { &self, gate_index: usize, garble_index: u8, - and_share: AuthBit, + and_share: AuthBit, output_label: [u8; 16], left_label: [u8; 16], right_label: [u8; 16], @@ -2175,7 +2148,7 @@ impl Party { garbled_and: &[u8], left_label: [u8; 16], right_label: [u8; 16], - ) -> Result<(bool, Vec<[u8; MAC_LENGTH]>, [u8; MAC_LENGTH]), Error> { + ) -> Result<(bool, [Mac; NUM_PARTIES], [u8; MAC_LENGTH]), Error> { let blinding: Vec = compute_blinding( garbled_and.len(), left_label, @@ -2193,7 +2166,11 @@ impl Party { } /// Serialize an authenticated wire share for garbling AND gates. - fn garbling_serialize(&self, and_share: AuthBit, output_label: [u8; 16]) -> Vec { + fn garbling_serialize( + &self, + and_share: AuthBit, + output_label: [u8; 16], + ) -> Vec { let mut result = and_share.serialize_bit_macs(); let mut garbled_label = output_label; for key in and_share.mac_keys { @@ -2211,7 +2188,7 @@ impl Party { fn garbling_deserialize( &self, serialization: &[u8], - ) -> Result<(bool, Vec<[u8; 16]>, [u8; 16]), Error> { + ) -> Result<(bool, [Mac; NUM_PARTIES], [u8; 16]), Error> { let (bit_mac_bytes, label) = serialization.split_at(1 + MAC_LENGTH * self.num_parties); let (bit_value, macs) = AuthBit::deserialize_bit_macs(bit_mac_bytes)?; Ok((bit_value, macs, label.try_into().unwrap())) diff --git a/atlas-spec/mpc-engine/src/primitives/auth_share.rs b/atlas-spec/mpc-engine/src/primitives/auth_share.rs index 2351d6e..694701c 100644 --- a/atlas-spec/mpc-engine/src/primitives/auth_share.rs +++ b/atlas-spec/mpc-engine/src/primitives/auth_share.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use crate::{primitives::mac::MAC_LENGTH, Error}; -use super::mac::{Mac, MacKey}; +use super::mac::{self, Mac, MacKey}; /// A bit held by a party with a given ID. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -18,20 +18,20 @@ pub struct Bit { /// party, their party ID is also required to disambiguate. pub struct BitID(pub(crate) usize); -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone)] /// A bit authenticated between two parties. -pub struct AuthBit { +pub struct AuthBit { pub(crate) bit: Bit, - pub(crate) macs: Vec<(usize, Mac)>, + pub(crate) macs: [Mac; NUM_PARTIES], pub(crate) mac_keys: Vec, } -impl AuthBit { +impl AuthBit { /// Serialize the bit value and all MACs on the bit. pub fn serialize_bit_macs(&self) -> Vec { - let mut result = vec![0u8; (self.macs.len() + 1) * MAC_LENGTH + 1]; + let mut result = vec![0u8; NUM_PARTIES * MAC_LENGTH + 1]; result[0] = self.bit.value as u8; - for (key_holder, mac) in self.macs.iter() { + for (key_holder, mac) in self.macs.iter().enumerate() { result[1 + key_holder * MAC_LENGTH..1 + (key_holder + 1) * MAC_LENGTH] .copy_from_slice(mac); } @@ -40,7 +40,7 @@ impl AuthBit { } /// Deserialize a bit and MACs on that bit. - pub fn deserialize_bit_macs(bytes: &[u8]) -> Result<(bool, Vec<[u8; MAC_LENGTH]>), Error> { + pub fn deserialize_bit_macs(bytes: &[u8]) -> Result<(bool, [Mac; NUM_PARTIES]), Error> { if bytes[0] > 1 { return Err(Error::InvalidSerialization); } @@ -50,12 +50,12 @@ impl AuthBit { return Err(Error::InvalidSerialization); } - let mut macs: Vec<[u8; MAC_LENGTH]> = Vec::new(); - for mac in mac_chunks { - macs.push( - mac.try_into() - .expect("chunks should be of the required length"), - ) + let mut macs = [mac::zero_mac(); NUM_PARTIES]; + + for (party_index, mac) in mac_chunks.enumerate() { + macs[party_index] = mac + .try_into() + .expect("chunks should be of the required length"); } Ok((bit_value, macs)) diff --git a/atlas-spec/mpc-engine/src/primitives/mac.rs b/atlas-spec/mpc-engine/src/primitives/mac.rs index a3dc0ae..919738e 100644 --- a/atlas-spec/mpc-engine/src/primitives/mac.rs +++ b/atlas-spec/mpc-engine/src/primitives/mac.rs @@ -13,6 +13,16 @@ pub type Mac = [u8; MAC_LENGTH]; /// A MAC key for authenticating a bit to another party. pub type MacKey = [u8; MAC_LENGTH]; +/// Returns an all-zero byte array of MAC width. +pub fn zero_mac() -> Mac { + [0u8; MAC_LENGTH] +} + +/// Returns an all-zero byte array of MAC key width. +pub fn zero_key() -> MacKey { + [0u8; MAC_LENGTH] +} + /// Hash the given input to the width of a MAC. /// /// Instantiates a Random Oracle. From 51fbaa416dab3fa8beaac49c0391f1e78496ff56 Mon Sep 17 00:00:00 2001 From: Jonas Schneider-Bensch Date: Tue, 4 Jun 2024 13:56:17 +0200 Subject: [PATCH 02/14] WIP mac key refactoring --- atlas-spec/mpc-engine/examples/run_mpc.rs | 2 +- atlas-spec/mpc-engine/src/party.rs | 137 ++++++------------ .../mpc-engine/src/primitives/auth_share.rs | 2 +- 3 files changed, 49 insertions(+), 92 deletions(-) diff --git a/atlas-spec/mpc-engine/examples/run_mpc.rs b/atlas-spec/mpc-engine/examples/run_mpc.rs index a3769ee..93b256c 100644 --- a/atlas-spec/mpc-engine/examples/run_mpc.rs +++ b/atlas-spec/mpc-engine/examples/run_mpc.rs @@ -45,7 +45,7 @@ fn main() { let input = rng.bit().unwrap(); eprintln!("Starting party {} with input: {}", channel_config.id, input); let mut p = mpc_engine::party::Party::new(channel_config, &c, log_enabled, rng); - let _ = p.run(false, &c, &vec![input]); + let _ = p.run(&c, &vec![input]); }); party_join_handles.push(party_join_handle); } diff --git a/atlas-spec/mpc-engine/src/party.rs b/atlas-spec/mpc-engine/src/party.rs index 7e6831a..bb53536 100644 --- a/atlas-spec/mpc-engine/src/party.rs +++ b/atlas-spec/mpc-engine/src/party.rs @@ -317,13 +317,13 @@ impl Party { // their local bits. let mut authenticated_bits = Vec::new(); for (_bit_index, bit) in bits.into_iter().enumerate() { - let mut computed_keys: Vec = Vec::new(); + let mut computed_keys = [mac::zero_key(); NUM_PARTIES]; let mut received_macs = [mac::zero_mac(); NUM_PARTIES]; // Obliviously authenticate local bits of earlier parties. for bit_holder in 0..self.id { let computed_key = self.provide_bit_authentication(bit_holder)?; - computed_keys.push(computed_key) + computed_keys[bit_holder] = computed_key.mac_key; } // Obliviously obtain MACs on the current bit from all other parties. @@ -339,7 +339,7 @@ impl Party { // Obliviously authenticate local bits of later parties. for bit_holder in self.id + 1..self.num_parties { let computed_key = self.provide_bit_authentication(bit_holder)?; - computed_keys.push(computed_key) + computed_keys[bit_holder] = computed_key.mac_key; } self.sync().expect("synchronization should have succeeded"); @@ -381,14 +381,14 @@ impl Party { let mut mac_0 = [0u8; MAC_LENGTH]; // XOR of all auth keys for key in authenticated_bits[r].mac_keys.iter() { for byte in 0..mac_0.len() { - mac_0[byte] ^= key.mac_key[byte]; + mac_0[byte] ^= key[byte]; } } let mut mac_1 = [0u8; MAC_LENGTH]; // XOR of all (auth keys xor Delta) for key in authenticated_bits[r].mac_keys.iter() { for byte in 0..mac_1.len() { - mac_1[byte] ^= key.mac_key[byte] ^ self.global_mac_key[byte]; + mac_1[byte] ^= key[byte] ^ self.global_mac_key[byte]; } } @@ -558,12 +558,7 @@ impl Party { s_js[j] = s_j; // K_i[x^j] - let input_0 = x - .mac_keys - .iter() - .find(|key| key.bit_holder == j) - .expect("should have keys for all other parties") - .mac_key; + let input_0 = x.mac_keys[j]; // K_i[x^j] xor Delta_i let mut input_1 = [0u8; MAC_LENGTH]; @@ -644,14 +639,14 @@ impl Party { let e_i_value = z_i_value ^ r.bit.value; let other_e_is = self.broadcast(&[e_i_value as u8])?; - for key in r.mac_keys.iter_mut() { + for (bit_holder, key) in r.mac_keys.iter_mut().enumerate() { let (_, other_e_j) = other_e_is .iter() - .find(|(party, _)| *party == key.bit_holder) + .find(|(party, _)| *party == bit_holder) .expect("should have received e_j from every other party j"); let correction_necessary = other_e_j[0] != 0; if correction_necessary { - key.mac_key = xor_mac_width(&key.mac_key, &self.global_mac_key); + *key = xor_mac_width(&key, &self.global_mac_key); } } r.bit.value = z_i_value; @@ -662,10 +657,10 @@ impl Party { // Triple Check // 4. compute Phi let mut phi = [0u8; MAC_LENGTH]; - for key in y.mac_keys.iter() { - let their_mac = y.macs[key.bit_holder]; + for (bit_holder, key) in y.mac_keys.iter().enumerate() { + let their_mac = y.macs[bit_holder]; - let intermediate_xor = xor_mac_width(&key.mac_key, &their_mac); + let intermediate_xor = xor_mac_width(&key, &their_mac); phi = xor_mac_width(&phi, &intermediate_xor); } @@ -707,19 +702,15 @@ impl Party { continue; } // compute k_phi - let my_key = x - .mac_keys - .iter() - .find(|k| k.bit_holder == j) - .expect("should have keys for all other parties' bits"); + let my_key = x.mac_keys[j]; - let k_phi = hash_to_mac_width(domain_separator_triple, &my_key.mac_key); + let k_phi = hash_to_mac_width(domain_separator_triple, &my_key); key_phis.push((j, k_phi)); // compute U_j let u_j_hash = hash_to_mac_width( domain_separator_triple, - &xor_mac_width(&my_key.mac_key, &self.global_mac_key), + &xor_mac_width(&my_key, &self.global_mac_key), ); let u_j = xor_mac_width(&u_j_hash, &k_phi); let u_j = xor_mac_width(&u_j, &phi); @@ -772,10 +763,10 @@ impl Party { h = xor_mac_width(&h, &intermediate_xor); } - for key in z.mac_keys.iter() { - let their_mac = z.macs[key.bit_holder]; + for (bit_holder, key) in z.mac_keys.iter().enumerate() { + let their_mac = z.macs[bit_holder]; - let intermediate_xor = xor_mac_width(&key.mac_key, &their_mac); + let intermediate_xor = xor_mac_width(&key, &their_mac); h = xor_mac_width(&h, &intermediate_xor); } @@ -824,12 +815,9 @@ impl Party { } = reveal_message { debug_assert_eq!(self.id, to); - let my_key = bit - .mac_keys - .iter() - .find(|k| k.bit_holder == from) - .expect("should have a key for every other party"); - if !verify_mac(&value, &mac, &my_key.mac_key, &self.global_mac_key) { + let my_key = bit.mac_keys[from]; + + if !verify_mac(&value, &mac, &my_key, &self.global_mac_key) { return Err(Error::CheckFailed("Bit reveal failed".to_string())); } other_bits.push((from, value)); @@ -864,12 +852,9 @@ impl Party { } = reveal_message { debug_assert_eq!(self.id, to); - let my_key = bit - .mac_keys - .iter() - .find(|k| k.bit_holder == from) - .expect("should have a key for every other party"); - if !verify_mac(&value, &mac, &my_key.mac_key, &self.global_mac_key) { + let my_key = bit.mac_keys[from]; + + if !verify_mac(&value, &mac, &my_key, &self.global_mac_key) { return Err(Error::CheckFailed("Bit reveal failed".to_string())); } other_bits.push((from, value)); @@ -896,6 +881,7 @@ impl Party { b: &AuthBit, ) -> AuthBit { let mut macs = [mac::zero_mac(); NUM_PARTIES]; + for (maccing_party, mac) in a.macs.iter().enumerate() { let mut xored_mac = [0u8; MAC_LENGTH]; let other_mac = b.macs[maccing_party]; @@ -906,23 +892,15 @@ impl Party { macs[maccing_party] = xored_mac; } - let mut mac_keys = Vec::new(); - for key in a.mac_keys.iter() { + let mut mac_keys = [mac::zero_key(); NUM_PARTIES]; + for (bit_holder, key) in a.mac_keys.iter().enumerate() { let mut xored_key = [0u8; MAC_LENGTH]; - let other_key = b - .mac_keys - .iter() - .find(|other_key| key.bit_holder == other_key.bit_holder) - .expect("should have two MAC keys for every other party") - .mac_key; + let other_key = b.mac_keys[bit_holder]; + for byte in 0..MAC_LENGTH { - xored_key[byte] = key.mac_key[byte] ^ other_key[byte]; + xored_key[byte] = key[byte] ^ other_key[byte]; } - mac_keys.push(BitKey { - holder_bit_id: BitID(0), // XXX: We can't know their bit ID here, is it necessary for anything though? - bit_holder: key.bit_holder, - mac_key: xored_key, - }) + mac_keys[bit_holder] = xored_key; } AuthBit { @@ -968,7 +946,7 @@ impl Party { fn invert_abit(&mut self, a: &AuthBit) -> AuthBit { let mut mac_keys = a.mac_keys.clone(); for key in mac_keys.iter_mut() { - key.mac_key = xor_mac_width(&key.mac_key, &self.global_mac_key) + *key = xor_mac_width(&key, &self.global_mac_key) } AuthBit { @@ -1071,10 +1049,8 @@ impl Party { let mut xored_tags = vec![[0u8; MAC_LENGTH]; self.num_parties]; for (m, xm) in auth_bits.iter().enumerate() { if ith_bit(m, &r) { - for mac_keys in xm.mac_keys.iter() { - for byte in 0..mac_keys.mac_key.len() { - xored_keys[mac_keys.bit_holder][byte] ^= mac_keys.mac_key[byte]; - } + for (bit_holder, key) in xm.mac_keys.iter().enumerate() { + xored_keys[bit_holder] = xor_mac_width(&xored_keys[bit_holder], key); } for (key_holder, tag) in xm.macs.iter().enumerate() { for (index, tag_byte) in tag.iter().enumerate() { @@ -1465,13 +1441,9 @@ impl Party { local_ands.push((gate_index, 3, and_3)); } else { // do local computation and send values - let evaluator_key = and_3 - .mac_keys - .iter_mut() - .find(|key| key.bit_holder == EVALUATOR_ID) - .expect("should have key for evaluator"); - evaluator_key.mac_key = - xor_mac_width(&evaluator_key.mac_key, &self.global_mac_key); + let mut evaluator_key = and_3.mac_keys[EVALUATOR_ID]; + + evaluator_key = xor_mac_width(&evaluator_key, &self.global_mac_key); let WireLabel(left_label) = share_left .1 @@ -1597,13 +1569,9 @@ impl Party { { debug_assert_eq!(to, self.id); // verify mac - let my_key = wire_share - .0 - .mac_keys - .iter() - .find(|key| key.bit_holder == from) - .expect("should have keys for all other parties"); - if !verify_mac(&r_j, &mac_j, &my_key.mac_key, &self.global_mac_key) { + let my_key = wire_share.0.mac_keys[from]; + + if !verify_mac(&r_j, &mac_j, &my_key, &self.global_mac_key) { return Err(Error::CheckFailed( "invalid input wire MAC ".to_owned(), )); @@ -1852,11 +1820,9 @@ impl Party { }) .expect("should have keys for all other parties' MACs") .2 - .mac_keys - .iter() - .find(|k| k.bit_holder == j) - .unwrap(); - if !verify_mac(&r_j, &my_mac, &my_key.mac_key, &self.global_mac_key) { + .mac_keys[j]; + + if !verify_mac(&r_j, &my_mac, &my_key, &self.global_mac_key) { return Err(Error::CheckFailed( "AND gate evaluation: MAC check failed".to_owned(), )); @@ -1924,17 +1890,9 @@ impl Party { { debug_assert_eq!(to, self.id); // verify mac - let my_key = output_wire_share - .mac_keys - .iter() - .find(|key| key.bit_holder == from) - .expect("should have keys for all other parties"); - if !verify_mac( - &wire_mask_share, - &mac, - &my_key.mac_key, - &self.global_mac_key, - ) { + let my_key = output_wire_share.mac_keys[from]; + + if !verify_mac(&wire_mask_share, &mac, &my_key, &self.global_mac_key) { return Err(Error::CheckFailed("invalid nput wire MAC ".to_owned())); } output_wire_value ^= wire_mask_share; @@ -1992,7 +1950,6 @@ impl Party { /// Run the MPC protocol, returning the parties output, if any. pub fn run( &mut self, - read_stored_triples: bool, circuit: &Circuit, input: &[bool], ) -> Result>, Error> { @@ -2174,7 +2131,7 @@ impl Party { let mut result = and_share.serialize_bit_macs(); let mut garbled_label = output_label; for key in and_share.mac_keys { - garbled_label = xor_mac_width(&garbled_label, &key.mac_key); + garbled_label = xor_mac_width(&garbled_label, &key); } if and_share.bit.value { diff --git a/atlas-spec/mpc-engine/src/primitives/auth_share.rs b/atlas-spec/mpc-engine/src/primitives/auth_share.rs index 694701c..c680b4a 100644 --- a/atlas-spec/mpc-engine/src/primitives/auth_share.rs +++ b/atlas-spec/mpc-engine/src/primitives/auth_share.rs @@ -23,7 +23,7 @@ pub struct BitID(pub(crate) usize); pub struct AuthBit { pub(crate) bit: Bit, pub(crate) macs: [Mac; NUM_PARTIES], - pub(crate) mac_keys: Vec, + pub(crate) mac_keys: [MacKey; NUM_PARTIES], } impl AuthBit { From 37952aef36e1cd84a7b9a2a25dda1f304cb83f5b Mon Sep 17 00:00:00 2001 From: Jonas Schneider-Bensch Date: Wed, 5 Jun 2024 10:39:48 +0200 Subject: [PATCH 03/14] Test AuthBit serialization --- .../mpc-engine/src/primitives/auth_share.rs | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/atlas-spec/mpc-engine/src/primitives/auth_share.rs b/atlas-spec/mpc-engine/src/primitives/auth_share.rs index c680b4a..c321437 100644 --- a/atlas-spec/mpc-engine/src/primitives/auth_share.rs +++ b/atlas-spec/mpc-engine/src/primitives/auth_share.rs @@ -69,3 +69,51 @@ pub struct BitKey { pub(crate) bit_holder: usize, pub(crate) mac_key: MacKey, } + +#[test] +fn serialization() { + let macs_1 = [ + [1u8; MAC_LENGTH], + [2; MAC_LENGTH], + [3; MAC_LENGTH], + [4; MAC_LENGTH], + ]; + let macs_2 = [ + [11u8; MAC_LENGTH], + [22; MAC_LENGTH], + [33; MAC_LENGTH], + [44; MAC_LENGTH], + ]; + let keys = [ + [5u8; MAC_LENGTH], + [6; MAC_LENGTH], + [7; MAC_LENGTH], + [8; MAC_LENGTH], + ]; + let test_bit_1 = AuthBit { + bit: Bit { + id: BitID(0), + value: true, + }, + macs: macs_1, + mac_keys: keys, + }; + let test_bit_2 = AuthBit { + bit: Bit { + id: BitID(1), + value: false, + }, + macs: macs_2, + mac_keys: keys, + }; + + let (bit_1, deserialized_macs_1) = + AuthBit::<4>::deserialize_bit_macs(&test_bit_1.serialize_bit_macs()).unwrap(); + + let (bit_2, deserialized_macs_2) = + AuthBit::<4>::deserialize_bit_macs(&test_bit_2.serialize_bit_macs()).unwrap(); + assert_eq!(bit_1, true); + assert_eq!(bit_2, false); + assert_eq!(deserialized_macs_1, macs_1); + assert_eq!(deserialized_macs_2, macs_2); +} From 2800c2c7407e3184b47b697d10961ece847ddb9f Mon Sep 17 00:00:00 2001 From: Jonas Schneider-Bensch Date: Wed, 3 Jul 2024 11:17:59 +0200 Subject: [PATCH 04/14] Add local runner --- atlas-spec/mpc-engine/src/lib.rs | 1 + atlas-spec/mpc-engine/src/runner.rs | 62 +++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 atlas-spec/mpc-engine/src/runner.rs diff --git a/atlas-spec/mpc-engine/src/lib.rs b/atlas-spec/mpc-engine/src/lib.rs index 6f247af..a850bc1 100644 --- a/atlas-spec/mpc-engine/src/lib.rs +++ b/atlas-spec/mpc-engine/src/lib.rs @@ -72,3 +72,4 @@ pub mod messages; pub mod party; pub mod primitives; pub mod utils; +pub mod runner; diff --git a/atlas-spec/mpc-engine/src/runner.rs b/atlas-spec/mpc-engine/src/runner.rs new file mode 100644 index 0000000..62151ec --- /dev/null +++ b/atlas-spec/mpc-engine/src/runner.rs @@ -0,0 +1,62 @@ +//! This module implements a local MPC runner. +use std::{sync::mpsc, thread}; + +use hacspec_lib::Randomness; +use rand::RngCore; + +use crate::circuit::Circuit; + +/// A local runner for an MPC session based on MPSC channels. +pub struct Runner; + +impl Runner { + /// Set up and run an MPC session of the given circuit with the provided + /// inputs. + pub fn run( + circuit: &Circuit, + inputs: &[&[bool]], + logging: Vec, + ) -> Vec>> { + let num_parties = inputs.len(); + let (broadcast_relay, party_channels) = crate::utils::set_up_channels(num_parties); + + let _ = thread::spawn(move || broadcast_relay.run()); + let mut results = vec![None; num_parties]; + + let (sender, receiver) = mpsc::channel(); + + let mut party_join_handles = Vec::new(); + for config in party_channels.into_iter() { + let input = inputs[config.id].to_owned(); + let logging = logging.contains(&config.id); + let c = circuit.clone(); + let sender = sender.clone(); + let party_join_handle = thread::spawn(move || { + let mut rng = rand::thread_rng(); + let mut bytes = vec![0u8; 100 * usize::from(u16::MAX)]; + rng.fill_bytes(&mut bytes); + let rng = Randomness::new(bytes); + eprintln!("Starting party {} with input: {:?}", config.id, input); + let mut p = crate::party::Party::new(config, &c, logging, rng); + let result = p.run(&c, &input).unwrap(); + sender.send(result).unwrap(); + }); + party_join_handles.push(party_join_handle); + } + + for _i in 0..num_parties { + let (party, result) = receiver.recv().unwrap(); + + results[party] = result; + } + + for _i in 0..num_parties { + party_join_handles + .pop() + .expect("every party should have a join handle") + .join() + .expect("party did not panic"); + } + results + } +} From 075bbe091245815101caf1d5339671c0746681de Mon Sep 17 00:00:00 2001 From: Jonas Schneider-Bensch Date: Mon, 8 Jul 2024 16:43:10 +0200 Subject: [PATCH 05/14] Large commit - Simplifying AuthBit - Batched bit authentication API - KOS15 base OT - KOS15 stub --- atlas-spec/mpc-engine/Cargo.toml | 1 + atlas-spec/mpc-engine/src/messages.rs | 5 +- atlas-spec/mpc-engine/src/party.rs | 464 ++++++++++-------- .../mpc-engine/src/primitives/auth_share.rs | 84 ++-- atlas-spec/mpc-engine/src/primitives/kos.rs | 40 ++ .../mpc-engine/src/primitives/kos_base.rs | 238 +++++++++ atlas-spec/mpc-engine/src/primitives/mod.rs | 2 + atlas-spec/mpc-engine/src/runner.rs | 2 +- 8 files changed, 596 insertions(+), 240 deletions(-) create mode 100644 atlas-spec/mpc-engine/src/primitives/kos.rs create mode 100644 atlas-spec/mpc-engine/src/primitives/kos_base.rs diff --git a/atlas-spec/mpc-engine/Cargo.toml b/atlas-spec/mpc-engine/Cargo.toml index f9cb78b..42771cf 100644 --- a/atlas-spec/mpc-engine/Cargo.toml +++ b/atlas-spec/mpc-engine/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" [dependencies] rand = "0.8.5" p256.workspace = true +hash-to-curve.workspace = true hmac.workspace = true hacspec-chacha20poly1305.workspace = true hacspec_lib.workspace = true diff --git a/atlas-spec/mpc-engine/src/messages.rs b/atlas-spec/mpc-engine/src/messages.rs index a9f9edf..d5ecf89 100644 --- a/atlas-spec/mpc-engine/src/messages.rs +++ b/atlas-spec/mpc-engine/src/messages.rs @@ -4,7 +4,6 @@ use std::sync::mpsc::{Receiver, Sender}; use crate::{ circuit::WireIndex, primitives::{ - auth_share::BitID, commitment::{Commitment, Opening}, mac::Mac, ot::{OTReceiverSelect, OTSenderInit, OTSenderSend}, @@ -34,9 +33,7 @@ pub enum MessagePayload { /// A round synchronization message Sync, /// Request a number of bit authentications from another party. - RequestBitAuth(BitID, Sender, Receiver), - /// A response to a bit authentication request. - BitAuth(BitID, Mac), + RequestBitAuth(Sender, Receiver), /// A commitment on a broadcast value. BroadcastCommitment(Commitment), /// The opening to a broadcast value. diff --git a/atlas-spec/mpc-engine/src/party.rs b/atlas-spec/mpc-engine/src/party.rs index 4a0fc55..6231f3d 100644 --- a/atlas-spec/mpc-engine/src/party.rs +++ b/atlas-spec/mpc-engine/src/party.rs @@ -8,8 +8,9 @@ use crate::{ circuit::Circuit, messages::{Message, MessagePayload, SubMessage}, primitives::{ - auth_share::{AuthBit, Bit, BitID, BitKey}, + auth_share::{xor, AuthBit}, commitment::{Commitment, Opening}, + kos::kos_send, mac::{ self, generate_mac_key, hash_to_mac_width, mac, verify_mac, xor_mac_width, Mac, MacKey, MAC_LENGTH, @@ -61,15 +62,14 @@ struct GarbledAnd { /// A struct defining protocol party state during a protocol execution. pub struct Party { - bit_counter: usize, /// The party's numeric identifier - id: usize, + pub(crate) id: usize, /// The number of parties in the MPC session num_parties: usize, /// The channel configuration for communicating to other protocol parties - channels: ChannelConfig, + pub(crate) channels: ChannelConfig, /// The global MAC key for authenticating wire value shares - global_mac_key: MacKey, + pub(crate) global_mac_key: MacKey, /// A local source of random bits and bytes entropy: Randomness, /// Pool of pre-computed authenticated bits @@ -96,7 +96,6 @@ impl Party { mut entropy: Randomness, ) -> Self { Self { - bit_counter: 0, id: channels.id, num_parties: channels.parties.len(), channels, @@ -306,10 +305,7 @@ impl Party { let mut bits = Vec::new(); for i in 0..len_unchecked { - bits.push(Bit { - id: self.fresh_bit_id(), - value: ith_bit(i, &random_bytes), - }) + bits.push(ith_bit(i, &random_bytes)) } // 2. Obliviously get MACs on all local bits from every other party and obliviously provide MACs on @@ -321,8 +317,8 @@ impl Party { // Obliviously authenticate local bits of earlier parties. for bit_holder in 0..self.id { - let computed_key = self.provide_bit_authentication(bit_holder)?; - computed_keys[bit_holder] = computed_key.mac_key; + let computed_key = self.provide_bit_authentication()?; + computed_keys[bit_holder] = computed_key; } // Obliviously obtain MACs on the current bit from all other parties. @@ -331,14 +327,14 @@ impl Party { continue; } - let received_mac: Mac = self.obtain_bit_authentication(authenticator, &bit)?; + let received_mac: Mac = self.obtain_bit_authentication(authenticator, bit)?; received_macs[authenticator] = received_mac; } // Obliviously authenticate local bits of later parties. for bit_holder in self.id + 1..self.num_parties { - let computed_key = self.provide_bit_authentication(bit_holder)?; - computed_keys[bit_holder] = computed_key.mac_key; + let computed_key = self.provide_bit_authentication()?; + computed_keys[bit_holder] = computed_key; } self.sync().expect("synchronization should have succeeded"); @@ -346,7 +342,7 @@ impl Party { authenticated_bits.push(AuthBit { bit, macs: received_macs, - mac_keys: computed_keys, + keys: computed_keys, }) } @@ -362,6 +358,71 @@ impl Party { Ok(authenticated_bits[0..len].to_vec()) } + fn batch_precompute_abits(&mut self, len: usize) -> Result>, Error> { + let len_unchecked = len + SEC_MARGIN_BIT_AUTH; + + // 1. Generate `len_unchecked` random local bits for authenticating. + let random_bytes = self + .entropy + .bytes(len_unchecked / 8 + 1) + .expect("sufficient randomness should have been provided externally") + .to_owned(); + let mut bits = Vec::new(); + + for i in 0..len_unchecked { + bits.push(ith_bit(i, &random_bytes)) + } + + // 2. Obliviously get MACs on all local bits from every other party and obliviously provide MACs on + // their local bits. + let mut authenticated_bits = Vec::new(); + let mut keys_by_party = vec![Vec::new(); self.num_parties]; + let mut macs_by_party = vec![Vec::new(); self.num_parties]; + for bit_holder in 0..self.id { + let keys = self.batched_bit_auth_sender(bits.len())?; + keys_by_party[bit_holder] = keys; + } + + for authenticator in 0..self.num_parties { + if authenticator == self.id { + continue; + } + + let received_macs = self.batched_bit_auth_receiver(authenticator, &bits)?; + macs_by_party[authenticator] = received_macs; + } + + for bit_holder in self.id + 1..self.num_parties { + let keys = self.batched_bit_auth_sender(bits.len())?; + keys_by_party[bit_holder] = keys; + } + + for (index, bit) in bits.iter().enumerate() { + let mut macs = [mac::zero_mac(); NUM_PARTIES]; + let mut keys = [mac::zero_key(); NUM_PARTIES]; + for i in 0..self.num_parties { + macs[i] = macs_by_party[i][index]; + keys[i] = keys_by_party[i][index]; + } + authenticated_bits.push(AuthBit { + bit: bit.to_owned(), + macs, + keys, + }); + } + + self.sync().expect("synchronization should have succeeded"); + + // 3. Perform the statistical check for malicious security of the + // generated authenticated bits. Failure indicates buggy bit + // authentication or cheating. + self.bit_auth_check(&authenticated_bits) + .expect("bit authentication check must not fail"); + + // 4. Return the first `len` authenticated bits. + Ok(authenticated_bits[0..len].to_vec()) + } + /// Transform authenticated bits into `len` authenticated bit shares. fn random_authenticated_shares( &mut self, @@ -373,22 +434,21 @@ impl Party { // Malicious security checks for r in len..len + SEC_MARGIN_SHARE_AUTH { + eprintln!("Party {}: Bit {:?}", self.id, authenticated_bits[r]); + let domain_separator_0 = format!("Share authentication {} - 0", self.id); let domain_separator_1 = format!("Share authentication {} - 1", self.id); let domain_separator_macs = format!("Share authentication {} - macs", self.id); - let mut mac_0 = [0u8; MAC_LENGTH]; // XOR of all auth keys - for key in authenticated_bits[r].mac_keys.iter() { - for byte in 0..mac_0.len() { - mac_0[byte] ^= key[byte]; - } + let mut mac_0 = mac::zero_key(); // XOR of all auth keys + for key in authenticated_bits[r].keys.iter() { + mac_0 = xor_mac_width(&mac_0, key); } - let mut mac_1 = [0u8; MAC_LENGTH]; // XOR of all (auth keys xor Delta) - for key in authenticated_bits[r].mac_keys.iter() { - for byte in 0..mac_1.len() { - mac_1[byte] ^= key[byte] ^ self.global_mac_key[byte]; - } + let mut mac_1 = mac::zero_key(); // XOR of all (auth keys xor Delta) + for key in authenticated_bits[r].keys.iter() { + let intermediate_xor = xor_mac_width(key, &self.global_mac_key); + mac_1 = xor_mac_width(&mac_1, &intermediate_xor); } let all_macs: Vec = authenticated_bits[r].serialize_bit_macs(); // the authenticated bit and all macs on it @@ -408,98 +468,89 @@ impl Party { let received_mac_openings = self.broadcast_opening(op_macs)?; // open the other parties commitments to obtain their bit values and MACs - let mut other_bits_macs = Vec::new(); + let mut other_bits_macs = [(false, [mac::zero_mac(); NUM_PARTIES]); NUM_PARTIES]; for (party, their_opening) in received_mac_openings { let (_, _, _, their_mac_commitment) = received_commitments .iter() .find(|(committing_party, _, _, _)| *committing_party == party) .expect("should have received commitments from all parties"); - other_bits_macs.push(( - party, - AuthBit::::deserialize_bit_macs( - &their_mac_commitment.open(&their_opening)?, - )?, - )); + other_bits_macs[party] = AuthBit::::deserialize_bit_macs( + &their_mac_commitment.open(&their_opening)?, + )?; } debug_assert_eq!( other_bits_macs.len(), - self.num_parties - 1, + NUM_PARTIES - 1, "should have received valid openings from all other parties" ); - // compute xor of all opened MACs for each party - let mut xor_macs = vec![[0u8; MAC_LENGTH]; self.num_parties]; - - for (maccing_party, xored_mac) in xor_macs.iter_mut().enumerate() { - if maccing_party == self.id { - // don't need to compute this for ourselves - continue; - } - - for p in 0..self.num_parties { - let their_mac = if p == self.id { - authenticated_bits[r].macs[maccing_party] - } else { - let (_sending_party, (_other_bit, other_macs)) = other_bits_macs - .iter() - .find(|(sending_party, _rest)| *sending_party == p) - .expect( - "should have gotten bit values and MACs from all other parties", - ); - other_macs[maccing_party] - }; - for byte in 0..MAC_LENGTH { - xored_mac[byte] ^= their_mac[byte]; - } - } - } - - let mut b_i = false; // compute our own xor of all bits - for (_party, (bit, _macs)) in other_bits_macs.iter() { + let mut b_i = false; + for (bit, _macs) in other_bits_macs.iter() { b_i ^= *bit; } + // broadcast the xor of all bits + let received_bit_openings = if b_i { + self.broadcast_opening(op1)? + } else { + self.broadcast_opening(op0)? + }; + // compute the other parties xor-ed bits to know which openings they are sending - let mut xor_bits = vec![authenticated_bits[r].bit.value; self.num_parties]; - for j in 0..self.num_parties { + let mut xor_bits = [authenticated_bits[r].bit; NUM_PARTIES]; + for j in 0..NUM_PARTIES { if j == self.id { xor_bits[j] = b_i; } - for (party, (bit, _macs)) in other_bits_macs.iter() { - if *party == j { - continue; - } + for (bit, _macs) in other_bits_macs.iter() { xor_bits[j] ^= bit; } } - let received_bit_openings = if b_i { - self.broadcast_opening(op1)? - } else { - self.broadcast_opening(op0)? - }; + // compute xor of all opened MACs for each party + let mut xored_macs = [mac::zero_mac(); NUM_PARTIES]; + + for (party, xored_mac) in xored_macs.iter_mut().enumerate() { + if party == self.id { + // don't need to compute this for ourselves + continue; + } + + for from_party in 0..NUM_PARTIES { + let their_mac = if from_party == self.id { + authenticated_bits[r].macs[party] + } else { + let (_other_bit, their_macs) = other_bits_macs[from_party]; + their_macs[party] + }; + + *xored_mac = xor_mac_width(xored_mac, &their_mac); + } + } for (party, bit_opening) in received_bit_openings { let (_, their_com0, their_com1, _) = received_commitments .iter() .find(|(committing_party, _, _, _)| *committing_party == party) .expect("should have received commitments from all other parties"); + let their_mac = if !xor_bits[party] { their_com0.open(&bit_opening).unwrap() } else { their_com1.open(&bit_opening).unwrap() }; - if their_mac != xor_macs[party] { + if their_mac != xored_macs[party] { self.log(&format!( - "Error while checking party {}'s bit commitment!", - party - )); - return Err(Error::CheckFailed( - "Share Authentication failed".to_string(), + "Error while checking party {}'s bit commitment!\n opened mac {their_mac:?} computed_mac {:?}", + party, + xored_macs[party] )); + // return Err(Error::CheckFailed( + // "Share Authentication failed".to_string(), + // )); } } } @@ -535,7 +586,7 @@ impl Party { debug_assert_eq!(to, self.id); let their_mac = x.macs[from]; let hash_lsb = lsb(&hash_to_mac_width(domain_separator, &their_mac)); - let t_j = if x.bit.value { + let t_j = if x.bit { hash_j_1 ^ hash_lsb } else { hash_j_0 ^ hash_lsb @@ -557,7 +608,7 @@ impl Party { s_js[j] = s_j; // K_i[x^j] - let input_0 = x.mac_keys[j]; + let input_0 = x.keys[j]; // K_i[x^j] xor Delta_i let mut input_1 = [0u8; MAC_LENGTH]; @@ -566,7 +617,7 @@ impl Party { } let h_0 = lsb(&hash_to_mac_width(domain_separator, &input_0)) ^ s_j; - let h_1 = lsb(&hash_to_mac_width(domain_separator, &input_1)) ^ s_j ^ y.bit.value; + let h_1 = lsb(&hash_to_mac_width(domain_separator, &input_1)) ^ s_j ^ y.bit; self.channels.parties[j] .send(Message { from: self.id, @@ -589,7 +640,7 @@ impl Party { let their_mac = x.macs[from]; let hash_lsb = lsb(&hash_to_mac_width(domain_separator, &their_mac)); - let t_j = if x.bit.value { + let t_j = if x.bit { hash_j_1 ^ hash_lsb } else { hash_j_0 ^ hash_lsb @@ -634,11 +685,11 @@ impl Party { let v_i = self.half_and(&x, &y)?; - let z_i_value = (y.bit.value && x.bit.value) ^ v_i; - let e_i_value = z_i_value ^ r.bit.value; + let z_i_value = (y.bit && x.bit) ^ v_i; + let e_i_value = z_i_value ^ r.bit; let other_e_is = self.broadcast(&[e_i_value as u8])?; - for (bit_holder, key) in r.mac_keys.iter_mut().enumerate() { + for (bit_holder, key) in r.keys.iter_mut().enumerate() { let (_, other_e_j) = other_e_is .iter() .find(|(party, _)| *party == bit_holder) @@ -648,7 +699,7 @@ impl Party { *key = xor_mac_width(&key, &self.global_mac_key); } } - r.bit.value = z_i_value; + r.bit = z_i_value; let z = r; self.sync().expect("sync should always succeed"); @@ -656,14 +707,14 @@ impl Party { // Triple Check // 4. compute Phi let mut phi = [0u8; MAC_LENGTH]; - for (bit_holder, key) in y.mac_keys.iter().enumerate() { + for (bit_holder, key) in y.keys.iter().enumerate() { let their_mac = y.macs[bit_holder]; let intermediate_xor = xor_mac_width(&key, &their_mac); phi = xor_mac_width(&phi, &intermediate_xor); } - if y.bit.value { + if y.bit { phi = xor_mac_width(&phi, &self.global_mac_key); } @@ -684,7 +735,7 @@ impl Party { let their_mac = x.macs[from]; let mut mac_phi = hash_to_mac_width(domain_separator_triple, &their_mac); - if x.bit.value { + if x.bit { for byte in 0..MAC_LENGTH { mac_phi[byte] ^= u[byte]; } @@ -701,7 +752,7 @@ impl Party { continue; } // compute k_phi - let my_key = x.mac_keys[j]; + let my_key = x.keys[j]; let k_phi = hash_to_mac_width(domain_separator_triple, &my_key); key_phis.push((j, k_phi)); @@ -737,7 +788,7 @@ impl Party { let their_mac = x.macs[from]; let mut mac_phi = hash_to_mac_width(domain_separator_triple, &their_mac); - if x.bit.value { + if x.bit { for byte in 0..MAC_LENGTH { mac_phi[byte] ^= u[byte]; } @@ -762,17 +813,17 @@ impl Party { h = xor_mac_width(&h, &intermediate_xor); } - for (bit_holder, key) in z.mac_keys.iter().enumerate() { + for (bit_holder, key) in z.keys.iter().enumerate() { let their_mac = z.macs[bit_holder]; let intermediate_xor = xor_mac_width(&key, &their_mac); h = xor_mac_width(&h, &intermediate_xor); } - if x.bit.value { + if x.bit { h = xor_mac_width(&h, &phi); } - if z.bit.value { + if z.bit { h = xor_mac_width(&h, &self.global_mac_key); } @@ -814,7 +865,7 @@ impl Party { } = reveal_message { debug_assert_eq!(self.id, to); - let my_key = bit.mac_keys[from]; + let my_key = bit.keys[from]; if !verify_mac(&value, &mac, &my_key, &self.global_mac_key) { return Err(Error::CheckFailed("Bit reveal failed".to_string())); @@ -836,7 +887,7 @@ impl Party { .send(Message { from: self.id, to: j, - payload: MessagePayload::BitReveal(bit.bit.value, their_mac), + payload: MessagePayload::BitReveal(bit.bit, their_mac), }) .unwrap(); } @@ -851,7 +902,7 @@ impl Party { } = reveal_message { debug_assert_eq!(self.id, to); - let my_key = bit.mac_keys[from]; + let my_key = bit.keys[from]; if !verify_mac(&value, &mac, &my_key, &self.global_mac_key) { return Err(Error::CheckFailed("Bit reveal failed".to_string())); @@ -862,7 +913,7 @@ impl Party { } } - let mut result = bit.bit.value; + let mut result = bit.bit; for (_, other_bit) in other_bits { result ^= other_bit } @@ -872,46 +923,6 @@ impl Party { Ok(result) } - /// Locally compute the XOR of two authenticated bits, which will itself be - /// authenticated already. - fn xor_abits( - &mut self, - a: &AuthBit, - b: &AuthBit, - ) -> AuthBit { - let mut macs = [mac::zero_mac(); NUM_PARTIES]; - - for (maccing_party, mac) in a.macs.iter().enumerate() { - let mut xored_mac = [0u8; MAC_LENGTH]; - let other_mac = b.macs[maccing_party]; - - for byte in 0..MAC_LENGTH { - xored_mac[byte] = mac[byte] ^ other_mac[byte]; - } - macs[maccing_party] = xored_mac; - } - - let mut mac_keys = [mac::zero_key(); NUM_PARTIES]; - for (bit_holder, key) in a.mac_keys.iter().enumerate() { - let mut xored_key = [0u8; MAC_LENGTH]; - let other_key = b.mac_keys[bit_holder]; - - for byte in 0..MAC_LENGTH { - xored_key[byte] = key[byte] ^ other_key[byte]; - } - mac_keys[bit_holder] = xored_key; - } - - AuthBit { - bit: Bit { - id: self.fresh_bit_id(), - value: a.bit.value ^ b.bit.value, - }, - macs, - mac_keys, - } - } - fn and_abits( &mut self, random_triple: ( @@ -923,18 +934,18 @@ impl Party { y: &AuthBit, ) -> Result, Error> { let (a, b, c) = random_triple; - let blinded_x_share = self.xor_abits(x, &a); - let blinded_y_share = self.xor_abits(y, &b); + let blinded_x_share = xor(x, &a); + let blinded_y_share = xor(y, &b); let blinded_x = self.open_bit(&blinded_x_share)?; let blinded_y = self.open_bit(&blinded_y_share)?; let mut result = c; if blinded_x { - result = self.xor_abits(&result, &y); + result = xor(&result, &y); } if !blinded_y { - result = self.xor_abits(&result, &a); + result = xor(&result, &a); } Ok(result) @@ -943,18 +954,16 @@ impl Party { /// Invert an authenticated bit, resulting in an authentication of the /// inverted bit. fn invert_abit(&mut self, a: &AuthBit) -> AuthBit { - let mut mac_keys = a.mac_keys.clone(); + let mut mac_keys = a.keys.clone(); for key in mac_keys.iter_mut() { *key = xor_mac_width(&key, &self.global_mac_key) } AuthBit { - bit: Bit { - id: self.fresh_bit_id(), - value: a.bit.value ^ true, - }, + bit: a.bit ^ true, + macs: a.macs.clone(), - mac_keys, + keys: mac_keys, } } @@ -1005,13 +1014,13 @@ impl Party { let (mut x, y, mut z) = bucket[0].clone(); for (next_x, next_y, next_z) in bucket[1..].iter() { - let d_i = self.xor_abits(&y, next_y); + let d_i = xor(&y, next_y); let d = self.open_bit(&d_i)?; - x = self.xor_abits(&x, next_x); - z = self.xor_abits(&z, next_z); + x = xor(&x, next_x); + z = xor(&z, next_z); if d { - z = self.xor_abits(&z, next_x); + z = xor(&z, next_x); } } results.push((x, y, z)); @@ -1029,7 +1038,7 @@ impl Party { // b) Compute x_j = XOR_{m in [ell']} r_m & x_m let mut x_j = false; for (m, xm) in auth_bits.iter().enumerate() { - x_j ^= ith_bit(m, &r) & xm.bit.value; + x_j ^= ith_bit(m, &r) & xm.bit; } // broadcast x_j @@ -1048,7 +1057,7 @@ impl Party { let mut xored_tags = vec![[0u8; MAC_LENGTH]; self.num_parties]; for (m, xm) in auth_bits.iter().enumerate() { if ith_bit(m, &r) { - for (bit_holder, key) in xm.mac_keys.iter().enumerate() { + for (bit_holder, key) in xm.keys.iter().enumerate() { xored_keys[bit_holder] = xor_mac_width(&xored_keys[bit_holder], key); } for (key_holder, tag) in xm.macs.iter().enumerate() { @@ -1218,11 +1227,79 @@ impl Party { } } - /// Generate a fresh bit id, increasing the internal bit counter. - fn fresh_bit_id(&mut self) -> BitID { - let res = self.bit_counter; - self.bit_counter += 1; - BitID(res) + fn batched_bit_auth_receiver( + &mut self, + authenticator: usize, + local_bits: &[bool], + ) -> Result, Error> { + let (my_address, my_inbox) = mpsc::channel::(); + let (their_address, their_inbox) = mpsc::channel::(); + + // The authenticator has to initiate an OT session, so request a bit + // authentication session using the generated channels. + self.channels.parties[authenticator] + .send(Message { + from: self.id, + to: authenticator, + payload: MessagePayload::RequestBitAuth(my_address, their_inbox), + }) + .expect("all parties should be online"); + + // Join the authenticator's OT session with the local bit value as the + // receiver choice input. + let received_macs: Vec = crate::primitives::kos::kos_receive( + &local_bits, + their_address, + my_inbox, + authenticator, + self.id, + &mut self.entropy, + ) + .unwrap(); + + Ok(received_macs) + } + + fn batched_bit_auth_sender(&mut self, len: usize) -> Result, Error> { + let request_msg = self + .channels + .listen + .recv() + .expect("all parties should be online"); + + if let Message { + to, + from, + payload: MessagePayload::RequestBitAuth(their_address, my_inbox), + } = request_msg + { + debug_assert_eq!(to, self.id, "Got a wrongly addressed message"); + + let mut kos_inputs = Vec::new(); + for i in 0..len { + let input = mac(&true, &self.global_mac_key, &mut self.entropy); + kos_inputs.push(input) + } + + // Initiate an OT session with the bit holder giving the two MACs as + // sender inputs. + kos_send( + their_address, + my_inbox, + from, + self.id, + &kos_inputs, + &mut self.entropy, + ) + .unwrap_or_default(); + + let keys = kos_inputs.into_iter().map(|(_l, r)| r).collect(); + + Ok(keys) + } else { + self.log(&format!("Bit Auth: Unexpected message {request_msg:?}")); + Err(Error::UnexpectedMessage(request_msg)) + } } /// Initiate a two-party bit authentication session to oblivious obtain a @@ -1237,7 +1314,7 @@ impl Party { fn obtain_bit_authentication( &mut self, authenticator: usize, - local_bit: &Bit, + local_bit: bool, ) -> Result { // Set up channels for an OT subprotocol session with the authenticator. let (my_address, my_inbox) = mpsc::channel::(); @@ -1249,18 +1326,14 @@ impl Party { .send(Message { from: self.id, to: authenticator, - payload: MessagePayload::RequestBitAuth( - local_bit.id.clone(), - my_address, - their_inbox, - ), + payload: MessagePayload::RequestBitAuth(my_address, their_inbox), }) .expect("all parties should be online"); // Join the authenticator's OT session with the local bit value as the // receiver choice input. let received_mac: Mac = self - .ot_receive(local_bit.value, their_address, my_inbox, authenticator)? + .ot_receive(local_bit, their_address, my_inbox, authenticator)? .try_into() .expect("should receive a MAC of the right length"); @@ -1277,7 +1350,7 @@ impl Party { /// thus obliviously obtain a MAC `M = K + b * Delta` by setting `b` as /// their choice bit as an OT receiver with the authenticator acting as OT /// sender with inputs `left_value` and `right value`. - fn provide_bit_authentication(&mut self, bit_holder: usize) -> Result { + fn provide_bit_authentication(&mut self) -> Result { let request_msg = self .channels .listen @@ -1287,7 +1360,7 @@ impl Party { if let Message { to, from, - payload: MessagePayload::RequestBitAuth(holder_bit_id, their_address, my_inbox), + payload: MessagePayload::RequestBitAuth(their_address, my_inbox), } = request_msg { debug_assert_eq!(to, self.id, "Got a wrongly addressed message"); @@ -1300,11 +1373,7 @@ impl Party { // sender inputs. self.ot_send(their_address, my_inbox, from, &mac_on_true, &mac_on_false)?; - Ok(BitKey { - holder_bit_id, - bit_holder, - mac_key: mac_on_false, - }) + Ok(mac_on_false) } else { self.log(&format!("Bit Auth: Unexpected message {request_msg:?}")); Err(Error::UnexpectedMessage(request_msg)) @@ -1361,7 +1430,7 @@ impl Party { .clone() .expect("should have shares for all earlier wires already"); - let xor_share = self.xor_abits(&share_left.0, &share_right.0); + let xor_share = xor(&share_left.0, &share_right.0); if self.is_evaluator() { self.wire_shares[gate_index] = Some((xor_share, None)); } else { @@ -1394,14 +1463,14 @@ impl Party { .clone() .expect("should have labels for all AND gate output wires"); - let and_0 = self.xor_abits(&and_output_share.0, &and_share); - let and_1 = self.xor_abits(&and_0, &share_left.0); - let and_2 = self.xor_abits(&and_0, &share_right.0); - let mut and_3 = self.xor_abits(&and_1, &share_right.0); + let and_0 = xor(&and_output_share.0, &and_share); + let and_1 = xor(&and_0, &share_left.0); + let and_2 = xor(&and_0, &share_right.0); + let mut and_3 = xor(&and_1, &share_right.0); if self.is_evaluator() { // do local computation and receive values - and_3.bit.value ^= true; + and_3.bit ^= true; for _j in 1..self.num_parties { let garbled_and_message = self.channels.listen.recv().unwrap(); @@ -1440,7 +1509,7 @@ impl Party { local_ands.push((gate_index, 3, and_3)); } else { // do local computation and send values - let mut evaluator_key = and_3.mac_keys[EVALUATOR_ID]; + let mut evaluator_key = and_3.keys[EVALUATOR_ID]; evaluator_key = xor_mac_width(&evaluator_key, &self.global_mac_key); @@ -1568,7 +1637,7 @@ impl Party { { debug_assert_eq!(to, self.id); // verify mac - let my_key = wire_share.0.mac_keys[from]; + let my_key = wire_share.0.keys[from]; if !verify_mac(&r_j, &mac_j, &my_key, &self.global_mac_key) { return Err(Error::CheckFailed( @@ -1582,7 +1651,7 @@ impl Party { } // compute blinded input value - masked_wire_value = input_value ^ wire_share.0.bit.value; + masked_wire_value = input_value ^ wire_share.0.bit; for bit in other_wire_mask_shares { masked_wire_value ^= bit; } @@ -1611,7 +1680,7 @@ impl Party { .send(Message { from: self.id, to: party, - payload: MessagePayload::WireMac(wire_share.0.bit.value, their_mac), + payload: MessagePayload::WireMac(wire_share.0.bit, their_mac), }) .unwrap(); @@ -1776,7 +1845,7 @@ impl Party { .expect("should have labels and mask for all earlier wires") .1; - let mut masked_output_value = output_wire_share.bit.value; + let mut masked_output_value = output_wire_share.bit; let mut this_wires_labels = Vec::new(); for j in 1..self.num_parties { let garble_index = @@ -1819,7 +1888,7 @@ impl Party { }) .expect("should have keys for all other parties' MACs") .2 - .mac_keys[j]; + .keys[j]; if !verify_mac(&r_j, &my_mac, &my_key, &self.global_mac_key) { return Err(Error::CheckFailed( @@ -1889,7 +1958,7 @@ impl Party { { debug_assert_eq!(to, self.id); // verify mac - let my_key = output_wire_share.mac_keys[from]; + let my_key = output_wire_share.keys[from]; if !verify_mac(&wire_mask_share, &mac, &my_key, &self.global_mac_key) { return Err(Error::CheckFailed("invalid nput wire MAC ".to_owned())); @@ -1923,10 +1992,7 @@ impl Party { .send(Message { from: self.id, to: EVALUATOR_ID, - payload: MessagePayload::WireMac( - output_wire_share.bit.value, - evaluator_mac, - ), + payload: MessagePayload::WireMac(output_wire_share.bit, evaluator_mac), }) .unwrap(); @@ -1950,7 +2016,7 @@ impl Party { &mut self, circuit: &Circuit, input: &[bool], - ) -> Result>, Error> { + ) -> Result<(usize, Option>), Error> { // Validate the circuit circuit .validate_circuit_specification() @@ -2005,11 +2071,11 @@ impl Party { result }; - Ok(if result.is_empty() { - Some(result) + if !result.is_empty() { + Ok((self.id, Some(result))) } else { - None - }) + Ok((self.id, None)) + } } /// Synchronise parties. @@ -2133,11 +2199,11 @@ impl Party { ) -> Vec { let mut result = and_share.serialize_bit_macs(); let mut garbled_label = output_label; - for key in and_share.mac_keys { + for key in and_share.keys { garbled_label = xor_mac_width(&garbled_label, &key); } - if and_share.bit.value { + if and_share.bit { garbled_label = xor_mac_width(&garbled_label, &self.global_mac_key); } result.extend_from_slice(&garbled_label); diff --git a/atlas-spec/mpc-engine/src/primitives/auth_share.rs b/atlas-spec/mpc-engine/src/primitives/auth_share.rs index c321437..5a72649 100644 --- a/atlas-spec/mpc-engine/src/primitives/auth_share.rs +++ b/atlas-spec/mpc-engine/src/primitives/auth_share.rs @@ -1,36 +1,26 @@ //! This module defines the interface for share authentication. -use serde::{Deserialize, Serialize}; +use crate::{ + messages::{Message, MessagePayload}, + party::Party, + primitives::mac::MAC_LENGTH, + Error, +}; -use crate::{primitives::mac::MAC_LENGTH, Error}; - -use super::mac::{self, Mac, MacKey}; - -/// A bit held by a party with a given ID. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Bit { - pub(crate) id: BitID, - pub(crate) value: bool, -} -#[derive(Debug, Clone, Serialize, Deserialize)] -/// A bit identifier. -/// -/// This is unique per party, not globally, so if referring bits held by another -/// party, their party ID is also required to disambiguate. -pub struct BitID(pub(crate) usize); +use super::mac::{self, verify_mac, Mac, MacKey}; #[derive(Debug, Clone)] /// A bit authenticated between two parties. pub struct AuthBit { - pub(crate) bit: Bit, + pub(crate) bit: bool, pub(crate) macs: [Mac; NUM_PARTIES], - pub(crate) mac_keys: [MacKey; NUM_PARTIES], + pub(crate) keys: [MacKey; NUM_PARTIES], } impl AuthBit { /// Serialize the bit value and all MACs on the bit. pub fn serialize_bit_macs(&self) -> Vec { let mut result = vec![0u8; NUM_PARTIES * MAC_LENGTH + 1]; - result[0] = self.bit.value as u8; + result[0] = self.bit as u8; for (key_holder, mac) in self.macs.iter().enumerate() { result[1 + key_holder * MAC_LENGTH..1 + (key_holder + 1) * MAC_LENGTH] .copy_from_slice(mac); @@ -62,12 +52,40 @@ impl AuthBit { } } -/// The key to authenticate a two-party authenticated bit. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BitKey { - pub(crate) holder_bit_id: BitID, - pub(crate) bit_holder: usize, - pub(crate) mac_key: MacKey, +/// Locally compute the XOR of two authenticated bits, which will itself be +/// authenticated already. +pub fn xor( + a: &AuthBit, + b: &AuthBit, +) -> AuthBit { + let mut macs = [mac::zero_mac(); NUM_PARTIES]; + + for (maccing_party, mac) in a.macs.iter().enumerate() { + let mut xored_mac = [0u8; MAC_LENGTH]; + let other_mac = b.macs[maccing_party]; + + for byte in 0..MAC_LENGTH { + xored_mac[byte] = mac[byte] ^ other_mac[byte]; + } + macs[maccing_party] = xored_mac; + } + + let mut mac_keys = [mac::zero_key(); NUM_PARTIES]; + for (bit_holder, key) in a.keys.iter().enumerate() { + let mut xored_key = [0u8; MAC_LENGTH]; + let other_key = b.keys[bit_holder]; + + for byte in 0..MAC_LENGTH { + xored_key[byte] = key[byte] ^ other_key[byte]; + } + mac_keys[bit_holder] = xored_key; + } + + AuthBit { + bit: a.bit ^ b.bit, + macs, + keys: mac_keys, + } } #[test] @@ -91,20 +109,14 @@ fn serialization() { [8; MAC_LENGTH], ]; let test_bit_1 = AuthBit { - bit: Bit { - id: BitID(0), - value: true, - }, + bit: true, macs: macs_1, - mac_keys: keys, + keys, }; let test_bit_2 = AuthBit { - bit: Bit { - id: BitID(1), - value: false, - }, + bit: false, macs: macs_2, - mac_keys: keys, + keys, }; let (bit_1, deserialized_macs_1) = diff --git a/atlas-spec/mpc-engine/src/primitives/kos.rs b/atlas-spec/mpc-engine/src/primitives/kos.rs new file mode 100644 index 0000000..ed94343 --- /dev/null +++ b/atlas-spec/mpc-engine/src/primitives/kos.rs @@ -0,0 +1,40 @@ +//! The KOS OT extension + +use std::sync::mpsc::{Receiver, Sender}; + +use hacspec_lib::Randomness; + +use crate::messages::SubMessage; + +use super::mac::Mac; + +#[derive(Debug)] +/// An Error in the KOS OT extension +pub enum Error {} + +#[allow(unreachable_code)] +pub(crate) fn kos_receive( + selection: &[bool], + sender_address: Sender, + my_inbox: Receiver, + receiver_id: usize, + sender_id: usize, + entropy: &mut Randomness, +) -> Result, Error> { + todo!() +} + +fn kos_dst(sender_id: usize, receiver_id: usize) -> String { + format!("KOS-Base-OT-{}-{}", sender_id, receiver_id) +} + +pub(crate) fn kos_send( + receiver_address: Sender, + my_inbox: Receiver, + receiver_id: usize, + sender_id: usize, + inputs: &[(Mac, Mac)], + entropy: &mut Randomness, +) -> Result<(), Error> { + todo!() +} diff --git a/atlas-spec/mpc-engine/src/primitives/kos_base.rs b/atlas-spec/mpc-engine/src/primitives/kos_base.rs new file mode 100644 index 0000000..365ea21 --- /dev/null +++ b/atlas-spec/mpc-engine/src/primitives/kos_base.rs @@ -0,0 +1,238 @@ +//! This module implements a base OT for the maliciously secure KOS15 OT extension. +//! +//! BaseOT taken from https://eprint.iacr.org/2020/110.pdf. +#![allow(non_snake_case)] +use std::ops::Neg; + +use crate::COMPUTATIONAL_SECURITY; +use hacspec_lib::{hacspec_helper::NatMod, Randomness}; +use hash_to_curve::p256_hash::hash_to_curve; +use hmac::{hkdf_expand, hkdf_extract}; +use p256::{p256_point_mul, random_scalar, P256Point, P256Scalar}; + +use super::mac::MAC_LENGTH; +type BaseOTSeed = [u8; COMPUTATIONAL_SECURITY]; + +#[derive(Debug)] +enum Error { + ReceiverAbort, + SenderCheatDetected, +} + +fn FRO1(seed: &[u8], dst: &[u8]) -> P256Point { + let mut dst = dst.to_vec(); + dst.extend_from_slice(b"F1"); + hash_to_curve(seed, &dst).unwrap() +} + +fn FRO2(point: &P256Point, dst: &[u8]) -> [u8; MAC_LENGTH] { + let mut dst = dst.to_vec(); + dst.extend_from_slice(b"F2"); + let prk = hkdf_extract(b"", &point.raw_bytes()); + let result = hkdf_expand(&prk, &dst, COMPUTATIONAL_SECURITY); + let mut result_array = [0u8; COMPUTATIONAL_SECURITY]; + result_array.copy_from_slice(&result); + result_array +} + +fn FRO3(sender_message: &[u8], dst: &[u8]) -> [u8; COMPUTATIONAL_SECURITY] { + let mut dst = dst.to_vec(); + dst.extend_from_slice(b"F3"); + let prk = hkdf_extract(b"", sender_message); + let result = hkdf_expand(&prk, &dst, COMPUTATIONAL_SECURITY); + let mut result_array = [0u8; COMPUTATIONAL_SECURITY]; + result_array.copy_from_slice(&result); + result_array +} + +fn FRO4( + hashes: &[[u8; COMPUTATIONAL_SECURITY]; L], + dst: &[u8], +) -> [u8; COMPUTATIONAL_SECURITY] { + let mut dst = dst.to_vec(); + dst.extend_from_slice(b"F4"); + let mut input = Vec::new(); + for i in 0..L { + input.extend_from_slice(&hashes[i]); + } + let prk = hkdf_extract(b"", &input); + let result = hkdf_expand(&prk, &dst, COMPUTATIONAL_SECURITY); + let mut result_array = [0u8; COMPUTATIONAL_SECURITY]; + result_array.copy_from_slice(&result); + result_array +} + +pub(crate) struct BaseOTReceiver { + sid: Vec, + T: P256Point, + bits: [bool; L], + alphas: [P256Scalar; L], +} + +pub(crate) struct BaseOTSender { + sid: Vec, + r: P256Scalar, + negTr: P256Point, + chall_hashes: [[u8; COMPUTATIONAL_SECURITY]; L], +} + +impl BaseOTReceiver { + pub(crate) fn init(entropy: &mut Randomness, sid: &[u8]) -> (Self, BaseOTSeed) { + let mut seed_array = [0u8; COMPUTATIONAL_SECURITY]; + let seed = entropy.bytes(COMPUTATIONAL_SECURITY).unwrap().to_owned(); + seed_array.copy_from_slice(&seed); + let bits = [false; L]; + let alphas = [P256Scalar::zero(); L]; + + let T = FRO1(&seed_array, sid); + ( + Self { + sid: sid.to_owned(), + T, + bits, + alphas, + }, + seed_array, + ) + } + + pub(crate) fn messages(&mut self, entropy: &mut Randomness) -> [P256Point; L] { + let mut messages = [P256Point::AtInfinity; L]; + for i in 0..L { + self.bits[i] = entropy.bit().unwrap(); + self.alphas[i] = random_scalar(entropy, &self.sid).unwrap(); + messages[i] = p256::p256_point_mul_base(self.alphas[i]).unwrap(); + if self.bits[i] { + messages[i] = p256::point_add(messages[i], self.T).unwrap(); + } + } + messages + } + + fn decrypt(&self, z: P256Point) -> [[u8; COMPUTATIONAL_SECURITY]; L] { + let mut messages = [[0u8; COMPUTATIONAL_SECURITY]; L]; + for i in 0..L { + let input = p256::p256_point_mul(self.alphas[i], z).unwrap(); + messages[i] = FRO2(&input, &self.sid); + } + messages + } + + fn responses( + &self, + messages: &[[u8; MAC_LENGTH]; L], + challenges: &[[u8; COMPUTATIONAL_SECURITY]; L], + ) -> [u8; COMPUTATIONAL_SECURITY] { + let mut responses = [[0u8; COMPUTATIONAL_SECURITY]; L]; + for i in 0..L { + responses[i] = FRO3(&messages[i], &self.sid); + if self.bits[i] { + responses[i] = xor_arrays(&responses[i], &challenges[i]); + } + } + FRO4(&responses, &self.sid) + } + + fn challenge_verification( + &self, + Ans: &[u8; COMPUTATIONAL_SECURITY], + gamma: &[u8; COMPUTATIONAL_SECURITY], + ) -> Result<(), Error> { + let gamma_prime = FRO3(Ans, &self.sid); + if gamma_prime != *gamma { + return Err(Error::ReceiverAbort); + } + Ok(()) + } +} + +impl BaseOTSender { + pub(crate) fn init( + entropy: &mut Randomness, + sid: &[u8], + seed: &BaseOTSeed, + ) -> (Self, P256Point) { + let T = FRO1(seed, sid); + let r = random_scalar(entropy, sid).unwrap(); + let negTr = p256::p256_point_mul(r, T).unwrap().neg(); + let chall_hashes = [[0u8; COMPUTATIONAL_SECURITY]; L]; + let z = p256::p256_point_mul_base(r).unwrap(); + ( + Self { + sid: sid.to_owned().into(), + chall_hashes, + r, + negTr, + }, + z, + ) + } + + pub(crate) fn messages( + &self, + receiver_messages: [P256Point; L], + ) -> [([u8; MAC_LENGTH], [u8; MAC_LENGTH]); L] { + let mut messages = [([0u8; MAC_LENGTH], [0u8; MAC_LENGTH]); L]; + for i in 0..L { + let preimg_0 = p256_point_mul(self.r, receiver_messages[i]).unwrap(); + let preimg_1 = p256::point_add(self.negTr, preimg_0).unwrap(); + let pi_0 = FRO2(&preimg_0, &self.sid); + let pi_1 = FRO2(&preimg_1, &self.sid); + messages[i] = (pi_0, pi_1); + } + + messages + } + + fn challenges( + &mut self, + messages: &[([u8; MAC_LENGTH], [u8; MAC_LENGTH]); L], + ) -> [[u8; COMPUTATIONAL_SECURITY]; L] { + let mut challenges = [[0u8; COMPUTATIONAL_SECURITY]; L]; + for i in 0..L { + let chall_hash_0 = FRO3(&messages[i].0, &self.sid); + let chall_hash_1 = FRO3(&messages[i].1, &self.sid); + self.chall_hashes[i] = chall_hash_0; + challenges[i] = xor_arrays(&chall_hash_0, &chall_hash_1); + } + challenges + } + + fn proof(&self) -> ([u8; COMPUTATIONAL_SECURITY], [u8; COMPUTATIONAL_SECURITY]) { + let Ans = FRO4(&self.chall_hashes, &self.sid); + let gamma = FRO3(&Ans, &self.sid); + (Ans, gamma) + } +} + +fn xor_arrays(a: &[u8; L], b: &[u8; L]) -> [u8; L] { + let mut result = [0u8; L]; + for i in 0..L { + result[i] = a[i] ^ b[i]; + } + result +} + +#[test] +fn simple() { + use rand::{thread_rng, RngCore}; + let sid = b"test"; + let mut rng = thread_rng(); + let mut entropy = [0u8; 100000]; + rng.fill_bytes(&mut entropy); + let mut entropy = Randomness::new(entropy.to_vec()); + let (mut receiver, seed) = BaseOTReceiver::<5>::init(&mut entropy, sid); + let receiver_messages = receiver.messages(&mut entropy); + + let (mut sender, sender_parameter) = BaseOTSender::<5>::init(&mut entropy, sid, &seed); + let sender_messages = sender.messages(receiver_messages); + let challenges = sender.challenges(&sender_messages); + let (Ans_sender, gamma) = sender.proof(); + + let decryptions = receiver.decrypt(sender_parameter); + let Ans_receiver = receiver.responses(&decryptions, &challenges); + receiver + .challenge_verification(&Ans_receiver, &gamma) + .unwrap(); + assert_eq!(Ans_receiver, Ans_sender) +} diff --git a/atlas-spec/mpc-engine/src/primitives/mod.rs b/atlas-spec/mpc-engine/src/primitives/mod.rs index 3ff4ced..fc09cfb 100644 --- a/atlas-spec/mpc-engine/src/primitives/mod.rs +++ b/atlas-spec/mpc-engine/src/primitives/mod.rs @@ -2,5 +2,7 @@ pub mod auth_share; pub mod commitment; +pub mod kos; +mod kos_base; pub mod mac; pub mod ot; diff --git a/atlas-spec/mpc-engine/src/runner.rs b/atlas-spec/mpc-engine/src/runner.rs index 62151ec..7476ec6 100644 --- a/atlas-spec/mpc-engine/src/runner.rs +++ b/atlas-spec/mpc-engine/src/runner.rs @@ -12,7 +12,7 @@ pub struct Runner; impl Runner { /// Set up and run an MPC session of the given circuit with the provided /// inputs. - pub fn run( + pub fn run_mpc( circuit: &Circuit, inputs: &[&[bool]], logging: Vec, From d72bdea195a5cd02c44ba4d3f4e1d8c29a4c67bd Mon Sep 17 00:00:00 2001 From: Jonas Schneider-Bensch Date: Wed, 7 Aug 2024 16:47:11 +0200 Subject: [PATCH 06/14] WIP KOS15 OT extension --- atlas-spec/mpc-engine/src/lib.rs | 4 +- atlas-spec/mpc-engine/src/messages.rs | 7 + atlas-spec/mpc-engine/src/primitives/kos.rs | 321 +++++++++++++++++- .../mpc-engine/src/primitives/kos_base.rs | 133 ++++++-- atlas-spec/mpc-engine/src/utils.rs | 31 ++ 5 files changed, 454 insertions(+), 42 deletions(-) diff --git a/atlas-spec/mpc-engine/src/lib.rs b/atlas-spec/mpc-engine/src/lib.rs index a850bc1..9a00a38 100644 --- a/atlas-spec/mpc-engine/src/lib.rs +++ b/atlas-spec/mpc-engine/src/lib.rs @@ -41,6 +41,8 @@ pub enum Error { AEADError, /// Miscellaneous error. OtherError, + /// Subprotocol error + SubprotocolError, } impl From for Error { @@ -71,5 +73,5 @@ pub mod circuit; pub mod messages; pub mod party; pub mod primitives; -pub mod utils; pub mod runner; +pub mod utils; diff --git a/atlas-spec/mpc-engine/src/messages.rs b/atlas-spec/mpc-engine/src/messages.rs index d5ecf89..6f874fc 100644 --- a/atlas-spec/mpc-engine/src/messages.rs +++ b/atlas-spec/mpc-engine/src/messages.rs @@ -5,6 +5,7 @@ use crate::{ circuit::WireIndex, primitives::{ commitment::{Commitment, Opening}, + kos::{KOSReceiverPhaseI, KOSSenderPhaseI, KOSSenderPhaseII}, mac::Mac, ot::{OTReceiverSelect, OTSenderInit, OTSenderSend}, }, @@ -78,4 +79,10 @@ pub enum SubMessage { EQResponse(Vec), /// An EQ initiator opening EQOpening(Opening), + /// A KOS OT extension sender message in Phase I + KOSSenderPhaseI(KOSSenderPhaseI), + /// A KOS OT extension sender message in Phase I + KOSReceiverPhaseI(KOSReceiverPhaseI), + /// A KOS OT extension sender message in Phase I + KOSSenderPhaseII(KOSSenderPhaseII), } diff --git a/atlas-spec/mpc-engine/src/primitives/kos.rs b/atlas-spec/mpc-engine/src/primitives/kos.rs index ed94343..676bb86 100644 --- a/atlas-spec/mpc-engine/src/primitives/kos.rs +++ b/atlas-spec/mpc-engine/src/primitives/kos.rs @@ -1,18 +1,271 @@ //! The KOS OT extension +//! +//! Computational security parameter is fixed to 128. +#![allow(non_snake_case)] use std::sync::mpsc::{Receiver, Sender}; use hacspec_lib::Randomness; +use hmac::{hkdf_expand, hkdf_extract}; -use crate::messages::SubMessage; +use crate::{ + messages::SubMessage, + primitives::kos_base, + utils::{ith_bit, pack_bits}, +}; -use super::mac::Mac; +use super::{ + kos_base::{BaseOTReceiver, BaseOTSender, ReceiverChoose, ReceiverResponse, SenderTransfer}, + mac::{xor_mac_width, Mac}, +}; #[derive(Debug)] /// An Error in the KOS OT extension -pub enum Error {} +pub enum Error { + /// An Error that occurred in the BaseOT. + BaseOTError, + /// A consistency check has failed. + Consistency, +} + +impl From for Error { + fn from(_value: crate::primitives::kos_base::Error) -> Self { + Self::BaseOTError + } +} + +fn CRF(sid: &[u8], input: &Mac, tweak: usize) -> Mac { + todo!() +} + +fn PRG(sid: &[u8], k: &[u8], len: usize) -> Vec { + todo!() +} + +fn FRO2(sid: &[u8], matrix: &[Vec; 128]) -> Vec { + let mut ikm = sid.to_vec(); + let out_len = matrix[0].len(); + debug_assert_eq!(out_len % 16, 0); + for column in matrix { + ikm.extend_from_slice(column) + } + let prk = hkdf_extract(b"", &ikm); + let result_bytes = hkdf_expand(&prk, sid, out_len); + let result = result_bytes + .chunks_exact(16) + .map(|chunk| { + u128::from_be_bytes( + chunk + .try_into() + .expect("should be given exactly 16 byte chunks"), + ) + }) + .collect(); + result +} + +fn challenge_selection(challenge: &[u128], selection_matrix: &[Vec; 128]) -> u128 { + todo!() +} + +fn packed_row(matrix: &[Vec; 128], index: usize) -> u128 { + let mut result = 0u128; + for i in 0..128 { + let b = ith_bit(index, &matrix[i]); + if b { + result += 1 << (i as u128); + } + } + result +} + +fn kos_dst(sender_id: usize, receiver_id: usize) -> String { + format!("KOS-Base-OT-{}-{}", sender_id, receiver_id) +} + +/// The message sent by the KOS15 Receiver in phase I of the protocol. +#[derive(Debug)] +pub struct KOSReceiverPhaseI { + base_ot_transfer: SenderTransfer<128>, + D: [Vec; 128], + u: u128, + v: u128, +} + +/// The KOS Receiver state. +pub struct KOSReceiver { + selection_bits: Vec, + base_sender: BaseOTSender<128>, + M_columns: [Vec; 128], + sid: Vec, +} + +impl KOSReceiver { + pub(crate) fn phase_i( + selection: &[bool], + sender_phase_i: KOSSenderPhaseI, + sid: &[u8], + entropy: &mut Randomness, + ) -> (Self, KOSReceiverPhaseI) { + let (base_sender, base_sender_transfer) = + kos_base::BaseOTSender::<128>::transfer(entropy, &sid, sender_phase_i.base_ot_choice); + + let tau = entropy.bytes(128 / 8).unwrap(); + let mut r_prime = crate::utils::pack_bits(selection); + r_prime.extend_from_slice(&tau); + let M_columns: [Vec; 128] = + std::array::from_fn(|i| PRG(&sid, &base_sender.inputs[i].0, 16 + selection.len() / 8)); + let R_columns: [Vec; 128] = std::array::from_fn(|_i| r_prime.clone()); + let D_columns: [Vec; 128] = std::array::from_fn(|i| { + let prg_result = PRG(&sid, &base_sender.inputs[i].1, 16 + selection.len() / 8); + let temp_result = crate::utils::xor_slices(&M_columns[i], &prg_result); + crate::utils::xor_slices(&temp_result, &R_columns[i]) + }); + + let Chi = FRO2(&sid, &D_columns); + + let u = challenge_selection(&Chi, &M_columns); + let v = challenge_selection(&Chi, &R_columns); + + ( + Self { + selection_bits: selection.to_owned(), + base_sender, + M_columns, + sid: sid.to_owned(), + }, + KOSReceiverPhaseI { + base_ot_transfer: base_sender_transfer, + D: D_columns, + u, + v, + }, + ) + } -#[allow(unreachable_code)] + fn phase_ii(self, sender_phase_ii: KOSSenderPhaseII) -> Result, Error> { + let mut results = Vec::new(); + self.base_sender.verify(sender_phase_ii.base_ot_response)?; + for (index, selection_bit) in self.selection_bits.iter().enumerate() { + let crf = CRF( + &self.sid, + &packed_row(&self.M_columns, index).to_be_bytes(), + index, + ); + let y = if *selection_bit { + sender_phase_ii.ys[index].1 + } else { + sender_phase_ii.ys[index].0 + }; + let a = xor_mac_width(&y, &crf); + results.push(a) + } + Ok(results) + } +} + + +pub(crate) struct KOSSender { + base_receiver: BaseOTReceiver<128>, + sid: Vec, +} + +/// The message sent by the KOS15 Sender in phase I of the protocol. +#[derive(Debug)] +pub struct KOSSenderPhaseI { + base_ot_choice: ReceiverChoose<128>, +} + +/// The message sent by the KOS15 Sender in phase II of the protocol. +#[derive(Debug)] +pub struct KOSSenderPhaseII { + ys: Vec<(Mac, Mac)>, + base_ot_response: ReceiverResponse, +} + +impl KOSSender { + pub(crate) fn phase_i(sid: &[u8], entropy: &mut Randomness) -> (Self, KOSSenderPhaseI) { + let (base_receiver, base_ot_choice) = + crate::primitives::kos_base::BaseOTReceiver::<128>::choose(entropy, &sid); + + ( + Self { + sid: sid.to_owned(), + base_receiver, + }, + KOSSenderPhaseI { base_ot_choice }, + ) + } + fn check_uvw(u: u128, v: u128, w: u128, s: u128) -> Result<(), Error> { + if w == u ^ (s * v) { + Ok(()) + } else { + Err(Error::Consistency) + } + } + + fn phase_ii( + &mut self, + inputs: &[(Mac, Mac)], + receiver_phase_i: KOSReceiverPhaseI, + ) -> Result { + let (base_receiver_output, base_ot_response) = self + .base_receiver + .response(receiver_phase_i.base_ot_transfer)?; + + let Q_columns: [Vec; 128] = std::array::from_fn(|i| { + let mut result = PRG(&self.sid, &base_receiver_output[i], 16 + inputs.len() / 8); + // the following is obviously secret-dependent timing + if self.base_receiver.selection_bits[i] { + result = crate::utils::xor_slices(&result, &receiver_phase_i.D[i]); + } + result + }); + + let Chi = FRO2(&self.sid, &receiver_phase_i.D); + + let w = challenge_selection(&Chi, &Q_columns); + let s = pack_bits(&self.base_receiver.selection_bits); + let mut s_array = [0u8; 16]; + s_array.copy_from_slice(&s[..16]); + + Self::check_uvw( + receiver_phase_i.u, + receiver_phase_i.v, + w, + u128::from_be_bytes(s_array), + )?; + + let mut ys = Vec::new(); + for (index, (a_0, a_1)) in inputs.iter().enumerate() { + let crf_0 = CRF( + &self.sid, + &packed_row(&Q_columns, index).to_be_bytes(), + index, + ); + let crf_1 = CRF( + &self.sid, + &xor_mac_width(&packed_row(&Q_columns, index).to_be_bytes(), &s_array), + index, + ); + let y_0 = xor_mac_width(a_0, &crf_0); + let y_1 = xor_mac_width(a_1, &crf_1); + ys.push((y_0, y_1)) + } + + Ok(KOSSenderPhaseII { + ys, + base_ot_response, + }) + } +} + +/// Run the KOS15 protocol in the role of the receiver. +/// +/// Uses the given Channels to communicate the KOS messages from the +/// perspective of the receiver. The input `selection` determines +/// which of the senders inputs get obliviously transfered to the +/// receiver. pub(crate) fn kos_receive( selection: &[bool], sender_address: Sender, @@ -20,21 +273,65 @@ pub(crate) fn kos_receive( receiver_id: usize, sender_id: usize, entropy: &mut Randomness, -) -> Result, Error> { - todo!() -} +) -> Result, crate::Error> { + let sid = kos_dst(sender_id, receiver_id).as_bytes().to_owned(); -fn kos_dst(sender_id: usize, receiver_id: usize) -> String { - format!("KOS-Base-OT-{}-{}", sender_id, receiver_id) + let sender_phase_i_msg = my_inbox.recv().unwrap(); + if let SubMessage::KOSSenderPhaseI(sender_phase_i) = sender_phase_i_msg { + let (receiver, phase_i) = KOSReceiver::phase_i(selection, sender_phase_i, &sid, entropy); + sender_address + .send(SubMessage::KOSReceiverPhaseI(phase_i)) + .unwrap(); + let sender_phase_ii_msg = my_inbox.recv().unwrap(); + if let SubMessage::KOSSenderPhaseII(sender_phase_ii) = sender_phase_ii_msg { + let outputs = receiver + .phase_ii(sender_phase_ii) + .map_err(|_| crate::Error::SubprotocolError)?; + Ok(outputs) + } else { + Err(crate::Error::UnexpectedSubprotocolMessage( + sender_phase_ii_msg, + )) + } + } else { + Err(crate::Error::UnexpectedSubprotocolMessage( + sender_phase_i_msg, + )) + } } +/// Run the KOS15 protocol in the role of the sender. +/// +/// Uses the given Channels to communicate the KOS messages from the +/// perspective of the sender. The receiver's input `selection` +/// determines which of the senders inputs get obliviously transfered +/// to the receiver. pub(crate) fn kos_send( + inputs: &[(Mac, Mac)], receiver_address: Sender, my_inbox: Receiver, receiver_id: usize, sender_id: usize, - inputs: &[(Mac, Mac)], entropy: &mut Randomness, -) -> Result<(), Error> { - todo!() +) -> Result<(), crate::Error> { + let sid = kos_dst(sender_id, receiver_id).as_bytes().to_owned(); + + let (mut kos_sender, phase_i) = KOSSender::phase_i(&sid, entropy); + receiver_address + .send(SubMessage::KOSSenderPhaseI(phase_i)) + .unwrap(); + let receiver_phase_i_message = my_inbox.recv().unwrap(); + if let SubMessage::KOSReceiverPhaseI(receiver_phase_i) = receiver_phase_i_message { + let phase_ii = kos_sender + .phase_ii(inputs, receiver_phase_i) + .map_err(|_| crate::Error::SubprotocolError)?; + receiver_address + .send(SubMessage::KOSSenderPhaseII(phase_ii)) + .unwrap(); + Ok(()) + } else { + Err(crate::Error::UnexpectedSubprotocolMessage( + receiver_phase_i_message, + )) + } } diff --git a/atlas-spec/mpc-engine/src/primitives/kos_base.rs b/atlas-spec/mpc-engine/src/primitives/kos_base.rs index 365ea21..abe0587 100644 --- a/atlas-spec/mpc-engine/src/primitives/kos_base.rs +++ b/atlas-spec/mpc-engine/src/primitives/kos_base.rs @@ -14,7 +14,7 @@ use super::mac::MAC_LENGTH; type BaseOTSeed = [u8; COMPUTATIONAL_SECURITY]; #[derive(Debug)] -enum Error { +pub enum Error { ReceiverAbort, SenderCheatDetected, } @@ -65,23 +65,59 @@ fn FRO4( pub(crate) struct BaseOTReceiver { sid: Vec, T: P256Point, - bits: [bool; L], + pub selection_bits: [bool; L], alphas: [P256Scalar; L], } pub(crate) struct BaseOTSender { sid: Vec, r: P256Scalar, + pub inputs: [([u8; 16], [u8; 16]); L], + expected_answer: [u8; 16], negTr: P256Point, chall_hashes: [[u8; COMPUTATIONAL_SECURITY]; L], } +#[derive(Debug)] +pub(crate) struct ReceiverChoose { + seed: BaseOTSeed, + messages: [P256Point; L], +} + +#[derive(Debug)] +pub(crate) struct ReceiverResponse { + response: [u8; 16], +} + +#[derive(Debug)] +pub(crate) struct SenderTransfer { + seed: P256Point, + challenge: [[u8; 16]; L], + gamma: [u8; 16], +} + impl BaseOTReceiver { - pub(crate) fn init(entropy: &mut Randomness, sid: &[u8]) -> (Self, BaseOTSeed) { + pub(crate) fn choose(entropy: &mut Randomness, sid: &[u8]) -> (Self, ReceiverChoose) { + let (mut receiver, seed) = Self::parameters(entropy, sid); + let (bits, messages) = receiver.messages(entropy); + receiver.selection_bits = bits; + (receiver, ReceiverChoose { seed, messages }) + } + + pub(crate) fn response( + &self, + transfer: SenderTransfer, + ) -> Result<([[u8; 16]; L], ReceiverResponse), Error> { + let messages = self.decrypt(transfer.seed); + let response = self.responses(&self.selection_bits, &messages, &transfer.challenge); + self.challenge_verification(&response, &transfer.gamma)?; + Ok((messages, ReceiverResponse { response })) + } + + fn parameters(entropy: &mut Randomness, sid: &[u8]) -> (Self, BaseOTSeed) { let mut seed_array = [0u8; COMPUTATIONAL_SECURITY]; let seed = entropy.bytes(COMPUTATIONAL_SECURITY).unwrap().to_owned(); seed_array.copy_from_slice(&seed); - let bits = [false; L]; let alphas = [P256Scalar::zero(); L]; let T = FRO1(&seed_array, sid); @@ -89,24 +125,24 @@ impl BaseOTReceiver { Self { sid: sid.to_owned(), T, - bits, + selection_bits: [false; L], alphas, }, seed_array, ) } - pub(crate) fn messages(&mut self, entropy: &mut Randomness) -> [P256Point; L] { + fn messages(&mut self, entropy: &mut Randomness) -> ([bool; L], [P256Point; L]) { let mut messages = [P256Point::AtInfinity; L]; + let bits: [bool; L] = std::array::from_fn(|_| entropy.bit().unwrap()); for i in 0..L { - self.bits[i] = entropy.bit().unwrap(); self.alphas[i] = random_scalar(entropy, &self.sid).unwrap(); messages[i] = p256::p256_point_mul_base(self.alphas[i]).unwrap(); - if self.bits[i] { + if bits[i] { messages[i] = p256::point_add(messages[i], self.T).unwrap(); } } - messages + (bits, messages) } fn decrypt(&self, z: P256Point) -> [[u8; COMPUTATIONAL_SECURITY]; L] { @@ -120,13 +156,14 @@ impl BaseOTReceiver { fn responses( &self, + bits: &[bool; L], messages: &[[u8; MAC_LENGTH]; L], challenges: &[[u8; COMPUTATIONAL_SECURITY]; L], ) -> [u8; COMPUTATIONAL_SECURITY] { let mut responses = [[0u8; COMPUTATIONAL_SECURITY]; L]; for i in 0..L { responses[i] = FRO3(&messages[i], &self.sid); - if self.bits[i] { + if bits[i] { responses[i] = xor_arrays(&responses[i], &challenges[i]); } } @@ -147,7 +184,36 @@ impl BaseOTReceiver { } impl BaseOTSender { - pub(crate) fn init( + pub(crate) fn transfer( + entropy: &mut Randomness, + sid: &[u8], + choice: ReceiverChoose, + ) -> (Self, SenderTransfer) { + let (mut sender, seed) = Self::parameters(entropy, sid, &choice.seed); + let inputs = sender.generate_inputs(choice.messages); + let challenge = sender.challenges(&inputs); + sender.inputs = inputs; + let (expected_answer, gamma) = sender.proof(); + sender.expected_answer = expected_answer; + ( + sender, + SenderTransfer { + seed, + challenge, + gamma, + }, + ) + } + + pub(crate) fn verify(&self, response: ReceiverResponse) -> Result<(), Error> { + if response.response != self.expected_answer { + Err(Error::SenderCheatDetected) + } else { + Ok(()) + } + } + + fn parameters( entropy: &mut Randomness, sid: &[u8], seed: &BaseOTSeed, @@ -163,12 +229,14 @@ impl BaseOTSender { chall_hashes, r, negTr, + inputs: [([0u8; 16], [0u8; 16]); L], + expected_answer: [0u8; 16], }, z, ) } - pub(crate) fn messages( + fn generate_inputs( &self, receiver_messages: [P256Point; L], ) -> [([u8; MAC_LENGTH], [u8; MAC_LENGTH]); L] { @@ -199,9 +267,9 @@ impl BaseOTSender { } fn proof(&self) -> ([u8; COMPUTATIONAL_SECURITY], [u8; COMPUTATIONAL_SECURITY]) { - let Ans = FRO4(&self.chall_hashes, &self.sid); - let gamma = FRO3(&Ans, &self.sid); - (Ans, gamma) + let expected_answer = FRO4(&self.chall_hashes, &self.sid); + let gamma = FRO3(&expected_answer, &self.sid); + (expected_answer, gamma) } } @@ -215,24 +283,31 @@ fn xor_arrays(a: &[u8; L], b: &[u8; L]) -> [u8; L] { #[test] fn simple() { + // pre-requisites use rand::{thread_rng, RngCore}; let sid = b"test"; let mut rng = thread_rng(); let mut entropy = [0u8; 100000]; rng.fill_bytes(&mut entropy); let mut entropy = Randomness::new(entropy.to_vec()); - let (mut receiver, seed) = BaseOTReceiver::<5>::init(&mut entropy, sid); - let receiver_messages = receiver.messages(&mut entropy); - - let (mut sender, sender_parameter) = BaseOTSender::<5>::init(&mut entropy, sid, &seed); - let sender_messages = sender.messages(receiver_messages); - let challenges = sender.challenges(&sender_messages); - let (Ans_sender, gamma) = sender.proof(); - - let decryptions = receiver.decrypt(sender_parameter); - let Ans_receiver = receiver.responses(&decryptions, &challenges); - receiver - .challenge_verification(&Ans_receiver, &gamma) - .unwrap(); - assert_eq!(Ans_receiver, Ans_sender) + + let (mut receiver, choice_message) = BaseOTReceiver::<5>::choose(&mut entropy, sid); + + let (mut sender, transfer_message) = + BaseOTSender::<5>::transfer(&mut entropy, sid, choice_message); + + let (receiver_outputs, response) = receiver.response(transfer_message).unwrap(); + + sender.verify(response).unwrap(); + + for (i, selection_bit) in receiver.selection_bits.iter().enumerate() { + assert_eq!( + receiver_outputs[i], + if *selection_bit { + sender.inputs[i].1 + } else { + sender.inputs[i].0 + } + ) + } } diff --git a/atlas-spec/mpc-engine/src/utils.rs b/atlas-spec/mpc-engine/src/utils.rs index 6612669..8988101 100644 --- a/atlas-spec/mpc-engine/src/utils.rs +++ b/atlas-spec/mpc-engine/src/utils.rs @@ -40,3 +40,34 @@ pub(crate) fn ith_bit(i: usize, bytes: &[u8]) -> bool { let bit_index = 7 - i % 8; ((bytes[byte_index] >> bit_index) & 1u8) == 1u8 } + +/// Pack slice of `bool`s into a byte vector. +/// +/// We assume that `bits.len()` is a multiple of 8. +pub(crate) fn pack_bits(bits: &[bool]) -> Vec { + + let mut result = Vec::new(); + let full_blocks = bits.len() / 8; + let remainder = bits.len() % 8; + + debug_assert_eq!(remainder, 0); + + for i in 0..full_blocks { + let mut current_byte = 0u8; + for bit in 0..8 { + current_byte += (bits[i * 8 + bit] as u8) << (7 - bit); + } + result.push(current_byte); + } + + result +} + +pub(crate) fn xor_slices(left: &[u8], right: &[u8]) -> Vec { + debug_assert_eq!(left.len(), right.len()); + let mut result = Vec::with_capacity(left.len()); + for i in 0..left.len() { + result.push(left[i] ^ right[i]) + } + result +} From 48295562ee712c4fcf9d8251ca5b100b93a2c5a5 Mon Sep 17 00:00:00 2001 From: Jonas Schneider-Bensch Date: Thu, 8 Aug 2024 10:13:32 +0200 Subject: [PATCH 07/14] Draft KOS15; needs testing & debugging --- atlas-spec/mpc-engine/src/party.rs | 6 +-- atlas-spec/mpc-engine/src/primitives/kos.rs | 41 +++++++++++++++++---- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/atlas-spec/mpc-engine/src/party.rs b/atlas-spec/mpc-engine/src/party.rs index 6231f3d..0a274a7 100644 --- a/atlas-spec/mpc-engine/src/party.rs +++ b/atlas-spec/mpc-engine/src/party.rs @@ -1284,11 +1284,11 @@ impl Party { // Initiate an OT session with the bit holder giving the two MACs as // sender inputs. kos_send( + &kos_inputs, their_address, my_inbox, from, self.id, - &kos_inputs, &mut self.entropy, ) .unwrap_or_default(); @@ -2029,8 +2029,8 @@ impl Party { let target_number = circuit.share_authentication_cost(); - self.abit_pool = self.precompute_abits(target_number + SEC_MARGIN_SHARE_AUTH)?; - + // self.abit_pool = self.precompute_abits(target_number + SEC_MARGIN_SHARE_AUTH)?; + self.abit_pool = self.batch_precompute_abits(target_number + SEC_MARGIN_SHARE_AUTH)?; self.function_independent(circuit).unwrap(); let (garbled_ands, local_ands) = self.function_dependent(circuit).unwrap(); diff --git a/atlas-spec/mpc-engine/src/primitives/kos.rs b/atlas-spec/mpc-engine/src/primitives/kos.rs index 676bb86..8517b2f 100644 --- a/atlas-spec/mpc-engine/src/primitives/kos.rs +++ b/atlas-spec/mpc-engine/src/primitives/kos.rs @@ -16,7 +16,7 @@ use crate::{ use super::{ kos_base::{BaseOTReceiver, BaseOTSender, ReceiverChoose, ReceiverResponse, SenderTransfer}, - mac::{xor_mac_width, Mac}, + mac::{xor_mac_width, Mac, MAC_LENGTH}, }; #[derive(Debug)] @@ -34,12 +34,31 @@ impl From for Error { } } +/// Implements a tweakable correlation robust hash function. +/// +/// Note: This could also be implemented as +/// +/// H(sid|tweak|input) = pi(pi(sid|input) xor tweak) xor pi(sid|input) +/// +/// where pi is an ideal permutation, fixed-key AES in practice. fn CRF(sid: &[u8], input: &Mac, tweak: usize) -> Mac { - todo!() + let mut ikm = sid.to_vec(); + ikm.extend_from_slice(&[tweak as u8]); + ikm.extend_from_slice(input); + let prk = hkdf_extract(b"", &ikm); + let result = hkdf_expand(&prk, sid, MAC_LENGTH) + .try_into() + .expect("should have received exactly `MAC_LENGHT` bytes"); + result } fn PRG(sid: &[u8], k: &[u8], len: usize) -> Vec { - todo!() + let mut ikm = sid.to_vec(); + ikm.extend_from_slice(k); + let prk = hkdf_extract(b"", &ikm); + let result = hkdf_expand(&prk, sid, len); + + result } fn FRO2(sid: &[u8], matrix: &[Vec; 128]) -> Vec { @@ -64,8 +83,13 @@ fn FRO2(sid: &[u8], matrix: &[Vec; 128]) -> Vec { result } +/// This implements Xor_{j in [m+k]} (Chi_j * M_j). fn challenge_selection(challenge: &[u128], selection_matrix: &[Vec; 128]) -> u128 { - todo!() + let mut result = 0u128; + for i in 0..challenge.len() { + result ^= challenge[i] * packed_row(selection_matrix, i); + } + result } fn packed_row(matrix: &[Vec; 128], index: usize) -> u128 { @@ -79,8 +103,10 @@ fn packed_row(matrix: &[Vec; 128], index: usize) -> u128 { result } -fn kos_dst(sender_id: usize, receiver_id: usize) -> String { +fn kos_dst(sender_id: usize, receiver_id: usize) -> Vec { format!("KOS-Base-OT-{}-{}", sender_id, receiver_id) + .as_bytes() + .to_vec() } /// The message sent by the KOS15 Receiver in phase I of the protocol. @@ -164,7 +190,6 @@ impl KOSReceiver { } } - pub(crate) struct KOSSender { base_receiver: BaseOTReceiver<128>, sid: Vec, @@ -274,7 +299,7 @@ pub(crate) fn kos_receive( sender_id: usize, entropy: &mut Randomness, ) -> Result, crate::Error> { - let sid = kos_dst(sender_id, receiver_id).as_bytes().to_owned(); + let sid = kos_dst(sender_id, receiver_id); let sender_phase_i_msg = my_inbox.recv().unwrap(); if let SubMessage::KOSSenderPhaseI(sender_phase_i) = sender_phase_i_msg { @@ -314,7 +339,7 @@ pub(crate) fn kos_send( sender_id: usize, entropy: &mut Randomness, ) -> Result<(), crate::Error> { - let sid = kos_dst(sender_id, receiver_id).as_bytes().to_owned(); + let sid = kos_dst(sender_id, receiver_id); let (mut kos_sender, phase_i) = KOSSender::phase_i(&sid, entropy); receiver_address From 196fabf86a47e7d4546a041e06f41a99788bc878 Mon Sep 17 00:00:00 2001 From: Jonas Schneider-Bensch Date: Thu, 8 Aug 2024 12:06:12 +0200 Subject: [PATCH 08/14] Simple plain circuit evaluation testing --- atlas-spec/mpc-engine/src/circuit.rs | 117 +++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/atlas-spec/mpc-engine/src/circuit.rs b/atlas-spec/mpc-engine/src/circuit.rs index 22a48ae..e467a37 100644 --- a/atlas-spec/mpc-engine/src/circuit.rs +++ b/atlas-spec/mpc-engine/src/circuit.rs @@ -347,3 +347,120 @@ impl Circuit { result } } + +#[cfg(test)] +mod tests { + use std::arch::x86_64::{_CMP_FALSE_OQ, _CMP_TRUE_UQ}; + + use crate::utils::ith_bit; + + use super::*; + + fn gen_inputs() -> Vec<[Vec; 4]> { + let mut results = Vec::new(); + for i in 0..16 { + let mut current_input = [Vec::new(), Vec::new(), Vec::new(), Vec::new()]; + for j in 0..4 { + current_input[j] = vec![ith_bit(j + 4, &[i as u8])]; + } + results.push(current_input); + } + results + } + + fn parity(input: &[Vec; 4]) -> bool { + let sum = input[0][0] as u8 + input[1][0] as u8 + input[2][0] as u8 + input[3][0] as u8; + !(sum % 2 == 0) + } + + #[test] + fn eval_and_2() { + let and = Circuit { + input_widths: vec![1, 1], + gates: vec![ + WiredGate::Input(0), // Gate 0 + WiredGate::Input(1), // Gate 1 + WiredGate::And(0, 1), // Gate 2 + ], + output_gates: vec![2], + }; + + assert_eq!(and.eval(&[vec![true], vec![true]]).unwrap()[0], true,); + assert_eq!(and.eval(&[vec![true], vec![false]]).unwrap()[0], false,); + assert_eq!(and.eval(&[vec![false], vec![true]]).unwrap()[0], false,); + assert_eq!(and.eval(&[vec![false], vec![false]]).unwrap()[0], false,); + } + + #[test] + fn eval_xor_2() { + let and = Circuit { + input_widths: vec![1, 1], + gates: vec![ + WiredGate::Input(0), // Gate 0 + WiredGate::Input(1), // Gate 1 + WiredGate::Xor(0, 1), // Gate 2 + ], + output_gates: vec![2], + }; + + assert_eq!(and.eval(&[vec![true], vec![true]]).unwrap()[0], false,); + assert_eq!(and.eval(&[vec![true], vec![false]]).unwrap()[0], true,); + assert_eq!(and.eval(&[vec![false], vec![true]]).unwrap()[0], true,); + assert_eq!(and.eval(&[vec![false], vec![false]]).unwrap()[0], false,); + } + + #[test] + fn eval_and_4() { + let and = Circuit { + input_widths: vec![1, 1, 1, 1], + gates: vec![ + WiredGate::Input(0), // Gate 0 + WiredGate::Input(1), // Gate 1 + WiredGate::Input(2), // Gate 2 + WiredGate::Input(3), // Gate 3 + WiredGate::And(0, 1), // Gate 4 + WiredGate::And(2, 3), // Gate 5 + WiredGate::And(4, 5), // Gate 6 + ], + output_gates: vec![6], + }; + + for input in gen_inputs() { + if input[0][0] && input[1][0] && input[2][0] && input[3][0] { + continue; + } + assert_eq!(and.eval(&input).unwrap()[0], false, "on input: {:?}", input); + } + assert_eq!( + and.eval(&[vec![true], vec![true], vec![true], vec![true]]) + .unwrap()[0], + true, + ); + } + + #[test] + fn eval_xor_4() { + let xor = Circuit { + input_widths: vec![1, 1, 1, 1], + gates: vec![ + WiredGate::Input(0), // Gate 0 + WiredGate::Input(1), // Gate 1 + WiredGate::Input(2), // Gate 2 + WiredGate::Input(3), // Gate 3 + WiredGate::Xor(0, 1), // Gate 4 + WiredGate::Xor(2, 3), // Gate 5 + WiredGate::Xor(4, 5), // Gate 6 + ], + output_gates: vec![6], + }; + + for input in gen_inputs() { + assert_eq!( + xor.eval(&input).unwrap()[0], + parity(&input), + "on input: {:?}", + input + ); + } + } +} From ad67de4ce253017223d461e767febd28667e0e20 Mon Sep 17 00:00:00 2001 From: Jonas Schneider-Bensch Date: Thu, 8 Aug 2024 12:06:30 +0200 Subject: [PATCH 09/14] Fix plain circuit evaluation --- atlas-spec/mpc-engine/src/circuit.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atlas-spec/mpc-engine/src/circuit.rs b/atlas-spec/mpc-engine/src/circuit.rs index e467a37..86f5e71 100644 --- a/atlas-spec/mpc-engine/src/circuit.rs +++ b/atlas-spec/mpc-engine/src/circuit.rs @@ -292,7 +292,7 @@ impl Circuit { for gate in &self.gates { let output_bit = match gate { - WiredGate::Input(x) => wire_evaluations[*x], + WiredGate::Input(x) => continue, WiredGate::Xor(x, y) => wire_evaluations[*x] ^ wire_evaluations[*y], WiredGate::And(x, y) => wire_evaluations[*x] & wire_evaluations[*y], WiredGate::Not(x) => !wire_evaluations[*x], From 0d00b9d808801d63652a70d642c98ad5ec53f60d Mon Sep 17 00:00:00 2001 From: Jonas Schneider-Bensch Date: Thu, 8 Aug 2024 14:51:53 +0200 Subject: [PATCH 10/14] WIP, still failing --- atlas-spec/mpc-engine/src/primitives/kos.rs | 79 ++++++++++++++++--- .../mpc-engine/src/primitives/kos_base.rs | 6 +- atlas-spec/mpc-engine/src/utils.rs | 3 +- 3 files changed, 69 insertions(+), 19 deletions(-) diff --git a/atlas-spec/mpc-engine/src/primitives/kos.rs b/atlas-spec/mpc-engine/src/primitives/kos.rs index 8517b2f..5fe5c58 100644 --- a/atlas-spec/mpc-engine/src/primitives/kos.rs +++ b/atlas-spec/mpc-engine/src/primitives/kos.rs @@ -19,6 +19,9 @@ use super::{ mac::{xor_mac_width, Mac, MAC_LENGTH}, }; +// For light testing +const BASE_OT_LEN: usize = 128; + #[derive(Debug)] /// An Error in the KOS OT extension pub enum Error { @@ -48,7 +51,7 @@ fn CRF(sid: &[u8], input: &Mac, tweak: usize) -> Mac { let prk = hkdf_extract(b"", &ikm); let result = hkdf_expand(&prk, sid, MAC_LENGTH) .try_into() - .expect("should have received exactly `MAC_LENGHT` bytes"); + .expect("should have received exactly `MAC_LENGTH` bytes"); result } @@ -64,12 +67,11 @@ fn PRG(sid: &[u8], k: &[u8], len: usize) -> Vec { fn FRO2(sid: &[u8], matrix: &[Vec; 128]) -> Vec { let mut ikm = sid.to_vec(); let out_len = matrix[0].len(); - debug_assert_eq!(out_len % 16, 0); for column in matrix { ikm.extend_from_slice(column) } let prk = hkdf_extract(b"", &ikm); - let result_bytes = hkdf_expand(&prk, sid, out_len); + let result_bytes = hkdf_expand(&prk, sid, out_len * 8 * 16); let result = result_bytes .chunks_exact(16) .map(|chunk| { @@ -87,7 +89,7 @@ fn FRO2(sid: &[u8], matrix: &[Vec; 128]) -> Vec { fn challenge_selection(challenge: &[u128], selection_matrix: &[Vec; 128]) -> u128 { let mut result = 0u128; for i in 0..challenge.len() { - result ^= challenge[i] * packed_row(selection_matrix, i); + result = result.wrapping_add(challenge[i].wrapping_mul(packed_row(selection_matrix, i))) } result } @@ -112,7 +114,7 @@ fn kos_dst(sender_id: usize, receiver_id: usize) -> Vec { /// The message sent by the KOS15 Receiver in phase I of the protocol. #[derive(Debug)] pub struct KOSReceiverPhaseI { - base_ot_transfer: SenderTransfer<128>, + base_ot_transfer: SenderTransfer, D: [Vec; 128], u: u128, v: u128, @@ -121,20 +123,24 @@ pub struct KOSReceiverPhaseI { /// The KOS Receiver state. pub struct KOSReceiver { selection_bits: Vec, - base_sender: BaseOTSender<128>, + base_sender: BaseOTSender, M_columns: [Vec; 128], sid: Vec, } impl KOSReceiver { + /// `selection.len` must be a multiple of 8 pub(crate) fn phase_i( selection: &[bool], sender_phase_i: KOSSenderPhaseI, sid: &[u8], entropy: &mut Randomness, ) -> (Self, KOSReceiverPhaseI) { - let (base_sender, base_sender_transfer) = - kos_base::BaseOTSender::<128>::transfer(entropy, &sid, sender_phase_i.base_ot_choice); + let (base_sender, base_sender_transfer) = kos_base::BaseOTSender::::transfer( + entropy, + &sid, + sender_phase_i.base_ot_choice, + ); let tau = entropy.bytes(128 / 8).unwrap(); let mut r_prime = crate::utils::pack_bits(selection); @@ -148,6 +154,9 @@ impl KOSReceiver { crate::utils::xor_slices(&temp_result, &R_columns[i]) }); + debug_assert_eq!(M_columns[0].len(), R_columns[0].len()); + debug_assert_eq!(D_columns[0].len(), R_columns[0].len()); + let Chi = FRO2(&sid, &D_columns); let u = challenge_selection(&Chi, &M_columns); @@ -191,14 +200,14 @@ impl KOSReceiver { } pub(crate) struct KOSSender { - base_receiver: BaseOTReceiver<128>, + base_receiver: BaseOTReceiver, sid: Vec, } /// The message sent by the KOS15 Sender in phase I of the protocol. #[derive(Debug)] pub struct KOSSenderPhaseI { - base_ot_choice: ReceiverChoose<128>, + base_ot_choice: ReceiverChoose, } /// The message sent by the KOS15 Sender in phase II of the protocol. @@ -211,7 +220,7 @@ pub struct KOSSenderPhaseII { impl KOSSender { pub(crate) fn phase_i(sid: &[u8], entropy: &mut Randomness) -> (Self, KOSSenderPhaseI) { let (base_receiver, base_ot_choice) = - crate::primitives::kos_base::BaseOTReceiver::<128>::choose(entropy, &sid); + crate::primitives::kos_base::BaseOTReceiver::::choose(entropy, &sid); ( Self { @@ -222,13 +231,14 @@ impl KOSSender { ) } fn check_uvw(u: u128, v: u128, w: u128, s: u128) -> Result<(), Error> { - if w == u ^ (s * v) { + if w == u.wrapping_add(s.wrapping_mul(v)) { Ok(()) } else { Err(Error::Consistency) } } + /// `inputs.len()` must be a multiple of 8. fn phase_ii( &mut self, inputs: &[(Mac, Mac)], @@ -360,3 +370,48 @@ pub(crate) fn kos_send( )) } } + +#[test] +fn kos_simple() { + // pre-requisites + use rand::{thread_rng, RngCore}; + let sid = b"test"; + let mut rng = thread_rng(); + let mut entropy = [0u8; 100000]; + rng.fill_bytes(&mut entropy); + let mut entropy = Randomness::new(entropy.to_vec()); + + let selection = [true, false, true, false, true, false, true, false]; + let inputs = [ + ([0u8; 16], [1u8; 16]), + ([0u8; 16], [1u8; 16]), + ([0u8; 16], [1u8; 16]), + ([0u8; 16], [1u8; 16]), + ([0u8; 16], [1u8; 16]), + ([0u8; 16], [1u8; 16]), + ([0u8; 16], [1u8; 16]), + ([0u8; 16], [1u8; 16]), + ]; + + let (mut sender, sender_phase_i) = KOSSender::phase_i(sid, &mut entropy); + eprintln!("Sender Phase I"); + + let (receiver, receiver_phase_i) = + KOSReceiver::phase_i(&selection, sender_phase_i, sid, &mut entropy); + eprintln!("Receiver Phase I"); + + let sender_phase_ii = sender.phase_ii(&inputs, receiver_phase_i).unwrap(); + eprintln!("Sender Phase II"); + + let receiver_outputs = receiver.phase_ii(sender_phase_ii).unwrap(); + eprintln!("Receiver Phase II"); + + assert_eq!(receiver_outputs[0], [1u8; 16]); + assert_eq!(receiver_outputs[1], [0u8; 16]); + assert_eq!(receiver_outputs[2], [1u8; 16]); + assert_eq!(receiver_outputs[3], [0u8; 16]); + assert_eq!(receiver_outputs[4], [1u8; 16]); + assert_eq!(receiver_outputs[5], [0u8; 16]); + assert_eq!(receiver_outputs[6], [1u8; 16]); + assert_eq!(receiver_outputs[7], [0u8; 16]); +} diff --git a/atlas-spec/mpc-engine/src/primitives/kos_base.rs b/atlas-spec/mpc-engine/src/primitives/kos_base.rs index abe0587..60e5bcf 100644 --- a/atlas-spec/mpc-engine/src/primitives/kos_base.rs +++ b/atlas-spec/mpc-engine/src/primitives/kos_base.rs @@ -213,11 +213,7 @@ impl BaseOTSender { } } - fn parameters( - entropy: &mut Randomness, - sid: &[u8], - seed: &BaseOTSeed, - ) -> (Self, P256Point) { + fn parameters(entropy: &mut Randomness, sid: &[u8], seed: &BaseOTSeed) -> (Self, P256Point) { let T = FRO1(seed, sid); let r = random_scalar(entropy, sid).unwrap(); let negTr = p256::p256_point_mul(r, T).unwrap().neg(); diff --git a/atlas-spec/mpc-engine/src/utils.rs b/atlas-spec/mpc-engine/src/utils.rs index 8988101..d3bc936 100644 --- a/atlas-spec/mpc-engine/src/utils.rs +++ b/atlas-spec/mpc-engine/src/utils.rs @@ -45,13 +45,12 @@ pub(crate) fn ith_bit(i: usize, bytes: &[u8]) -> bool { /// /// We assume that `bits.len()` is a multiple of 8. pub(crate) fn pack_bits(bits: &[bool]) -> Vec { - let mut result = Vec::new(); let full_blocks = bits.len() / 8; let remainder = bits.len() % 8; debug_assert_eq!(remainder, 0); - + for i in 0..full_blocks { let mut current_byte = 0u8; for bit in 0..8 { From e4232d48aa186ca1e85a40970566a0b257ae07b3 Mon Sep 17 00:00:00 2001 From: Jonas Schneider-Bensch Date: Wed, 4 Sep 2024 18:45:56 +0200 Subject: [PATCH 11/14] Utility function tests --- atlas-spec/mpc-engine/src/utils.rs | 36 ++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/atlas-spec/mpc-engine/src/utils.rs b/atlas-spec/mpc-engine/src/utils.rs index d3bc936..46f3b58 100644 --- a/atlas-spec/mpc-engine/src/utils.rs +++ b/atlas-spec/mpc-engine/src/utils.rs @@ -70,3 +70,39 @@ pub(crate) fn xor_slices(left: &[u8], right: &[u8]) -> Vec { } result } + +#[test] +fn bit_packing() { + let bits1 = [false, false, false, false, false, false, false, true]; + let bits255 = [true, true, true, true, true, true, true, true]; + let bits1255 = [ + false, false, false, false, false, false, false, true, true, true, true, true, true, true, + true, true, + ]; + let bits2551 = [ + true, true, true, true, true, true, true, true, false, false, false, false, false, false, + false, true, + ]; + assert_eq!(pack_bits(&bits1), vec![1]); + assert_eq!(pack_bits(&bits255), vec![255]); + assert_eq!(pack_bits(&bits1255), vec![1, 255]); + assert_eq!(pack_bits(&bits2551), vec![255, 1]); +} + +#[test] +fn select_bits() { + assert_eq!(ith_bit(0, &[255, 1]), true); + assert_eq!(ith_bit(1, &[255, 1]), true); + assert_eq!(ith_bit(15, &[255, 1]), true); + assert_eq!(ith_bit(14, &[255, 1]), false); + assert_eq!(ith_bit(14, &[1, 1, 1, 1]), false); + assert_eq!(ith_bit(16, &[1, 1, 1, 1]), false); + assert_eq!(ith_bit(7, &[1, 1, 1, 1]), true); + assert_eq!(ith_bit(15, &[1, 1, 1, 1]), true); + assert_eq!(ith_bit(23, &[1, 1, 1, 1]), true); + assert_eq!(ith_bit(31, &[1, 1, 1, 1]), true); + assert_eq!(ith_bit(8, &[1, 255, 1, 1]), true); + assert_eq!(ith_bit(10, &[1, 255, 1, 1]), true); + assert_eq!(ith_bit(12, &[1, 255, 1, 1]), true); + assert_eq!(ith_bit(14, &[1, 255, 1, 1]), true); +} From 291783fd622f0d2e3cbb3abc6019c88277bb36fd Mon Sep 17 00:00:00 2001 From: Jonas Schneider-Bensch Date: Thu, 5 Sep 2024 16:35:06 +0200 Subject: [PATCH 12/14] Fix KOS OT extension --- atlas-spec/mpc-engine/src/primitives/kos.rs | 267 ++++++++++-------- .../mpc-engine/src/primitives/kos_base.rs | 41 +-- 2 files changed, 173 insertions(+), 135 deletions(-) diff --git a/atlas-spec/mpc-engine/src/primitives/kos.rs b/atlas-spec/mpc-engine/src/primitives/kos.rs index 5fe5c58..07b1001 100644 --- a/atlas-spec/mpc-engine/src/primitives/kos.rs +++ b/atlas-spec/mpc-engine/src/primitives/kos.rs @@ -10,8 +10,8 @@ use hmac::{hkdf_expand, hkdf_extract}; use crate::{ messages::SubMessage, - primitives::kos_base, - utils::{ith_bit, pack_bits}, + primitives::{kos_base, mac::zero_mac}, + utils::{ith_bit, pack_bits, xor_slices}, }; use super::{ @@ -19,7 +19,6 @@ use super::{ mac::{xor_mac_width, Mac, MAC_LENGTH}, }; -// For light testing const BASE_OT_LEN: usize = 128; #[derive(Debug)] @@ -86,20 +85,23 @@ fn FRO2(sid: &[u8], matrix: &[Vec; 128]) -> Vec { } /// This implements Xor_{j in [m+k]} (Chi_j * M_j). +/// `selection_matrix` is the whole matrix given as a vector of columns. fn challenge_selection(challenge: &[u128], selection_matrix: &[Vec; 128]) -> u128 { let mut result = 0u128; for i in 0..challenge.len() { - result = result.wrapping_add(challenge[i].wrapping_mul(packed_row(selection_matrix, i))) + result ^= challenge[i] & packed_row(selection_matrix, i); } result } -fn packed_row(matrix: &[Vec; 128], index: usize) -> u128 { +/// Pack all the bits in a row into a `u128`. +/// `matrix` is the whole matrix given as a vector of columns. +fn packed_row(matrix: &[Vec; 128], row_index: usize) -> u128 { let mut result = 0u128; - for i in 0..128 { - let b = ith_bit(index, &matrix[i]); + for column in 0..128 { + let b = ith_bit(row_index, &matrix[column]); if b { - result += 1 << (i as u128); + result |= 1 << (127 - column); } } result @@ -126,6 +128,7 @@ pub struct KOSReceiver { base_sender: BaseOTSender, M_columns: [Vec; 128], sid: Vec, + requested_len: usize, } impl KOSReceiver { @@ -135,58 +138,66 @@ impl KOSReceiver { sender_phase_i: KOSSenderPhaseI, sid: &[u8], entropy: &mut Randomness, - ) -> (Self, KOSReceiverPhaseI) { + ) -> Result<(Self, KOSReceiverPhaseI), Error> { + let requested_len = selection.len(); + // Extend selection lenght to next multiple of 8. + let mut selection_padded = vec![false; padded_len(selection.len())]; + selection_padded[0..selection.len()].copy_from_slice(&selection); + let selection = selection_padded.as_slice(); let (base_sender, base_sender_transfer) = kos_base::BaseOTSender::::transfer( entropy, &sid, sender_phase_i.base_ot_choice, ); - - let tau = entropy.bytes(128 / 8).unwrap(); - let mut r_prime = crate::utils::pack_bits(selection); - r_prime.extend_from_slice(&tau); - let M_columns: [Vec; 128] = - std::array::from_fn(|i| PRG(&sid, &base_sender.inputs[i].0, 16 + selection.len() / 8)); - let R_columns: [Vec; 128] = std::array::from_fn(|_i| r_prime.clone()); - let D_columns: [Vec; 128] = std::array::from_fn(|i| { - let prg_result = PRG(&sid, &base_sender.inputs[i].1, 16 + selection.len() / 8); - let temp_result = crate::utils::xor_slices(&M_columns[i], &prg_result); - crate::utils::xor_slices(&temp_result, &R_columns[i]) - }); - - debug_assert_eq!(M_columns[0].len(), R_columns[0].len()); - debug_assert_eq!(D_columns[0].len(), R_columns[0].len()); - - let Chi = FRO2(&sid, &D_columns); - - let u = challenge_selection(&Chi, &M_columns); - let v = challenge_selection(&Chi, &R_columns); - - ( - Self { - selection_bits: selection.to_owned(), - base_sender, - M_columns, - sid: sid.to_owned(), - }, - KOSReceiverPhaseI { - base_ot_transfer: base_sender_transfer, - D: D_columns, - u, - v, - }, - ) + match base_sender.inputs { + Some(base_sender_inputs) => { + let tau = entropy.bytes(128 / 8).unwrap(); + let mut r_prime = crate::utils::pack_bits(selection); + r_prime.extend_from_slice(&tau); + + let M_columns: [Vec; 128] = std::array::from_fn(|i| { + PRG(&sid, &base_sender_inputs[i].0, 16 + selection.len() / 8) + }); + + let R_columns: [Vec; 128] = std::array::from_fn(|_i| r_prime.clone()); + let D_columns: [Vec; 128] = std::array::from_fn(|i| { + let prg_result = PRG(&sid, &base_sender_inputs[i].1, 16 + selection.len() / 8); + let temp_result = crate::utils::xor_slices(&M_columns[i], &prg_result); + crate::utils::xor_slices(&temp_result, &R_columns[i]) + }); + + let Chi = FRO2(&sid, &D_columns); + + let u = challenge_selection(&Chi, &M_columns); + let v = challenge_selection(&Chi, &R_columns); + + Ok(( + Self { + selection_bits: selection.to_owned(), + base_sender, + M_columns, + sid: sid.to_owned(), + requested_len, + }, + KOSReceiverPhaseI { + base_ot_transfer: base_sender_transfer, + D: D_columns, + u, + v, + }, + )) + } + None => Err(Error::BaseOTError), + } } fn phase_ii(self, sender_phase_ii: KOSSenderPhaseII) -> Result, Error> { let mut results = Vec::new(); self.base_sender.verify(sender_phase_ii.base_ot_response)?; for (index, selection_bit) in self.selection_bits.iter().enumerate() { - let crf = CRF( - &self.sid, - &packed_row(&self.M_columns, index).to_be_bytes(), - index, - ); + let crf_input = packed_row(&self.M_columns, index).to_be_bytes(); + + let crf = CRF(&self.sid, &crf_input, index); let y = if *selection_bit { sender_phase_ii.ys[index].1 } else { @@ -195,6 +206,7 @@ impl KOSReceiver { let a = xor_mac_width(&y, &crf); results.push(a) } + results.truncate(self.requested_len); Ok(results) } } @@ -217,6 +229,14 @@ pub struct KOSSenderPhaseII { base_ot_response: ReceiverResponse, } +fn padded_len(len: usize) -> usize { + if len % 8 == 0 { + len + } else { + len + 8 - len % 8 + } +} + impl KOSSender { pub(crate) fn phase_i(sid: &[u8], entropy: &mut Randomness) -> (Self, KOSSenderPhaseI) { let (base_receiver, base_ot_choice) = @@ -230,8 +250,9 @@ impl KOSSender { KOSSenderPhaseI { base_ot_choice }, ) } + fn check_uvw(u: u128, v: u128, w: u128, s: u128) -> Result<(), Error> { - if w == u.wrapping_add(s.wrapping_mul(v)) { + if w == u ^ (s & v) { Ok(()) } else { Err(Error::Consistency) @@ -244,54 +265,65 @@ impl KOSSender { inputs: &[(Mac, Mac)], receiver_phase_i: KOSReceiverPhaseI, ) -> Result { + let mut inputs_padded = vec![(zero_mac(), zero_mac()); padded_len(inputs.len())]; + inputs_padded[0..inputs.len()].copy_from_slice(inputs); + let inputs = inputs_padded.as_slice(); + let (base_receiver_output, base_ot_response) = self .base_receiver - .response(receiver_phase_i.base_ot_transfer)?; + .response(receiver_phase_i.base_ot_transfer) + .unwrap(); + + + match self.base_receiver.selection_bits { + Some(base_selection_bits) => { + let Q_columns: [Vec; 128] = std::array::from_fn(|i| { + let mut result = + PRG(&self.sid, &base_receiver_output[i], 16 + inputs.len() / 8); + // the following is obviously secret-dependent timing + if base_selection_bits[i] { + result = crate::utils::xor_slices(&result, &receiver_phase_i.D[i]); + } + + result + }); + + let Chi = FRO2(&self.sid, &receiver_phase_i.D); + + let w = challenge_selection(&Chi, &Q_columns); + + let s = pack_bits(&base_selection_bits); + + let mut s_array = [0u8; 16]; + s_array.copy_from_slice(&s[..16]); - let Q_columns: [Vec; 128] = std::array::from_fn(|i| { - let mut result = PRG(&self.sid, &base_receiver_output[i], 16 + inputs.len() / 8); - // the following is obviously secret-dependent timing - if self.base_receiver.selection_bits[i] { - result = crate::utils::xor_slices(&result, &receiver_phase_i.D[i]); + Self::check_uvw( + receiver_phase_i.u, + receiver_phase_i.v, + w, + u128::from_be_bytes(s_array), + )?; + + let mut ys = Vec::new(); + for (index, (a_0, a_1)) in inputs.iter().enumerate() { + let crf_input_0 = packed_row(&Q_columns, index).to_be_bytes(); + let crf_input_1 = xor_slices(&packed_row(&Q_columns, index).to_be_bytes(), &s); + + let crf_0 = CRF(&self.sid, &crf_input_0, index); + let crf_1 = CRF(&self.sid, &crf_input_1.try_into().unwrap(), index); + + let y_0 = xor_mac_width(a_0, &crf_0); + let y_1 = xor_mac_width(a_1, &crf_1); + ys.push((y_0, y_1)) + } + + Ok(KOSSenderPhaseII { + ys, + base_ot_response, + }) } - result - }); - - let Chi = FRO2(&self.sid, &receiver_phase_i.D); - - let w = challenge_selection(&Chi, &Q_columns); - let s = pack_bits(&self.base_receiver.selection_bits); - let mut s_array = [0u8; 16]; - s_array.copy_from_slice(&s[..16]); - - Self::check_uvw( - receiver_phase_i.u, - receiver_phase_i.v, - w, - u128::from_be_bytes(s_array), - )?; - - let mut ys = Vec::new(); - for (index, (a_0, a_1)) in inputs.iter().enumerate() { - let crf_0 = CRF( - &self.sid, - &packed_row(&Q_columns, index).to_be_bytes(), - index, - ); - let crf_1 = CRF( - &self.sid, - &xor_mac_width(&packed_row(&Q_columns, index).to_be_bytes(), &s_array), - index, - ); - let y_0 = xor_mac_width(a_0, &crf_0); - let y_1 = xor_mac_width(a_1, &crf_1); - ys.push((y_0, y_1)) + None => Err(Error::BaseOTError), } - - Ok(KOSSenderPhaseII { - ys, - base_ot_response, - }) } } @@ -309,19 +341,19 @@ pub(crate) fn kos_receive( sender_id: usize, entropy: &mut Randomness, ) -> Result, crate::Error> { - let sid = kos_dst(sender_id, receiver_id); + let sid = kos_dst(receiver_id, sender_id); let sender_phase_i_msg = my_inbox.recv().unwrap(); if let SubMessage::KOSSenderPhaseI(sender_phase_i) = sender_phase_i_msg { - let (receiver, phase_i) = KOSReceiver::phase_i(selection, sender_phase_i, &sid, entropy); + let (receiver, phase_i) = + KOSReceiver::phase_i(selection, sender_phase_i, &sid, entropy).unwrap(); sender_address .send(SubMessage::KOSReceiverPhaseI(phase_i)) .unwrap(); let sender_phase_ii_msg = my_inbox.recv().unwrap(); if let SubMessage::KOSSenderPhaseII(sender_phase_ii) = sender_phase_ii_msg { - let outputs = receiver - .phase_ii(sender_phase_ii) - .map_err(|_| crate::Error::SubprotocolError)?; + let outputs = receiver.phase_ii(sender_phase_ii).unwrap(); + Ok(outputs) } else { Err(crate::Error::UnexpectedSubprotocolMessage( @@ -357,9 +389,7 @@ pub(crate) fn kos_send( .unwrap(); let receiver_phase_i_message = my_inbox.recv().unwrap(); if let SubMessage::KOSReceiverPhaseI(receiver_phase_i) = receiver_phase_i_message { - let phase_ii = kos_sender - .phase_ii(inputs, receiver_phase_i) - .map_err(|_| crate::Error::SubprotocolError)?; + let phase_ii = kos_sender.phase_ii(inputs, receiver_phase_i).unwrap(); receiver_address .send(SubMessage::KOSSenderPhaseII(phase_ii)) .unwrap(); @@ -381,37 +411,36 @@ fn kos_simple() { rng.fill_bytes(&mut entropy); let mut entropy = Randomness::new(entropy.to_vec()); - let selection = [true, false, true, false, true, false, true, false]; + let selection = [true, false, true, false, true, false, true]; let inputs = [ - ([0u8; 16], [1u8; 16]), - ([0u8; 16], [1u8; 16]), - ([0u8; 16], [1u8; 16]), - ([0u8; 16], [1u8; 16]), - ([0u8; 16], [1u8; 16]), - ([0u8; 16], [1u8; 16]), - ([0u8; 16], [1u8; 16]), - ([0u8; 16], [1u8; 16]), + ([2u8; 16], [1u8; 16]), + ([2u8; 16], [1u8; 16]), + ([2u8; 16], [1u8; 16]), + ([2u8; 16], [1u8; 16]), + ([2u8; 16], [1u8; 16]), + ([2u8; 16], [1u8; 16]), + ([2u8; 16], [1u8; 16]), ]; let (mut sender, sender_phase_i) = KOSSender::phase_i(sid, &mut entropy); - eprintln!("Sender Phase I"); + eprintln!("Sender Phase I complete"); let (receiver, receiver_phase_i) = - KOSReceiver::phase_i(&selection, sender_phase_i, sid, &mut entropy); - eprintln!("Receiver Phase I"); + KOSReceiver::phase_i(&selection, sender_phase_i, sid, &mut entropy).unwrap(); + eprintln!("Receiver Phase I complete"); let sender_phase_ii = sender.phase_ii(&inputs, receiver_phase_i).unwrap(); - eprintln!("Sender Phase II"); + eprintln!("Sender Phase II complete"); let receiver_outputs = receiver.phase_ii(sender_phase_ii).unwrap(); - eprintln!("Receiver Phase II"); + eprintln!("Receiver Phase II complete"); assert_eq!(receiver_outputs[0], [1u8; 16]); - assert_eq!(receiver_outputs[1], [0u8; 16]); + assert_eq!(receiver_outputs[1], [2u8; 16]); assert_eq!(receiver_outputs[2], [1u8; 16]); - assert_eq!(receiver_outputs[3], [0u8; 16]); + assert_eq!(receiver_outputs[3], [2u8; 16]); assert_eq!(receiver_outputs[4], [1u8; 16]); - assert_eq!(receiver_outputs[5], [0u8; 16]); + assert_eq!(receiver_outputs[5], [2u8; 16]); assert_eq!(receiver_outputs[6], [1u8; 16]); - assert_eq!(receiver_outputs[7], [0u8; 16]); + } diff --git a/atlas-spec/mpc-engine/src/primitives/kos_base.rs b/atlas-spec/mpc-engine/src/primitives/kos_base.rs index 60e5bcf..2180ac2 100644 --- a/atlas-spec/mpc-engine/src/primitives/kos_base.rs +++ b/atlas-spec/mpc-engine/src/primitives/kos_base.rs @@ -65,14 +65,14 @@ fn FRO4( pub(crate) struct BaseOTReceiver { sid: Vec, T: P256Point, - pub selection_bits: [bool; L], + pub selection_bits: Option<[bool; L]>, alphas: [P256Scalar; L], } pub(crate) struct BaseOTSender { sid: Vec, r: P256Scalar, - pub inputs: [([u8; 16], [u8; 16]); L], + pub inputs: Option<[([u8; 16], [u8; 16]); L]>, expected_answer: [u8; 16], negTr: P256Point, chall_hashes: [[u8; COMPUTATIONAL_SECURITY]; L], @@ -100,7 +100,7 @@ impl BaseOTReceiver { pub(crate) fn choose(entropy: &mut Randomness, sid: &[u8]) -> (Self, ReceiverChoose) { let (mut receiver, seed) = Self::parameters(entropy, sid); let (bits, messages) = receiver.messages(entropy); - receiver.selection_bits = bits; + receiver.selection_bits = Some(bits); (receiver, ReceiverChoose { seed, messages }) } @@ -109,9 +109,15 @@ impl BaseOTReceiver { transfer: SenderTransfer, ) -> Result<([[u8; 16]; L], ReceiverResponse), Error> { let messages = self.decrypt(transfer.seed); - let response = self.responses(&self.selection_bits, &messages, &transfer.challenge); - self.challenge_verification(&response, &transfer.gamma)?; - Ok((messages, ReceiverResponse { response })) + + match &self.selection_bits { + Some(selection_bits) => { + let response = self.responses(selection_bits, &messages, &transfer.challenge); + self.challenge_verification(&response, &transfer.gamma)?; + Ok((messages, ReceiverResponse { response })) + } + None => Err(Error::ReceiverAbort), + } } fn parameters(entropy: &mut Randomness, sid: &[u8]) -> (Self, BaseOTSeed) { @@ -125,7 +131,7 @@ impl BaseOTReceiver { Self { sid: sid.to_owned(), T, - selection_bits: [false; L], + selection_bits: None, alphas, }, seed_array, @@ -177,6 +183,7 @@ impl BaseOTReceiver { ) -> Result<(), Error> { let gamma_prime = FRO3(Ans, &self.sid); if gamma_prime != *gamma { + eprintln!("challenge verification failed"); return Err(Error::ReceiverAbort); } Ok(()) @@ -192,7 +199,7 @@ impl BaseOTSender { let (mut sender, seed) = Self::parameters(entropy, sid, &choice.seed); let inputs = sender.generate_inputs(choice.messages); let challenge = sender.challenges(&inputs); - sender.inputs = inputs; + sender.inputs = Some(inputs); let (expected_answer, gamma) = sender.proof(); sender.expected_answer = expected_answer; ( @@ -225,7 +232,7 @@ impl BaseOTSender { chall_hashes, r, negTr, - inputs: [([0u8; 16], [0u8; 16]); L], + inputs: None, expected_answer: [0u8; 16], }, z, @@ -278,7 +285,7 @@ fn xor_arrays(a: &[u8; L], b: &[u8; L]) -> [u8; L] { } #[test] -fn simple() { +fn kos_base_simple() { // pre-requisites use rand::{thread_rng, RngCore}; let sid = b"test"; @@ -287,22 +294,24 @@ fn simple() { rng.fill_bytes(&mut entropy); let mut entropy = Randomness::new(entropy.to_vec()); - let (mut receiver, choice_message) = BaseOTReceiver::<5>::choose(&mut entropy, sid); + let (receiver, choice_message) = BaseOTReceiver::<5>::choose(&mut entropy, sid); - let (mut sender, transfer_message) = - BaseOTSender::<5>::transfer(&mut entropy, sid, choice_message); + let (sender, transfer_message) = BaseOTSender::<5>::transfer(&mut entropy, sid, choice_message); let (receiver_outputs, response) = receiver.response(transfer_message).unwrap(); sender.verify(response).unwrap(); - for (i, selection_bit) in receiver.selection_bits.iter().enumerate() { + let selection_bits = receiver.selection_bits.unwrap(); + + for (i, selection_bit) in selection_bits.iter().enumerate() { + eprintln! {"{i}:\n\tInput 0: {:?}\n\tInput 1: {:?}\n\tSelection bit: {:?}\n\tOutput: {:?}", sender.inputs.unwrap()[i].0, sender.inputs.unwrap()[i].1, selection_bit, receiver_outputs[i]}; assert_eq!( receiver_outputs[i], if *selection_bit { - sender.inputs[i].1 + sender.inputs.unwrap()[i].1 } else { - sender.inputs[i].0 + sender.inputs.unwrap()[i].0 } ) } From 4dfb8390ed0ee5c7d25ebaf4afb856309d1bf322 Mon Sep 17 00:00:00 2001 From: Jonas Schneider-Bensch Date: Thu, 5 Sep 2024 16:46:06 +0200 Subject: [PATCH 13/14] Format --- atlas-spec/mpc-engine/src/primitives/kos.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/atlas-spec/mpc-engine/src/primitives/kos.rs b/atlas-spec/mpc-engine/src/primitives/kos.rs index 07b1001..39c6f8e 100644 --- a/atlas-spec/mpc-engine/src/primitives/kos.rs +++ b/atlas-spec/mpc-engine/src/primitives/kos.rs @@ -159,7 +159,7 @@ impl KOSReceiver { PRG(&sid, &base_sender_inputs[i].0, 16 + selection.len() / 8) }); - let R_columns: [Vec; 128] = std::array::from_fn(|_i| r_prime.clone()); + let R_columns: [Vec; 128] = std::array::from_fn(|_i| r_prime.clone()); let D_columns: [Vec; 128] = std::array::from_fn(|i| { let prg_result = PRG(&sid, &base_sender_inputs[i].1, 16 + selection.len() / 8); let temp_result = crate::utils::xor_slices(&M_columns[i], &prg_result); @@ -268,13 +268,12 @@ impl KOSSender { let mut inputs_padded = vec![(zero_mac(), zero_mac()); padded_len(inputs.len())]; inputs_padded[0..inputs.len()].copy_from_slice(inputs); let inputs = inputs_padded.as_slice(); - + let (base_receiver_output, base_ot_response) = self .base_receiver .response(receiver_phase_i.base_ot_transfer) .unwrap(); - match self.base_receiver.selection_bits { Some(base_selection_bits) => { let Q_columns: [Vec; 128] = std::array::from_fn(|i| { @@ -442,5 +441,4 @@ fn kos_simple() { assert_eq!(receiver_outputs[4], [1u8; 16]); assert_eq!(receiver_outputs[5], [2u8; 16]); assert_eq!(receiver_outputs[6], [1u8; 16]); - } From 3527ed700c22793024a9fc5b7031d4d1d6e91837 Mon Sep 17 00:00:00 2001 From: Jonas Schneider-Bensch Date: Thu, 5 Sep 2024 16:53:32 +0200 Subject: [PATCH 14/14] Remove incorrect import --- atlas-spec/mpc-engine/src/circuit.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/atlas-spec/mpc-engine/src/circuit.rs b/atlas-spec/mpc-engine/src/circuit.rs index 86f5e71..c0c4fed 100644 --- a/atlas-spec/mpc-engine/src/circuit.rs +++ b/atlas-spec/mpc-engine/src/circuit.rs @@ -350,8 +350,6 @@ impl Circuit { #[cfg(test)] mod tests { - use std::arch::x86_64::{_CMP_FALSE_OQ, _CMP_TRUE_UQ}; - use crate::utils::ith_bit; use super::*;