Skip to content

Commit

Permalink
feat: add async cert validation support (#5110)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarolYeh910 authored Feb 20, 2025
1 parent 7ab8cd0 commit f8904b1
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 39 deletions.
12 changes: 8 additions & 4 deletions api/unstable/crl.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,16 @@ struct s2n_cert_validation_info;
*
* If the validation performed in the callback is successful, `s2n_cert_validation_accept()` MUST be called to allow
* `s2n_negotiate()` to continue the handshake. If the validation is unsuccessful, `s2n_cert_validation_reject()`
* MUST be called, which will cause `s2n_negotiate()` to error. The behavior of `s2n_negotiate()` is undefined if
* neither `s2n_cert_validation_accept()` or `s2n_cert_validation_reject()` are called.
* MUST be called, which will cause `s2n_negotiate()` to error.
*
* To use the validation callback asynchronously, return `S2N_SUCCESS` without calling `s2n_cert_validation_accept()`
* or `s2n_cert_validation_reject()`. This will pause the handshake, and `s2n_negotiate()` will throw an `S2N_ERR_T_BLOCKED`
* error and `s2n_blocked_status` will be set to `S2N_BLOCKED_ON_APPLICATION_INPUT`. Applications should call
* `s2n_cert_validation_accept()` or `s2n_cert_validation_reject()` to unpause the handshake before retrying `s2n_negotiate()`.
*
* The `info` parameter is passed to the callback in order to call APIs specific to the cert validation callback, like
* `s2n_cert_validation_accept()` and `s2n_cert_validation_reject()`. The `info` argument is only valid for the
* lifetime of the callback, and must not be used after the callback has finished.
* `s2n_cert_validation_accept()` and `s2n_cert_validation_reject()`. The `info` argument shares the same lifetime as
* `s2n_connection`.
*
* After calling `s2n_cert_validation_reject()`, `s2n_negotiate()` will fail with a protocol error indicating that
* the cert has been rejected from the callback. If more information regarding an application's custom validation
Expand Down
146 changes: 136 additions & 10 deletions tests/unit/s2n_cert_validation_callback_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ struct s2n_cert_validation_data {
unsigned return_success : 1;

int invoked_count;
struct s2n_cert_validation_info *info;
};

static int s2n_test_cert_validation_callback(struct s2n_connection *conn, struct s2n_cert_validation_info *info, void *ctx)
{
struct s2n_cert_validation_data *data = (struct s2n_cert_validation_data *) ctx;

data->invoked_count += 1;
/* Pass the `s2n_cert_validation_info` struct to application-defined `ctx` */
data->info = info;

int ret = S2N_FAILURE;
if (data->return_success) {
Expand Down Expand Up @@ -187,16 +190,6 @@ int main(int argc, char *argv[])
.data = { .call_accept_or_reject = true, .accept = false, .return_success = false },
.expected_error = S2N_ERR_CANCELLED
},
{
.data = { .call_accept_or_reject = false, .return_success = false },
.expected_error = S2N_ERR_CANCELLED
},

/* Error if accept or reject wasn't called from the callback */
{
.data = { .call_accept_or_reject = false, .return_success = true },
.expected_error = S2N_ERR_INVALID_STATE
},
};
/* clang-format on */

Expand Down Expand Up @@ -444,6 +437,139 @@ int main(int argc, char *argv[])

EXPECT_EQUAL(data.invoked_count, 1);
}

/* For async cases, accept or reject API will be called outside of the validation callback.
* Iterate over both TLS 1.3 and 1.2 policies to ensure the stuffer reset logic works in all cases.
*/
struct s2n_cert_validation_data async_test_cases[] = {
{ .call_accept_or_reject = false, .accept = true, .return_success = true },
{ .call_accept_or_reject = false, .accept = false, .return_success = true },
};
const char *versions[] = { "20240501", "20170210" };

/* Async callback is invoked on the client after receiving the server's certificate */
for (int test_case_idx = 0; test_case_idx < s2n_array_len(async_test_cases); test_case_idx++) {
for (int version_idx = 0; version_idx < s2n_array_len(versions); version_idx++) {
DEFER_CLEANUP(struct s2n_config *config = s2n_config_new(), s2n_config_ptr_free);
EXPECT_NOT_NULL(config);
EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key));
EXPECT_SUCCESS(s2n_config_set_verification_ca_location(config, S2N_DEFAULT_TEST_CERT_CHAIN, NULL));
EXPECT_SUCCESS(s2n_config_set_cipher_preferences(config, versions[version_idx]));

struct s2n_cert_validation_data data = async_test_cases[test_case_idx];
EXPECT_SUCCESS(s2n_config_set_cert_validation_cb(config, s2n_test_cert_validation_callback_self_talk, &data));

DEFER_CLEANUP(struct s2n_connection *server_conn = s2n_connection_new(S2N_SERVER), s2n_connection_ptr_free);
EXPECT_NOT_NULL(server_conn);
EXPECT_SUCCESS(s2n_connection_set_config(server_conn, config));

DEFER_CLEANUP(struct s2n_connection *client_conn = s2n_connection_new(S2N_CLIENT), s2n_connection_ptr_free);
EXPECT_NOT_NULL(client_conn);
EXPECT_SUCCESS(s2n_connection_set_config(client_conn, config));
EXPECT_SUCCESS(s2n_connection_set_blinding(client_conn, S2N_SELF_SERVICE_BLINDING));
EXPECT_SUCCESS(s2n_set_server_name(client_conn, "localhost"));

DEFER_CLEANUP(struct s2n_test_io_pair io_pair = { 0 }, s2n_io_pair_close);
EXPECT_SUCCESS(s2n_io_pair_init_non_blocking(&io_pair));
EXPECT_SUCCESS(s2n_connection_set_io_pair(client_conn, &io_pair));
EXPECT_SUCCESS(s2n_connection_set_io_pair(server_conn, &io_pair));

for (int i = 0; i < 3; i++) {
EXPECT_FAILURE_WITH_ERRNO(s2n_negotiate_test_server_and_client(server_conn, client_conn),
S2N_ERR_ASYNC_BLOCKED);
EXPECT_EQUAL(data.invoked_count, 1);
}

/* Ensure that the server's certificate chain can be retrieved after `S2N_ERR_ASYNC_BLOCKED` */
DEFER_CLEANUP(struct s2n_cert_chain_and_key *peer_cert_chain = s2n_cert_chain_and_key_new(),
s2n_cert_chain_and_key_ptr_free);
EXPECT_NOT_NULL(peer_cert_chain);
EXPECT_SUCCESS(s2n_connection_get_peer_cert_chain(client_conn, peer_cert_chain));
/* Ensure the certificate chain is non-empty */
uint32_t peer_cert_chain_len = 0;
EXPECT_SUCCESS(s2n_cert_chain_get_length(peer_cert_chain, &peer_cert_chain_len));
EXPECT_TRUE(peer_cert_chain_len > 0);

struct s2n_cert_validation_info *info = data.info;
EXPECT_NOT_NULL(info);

if (async_test_cases[test_case_idx].accept) {
EXPECT_SUCCESS(s2n_cert_validation_accept(info));
EXPECT_SUCCESS(s2n_negotiate_test_server_and_client(server_conn, client_conn));
} else {
EXPECT_SUCCESS(s2n_cert_validation_reject(info));
EXPECT_FAILURE_WITH_ERRNO(s2n_negotiate_test_server_and_client(server_conn, client_conn),
S2N_ERR_CERT_REJECTED);
}

EXPECT_EQUAL(data.invoked_count, 1);
}
}

/* Async callback is invoked on the server after receiving the client's certificate */
for (int test_case_idx = 0; test_case_idx < s2n_array_len(async_test_cases); test_case_idx++) {
for (int version_idx = 0; version_idx < s2n_array_len(versions); version_idx++) {
DEFER_CLEANUP(struct s2n_config *server_config = s2n_config_new(), s2n_config_ptr_free);
EXPECT_NOT_NULL(server_config);
EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(server_config, chain_and_key));
EXPECT_SUCCESS(s2n_config_set_verification_ca_location(server_config, S2N_DEFAULT_TEST_CERT_CHAIN, NULL));
EXPECT_SUCCESS(s2n_config_set_cipher_preferences(server_config, versions[version_idx]));
EXPECT_SUCCESS(s2n_config_set_client_auth_type(server_config, S2N_CERT_AUTH_REQUIRED));

struct s2n_cert_validation_data data = async_test_cases[test_case_idx];
EXPECT_SUCCESS(s2n_config_set_cert_validation_cb(server_config,
s2n_test_cert_validation_callback_self_talk_server, &data));

DEFER_CLEANUP(struct s2n_connection *server_conn = s2n_connection_new(S2N_SERVER), s2n_connection_ptr_free);
EXPECT_NOT_NULL(server_conn);
EXPECT_SUCCESS(s2n_connection_set_config(server_conn, server_config));
EXPECT_SUCCESS(s2n_connection_set_blinding(server_conn, S2N_SELF_SERVICE_BLINDING));

DEFER_CLEANUP(struct s2n_config *client_config = s2n_config_new(), s2n_config_ptr_free);
EXPECT_NOT_NULL(client_config);
EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(client_config, chain_and_key));
EXPECT_SUCCESS(s2n_config_set_verification_ca_location(client_config, S2N_DEFAULT_TEST_CERT_CHAIN, NULL));
EXPECT_SUCCESS(s2n_config_set_cipher_preferences(client_config, versions[version_idx]));
EXPECT_SUCCESS(s2n_config_set_client_auth_type(client_config, S2N_CERT_AUTH_OPTIONAL));

DEFER_CLEANUP(struct s2n_connection *client_conn = s2n_connection_new(S2N_CLIENT), s2n_connection_ptr_free);
EXPECT_NOT_NULL(client_conn);
EXPECT_SUCCESS(s2n_connection_set_config(client_conn, client_config));
EXPECT_SUCCESS(s2n_set_server_name(client_conn, "localhost"));

DEFER_CLEANUP(struct s2n_test_io_pair io_pair = { 0 }, s2n_io_pair_close);
EXPECT_SUCCESS(s2n_io_pair_init_non_blocking(&io_pair));
EXPECT_SUCCESS(s2n_connection_set_io_pair(client_conn, &io_pair));
EXPECT_SUCCESS(s2n_connection_set_io_pair(server_conn, &io_pair));

for (int i = 0; i < 3; i++) {
EXPECT_FAILURE_WITH_ERRNO(s2n_negotiate_test_server_and_client(server_conn, client_conn),
S2N_ERR_ASYNC_BLOCKED);
EXPECT_EQUAL(data.invoked_count, 1);
}

/* Ensure that the client's certificate chain can be retrieved after `S2N_ERR_ASYNC_BLOCKED` */
uint8_t *der_cert_chain = 0;
uint32_t cert_chain_len = 0;
EXPECT_SUCCESS(s2n_connection_get_client_cert_chain(server_conn, &der_cert_chain, &cert_chain_len));
/* Ensure the certificate chain is non-empty */
EXPECT_TRUE(cert_chain_len > 0);

struct s2n_cert_validation_info *info = data.info;
EXPECT_NOT_NULL(info);

if (async_test_cases[test_case_idx].accept) {
EXPECT_SUCCESS(s2n_cert_validation_accept(info));
EXPECT_SUCCESS(s2n_negotiate_test_server_and_client(server_conn, client_conn));
} else {
EXPECT_SUCCESS(s2n_cert_validation_reject(info));
EXPECT_FAILURE_WITH_ERRNO(s2n_negotiate_test_server_and_client(server_conn, client_conn),
S2N_ERR_CERT_REJECTED);
}

EXPECT_EQUAL(data.invoked_count, 1);
}
}
}

END_TEST();
Expand Down
26 changes: 18 additions & 8 deletions tls/s2n_client_cert.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,12 @@ static S2N_RESULT s2n_client_cert_chain_store(struct s2n_connection *conn,
RESULT_ENSURE_REF(conn);
RESULT_ENSURE_REF(raw_cert_chain);

/* There shouldn't already be a client cert chain, but free just in case */
RESULT_GUARD_POSIX(s2n_free(&conn->handshake_params.client_cert_chain));
/* If a client cert chain has already been stored (e.g. on the re-entry case
* of an async callback), no need to store it again.
*/
if (conn->handshake_params.client_cert_chain.size > 0) {
return S2N_RESULT_OK;
}

/* Earlier versions are a basic copy */
if (conn->actual_protocol_version < S2N_TLS13) {
Expand Down Expand Up @@ -101,23 +105,26 @@ static S2N_RESULT s2n_client_cert_chain_store(struct s2n_connection *conn,

int s2n_client_cert_recv(struct s2n_connection *conn)
{
/* s2n_client_cert_recv() may be re-entered due to handling an async callback.
* We operate on a copy of `handshake.io` to ensure the stuffer is initilized properly on the re-entry case.
*/
struct s2n_stuffer in = conn->handshake.io;

if (conn->actual_protocol_version == S2N_TLS13) {
uint8_t certificate_request_context_len = 0;
POSIX_GUARD(s2n_stuffer_read_uint8(&conn->handshake.io, &certificate_request_context_len));
POSIX_GUARD(s2n_stuffer_read_uint8(&in, &certificate_request_context_len));
S2N_ERROR_IF(certificate_request_context_len != 0, S2N_ERR_BAD_MESSAGE);
}

struct s2n_stuffer *in = &conn->handshake.io;

uint32_t cert_chain_size = 0;
POSIX_GUARD(s2n_stuffer_read_uint24(in, &cert_chain_size));
POSIX_ENSURE(cert_chain_size <= s2n_stuffer_data_available(in), S2N_ERR_BAD_MESSAGE);
POSIX_GUARD(s2n_stuffer_read_uint24(&in, &cert_chain_size));
POSIX_ENSURE(cert_chain_size <= s2n_stuffer_data_available(&in), S2N_ERR_BAD_MESSAGE);
if (cert_chain_size == 0) {
POSIX_GUARD(s2n_conn_set_handshake_no_client_cert(conn));
return S2N_SUCCESS;
}

uint8_t *cert_chain_data = s2n_stuffer_raw_read(in, cert_chain_size);
uint8_t *cert_chain_data = s2n_stuffer_raw_read(&in, cert_chain_size);
POSIX_ENSURE_REF(cert_chain_data);

struct s2n_blob cert_chain = { 0 };
Expand All @@ -139,6 +146,9 @@ int s2n_client_cert_recv(struct s2n_connection *conn)
POSIX_GUARD(s2n_pkey_check_key_exists(&public_key));
conn->handshake_params.client_public_key = public_key;

/* Update handshake.io to reflect the true stuffer state after all async callbacks are handled. */
conn->handshake.io = in;

return S2N_SUCCESS;
}

Expand Down
16 changes: 12 additions & 4 deletions tls/s2n_server_cert.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,21 @@

int s2n_server_cert_recv(struct s2n_connection *conn)
{
/* s2n_server_cert_recv() may be re-entered due to handling an async callback.
* We operate on a copy of `handshake.io` to ensure the stuffer is initilized properly on the re-entry case.
*/
struct s2n_stuffer in = conn->handshake.io;

if (conn->actual_protocol_version == S2N_TLS13) {
uint8_t certificate_request_context_len = 0;
POSIX_GUARD(s2n_stuffer_read_uint8(&conn->handshake.io, &certificate_request_context_len));
POSIX_GUARD(s2n_stuffer_read_uint8(&in, &certificate_request_context_len));
S2N_ERROR_IF(certificate_request_context_len != 0, S2N_ERR_BAD_MESSAGE);
}

uint32_t size_of_all_certificates = 0;
POSIX_GUARD(s2n_stuffer_read_uint24(&conn->handshake.io, &size_of_all_certificates));
POSIX_GUARD(s2n_stuffer_read_uint24(&in, &size_of_all_certificates));

S2N_ERROR_IF(size_of_all_certificates > s2n_stuffer_data_available(&conn->handshake.io) || size_of_all_certificates < 3,
S2N_ERROR_IF(size_of_all_certificates > s2n_stuffer_data_available(&in) || size_of_all_certificates < 3,
S2N_ERR_BAD_MESSAGE);

s2n_cert_public_key public_key;
Expand All @@ -40,7 +45,7 @@ int s2n_server_cert_recv(struct s2n_connection *conn)
s2n_pkey_type actual_cert_pkey_type;
struct s2n_blob cert_chain = { 0 };
cert_chain.size = size_of_all_certificates;
cert_chain.data = s2n_stuffer_raw_read(&conn->handshake.io, size_of_all_certificates);
cert_chain.data = s2n_stuffer_raw_read(&in, size_of_all_certificates);
POSIX_ENSURE_REF(cert_chain.data);

POSIX_GUARD_RESULT(s2n_x509_validator_validate_cert_chain(&conn->x509_validator, conn, cert_chain.data,
Expand All @@ -50,6 +55,9 @@ int s2n_server_cert_recv(struct s2n_connection *conn)
POSIX_GUARD_RESULT(s2n_pkey_setup_for_type(&public_key, actual_cert_pkey_type));
conn->handshake_params.server_public_key = public_key;

/* Update handshake.io to reflect the true stuffer state after all async callbacks are handled. */
conn->handshake.io = in;

return 0;
}

Expand Down
45 changes: 37 additions & 8 deletions tls/s2n_x509_validator.c
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ int s2n_x509_validator_init_no_x509_validation(struct s2n_x509_validator *valida
validator->state = INIT;
validator->cert_chain_from_wire = sk_X509_new_null();
validator->crl_lookup_list = NULL;
validator->cert_validation_info = (struct s2n_cert_validation_info){ 0 };
validator->cert_validation_cb_invoked = false;

return 0;
}
Expand All @@ -168,6 +170,8 @@ int s2n_x509_validator_init(struct s2n_x509_validator *validator, struct s2n_x50
validator->cert_chain_from_wire = sk_X509_new_null();
validator->state = INIT;
validator->crl_lookup_list = NULL;
validator->cert_validation_info = (struct s2n_cert_validation_info){ 0 };
validator->cert_validation_cb_invoked = false;

return 0;
}
Expand Down Expand Up @@ -750,8 +754,8 @@ static S2N_RESULT s2n_x509_validator_parse_leaf_certificate_extensions(struct s2
return S2N_RESULT_OK;
}

S2N_RESULT s2n_x509_validator_validate_cert_chain(struct s2n_x509_validator *validator, struct s2n_connection *conn,
uint8_t *cert_chain_in, uint32_t cert_chain_len, s2n_pkey_type *pkey_type, struct s2n_pkey *public_key_out)
S2N_RESULT s2n_x509_validator_validate_cert_chain_pre_cb(struct s2n_x509_validator *validator, struct s2n_connection *conn,
uint8_t *cert_chain_in, uint32_t cert_chain_len)
{
RESULT_ENSURE_REF(conn);
RESULT_ENSURE_REF(conn->config);
Expand Down Expand Up @@ -788,12 +792,37 @@ S2N_RESULT s2n_x509_validator_validate_cert_chain(struct s2n_x509_validator *val
RESULT_GUARD_POSIX(s2n_extension_list_process(S2N_EXTENSION_LIST_CERTIFICATE, conn, &first_certificate_extensions));
}

if (conn->config->cert_validation_cb) {
struct s2n_cert_validation_info info = { 0 };
RESULT_ENSURE(conn->config->cert_validation_cb(conn, &info, conn->config->cert_validation_ctx) >= S2N_SUCCESS,
S2N_ERR_CANCELLED);
RESULT_ENSURE(info.finished, S2N_ERR_INVALID_STATE);
RESULT_ENSURE(info.accepted, S2N_ERR_CERT_REJECTED);
return S2N_RESULT_OK;
}

static S2N_RESULT s2n_x509_validator_handle_cert_validation_callback_result(struct s2n_x509_validator *validator)
{
RESULT_ENSURE_REF(validator);

if (!validator->cert_validation_info.finished) {
RESULT_BAIL(S2N_ERR_ASYNC_BLOCKED);
}

RESULT_ENSURE(validator->cert_validation_info.accepted, S2N_ERR_CERT_REJECTED);
return S2N_RESULT_OK;
}

S2N_RESULT s2n_x509_validator_validate_cert_chain(struct s2n_x509_validator *validator, struct s2n_connection *conn,
uint8_t *cert_chain_in, uint32_t cert_chain_len, s2n_pkey_type *pkey_type, struct s2n_pkey *public_key_out)
{
RESULT_ENSURE_REF(validator);

if (validator->cert_validation_cb_invoked) {
RESULT_GUARD(s2n_x509_validator_handle_cert_validation_callback_result(validator));
} else {
RESULT_GUARD(s2n_x509_validator_validate_cert_chain_pre_cb(validator, conn, cert_chain_in, cert_chain_len));

if (conn->config->cert_validation_cb) {
RESULT_ENSURE(conn->config->cert_validation_cb(conn, &(validator->cert_validation_info), conn->config->cert_validation_ctx) >= S2N_SUCCESS,
S2N_ERR_CANCELLED);
validator->cert_validation_cb_invoked = true;
RESULT_GUARD(s2n_x509_validator_handle_cert_validation_callback_result(validator));
}
}

/* retrieve information from leaf cert */
Expand Down
Loading

0 comments on commit f8904b1

Please sign in to comment.