From f0186fa421aab5d09328dbe3d726b3be046b6c3b Mon Sep 17 00:00:00 2001 From: daxpedda Date: Fri, 11 Apr 2025 21:08:12 +0200 Subject: [PATCH 1/3] Various fixes and improvements to hash2curve --- elliptic-curve/src/hash2curve/hash2field.rs | 14 ++- .../src/hash2curve/hash2field/expand_msg.rs | 4 +- .../hash2curve/hash2field/expand_msg/xmd.rs | 55 ++++++++---- .../hash2curve/hash2field/expand_msg/xof.rs | 88 +++++++++++++------ 4 files changed, 111 insertions(+), 50 deletions(-) diff --git a/elliptic-curve/src/hash2curve/hash2field.rs b/elliptic-curve/src/hash2curve/hash2field.rs index 946c8a39d..4ffcf8062 100644 --- a/elliptic-curve/src/hash2curve/hash2field.rs +++ b/elliptic-curve/src/hash2curve/hash2field.rs @@ -4,15 +4,20 @@ mod expand_msg; +use core::num::NonZeroUsize; + pub use expand_msg::{xmd::*, xof::*, *}; use crate::{Error, Result}; -use hybrid_array::{Array, ArraySize, typenum::Unsigned}; +use hybrid_array::{ + Array, ArraySize, + typenum::{NonZero, Unsigned}, +}; /// The trait for helping to convert to a field element. pub trait FromOkm { /// The number of bytes needed to convert to a field element. - type Length: ArraySize; + type Length: ArraySize + NonZero; /// Convert a byte sequence into a field element. fn from_okm(data: &Array) -> Self; @@ -37,7 +42,10 @@ where E: ExpandMsg<'a>, T: FromOkm + Default, { - let len_in_bytes = T::Length::to_usize().checked_mul(out.len()).ok_or(Error)?; + let len_in_bytes = T::Length::to_usize() + .checked_mul(out.len()) + .and_then(NonZeroUsize::new) + .ok_or(Error)?; let mut tmp = Array::::Length>::default(); let mut expander = E::expand_message(data, domain, len_in_bytes)?; for o in out.iter_mut() { diff --git a/elliptic-curve/src/hash2curve/hash2field/expand_msg.rs b/elliptic-curve/src/hash2curve/hash2field/expand_msg.rs index 510ce5b2f..444dbf72f 100644 --- a/elliptic-curve/src/hash2curve/hash2field/expand_msg.rs +++ b/elliptic-curve/src/hash2curve/hash2field/expand_msg.rs @@ -3,6 +3,8 @@ pub(super) mod xmd; pub(super) mod xof; +use core::num::NonZero; + use crate::{Error, Result}; use digest::{Digest, ExtendableOutput, Update, XofReader}; use hybrid_array::typenum::{IsLess, U256}; @@ -28,7 +30,7 @@ pub trait ExpandMsg<'a> { fn expand_message( msgs: &[&[u8]], dsts: &'a [&'a [u8]], - len_in_bytes: usize, + len_in_bytes: NonZero, ) -> Result; } diff --git a/elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs b/elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs index c1cb250b7..c981cad41 100644 --- a/elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs +++ b/elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs @@ -1,6 +1,6 @@ //! `expand_message_xmd` based on a hash function. -use core::marker::PhantomData; +use core::{marker::PhantomData, num::NonZero, ops::Mul}; use super::{Domain, ExpandMsg, Expander}; use crate::{Error, Result}; @@ -8,52 +8,64 @@ use digest::{ FixedOutput, HashMarker, array::{ Array, - typenum::{IsLess, IsLessOrEqual, U256, Unsigned}, + typenum::{IsGreaterOrEqual, IsLess, IsLessOrEqual, U2, U8, U256, Unsigned}, }, core_api::BlockSizeUser, }; -/// Placeholder type for implementing `expand_message_xmd` based on a hash function +/// Implements `expand_message_xof` via the [`ExpandMsg`] trait: +/// +/// +/// `K` is the target security level in bits: +/// +/// /// /// # Errors /// - `dst.is_empty()` -/// - `len_in_bytes == 0` /// - `len_in_bytes > u16::MAX` /// - `len_in_bytes > 255 * HashT::OutputSize` #[derive(Debug)] -pub struct ExpandMsgXmd(PhantomData) +pub struct ExpandMsgXmd(PhantomData<(HashT, K)>) where HashT: BlockSizeUser + Default + FixedOutput + HashMarker, HashT::OutputSize: IsLess, - HashT::OutputSize: IsLessOrEqual; + HashT::OutputSize: IsLessOrEqual, + HashT::OutputSize: Mul, + U2: Mul, + >::Output: IsGreaterOrEqual<>::Output>; -/// ExpandMsgXmd implements expand_message_xmd for the ExpandMsg trait -impl<'a, HashT> ExpandMsg<'a> for ExpandMsgXmd +impl<'a, HashT, K> ExpandMsg<'a> for ExpandMsgXmd where HashT: BlockSizeUser + Default + FixedOutput + HashMarker, - // If `len_in_bytes` is bigger then 256, length of the `DST` will depend on - // the output size of the hash, which is still not allowed to be bigger then 256: + // If DST is larger than 255 bytes, the length of the computed DST will depend on the output + // size of the hash, which is still not allowed to be larger than 256: // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5.4.1-6 HashT::OutputSize: IsLess, // Constraint set by `expand_message_xmd`: // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5.4.1-4 HashT::OutputSize: IsLessOrEqual, + // The number of bits output by `HashT` MUST be larger or equal to `2 * K`: + // https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.1-2.1 + HashT::OutputSize: Mul, + U2: Mul, + >::Output: IsGreaterOrEqual<>::Output>, { type Expander = ExpanderXmd<'a, HashT>; fn expand_message( msgs: &[&[u8]], dsts: &'a [&'a [u8]], - len_in_bytes: usize, + len_in_bytes: NonZero, ) -> Result { - if len_in_bytes == 0 { + let len_in_bytes_u16 = u16::try_from(len_in_bytes.get()).map_err(|_| Error)?; + + // `255 * ` can not exceed `u16::MAX` + if len_in_bytes_u16 > 255 * HashT::OutputSize::to_u16() { return Err(Error); } - let len_in_bytes_u16 = u16::try_from(len_in_bytes).map_err(|_| Error)?; - let b_in_bytes = HashT::OutputSize::to_usize(); - let ell = u8::try_from(len_in_bytes.div_ceil(b_in_bytes)).map_err(|_| Error)?; + let ell = u8::try_from(len_in_bytes.get().div_ceil(b_in_bytes)).map_err(|_| Error)?; let domain = Domain::xmd::(dsts)?; let mut b_0 = HashT::default(); @@ -157,7 +169,7 @@ mod test { use hex_literal::hex; use hybrid_array::{ ArraySize, - typenum::{U32, U128}, + typenum::{U8, U32, U128}, }; use sha2::Sha256; @@ -209,13 +221,18 @@ mod test { ) -> Result<()> where HashT: BlockSizeUser + Default + FixedOutput + HashMarker, - HashT::OutputSize: IsLess + IsLessOrEqual, + HashT::OutputSize: IsLess + IsLessOrEqual + Mul, + U2: Mul, + >::Output: IsGreaterOrEqual<>::Output>, { assert_message::(self.msg, domain, L::to_u16(), self.msg_prime); let dst = [dst]; - let mut expander = - ExpandMsgXmd::::expand_message(&[self.msg], &dst, L::to_usize())?; + let mut expander = ExpandMsgXmd::::expand_message( + &[self.msg], + &dst, + NonZero::new(L::to_usize()).ok_or(Error)?, + )?; let mut uniform_bytes = Array::::default(); expander.fill_bytes(&mut uniform_bytes); diff --git a/elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs b/elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs index 6a5c14621..b56351223 100644 --- a/elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs +++ b/elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs @@ -2,26 +2,45 @@ use super::{Domain, ExpandMsg, Expander}; use crate::{Error, Result}; -use core::fmt; -use digest::{ExtendableOutput, Update, XofReader}; -use hybrid_array::typenum::U32; - -/// Placeholder type for implementing `expand_message_xof` based on an extendable output function +use core::{ + fmt, + marker::PhantomData, + num::NonZero, + ops::{Div, Mul}, +}; +use digest::{ExtendableOutput, HashMarker, Update, XofReader}; +use hybrid_array::{ + ArraySize, + typenum::{IsLess, U2, U8, U256}, +}; + +/// Implements `expand_message_xof` via the [`ExpandMsg`] trait: +/// +/// +/// `K` is the target security level in bits: +/// +/// /// /// # Errors /// - `dst.is_empty()` -/// - `len_in_bytes == 0` /// - `len_in_bytes > u16::MAX` -pub struct ExpandMsgXof +pub struct ExpandMsgXof where - HashT: Default + ExtendableOutput + Update, + HashT: Default + ExtendableOutput + Update + HashMarker, + U2: Mul, + >::Output: Div, + HashSize: ArraySize + IsLess, { reader: ::Reader, + _k: PhantomData, } -impl fmt::Debug for ExpandMsgXof +impl fmt::Debug for ExpandMsgXof where - HashT: Default + ExtendableOutput + Update, + HashT: Default + ExtendableOutput + Update + HashMarker, + U2: Mul, + >::Output: Div, + HashSize: ArraySize + IsLess, ::Reader: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -31,25 +50,28 @@ where } } -/// ExpandMsgXof implements `expand_message_xof` for the [`ExpandMsg`] trait -impl<'a, HashT> ExpandMsg<'a> for ExpandMsgXof +type HashSize = <>::Output as Div>::Output; + +impl<'a, HashT, K> ExpandMsg<'a> for ExpandMsgXof where - HashT: Default + ExtendableOutput + Update, + HashT: Default + ExtendableOutput + Update + HashMarker, + // If DST is larger than 255 bytes, the length of the computed DST is calculated by + // `2 * k / 8`. + // https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.1-2.1 + U2: Mul, + >::Output: Div, + HashSize: ArraySize + IsLess, { type Expander = Self; fn expand_message( msgs: &[&[u8]], dsts: &'a [&'a [u8]], - len_in_bytes: usize, + len_in_bytes: NonZero, ) -> Result { - if len_in_bytes == 0 { - return Err(Error); - } - - let len_in_bytes = u16::try_from(len_in_bytes).map_err(|_| Error)?; + let len_in_bytes = u16::try_from(len_in_bytes.get()).map_err(|_| Error)?; - let domain = Domain::::xof::(dsts)?; + let domain = Domain::>::xof::(dsts)?; let mut reader = HashT::default(); for msg in msgs { @@ -60,13 +82,19 @@ where domain.update_hash(&mut reader); reader.update(&[domain.len()]); let reader = reader.finalize_xof(); - Ok(Self { reader }) + Ok(Self { + reader, + _k: PhantomData, + }) } } -impl Expander for ExpandMsgXof +impl Expander for ExpandMsgXof where - HashT: Default + ExtendableOutput + Update, + HashT: Default + ExtendableOutput + Update + HashMarker, + U2: Mul, + >::Output: Div, + HashSize: ArraySize + IsLess, { fn fill_bytes(&mut self, okm: &mut [u8]) { self.reader.read(okm); @@ -78,7 +106,10 @@ mod test { use super::*; use core::mem::size_of; use hex_literal::hex; - use hybrid_array::{Array, ArraySize, typenum::U128}; + use hybrid_array::{ + Array, ArraySize, + typenum::{U32, U128}, + }; use sha3::Shake128; fn assert_message(msg: &[u8], domain: &Domain<'_, U32>, len_in_bytes: u16, bytes: &[u8]) { @@ -110,13 +141,16 @@ mod test { #[allow(clippy::panic_in_result_fn)] fn assert(&self, dst: &'static [u8], domain: &Domain<'_, U32>) -> Result<()> where - HashT: Default + ExtendableOutput + Update, + HashT: Default + ExtendableOutput + Update + HashMarker, L: ArraySize, { assert_message(self.msg, domain, L::to_u16(), self.msg_prime); - let mut expander = - ExpandMsgXof::::expand_message(&[self.msg], &[dst], L::to_usize())?; + let mut expander = ExpandMsgXof::::expand_message( + &[self.msg], + &[dst], + NonZero::new(L::to_usize()).ok_or(Error)?, + )?; let mut uniform_bytes = Array::::default(); expander.fill_bytes(&mut uniform_bytes); From 76a991a37802377c83420042688b435b5cb85556 Mon Sep 17 00:00:00 2001 From: daxpedda Date: Fri, 11 Apr 2025 22:21:51 +0200 Subject: [PATCH 2/3] Add `K` to `GroupDigest` --- elliptic-curve/src/hash2curve/group_digest.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/elliptic-curve/src/hash2curve/group_digest.rs b/elliptic-curve/src/hash2curve/group_digest.rs index 2a663a11b..de65fab51 100644 --- a/elliptic-curve/src/hash2curve/group_digest.rs +++ b/elliptic-curve/src/hash2curve/group_digest.rs @@ -3,6 +3,7 @@ use super::{ExpandMsg, FromOkm, MapToCurve, hash_to_field}; use crate::{CurveArithmetic, ProjectivePoint, Result}; use group::cofactor::CofactorGroup; +use hybrid_array::typenum::Unsigned; /// Adds hashing arbitrary byte sequences to a valid group element pub trait GroupDigest: CurveArithmetic @@ -12,6 +13,11 @@ where /// The field element representation for a group value with multiple elements type FieldElement: FromOkm + MapToCurve> + Default + Copy; + /// The target security level in bits: + /// + /// + type K: Unsigned; + /// Computes the hash to curve routine. /// /// From : From 2f0bf9b9d21bd9ec069443b046edebf79e464bd8 Mon Sep 17 00:00:00 2001 From: daxpedda Date: Thu, 17 Apr 2025 16:04:41 +0200 Subject: [PATCH 3/3] Change `K` to represent bytes instead of bits --- elliptic-curve/src/hash2curve/group_digest.rs | 2 +- .../hash2curve/hash2field/expand_msg/xmd.rs | 23 +++++----- .../hash2curve/hash2field/expand_msg/xof.rs | 42 +++++++------------ 3 files changed, 26 insertions(+), 41 deletions(-) diff --git a/elliptic-curve/src/hash2curve/group_digest.rs b/elliptic-curve/src/hash2curve/group_digest.rs index de65fab51..8fa27c46c 100644 --- a/elliptic-curve/src/hash2curve/group_digest.rs +++ b/elliptic-curve/src/hash2curve/group_digest.rs @@ -13,7 +13,7 @@ where /// The field element representation for a group value with multiple elements type FieldElement: FromOkm + MapToCurve> + Default + Copy; - /// The target security level in bits: + /// The target security level in bytes: /// /// type K: Unsigned; diff --git a/elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs b/elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs index c981cad41..e98843e38 100644 --- a/elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs +++ b/elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs @@ -8,7 +8,7 @@ use digest::{ FixedOutput, HashMarker, array::{ Array, - typenum::{IsGreaterOrEqual, IsLess, IsLessOrEqual, U2, U8, U256, Unsigned}, + typenum::{IsGreaterOrEqual, IsLess, IsLessOrEqual, U2, U256, Unsigned}, }, core_api::BlockSizeUser, }; @@ -16,7 +16,7 @@ use digest::{ /// Implements `expand_message_xof` via the [`ExpandMsg`] trait: /// /// -/// `K` is the target security level in bits: +/// `K` is the target security level in bytes: /// /// /// @@ -30,9 +30,8 @@ where HashT: BlockSizeUser + Default + FixedOutput + HashMarker, HashT::OutputSize: IsLess, HashT::OutputSize: IsLessOrEqual, - HashT::OutputSize: Mul, - U2: Mul, - >::Output: IsGreaterOrEqual<>::Output>; + K: Mul, + HashT::OutputSize: IsGreaterOrEqual<>::Output>; impl<'a, HashT, K> ExpandMsg<'a> for ExpandMsgXmd where @@ -44,11 +43,10 @@ where // Constraint set by `expand_message_xmd`: // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5.4.1-4 HashT::OutputSize: IsLessOrEqual, - // The number of bits output by `HashT` MUST be larger or equal to `2 * K`: + // The number of bits output by `HashT` MUST be larger or equal to `K * 2`: // https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.1-2.1 - HashT::OutputSize: Mul, - U2: Mul, - >::Output: IsGreaterOrEqual<>::Output>, + K: Mul, + HashT::OutputSize: IsGreaterOrEqual<>::Output>, { type Expander = ExpanderXmd<'a, HashT>; @@ -169,7 +167,7 @@ mod test { use hex_literal::hex; use hybrid_array::{ ArraySize, - typenum::{U8, U32, U128}, + typenum::{U4, U8, U32, U128}, }; use sha2::Sha256; @@ -222,13 +220,12 @@ mod test { where HashT: BlockSizeUser + Default + FixedOutput + HashMarker, HashT::OutputSize: IsLess + IsLessOrEqual + Mul, - U2: Mul, - >::Output: IsGreaterOrEqual<>::Output>, + HashT::OutputSize: IsGreaterOrEqual<>::Output>, { assert_message::(self.msg, domain, L::to_u16(), self.msg_prime); let dst = [dst]; - let mut expander = ExpandMsgXmd::::expand_message( + let mut expander = ExpandMsgXmd::::expand_message( &[self.msg], &dst, NonZero::new(L::to_usize()).ok_or(Error)?, diff --git a/elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs b/elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs index b56351223..9d40ed2c1 100644 --- a/elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs +++ b/elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs @@ -2,22 +2,17 @@ use super::{Domain, ExpandMsg, Expander}; use crate::{Error, Result}; -use core::{ - fmt, - marker::PhantomData, - num::NonZero, - ops::{Div, Mul}, -}; +use core::{fmt, marker::PhantomData, num::NonZero, ops::Mul}; use digest::{ExtendableOutput, HashMarker, Update, XofReader}; use hybrid_array::{ ArraySize, - typenum::{IsLess, U2, U8, U256}, + typenum::{IsLess, U2, U256}, }; /// Implements `expand_message_xof` via the [`ExpandMsg`] trait: /// /// -/// `K` is the target security level in bits: +/// `K` is the target security level in bytes: /// /// /// @@ -27,9 +22,8 @@ use hybrid_array::{ pub struct ExpandMsgXof where HashT: Default + ExtendableOutput + Update + HashMarker, - U2: Mul, - >::Output: Div, - HashSize: ArraySize + IsLess, + K: Mul, + >::Output: ArraySize + IsLess, { reader: ::Reader, _k: PhantomData, @@ -38,9 +32,8 @@ where impl fmt::Debug for ExpandMsgXof where HashT: Default + ExtendableOutput + Update + HashMarker, - U2: Mul, - >::Output: Div, - HashSize: ArraySize + IsLess, + K: Mul, + >::Output: ArraySize + IsLess, ::Reader: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -50,17 +43,13 @@ where } } -type HashSize = <>::Output as Div>::Output; - impl<'a, HashT, K> ExpandMsg<'a> for ExpandMsgXof where HashT: Default + ExtendableOutput + Update + HashMarker, - // If DST is larger than 255 bytes, the length of the computed DST is calculated by - // `2 * k / 8`. + // If DST is larger than 255 bytes, the length of the computed DST is calculated by `K * 2`. // https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.1-2.1 - U2: Mul, - >::Output: Div, - HashSize: ArraySize + IsLess, + K: Mul, + >::Output: ArraySize + IsLess, { type Expander = Self; @@ -71,7 +60,7 @@ where ) -> Result { let len_in_bytes = u16::try_from(len_in_bytes.get()).map_err(|_| Error)?; - let domain = Domain::>::xof::(dsts)?; + let domain = Domain::<>::Output>::xof::(dsts)?; let mut reader = HashT::default(); for msg in msgs { @@ -92,9 +81,8 @@ where impl Expander for ExpandMsgXof where HashT: Default + ExtendableOutput + Update + HashMarker, - U2: Mul, - >::Output: Div, - HashSize: ArraySize + IsLess, + K: Mul, + >::Output: ArraySize + IsLess, { fn fill_bytes(&mut self, okm: &mut [u8]) { self.reader.read(okm); @@ -108,7 +96,7 @@ mod test { use hex_literal::hex; use hybrid_array::{ Array, ArraySize, - typenum::{U32, U128}, + typenum::{U16, U32, U128}, }; use sha3::Shake128; @@ -146,7 +134,7 @@ mod test { { assert_message(self.msg, domain, L::to_u16(), self.msg_prime); - let mut expander = ExpandMsgXof::::expand_message( + let mut expander = ExpandMsgXof::::expand_message( &[self.msg], &[dst], NonZero::new(L::to_usize()).ok_or(Error)?,