Skip to content

Commit 5e8d018

Browse files
authored
Merge pull request #8659 from kojiws/improve_mldsa_priv_key_import
Improve ML-DSA private key import and the test
2 parents f458930 + c05c827 commit 5e8d018

File tree

5 files changed

+284
-90
lines changed

5 files changed

+284
-90
lines changed

src/ssl_load.c

+53-81
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,9 @@ static int ProcessBufferTryDecodeDilithium(WOLFSSL_CTX* ctx, WOLFSSL* ssl,
946946
int ret;
947947
word32 idx;
948948
dilithium_key* key;
949+
int keyFormatTemp = 0;
950+
int keyTypeTemp;
951+
int keySizeTemp;
949952

950953
/* Allocate a Dilithium key to parse into. */
951954
key = (dilithium_key*)XMALLOC(sizeof(dilithium_key), heap,
@@ -956,105 +959,74 @@ static int ProcessBufferTryDecodeDilithium(WOLFSSL_CTX* ctx, WOLFSSL* ssl,
956959

957960
/* Initialize Dilithium key. */
958961
ret = wc_dilithium_init(key);
959-
if (ret == 0) {
960-
/* Set up key to parse the format specified. */
961-
if ((*keyFormat == ML_DSA_LEVEL2k) || ((*keyFormat == 0) &&
962-
((der->length == ML_DSA_LEVEL2_KEY_SIZE) ||
963-
(der->length == ML_DSA_LEVEL2_PRV_KEY_SIZE)))) {
964-
ret = wc_dilithium_set_level(key, WC_ML_DSA_44);
965-
}
966-
else if ((*keyFormat == ML_DSA_LEVEL3k) || ((*keyFormat == 0) &&
967-
((der->length == ML_DSA_LEVEL3_KEY_SIZE) ||
968-
(der->length == ML_DSA_LEVEL3_PRV_KEY_SIZE)))) {
969-
ret = wc_dilithium_set_level(key, WC_ML_DSA_65);
970-
}
971-
else if ((*keyFormat == ML_DSA_LEVEL5k) || ((*keyFormat == 0) &&
972-
((der->length == ML_DSA_LEVEL5_KEY_SIZE) ||
973-
(der->length == ML_DSA_LEVEL5_PRV_KEY_SIZE)))) {
974-
ret = wc_dilithium_set_level(key, WC_ML_DSA_87);
975-
}
976-
#ifdef WOLFSSL_DILITHIUM_FIPS204_DRAFT
977-
else if ((*keyFormat == DILITHIUM_LEVEL2k) || ((*keyFormat == 0) &&
978-
((der->length == DILITHIUM_LEVEL2_KEY_SIZE) ||
979-
(der->length == DILITHIUM_LEVEL2_PRV_KEY_SIZE)))) {
980-
ret = wc_dilithium_set_level(key, WC_ML_DSA_44_DRAFT);
981-
}
982-
else if ((*keyFormat == DILITHIUM_LEVEL3k) || ((*keyFormat == 0) &&
983-
((der->length == DILITHIUM_LEVEL3_KEY_SIZE) ||
984-
(der->length == DILITHIUM_LEVEL3_PRV_KEY_SIZE)))) {
985-
ret = wc_dilithium_set_level(key, WC_ML_DSA_65_DRAFT);
986-
}
987-
else if ((*keyFormat == DILITHIUM_LEVEL5k) || ((*keyFormat == 0) &&
988-
((der->length == DILITHIUM_LEVEL5_KEY_SIZE) ||
989-
(der->length == DILITHIUM_LEVEL5_PRV_KEY_SIZE)))) {
990-
ret = wc_dilithium_set_level(key, WC_ML_DSA_87_DRAFT);
991-
}
992-
#endif /* WOLFSSL_DILITHIUM_FIPS204_DRAFT */
993-
else {
994-
wc_dilithium_free(key);
995-
ret = ALGO_ID_E;
996-
}
997-
}
998-
999962
if (ret == 0) {
1000963
/* Decode as a Dilithium private key. */
1001964
idx = 0;
1002965
ret = wc_Dilithium_PrivateKeyDecode(der->buffer, &idx, key, der->length);
1003966
if (ret == 0) {
1004-
/* Get the minimum Dilithium key size from SSL or SSL context
1005-
* object. */
1006-
int minKeySz = ssl ? ssl->options.minDilithiumKeySz :
1007-
ctx->minDilithiumKeySz;
1008-
1009-
/* Format is known. */
1010-
if (*keyFormat == ML_DSA_LEVEL2k) {
1011-
*keyType = dilithium_level2_sa_algo;
1012-
*keySize = ML_DSA_LEVEL2_KEY_SIZE;
1013-
}
1014-
else if (*keyFormat == ML_DSA_LEVEL3k) {
1015-
*keyType = dilithium_level3_sa_algo;
1016-
*keySize = ML_DSA_LEVEL3_KEY_SIZE;
1017-
}
1018-
else if (*keyFormat == ML_DSA_LEVEL5k) {
1019-
*keyType = dilithium_level5_sa_algo;
1020-
*keySize = ML_DSA_LEVEL5_KEY_SIZE;
1021-
}
1022-
#ifdef WOLFSSL_DILITHIUM_FIPS204_DRAFT
1023-
else if (*keyFormat == DILITHIUM_LEVEL2k) {
1024-
*keyType = dilithium_level2_sa_algo;
1025-
*keySize = DILITHIUM_LEVEL2_KEY_SIZE;
1026-
}
1027-
else if (*keyFormat == DILITHIUM_LEVEL3k) {
1028-
*keyType = dilithium_level3_sa_algo;
1029-
*keySize = DILITHIUM_LEVEL3_KEY_SIZE;
967+
ret = dilithium_get_oid_sum(key, &keyFormatTemp);
968+
if (ret == 0) {
969+
/* Format is known. */
970+
#if defined(WOLFSSL_DILITHIUM_FIPS204_DRAFT)
971+
if (keyFormatTemp == DILITHIUM_LEVEL2k) {
972+
keyTypeTemp = dilithium_level2_sa_algo;
973+
keySizeTemp = DILITHIUM_LEVEL2_KEY_SIZE;
974+
}
975+
else if (keyFormatTemp == DILITHIUM_LEVEL3k) {
976+
keyTypeTemp = dilithium_level3_sa_algo;
977+
keySizeTemp = DILITHIUM_LEVEL3_KEY_SIZE;
978+
}
979+
else if (keyFormatTemp == DILITHIUM_LEVEL5k) {
980+
keyTypeTemp = dilithium_level5_sa_algo;
981+
keySizeTemp = DILITHIUM_LEVEL5_KEY_SIZE;
982+
}
983+
else
984+
#endif /* WOLFSSL_DILITHIUM_FIPS204_DRAFT */
985+
if (keyFormatTemp == ML_DSA_LEVEL2k) {
986+
keyTypeTemp = dilithium_level2_sa_algo;
987+
keySizeTemp = ML_DSA_LEVEL2_KEY_SIZE;
988+
}
989+
else if (keyFormatTemp == ML_DSA_LEVEL3k) {
990+
keyTypeTemp = dilithium_level3_sa_algo;
991+
keySizeTemp = ML_DSA_LEVEL3_KEY_SIZE;
992+
}
993+
else if (keyFormatTemp == ML_DSA_LEVEL5k) {
994+
keyTypeTemp = dilithium_level5_sa_algo;
995+
keySizeTemp = ML_DSA_LEVEL5_KEY_SIZE;
996+
}
997+
else {
998+
ret = ALGO_ID_E;
999+
}
10301000
}
1031-
else if (*keyFormat == DILITHIUM_LEVEL5k) {
1032-
*keyType = dilithium_level5_sa_algo;
1033-
*keySize = DILITHIUM_LEVEL5_KEY_SIZE;
1001+
1002+
if (ret == 0) {
1003+
/* Get the minimum Dilithium key size from SSL or SSL context
1004+
* object. */
1005+
int minKeySz = ssl ? ssl->options.minDilithiumKeySz :
1006+
ctx->minDilithiumKeySz;
1007+
1008+
/* Check that the size of the Dilithium key is enough. */
1009+
if (keySizeTemp < minKeySz) {
1010+
WOLFSSL_MSG("Dilithium private key too small");
1011+
ret = DILITHIUM_KEY_SIZE_E;
1012+
}
10341013
}
1035-
#endif /* WOLFSSL_DILITHIUM_FIPS204_DRAFT */
10361014

1037-
/* Check that the size of the Dilithium key is enough. */
1038-
if (*keySize < minKeySz) {
1039-
WOLFSSL_MSG("Dilithium private key too small");
1040-
ret = DILITHIUM_KEY_SIZE_E;
1015+
if (ret == 0) {
1016+
*keyFormat = keyFormatTemp;
1017+
*keyType = keyTypeTemp;
1018+
*keySize = keySizeTemp;
10411019
}
10421020
}
1043-
/* Not a Dilithium key but check whether we know what it is. */
10441021
else if (*keyFormat == 0) {
10451022
WOLFSSL_MSG("Not a Dilithium key");
1046-
/* Format unknown so keep trying. */
1023+
/* Unknown format wasn't dilithium, so keep trying other formats. */
10471024
ret = 0;
10481025
}
10491026

10501027
/* Free dynamically allocated data in key. */
10511028
wc_dilithium_free(key);
10521029
}
1053-
else if ((ret == WC_NO_ERR_TRACE(ALGO_ID_E)) && (*keyFormat == 0)) {
1054-
WOLFSSL_MSG("Not a Dilithium key");
1055-
/* Format unknown so keep trying. */
1056-
ret = 0;
1057-
}
10581030

10591031
/* Dispose of allocated key. */
10601032
XFREE(key, heap, DYNAMIC_TYPE_DILITHIUM);

tests/api/test_mldsa.c

+128-1
Original file line numberDiff line numberDiff line change
@@ -2959,7 +2959,7 @@ int test_wc_dilithium_der(void)
29592959
idx = 0;
29602960
#ifdef WOLFSSL_DILITHIUM_FIPS204_DRAFT
29612961
ExpectIntEQ(wc_Dilithium_PrivateKeyDecode(der, &idx, key, privDerLen),
2962-
WC_NO_ERR_TRACE(BAD_FUNC_ARG));
2962+
WC_NO_ERR_TRACE(ASN_PARSE_E));
29632963
#else
29642964
ExpectIntEQ(wc_Dilithium_PrivateKeyDecode(der, &idx, key, privDerLen),
29652965
WC_NO_ERR_TRACE(ASN_PARSE_E));
@@ -16658,3 +16658,130 @@ int test_wc_dilithium_verify_kats(void)
1665816658
return EXPECT_RESULT();
1665916659
}
1666016660

16661+
int test_mldsa_pkcs8(void)
16662+
{
16663+
EXPECT_DECLS;
16664+
#if !defined(NO_ASN) && defined(HAVE_PKCS8) && \
16665+
defined(HAVE_DILITHIUM) && !defined(NO_TLS) && \
16666+
(!defined(NO_WOLFSSL_CLIENT) || !defined(NO_WOLFSSL_SERVER))
16667+
16668+
WOLFSSL_CTX* ctx = NULL;
16669+
size_t i;
16670+
const int derMaxSz = DILITHIUM_MAX_BOTH_KEY_DER_SIZE;
16671+
const int tempMaxSz = DILITHIUM_MAX_BOTH_KEY_PEM_SIZE;
16672+
byte* der = NULL;
16673+
byte* temp = NULL; /* Store PEM or intermediate key */
16674+
word32 derSz = 0;
16675+
word32 pemSz = 0;
16676+
word32 keySz = 0;
16677+
dilithium_key mldsa_key;
16678+
WC_RNG rng;
16679+
word32 size;
16680+
16681+
struct {
16682+
int wcId;
16683+
int oidSum;
16684+
int keySz;
16685+
} test_variant[] = {
16686+
{WC_ML_DSA_44, ML_DSA_LEVEL2k, ML_DSA_LEVEL2_PRV_KEY_SIZE},
16687+
{WC_ML_DSA_65, ML_DSA_LEVEL3k, ML_DSA_LEVEL3_PRV_KEY_SIZE},
16688+
{WC_ML_DSA_87, ML_DSA_LEVEL5k, ML_DSA_LEVEL5_PRV_KEY_SIZE}
16689+
};
16690+
16691+
(void) pemSz;
16692+
16693+
ExpectNotNull(der = (byte*) XMALLOC(derMaxSz, NULL,
16694+
DYNAMIC_TYPE_TMP_BUFFER));
16695+
ExpectNotNull(temp = (byte*) XMALLOC(tempMaxSz, NULL,
16696+
DYNAMIC_TYPE_TMP_BUFFER));
16697+
16698+
#ifndef NO_WOLFSSL_SERVER
16699+
ExpectNotNull(ctx = wolfSSL_CTX_new(wolfSSLv23_server_method()));
16700+
#else
16701+
ExpectNotNull(ctx = wolfSSL_CTX_new(wolfSSLv23_client_method()));
16702+
#endif /* NO_WOLFSSL_SERVER */
16703+
16704+
ExpectIntEQ(wc_InitRng(&rng), 0);
16705+
ExpectIntEQ(wc_dilithium_init(&mldsa_key), 0);
16706+
16707+
/* Test private + public key (separated format) */
16708+
for (i = 0; i < sizeof(test_variant) / sizeof(test_variant[0]); ++i) {
16709+
ExpectIntEQ(wc_dilithium_set_level(&mldsa_key,
16710+
test_variant[i].wcId), 0);
16711+
ExpectIntEQ(wc_dilithium_make_key(&mldsa_key, &rng), 0);
16712+
16713+
ExpectIntGT(derSz = wc_Dilithium_KeyToDer(&mldsa_key, der, derMaxSz),
16714+
0);
16715+
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, der, derSz,
16716+
WOLFSSL_FILETYPE_ASN1), WOLFSSL_SUCCESS);
16717+
16718+
#ifdef WOLFSSL_DER_TO_PEM
16719+
ExpectIntGT(pemSz = wc_DerToPem(der, derSz, temp, tempMaxSz,
16720+
PKCS8_PRIVATEKEY_TYPE), 0);
16721+
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, temp, pemSz,
16722+
WOLFSSL_FILETYPE_PEM), WOLFSSL_SUCCESS);
16723+
#endif /* WOLFSSL_DER_TO_PEM */
16724+
}
16725+
16726+
/* Test private key only */
16727+
for (i = 0; i < sizeof(test_variant) / sizeof(test_variant[0]); ++i) {
16728+
ExpectIntEQ(wc_dilithium_set_level(&mldsa_key, test_variant[i].wcId),
16729+
0);
16730+
ExpectIntEQ(wc_dilithium_make_key(&mldsa_key, &rng), 0);
16731+
16732+
ExpectIntGT(derSz = wc_Dilithium_PrivateKeyToDer(&mldsa_key, der,
16733+
derMaxSz), 0);
16734+
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, der, derSz,
16735+
WOLFSSL_FILETYPE_ASN1), WOLFSSL_SUCCESS);
16736+
16737+
#ifdef WOLFSSL_DER_TO_PEM
16738+
ExpectIntGT(pemSz = wc_DerToPem(der, derSz, temp, tempMaxSz,
16739+
PKCS8_PRIVATEKEY_TYPE), 0);
16740+
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, temp, pemSz,
16741+
WOLFSSL_FILETYPE_PEM), WOLFSSL_SUCCESS);
16742+
#endif /* WOLFSSL_DER_TO_PEM */
16743+
}
16744+
16745+
/* Test private + public key (integrated format) */
16746+
for (i = 0; i < sizeof(test_variant) / sizeof(test_variant[0]); ++i) {
16747+
ExpectIntEQ(wc_dilithium_set_level(&mldsa_key, test_variant[i].wcId),
16748+
0);
16749+
ExpectIntEQ(wc_dilithium_make_key(&mldsa_key, &rng), 0);
16750+
16751+
keySz = 0;
16752+
temp[0] = 0x04; /* ASN.1 OCTET STRING */
16753+
temp[1] = 0x82; /* 2 bytes length field */
16754+
temp[2] = (test_variant[i].keySz >> 8) & 0xff; /* MSB of the length */
16755+
temp[3] = test_variant[i].keySz & 0xff; /* LSB of the length */
16756+
keySz += 4;
16757+
size = tempMaxSz - keySz;
16758+
ExpectIntEQ(wc_dilithium_export_private(&mldsa_key, temp + keySz,
16759+
&size), 0);
16760+
keySz += size;
16761+
size = tempMaxSz - keySz;
16762+
ExpectIntEQ(wc_dilithium_export_public(&mldsa_key, temp + keySz, &size),
16763+
0);
16764+
keySz += size;
16765+
derSz = derMaxSz;
16766+
ExpectIntGT(wc_CreatePKCS8Key(der, &derSz, temp, keySz,
16767+
test_variant[i].oidSum, NULL, 0), 0);
16768+
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, der, derSz,
16769+
WOLFSSL_FILETYPE_ASN1), WOLFSSL_SUCCESS);
16770+
16771+
#ifdef WOLFSSL_DER_TO_PEM
16772+
ExpectIntGT(pemSz = wc_DerToPem(der, derSz, temp, tempMaxSz,
16773+
PKCS8_PRIVATEKEY_TYPE), 0);
16774+
ExpectIntEQ(wolfSSL_CTX_use_PrivateKey_buffer(ctx, temp, pemSz,
16775+
WOLFSSL_FILETYPE_PEM), WOLFSSL_SUCCESS);
16776+
#endif /* WOLFSSL_DER_TO_PEM */
16777+
}
16778+
16779+
wc_dilithium_free(&mldsa_key);
16780+
ExpectIntEQ(wc_FreeRng(&rng), 0);
16781+
wolfSSL_CTX_free(ctx);
16782+
XFREE(temp, NULL, DYNAMIC_TYPE_TMP_BUFFER);
16783+
XFREE(der, NULL, DYNAMIC_TYPE_TMP_BUFFER);
16784+
16785+
#endif
16786+
return EXPECT_RESULT();
16787+
}

tests/api/test_mldsa.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ int test_wc_dilithium_der(void);
3535
int test_wc_dilithium_make_key_from_seed(void);
3636
int test_wc_dilithium_sig_kats(void);
3737
int test_wc_dilithium_verify_kats(void);
38+
int test_mldsa_pkcs8(void);
3839

3940
#define TEST_MLDSA_DECLS \
4041
TEST_DECL_GROUP("mldsa", test_wc_dilithium), \
@@ -47,6 +48,7 @@ int test_wc_dilithium_verify_kats(void);
4748
TEST_DECL_GROUP("mldsa", test_wc_dilithium_der), \
4849
TEST_DECL_GROUP("mldsa", test_wc_dilithium_make_key_from_seed), \
4950
TEST_DECL_GROUP("mldsa", test_wc_dilithium_sig_kats), \
50-
TEST_DECL_GROUP("mldsa", test_wc_dilithium_verify_kats)
51+
TEST_DECL_GROUP("mldsa", test_wc_dilithium_verify_kats), \
52+
TEST_DECL_GROUP("mldsa", test_mldsa_pkcs8)
5153

5254
#endif /* WOLFCRYPT_TEST_MLDSA_H */

0 commit comments

Comments
 (0)