From 7155032e1ee149ba3b83fd5b98ec169b19ffbe56 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Thu, 20 Feb 2025 18:19:38 -0700 Subject: [PATCH] Add support for server-side websocket upgrade --- include/aws/http/private/h1_encoder.h | 1 + include/aws/http/private/websocket_impl.h | 11 ++ include/aws/http/websocket.h | 85 ++++++++++++++ source/h1_connection.c | 3 +- source/h1_encoder.c | 2 + source/websocket.c | 136 ++++++++++++++++++++++ source/websocket_bootstrap.c | 8 +- tests/test_connection_manager.c | 6 +- 8 files changed, 240 insertions(+), 12 deletions(-) diff --git a/include/aws/http/private/h1_encoder.h b/include/aws/http/private/h1_encoder.h index 11b4965c0..cf3871dfb 100644 --- a/include/aws/http/private/h1_encoder.h +++ b/include/aws/http/private/h1_encoder.h @@ -48,6 +48,7 @@ struct aws_h1_encoder_message { uint64_t content_length; bool has_connection_close_header; bool has_chunked_encoding_header; + bool is_switching_protocols; }; enum aws_h1_encoder_state { diff --git a/include/aws/http/private/websocket_impl.h b/include/aws/http/private/websocket_impl.h index c807be2da..77902afb0 100644 --- a/include/aws/http/private/websocket_impl.h +++ b/include/aws/http/private/websocket_impl.h @@ -111,5 +111,16 @@ AWS_HTTP_API void aws_websocket_client_bootstrap_set_system_vtable( const struct aws_websocket_client_bootstrap_system_vtable *system_vtable); +/** + * Calculate the value for the Sec-WebSocket-Accept header. + * This value is the base64 encoding of a SHA-1 hash of the Sec-WebSocket-Key concatenated with a magic string. + * out_buf should be uninitialized. + */ +AWS_HTTP_API +int aws_websocket_calculate_sec_websocket_accept( + struct aws_byte_cursor sec_websocket_key, + struct aws_byte_buf *out_buf, + struct aws_allocator *alloc); + AWS_EXTERN_C_END #endif /* AWS_HTTP_WEBSOCKET_IMPL_H */ diff --git a/include/aws/http/websocket.h b/include/aws/http/websocket.h index 39703b4e2..5dc7bedf1 100644 --- a/include/aws/http/websocket.h +++ b/include/aws/http/websocket.h @@ -11,6 +11,7 @@ AWS_PUSH_SANE_WARNING_LEVEL struct aws_http_header; struct aws_http_message; +struct aws_http_stream; /* TODO: Document lifetime stuff */ /* TODO: Document CLOSE frame behavior (when auto-sent during close, when auto-closed) */ @@ -290,6 +291,59 @@ struct aws_websocket_client_connection_options { const struct aws_host_resolution_config *host_resolution_config; }; +struct aws_websocket_server_upgrade_options { + /** + * Initial size of the websocket's read window. + * Ignored unless `manual_window_management` is true. + * Set to 0 to prevent any incoming websocket frames until aws_websocket_increment_read_window() is called. + */ + size_t initial_window_size; + + /** + * User data for callbacks. + * Optional. + */ + void *user_data; + + /** + * Called when each new frame arrives. + * Optional. + * See `aws_websocket_on_incoming_frame_begin_fn`. + */ + aws_websocket_on_incoming_frame_begin_fn *on_incoming_frame_begin; + + /** + * Called repeatedly as payload data arrives. + * Optional. + * See `aws_websocket_on_incoming_frame_payload_fn`. + */ + aws_websocket_on_incoming_frame_payload_fn *on_incoming_frame_payload; + + /** + * Called when done processing an incoming frame. + * Optional. + * See `aws_websocket_on_incoming_frame_complete_fn`. + */ + aws_websocket_on_incoming_frame_complete_fn *on_incoming_frame_complete; + + /** + * Set to true to manually manage the read window size. + * + * If this is false, no backpressure is applied and frames will arrive as fast as possible. + * + * If this is true, then whenever the read window reaches 0 you will stop receiving anything. + * The websocket's `initial_window_size` determines the starting size of the read window. + * The read window shrinks as you receive the payload from "data" frames (TEXT, BINARY, and CONTINUATION). + * Use aws_websocket_increment_read_window() to increment the window again and keep frames flowing. + * Maintain a larger window to keep up high throughput. + * You only need to worry about the payload from "data" frames. + * The websocket automatically increments the window to account for any + * other incoming bytes, including other parts of a frame (opcode, payload-length, etc) + * and the payload of other frame types (PING, PONG, CLOSE). + */ + bool manual_window_management; +}; + /** * Called repeatedly as the websocket's payload is streamed out. * The user should write payload data to out_buf, up to available capacity. @@ -486,6 +540,37 @@ struct aws_http_message *aws_http_message_new_websocket_handshake_request( struct aws_byte_cursor path, struct aws_byte_cursor host); +/** + * Return true if the request is a valid websocket upgrade request. + */ +AWS_HTTP_API +bool aws_websocket_is_websocket_request(const struct aws_http_message *request); + +/** + * Create response with all required fields for a websocket upgrade response. + * The following headers are added: + * + * Upgrade: websocket + * Connection: Upgrade + * Sec-WebSocket-Accept: + */ +AWS_HTTP_API +struct aws_http_message *aws_http_message_new_websocket_handshake_response( + struct aws_allocator *allocator, + struct aws_byte_cursor accept_key); + +/** + * Upgrade an incoming HTTP connection to a websocket connection. + * This function should be called from the on_request_done callback of a request handler. + * It expects a fully constructed request and will handle sending the handshake response + * and install the websocket handler into the channel. + */ +AWS_HTTP_API +struct aws_websocket *aws_websocket_upgrade( + struct aws_allocator *allocator, + struct aws_http_stream *stream, + const struct aws_websocket_server_upgrade_options *options); + AWS_EXTERN_C_END AWS_POP_SANE_WARNING_LEVEL diff --git a/source/h1_connection.c b/source/h1_connection.c index b3addef50..e4dedec7c 100644 --- a/source/h1_connection.c +++ b/source/h1_connection.c @@ -621,7 +621,8 @@ static void s_stream_complete(struct aws_h1_stream *stream, int error_code) { * If this is the end of a successful CONNECT request, mark ourselves as pass-through since the proxy layer * will be tacking on a new http handler (and possibly a tls handler in-between). */ - if (error_code == AWS_ERROR_SUCCESS && s_aws_http_stream_was_successful_connect(stream)) { + if (error_code == AWS_ERROR_SUCCESS && + (s_aws_http_stream_was_successful_connect(stream) || stream->encoder_message.is_switching_protocols)) { if (s_aws_http1_switch_protocols(connection)) { error_code = AWS_ERROR_HTTP_PROTOCOL_SWITCH_FAILURE; s_shutdown_due_to_error(connection, error_code); diff --git a/source/h1_encoder.c b/source/h1_encoder.c index 277dce946..1e23a6faf 100644 --- a/source/h1_encoder.c +++ b/source/h1_encoder.c @@ -371,6 +371,8 @@ int aws_h1_encoder_message_init_from_response( struct aws_byte_cursor status_text = aws_byte_cursor_from_c_str(aws_http_status_text(status_int)); + message->is_switching_protocols = status_int == AWS_HTTP_STATUS_CODE_101_SWITCHING_PROTOCOLS; + /** * Calculate total size needed for outgoing_head_buffer, then write to buffer. */ diff --git a/source/websocket.c b/source/websocket.c index da5aedbb0..86f350cf6 100644 --- a/source/websocket.c +++ b/source/websocket.c @@ -10,9 +10,11 @@ #include #include #include +#include #include #include #include +#include #include #include @@ -1790,3 +1792,137 @@ struct aws_http_message *aws_http_message_new_websocket_handshake_request( aws_http_message_destroy(request); return NULL; } + +bool aws_websocket_is_websocket_request(const struct aws_http_message *request) { + AWS_PRECONDITION(request); + + const struct aws_http_headers *headers = aws_http_message_get_headers(request); + struct aws_byte_cursor upgrade_header_value; + if (aws_http_headers_get(headers, aws_byte_cursor_from_c_str("Upgrade"), &upgrade_header_value)) { + return false; + } + + if (aws_byte_cursor_eq_c_str_ignore_case(&upgrade_header_value, "websocket") == false) { + return false; + } + + struct aws_byte_cursor connection_header_value; + if (aws_http_headers_get(headers, aws_byte_cursor_from_c_str("Connection"), &connection_header_value)) { + return false; + } + + if (aws_byte_cursor_eq_c_str_ignore_case(&connection_header_value, "Upgrade") == false) { + return false; + } + + struct aws_byte_cursor sec_websocket_key_header_value; + if (aws_http_headers_get( + headers, aws_byte_cursor_from_c_str("Sec-WebSocket-Key"), &sec_websocket_key_header_value)) { + return false; + } + + struct aws_byte_cursor sec_websocket_version_header_value; + if (aws_http_headers_get( + headers, aws_byte_cursor_from_c_str("Sec-WebSocket-Version"), &sec_websocket_version_header_value)) { + return false; + } + + if (aws_byte_cursor_eq_c_str_ignore_case(&sec_websocket_version_header_value, "13") == false) { + return false; + } + + return true; +} + +struct aws_http_message *aws_http_message_new_websocket_handshake_response( + struct aws_allocator *allocator, + struct aws_byte_cursor sec_websocket_key) { + + AWS_PRECONDITION(allocator); + AWS_PRECONDITION(aws_byte_cursor_is_valid(&sec_websocket_key)); + + struct aws_http_message *response = aws_http_message_new_response(allocator); + if (!response) { + goto error; + } + + int err = aws_http_message_set_response_status(response, AWS_HTTP_STATUS_CODE_101_SWITCHING_PROTOCOLS); + if (err) { + goto error; + } + + struct aws_byte_buf expected_sec_websocket_accept = aws_byte_buf_from_array( + (uint8_t[]){0}, 0); /* This will be filled in by aws_websocket_calculate_sec_websocket_accept */ + if (aws_websocket_calculate_sec_websocket_accept(sec_websocket_key, &expected_sec_websocket_accept, allocator)) { + goto error; + } + + struct aws_http_header required_headers[] = { + { + .name = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("Upgrade"), + .value = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("websocket"), + }, + { + .name = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("Connection"), + .value = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("Upgrade"), + }, + { + .name = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("Sec-WebSocket-Accept"), + .value = aws_byte_cursor_from_buf(&expected_sec_websocket_accept), + }, + }; + + for (size_t i = 0; i < AWS_ARRAY_SIZE(required_headers); ++i) { + err = aws_http_message_add_header(response, required_headers[i]); + if (err) { + goto error; + } + } + + return response; +error: + aws_http_message_destroy(response); + return NULL; +} + +struct aws_websocket *aws_websocket_upgrade( + struct aws_allocator *allocator, + struct aws_http_stream *stream, + const struct aws_websocket_server_upgrade_options *options) { + + AWS_PRECONDITION(stream); + AWS_PRECONDITION(options); + + /* Insert websocket handler into channel */ + struct aws_http_connection *http_connection = aws_http_stream_get_connection(stream); + AWS_ASSERT(http_connection); + + struct aws_channel *channel = aws_http_connection_get_channel(http_connection); + AWS_ASSERT(channel); + + struct aws_websocket_handler_options ws_options = { + .allocator = allocator, + .channel = channel, + .initial_window_size = options->initial_window_size, + .user_data = options->user_data, + .on_incoming_frame_begin = options->on_incoming_frame_begin, + .on_incoming_frame_payload = options->on_incoming_frame_payload, + .on_incoming_frame_complete = options->on_incoming_frame_complete, + .is_server = true, + .manual_window_update = options->manual_window_management, + }; + + struct aws_websocket *websocket = aws_websocket_handler_new(&ws_options); + if (!websocket) { + AWS_LOGF_ERROR(AWS_LS_HTTP_WEBSOCKET, "Failed to create websocket handler."); + return NULL; + } + + /* Success! Setup complete! */ + AWS_LOGF_DEBUG(/* Debug log about creation of websocket. */ + AWS_LS_HTTP_WEBSOCKET, + "id=%p: Websocket upgrade complete.", + (void *)websocket); + + return websocket; +} diff --git a/source/websocket_bootstrap.c b/source/websocket_bootstrap.c index 6c66c8515..367f5be84 100644 --- a/source/websocket_bootstrap.c +++ b/source/websocket_bootstrap.c @@ -89,10 +89,6 @@ struct aws_websocket_client_bootstrap { }; static void s_ws_bootstrap_destroy(struct aws_websocket_client_bootstrap *ws_bootstrap); -static int s_ws_bootstrap_calculate_sec_websocket_accept( - struct aws_byte_cursor sec_websocket_key, - struct aws_byte_buf *out_buf, - struct aws_allocator *alloc); static void s_ws_bootstrap_cancel_setup_due_to_err( struct aws_websocket_client_bootstrap *ws_bootstrap, struct aws_http_connection *http_connection, @@ -181,7 +177,7 @@ int aws_websocket_client_connect(const struct aws_websocket_client_connection_op ws_bootstrap->response_headers = aws_http_headers_new(ws_bootstrap->alloc); aws_byte_buf_init(&ws_bootstrap->response_body, ws_bootstrap->alloc, 0); - if (s_ws_bootstrap_calculate_sec_websocket_accept( + if (aws_websocket_calculate_sec_websocket_accept( sec_websocket_key, &ws_bootstrap->expected_sec_websocket_accept, ws_bootstrap->alloc)) { goto error; } @@ -270,7 +266,7 @@ static void s_ws_bootstrap_destroy(struct aws_websocket_client_bootstrap *ws_boo * "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" but ignoring any leading and * trailing whitespace */ -static int s_ws_bootstrap_calculate_sec_websocket_accept( +int aws_websocket_calculate_sec_websocket_accept( struct aws_byte_cursor sec_websocket_key, struct aws_byte_buf *out_buf, struct aws_allocator *alloc) { diff --git a/tests/test_connection_manager.c b/tests/test_connection_manager.c index e4dcf78b4..bee0a09ec 100644 --- a/tests/test_connection_manager.c +++ b/tests/test_connection_manager.c @@ -161,11 +161,7 @@ static int s_cm_tester_init(struct cm_tester_options *options) { clock_fn = options->mock_table->aws_high_res_clock_get_ticks; } - struct aws_event_loop_group_options elg_options = { - .loop_count = 1, - .clock_override = clock_fn, - }; - tester->event_loop_group = aws_event_loop_group_new(tester->allocator, &elg_options); + tester->event_loop_group = aws_event_loop_group_new(tester->allocator, clock_fn, 1, NULL, NULL, NULL); struct aws_host_resolver_default_options resolver_options = { .el_group = tester->event_loop_group,