33use digest:: {
44 Digest , HashMarker ,
55 array:: Array ,
6- consts:: { U8 , U32 } ,
6+ consts:: { U16 , U32 } ,
77} ;
88use subtle:: CtOption ;
99
@@ -31,10 +31,8 @@ impl RistrettoPoint {
3131 let digest = D :: digest ( data) ;
3232 fe_bytes[ 0 ..32 ] . copy_from_slice ( digest. as_slice ( ) ) ;
3333 fe_bytes[ 8 ..24 ] . copy_from_slice ( data) ;
34- fe_bytes[ 0 ] &= 0b11111110 ;
35- fe_bytes[ 31 ] &= 0b00111111 ;
36- let fe = FieldElement :: from_bytes ( & fe_bytes) ;
37- RistrettoPoint :: elligator_ristretto_flavor ( & fe)
34+ // Clear the appropriate bits to make this a reduced positive field elem, and map to curve
35+ RistrettoPoint :: map_pos_felem_to_curve ( fe_bytes)
3836 }
3937
4038 /// Decode 16 bytes of data from a RistrettoPoint, using the Lizard method. Returns `None` if
@@ -70,8 +68,8 @@ impl RistrettoPoint {
7068 }
7169
7270 /// Computes the at most 8 positive FieldElements f such that `self ==
73- /// RistrettoPoint::elligator_ristretto_flavor(f)`. Assumes self is even.
74- fn elligator_ristretto_flavor_inverse ( & self ) -> [ CtOption < FieldElement > ; 8 ] {
71+ /// RistrettoPoint::elligator_ristretto_flavor(f)`.
72+ fn elligator_ristretto_flavor_inverse ( & self ) -> [ CtOption < FieldElement > ; 16 ] {
7573 // Elligator2 computes a Point from a FieldElement in two steps: first
7674 // it computes a (s,t) on the Jacobi quartic and then computes the
7775 // corresponding even point on the Edwards curve.
@@ -93,11 +91,15 @@ impl RistrettoPoint {
9391
9492 let jcs = self . to_jacobi_quartic_ristretto ( ) ;
9593
96- // Compute the inverse of the every point and its dual
97- let invs = jcs. iter ( ) . flat_map ( |jc| [ jc. e_inv ( ) , jc. dual ( ) . e_inv ( ) ] ) ;
94+ // Compute the positive solution to e⁻¹ on every point and its dual
95+ let pos_invs = jcs
96+ . iter ( )
97+ . flat_map ( |jc| [ jc. e_inv_positive ( ) , jc. dual ( ) . e_inv_positive ( ) ] ) ;
98+ // Compute the other solutions to e⁻¹, ie the negatives of the above solutions
99+ let neg_invs = pos_invs. clone ( ) . map ( |mx| mx. map ( |x| -& x) ) ;
98100 // This cannot panic because jcs is guaranteed to be size 4, and the above iterator expands
99101 // it to size 8
100- Array :: < _ , U8 > :: from_iter ( invs ) . 0
102+ Array :: < _ , U16 > :: from_iter ( pos_invs . chain ( neg_invs ) ) . 0
101103 }
102104
103105 /// Find a point on the Jacobi quartic associated to each of the four
@@ -178,33 +180,43 @@ impl RistrettoPoint {
178180 ]
179181 }
180182
181- /// Interprets the given bytestring as a positive field element and computes the Ristretto
182- /// Elligator map. Note this clears the bottom bit and top two bits of `bytes`.
183+ /// Clears bits in `bytes` to make it a positive, reduced field element, then performs
184+ /// [`RistrettoPoint::map_to_curve`]). Specifically, clears the the bottom bit (bottom bit of
185+ /// `bytes[0]`) and second-to-top bit (second-to-top bit of `bytes[31]`). The first is to ensure
186+ /// we have a positive field element and the second is to ensure we are below the modulus
187+ /// (map-to-curve clears the topmost bit for us).
188+ pub fn map_pos_felem_to_curve ( mut bytes : [ u8 ; 32 ] ) -> RistrettoPoint {
189+ bytes[ 0 ] &= 0b11111110 ;
190+ bytes[ 31 ] &= 0b10111111 ;
191+
192+ RistrettoPoint :: map_to_curve ( bytes)
193+ }
194+
195+ /// Interprets the given bytestring as a field element and computes the Ristretto Elligator map.
196+ /// This is the MAP function in
197+ /// [RFC 9496](https://www.rfc-editor.org/rfc/rfc9496.html#section-4.3.4-4).
198+ /// Note this clears the top bit (`bytes[31] & 0x80`).
183199 ///
184200 /// # Warning
185201 ///
186202 /// This function does not produce cryptographically random-looking Ristretto points. Use
187203 /// [`Self::hash_from_bytes`] for that. DO NOT USE THIS FUNCTION unless you really know what
188204 /// you're doing.
189205 pub fn map_to_curve ( mut bytes : [ u8 ; 32 ] ) -> RistrettoPoint {
190- // We only have a meaningful inverse if we give Elligator a point in its domain, ie a
191- // positive (meaning low bit 0) field element. Mask off the top two bits to ensure it's less
192- // than the modulus, and the bottom bit for evenness.
193- bytes[ 0 ] &= 0b11111110 ;
194- bytes[ 31 ] &= 0b00111111 ;
206+ bytes[ 31 ] &= 0b01111111 ;
195207
196208 let fe = FieldElement :: from_bytes ( & bytes) ;
197209 RistrettoPoint :: elligator_ristretto_flavor ( & fe)
198210 }
199211
200212 /// Computes the possible bytestrings that could have produced this point via
201213 /// [`Self::map_to_curve`].
202- pub fn map_to_curve_inverse ( & self ) -> [ CtOption < [ u8 ; 32 ] > ; 8 ] {
214+ pub fn map_to_curve_inverse ( & self ) -> [ CtOption < [ u8 ; 32 ] > ; 16 ] {
203215 // Compute the inverses
204216 let fes = self . elligator_ristretto_flavor_inverse ( ) ;
205217 // Serialize the field elements
206218 let it = fes. map ( |fe| fe. map ( |f| f. to_bytes ( ) ) ) ;
207- Array :: < _ , U8 > :: from_iter ( it) . 0
219+ Array :: < _ , U16 > :: from_iter ( it) . 0
208220 }
209221}
210222
@@ -331,7 +343,7 @@ mod test {
331343 }
332344 }
333345
334- // Tests that map_to_curve_inverse ○ map_to_curve is the identity
346+ // Tests that map_to_curve ○ map_to_curve_inverse is the identity
335347 #[ test]
336348 fn map_to_curve_inverse ( ) {
337349 let mut rng = rand:: rng ( ) ;
@@ -342,9 +354,31 @@ mod test {
342354
343355 // Map to Ristretto and invert it
344356 let pt = RistrettoPoint :: map_to_curve ( input) ;
357+
358+ let inverses = pt. map_to_curve_inverse ( ) ;
359+
360+ // Assert at least one inverse exists, and all the inverses map to pt
361+ assert ! ( inverses. iter( ) . any( |i| bool :: from( i. is_some( ) ) ) ) ;
362+ for inv in inverses. into_iter ( ) . filter_map ( CtOption :: into_option) {
363+ assert_eq ! ( pt, RistrettoPoint :: map_to_curve( inv) ) ;
364+ }
365+ }
366+ }
367+
368+ // Tests that map_to_curve_inverse ○ map_pos_felem_to_curve is the identity
369+ #[ test]
370+ fn map_pos_felem_to_curve_inverse ( ) {
371+ let mut rng = rand:: rng ( ) ;
372+
373+ for _ in 0 ..100 {
374+ let mut input = [ 0u8 ; 32 ] ;
375+ rng. fill_bytes ( & mut input) ;
376+
377+ // Map to Ristretto and invert it
378+ let pt = RistrettoPoint :: map_pos_felem_to_curve ( input) ;
345379 let inverses = pt. map_to_curve_inverse ( ) ;
346380
347- // map_to_curve masks the bottom bit and top two bits of `input`
381+ // map_pos_felem_to_curve masks the bottom bit and top two bits of `input`
348382 let mut expected_inverse = input;
349383 expected_inverse[ 31 ] &= 0b00111111 ;
350384 expected_inverse[ 0 ] &= 0b11111110 ;
0 commit comments