Skip to content

Commit 00ff754

Browse files
committed
C: Public Key from Secret Key function
Refactor keygen to use a new function that derives t0 t1 tr pk from rho s1 s2 so that this function can also be called by a utility function pk_to_sk that generates the pk given the sk. We also include ct_memcmp for constant time comparison. Signed-off-by: Jake Massimo <[email protected]>
1 parent c6d7c93 commit 00ff754

File tree

14 files changed

+531
-42
lines changed

14 files changed

+531
-42
lines changed

mldsa/mldsa_native.S

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@
321321
#undef crypto_sign_keypair
322322
#undef crypto_sign_keypair_internal
323323
#undef crypto_sign_open
324+
#undef crypto_sign_pk_from_sk
324325
#undef crypto_sign_signature
325326
#undef crypto_sign_signature_extmu
326327
#undef crypto_sign_signature_internal

mldsa/mldsa_native.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@
318318
#undef crypto_sign_keypair
319319
#undef crypto_sign_keypair_internal
320320
#undef crypto_sign_open
321+
#undef crypto_sign_pk_from_sk
321322
#undef crypto_sign_signature
322323
#undef crypto_sign_signature_extmu
323324
#undef crypto_sign_signature_internal

mldsa/mldsa_native.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,25 @@ size_t MLD_API_NAMESPACE(prepare_domain_separation_prefix)(
615615
uint8_t prefix[MLD_DOMAIN_SEPARATION_MAX_BYTES], const uint8_t *ph,
616616
size_t phlen, const uint8_t *ctx, size_t ctxlen, int hashalg);
617617

618+
/*************************************************
619+
* Name: crypto_sign_pk_from_sk
620+
*
621+
* Description: Derives public key from secret key with validation.
622+
* Checks that t0 and tr stored in sk match recomputed values.
623+
*
624+
* Arguments:
625+
* - uint8_t pk[MLDSA_PUBLICKEYBYTES(MLD_CONFIG_API_PARAMETER_SET)]:
626+
* output public key
627+
* - const uint8_t sk[MLDSA_SECRETKEYBYTES(MLD_CONFIG_API_PARAMETER_SET)]:
628+
* input secret key
629+
*
630+
* Returns 0 on success, -1 if validation fails (invalid secret key)
631+
**************************************************/
632+
MLD_API_MUST_CHECK_RETURN_VALUE
633+
int MLD_API_NAMESPACE(pk_from_sk)(
634+
uint8_t pk[MLDSA_PUBLICKEYBYTES(MLD_CONFIG_API_PARAMETER_SET)],
635+
const uint8_t sk[MLDSA_SECRETKEYBYTES(MLD_CONFIG_API_PARAMETER_SET)]);
636+
618637
/****************************** SUPERCOP API *********************************/
619638

620639
#if !defined(MLD_CONFIG_API_NO_SUPERCOP)

mldsa/src/ct.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,48 @@ __contract__(
240240
#if !defined(__ASSEMBLER__)
241241
#include <string.h>
242242

243+
/*************************************************
244+
* Name: mld_ct_memcmp
245+
*
246+
* Description: Compare two arrays for equality in constant time.
247+
*
248+
* Arguments: const void *a: pointer to first byte array
249+
* const void *b: pointer to second byte array
250+
* size_t len: length of the byte arrays
251+
*
252+
* Returns 0 if the byte arrays are equal, a non-zero value otherwise
253+
**************************************************/
254+
static MLD_INLINE uint8_t mld_ct_memcmp(const void *a, const void *b,
255+
const size_t len)
256+
__contract__(
257+
requires(len <= UINT16_MAX)
258+
requires(memory_no_alias(a, len))
259+
requires(memory_no_alias(b, len))
260+
ensures((return_value == 0) == forall(i, 0, len, (((const uint8_t *)a)[i] == ((const uint8_t *)b)[i])))
261+
)
262+
{
263+
const uint8_t *a_bytes = (const uint8_t *)a;
264+
const uint8_t *b_bytes = (const uint8_t *)b;
265+
uint8_t r = 0, s = 0;
266+
unsigned i;
267+
268+
for (i = 0; i < len; i++)
269+
__loop__(
270+
invariant(i <= len)
271+
invariant((r == 0) == (forall(k, 0, i, (a_bytes[k] == b_bytes[k])))))
272+
{
273+
r |= a_bytes[i] ^ b_bytes[i];
274+
/* s is useless, but prevents the loop from being aborted once r=0xff. */
275+
s ^= a_bytes[i] ^ b_bytes[i];
276+
}
277+
278+
/*
279+
* XOR twice with s, separated by a value barrier, to prevent the compile
280+
* from dropping the s computation in the loop.
281+
*/
282+
return (uint8_t)((mld_value_barrier_u32((uint32_t)r) ^ s) ^ s);
283+
}
284+
243285
/*************************************************
244286
* Name: mld_zeroize
245287
*

mldsa/src/sign.c

Lines changed: 146 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <string.h>
2929

3030
#include "cbmc.h"
31+
#include "ct.h"
3132
#include "debug.h"
3233
#include "packing.h"
3334
#include "poly.h"
@@ -48,6 +49,8 @@
4849
#define mld_H MLD_ADD_PARAM_SET(mld_H)
4950
#define mld_attempt_signature_generation \
5051
MLD_ADD_PARAM_SET(mld_attempt_signature_generation)
52+
#define mld_compute_t0_t1_tr_from_sk_components \
53+
MLD_ADD_PARAM_SET(mld_compute_t0_t1_tr_from_sk_components)
5154
/* End of parameter set namespacing */
5255

5356

@@ -174,6 +177,85 @@ __contract__(
174177
#endif /* !MLD_CONFIG_SERIAL_FIPS202_ONLY */
175178
}
176179

180+
/*************************************************
181+
* Name: mld_compute_t0_t1_tr_from_sk_components
182+
*
183+
* Description: Computes t0, t1, tr, and pk from secret key components
184+
* rho, s1, s2. This is the shared computation used by
185+
* both keygen and generating the public key from the
186+
* secret key.
187+
*
188+
* Arguments: - mld_polyveck *t0: output t0
189+
* - mld_polyveck *t1: output t1
190+
* - uint8_t tr[MLDSA_TRBYTES]: output tr
191+
* - uint8_t pk[CRYPTO_PUBLICKEYBYTES]: output public key
192+
* - const uint8_t rho[MLDSA_SEEDBYTES]: input rho
193+
* - const mld_polyvecl *s1: input s1
194+
* - const mld_polyveck *s2: input s2
195+
**************************************************/
196+
static void mld_compute_t0_t1_tr_from_sk_components(
197+
mld_polyveck *t0, mld_polyveck *t1, uint8_t tr[MLDSA_TRBYTES],
198+
uint8_t pk[CRYPTO_PUBLICKEYBYTES], const uint8_t rho[MLDSA_SEEDBYTES],
199+
const mld_polyvecl *s1, const mld_polyveck *s2)
200+
__contract__(
201+
requires(memory_no_alias(t0, sizeof(mld_polyveck)))
202+
requires(memory_no_alias(t1, sizeof(mld_polyveck)))
203+
requires(memory_no_alias(tr, MLDSA_TRBYTES))
204+
requires(memory_no_alias(pk, CRYPTO_PUBLICKEYBYTES))
205+
requires(memory_no_alias(rho, MLDSA_SEEDBYTES))
206+
requires(memory_no_alias(s1, sizeof(mld_polyvecl)))
207+
requires(memory_no_alias(s2, sizeof(mld_polyveck)))
208+
requires(forall(l0, 0, MLDSA_L, array_bound(s1->vec[l0].coeffs, 0, MLDSA_N, MLD_POLYETA_UNPACK_LOWER_BOUND, MLDSA_ETA + 1)))
209+
requires(forall(k0, 0, MLDSA_K, array_bound(s2->vec[k0].coeffs, 0, MLDSA_N, MLD_POLYETA_UNPACK_LOWER_BOUND, MLDSA_ETA + 1)))
210+
assigns(memory_slice(t0, sizeof(mld_polyveck)))
211+
assigns(memory_slice(t1, sizeof(mld_polyveck)))
212+
assigns(memory_slice(tr, MLDSA_TRBYTES))
213+
assigns(memory_slice(pk, CRYPTO_PUBLICKEYBYTES))
214+
ensures(forall(k1, 0, MLDSA_K, array_bound(t0->vec[k1].coeffs, 0, MLDSA_N, -(1<<(MLDSA_D-1)) + 1, (1<<(MLDSA_D-1)) + 1)))
215+
ensures(forall(k2, 0, MLDSA_K, array_bound(t1->vec[k2].coeffs, 0, MLDSA_N, 0, 1 << 10)))
216+
)
217+
{
218+
mld_polyvecl mat[MLDSA_K], s1hat;
219+
mld_polyveck t;
220+
221+
/* Expand matrix */
222+
mld_polyvec_matrix_expand(mat, rho);
223+
224+
/* Matrix-vector multiplication */
225+
s1hat = *s1;
226+
mld_polyvecl_ntt(&s1hat);
227+
mld_polyvec_matrix_pointwise_montgomery(&t, mat, &s1hat);
228+
mld_polyveck_reduce(&t);
229+
mld_polyveck_invntt_tomont(&t);
230+
231+
/* Add error vector s2 */
232+
mld_polyveck_add(&t, s2);
233+
234+
/* Reference: The following reduction is not present in the reference
235+
* implementation. Omitting this reduction requires the output of
236+
* the invntt to be small enough such that the addition of s2 does
237+
* not result in absolute values >= MLDSA_Q. While our C, x86_64,
238+
* and AArch64 invntt implementations produce small enough
239+
* values for this to work out, it complicates the bounds
240+
* reasoning. We instead add an additional reduction, and can
241+
* consequently, relax the bounds requirements for the invntt.
242+
*/
243+
mld_polyveck_reduce(&t);
244+
245+
/* Decompose to get t1, t0 */
246+
mld_polyveck_caddq(&t);
247+
mld_polyveck_power2round(t1, t0, &t);
248+
249+
/* Pack public key and compute tr */
250+
mld_pack_pk(pk, rho, t1);
251+
mld_shake256(tr, MLDSA_TRBYTES, pk, CRYPTO_PUBLICKEYBYTES);
252+
253+
/* @[FIPS204, Section 3.6.3] Destruction of intermediate values. */
254+
mld_zeroize(mat, sizeof(mat));
255+
mld_zeroize(&s1hat, sizeof(s1hat));
256+
mld_zeroize(&t, sizeof(t));
257+
}
258+
177259
MLD_MUST_CHECK_RETURN_VALUE
178260
MLD_EXTERNAL_API
179261
int crypto_sign_keypair_internal(uint8_t pk[CRYPTO_PUBLICKEYBYTES],
@@ -184,9 +266,8 @@ int crypto_sign_keypair_internal(uint8_t pk[CRYPTO_PUBLICKEYBYTES],
184266
MLD_ALIGN uint8_t inbuf[MLDSA_SEEDBYTES + 2];
185267
MLD_ALIGN uint8_t tr[MLDSA_TRBYTES];
186268
const uint8_t *rho, *rhoprime, *key;
187-
mld_polyvecl mat[MLDSA_K];
188-
mld_polyvecl s1, s1hat;
189-
mld_polyveck s2, t2, t1, t0;
269+
mld_polyvecl s1;
270+
mld_polyveck s2, t1, t0;
190271

191272
/* Get randomness for rho, rhoprime and key */
192273
mld_memcpy(inbuf, seed, MLDSA_SEEDBYTES);
@@ -200,50 +281,23 @@ int crypto_sign_keypair_internal(uint8_t pk[CRYPTO_PUBLICKEYBYTES],
200281

201282
/* Constant time: rho is part of the public key and, hence, public. */
202283
MLD_CT_TESTING_DECLASSIFY(rho, MLDSA_SEEDBYTES);
203-
/* Expand matrix */
204-
mld_polyvec_matrix_expand(mat, rho);
205-
mld_sample_s1_s2(&s1, &s2, rhoprime);
206-
207-
/* Matrix-vector multiplication */
208-
s1hat = s1;
209-
mld_polyvecl_ntt(&s1hat);
210-
mld_polyvec_matrix_pointwise_montgomery(&t1, mat, &s1hat);
211-
mld_polyveck_reduce(&t1);
212-
mld_polyveck_invntt_tomont(&t1);
213-
214-
/* Add error vector s2 */
215-
mld_polyveck_add(&t1, &s2);
216284

217-
/* Reference: The following reduction is not present in the reference
218-
* implementation. Omitting this reduction requires the output of
219-
* the invntt to be small enough such that the addition of s2 does
220-
* not result in absolute values >= MLDSA_Q. While our C, x86_64,
221-
* and AArch64 invntt implementations produce small enough
222-
* values for this to work out, it complicates the bounds
223-
* reasoning. We instead add an additional reduction, and can
224-
* consequently, relax the bounds requirements for the invntt.
225-
*/
226-
mld_polyveck_reduce(&t1);
285+
/* Sample s1 and s2 */
286+
mld_sample_s1_s2(&s1, &s2, rhoprime);
227287

228-
/* Extract t1 and write public key */
229-
mld_polyveck_caddq(&t1);
230-
mld_polyveck_power2round(&t2, &t0, &t1);
231-
mld_pack_pk(pk, rho, &t2);
288+
/* Compute t0, t1, tr, and pk from rho, s1, s2 */
289+
mld_compute_t0_t1_tr_from_sk_components(&t0, &t1, tr, pk, rho, &s1, &s2);
232290

233-
/* Compute H(rho, t1) and write secret key */
234-
mld_shake256(tr, MLDSA_TRBYTES, pk, CRYPTO_PUBLICKEYBYTES);
291+
/* Pack secret key */
235292
mld_pack_sk(sk, rho, tr, key, &t0, &s1, &s2);
236293

237294
/* @[FIPS204, Section 3.6.3] Destruction of intermediate values. */
238295
mld_zeroize(seedbuf, sizeof(seedbuf));
239296
mld_zeroize(inbuf, sizeof(inbuf));
240297
mld_zeroize(tr, sizeof(tr));
241-
mld_zeroize(mat, sizeof(mat));
242298
mld_zeroize(&s1, sizeof(s1));
243-
mld_zeroize(&s1hat, sizeof(s1hat));
244299
mld_zeroize(&s2, sizeof(s2));
245300
mld_zeroize(&t1, sizeof(t1));
246-
mld_zeroize(&t2, sizeof(t2));
247301
mld_zeroize(&t0, sizeof(t0));
248302

249303
/* Constant time: pk is the public key, inherently public data */
@@ -1130,6 +1184,62 @@ size_t mld_prepare_domain_separation_prefix(
11301184
return 2 + ctxlen + MLD_PRE_HASH_OID_LEN + phlen;
11311185
}
11321186

1187+
MLD_EXTERNAL_API
1188+
int crypto_sign_pk_from_sk(uint8_t pk[CRYPTO_PUBLICKEYBYTES],
1189+
const uint8_t sk[CRYPTO_SECRETKEYBYTES])
1190+
{
1191+
MLD_ALIGN uint8_t rho[MLDSA_SEEDBYTES];
1192+
MLD_ALIGN uint8_t tr[MLDSA_TRBYTES];
1193+
MLD_ALIGN uint8_t tr_computed[MLDSA_TRBYTES];
1194+
MLD_ALIGN uint8_t key[MLDSA_SEEDBYTES];
1195+
mld_polyvecl s1;
1196+
mld_polyveck s2, t0, t0_computed, t1;
1197+
int res;
1198+
1199+
/* Unpack secret key */
1200+
mld_unpack_sk(rho, tr, key, &t0, &s1, &s2, sk);
1201+
1202+
/* Recompute t0, t1, tr, and pk from rho, s1, s2 */
1203+
mld_compute_t0_t1_tr_from_sk_components(&t0_computed, &t1, tr_computed, pk,
1204+
rho, &s1, &s2);
1205+
1206+
/* Declassify public key */
1207+
MLD_CT_TESTING_DECLASSIFY(pk, CRYPTO_PUBLICKEYBYTES);
1208+
1209+
/* Validate t0 using constant-time comparison */
1210+
res = mld_ct_memcmp(&t0, &t0_computed, sizeof(mld_polyveck));
1211+
/* Declassify comparison result */
1212+
MLD_CT_TESTING_DECLASSIFY(&res, sizeof(res));
1213+
if (res != 0)
1214+
{
1215+
res = -1;
1216+
goto cleanup;
1217+
}
1218+
1219+
/* Validate tr using constant-time comparison */
1220+
res = mld_ct_memcmp(tr, tr_computed, MLDSA_TRBYTES);
1221+
/* Declassify comparison result */
1222+
MLD_CT_TESTING_DECLASSIFY(&res, sizeof(res));
1223+
if (res != 0)
1224+
{
1225+
res = -1;
1226+
goto cleanup;
1227+
}
1228+
1229+
res = 0;
1230+
1231+
cleanup:
1232+
/* @[FIPS204, Section 3.6.3] Destruction of intermediate values. */
1233+
mld_zeroize(&s1, sizeof(s1));
1234+
mld_zeroize(&s2, sizeof(s2));
1235+
mld_zeroize(&t0, sizeof(t0));
1236+
mld_zeroize(&t0_computed, sizeof(t0_computed));
1237+
mld_zeroize(key, sizeof(key));
1238+
mld_zeroize(tr_computed, sizeof(tr_computed));
1239+
1240+
return res;
1241+
}
1242+
11331243
/* To facilitate single-compilation-unit (SCU) builds, undefine all macros.
11341244
* Don't modify by hand -- this is auto-generated by scripts/autogen. */
11351245
#undef mld_check_pct
@@ -1138,5 +1248,6 @@ size_t mld_prepare_domain_separation_prefix(
11381248
#undef mld_get_hash_oid
11391249
#undef mld_H
11401250
#undef mld_attempt_signature_generation
1251+
#undef mld_compute_t0_t1_tr_from_sk_components
11411252
#undef NONCE_UB
11421253
#undef MLD_PRE_HASH_OID_LEN

mldsa/src/sign.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
MLD_NAMESPACE_KL(verify_pre_hash_shake256)
6565
#define mld_prepare_domain_separation_prefix \
6666
MLD_NAMESPACE_KL(prepare_domain_separation_prefix)
67+
#define crypto_sign_pk_from_sk MLD_NAMESPACE_KL(pk_from_sk)
6768

6869
/*************************************************
6970
* Hash algorithm constants for domain separation
@@ -686,4 +687,25 @@ __contract__(
686687
ensures(return_value <= MLD_DOMAIN_SEPARATION_MAX_BYTES)
687688
);
688689

690+
/*************************************************
691+
* Name: crypto_sign_pk_from_sk
692+
*
693+
* Description: Derives public key from secret key with validation.
694+
* Checks that t0 and tr stored in sk match recomputed values.
695+
*
696+
* Arguments: - uint8_t pk[CRYPTO_PUBLICKEYBYTES]: output public key
697+
* - const uint8_t sk[CRYPTO_SECRETKEYBYTES]: input secret key
698+
*
699+
* Returns 0 on success, -1 if validation fails (corrupted secret key)
700+
**************************************************/
701+
MLD_MUST_CHECK_RETURN_VALUE
702+
MLD_EXTERNAL_API
703+
int crypto_sign_pk_from_sk(uint8_t pk[CRYPTO_PUBLICKEYBYTES],
704+
const uint8_t sk[CRYPTO_SECRETKEYBYTES])
705+
__contract__(
706+
requires(memory_no_alias(pk, CRYPTO_PUBLICKEYBYTES))
707+
requires(memory_no_alias(sk, CRYPTO_SECRETKEYBYTES))
708+
assigns(memory_slice(pk, CRYPTO_PUBLICKEYBYTES))
709+
ensures(return_value == 0 || return_value == -1)
710+
);
689711
#endif /* !MLD_SIGN_H */

0 commit comments

Comments
 (0)