From 7b4b7c262bac5f28194ddd852097051cdbb973e5 Mon Sep 17 00:00:00 2001 From: Maxime Buyse Date: Wed, 22 Oct 2025 11:12:19 +0200 Subject: [PATCH 1/8] Improve proofs with refinement types. --- Cargo.lock | 47 ++++++++++--- Cargo.toml | 2 +- ...crux_ml_kem.Ind_cca.Incremental.Types.fsti | 3 +- .../Libcrux_ml_kem.Mlkem768.Incremental.fsti | 2 +- src/encoding/gf.rs | 10 ++- src/encoding/polynomial.rs | 69 ++++++++++++------- src/v1/chunked/send_ct.rs | 32 ++++----- src/v1/chunked/send_ct/serialize.rs | 44 ++---------- src/v1/chunked/send_ek.rs | 26 +++---- src/v1/chunked/send_ek/serialize.rs | 41 +---------- src/v1/chunked/states/serialize.rs | 13 ++++ src/v1/unchunked/send_ct.rs | 1 + src/v1/unchunked/send_ek.rs | 2 +- 13 files changed, 142 insertions(+), 150 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 11b1a8d..24bfa33 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -65,7 +65,7 @@ version = "0.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94950e87ea550d6d68f1993f3e7bebc8cb7235157bff84337d46195c3aa0b3f0" dependencies = [ - "hax-lib", + "hax-lib 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", "pastey", "rand 0.9.1", ] @@ -226,30 +226,61 @@ version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +[[package]] +name = "hax-lib" +version = "0.3.5" +dependencies = [ + "hax-lib-macros 0.3.5", + "num-bigint", + "num-traits", +] + [[package]] name = "hax-lib" version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74d9ba66d1739c68e0219b2b2238b5c4145f491ebf181b9c6ab561a19352ae86" dependencies = [ - "hax-lib-macros", + "hax-lib-macros 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", "num-bigint", "num-traits", ] +[[package]] +name = "hax-lib-macros" +version = "0.3.5" +dependencies = [ + "hax-lib-macros-types 0.3.5", + "proc-macro-error2", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "hax-lib-macros" version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24ba777a231a58d1bce1d68313fa6b6afcc7966adef23d60f45b8a2b9b688bf1" dependencies = [ - "hax-lib-macros-types", + "hax-lib-macros-types 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", "proc-macro-error2", "proc-macro2", "quote", "syn", ] +[[package]] +name = "hax-lib-macros-types" +version = "0.3.5" +dependencies = [ + "proc-macro2", + "quote", + "serde", + "serde_json", + "uuid", +] + [[package]] name = "hax-lib-macros-types" version = "0.3.5" @@ -355,7 +386,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d3b41dcbc21a5fb7efbbb5af7405b2e79c4bfe443924e90b13afc0080318d31" dependencies = [ "core-models", - "hax-lib", + "hax-lib 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -374,7 +405,7 @@ version = "0.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d368d3e8d6a74e277178d54921eca112a1e6b7837d7d8bc555091acb5d817f5" dependencies = [ - "hax-lib", + "hax-lib 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", "libcrux-intrinsics", "libcrux-platform", "libcrux-secrets", @@ -396,7 +427,7 @@ version = "0.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "332737e629fe6ba7547f5c0f90559eac865d5dbecf98138ffae8f16ab8cbe33f" dependencies = [ - "hax-lib", + "hax-lib 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -416,7 +447,7 @@ version = "0.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29d95de4257eafdfaf3bffecadb615219b0ca920c553722b3646d32dde76c797" dependencies = [ - "hax-lib", + "hax-lib 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", "libcrux-intrinsics", "libcrux-platform", ] @@ -875,7 +906,7 @@ dependencies = [ "curve25519-dalek", "displaydoc", "galois_field_2pm", - "hax-lib", + "hax-lib 0.3.5", "hkdf", "hmac", "libcrux-hkdf", diff --git a/Cargo.toml b/Cargo.toml index 112c7d6..510a0ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ rust-version = "1.83.0" [dependencies] curve25519-dalek = { version = "4.1.3", features = ["rand_core"] } displaydoc = "0.2" -hax-lib = "0.3.5" +hax-lib = {path = "../hax/hax-lib"} hkdf = "0.12" libcrux-hkdf = "0.0.3" libcrux-hmac = "0.0.3" diff --git a/proofs/fstar/models/Libcrux_ml_kem.Ind_cca.Incremental.Types.fsti b/proofs/fstar/models/Libcrux_ml_kem.Ind_cca.Incremental.Types.fsti index 8c244e5..0b22293 100644 --- a/proofs/fstar/models/Libcrux_ml_kem.Ind_cca.Incremental.Types.fsti +++ b/proofs/fstar/models/Libcrux_ml_kem.Ind_cca.Incremental.Types.fsti @@ -137,7 +137,8 @@ val impl_4__deserialize type t_Ciphertext1 (v_LEN: usize) = { f_value:t_Array u8 v_LEN } /// The size of the ciphertext. -val impl_5__len: v_LEN: usize -> Prims.unit -> Prims.Pure usize Prims.l_True (fun _ -> Prims.l_True) +val impl_5__len: v_LEN: usize -> Prims.unit -> Prims.Pure usize Prims.l_True + (ensures fun res -> let res:usize = res in res =. v_LEN) /// The partial ciphertext c2 - second part. type t_Ciphertext2 (v_LEN: usize) = { f_value:t_Array u8 v_LEN } diff --git a/proofs/fstar/models/Libcrux_ml_kem.Mlkem768.Incremental.fsti b/proofs/fstar/models/Libcrux_ml_kem.Mlkem768.Incremental.fsti index 65121da..97051c7 100644 --- a/proofs/fstar/models/Libcrux_ml_kem.Mlkem768.Incremental.fsti +++ b/proofs/fstar/models/Libcrux_ml_kem.Mlkem768.Incremental.fsti @@ -15,7 +15,7 @@ val pk1_len: Prims.unit -> Prims.Pure usize Prims.l_True (ensures fun res -> let res:usize = res in res =. mk_usize 64) /// Get the size of the second public key in bytes. -val pk2_len: Prims.unit -> Prims.Pure usize Prims.l_True (fun _ -> Prims.l_True) +val pk2_len: Prims.unit -> Prims.Pure usize Prims.l_True (ensures fun res -> res =. mk_usize 1152) /// The size of a compressed key pair in bytes. let v_COMPRESSED_KEYPAIR_LEN: usize = Libcrux_ml_kem.Mlkem768.v_SECRET_KEY_SIZE diff --git a/src/encoding/gf.rs b/src/encoding/gf.rs index d4ecdfc..064e708 100644 --- a/src/encoding/gf.rs +++ b/src/encoding/gf.rs @@ -196,11 +196,15 @@ impl ops::Div<&GF16> for GF16 { } #[inline] -#[hax_lib::fstar::verification_status(lax)] // proving absence of overflow in loop condition is tricky +#[hax_lib::requires(into.len() <= usize::MAX - 2)] +#[hax_lib::ensures(|_| future(into).len() == into.len())] pub fn parallel_mult(a: GF16, into: &mut [GF16]) { - let mut i = 0; + let mut i: usize = 0; + #[cfg(hax)] + let l = into.len(); while i + 2 <= into.len() { - hax_lib::loop_decreases!(into.len() - i); + hax_lib::loop_decreases!(l - i); + hax_lib::loop_invariant!(into.len() == l && i <= l); (into[i].value, into[i + 1].value) = mul2_u16(a.value, into[i].value, into[i + 1].value); i += 2; } diff --git a/src/encoding/polynomial.rs b/src/encoding/polynomial.rs index 0989172..729ed28 100644 --- a/src/encoding/polynomial.rs +++ b/src/encoding/polynomial.rs @@ -71,6 +71,7 @@ pub const MAX_STORED_POLYNOMIAL_DEGREE_V1: usize = 35; pub const MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1: usize = 36; #[derive(Clone, PartialEq)] +#[hax_lib::attributes] pub(crate) struct Poly { // For Protocol V1 we interpolate at most 36 values, which produces a // degree 35 polynomial (with 36 coefficients). In an intermediate calculation @@ -78,6 +79,7 @@ pub(crate) struct Poly { // higher, thus we get the following constraint: // // coefficients.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 + 1 + #[hax_lib::refine(coefficients.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 + 1)] pub coefficients: Vec, } @@ -245,6 +247,7 @@ impl Poly { gf::parallel_mult(m, &mut self.coefficients); } + #[hax_lib::opaque] // zip fn compute_at(&self, x: GF16) -> GF16 { // Compute x^0 .. x^N let mut xs = Vec::with_capacity(self.coefficients.len()); @@ -265,6 +268,7 @@ impl Poly { } /// Internal function for lagrange_polynomial_from_complete_points. + #[hax_lib::opaque] // zip fn lagrange_sum(pts: &[Pt], polys: &[Poly]) -> Poly { let mut out = Poly::zero(pts.len()); for (pt, poly) in pts.iter().zip(polys.iter()) { @@ -279,6 +283,7 @@ impl Poly { /// range [0..pts.len()), return a polynomial that computes those points. #[hax_lib::requires(pts.len() == 0 || pts.len() == 1 || pts.len() == 3 || pts.len() == 5 || pts.len() == 30 || pts.len() == 34 || pts.len() == 36)] + #[hax_lib::opaque] // iterators fn from_complete_points(pts: &[Pt]) -> Result { for (i, pt) in pts.iter().enumerate() { if pt.x.value != i as u16 { @@ -314,7 +319,6 @@ impl Poly { Ok(Self::lagrange_sum(pts, &polys)) } - #[hax_lib::requires(self.coefficients.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1)] pub fn serialize(&self) -> Vec { // For Protocol V1 the polynomials that get serialized will always have // coefficients.len() <= MAX_STORED_POLYNOMIAL_DEGREE_V1 + 1 @@ -336,6 +340,7 @@ impl Poly { for coeff in serialized.chunks_exact(2) { coefficients.push(GF16::new(u16::from_be_bytes(coeff.try_into().unwrap()))); } + hax_lib::assume!(coefficients.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 + 1); Ok(Self { coefficients }) } } @@ -441,6 +446,7 @@ impl PolyConst { } fn to_poly(&self) -> Poly { + hax_lib::assume!(N <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 + 1); Poly { coefficients: self.coefficients.to_vec(), } @@ -495,11 +501,18 @@ const CHUNK_SIZE: usize = 32; // Number of polys or points that need to be tracked when using GF(2^16) with 2-byte elements pub const NUM_POLYS: usize = CHUNK_SIZE / 2; +#[derive(Clone)] +#[hax_lib::attributes] +pub struct Point { + #[hax_lib::refine(value.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1)] + pub value: Vec, +} + #[cfg_attr(test, derive(Clone))] pub(crate) enum EncoderState { // For 32B chunks the outer vector has length 16. // Using MLKEM-768 the inner vector has length <= MAX_STORED_POLYNOMIAL_DEGREE_V1 + 1 - Points([Vec; NUM_POLYS]), + Points([Point; NUM_POLYS]), // For 32B chunks this vector has length 16. Polys([Poly; NUM_POLYS]), } @@ -517,12 +530,6 @@ impl PolyEncoder { &self.s } - #[hax_lib::requires(match self.s { - EncoderState::Points(points) => hax_lib::Prop::from(points.len() == 16).and(hax_lib::prop::forall(|pts: &Vec| - hax_lib::prop::implies(points.contains(pts), pts.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1))), - EncoderState::Polys(polys) => hax_lib::Prop::from(polys.len() == 16).and(hax_lib::prop::forall(|poly: &Poly| - hax_lib::prop::implies(polys.contains(poly), poly.coefficients.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1))) - })] pub fn into_pb(self) -> proto::pq_ratchet::PolynomialEncoder { let mut out = proto::pq_ratchet::PolynomialEncoder { idx: self.idx, @@ -535,7 +542,7 @@ impl PolyEncoder { #[allow(clippy::needless_range_loop)] for j in 0..points.len() { hax_lib::loop_invariant!(|j: usize| out.pts.len() == j); - let pts = &points[j]; + let pts = &points[j].value; let mut v = Vec::::with_capacity(2 * pts.len()); #[allow(clippy::needless_range_loop)] for i in 0..pts.len() { @@ -563,16 +570,12 @@ impl PolyEncoder { if pb.pts.len() != NUM_POLYS { return Err(PolynomialError::SerializationInvalid); } - let mut out = core::array::from_fn(|_| Vec::::new()); + let mut out = core::array::from_fn(|_| Point { + value: Vec::::new(), + }); #[allow(clippy::needless_range_loop)] for i in 0..NUM_POLYS { - hax_lib::loop_invariant!(|_: usize| hax_lib::prop::forall(|pts: &Vec| { - hax_lib::prop::implies( - out.contains(pts), - pts.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1, - ) - })); let pts = &pb.pts[i]; if pts.len() % 2 != 0 { return Err(PolynomialError::SerializationInvalid); @@ -581,7 +584,8 @@ impl PolyEncoder { for pt in pts.chunks_exact(2) { v.push(GF16::new(u16::from_be_bytes(pt.try_into().unwrap()))); } - out[i] = v; + hax_lib::assume!(v.len() <= MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1); + out[i] = Point { value: v }; } EncoderState::Points(out) } else if pb.polys.len() == NUM_POLYS { @@ -597,11 +601,12 @@ impl PolyEncoder { } #[requires(poly < 16)] + #[hax_lib::opaque] // iterators fn point_at(&mut self, poly: usize, idx: usize) -> GF16 { if let EncoderState::Points(ref pts) = self.s { hax_lib::assume!(pts.len() == 16); - if idx < pts[poly].len() { - return pts[poly][idx]; + if idx < pts[poly].value.len() { + return pts[poly].value[idx]; } // If we reach here, we've come to the first point we want to // find that wasn't part of the original set of points. We @@ -611,6 +616,7 @@ impl PolyEncoder { let mut polys: [Poly; NUM_POLYS] = core::array::from_fn(|_| Poly::zero(1)); for i in 0..NUM_POLYS { let pt_vec = pts[i] + .value .iter() .enumerate() .map(|(x, y)| Pt { @@ -648,12 +654,16 @@ impl PolyEncoder { } else if msg.len() > (1 << 16) * NUM_POLYS { return Err(PolynomialError::MessageLengthTooLong.into()); } - let mut pts: [Vec; NUM_POLYS] = - core::array::from_fn(|_| Vec::::with_capacity(msg.len() / 2)); + let mut pts: [Point; NUM_POLYS] = core::array::from_fn(|_| Point { + value: Vec::::with_capacity(msg.len() / 2), + }); for (i, c) in msg.chunks_exact(2).enumerate() { hax_lib::loop_invariant!(|_: usize| pts.len() >= NUM_POLYS); let poly = i % pts.len(); - pts[poly].push(GF16::new(((c[0] as u16) << 8) + (c[1] as u16))); + hax_lib::assume!(pts[poly].value.len() < MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1); + pts[poly] + .value + .push(GF16::new(((c[0] as u16) << 8) + (c[1] as u16))); } Ok(Self { idx: 0, @@ -738,6 +748,7 @@ impl PolyDecoder { } } + #[hax_lib::ensures(|res| hax_lib::implies(len_bytes % 2 == 0, res.is_ok() && res.unwrap().pts_needed == len_bytes / 2))] fn new_with_poly_count(len_bytes: usize, _polys: usize) -> Result { if len_bytes % 2 != 0 { return Err(PolynomialError::MessageLengthEven.into()); @@ -769,7 +780,9 @@ impl PolyDecoder { out } + #[hax_lib::ensures(|res| hax_lib::implies(pb.pts.len() == 16, res.is_ok() && res.unwrap().pts_needed == pb.pts_needed as usize))] pub fn from_pb(pb: proto::pq_ratchet::PolynomialDecoder) -> Result { + hax_lib::fstar!("admit ()"); // prove precondition with return in loop if pb.pts.len() != 16 { return Err(PolynomialError::SerializationInvalid); } @@ -795,14 +808,19 @@ impl PolyDecoder { #[hax_lib::attributes] impl Decoder for PolyDecoder { + #[hax_lib::ensures(|res| hax_lib::implies(len_bytes % 2 == 0, res.is_ok() && res.unwrap().pts_needed == len_bytes / 2))] fn new(len_bytes: usize) -> Result { Self::new_with_poly_count(len_bytes, 16) } - #[hax_lib::requires(self.pts.len() == 16)] + #[hax_lib::ensures(|_| future(self).pts_needed == self.pts_needed)] fn add_chunk(&mut self, chunk: &Chunk) { + #[cfg(hax)] + let initial_pts_needed = self.pts_needed; for i in 0usize..16 { - hax_lib::loop_invariant!(|_: usize| self.pts.len() == 16); + hax_lib::loop_invariant!( + |_: usize| self.pts.len() == 16 && self.pts_needed == initial_pts_needed + ); let total_idx = (chunk.index as usize) * 16 + i; let poly = total_idx % 16; let poly_idx = total_idx / 16; @@ -823,7 +841,7 @@ impl Decoder for PolyDecoder { } } - #[hax_lib::requires(self.pts_needed < MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1)] + #[hax_lib::requires(self.pts_needed < usize::MAX / 2)] fn decoded_message(&self) -> Option> { if self.is_complete { return None; @@ -844,6 +862,7 @@ impl Decoder for PolyDecoder { let mut polys: [Option; 16] = core::array::from_fn(|_| None); let mut out: Vec = Vec::with_capacity(self.pts_needed * 2); for i in 0..self.pts_needed { + hax_lib::loop_invariant!(out.len() == i * 2); let poly = i % 16; let poly_idx = i / 16; let pt = Pt { diff --git a/src/v1/chunked/send_ct.rs b/src/v1/chunked/send_ct.rs index a67a43a..f81d050 100644 --- a/src/v1/chunked/send_ct.rs +++ b/src/v1/chunked/send_ct.rs @@ -12,9 +12,11 @@ use crate::{Epoch, EpochSecret, Error}; use rand::{CryptoRng, Rng}; #[cfg_attr(test, derive(Clone))] +#[hax_lib::attributes] pub struct NoHeaderReceived { pub(super) uc: unchunked::NoHeaderReceived, // `receiving_hdr` only decodes messages of length `incremental_mlkem768::HEADER_SIZE + authenticator::Authenticator::MACSIZE` + #[hax_lib::refine(receiving_hdr.get_pts_needed() == (incremental_mlkem768::HEADER_SIZE + authenticator::Authenticator::MACSIZE) / 2)] pub(super) receiving_hdr: polynomial::PolyDecoder, } @@ -64,7 +66,6 @@ impl NoHeaderReceived { let decoder = polynomial::PolyDecoder::new( incremental_mlkem768::HEADER_SIZE + authenticator::Authenticator::MACSIZE, ); - hax_lib::assume!(decoder.is_ok()); NoHeaderReceived { uc: unchunked::NoHeaderReceived::new(auth_key), receiving_hdr: decoder.expect("should be able to decode header size"), @@ -83,15 +84,14 @@ impl NoHeaderReceived { mut receiving_hdr, } = self; receiving_hdr.add_chunk(chunk); - hax_lib::assume!( - receiving_hdr.get_pts_needed() <= polynomial::MAX_STORED_POLYNOMIAL_DEGREE_V1 - ); if let Some(mut hdr) = receiving_hdr.decoded_message() { let mac: authenticator::Mac = hdr.drain(incremental_mlkem768::HEADER_SIZE..).collect(); + // To remove this we can either: + // Add a model of `drain` in hax core lib and add the necessary pre/post to propagate this + // Switch to fixed length instead of Vec? hax_lib::assume!(hdr.len() == 64 && mac.len() == authenticator::Authenticator::MACSIZE); let receiving_ek = polynomial::PolyDecoder::new(incremental_mlkem768::ENCAPSULATION_KEY_SIZE); - hax_lib::assume!(receiving_ek.is_ok()); Ok(NoHeaderReceivedRecvChunk::Done(HeaderReceived { uc: uc.recv_header(epoch, hdr, &mac)?, receiving_ek: receiving_ek.expect("should be able to decode EncapsulationKey size"), @@ -123,7 +123,6 @@ impl HeaderReceived { let (uc, ct1, epoch_secret) = uc.send_ct1(rng); let encoder = polynomial::PolyEncoder::encode_bytes(&ct1); - hax_lib::assume!(encoder.is_ok()); let mut sending_ct1 = encoder.expect("should be able to send CTSIZE"); let chunk = sending_ct1.next_chunk(); ( @@ -151,11 +150,9 @@ pub enum Ct1SampledRecvChunk { Done(Ct2Sampled), } +#[hax_lib::requires(ct2.len() == 128 && mac.len() == authenticator::Authenticator::MACSIZE)] fn send_ct2_encoder(ct2: &[u8], mac: &[u8]) -> polynomial::PolyEncoder { - hax_lib::assume!( - [ct2, mac].concat().len() % 2 == 0 - && [ct2, mac].concat().len() <= (1 << 16) * crate::encoding::polynomial::NUM_POLYS - ); + hax_lib::assume!(polynomial::PolyEncoder::encode_bytes(&[ct2, mac].concat()).is_ok()); // needs model of concat polynomial::PolyEncoder::encode_bytes(&[ct2, mac].concat()).expect("should be able to send ct2") } @@ -175,18 +172,13 @@ impl Ct1Sampled { } = self; receiving_ek.add_chunk(chunk); hax_lib::assume!( - receiving_ek.get_pts_needed() <= polynomial::MAX_STORED_POLYNOMIAL_DEGREE_V1 + receiving_ek.pts_needed < polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 * 2 + 1 ); Ok(if let Some(decoded) = receiving_ek.decoded_message() { - hax_lib::assume!(decoded.len() == 1152); + hax_lib::assume!(decoded.len() == 1152); // Need to prove that receiving_ek.pts_needed is 576. This seems contradictory with the condition above let uc = uc.recv_ek(epoch, decoded)?; if ct1_ack { let (uc, ct2, mac) = uc.send_ct2(); - hax_lib::assume!( - [ct2.clone(), mac.clone()].concat().len() % 2 == 0 - && [ct2.clone(), mac.clone()].concat().len() - <= (1 << 16) * crate::encoding::polynomial::NUM_POLYS - ); Ct1SampledRecvChunk::Done(Ct2Sampled { uc, sending_ct2: send_ct2_encoder(&ct2, &mac), @@ -273,10 +265,10 @@ impl Ct1Acknowledged { } = self; receiving_ek.add_chunk(chunk); hax_lib::assume!( - receiving_ek.get_pts_needed() <= polynomial::MAX_STORED_POLYNOMIAL_DEGREE_V1 - ); + receiving_ek.get_pts_needed() < polynomial::MAX_STORED_POLYNOMIAL_DEGREE_V1 * 2 + 1 + ); // Could be done using a precondition or a refinement type Ok(if let Some(decoded) = receiving_ek.decoded_message() { - hax_lib::assume!(decoded.len() == 1152); + hax_lib::assume!(decoded.len() == 1152); // post-condition on `decoded_message`? hard because of returns let uc = uc.recv_ek(epoch, decoded)?; let (uc, ct2, mac) = uc.send_ct2(); Ct1AcknowledgedRecvChunk::Done(Ct2Sampled { diff --git a/src/v1/chunked/send_ct/serialize.rs b/src/v1/chunked/send_ct/serialize.rs index 45f2229..a8651cf 100644 --- a/src/v1/chunked/send_ct/serialize.rs +++ b/src/v1/chunked/send_ct/serialize.rs @@ -6,6 +6,7 @@ use crate::encoding::polynomial; use crate::proto::pq_ratchet as pqrpb; use crate::v1::unchunked; +#[hax_lib::attributes] impl NoHeaderReceived { pub fn into_pb(self) -> pqrpb::v1_state::chunked::NoHeaderReceived { pqrpb::v1_state::chunked::NoHeaderReceived { @@ -14,6 +15,10 @@ impl NoHeaderReceived { } } + #[hax_lib::requires(match pb.receiving_hdr { + Some(rhdr) => rhdr.pts_needed == ((incremental_mlkem768::HEADER_SIZE + authenticator::Authenticator::MACSIZE) / 2) as u32, + None => true} + )] pub fn from_pb(pb: pqrpb::v1_state::chunked::NoHeaderReceived) -> Result { Ok(Self { uc: unchunked::send_ct::NoHeaderReceived::from_pb(pb.uc.ok_or(Error::StateDecode)?)?, @@ -46,19 +51,6 @@ impl HeaderReceived { impl Ct1Sampled { pub fn into_pb(self) -> pqrpb::v1_state::chunked::Ct1Sampled { - hax_lib::assume!(match self.sending_ct1.get_encoder_state() { - polynomial::EncoderState::Points(points) => hax_lib::prop::forall( - |pts: &Vec| hax_lib::prop::implies( - points.contains(pts), - pts.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - ) - ), - polynomial::EncoderState::Polys(polys) => - hax_lib::prop::forall(|poly: &polynomial::Poly| hax_lib::prop::implies( - polys.contains(poly), - poly.coefficients.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - )), - }); pqrpb::v1_state::chunked::Ct1Sampled { uc: Some(self.uc.into_pb()), sending_ct1: Some(self.sending_ct1.into_pb()), @@ -83,19 +75,6 @@ impl Ct1Sampled { impl EkReceivedCt1Sampled { pub fn into_pb(self) -> pqrpb::v1_state::chunked::EkReceivedCt1Sampled { - hax_lib::assume!(match self.sending_ct1.get_encoder_state() { - polynomial::EncoderState::Points(points) => hax_lib::prop::forall( - |pts: &Vec| hax_lib::prop::implies( - points.contains(pts), - pts.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - ) - ), - polynomial::EncoderState::Polys(polys) => - hax_lib::prop::forall(|poly: &polynomial::Poly| hax_lib::prop::implies( - polys.contains(poly), - poly.coefficients.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - )), - }); pqrpb::v1_state::chunked::EkReceivedCt1Sampled { uc: Some(self.uc.into_pb()), sending_ct1: Some(self.sending_ct1.into_pb()), @@ -134,19 +113,6 @@ impl Ct1Acknowledged { impl Ct2Sampled { pub fn into_pb(self) -> pqrpb::v1_state::chunked::Ct2Sampled { - hax_lib::assume!(match self.sending_ct2.get_encoder_state() { - polynomial::EncoderState::Points(points) => hax_lib::prop::forall( - |pts: &Vec| hax_lib::prop::implies( - points.contains(pts), - pts.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - ) - ), - polynomial::EncoderState::Polys(polys) => - hax_lib::prop::forall(|poly: &polynomial::Poly| hax_lib::prop::implies( - polys.contains(poly), - poly.coefficients.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - )), - }); pqrpb::v1_state::chunked::Ct2Sampled { uc: Some(self.uc.into_pb()), sending_ct2: Some(self.sending_ct2.into_pb()), diff --git a/src/v1/chunked/send_ek.rs b/src/v1/chunked/send_ek.rs index 10e5694..059adb1 100644 --- a/src/v1/chunked/send_ek.rs +++ b/src/v1/chunked/send_ek.rs @@ -24,10 +24,12 @@ pub struct KeysSampled { } #[cfg_attr(test, derive(Clone))] +#[hax_lib::attributes] pub struct HeaderSent { uc: unchunked::EkSent, sending_ek: polynomial::PolyEncoder, // `receiving_ct1` only decodes messages of length `incremental_mlkem768::CIPHERTEXT1_SIZE` + #[hax_lib::refine(receiving_ct1.pts_needed == incremental_mlkem768::CIPHERTEXT1_SIZE / 2)] receiving_ct1: polynomial::PolyDecoder, } @@ -55,7 +57,7 @@ impl KeysUnsampled { let (uc, hdr, mac) = self.uc.send_header(rng); let to_send = [hdr, mac].concat(); let encoder = polynomial::PolyEncoder::encode_bytes(&to_send); - hax_lib::assume!(encoder.is_ok()); + hax_lib::assume!(encoder.is_ok()); // needs model of concat let mut sending_hdr = encoder.expect("should be able to encode header size"); let chunk = sending_hdr.next_chunk(); (KeysSampled { uc, sending_hdr }, chunk) @@ -80,14 +82,14 @@ impl KeysSampled { #[hax_lib::requires(epoch == self.uc.epoch)] pub fn recv_ct1_chunk(self, epoch: Epoch, chunk: &Chunk) -> HeaderSent { assert_eq!(epoch, self.uc.epoch); + //hax_lib::assert!(incremental_mlkem768::CIPHERTEXT1_SIZE % 2 == 0); let decoder = polynomial::PolyDecoder::new(incremental_mlkem768::CIPHERTEXT1_SIZE); - hax_lib::assume!(decoder.is_ok()); + //hax_lib::assert!(decoder.is_ok()); let mut receiving_ct1 = decoder.expect("should be able to decode header size"); receiving_ct1.add_chunk(chunk); let (uc, ek) = self.uc.send_ek(); - + hax_lib::assume!(ek.len() % 2 == 0); let encoder = polynomial::PolyEncoder::encode_bytes(&ek); - hax_lib::assume!(encoder.is_ok()); let sending_ek = encoder.expect("should be able to send ek"); HeaderSent { uc, @@ -137,9 +139,9 @@ impl HeaderSent { receiving_ct1.add_chunk(chunk); hax_lib::assume!( receiving_ct1.get_pts_needed() <= polynomial::MAX_STORED_POLYNOMIAL_DEGREE_V1 - ); + ); // Seems contradictory with necessary length returned by `decoded_message` (960) if let Some(decoded) = receiving_ct1.decoded_message() { - hax_lib::assume!(decoded.len() == 960); + hax_lib::assume!(decoded.len() == incremental_mlkem768::CIPHERTEXT1_SIZE); let uc = uc.recv_ct1(epoch, decoded); HeaderSentRecvChunk::Done(Ct1Received { uc, sending_ek }) } else { @@ -167,10 +169,13 @@ impl Ct1Received { #[hax_lib::requires(epoch == self.uc.epoch)] pub fn recv_ct2_chunk(self, epoch: Epoch, chunk: &Chunk) -> EkSentCt1Received { assert_eq!(epoch, self.uc.epoch); + hax_lib::assert!( + (incremental_mlkem768::CIPHERTEXT2_SIZE + authenticator::Authenticator::MACSIZE) % 2 + == 0 + ); let decoder = polynomial::PolyDecoder::new( incremental_mlkem768::CIPHERTEXT2_SIZE + authenticator::Authenticator::MACSIZE, ); - hax_lib::assume!(decoder.is_ok()); let mut receiving_ct2 = decoder.expect("should be able to decode ct2+mac size"); receiving_ct2.add_chunk(chunk); EkSentCt1Received { @@ -203,9 +208,7 @@ impl EkSentCt1Received { mut receiving_ct2, } = self; receiving_ct2.add_chunk(chunk); - hax_lib::assume!( - receiving_ct2.get_pts_needed() <= polynomial::MAX_STORED_POLYNOMIAL_DEGREE_V1 - ); + hax_lib::assume!(receiving_ct2.pts_needed < usize::MAX / 2); if let Some(mut ct2) = receiving_ct2.decoded_message() { let mac: authenticator::Mac = ct2 .drain(incremental_mlkem768::CIPHERTEXT2_SIZE..) @@ -213,12 +216,11 @@ impl EkSentCt1Received { hax_lib::assume!( ct2.len() == incremental_mlkem768::CIPHERTEXT2_SIZE && mac.len() == authenticator::Authenticator::MACSIZE - ); + ); // Needs model of drain let (uc, sec) = uc.recv_ct2(ct2, mac)?; let decoder = polynomial::PolyDecoder::new( incremental_mlkem768::HEADER_SIZE + authenticator::Authenticator::MACSIZE, ); - hax_lib::assume!(decoder.is_ok()); Ok(EkSentCt1ReceivedRecvChunk::Done(( send_ct::NoHeaderReceived { uc, diff --git a/src/v1/chunked/send_ek/serialize.rs b/src/v1/chunked/send_ek/serialize.rs index af2734f..843511e 100644 --- a/src/v1/chunked/send_ek/serialize.rs +++ b/src/v1/chunked/send_ek/serialize.rs @@ -22,19 +22,6 @@ impl KeysUnsampled { impl KeysSampled { pub fn into_pb(self) -> pqrpb::v1_state::chunked::KeysSampled { - hax_lib::assume!(match self.sending_hdr.get_encoder_state() { - polynomial::EncoderState::Points(points) => hax_lib::prop::forall( - |pts: &Vec| hax_lib::prop::implies( - points.contains(pts), - pts.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - ) - ), - polynomial::EncoderState::Polys(polys) => - hax_lib::prop::forall(|poly: &polynomial::Poly| hax_lib::prop::implies( - polys.contains(poly), - poly.coefficients.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - )), - }); pqrpb::v1_state::chunked::KeysSampled { uc: Some(self.uc.into_pb()), sending_hdr: Some(self.sending_hdr.into_pb()), @@ -52,21 +39,9 @@ impl KeysSampled { } } +#[hax_lib::attributes] impl HeaderSent { pub fn into_pb(self) -> pqrpb::v1_state::chunked::HeaderSent { - hax_lib::assume!(match self.sending_ek.get_encoder_state() { - polynomial::EncoderState::Points(points) => hax_lib::prop::forall( - |pts: &Vec| hax_lib::prop::implies( - points.contains(pts), - pts.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - ) - ), - polynomial::EncoderState::Polys(polys) => - hax_lib::prop::forall(|poly: &polynomial::Poly| hax_lib::prop::implies( - polys.contains(poly), - poly.coefficients.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - )), - }); pqrpb::v1_state::chunked::HeaderSent { uc: Some(self.uc.into_pb()), sending_ek: Some(self.sending_ek.into_pb()), @@ -74,6 +49,7 @@ impl HeaderSent { } } + #[hax_lib::requires(match pb.receiving_ct1 {Some(d) => d.pts_needed as usize == incremental_mlkem768::CIPHERTEXT1_SIZE / 2, None => true })] pub fn from_pb(pb: pqrpb::v1_state::chunked::HeaderSent) -> Result { Ok(Self { uc: unchunked::send_ek::EkSent::from_pb(pb.uc.ok_or(Error::StateDecode)?)?, @@ -89,19 +65,6 @@ impl HeaderSent { impl Ct1Received { pub fn into_pb(self) -> pqrpb::v1_state::chunked::Ct1Received { - hax_lib::assume!(match self.sending_ek.get_encoder_state() { - polynomial::EncoderState::Points(points) => hax_lib::prop::forall( - |pts: &Vec| hax_lib::prop::implies( - points.contains(pts), - pts.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - ) - ), - polynomial::EncoderState::Polys(polys) => - hax_lib::prop::forall(|poly: &polynomial::Poly| hax_lib::prop::implies( - polys.contains(poly), - poly.coefficients.len() <= polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 - )), - }); pqrpb::v1_state::chunked::Ct1Received { uc: Some(self.uc.into_pb()), sending_ek: Some(self.sending_ek.into_pb()), diff --git a/src/v1/chunked/states/serialize.rs b/src/v1/chunked/states/serialize.rs index dd0ded5..f9c61e8 100644 --- a/src/v1/chunked/states/serialize.rs +++ b/src/v1/chunked/states/serialize.rs @@ -54,6 +54,11 @@ impl States { Self::KeysSampled(send_ek::KeysSampled::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::HeaderSent(pb)) => { + hax_lib::assume!(match &pb.receiving_ct1 { + Some(d) => + d.pts_needed as usize == crate::incremental_mlkem768::CIPHERTEXT1_SIZE / 2, + None => true, + }); Self::HeaderSent(send_ek::HeaderSent::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::Ct1Received(pb)) => { @@ -65,6 +70,14 @@ impl States { // send_ct Some(pqrpb::v1_state::InnerState::NoHeaderReceived(pb)) => { + hax_lib::assume!(match &pb.receiving_hdr { + Some(rhdr) => + rhdr.pts_needed + == ((crate::incremental_mlkem768::HEADER_SIZE + + crate::authenticator::Authenticator::MACSIZE) + / 2) as u32, + None => true, + }); Self::NoHeaderReceived(send_ct::NoHeaderReceived::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::HeaderReceived(pb)) => { diff --git a/src/v1/unchunked/send_ct.rs b/src/v1/unchunked/send_ct.rs index 470b18d..a568ef0 100644 --- a/src/v1/unchunked/send_ct.rs +++ b/src/v1/unchunked/send_ct.rs @@ -116,6 +116,7 @@ impl NoHeaderReceived { #[hax_lib::attributes] impl HeaderReceived { #[hax_lib::requires(self.hdr.len() == 64)] + #[hax_lib::ensures(|(_, ct1, _)| ct1.len() == 960)] pub fn send_ct1( self, rng: &mut R, diff --git a/src/v1/unchunked/send_ek.rs b/src/v1/unchunked/send_ek.rs index fa9b4a8..6816332 100644 --- a/src/v1/unchunked/send_ek.rs +++ b/src/v1/unchunked/send_ek.rs @@ -131,7 +131,7 @@ impl EkSent { #[hax_lib::attributes] impl EkSentCt1Received { - #[hax_lib::requires(ct2.len() == 128 && mac.len() == authenticator::Authenticator::MACSIZE)] + #[hax_lib::requires(ct2.len() == incremental_mlkem768::CIPHERTEXT2_SIZE && mac.len() == authenticator::Authenticator::MACSIZE)] pub fn recv_ct2( self, ct2: incremental_mlkem768::Ciphertext2, From 6cbfbe7632e8b69707887209291a64a67f47f69e Mon Sep 17 00:00:00 2001 From: Maxime Buyse Date: Wed, 22 Oct 2025 11:38:12 +0200 Subject: [PATCH 2/8] Use split_off instead of drain. --- src/v1/chunked/send_ct.rs | 10 +++++----- src/v1/chunked/send_ek.rs | 11 +++++------ 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/v1/chunked/send_ct.rs b/src/v1/chunked/send_ct.rs index f81d050..a9a6c61 100644 --- a/src/v1/chunked/send_ct.rs +++ b/src/v1/chunked/send_ct.rs @@ -85,11 +85,11 @@ impl NoHeaderReceived { } = self; receiving_hdr.add_chunk(chunk); if let Some(mut hdr) = receiving_hdr.decoded_message() { - let mac: authenticator::Mac = hdr.drain(incremental_mlkem768::HEADER_SIZE..).collect(); - // To remove this we can either: - // Add a model of `drain` in hax core lib and add the necessary pre/post to propagate this - // Switch to fixed length instead of Vec? - hax_lib::assume!(hdr.len() == 64 && mac.len() == authenticator::Authenticator::MACSIZE); + hax_lib::assume!( + hdr.len() + == incremental_mlkem768::HEADER_SIZE + authenticator::Authenticator::MACSIZE + ); + let mac: authenticator::Mac = hdr.split_off(incremental_mlkem768::HEADER_SIZE); let receiving_ek = polynomial::PolyDecoder::new(incremental_mlkem768::ENCAPSULATION_KEY_SIZE); Ok(NoHeaderReceivedRecvChunk::Done(HeaderReceived { diff --git a/src/v1/chunked/send_ek.rs b/src/v1/chunked/send_ek.rs index 059adb1..3fbb02b 100644 --- a/src/v1/chunked/send_ek.rs +++ b/src/v1/chunked/send_ek.rs @@ -210,13 +210,12 @@ impl EkSentCt1Received { receiving_ct2.add_chunk(chunk); hax_lib::assume!(receiving_ct2.pts_needed < usize::MAX / 2); if let Some(mut ct2) = receiving_ct2.decoded_message() { - let mac: authenticator::Mac = ct2 - .drain(incremental_mlkem768::CIPHERTEXT2_SIZE..) - .collect(); hax_lib::assume!( - ct2.len() == incremental_mlkem768::CIPHERTEXT2_SIZE - && mac.len() == authenticator::Authenticator::MACSIZE - ); // Needs model of drain + ct2.len() + == incremental_mlkem768::CIPHERTEXT2_SIZE + + authenticator::Authenticator::MACSIZE + ); + let mac: authenticator::Mac = ct2.split_off(incremental_mlkem768::CIPHERTEXT2_SIZE); let (uc, sec) = uc.recv_ct2(ct2, mac)?; let decoder = polynomial::PolyDecoder::new( incremental_mlkem768::HEADER_SIZE + authenticator::Authenticator::MACSIZE, From ae680fb0684f3036f1bf2eae4b56e562fe67db08 Mon Sep 17 00:00:00 2001 From: Maxime Buyse Date: Wed, 22 Oct 2025 14:26:24 +0200 Subject: [PATCH 3/8] Replace 'concat' for proofs. --- src/incremental_mlkem768.rs | 2 +- src/v1/chunked/send_ct.rs | 5 +++-- src/v1/chunked/send_ek.rs | 7 ++----- src/v1/unchunked/send_ek.rs | 2 ++ 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/incremental_mlkem768.rs b/src/incremental_mlkem768.rs index ba28e1b..8d7826b 100644 --- a/src/incremental_mlkem768.rs +++ b/src/incremental_mlkem768.rs @@ -30,7 +30,7 @@ pub fn ek_matches_header(ek: &EncapsulationKey, hdr: &Header) -> bool { } /// Generate a new keypair and associated header. -#[hax_lib::ensures(|result| result.hdr.len() == 64 && result.ek.len() == 1152 && result.dk.len() == 2400)] +#[hax_lib::ensures(|result| result.hdr.len() == HEADER_SIZE && result.ek.len() == ENCAPSULATION_KEY_SIZE && result.dk.len() == 2400)] pub fn generate(rng: &mut R) -> Keys { let mut randomness = [0u8; libcrux_ml_kem::KEY_GENERATION_SEED_SIZE]; rng.fill_bytes(&mut randomness); diff --git a/src/v1/chunked/send_ct.rs b/src/v1/chunked/send_ct.rs index a9a6c61..c1739df 100644 --- a/src/v1/chunked/send_ct.rs +++ b/src/v1/chunked/send_ct.rs @@ -152,8 +152,9 @@ pub enum Ct1SampledRecvChunk { #[hax_lib::requires(ct2.len() == 128 && mac.len() == authenticator::Authenticator::MACSIZE)] fn send_ct2_encoder(ct2: &[u8], mac: &[u8]) -> polynomial::PolyEncoder { - hax_lib::assume!(polynomial::PolyEncoder::encode_bytes(&[ct2, mac].concat()).is_ok()); // needs model of concat - polynomial::PolyEncoder::encode_bytes(&[ct2, mac].concat()).expect("should be able to send ct2") + let mut msg = ct2.to_vec(); + msg.extend_from_slice(mac); + polynomial::PolyEncoder::encode_bytes(&msg).expect("should be able to send ct2") } #[hax_lib::attributes] diff --git a/src/v1/chunked/send_ek.rs b/src/v1/chunked/send_ek.rs index 3fbb02b..6f751e4 100644 --- a/src/v1/chunked/send_ek.rs +++ b/src/v1/chunked/send_ek.rs @@ -54,10 +54,9 @@ impl KeysUnsampled { } pub fn send_hdr_chunk(self, rng: &mut R) -> (KeysSampled, Chunk) { - let (uc, hdr, mac) = self.uc.send_header(rng); - let to_send = [hdr, mac].concat(); + let (uc, mut to_send, mut mac) = self.uc.send_header(rng); + to_send.append(&mut mac); let encoder = polynomial::PolyEncoder::encode_bytes(&to_send); - hax_lib::assume!(encoder.is_ok()); // needs model of concat let mut sending_hdr = encoder.expect("should be able to encode header size"); let chunk = sending_hdr.next_chunk(); (KeysSampled { uc, sending_hdr }, chunk) @@ -82,9 +81,7 @@ impl KeysSampled { #[hax_lib::requires(epoch == self.uc.epoch)] pub fn recv_ct1_chunk(self, epoch: Epoch, chunk: &Chunk) -> HeaderSent { assert_eq!(epoch, self.uc.epoch); - //hax_lib::assert!(incremental_mlkem768::CIPHERTEXT1_SIZE % 2 == 0); let decoder = polynomial::PolyDecoder::new(incremental_mlkem768::CIPHERTEXT1_SIZE); - //hax_lib::assert!(decoder.is_ok()); let mut receiving_ct1 = decoder.expect("should be able to decode header size"); receiving_ct1.add_chunk(chunk); let (uc, ek) = self.uc.send_ek(); diff --git a/src/v1/unchunked/send_ek.rs b/src/v1/unchunked/send_ek.rs index 6816332..1606327 100644 --- a/src/v1/unchunked/send_ek.rs +++ b/src/v1/unchunked/send_ek.rs @@ -71,6 +71,7 @@ pub struct EkSentCt1Received { ct1: incremental_mlkem768::Ciphertext1, } +#[hax_lib::attributes] impl KeysUnsampled { pub fn new(auth_key: &[u8]) -> Self { Self { @@ -79,6 +80,7 @@ impl KeysUnsampled { } } + #[hax_lib::ensures(|(_, hdr, mac)| hdr.len() == incremental_mlkem768::HEADER_SIZE && mac.len() == authenticator::Authenticator::MACSIZE)] pub fn send_header( self, rng: &mut R, From dd217f0e7eacfaae762a6b0e5cba6d1596606b83 Mon Sep 17 00:00:00 2001 From: Maxime Buyse Date: Wed, 22 Oct 2025 15:00:00 +0200 Subject: [PATCH 4/8] Right assumptions on pts_needed. --- src/v1/chunked/send_ct.rs | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/v1/chunked/send_ct.rs b/src/v1/chunked/send_ct.rs index c1739df..8cc6c46 100644 --- a/src/v1/chunked/send_ct.rs +++ b/src/v1/chunked/send_ct.rs @@ -88,7 +88,7 @@ impl NoHeaderReceived { hax_lib::assume!( hdr.len() == incremental_mlkem768::HEADER_SIZE + authenticator::Authenticator::MACSIZE - ); + ); // post on `decoded_message` let mac: authenticator::Mac = hdr.split_off(incremental_mlkem768::HEADER_SIZE); let receiving_ek = polynomial::PolyDecoder::new(incremental_mlkem768::ENCAPSULATION_KEY_SIZE); @@ -172,11 +172,9 @@ impl Ct1Sampled { sending_ct1, } = self; receiving_ek.add_chunk(chunk); - hax_lib::assume!( - receiving_ek.pts_needed < polynomial::MAX_INTERMEDIATE_POLYNOMIAL_DEGREE_V1 * 2 + 1 - ); + hax_lib::assume!(receiving_ek.pts_needed == 576); Ok(if let Some(decoded) = receiving_ek.decoded_message() { - hax_lib::assume!(decoded.len() == 1152); // Need to prove that receiving_ek.pts_needed is 576. This seems contradictory with the condition above + hax_lib::assume!(decoded.len() == 1152); // Could come from a post condition on `decoded_message` (problem: return in loop) let uc = uc.recv_ek(epoch, decoded)?; if ct1_ack { let (uc, ct2, mac) = uc.send_ct2(); @@ -265,9 +263,7 @@ impl Ct1Acknowledged { mut receiving_ek, } = self; receiving_ek.add_chunk(chunk); - hax_lib::assume!( - receiving_ek.get_pts_needed() < polynomial::MAX_STORED_POLYNOMIAL_DEGREE_V1 * 2 + 1 - ); // Could be done using a precondition or a refinement type + hax_lib::assume!(receiving_ek.get_pts_needed() == 576); // Could be done using a precondition or a refinement type Ok(if let Some(decoded) = receiving_ek.decoded_message() { hax_lib::assume!(decoded.len() == 1152); // post-condition on `decoded_message`? hard because of returns let uc = uc.recv_ek(epoch, decoded)?; From c5002febda9d662c952794c4615a35d11e95f71c Mon Sep 17 00:00:00 2001 From: Maxime Buyse Date: Wed, 22 Oct 2025 15:57:41 +0200 Subject: [PATCH 5/8] Remove admit thanks to new model of while loops. --- src/chain.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/chain.rs b/src/chain.rs index d509323..b0d8c55 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -217,6 +217,7 @@ impl ChainEpochDirection { } #[hax_lib::requires(next.len() > 0 && *ctr < u32::MAX)] + #[hax_lib::ensures(|_| *future(ctr) == ctr + 1)] fn next_key_internal(next: &mut [u8], ctr: &mut u32) -> (u32, [u8; 32]) { assert!(!next.is_empty()); *ctr += 1; @@ -258,11 +259,20 @@ impl ChainEpochDirection { // them all. self.prev.clear(); } - hax_lib::fstar!("admit ()"); // potential overflows in condition and body of the loop while at > self.ctr + 1 { + hax_lib::loop_invariant!(self.ctr < u32::MAX); + hax_lib::loop_decreases!(u32::MAX - self.ctr); + hax_lib::assume!(self.next.len() > 0); let k = Self::next_key_internal(&mut self.next, &mut self.ctr); + hax_lib::assume!( + params.max_ooo_keys_or_default() < 390451572 && self.ctr <= u32::MAX - 390451572 + ); // Only add keys into our history if we're not going to immediately GC them. if self.ctr + params.max_ooo_keys_or_default() >= at { + hax_lib::assume!( + params.trim_size() < 119304647 + && self.prev.data.len() <= KeyHistory::KEY_SIZE * params.trim_size() + ); self.prev.add(k, params); } } @@ -270,6 +280,8 @@ impl ChainEpochDirection { // want to throw away. self.prev.gc(self.ctr, params); + hax_lib::assume!(self.next.len() > 0); + Ok(Self::next_key_internal(&mut self.next, &mut self.ctr) .1 .to_vec()) From c7b58ec521ba36fcfe39cd1b562c2fba439e1681 Mon Sep 17 00:00:00 2001 From: Maxime Buyse Date: Wed, 22 Oct 2025 15:59:51 +0200 Subject: [PATCH 6/8] Temporarily use hax from the right branch. --- .github/workflows/hax.yml | 2 +- Cargo.lock | 25 ++++++++++++++----------- Cargo.toml | 2 +- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/.github/workflows/hax.yml b/.github/workflows/hax.yml index 0ce89b4..d2a0197 100644 --- a/.github/workflows/hax.yml +++ b/.github/workflows/hax.yml @@ -14,7 +14,7 @@ jobs: uses: hacspec/hax-actions@main with: fstar: v2025.02.17 - hax_reference: hax-lib-v0.3.5 + hax_reference: tls-codec-panic-freedom - run: sudo apt-get install protobuf-compiler diff --git a/Cargo.lock b/Cargo.lock index 24bfa33..798c5d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -229,8 +229,10 @@ checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" [[package]] name = "hax-lib" version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74d9ba66d1739c68e0219b2b2238b5c4145f491ebf181b9c6ab561a19352ae86" dependencies = [ - "hax-lib-macros 0.3.5", + "hax-lib-macros 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", "num-bigint", "num-traits", ] @@ -238,10 +240,9 @@ dependencies = [ [[package]] name = "hax-lib" version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74d9ba66d1739c68e0219b2b2238b5c4145f491ebf181b9c6ab561a19352ae86" +source = "git+https://github.com/cryspen/hax?branch=tls-codec-panic-freedom#19fe3e7b7293d1f67ab887290adb4088ce0857b3" dependencies = [ - "hax-lib-macros 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", + "hax-lib-macros 0.3.5 (git+https://github.com/cryspen/hax?branch=tls-codec-panic-freedom)", "num-bigint", "num-traits", ] @@ -249,8 +250,10 @@ dependencies = [ [[package]] name = "hax-lib-macros" version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24ba777a231a58d1bce1d68313fa6b6afcc7966adef23d60f45b8a2b9b688bf1" dependencies = [ - "hax-lib-macros-types 0.3.5", + "hax-lib-macros-types 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", "proc-macro-error2", "proc-macro2", "quote", @@ -260,10 +263,9 @@ dependencies = [ [[package]] name = "hax-lib-macros" version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24ba777a231a58d1bce1d68313fa6b6afcc7966adef23d60f45b8a2b9b688bf1" +source = "git+https://github.com/cryspen/hax?branch=tls-codec-panic-freedom#19fe3e7b7293d1f67ab887290adb4088ce0857b3" dependencies = [ - "hax-lib-macros-types 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", + "hax-lib-macros-types 0.3.5 (git+https://github.com/cryspen/hax?branch=tls-codec-panic-freedom)", "proc-macro-error2", "proc-macro2", "quote", @@ -273,6 +275,8 @@ dependencies = [ [[package]] name = "hax-lib-macros-types" version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "867e19177d7425140b417cd27c2e05320e727ee682e98368f88b7194e80ad515" dependencies = [ "proc-macro2", "quote", @@ -284,8 +288,7 @@ dependencies = [ [[package]] name = "hax-lib-macros-types" version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "867e19177d7425140b417cd27c2e05320e727ee682e98368f88b7194e80ad515" +source = "git+https://github.com/cryspen/hax?branch=tls-codec-panic-freedom#19fe3e7b7293d1f67ab887290adb4088ce0857b3" dependencies = [ "proc-macro2", "quote", @@ -906,7 +909,7 @@ dependencies = [ "curve25519-dalek", "displaydoc", "galois_field_2pm", - "hax-lib 0.3.5", + "hax-lib 0.3.5 (git+https://github.com/cryspen/hax?branch=tls-codec-panic-freedom)", "hkdf", "hmac", "libcrux-hkdf", diff --git a/Cargo.toml b/Cargo.toml index 510a0ec..bf9a4d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ rust-version = "1.83.0" [dependencies] curve25519-dalek = { version = "4.1.3", features = ["rand_core"] } displaydoc = "0.2" -hax-lib = {path = "../hax/hax-lib"} +hax-lib = {git = "https://github.com/cryspen/hax", branch = "tls-codec-panic-freedom"} hkdf = "0.12" libcrux-hkdf = "0.0.3" libcrux-hmac = "0.0.3" From a126d4a6bb81c2953533da7c664f9331c00c5412 Mon Sep 17 00:00:00 2001 From: Maxime Buyse Date: Tue, 28 Oct 2025 17:11:44 +0100 Subject: [PATCH 7/8] Replace hax assumes by real checks. --- src/encoding/polynomial.rs | 12 +++++-- src/v1/chunked/send_ct.rs | 14 ++++----- src/v1/chunked/send_ct/serialize.rs | 15 +++++++++ src/v1/chunked/states/serialize.rs | 49 +++++++++++++++++++++-------- 4 files changed, 67 insertions(+), 23 deletions(-) diff --git a/src/encoding/polynomial.rs b/src/encoding/polynomial.rs index 729ed28..9e9a83e 100644 --- a/src/encoding/polynomial.rs +++ b/src/encoding/polynomial.rs @@ -842,19 +842,27 @@ impl Decoder for PolyDecoder { } #[hax_lib::requires(self.pts_needed < usize::MAX / 2)] + #[hax_lib::ensures(|res| match res { + Some(v) => v.len() / 2 == self.pts_needed, + None => true + })] fn decoded_message(&self) -> Option> { if self.is_complete { return None; } let mut points_vecs = Vec::with_capacity(self.pts.len()); + let mut ret_none = false; for i in 0..(self.pts.len()) { let pts = &self.pts[i]; if pts.len() < self.necessary_points(i) { - return None; + ret_none = true; } else { points_vecs.push(&pts[..self.necessary_points(i)]); } } + if ret_none { + return None; + } // We may or may not need these vectors of points (only if we need // to do a lagrange_interpolate call). For now, we just create // them regardless. However, we could optimize to lazily create them @@ -862,7 +870,7 @@ impl Decoder for PolyDecoder { let mut polys: [Option; 16] = core::array::from_fn(|_| None); let mut out: Vec = Vec::with_capacity(self.pts_needed * 2); for i in 0..self.pts_needed { - hax_lib::loop_invariant!(out.len() == i * 2); + hax_lib::loop_invariant!(|i: usize| out.len() == i * 2); let poly = i % 16; let poly_idx = i / 16; let pt = Pt { diff --git a/src/v1/chunked/send_ct.rs b/src/v1/chunked/send_ct.rs index 8cc6c46..66f0040 100644 --- a/src/v1/chunked/send_ct.rs +++ b/src/v1/chunked/send_ct.rs @@ -21,17 +21,21 @@ pub struct NoHeaderReceived { } #[cfg_attr(test, derive(Clone))] +#[hax_lib::attributes] pub struct HeaderReceived { uc: unchunked::HeaderReceived, // `receiving_ek` only decodes messages of length `incremental_mlkem768::ENCAPSULATION_KEY_SIZE` + #[hax_lib::refine(receiving_ek.get_pts_needed() == incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2)] receiving_ek: polynomial::PolyDecoder, } #[cfg_attr(test, derive(Clone))] +#[hax_lib::attributes] pub struct Ct1Sampled { uc: unchunked::Ct1Sent, sending_ct1: polynomial::PolyEncoder, // `receiving_ek` only decodes messages of length `incremental_mlkem768::ENCAPSULATION_KEY_SIZE` + #[hax_lib::refine(receiving_ek.get_pts_needed() == incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2)] receiving_ek: polynomial::PolyDecoder, } @@ -42,9 +46,11 @@ pub struct EkReceivedCt1Sampled { } #[cfg_attr(test, derive(Clone))] +#[hax_lib::attributes] pub struct Ct1Acknowledged { uc: unchunked::Ct1Sent, // `receiving_ek` only decodes messages of length `incremental_mlkem768::ENCAPSULATION_KEY_SIZE` + #[hax_lib::refine(receiving_ek.get_pts_needed() == incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2)] receiving_ek: polynomial::PolyDecoder, } @@ -85,10 +91,6 @@ impl NoHeaderReceived { } = self; receiving_hdr.add_chunk(chunk); if let Some(mut hdr) = receiving_hdr.decoded_message() { - hax_lib::assume!( - hdr.len() - == incremental_mlkem768::HEADER_SIZE + authenticator::Authenticator::MACSIZE - ); // post on `decoded_message` let mac: authenticator::Mac = hdr.split_off(incremental_mlkem768::HEADER_SIZE); let receiving_ek = polynomial::PolyDecoder::new(incremental_mlkem768::ENCAPSULATION_KEY_SIZE); @@ -172,9 +174,7 @@ impl Ct1Sampled { sending_ct1, } = self; receiving_ek.add_chunk(chunk); - hax_lib::assume!(receiving_ek.pts_needed == 576); Ok(if let Some(decoded) = receiving_ek.decoded_message() { - hax_lib::assume!(decoded.len() == 1152); // Could come from a post condition on `decoded_message` (problem: return in loop) let uc = uc.recv_ek(epoch, decoded)?; if ct1_ack { let (uc, ct2, mac) = uc.send_ct2(); @@ -263,9 +263,7 @@ impl Ct1Acknowledged { mut receiving_ek, } = self; receiving_ek.add_chunk(chunk); - hax_lib::assume!(receiving_ek.get_pts_needed() == 576); // Could be done using a precondition or a refinement type Ok(if let Some(decoded) = receiving_ek.decoded_message() { - hax_lib::assume!(decoded.len() == 1152); // post-condition on `decoded_message`? hard because of returns let uc = uc.recv_ek(epoch, decoded)?; let (uc, ct2, mac) = uc.send_ct2(); Ct1AcknowledgedRecvChunk::Done(Ct2Sampled { diff --git a/src/v1/chunked/send_ct/serialize.rs b/src/v1/chunked/send_ct/serialize.rs index a8651cf..a88f357 100644 --- a/src/v1/chunked/send_ct/serialize.rs +++ b/src/v1/chunked/send_ct/serialize.rs @@ -39,6 +39,11 @@ impl HeaderReceived { } pub fn from_pb(pb: pqrpb::v1_state::chunked::HeaderReceived) -> Result { + if let Some(d) = &pb.receiving_ek { + if d.pts_needed as usize != crate::incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2 { + return Err(Error::MsgDecode); + } + } Ok(Self { uc: unchunked::send_ct::HeaderReceived::from_pb(pb.uc.ok_or(Error::StateDecode)?)?, receiving_ek: polynomial::PolyDecoder::from_pb( @@ -59,6 +64,11 @@ impl Ct1Sampled { } pub fn from_pb(pb: pqrpb::v1_state::chunked::Ct1Sampled) -> Result { + if let Some(d) = &pb.receiving_ek { + if d.pts_needed as usize != crate::incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2 { + return Err(Error::MsgDecode); + } + } Ok(Self { uc: unchunked::send_ct::Ct1Sent::from_pb(pb.uc.ok_or(Error::StateDecode)?)?, sending_ct1: polynomial::PolyEncoder::from_pb( @@ -101,6 +111,11 @@ impl Ct1Acknowledged { } pub fn from_pb(pb: pqrpb::v1_state::chunked::Ct1Acknowledged) -> Result { + if let Some(d) = &pb.receiving_ek { + if d.pts_needed as usize != crate::incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2 { + return Err(Error::MsgDecode); + } + } Ok(Self { uc: unchunked::send_ct::Ct1Sent::from_pb(pb.uc.ok_or(Error::StateDecode)?)?, receiving_ek: polynomial::PolyDecoder::from_pb( diff --git a/src/v1/chunked/states/serialize.rs b/src/v1/chunked/states/serialize.rs index f9c61e8..329dae9 100644 --- a/src/v1/chunked/states/serialize.rs +++ b/src/v1/chunked/states/serialize.rs @@ -54,11 +54,11 @@ impl States { Self::KeysSampled(send_ek::KeysSampled::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::HeaderSent(pb)) => { - hax_lib::assume!(match &pb.receiving_ct1 { - Some(d) => - d.pts_needed as usize == crate::incremental_mlkem768::CIPHERTEXT1_SIZE / 2, - None => true, - }); + if let Some(d) = &pb.receiving_ct1 { + if d.pts_needed as usize != crate::incremental_mlkem768::CIPHERTEXT1_SIZE / 2 { + return Err(Error::MsgDecode); + } + } Self::HeaderSent(send_ek::HeaderSent::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::Ct1Received(pb)) => { @@ -70,26 +70,49 @@ impl States { // send_ct Some(pqrpb::v1_state::InnerState::NoHeaderReceived(pb)) => { - hax_lib::assume!(match &pb.receiving_hdr { - Some(rhdr) => - rhdr.pts_needed - == ((crate::incremental_mlkem768::HEADER_SIZE - + crate::authenticator::Authenticator::MACSIZE) - / 2) as u32, - None => true, - }); + if let Some(rhdr) = &pb.receiving_hdr { + if rhdr.pts_needed + != ((crate::incremental_mlkem768::HEADER_SIZE + + crate::authenticator::Authenticator::MACSIZE) + / 2) as u32 + { + return Err(Error::MsgDecode); + } + } + Self::NoHeaderReceived(send_ct::NoHeaderReceived::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::HeaderReceived(pb)) => { + if let Some(d) = &pb.receiving_ek { + if d.pts_needed as usize + != crate::incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2 + { + return Err(Error::MsgDecode); + } + } Self::HeaderReceived(send_ct::HeaderReceived::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::Ct1Sampled(pb)) => { + if let Some(d) = &pb.receiving_ek { + if d.pts_needed as usize + != crate::incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2 + { + return Err(Error::MsgDecode); + } + } Self::Ct1Sampled(send_ct::Ct1Sampled::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::EkReceivedCt1Sampled(pb)) => { Self::EkReceivedCt1Sampled(send_ct::EkReceivedCt1Sampled::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::Ct1Acknowledged(pb)) => { + if let Some(d) = &pb.receiving_ek { + if d.pts_needed as usize + != crate::incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2 + { + return Err(Error::MsgDecode); + } + } Self::Ct1Acknowledged(send_ct::Ct1Acknowledged::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::Ct2Sampled(pb)) => { From 995a488924203b30ab4cdbd4103dc138205e8b70 Mon Sep 17 00:00:00 2001 From: Maxime Buyse Date: Wed, 29 Oct 2025 15:18:50 +0100 Subject: [PATCH 8/8] Refactor hax proofs to return errors when length is wrong instead of hacing assumption. --- src/v1/chunked/send_ct/serialize.rs | 13 +++++++---- src/v1/chunked/send_ek.rs | 13 ++--------- src/v1/chunked/send_ek/serialize.rs | 14 ++++++++++- src/v1/chunked/states/serialize.rs | 36 ----------------------------- src/v1/unchunked/send_ek.rs | 2 ++ 5 files changed, 26 insertions(+), 52 deletions(-) diff --git a/src/v1/chunked/send_ct/serialize.rs b/src/v1/chunked/send_ct/serialize.rs index a88f357..1da63b4 100644 --- a/src/v1/chunked/send_ct/serialize.rs +++ b/src/v1/chunked/send_ct/serialize.rs @@ -15,11 +15,16 @@ impl NoHeaderReceived { } } - #[hax_lib::requires(match pb.receiving_hdr { - Some(rhdr) => rhdr.pts_needed == ((incremental_mlkem768::HEADER_SIZE + authenticator::Authenticator::MACSIZE) / 2) as u32, - None => true} - )] pub fn from_pb(pb: pqrpb::v1_state::chunked::NoHeaderReceived) -> Result { + if let Some(rhdr) = &pb.receiving_hdr { + if rhdr.pts_needed + != ((crate::incremental_mlkem768::HEADER_SIZE + + crate::authenticator::Authenticator::MACSIZE) + / 2) as u32 + { + return Err(Error::MsgDecode); + } + } Ok(Self { uc: unchunked::send_ct::NoHeaderReceived::from_pb(pb.uc.ok_or(Error::StateDecode)?)?, receiving_hdr: polynomial::PolyDecoder::from_pb( diff --git a/src/v1/chunked/send_ek.rs b/src/v1/chunked/send_ek.rs index 6f751e4..742c9d3 100644 --- a/src/v1/chunked/send_ek.rs +++ b/src/v1/chunked/send_ek.rs @@ -40,9 +40,11 @@ pub struct Ct1Received { } #[cfg_attr(test, derive(Clone))] +#[hax_lib::attributes] pub struct EkSentCt1Received { uc: unchunked::EkSentCt1Received, // `receiving_ct2` only decodes messages of length `incremental_mlkem768::CIPHERTEXT2_SIZE + authenticator::Authenticator::MACSIZE` + #[hax_lib::refine(receiving_ct2.pts_needed == (incremental_mlkem768::CIPHERTEXT2_SIZE + authenticator::Authenticator::MACSIZE) / 2)] receiving_ct2: polynomial::PolyDecoder, } @@ -85,7 +87,6 @@ impl KeysSampled { let mut receiving_ct1 = decoder.expect("should be able to decode header size"); receiving_ct1.add_chunk(chunk); let (uc, ek) = self.uc.send_ek(); - hax_lib::assume!(ek.len() % 2 == 0); let encoder = polynomial::PolyEncoder::encode_bytes(&ek); let sending_ek = encoder.expect("should be able to send ek"); HeaderSent { @@ -134,11 +135,7 @@ impl HeaderSent { mut receiving_ct1, } = self; receiving_ct1.add_chunk(chunk); - hax_lib::assume!( - receiving_ct1.get_pts_needed() <= polynomial::MAX_STORED_POLYNOMIAL_DEGREE_V1 - ); // Seems contradictory with necessary length returned by `decoded_message` (960) if let Some(decoded) = receiving_ct1.decoded_message() { - hax_lib::assume!(decoded.len() == incremental_mlkem768::CIPHERTEXT1_SIZE); let uc = uc.recv_ct1(epoch, decoded); HeaderSentRecvChunk::Done(Ct1Received { uc, sending_ek }) } else { @@ -205,13 +202,7 @@ impl EkSentCt1Received { mut receiving_ct2, } = self; receiving_ct2.add_chunk(chunk); - hax_lib::assume!(receiving_ct2.pts_needed < usize::MAX / 2); if let Some(mut ct2) = receiving_ct2.decoded_message() { - hax_lib::assume!( - ct2.len() - == incremental_mlkem768::CIPHERTEXT2_SIZE - + authenticator::Authenticator::MACSIZE - ); let mac: authenticator::Mac = ct2.split_off(incremental_mlkem768::CIPHERTEXT2_SIZE); let (uc, sec) = uc.recv_ct2(ct2, mac)?; let decoder = polynomial::PolyDecoder::new( diff --git a/src/v1/chunked/send_ek/serialize.rs b/src/v1/chunked/send_ek/serialize.rs index 843511e..5e8f9dd 100644 --- a/src/v1/chunked/send_ek/serialize.rs +++ b/src/v1/chunked/send_ek/serialize.rs @@ -49,8 +49,12 @@ impl HeaderSent { } } - #[hax_lib::requires(match pb.receiving_ct1 {Some(d) => d.pts_needed as usize == incremental_mlkem768::CIPHERTEXT1_SIZE / 2, None => true })] pub fn from_pb(pb: pqrpb::v1_state::chunked::HeaderSent) -> Result { + if let Some(d) = &pb.receiving_ct1 { + if d.pts_needed as usize != crate::incremental_mlkem768::CIPHERTEXT1_SIZE / 2 { + return Err(Error::MsgDecode); + } + } Ok(Self { uc: unchunked::send_ek::EkSent::from_pb(pb.uc.ok_or(Error::StateDecode)?)?, sending_ek: polynomial::PolyEncoder::from_pb(pb.sending_ek.ok_or(Error::StateDecode)?) @@ -89,6 +93,14 @@ impl EkSentCt1Received { } pub fn from_pb(pb: pqrpb::v1_state::chunked::EkSentCt1Received) -> Result { + if let Some(d) = &pb.receiving_ct2 { + if d.pts_needed as usize + != (incremental_mlkem768::CIPHERTEXT2_SIZE + authenticator::Authenticator::MACSIZE) + / 2 + { + return Err(Error::MsgDecode); + } + } Ok(Self { uc: unchunked::send_ek::EkSentCt1Received::from_pb(pb.uc.ok_or(Error::StateDecode)?)?, receiving_ct2: polynomial::PolyDecoder::from_pb( diff --git a/src/v1/chunked/states/serialize.rs b/src/v1/chunked/states/serialize.rs index 329dae9..dd0ded5 100644 --- a/src/v1/chunked/states/serialize.rs +++ b/src/v1/chunked/states/serialize.rs @@ -54,11 +54,6 @@ impl States { Self::KeysSampled(send_ek::KeysSampled::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::HeaderSent(pb)) => { - if let Some(d) = &pb.receiving_ct1 { - if d.pts_needed as usize != crate::incremental_mlkem768::CIPHERTEXT1_SIZE / 2 { - return Err(Error::MsgDecode); - } - } Self::HeaderSent(send_ek::HeaderSent::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::Ct1Received(pb)) => { @@ -70,49 +65,18 @@ impl States { // send_ct Some(pqrpb::v1_state::InnerState::NoHeaderReceived(pb)) => { - if let Some(rhdr) = &pb.receiving_hdr { - if rhdr.pts_needed - != ((crate::incremental_mlkem768::HEADER_SIZE - + crate::authenticator::Authenticator::MACSIZE) - / 2) as u32 - { - return Err(Error::MsgDecode); - } - } - Self::NoHeaderReceived(send_ct::NoHeaderReceived::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::HeaderReceived(pb)) => { - if let Some(d) = &pb.receiving_ek { - if d.pts_needed as usize - != crate::incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2 - { - return Err(Error::MsgDecode); - } - } Self::HeaderReceived(send_ct::HeaderReceived::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::Ct1Sampled(pb)) => { - if let Some(d) = &pb.receiving_ek { - if d.pts_needed as usize - != crate::incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2 - { - return Err(Error::MsgDecode); - } - } Self::Ct1Sampled(send_ct::Ct1Sampled::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::EkReceivedCt1Sampled(pb)) => { Self::EkReceivedCt1Sampled(send_ct::EkReceivedCt1Sampled::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::Ct1Acknowledged(pb)) => { - if let Some(d) = &pb.receiving_ek { - if d.pts_needed as usize - != crate::incremental_mlkem768::ENCAPSULATION_KEY_SIZE / 2 - { - return Err(Error::MsgDecode); - } - } Self::Ct1Acknowledged(send_ct::Ct1Acknowledged::from_pb(pb)?) } Some(pqrpb::v1_state::InnerState::Ct2Sampled(pb)) => { diff --git a/src/v1/unchunked/send_ek.rs b/src/v1/unchunked/send_ek.rs index 1606327..cb3d9aa 100644 --- a/src/v1/unchunked/send_ek.rs +++ b/src/v1/unchunked/send_ek.rs @@ -100,7 +100,9 @@ impl KeysUnsampled { } } +#[hax_lib::attributes] impl HeaderSent { + #[hax_lib::ensures(|(_, ek)| ek.len() == 1152)] pub fn send_ek(self) -> (EkSent, incremental_mlkem768::EncapsulationKey) { ( EkSent {