diff --git a/.github/workflows/hax.yml b/.github/workflows/hax.yml index 0ce89b4..d2a0197 100644 --- a/.github/workflows/hax.yml +++ b/.github/workflows/hax.yml @@ -14,7 +14,7 @@ jobs: uses: hacspec/hax-actions@main with: fstar: v2025.02.17 - hax_reference: hax-lib-v0.3.5 + hax_reference: tls-codec-panic-freedom - run: sudo apt-get install protobuf-compiler diff --git a/Cargo.lock b/Cargo.lock index 11b1a8d..798c5d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -65,7 +65,7 @@ version = "0.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94950e87ea550d6d68f1993f3e7bebc8cb7235157bff84337d46195c3aa0b3f0" dependencies = [ - "hax-lib", + "hax-lib 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", "pastey", "rand 0.9.1", ] @@ -232,7 +232,17 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74d9ba66d1739c68e0219b2b2238b5c4145f491ebf181b9c6ab561a19352ae86" dependencies = [ - "hax-lib-macros", + "hax-lib-macros 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", + "num-bigint", + "num-traits", +] + +[[package]] +name = "hax-lib" +version = "0.3.5" +source = "git+https://github.com/cryspen/hax?branch=tls-codec-panic-freedom#19fe3e7b7293d1f67ab887290adb4088ce0857b3" +dependencies = [ + "hax-lib-macros 0.3.5 (git+https://github.com/cryspen/hax?branch=tls-codec-panic-freedom)", "num-bigint", "num-traits", ] @@ -243,7 +253,19 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24ba777a231a58d1bce1d68313fa6b6afcc7966adef23d60f45b8a2b9b688bf1" dependencies = [ - "hax-lib-macros-types", + "hax-lib-macros-types 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", + "proc-macro-error2", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "hax-lib-macros" +version = "0.3.5" +source = "git+https://github.com/cryspen/hax?branch=tls-codec-panic-freedom#19fe3e7b7293d1f67ab887290adb4088ce0857b3" +dependencies = [ + "hax-lib-macros-types 0.3.5 (git+https://github.com/cryspen/hax?branch=tls-codec-panic-freedom)", "proc-macro-error2", "proc-macro2", "quote", @@ -263,6 +285,18 @@ dependencies = [ "uuid", ] +[[package]] +name = "hax-lib-macros-types" +version = "0.3.5" +source = "git+https://github.com/cryspen/hax?branch=tls-codec-panic-freedom#19fe3e7b7293d1f67ab887290adb4088ce0857b3" +dependencies = [ + "proc-macro2", + "quote", + "serde", + "serde_json", + "uuid", +] + [[package]] name = "heck" version = "0.5.0" @@ -355,7 +389,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d3b41dcbc21a5fb7efbbb5af7405b2e79c4bfe443924e90b13afc0080318d31" dependencies = [ "core-models", - "hax-lib", + "hax-lib 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -374,7 +408,7 @@ version = "0.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d368d3e8d6a74e277178d54921eca112a1e6b7837d7d8bc555091acb5d817f5" dependencies = [ - "hax-lib", + "hax-lib 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", "libcrux-intrinsics", "libcrux-platform", "libcrux-secrets", @@ -396,7 +430,7 @@ version = "0.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "332737e629fe6ba7547f5c0f90559eac865d5dbecf98138ffae8f16ab8cbe33f" dependencies = [ - "hax-lib", + "hax-lib 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -416,7 +450,7 @@ version = "0.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29d95de4257eafdfaf3bffecadb615219b0ca920c553722b3646d32dde76c797" dependencies = [ - "hax-lib", + "hax-lib 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", "libcrux-intrinsics", "libcrux-platform", ] @@ -875,7 +909,7 @@ dependencies = [ "curve25519-dalek", "displaydoc", "galois_field_2pm", - "hax-lib", + "hax-lib 0.3.5 (git+https://github.com/cryspen/hax?branch=tls-codec-panic-freedom)", "hkdf", "hmac", "libcrux-hkdf", diff --git a/Cargo.toml b/Cargo.toml index 112c7d6..bf9a4d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ rust-version = "1.83.0" [dependencies] curve25519-dalek = { version = "4.1.3", features = ["rand_core"] } displaydoc = "0.2" -hax-lib = "0.3.5" +hax-lib = {git = "https://github.com/cryspen/hax", branch = "tls-codec-panic-freedom"} hkdf = "0.12" libcrux-hkdf = "0.0.3" libcrux-hmac = "0.0.3" diff --git a/proofs/fstar/models/Libcrux_ml_kem.Ind_cca.Incremental.Types.fsti b/proofs/fstar/models/Libcrux_ml_kem.Ind_cca.Incremental.Types.fsti index 8c244e5..0b22293 100644 --- a/proofs/fstar/models/Libcrux_ml_kem.Ind_cca.Incremental.Types.fsti +++ b/proofs/fstar/models/Libcrux_ml_kem.Ind_cca.Incremental.Types.fsti @@ -137,7 +137,8 @@ val impl_4__deserialize type t_Ciphertext1 (v_LEN: usize) = { f_value:t_Array u8 v_LEN } /// The size of the ciphertext. -val impl_5__len: v_LEN: usize -> Prims.unit -> Prims.Pure usize Prims.l_True (fun _ -> Prims.l_True) +val impl_5__len: v_LEN: usize -> Prims.unit -> Prims.Pure usize Prims.l_True + (ensures fun res -> let res:usize = res in res =. v_LEN) /// The partial ciphertext c2 - second part. type t_Ciphertext2 (v_LEN: usize) = { f_value:t_Array u8 v_LEN } diff --git a/proofs/fstar/models/Libcrux_ml_kem.Mlkem768.Incremental.fsti b/proofs/fstar/models/Libcrux_ml_kem.Mlkem768.Incremental.fsti index 65121da..97051c7 100644 --- a/proofs/fstar/models/Libcrux_ml_kem.Mlkem768.Incremental.fsti +++ b/proofs/fstar/models/Libcrux_ml_kem.Mlkem768.Incremental.fsti @@ -15,7 +15,7 @@ val pk1_len: Prims.unit -> Prims.Pure usize Prims.l_True (ensures fun res -> let res:usize = res in res =. mk_usize 64) /// Get the size of the second public key in bytes. -val pk2_len: Prims.unit -> Prims.Pure usize Prims.l_True (fun _ -> Prims.l_True) +val pk2_len: Prims.unit -> Prims.Pure usize Prims.l_True (ensures fun res -> res =. mk_usize 1152) /// The size of a compressed key pair in bytes. let v_COMPRESSED_KEYPAIR_LEN: usize = Libcrux_ml_kem.Mlkem768.v_SECRET_KEY_SIZE diff --git a/src/chain.rs b/src/chain.rs index d509323..b0d8c55 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -217,6 +217,7 @@ impl ChainEpochDirection { } #[hax_lib::requires(next.len() > 0 && *ctr < u32::MAX)] + #[hax_lib::ensures(|_| *future(ctr) == ctr + 1)] fn next_key_internal(next: &mut [u8], ctr: &mut u32) -> (u32, [u8; 32]) { assert!(!next.is_empty()); *ctr += 1; @@ -258,11 +259,20 @@ impl ChainEpochDirection { // them all. self.prev.clear(); } - hax_lib::fstar!("admit ()"); // potential overflows in condition and body of the loop while at > self.ctr + 1 { + hax_lib::loop_invariant!(self.ctr < u32::MAX); + hax_lib::loop_decreases!(u32::MAX - self.ctr); + hax_lib::assume!(self.next.len() > 0); let k = Self::next_key_internal(&mut self.next, &mut self.ctr); + hax_lib::assume!( + params.max_ooo_keys_or_default() < 390451572 && self.ctr <= u32::MAX - 390451572 + ); // Only add keys into our history if we're not going to immediately GC them. if self.ctr + params.max_ooo_keys_or_default() >= at { + hax_lib::assume!( + params.trim_size() < 119304647 + && self.prev.data.len() <= KeyHistory::KEY_SIZE * params.trim_size() + ); self.prev.add(k, params); } } @@ -270,6 +280,8 @@ impl ChainEpochDirection { // want to throw away. self.prev.gc(self.ctr, params); + hax_lib::assume!(self.next.len() > 0); + Ok(Self::next_key_internal(&mut self.next, &mut self.ctr) .1 .to_vec()) diff --git a/src/encoding/gf.rs b/src/encoding/gf.rs index d4ecdfc..064e708 100644 --- a/src/encoding/gf.rs +++ b/src/encoding/gf.rs @@ -196,11 +196,15 @@ impl ops::Div<&GF16> for GF16 { } #[inline] -#[hax_lib::fstar::verification_status(lax)] // proving absence of overflow in loop condition is tricky +#[hax_lib::requires(into.len() <= usize::MAX - 2)] +#[hax_lib::ensures(|_| future(into).len() == into.len())] pub fn parallel_mult(a: GF16, into: &mut [GF16]) { - let mut i = 0; + let mut i: usize = 0; + #[cfg(hax)] + let l = into.len(); while i + 2 <= into.len() { - hax_lib::loop_decreases!(into.len() - i); + hax_lib::loop_decreases!(l - i); + hax_lib::loop_invariant!(into.len() == l && i <= l); (into[i].value, into[i + 1].value) = mul2_u16(a.value, into[i].value, into[i + 1].value); i += 2; } diff --git a/src/encoding/polynomial.rs b/src/encoding/polynomial.rs index 0989172..9e9a83e 100644 --- a/src/encoding/polynomial.rs +++ b/src/encoding/polynomial.rs @@ -71,6 +71,7 @@ pub const MAX_STORED_POLYNOMIAL_DEGREE_V1: usize = 35; pub const MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1: usize = 36; #[derive(Clone, PartialEq)] +#[hax_lib::attributes] pub(crate) struct Poly { // For Protocol V1 we interpolate at most 36 values, which produces a // degree 35 polynomial (with 36 coefficients). In an intermediate calculation @@ -78,6 +79,7 @@ pub(crate) struct Poly { // higher, thus we get the following constraint: // // coefficients.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 + 1 + #[hax_lib::refine(coefficients.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 + 1)] pub coefficients: Vec, } @@ -245,6 +247,7 @@ impl Poly { gf::parallel_mult(m, &mut self.coefficients); } + #[hax_lib::opaque] // zip fn compute_at(&self, x: GF16) -> GF16 { // Compute x^0 .. x^N let mut xs = Vec::with_capacity(self.coefficients.len()); @@ -265,6 +268,7 @@ impl Poly { } /// Internal function for lagrange_polynomial_from_complete_points. + #[hax_lib::opaque] // zip fn lagrange_sum(pts: &[Pt], polys: &[Poly]) -> Poly { let mut out = Poly::zero(pts.len()); for (pt, poly) in pts.iter().zip(polys.iter()) { @@ -279,6 +283,7 @@ impl Poly { /// range [0..pts.len()), return a polynomial that computes those points. #[hax_lib::requires(pts.len() == 0 || pts.len() == 1 || pts.len() == 3 || pts.len() == 5 || pts.len() == 30 || pts.len() == 34 || pts.len() == 36)] + #[hax_lib::opaque] // iterators fn from_complete_points(pts: &[Pt]) -> Result { for (i, pt) in pts.iter().enumerate() { if pt.x.value != i as u16 { @@ -314,7 +319,6 @@ impl Poly { Ok(Self::lagrange_sum(pts, &polys)) } - #[hax_lib::requires(self.coefficients.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1)] pub fn serialize(&self) -> Vec { // For Protocol V1 the polynomials that get serialized will always have // coefficients.len() <= MAX_STORED_POLYNOMIAL_DEGREE_V1 + 1 @@ -336,6 +340,7 @@ impl Poly { for coeff in serialized.chunks_exact(2) { coefficients.push(GF16::new(u16::from_be_bytes(coeff.try_into().unwrap()))); } + hax_lib::assume!(coefficients.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 + 1); Ok(Self { coefficients }) } } @@ -441,6 +446,7 @@ impl PolyConst { } fn to_poly(&self) -> Poly { + hax_lib::assume!(N <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 + 1); Poly { coefficients: self.coefficients.to_vec(), } @@ -495,11 +501,18 @@ const CHUNK_SIZE: usize = 32; // Number of polys or points that need to be tracked when using GF(2^16) with 2-byte elements pub const NUM_POLYS: usize = CHUNK_SIZE / 2; +#[derive(Clone)] +#[hax_lib::attributes] +pub struct Point { + #[hax_lib::refine(value.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1)] + pub value: Vec, +} + #[cfg_attr(test, derive(Clone))] pub(crate) enum EncoderState { // For 32B chunks the outer vector has length 16. // Using MLKEM-768 the inner vector has length <= MAX_STORED_POLYNOMIAL_DEGREE_V1 + 1 - Points([Vec; NUM_POLYS]), + Points([Point; NUM_POLYS]), // For 32B chunks this vector has length 16. Polys([Poly; NUM_POLYS]), } @@ -517,12 +530,6 @@ impl PolyEncoder { &self.s } - #[hax_lib::requires(match self.s { - EncoderState::Points(points) => hax_lib::Prop::from(points.len() == 16).and(hax_lib::prop::forall(|pts: &Vec| - hax_lib::prop::implies(points.contains(pts), pts.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1))), - EncoderState::Polys(polys) => hax_lib::Prop::from(polys.len() == 16).and(hax_lib::prop::forall(|poly: &Poly| - hax_lib::prop::implies(polys.contains(poly), poly.coefficients.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1))) - })] pub fn into_pb(self) -> proto::pq_ratchet::PolynomialEncoder { let mut out = proto::pq_ratchet::PolynomialEncoder { idx: self.idx, @@ -535,7 +542,7 @@ impl PolyEncoder { #[allow(clippy::needless_range_loop)] for j in 0..points.len() { hax_lib::loop_invariant!(|j: usize| out.pts.len() == j); - let pts = &points[j]; + let pts = &points[j].value; let mut v = Vec::::with_capacity(2 * pts.len()); #[allow(clippy::needless_range_loop)] for i in 0..pts.len() { @@ -563,16 +570,12 @@ impl PolyEncoder { if pb.pts.len() != NUM_POLYS { return Err(PolynomialError::SerializationInvalid); } - let mut out = core::array::from_fn(|_| Vec::::new()); + let mut out = core::array::from_fn(|_| Point { + value: Vec::::new(), + }); #[allow(clippy::needless_range_loop)] for i in 0..NUM_POLYS { - hax_lib::loop_invariant!(|_: usize| hax_lib::prop::forall(|pts: &Vec| { - hax_lib::prop::implies( - out.contains(pts), - pts.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1, - ) - })); let pts = &pb.pts[i]; if pts.len() % 2 != 0 { return Err(PolynomialError::SerializationInvalid); @@ -581,7 +584,8 @@ impl PolyEncoder { for pt in pts.chunks_exact(2) { v.push(GF16::new(u16::from_be_bytes(pt.try_into().unwrap()))); } - out[i] = v; + hax_lib::assume!(v.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1); + out[i] = Point { value: v }; } EncoderState::Points(out) } else if pb.polys.len() == NUM_POLYS { @@ -597,11 +601,12 @@ impl PolyEncoder { } #[requires(poly < 16)] + #[hax_lib::opaque] // iterators fn point_at(&mut self, poly: usize, idx: usize) -> GF16 { if let EncoderState::Points(ref pts) = self.s { hax_lib::assume!(pts.len() == 16); - if idx < pts[poly].len() { - return pts[poly][idx]; + if idx < pts[poly].value.len() { + return pts[poly].value[idx]; } // If we reach here, we've come to the first point we want to // find that wasn't part of the original set of points. We @@ -611,6 +616,7 @@ impl PolyEncoder { let mut polys: [Poly; NUM_POLYS] = core::array::from_fn(|_| Poly::zero(1)); for i in 0..NUM_POLYS { let pt_vec = pts[i] + .value .iter() .enumerate() .map(|(x, y)| Pt { @@ -648,12 +654,16 @@ impl PolyEncoder { } else if msg.len() > (1 << 16) * NUM_POLYS { return Err(PolynomialError::MessageLengthTooLong.into()); } - let mut pts: [Vec; NUM_POLYS] = - core::array::from_fn(|_| Vec::::with_capacity(msg.len() / 2)); + let mut pts: [Point; NUM_POLYS] = core::array::from_fn(|_| Point { + value: Vec::::with_capacity(msg.len() / 2), + }); for (i, c) in msg.chunks_exact(2).enumerate() { hax_lib::loop_invariant!(|_: usize| pts.len() >= NUM_POLYS); let poly = i % pts.len(); - pts[poly].push(GF16::new(((c[0] as u16) << 8) + (c[1] as u16))); + hax_lib::assume!(pts[poly].value.len() < MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1); + pts[poly] + .value + .push(GF16::new(((c[0] as u16) << 8) + (c[1] as u16))); } Ok(Self { idx: 0, @@ -738,6 +748,7 @@ impl PolyDecoder { } } + #[hax_lib::ensures(|res| hax_lib::implies(len_bytes % 2 == 0, res.is_ok() && res.unwrap().pts_needed == len_bytes / 2))] fn new_with_poly_count(len_bytes: usize, _polys: usize) -> Result { if len_bytes % 2 != 0 { return Err(PolynomialError::MessageLengthEven.into()); @@ -769,7 +780,9 @@ impl PolyDecoder { out } + #[hax_lib::ensures(|res| hax_lib::implies(pb.pts.len() == 16, res.is_ok() && res.unwrap().pts_needed == pb.pts_needed as usize))] pub fn from_pb(pb: proto::pq_ratchet::PolynomialDecoder) -> Result { + hax_lib::fstar!("admit ()"); // prove precondition with return in loop if pb.pts.len() != 16 { return Err(PolynomialError::SerializationInvalid); } @@ -795,14 +808,19 @@ impl PolyDecoder { #[hax_lib::attributes] impl Decoder for PolyDecoder { + #[hax_lib::ensures(|res| hax_lib::implies(len_bytes % 2 == 0, res.is_ok() && res.unwrap().pts_needed == len_bytes / 2))] fn new(len_bytes: usize) -> Result { Self::new_with_poly_count(len_bytes, 16) } - #[hax_lib::requires(self.pts.len() == 16)] + #[hax_lib::ensures(|_| future(self).pts_needed == self.pts_needed)] fn add_chunk(&mut self, chunk: &Chunk) { + #[cfg(hax)] + let initial_pts_needed = self.pts_needed; for i in 0usize..16 { - hax_lib::loop_invariant!(|_: usize| self.pts.len() == 16); + hax_lib::loop_invariant!( + |_: usize| self.pts.len() == 16 && self.pts_needed == initial_pts_needed + ); let total_idx = (chunk.index as usize) * 16 + i; let poly = total_idx % 16; let poly_idx = total_idx / 16; @@ -823,20 +841,28 @@ impl Decoder for PolyDecoder { } } - #[hax_lib::requires(self.pts_needed < MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1)] + #[hax_lib::requires(self.pts_needed < usize::MAX / 2)] + #[hax_lib::ensures(|res| match res { + Some(v) => v.len() / 2 == self.pts_needed, + None => true + })] fn decoded_message(&self) -> Option> { if self.is_complete { return None; } let mut points_vecs = Vec::with_capacity(self.pts.len()); + let mut ret_none = false; for i in 0..(self.pts.len()) { let pts = &self.pts[i]; if pts.len() < self.necessary_points(i) { - return None; + ret_none = true; } else { points_vecs.push(&pts[..self.necessary_points(i)]); } } + if ret_none { + return None; + } // We may or may not need these vectors of points (only if we need // to do a lagrange_interpolate call). For now, we just create // them regardless. However, we could optimize to lazily create them @@ -844,6 +870,7 @@ impl Decoder for PolyDecoder { let mut polys: [Option; 16] = core::array::from_fn(|_| None); let mut out: Vec = Vec::with_capacity(self.pts_needed * 2); for i in 0..self.pts_needed { + hax_lib::loop_invariant!(|i: usize| out.len() == i * 2); let poly = i % 16; let poly_idx = i / 16; let pt = Pt { diff --git a/src/incremental_mlkem768.rs b/src/incremental_mlkem768.rs index ba28e1b..8d7826b 100644 --- a/src/incremental_mlkem768.rs +++ b/src/incremental_mlkem768.rs @@ -30,7 +30,7 @@ pub fn ek_matches_header(ek: &EncapsulationKey, hdr: &Header) -> bool { } /// Generate a new keypair and associated header. -#[hax_lib::ensures(|result| result.hdr.len() == 64 && result.ek.len() == 1152 && result.dk.len() == 2400)] +#[hax_lib::ensures(|result| result.hdr.len() == HEADER_SIZE && result.ek.len() == ENCAPSULATION_KEY_SIZE && result.dk.len() == 2400)] pub fn generate(rng: &mut R) -> Keys { let mut randomness = [0u8; libcrux_ml_kem::KEY_GENERATION_SEED_SIZE]; rng.fill_bytes(&mut randomness); diff --git a/src/v1/chunked/send_ct.rs b/src/v1/chunked/send_ct.rs index a67a43a..66f0040 100644 --- a/src/v1/chunked/send_ct.rs +++ b/src/v1/chunked/send_ct.rs @@ -12,24 +12,30 @@ use crate::{Epoch, EpochSecret, Error}; use rand::{CryptoRng, Rng}; #[cfg_attr(test, derive(Clone))] +#[hax_lib::attributes] pub struct NoHeaderReceived { pub(super) uc: unchunked::NoHeaderReceived, // `receiving_hdr` only decodes messages of length `incremental_mlkem768::HEADER_SIZE + authenticator::Authenticator::MACSIZE` + #[hax_lib::refine(receiving_hdr.get_pts_needed() == (incremental_mlkem768::HEADER_SIZE + authenticator::Authenticator::MACSIZE) / 2)] pub(super) receiving_hdr: polynomial::PolyDecoder, } #[cfg_attr(test, derive(Clone))] +#[hax_lib::attributes] pub struct HeaderReceived { uc: unchunked::HeaderReceived, // `receiving_ek` only decodes messages of length `incremental_mlkem768::ENCAPSULATION_KEY_SIZE` + #[hax_lib::refine(receiving_ek.get_pts_needed() == incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2)] receiving_ek: polynomial::PolyDecoder, } #[cfg_attr(test, derive(Clone))] +#[hax_lib::attributes] pub struct Ct1Sampled { uc: unchunked::Ct1Sent, sending_ct1: polynomial::PolyEncoder, // `receiving_ek` only decodes messages of length `incremental_mlkem768::ENCAPSULATION_KEY_SIZE` + #[hax_lib::refine(receiving_ek.get_pts_needed() == incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2)] receiving_ek: polynomial::PolyDecoder, } @@ -40,9 +46,11 @@ pub struct EkReceivedCt1Sampled { } #[cfg_attr(test, derive(Clone))] +#[hax_lib::attributes] pub struct Ct1Acknowledged { uc: unchunked::Ct1Sent, // `receiving_ek` only decodes messages of length `incremental_mlkem768::ENCAPSULATION_KEY_SIZE` + #[hax_lib::refine(receiving_ek.get_pts_needed() == incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2)] receiving_ek: polynomial::PolyDecoder, } @@ -64,7 +72,6 @@ impl NoHeaderReceived { let decoder = polynomial::PolyDecoder::new( incremental_mlkem768::HEADER_SIZE + authenticator::Authenticator::MACSIZE, ); - hax_lib::assume!(decoder.is_ok()); NoHeaderReceived { uc: unchunked::NoHeaderReceived::new(auth_key), receiving_hdr: decoder.expect("should be able to decode header size"), @@ -83,15 +90,10 @@ impl NoHeaderReceived { mut receiving_hdr, } = self; receiving_hdr.add_chunk(chunk); - hax_lib::assume!( - receiving_hdr.get_pts_needed() <= polynomial::MAX_STORED_POLYNOMIAL_DEGREE_V1 - ); if let Some(mut hdr) = receiving_hdr.decoded_message() { - let mac: authenticator::Mac = hdr.drain(incremental_mlkem768::HEADER_SIZE..).collect(); - hax_lib::assume!(hdr.len() == 64 && mac.len() == authenticator::Authenticator::MACSIZE); + let mac: authenticator::Mac = hdr.split_off(incremental_mlkem768::HEADER_SIZE); let receiving_ek = polynomial::PolyDecoder::new(incremental_mlkem768::ENCAPSULATION_KEY_SIZE); - hax_lib::assume!(receiving_ek.is_ok()); Ok(NoHeaderReceivedRecvChunk::Done(HeaderReceived { uc: uc.recv_header(epoch, hdr, &mac)?, receiving_ek: receiving_ek.expect("should be able to decode EncapsulationKey size"), @@ -123,7 +125,6 @@ impl HeaderReceived { let (uc, ct1, epoch_secret) = uc.send_ct1(rng); let encoder = polynomial::PolyEncoder::encode_bytes(&ct1); - hax_lib::assume!(encoder.is_ok()); let mut sending_ct1 = encoder.expect("should be able to send CTSIZE"); let chunk = sending_ct1.next_chunk(); ( @@ -151,12 +152,11 @@ pub enum Ct1SampledRecvChunk { Done(Ct2Sampled), } +#[hax_lib::requires(ct2.len() == 128 && mac.len() == authenticator::Authenticator::MACSIZE)] fn send_ct2_encoder(ct2: &[u8], mac: &[u8]) -> polynomial::PolyEncoder { - hax_lib::assume!( - [ct2, mac].concat().len() % 2 == 0 - && [ct2, mac].concat().len() <= (1 << 16) * crate::encoding::polynomial::NUM_POLYS - ); - polynomial::PolyEncoder::encode_bytes(&[ct2, mac].concat()).expect("should be able to send ct2") + let mut msg = ct2.to_vec(); + msg.extend_from_slice(mac); + polynomial::PolyEncoder::encode_bytes(&msg).expect("should be able to send ct2") } #[hax_lib::attributes] @@ -174,19 +174,10 @@ impl Ct1Sampled { sending_ct1, } = self; receiving_ek.add_chunk(chunk); - hax_lib::assume!( - receiving_ek.get_pts_needed() <= polynomial::MAX_STORED_POLYNOMIAL_DEGREE_V1 - ); Ok(if let Some(decoded) = receiving_ek.decoded_message() { - hax_lib::assume!(decoded.len() == 1152); let uc = uc.recv_ek(epoch, decoded)?; if ct1_ack { let (uc, ct2, mac) = uc.send_ct2(); - hax_lib::assume!( - [ct2.clone(), mac.clone()].concat().len() % 2 == 0 - && [ct2.clone(), mac.clone()].concat().len() - <= (1 << 16) * crate::encoding::polynomial::NUM_POLYS - ); Ct1SampledRecvChunk::Done(Ct2Sampled { uc, sending_ct2: send_ct2_encoder(&ct2, &mac), @@ -272,11 +263,7 @@ impl Ct1Acknowledged { mut receiving_ek, } = self; receiving_ek.add_chunk(chunk); - hax_lib::assume!( - receiving_ek.get_pts_needed() <= polynomial::MAX_STORED_POLYNOMIAL_DEGREE_V1 - ); Ok(if let Some(decoded) = receiving_ek.decoded_message() { - hax_lib::assume!(decoded.len() == 1152); let uc = uc.recv_ek(epoch, decoded)?; let (uc, ct2, mac) = uc.send_ct2(); Ct1AcknowledgedRecvChunk::Done(Ct2Sampled { diff --git a/src/v1/chunked/send_ct/serialize.rs b/src/v1/chunked/send_ct/serialize.rs index 45f2229..1da63b4 100644 --- a/src/v1/chunked/send_ct/serialize.rs +++ b/src/v1/chunked/send_ct/serialize.rs @@ -6,6 +6,7 @@ use crate::encoding::polynomial; use crate::proto::pq_ratchet as pqrpb; use crate::v1::unchunked; +#[hax_lib::attributes] impl NoHeaderReceived { pub fn into_pb(self) -> pqrpb::v1_state::chunked::NoHeaderReceived { pqrpb::v1_state::chunked::NoHeaderReceived { @@ -15,6 +16,15 @@ impl NoHeaderReceived { } pub fn from_pb(pb: pqrpb::v1_state::chunked::NoHeaderReceived) -> Result { + if let Some(rhdr) = &pb.receiving_hdr { + if rhdr.pts_needed + != ((crate::incremental_mlkem768::HEADER_SIZE + + crate::authenticator::Authenticator::MACSIZE) + / 2) as u32 + { + return Err(Error::MsgDecode); + } + } Ok(Self { uc: unchunked::send_ct::NoHeaderReceived::from_pb(pb.uc.ok_or(Error::StateDecode)?)?, receiving_hdr: polynomial::PolyDecoder::from_pb( @@ -34,6 +44,11 @@ impl HeaderReceived { } pub fn from_pb(pb: pqrpb::v1_state::chunked::HeaderReceived) -> Result { + if let Some(d) = &pb.receiving_ek { + if d.pts_needed as usize != crate::incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2 { + return Err(Error::MsgDecode); + } + } Ok(Self { uc: unchunked::send_ct::HeaderReceived::from_pb(pb.uc.ok_or(Error::StateDecode)?)?, receiving_ek: polynomial::PolyDecoder::from_pb( @@ -46,19 +61,6 @@ impl HeaderReceived { impl Ct1Sampled { pub fn into_pb(self) -> pqrpb::v1_state::chunked::Ct1Sampled { - hax_lib::assume!(match self.sending_ct1.get_encoder_state() { - polynomial::EncoderState::Points(points) => hax_lib::prop::forall( - |pts: &Vec| hax_lib::prop::implies( - points.contains(pts), - pts.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - ) - ), - polynomial::EncoderState::Polys(polys) => - hax_lib::prop::forall(|poly: &polynomial::Poly| hax_lib::prop::implies( - polys.contains(poly), - poly.coefficients.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - )), - }); pqrpb::v1_state::chunked::Ct1Sampled { uc: Some(self.uc.into_pb()), sending_ct1: Some(self.sending_ct1.into_pb()), @@ -67,6 +69,11 @@ impl Ct1Sampled { } pub fn from_pb(pb: pqrpb::v1_state::chunked::Ct1Sampled) -> Result { + if let Some(d) = &pb.receiving_ek { + if d.pts_needed as usize != crate::incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2 { + return Err(Error::MsgDecode); + } + } Ok(Self { uc: unchunked::send_ct::Ct1Sent::from_pb(pb.uc.ok_or(Error::StateDecode)?)?, sending_ct1: polynomial::PolyEncoder::from_pb( @@ -83,19 +90,6 @@ impl Ct1Sampled { impl EkReceivedCt1Sampled { pub fn into_pb(self) -> pqrpb::v1_state::chunked::EkReceivedCt1Sampled { - hax_lib::assume!(match self.sending_ct1.get_encoder_state() { - polynomial::EncoderState::Points(points) => hax_lib::prop::forall( - |pts: &Vec| hax_lib::prop::implies( - points.contains(pts), - pts.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - ) - ), - polynomial::EncoderState::Polys(polys) => - hax_lib::prop::forall(|poly: &polynomial::Poly| hax_lib::prop::implies( - polys.contains(poly), - poly.coefficients.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - )), - }); pqrpb::v1_state::chunked::EkReceivedCt1Sampled { uc: Some(self.uc.into_pb()), sending_ct1: Some(self.sending_ct1.into_pb()), @@ -122,6 +116,11 @@ impl Ct1Acknowledged { } pub fn from_pb(pb: pqrpb::v1_state::chunked::Ct1Acknowledged) -> Result { + if let Some(d) = &pb.receiving_ek { + if d.pts_needed as usize != crate::incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2 { + return Err(Error::MsgDecode); + } + } Ok(Self { uc: unchunked::send_ct::Ct1Sent::from_pb(pb.uc.ok_or(Error::StateDecode)?)?, receiving_ek: polynomial::PolyDecoder::from_pb( @@ -134,19 +133,6 @@ impl Ct1Acknowledged { impl Ct2Sampled { pub fn into_pb(self) -> pqrpb::v1_state::chunked::Ct2Sampled { - hax_lib::assume!(match self.sending_ct2.get_encoder_state() { - polynomial::EncoderState::Points(points) => hax_lib::prop::forall( - |pts: &Vec| hax_lib::prop::implies( - points.contains(pts), - pts.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - ) - ), - polynomial::EncoderState::Polys(polys) => - hax_lib::prop::forall(|poly: &polynomial::Poly| hax_lib::prop::implies( - polys.contains(poly), - poly.coefficients.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - )), - }); pqrpb::v1_state::chunked::Ct2Sampled { uc: Some(self.uc.into_pb()), sending_ct2: Some(self.sending_ct2.into_pb()), diff --git a/src/v1/chunked/send_ek.rs b/src/v1/chunked/send_ek.rs index 10e5694..742c9d3 100644 --- a/src/v1/chunked/send_ek.rs +++ b/src/v1/chunked/send_ek.rs @@ -24,10 +24,12 @@ pub struct KeysSampled { } #[cfg_attr(test, derive(Clone))] +#[hax_lib::attributes] pub struct HeaderSent { uc: unchunked::EkSent, sending_ek: polynomial::PolyEncoder, // `receiving_ct1` only decodes messages of length `incremental_mlkem768::CIPHERTEXT1_SIZE` + #[hax_lib::refine(receiving_ct1.pts_needed == incremental_mlkem768::CIPHERTEXT1_SIZE / 2)] receiving_ct1: polynomial::PolyDecoder, } @@ -38,9 +40,11 @@ pub struct Ct1Received { } #[cfg_attr(test, derive(Clone))] +#[hax_lib::attributes] pub struct EkSentCt1Received { uc: unchunked::EkSentCt1Received, // `receiving_ct2` only decodes messages of length `incremental_mlkem768::CIPHERTEXT2_SIZE + authenticator::Authenticator::MACSIZE` + #[hax_lib::refine(receiving_ct2.pts_needed == (incremental_mlkem768::CIPHERTEXT2_SIZE + authenticator::Authenticator::MACSIZE) / 2)] receiving_ct2: polynomial::PolyDecoder, } @@ -52,10 +56,9 @@ impl KeysUnsampled { } pub fn send_hdr_chunk(self, rng: &mut R) -> (KeysSampled, Chunk) { - let (uc, hdr, mac) = self.uc.send_header(rng); - let to_send = [hdr, mac].concat(); + let (uc, mut to_send, mut mac) = self.uc.send_header(rng); + to_send.append(&mut mac); let encoder = polynomial::PolyEncoder::encode_bytes(&to_send); - hax_lib::assume!(encoder.is_ok()); let mut sending_hdr = encoder.expect("should be able to encode header size"); let chunk = sending_hdr.next_chunk(); (KeysSampled { uc, sending_hdr }, chunk) @@ -81,13 +84,10 @@ impl KeysSampled { pub fn recv_ct1_chunk(self, epoch: Epoch, chunk: &Chunk) -> HeaderSent { assert_eq!(epoch, self.uc.epoch); let decoder = polynomial::PolyDecoder::new(incremental_mlkem768::CIPHERTEXT1_SIZE); - hax_lib::assume!(decoder.is_ok()); let mut receiving_ct1 = decoder.expect("should be able to decode header size"); receiving_ct1.add_chunk(chunk); let (uc, ek) = self.uc.send_ek(); - let encoder = polynomial::PolyEncoder::encode_bytes(&ek); - hax_lib::assume!(encoder.is_ok()); let sending_ek = encoder.expect("should be able to send ek"); HeaderSent { uc, @@ -135,11 +135,7 @@ impl HeaderSent { mut receiving_ct1, } = self; receiving_ct1.add_chunk(chunk); - hax_lib::assume!( - receiving_ct1.get_pts_needed() <= polynomial::MAX_STORED_POLYNOMIAL_DEGREE_V1 - ); if let Some(decoded) = receiving_ct1.decoded_message() { - hax_lib::assume!(decoded.len() == 960); let uc = uc.recv_ct1(epoch, decoded); HeaderSentRecvChunk::Done(Ct1Received { uc, sending_ek }) } else { @@ -167,10 +163,13 @@ impl Ct1Received { #[hax_lib::requires(epoch == self.uc.epoch)] pub fn recv_ct2_chunk(self, epoch: Epoch, chunk: &Chunk) -> EkSentCt1Received { assert_eq!(epoch, self.uc.epoch); + hax_lib::assert!( + (incremental_mlkem768::CIPHERTEXT2_SIZE + authenticator::Authenticator::MACSIZE) % 2 + == 0 + ); let decoder = polynomial::PolyDecoder::new( incremental_mlkem768::CIPHERTEXT2_SIZE + authenticator::Authenticator::MACSIZE, ); - hax_lib::assume!(decoder.is_ok()); let mut receiving_ct2 = decoder.expect("should be able to decode ct2+mac size"); receiving_ct2.add_chunk(chunk); EkSentCt1Received { @@ -203,22 +202,12 @@ impl EkSentCt1Received { mut receiving_ct2, } = self; receiving_ct2.add_chunk(chunk); - hax_lib::assume!( - receiving_ct2.get_pts_needed() <= polynomial::MAX_STORED_POLYNOMIAL_DEGREE_V1 - ); if let Some(mut ct2) = receiving_ct2.decoded_message() { - let mac: authenticator::Mac = ct2 - .drain(incremental_mlkem768::CIPHERTEXT2_SIZE..) - .collect(); - hax_lib::assume!( - ct2.len() == incremental_mlkem768::CIPHERTEXT2_SIZE - && mac.len() == authenticator::Authenticator::MACSIZE - ); + let mac: authenticator::Mac = ct2.split_off(incremental_mlkem768::CIPHERTEXT2_SIZE); let (uc, sec) = uc.recv_ct2(ct2, mac)?; let decoder = polynomial::PolyDecoder::new( incremental_mlkem768::HEADER_SIZE + authenticator::Authenticator::MACSIZE, ); - hax_lib::assume!(decoder.is_ok()); Ok(EkSentCt1ReceivedRecvChunk::Done(( send_ct::NoHeaderReceived { uc, diff --git a/src/v1/chunked/send_ek/serialize.rs b/src/v1/chunked/send_ek/serialize.rs index af2734f..5e8f9dd 100644 --- a/src/v1/chunked/send_ek/serialize.rs +++ b/src/v1/chunked/send_ek/serialize.rs @@ -22,19 +22,6 @@ impl KeysUnsampled { impl KeysSampled { pub fn into_pb(self) -> pqrpb::v1_state::chunked::KeysSampled { - hax_lib::assume!(match self.sending_hdr.get_encoder_state() { - polynomial::EncoderState::Points(points) => hax_lib::prop::forall( - |pts: &Vec| hax_lib::prop::implies( - points.contains(pts), - pts.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - ) - ), - polynomial::EncoderState::Polys(polys) => - hax_lib::prop::forall(|poly: &polynomial::Poly| hax_lib::prop::implies( - polys.contains(poly), - poly.coefficients.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - )), - }); pqrpb::v1_state::chunked::KeysSampled { uc: Some(self.uc.into_pb()), sending_hdr: Some(self.sending_hdr.into_pb()), @@ -52,21 +39,9 @@ impl KeysSampled { } } +#[hax_lib::attributes] impl HeaderSent { pub fn into_pb(self) -> pqrpb::v1_state::chunked::HeaderSent { - hax_lib::assume!(match self.sending_ek.get_encoder_state() { - polynomial::EncoderState::Points(points) => hax_lib::prop::forall( - |pts: &Vec| hax_lib::prop::implies( - points.contains(pts), - pts.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - ) - ), - polynomial::EncoderState::Polys(polys) => - hax_lib::prop::forall(|poly: &polynomial::Poly| hax_lib::prop::implies( - polys.contains(poly), - poly.coefficients.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - )), - }); pqrpb::v1_state::chunked::HeaderSent { uc: Some(self.uc.into_pb()), sending_ek: Some(self.sending_ek.into_pb()), @@ -75,6 +50,11 @@ impl HeaderSent { } pub fn from_pb(pb: pqrpb::v1_state::chunked::HeaderSent) -> Result { + if let Some(d) = &pb.receiving_ct1 { + if d.pts_needed as usize != crate::incremental_mlkem768::CIPHERTEXT1_SIZE / 2 { + return Err(Error::MsgDecode); + } + } Ok(Self { uc: unchunked::send_ek::EkSent::from_pb(pb.uc.ok_or(Error::StateDecode)?)?, sending_ek: polynomial::PolyEncoder::from_pb(pb.sending_ek.ok_or(Error::StateDecode)?) @@ -89,19 +69,6 @@ impl HeaderSent { impl Ct1Received { pub fn into_pb(self) -> pqrpb::v1_state::chunked::Ct1Received { - hax_lib::assume!(match self.sending_ek.get_encoder_state() { - polynomial::EncoderState::Points(points) => hax_lib::prop::forall( - |pts: &Vec| hax_lib::prop::implies( - points.contains(pts), - pts.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - ) - ), - polynomial::EncoderState::Polys(polys) => - hax_lib::prop::forall(|poly: &polynomial::Poly| hax_lib::prop::implies( - polys.contains(poly), - poly.coefficients.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - )), - }); pqrpb::v1_state::chunked::Ct1Received { uc: Some(self.uc.into_pb()), sending_ek: Some(self.sending_ek.into_pb()), @@ -126,6 +93,14 @@ impl EkSentCt1Received { } pub fn from_pb(pb: pqrpb::v1_state::chunked::EkSentCt1Received) -> Result { + if let Some(d) = &pb.receiving_ct2 { + if d.pts_needed as usize + != (incremental_mlkem768::CIPHERTEXT2_SIZE + authenticator::Authenticator::MACSIZE) + / 2 + { + return Err(Error::MsgDecode); + } + } Ok(Self { uc: unchunked::send_ek::EkSentCt1Received::from_pb(pb.uc.ok_or(Error::StateDecode)?)?, receiving_ct2: polynomial::PolyDecoder::from_pb( diff --git a/src/v1/unchunked/send_ct.rs b/src/v1/unchunked/send_ct.rs index 470b18d..a568ef0 100644 --- a/src/v1/unchunked/send_ct.rs +++ b/src/v1/unchunked/send_ct.rs @@ -116,6 +116,7 @@ impl NoHeaderReceived { #[hax_lib::attributes] impl HeaderReceived { #[hax_lib::requires(self.hdr.len() == 64)] + #[hax_lib::ensures(|(_, ct1, _)| ct1.len() == 960)] pub fn send_ct1( self, rng: &mut R, diff --git a/src/v1/unchunked/send_ek.rs b/src/v1/unchunked/send_ek.rs index fa9b4a8..cb3d9aa 100644 --- a/src/v1/unchunked/send_ek.rs +++ b/src/v1/unchunked/send_ek.rs @@ -71,6 +71,7 @@ pub struct EkSentCt1Received { ct1: incremental_mlkem768::Ciphertext1, } +#[hax_lib::attributes] impl KeysUnsampled { pub fn new(auth_key: &[u8]) -> Self { Self { @@ -79,6 +80,7 @@ impl KeysUnsampled { } } + #[hax_lib::ensures(|(_, hdr, mac)| hdr.len() == incremental_mlkem768::HEADER_SIZE && mac.len() == authenticator::Authenticator::MACSIZE)] pub fn send_header( self, rng: &mut R, @@ -98,7 +100,9 @@ impl KeysUnsampled { } } +#[hax_lib::attributes] impl HeaderSent { + #[hax_lib::ensures(|(_, ek)| ek.len() == 1152)] pub fn send_ek(self) -> (EkSent, incremental_mlkem768::EncapsulationKey) { ( EkSent { @@ -131,7 +135,7 @@ impl EkSent { #[hax_lib::attributes] impl EkSentCt1Received { - #[hax_lib::requires(ct2.len() == 128 && mac.len() == authenticator::Authenticator::MACSIZE)] + #[hax_lib::requires(ct2.len() == incremental_mlkem768::CIPHERTEXT2_SIZE && mac.len() == authenticator::Authenticator::MACSIZE)] pub fn recv_ct2( self, ct2: incremental_mlkem768::Ciphertext2,